sopro 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sopro/nn/__init__.py ADDED
@@ -0,0 +1,20 @@
1
+ from .blocks import GLU, AttentiveStatsPool, DepthwiseConv1d, RMSNorm, SSMLiteBlock
2
+ from .embeddings import CodebookEmbedding, SinusoidalPositionalEmbedding, TextEmbedding
3
+ from .speaker import SpeakerFiLM, Token2SV
4
+ from .xattn import RefXAttn, RefXAttnBlock, TextXAttnBlock
5
+
6
+ __all__ = [
7
+ "GLU",
8
+ "RMSNorm",
9
+ "DepthwiseConv1d",
10
+ "SSMLiteBlock",
11
+ "AttentiveStatsPool",
12
+ "SinusoidalPositionalEmbedding",
13
+ "TextEmbedding",
14
+ "CodebookEmbedding",
15
+ "Token2SV",
16
+ "SpeakerFiLM",
17
+ "RefXAttn",
18
+ "RefXAttnBlock",
19
+ "TextXAttnBlock",
20
+ ]
sopro/nn/blocks.py ADDED
@@ -0,0 +1,110 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class GLU(nn.Module):
9
+ def __init__(self, d: int):
10
+ super().__init__()
11
+ self.pro = nn.Linear(d, 2 * d)
12
+
13
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
14
+ a, b = self.pro(x).chunk(2, dim=-1)
15
+ return a * torch.sigmoid(b)
16
+
17
+
18
+ class RMSNorm(nn.Module):
19
+ def __init__(self, dim: int, eps: float = 1e-6):
20
+ super().__init__()
21
+ self.eps = eps
22
+ self.weight = nn.Parameter(torch.ones(dim))
23
+
24
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
25
+ x32 = x.float()
26
+ var = x32.pow(2).mean(dim=-1, keepdim=True)
27
+ y32 = x32 * torch.rsqrt(var + self.eps)
28
+ y32 = y32 * self.weight.float()
29
+ return y32.to(dtype=x.dtype)
30
+
31
+
32
+ class DepthwiseConv1d(nn.Module):
33
+ def __init__(
34
+ self, d: int, kernel_size: int = 7, causal: bool = False, dilation: int = 1
35
+ ):
36
+ super().__init__()
37
+ self.causal = causal
38
+ self.dilation = dilation
39
+ self.kernel_size = kernel_size
40
+ self.dw = nn.Conv1d(d, d, kernel_size, groups=d, padding=0, dilation=dilation)
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ xt = x.transpose(1, 2)
44
+ if self.causal:
45
+ pad_left = (self.kernel_size - 1) * self.dilation
46
+ xt = F.pad(xt, (pad_left, 0))
47
+ else:
48
+ total = (self.kernel_size - 1) * self.dilation
49
+ left = total // 2
50
+ right = total - left
51
+ xt = F.pad(xt, (left, right))
52
+ y = self.dw(xt)
53
+ return y.transpose(1, 2)
54
+
55
+
56
+ class SSMLiteBlock(nn.Module):
57
+ def __init__(
58
+ self,
59
+ d: int,
60
+ dropout: float = 0.05,
61
+ causal: bool = False,
62
+ kernel_size: int = 7,
63
+ dilation: int = 1,
64
+ ):
65
+ super().__init__()
66
+ self.norm = RMSNorm(d)
67
+ self.glu = GLU(d)
68
+ self.dw = DepthwiseConv1d(
69
+ d, kernel_size=kernel_size, causal=causal, dilation=dilation
70
+ )
71
+ self.ff = nn.Sequential(
72
+ RMSNorm(d),
73
+ nn.Linear(d, 4 * d),
74
+ nn.GELU(),
75
+ nn.Linear(4 * d, d),
76
+ )
77
+ self.drop = nn.Dropout(dropout)
78
+
79
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
80
+ h = self.glu(self.norm(x))
81
+ h = self.dw(h)
82
+ x = x + self.drop(h)
83
+ x = x + self.drop(self.ff(x))
84
+ return x
85
+
86
+
87
+ class AttentiveStatsPool(nn.Module):
88
+ def __init__(self, d: int):
89
+ super().__init__()
90
+ self.attn = nn.Sequential(
91
+ nn.Linear(d, d),
92
+ nn.Tanh(),
93
+ nn.Linear(d, 1),
94
+ )
95
+
96
+ def forward(
97
+ self, h: torch.Tensor, lengths: torch.Tensor | None = None
98
+ ) -> torch.Tensor:
99
+ B, T, D = h.shape
100
+ logits = self.attn(h).squeeze(-1)
101
+
102
+ if lengths is not None:
103
+ mask = torch.arange(T, device=h.device)[None, :] < lengths[:, None]
104
+ logits = logits.masked_fill(~mask, -1e9)
105
+
106
+ w = torch.softmax(logits, dim=1).unsqueeze(-1)
107
+ mu = (h * w).sum(dim=1)
108
+ var = (w * (h - mu.unsqueeze(1)).pow(2)).sum(dim=1).clamp_min(1e-6)
109
+ std = torch.sqrt(var)
110
+ return torch.cat([mu, std], dim=-1)
sopro/nn/embeddings.py ADDED
@@ -0,0 +1,96 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import List, Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class SinusoidalPositionalEmbedding(nn.Module):
12
+ def __init__(self, d_model: int, max_len: int = 10000):
13
+ super().__init__()
14
+ pe = torch.zeros(max_len, d_model)
15
+ position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
16
+ div_term = torch.exp(
17
+ torch.arange(0, d_model, 2, dtype=torch.float32)
18
+ * (-math.log(10000.0) / d_model)
19
+ )
20
+ pe[:, 0::2] = torch.sin(position * div_term)
21
+ pe[:, 1::2] = torch.cos(position * div_term)
22
+ self.register_buffer("pe", pe, persistent=False)
23
+
24
+ def forward(self, positions: torch.Tensor) -> torch.Tensor:
25
+ return self.pe.index_select(0, positions.long())
26
+
27
+
28
+ class TextEmbedding(nn.Module):
29
+ def __init__(self, vocab_size: int, d_model: int):
30
+ super().__init__()
31
+ self.emb = nn.Embedding(vocab_size, d_model)
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ return self.emb(x)
35
+
36
+
37
+ class CodebookEmbedding(nn.Module):
38
+ def __init__(
39
+ self, num_codebooks: int, vocab_size: int, d_model: int, use_bos: bool = True
40
+ ):
41
+ super().__init__()
42
+ self.Q = int(num_codebooks)
43
+ self.V = int(vocab_size)
44
+ self.D = int(d_model)
45
+ self.use_bos = bool(use_bos)
46
+
47
+ table_size = self.Q * self.V + (1 if self.use_bos else 0)
48
+ self.emb = nn.Embedding(table_size, d_model)
49
+ self.bos_id = (self.Q * self.V) if self.use_bos else None
50
+
51
+ def _indices_for(self, tokens: torch.Tensor, cb_index: int) -> torch.Tensor:
52
+ return cb_index * self.V + tokens
53
+
54
+ def embed_tokens(self, tokens: torch.Tensor, cb_index: int) -> torch.Tensor:
55
+ return self.emb(self._indices_for(tokens, cb_index))
56
+
57
+ def embed_shift_by_k(
58
+ self, tokens: torch.Tensor, cb_index: int, k: int
59
+ ) -> torch.Tensor:
60
+ idx = self._indices_for(tokens, cb_index)
61
+ B, T = idx.shape
62
+ if (not self.use_bos) or (self.bos_id is None) or k <= 0:
63
+ pad_tok = idx[:, :1]
64
+ else:
65
+ pad_tok = torch.full(
66
+ (B, 1), self.bos_id, dtype=torch.long, device=idx.device
67
+ )
68
+
69
+ if k >= T:
70
+ idx_shift = pad_tok.expand(-1, T)
71
+ else:
72
+ pad = pad_tok.expand(-1, k)
73
+ idx_shift = torch.cat([pad, idx[:, :-k]], dim=1)
74
+
75
+ return self.emb(idx_shift)
76
+
77
+ def sum_embed_subset(
78
+ self,
79
+ tokens_subset: Optional[torch.Tensor],
80
+ cb_indices: Optional[List[int]],
81
+ keep_mask: Optional[torch.Tensor] = None,
82
+ ) -> torch.Tensor:
83
+ if tokens_subset is None or cb_indices is None or len(cb_indices) == 0:
84
+ return 0.0
85
+
86
+ B, T, K = tokens_subset.shape
87
+ idx_list = []
88
+ for k, cb in enumerate(cb_indices):
89
+ idx_list.append(self._indices_for(tokens_subset[..., k], cb).unsqueeze(2))
90
+ idx = torch.cat(idx_list, dim=2)
91
+ emb = self.emb(idx)
92
+
93
+ if keep_mask is not None:
94
+ emb = emb * keep_mask.unsqueeze(-1).to(emb.dtype)
95
+
96
+ return emb.sum(dim=2)
sopro/nn/speaker.py ADDED
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from .blocks import AttentiveStatsPool, DepthwiseConv1d
10
+
11
+
12
+ class Token2SV(nn.Module):
13
+ def __init__(
14
+ self, Q: int, V: int, d: int = 192, out_dim: int = 256, dropout: float = 0.05
15
+ ):
16
+ super().__init__()
17
+ self.Q, self.V = int(Q), int(V)
18
+ self.emb = nn.Embedding(self.Q * self.V, d)
19
+
20
+ initial_weights = torch.linspace(1.0, 0.1, self.Q)
21
+ self.cb_weights = nn.Parameter(initial_weights)
22
+
23
+ self.enc = nn.Sequential(
24
+ DepthwiseConv1d(d, 7, causal=False),
25
+ nn.GELU(),
26
+ nn.Dropout(dropout),
27
+ DepthwiseConv1d(d, 7, causal=False),
28
+ nn.GELU(),
29
+ )
30
+
31
+ self.pool = AttentiveStatsPool(d)
32
+ self.proj = nn.Linear(2 * d, out_dim)
33
+
34
+ def _get_mixed_embedding(self, embed_btqd: torch.Tensor) -> torch.Tensor:
35
+ w = F.softmax(self.cb_weights, dim=0).view(1, 1, self.Q, 1)
36
+ return (embed_btqd * w).sum(dim=2)
37
+
38
+ def forward(
39
+ self, tokens_btq: torch.Tensor, lengths: Optional[torch.Tensor] = None
40
+ ) -> torch.Tensor:
41
+ B, T, Q = tokens_btq.shape
42
+ q_idx = torch.arange(Q, device=tokens_btq.device, dtype=torch.long).view(
43
+ 1, 1, Q
44
+ )
45
+ idx = q_idx * self.V + tokens_btq.long()
46
+ raw_emb = self.emb(idx)
47
+
48
+ if self.training:
49
+ keep_prob = 0.95
50
+ mask = torch.rand(B, T, device=tokens_btq.device) < keep_prob
51
+ bad = mask.sum(dim=1) == 0
52
+ if bad.any():
53
+ bad_idx = bad.nonzero(as_tuple=False).squeeze(1)
54
+ rand_pos = torch.randint(
55
+ 0, T, (bad_idx.numel(),), device=tokens_btq.device
56
+ )
57
+ mask[bad_idx, rand_pos] = True
58
+ raw_emb = raw_emb * mask.float().unsqueeze(-1).unsqueeze(-1)
59
+
60
+ x = self._get_mixed_embedding(raw_emb)
61
+ h = self.enc(x)
62
+ pooled = self.pool(h, lengths=lengths)
63
+ e = self.proj(pooled)
64
+ return F.normalize(e, dim=-1, eps=1e-6)
65
+
66
+
67
+ class SpeakerFiLM(nn.Module):
68
+ def __init__(self, d_model: int, sv_dim: int):
69
+ super().__init__()
70
+ self.mlp = nn.Sequential(
71
+ nn.Linear(sv_dim, d_model),
72
+ nn.GELU(),
73
+ nn.Linear(d_model, 2 * d_model),
74
+ )
75
+ self.norm = nn.LayerNorm(d_model)
76
+ nn.init.zeros_(self.mlp[-1].weight)
77
+ nn.init.zeros_(self.mlp[-1].bias)
78
+
79
+ def forward(
80
+ self, base_bt_d: torch.Tensor, spk_b_d: torch.Tensor, strength: float = 1.0
81
+ ) -> torch.Tensor:
82
+ B, T, D = base_bt_d.shape
83
+ film = self.mlp(spk_b_d)
84
+ gamma, beta = film.chunk(2, dim=-1)
85
+ gamma = gamma.unsqueeze(1).expand(B, T, D)
86
+ beta = beta.unsqueeze(1).expand(B, T, D)
87
+ x = self.norm(base_bt_d)
88
+ return x * (1 + strength * torch.tanh(gamma)) + strength * torch.tanh(beta)
sopro/nn/xattn.py ADDED
@@ -0,0 +1,98 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .blocks import RMSNorm
9
+
10
+
11
+ def rms(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
12
+ return torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps)
13
+
14
+
15
+ class RefXAttnBlock(nn.Module):
16
+ def __init__(self, d_model: int, heads: int = 2, dropout: float = 0.0):
17
+ super().__init__()
18
+ self.nq = RMSNorm(d_model)
19
+ self.nkv = RMSNorm(d_model)
20
+ self.attn = nn.MultiheadAttention(
21
+ d_model, heads, batch_first=True, dropout=dropout
22
+ )
23
+ self.gate = nn.Parameter(torch.tensor(0.5))
24
+ self.gmax = 0.35
25
+
26
+ def forward(
27
+ self,
28
+ x: torch.Tensor,
29
+ ref: torch.Tensor,
30
+ key_padding_mask: Optional[torch.Tensor] = None,
31
+ ) -> torch.Tensor:
32
+ q = self.nq(x)
33
+ kv = self.nkv(ref.float())
34
+
35
+ with torch.autocast(device_type=q.device.type, enabled=False):
36
+ a, _ = self.attn(
37
+ q.float(), kv, kv, key_padding_mask=key_padding_mask, need_weights=False
38
+ )
39
+
40
+ a = torch.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
41
+
42
+ rms_x = rms(x.float())
43
+ rms_a = rms(a)
44
+ scale = rms_x / rms_a
45
+ a = (a * scale).to(x.dtype)
46
+
47
+ gate_eff = (self.gmax * torch.tanh(self.gate)).to(x.dtype)
48
+ return x + gate_eff * a
49
+
50
+
51
+ class RefXAttn(nn.Module):
52
+ def __init__(
53
+ self, d_model: int, heads: int = 2, layers: int = 3, dropout: float = 0.0
54
+ ):
55
+ super().__init__()
56
+ self.blocks = nn.ModuleList(
57
+ [RefXAttnBlock(d_model, heads, dropout) for _ in range(layers)]
58
+ )
59
+
60
+ def forward(
61
+ self,
62
+ x: torch.Tensor,
63
+ ref: torch.Tensor,
64
+ key_padding_mask: Optional[torch.Tensor] = None,
65
+ ) -> torch.Tensor:
66
+ for blk in self.blocks:
67
+ x = blk(x, ref, key_padding_mask)
68
+ return x
69
+
70
+
71
+ class TextXAttnBlock(nn.Module):
72
+ def __init__(self, d_model: int, heads: int = 4, dropout: float = 0.0):
73
+ super().__init__()
74
+ self.nq = RMSNorm(d_model)
75
+ self.nkv = RMSNorm(d_model)
76
+ self.attn = nn.MultiheadAttention(
77
+ d_model, num_heads=heads, dropout=dropout, batch_first=True
78
+ )
79
+ self.gate = nn.Parameter(torch.tensor(0.0))
80
+
81
+ def forward(
82
+ self,
83
+ x: torch.Tensor,
84
+ context: torch.Tensor,
85
+ key_padding_mask: Optional[torch.Tensor] = None,
86
+ ) -> torch.Tensor:
87
+ q = self.nq(x)
88
+ kv = self.nkv(context)
89
+ with torch.autocast(device_type=q.device.type, enabled=False):
90
+ out, _ = self.attn(
91
+ q.float(),
92
+ kv.float(),
93
+ kv.float(),
94
+ key_padding_mask=key_padding_mask,
95
+ need_weights=False,
96
+ )
97
+ out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0).to(x.dtype)
98
+ return x + torch.tanh(self.gate) * out
sopro/sampling.py ADDED
@@ -0,0 +1,101 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import List
4
+
5
+ import torch
6
+
7
+
8
+ def center_crop_tokens(ref_tq: torch.Tensor, win_frames: int) -> torch.Tensor:
9
+ T = int(ref_tq.size(0))
10
+ if T <= win_frames:
11
+ return ref_tq
12
+ s = (T - win_frames) // 2
13
+ return ref_tq[s : s + win_frames]
14
+
15
+
16
+ def repeated_tail(hist: List[int], max_n: int = 16) -> bool:
17
+ L = len(hist)
18
+ for n in range(3, min(max_n, L // 2) + 1):
19
+ if hist[-n:] == hist[-2 * n : -n]:
20
+ return True
21
+ return False
22
+
23
+
24
+ def sample_token(
25
+ logits_1x1v: torch.Tensor,
26
+ history: List[int],
27
+ top_p: float = 0.9,
28
+ top_k: int = 0,
29
+ temperature: float = 1.0,
30
+ repetition_penalty: float = 1.0,
31
+ eps: float = 1e-12,
32
+ ) -> int:
33
+ x = logits_1x1v
34
+
35
+ x = torch.nan_to_num(x, nan=-1e9, posinf=1e9, neginf=-1e9)
36
+
37
+ if temperature and temperature != 1.0:
38
+ x = x / float(temperature)
39
+
40
+ if repetition_penalty != 1.0 and len(history) > 0:
41
+ context = history[-50:]
42
+ ids = torch.tensor(list(set(context)), device=x.device, dtype=torch.long)
43
+ if ids.numel() > 0:
44
+ vals = x[0, 0, ids]
45
+ neg = vals < 0
46
+ vals = torch.where(
47
+ neg, vals * repetition_penalty, vals / repetition_penalty
48
+ )
49
+ x = x.clone()
50
+ x[0, 0, ids] = vals
51
+
52
+ probs = torch.softmax(x, dim=-1).view(1, -1)
53
+ probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
54
+
55
+ V = int(probs.size(-1))
56
+ if top_k and top_k > 0:
57
+ k = min(int(top_k), V)
58
+ val, idx = torch.topk(probs, k, dim=-1)
59
+ newp = torch.zeros_like(probs)
60
+ newp.scatter_(1, idx, val)
61
+ probs = newp
62
+
63
+ s = probs.sum(dim=-1, keepdim=True)
64
+ if float(s.item()) <= eps:
65
+ return int(torch.argmax(x[0, 0]).item())
66
+ probs = probs / s
67
+
68
+ if top_p is not None and top_p < 1.0:
69
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
70
+ cum = torch.cumsum(sorted_probs, dim=-1)
71
+
72
+ remove = cum > float(top_p)
73
+ remove[..., 1:] = remove[..., :-1].clone()
74
+ remove[..., 0] = False
75
+
76
+ sorted_probs = sorted_probs.masked_fill(remove, 0.0)
77
+
78
+ s = sorted_probs.sum(dim=-1, keepdim=True)
79
+ if float(s.item()) <= eps:
80
+ return int(torch.argmax(x[0, 0]).item())
81
+ sorted_probs = sorted_probs / s
82
+
83
+ j = torch.multinomial(sorted_probs, 1).item()
84
+ token = int(sorted_idx[0, j].item())
85
+
86
+ return token
87
+
88
+ s = probs.sum(dim=-1, keepdim=True)
89
+ if float(s.item()) <= eps:
90
+ return int(torch.argmax(x[0, 0]).item())
91
+ probs = probs / s
92
+
93
+ return int(torch.multinomial(probs, 1).item())
94
+
95
+
96
+ def rf_ar(ar_kernel: int, dilations: Tuple[int, ...]) -> int:
97
+ return 1 + (ar_kernel - 1) * int(sum(dilations))
98
+
99
+
100
+ def rf_nar(n_layers_nar: int, kernel_size: int = 7, dilation: int = 1) -> int:
101
+ return 1 + (kernel_size - 1) * int(n_layers_nar) * int(dilation)
sopro/streaming.py ADDED
@@ -0,0 +1,165 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from dataclasses import dataclass
5
+ from typing import Iterator, List, Optional, Tuple
6
+
7
+ import torch
8
+
9
+ from .codec.mimi import MimiDecodeState, MimiStreamDecoder
10
+ from .model import SoproTTS, SoproTTSModel
11
+
12
+
13
+ @dataclass
14
+ class StreamConfig:
15
+ chunk_frames: int = 16
16
+ nar_context_frames: Optional[int] = None
17
+ cond_chunk_size: int = 32
18
+
19
+
20
+ class SoproTTSStreamer:
21
+ def __init__(self, tts: SoproTTS, cfg: Optional[StreamConfig] = None):
22
+ self.tts = tts
23
+ self.cfg = cfg or StreamConfig()
24
+ self.mimi_stream = MimiStreamDecoder(tts.codec)
25
+
26
+ @torch.inference_mode()
27
+ def stream(
28
+ self,
29
+ text: str,
30
+ *,
31
+ ref_audio_path: Optional[str] = None,
32
+ ref_tokens_tq: Optional[torch.Tensor] = None,
33
+ max_frames: int = 400,
34
+ top_p: float = 0.9,
35
+ temperature: float = 1.05,
36
+ anti_loop: bool = True,
37
+ use_prefix: bool = True,
38
+ prefix_sec_fixed: Optional[float] = None,
39
+ style_strength: Optional[float] = None,
40
+ ref_seconds: Optional[float] = None,
41
+ chunk_frames: Optional[int] = None,
42
+ nar_context_frames: Optional[int] = None,
43
+ cond_chunk_size: Optional[int] = None,
44
+ use_stop_head: Optional[bool] = None,
45
+ stop_patience: Optional[int] = None,
46
+ stop_threshold: Optional[float] = None,
47
+ min_gen_frames: Optional[int] = None,
48
+ ) -> Iterator[torch.Tensor]:
49
+ model: SoproTTSModel = self.tts.model
50
+ device = self.tts.device
51
+
52
+ text_ids = self.tts.encode_text(text)
53
+ ref = self.tts.encode_reference(
54
+ ref_audio_path=ref_audio_path,
55
+ ref_tokens_tq=ref_tokens_tq,
56
+ ref_seconds=ref_seconds,
57
+ )
58
+
59
+ prep = model.prepare_conditioning_lazy(
60
+ text_ids,
61
+ ref,
62
+ max_frames=max_frames,
63
+ device=device,
64
+ style_strength=float(
65
+ style_strength
66
+ if style_strength is not None
67
+ else self.tts.cfg.style_strength
68
+ ),
69
+ )
70
+
71
+ cf = int(chunk_frames if chunk_frames is not None else self.cfg.chunk_frames)
72
+ cond_cs = int(
73
+ cond_chunk_size if cond_chunk_size is not None else self.cfg.cond_chunk_size
74
+ )
75
+
76
+ nar_ctx = (
77
+ nar_context_frames
78
+ if nar_context_frames is not None
79
+ else self.cfg.nar_context_frames
80
+ )
81
+ if nar_ctx is None:
82
+ nar_ctx = int(model.rf_nar())
83
+ nar_ctx = int(nar_ctx)
84
+
85
+ hist_A: List[int] = []
86
+
87
+ frames_emitted = 0
88
+
89
+ mimi_state = MimiDecodeState()
90
+
91
+ def refine_and_emit(end: int) -> Optional[torch.Tensor]:
92
+ nonlocal frames_emitted, mimi_state
93
+
94
+ new_start = frames_emitted
95
+ if end <= new_start:
96
+ return None
97
+
98
+ win_start = max(0, new_start - nar_ctx)
99
+ win_end = end
100
+
101
+ cond_win = prep["cond_all"][:, win_start:win_end, :]
102
+ tokens_A_win = torch.as_tensor(
103
+ hist_A[win_start:win_end], device=device, dtype=torch.long
104
+ ).unsqueeze(0)
105
+
106
+ tokens_win_tq = model.nar_refine(cond_win, tokens_A_win).squeeze(0)
107
+
108
+ tail_i = new_start - win_start
109
+ emit_tokens = tokens_win_tq[tail_i:, :]
110
+
111
+ wav_chunk, mimi_state = self.mimi_stream.decode_step(
112
+ emit_tokens, mimi_state
113
+ )
114
+ frames_emitted = end
115
+ return wav_chunk if wav_chunk.numel() > 0 else None
116
+
117
+ for _t, rvq1_id, _p_stop in model.ar_stream(
118
+ prep,
119
+ max_frames=max_frames,
120
+ top_p=top_p,
121
+ temperature=temperature,
122
+ anti_loop=anti_loop,
123
+ use_prefix=use_prefix,
124
+ prefix_sec_fixed=prefix_sec_fixed,
125
+ cond_chunk_size=cond_cs,
126
+ use_stop_head=use_stop_head,
127
+ stop_patience=stop_patience,
128
+ stop_threshold=stop_threshold,
129
+ min_gen_frames=min_gen_frames,
130
+ ):
131
+ hist_A.append(int(rvq1_id))
132
+ T = len(hist_A)
133
+
134
+ is_boundary = (T % cf) == 0
135
+ if not is_boundary and T < max_frames:
136
+ continue
137
+
138
+ wav = refine_and_emit(T)
139
+ if wav is not None:
140
+ yield wav
141
+
142
+ T_final = len(hist_A)
143
+ if frames_emitted < T_final:
144
+ wav = refine_and_emit(T_final)
145
+ if wav is not None:
146
+ yield wav
147
+
148
+
149
+ def stream(
150
+ tts: SoproTTS,
151
+ text: str,
152
+ *,
153
+ ref_audio_path: Optional[str] = None,
154
+ ref_tokens_tq: Optional[torch.Tensor] = None,
155
+ chunk_frames: int = 6,
156
+ **kwargs,
157
+ ) -> Iterator[torch.Tensor]:
158
+ streamer = SoproTTSStreamer(tts, StreamConfig(chunk_frames=chunk_frames))
159
+ return streamer.stream(
160
+ text,
161
+ ref_audio_path=ref_audio_path,
162
+ ref_tokens_tq=ref_tokens_tq,
163
+ chunk_frames=chunk_frames,
164
+ **kwargs,
165
+ )