sopro 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sopro/__init__.py +6 -0
- sopro/audio.py +155 -0
- sopro/cli.py +185 -0
- sopro/codec/__init__.py +3 -0
- sopro/codec/mimi.py +181 -0
- sopro/config.py +48 -0
- sopro/constants.py +5 -0
- sopro/hub.py +53 -0
- sopro/model.py +853 -0
- sopro/nn/__init__.py +20 -0
- sopro/nn/blocks.py +110 -0
- sopro/nn/embeddings.py +96 -0
- sopro/nn/speaker.py +88 -0
- sopro/nn/xattn.py +98 -0
- sopro/sampling.py +101 -0
- sopro/streaming.py +165 -0
- sopro/tokenizer.py +38 -0
- sopro-1.0.0.dist-info/METADATA +182 -0
- sopro-1.0.0.dist-info/RECORD +23 -0
- sopro-1.0.0.dist-info/WHEEL +5 -0
- sopro-1.0.0.dist-info/entry_points.txt +2 -0
- sopro-1.0.0.dist-info/licenses/LICENSE.txt +201 -0
- sopro-1.0.0.dist-info/top_level.txt +1 -0
sopro/model.py
ADDED
|
@@ -0,0 +1,853 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Dict, Iterator, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
from sopro.config import SoproTTSConfig
|
|
11
|
+
from sopro.hub import (
|
|
12
|
+
download_repo,
|
|
13
|
+
load_cfg_from_safetensors,
|
|
14
|
+
load_state_dict_from_safetensors,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from .audio import save_audio
|
|
18
|
+
from .codec.mimi import MimiCodec
|
|
19
|
+
from .constants import TARGET_SR
|
|
20
|
+
from .nn import (
|
|
21
|
+
CodebookEmbedding,
|
|
22
|
+
RefXAttn,
|
|
23
|
+
RMSNorm,
|
|
24
|
+
SinusoidalPositionalEmbedding,
|
|
25
|
+
SpeakerFiLM,
|
|
26
|
+
SSMLiteBlock,
|
|
27
|
+
TextEmbedding,
|
|
28
|
+
TextXAttnBlock,
|
|
29
|
+
Token2SV,
|
|
30
|
+
)
|
|
31
|
+
from .sampling import center_crop_tokens, repeated_tail
|
|
32
|
+
from .sampling import rf_ar as rf_ar_fn
|
|
33
|
+
from .sampling import rf_nar as rf_nar_fn
|
|
34
|
+
from .sampling import sample_token
|
|
35
|
+
from .tokenizer import TextTokenizer
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class TextEncoder(nn.Module):
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
cfg: SoproTTSConfig,
|
|
42
|
+
d_model: int,
|
|
43
|
+
n_layers: int,
|
|
44
|
+
tokenizer: TextTokenizer,
|
|
45
|
+
):
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.tok = tokenizer
|
|
48
|
+
self.embed = TextEmbedding(self.tok.vocab_size, d_model)
|
|
49
|
+
self.layers = nn.ModuleList(
|
|
50
|
+
[SSMLiteBlock(d_model, cfg.dropout, causal=False) for _ in range(n_layers)]
|
|
51
|
+
)
|
|
52
|
+
self.pos = SinusoidalPositionalEmbedding(d_model, max_len=cfg.max_text_len + 8)
|
|
53
|
+
self.norm = RMSNorm(d_model)
|
|
54
|
+
|
|
55
|
+
def forward(
|
|
56
|
+
self, text_ids: torch.Tensor, mask: torch.Tensor
|
|
57
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
58
|
+
x = self.embed(text_ids)
|
|
59
|
+
L = x.size(1)
|
|
60
|
+
pos = self.pos(torch.arange(L, device=x.device))
|
|
61
|
+
x = x + pos.unsqueeze(0)
|
|
62
|
+
|
|
63
|
+
x = x * mask.unsqueeze(-1).float()
|
|
64
|
+
for layer in self.layers:
|
|
65
|
+
x = layer(x)
|
|
66
|
+
x = self.norm(x)
|
|
67
|
+
|
|
68
|
+
mask_f = mask.float().unsqueeze(-1)
|
|
69
|
+
pooled = (x * mask_f).sum(dim=1) / (mask_f.sum(dim=1) + 1e-6)
|
|
70
|
+
return x, pooled
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class ARRVQ1Generator(nn.Module):
|
|
74
|
+
def __init__(self, cfg: SoproTTSConfig, d_model: int, vocab: int):
|
|
75
|
+
super().__init__()
|
|
76
|
+
ks = cfg.ar_kernel
|
|
77
|
+
dils: List[int] = []
|
|
78
|
+
while len(dils) < cfg.n_layers_ar:
|
|
79
|
+
dils.extend(list(cfg.ar_dilation_cycle))
|
|
80
|
+
dils = dils[: cfg.n_layers_ar]
|
|
81
|
+
|
|
82
|
+
self.dils = tuple(int(d) for d in dils)
|
|
83
|
+
self.blocks = nn.ModuleList(
|
|
84
|
+
[
|
|
85
|
+
SSMLiteBlock(
|
|
86
|
+
d_model, cfg.dropout, causal=True, kernel_size=ks, dilation=d
|
|
87
|
+
)
|
|
88
|
+
for d in self.dils
|
|
89
|
+
]
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
self.attn_freq = int(cfg.ar_text_attn_freq)
|
|
93
|
+
self.x_attns = nn.ModuleList()
|
|
94
|
+
for i in range(len(self.blocks)):
|
|
95
|
+
if (i + 1) % self.attn_freq == 0:
|
|
96
|
+
self.x_attns.append(
|
|
97
|
+
TextXAttnBlock(d_model, heads=4, dropout=cfg.dropout)
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
self.x_attns.append(nn.Identity())
|
|
101
|
+
|
|
102
|
+
self.norm = RMSNorm(d_model)
|
|
103
|
+
self.head = nn.Linear(d_model, vocab)
|
|
104
|
+
|
|
105
|
+
def forward(
|
|
106
|
+
self,
|
|
107
|
+
x: torch.Tensor,
|
|
108
|
+
text_emb: Optional[torch.Tensor] = None,
|
|
109
|
+
text_mask: Optional[torch.Tensor] = None,
|
|
110
|
+
) -> torch.Tensor:
|
|
111
|
+
key_padding_mask = ~text_mask if text_mask is not None else None
|
|
112
|
+
|
|
113
|
+
if key_padding_mask is not None:
|
|
114
|
+
bad_rows = key_padding_mask.all(dim=1)
|
|
115
|
+
if bad_rows.any():
|
|
116
|
+
key_padding_mask = key_padding_mask.clone()
|
|
117
|
+
idx = torch.nonzero(bad_rows, as_tuple=False).squeeze(1)
|
|
118
|
+
key_padding_mask[idx, 0] = False
|
|
119
|
+
if text_emb is not None:
|
|
120
|
+
text_emb = text_emb.clone()
|
|
121
|
+
text_emb[idx, 0, :] = 0
|
|
122
|
+
|
|
123
|
+
h = x
|
|
124
|
+
for i, lyr in enumerate(self.blocks):
|
|
125
|
+
h = lyr(h)
|
|
126
|
+
if not isinstance(self.x_attns[i], nn.Identity) and text_emb is not None:
|
|
127
|
+
h = self.x_attns[i](h, text_emb, key_padding_mask=key_padding_mask)
|
|
128
|
+
|
|
129
|
+
h = self.norm(h)
|
|
130
|
+
return self.head(h)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class StageRefiner(nn.Module):
|
|
134
|
+
def __init__(self, cfg: SoproTTSConfig, D: int, num_heads: int, codebook_size: int):
|
|
135
|
+
super().__init__()
|
|
136
|
+
self.blocks = nn.ModuleList(
|
|
137
|
+
[SSMLiteBlock(D, cfg.dropout, causal=True) for _ in range(cfg.n_layers_nar)]
|
|
138
|
+
)
|
|
139
|
+
self.norm = RMSNorm(D)
|
|
140
|
+
self.pre = nn.Linear(D, cfg.nar_head_dim)
|
|
141
|
+
self.heads = nn.ModuleList(
|
|
142
|
+
[nn.Linear(cfg.nar_head_dim, codebook_size) for _ in range(num_heads)]
|
|
143
|
+
)
|
|
144
|
+
self.mix = nn.Parameter(torch.ones(2, dtype=torch.float32))
|
|
145
|
+
|
|
146
|
+
def forward_hidden(
|
|
147
|
+
self, cond_bt_d: torch.Tensor, prev_bt_d: torch.Tensor
|
|
148
|
+
) -> torch.Tensor:
|
|
149
|
+
w = torch.softmax(self.mix, dim=0)
|
|
150
|
+
x = w[0] * cond_bt_d + w[1] * prev_bt_d
|
|
151
|
+
for b in self.blocks:
|
|
152
|
+
x = b(x)
|
|
153
|
+
return self.norm(x)
|
|
154
|
+
|
|
155
|
+
def forward_heads(self, h: torch.Tensor) -> List[torch.Tensor]:
|
|
156
|
+
z = self.pre(h)
|
|
157
|
+
return [head(z) for head in self.heads]
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class StopHead(nn.Module):
|
|
161
|
+
def __init__(self, D: int):
|
|
162
|
+
super().__init__()
|
|
163
|
+
self.proj = nn.Linear(D, 1)
|
|
164
|
+
|
|
165
|
+
def forward(self, h: torch.Tensor) -> torch.Tensor:
|
|
166
|
+
return self.proj(h).squeeze(-1)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class SoproTTSModel(nn.Module):
|
|
170
|
+
def __init__(self, cfg: SoproTTSConfig, tokenizer: TextTokenizer):
|
|
171
|
+
super().__init__()
|
|
172
|
+
self.cfg = cfg
|
|
173
|
+
D = int(cfg.d_model)
|
|
174
|
+
|
|
175
|
+
self.text_enc = TextEncoder(cfg, D, cfg.n_layers_text, tokenizer)
|
|
176
|
+
self.frame_pos = SinusoidalPositionalEmbedding(D, max_len=cfg.pos_emb_max + 2)
|
|
177
|
+
|
|
178
|
+
self.cb_embed = CodebookEmbedding(
|
|
179
|
+
cfg.num_codebooks, cfg.codebook_size, D, use_bos=True
|
|
180
|
+
)
|
|
181
|
+
self.rvq1_bos_id = self.cb_embed.bos_id
|
|
182
|
+
|
|
183
|
+
self.token2sv = Token2SV(
|
|
184
|
+
cfg.num_codebooks,
|
|
185
|
+
cfg.codebook_size,
|
|
186
|
+
d=192,
|
|
187
|
+
out_dim=cfg.sv_student_dim,
|
|
188
|
+
dropout=cfg.dropout,
|
|
189
|
+
)
|
|
190
|
+
self.spk_film = SpeakerFiLM(D, sv_dim=cfg.sv_student_dim)
|
|
191
|
+
self.cond_norm = RMSNorm(D)
|
|
192
|
+
|
|
193
|
+
self.ar = ARRVQ1Generator(cfg, D, cfg.codebook_size)
|
|
194
|
+
if cfg.ar_lookback > 0:
|
|
195
|
+
self.ar_hist_w = nn.Parameter(torch.zeros(cfg.ar_lookback))
|
|
196
|
+
|
|
197
|
+
def idxs(rng: Tuple[int, int]) -> List[int]:
|
|
198
|
+
lo, hi = rng
|
|
199
|
+
return list(range(lo - 1, hi))
|
|
200
|
+
|
|
201
|
+
self.stage_indices: Dict[str, List[int]] = {
|
|
202
|
+
"B": idxs(cfg.stage_B),
|
|
203
|
+
"C": idxs(cfg.stage_C),
|
|
204
|
+
"D": idxs(cfg.stage_D),
|
|
205
|
+
"E": idxs(cfg.stage_E),
|
|
206
|
+
}
|
|
207
|
+
self.stages = nn.ModuleDict(
|
|
208
|
+
{
|
|
209
|
+
s: StageRefiner(cfg, D, len(self.stage_indices[s]), cfg.codebook_size)
|
|
210
|
+
for s in ["B", "C", "D", "E"]
|
|
211
|
+
}
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
self.stop_head = StopHead(D) if cfg.use_stop_head else None
|
|
215
|
+
|
|
216
|
+
self.ref_enc_blocks = nn.ModuleList(
|
|
217
|
+
[SSMLiteBlock(D, cfg.dropout, causal=False) for _ in range(2)]
|
|
218
|
+
)
|
|
219
|
+
self.ref_enc_norm = RMSNorm(D)
|
|
220
|
+
self.ref_xattn_stack = RefXAttn(
|
|
221
|
+
D, heads=cfg.ref_attn_heads, layers=3, dropout=cfg.dropout
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
def rf_ar(self) -> int:
|
|
225
|
+
return rf_ar_fn(self.cfg.ar_kernel, self.ar.dils)
|
|
226
|
+
|
|
227
|
+
def rf_nar(self) -> int:
|
|
228
|
+
return rf_nar_fn(self.cfg.n_layers_nar, kernel_size=7, dilation=1)
|
|
229
|
+
|
|
230
|
+
def _pool_time(self, x: torch.Tensor, factor: int) -> torch.Tensor:
|
|
231
|
+
if factor <= 1 or x.size(1) < 2 * factor:
|
|
232
|
+
return x
|
|
233
|
+
return F.avg_pool1d(
|
|
234
|
+
x.transpose(1, 2), kernel_size=factor, stride=factor
|
|
235
|
+
).transpose(1, 2)
|
|
236
|
+
|
|
237
|
+
def _normalize_ref_mask(
|
|
238
|
+
self, ref_mask: Optional[torch.Tensor], device: torch.device
|
|
239
|
+
) -> Optional[torch.Tensor]:
|
|
240
|
+
if ref_mask is None:
|
|
241
|
+
return None
|
|
242
|
+
mk = ref_mask.to(device).bool()
|
|
243
|
+
if mk.ndim == 1:
|
|
244
|
+
mk = mk.unsqueeze(0)
|
|
245
|
+
return mk
|
|
246
|
+
|
|
247
|
+
def _encode_reference_seq(self, ref_tokens: torch.Tensor) -> torch.Tensor:
|
|
248
|
+
B, Tr, Q = ref_tokens.shape
|
|
249
|
+
emb_sum = 0.0
|
|
250
|
+
for q in range(Q):
|
|
251
|
+
emb_sum = emb_sum + self.cb_embed.embed_tokens(
|
|
252
|
+
ref_tokens[:, :, q], cb_index=q
|
|
253
|
+
)
|
|
254
|
+
x = emb_sum / float(Q)
|
|
255
|
+
for b in self.ref_enc_blocks:
|
|
256
|
+
x = b(x)
|
|
257
|
+
return self.ref_enc_norm(x)
|
|
258
|
+
|
|
259
|
+
def _single_pass_ref_xattn(
|
|
260
|
+
self,
|
|
261
|
+
cond_btd: torch.Tensor,
|
|
262
|
+
ref_seq: torch.Tensor,
|
|
263
|
+
ref_mask: Optional[torch.Tensor] = None,
|
|
264
|
+
) -> torch.Tensor:
|
|
265
|
+
ref_seq_p = self._pool_time(ref_seq, 1)
|
|
266
|
+
|
|
267
|
+
key_padding_mask = None
|
|
268
|
+
if ref_mask is not None:
|
|
269
|
+
mk_bool = ref_mask.bool()
|
|
270
|
+
B, Tr = mk_bool.shape
|
|
271
|
+
pooled_len = ref_seq_p.size(1)
|
|
272
|
+
if pooled_len == Tr:
|
|
273
|
+
key_padding_mask = ~mk_bool
|
|
274
|
+
else:
|
|
275
|
+
cut = pooled_len * 2
|
|
276
|
+
mk2 = mk_bool[:, :cut].reshape(B, pooled_len, 2).any(dim=2)
|
|
277
|
+
key_padding_mask = ~mk2
|
|
278
|
+
|
|
279
|
+
return self.ref_xattn_stack(
|
|
280
|
+
cond_btd, ref_seq_p, key_padding_mask=key_padding_mask
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def _base_cond_at(
|
|
284
|
+
self, t: int, txt_pool: torch.Tensor, device: torch.device
|
|
285
|
+
) -> torch.Tensor:
|
|
286
|
+
pos = self.frame_pos(torch.tensor([t], device=device)).unsqueeze(0)
|
|
287
|
+
return txt_pool[:, None, :] + pos
|
|
288
|
+
|
|
289
|
+
def _ar_prev_from_seq(self, seq_1xT: torch.Tensor) -> torch.Tensor:
|
|
290
|
+
K = int(self.cfg.ar_lookback)
|
|
291
|
+
if K <= 0 or getattr(self, "ar_hist_w", None) is None:
|
|
292
|
+
return self.cb_embed.embed_shift_by_k(seq_1xT, cb_index=0, k=1)
|
|
293
|
+
|
|
294
|
+
ws = torch.softmax(self.ar_hist_w, dim=0)
|
|
295
|
+
acc = 0.0
|
|
296
|
+
k_max = min(K, int(seq_1xT.size(1)))
|
|
297
|
+
for k in range(1, k_max + 1):
|
|
298
|
+
acc = acc + ws[k - 1] * self.cb_embed.embed_shift_by_k(
|
|
299
|
+
seq_1xT, cb_index=0, k=k
|
|
300
|
+
)
|
|
301
|
+
return acc
|
|
302
|
+
|
|
303
|
+
@torch.no_grad()
|
|
304
|
+
def prepare_conditioning(
|
|
305
|
+
self,
|
|
306
|
+
text_ids_1d: torch.Tensor,
|
|
307
|
+
ref_tokens_tq: torch.Tensor,
|
|
308
|
+
*,
|
|
309
|
+
max_frames: int,
|
|
310
|
+
device: torch.device,
|
|
311
|
+
style_strength: float = 1.0,
|
|
312
|
+
ref_mask: Optional[torch.Tensor] = None,
|
|
313
|
+
chunk_size: Optional[int] = None,
|
|
314
|
+
) -> Dict[str, torch.Tensor]:
|
|
315
|
+
self.eval()
|
|
316
|
+
|
|
317
|
+
text_ids = text_ids_1d.to(device)
|
|
318
|
+
text_mask = torch.ones_like(text_ids, dtype=torch.bool).unsqueeze(0)
|
|
319
|
+
txt_seq, txt_pool = self.text_enc(text_ids.unsqueeze(0), text_mask)
|
|
320
|
+
|
|
321
|
+
ref_btq = ref_tokens_tq.unsqueeze(0).to(device)
|
|
322
|
+
|
|
323
|
+
sv_ref = self.token2sv(ref_btq, lengths=None)
|
|
324
|
+
ref_seq = self._encode_reference_seq(ref_btq)
|
|
325
|
+
ref_mask_btr = self._normalize_ref_mask(ref_mask, device)
|
|
326
|
+
|
|
327
|
+
T = int(max_frames)
|
|
328
|
+
if T <= 0:
|
|
329
|
+
cond_all = torch.zeros(
|
|
330
|
+
(1, 0, txt_pool.size(-1)), device=device, dtype=txt_pool.dtype
|
|
331
|
+
)
|
|
332
|
+
else:
|
|
333
|
+
pos = self.frame_pos(torch.arange(T, device=device)).unsqueeze(0)
|
|
334
|
+
base_all = txt_pool[:, None, :] + pos
|
|
335
|
+
base_all = self.spk_film(base_all, sv_ref, strength=float(style_strength))
|
|
336
|
+
|
|
337
|
+
if chunk_size is None or int(chunk_size) >= T:
|
|
338
|
+
out = self._single_pass_ref_xattn(
|
|
339
|
+
base_all, ref_seq, ref_mask=ref_mask_btr
|
|
340
|
+
)
|
|
341
|
+
cond_all = self.cond_norm(out)
|
|
342
|
+
else:
|
|
343
|
+
cs = int(chunk_size)
|
|
344
|
+
chunks: List[torch.Tensor] = []
|
|
345
|
+
for s in range(0, T, cs):
|
|
346
|
+
e = min(T, s + cs)
|
|
347
|
+
q = base_all[:, s:e, :]
|
|
348
|
+
out = self._single_pass_ref_xattn(q, ref_seq, ref_mask=ref_mask_btr)
|
|
349
|
+
chunks.append(self.cond_norm(out))
|
|
350
|
+
cond_all = torch.cat(chunks, dim=1)
|
|
351
|
+
|
|
352
|
+
return {
|
|
353
|
+
"txt_seq": txt_seq,
|
|
354
|
+
"text_mask": text_mask,
|
|
355
|
+
"cond_all": cond_all,
|
|
356
|
+
"ref_btq": ref_btq,
|
|
357
|
+
"txt_pool": txt_pool,
|
|
358
|
+
"sv_ref": sv_ref,
|
|
359
|
+
"ref_seq": ref_seq,
|
|
360
|
+
"ref_mask": (
|
|
361
|
+
ref_mask_btr
|
|
362
|
+
if ref_mask_btr is not None
|
|
363
|
+
else torch.empty(0, device=device, dtype=torch.bool)
|
|
364
|
+
),
|
|
365
|
+
"style_strength": torch.tensor(float(style_strength), device=device),
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
@torch.no_grad()
|
|
369
|
+
def prepare_conditioning_lazy(
|
|
370
|
+
self,
|
|
371
|
+
text_ids_1d: torch.Tensor,
|
|
372
|
+
ref_tokens_tq: torch.Tensor,
|
|
373
|
+
*,
|
|
374
|
+
max_frames: int,
|
|
375
|
+
device: torch.device,
|
|
376
|
+
style_strength: float = 1.0,
|
|
377
|
+
ref_mask: Optional[torch.Tensor] = None,
|
|
378
|
+
) -> Dict[str, torch.Tensor]:
|
|
379
|
+
self.eval()
|
|
380
|
+
|
|
381
|
+
text_ids = text_ids_1d.to(device)
|
|
382
|
+
text_mask = torch.ones_like(text_ids, dtype=torch.bool).unsqueeze(0)
|
|
383
|
+
txt_seq, txt_pool = self.text_enc(text_ids.unsqueeze(0), text_mask)
|
|
384
|
+
|
|
385
|
+
ref_btq = ref_tokens_tq.unsqueeze(0).to(device)
|
|
386
|
+
sv_ref = self.token2sv(ref_btq, lengths=None)
|
|
387
|
+
ref_seq = self._encode_reference_seq(ref_btq)
|
|
388
|
+
ref_mask_btr = self._normalize_ref_mask(ref_mask, device)
|
|
389
|
+
|
|
390
|
+
D = int(txt_pool.size(-1))
|
|
391
|
+
cond_all = torch.zeros((1, 0, D), device=device, dtype=txt_pool.dtype)
|
|
392
|
+
|
|
393
|
+
return {
|
|
394
|
+
"txt_seq": txt_seq,
|
|
395
|
+
"text_mask": text_mask,
|
|
396
|
+
"cond_all": cond_all,
|
|
397
|
+
"ref_btq": ref_btq,
|
|
398
|
+
"txt_pool": txt_pool,
|
|
399
|
+
"sv_ref": sv_ref,
|
|
400
|
+
"ref_seq": ref_seq,
|
|
401
|
+
"ref_mask": (
|
|
402
|
+
ref_mask_btr
|
|
403
|
+
if ref_mask_btr is not None
|
|
404
|
+
else torch.empty(0, device=device, dtype=torch.bool)
|
|
405
|
+
),
|
|
406
|
+
"style_strength": torch.tensor(float(style_strength), device=device),
|
|
407
|
+
"max_frames": torch.tensor(int(max_frames), device=device),
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
@torch.no_grad()
|
|
411
|
+
def ensure_cond_upto(
|
|
412
|
+
self,
|
|
413
|
+
prep: Dict[str, torch.Tensor],
|
|
414
|
+
t_inclusive: int,
|
|
415
|
+
*,
|
|
416
|
+
chunk_size: int = 64,
|
|
417
|
+
) -> None:
|
|
418
|
+
if t_inclusive < 0:
|
|
419
|
+
return
|
|
420
|
+
|
|
421
|
+
cond_all = prep.get("cond_all", None)
|
|
422
|
+
if cond_all is None:
|
|
423
|
+
raise KeyError("prep dict missing 'cond_all'.")
|
|
424
|
+
|
|
425
|
+
have = int(cond_all.size(1))
|
|
426
|
+
need_min = int(t_inclusive) + 1
|
|
427
|
+
if have >= need_min:
|
|
428
|
+
return
|
|
429
|
+
|
|
430
|
+
if "txt_pool" not in prep or "sv_ref" not in prep or "ref_seq" not in prep:
|
|
431
|
+
raise RuntimeError(
|
|
432
|
+
"Lazy conditioning requested but prep lacks txt_pool/sv_ref/ref_seq. "
|
|
433
|
+
"Use prepare_conditioning_lazy() or prepare_conditioning(..., chunk_size=...)."
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
device = cond_all.device
|
|
437
|
+
txt_pool = prep["txt_pool"]
|
|
438
|
+
sv_ref = prep["sv_ref"]
|
|
439
|
+
ref_seq = prep["ref_seq"]
|
|
440
|
+
|
|
441
|
+
style_strength = float(
|
|
442
|
+
prep.get(
|
|
443
|
+
"style_strength", torch.tensor(self.cfg.style_strength, device=device)
|
|
444
|
+
).item()
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
ref_mask = prep.get("ref_mask", None)
|
|
448
|
+
if ref_mask is not None and ref_mask.numel() == 0:
|
|
449
|
+
ref_mask = None
|
|
450
|
+
|
|
451
|
+
max_frames = prep.get("max_frames", None)
|
|
452
|
+
maxT = int(max_frames.item()) if max_frames is not None else None
|
|
453
|
+
|
|
454
|
+
cs = max(1, int(chunk_size))
|
|
455
|
+
|
|
456
|
+
need = ((need_min + cs - 1) // cs) * cs
|
|
457
|
+
if maxT is not None:
|
|
458
|
+
need = min(need, maxT)
|
|
459
|
+
|
|
460
|
+
if have >= need:
|
|
461
|
+
return
|
|
462
|
+
|
|
463
|
+
new_chunks: List[torch.Tensor] = []
|
|
464
|
+
for s in range(have, need, cs):
|
|
465
|
+
e = min(need, s + cs)
|
|
466
|
+
pos = self.frame_pos(torch.arange(s, e, device=device)).unsqueeze(0)
|
|
467
|
+
base = txt_pool[:, None, :] + pos
|
|
468
|
+
base = self.spk_film(base, sv_ref, strength=style_strength)
|
|
469
|
+
out = self._single_pass_ref_xattn(base, ref_seq, ref_mask=ref_mask)
|
|
470
|
+
new_chunks.append(self.cond_norm(out))
|
|
471
|
+
|
|
472
|
+
prep["cond_all"] = torch.cat([cond_all] + new_chunks, dim=1)
|
|
473
|
+
|
|
474
|
+
def build_ar_prefix(
|
|
475
|
+
self,
|
|
476
|
+
ref_btq: torch.Tensor,
|
|
477
|
+
device: torch.device,
|
|
478
|
+
prefix_sec_fixed: Optional[float],
|
|
479
|
+
use_prefix: bool,
|
|
480
|
+
) -> torch.Tensor:
|
|
481
|
+
if not use_prefix or ref_btq.size(1) == 0:
|
|
482
|
+
return torch.zeros(1, 0, dtype=torch.long, device=device)
|
|
483
|
+
|
|
484
|
+
avail = int(ref_btq.size(1))
|
|
485
|
+
fps = float(self.cfg.mimi_fps)
|
|
486
|
+
|
|
487
|
+
if prefix_sec_fixed is not None and prefix_sec_fixed > 0:
|
|
488
|
+
P = min(avail, int(round(prefix_sec_fixed * fps)))
|
|
489
|
+
else:
|
|
490
|
+
P = min(avail, max(1, int(round(self.cfg.preprompt_sec_max * fps))))
|
|
491
|
+
|
|
492
|
+
if P <= 0:
|
|
493
|
+
return torch.zeros(1, 0, dtype=torch.long, device=device)
|
|
494
|
+
return ref_btq[:, :P, 0].contiguous()
|
|
495
|
+
|
|
496
|
+
@torch.no_grad()
|
|
497
|
+
def ar_stream(
|
|
498
|
+
self,
|
|
499
|
+
prep: Dict[str, torch.Tensor],
|
|
500
|
+
*,
|
|
501
|
+
max_frames: int,
|
|
502
|
+
top_p: float = 0.9,
|
|
503
|
+
temperature: float = 1.05,
|
|
504
|
+
anti_loop: bool = True,
|
|
505
|
+
loop_streak: int = 8,
|
|
506
|
+
recovery_top_p: float = 0.85,
|
|
507
|
+
recovery_temp: float = 1.2,
|
|
508
|
+
use_prefix: bool = True,
|
|
509
|
+
prefix_sec_fixed: Optional[float] = None,
|
|
510
|
+
cond_chunk_size: Optional[int] = None,
|
|
511
|
+
use_stop_head: Optional[bool] = None,
|
|
512
|
+
stop_patience: Optional[int] = None,
|
|
513
|
+
stop_threshold: Optional[float] = None,
|
|
514
|
+
min_gen_frames: Optional[int] = None,
|
|
515
|
+
) -> Iterator[Tuple[int, int, Optional[float]]]:
|
|
516
|
+
device = prep["cond_all"].device
|
|
517
|
+
cond_all = prep["cond_all"]
|
|
518
|
+
txt_seq = prep["txt_seq"]
|
|
519
|
+
text_mask = prep["text_mask"]
|
|
520
|
+
ref_btq = prep["ref_btq"]
|
|
521
|
+
|
|
522
|
+
R_AR = self.rf_ar()
|
|
523
|
+
|
|
524
|
+
stop_head = self.stop_head
|
|
525
|
+
if use_stop_head is not None:
|
|
526
|
+
if not bool(use_stop_head):
|
|
527
|
+
stop_head = None
|
|
528
|
+
|
|
529
|
+
eff_stop_patience = int(
|
|
530
|
+
stop_patience if stop_patience is not None else self.cfg.stop_patience
|
|
531
|
+
)
|
|
532
|
+
eff_stop_threshold = float(
|
|
533
|
+
stop_threshold if stop_threshold is not None else self.cfg.stop_threshold
|
|
534
|
+
)
|
|
535
|
+
eff_min_gen_frames = int(
|
|
536
|
+
min_gen_frames if min_gen_frames is not None else self.cfg.min_gen_frames
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
A_prefix = self.build_ar_prefix(
|
|
540
|
+
ref_btq, device, prefix_sec_fixed, use_prefix=use_prefix
|
|
541
|
+
)
|
|
542
|
+
P = int(A_prefix.size(1))
|
|
543
|
+
|
|
544
|
+
ctx_ids = torch.zeros(
|
|
545
|
+
(1, P + int(max_frames) + 1), dtype=torch.long, device=device
|
|
546
|
+
)
|
|
547
|
+
if P > 0:
|
|
548
|
+
ctx_ids[:, :P] = A_prefix
|
|
549
|
+
|
|
550
|
+
hist_A: List[int] = []
|
|
551
|
+
loop_streak_count = 0
|
|
552
|
+
stop_streak_count = 0
|
|
553
|
+
last_a: Optional[int] = None
|
|
554
|
+
|
|
555
|
+
gen_len = 0
|
|
556
|
+
|
|
557
|
+
for t in range(int(max_frames)):
|
|
558
|
+
if prep["cond_all"].size(1) < (t + 1):
|
|
559
|
+
self.ensure_cond_upto(prep, t, chunk_size=int(cond_chunk_size or 64))
|
|
560
|
+
cond_all = prep["cond_all"]
|
|
561
|
+
|
|
562
|
+
L_ar = min(t + 1, R_AR)
|
|
563
|
+
s_ar = t + 1 - L_ar
|
|
564
|
+
cond_win_ar = cond_all[:, s_ar : t + 1, :]
|
|
565
|
+
|
|
566
|
+
total_len = P + gen_len + 1
|
|
567
|
+
A_ctx_full = ctx_ids[:, :total_len]
|
|
568
|
+
|
|
569
|
+
prev_ctx_full = self._ar_prev_from_seq(A_ctx_full)
|
|
570
|
+
prev_ctx_win = prev_ctx_full[:, -L_ar:, :]
|
|
571
|
+
|
|
572
|
+
cur_top_p, cur_temp = top_p, temperature
|
|
573
|
+
if anti_loop:
|
|
574
|
+
if repeated_tail(hist_A, max_n=16):
|
|
575
|
+
cur_top_p, cur_temp = recovery_top_p, recovery_temp
|
|
576
|
+
elif last_a is not None and loop_streak_count >= loop_streak:
|
|
577
|
+
cur_top_p, cur_temp = recovery_top_p, recovery_temp
|
|
578
|
+
|
|
579
|
+
ar_logits_win = self.ar(
|
|
580
|
+
cond_win_ar + prev_ctx_win, text_emb=txt_seq, text_mask=text_mask
|
|
581
|
+
)
|
|
582
|
+
ar_logits_t = ar_logits_win[:, -1:, :]
|
|
583
|
+
|
|
584
|
+
rvq1_id = sample_token(
|
|
585
|
+
ar_logits_t,
|
|
586
|
+
history=hist_A,
|
|
587
|
+
top_p=cur_top_p,
|
|
588
|
+
temperature=cur_temp,
|
|
589
|
+
top_k=50,
|
|
590
|
+
repetition_penalty=1.1,
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
ctx_ids[0, P + gen_len] = int(rvq1_id)
|
|
594
|
+
gen_len += 1
|
|
595
|
+
|
|
596
|
+
hist_A.append(int(rvq1_id))
|
|
597
|
+
loop_streak_count = (
|
|
598
|
+
(loop_streak_count + 1)
|
|
599
|
+
if (last_a is not None and rvq1_id == last_a)
|
|
600
|
+
else 0
|
|
601
|
+
)
|
|
602
|
+
last_a = int(rvq1_id)
|
|
603
|
+
|
|
604
|
+
p_stop: Optional[float] = None
|
|
605
|
+
if stop_head is not None:
|
|
606
|
+
A_now = torch.tensor([[rvq1_id]], device=device, dtype=torch.long)
|
|
607
|
+
stop_inp = (
|
|
608
|
+
cond_all[:, t : t + 1, :]
|
|
609
|
+
+ self.cb_embed.embed_tokens(A_now, cb_index=0).detach()
|
|
610
|
+
)
|
|
611
|
+
stop_logits = stop_head(stop_inp)
|
|
612
|
+
p_stop = float(torch.sigmoid(stop_logits).item())
|
|
613
|
+
|
|
614
|
+
if t + 1 >= eff_min_gen_frames and p_stop > eff_stop_threshold:
|
|
615
|
+
stop_streak_count += 1
|
|
616
|
+
else:
|
|
617
|
+
stop_streak_count = 0
|
|
618
|
+
|
|
619
|
+
yield t, int(rvq1_id), p_stop
|
|
620
|
+
|
|
621
|
+
if stop_head is not None and stop_streak_count >= eff_stop_patience:
|
|
622
|
+
break
|
|
623
|
+
|
|
624
|
+
@torch.no_grad()
|
|
625
|
+
def nar_refine(
|
|
626
|
+
self, cond_seq: torch.Tensor, tokens_A_1xT: torch.Tensor
|
|
627
|
+
) -> torch.Tensor:
|
|
628
|
+
preds_all: List[torch.Tensor] = [tokens_A_1xT.unsqueeze(-1)]
|
|
629
|
+
prev_tokens_list: List[torch.Tensor] = [tokens_A_1xT.unsqueeze(-1)]
|
|
630
|
+
prev_cb_list: List[List[int]] = [[0]]
|
|
631
|
+
|
|
632
|
+
for stage_name in ["B", "C", "D", "E"]:
|
|
633
|
+
idxs = self.stage_indices[stage_name]
|
|
634
|
+
prev_tokens_cat = torch.cat(prev_tokens_list, dim=-1)
|
|
635
|
+
prev_cbs_cat = sum(prev_cb_list, [])
|
|
636
|
+
prev_emb_sum = self.cb_embed.sum_embed_subset(prev_tokens_cat, prev_cbs_cat)
|
|
637
|
+
|
|
638
|
+
h = self.stages[stage_name].forward_hidden(cond_seq, prev_emb_sum)
|
|
639
|
+
logits_list = self.stages[stage_name].forward_heads(h)
|
|
640
|
+
preds = torch.stack([x.argmax(dim=-1) for x in logits_list], dim=-1)
|
|
641
|
+
|
|
642
|
+
preds_all.append(preds)
|
|
643
|
+
prev_tokens_list.append(preds)
|
|
644
|
+
prev_cb_list.append(idxs)
|
|
645
|
+
|
|
646
|
+
tokens_btq = torch.cat(preds_all, dim=-1)
|
|
647
|
+
return tokens_btq
|
|
648
|
+
|
|
649
|
+
@torch.no_grad()
|
|
650
|
+
def generate_tokens(
|
|
651
|
+
self,
|
|
652
|
+
text_ids_1d: torch.Tensor,
|
|
653
|
+
ref_tokens_tq: torch.Tensor,
|
|
654
|
+
*,
|
|
655
|
+
max_frames: int,
|
|
656
|
+
device: torch.device,
|
|
657
|
+
top_p: float = 0.9,
|
|
658
|
+
temperature: float = 1.05,
|
|
659
|
+
anti_loop: bool = True,
|
|
660
|
+
use_prefix: bool = True,
|
|
661
|
+
prefix_sec_fixed: Optional[float] = None,
|
|
662
|
+
style_strength: float = 1.0,
|
|
663
|
+
use_stop_head: Optional[bool] = None,
|
|
664
|
+
stop_patience: Optional[int] = None,
|
|
665
|
+
stop_threshold: Optional[float] = None,
|
|
666
|
+
min_gen_frames: Optional[int] = None,
|
|
667
|
+
) -> torch.Tensor:
|
|
668
|
+
prep = self.prepare_conditioning(
|
|
669
|
+
text_ids_1d,
|
|
670
|
+
ref_tokens_tq,
|
|
671
|
+
max_frames=max_frames,
|
|
672
|
+
device=device,
|
|
673
|
+
style_strength=style_strength,
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
hist_A: List[int] = []
|
|
677
|
+
for _t, rvq1, _p_stop in self.ar_stream(
|
|
678
|
+
prep,
|
|
679
|
+
max_frames=max_frames,
|
|
680
|
+
top_p=top_p,
|
|
681
|
+
temperature=temperature,
|
|
682
|
+
anti_loop=anti_loop,
|
|
683
|
+
use_prefix=use_prefix,
|
|
684
|
+
prefix_sec_fixed=prefix_sec_fixed,
|
|
685
|
+
use_stop_head=use_stop_head,
|
|
686
|
+
stop_patience=stop_patience,
|
|
687
|
+
stop_threshold=stop_threshold,
|
|
688
|
+
min_gen_frames=min_gen_frames,
|
|
689
|
+
):
|
|
690
|
+
hist_A.append(rvq1)
|
|
691
|
+
|
|
692
|
+
T = len(hist_A)
|
|
693
|
+
if T == 0:
|
|
694
|
+
return torch.zeros(
|
|
695
|
+
0, self.cfg.num_codebooks, dtype=torch.long, device=device
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
tokens_A = torch.tensor(hist_A, device=device, dtype=torch.long).unsqueeze(0)
|
|
699
|
+
cond_seq = prep["cond_all"][:, :T, :]
|
|
700
|
+
tokens_btq_1xTQ = self.nar_refine(cond_seq, tokens_A)
|
|
701
|
+
return tokens_btq_1xTQ.squeeze(0)
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
class SoproTTS:
|
|
705
|
+
def __init__(
|
|
706
|
+
self,
|
|
707
|
+
model: SoproTTSModel,
|
|
708
|
+
cfg: SoproTTSConfig,
|
|
709
|
+
tokenizer: TextTokenizer,
|
|
710
|
+
codec: MimiCodec,
|
|
711
|
+
device: str,
|
|
712
|
+
):
|
|
713
|
+
self.model = model
|
|
714
|
+
self.cfg = cfg
|
|
715
|
+
self.tokenizer = tokenizer
|
|
716
|
+
self.codec = codec
|
|
717
|
+
self.device = torch.device(device)
|
|
718
|
+
|
|
719
|
+
@classmethod
|
|
720
|
+
def from_pretrained(
|
|
721
|
+
cls,
|
|
722
|
+
repo_id: str,
|
|
723
|
+
*,
|
|
724
|
+
revision: Optional[str] = None,
|
|
725
|
+
cache_dir: Optional[str] = None,
|
|
726
|
+
token: Optional[str] = None,
|
|
727
|
+
device: Optional[str] = None,
|
|
728
|
+
) -> "SoproTTS":
|
|
729
|
+
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
730
|
+
dev = torch.device(device)
|
|
731
|
+
|
|
732
|
+
local_dir = download_repo(
|
|
733
|
+
repo_id, revision=revision, cache_dir=cache_dir, token=token
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
model_path = os.path.join(local_dir, "model.safetensors")
|
|
737
|
+
if not os.path.exists(model_path):
|
|
738
|
+
raise FileNotFoundError(f"Expected {model_path} in repo snapshot.")
|
|
739
|
+
|
|
740
|
+
cfg = load_cfg_from_safetensors(model_path)
|
|
741
|
+
|
|
742
|
+
tokenizer = TextTokenizer(model_name=local_dir)
|
|
743
|
+
|
|
744
|
+
model = SoproTTSModel(cfg, tokenizer).to(dev).eval()
|
|
745
|
+
state = load_state_dict_from_safetensors(model_path)
|
|
746
|
+
|
|
747
|
+
model.load_state_dict(state)
|
|
748
|
+
|
|
749
|
+
codec = MimiCodec(num_quantizers=cfg.num_codebooks, device=device)
|
|
750
|
+
|
|
751
|
+
return cls(
|
|
752
|
+
model=model, cfg=cfg, tokenizer=tokenizer, codec=codec, device=device
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
def encode_text(self, text: str) -> torch.Tensor:
|
|
756
|
+
ids = self.tokenizer.encode(text)
|
|
757
|
+
return torch.tensor(ids, dtype=torch.long, device=self.device)
|
|
758
|
+
|
|
759
|
+
def encode_reference(
|
|
760
|
+
self,
|
|
761
|
+
*,
|
|
762
|
+
ref_audio_path: Optional[str] = None,
|
|
763
|
+
ref_tokens_tq: Optional[torch.Tensor] = None,
|
|
764
|
+
ref_seconds: Optional[float] = None,
|
|
765
|
+
) -> torch.Tensor:
|
|
766
|
+
if (ref_tokens_tq is None) and (ref_audio_path is None):
|
|
767
|
+
raise RuntimeError(
|
|
768
|
+
"SoproTTS requires a reference. Provide ref_audio_path=... or ref_tokens_tq=..."
|
|
769
|
+
)
|
|
770
|
+
if (ref_tokens_tq is not None) and (ref_audio_path is not None):
|
|
771
|
+
raise RuntimeError(
|
|
772
|
+
"Provide only one of ref_audio_path or ref_tokens_tq (not both)."
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
if ref_seconds is None:
|
|
776
|
+
ref_seconds = float(self.cfg.ref_seconds_max)
|
|
777
|
+
|
|
778
|
+
if ref_tokens_tq is not None:
|
|
779
|
+
ref = ref_tokens_tq.to(self.device).long()
|
|
780
|
+
if ref_seconds > 0:
|
|
781
|
+
fps = float(self.cfg.mimi_fps)
|
|
782
|
+
win = max(1, int(round(ref_seconds * fps)))
|
|
783
|
+
ref = center_crop_tokens(ref, win)
|
|
784
|
+
return ref
|
|
785
|
+
|
|
786
|
+
crop_seconds = (
|
|
787
|
+
ref_seconds if (ref_seconds is not None and ref_seconds > 0) else None
|
|
788
|
+
)
|
|
789
|
+
ref = (
|
|
790
|
+
self.codec.encode_file(ref_audio_path, crop_seconds=crop_seconds)
|
|
791
|
+
.to(self.device)
|
|
792
|
+
.long()
|
|
793
|
+
)
|
|
794
|
+
return ref
|
|
795
|
+
|
|
796
|
+
@torch.no_grad()
|
|
797
|
+
def synthesize(
|
|
798
|
+
self,
|
|
799
|
+
text: str,
|
|
800
|
+
*,
|
|
801
|
+
ref_audio_path: Optional[str] = None,
|
|
802
|
+
ref_tokens_tq: Optional[torch.Tensor] = None,
|
|
803
|
+
max_frames: int = 400,
|
|
804
|
+
top_p: float = 0.9,
|
|
805
|
+
temperature: float = 1.05,
|
|
806
|
+
anti_loop: bool = True,
|
|
807
|
+
use_prefix: bool = True,
|
|
808
|
+
prefix_sec_fixed: Optional[float] = None,
|
|
809
|
+
style_strength: Optional[float] = None,
|
|
810
|
+
ref_seconds: Optional[float] = None,
|
|
811
|
+
use_stop_head: Optional[bool] = None,
|
|
812
|
+
stop_patience: Optional[int] = None,
|
|
813
|
+
stop_threshold: Optional[float] = None,
|
|
814
|
+
min_gen_frames: Optional[int] = None,
|
|
815
|
+
) -> torch.Tensor:
|
|
816
|
+
text_ids = self.encode_text(text)
|
|
817
|
+
ref = self.encode_reference(
|
|
818
|
+
ref_audio_path=ref_audio_path,
|
|
819
|
+
ref_tokens_tq=ref_tokens_tq,
|
|
820
|
+
ref_seconds=ref_seconds,
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
tokens_tq = self.model.generate_tokens(
|
|
824
|
+
text_ids,
|
|
825
|
+
ref,
|
|
826
|
+
max_frames=max_frames,
|
|
827
|
+
device=self.device,
|
|
828
|
+
top_p=top_p,
|
|
829
|
+
temperature=temperature,
|
|
830
|
+
anti_loop=anti_loop,
|
|
831
|
+
use_prefix=use_prefix,
|
|
832
|
+
prefix_sec_fixed=prefix_sec_fixed,
|
|
833
|
+
style_strength=float(
|
|
834
|
+
style_strength
|
|
835
|
+
if style_strength is not None
|
|
836
|
+
else self.cfg.style_strength
|
|
837
|
+
),
|
|
838
|
+
use_stop_head=use_stop_head,
|
|
839
|
+
stop_patience=stop_patience,
|
|
840
|
+
stop_threshold=stop_threshold,
|
|
841
|
+
min_gen_frames=min_gen_frames,
|
|
842
|
+
)
|
|
843
|
+
|
|
844
|
+
wav = self.codec.decode_full(tokens_tq)
|
|
845
|
+
return wav
|
|
846
|
+
|
|
847
|
+
def stream(self, text: str, **kwargs) -> Iterator[torch.Tensor]:
|
|
848
|
+
from .streaming import stream as _stream
|
|
849
|
+
|
|
850
|
+
return _stream(self, text, **kwargs)
|
|
851
|
+
|
|
852
|
+
def save_wav(self, path: str, wav_1xT: torch.Tensor) -> None:
|
|
853
|
+
save_audio(path, wav_1xT, sr=TARGET_SR)
|