sopro 1.0.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sopro/__init__.py +1 -1
- sopro/cli.py +31 -46
- sopro/config.py +15 -20
- sopro/hub.py +2 -3
- sopro/model.py +265 -535
- sopro/nn/__init__.py +7 -3
- sopro/nn/blocks.py +78 -0
- sopro/nn/embeddings.py +16 -0
- sopro/nn/generator.py +130 -0
- sopro/nn/nar.py +116 -0
- sopro/nn/ref.py +160 -0
- sopro/nn/speaker.py +14 -17
- sopro/nn/text.py +132 -0
- sopro/sampling.py +3 -3
- sopro/streaming.py +25 -38
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/METADATA +30 -7
- sopro-1.5.0.dist-info/RECORD +26 -0
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/WHEEL +1 -1
- sopro/nn/xattn.py +0 -98
- sopro-1.0.1.dist-info/RECORD +0 -23
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/entry_points.txt +0 -0
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/licenses/LICENSE.txt +0 -0
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/top_level.txt +0 -0
sopro/model.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
+
from dataclasses import dataclass
|
|
4
5
|
from typing import Dict, Iterator, List, Optional, Tuple
|
|
5
6
|
|
|
6
7
|
import torch
|
|
7
8
|
import torch.nn as nn
|
|
8
|
-
import torch.nn.functional as F
|
|
9
9
|
|
|
10
10
|
from sopro.config import SoproTTSConfig
|
|
11
11
|
from sopro.hub import (
|
|
@@ -18,152 +18,36 @@ from .audio import save_audio
|
|
|
18
18
|
from .codec.mimi import MimiCodec
|
|
19
19
|
from .constants import TARGET_SR
|
|
20
20
|
from .nn import (
|
|
21
|
+
ARRVQ1Generator,
|
|
21
22
|
CodebookEmbedding,
|
|
22
|
-
|
|
23
|
+
RefXAttnStack,
|
|
23
24
|
RMSNorm,
|
|
24
25
|
SinusoidalPositionalEmbedding,
|
|
25
26
|
SpeakerFiLM,
|
|
26
27
|
SSMLiteBlock,
|
|
27
|
-
|
|
28
|
-
TextXAttnBlock,
|
|
28
|
+
TextEncoder,
|
|
29
29
|
Token2SV,
|
|
30
30
|
)
|
|
31
|
-
from .
|
|
31
|
+
from .nn.nar import NARSinglePass
|
|
32
|
+
from .sampling import repeated_tail
|
|
32
33
|
from .sampling import rf_ar as rf_ar_fn
|
|
33
34
|
from .sampling import rf_nar as rf_nar_fn
|
|
34
35
|
from .sampling import sample_token
|
|
35
36
|
from .tokenizer import TextTokenizer
|
|
36
37
|
|
|
37
38
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
d_model: int,
|
|
43
|
-
n_layers: int,
|
|
44
|
-
tokenizer: TextTokenizer,
|
|
45
|
-
):
|
|
46
|
-
super().__init__()
|
|
47
|
-
self.tok = tokenizer
|
|
48
|
-
self.embed = TextEmbedding(self.tok.vocab_size, d_model)
|
|
49
|
-
self.layers = nn.ModuleList(
|
|
50
|
-
[SSMLiteBlock(d_model, cfg.dropout, causal=False) for _ in range(n_layers)]
|
|
51
|
-
)
|
|
52
|
-
self.pos = SinusoidalPositionalEmbedding(d_model, max_len=cfg.max_text_len + 8)
|
|
53
|
-
self.norm = RMSNorm(d_model)
|
|
54
|
-
|
|
55
|
-
def forward(
|
|
56
|
-
self, text_ids: torch.Tensor, mask: torch.Tensor
|
|
57
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
58
|
-
x = self.embed(text_ids)
|
|
59
|
-
L = x.size(1)
|
|
60
|
-
pos = self.pos(torch.arange(L, device=x.device))
|
|
61
|
-
x = x + pos.unsqueeze(0)
|
|
62
|
-
|
|
63
|
-
x = x * mask.unsqueeze(-1).float()
|
|
64
|
-
for layer in self.layers:
|
|
65
|
-
x = layer(x)
|
|
66
|
-
x = self.norm(x)
|
|
67
|
-
|
|
68
|
-
mask_f = mask.float().unsqueeze(-1)
|
|
69
|
-
pooled = (x * mask_f).sum(dim=1) / (mask_f.sum(dim=1) + 1e-6)
|
|
70
|
-
return x, pooled
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
class ARRVQ1Generator(nn.Module):
|
|
74
|
-
def __init__(self, cfg: SoproTTSConfig, d_model: int, vocab: int):
|
|
75
|
-
super().__init__()
|
|
76
|
-
ks = cfg.ar_kernel
|
|
77
|
-
dils: List[int] = []
|
|
78
|
-
while len(dils) < cfg.n_layers_ar:
|
|
79
|
-
dils.extend(list(cfg.ar_dilation_cycle))
|
|
80
|
-
dils = dils[: cfg.n_layers_ar]
|
|
81
|
-
|
|
82
|
-
self.dils = tuple(int(d) for d in dils)
|
|
83
|
-
self.blocks = nn.ModuleList(
|
|
84
|
-
[
|
|
85
|
-
SSMLiteBlock(
|
|
86
|
-
d_model, cfg.dropout, causal=True, kernel_size=ks, dilation=d
|
|
87
|
-
)
|
|
88
|
-
for d in self.dils
|
|
89
|
-
]
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
self.attn_freq = int(cfg.ar_text_attn_freq)
|
|
93
|
-
self.x_attns = nn.ModuleList()
|
|
94
|
-
for i in range(len(self.blocks)):
|
|
95
|
-
if (i + 1) % self.attn_freq == 0:
|
|
96
|
-
self.x_attns.append(
|
|
97
|
-
TextXAttnBlock(d_model, heads=4, dropout=cfg.dropout)
|
|
98
|
-
)
|
|
99
|
-
else:
|
|
100
|
-
self.x_attns.append(nn.Identity())
|
|
101
|
-
|
|
102
|
-
self.norm = RMSNorm(d_model)
|
|
103
|
-
self.head = nn.Linear(d_model, vocab)
|
|
104
|
-
|
|
105
|
-
def forward(
|
|
106
|
-
self,
|
|
107
|
-
x: torch.Tensor,
|
|
108
|
-
text_emb: Optional[torch.Tensor] = None,
|
|
109
|
-
text_mask: Optional[torch.Tensor] = None,
|
|
110
|
-
) -> torch.Tensor:
|
|
111
|
-
key_padding_mask = ~text_mask if text_mask is not None else None
|
|
112
|
-
|
|
113
|
-
if key_padding_mask is not None:
|
|
114
|
-
bad_rows = key_padding_mask.all(dim=1)
|
|
115
|
-
if bad_rows.any():
|
|
116
|
-
key_padding_mask = key_padding_mask.clone()
|
|
117
|
-
idx = torch.nonzero(bad_rows, as_tuple=False).squeeze(1)
|
|
118
|
-
key_padding_mask[idx, 0] = False
|
|
119
|
-
if text_emb is not None:
|
|
120
|
-
text_emb = text_emb.clone()
|
|
121
|
-
text_emb[idx, 0, :] = 0
|
|
122
|
-
|
|
123
|
-
h = x
|
|
124
|
-
for i, lyr in enumerate(self.blocks):
|
|
125
|
-
h = lyr(h)
|
|
126
|
-
if not isinstance(self.x_attns[i], nn.Identity) and text_emb is not None:
|
|
127
|
-
h = self.x_attns[i](h, text_emb, key_padding_mask=key_padding_mask)
|
|
128
|
-
|
|
129
|
-
h = self.norm(h)
|
|
130
|
-
return self.head(h)
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
class StageRefiner(nn.Module):
|
|
134
|
-
def __init__(self, cfg: SoproTTSConfig, D: int, num_heads: int, codebook_size: int):
|
|
135
|
-
super().__init__()
|
|
136
|
-
self.blocks = nn.ModuleList(
|
|
137
|
-
[SSMLiteBlock(D, cfg.dropout, causal=True) for _ in range(cfg.n_layers_nar)]
|
|
138
|
-
)
|
|
139
|
-
self.norm = RMSNorm(D)
|
|
140
|
-
self.pre = nn.Linear(D, cfg.nar_head_dim)
|
|
141
|
-
self.heads = nn.ModuleList(
|
|
142
|
-
[nn.Linear(cfg.nar_head_dim, codebook_size) for _ in range(num_heads)]
|
|
143
|
-
)
|
|
144
|
-
self.mix = nn.Parameter(torch.ones(2, dtype=torch.float32))
|
|
145
|
-
|
|
146
|
-
def forward_hidden(
|
|
147
|
-
self, cond_bt_d: torch.Tensor, prev_bt_d: torch.Tensor
|
|
148
|
-
) -> torch.Tensor:
|
|
149
|
-
w = torch.softmax(self.mix, dim=0)
|
|
150
|
-
x = w[0] * cond_bt_d + w[1] * prev_bt_d
|
|
151
|
-
for b in self.blocks:
|
|
152
|
-
x = b(x)
|
|
153
|
-
return self.norm(x)
|
|
154
|
-
|
|
155
|
-
def forward_heads(self, h: torch.Tensor) -> List[torch.Tensor]:
|
|
156
|
-
z = self.pre(h)
|
|
157
|
-
return [head(z) for head in self.heads]
|
|
158
|
-
|
|
39
|
+
def _stage_range_to_indices(stage_rng: Tuple[int, int], Q: int) -> List[int]:
|
|
40
|
+
lo, hi = int(stage_rng[0]), int(stage_rng[1])
|
|
41
|
+
idxs = list(range(lo - 1, hi))
|
|
42
|
+
return [i for i in idxs if 1 <= i < Q]
|
|
159
43
|
|
|
160
|
-
class StopHead(nn.Module):
|
|
161
|
-
def __init__(self, D: int):
|
|
162
|
-
super().__init__()
|
|
163
|
-
self.proj = nn.Linear(D, 1)
|
|
164
44
|
|
|
165
|
-
|
|
166
|
-
|
|
45
|
+
@dataclass
|
|
46
|
+
class PreparedReference:
|
|
47
|
+
ref_tokens_btq: torch.Tensor
|
|
48
|
+
sv_ref: torch.Tensor
|
|
49
|
+
ref_seq: torch.Tensor
|
|
50
|
+
ref_kv_caches: List[Dict[str, torch.Tensor]]
|
|
167
51
|
|
|
168
52
|
|
|
169
53
|
class SoproTTSModel(nn.Module):
|
|
@@ -172,327 +56,165 @@ class SoproTTSModel(nn.Module):
|
|
|
172
56
|
self.cfg = cfg
|
|
173
57
|
D = int(cfg.d_model)
|
|
174
58
|
|
|
175
|
-
self.
|
|
176
|
-
|
|
59
|
+
self.eos_id = int(cfg.codebook_size)
|
|
60
|
+
|
|
61
|
+
self.text_enc = TextEncoder(cfg, D, int(cfg.n_layers_text), tokenizer)
|
|
62
|
+
self.frame_pos = SinusoidalPositionalEmbedding(
|
|
63
|
+
D, max_len=int(cfg.pos_emb_max) + 8
|
|
64
|
+
)
|
|
177
65
|
|
|
178
66
|
self.cb_embed = CodebookEmbedding(
|
|
179
67
|
cfg.num_codebooks, cfg.codebook_size, D, use_bos=True
|
|
180
68
|
)
|
|
181
|
-
|
|
69
|
+
|
|
70
|
+
self.nar_prev_cb_weights = nn.Parameter(
|
|
71
|
+
torch.zeros(cfg.num_codebooks, dtype=torch.float32)
|
|
72
|
+
)
|
|
182
73
|
|
|
183
74
|
self.token2sv = Token2SV(
|
|
184
75
|
cfg.num_codebooks,
|
|
185
76
|
cfg.codebook_size,
|
|
186
77
|
d=192,
|
|
187
|
-
out_dim=cfg.sv_student_dim,
|
|
78
|
+
out_dim=int(cfg.sv_student_dim),
|
|
188
79
|
dropout=cfg.dropout,
|
|
189
80
|
)
|
|
190
|
-
self.spk_film = SpeakerFiLM(D, sv_dim=cfg.sv_student_dim)
|
|
191
|
-
self.cond_norm = RMSNorm(D)
|
|
81
|
+
self.spk_film = SpeakerFiLM(D, sv_dim=int(cfg.sv_student_dim))
|
|
192
82
|
|
|
193
|
-
self.ar = ARRVQ1Generator(cfg, D, cfg.codebook_size)
|
|
194
|
-
if cfg.ar_lookback > 0:
|
|
195
|
-
self.ar_hist_w = nn.Parameter(torch.zeros(cfg.ar_lookback))
|
|
196
|
-
|
|
197
|
-
def idxs(rng: Tuple[int, int]) -> List[int]:
|
|
198
|
-
lo, hi = rng
|
|
199
|
-
return list(range(lo - 1, hi))
|
|
83
|
+
self.ar = ARRVQ1Generator(cfg, D, int(cfg.codebook_size) + 1)
|
|
200
84
|
|
|
85
|
+
Q = int(cfg.num_codebooks)
|
|
201
86
|
self.stage_indices: Dict[str, List[int]] = {
|
|
202
|
-
"B":
|
|
203
|
-
"C":
|
|
204
|
-
"D":
|
|
205
|
-
"E":
|
|
87
|
+
"B": _stage_range_to_indices(cfg.stage_B, Q),
|
|
88
|
+
"C": _stage_range_to_indices(cfg.stage_C, Q),
|
|
89
|
+
"D": _stage_range_to_indices(cfg.stage_D, Q),
|
|
90
|
+
"E": _stage_range_to_indices(cfg.stage_E, Q),
|
|
206
91
|
}
|
|
207
|
-
self.
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
)
|
|
92
|
+
self.stage_order = [
|
|
93
|
+
s for s in ["B", "C", "D", "E"] if len(self.stage_indices[s]) > 0
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
self.nar = NARSinglePass(cfg, D, stage_specs=self.stage_indices)
|
|
213
97
|
|
|
214
|
-
self.
|
|
98
|
+
self.cond_norm = RMSNorm(D)
|
|
215
99
|
|
|
100
|
+
ref_enc_layers = int(getattr(cfg, "ref_enc_layers", 2))
|
|
216
101
|
self.ref_enc_blocks = nn.ModuleList(
|
|
217
|
-
[SSMLiteBlock(D, cfg.dropout, causal=False) for _ in range(
|
|
102
|
+
[SSMLiteBlock(D, cfg.dropout, causal=False) for _ in range(ref_enc_layers)]
|
|
218
103
|
)
|
|
219
104
|
self.ref_enc_norm = RMSNorm(D)
|
|
220
|
-
|
|
221
|
-
|
|
105
|
+
|
|
106
|
+
self.ref_xattn = RefXAttnStack(
|
|
107
|
+
D,
|
|
108
|
+
heads=cfg.ref_xattn_heads,
|
|
109
|
+
layers=cfg.ref_xattn_layers,
|
|
110
|
+
gmax=cfg.ref_xattn_gmax,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
self.register_buffer(
|
|
114
|
+
"ref_cb_weights",
|
|
115
|
+
torch.linspace(1.0, 0.1, int(cfg.num_codebooks)),
|
|
116
|
+
persistent=True,
|
|
222
117
|
)
|
|
223
118
|
|
|
224
119
|
def rf_ar(self) -> int:
|
|
225
|
-
return rf_ar_fn(
|
|
120
|
+
return rf_ar_fn(
|
|
121
|
+
int(self.cfg.ar_kernel),
|
|
122
|
+
getattr(self.ar, "dils", tuple(int(x) for x in self.cfg.ar_dilation_cycle)),
|
|
123
|
+
)
|
|
226
124
|
|
|
227
125
|
def rf_nar(self) -> int:
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
return
|
|
234
|
-
x.transpose(1, 2), kernel_size=factor, stride=factor
|
|
235
|
-
).transpose(1, 2)
|
|
236
|
-
|
|
237
|
-
def _normalize_ref_mask(
|
|
238
|
-
self, ref_mask: Optional[torch.Tensor], device: torch.device
|
|
239
|
-
) -> Optional[torch.Tensor]:
|
|
240
|
-
if ref_mask is None:
|
|
241
|
-
return None
|
|
242
|
-
mk = ref_mask.to(device).bool()
|
|
243
|
-
if mk.ndim == 1:
|
|
244
|
-
mk = mk.unsqueeze(0)
|
|
245
|
-
return mk
|
|
246
|
-
|
|
247
|
-
def _encode_reference_seq(self, ref_tokens: torch.Tensor) -> torch.Tensor:
|
|
248
|
-
B, Tr, Q = ref_tokens.shape
|
|
249
|
-
emb_sum = 0.0
|
|
250
|
-
for q in range(Q):
|
|
251
|
-
emb_sum = emb_sum + self.cb_embed.embed_tokens(
|
|
252
|
-
ref_tokens[:, :, q], cb_index=q
|
|
253
|
-
)
|
|
254
|
-
x = emb_sum / float(Q)
|
|
255
|
-
for b in self.ref_enc_blocks:
|
|
256
|
-
x = b(x)
|
|
257
|
-
return self.ref_enc_norm(x)
|
|
126
|
+
cycle = tuple(int(x) for x in self.cfg.nar_dilation_cycle) or (1,)
|
|
127
|
+
dils: List[int] = []
|
|
128
|
+
while len(dils) < int(self.cfg.n_layers_nar):
|
|
129
|
+
dils.extend(list(cycle))
|
|
130
|
+
dils = dils[: int(self.cfg.n_layers_nar)]
|
|
131
|
+
return rf_nar_fn(int(self.cfg.nar_kernel_size), tuple(dils))
|
|
258
132
|
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
ref_seq: torch.Tensor,
|
|
263
|
-
ref_mask: Optional[torch.Tensor] = None,
|
|
264
|
-
) -> torch.Tensor:
|
|
265
|
-
ref_seq_p = self._pool_time(ref_seq, 1)
|
|
266
|
-
|
|
267
|
-
key_padding_mask = None
|
|
268
|
-
if ref_mask is not None:
|
|
269
|
-
mk_bool = ref_mask.bool()
|
|
270
|
-
B, Tr = mk_bool.shape
|
|
271
|
-
pooled_len = ref_seq_p.size(1)
|
|
272
|
-
if pooled_len == Tr:
|
|
273
|
-
key_padding_mask = ~mk_bool
|
|
274
|
-
else:
|
|
275
|
-
cut = pooled_len * 2
|
|
276
|
-
mk2 = mk_bool[:, :cut].reshape(B, pooled_len, 2).any(dim=2)
|
|
277
|
-
key_padding_mask = ~mk2
|
|
133
|
+
@torch.no_grad()
|
|
134
|
+
def _encode_reference_seq(self, ref_tokens_btq: torch.Tensor) -> torch.Tensor:
|
|
135
|
+
B, Tr, Q = ref_tokens_btq.shape
|
|
278
136
|
|
|
279
|
-
|
|
280
|
-
|
|
137
|
+
w = torch.softmax(self.ref_cb_weights.float(), dim=0).to(
|
|
138
|
+
device=ref_tokens_btq.device
|
|
281
139
|
)
|
|
282
140
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
return txt_pool[:, None, :] + pos
|
|
288
|
-
|
|
289
|
-
def _ar_prev_from_seq(self, seq_1xT: torch.Tensor) -> torch.Tensor:
|
|
290
|
-
K = int(self.cfg.ar_lookback)
|
|
291
|
-
if K <= 0 or getattr(self, "ar_hist_w", None) is None:
|
|
292
|
-
return self.cb_embed.embed_shift_by_k(seq_1xT, cb_index=0, k=1)
|
|
293
|
-
|
|
294
|
-
ws = torch.softmax(self.ar_hist_w, dim=0)
|
|
295
|
-
acc = 0.0
|
|
296
|
-
k_max = min(K, int(seq_1xT.size(1)))
|
|
297
|
-
for k in range(1, k_max + 1):
|
|
298
|
-
acc = acc + ws[k - 1] * self.cb_embed.embed_shift_by_k(
|
|
299
|
-
seq_1xT, cb_index=0, k=k
|
|
300
|
-
)
|
|
301
|
-
return acc
|
|
141
|
+
x = 0.0
|
|
142
|
+
for q in range(Q):
|
|
143
|
+
e = self.cb_embed.embed_tokens(ref_tokens_btq[:, :, q], cb_index=q)
|
|
144
|
+
x = x + w[q].to(e.dtype) * e
|
|
302
145
|
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
self,
|
|
306
|
-
text_ids_1d: torch.Tensor,
|
|
307
|
-
ref_tokens_tq: torch.Tensor,
|
|
308
|
-
*,
|
|
309
|
-
max_frames: int,
|
|
310
|
-
device: torch.device,
|
|
311
|
-
style_strength: float = 1.0,
|
|
312
|
-
ref_mask: Optional[torch.Tensor] = None,
|
|
313
|
-
chunk_size: Optional[int] = None,
|
|
314
|
-
) -> Dict[str, torch.Tensor]:
|
|
315
|
-
self.eval()
|
|
146
|
+
for b in self.ref_enc_blocks:
|
|
147
|
+
x = b(x)
|
|
316
148
|
|
|
317
|
-
|
|
318
|
-
text_mask = torch.ones_like(text_ids, dtype=torch.bool).unsqueeze(0)
|
|
319
|
-
txt_seq, txt_pool = self.text_enc(text_ids.unsqueeze(0), text_mask)
|
|
149
|
+
return self.ref_enc_norm(x)
|
|
320
150
|
|
|
321
|
-
|
|
151
|
+
@torch.no_grad()
|
|
152
|
+
def prepare_reference(
|
|
153
|
+
self, ref_tokens_tq: torch.Tensor, *, device: torch.device
|
|
154
|
+
) -> PreparedReference:
|
|
155
|
+
ref_tokens_btq = ref_tokens_tq.unsqueeze(0).to(device=device, dtype=torch.long)
|
|
156
|
+
Tr = int(ref_tokens_btq.size(1))
|
|
322
157
|
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
ref_mask_btr = self._normalize_ref_mask(ref_mask, device)
|
|
158
|
+
lengths = torch.tensor([Tr], device=device, dtype=torch.long)
|
|
159
|
+
sv_ref = self.token2sv(ref_tokens_btq, lengths=lengths)
|
|
326
160
|
|
|
327
|
-
|
|
328
|
-
if T <= 0:
|
|
329
|
-
cond_all = torch.zeros(
|
|
330
|
-
(1, 0, txt_pool.size(-1)), device=device, dtype=txt_pool.dtype
|
|
331
|
-
)
|
|
332
|
-
else:
|
|
333
|
-
pos = self.frame_pos(torch.arange(T, device=device)).unsqueeze(0)
|
|
334
|
-
base_all = txt_pool[:, None, :] + pos
|
|
335
|
-
base_all = self.spk_film(base_all, sv_ref, strength=float(style_strength))
|
|
336
|
-
|
|
337
|
-
if chunk_size is None or int(chunk_size) >= T:
|
|
338
|
-
out = self._single_pass_ref_xattn(
|
|
339
|
-
base_all, ref_seq, ref_mask=ref_mask_btr
|
|
340
|
-
)
|
|
341
|
-
cond_all = self.cond_norm(out)
|
|
342
|
-
else:
|
|
343
|
-
cs = int(chunk_size)
|
|
344
|
-
chunks: List[torch.Tensor] = []
|
|
345
|
-
for s in range(0, T, cs):
|
|
346
|
-
e = min(T, s + cs)
|
|
347
|
-
q = base_all[:, s:e, :]
|
|
348
|
-
out = self._single_pass_ref_xattn(q, ref_seq, ref_mask=ref_mask_btr)
|
|
349
|
-
chunks.append(self.cond_norm(out))
|
|
350
|
-
cond_all = torch.cat(chunks, dim=1)
|
|
161
|
+
ref_seq = self._encode_reference_seq(ref_tokens_btq)
|
|
351
162
|
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
"ref_mask": (
|
|
361
|
-
ref_mask_btr
|
|
362
|
-
if ref_mask_btr is not None
|
|
363
|
-
else torch.empty(0, device=device, dtype=torch.bool)
|
|
364
|
-
),
|
|
365
|
-
"style_strength": torch.tensor(float(style_strength), device=device),
|
|
366
|
-
}
|
|
163
|
+
ref_kv_caches = self.ref_xattn.build_kv_caches(ref_seq, key_padding_mask=None)
|
|
164
|
+
|
|
165
|
+
return PreparedReference(
|
|
166
|
+
ref_tokens_btq=ref_tokens_btq,
|
|
167
|
+
sv_ref=sv_ref,
|
|
168
|
+
ref_seq=ref_seq,
|
|
169
|
+
ref_kv_caches=ref_kv_caches,
|
|
170
|
+
)
|
|
367
171
|
|
|
368
172
|
@torch.no_grad()
|
|
369
|
-
def
|
|
173
|
+
def prepare_conditioning(
|
|
370
174
|
self,
|
|
371
175
|
text_ids_1d: torch.Tensor,
|
|
372
|
-
|
|
176
|
+
ref: PreparedReference,
|
|
373
177
|
*,
|
|
374
178
|
max_frames: int,
|
|
375
179
|
device: torch.device,
|
|
376
|
-
style_strength: float = 1.
|
|
377
|
-
ref_mask: Optional[torch.Tensor] = None,
|
|
180
|
+
style_strength: float = 1.2,
|
|
378
181
|
) -> Dict[str, torch.Tensor]:
|
|
379
182
|
self.eval()
|
|
183
|
+
sv_ref = ref.sv_ref.to(device)
|
|
380
184
|
|
|
381
185
|
text_ids = text_ids_1d.to(device)
|
|
382
186
|
text_mask = torch.ones_like(text_ids, dtype=torch.bool).unsqueeze(0)
|
|
383
187
|
txt_seq, txt_pool = self.text_enc(text_ids.unsqueeze(0), text_mask)
|
|
384
188
|
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
189
|
+
if sv_ref is not None:
|
|
190
|
+
if sv_ref.dim() == 1:
|
|
191
|
+
sv_ref = sv_ref.unsqueeze(0)
|
|
192
|
+
sv_ref = sv_ref.to(device)
|
|
193
|
+
else:
|
|
194
|
+
ref_btq = ref_tokens_tq.unsqueeze(0).to(device)
|
|
195
|
+
ref_len = torch.tensor(
|
|
196
|
+
[int(ref_btq.size(1))], device=device, dtype=torch.long
|
|
197
|
+
)
|
|
198
|
+
sv_ref = self.token2sv(ref_btq, lengths=ref_len)
|
|
199
|
+
|
|
200
|
+
Tar = int(max_frames) + 1
|
|
201
|
+
pos = self.frame_pos(torch.arange(Tar, device=device)).unsqueeze(0)
|
|
202
|
+
base_ar = txt_pool[:, None, :] + pos
|
|
203
|
+
cond_ar = self.spk_film(base_ar, sv_ref, strength=float(style_strength))
|
|
389
204
|
|
|
390
|
-
|
|
391
|
-
|
|
205
|
+
cond_ar, _ = self.ref_xattn(
|
|
206
|
+
cond_ar, kv_caches=ref.ref_kv_caches, use_cache=True
|
|
207
|
+
)
|
|
208
|
+
cond_ar = self.cond_norm(cond_ar)
|
|
392
209
|
|
|
393
210
|
return {
|
|
394
211
|
"txt_seq": txt_seq,
|
|
395
212
|
"text_mask": text_mask,
|
|
396
|
-
"cond_all": cond_all,
|
|
397
|
-
"ref_btq": ref_btq,
|
|
398
213
|
"txt_pool": txt_pool,
|
|
399
214
|
"sv_ref": sv_ref,
|
|
400
|
-
"
|
|
401
|
-
"ref_mask": (
|
|
402
|
-
ref_mask_btr
|
|
403
|
-
if ref_mask_btr is not None
|
|
404
|
-
else torch.empty(0, device=device, dtype=torch.bool)
|
|
405
|
-
),
|
|
406
|
-
"style_strength": torch.tensor(float(style_strength), device=device),
|
|
407
|
-
"max_frames": torch.tensor(int(max_frames), device=device),
|
|
215
|
+
"cond_ar": cond_ar,
|
|
408
216
|
}
|
|
409
217
|
|
|
410
|
-
@torch.no_grad()
|
|
411
|
-
def ensure_cond_upto(
|
|
412
|
-
self,
|
|
413
|
-
prep: Dict[str, torch.Tensor],
|
|
414
|
-
t_inclusive: int,
|
|
415
|
-
*,
|
|
416
|
-
chunk_size: int = 64,
|
|
417
|
-
) -> None:
|
|
418
|
-
if t_inclusive < 0:
|
|
419
|
-
return
|
|
420
|
-
|
|
421
|
-
cond_all = prep.get("cond_all", None)
|
|
422
|
-
if cond_all is None:
|
|
423
|
-
raise KeyError("prep dict missing 'cond_all'.")
|
|
424
|
-
|
|
425
|
-
have = int(cond_all.size(1))
|
|
426
|
-
need_min = int(t_inclusive) + 1
|
|
427
|
-
if have >= need_min:
|
|
428
|
-
return
|
|
429
|
-
|
|
430
|
-
if "txt_pool" not in prep or "sv_ref" not in prep or "ref_seq" not in prep:
|
|
431
|
-
raise RuntimeError(
|
|
432
|
-
"Lazy conditioning requested but prep lacks txt_pool/sv_ref/ref_seq. "
|
|
433
|
-
"Use prepare_conditioning_lazy() or prepare_conditioning(..., chunk_size=...)."
|
|
434
|
-
)
|
|
435
|
-
|
|
436
|
-
device = cond_all.device
|
|
437
|
-
txt_pool = prep["txt_pool"]
|
|
438
|
-
sv_ref = prep["sv_ref"]
|
|
439
|
-
ref_seq = prep["ref_seq"]
|
|
440
|
-
|
|
441
|
-
style_strength = float(
|
|
442
|
-
prep.get(
|
|
443
|
-
"style_strength", torch.tensor(self.cfg.style_strength, device=device)
|
|
444
|
-
).item()
|
|
445
|
-
)
|
|
446
|
-
|
|
447
|
-
ref_mask = prep.get("ref_mask", None)
|
|
448
|
-
if ref_mask is not None and ref_mask.numel() == 0:
|
|
449
|
-
ref_mask = None
|
|
450
|
-
|
|
451
|
-
max_frames = prep.get("max_frames", None)
|
|
452
|
-
maxT = int(max_frames.item()) if max_frames is not None else None
|
|
453
|
-
|
|
454
|
-
cs = max(1, int(chunk_size))
|
|
455
|
-
|
|
456
|
-
need = ((need_min + cs - 1) // cs) * cs
|
|
457
|
-
if maxT is not None:
|
|
458
|
-
need = min(need, maxT)
|
|
459
|
-
|
|
460
|
-
if have >= need:
|
|
461
|
-
return
|
|
462
|
-
|
|
463
|
-
new_chunks: List[torch.Tensor] = []
|
|
464
|
-
for s in range(have, need, cs):
|
|
465
|
-
e = min(need, s + cs)
|
|
466
|
-
pos = self.frame_pos(torch.arange(s, e, device=device)).unsqueeze(0)
|
|
467
|
-
base = txt_pool[:, None, :] + pos
|
|
468
|
-
base = self.spk_film(base, sv_ref, strength=style_strength)
|
|
469
|
-
out = self._single_pass_ref_xattn(base, ref_seq, ref_mask=ref_mask)
|
|
470
|
-
new_chunks.append(self.cond_norm(out))
|
|
471
|
-
|
|
472
|
-
prep["cond_all"] = torch.cat([cond_all] + new_chunks, dim=1)
|
|
473
|
-
|
|
474
|
-
def build_ar_prefix(
|
|
475
|
-
self,
|
|
476
|
-
ref_btq: torch.Tensor,
|
|
477
|
-
device: torch.device,
|
|
478
|
-
prefix_sec_fixed: Optional[float],
|
|
479
|
-
use_prefix: bool,
|
|
480
|
-
) -> torch.Tensor:
|
|
481
|
-
if not use_prefix or ref_btq.size(1) == 0:
|
|
482
|
-
return torch.zeros(1, 0, dtype=torch.long, device=device)
|
|
483
|
-
|
|
484
|
-
avail = int(ref_btq.size(1))
|
|
485
|
-
fps = float(self.cfg.mimi_fps)
|
|
486
|
-
|
|
487
|
-
if prefix_sec_fixed is not None and prefix_sec_fixed > 0:
|
|
488
|
-
P = min(avail, int(round(prefix_sec_fixed * fps)))
|
|
489
|
-
else:
|
|
490
|
-
P = min(avail, max(1, int(round(self.cfg.preprompt_sec_max * fps))))
|
|
491
|
-
|
|
492
|
-
if P <= 0:
|
|
493
|
-
return torch.zeros(1, 0, dtype=torch.long, device=device)
|
|
494
|
-
return ref_btq[:, :P, 0].contiguous()
|
|
495
|
-
|
|
496
218
|
@torch.no_grad()
|
|
497
219
|
def ar_stream(
|
|
498
220
|
self,
|
|
@@ -505,200 +227,178 @@ class SoproTTSModel(nn.Module):
|
|
|
505
227
|
loop_streak: int = 8,
|
|
506
228
|
recovery_top_p: float = 0.85,
|
|
507
229
|
recovery_temp: float = 1.2,
|
|
508
|
-
use_prefix: bool = True,
|
|
509
|
-
prefix_sec_fixed: Optional[float] = None,
|
|
510
|
-
cond_chunk_size: Optional[int] = None,
|
|
511
|
-
use_stop_head: Optional[bool] = None,
|
|
512
|
-
stop_patience: Optional[int] = None,
|
|
513
|
-
stop_threshold: Optional[float] = None,
|
|
514
230
|
min_gen_frames: Optional[int] = None,
|
|
515
|
-
) -> Iterator[Tuple[int, int,
|
|
516
|
-
device = prep["
|
|
517
|
-
|
|
231
|
+
) -> Iterator[Tuple[int, int, bool]]:
|
|
232
|
+
device = prep["cond_ar"].device
|
|
233
|
+
cond_ar = prep["cond_ar"]
|
|
518
234
|
txt_seq = prep["txt_seq"]
|
|
519
235
|
text_mask = prep["text_mask"]
|
|
520
|
-
ref_btq = prep["ref_btq"]
|
|
521
|
-
|
|
522
|
-
R_AR = self.rf_ar()
|
|
523
|
-
|
|
524
|
-
stop_head = self.stop_head
|
|
525
|
-
if use_stop_head is not None:
|
|
526
|
-
if not bool(use_stop_head):
|
|
527
|
-
stop_head = None
|
|
528
236
|
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
)
|
|
532
|
-
eff_stop_threshold = float(
|
|
533
|
-
stop_threshold if stop_threshold is not None else self.cfg.stop_threshold
|
|
534
|
-
)
|
|
535
|
-
eff_min_gen_frames = int(
|
|
237
|
+
eos_id = int(self.eos_id)
|
|
238
|
+
eff_min_gen = int(
|
|
536
239
|
min_gen_frames if min_gen_frames is not None else self.cfg.min_gen_frames
|
|
537
240
|
)
|
|
538
241
|
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
)
|
|
542
|
-
P = int(A_prefix.size(1))
|
|
242
|
+
max_steps = int(max_frames) + 1
|
|
243
|
+
ctx_ids = torch.zeros((1, max_steps), dtype=torch.long, device=device)
|
|
543
244
|
|
|
544
|
-
|
|
545
|
-
|
|
245
|
+
ar_state = self.ar.init_stream_state(
|
|
246
|
+
batch_size=1,
|
|
247
|
+
device=device,
|
|
248
|
+
dtype=cond_ar.dtype,
|
|
249
|
+
text_emb=txt_seq,
|
|
250
|
+
text_mask=text_mask,
|
|
546
251
|
)
|
|
547
|
-
if P > 0:
|
|
548
|
-
ctx_ids[:, :P] = A_prefix
|
|
549
252
|
|
|
550
|
-
|
|
253
|
+
hist: List[int] = []
|
|
551
254
|
loop_streak_count = 0
|
|
552
|
-
stop_streak_count = 0
|
|
553
255
|
last_a: Optional[int] = None
|
|
554
256
|
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
L_ar = min(t + 1, R_AR)
|
|
563
|
-
s_ar = t + 1 - L_ar
|
|
564
|
-
cond_win_ar = cond_all[:, s_ar : t + 1, :]
|
|
257
|
+
if self.cb_embed.bos_id is None:
|
|
258
|
+
raise RuntimeError(
|
|
259
|
+
"CodebookEmbedding.use_bos must be True for streaming AR cache"
|
|
260
|
+
)
|
|
261
|
+
bos_idx = torch.full(
|
|
262
|
+
(1, 1), int(self.cb_embed.bos_id), device=device, dtype=torch.long
|
|
263
|
+
)
|
|
565
264
|
|
|
566
|
-
|
|
567
|
-
|
|
265
|
+
for t in range(max_steps):
|
|
266
|
+
if t == 0:
|
|
267
|
+
prev_emb = self.cb_embed.emb(bos_idx)
|
|
268
|
+
else:
|
|
269
|
+
prev_tok = ctx_ids[:, t - 1 : t]
|
|
270
|
+
prev_emb = self.cb_embed.embed_tokens(prev_tok, cb_index=0)
|
|
568
271
|
|
|
569
|
-
|
|
570
|
-
prev_ctx_win = prev_ctx_full[:, -L_ar:, :]
|
|
272
|
+
x_t = cond_ar[:, t : t + 1, :] + prev_emb
|
|
571
273
|
|
|
572
274
|
cur_top_p, cur_temp = top_p, temperature
|
|
573
275
|
if anti_loop:
|
|
574
|
-
if repeated_tail(
|
|
276
|
+
if repeated_tail(hist, max_n=16):
|
|
575
277
|
cur_top_p, cur_temp = recovery_top_p, recovery_temp
|
|
576
278
|
elif last_a is not None and loop_streak_count >= loop_streak:
|
|
577
279
|
cur_top_p, cur_temp = recovery_top_p, recovery_temp
|
|
578
280
|
|
|
579
|
-
|
|
580
|
-
|
|
281
|
+
logits_t, ar_state = self.ar.step(
|
|
282
|
+
x_t, ar_state, text_emb=txt_seq, text_mask=text_mask
|
|
581
283
|
)
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
ar_logits_t,
|
|
586
|
-
history=hist_A,
|
|
284
|
+
tok = sample_token(
|
|
285
|
+
logits_t,
|
|
286
|
+
history=hist,
|
|
587
287
|
top_p=cur_top_p,
|
|
588
288
|
temperature=cur_temp,
|
|
589
289
|
top_k=50,
|
|
590
290
|
repetition_penalty=1.1,
|
|
591
291
|
)
|
|
592
292
|
|
|
593
|
-
ctx_ids[0,
|
|
594
|
-
|
|
293
|
+
ctx_ids[0, t] = int(tok)
|
|
294
|
+
hist.append(int(tok))
|
|
595
295
|
|
|
596
|
-
hist_A.append(int(rvq1_id))
|
|
597
296
|
loop_streak_count = (
|
|
598
|
-
(loop_streak_count + 1)
|
|
599
|
-
if (last_a is not None and rvq1_id == last_a)
|
|
600
|
-
else 0
|
|
297
|
+
(loop_streak_count + 1) if (last_a is not None and tok == last_a) else 0
|
|
601
298
|
)
|
|
602
|
-
last_a = int(
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
cond_all[:, t : t + 1, :]
|
|
609
|
-
+ self.cb_embed.embed_tokens(A_now, cb_index=0).detach()
|
|
610
|
-
)
|
|
611
|
-
stop_logits = stop_head(stop_inp)
|
|
612
|
-
p_stop = float(torch.sigmoid(stop_logits).item())
|
|
613
|
-
|
|
614
|
-
if t + 1 >= eff_min_gen_frames and p_stop > eff_stop_threshold:
|
|
615
|
-
stop_streak_count += 1
|
|
616
|
-
else:
|
|
617
|
-
stop_streak_count = 0
|
|
618
|
-
|
|
619
|
-
yield t, int(rvq1_id), p_stop
|
|
620
|
-
|
|
621
|
-
if stop_head is not None and stop_streak_count >= eff_stop_patience:
|
|
299
|
+
last_a = int(tok)
|
|
300
|
+
|
|
301
|
+
is_eos = int(tok) == eos_id
|
|
302
|
+
yield t, int(tok), bool(is_eos)
|
|
303
|
+
|
|
304
|
+
if is_eos and (t + 1) >= eff_min_gen:
|
|
622
305
|
break
|
|
623
306
|
|
|
624
307
|
@torch.no_grad()
|
|
625
308
|
def nar_refine(
|
|
626
|
-
self, cond_seq: torch.Tensor,
|
|
309
|
+
self, cond_seq: torch.Tensor, rvq1_1xT: torch.Tensor
|
|
627
310
|
) -> torch.Tensor:
|
|
628
|
-
|
|
629
|
-
|
|
311
|
+
B, T, D = cond_seq.shape
|
|
312
|
+
Q = int(self.cfg.num_codebooks)
|
|
313
|
+
|
|
314
|
+
out_btq = torch.zeros((B, T, Q), device=cond_seq.device, dtype=torch.long)
|
|
315
|
+
out_btq[:, :, 0] = rvq1_1xT
|
|
316
|
+
|
|
317
|
+
prev_tokens_list: List[torch.Tensor] = [rvq1_1xT.unsqueeze(-1)]
|
|
630
318
|
prev_cb_list: List[List[int]] = [[0]]
|
|
631
319
|
|
|
632
|
-
for
|
|
633
|
-
idxs = self.stage_indices[
|
|
320
|
+
for stage in self.stage_order:
|
|
321
|
+
idxs = self.stage_indices[stage]
|
|
322
|
+
if len(idxs) == 0:
|
|
323
|
+
continue
|
|
324
|
+
|
|
634
325
|
prev_tokens_cat = torch.cat(prev_tokens_list, dim=-1)
|
|
635
326
|
prev_cbs_cat = sum(prev_cb_list, [])
|
|
636
|
-
prev_emb_sum = self.cb_embed.sum_embed_subset(prev_tokens_cat, prev_cbs_cat)
|
|
637
327
|
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
328
|
+
prev_emb_sum = self.cb_embed.sum_embed_subset(
|
|
329
|
+
prev_tokens_cat,
|
|
330
|
+
prev_cbs_cat,
|
|
331
|
+
keep_mask=None,
|
|
332
|
+
cb_weights=self.nar_prev_cb_weights,
|
|
333
|
+
)
|
|
641
334
|
|
|
642
|
-
|
|
643
|
-
|
|
335
|
+
logits_list = self.nar.forward_stage(stage, cond_seq, prev_emb_sum)
|
|
336
|
+
if len(logits_list) == 0:
|
|
337
|
+
continue
|
|
338
|
+
|
|
339
|
+
preds = torch.stack([lg.argmax(dim=-1) for lg in logits_list], dim=-1)
|
|
340
|
+
|
|
341
|
+
for k, cb in enumerate(idxs):
|
|
342
|
+
out_btq[:, :, cb] = preds[:, :, k]
|
|
343
|
+
|
|
344
|
+
prev_tokens_list.append(preds.detach())
|
|
644
345
|
prev_cb_list.append(idxs)
|
|
645
346
|
|
|
646
|
-
|
|
647
|
-
return tokens_btq
|
|
347
|
+
return out_btq
|
|
648
348
|
|
|
649
349
|
@torch.no_grad()
|
|
650
350
|
def generate_tokens(
|
|
651
351
|
self,
|
|
652
352
|
text_ids_1d: torch.Tensor,
|
|
653
|
-
|
|
353
|
+
ref: PreparedReference,
|
|
654
354
|
*,
|
|
655
355
|
max_frames: int,
|
|
656
356
|
device: torch.device,
|
|
657
357
|
top_p: float = 0.9,
|
|
658
358
|
temperature: float = 1.05,
|
|
659
359
|
anti_loop: bool = True,
|
|
660
|
-
|
|
661
|
-
prefix_sec_fixed: Optional[float] = None,
|
|
662
|
-
style_strength: float = 1.0,
|
|
663
|
-
use_stop_head: Optional[bool] = None,
|
|
664
|
-
stop_patience: Optional[int] = None,
|
|
665
|
-
stop_threshold: Optional[float] = None,
|
|
360
|
+
style_strength: float = 1.2,
|
|
666
361
|
min_gen_frames: Optional[int] = None,
|
|
667
362
|
) -> torch.Tensor:
|
|
668
363
|
prep = self.prepare_conditioning(
|
|
669
364
|
text_ids_1d,
|
|
670
|
-
|
|
365
|
+
ref,
|
|
671
366
|
max_frames=max_frames,
|
|
672
367
|
device=device,
|
|
673
368
|
style_strength=style_strength,
|
|
674
369
|
)
|
|
675
370
|
|
|
676
|
-
|
|
677
|
-
|
|
371
|
+
eos_id = int(self.eos_id)
|
|
372
|
+
hist: List[int] = []
|
|
373
|
+
for _t, tok, is_eos in self.ar_stream(
|
|
678
374
|
prep,
|
|
679
375
|
max_frames=max_frames,
|
|
680
376
|
top_p=top_p,
|
|
681
377
|
temperature=temperature,
|
|
682
378
|
anti_loop=anti_loop,
|
|
683
|
-
use_prefix=use_prefix,
|
|
684
|
-
prefix_sec_fixed=prefix_sec_fixed,
|
|
685
|
-
use_stop_head=use_stop_head,
|
|
686
|
-
stop_patience=stop_patience,
|
|
687
|
-
stop_threshold=stop_threshold,
|
|
688
379
|
min_gen_frames=min_gen_frames,
|
|
689
380
|
):
|
|
690
|
-
|
|
381
|
+
hist.append(tok)
|
|
382
|
+
if is_eos:
|
|
383
|
+
break
|
|
384
|
+
|
|
385
|
+
Tfull = len(hist)
|
|
386
|
+
cut = Tfull
|
|
387
|
+
for i, v in enumerate(hist):
|
|
388
|
+
if int(v) == eos_id:
|
|
389
|
+
cut = i
|
|
390
|
+
break
|
|
691
391
|
|
|
692
|
-
T =
|
|
693
|
-
if T
|
|
392
|
+
T = int(cut)
|
|
393
|
+
if T <= 0:
|
|
694
394
|
return torch.zeros(
|
|
695
|
-
0, self.cfg.num_codebooks, dtype=torch.long, device=device
|
|
395
|
+
(0, int(self.cfg.num_codebooks)), dtype=torch.long, device=device
|
|
696
396
|
)
|
|
697
397
|
|
|
698
|
-
|
|
699
|
-
cond_seq = prep["
|
|
700
|
-
|
|
701
|
-
return
|
|
398
|
+
rvq1 = torch.tensor(hist[:T], device=device, dtype=torch.long).unsqueeze(0)
|
|
399
|
+
cond_seq = prep["cond_ar"][:, :T, :]
|
|
400
|
+
tokens_1xTQ = self.nar_refine(cond_seq, rvq1)
|
|
401
|
+
return tokens_1xTQ.squeeze(0)
|
|
702
402
|
|
|
703
403
|
|
|
704
404
|
class SoproTTS:
|
|
@@ -738,16 +438,14 @@ class SoproTTS:
|
|
|
738
438
|
raise FileNotFoundError(f"Expected {model_path} in repo snapshot.")
|
|
739
439
|
|
|
740
440
|
cfg = load_cfg_from_safetensors(model_path)
|
|
741
|
-
|
|
742
441
|
tokenizer = TextTokenizer(model_name=local_dir)
|
|
743
442
|
|
|
744
443
|
model = SoproTTSModel(cfg, tokenizer).to(dev).eval()
|
|
745
444
|
state = load_state_dict_from_safetensors(model_path)
|
|
746
445
|
|
|
747
|
-
model.load_state_dict(state)
|
|
446
|
+
model.load_state_dict(state, strict=False)
|
|
748
447
|
|
|
749
448
|
codec = MimiCodec(num_quantizers=cfg.num_codebooks, device=device)
|
|
750
|
-
|
|
751
449
|
return cls(
|
|
752
450
|
model=model, cfg=cfg, tokenizer=tokenizer, codec=codec, device=device
|
|
753
451
|
)
|
|
@@ -756,6 +454,26 @@ class SoproTTS:
|
|
|
756
454
|
ids = self.tokenizer.encode(text)
|
|
757
455
|
return torch.tensor(ids, dtype=torch.long, device=self.device)
|
|
758
456
|
|
|
457
|
+
@torch.inference_mode()
|
|
458
|
+
def encode_speaker(
|
|
459
|
+
self,
|
|
460
|
+
*,
|
|
461
|
+
ref_audio_path: Optional[str] = None,
|
|
462
|
+
ref_tokens_tq: Optional[torch.Tensor] = None,
|
|
463
|
+
ref_seconds: Optional[float] = None,
|
|
464
|
+
) -> torch.Tensor:
|
|
465
|
+
ref = self.encode_reference(
|
|
466
|
+
ref_audio_path=ref_audio_path,
|
|
467
|
+
ref_tokens_tq=ref_tokens_tq,
|
|
468
|
+
ref_seconds=ref_seconds,
|
|
469
|
+
)
|
|
470
|
+
ref_btq = ref.unsqueeze(0)
|
|
471
|
+
lengths = torch.tensor(
|
|
472
|
+
[int(ref_btq.size(1))], device=self.device, dtype=torch.long
|
|
473
|
+
)
|
|
474
|
+
sv = self.model.token2sv(ref_btq, lengths=lengths)
|
|
475
|
+
return sv.squeeze(0).detach()
|
|
476
|
+
|
|
759
477
|
def encode_reference(
|
|
760
478
|
self,
|
|
761
479
|
*,
|
|
@@ -763,6 +481,8 @@ class SoproTTS:
|
|
|
763
481
|
ref_tokens_tq: Optional[torch.Tensor] = None,
|
|
764
482
|
ref_seconds: Optional[float] = None,
|
|
765
483
|
) -> torch.Tensor:
|
|
484
|
+
from .sampling import center_crop_tokens
|
|
485
|
+
|
|
766
486
|
if (ref_tokens_tq is None) and (ref_audio_path is None):
|
|
767
487
|
raise RuntimeError(
|
|
768
488
|
"SoproTTS requires a reference. Provide ref_audio_path=... or ref_tokens_tq=..."
|
|
@@ -773,11 +493,11 @@ class SoproTTS:
|
|
|
773
493
|
)
|
|
774
494
|
|
|
775
495
|
if ref_seconds is None:
|
|
776
|
-
ref_seconds =
|
|
496
|
+
ref_seconds = 12.0
|
|
777
497
|
|
|
778
498
|
if ref_tokens_tq is not None:
|
|
779
499
|
ref = ref_tokens_tq.to(self.device).long()
|
|
780
|
-
if ref_seconds > 0:
|
|
500
|
+
if ref_seconds and ref_seconds > 0:
|
|
781
501
|
fps = float(self.cfg.mimi_fps)
|
|
782
502
|
win = max(1, int(round(ref_seconds * fps)))
|
|
783
503
|
ref = center_crop_tokens(ref, win)
|
|
@@ -793,51 +513,61 @@ class SoproTTS:
|
|
|
793
513
|
)
|
|
794
514
|
return ref
|
|
795
515
|
|
|
796
|
-
@torch.
|
|
516
|
+
@torch.inference_mode()
|
|
517
|
+
def prepare_reference(
|
|
518
|
+
self,
|
|
519
|
+
*,
|
|
520
|
+
ref_audio_path: Optional[str] = None,
|
|
521
|
+
ref_tokens_tq: Optional[torch.Tensor] = None,
|
|
522
|
+
ref_seconds: Optional[float] = None,
|
|
523
|
+
) -> PreparedReference:
|
|
524
|
+
tokens_tq = self.encode_reference(
|
|
525
|
+
ref_audio_path=ref_audio_path,
|
|
526
|
+
ref_tokens_tq=ref_tokens_tq,
|
|
527
|
+
ref_seconds=ref_seconds,
|
|
528
|
+
)
|
|
529
|
+
return self.model.prepare_reference(tokens_tq, device=self.device)
|
|
530
|
+
|
|
531
|
+
@torch.inference_mode()
|
|
797
532
|
def synthesize(
|
|
798
533
|
self,
|
|
799
534
|
text: str,
|
|
800
535
|
*,
|
|
536
|
+
ref: Optional[PreparedReference] = None,
|
|
801
537
|
ref_audio_path: Optional[str] = None,
|
|
802
538
|
ref_tokens_tq: Optional[torch.Tensor] = None,
|
|
803
539
|
max_frames: int = 400,
|
|
804
540
|
top_p: float = 0.9,
|
|
805
541
|
temperature: float = 1.05,
|
|
806
542
|
anti_loop: bool = True,
|
|
807
|
-
use_prefix: bool = True,
|
|
808
|
-
prefix_sec_fixed: Optional[float] = None,
|
|
809
543
|
style_strength: Optional[float] = None,
|
|
810
544
|
ref_seconds: Optional[float] = None,
|
|
811
|
-
use_stop_head: Optional[bool] = None,
|
|
812
|
-
stop_patience: Optional[int] = None,
|
|
813
|
-
stop_threshold: Optional[float] = None,
|
|
814
545
|
min_gen_frames: Optional[int] = None,
|
|
815
546
|
) -> torch.Tensor:
|
|
816
547
|
text_ids = self.encode_text(text)
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
548
|
+
|
|
549
|
+
if ref is None:
|
|
550
|
+
ref = self.prepare_reference(
|
|
551
|
+
ref_audio_path=ref_audio_path,
|
|
552
|
+
ref_tokens_tq=ref_tokens_tq,
|
|
553
|
+
ref_seconds=ref_seconds,
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
text_ids = self.encode_text(text)
|
|
822
557
|
|
|
823
558
|
tokens_tq = self.model.generate_tokens(
|
|
824
559
|
text_ids,
|
|
825
|
-
ref,
|
|
560
|
+
ref=ref,
|
|
826
561
|
max_frames=max_frames,
|
|
827
562
|
device=self.device,
|
|
828
563
|
top_p=top_p,
|
|
829
564
|
temperature=temperature,
|
|
830
565
|
anti_loop=anti_loop,
|
|
831
|
-
use_prefix=use_prefix,
|
|
832
|
-
prefix_sec_fixed=prefix_sec_fixed,
|
|
833
566
|
style_strength=float(
|
|
834
567
|
style_strength
|
|
835
568
|
if style_strength is not None
|
|
836
569
|
else self.cfg.style_strength
|
|
837
570
|
),
|
|
838
|
-
use_stop_head=use_stop_head,
|
|
839
|
-
stop_patience=stop_patience,
|
|
840
|
-
stop_threshold=stop_threshold,
|
|
841
571
|
min_gen_frames=min_gen_frames,
|
|
842
572
|
)
|
|
843
573
|
|