tiny-tts 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tiny_tts/__init__.py ADDED
@@ -0,0 +1,102 @@
1
+ import os
2
+ import torch
3
+ import soundfile as sf
4
+ from tiny_tts.text.english import normalize_text, grapheme_to_phoneme
5
+ from tiny_tts.text.advanced_normalization import advanced_normalize
6
+ from tiny_tts.text import phonemes_to_ids
7
+ from tiny_tts.nn import commons
8
+ from tiny_tts.models.synthesizer import VoiceSynthesizer
9
+ from tiny_tts.text.symbols import symbols
10
+ from tiny_tts.utils.config import (
11
+ SAMPLING_RATE, SEGMENT_FRAMES, ADD_BLANK, SPEC_CHANNELS,
12
+ N_SPEAKERS, SPK2ID, MODEL_PARAMS,
13
+ )
14
+ from tiny_tts.infer import load_engine
15
+
16
+ class TinyTTS:
17
+ def __init__(self, checkpoint_path=None, device=None):
18
+ if device is None:
19
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+ else:
21
+ self.device = device
22
+
23
+ if checkpoint_path is None:
24
+ # Look for default checkpoint in pacakage
25
+ pkg_dir = os.path.dirname(os.path.abspath(__file__))
26
+ default_ckpt = os.path.join(os.path.dirname(pkg_dir), "checkpoints", "G.pth")
27
+ # 2. Check HuggingFace Cache / Download
28
+ if not os.path.exists(default_ckpt):
29
+ try:
30
+ from huggingface_hub import hf_hub_download
31
+ print("Downloading/Loading checkpoint from Hugging Face Hub (backtracking/tiny-tts)...")
32
+ default_ckpt = hf_hub_download(repo_id="backtracking/tiny-tts", filename="G.pth")
33
+ except ImportError:
34
+ raise ImportError("huggingface_hub is required to auto-download the model. Run: pip install huggingface_hub")
35
+ except Exception as e:
36
+ raise ValueError(f"Failed to download checkpoint from Hugging Face: {e}")
37
+
38
+ checkpoint_path = default_ckpt
39
+
40
+ self.model = load_engine(checkpoint_path, self.device)
41
+
42
+ def speak(self, text, output_path="output.wav", speaker="MALE", speed=1.0, use_advanced_normalization=True):
43
+ """Synthesize text to speech and save to output_path.
44
+
45
+ Args:
46
+ text: Input text
47
+ output_path: Output audio file path
48
+ speaker: Speaker ID (default: "MALE")
49
+ speed: Speech speed (1.0=normal, 1.5=faster, 0.7=slower)
50
+ use_advanced_normalization: Enable advanced text normalization (URLs, emails, etc.)
51
+ """
52
+ print(f"Synthesizing: {text}")
53
+
54
+ # Normalize text
55
+ if use_advanced_normalization:
56
+ normalized = advanced_normalize(text)
57
+ else:
58
+ normalized = normalize_text(text)
59
+
60
+ # Phonemize
61
+ phones, tones, word2ph = grapheme_to_phoneme(normalized)
62
+
63
+ # Convert to sequence
64
+ phone_ids, tone_ids, lang_ids = phonemes_to_ids(phones, tones, "EN")
65
+
66
+ # Add blanks
67
+ if ADD_BLANK:
68
+ phone_ids = commons.insert_blanks(phone_ids, 0)
69
+ tone_ids = commons.insert_blanks(tone_ids, 0)
70
+ lang_ids = commons.insert_blanks(lang_ids, 0)
71
+
72
+ x = torch.LongTensor(phone_ids).unsqueeze(0).to(self.device)
73
+ x_lengths = torch.LongTensor([len(phone_ids)]).to(self.device)
74
+ tone = torch.LongTensor(tone_ids).unsqueeze(0).to(self.device)
75
+ language = torch.LongTensor(lang_ids).unsqueeze(0).to(self.device)
76
+
77
+ # Speaker ID
78
+ if speaker not in SPK2ID:
79
+ print(f"Warning: Speaker '{speaker}' not found, using ID 0. Available: {list(SPK2ID.keys())}")
80
+ sid = torch.LongTensor([0]).to(self.device)
81
+ else:
82
+ sid = torch.LongTensor([SPK2ID[speaker]]).to(self.device)
83
+
84
+ # BERT features (disabled - using zero tensors)
85
+ bert = torch.zeros(1024, len(phone_ids)).to(self.device).unsqueeze(0)
86
+ ja_bert = torch.zeros(768, len(phone_ids)).to(self.device).unsqueeze(0)
87
+
88
+ # speed > 1.0 = faster speech, < 1.0 = slower speech
89
+ length_scale = 1.0 / speed
90
+
91
+ with torch.no_grad():
92
+ audio, *_ = self.model.infer(
93
+ x, x_lengths, sid, tone, language, bert, ja_bert,
94
+ noise_scale=0.667,
95
+ noise_scale_w=0.8,
96
+ length_scale=length_scale
97
+ )
98
+
99
+ audio_np = audio[0, 0].cpu().numpy()
100
+ sf.write(output_path, audio_np, SAMPLING_RATE)
101
+ print(f"Saved audio to {output_path}")
102
+ return audio_np
@@ -0,0 +1,16 @@
1
+ from numpy import zeros, int32, float32
2
+ from torch import from_numpy
3
+
4
+ from .core import viterbi_decode_kernel
5
+
6
+
7
+ def viterbi_decode(neg_cent, mask):
8
+ device = neg_cent.device
9
+ dtype = neg_cent.dtype
10
+ neg_cent = neg_cent.data.cpu().numpy().astype(float32)
11
+ path = zeros(neg_cent.shape, dtype=int32)
12
+
13
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
14
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
15
+ viterbi_decode_kernel(path, neg_cent, t_t_max, t_s_max)
16
+ return from_numpy(path).to(device=device, dtype=dtype)
@@ -0,0 +1,46 @@
1
+ import numba
2
+
3
+
4
+ @numba.jit(
5
+ numba.void(
6
+ numba.int32[:, :, ::1],
7
+ numba.float32[:, :, ::1],
8
+ numba.int32[::1],
9
+ numba.int32[::1],
10
+ ),
11
+ nopython=True,
12
+ nogil=True,
13
+ )
14
+ def viterbi_decode_kernel(paths, values, t_ys, t_xs):
15
+ b = paths.shape[0]
16
+ max_neg_val = -1e9
17
+ for i in range(int(b)):
18
+ path = paths[i]
19
+ value = values[i]
20
+ t_y = t_ys[i]
21
+ t_x = t_xs[i]
22
+
23
+ v_prev = v_cur = 0.0
24
+ index = t_x - 1
25
+
26
+ for y in range(t_y):
27
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
28
+ if x == y:
29
+ v_cur = max_neg_val
30
+ else:
31
+ v_cur = value[y - 1, x]
32
+ if x == 0:
33
+ if y == 0:
34
+ v_prev = 0.0
35
+ else:
36
+ v_prev = max_neg_val
37
+ else:
38
+ v_prev = value[y - 1, x - 1]
39
+ value[y, x] += max(v_prev, v_cur)
40
+
41
+ for y in range(t_y - 1, -1, -1):
42
+ path[y, index] = 1
43
+ if index != 0 and (
44
+ index == y or value[y - 1, index] < value[y - 1, index - 1]
45
+ ):
46
+ index = index - 1
tiny_tts/infer.py ADDED
@@ -0,0 +1,191 @@
1
+ import os
2
+ import sys
3
+ import re
4
+ import torch
5
+ import soundfile as sf
6
+ import argparse
7
+ from tiny_tts.text.english import normalize_text, grapheme_to_phoneme
8
+ from tiny_tts.text import phonemes_to_ids
9
+ from tiny_tts.nn import commons
10
+ from tiny_tts.models import VoiceSynthesizer
11
+ from tiny_tts.text.symbols import symbols
12
+ from tiny_tts.utils import (
13
+ SAMPLING_RATE, SEGMENT_FRAMES, ADD_BLANK, SPEC_CHANNELS,
14
+ N_SPEAKERS, SPK2ID, MODEL_PARAMS,
15
+ )
16
+
17
+
18
+ def load_engine(checkpoint_path, device='cuda'):
19
+ print(f"Loading model from {checkpoint_path}")
20
+ net_g = VoiceSynthesizer(
21
+ len(symbols),
22
+ SPEC_CHANNELS,
23
+ SEGMENT_FRAMES,
24
+ n_speakers=N_SPEAKERS,
25
+ **MODEL_PARAMS
26
+ ).to(device)
27
+
28
+ # Count model parameters
29
+ total_params = sum(p.numel() for p in net_g.parameters())
30
+ trainable_params = sum(p.numel() for p in net_g.parameters() if p.requires_grad)
31
+ print(f"Model parameters: {total_params/1e6:.2f}M total, {trainable_params/1e6:.2f}M trainable")
32
+
33
+ checkpoint = torch.load(checkpoint_path, map_location=device)
34
+ state_dict = checkpoint['model']
35
+
36
+ # Remove module. prefix and filter shape mismatches
37
+ model_state = net_g.state_dict()
38
+ new_state_dict = {}
39
+ skipped = []
40
+ for k, v in state_dict.items():
41
+ key = k[7:] if k.startswith('module.') else k
42
+ if key in model_state:
43
+ if v.shape == model_state[key].shape:
44
+ new_state_dict[key] = v
45
+ else:
46
+ skipped.append(f"{key}: ckpt{v.shape} vs model{model_state[key].shape}")
47
+ else:
48
+ new_state_dict[key] = v
49
+
50
+ if skipped:
51
+ print(f"Skipped {len(skipped)} mismatched keys:")
52
+ for s in skipped[:5]:
53
+ print(f" {s}")
54
+ if len(skipped) > 5:
55
+ print(f" ... and {len(skipped)-5} more")
56
+
57
+ net_g.load_state_dict(new_state_dict, strict=False)
58
+ net_g.eval()
59
+
60
+ # Fold weight_norm into weight tensors for faster inference (~18% speedup)
61
+ net_g.dec.remove_weight_norm()
62
+
63
+ return net_g
64
+
65
+
66
+ def synthesize(text, output_path, model, speaker="MALE", device='cuda', speed=1.0):
67
+ print(f"Synthesizing: {text}")
68
+
69
+ # Normalize text
70
+ normalized = normalize_text(text)
71
+
72
+ # Phonemize
73
+ phones, tones, word2ph = grapheme_to_phoneme(normalized)
74
+
75
+ # Convert to sequence
76
+ phone_ids, tone_ids, lang_ids = phonemes_to_ids(phones, tones, "EN")
77
+
78
+ # Add blanks
79
+ if ADD_BLANK:
80
+ phone_ids = commons.insert_blanks(phone_ids, 0)
81
+ tone_ids = commons.insert_blanks(tone_ids, 0)
82
+ lang_ids = commons.insert_blanks(lang_ids, 0)
83
+
84
+ x = torch.LongTensor(phone_ids).unsqueeze(0).to(device)
85
+ x_lengths = torch.LongTensor([len(phone_ids)]).to(device)
86
+ tone = torch.LongTensor(tone_ids).unsqueeze(0).to(device)
87
+ language = torch.LongTensor(lang_ids).unsqueeze(0).to(device)
88
+
89
+ # Speaker ID
90
+ if speaker not in SPK2ID:
91
+ print(f"Warning: Speaker {speaker} not found, using ID 0")
92
+ sid = torch.LongTensor([0]).to(device)
93
+ else:
94
+ sid = torch.LongTensor([SPK2ID[speaker]]).to(device)
95
+
96
+ # BERT features (disabled - using zero tensors)
97
+ bert = torch.zeros(1024, len(phone_ids)).to(device).unsqueeze(0)
98
+ ja_bert = torch.zeros(768, len(phone_ids)).to(device).unsqueeze(0)
99
+
100
+ # speed > 1.0 = faster speech, < 1.0 = slower speech
101
+ length_scale = 1.0 / speed
102
+
103
+ with torch.no_grad():
104
+ audio, *_ = model.infer(
105
+ x, x_lengths, sid, tone, language, bert, ja_bert,
106
+ noise_scale=0.667,
107
+ noise_scale_w=0.8,
108
+ length_scale=length_scale
109
+ )
110
+
111
+ audio = audio[0, 0].cpu().numpy()
112
+ sf.write(output_path, audio, SAMPLING_RATE)
113
+ print(f"Saved audio to {output_path}")
114
+
115
+
116
+ def get_latest_checkpoint(checkpoint_dir):
117
+ """Finds the latest G_*.pth checkpoint in the given directory."""
118
+ checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith('G_') and f.endswith('.pth')]
119
+ if not checkpoints:
120
+ return None
121
+
122
+ def get_step(filename):
123
+ match = re.search(r'_(\d+)\.pth', filename)
124
+ return int(match.group(1)) if match else -1
125
+
126
+ latest_ckpt = max(checkpoints, key=get_step)
127
+ return os.path.join(checkpoint_dir, latest_ckpt)
128
+
129
+
130
+ def main():
131
+ parser = argparse.ArgumentParser(description="TinyTTS — English Text-to-Speech Inference")
132
+ parser.add_argument("--text", "-t", type=str, default="The weather is nice today, and I feel very relaxed.", help="Text to synthesize")
133
+ parser.add_argument("--checkpoint", "-c", type=str, default=None, help="Path to checkpoint. Auto-downloads if not provided.")
134
+ parser.add_argument("--output", "-o", type=str, default="output.wav", help="Output audio file path")
135
+ parser.add_argument("--speaker", "-s", type=str, default="MALE", help="Speaker ID")
136
+ parser.add_argument("--speed", type=float, default=1.0, help="Speech speed (1.0=normal, 1.5=faster, 0.7=slower)")
137
+ parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda or cpu)")
138
+
139
+ args = parser.parse_args()
140
+
141
+ if args.checkpoint is None:
142
+ try:
143
+ from huggingface_hub import hf_hub_download
144
+ print("Downloading/Loading checkpoint from Hugging Face Hub (backtracking/tiny-tts)...")
145
+ args.checkpoint = hf_hub_download(repo_id="backtracking/tiny-tts", filename="G.pth")
146
+ except ImportError:
147
+ print("Error: huggingface_hub is required for auto-download. Run: pip install huggingface_hub")
148
+ sys.exit(1)
149
+ except Exception as e:
150
+ print(f"Error downloading checkpoint: {e}")
151
+ sys.exit(1)
152
+
153
+ if not os.path.exists(args.checkpoint):
154
+ print(f"Error: Checkpoint or directory not found at {args.checkpoint}")
155
+ sys.exit(1)
156
+
157
+ if os.path.isdir(args.checkpoint):
158
+ latest_ckpt = get_latest_checkpoint(args.checkpoint)
159
+ if not latest_ckpt:
160
+ print(f"Error: No G_*.pth checkpoints found in directory {args.checkpoint}")
161
+ sys.exit(1)
162
+ args.checkpoint = latest_ckpt
163
+ print(f"Auto-detected latest checkpoint: {args.checkpoint}")
164
+
165
+ # Extract step from checkpoint filename
166
+ ckpt_basename = os.path.basename(args.checkpoint)
167
+ match = re.search(r'_(\d+)\.pth', ckpt_basename)
168
+ step_str = match.group(1) if match else "unknown"
169
+
170
+ # Save to output folder
171
+ out_dir = "infer_outputs"
172
+ os.makedirs(out_dir, exist_ok=True)
173
+
174
+ out_name = os.path.basename(args.output)
175
+ name, ext = os.path.splitext(out_name)
176
+ model = load_engine(args.checkpoint, args.device)
177
+
178
+ if args.speaker.lower() == "all":
179
+ if not SPK2ID:
180
+ print("Error: No speakers found")
181
+ sys.exit(1)
182
+ print(f"Synthesizing for all {len(SPK2ID)} speakers...")
183
+ for spk in SPK2ID.keys():
184
+ final_output = os.path.join(out_dir, f"{name}_step{step_str}_spk{spk}{ext}")
185
+ synthesize(args.text, final_output, model, speaker=spk, device=args.device, speed=args.speed)
186
+ else:
187
+ final_output = os.path.join(out_dir, f"{name}_step{step_str}_spk{args.speaker}{ext}")
188
+ synthesize(args.text, final_output, model, speaker=args.speaker, device=args.device, speed=args.speed)
189
+
190
+ if __name__ == "__main__":
191
+ main()
tiny_tts/infer_onnx.py ADDED
@@ -0,0 +1,199 @@
1
+ """
2
+ ONNX Runtime inference engine for TinyTTS.
3
+
4
+ Replaces the PyTorch VoiceSynthesizer.infer() with equivalent
5
+ ONNX Runtime sessions + NumPy ops for the non-exported parts
6
+ (alignment path computation).
7
+ """
8
+ import os
9
+ import numpy as np
10
+ import soundfile as sf
11
+
12
+ from tiny_tts.text.english import normalize_text, grapheme_to_phoneme
13
+ from tiny_tts.text import phonemes_to_ids
14
+ from tiny_tts.nn import commons
15
+ from tiny_tts.utils.config import (
16
+ SAMPLING_RATE, ADD_BLANK, SPK2ID,
17
+ )
18
+
19
+ try:
20
+ import onnxruntime as ort
21
+ except ImportError:
22
+ raise ImportError("onnxruntime is required. Run: pip install onnxruntime")
23
+
24
+
25
+ def _build_session(path: str, use_gpu: bool = False):
26
+ """Create an ORT InferenceSession with optional GPU support."""
27
+ providers = (
28
+ ["CUDAExecutionProvider", "CPUExecutionProvider"]
29
+ if use_gpu else
30
+ ["CPUExecutionProvider"]
31
+ )
32
+ opts = ort.SessionOptions()
33
+ opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
34
+ opts.intra_op_num_threads = os.cpu_count() or 4
35
+ return ort.InferenceSession(path, sess_options=opts, providers=providers)
36
+
37
+
38
+ def _create_length_mask_np(lengths, max_len=None):
39
+ """NumPy equivalent of commons.create_length_mask."""
40
+ if max_len is None:
41
+ max_len = int(lengths.max())
42
+ ids = np.arange(max_len, dtype=np.float32) # [T]
43
+ mask = (ids[None, :] < lengths[:, None]).astype(np.float32) # [B, T]
44
+ return mask
45
+
46
+
47
+ def _compute_alignment_path_np(w_ceil, attn_mask):
48
+ """
49
+ Monotonic alignment path - vectorized via cumsum (much faster than Python loops).
50
+ w_ceil: [B, 1, T_x] — integer duration per phone
51
+ attn_mask: [B, 1, T_y, T_x] — joint mask
52
+ Returns attn: [B, 1, T_y, T_x]
53
+ """
54
+ B, _, T_x = w_ceil.shape
55
+ T_y = attn_mask.shape[2]
56
+
57
+ # Build duration matrix: for each phone column expand the duration
58
+ # cumulative sum of durations gives us the end frame index for each phone
59
+ dur = w_ceil[:, 0, :] # [B, T_x]
60
+ cum_dur = np.cumsum(dur, axis=1) # [B, T_x] — end frame (1-indexed)
61
+ cum_dur_prev = np.pad(cum_dur[:, :-1], ((0,0),(1,0))) # [B, T_x] — start frame
62
+
63
+ # Frame indices: [1, T_y, 1]
64
+ frame_idx = np.arange(T_y, dtype=np.float32)[None, :, None] # [1, T_y, 1]
65
+ # For each phone, mark frames [start, end)
66
+ # cum_dur_prev: [B,1,T_x], cum_dur: [B,1,T_x]
67
+ start = cum_dur_prev[:, None, :] # [B, 1, T_x]
68
+ end = cum_dur[:, None, :] # [B, 1, T_x]
69
+ attn = ((frame_idx >= start) & (frame_idx < end)).astype(np.float32) # [B, T_y, T_x]
70
+ attn = attn[:, None, :, :] # [B, 1, T_y, T_x]
71
+ return attn * attn_mask
72
+
73
+
74
+ class OnnxTinyTTS:
75
+ """
76
+ Inference using ONNX Runtime.
77
+
78
+ Args:
79
+ onnx_dir: directory containing the 4 .onnx files
80
+ use_gpu: if True, try CUDAExecutionProvider
81
+ """
82
+
83
+ def __init__(self, onnx_dir: str = "onnx", use_gpu: bool = False):
84
+ onnx_dir = os.path.abspath(onnx_dir)
85
+ print(f"Loading ONNX sessions from: {onnx_dir}")
86
+
87
+ self._enc = _build_session(os.path.join(onnx_dir, "text_encoder.onnx"), use_gpu)
88
+ self._dp = _build_session(os.path.join(onnx_dir, "duration_predictor.onnx"), use_gpu)
89
+ self._flow = _build_session(os.path.join(onnx_dir, "flow.onnx"), use_gpu)
90
+ self._dec = _build_session(os.path.join(onnx_dir, "decoder.onnx"), use_gpu)
91
+
92
+ print("ONNX sessions ready ✅")
93
+
94
+ def _text_to_ids(self, text: str):
95
+ normalized = normalize_text(text)
96
+ phones, tones, _ = grapheme_to_phoneme(normalized)
97
+ phone_ids, tone_ids, lang_ids = phonemes_to_ids(phones, tones, "EN")
98
+
99
+ if ADD_BLANK:
100
+ phone_ids = commons.insert_blanks(phone_ids, 0)
101
+ tone_ids = commons.insert_blanks(tone_ids, 0)
102
+ lang_ids = commons.insert_blanks(lang_ids, 0)
103
+
104
+ return phone_ids, tone_ids, lang_ids
105
+
106
+ def speak(
107
+ self,
108
+ text: str,
109
+ output_path: str = "onnx_output.wav",
110
+ speaker: str = "female",
111
+ noise_scale: float = 0.667,
112
+ noise_scale_w: float = 0.8,
113
+ length_scale: float = 1.0,
114
+ output_sr: int = None,
115
+ ) -> np.ndarray:
116
+ """Synthesize speech and save to output_path.
117
+
118
+ Args:
119
+ output_sr: If set (e.g. 22050), resample the output from 44100 Hz.
120
+ Useful to reduce file size while keeping quality.
121
+ """
122
+ print(f"[ONNX] Synthesizing: {text}")
123
+
124
+ phone_ids, tone_ids, lang_ids = self._text_to_ids(text)
125
+ T = len(phone_ids)
126
+
127
+ # Prepare inputs as float32 / int64 arrays
128
+ x = np.array(phone_ids, dtype=np.int64)[None, :] # [1, T]
129
+ x_len = np.array([T], dtype=np.int64) # [1]
130
+ tone = np.array(tone_ids, dtype=np.int64)[None, :] # [1, T]
131
+ lang = np.array(lang_ids, dtype=np.int64)[None, :] # [1, T]
132
+ bert = np.zeros((1, 1024, T), dtype=np.float32)
133
+ ja_bert = np.zeros((1, 768, T), dtype=np.float32)
134
+ sid_val = SPK2ID.get(speaker, 0)
135
+ sid = np.array([sid_val], dtype=np.int64) # [1]
136
+
137
+ # ── 1. Text Encoder ──────────────────────────────────────────────
138
+ x_enc, m_p, logs_p, x_mask, g = self._enc.run(
139
+ None,
140
+ {
141
+ "phone_ids": x,
142
+ "phone_lengths":x_len,
143
+ "tone_ids": tone,
144
+ "language_ids": lang,
145
+ "bert": bert,
146
+ "ja_bert": ja_bert,
147
+ "speaker_id": sid,
148
+ },
149
+ )
150
+
151
+ # ── 2. Duration Predictor ─────────────────────────────────────────
152
+ logw = self._dp.run(None, {"x": x_enc, "x_mask": x_mask, "g": g})[0]
153
+
154
+ # ── 3. Alignment Path (NumPy) ─────────────────────────────────────
155
+ w = np.exp(logw) * x_mask * length_scale # [1, 1, T]
156
+ w_ceil = np.ceil(w) # [1, 1, T]
157
+ y_len = max(1, int(w_ceil.sum()))
158
+ y_lens = np.array([y_len], dtype=np.int64)
159
+
160
+ y_mask = _create_length_mask_np(y_lens, y_len) # [1, T_y]
161
+ y_mask = y_mask[:, None, :] # [1, 1, T_y]
162
+ # attn_mask: [1, 1, T_y, T_x] (outer product of frame mask and phone mask)
163
+ attn_mask = y_mask[:, :, :, None] * x_mask[:, :, None, :] # [1,1,T_y,T_x]
164
+ attn = _compute_alignment_path_np(w_ceil, attn_mask) # [1, 1, T_y, T_x]
165
+
166
+ # Expand prior stats via alignment
167
+ m_p_exp = np.matmul(attn[:, 0], m_p.transpose(0, 2, 1)).transpose(0, 2, 1)
168
+ logs_p_exp = np.matmul(attn[:, 0], logs_p.transpose(0, 2, 1)).transpose(0, 2, 1)
169
+
170
+ # ── 4. Sample z_p ─────────────────────────────────────────────────
171
+ z_p = m_p_exp + np.random.randn(*m_p_exp.shape).astype(np.float32) * \
172
+ np.exp(logs_p_exp) * noise_scale
173
+
174
+ # ── 5. Flow (reverse) ─────────────────────────────────────────────
175
+ z = self._flow.run(
176
+ None,
177
+ {"z_p": z_p, "y_mask": y_mask.astype(np.float32), "g": g},
178
+ )[0]
179
+
180
+ # ── 6. Decoder ────────────────────────────────────────────────────
181
+ z_masked = (z * y_mask).astype(np.float32)
182
+ audio = self._dec.run(None, {"z": z_masked, "g": g})[0] # [1, 1, samples]
183
+
184
+ audio_np = audio[0, 0]
185
+ save_sr = SAMPLING_RATE
186
+ if output_sr is not None and output_sr != SAMPLING_RATE:
187
+ try:
188
+ import torchaudio
189
+ import torch
190
+ wav_t = torch.from_numpy(audio_np).unsqueeze(0)
191
+ resampler = torchaudio.transforms.Resample(SAMPLING_RATE, output_sr)
192
+ audio_np = resampler(wav_t).squeeze(0).numpy()
193
+ save_sr = output_sr
194
+ except Exception as e:
195
+ print(f"[ONNX] Resampling failed ({e}), saving at {SAMPLING_RATE}Hz")
196
+
197
+ sf.write(output_path, audio_np, save_sr)
198
+ print(f"[ONNX] Saved: {output_path} ({save_sr}Hz)")
199
+ return audio_np
@@ -0,0 +1,6 @@
1
+ """
2
+ TinyTTS Integrations Package
3
+ """
4
+ from .langchain import TinyTTSTool, TinyTTSVoiceTool, create_tts_chain
5
+
6
+ __all__ = ["TinyTTSTool", "TinyTTSVoiceTool", "create_tts_chain"]