tiny-tts 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tiny_tts/__init__.py +102 -0
- tiny_tts/alignment/__init__.py +16 -0
- tiny_tts/alignment/core.py +46 -0
- tiny_tts/infer.py +191 -0
- tiny_tts/infer_onnx.py +199 -0
- tiny_tts/integrations/__init__.py +6 -0
- tiny_tts/integrations/langchain.py +179 -0
- tiny_tts/models/__init__.py +1 -0
- tiny_tts/models/synthesizer.py +718 -0
- tiny_tts/nn/__init__.py +1 -0
- tiny_tts/nn/attentions.py +424 -0
- tiny_tts/nn/commons.py +151 -0
- tiny_tts/nn/modules.py +578 -0
- tiny_tts/nn/transforms.py +209 -0
- tiny_tts/text/__init__.py +19 -0
- tiny_tts/text/advanced_normalization.py +370 -0
- tiny_tts/text/english.py +173 -0
- tiny_tts/text/english_utils/__init__.py +0 -0
- tiny_tts/text/english_utils/abbreviations.py +35 -0
- tiny_tts/text/english_utils/number_norm.py +97 -0
- tiny_tts/text/english_utils/time_norm.py +47 -0
- tiny_tts/text/symbols.py +293 -0
- tiny_tts/utils/__init__.py +5 -0
- tiny_tts/utils/config.py +42 -0
- tiny_tts-0.2.0.dist-info/METADATA +458 -0
- tiny_tts-0.2.0.dist-info/RECORD +30 -0
- tiny_tts-0.2.0.dist-info/WHEEL +5 -0
- tiny_tts-0.2.0.dist-info/entry_points.txt +3 -0
- tiny_tts-0.2.0.dist-info/licenses/LICENSE +152 -0
- tiny_tts-0.2.0.dist-info/top_level.txt +1 -0
tiny_tts/__init__.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import soundfile as sf
|
|
4
|
+
from tiny_tts.text.english import normalize_text, grapheme_to_phoneme
|
|
5
|
+
from tiny_tts.text.advanced_normalization import advanced_normalize
|
|
6
|
+
from tiny_tts.text import phonemes_to_ids
|
|
7
|
+
from tiny_tts.nn import commons
|
|
8
|
+
from tiny_tts.models.synthesizer import VoiceSynthesizer
|
|
9
|
+
from tiny_tts.text.symbols import symbols
|
|
10
|
+
from tiny_tts.utils.config import (
|
|
11
|
+
SAMPLING_RATE, SEGMENT_FRAMES, ADD_BLANK, SPEC_CHANNELS,
|
|
12
|
+
N_SPEAKERS, SPK2ID, MODEL_PARAMS,
|
|
13
|
+
)
|
|
14
|
+
from tiny_tts.infer import load_engine
|
|
15
|
+
|
|
16
|
+
class TinyTTS:
|
|
17
|
+
def __init__(self, checkpoint_path=None, device=None):
|
|
18
|
+
if device is None:
|
|
19
|
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
20
|
+
else:
|
|
21
|
+
self.device = device
|
|
22
|
+
|
|
23
|
+
if checkpoint_path is None:
|
|
24
|
+
# Look for default checkpoint in pacakage
|
|
25
|
+
pkg_dir = os.path.dirname(os.path.abspath(__file__))
|
|
26
|
+
default_ckpt = os.path.join(os.path.dirname(pkg_dir), "checkpoints", "G.pth")
|
|
27
|
+
# 2. Check HuggingFace Cache / Download
|
|
28
|
+
if not os.path.exists(default_ckpt):
|
|
29
|
+
try:
|
|
30
|
+
from huggingface_hub import hf_hub_download
|
|
31
|
+
print("Downloading/Loading checkpoint from Hugging Face Hub (backtracking/tiny-tts)...")
|
|
32
|
+
default_ckpt = hf_hub_download(repo_id="backtracking/tiny-tts", filename="G.pth")
|
|
33
|
+
except ImportError:
|
|
34
|
+
raise ImportError("huggingface_hub is required to auto-download the model. Run: pip install huggingface_hub")
|
|
35
|
+
except Exception as e:
|
|
36
|
+
raise ValueError(f"Failed to download checkpoint from Hugging Face: {e}")
|
|
37
|
+
|
|
38
|
+
checkpoint_path = default_ckpt
|
|
39
|
+
|
|
40
|
+
self.model = load_engine(checkpoint_path, self.device)
|
|
41
|
+
|
|
42
|
+
def speak(self, text, output_path="output.wav", speaker="MALE", speed=1.0, use_advanced_normalization=True):
|
|
43
|
+
"""Synthesize text to speech and save to output_path.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
text: Input text
|
|
47
|
+
output_path: Output audio file path
|
|
48
|
+
speaker: Speaker ID (default: "MALE")
|
|
49
|
+
speed: Speech speed (1.0=normal, 1.5=faster, 0.7=slower)
|
|
50
|
+
use_advanced_normalization: Enable advanced text normalization (URLs, emails, etc.)
|
|
51
|
+
"""
|
|
52
|
+
print(f"Synthesizing: {text}")
|
|
53
|
+
|
|
54
|
+
# Normalize text
|
|
55
|
+
if use_advanced_normalization:
|
|
56
|
+
normalized = advanced_normalize(text)
|
|
57
|
+
else:
|
|
58
|
+
normalized = normalize_text(text)
|
|
59
|
+
|
|
60
|
+
# Phonemize
|
|
61
|
+
phones, tones, word2ph = grapheme_to_phoneme(normalized)
|
|
62
|
+
|
|
63
|
+
# Convert to sequence
|
|
64
|
+
phone_ids, tone_ids, lang_ids = phonemes_to_ids(phones, tones, "EN")
|
|
65
|
+
|
|
66
|
+
# Add blanks
|
|
67
|
+
if ADD_BLANK:
|
|
68
|
+
phone_ids = commons.insert_blanks(phone_ids, 0)
|
|
69
|
+
tone_ids = commons.insert_blanks(tone_ids, 0)
|
|
70
|
+
lang_ids = commons.insert_blanks(lang_ids, 0)
|
|
71
|
+
|
|
72
|
+
x = torch.LongTensor(phone_ids).unsqueeze(0).to(self.device)
|
|
73
|
+
x_lengths = torch.LongTensor([len(phone_ids)]).to(self.device)
|
|
74
|
+
tone = torch.LongTensor(tone_ids).unsqueeze(0).to(self.device)
|
|
75
|
+
language = torch.LongTensor(lang_ids).unsqueeze(0).to(self.device)
|
|
76
|
+
|
|
77
|
+
# Speaker ID
|
|
78
|
+
if speaker not in SPK2ID:
|
|
79
|
+
print(f"Warning: Speaker '{speaker}' not found, using ID 0. Available: {list(SPK2ID.keys())}")
|
|
80
|
+
sid = torch.LongTensor([0]).to(self.device)
|
|
81
|
+
else:
|
|
82
|
+
sid = torch.LongTensor([SPK2ID[speaker]]).to(self.device)
|
|
83
|
+
|
|
84
|
+
# BERT features (disabled - using zero tensors)
|
|
85
|
+
bert = torch.zeros(1024, len(phone_ids)).to(self.device).unsqueeze(0)
|
|
86
|
+
ja_bert = torch.zeros(768, len(phone_ids)).to(self.device).unsqueeze(0)
|
|
87
|
+
|
|
88
|
+
# speed > 1.0 = faster speech, < 1.0 = slower speech
|
|
89
|
+
length_scale = 1.0 / speed
|
|
90
|
+
|
|
91
|
+
with torch.no_grad():
|
|
92
|
+
audio, *_ = self.model.infer(
|
|
93
|
+
x, x_lengths, sid, tone, language, bert, ja_bert,
|
|
94
|
+
noise_scale=0.667,
|
|
95
|
+
noise_scale_w=0.8,
|
|
96
|
+
length_scale=length_scale
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
audio_np = audio[0, 0].cpu().numpy()
|
|
100
|
+
sf.write(output_path, audio_np, SAMPLING_RATE)
|
|
101
|
+
print(f"Saved audio to {output_path}")
|
|
102
|
+
return audio_np
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from numpy import zeros, int32, float32
|
|
2
|
+
from torch import from_numpy
|
|
3
|
+
|
|
4
|
+
from .core import viterbi_decode_kernel
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def viterbi_decode(neg_cent, mask):
|
|
8
|
+
device = neg_cent.device
|
|
9
|
+
dtype = neg_cent.dtype
|
|
10
|
+
neg_cent = neg_cent.data.cpu().numpy().astype(float32)
|
|
11
|
+
path = zeros(neg_cent.shape, dtype=int32)
|
|
12
|
+
|
|
13
|
+
t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
|
|
14
|
+
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
|
|
15
|
+
viterbi_decode_kernel(path, neg_cent, t_t_max, t_s_max)
|
|
16
|
+
return from_numpy(path).to(device=device, dtype=dtype)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import numba
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@numba.jit(
|
|
5
|
+
numba.void(
|
|
6
|
+
numba.int32[:, :, ::1],
|
|
7
|
+
numba.float32[:, :, ::1],
|
|
8
|
+
numba.int32[::1],
|
|
9
|
+
numba.int32[::1],
|
|
10
|
+
),
|
|
11
|
+
nopython=True,
|
|
12
|
+
nogil=True,
|
|
13
|
+
)
|
|
14
|
+
def viterbi_decode_kernel(paths, values, t_ys, t_xs):
|
|
15
|
+
b = paths.shape[0]
|
|
16
|
+
max_neg_val = -1e9
|
|
17
|
+
for i in range(int(b)):
|
|
18
|
+
path = paths[i]
|
|
19
|
+
value = values[i]
|
|
20
|
+
t_y = t_ys[i]
|
|
21
|
+
t_x = t_xs[i]
|
|
22
|
+
|
|
23
|
+
v_prev = v_cur = 0.0
|
|
24
|
+
index = t_x - 1
|
|
25
|
+
|
|
26
|
+
for y in range(t_y):
|
|
27
|
+
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
|
28
|
+
if x == y:
|
|
29
|
+
v_cur = max_neg_val
|
|
30
|
+
else:
|
|
31
|
+
v_cur = value[y - 1, x]
|
|
32
|
+
if x == 0:
|
|
33
|
+
if y == 0:
|
|
34
|
+
v_prev = 0.0
|
|
35
|
+
else:
|
|
36
|
+
v_prev = max_neg_val
|
|
37
|
+
else:
|
|
38
|
+
v_prev = value[y - 1, x - 1]
|
|
39
|
+
value[y, x] += max(v_prev, v_cur)
|
|
40
|
+
|
|
41
|
+
for y in range(t_y - 1, -1, -1):
|
|
42
|
+
path[y, index] = 1
|
|
43
|
+
if index != 0 and (
|
|
44
|
+
index == y or value[y - 1, index] < value[y - 1, index - 1]
|
|
45
|
+
):
|
|
46
|
+
index = index - 1
|
tiny_tts/infer.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import re
|
|
4
|
+
import torch
|
|
5
|
+
import soundfile as sf
|
|
6
|
+
import argparse
|
|
7
|
+
from tiny_tts.text.english import normalize_text, grapheme_to_phoneme
|
|
8
|
+
from tiny_tts.text import phonemes_to_ids
|
|
9
|
+
from tiny_tts.nn import commons
|
|
10
|
+
from tiny_tts.models import VoiceSynthesizer
|
|
11
|
+
from tiny_tts.text.symbols import symbols
|
|
12
|
+
from tiny_tts.utils import (
|
|
13
|
+
SAMPLING_RATE, SEGMENT_FRAMES, ADD_BLANK, SPEC_CHANNELS,
|
|
14
|
+
N_SPEAKERS, SPK2ID, MODEL_PARAMS,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def load_engine(checkpoint_path, device='cuda'):
|
|
19
|
+
print(f"Loading model from {checkpoint_path}")
|
|
20
|
+
net_g = VoiceSynthesizer(
|
|
21
|
+
len(symbols),
|
|
22
|
+
SPEC_CHANNELS,
|
|
23
|
+
SEGMENT_FRAMES,
|
|
24
|
+
n_speakers=N_SPEAKERS,
|
|
25
|
+
**MODEL_PARAMS
|
|
26
|
+
).to(device)
|
|
27
|
+
|
|
28
|
+
# Count model parameters
|
|
29
|
+
total_params = sum(p.numel() for p in net_g.parameters())
|
|
30
|
+
trainable_params = sum(p.numel() for p in net_g.parameters() if p.requires_grad)
|
|
31
|
+
print(f"Model parameters: {total_params/1e6:.2f}M total, {trainable_params/1e6:.2f}M trainable")
|
|
32
|
+
|
|
33
|
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
34
|
+
state_dict = checkpoint['model']
|
|
35
|
+
|
|
36
|
+
# Remove module. prefix and filter shape mismatches
|
|
37
|
+
model_state = net_g.state_dict()
|
|
38
|
+
new_state_dict = {}
|
|
39
|
+
skipped = []
|
|
40
|
+
for k, v in state_dict.items():
|
|
41
|
+
key = k[7:] if k.startswith('module.') else k
|
|
42
|
+
if key in model_state:
|
|
43
|
+
if v.shape == model_state[key].shape:
|
|
44
|
+
new_state_dict[key] = v
|
|
45
|
+
else:
|
|
46
|
+
skipped.append(f"{key}: ckpt{v.shape} vs model{model_state[key].shape}")
|
|
47
|
+
else:
|
|
48
|
+
new_state_dict[key] = v
|
|
49
|
+
|
|
50
|
+
if skipped:
|
|
51
|
+
print(f"Skipped {len(skipped)} mismatched keys:")
|
|
52
|
+
for s in skipped[:5]:
|
|
53
|
+
print(f" {s}")
|
|
54
|
+
if len(skipped) > 5:
|
|
55
|
+
print(f" ... and {len(skipped)-5} more")
|
|
56
|
+
|
|
57
|
+
net_g.load_state_dict(new_state_dict, strict=False)
|
|
58
|
+
net_g.eval()
|
|
59
|
+
|
|
60
|
+
# Fold weight_norm into weight tensors for faster inference (~18% speedup)
|
|
61
|
+
net_g.dec.remove_weight_norm()
|
|
62
|
+
|
|
63
|
+
return net_g
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def synthesize(text, output_path, model, speaker="MALE", device='cuda', speed=1.0):
|
|
67
|
+
print(f"Synthesizing: {text}")
|
|
68
|
+
|
|
69
|
+
# Normalize text
|
|
70
|
+
normalized = normalize_text(text)
|
|
71
|
+
|
|
72
|
+
# Phonemize
|
|
73
|
+
phones, tones, word2ph = grapheme_to_phoneme(normalized)
|
|
74
|
+
|
|
75
|
+
# Convert to sequence
|
|
76
|
+
phone_ids, tone_ids, lang_ids = phonemes_to_ids(phones, tones, "EN")
|
|
77
|
+
|
|
78
|
+
# Add blanks
|
|
79
|
+
if ADD_BLANK:
|
|
80
|
+
phone_ids = commons.insert_blanks(phone_ids, 0)
|
|
81
|
+
tone_ids = commons.insert_blanks(tone_ids, 0)
|
|
82
|
+
lang_ids = commons.insert_blanks(lang_ids, 0)
|
|
83
|
+
|
|
84
|
+
x = torch.LongTensor(phone_ids).unsqueeze(0).to(device)
|
|
85
|
+
x_lengths = torch.LongTensor([len(phone_ids)]).to(device)
|
|
86
|
+
tone = torch.LongTensor(tone_ids).unsqueeze(0).to(device)
|
|
87
|
+
language = torch.LongTensor(lang_ids).unsqueeze(0).to(device)
|
|
88
|
+
|
|
89
|
+
# Speaker ID
|
|
90
|
+
if speaker not in SPK2ID:
|
|
91
|
+
print(f"Warning: Speaker {speaker} not found, using ID 0")
|
|
92
|
+
sid = torch.LongTensor([0]).to(device)
|
|
93
|
+
else:
|
|
94
|
+
sid = torch.LongTensor([SPK2ID[speaker]]).to(device)
|
|
95
|
+
|
|
96
|
+
# BERT features (disabled - using zero tensors)
|
|
97
|
+
bert = torch.zeros(1024, len(phone_ids)).to(device).unsqueeze(0)
|
|
98
|
+
ja_bert = torch.zeros(768, len(phone_ids)).to(device).unsqueeze(0)
|
|
99
|
+
|
|
100
|
+
# speed > 1.0 = faster speech, < 1.0 = slower speech
|
|
101
|
+
length_scale = 1.0 / speed
|
|
102
|
+
|
|
103
|
+
with torch.no_grad():
|
|
104
|
+
audio, *_ = model.infer(
|
|
105
|
+
x, x_lengths, sid, tone, language, bert, ja_bert,
|
|
106
|
+
noise_scale=0.667,
|
|
107
|
+
noise_scale_w=0.8,
|
|
108
|
+
length_scale=length_scale
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
audio = audio[0, 0].cpu().numpy()
|
|
112
|
+
sf.write(output_path, audio, SAMPLING_RATE)
|
|
113
|
+
print(f"Saved audio to {output_path}")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def get_latest_checkpoint(checkpoint_dir):
|
|
117
|
+
"""Finds the latest G_*.pth checkpoint in the given directory."""
|
|
118
|
+
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith('G_') and f.endswith('.pth')]
|
|
119
|
+
if not checkpoints:
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
def get_step(filename):
|
|
123
|
+
match = re.search(r'_(\d+)\.pth', filename)
|
|
124
|
+
return int(match.group(1)) if match else -1
|
|
125
|
+
|
|
126
|
+
latest_ckpt = max(checkpoints, key=get_step)
|
|
127
|
+
return os.path.join(checkpoint_dir, latest_ckpt)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def main():
|
|
131
|
+
parser = argparse.ArgumentParser(description="TinyTTS — English Text-to-Speech Inference")
|
|
132
|
+
parser.add_argument("--text", "-t", type=str, default="The weather is nice today, and I feel very relaxed.", help="Text to synthesize")
|
|
133
|
+
parser.add_argument("--checkpoint", "-c", type=str, default=None, help="Path to checkpoint. Auto-downloads if not provided.")
|
|
134
|
+
parser.add_argument("--output", "-o", type=str, default="output.wav", help="Output audio file path")
|
|
135
|
+
parser.add_argument("--speaker", "-s", type=str, default="MALE", help="Speaker ID")
|
|
136
|
+
parser.add_argument("--speed", type=float, default=1.0, help="Speech speed (1.0=normal, 1.5=faster, 0.7=slower)")
|
|
137
|
+
parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda or cpu)")
|
|
138
|
+
|
|
139
|
+
args = parser.parse_args()
|
|
140
|
+
|
|
141
|
+
if args.checkpoint is None:
|
|
142
|
+
try:
|
|
143
|
+
from huggingface_hub import hf_hub_download
|
|
144
|
+
print("Downloading/Loading checkpoint from Hugging Face Hub (backtracking/tiny-tts)...")
|
|
145
|
+
args.checkpoint = hf_hub_download(repo_id="backtracking/tiny-tts", filename="G.pth")
|
|
146
|
+
except ImportError:
|
|
147
|
+
print("Error: huggingface_hub is required for auto-download. Run: pip install huggingface_hub")
|
|
148
|
+
sys.exit(1)
|
|
149
|
+
except Exception as e:
|
|
150
|
+
print(f"Error downloading checkpoint: {e}")
|
|
151
|
+
sys.exit(1)
|
|
152
|
+
|
|
153
|
+
if not os.path.exists(args.checkpoint):
|
|
154
|
+
print(f"Error: Checkpoint or directory not found at {args.checkpoint}")
|
|
155
|
+
sys.exit(1)
|
|
156
|
+
|
|
157
|
+
if os.path.isdir(args.checkpoint):
|
|
158
|
+
latest_ckpt = get_latest_checkpoint(args.checkpoint)
|
|
159
|
+
if not latest_ckpt:
|
|
160
|
+
print(f"Error: No G_*.pth checkpoints found in directory {args.checkpoint}")
|
|
161
|
+
sys.exit(1)
|
|
162
|
+
args.checkpoint = latest_ckpt
|
|
163
|
+
print(f"Auto-detected latest checkpoint: {args.checkpoint}")
|
|
164
|
+
|
|
165
|
+
# Extract step from checkpoint filename
|
|
166
|
+
ckpt_basename = os.path.basename(args.checkpoint)
|
|
167
|
+
match = re.search(r'_(\d+)\.pth', ckpt_basename)
|
|
168
|
+
step_str = match.group(1) if match else "unknown"
|
|
169
|
+
|
|
170
|
+
# Save to output folder
|
|
171
|
+
out_dir = "infer_outputs"
|
|
172
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
173
|
+
|
|
174
|
+
out_name = os.path.basename(args.output)
|
|
175
|
+
name, ext = os.path.splitext(out_name)
|
|
176
|
+
model = load_engine(args.checkpoint, args.device)
|
|
177
|
+
|
|
178
|
+
if args.speaker.lower() == "all":
|
|
179
|
+
if not SPK2ID:
|
|
180
|
+
print("Error: No speakers found")
|
|
181
|
+
sys.exit(1)
|
|
182
|
+
print(f"Synthesizing for all {len(SPK2ID)} speakers...")
|
|
183
|
+
for spk in SPK2ID.keys():
|
|
184
|
+
final_output = os.path.join(out_dir, f"{name}_step{step_str}_spk{spk}{ext}")
|
|
185
|
+
synthesize(args.text, final_output, model, speaker=spk, device=args.device, speed=args.speed)
|
|
186
|
+
else:
|
|
187
|
+
final_output = os.path.join(out_dir, f"{name}_step{step_str}_spk{args.speaker}{ext}")
|
|
188
|
+
synthesize(args.text, final_output, model, speaker=args.speaker, device=args.device, speed=args.speed)
|
|
189
|
+
|
|
190
|
+
if __name__ == "__main__":
|
|
191
|
+
main()
|
tiny_tts/infer_onnx.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ONNX Runtime inference engine for TinyTTS.
|
|
3
|
+
|
|
4
|
+
Replaces the PyTorch VoiceSynthesizer.infer() with equivalent
|
|
5
|
+
ONNX Runtime sessions + NumPy ops for the non-exported parts
|
|
6
|
+
(alignment path computation).
|
|
7
|
+
"""
|
|
8
|
+
import os
|
|
9
|
+
import numpy as np
|
|
10
|
+
import soundfile as sf
|
|
11
|
+
|
|
12
|
+
from tiny_tts.text.english import normalize_text, grapheme_to_phoneme
|
|
13
|
+
from tiny_tts.text import phonemes_to_ids
|
|
14
|
+
from tiny_tts.nn import commons
|
|
15
|
+
from tiny_tts.utils.config import (
|
|
16
|
+
SAMPLING_RATE, ADD_BLANK, SPK2ID,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
import onnxruntime as ort
|
|
21
|
+
except ImportError:
|
|
22
|
+
raise ImportError("onnxruntime is required. Run: pip install onnxruntime")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _build_session(path: str, use_gpu: bool = False):
|
|
26
|
+
"""Create an ORT InferenceSession with optional GPU support."""
|
|
27
|
+
providers = (
|
|
28
|
+
["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
29
|
+
if use_gpu else
|
|
30
|
+
["CPUExecutionProvider"]
|
|
31
|
+
)
|
|
32
|
+
opts = ort.SessionOptions()
|
|
33
|
+
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
34
|
+
opts.intra_op_num_threads = os.cpu_count() or 4
|
|
35
|
+
return ort.InferenceSession(path, sess_options=opts, providers=providers)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _create_length_mask_np(lengths, max_len=None):
|
|
39
|
+
"""NumPy equivalent of commons.create_length_mask."""
|
|
40
|
+
if max_len is None:
|
|
41
|
+
max_len = int(lengths.max())
|
|
42
|
+
ids = np.arange(max_len, dtype=np.float32) # [T]
|
|
43
|
+
mask = (ids[None, :] < lengths[:, None]).astype(np.float32) # [B, T]
|
|
44
|
+
return mask
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _compute_alignment_path_np(w_ceil, attn_mask):
|
|
48
|
+
"""
|
|
49
|
+
Monotonic alignment path - vectorized via cumsum (much faster than Python loops).
|
|
50
|
+
w_ceil: [B, 1, T_x] — integer duration per phone
|
|
51
|
+
attn_mask: [B, 1, T_y, T_x] — joint mask
|
|
52
|
+
Returns attn: [B, 1, T_y, T_x]
|
|
53
|
+
"""
|
|
54
|
+
B, _, T_x = w_ceil.shape
|
|
55
|
+
T_y = attn_mask.shape[2]
|
|
56
|
+
|
|
57
|
+
# Build duration matrix: for each phone column expand the duration
|
|
58
|
+
# cumulative sum of durations gives us the end frame index for each phone
|
|
59
|
+
dur = w_ceil[:, 0, :] # [B, T_x]
|
|
60
|
+
cum_dur = np.cumsum(dur, axis=1) # [B, T_x] — end frame (1-indexed)
|
|
61
|
+
cum_dur_prev = np.pad(cum_dur[:, :-1], ((0,0),(1,0))) # [B, T_x] — start frame
|
|
62
|
+
|
|
63
|
+
# Frame indices: [1, T_y, 1]
|
|
64
|
+
frame_idx = np.arange(T_y, dtype=np.float32)[None, :, None] # [1, T_y, 1]
|
|
65
|
+
# For each phone, mark frames [start, end)
|
|
66
|
+
# cum_dur_prev: [B,1,T_x], cum_dur: [B,1,T_x]
|
|
67
|
+
start = cum_dur_prev[:, None, :] # [B, 1, T_x]
|
|
68
|
+
end = cum_dur[:, None, :] # [B, 1, T_x]
|
|
69
|
+
attn = ((frame_idx >= start) & (frame_idx < end)).astype(np.float32) # [B, T_y, T_x]
|
|
70
|
+
attn = attn[:, None, :, :] # [B, 1, T_y, T_x]
|
|
71
|
+
return attn * attn_mask
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class OnnxTinyTTS:
|
|
75
|
+
"""
|
|
76
|
+
Inference using ONNX Runtime.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
onnx_dir: directory containing the 4 .onnx files
|
|
80
|
+
use_gpu: if True, try CUDAExecutionProvider
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self, onnx_dir: str = "onnx", use_gpu: bool = False):
|
|
84
|
+
onnx_dir = os.path.abspath(onnx_dir)
|
|
85
|
+
print(f"Loading ONNX sessions from: {onnx_dir}")
|
|
86
|
+
|
|
87
|
+
self._enc = _build_session(os.path.join(onnx_dir, "text_encoder.onnx"), use_gpu)
|
|
88
|
+
self._dp = _build_session(os.path.join(onnx_dir, "duration_predictor.onnx"), use_gpu)
|
|
89
|
+
self._flow = _build_session(os.path.join(onnx_dir, "flow.onnx"), use_gpu)
|
|
90
|
+
self._dec = _build_session(os.path.join(onnx_dir, "decoder.onnx"), use_gpu)
|
|
91
|
+
|
|
92
|
+
print("ONNX sessions ready ✅")
|
|
93
|
+
|
|
94
|
+
def _text_to_ids(self, text: str):
|
|
95
|
+
normalized = normalize_text(text)
|
|
96
|
+
phones, tones, _ = grapheme_to_phoneme(normalized)
|
|
97
|
+
phone_ids, tone_ids, lang_ids = phonemes_to_ids(phones, tones, "EN")
|
|
98
|
+
|
|
99
|
+
if ADD_BLANK:
|
|
100
|
+
phone_ids = commons.insert_blanks(phone_ids, 0)
|
|
101
|
+
tone_ids = commons.insert_blanks(tone_ids, 0)
|
|
102
|
+
lang_ids = commons.insert_blanks(lang_ids, 0)
|
|
103
|
+
|
|
104
|
+
return phone_ids, tone_ids, lang_ids
|
|
105
|
+
|
|
106
|
+
def speak(
|
|
107
|
+
self,
|
|
108
|
+
text: str,
|
|
109
|
+
output_path: str = "onnx_output.wav",
|
|
110
|
+
speaker: str = "female",
|
|
111
|
+
noise_scale: float = 0.667,
|
|
112
|
+
noise_scale_w: float = 0.8,
|
|
113
|
+
length_scale: float = 1.0,
|
|
114
|
+
output_sr: int = None,
|
|
115
|
+
) -> np.ndarray:
|
|
116
|
+
"""Synthesize speech and save to output_path.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
output_sr: If set (e.g. 22050), resample the output from 44100 Hz.
|
|
120
|
+
Useful to reduce file size while keeping quality.
|
|
121
|
+
"""
|
|
122
|
+
print(f"[ONNX] Synthesizing: {text}")
|
|
123
|
+
|
|
124
|
+
phone_ids, tone_ids, lang_ids = self._text_to_ids(text)
|
|
125
|
+
T = len(phone_ids)
|
|
126
|
+
|
|
127
|
+
# Prepare inputs as float32 / int64 arrays
|
|
128
|
+
x = np.array(phone_ids, dtype=np.int64)[None, :] # [1, T]
|
|
129
|
+
x_len = np.array([T], dtype=np.int64) # [1]
|
|
130
|
+
tone = np.array(tone_ids, dtype=np.int64)[None, :] # [1, T]
|
|
131
|
+
lang = np.array(lang_ids, dtype=np.int64)[None, :] # [1, T]
|
|
132
|
+
bert = np.zeros((1, 1024, T), dtype=np.float32)
|
|
133
|
+
ja_bert = np.zeros((1, 768, T), dtype=np.float32)
|
|
134
|
+
sid_val = SPK2ID.get(speaker, 0)
|
|
135
|
+
sid = np.array([sid_val], dtype=np.int64) # [1]
|
|
136
|
+
|
|
137
|
+
# ── 1. Text Encoder ──────────────────────────────────────────────
|
|
138
|
+
x_enc, m_p, logs_p, x_mask, g = self._enc.run(
|
|
139
|
+
None,
|
|
140
|
+
{
|
|
141
|
+
"phone_ids": x,
|
|
142
|
+
"phone_lengths":x_len,
|
|
143
|
+
"tone_ids": tone,
|
|
144
|
+
"language_ids": lang,
|
|
145
|
+
"bert": bert,
|
|
146
|
+
"ja_bert": ja_bert,
|
|
147
|
+
"speaker_id": sid,
|
|
148
|
+
},
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# ── 2. Duration Predictor ─────────────────────────────────────────
|
|
152
|
+
logw = self._dp.run(None, {"x": x_enc, "x_mask": x_mask, "g": g})[0]
|
|
153
|
+
|
|
154
|
+
# ── 3. Alignment Path (NumPy) ─────────────────────────────────────
|
|
155
|
+
w = np.exp(logw) * x_mask * length_scale # [1, 1, T]
|
|
156
|
+
w_ceil = np.ceil(w) # [1, 1, T]
|
|
157
|
+
y_len = max(1, int(w_ceil.sum()))
|
|
158
|
+
y_lens = np.array([y_len], dtype=np.int64)
|
|
159
|
+
|
|
160
|
+
y_mask = _create_length_mask_np(y_lens, y_len) # [1, T_y]
|
|
161
|
+
y_mask = y_mask[:, None, :] # [1, 1, T_y]
|
|
162
|
+
# attn_mask: [1, 1, T_y, T_x] (outer product of frame mask and phone mask)
|
|
163
|
+
attn_mask = y_mask[:, :, :, None] * x_mask[:, :, None, :] # [1,1,T_y,T_x]
|
|
164
|
+
attn = _compute_alignment_path_np(w_ceil, attn_mask) # [1, 1, T_y, T_x]
|
|
165
|
+
|
|
166
|
+
# Expand prior stats via alignment
|
|
167
|
+
m_p_exp = np.matmul(attn[:, 0], m_p.transpose(0, 2, 1)).transpose(0, 2, 1)
|
|
168
|
+
logs_p_exp = np.matmul(attn[:, 0], logs_p.transpose(0, 2, 1)).transpose(0, 2, 1)
|
|
169
|
+
|
|
170
|
+
# ── 4. Sample z_p ─────────────────────────────────────────────────
|
|
171
|
+
z_p = m_p_exp + np.random.randn(*m_p_exp.shape).astype(np.float32) * \
|
|
172
|
+
np.exp(logs_p_exp) * noise_scale
|
|
173
|
+
|
|
174
|
+
# ── 5. Flow (reverse) ─────────────────────────────────────────────
|
|
175
|
+
z = self._flow.run(
|
|
176
|
+
None,
|
|
177
|
+
{"z_p": z_p, "y_mask": y_mask.astype(np.float32), "g": g},
|
|
178
|
+
)[0]
|
|
179
|
+
|
|
180
|
+
# ── 6. Decoder ────────────────────────────────────────────────────
|
|
181
|
+
z_masked = (z * y_mask).astype(np.float32)
|
|
182
|
+
audio = self._dec.run(None, {"z": z_masked, "g": g})[0] # [1, 1, samples]
|
|
183
|
+
|
|
184
|
+
audio_np = audio[0, 0]
|
|
185
|
+
save_sr = SAMPLING_RATE
|
|
186
|
+
if output_sr is not None and output_sr != SAMPLING_RATE:
|
|
187
|
+
try:
|
|
188
|
+
import torchaudio
|
|
189
|
+
import torch
|
|
190
|
+
wav_t = torch.from_numpy(audio_np).unsqueeze(0)
|
|
191
|
+
resampler = torchaudio.transforms.Resample(SAMPLING_RATE, output_sr)
|
|
192
|
+
audio_np = resampler(wav_t).squeeze(0).numpy()
|
|
193
|
+
save_sr = output_sr
|
|
194
|
+
except Exception as e:
|
|
195
|
+
print(f"[ONNX] Resampling failed ({e}), saving at {SAMPLING_RATE}Hz")
|
|
196
|
+
|
|
197
|
+
sf.write(output_path, audio_np, save_sr)
|
|
198
|
+
print(f"[ONNX] Saved: {output_path} ({save_sr}Hz)")
|
|
199
|
+
return audio_np
|