sopro 1.0.2__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sopro/model.py CHANGED
@@ -1,11 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import os
4
+ from dataclasses import dataclass
4
5
  from typing import Dict, Iterator, List, Optional, Tuple
5
6
 
6
7
  import torch
7
8
  import torch.nn as nn
8
- import torch.nn.functional as F
9
9
 
10
10
  from sopro.config import SoproTTSConfig
11
11
  from sopro.hub import (
@@ -18,152 +18,36 @@ from .audio import save_audio
18
18
  from .codec.mimi import MimiCodec
19
19
  from .constants import TARGET_SR
20
20
  from .nn import (
21
+ ARRVQ1Generator,
21
22
  CodebookEmbedding,
22
- RefXAttn,
23
+ RefXAttnStack,
23
24
  RMSNorm,
24
25
  SinusoidalPositionalEmbedding,
25
26
  SpeakerFiLM,
26
27
  SSMLiteBlock,
27
- TextEmbedding,
28
- TextXAttnBlock,
28
+ TextEncoder,
29
29
  Token2SV,
30
30
  )
31
- from .sampling import center_crop_tokens, repeated_tail
31
+ from .nn.nar import NARSinglePass
32
+ from .sampling import repeated_tail
32
33
  from .sampling import rf_ar as rf_ar_fn
33
34
  from .sampling import rf_nar as rf_nar_fn
34
35
  from .sampling import sample_token
35
36
  from .tokenizer import TextTokenizer
36
37
 
37
38
 
38
- class TextEncoder(nn.Module):
39
- def __init__(
40
- self,
41
- cfg: SoproTTSConfig,
42
- d_model: int,
43
- n_layers: int,
44
- tokenizer: TextTokenizer,
45
- ):
46
- super().__init__()
47
- self.tok = tokenizer
48
- self.embed = TextEmbedding(self.tok.vocab_size, d_model)
49
- self.layers = nn.ModuleList(
50
- [SSMLiteBlock(d_model, cfg.dropout, causal=False) for _ in range(n_layers)]
51
- )
52
- self.pos = SinusoidalPositionalEmbedding(d_model, max_len=cfg.max_text_len + 8)
53
- self.norm = RMSNorm(d_model)
54
-
55
- def forward(
56
- self, text_ids: torch.Tensor, mask: torch.Tensor
57
- ) -> Tuple[torch.Tensor, torch.Tensor]:
58
- x = self.embed(text_ids)
59
- L = x.size(1)
60
- pos = self.pos(torch.arange(L, device=x.device))
61
- x = x + pos.unsqueeze(0)
62
-
63
- x = x * mask.unsqueeze(-1).float()
64
- for layer in self.layers:
65
- x = layer(x)
66
- x = self.norm(x)
67
-
68
- mask_f = mask.float().unsqueeze(-1)
69
- pooled = (x * mask_f).sum(dim=1) / (mask_f.sum(dim=1) + 1e-6)
70
- return x, pooled
71
-
72
-
73
- class ARRVQ1Generator(nn.Module):
74
- def __init__(self, cfg: SoproTTSConfig, d_model: int, vocab: int):
75
- super().__init__()
76
- ks = cfg.ar_kernel
77
- dils: List[int] = []
78
- while len(dils) < cfg.n_layers_ar:
79
- dils.extend(list(cfg.ar_dilation_cycle))
80
- dils = dils[: cfg.n_layers_ar]
81
-
82
- self.dils = tuple(int(d) for d in dils)
83
- self.blocks = nn.ModuleList(
84
- [
85
- SSMLiteBlock(
86
- d_model, cfg.dropout, causal=True, kernel_size=ks, dilation=d
87
- )
88
- for d in self.dils
89
- ]
90
- )
91
-
92
- self.attn_freq = int(cfg.ar_text_attn_freq)
93
- self.x_attns = nn.ModuleList()
94
- for i in range(len(self.blocks)):
95
- if (i + 1) % self.attn_freq == 0:
96
- self.x_attns.append(
97
- TextXAttnBlock(d_model, heads=4, dropout=cfg.dropout)
98
- )
99
- else:
100
- self.x_attns.append(nn.Identity())
101
-
102
- self.norm = RMSNorm(d_model)
103
- self.head = nn.Linear(d_model, vocab)
104
-
105
- def forward(
106
- self,
107
- x: torch.Tensor,
108
- text_emb: Optional[torch.Tensor] = None,
109
- text_mask: Optional[torch.Tensor] = None,
110
- ) -> torch.Tensor:
111
- key_padding_mask = ~text_mask if text_mask is not None else None
112
-
113
- if key_padding_mask is not None:
114
- bad_rows = key_padding_mask.all(dim=1)
115
- if bad_rows.any():
116
- key_padding_mask = key_padding_mask.clone()
117
- idx = torch.nonzero(bad_rows, as_tuple=False).squeeze(1)
118
- key_padding_mask[idx, 0] = False
119
- if text_emb is not None:
120
- text_emb = text_emb.clone()
121
- text_emb[idx, 0, :] = 0
122
-
123
- h = x
124
- for i, lyr in enumerate(self.blocks):
125
- h = lyr(h)
126
- if not isinstance(self.x_attns[i], nn.Identity) and text_emb is not None:
127
- h = self.x_attns[i](h, text_emb, key_padding_mask=key_padding_mask)
128
-
129
- h = self.norm(h)
130
- return self.head(h)
131
-
132
-
133
- class StageRefiner(nn.Module):
134
- def __init__(self, cfg: SoproTTSConfig, D: int, num_heads: int, codebook_size: int):
135
- super().__init__()
136
- self.blocks = nn.ModuleList(
137
- [SSMLiteBlock(D, cfg.dropout, causal=True) for _ in range(cfg.n_layers_nar)]
138
- )
139
- self.norm = RMSNorm(D)
140
- self.pre = nn.Linear(D, cfg.nar_head_dim)
141
- self.heads = nn.ModuleList(
142
- [nn.Linear(cfg.nar_head_dim, codebook_size) for _ in range(num_heads)]
143
- )
144
- self.mix = nn.Parameter(torch.ones(2, dtype=torch.float32))
145
-
146
- def forward_hidden(
147
- self, cond_bt_d: torch.Tensor, prev_bt_d: torch.Tensor
148
- ) -> torch.Tensor:
149
- w = torch.softmax(self.mix, dim=0)
150
- x = w[0] * cond_bt_d + w[1] * prev_bt_d
151
- for b in self.blocks:
152
- x = b(x)
153
- return self.norm(x)
154
-
155
- def forward_heads(self, h: torch.Tensor) -> List[torch.Tensor]:
156
- z = self.pre(h)
157
- return [head(z) for head in self.heads]
158
-
39
+ def _stage_range_to_indices(stage_rng: Tuple[int, int], Q: int) -> List[int]:
40
+ lo, hi = int(stage_rng[0]), int(stage_rng[1])
41
+ idxs = list(range(lo - 1, hi))
42
+ return [i for i in idxs if 1 <= i < Q]
159
43
 
160
- class StopHead(nn.Module):
161
- def __init__(self, D: int):
162
- super().__init__()
163
- self.proj = nn.Linear(D, 1)
164
44
 
165
- def forward(self, h: torch.Tensor) -> torch.Tensor:
166
- return self.proj(h).squeeze(-1)
45
+ @dataclass
46
+ class PreparedReference:
47
+ ref_tokens_btq: torch.Tensor
48
+ sv_ref: torch.Tensor
49
+ ref_seq: torch.Tensor
50
+ ref_kv_caches: List[Dict[str, torch.Tensor]]
167
51
 
168
52
 
169
53
  class SoproTTSModel(nn.Module):
@@ -172,327 +56,165 @@ class SoproTTSModel(nn.Module):
172
56
  self.cfg = cfg
173
57
  D = int(cfg.d_model)
174
58
 
175
- self.text_enc = TextEncoder(cfg, D, cfg.n_layers_text, tokenizer)
176
- self.frame_pos = SinusoidalPositionalEmbedding(D, max_len=cfg.pos_emb_max + 2)
59
+ self.eos_id = int(cfg.codebook_size)
60
+
61
+ self.text_enc = TextEncoder(cfg, D, int(cfg.n_layers_text), tokenizer)
62
+ self.frame_pos = SinusoidalPositionalEmbedding(
63
+ D, max_len=int(cfg.pos_emb_max) + 8
64
+ )
177
65
 
178
66
  self.cb_embed = CodebookEmbedding(
179
67
  cfg.num_codebooks, cfg.codebook_size, D, use_bos=True
180
68
  )
181
- self.rvq1_bos_id = self.cb_embed.bos_id
69
+
70
+ self.nar_prev_cb_weights = nn.Parameter(
71
+ torch.zeros(cfg.num_codebooks, dtype=torch.float32)
72
+ )
182
73
 
183
74
  self.token2sv = Token2SV(
184
75
  cfg.num_codebooks,
185
76
  cfg.codebook_size,
186
77
  d=192,
187
- out_dim=cfg.sv_student_dim,
78
+ out_dim=int(cfg.sv_student_dim),
188
79
  dropout=cfg.dropout,
189
80
  )
190
- self.spk_film = SpeakerFiLM(D, sv_dim=cfg.sv_student_dim)
191
- self.cond_norm = RMSNorm(D)
81
+ self.spk_film = SpeakerFiLM(D, sv_dim=int(cfg.sv_student_dim))
192
82
 
193
- self.ar = ARRVQ1Generator(cfg, D, cfg.codebook_size)
194
- if cfg.ar_lookback > 0:
195
- self.ar_hist_w = nn.Parameter(torch.zeros(cfg.ar_lookback))
196
-
197
- def idxs(rng: Tuple[int, int]) -> List[int]:
198
- lo, hi = rng
199
- return list(range(lo - 1, hi))
83
+ self.ar = ARRVQ1Generator(cfg, D, int(cfg.codebook_size) + 1)
200
84
 
85
+ Q = int(cfg.num_codebooks)
201
86
  self.stage_indices: Dict[str, List[int]] = {
202
- "B": idxs(cfg.stage_B),
203
- "C": idxs(cfg.stage_C),
204
- "D": idxs(cfg.stage_D),
205
- "E": idxs(cfg.stage_E),
87
+ "B": _stage_range_to_indices(cfg.stage_B, Q),
88
+ "C": _stage_range_to_indices(cfg.stage_C, Q),
89
+ "D": _stage_range_to_indices(cfg.stage_D, Q),
90
+ "E": _stage_range_to_indices(cfg.stage_E, Q),
206
91
  }
207
- self.stages = nn.ModuleDict(
208
- {
209
- s: StageRefiner(cfg, D, len(self.stage_indices[s]), cfg.codebook_size)
210
- for s in ["B", "C", "D", "E"]
211
- }
212
- )
92
+ self.stage_order = [
93
+ s for s in ["B", "C", "D", "E"] if len(self.stage_indices[s]) > 0
94
+ ]
95
+
96
+ self.nar = NARSinglePass(cfg, D, stage_specs=self.stage_indices)
213
97
 
214
- self.stop_head = StopHead(D) if cfg.use_stop_head else None
98
+ self.cond_norm = RMSNorm(D)
215
99
 
100
+ ref_enc_layers = int(getattr(cfg, "ref_enc_layers", 2))
216
101
  self.ref_enc_blocks = nn.ModuleList(
217
- [SSMLiteBlock(D, cfg.dropout, causal=False) for _ in range(2)]
102
+ [SSMLiteBlock(D, cfg.dropout, causal=False) for _ in range(ref_enc_layers)]
218
103
  )
219
104
  self.ref_enc_norm = RMSNorm(D)
220
- self.ref_xattn_stack = RefXAttn(
221
- D, heads=cfg.ref_attn_heads, layers=3, dropout=cfg.dropout
105
+
106
+ self.ref_xattn = RefXAttnStack(
107
+ D,
108
+ heads=cfg.ref_xattn_heads,
109
+ layers=cfg.ref_xattn_layers,
110
+ gmax=cfg.ref_xattn_gmax,
111
+ )
112
+
113
+ self.register_buffer(
114
+ "ref_cb_weights",
115
+ torch.linspace(1.0, 0.1, int(cfg.num_codebooks)),
116
+ persistent=True,
222
117
  )
223
118
 
224
119
  def rf_ar(self) -> int:
225
- return rf_ar_fn(self.cfg.ar_kernel, self.ar.dils)
120
+ return rf_ar_fn(
121
+ int(self.cfg.ar_kernel),
122
+ getattr(self.ar, "dils", tuple(int(x) for x in self.cfg.ar_dilation_cycle)),
123
+ )
226
124
 
227
125
  def rf_nar(self) -> int:
228
- return rf_nar_fn(self.cfg.n_layers_nar, kernel_size=7, dilation=1)
229
-
230
- def _pool_time(self, x: torch.Tensor, factor: int) -> torch.Tensor:
231
- if factor <= 1 or x.size(1) < 2 * factor:
232
- return x
233
- return F.avg_pool1d(
234
- x.transpose(1, 2), kernel_size=factor, stride=factor
235
- ).transpose(1, 2)
236
-
237
- def _normalize_ref_mask(
238
- self, ref_mask: Optional[torch.Tensor], device: torch.device
239
- ) -> Optional[torch.Tensor]:
240
- if ref_mask is None:
241
- return None
242
- mk = ref_mask.to(device).bool()
243
- if mk.ndim == 1:
244
- mk = mk.unsqueeze(0)
245
- return mk
246
-
247
- def _encode_reference_seq(self, ref_tokens: torch.Tensor) -> torch.Tensor:
248
- B, Tr, Q = ref_tokens.shape
249
- emb_sum = 0.0
250
- for q in range(Q):
251
- emb_sum = emb_sum + self.cb_embed.embed_tokens(
252
- ref_tokens[:, :, q], cb_index=q
253
- )
254
- x = emb_sum / float(Q)
255
- for b in self.ref_enc_blocks:
256
- x = b(x)
257
- return self.ref_enc_norm(x)
126
+ cycle = tuple(int(x) for x in self.cfg.nar_dilation_cycle) or (1,)
127
+ dils: List[int] = []
128
+ while len(dils) < int(self.cfg.n_layers_nar):
129
+ dils.extend(list(cycle))
130
+ dils = dils[: int(self.cfg.n_layers_nar)]
131
+ return rf_nar_fn(int(self.cfg.nar_kernel_size), tuple(dils))
258
132
 
259
- def _single_pass_ref_xattn(
260
- self,
261
- cond_btd: torch.Tensor,
262
- ref_seq: torch.Tensor,
263
- ref_mask: Optional[torch.Tensor] = None,
264
- ) -> torch.Tensor:
265
- ref_seq_p = self._pool_time(ref_seq, 1)
266
-
267
- key_padding_mask = None
268
- if ref_mask is not None:
269
- mk_bool = ref_mask.bool()
270
- B, Tr = mk_bool.shape
271
- pooled_len = ref_seq_p.size(1)
272
- if pooled_len == Tr:
273
- key_padding_mask = ~mk_bool
274
- else:
275
- cut = pooled_len * 2
276
- mk2 = mk_bool[:, :cut].reshape(B, pooled_len, 2).any(dim=2)
277
- key_padding_mask = ~mk2
133
+ @torch.no_grad()
134
+ def _encode_reference_seq(self, ref_tokens_btq: torch.Tensor) -> torch.Tensor:
135
+ B, Tr, Q = ref_tokens_btq.shape
278
136
 
279
- return self.ref_xattn_stack(
280
- cond_btd, ref_seq_p, key_padding_mask=key_padding_mask
137
+ w = torch.softmax(self.ref_cb_weights.float(), dim=0).to(
138
+ device=ref_tokens_btq.device
281
139
  )
282
140
 
283
- def _base_cond_at(
284
- self, t: int, txt_pool: torch.Tensor, device: torch.device
285
- ) -> torch.Tensor:
286
- pos = self.frame_pos(torch.tensor([t], device=device)).unsqueeze(0)
287
- return txt_pool[:, None, :] + pos
288
-
289
- def _ar_prev_from_seq(self, seq_1xT: torch.Tensor) -> torch.Tensor:
290
- K = int(self.cfg.ar_lookback)
291
- if K <= 0 or getattr(self, "ar_hist_w", None) is None:
292
- return self.cb_embed.embed_shift_by_k(seq_1xT, cb_index=0, k=1)
293
-
294
- ws = torch.softmax(self.ar_hist_w, dim=0)
295
- acc = 0.0
296
- k_max = min(K, int(seq_1xT.size(1)))
297
- for k in range(1, k_max + 1):
298
- acc = acc + ws[k - 1] * self.cb_embed.embed_shift_by_k(
299
- seq_1xT, cb_index=0, k=k
300
- )
301
- return acc
141
+ x = 0.0
142
+ for q in range(Q):
143
+ e = self.cb_embed.embed_tokens(ref_tokens_btq[:, :, q], cb_index=q)
144
+ x = x + w[q].to(e.dtype) * e
302
145
 
303
- @torch.no_grad()
304
- def prepare_conditioning(
305
- self,
306
- text_ids_1d: torch.Tensor,
307
- ref_tokens_tq: torch.Tensor,
308
- *,
309
- max_frames: int,
310
- device: torch.device,
311
- style_strength: float = 1.0,
312
- ref_mask: Optional[torch.Tensor] = None,
313
- chunk_size: Optional[int] = None,
314
- ) -> Dict[str, torch.Tensor]:
315
- self.eval()
146
+ for b in self.ref_enc_blocks:
147
+ x = b(x)
316
148
 
317
- text_ids = text_ids_1d.to(device)
318
- text_mask = torch.ones_like(text_ids, dtype=torch.bool).unsqueeze(0)
319
- txt_seq, txt_pool = self.text_enc(text_ids.unsqueeze(0), text_mask)
149
+ return self.ref_enc_norm(x)
320
150
 
321
- ref_btq = ref_tokens_tq.unsqueeze(0).to(device)
151
+ @torch.no_grad()
152
+ def prepare_reference(
153
+ self, ref_tokens_tq: torch.Tensor, *, device: torch.device
154
+ ) -> PreparedReference:
155
+ ref_tokens_btq = ref_tokens_tq.unsqueeze(0).to(device=device, dtype=torch.long)
156
+ Tr = int(ref_tokens_btq.size(1))
322
157
 
323
- sv_ref = self.token2sv(ref_btq, lengths=None)
324
- ref_seq = self._encode_reference_seq(ref_btq)
325
- ref_mask_btr = self._normalize_ref_mask(ref_mask, device)
158
+ lengths = torch.tensor([Tr], device=device, dtype=torch.long)
159
+ sv_ref = self.token2sv(ref_tokens_btq, lengths=lengths)
326
160
 
327
- T = int(max_frames)
328
- if T <= 0:
329
- cond_all = torch.zeros(
330
- (1, 0, txt_pool.size(-1)), device=device, dtype=txt_pool.dtype
331
- )
332
- else:
333
- pos = self.frame_pos(torch.arange(T, device=device)).unsqueeze(0)
334
- base_all = txt_pool[:, None, :] + pos
335
- base_all = self.spk_film(base_all, sv_ref, strength=float(style_strength))
336
-
337
- if chunk_size is None or int(chunk_size) >= T:
338
- out = self._single_pass_ref_xattn(
339
- base_all, ref_seq, ref_mask=ref_mask_btr
340
- )
341
- cond_all = self.cond_norm(out)
342
- else:
343
- cs = int(chunk_size)
344
- chunks: List[torch.Tensor] = []
345
- for s in range(0, T, cs):
346
- e = min(T, s + cs)
347
- q = base_all[:, s:e, :]
348
- out = self._single_pass_ref_xattn(q, ref_seq, ref_mask=ref_mask_btr)
349
- chunks.append(self.cond_norm(out))
350
- cond_all = torch.cat(chunks, dim=1)
161
+ ref_seq = self._encode_reference_seq(ref_tokens_btq)
351
162
 
352
- return {
353
- "txt_seq": txt_seq,
354
- "text_mask": text_mask,
355
- "cond_all": cond_all,
356
- "ref_btq": ref_btq,
357
- "txt_pool": txt_pool,
358
- "sv_ref": sv_ref,
359
- "ref_seq": ref_seq,
360
- "ref_mask": (
361
- ref_mask_btr
362
- if ref_mask_btr is not None
363
- else torch.empty(0, device=device, dtype=torch.bool)
364
- ),
365
- "style_strength": torch.tensor(float(style_strength), device=device),
366
- }
163
+ ref_kv_caches = self.ref_xattn.build_kv_caches(ref_seq, key_padding_mask=None)
164
+
165
+ return PreparedReference(
166
+ ref_tokens_btq=ref_tokens_btq,
167
+ sv_ref=sv_ref,
168
+ ref_seq=ref_seq,
169
+ ref_kv_caches=ref_kv_caches,
170
+ )
367
171
 
368
172
  @torch.no_grad()
369
- def prepare_conditioning_lazy(
173
+ def prepare_conditioning(
370
174
  self,
371
175
  text_ids_1d: torch.Tensor,
372
- ref_tokens_tq: torch.Tensor,
176
+ ref: PreparedReference,
373
177
  *,
374
178
  max_frames: int,
375
179
  device: torch.device,
376
- style_strength: float = 1.0,
377
- ref_mask: Optional[torch.Tensor] = None,
180
+ style_strength: float = 1.2,
378
181
  ) -> Dict[str, torch.Tensor]:
379
182
  self.eval()
183
+ sv_ref = ref.sv_ref.to(device)
380
184
 
381
185
  text_ids = text_ids_1d.to(device)
382
186
  text_mask = torch.ones_like(text_ids, dtype=torch.bool).unsqueeze(0)
383
187
  txt_seq, txt_pool = self.text_enc(text_ids.unsqueeze(0), text_mask)
384
188
 
385
- ref_btq = ref_tokens_tq.unsqueeze(0).to(device)
386
- sv_ref = self.token2sv(ref_btq, lengths=None)
387
- ref_seq = self._encode_reference_seq(ref_btq)
388
- ref_mask_btr = self._normalize_ref_mask(ref_mask, device)
189
+ if sv_ref is not None:
190
+ if sv_ref.dim() == 1:
191
+ sv_ref = sv_ref.unsqueeze(0)
192
+ sv_ref = sv_ref.to(device)
193
+ else:
194
+ ref_btq = ref_tokens_tq.unsqueeze(0).to(device)
195
+ ref_len = torch.tensor(
196
+ [int(ref_btq.size(1))], device=device, dtype=torch.long
197
+ )
198
+ sv_ref = self.token2sv(ref_btq, lengths=ref_len)
199
+
200
+ Tar = int(max_frames) + 1
201
+ pos = self.frame_pos(torch.arange(Tar, device=device)).unsqueeze(0)
202
+ base_ar = txt_pool[:, None, :] + pos
203
+ cond_ar = self.spk_film(base_ar, sv_ref, strength=float(style_strength))
389
204
 
390
- D = int(txt_pool.size(-1))
391
- cond_all = torch.zeros((1, 0, D), device=device, dtype=txt_pool.dtype)
205
+ cond_ar, _ = self.ref_xattn(
206
+ cond_ar, kv_caches=ref.ref_kv_caches, use_cache=True
207
+ )
208
+ cond_ar = self.cond_norm(cond_ar)
392
209
 
393
210
  return {
394
211
  "txt_seq": txt_seq,
395
212
  "text_mask": text_mask,
396
- "cond_all": cond_all,
397
- "ref_btq": ref_btq,
398
213
  "txt_pool": txt_pool,
399
214
  "sv_ref": sv_ref,
400
- "ref_seq": ref_seq,
401
- "ref_mask": (
402
- ref_mask_btr
403
- if ref_mask_btr is not None
404
- else torch.empty(0, device=device, dtype=torch.bool)
405
- ),
406
- "style_strength": torch.tensor(float(style_strength), device=device),
407
- "max_frames": torch.tensor(int(max_frames), device=device),
215
+ "cond_ar": cond_ar,
408
216
  }
409
217
 
410
- @torch.no_grad()
411
- def ensure_cond_upto(
412
- self,
413
- prep: Dict[str, torch.Tensor],
414
- t_inclusive: int,
415
- *,
416
- chunk_size: int = 64,
417
- ) -> None:
418
- if t_inclusive < 0:
419
- return
420
-
421
- cond_all = prep.get("cond_all", None)
422
- if cond_all is None:
423
- raise KeyError("prep dict missing 'cond_all'.")
424
-
425
- have = int(cond_all.size(1))
426
- need_min = int(t_inclusive) + 1
427
- if have >= need_min:
428
- return
429
-
430
- if "txt_pool" not in prep or "sv_ref" not in prep or "ref_seq" not in prep:
431
- raise RuntimeError(
432
- "Lazy conditioning requested but prep lacks txt_pool/sv_ref/ref_seq. "
433
- "Use prepare_conditioning_lazy() or prepare_conditioning(..., chunk_size=...)."
434
- )
435
-
436
- device = cond_all.device
437
- txt_pool = prep["txt_pool"]
438
- sv_ref = prep["sv_ref"]
439
- ref_seq = prep["ref_seq"]
440
-
441
- style_strength = float(
442
- prep.get(
443
- "style_strength", torch.tensor(self.cfg.style_strength, device=device)
444
- ).item()
445
- )
446
-
447
- ref_mask = prep.get("ref_mask", None)
448
- if ref_mask is not None and ref_mask.numel() == 0:
449
- ref_mask = None
450
-
451
- max_frames = prep.get("max_frames", None)
452
- maxT = int(max_frames.item()) if max_frames is not None else None
453
-
454
- cs = max(1, int(chunk_size))
455
-
456
- need = ((need_min + cs - 1) // cs) * cs
457
- if maxT is not None:
458
- need = min(need, maxT)
459
-
460
- if have >= need:
461
- return
462
-
463
- new_chunks: List[torch.Tensor] = []
464
- for s in range(have, need, cs):
465
- e = min(need, s + cs)
466
- pos = self.frame_pos(torch.arange(s, e, device=device)).unsqueeze(0)
467
- base = txt_pool[:, None, :] + pos
468
- base = self.spk_film(base, sv_ref, strength=style_strength)
469
- out = self._single_pass_ref_xattn(base, ref_seq, ref_mask=ref_mask)
470
- new_chunks.append(self.cond_norm(out))
471
-
472
- prep["cond_all"] = torch.cat([cond_all] + new_chunks, dim=1)
473
-
474
- def build_ar_prefix(
475
- self,
476
- ref_btq: torch.Tensor,
477
- device: torch.device,
478
- prefix_sec_fixed: Optional[float],
479
- use_prefix: bool,
480
- ) -> torch.Tensor:
481
- if not use_prefix or ref_btq.size(1) == 0:
482
- return torch.zeros(1, 0, dtype=torch.long, device=device)
483
-
484
- avail = int(ref_btq.size(1))
485
- fps = float(self.cfg.mimi_fps)
486
-
487
- if prefix_sec_fixed is not None and prefix_sec_fixed > 0:
488
- P = min(avail, int(round(prefix_sec_fixed * fps)))
489
- else:
490
- P = min(avail, max(1, int(round(self.cfg.preprompt_sec_max * fps))))
491
-
492
- if P <= 0:
493
- return torch.zeros(1, 0, dtype=torch.long, device=device)
494
- return ref_btq[:, :P, 0].contiguous()
495
-
496
218
  @torch.no_grad()
497
219
  def ar_stream(
498
220
  self,
@@ -505,200 +227,178 @@ class SoproTTSModel(nn.Module):
505
227
  loop_streak: int = 8,
506
228
  recovery_top_p: float = 0.85,
507
229
  recovery_temp: float = 1.2,
508
- use_prefix: bool = True,
509
- prefix_sec_fixed: Optional[float] = None,
510
- cond_chunk_size: Optional[int] = None,
511
- use_stop_head: Optional[bool] = None,
512
- stop_patience: Optional[int] = None,
513
- stop_threshold: Optional[float] = None,
514
230
  min_gen_frames: Optional[int] = None,
515
- ) -> Iterator[Tuple[int, int, Optional[float]]]:
516
- device = prep["cond_all"].device
517
- cond_all = prep["cond_all"]
231
+ ) -> Iterator[Tuple[int, int, bool]]:
232
+ device = prep["cond_ar"].device
233
+ cond_ar = prep["cond_ar"]
518
234
  txt_seq = prep["txt_seq"]
519
235
  text_mask = prep["text_mask"]
520
- ref_btq = prep["ref_btq"]
521
-
522
- R_AR = self.rf_ar()
523
-
524
- stop_head = self.stop_head
525
- if use_stop_head is not None:
526
- if not bool(use_stop_head):
527
- stop_head = None
528
236
 
529
- eff_stop_patience = int(
530
- stop_patience if stop_patience is not None else self.cfg.stop_patience
531
- )
532
- eff_stop_threshold = float(
533
- stop_threshold if stop_threshold is not None else self.cfg.stop_threshold
534
- )
535
- eff_min_gen_frames = int(
237
+ eos_id = int(self.eos_id)
238
+ eff_min_gen = int(
536
239
  min_gen_frames if min_gen_frames is not None else self.cfg.min_gen_frames
537
240
  )
538
241
 
539
- A_prefix = self.build_ar_prefix(
540
- ref_btq, device, prefix_sec_fixed, use_prefix=use_prefix
541
- )
542
- P = int(A_prefix.size(1))
242
+ max_steps = int(max_frames) + 1
243
+ ctx_ids = torch.zeros((1, max_steps), dtype=torch.long, device=device)
543
244
 
544
- ctx_ids = torch.zeros(
545
- (1, P + int(max_frames) + 1), dtype=torch.long, device=device
245
+ ar_state = self.ar.init_stream_state(
246
+ batch_size=1,
247
+ device=device,
248
+ dtype=cond_ar.dtype,
249
+ text_emb=txt_seq,
250
+ text_mask=text_mask,
546
251
  )
547
- if P > 0:
548
- ctx_ids[:, :P] = A_prefix
549
252
 
550
- hist_A: List[int] = []
253
+ hist: List[int] = []
551
254
  loop_streak_count = 0
552
- stop_streak_count = 0
553
255
  last_a: Optional[int] = None
554
256
 
555
- gen_len = 0
556
-
557
- for t in range(int(max_frames)):
558
- if prep["cond_all"].size(1) < (t + 1):
559
- self.ensure_cond_upto(prep, t, chunk_size=int(cond_chunk_size or 64))
560
- cond_all = prep["cond_all"]
561
-
562
- L_ar = min(t + 1, R_AR)
563
- s_ar = t + 1 - L_ar
564
- cond_win_ar = cond_all[:, s_ar : t + 1, :]
257
+ if self.cb_embed.bos_id is None:
258
+ raise RuntimeError(
259
+ "CodebookEmbedding.use_bos must be True for streaming AR cache"
260
+ )
261
+ bos_idx = torch.full(
262
+ (1, 1), int(self.cb_embed.bos_id), device=device, dtype=torch.long
263
+ )
565
264
 
566
- total_len = P + gen_len + 1
567
- A_ctx_full = ctx_ids[:, :total_len]
265
+ for t in range(max_steps):
266
+ if t == 0:
267
+ prev_emb = self.cb_embed.emb(bos_idx)
268
+ else:
269
+ prev_tok = ctx_ids[:, t - 1 : t]
270
+ prev_emb = self.cb_embed.embed_tokens(prev_tok, cb_index=0)
568
271
 
569
- prev_ctx_full = self._ar_prev_from_seq(A_ctx_full)
570
- prev_ctx_win = prev_ctx_full[:, -L_ar:, :]
272
+ x_t = cond_ar[:, t : t + 1, :] + prev_emb
571
273
 
572
274
  cur_top_p, cur_temp = top_p, temperature
573
275
  if anti_loop:
574
- if repeated_tail(hist_A, max_n=16):
276
+ if repeated_tail(hist, max_n=16):
575
277
  cur_top_p, cur_temp = recovery_top_p, recovery_temp
576
278
  elif last_a is not None and loop_streak_count >= loop_streak:
577
279
  cur_top_p, cur_temp = recovery_top_p, recovery_temp
578
280
 
579
- ar_logits_win = self.ar(
580
- cond_win_ar + prev_ctx_win, text_emb=txt_seq, text_mask=text_mask
281
+ logits_t, ar_state = self.ar.step(
282
+ x_t, ar_state, text_emb=txt_seq, text_mask=text_mask
581
283
  )
582
- ar_logits_t = ar_logits_win[:, -1:, :]
583
-
584
- rvq1_id = sample_token(
585
- ar_logits_t,
586
- history=hist_A,
284
+ tok = sample_token(
285
+ logits_t,
286
+ history=hist,
587
287
  top_p=cur_top_p,
588
288
  temperature=cur_temp,
589
289
  top_k=50,
590
290
  repetition_penalty=1.1,
591
291
  )
592
292
 
593
- ctx_ids[0, P + gen_len] = int(rvq1_id)
594
- gen_len += 1
293
+ ctx_ids[0, t] = int(tok)
294
+ hist.append(int(tok))
595
295
 
596
- hist_A.append(int(rvq1_id))
597
296
  loop_streak_count = (
598
- (loop_streak_count + 1)
599
- if (last_a is not None and rvq1_id == last_a)
600
- else 0
297
+ (loop_streak_count + 1) if (last_a is not None and tok == last_a) else 0
601
298
  )
602
- last_a = int(rvq1_id)
603
-
604
- p_stop: Optional[float] = None
605
- if stop_head is not None:
606
- A_now = torch.tensor([[rvq1_id]], device=device, dtype=torch.long)
607
- stop_inp = (
608
- cond_all[:, t : t + 1, :]
609
- + self.cb_embed.embed_tokens(A_now, cb_index=0).detach()
610
- )
611
- stop_logits = stop_head(stop_inp)
612
- p_stop = float(torch.sigmoid(stop_logits).item())
613
-
614
- if t + 1 >= eff_min_gen_frames and p_stop > eff_stop_threshold:
615
- stop_streak_count += 1
616
- else:
617
- stop_streak_count = 0
618
-
619
- yield t, int(rvq1_id), p_stop
620
-
621
- if stop_head is not None and stop_streak_count >= eff_stop_patience:
299
+ last_a = int(tok)
300
+
301
+ is_eos = int(tok) == eos_id
302
+ yield t, int(tok), bool(is_eos)
303
+
304
+ if is_eos and (t + 1) >= eff_min_gen:
622
305
  break
623
306
 
624
307
  @torch.no_grad()
625
308
  def nar_refine(
626
- self, cond_seq: torch.Tensor, tokens_A_1xT: torch.Tensor
309
+ self, cond_seq: torch.Tensor, rvq1_1xT: torch.Tensor
627
310
  ) -> torch.Tensor:
628
- preds_all: List[torch.Tensor] = [tokens_A_1xT.unsqueeze(-1)]
629
- prev_tokens_list: List[torch.Tensor] = [tokens_A_1xT.unsqueeze(-1)]
311
+ B, T, D = cond_seq.shape
312
+ Q = int(self.cfg.num_codebooks)
313
+
314
+ out_btq = torch.zeros((B, T, Q), device=cond_seq.device, dtype=torch.long)
315
+ out_btq[:, :, 0] = rvq1_1xT
316
+
317
+ prev_tokens_list: List[torch.Tensor] = [rvq1_1xT.unsqueeze(-1)]
630
318
  prev_cb_list: List[List[int]] = [[0]]
631
319
 
632
- for stage_name in ["B", "C", "D", "E"]:
633
- idxs = self.stage_indices[stage_name]
320
+ for stage in self.stage_order:
321
+ idxs = self.stage_indices[stage]
322
+ if len(idxs) == 0:
323
+ continue
324
+
634
325
  prev_tokens_cat = torch.cat(prev_tokens_list, dim=-1)
635
326
  prev_cbs_cat = sum(prev_cb_list, [])
636
- prev_emb_sum = self.cb_embed.sum_embed_subset(prev_tokens_cat, prev_cbs_cat)
637
327
 
638
- h = self.stages[stage_name].forward_hidden(cond_seq, prev_emb_sum)
639
- logits_list = self.stages[stage_name].forward_heads(h)
640
- preds = torch.stack([x.argmax(dim=-1) for x in logits_list], dim=-1)
328
+ prev_emb_sum = self.cb_embed.sum_embed_subset(
329
+ prev_tokens_cat,
330
+ prev_cbs_cat,
331
+ keep_mask=None,
332
+ cb_weights=self.nar_prev_cb_weights,
333
+ )
334
+
335
+ logits_list = self.nar.forward_stage(stage, cond_seq, prev_emb_sum)
336
+ if len(logits_list) == 0:
337
+ continue
338
+
339
+ preds = torch.stack([lg.argmax(dim=-1) for lg in logits_list], dim=-1)
340
+
341
+ for k, cb in enumerate(idxs):
342
+ out_btq[:, :, cb] = preds[:, :, k]
641
343
 
642
- preds_all.append(preds)
643
- prev_tokens_list.append(preds)
344
+ prev_tokens_list.append(preds.detach())
644
345
  prev_cb_list.append(idxs)
645
346
 
646
- tokens_btq = torch.cat(preds_all, dim=-1)
647
- return tokens_btq
347
+ return out_btq
648
348
 
649
349
  @torch.no_grad()
650
350
  def generate_tokens(
651
351
  self,
652
352
  text_ids_1d: torch.Tensor,
653
- ref_tokens_tq: torch.Tensor,
353
+ ref: PreparedReference,
654
354
  *,
655
355
  max_frames: int,
656
356
  device: torch.device,
657
357
  top_p: float = 0.9,
658
358
  temperature: float = 1.05,
659
359
  anti_loop: bool = True,
660
- use_prefix: bool = True,
661
- prefix_sec_fixed: Optional[float] = None,
662
- style_strength: float = 1.0,
663
- use_stop_head: Optional[bool] = None,
664
- stop_patience: Optional[int] = None,
665
- stop_threshold: Optional[float] = None,
360
+ style_strength: float = 1.2,
666
361
  min_gen_frames: Optional[int] = None,
667
362
  ) -> torch.Tensor:
668
363
  prep = self.prepare_conditioning(
669
364
  text_ids_1d,
670
- ref_tokens_tq,
365
+ ref,
671
366
  max_frames=max_frames,
672
367
  device=device,
673
368
  style_strength=style_strength,
674
369
  )
675
370
 
676
- hist_A: List[int] = []
677
- for _t, rvq1, _p_stop in self.ar_stream(
371
+ eos_id = int(self.eos_id)
372
+ hist: List[int] = []
373
+ for _t, tok, is_eos in self.ar_stream(
678
374
  prep,
679
375
  max_frames=max_frames,
680
376
  top_p=top_p,
681
377
  temperature=temperature,
682
378
  anti_loop=anti_loop,
683
- use_prefix=use_prefix,
684
- prefix_sec_fixed=prefix_sec_fixed,
685
- use_stop_head=use_stop_head,
686
- stop_patience=stop_patience,
687
- stop_threshold=stop_threshold,
688
379
  min_gen_frames=min_gen_frames,
689
380
  ):
690
- hist_A.append(rvq1)
381
+ hist.append(tok)
382
+ if is_eos:
383
+ break
384
+
385
+ Tfull = len(hist)
386
+ cut = Tfull
387
+ for i, v in enumerate(hist):
388
+ if int(v) == eos_id:
389
+ cut = i
390
+ break
691
391
 
692
- T = len(hist_A)
693
- if T == 0:
392
+ T = int(cut)
393
+ if T <= 0:
694
394
  return torch.zeros(
695
- 0, self.cfg.num_codebooks, dtype=torch.long, device=device
395
+ (0, int(self.cfg.num_codebooks)), dtype=torch.long, device=device
696
396
  )
697
397
 
698
- tokens_A = torch.tensor(hist_A, device=device, dtype=torch.long).unsqueeze(0)
699
- cond_seq = prep["cond_all"][:, :T, :]
700
- tokens_btq_1xTQ = self.nar_refine(cond_seq, tokens_A)
701
- return tokens_btq_1xTQ.squeeze(0)
398
+ rvq1 = torch.tensor(hist[:T], device=device, dtype=torch.long).unsqueeze(0)
399
+ cond_seq = prep["cond_ar"][:, :T, :]
400
+ tokens_1xTQ = self.nar_refine(cond_seq, rvq1)
401
+ return tokens_1xTQ.squeeze(0)
702
402
 
703
403
 
704
404
  class SoproTTS:
@@ -738,16 +438,14 @@ class SoproTTS:
738
438
  raise FileNotFoundError(f"Expected {model_path} in repo snapshot.")
739
439
 
740
440
  cfg = load_cfg_from_safetensors(model_path)
741
-
742
441
  tokenizer = TextTokenizer(model_name=local_dir)
743
442
 
744
443
  model = SoproTTSModel(cfg, tokenizer).to(dev).eval()
745
444
  state = load_state_dict_from_safetensors(model_path)
746
445
 
747
- model.load_state_dict(state)
446
+ model.load_state_dict(state, strict=False)
748
447
 
749
448
  codec = MimiCodec(num_quantizers=cfg.num_codebooks, device=device)
750
-
751
449
  return cls(
752
450
  model=model, cfg=cfg, tokenizer=tokenizer, codec=codec, device=device
753
451
  )
@@ -756,6 +454,26 @@ class SoproTTS:
756
454
  ids = self.tokenizer.encode(text)
757
455
  return torch.tensor(ids, dtype=torch.long, device=self.device)
758
456
 
457
+ @torch.inference_mode()
458
+ def encode_speaker(
459
+ self,
460
+ *,
461
+ ref_audio_path: Optional[str] = None,
462
+ ref_tokens_tq: Optional[torch.Tensor] = None,
463
+ ref_seconds: Optional[float] = None,
464
+ ) -> torch.Tensor:
465
+ ref = self.encode_reference(
466
+ ref_audio_path=ref_audio_path,
467
+ ref_tokens_tq=ref_tokens_tq,
468
+ ref_seconds=ref_seconds,
469
+ )
470
+ ref_btq = ref.unsqueeze(0)
471
+ lengths = torch.tensor(
472
+ [int(ref_btq.size(1))], device=self.device, dtype=torch.long
473
+ )
474
+ sv = self.model.token2sv(ref_btq, lengths=lengths)
475
+ return sv.squeeze(0).detach()
476
+
759
477
  def encode_reference(
760
478
  self,
761
479
  *,
@@ -763,6 +481,8 @@ class SoproTTS:
763
481
  ref_tokens_tq: Optional[torch.Tensor] = None,
764
482
  ref_seconds: Optional[float] = None,
765
483
  ) -> torch.Tensor:
484
+ from .sampling import center_crop_tokens
485
+
766
486
  if (ref_tokens_tq is None) and (ref_audio_path is None):
767
487
  raise RuntimeError(
768
488
  "SoproTTS requires a reference. Provide ref_audio_path=... or ref_tokens_tq=..."
@@ -773,11 +493,11 @@ class SoproTTS:
773
493
  )
774
494
 
775
495
  if ref_seconds is None:
776
- ref_seconds = float(self.cfg.ref_seconds_max)
496
+ ref_seconds = 12.0
777
497
 
778
498
  if ref_tokens_tq is not None:
779
499
  ref = ref_tokens_tq.to(self.device).long()
780
- if ref_seconds > 0:
500
+ if ref_seconds and ref_seconds > 0:
781
501
  fps = float(self.cfg.mimi_fps)
782
502
  win = max(1, int(round(ref_seconds * fps)))
783
503
  ref = center_crop_tokens(ref, win)
@@ -793,51 +513,61 @@ class SoproTTS:
793
513
  )
794
514
  return ref
795
515
 
516
+ @torch.inference_mode()
517
+ def prepare_reference(
518
+ self,
519
+ *,
520
+ ref_audio_path: Optional[str] = None,
521
+ ref_tokens_tq: Optional[torch.Tensor] = None,
522
+ ref_seconds: Optional[float] = None,
523
+ ) -> PreparedReference:
524
+ tokens_tq = self.encode_reference(
525
+ ref_audio_path=ref_audio_path,
526
+ ref_tokens_tq=ref_tokens_tq,
527
+ ref_seconds=ref_seconds,
528
+ )
529
+ return self.model.prepare_reference(tokens_tq, device=self.device)
530
+
796
531
  @torch.inference_mode()
797
532
  def synthesize(
798
533
  self,
799
534
  text: str,
800
535
  *,
536
+ ref: Optional[PreparedReference] = None,
801
537
  ref_audio_path: Optional[str] = None,
802
538
  ref_tokens_tq: Optional[torch.Tensor] = None,
803
539
  max_frames: int = 400,
804
540
  top_p: float = 0.9,
805
541
  temperature: float = 1.05,
806
542
  anti_loop: bool = True,
807
- use_prefix: bool = True,
808
- prefix_sec_fixed: Optional[float] = None,
809
543
  style_strength: Optional[float] = None,
810
544
  ref_seconds: Optional[float] = None,
811
- use_stop_head: Optional[bool] = None,
812
- stop_patience: Optional[int] = None,
813
- stop_threshold: Optional[float] = None,
814
545
  min_gen_frames: Optional[int] = None,
815
546
  ) -> torch.Tensor:
816
547
  text_ids = self.encode_text(text)
817
- ref = self.encode_reference(
818
- ref_audio_path=ref_audio_path,
819
- ref_tokens_tq=ref_tokens_tq,
820
- ref_seconds=ref_seconds,
821
- )
548
+
549
+ if ref is None:
550
+ ref = self.prepare_reference(
551
+ ref_audio_path=ref_audio_path,
552
+ ref_tokens_tq=ref_tokens_tq,
553
+ ref_seconds=ref_seconds,
554
+ )
555
+
556
+ text_ids = self.encode_text(text)
822
557
 
823
558
  tokens_tq = self.model.generate_tokens(
824
559
  text_ids,
825
- ref,
560
+ ref=ref,
826
561
  max_frames=max_frames,
827
562
  device=self.device,
828
563
  top_p=top_p,
829
564
  temperature=temperature,
830
565
  anti_loop=anti_loop,
831
- use_prefix=use_prefix,
832
- prefix_sec_fixed=prefix_sec_fixed,
833
566
  style_strength=float(
834
567
  style_strength
835
568
  if style_strength is not None
836
569
  else self.cfg.style_strength
837
570
  ),
838
- use_stop_head=use_stop_head,
839
- stop_patience=stop_patience,
840
- stop_threshold=stop_threshold,
841
571
  min_gen_frames=min_gen_frames,
842
572
  )
843
573