sopro 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sopro/__init__.py +6 -0
- sopro/audio.py +155 -0
- sopro/cli.py +185 -0
- sopro/codec/__init__.py +3 -0
- sopro/codec/mimi.py +181 -0
- sopro/config.py +48 -0
- sopro/constants.py +5 -0
- sopro/hub.py +53 -0
- sopro/model.py +853 -0
- sopro/nn/__init__.py +20 -0
- sopro/nn/blocks.py +110 -0
- sopro/nn/embeddings.py +96 -0
- sopro/nn/speaker.py +88 -0
- sopro/nn/xattn.py +98 -0
- sopro/sampling.py +101 -0
- sopro/streaming.py +165 -0
- sopro/tokenizer.py +38 -0
- sopro-1.0.0.dist-info/METADATA +182 -0
- sopro-1.0.0.dist-info/RECORD +23 -0
- sopro-1.0.0.dist-info/WHEEL +5 -0
- sopro-1.0.0.dist-info/entry_points.txt +2 -0
- sopro-1.0.0.dist-info/licenses/LICENSE.txt +201 -0
- sopro-1.0.0.dist-info/top_level.txt +1 -0
sopro/nn/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from .blocks import GLU, AttentiveStatsPool, DepthwiseConv1d, RMSNorm, SSMLiteBlock
|
|
2
|
+
from .embeddings import CodebookEmbedding, SinusoidalPositionalEmbedding, TextEmbedding
|
|
3
|
+
from .speaker import SpeakerFiLM, Token2SV
|
|
4
|
+
from .xattn import RefXAttn, RefXAttnBlock, TextXAttnBlock
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"GLU",
|
|
8
|
+
"RMSNorm",
|
|
9
|
+
"DepthwiseConv1d",
|
|
10
|
+
"SSMLiteBlock",
|
|
11
|
+
"AttentiveStatsPool",
|
|
12
|
+
"SinusoidalPositionalEmbedding",
|
|
13
|
+
"TextEmbedding",
|
|
14
|
+
"CodebookEmbedding",
|
|
15
|
+
"Token2SV",
|
|
16
|
+
"SpeakerFiLM",
|
|
17
|
+
"RefXAttn",
|
|
18
|
+
"RefXAttnBlock",
|
|
19
|
+
"TextXAttnBlock",
|
|
20
|
+
]
|
sopro/nn/blocks.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GLU(nn.Module):
|
|
9
|
+
def __init__(self, d: int):
|
|
10
|
+
super().__init__()
|
|
11
|
+
self.pro = nn.Linear(d, 2 * d)
|
|
12
|
+
|
|
13
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
14
|
+
a, b = self.pro(x).chunk(2, dim=-1)
|
|
15
|
+
return a * torch.sigmoid(b)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RMSNorm(nn.Module):
|
|
19
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.eps = eps
|
|
22
|
+
self.weight = nn.Parameter(torch.ones(dim))
|
|
23
|
+
|
|
24
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
25
|
+
x32 = x.float()
|
|
26
|
+
var = x32.pow(2).mean(dim=-1, keepdim=True)
|
|
27
|
+
y32 = x32 * torch.rsqrt(var + self.eps)
|
|
28
|
+
y32 = y32 * self.weight.float()
|
|
29
|
+
return y32.to(dtype=x.dtype)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class DepthwiseConv1d(nn.Module):
|
|
33
|
+
def __init__(
|
|
34
|
+
self, d: int, kernel_size: int = 7, causal: bool = False, dilation: int = 1
|
|
35
|
+
):
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.causal = causal
|
|
38
|
+
self.dilation = dilation
|
|
39
|
+
self.kernel_size = kernel_size
|
|
40
|
+
self.dw = nn.Conv1d(d, d, kernel_size, groups=d, padding=0, dilation=dilation)
|
|
41
|
+
|
|
42
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
43
|
+
xt = x.transpose(1, 2)
|
|
44
|
+
if self.causal:
|
|
45
|
+
pad_left = (self.kernel_size - 1) * self.dilation
|
|
46
|
+
xt = F.pad(xt, (pad_left, 0))
|
|
47
|
+
else:
|
|
48
|
+
total = (self.kernel_size - 1) * self.dilation
|
|
49
|
+
left = total // 2
|
|
50
|
+
right = total - left
|
|
51
|
+
xt = F.pad(xt, (left, right))
|
|
52
|
+
y = self.dw(xt)
|
|
53
|
+
return y.transpose(1, 2)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class SSMLiteBlock(nn.Module):
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
d: int,
|
|
60
|
+
dropout: float = 0.05,
|
|
61
|
+
causal: bool = False,
|
|
62
|
+
kernel_size: int = 7,
|
|
63
|
+
dilation: int = 1,
|
|
64
|
+
):
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.norm = RMSNorm(d)
|
|
67
|
+
self.glu = GLU(d)
|
|
68
|
+
self.dw = DepthwiseConv1d(
|
|
69
|
+
d, kernel_size=kernel_size, causal=causal, dilation=dilation
|
|
70
|
+
)
|
|
71
|
+
self.ff = nn.Sequential(
|
|
72
|
+
RMSNorm(d),
|
|
73
|
+
nn.Linear(d, 4 * d),
|
|
74
|
+
nn.GELU(),
|
|
75
|
+
nn.Linear(4 * d, d),
|
|
76
|
+
)
|
|
77
|
+
self.drop = nn.Dropout(dropout)
|
|
78
|
+
|
|
79
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
80
|
+
h = self.glu(self.norm(x))
|
|
81
|
+
h = self.dw(h)
|
|
82
|
+
x = x + self.drop(h)
|
|
83
|
+
x = x + self.drop(self.ff(x))
|
|
84
|
+
return x
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class AttentiveStatsPool(nn.Module):
|
|
88
|
+
def __init__(self, d: int):
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.attn = nn.Sequential(
|
|
91
|
+
nn.Linear(d, d),
|
|
92
|
+
nn.Tanh(),
|
|
93
|
+
nn.Linear(d, 1),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def forward(
|
|
97
|
+
self, h: torch.Tensor, lengths: torch.Tensor | None = None
|
|
98
|
+
) -> torch.Tensor:
|
|
99
|
+
B, T, D = h.shape
|
|
100
|
+
logits = self.attn(h).squeeze(-1)
|
|
101
|
+
|
|
102
|
+
if lengths is not None:
|
|
103
|
+
mask = torch.arange(T, device=h.device)[None, :] < lengths[:, None]
|
|
104
|
+
logits = logits.masked_fill(~mask, -1e9)
|
|
105
|
+
|
|
106
|
+
w = torch.softmax(logits, dim=1).unsqueeze(-1)
|
|
107
|
+
mu = (h * w).sum(dim=1)
|
|
108
|
+
var = (w * (h - mu.unsqueeze(1)).pow(2)).sum(dim=1).clamp_min(1e-6)
|
|
109
|
+
std = torch.sqrt(var)
|
|
110
|
+
return torch.cat([mu, std], dim=-1)
|
sopro/nn/embeddings.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SinusoidalPositionalEmbedding(nn.Module):
|
|
12
|
+
def __init__(self, d_model: int, max_len: int = 10000):
|
|
13
|
+
super().__init__()
|
|
14
|
+
pe = torch.zeros(max_len, d_model)
|
|
15
|
+
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
|
|
16
|
+
div_term = torch.exp(
|
|
17
|
+
torch.arange(0, d_model, 2, dtype=torch.float32)
|
|
18
|
+
* (-math.log(10000.0) / d_model)
|
|
19
|
+
)
|
|
20
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
|
21
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
|
22
|
+
self.register_buffer("pe", pe, persistent=False)
|
|
23
|
+
|
|
24
|
+
def forward(self, positions: torch.Tensor) -> torch.Tensor:
|
|
25
|
+
return self.pe.index_select(0, positions.long())
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TextEmbedding(nn.Module):
|
|
29
|
+
def __init__(self, vocab_size: int, d_model: int):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.emb = nn.Embedding(vocab_size, d_model)
|
|
32
|
+
|
|
33
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
34
|
+
return self.emb(x)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class CodebookEmbedding(nn.Module):
|
|
38
|
+
def __init__(
|
|
39
|
+
self, num_codebooks: int, vocab_size: int, d_model: int, use_bos: bool = True
|
|
40
|
+
):
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.Q = int(num_codebooks)
|
|
43
|
+
self.V = int(vocab_size)
|
|
44
|
+
self.D = int(d_model)
|
|
45
|
+
self.use_bos = bool(use_bos)
|
|
46
|
+
|
|
47
|
+
table_size = self.Q * self.V + (1 if self.use_bos else 0)
|
|
48
|
+
self.emb = nn.Embedding(table_size, d_model)
|
|
49
|
+
self.bos_id = (self.Q * self.V) if self.use_bos else None
|
|
50
|
+
|
|
51
|
+
def _indices_for(self, tokens: torch.Tensor, cb_index: int) -> torch.Tensor:
|
|
52
|
+
return cb_index * self.V + tokens
|
|
53
|
+
|
|
54
|
+
def embed_tokens(self, tokens: torch.Tensor, cb_index: int) -> torch.Tensor:
|
|
55
|
+
return self.emb(self._indices_for(tokens, cb_index))
|
|
56
|
+
|
|
57
|
+
def embed_shift_by_k(
|
|
58
|
+
self, tokens: torch.Tensor, cb_index: int, k: int
|
|
59
|
+
) -> torch.Tensor:
|
|
60
|
+
idx = self._indices_for(tokens, cb_index)
|
|
61
|
+
B, T = idx.shape
|
|
62
|
+
if (not self.use_bos) or (self.bos_id is None) or k <= 0:
|
|
63
|
+
pad_tok = idx[:, :1]
|
|
64
|
+
else:
|
|
65
|
+
pad_tok = torch.full(
|
|
66
|
+
(B, 1), self.bos_id, dtype=torch.long, device=idx.device
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
if k >= T:
|
|
70
|
+
idx_shift = pad_tok.expand(-1, T)
|
|
71
|
+
else:
|
|
72
|
+
pad = pad_tok.expand(-1, k)
|
|
73
|
+
idx_shift = torch.cat([pad, idx[:, :-k]], dim=1)
|
|
74
|
+
|
|
75
|
+
return self.emb(idx_shift)
|
|
76
|
+
|
|
77
|
+
def sum_embed_subset(
|
|
78
|
+
self,
|
|
79
|
+
tokens_subset: Optional[torch.Tensor],
|
|
80
|
+
cb_indices: Optional[List[int]],
|
|
81
|
+
keep_mask: Optional[torch.Tensor] = None,
|
|
82
|
+
) -> torch.Tensor:
|
|
83
|
+
if tokens_subset is None or cb_indices is None or len(cb_indices) == 0:
|
|
84
|
+
return 0.0
|
|
85
|
+
|
|
86
|
+
B, T, K = tokens_subset.shape
|
|
87
|
+
idx_list = []
|
|
88
|
+
for k, cb in enumerate(cb_indices):
|
|
89
|
+
idx_list.append(self._indices_for(tokens_subset[..., k], cb).unsqueeze(2))
|
|
90
|
+
idx = torch.cat(idx_list, dim=2)
|
|
91
|
+
emb = self.emb(idx)
|
|
92
|
+
|
|
93
|
+
if keep_mask is not None:
|
|
94
|
+
emb = emb * keep_mask.unsqueeze(-1).to(emb.dtype)
|
|
95
|
+
|
|
96
|
+
return emb.sum(dim=2)
|
sopro/nn/speaker.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
from .blocks import AttentiveStatsPool, DepthwiseConv1d
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Token2SV(nn.Module):
|
|
13
|
+
def __init__(
|
|
14
|
+
self, Q: int, V: int, d: int = 192, out_dim: int = 256, dropout: float = 0.05
|
|
15
|
+
):
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.Q, self.V = int(Q), int(V)
|
|
18
|
+
self.emb = nn.Embedding(self.Q * self.V, d)
|
|
19
|
+
|
|
20
|
+
initial_weights = torch.linspace(1.0, 0.1, self.Q)
|
|
21
|
+
self.cb_weights = nn.Parameter(initial_weights)
|
|
22
|
+
|
|
23
|
+
self.enc = nn.Sequential(
|
|
24
|
+
DepthwiseConv1d(d, 7, causal=False),
|
|
25
|
+
nn.GELU(),
|
|
26
|
+
nn.Dropout(dropout),
|
|
27
|
+
DepthwiseConv1d(d, 7, causal=False),
|
|
28
|
+
nn.GELU(),
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
self.pool = AttentiveStatsPool(d)
|
|
32
|
+
self.proj = nn.Linear(2 * d, out_dim)
|
|
33
|
+
|
|
34
|
+
def _get_mixed_embedding(self, embed_btqd: torch.Tensor) -> torch.Tensor:
|
|
35
|
+
w = F.softmax(self.cb_weights, dim=0).view(1, 1, self.Q, 1)
|
|
36
|
+
return (embed_btqd * w).sum(dim=2)
|
|
37
|
+
|
|
38
|
+
def forward(
|
|
39
|
+
self, tokens_btq: torch.Tensor, lengths: Optional[torch.Tensor] = None
|
|
40
|
+
) -> torch.Tensor:
|
|
41
|
+
B, T, Q = tokens_btq.shape
|
|
42
|
+
q_idx = torch.arange(Q, device=tokens_btq.device, dtype=torch.long).view(
|
|
43
|
+
1, 1, Q
|
|
44
|
+
)
|
|
45
|
+
idx = q_idx * self.V + tokens_btq.long()
|
|
46
|
+
raw_emb = self.emb(idx)
|
|
47
|
+
|
|
48
|
+
if self.training:
|
|
49
|
+
keep_prob = 0.95
|
|
50
|
+
mask = torch.rand(B, T, device=tokens_btq.device) < keep_prob
|
|
51
|
+
bad = mask.sum(dim=1) == 0
|
|
52
|
+
if bad.any():
|
|
53
|
+
bad_idx = bad.nonzero(as_tuple=False).squeeze(1)
|
|
54
|
+
rand_pos = torch.randint(
|
|
55
|
+
0, T, (bad_idx.numel(),), device=tokens_btq.device
|
|
56
|
+
)
|
|
57
|
+
mask[bad_idx, rand_pos] = True
|
|
58
|
+
raw_emb = raw_emb * mask.float().unsqueeze(-1).unsqueeze(-1)
|
|
59
|
+
|
|
60
|
+
x = self._get_mixed_embedding(raw_emb)
|
|
61
|
+
h = self.enc(x)
|
|
62
|
+
pooled = self.pool(h, lengths=lengths)
|
|
63
|
+
e = self.proj(pooled)
|
|
64
|
+
return F.normalize(e, dim=-1, eps=1e-6)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class SpeakerFiLM(nn.Module):
|
|
68
|
+
def __init__(self, d_model: int, sv_dim: int):
|
|
69
|
+
super().__init__()
|
|
70
|
+
self.mlp = nn.Sequential(
|
|
71
|
+
nn.Linear(sv_dim, d_model),
|
|
72
|
+
nn.GELU(),
|
|
73
|
+
nn.Linear(d_model, 2 * d_model),
|
|
74
|
+
)
|
|
75
|
+
self.norm = nn.LayerNorm(d_model)
|
|
76
|
+
nn.init.zeros_(self.mlp[-1].weight)
|
|
77
|
+
nn.init.zeros_(self.mlp[-1].bias)
|
|
78
|
+
|
|
79
|
+
def forward(
|
|
80
|
+
self, base_bt_d: torch.Tensor, spk_b_d: torch.Tensor, strength: float = 1.0
|
|
81
|
+
) -> torch.Tensor:
|
|
82
|
+
B, T, D = base_bt_d.shape
|
|
83
|
+
film = self.mlp(spk_b_d)
|
|
84
|
+
gamma, beta = film.chunk(2, dim=-1)
|
|
85
|
+
gamma = gamma.unsqueeze(1).expand(B, T, D)
|
|
86
|
+
beta = beta.unsqueeze(1).expand(B, T, D)
|
|
87
|
+
x = self.norm(base_bt_d)
|
|
88
|
+
return x * (1 + strength * torch.tanh(gamma)) + strength * torch.tanh(beta)
|
sopro/nn/xattn.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
|
|
8
|
+
from .blocks import RMSNorm
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def rms(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
|
|
12
|
+
return torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RefXAttnBlock(nn.Module):
|
|
16
|
+
def __init__(self, d_model: int, heads: int = 2, dropout: float = 0.0):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.nq = RMSNorm(d_model)
|
|
19
|
+
self.nkv = RMSNorm(d_model)
|
|
20
|
+
self.attn = nn.MultiheadAttention(
|
|
21
|
+
d_model, heads, batch_first=True, dropout=dropout
|
|
22
|
+
)
|
|
23
|
+
self.gate = nn.Parameter(torch.tensor(0.5))
|
|
24
|
+
self.gmax = 0.35
|
|
25
|
+
|
|
26
|
+
def forward(
|
|
27
|
+
self,
|
|
28
|
+
x: torch.Tensor,
|
|
29
|
+
ref: torch.Tensor,
|
|
30
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
31
|
+
) -> torch.Tensor:
|
|
32
|
+
q = self.nq(x)
|
|
33
|
+
kv = self.nkv(ref.float())
|
|
34
|
+
|
|
35
|
+
with torch.autocast(device_type=q.device.type, enabled=False):
|
|
36
|
+
a, _ = self.attn(
|
|
37
|
+
q.float(), kv, kv, key_padding_mask=key_padding_mask, need_weights=False
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
a = torch.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
|
|
41
|
+
|
|
42
|
+
rms_x = rms(x.float())
|
|
43
|
+
rms_a = rms(a)
|
|
44
|
+
scale = rms_x / rms_a
|
|
45
|
+
a = (a * scale).to(x.dtype)
|
|
46
|
+
|
|
47
|
+
gate_eff = (self.gmax * torch.tanh(self.gate)).to(x.dtype)
|
|
48
|
+
return x + gate_eff * a
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class RefXAttn(nn.Module):
|
|
52
|
+
def __init__(
|
|
53
|
+
self, d_model: int, heads: int = 2, layers: int = 3, dropout: float = 0.0
|
|
54
|
+
):
|
|
55
|
+
super().__init__()
|
|
56
|
+
self.blocks = nn.ModuleList(
|
|
57
|
+
[RefXAttnBlock(d_model, heads, dropout) for _ in range(layers)]
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def forward(
|
|
61
|
+
self,
|
|
62
|
+
x: torch.Tensor,
|
|
63
|
+
ref: torch.Tensor,
|
|
64
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
65
|
+
) -> torch.Tensor:
|
|
66
|
+
for blk in self.blocks:
|
|
67
|
+
x = blk(x, ref, key_padding_mask)
|
|
68
|
+
return x
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class TextXAttnBlock(nn.Module):
|
|
72
|
+
def __init__(self, d_model: int, heads: int = 4, dropout: float = 0.0):
|
|
73
|
+
super().__init__()
|
|
74
|
+
self.nq = RMSNorm(d_model)
|
|
75
|
+
self.nkv = RMSNorm(d_model)
|
|
76
|
+
self.attn = nn.MultiheadAttention(
|
|
77
|
+
d_model, num_heads=heads, dropout=dropout, batch_first=True
|
|
78
|
+
)
|
|
79
|
+
self.gate = nn.Parameter(torch.tensor(0.0))
|
|
80
|
+
|
|
81
|
+
def forward(
|
|
82
|
+
self,
|
|
83
|
+
x: torch.Tensor,
|
|
84
|
+
context: torch.Tensor,
|
|
85
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
86
|
+
) -> torch.Tensor:
|
|
87
|
+
q = self.nq(x)
|
|
88
|
+
kv = self.nkv(context)
|
|
89
|
+
with torch.autocast(device_type=q.device.type, enabled=False):
|
|
90
|
+
out, _ = self.attn(
|
|
91
|
+
q.float(),
|
|
92
|
+
kv.float(),
|
|
93
|
+
kv.float(),
|
|
94
|
+
key_padding_mask=key_padding_mask,
|
|
95
|
+
need_weights=False,
|
|
96
|
+
)
|
|
97
|
+
out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0).to(x.dtype)
|
|
98
|
+
return x + torch.tanh(self.gate) * out
|
sopro/sampling.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def center_crop_tokens(ref_tq: torch.Tensor, win_frames: int) -> torch.Tensor:
|
|
9
|
+
T = int(ref_tq.size(0))
|
|
10
|
+
if T <= win_frames:
|
|
11
|
+
return ref_tq
|
|
12
|
+
s = (T - win_frames) // 2
|
|
13
|
+
return ref_tq[s : s + win_frames]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def repeated_tail(hist: List[int], max_n: int = 16) -> bool:
|
|
17
|
+
L = len(hist)
|
|
18
|
+
for n in range(3, min(max_n, L // 2) + 1):
|
|
19
|
+
if hist[-n:] == hist[-2 * n : -n]:
|
|
20
|
+
return True
|
|
21
|
+
return False
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def sample_token(
|
|
25
|
+
logits_1x1v: torch.Tensor,
|
|
26
|
+
history: List[int],
|
|
27
|
+
top_p: float = 0.9,
|
|
28
|
+
top_k: int = 0,
|
|
29
|
+
temperature: float = 1.0,
|
|
30
|
+
repetition_penalty: float = 1.0,
|
|
31
|
+
eps: float = 1e-12,
|
|
32
|
+
) -> int:
|
|
33
|
+
x = logits_1x1v
|
|
34
|
+
|
|
35
|
+
x = torch.nan_to_num(x, nan=-1e9, posinf=1e9, neginf=-1e9)
|
|
36
|
+
|
|
37
|
+
if temperature and temperature != 1.0:
|
|
38
|
+
x = x / float(temperature)
|
|
39
|
+
|
|
40
|
+
if repetition_penalty != 1.0 and len(history) > 0:
|
|
41
|
+
context = history[-50:]
|
|
42
|
+
ids = torch.tensor(list(set(context)), device=x.device, dtype=torch.long)
|
|
43
|
+
if ids.numel() > 0:
|
|
44
|
+
vals = x[0, 0, ids]
|
|
45
|
+
neg = vals < 0
|
|
46
|
+
vals = torch.where(
|
|
47
|
+
neg, vals * repetition_penalty, vals / repetition_penalty
|
|
48
|
+
)
|
|
49
|
+
x = x.clone()
|
|
50
|
+
x[0, 0, ids] = vals
|
|
51
|
+
|
|
52
|
+
probs = torch.softmax(x, dim=-1).view(1, -1)
|
|
53
|
+
probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
|
|
54
|
+
|
|
55
|
+
V = int(probs.size(-1))
|
|
56
|
+
if top_k and top_k > 0:
|
|
57
|
+
k = min(int(top_k), V)
|
|
58
|
+
val, idx = torch.topk(probs, k, dim=-1)
|
|
59
|
+
newp = torch.zeros_like(probs)
|
|
60
|
+
newp.scatter_(1, idx, val)
|
|
61
|
+
probs = newp
|
|
62
|
+
|
|
63
|
+
s = probs.sum(dim=-1, keepdim=True)
|
|
64
|
+
if float(s.item()) <= eps:
|
|
65
|
+
return int(torch.argmax(x[0, 0]).item())
|
|
66
|
+
probs = probs / s
|
|
67
|
+
|
|
68
|
+
if top_p is not None and top_p < 1.0:
|
|
69
|
+
sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
|
|
70
|
+
cum = torch.cumsum(sorted_probs, dim=-1)
|
|
71
|
+
|
|
72
|
+
remove = cum > float(top_p)
|
|
73
|
+
remove[..., 1:] = remove[..., :-1].clone()
|
|
74
|
+
remove[..., 0] = False
|
|
75
|
+
|
|
76
|
+
sorted_probs = sorted_probs.masked_fill(remove, 0.0)
|
|
77
|
+
|
|
78
|
+
s = sorted_probs.sum(dim=-1, keepdim=True)
|
|
79
|
+
if float(s.item()) <= eps:
|
|
80
|
+
return int(torch.argmax(x[0, 0]).item())
|
|
81
|
+
sorted_probs = sorted_probs / s
|
|
82
|
+
|
|
83
|
+
j = torch.multinomial(sorted_probs, 1).item()
|
|
84
|
+
token = int(sorted_idx[0, j].item())
|
|
85
|
+
|
|
86
|
+
return token
|
|
87
|
+
|
|
88
|
+
s = probs.sum(dim=-1, keepdim=True)
|
|
89
|
+
if float(s.item()) <= eps:
|
|
90
|
+
return int(torch.argmax(x[0, 0]).item())
|
|
91
|
+
probs = probs / s
|
|
92
|
+
|
|
93
|
+
return int(torch.multinomial(probs, 1).item())
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def rf_ar(ar_kernel: int, dilations: Tuple[int, ...]) -> int:
|
|
97
|
+
return 1 + (ar_kernel - 1) * int(sum(dilations))
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def rf_nar(n_layers_nar: int, kernel_size: int = 7, dilation: int = 1) -> int:
|
|
101
|
+
return 1 + (kernel_size - 1) * int(n_layers_nar) * int(dilation)
|
sopro/streaming.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Iterator, List, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from .codec.mimi import MimiDecodeState, MimiStreamDecoder
|
|
10
|
+
from .model import SoproTTS, SoproTTSModel
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class StreamConfig:
|
|
15
|
+
chunk_frames: int = 16
|
|
16
|
+
nar_context_frames: Optional[int] = None
|
|
17
|
+
cond_chunk_size: int = 32
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SoproTTSStreamer:
|
|
21
|
+
def __init__(self, tts: SoproTTS, cfg: Optional[StreamConfig] = None):
|
|
22
|
+
self.tts = tts
|
|
23
|
+
self.cfg = cfg or StreamConfig()
|
|
24
|
+
self.mimi_stream = MimiStreamDecoder(tts.codec)
|
|
25
|
+
|
|
26
|
+
@torch.inference_mode()
|
|
27
|
+
def stream(
|
|
28
|
+
self,
|
|
29
|
+
text: str,
|
|
30
|
+
*,
|
|
31
|
+
ref_audio_path: Optional[str] = None,
|
|
32
|
+
ref_tokens_tq: Optional[torch.Tensor] = None,
|
|
33
|
+
max_frames: int = 400,
|
|
34
|
+
top_p: float = 0.9,
|
|
35
|
+
temperature: float = 1.05,
|
|
36
|
+
anti_loop: bool = True,
|
|
37
|
+
use_prefix: bool = True,
|
|
38
|
+
prefix_sec_fixed: Optional[float] = None,
|
|
39
|
+
style_strength: Optional[float] = None,
|
|
40
|
+
ref_seconds: Optional[float] = None,
|
|
41
|
+
chunk_frames: Optional[int] = None,
|
|
42
|
+
nar_context_frames: Optional[int] = None,
|
|
43
|
+
cond_chunk_size: Optional[int] = None,
|
|
44
|
+
use_stop_head: Optional[bool] = None,
|
|
45
|
+
stop_patience: Optional[int] = None,
|
|
46
|
+
stop_threshold: Optional[float] = None,
|
|
47
|
+
min_gen_frames: Optional[int] = None,
|
|
48
|
+
) -> Iterator[torch.Tensor]:
|
|
49
|
+
model: SoproTTSModel = self.tts.model
|
|
50
|
+
device = self.tts.device
|
|
51
|
+
|
|
52
|
+
text_ids = self.tts.encode_text(text)
|
|
53
|
+
ref = self.tts.encode_reference(
|
|
54
|
+
ref_audio_path=ref_audio_path,
|
|
55
|
+
ref_tokens_tq=ref_tokens_tq,
|
|
56
|
+
ref_seconds=ref_seconds,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
prep = model.prepare_conditioning_lazy(
|
|
60
|
+
text_ids,
|
|
61
|
+
ref,
|
|
62
|
+
max_frames=max_frames,
|
|
63
|
+
device=device,
|
|
64
|
+
style_strength=float(
|
|
65
|
+
style_strength
|
|
66
|
+
if style_strength is not None
|
|
67
|
+
else self.tts.cfg.style_strength
|
|
68
|
+
),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
cf = int(chunk_frames if chunk_frames is not None else self.cfg.chunk_frames)
|
|
72
|
+
cond_cs = int(
|
|
73
|
+
cond_chunk_size if cond_chunk_size is not None else self.cfg.cond_chunk_size
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
nar_ctx = (
|
|
77
|
+
nar_context_frames
|
|
78
|
+
if nar_context_frames is not None
|
|
79
|
+
else self.cfg.nar_context_frames
|
|
80
|
+
)
|
|
81
|
+
if nar_ctx is None:
|
|
82
|
+
nar_ctx = int(model.rf_nar())
|
|
83
|
+
nar_ctx = int(nar_ctx)
|
|
84
|
+
|
|
85
|
+
hist_A: List[int] = []
|
|
86
|
+
|
|
87
|
+
frames_emitted = 0
|
|
88
|
+
|
|
89
|
+
mimi_state = MimiDecodeState()
|
|
90
|
+
|
|
91
|
+
def refine_and_emit(end: int) -> Optional[torch.Tensor]:
|
|
92
|
+
nonlocal frames_emitted, mimi_state
|
|
93
|
+
|
|
94
|
+
new_start = frames_emitted
|
|
95
|
+
if end <= new_start:
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
win_start = max(0, new_start - nar_ctx)
|
|
99
|
+
win_end = end
|
|
100
|
+
|
|
101
|
+
cond_win = prep["cond_all"][:, win_start:win_end, :]
|
|
102
|
+
tokens_A_win = torch.as_tensor(
|
|
103
|
+
hist_A[win_start:win_end], device=device, dtype=torch.long
|
|
104
|
+
).unsqueeze(0)
|
|
105
|
+
|
|
106
|
+
tokens_win_tq = model.nar_refine(cond_win, tokens_A_win).squeeze(0)
|
|
107
|
+
|
|
108
|
+
tail_i = new_start - win_start
|
|
109
|
+
emit_tokens = tokens_win_tq[tail_i:, :]
|
|
110
|
+
|
|
111
|
+
wav_chunk, mimi_state = self.mimi_stream.decode_step(
|
|
112
|
+
emit_tokens, mimi_state
|
|
113
|
+
)
|
|
114
|
+
frames_emitted = end
|
|
115
|
+
return wav_chunk if wav_chunk.numel() > 0 else None
|
|
116
|
+
|
|
117
|
+
for _t, rvq1_id, _p_stop in model.ar_stream(
|
|
118
|
+
prep,
|
|
119
|
+
max_frames=max_frames,
|
|
120
|
+
top_p=top_p,
|
|
121
|
+
temperature=temperature,
|
|
122
|
+
anti_loop=anti_loop,
|
|
123
|
+
use_prefix=use_prefix,
|
|
124
|
+
prefix_sec_fixed=prefix_sec_fixed,
|
|
125
|
+
cond_chunk_size=cond_cs,
|
|
126
|
+
use_stop_head=use_stop_head,
|
|
127
|
+
stop_patience=stop_patience,
|
|
128
|
+
stop_threshold=stop_threshold,
|
|
129
|
+
min_gen_frames=min_gen_frames,
|
|
130
|
+
):
|
|
131
|
+
hist_A.append(int(rvq1_id))
|
|
132
|
+
T = len(hist_A)
|
|
133
|
+
|
|
134
|
+
is_boundary = (T % cf) == 0
|
|
135
|
+
if not is_boundary and T < max_frames:
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
wav = refine_and_emit(T)
|
|
139
|
+
if wav is not None:
|
|
140
|
+
yield wav
|
|
141
|
+
|
|
142
|
+
T_final = len(hist_A)
|
|
143
|
+
if frames_emitted < T_final:
|
|
144
|
+
wav = refine_and_emit(T_final)
|
|
145
|
+
if wav is not None:
|
|
146
|
+
yield wav
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def stream(
|
|
150
|
+
tts: SoproTTS,
|
|
151
|
+
text: str,
|
|
152
|
+
*,
|
|
153
|
+
ref_audio_path: Optional[str] = None,
|
|
154
|
+
ref_tokens_tq: Optional[torch.Tensor] = None,
|
|
155
|
+
chunk_frames: int = 6,
|
|
156
|
+
**kwargs,
|
|
157
|
+
) -> Iterator[torch.Tensor]:
|
|
158
|
+
streamer = SoproTTSStreamer(tts, StreamConfig(chunk_frames=chunk_frames))
|
|
159
|
+
return streamer.stream(
|
|
160
|
+
text,
|
|
161
|
+
ref_audio_path=ref_audio_path,
|
|
162
|
+
ref_tokens_tq=ref_tokens_tq,
|
|
163
|
+
chunk_frames=chunk_frames,
|
|
164
|
+
**kwargs,
|
|
165
|
+
)
|