sopro 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sopro/model.py ADDED
@@ -0,0 +1,853 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Dict, Iterator, List, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from sopro.config import SoproTTSConfig
11
+ from sopro.hub import (
12
+ download_repo,
13
+ load_cfg_from_safetensors,
14
+ load_state_dict_from_safetensors,
15
+ )
16
+
17
+ from .audio import save_audio
18
+ from .codec.mimi import MimiCodec
19
+ from .constants import TARGET_SR
20
+ from .nn import (
21
+ CodebookEmbedding,
22
+ RefXAttn,
23
+ RMSNorm,
24
+ SinusoidalPositionalEmbedding,
25
+ SpeakerFiLM,
26
+ SSMLiteBlock,
27
+ TextEmbedding,
28
+ TextXAttnBlock,
29
+ Token2SV,
30
+ )
31
+ from .sampling import center_crop_tokens, repeated_tail
32
+ from .sampling import rf_ar as rf_ar_fn
33
+ from .sampling import rf_nar as rf_nar_fn
34
+ from .sampling import sample_token
35
+ from .tokenizer import TextTokenizer
36
+
37
+
38
+ class TextEncoder(nn.Module):
39
+ def __init__(
40
+ self,
41
+ cfg: SoproTTSConfig,
42
+ d_model: int,
43
+ n_layers: int,
44
+ tokenizer: TextTokenizer,
45
+ ):
46
+ super().__init__()
47
+ self.tok = tokenizer
48
+ self.embed = TextEmbedding(self.tok.vocab_size, d_model)
49
+ self.layers = nn.ModuleList(
50
+ [SSMLiteBlock(d_model, cfg.dropout, causal=False) for _ in range(n_layers)]
51
+ )
52
+ self.pos = SinusoidalPositionalEmbedding(d_model, max_len=cfg.max_text_len + 8)
53
+ self.norm = RMSNorm(d_model)
54
+
55
+ def forward(
56
+ self, text_ids: torch.Tensor, mask: torch.Tensor
57
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
58
+ x = self.embed(text_ids)
59
+ L = x.size(1)
60
+ pos = self.pos(torch.arange(L, device=x.device))
61
+ x = x + pos.unsqueeze(0)
62
+
63
+ x = x * mask.unsqueeze(-1).float()
64
+ for layer in self.layers:
65
+ x = layer(x)
66
+ x = self.norm(x)
67
+
68
+ mask_f = mask.float().unsqueeze(-1)
69
+ pooled = (x * mask_f).sum(dim=1) / (mask_f.sum(dim=1) + 1e-6)
70
+ return x, pooled
71
+
72
+
73
+ class ARRVQ1Generator(nn.Module):
74
+ def __init__(self, cfg: SoproTTSConfig, d_model: int, vocab: int):
75
+ super().__init__()
76
+ ks = cfg.ar_kernel
77
+ dils: List[int] = []
78
+ while len(dils) < cfg.n_layers_ar:
79
+ dils.extend(list(cfg.ar_dilation_cycle))
80
+ dils = dils[: cfg.n_layers_ar]
81
+
82
+ self.dils = tuple(int(d) for d in dils)
83
+ self.blocks = nn.ModuleList(
84
+ [
85
+ SSMLiteBlock(
86
+ d_model, cfg.dropout, causal=True, kernel_size=ks, dilation=d
87
+ )
88
+ for d in self.dils
89
+ ]
90
+ )
91
+
92
+ self.attn_freq = int(cfg.ar_text_attn_freq)
93
+ self.x_attns = nn.ModuleList()
94
+ for i in range(len(self.blocks)):
95
+ if (i + 1) % self.attn_freq == 0:
96
+ self.x_attns.append(
97
+ TextXAttnBlock(d_model, heads=4, dropout=cfg.dropout)
98
+ )
99
+ else:
100
+ self.x_attns.append(nn.Identity())
101
+
102
+ self.norm = RMSNorm(d_model)
103
+ self.head = nn.Linear(d_model, vocab)
104
+
105
+ def forward(
106
+ self,
107
+ x: torch.Tensor,
108
+ text_emb: Optional[torch.Tensor] = None,
109
+ text_mask: Optional[torch.Tensor] = None,
110
+ ) -> torch.Tensor:
111
+ key_padding_mask = ~text_mask if text_mask is not None else None
112
+
113
+ if key_padding_mask is not None:
114
+ bad_rows = key_padding_mask.all(dim=1)
115
+ if bad_rows.any():
116
+ key_padding_mask = key_padding_mask.clone()
117
+ idx = torch.nonzero(bad_rows, as_tuple=False).squeeze(1)
118
+ key_padding_mask[idx, 0] = False
119
+ if text_emb is not None:
120
+ text_emb = text_emb.clone()
121
+ text_emb[idx, 0, :] = 0
122
+
123
+ h = x
124
+ for i, lyr in enumerate(self.blocks):
125
+ h = lyr(h)
126
+ if not isinstance(self.x_attns[i], nn.Identity) and text_emb is not None:
127
+ h = self.x_attns[i](h, text_emb, key_padding_mask=key_padding_mask)
128
+
129
+ h = self.norm(h)
130
+ return self.head(h)
131
+
132
+
133
+ class StageRefiner(nn.Module):
134
+ def __init__(self, cfg: SoproTTSConfig, D: int, num_heads: int, codebook_size: int):
135
+ super().__init__()
136
+ self.blocks = nn.ModuleList(
137
+ [SSMLiteBlock(D, cfg.dropout, causal=True) for _ in range(cfg.n_layers_nar)]
138
+ )
139
+ self.norm = RMSNorm(D)
140
+ self.pre = nn.Linear(D, cfg.nar_head_dim)
141
+ self.heads = nn.ModuleList(
142
+ [nn.Linear(cfg.nar_head_dim, codebook_size) for _ in range(num_heads)]
143
+ )
144
+ self.mix = nn.Parameter(torch.ones(2, dtype=torch.float32))
145
+
146
+ def forward_hidden(
147
+ self, cond_bt_d: torch.Tensor, prev_bt_d: torch.Tensor
148
+ ) -> torch.Tensor:
149
+ w = torch.softmax(self.mix, dim=0)
150
+ x = w[0] * cond_bt_d + w[1] * prev_bt_d
151
+ for b in self.blocks:
152
+ x = b(x)
153
+ return self.norm(x)
154
+
155
+ def forward_heads(self, h: torch.Tensor) -> List[torch.Tensor]:
156
+ z = self.pre(h)
157
+ return [head(z) for head in self.heads]
158
+
159
+
160
+ class StopHead(nn.Module):
161
+ def __init__(self, D: int):
162
+ super().__init__()
163
+ self.proj = nn.Linear(D, 1)
164
+
165
+ def forward(self, h: torch.Tensor) -> torch.Tensor:
166
+ return self.proj(h).squeeze(-1)
167
+
168
+
169
+ class SoproTTSModel(nn.Module):
170
+ def __init__(self, cfg: SoproTTSConfig, tokenizer: TextTokenizer):
171
+ super().__init__()
172
+ self.cfg = cfg
173
+ D = int(cfg.d_model)
174
+
175
+ self.text_enc = TextEncoder(cfg, D, cfg.n_layers_text, tokenizer)
176
+ self.frame_pos = SinusoidalPositionalEmbedding(D, max_len=cfg.pos_emb_max + 2)
177
+
178
+ self.cb_embed = CodebookEmbedding(
179
+ cfg.num_codebooks, cfg.codebook_size, D, use_bos=True
180
+ )
181
+ self.rvq1_bos_id = self.cb_embed.bos_id
182
+
183
+ self.token2sv = Token2SV(
184
+ cfg.num_codebooks,
185
+ cfg.codebook_size,
186
+ d=192,
187
+ out_dim=cfg.sv_student_dim,
188
+ dropout=cfg.dropout,
189
+ )
190
+ self.spk_film = SpeakerFiLM(D, sv_dim=cfg.sv_student_dim)
191
+ self.cond_norm = RMSNorm(D)
192
+
193
+ self.ar = ARRVQ1Generator(cfg, D, cfg.codebook_size)
194
+ if cfg.ar_lookback > 0:
195
+ self.ar_hist_w = nn.Parameter(torch.zeros(cfg.ar_lookback))
196
+
197
+ def idxs(rng: Tuple[int, int]) -> List[int]:
198
+ lo, hi = rng
199
+ return list(range(lo - 1, hi))
200
+
201
+ self.stage_indices: Dict[str, List[int]] = {
202
+ "B": idxs(cfg.stage_B),
203
+ "C": idxs(cfg.stage_C),
204
+ "D": idxs(cfg.stage_D),
205
+ "E": idxs(cfg.stage_E),
206
+ }
207
+ self.stages = nn.ModuleDict(
208
+ {
209
+ s: StageRefiner(cfg, D, len(self.stage_indices[s]), cfg.codebook_size)
210
+ for s in ["B", "C", "D", "E"]
211
+ }
212
+ )
213
+
214
+ self.stop_head = StopHead(D) if cfg.use_stop_head else None
215
+
216
+ self.ref_enc_blocks = nn.ModuleList(
217
+ [SSMLiteBlock(D, cfg.dropout, causal=False) for _ in range(2)]
218
+ )
219
+ self.ref_enc_norm = RMSNorm(D)
220
+ self.ref_xattn_stack = RefXAttn(
221
+ D, heads=cfg.ref_attn_heads, layers=3, dropout=cfg.dropout
222
+ )
223
+
224
+ def rf_ar(self) -> int:
225
+ return rf_ar_fn(self.cfg.ar_kernel, self.ar.dils)
226
+
227
+ def rf_nar(self) -> int:
228
+ return rf_nar_fn(self.cfg.n_layers_nar, kernel_size=7, dilation=1)
229
+
230
+ def _pool_time(self, x: torch.Tensor, factor: int) -> torch.Tensor:
231
+ if factor <= 1 or x.size(1) < 2 * factor:
232
+ return x
233
+ return F.avg_pool1d(
234
+ x.transpose(1, 2), kernel_size=factor, stride=factor
235
+ ).transpose(1, 2)
236
+
237
+ def _normalize_ref_mask(
238
+ self, ref_mask: Optional[torch.Tensor], device: torch.device
239
+ ) -> Optional[torch.Tensor]:
240
+ if ref_mask is None:
241
+ return None
242
+ mk = ref_mask.to(device).bool()
243
+ if mk.ndim == 1:
244
+ mk = mk.unsqueeze(0)
245
+ return mk
246
+
247
+ def _encode_reference_seq(self, ref_tokens: torch.Tensor) -> torch.Tensor:
248
+ B, Tr, Q = ref_tokens.shape
249
+ emb_sum = 0.0
250
+ for q in range(Q):
251
+ emb_sum = emb_sum + self.cb_embed.embed_tokens(
252
+ ref_tokens[:, :, q], cb_index=q
253
+ )
254
+ x = emb_sum / float(Q)
255
+ for b in self.ref_enc_blocks:
256
+ x = b(x)
257
+ return self.ref_enc_norm(x)
258
+
259
+ def _single_pass_ref_xattn(
260
+ self,
261
+ cond_btd: torch.Tensor,
262
+ ref_seq: torch.Tensor,
263
+ ref_mask: Optional[torch.Tensor] = None,
264
+ ) -> torch.Tensor:
265
+ ref_seq_p = self._pool_time(ref_seq, 1)
266
+
267
+ key_padding_mask = None
268
+ if ref_mask is not None:
269
+ mk_bool = ref_mask.bool()
270
+ B, Tr = mk_bool.shape
271
+ pooled_len = ref_seq_p.size(1)
272
+ if pooled_len == Tr:
273
+ key_padding_mask = ~mk_bool
274
+ else:
275
+ cut = pooled_len * 2
276
+ mk2 = mk_bool[:, :cut].reshape(B, pooled_len, 2).any(dim=2)
277
+ key_padding_mask = ~mk2
278
+
279
+ return self.ref_xattn_stack(
280
+ cond_btd, ref_seq_p, key_padding_mask=key_padding_mask
281
+ )
282
+
283
+ def _base_cond_at(
284
+ self, t: int, txt_pool: torch.Tensor, device: torch.device
285
+ ) -> torch.Tensor:
286
+ pos = self.frame_pos(torch.tensor([t], device=device)).unsqueeze(0)
287
+ return txt_pool[:, None, :] + pos
288
+
289
+ def _ar_prev_from_seq(self, seq_1xT: torch.Tensor) -> torch.Tensor:
290
+ K = int(self.cfg.ar_lookback)
291
+ if K <= 0 or getattr(self, "ar_hist_w", None) is None:
292
+ return self.cb_embed.embed_shift_by_k(seq_1xT, cb_index=0, k=1)
293
+
294
+ ws = torch.softmax(self.ar_hist_w, dim=0)
295
+ acc = 0.0
296
+ k_max = min(K, int(seq_1xT.size(1)))
297
+ for k in range(1, k_max + 1):
298
+ acc = acc + ws[k - 1] * self.cb_embed.embed_shift_by_k(
299
+ seq_1xT, cb_index=0, k=k
300
+ )
301
+ return acc
302
+
303
+ @torch.no_grad()
304
+ def prepare_conditioning(
305
+ self,
306
+ text_ids_1d: torch.Tensor,
307
+ ref_tokens_tq: torch.Tensor,
308
+ *,
309
+ max_frames: int,
310
+ device: torch.device,
311
+ style_strength: float = 1.0,
312
+ ref_mask: Optional[torch.Tensor] = None,
313
+ chunk_size: Optional[int] = None,
314
+ ) -> Dict[str, torch.Tensor]:
315
+ self.eval()
316
+
317
+ text_ids = text_ids_1d.to(device)
318
+ text_mask = torch.ones_like(text_ids, dtype=torch.bool).unsqueeze(0)
319
+ txt_seq, txt_pool = self.text_enc(text_ids.unsqueeze(0), text_mask)
320
+
321
+ ref_btq = ref_tokens_tq.unsqueeze(0).to(device)
322
+
323
+ sv_ref = self.token2sv(ref_btq, lengths=None)
324
+ ref_seq = self._encode_reference_seq(ref_btq)
325
+ ref_mask_btr = self._normalize_ref_mask(ref_mask, device)
326
+
327
+ T = int(max_frames)
328
+ if T <= 0:
329
+ cond_all = torch.zeros(
330
+ (1, 0, txt_pool.size(-1)), device=device, dtype=txt_pool.dtype
331
+ )
332
+ else:
333
+ pos = self.frame_pos(torch.arange(T, device=device)).unsqueeze(0)
334
+ base_all = txt_pool[:, None, :] + pos
335
+ base_all = self.spk_film(base_all, sv_ref, strength=float(style_strength))
336
+
337
+ if chunk_size is None or int(chunk_size) >= T:
338
+ out = self._single_pass_ref_xattn(
339
+ base_all, ref_seq, ref_mask=ref_mask_btr
340
+ )
341
+ cond_all = self.cond_norm(out)
342
+ else:
343
+ cs = int(chunk_size)
344
+ chunks: List[torch.Tensor] = []
345
+ for s in range(0, T, cs):
346
+ e = min(T, s + cs)
347
+ q = base_all[:, s:e, :]
348
+ out = self._single_pass_ref_xattn(q, ref_seq, ref_mask=ref_mask_btr)
349
+ chunks.append(self.cond_norm(out))
350
+ cond_all = torch.cat(chunks, dim=1)
351
+
352
+ return {
353
+ "txt_seq": txt_seq,
354
+ "text_mask": text_mask,
355
+ "cond_all": cond_all,
356
+ "ref_btq": ref_btq,
357
+ "txt_pool": txt_pool,
358
+ "sv_ref": sv_ref,
359
+ "ref_seq": ref_seq,
360
+ "ref_mask": (
361
+ ref_mask_btr
362
+ if ref_mask_btr is not None
363
+ else torch.empty(0, device=device, dtype=torch.bool)
364
+ ),
365
+ "style_strength": torch.tensor(float(style_strength), device=device),
366
+ }
367
+
368
+ @torch.no_grad()
369
+ def prepare_conditioning_lazy(
370
+ self,
371
+ text_ids_1d: torch.Tensor,
372
+ ref_tokens_tq: torch.Tensor,
373
+ *,
374
+ max_frames: int,
375
+ device: torch.device,
376
+ style_strength: float = 1.0,
377
+ ref_mask: Optional[torch.Tensor] = None,
378
+ ) -> Dict[str, torch.Tensor]:
379
+ self.eval()
380
+
381
+ text_ids = text_ids_1d.to(device)
382
+ text_mask = torch.ones_like(text_ids, dtype=torch.bool).unsqueeze(0)
383
+ txt_seq, txt_pool = self.text_enc(text_ids.unsqueeze(0), text_mask)
384
+
385
+ ref_btq = ref_tokens_tq.unsqueeze(0).to(device)
386
+ sv_ref = self.token2sv(ref_btq, lengths=None)
387
+ ref_seq = self._encode_reference_seq(ref_btq)
388
+ ref_mask_btr = self._normalize_ref_mask(ref_mask, device)
389
+
390
+ D = int(txt_pool.size(-1))
391
+ cond_all = torch.zeros((1, 0, D), device=device, dtype=txt_pool.dtype)
392
+
393
+ return {
394
+ "txt_seq": txt_seq,
395
+ "text_mask": text_mask,
396
+ "cond_all": cond_all,
397
+ "ref_btq": ref_btq,
398
+ "txt_pool": txt_pool,
399
+ "sv_ref": sv_ref,
400
+ "ref_seq": ref_seq,
401
+ "ref_mask": (
402
+ ref_mask_btr
403
+ if ref_mask_btr is not None
404
+ else torch.empty(0, device=device, dtype=torch.bool)
405
+ ),
406
+ "style_strength": torch.tensor(float(style_strength), device=device),
407
+ "max_frames": torch.tensor(int(max_frames), device=device),
408
+ }
409
+
410
+ @torch.no_grad()
411
+ def ensure_cond_upto(
412
+ self,
413
+ prep: Dict[str, torch.Tensor],
414
+ t_inclusive: int,
415
+ *,
416
+ chunk_size: int = 64,
417
+ ) -> None:
418
+ if t_inclusive < 0:
419
+ return
420
+
421
+ cond_all = prep.get("cond_all", None)
422
+ if cond_all is None:
423
+ raise KeyError("prep dict missing 'cond_all'.")
424
+
425
+ have = int(cond_all.size(1))
426
+ need_min = int(t_inclusive) + 1
427
+ if have >= need_min:
428
+ return
429
+
430
+ if "txt_pool" not in prep or "sv_ref" not in prep or "ref_seq" not in prep:
431
+ raise RuntimeError(
432
+ "Lazy conditioning requested but prep lacks txt_pool/sv_ref/ref_seq. "
433
+ "Use prepare_conditioning_lazy() or prepare_conditioning(..., chunk_size=...)."
434
+ )
435
+
436
+ device = cond_all.device
437
+ txt_pool = prep["txt_pool"]
438
+ sv_ref = prep["sv_ref"]
439
+ ref_seq = prep["ref_seq"]
440
+
441
+ style_strength = float(
442
+ prep.get(
443
+ "style_strength", torch.tensor(self.cfg.style_strength, device=device)
444
+ ).item()
445
+ )
446
+
447
+ ref_mask = prep.get("ref_mask", None)
448
+ if ref_mask is not None and ref_mask.numel() == 0:
449
+ ref_mask = None
450
+
451
+ max_frames = prep.get("max_frames", None)
452
+ maxT = int(max_frames.item()) if max_frames is not None else None
453
+
454
+ cs = max(1, int(chunk_size))
455
+
456
+ need = ((need_min + cs - 1) // cs) * cs
457
+ if maxT is not None:
458
+ need = min(need, maxT)
459
+
460
+ if have >= need:
461
+ return
462
+
463
+ new_chunks: List[torch.Tensor] = []
464
+ for s in range(have, need, cs):
465
+ e = min(need, s + cs)
466
+ pos = self.frame_pos(torch.arange(s, e, device=device)).unsqueeze(0)
467
+ base = txt_pool[:, None, :] + pos
468
+ base = self.spk_film(base, sv_ref, strength=style_strength)
469
+ out = self._single_pass_ref_xattn(base, ref_seq, ref_mask=ref_mask)
470
+ new_chunks.append(self.cond_norm(out))
471
+
472
+ prep["cond_all"] = torch.cat([cond_all] + new_chunks, dim=1)
473
+
474
+ def build_ar_prefix(
475
+ self,
476
+ ref_btq: torch.Tensor,
477
+ device: torch.device,
478
+ prefix_sec_fixed: Optional[float],
479
+ use_prefix: bool,
480
+ ) -> torch.Tensor:
481
+ if not use_prefix or ref_btq.size(1) == 0:
482
+ return torch.zeros(1, 0, dtype=torch.long, device=device)
483
+
484
+ avail = int(ref_btq.size(1))
485
+ fps = float(self.cfg.mimi_fps)
486
+
487
+ if prefix_sec_fixed is not None and prefix_sec_fixed > 0:
488
+ P = min(avail, int(round(prefix_sec_fixed * fps)))
489
+ else:
490
+ P = min(avail, max(1, int(round(self.cfg.preprompt_sec_max * fps))))
491
+
492
+ if P <= 0:
493
+ return torch.zeros(1, 0, dtype=torch.long, device=device)
494
+ return ref_btq[:, :P, 0].contiguous()
495
+
496
+ @torch.no_grad()
497
+ def ar_stream(
498
+ self,
499
+ prep: Dict[str, torch.Tensor],
500
+ *,
501
+ max_frames: int,
502
+ top_p: float = 0.9,
503
+ temperature: float = 1.05,
504
+ anti_loop: bool = True,
505
+ loop_streak: int = 8,
506
+ recovery_top_p: float = 0.85,
507
+ recovery_temp: float = 1.2,
508
+ use_prefix: bool = True,
509
+ prefix_sec_fixed: Optional[float] = None,
510
+ cond_chunk_size: Optional[int] = None,
511
+ use_stop_head: Optional[bool] = None,
512
+ stop_patience: Optional[int] = None,
513
+ stop_threshold: Optional[float] = None,
514
+ min_gen_frames: Optional[int] = None,
515
+ ) -> Iterator[Tuple[int, int, Optional[float]]]:
516
+ device = prep["cond_all"].device
517
+ cond_all = prep["cond_all"]
518
+ txt_seq = prep["txt_seq"]
519
+ text_mask = prep["text_mask"]
520
+ ref_btq = prep["ref_btq"]
521
+
522
+ R_AR = self.rf_ar()
523
+
524
+ stop_head = self.stop_head
525
+ if use_stop_head is not None:
526
+ if not bool(use_stop_head):
527
+ stop_head = None
528
+
529
+ eff_stop_patience = int(
530
+ stop_patience if stop_patience is not None else self.cfg.stop_patience
531
+ )
532
+ eff_stop_threshold = float(
533
+ stop_threshold if stop_threshold is not None else self.cfg.stop_threshold
534
+ )
535
+ eff_min_gen_frames = int(
536
+ min_gen_frames if min_gen_frames is not None else self.cfg.min_gen_frames
537
+ )
538
+
539
+ A_prefix = self.build_ar_prefix(
540
+ ref_btq, device, prefix_sec_fixed, use_prefix=use_prefix
541
+ )
542
+ P = int(A_prefix.size(1))
543
+
544
+ ctx_ids = torch.zeros(
545
+ (1, P + int(max_frames) + 1), dtype=torch.long, device=device
546
+ )
547
+ if P > 0:
548
+ ctx_ids[:, :P] = A_prefix
549
+
550
+ hist_A: List[int] = []
551
+ loop_streak_count = 0
552
+ stop_streak_count = 0
553
+ last_a: Optional[int] = None
554
+
555
+ gen_len = 0
556
+
557
+ for t in range(int(max_frames)):
558
+ if prep["cond_all"].size(1) < (t + 1):
559
+ self.ensure_cond_upto(prep, t, chunk_size=int(cond_chunk_size or 64))
560
+ cond_all = prep["cond_all"]
561
+
562
+ L_ar = min(t + 1, R_AR)
563
+ s_ar = t + 1 - L_ar
564
+ cond_win_ar = cond_all[:, s_ar : t + 1, :]
565
+
566
+ total_len = P + gen_len + 1
567
+ A_ctx_full = ctx_ids[:, :total_len]
568
+
569
+ prev_ctx_full = self._ar_prev_from_seq(A_ctx_full)
570
+ prev_ctx_win = prev_ctx_full[:, -L_ar:, :]
571
+
572
+ cur_top_p, cur_temp = top_p, temperature
573
+ if anti_loop:
574
+ if repeated_tail(hist_A, max_n=16):
575
+ cur_top_p, cur_temp = recovery_top_p, recovery_temp
576
+ elif last_a is not None and loop_streak_count >= loop_streak:
577
+ cur_top_p, cur_temp = recovery_top_p, recovery_temp
578
+
579
+ ar_logits_win = self.ar(
580
+ cond_win_ar + prev_ctx_win, text_emb=txt_seq, text_mask=text_mask
581
+ )
582
+ ar_logits_t = ar_logits_win[:, -1:, :]
583
+
584
+ rvq1_id = sample_token(
585
+ ar_logits_t,
586
+ history=hist_A,
587
+ top_p=cur_top_p,
588
+ temperature=cur_temp,
589
+ top_k=50,
590
+ repetition_penalty=1.1,
591
+ )
592
+
593
+ ctx_ids[0, P + gen_len] = int(rvq1_id)
594
+ gen_len += 1
595
+
596
+ hist_A.append(int(rvq1_id))
597
+ loop_streak_count = (
598
+ (loop_streak_count + 1)
599
+ if (last_a is not None and rvq1_id == last_a)
600
+ else 0
601
+ )
602
+ last_a = int(rvq1_id)
603
+
604
+ p_stop: Optional[float] = None
605
+ if stop_head is not None:
606
+ A_now = torch.tensor([[rvq1_id]], device=device, dtype=torch.long)
607
+ stop_inp = (
608
+ cond_all[:, t : t + 1, :]
609
+ + self.cb_embed.embed_tokens(A_now, cb_index=0).detach()
610
+ )
611
+ stop_logits = stop_head(stop_inp)
612
+ p_stop = float(torch.sigmoid(stop_logits).item())
613
+
614
+ if t + 1 >= eff_min_gen_frames and p_stop > eff_stop_threshold:
615
+ stop_streak_count += 1
616
+ else:
617
+ stop_streak_count = 0
618
+
619
+ yield t, int(rvq1_id), p_stop
620
+
621
+ if stop_head is not None and stop_streak_count >= eff_stop_patience:
622
+ break
623
+
624
+ @torch.no_grad()
625
+ def nar_refine(
626
+ self, cond_seq: torch.Tensor, tokens_A_1xT: torch.Tensor
627
+ ) -> torch.Tensor:
628
+ preds_all: List[torch.Tensor] = [tokens_A_1xT.unsqueeze(-1)]
629
+ prev_tokens_list: List[torch.Tensor] = [tokens_A_1xT.unsqueeze(-1)]
630
+ prev_cb_list: List[List[int]] = [[0]]
631
+
632
+ for stage_name in ["B", "C", "D", "E"]:
633
+ idxs = self.stage_indices[stage_name]
634
+ prev_tokens_cat = torch.cat(prev_tokens_list, dim=-1)
635
+ prev_cbs_cat = sum(prev_cb_list, [])
636
+ prev_emb_sum = self.cb_embed.sum_embed_subset(prev_tokens_cat, prev_cbs_cat)
637
+
638
+ h = self.stages[stage_name].forward_hidden(cond_seq, prev_emb_sum)
639
+ logits_list = self.stages[stage_name].forward_heads(h)
640
+ preds = torch.stack([x.argmax(dim=-1) for x in logits_list], dim=-1)
641
+
642
+ preds_all.append(preds)
643
+ prev_tokens_list.append(preds)
644
+ prev_cb_list.append(idxs)
645
+
646
+ tokens_btq = torch.cat(preds_all, dim=-1)
647
+ return tokens_btq
648
+
649
+ @torch.no_grad()
650
+ def generate_tokens(
651
+ self,
652
+ text_ids_1d: torch.Tensor,
653
+ ref_tokens_tq: torch.Tensor,
654
+ *,
655
+ max_frames: int,
656
+ device: torch.device,
657
+ top_p: float = 0.9,
658
+ temperature: float = 1.05,
659
+ anti_loop: bool = True,
660
+ use_prefix: bool = True,
661
+ prefix_sec_fixed: Optional[float] = None,
662
+ style_strength: float = 1.0,
663
+ use_stop_head: Optional[bool] = None,
664
+ stop_patience: Optional[int] = None,
665
+ stop_threshold: Optional[float] = None,
666
+ min_gen_frames: Optional[int] = None,
667
+ ) -> torch.Tensor:
668
+ prep = self.prepare_conditioning(
669
+ text_ids_1d,
670
+ ref_tokens_tq,
671
+ max_frames=max_frames,
672
+ device=device,
673
+ style_strength=style_strength,
674
+ )
675
+
676
+ hist_A: List[int] = []
677
+ for _t, rvq1, _p_stop in self.ar_stream(
678
+ prep,
679
+ max_frames=max_frames,
680
+ top_p=top_p,
681
+ temperature=temperature,
682
+ anti_loop=anti_loop,
683
+ use_prefix=use_prefix,
684
+ prefix_sec_fixed=prefix_sec_fixed,
685
+ use_stop_head=use_stop_head,
686
+ stop_patience=stop_patience,
687
+ stop_threshold=stop_threshold,
688
+ min_gen_frames=min_gen_frames,
689
+ ):
690
+ hist_A.append(rvq1)
691
+
692
+ T = len(hist_A)
693
+ if T == 0:
694
+ return torch.zeros(
695
+ 0, self.cfg.num_codebooks, dtype=torch.long, device=device
696
+ )
697
+
698
+ tokens_A = torch.tensor(hist_A, device=device, dtype=torch.long).unsqueeze(0)
699
+ cond_seq = prep["cond_all"][:, :T, :]
700
+ tokens_btq_1xTQ = self.nar_refine(cond_seq, tokens_A)
701
+ return tokens_btq_1xTQ.squeeze(0)
702
+
703
+
704
+ class SoproTTS:
705
+ def __init__(
706
+ self,
707
+ model: SoproTTSModel,
708
+ cfg: SoproTTSConfig,
709
+ tokenizer: TextTokenizer,
710
+ codec: MimiCodec,
711
+ device: str,
712
+ ):
713
+ self.model = model
714
+ self.cfg = cfg
715
+ self.tokenizer = tokenizer
716
+ self.codec = codec
717
+ self.device = torch.device(device)
718
+
719
+ @classmethod
720
+ def from_pretrained(
721
+ cls,
722
+ repo_id: str,
723
+ *,
724
+ revision: Optional[str] = None,
725
+ cache_dir: Optional[str] = None,
726
+ token: Optional[str] = None,
727
+ device: Optional[str] = None,
728
+ ) -> "SoproTTS":
729
+ device = device or ("cuda" if torch.cuda.is_available() else "cpu")
730
+ dev = torch.device(device)
731
+
732
+ local_dir = download_repo(
733
+ repo_id, revision=revision, cache_dir=cache_dir, token=token
734
+ )
735
+
736
+ model_path = os.path.join(local_dir, "model.safetensors")
737
+ if not os.path.exists(model_path):
738
+ raise FileNotFoundError(f"Expected {model_path} in repo snapshot.")
739
+
740
+ cfg = load_cfg_from_safetensors(model_path)
741
+
742
+ tokenizer = TextTokenizer(model_name=local_dir)
743
+
744
+ model = SoproTTSModel(cfg, tokenizer).to(dev).eval()
745
+ state = load_state_dict_from_safetensors(model_path)
746
+
747
+ model.load_state_dict(state)
748
+
749
+ codec = MimiCodec(num_quantizers=cfg.num_codebooks, device=device)
750
+
751
+ return cls(
752
+ model=model, cfg=cfg, tokenizer=tokenizer, codec=codec, device=device
753
+ )
754
+
755
+ def encode_text(self, text: str) -> torch.Tensor:
756
+ ids = self.tokenizer.encode(text)
757
+ return torch.tensor(ids, dtype=torch.long, device=self.device)
758
+
759
+ def encode_reference(
760
+ self,
761
+ *,
762
+ ref_audio_path: Optional[str] = None,
763
+ ref_tokens_tq: Optional[torch.Tensor] = None,
764
+ ref_seconds: Optional[float] = None,
765
+ ) -> torch.Tensor:
766
+ if (ref_tokens_tq is None) and (ref_audio_path is None):
767
+ raise RuntimeError(
768
+ "SoproTTS requires a reference. Provide ref_audio_path=... or ref_tokens_tq=..."
769
+ )
770
+ if (ref_tokens_tq is not None) and (ref_audio_path is not None):
771
+ raise RuntimeError(
772
+ "Provide only one of ref_audio_path or ref_tokens_tq (not both)."
773
+ )
774
+
775
+ if ref_seconds is None:
776
+ ref_seconds = float(self.cfg.ref_seconds_max)
777
+
778
+ if ref_tokens_tq is not None:
779
+ ref = ref_tokens_tq.to(self.device).long()
780
+ if ref_seconds > 0:
781
+ fps = float(self.cfg.mimi_fps)
782
+ win = max(1, int(round(ref_seconds * fps)))
783
+ ref = center_crop_tokens(ref, win)
784
+ return ref
785
+
786
+ crop_seconds = (
787
+ ref_seconds if (ref_seconds is not None and ref_seconds > 0) else None
788
+ )
789
+ ref = (
790
+ self.codec.encode_file(ref_audio_path, crop_seconds=crop_seconds)
791
+ .to(self.device)
792
+ .long()
793
+ )
794
+ return ref
795
+
796
+ @torch.no_grad()
797
+ def synthesize(
798
+ self,
799
+ text: str,
800
+ *,
801
+ ref_audio_path: Optional[str] = None,
802
+ ref_tokens_tq: Optional[torch.Tensor] = None,
803
+ max_frames: int = 400,
804
+ top_p: float = 0.9,
805
+ temperature: float = 1.05,
806
+ anti_loop: bool = True,
807
+ use_prefix: bool = True,
808
+ prefix_sec_fixed: Optional[float] = None,
809
+ style_strength: Optional[float] = None,
810
+ ref_seconds: Optional[float] = None,
811
+ use_stop_head: Optional[bool] = None,
812
+ stop_patience: Optional[int] = None,
813
+ stop_threshold: Optional[float] = None,
814
+ min_gen_frames: Optional[int] = None,
815
+ ) -> torch.Tensor:
816
+ text_ids = self.encode_text(text)
817
+ ref = self.encode_reference(
818
+ ref_audio_path=ref_audio_path,
819
+ ref_tokens_tq=ref_tokens_tq,
820
+ ref_seconds=ref_seconds,
821
+ )
822
+
823
+ tokens_tq = self.model.generate_tokens(
824
+ text_ids,
825
+ ref,
826
+ max_frames=max_frames,
827
+ device=self.device,
828
+ top_p=top_p,
829
+ temperature=temperature,
830
+ anti_loop=anti_loop,
831
+ use_prefix=use_prefix,
832
+ prefix_sec_fixed=prefix_sec_fixed,
833
+ style_strength=float(
834
+ style_strength
835
+ if style_strength is not None
836
+ else self.cfg.style_strength
837
+ ),
838
+ use_stop_head=use_stop_head,
839
+ stop_patience=stop_patience,
840
+ stop_threshold=stop_threshold,
841
+ min_gen_frames=min_gen_frames,
842
+ )
843
+
844
+ wav = self.codec.decode_full(tokens_tq)
845
+ return wav
846
+
847
+ def stream(self, text: str, **kwargs) -> Iterator[torch.Tensor]:
848
+ from .streaming import stream as _stream
849
+
850
+ return _stream(self, text, **kwargs)
851
+
852
+ def save_wav(self, path: str, wav_1xT: torch.Tensor) -> None:
853
+ save_audio(path, wav_1xT, sr=TARGET_SR)