xinference 0.14.3__py3-none-any.whl → 0.14.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/worker.py +18 -9
- xinference/model/audio/chattts.py +4 -3
- xinference/model/audio/cosyvoice.py +4 -3
- xinference/model/audio/custom.py +4 -5
- xinference/model/embedding/core.py +2 -0
- xinference/model/embedding/custom.py +4 -5
- xinference/model/flexible/core.py +5 -1
- xinference/model/image/custom.py +4 -5
- xinference/model/image/stable_diffusion/core.py +21 -6
- xinference/model/llm/llm_family.py +5 -6
- xinference/model/llm/sglang/core.py +7 -1
- xinference/model/llm/transformers/core.py +2 -0
- xinference/model/llm/utils.py +3 -0
- xinference/model/llm/vllm/core.py +0 -33
- xinference/model/rerank/custom.py +4 -5
- xinference/model/utils.py +41 -1
- xinference/model/video/core.py +3 -1
- xinference/model/video/diffusers.py +41 -38
- xinference/model/video/model_spec.json +24 -1
- xinference/model/video/model_spec_modelscope.json +25 -1
- xinference/thirdparty/fish_speech/tools/api.py +1 -1
- xinference/thirdparty/matcha/__init__.py +0 -0
- xinference/thirdparty/matcha/app.py +357 -0
- xinference/thirdparty/matcha/cli.py +419 -0
- xinference/thirdparty/matcha/data/__init__.py +0 -0
- xinference/thirdparty/matcha/data/components/__init__.py +0 -0
- xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
- xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
- xinference/thirdparty/matcha/hifigan/config.py +28 -0
- xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
- xinference/thirdparty/matcha/hifigan/env.py +17 -0
- xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
- xinference/thirdparty/matcha/hifigan/models.py +368 -0
- xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
- xinference/thirdparty/matcha/models/__init__.py +0 -0
- xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
- xinference/thirdparty/matcha/models/components/__init__.py +0 -0
- xinference/thirdparty/matcha/models/components/decoder.py +443 -0
- xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
- xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
- xinference/thirdparty/matcha/models/components/transformer.py +316 -0
- xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
- xinference/thirdparty/matcha/onnx/__init__.py +0 -0
- xinference/thirdparty/matcha/onnx/export.py +181 -0
- xinference/thirdparty/matcha/onnx/infer.py +168 -0
- xinference/thirdparty/matcha/text/__init__.py +53 -0
- xinference/thirdparty/matcha/text/cleaners.py +121 -0
- xinference/thirdparty/matcha/text/numbers.py +71 -0
- xinference/thirdparty/matcha/text/symbols.py +17 -0
- xinference/thirdparty/matcha/train.py +122 -0
- xinference/thirdparty/matcha/utils/__init__.py +5 -0
- xinference/thirdparty/matcha/utils/audio.py +82 -0
- xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
- xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
- xinference/thirdparty/matcha/utils/instantiators.py +56 -0
- xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
- xinference/thirdparty/matcha/utils/model.py +90 -0
- xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
- xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
- xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
- xinference/thirdparty/matcha/utils/pylogger.py +21 -0
- xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
- xinference/thirdparty/matcha/utils/utils.py +259 -0
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/METADATA +20 -12
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/RECORD +70 -28
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,419 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import datetime as dt
|
|
3
|
+
import os
|
|
4
|
+
import warnings
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import numpy as np
|
|
9
|
+
import soundfile as sf
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from matcha.hifigan.config import v1
|
|
13
|
+
from matcha.hifigan.denoiser import Denoiser
|
|
14
|
+
from matcha.hifigan.env import AttrDict
|
|
15
|
+
from matcha.hifigan.models import Generator as HiFiGAN
|
|
16
|
+
from matcha.models.matcha_tts import MatchaTTS
|
|
17
|
+
from matcha.text import sequence_to_text, text_to_sequence
|
|
18
|
+
from matcha.utils.utils import assert_model_downloaded, get_user_data_dir, intersperse
|
|
19
|
+
|
|
20
|
+
MATCHA_URLS = {
|
|
21
|
+
"matcha_ljspeech": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_ljspeech.ckpt",
|
|
22
|
+
"matcha_vctk": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/matcha_vctk.ckpt",
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
VOCODER_URLS = {
|
|
26
|
+
"hifigan_T2_v1": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/generator_v1", # Old url: https://drive.google.com/file/d/14NENd4equCBLyyCSke114Mv6YR_j_uFs/view?usp=drive_link
|
|
27
|
+
"hifigan_univ_v1": "https://github.com/shivammehta25/Matcha-TTS-checkpoints/releases/download/v1.0/g_02500000", # Old url: https://drive.google.com/file/d/1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW/view?usp=drive_link
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
MULTISPEAKER_MODEL = {
|
|
31
|
+
"matcha_vctk": {"vocoder": "hifigan_univ_v1", "speaking_rate": 0.85, "spk": 0, "spk_range": (0, 107)}
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
SINGLESPEAKER_MODEL = {"matcha_ljspeech": {"vocoder": "hifigan_T2_v1", "speaking_rate": 0.95, "spk": None}}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def plot_spectrogram_to_numpy(spectrogram, filename):
|
|
38
|
+
fig, ax = plt.subplots(figsize=(12, 3))
|
|
39
|
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
|
40
|
+
plt.colorbar(im, ax=ax)
|
|
41
|
+
plt.xlabel("Frames")
|
|
42
|
+
plt.ylabel("Channels")
|
|
43
|
+
plt.title("Synthesised Mel-Spectrogram")
|
|
44
|
+
fig.canvas.draw()
|
|
45
|
+
plt.savefig(filename)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def process_text(i: int, text: str, device: torch.device):
|
|
49
|
+
print(f"[{i}] - Input text: {text}")
|
|
50
|
+
x = torch.tensor(
|
|
51
|
+
intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0),
|
|
52
|
+
dtype=torch.long,
|
|
53
|
+
device=device,
|
|
54
|
+
)[None]
|
|
55
|
+
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
|
|
56
|
+
x_phones = sequence_to_text(x.squeeze(0).tolist())
|
|
57
|
+
print(f"[{i}] - Phonetised text: {x_phones[1::2]}")
|
|
58
|
+
|
|
59
|
+
return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_texts(args):
|
|
63
|
+
if args.text:
|
|
64
|
+
texts = [args.text]
|
|
65
|
+
else:
|
|
66
|
+
with open(args.file, encoding="utf-8") as f:
|
|
67
|
+
texts = f.readlines()
|
|
68
|
+
return texts
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def assert_required_models_available(args):
|
|
72
|
+
save_dir = get_user_data_dir()
|
|
73
|
+
if not hasattr(args, "checkpoint_path") and args.checkpoint_path is None:
|
|
74
|
+
model_path = args.checkpoint_path
|
|
75
|
+
else:
|
|
76
|
+
model_path = save_dir / f"{args.model}.ckpt"
|
|
77
|
+
assert_model_downloaded(model_path, MATCHA_URLS[args.model])
|
|
78
|
+
|
|
79
|
+
vocoder_path = save_dir / f"{args.vocoder}"
|
|
80
|
+
assert_model_downloaded(vocoder_path, VOCODER_URLS[args.vocoder])
|
|
81
|
+
return {"matcha": model_path, "vocoder": vocoder_path}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def load_hifigan(checkpoint_path, device):
|
|
85
|
+
h = AttrDict(v1)
|
|
86
|
+
hifigan = HiFiGAN(h).to(device)
|
|
87
|
+
hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)["generator"])
|
|
88
|
+
_ = hifigan.eval()
|
|
89
|
+
hifigan.remove_weight_norm()
|
|
90
|
+
return hifigan
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def load_vocoder(vocoder_name, checkpoint_path, device):
|
|
94
|
+
print(f"[!] Loading {vocoder_name}!")
|
|
95
|
+
vocoder = None
|
|
96
|
+
if vocoder_name in ("hifigan_T2_v1", "hifigan_univ_v1"):
|
|
97
|
+
vocoder = load_hifigan(checkpoint_path, device)
|
|
98
|
+
else:
|
|
99
|
+
raise NotImplementedError(
|
|
100
|
+
f"Vocoder {vocoder_name} not implemented! define a load_<<vocoder_name>> method for it"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
denoiser = Denoiser(vocoder, mode="zeros")
|
|
104
|
+
print(f"[+] {vocoder_name} loaded!")
|
|
105
|
+
return vocoder, denoiser
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def load_matcha(model_name, checkpoint_path, device):
|
|
109
|
+
print(f"[!] Loading {model_name}!")
|
|
110
|
+
model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)
|
|
111
|
+
_ = model.eval()
|
|
112
|
+
|
|
113
|
+
print(f"[+] {model_name} loaded!")
|
|
114
|
+
return model
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def to_waveform(mel, vocoder, denoiser=None):
|
|
118
|
+
audio = vocoder(mel).clamp(-1, 1)
|
|
119
|
+
if denoiser is not None:
|
|
120
|
+
audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze()
|
|
121
|
+
|
|
122
|
+
return audio.cpu().squeeze()
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def save_to_folder(filename: str, output: dict, folder: str):
|
|
126
|
+
folder = Path(folder)
|
|
127
|
+
folder.mkdir(exist_ok=True, parents=True)
|
|
128
|
+
plot_spectrogram_to_numpy(np.array(output["mel"].squeeze().float().cpu()), f"{filename}.png")
|
|
129
|
+
np.save(folder / f"{filename}", output["mel"].cpu().numpy())
|
|
130
|
+
sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24")
|
|
131
|
+
return folder.resolve() / f"{filename}.wav"
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def validate_args(args):
|
|
135
|
+
assert (
|
|
136
|
+
args.text or args.file
|
|
137
|
+
), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms."
|
|
138
|
+
assert args.temperature >= 0, "Sampling temperature cannot be negative"
|
|
139
|
+
assert args.steps > 0, "Number of ODE steps must be greater than 0"
|
|
140
|
+
|
|
141
|
+
if args.checkpoint_path is None:
|
|
142
|
+
# When using pretrained models
|
|
143
|
+
if args.model in SINGLESPEAKER_MODEL:
|
|
144
|
+
args = validate_args_for_single_speaker_model(args)
|
|
145
|
+
|
|
146
|
+
if args.model in MULTISPEAKER_MODEL:
|
|
147
|
+
args = validate_args_for_multispeaker_model(args)
|
|
148
|
+
else:
|
|
149
|
+
# When using a custom model
|
|
150
|
+
if args.vocoder != "hifigan_univ_v1":
|
|
151
|
+
warn_ = "[-] Using custom model checkpoint! I would suggest passing --vocoder hifigan_univ_v1, unless the custom model is trained on LJ Speech."
|
|
152
|
+
warnings.warn(warn_, UserWarning)
|
|
153
|
+
if args.speaking_rate is None:
|
|
154
|
+
args.speaking_rate = 1.0
|
|
155
|
+
|
|
156
|
+
if args.batched:
|
|
157
|
+
assert args.batch_size > 0, "Batch size must be greater than 0"
|
|
158
|
+
assert args.speaking_rate > 0, "Speaking rate must be greater than 0"
|
|
159
|
+
|
|
160
|
+
return args
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def validate_args_for_multispeaker_model(args):
|
|
164
|
+
if args.vocoder is not None:
|
|
165
|
+
if args.vocoder != MULTISPEAKER_MODEL[args.model]["vocoder"]:
|
|
166
|
+
warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {MULTISPEAKER_MODEL[args.model]['vocoder']}"
|
|
167
|
+
warnings.warn(warn_, UserWarning)
|
|
168
|
+
else:
|
|
169
|
+
args.vocoder = MULTISPEAKER_MODEL[args.model]["vocoder"]
|
|
170
|
+
|
|
171
|
+
if args.speaking_rate is None:
|
|
172
|
+
args.speaking_rate = MULTISPEAKER_MODEL[args.model]["speaking_rate"]
|
|
173
|
+
|
|
174
|
+
spk_range = MULTISPEAKER_MODEL[args.model]["spk_range"]
|
|
175
|
+
if args.spk is not None:
|
|
176
|
+
assert (
|
|
177
|
+
args.spk >= spk_range[0] and args.spk <= spk_range[-1]
|
|
178
|
+
), f"Speaker ID must be between {spk_range} for this model."
|
|
179
|
+
else:
|
|
180
|
+
available_spk_id = MULTISPEAKER_MODEL[args.model]["spk"]
|
|
181
|
+
warn_ = f"[!] Speaker ID not provided! Using speaker ID {available_spk_id}"
|
|
182
|
+
warnings.warn(warn_, UserWarning)
|
|
183
|
+
args.spk = available_spk_id
|
|
184
|
+
|
|
185
|
+
return args
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def validate_args_for_single_speaker_model(args):
|
|
189
|
+
if args.vocoder is not None:
|
|
190
|
+
if args.vocoder != SINGLESPEAKER_MODEL[args.model]["vocoder"]:
|
|
191
|
+
warn_ = f"[-] Using {args.model} model! I would suggest passing --vocoder {SINGLESPEAKER_MODEL[args.model]['vocoder']}"
|
|
192
|
+
warnings.warn(warn_, UserWarning)
|
|
193
|
+
else:
|
|
194
|
+
args.vocoder = SINGLESPEAKER_MODEL[args.model]["vocoder"]
|
|
195
|
+
|
|
196
|
+
if args.speaking_rate is None:
|
|
197
|
+
args.speaking_rate = SINGLESPEAKER_MODEL[args.model]["speaking_rate"]
|
|
198
|
+
|
|
199
|
+
if args.spk != SINGLESPEAKER_MODEL[args.model]["spk"]:
|
|
200
|
+
warn_ = f"[-] Ignoring speaker id {args.spk} for {args.model}"
|
|
201
|
+
warnings.warn(warn_, UserWarning)
|
|
202
|
+
args.spk = SINGLESPEAKER_MODEL[args.model]["spk"]
|
|
203
|
+
|
|
204
|
+
return args
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@torch.inference_mode()
|
|
208
|
+
def cli():
|
|
209
|
+
parser = argparse.ArgumentParser(
|
|
210
|
+
description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching"
|
|
211
|
+
)
|
|
212
|
+
parser.add_argument(
|
|
213
|
+
"--model",
|
|
214
|
+
type=str,
|
|
215
|
+
default="matcha_ljspeech",
|
|
216
|
+
help="Model to use",
|
|
217
|
+
choices=MATCHA_URLS.keys(),
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
parser.add_argument(
|
|
221
|
+
"--checkpoint_path",
|
|
222
|
+
type=str,
|
|
223
|
+
default=None,
|
|
224
|
+
help="Path to the custom model checkpoint",
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
parser.add_argument(
|
|
228
|
+
"--vocoder",
|
|
229
|
+
type=str,
|
|
230
|
+
default=None,
|
|
231
|
+
help="Vocoder to use (default: will use the one suggested with the pretrained model))",
|
|
232
|
+
choices=VOCODER_URLS.keys(),
|
|
233
|
+
)
|
|
234
|
+
parser.add_argument("--text", type=str, default=None, help="Text to synthesize")
|
|
235
|
+
parser.add_argument("--file", type=str, default=None, help="Text file to synthesize")
|
|
236
|
+
parser.add_argument("--spk", type=int, default=None, help="Speaker ID")
|
|
237
|
+
parser.add_argument(
|
|
238
|
+
"--temperature",
|
|
239
|
+
type=float,
|
|
240
|
+
default=0.667,
|
|
241
|
+
help="Variance of the x0 noise (default: 0.667)",
|
|
242
|
+
)
|
|
243
|
+
parser.add_argument(
|
|
244
|
+
"--speaking_rate",
|
|
245
|
+
type=float,
|
|
246
|
+
default=None,
|
|
247
|
+
help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)",
|
|
248
|
+
)
|
|
249
|
+
parser.add_argument("--steps", type=int, default=10, help="Number of ODE steps (default: 10)")
|
|
250
|
+
parser.add_argument("--cpu", action="store_true", help="Use CPU for inference (default: use GPU if available)")
|
|
251
|
+
parser.add_argument(
|
|
252
|
+
"--denoiser_strength",
|
|
253
|
+
type=float,
|
|
254
|
+
default=0.00025,
|
|
255
|
+
help="Strength of the vocoder bias denoiser (default: 0.00025)",
|
|
256
|
+
)
|
|
257
|
+
parser.add_argument(
|
|
258
|
+
"--output_folder",
|
|
259
|
+
type=str,
|
|
260
|
+
default=os.getcwd(),
|
|
261
|
+
help="Output folder to save results (default: current dir)",
|
|
262
|
+
)
|
|
263
|
+
parser.add_argument("--batched", action="store_true", help="Batched inference (default: False)")
|
|
264
|
+
parser.add_argument(
|
|
265
|
+
"--batch_size", type=int, default=32, help="Batch size only useful when --batched (default: 32)"
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
args = parser.parse_args()
|
|
269
|
+
|
|
270
|
+
args = validate_args(args)
|
|
271
|
+
device = get_device(args)
|
|
272
|
+
print_config(args)
|
|
273
|
+
paths = assert_required_models_available(args)
|
|
274
|
+
|
|
275
|
+
if args.checkpoint_path is not None:
|
|
276
|
+
print(f"[🍵] Loading custom model from {args.checkpoint_path}")
|
|
277
|
+
paths["matcha"] = args.checkpoint_path
|
|
278
|
+
args.model = "custom_model"
|
|
279
|
+
|
|
280
|
+
model = load_matcha(args.model, paths["matcha"], device)
|
|
281
|
+
vocoder, denoiser = load_vocoder(args.vocoder, paths["vocoder"], device)
|
|
282
|
+
|
|
283
|
+
texts = get_texts(args)
|
|
284
|
+
|
|
285
|
+
spk = torch.tensor([args.spk], device=device, dtype=torch.long) if args.spk is not None else None
|
|
286
|
+
if len(texts) == 1 or not args.batched:
|
|
287
|
+
unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk)
|
|
288
|
+
else:
|
|
289
|
+
batched_synthesis(args, device, model, vocoder, denoiser, texts, spk)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class BatchedSynthesisDataset(torch.utils.data.Dataset):
|
|
293
|
+
def __init__(self, processed_texts):
|
|
294
|
+
self.processed_texts = processed_texts
|
|
295
|
+
|
|
296
|
+
def __len__(self):
|
|
297
|
+
return len(self.processed_texts)
|
|
298
|
+
|
|
299
|
+
def __getitem__(self, idx):
|
|
300
|
+
return self.processed_texts[idx]
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def batched_collate_fn(batch):
|
|
304
|
+
x = []
|
|
305
|
+
x_lengths = []
|
|
306
|
+
|
|
307
|
+
for b in batch:
|
|
308
|
+
x.append(b["x"].squeeze(0))
|
|
309
|
+
x_lengths.append(b["x_lengths"])
|
|
310
|
+
|
|
311
|
+
x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True)
|
|
312
|
+
x_lengths = torch.concat(x_lengths, dim=0)
|
|
313
|
+
return {"x": x, "x_lengths": x_lengths}
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def batched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
|
|
317
|
+
total_rtf = []
|
|
318
|
+
total_rtf_w = []
|
|
319
|
+
processed_text = [process_text(i, text, "cpu") for i, text in enumerate(texts)]
|
|
320
|
+
dataloader = torch.utils.data.DataLoader(
|
|
321
|
+
BatchedSynthesisDataset(processed_text),
|
|
322
|
+
batch_size=args.batch_size,
|
|
323
|
+
collate_fn=batched_collate_fn,
|
|
324
|
+
num_workers=8,
|
|
325
|
+
)
|
|
326
|
+
for i, batch in enumerate(dataloader):
|
|
327
|
+
i = i + 1
|
|
328
|
+
start_t = dt.datetime.now()
|
|
329
|
+
b = batch["x"].shape[0]
|
|
330
|
+
output = model.synthesise(
|
|
331
|
+
batch["x"].to(device),
|
|
332
|
+
batch["x_lengths"].to(device),
|
|
333
|
+
n_timesteps=args.steps,
|
|
334
|
+
temperature=args.temperature,
|
|
335
|
+
spks=spk.expand(b) if spk is not None else spk,
|
|
336
|
+
length_scale=args.speaking_rate,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
|
340
|
+
t = (dt.datetime.now() - start_t).total_seconds()
|
|
341
|
+
rtf_w = t * 22050 / (output["waveform"].shape[-1])
|
|
342
|
+
print(f"[🍵-Batch: {i}] Matcha-TTS RTF: {output['rtf']:.4f}")
|
|
343
|
+
print(f"[🍵-Batch: {i}] Matcha-TTS + VOCODER RTF: {rtf_w:.4f}")
|
|
344
|
+
total_rtf.append(output["rtf"])
|
|
345
|
+
total_rtf_w.append(rtf_w)
|
|
346
|
+
for j in range(output["mel"].shape[0]):
|
|
347
|
+
base_name = f"utterance_{j:03d}_speaker_{args.spk:03d}" if args.spk is not None else f"utterance_{j:03d}"
|
|
348
|
+
length = output["mel_lengths"][j]
|
|
349
|
+
new_dict = {"mel": output["mel"][j][:, :length], "waveform": output["waveform"][j][: length * 256]}
|
|
350
|
+
location = save_to_folder(base_name, new_dict, args.output_folder)
|
|
351
|
+
print(f"[🍵-{j}] Waveform saved: {location}")
|
|
352
|
+
|
|
353
|
+
print("".join(["="] * 100))
|
|
354
|
+
print(f"[🍵] Average Matcha-TTS RTF: {np.mean(total_rtf):.4f} ± {np.std(total_rtf)}")
|
|
355
|
+
print(f"[🍵] Average Matcha-TTS + VOCODER RTF: {np.mean(total_rtf_w):.4f} ± {np.std(total_rtf_w)}")
|
|
356
|
+
print("[🍵] Enjoy the freshly whisked 🍵 Matcha-TTS!")
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def unbatched_synthesis(args, device, model, vocoder, denoiser, texts, spk):
|
|
360
|
+
total_rtf = []
|
|
361
|
+
total_rtf_w = []
|
|
362
|
+
for i, text in enumerate(texts):
|
|
363
|
+
i = i + 1
|
|
364
|
+
base_name = f"utterance_{i:03d}_speaker_{args.spk:03d}" if args.spk is not None else f"utterance_{i:03d}"
|
|
365
|
+
|
|
366
|
+
print("".join(["="] * 100))
|
|
367
|
+
text = text.strip()
|
|
368
|
+
text_processed = process_text(i, text, device)
|
|
369
|
+
|
|
370
|
+
print(f"[🍵] Whisking Matcha-T(ea)TS for: {i}")
|
|
371
|
+
start_t = dt.datetime.now()
|
|
372
|
+
output = model.synthesise(
|
|
373
|
+
text_processed["x"],
|
|
374
|
+
text_processed["x_lengths"],
|
|
375
|
+
n_timesteps=args.steps,
|
|
376
|
+
temperature=args.temperature,
|
|
377
|
+
spks=spk,
|
|
378
|
+
length_scale=args.speaking_rate,
|
|
379
|
+
)
|
|
380
|
+
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
|
381
|
+
# RTF with HiFiGAN
|
|
382
|
+
t = (dt.datetime.now() - start_t).total_seconds()
|
|
383
|
+
rtf_w = t * 22050 / (output["waveform"].shape[-1])
|
|
384
|
+
print(f"[🍵-{i}] Matcha-TTS RTF: {output['rtf']:.4f}")
|
|
385
|
+
print(f"[🍵-{i}] Matcha-TTS + VOCODER RTF: {rtf_w:.4f}")
|
|
386
|
+
total_rtf.append(output["rtf"])
|
|
387
|
+
total_rtf_w.append(rtf_w)
|
|
388
|
+
|
|
389
|
+
location = save_to_folder(base_name, output, args.output_folder)
|
|
390
|
+
print(f"[+] Waveform saved: {location}")
|
|
391
|
+
|
|
392
|
+
print("".join(["="] * 100))
|
|
393
|
+
print(f"[🍵] Average Matcha-TTS RTF: {np.mean(total_rtf):.4f} ± {np.std(total_rtf)}")
|
|
394
|
+
print(f"[🍵] Average Matcha-TTS + VOCODER RTF: {np.mean(total_rtf_w):.4f} ± {np.std(total_rtf_w)}")
|
|
395
|
+
print("[🍵] Enjoy the freshly whisked 🍵 Matcha-TTS!")
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def print_config(args):
|
|
399
|
+
print("[!] Configurations: ")
|
|
400
|
+
print(f"\t- Model: {args.model}")
|
|
401
|
+
print(f"\t- Vocoder: {args.vocoder}")
|
|
402
|
+
print(f"\t- Temperature: {args.temperature}")
|
|
403
|
+
print(f"\t- Speaking rate: {args.speaking_rate}")
|
|
404
|
+
print(f"\t- Number of ODE steps: {args.steps}")
|
|
405
|
+
print(f"\t- Speaker: {args.spk}")
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def get_device(args):
|
|
409
|
+
if torch.cuda.is_available() and not args.cpu:
|
|
410
|
+
print("[+] GPU Available! Using GPU")
|
|
411
|
+
device = torch.device("cuda")
|
|
412
|
+
else:
|
|
413
|
+
print("[-] GPU not available or forced CPU run! Using CPU")
|
|
414
|
+
device = torch.device("cpu")
|
|
415
|
+
return device
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
if __name__ == "__main__":
|
|
419
|
+
cli()
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Dict, Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import torchaudio as ta
|
|
8
|
+
from lightning import LightningDataModule
|
|
9
|
+
from torch.utils.data.dataloader import DataLoader
|
|
10
|
+
|
|
11
|
+
from matcha.text import text_to_sequence
|
|
12
|
+
from matcha.utils.audio import mel_spectrogram
|
|
13
|
+
from matcha.utils.model import fix_len_compatibility, normalize
|
|
14
|
+
from matcha.utils.utils import intersperse
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def parse_filelist(filelist_path, split_char="|"):
|
|
18
|
+
with open(filelist_path, encoding="utf-8") as f:
|
|
19
|
+
filepaths_and_text = [line.strip().split(split_char) for line in f]
|
|
20
|
+
return filepaths_and_text
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TextMelDataModule(LightningDataModule):
|
|
24
|
+
def __init__( # pylint: disable=unused-argument
|
|
25
|
+
self,
|
|
26
|
+
name,
|
|
27
|
+
train_filelist_path,
|
|
28
|
+
valid_filelist_path,
|
|
29
|
+
batch_size,
|
|
30
|
+
num_workers,
|
|
31
|
+
pin_memory,
|
|
32
|
+
cleaners,
|
|
33
|
+
add_blank,
|
|
34
|
+
n_spks,
|
|
35
|
+
n_fft,
|
|
36
|
+
n_feats,
|
|
37
|
+
sample_rate,
|
|
38
|
+
hop_length,
|
|
39
|
+
win_length,
|
|
40
|
+
f_min,
|
|
41
|
+
f_max,
|
|
42
|
+
data_statistics,
|
|
43
|
+
seed,
|
|
44
|
+
load_durations,
|
|
45
|
+
):
|
|
46
|
+
super().__init__()
|
|
47
|
+
|
|
48
|
+
# this line allows to access init params with 'self.hparams' attribute
|
|
49
|
+
# also ensures init params will be stored in ckpt
|
|
50
|
+
self.save_hyperparameters(logger=False)
|
|
51
|
+
|
|
52
|
+
def setup(self, stage: Optional[str] = None): # pylint: disable=unused-argument
|
|
53
|
+
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
|
|
54
|
+
|
|
55
|
+
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
|
|
56
|
+
careful not to execute things like random split twice!
|
|
57
|
+
"""
|
|
58
|
+
# load and split datasets only if not loaded already
|
|
59
|
+
|
|
60
|
+
self.trainset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
|
|
61
|
+
self.hparams.train_filelist_path,
|
|
62
|
+
self.hparams.n_spks,
|
|
63
|
+
self.hparams.cleaners,
|
|
64
|
+
self.hparams.add_blank,
|
|
65
|
+
self.hparams.n_fft,
|
|
66
|
+
self.hparams.n_feats,
|
|
67
|
+
self.hparams.sample_rate,
|
|
68
|
+
self.hparams.hop_length,
|
|
69
|
+
self.hparams.win_length,
|
|
70
|
+
self.hparams.f_min,
|
|
71
|
+
self.hparams.f_max,
|
|
72
|
+
self.hparams.data_statistics,
|
|
73
|
+
self.hparams.seed,
|
|
74
|
+
self.hparams.load_durations,
|
|
75
|
+
)
|
|
76
|
+
self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
|
|
77
|
+
self.hparams.valid_filelist_path,
|
|
78
|
+
self.hparams.n_spks,
|
|
79
|
+
self.hparams.cleaners,
|
|
80
|
+
self.hparams.add_blank,
|
|
81
|
+
self.hparams.n_fft,
|
|
82
|
+
self.hparams.n_feats,
|
|
83
|
+
self.hparams.sample_rate,
|
|
84
|
+
self.hparams.hop_length,
|
|
85
|
+
self.hparams.win_length,
|
|
86
|
+
self.hparams.f_min,
|
|
87
|
+
self.hparams.f_max,
|
|
88
|
+
self.hparams.data_statistics,
|
|
89
|
+
self.hparams.seed,
|
|
90
|
+
self.hparams.load_durations,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def train_dataloader(self):
|
|
94
|
+
return DataLoader(
|
|
95
|
+
dataset=self.trainset,
|
|
96
|
+
batch_size=self.hparams.batch_size,
|
|
97
|
+
num_workers=self.hparams.num_workers,
|
|
98
|
+
pin_memory=self.hparams.pin_memory,
|
|
99
|
+
shuffle=True,
|
|
100
|
+
collate_fn=TextMelBatchCollate(self.hparams.n_spks),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def val_dataloader(self):
|
|
104
|
+
return DataLoader(
|
|
105
|
+
dataset=self.validset,
|
|
106
|
+
batch_size=self.hparams.batch_size,
|
|
107
|
+
num_workers=self.hparams.num_workers,
|
|
108
|
+
pin_memory=self.hparams.pin_memory,
|
|
109
|
+
shuffle=False,
|
|
110
|
+
collate_fn=TextMelBatchCollate(self.hparams.n_spks),
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def teardown(self, stage: Optional[str] = None):
|
|
114
|
+
"""Clean up after fit or test."""
|
|
115
|
+
pass # pylint: disable=unnecessary-pass
|
|
116
|
+
|
|
117
|
+
def state_dict(self):
|
|
118
|
+
"""Extra things to save to checkpoint."""
|
|
119
|
+
return {}
|
|
120
|
+
|
|
121
|
+
def load_state_dict(self, state_dict: Dict[str, Any]):
|
|
122
|
+
"""Things to do when loading checkpoint."""
|
|
123
|
+
pass # pylint: disable=unnecessary-pass
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class TextMelDataset(torch.utils.data.Dataset):
|
|
127
|
+
def __init__(
|
|
128
|
+
self,
|
|
129
|
+
filelist_path,
|
|
130
|
+
n_spks,
|
|
131
|
+
cleaners,
|
|
132
|
+
add_blank=True,
|
|
133
|
+
n_fft=1024,
|
|
134
|
+
n_mels=80,
|
|
135
|
+
sample_rate=22050,
|
|
136
|
+
hop_length=256,
|
|
137
|
+
win_length=1024,
|
|
138
|
+
f_min=0.0,
|
|
139
|
+
f_max=8000,
|
|
140
|
+
data_parameters=None,
|
|
141
|
+
seed=None,
|
|
142
|
+
load_durations=False,
|
|
143
|
+
):
|
|
144
|
+
self.filepaths_and_text = parse_filelist(filelist_path)
|
|
145
|
+
self.n_spks = n_spks
|
|
146
|
+
self.cleaners = cleaners
|
|
147
|
+
self.add_blank = add_blank
|
|
148
|
+
self.n_fft = n_fft
|
|
149
|
+
self.n_mels = n_mels
|
|
150
|
+
self.sample_rate = sample_rate
|
|
151
|
+
self.hop_length = hop_length
|
|
152
|
+
self.win_length = win_length
|
|
153
|
+
self.f_min = f_min
|
|
154
|
+
self.f_max = f_max
|
|
155
|
+
self.load_durations = load_durations
|
|
156
|
+
|
|
157
|
+
if data_parameters is not None:
|
|
158
|
+
self.data_parameters = data_parameters
|
|
159
|
+
else:
|
|
160
|
+
self.data_parameters = {"mel_mean": 0, "mel_std": 1}
|
|
161
|
+
random.seed(seed)
|
|
162
|
+
random.shuffle(self.filepaths_and_text)
|
|
163
|
+
|
|
164
|
+
def get_datapoint(self, filepath_and_text):
|
|
165
|
+
if self.n_spks > 1:
|
|
166
|
+
filepath, spk, text = (
|
|
167
|
+
filepath_and_text[0],
|
|
168
|
+
int(filepath_and_text[1]),
|
|
169
|
+
filepath_and_text[2],
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
filepath, text = filepath_and_text[0], filepath_and_text[1]
|
|
173
|
+
spk = None
|
|
174
|
+
|
|
175
|
+
text, cleaned_text = self.get_text(text, add_blank=self.add_blank)
|
|
176
|
+
mel = self.get_mel(filepath)
|
|
177
|
+
|
|
178
|
+
durations = self.get_durations(filepath, text) if self.load_durations else None
|
|
179
|
+
|
|
180
|
+
return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations}
|
|
181
|
+
|
|
182
|
+
def get_durations(self, filepath, text):
|
|
183
|
+
filepath = Path(filepath)
|
|
184
|
+
data_dir, name = filepath.parent.parent, filepath.stem
|
|
185
|
+
|
|
186
|
+
try:
|
|
187
|
+
dur_loc = data_dir / "durations" / f"{name}.npy"
|
|
188
|
+
durs = torch.from_numpy(np.load(dur_loc).astype(int))
|
|
189
|
+
|
|
190
|
+
except FileNotFoundError as e:
|
|
191
|
+
raise FileNotFoundError(
|
|
192
|
+
f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n"
|
|
193
|
+
) from e
|
|
194
|
+
|
|
195
|
+
assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match"
|
|
196
|
+
|
|
197
|
+
return durs
|
|
198
|
+
|
|
199
|
+
def get_mel(self, filepath):
|
|
200
|
+
audio, sr = ta.load(filepath)
|
|
201
|
+
assert sr == self.sample_rate
|
|
202
|
+
mel = mel_spectrogram(
|
|
203
|
+
audio,
|
|
204
|
+
self.n_fft,
|
|
205
|
+
self.n_mels,
|
|
206
|
+
self.sample_rate,
|
|
207
|
+
self.hop_length,
|
|
208
|
+
self.win_length,
|
|
209
|
+
self.f_min,
|
|
210
|
+
self.f_max,
|
|
211
|
+
center=False,
|
|
212
|
+
).squeeze()
|
|
213
|
+
mel = normalize(mel, self.data_parameters["mel_mean"], self.data_parameters["mel_std"])
|
|
214
|
+
return mel
|
|
215
|
+
|
|
216
|
+
def get_text(self, text, add_blank=True):
|
|
217
|
+
text_norm, cleaned_text = text_to_sequence(text, self.cleaners)
|
|
218
|
+
if self.add_blank:
|
|
219
|
+
text_norm = intersperse(text_norm, 0)
|
|
220
|
+
text_norm = torch.IntTensor(text_norm)
|
|
221
|
+
return text_norm, cleaned_text
|
|
222
|
+
|
|
223
|
+
def __getitem__(self, index):
|
|
224
|
+
datapoint = self.get_datapoint(self.filepaths_and_text[index])
|
|
225
|
+
return datapoint
|
|
226
|
+
|
|
227
|
+
def __len__(self):
|
|
228
|
+
return len(self.filepaths_and_text)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class TextMelBatchCollate:
|
|
232
|
+
def __init__(self, n_spks):
|
|
233
|
+
self.n_spks = n_spks
|
|
234
|
+
|
|
235
|
+
def __call__(self, batch):
|
|
236
|
+
B = len(batch)
|
|
237
|
+
y_max_length = max([item["y"].shape[-1] for item in batch])
|
|
238
|
+
y_max_length = fix_len_compatibility(y_max_length)
|
|
239
|
+
x_max_length = max([item["x"].shape[-1] for item in batch])
|
|
240
|
+
n_feats = batch[0]["y"].shape[-2]
|
|
241
|
+
|
|
242
|
+
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
|
|
243
|
+
x = torch.zeros((B, x_max_length), dtype=torch.long)
|
|
244
|
+
durations = torch.zeros((B, x_max_length), dtype=torch.long)
|
|
245
|
+
|
|
246
|
+
y_lengths, x_lengths = [], []
|
|
247
|
+
spks = []
|
|
248
|
+
filepaths, x_texts = [], []
|
|
249
|
+
for i, item in enumerate(batch):
|
|
250
|
+
y_, x_ = item["y"], item["x"]
|
|
251
|
+
y_lengths.append(y_.shape[-1])
|
|
252
|
+
x_lengths.append(x_.shape[-1])
|
|
253
|
+
y[i, :, : y_.shape[-1]] = y_
|
|
254
|
+
x[i, : x_.shape[-1]] = x_
|
|
255
|
+
spks.append(item["spk"])
|
|
256
|
+
filepaths.append(item["filepath"])
|
|
257
|
+
x_texts.append(item["x_text"])
|
|
258
|
+
if item["durations"] is not None:
|
|
259
|
+
durations[i, : item["durations"].shape[-1]] = item["durations"]
|
|
260
|
+
|
|
261
|
+
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
|
|
262
|
+
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
|
|
263
|
+
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
|
|
264
|
+
|
|
265
|
+
return {
|
|
266
|
+
"x": x,
|
|
267
|
+
"x_lengths": x_lengths,
|
|
268
|
+
"y": y,
|
|
269
|
+
"y_lengths": y_lengths,
|
|
270
|
+
"spks": spks,
|
|
271
|
+
"filepaths": filepaths,
|
|
272
|
+
"x_texts": x_texts,
|
|
273
|
+
"durations": durations if not torch.eq(durations, 0).all() else None,
|
|
274
|
+
}
|
|
File without changes
|