sopro 1.0.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sopro/__init__.py +1 -1
- sopro/cli.py +31 -46
- sopro/config.py +15 -20
- sopro/hub.py +2 -3
- sopro/model.py +265 -535
- sopro/nn/__init__.py +7 -3
- sopro/nn/blocks.py +78 -0
- sopro/nn/embeddings.py +16 -0
- sopro/nn/generator.py +130 -0
- sopro/nn/nar.py +116 -0
- sopro/nn/ref.py +160 -0
- sopro/nn/speaker.py +14 -17
- sopro/nn/text.py +132 -0
- sopro/sampling.py +3 -3
- sopro/streaming.py +25 -38
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/METADATA +30 -7
- sopro-1.5.0.dist-info/RECORD +26 -0
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/WHEEL +1 -1
- sopro/nn/xattn.py +0 -98
- sopro-1.0.1.dist-info/RECORD +0 -23
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/entry_points.txt +0 -0
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/licenses/LICENSE.txt +0 -0
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/top_level.txt +0 -0
sopro/nn/__init__.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
from .blocks import GLU, AttentiveStatsPool, DepthwiseConv1d, RMSNorm, SSMLiteBlock
|
|
2
2
|
from .embeddings import CodebookEmbedding, SinusoidalPositionalEmbedding, TextEmbedding
|
|
3
|
+
from .generator import ARRVQ1Generator
|
|
4
|
+
from .ref import RefXAttnBlock, RefXAttnStack
|
|
3
5
|
from .speaker import SpeakerFiLM, Token2SV
|
|
4
|
-
from .
|
|
6
|
+
from .text import TextEncoder, TextXAttnBlock
|
|
5
7
|
|
|
6
8
|
__all__ = [
|
|
7
9
|
"GLU",
|
|
@@ -14,7 +16,9 @@ __all__ = [
|
|
|
14
16
|
"CodebookEmbedding",
|
|
15
17
|
"Token2SV",
|
|
16
18
|
"SpeakerFiLM",
|
|
17
|
-
"
|
|
18
|
-
"RefXAttnBlock",
|
|
19
|
+
"TextEncoder",
|
|
19
20
|
"TextXAttnBlock",
|
|
21
|
+
"RefXAttnBlock",
|
|
22
|
+
"RefXAttnStack",
|
|
23
|
+
"ARRVQ1Generator",
|
|
20
24
|
]
|
sopro/nn/blocks.py
CHANGED
|
@@ -1,10 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
3
6
|
import torch
|
|
4
7
|
import torch.nn as nn
|
|
5
8
|
import torch.nn.functional as F
|
|
6
9
|
|
|
7
10
|
|
|
11
|
+
@dataclass
|
|
12
|
+
class DepthwiseConv1dState:
|
|
13
|
+
buf: torch.Tensor
|
|
14
|
+
|
|
15
|
+
|
|
8
16
|
class GLU(nn.Module):
|
|
9
17
|
def __init__(self, d: int):
|
|
10
18
|
super().__init__()
|
|
@@ -39,6 +47,19 @@ class DepthwiseConv1d(nn.Module):
|
|
|
39
47
|
self.kernel_size = kernel_size
|
|
40
48
|
self.dw = nn.Conv1d(d, d, kernel_size, groups=d, padding=0, dilation=dilation)
|
|
41
49
|
|
|
50
|
+
def _ctx_len(self) -> int:
|
|
51
|
+
return (self.kernel_size - 1) * self.dilation + 1
|
|
52
|
+
|
|
53
|
+
def init_state(
|
|
54
|
+
self, batch_size: int, device: torch.device, dtype: torch.dtype
|
|
55
|
+
) -> DepthwiseConv1dState:
|
|
56
|
+
if not self.causal:
|
|
57
|
+
raise ValueError("init_state is only valid for causal convs")
|
|
58
|
+
L = self._ctx_len()
|
|
59
|
+
D = int(self.dw.in_channels)
|
|
60
|
+
buf = torch.zeros((batch_size, L, D), device=device, dtype=dtype)
|
|
61
|
+
return DepthwiseConv1dState(buf=buf)
|
|
62
|
+
|
|
42
63
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
43
64
|
xt = x.transpose(1, 2)
|
|
44
65
|
if self.causal:
|
|
@@ -52,6 +73,42 @@ class DepthwiseConv1d(nn.Module):
|
|
|
52
73
|
y = self.dw(xt)
|
|
53
74
|
return y.transpose(1, 2)
|
|
54
75
|
|
|
76
|
+
def forward_step(
|
|
77
|
+
self, x_bt_d: torch.Tensor, state: Optional[DepthwiseConv1dState]
|
|
78
|
+
) -> Tuple[torch.Tensor, DepthwiseConv1dState]:
|
|
79
|
+
if not self.causal:
|
|
80
|
+
raise ValueError("forward_step is only valid for causal convs")
|
|
81
|
+
|
|
82
|
+
if x_bt_d.dim() == 2:
|
|
83
|
+
x_bt_d = x_bt_d.unsqueeze(1)
|
|
84
|
+
|
|
85
|
+
B, T, D = x_bt_d.shape
|
|
86
|
+
if T != 1:
|
|
87
|
+
raise ValueError("forward_step expects a single timestep [B,1,D]")
|
|
88
|
+
|
|
89
|
+
if state is None:
|
|
90
|
+
state = self.init_state(B, x_bt_d.device, x_bt_d.dtype)
|
|
91
|
+
|
|
92
|
+
buf = state.buf
|
|
93
|
+
if buf.size(1) > 1:
|
|
94
|
+
buf = torch.cat([buf[:, 1:, :], x_bt_d], dim=1)
|
|
95
|
+
else:
|
|
96
|
+
buf = x_bt_d
|
|
97
|
+
|
|
98
|
+
k = int(self.kernel_size)
|
|
99
|
+
d = int(self.dilation)
|
|
100
|
+
idx = torch.arange(0, k * d, d, device=x_bt_d.device)
|
|
101
|
+
x_bkd = buf.index_select(1, idx) # [B,k,D]
|
|
102
|
+
|
|
103
|
+
w_dk = self.dw.weight.squeeze(1).to(dtype=x_bt_d.dtype)
|
|
104
|
+
y_bd = (x_bkd.transpose(1, 2) * w_dk.unsqueeze(0)).sum(dim=-1)
|
|
105
|
+
if self.dw.bias is not None:
|
|
106
|
+
y_bd = y_bd + self.dw.bias.to(dtype=y_bd.dtype).unsqueeze(0)
|
|
107
|
+
|
|
108
|
+
y_bt_d = y_bd.unsqueeze(1)
|
|
109
|
+
state.buf = buf
|
|
110
|
+
return y_bt_d, state
|
|
111
|
+
|
|
55
112
|
|
|
56
113
|
class SSMLiteBlock(nn.Module):
|
|
57
114
|
def __init__(
|
|
@@ -76,6 +133,13 @@ class SSMLiteBlock(nn.Module):
|
|
|
76
133
|
)
|
|
77
134
|
self.drop = nn.Dropout(dropout)
|
|
78
135
|
|
|
136
|
+
def init_state(
|
|
137
|
+
self, batch_size: int, device: torch.device, dtype: torch.dtype
|
|
138
|
+
) -> dict:
|
|
139
|
+
if not self.dw.causal:
|
|
140
|
+
raise ValueError("SSMLiteBlock.init_state only valid for causal blocks")
|
|
141
|
+
return {"dw": self.dw.init_state(batch_size, device, dtype)}
|
|
142
|
+
|
|
79
143
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
80
144
|
h = self.glu(self.norm(x))
|
|
81
145
|
h = self.dw(h)
|
|
@@ -83,6 +147,20 @@ class SSMLiteBlock(nn.Module):
|
|
|
83
147
|
x = x + self.drop(self.ff(x))
|
|
84
148
|
return x
|
|
85
149
|
|
|
150
|
+
def forward_step(
|
|
151
|
+
self, x_bt_d: torch.Tensor, state: dict
|
|
152
|
+
) -> Tuple[torch.Tensor, dict]:
|
|
153
|
+
if not self.dw.causal:
|
|
154
|
+
raise ValueError("forward_step only valid for causal blocks")
|
|
155
|
+
|
|
156
|
+
h = self.glu(self.norm(x_bt_d))
|
|
157
|
+
y, dw_state = self.dw.forward_step(h, state.get("dw", None))
|
|
158
|
+
state["dw"] = dw_state
|
|
159
|
+
|
|
160
|
+
x = x_bt_d + self.drop(y)
|
|
161
|
+
x = x + self.drop(self.ff(x))
|
|
162
|
+
return x, state
|
|
163
|
+
|
|
86
164
|
|
|
87
165
|
class AttentiveStatsPool(nn.Module):
|
|
88
166
|
def __init__(self, d: int):
|
sopro/nn/embeddings.py
CHANGED
|
@@ -79,6 +79,7 @@ class CodebookEmbedding(nn.Module):
|
|
|
79
79
|
tokens_subset: Optional[torch.Tensor],
|
|
80
80
|
cb_indices: Optional[List[int]],
|
|
81
81
|
keep_mask: Optional[torch.Tensor] = None,
|
|
82
|
+
cb_weights: Optional[torch.Tensor] = None,
|
|
82
83
|
) -> torch.Tensor:
|
|
83
84
|
if tokens_subset is None or cb_indices is None or len(cb_indices) == 0:
|
|
84
85
|
return 0.0
|
|
@@ -90,6 +91,21 @@ class CodebookEmbedding(nn.Module):
|
|
|
90
91
|
idx = torch.cat(idx_list, dim=2)
|
|
91
92
|
emb = self.emb(idx)
|
|
92
93
|
|
|
94
|
+
if cb_weights is not None:
|
|
95
|
+
w = cb_weights
|
|
96
|
+
if w.dim() != 1:
|
|
97
|
+
raise ValueError("cb_weights must be 1D")
|
|
98
|
+
if w.numel() == self.Q:
|
|
99
|
+
cb_t = torch.tensor(cb_indices, device=emb.device, dtype=torch.long)
|
|
100
|
+
w = w.to(emb.device).index_select(0, cb_t)
|
|
101
|
+
elif w.numel() != K:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f"cb_weights must have len Q={self.Q} or K={K}, got {w.numel()}"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
w = F.softmax(w.float(), dim=0).to(dtype=emb.dtype)
|
|
107
|
+
emb = emb * w.view(1, 1, K, 1)
|
|
108
|
+
|
|
93
109
|
if keep_mask is not None:
|
|
94
110
|
emb = emb * keep_mask.unsqueeze(-1).to(emb.dtype)
|
|
95
111
|
|
sopro/nn/generator.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from typing import Dict, List, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
from sopro.config import SoproTTSConfig
|
|
7
|
+
from sopro.nn.blocks import RMSNorm, SSMLiteBlock
|
|
8
|
+
from sopro.nn.text import TextXAttnBlock
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ARRVQ1Generator(nn.Module):
|
|
12
|
+
def __init__(self, cfg: SoproTTSConfig, d_model: int, vocab: int):
|
|
13
|
+
super().__init__()
|
|
14
|
+
ks = int(cfg.ar_kernel)
|
|
15
|
+
|
|
16
|
+
dils: List[int] = []
|
|
17
|
+
while len(dils) < int(cfg.n_layers_ar):
|
|
18
|
+
dils.extend(list(cfg.ar_dilation_cycle))
|
|
19
|
+
dils = dils[: int(cfg.n_layers_ar)]
|
|
20
|
+
self.dils = tuple(int(d) for d in dils)
|
|
21
|
+
|
|
22
|
+
self.blocks = nn.ModuleList(
|
|
23
|
+
[
|
|
24
|
+
SSMLiteBlock(
|
|
25
|
+
d_model, cfg.dropout, causal=True, kernel_size=ks, dilation=d
|
|
26
|
+
)
|
|
27
|
+
for d in self.dils
|
|
28
|
+
]
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
self.attn_freq = int(cfg.ar_text_attn_freq)
|
|
32
|
+
self.x_attns = nn.ModuleList()
|
|
33
|
+
for i in range(len(self.blocks)):
|
|
34
|
+
if (i + 1) % self.attn_freq == 0:
|
|
35
|
+
self.x_attns.append(
|
|
36
|
+
TextXAttnBlock(d_model, heads=4, dropout=cfg.dropout)
|
|
37
|
+
)
|
|
38
|
+
else:
|
|
39
|
+
self.x_attns.append(nn.Identity())
|
|
40
|
+
|
|
41
|
+
self.norm = RMSNorm(d_model)
|
|
42
|
+
self.head = nn.Linear(d_model, vocab)
|
|
43
|
+
|
|
44
|
+
@torch.no_grad()
|
|
45
|
+
def init_stream_state(
|
|
46
|
+
self,
|
|
47
|
+
batch_size: int,
|
|
48
|
+
device: torch.device,
|
|
49
|
+
dtype: torch.dtype,
|
|
50
|
+
*,
|
|
51
|
+
text_emb: Optional[torch.Tensor] = None,
|
|
52
|
+
text_mask: Optional[torch.Tensor] = None,
|
|
53
|
+
) -> Dict[str, object]:
|
|
54
|
+
layer_states = [
|
|
55
|
+
blk.init_state(batch_size, device, dtype) for blk in self.blocks
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
kv_caches: List[Optional[Dict[str, torch.Tensor]]] = []
|
|
59
|
+
key_padding_mask = (~text_mask) if text_mask is not None else None
|
|
60
|
+
for xa in self.x_attns:
|
|
61
|
+
if isinstance(xa, nn.Identity) or (text_emb is None):
|
|
62
|
+
kv_caches.append(None)
|
|
63
|
+
else:
|
|
64
|
+
kv_caches.append(
|
|
65
|
+
xa.build_kv_cache(text_emb, key_padding_mask=key_padding_mask)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return {"layer_states": layer_states, "kv_caches": kv_caches}
|
|
69
|
+
|
|
70
|
+
def forward(
|
|
71
|
+
self,
|
|
72
|
+
x: torch.Tensor,
|
|
73
|
+
text_emb: Optional[torch.Tensor] = None,
|
|
74
|
+
text_mask: Optional[torch.Tensor] = None,
|
|
75
|
+
) -> torch.Tensor:
|
|
76
|
+
key_padding_mask = ~text_mask if text_mask is not None else None
|
|
77
|
+
|
|
78
|
+
if key_padding_mask is not None:
|
|
79
|
+
bad_rows = key_padding_mask.all(dim=1)
|
|
80
|
+
if bad_rows.any():
|
|
81
|
+
key_padding_mask = key_padding_mask.clone()
|
|
82
|
+
idx = torch.nonzero(bad_rows, as_tuple=False).squeeze(1)
|
|
83
|
+
key_padding_mask[idx, 0] = False
|
|
84
|
+
if text_emb is not None:
|
|
85
|
+
text_emb = text_emb.clone()
|
|
86
|
+
text_emb[idx, 0, :] = 0
|
|
87
|
+
|
|
88
|
+
h = x
|
|
89
|
+
for i, lyr in enumerate(self.blocks):
|
|
90
|
+
h = lyr(h)
|
|
91
|
+
if not isinstance(self.x_attns[i], nn.Identity) and text_emb is not None:
|
|
92
|
+
h = self.x_attns[i](h, text_emb, key_padding_mask=key_padding_mask)
|
|
93
|
+
|
|
94
|
+
h = self.norm(h)
|
|
95
|
+
|
|
96
|
+
return self.head(h)
|
|
97
|
+
|
|
98
|
+
@torch.no_grad()
|
|
99
|
+
def step(
|
|
100
|
+
self,
|
|
101
|
+
x_bt_d: torch.Tensor,
|
|
102
|
+
state: Dict[str, object],
|
|
103
|
+
*,
|
|
104
|
+
text_emb: Optional[torch.Tensor] = None,
|
|
105
|
+
text_mask: Optional[torch.Tensor] = None,
|
|
106
|
+
) -> Tuple[torch.Tensor, Dict[str, object]]:
|
|
107
|
+
h = x_bt_d
|
|
108
|
+
key_padding_mask = (~text_mask) if text_mask is not None else None
|
|
109
|
+
|
|
110
|
+
layer_states: List[dict] = state["layer_states"]
|
|
111
|
+
kv_caches: List[Optional[Dict[str, torch.Tensor]]] = state["kv_caches"]
|
|
112
|
+
|
|
113
|
+
for i, blk in enumerate(self.blocks):
|
|
114
|
+
h, layer_states[i] = blk.forward_step(h, layer_states[i])
|
|
115
|
+
|
|
116
|
+
xa = self.x_attns[i]
|
|
117
|
+
if (not isinstance(xa, nn.Identity)) and (text_emb is not None):
|
|
118
|
+
kv = kv_caches[i]
|
|
119
|
+
if kv is None:
|
|
120
|
+
kv = xa.build_kv_cache(text_emb, key_padding_mask=key_padding_mask)
|
|
121
|
+
h, kv = xa(h, kv_cache=kv, use_cache=True)
|
|
122
|
+
kv_caches[i] = kv
|
|
123
|
+
|
|
124
|
+
state["layer_states"] = layer_states
|
|
125
|
+
state["kv_caches"] = kv_caches
|
|
126
|
+
|
|
127
|
+
h = self.norm(h)
|
|
128
|
+
logits = self.head(h)
|
|
129
|
+
|
|
130
|
+
return logits, state
|
sopro/nn/nar.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
|
|
8
|
+
from sopro.config import SoproTTSConfig
|
|
9
|
+
|
|
10
|
+
from .blocks import RMSNorm, SSMLiteBlock
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class NARStageAdapter(nn.Module):
|
|
14
|
+
def __init__(self, d_model: int, hidden: int = 256):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.norm = RMSNorm(d_model)
|
|
17
|
+
self.mlp = nn.Sequential(
|
|
18
|
+
nn.Linear(d_model, hidden),
|
|
19
|
+
nn.GELU(),
|
|
20
|
+
nn.Linear(hidden, 2 * d_model),
|
|
21
|
+
)
|
|
22
|
+
nn.init.zeros_(self.mlp[-1].weight)
|
|
23
|
+
nn.init.zeros_(self.mlp[-1].bias)
|
|
24
|
+
|
|
25
|
+
def forward(self, x: torch.Tensor, stage_vec: torch.Tensor) -> torch.Tensor:
|
|
26
|
+
if stage_vec.dim() == 1:
|
|
27
|
+
stage_vec = stage_vec.unsqueeze(0).expand(x.size(0), -1)
|
|
28
|
+
g, b = self.mlp(stage_vec).chunk(2, dim=-1)
|
|
29
|
+
g = g.unsqueeze(1)
|
|
30
|
+
b = b.unsqueeze(1)
|
|
31
|
+
x = self.norm(x)
|
|
32
|
+
return x * (1 + torch.tanh(g)) + torch.tanh(b)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class NARSinglePass(nn.Module):
|
|
36
|
+
def __init__(
|
|
37
|
+
self, cfg: SoproTTSConfig, d_model: int, stage_specs: Dict[str, List[int]]
|
|
38
|
+
):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.cfg = cfg
|
|
41
|
+
self.stage_names = [
|
|
42
|
+
s for s in ["B", "C", "D", "E"] if len(stage_specs.get(s, [])) > 0
|
|
43
|
+
]
|
|
44
|
+
self.stage_to_id = {s: i for i, s in enumerate(self.stage_names)}
|
|
45
|
+
self.stage_specs = {s: stage_specs[s] for s in self.stage_names}
|
|
46
|
+
|
|
47
|
+
ks = int(cfg.nar_kernel_size)
|
|
48
|
+
cycle = tuple(int(x) for x in cfg.nar_dilation_cycle) or (1,)
|
|
49
|
+
dils: List[int] = []
|
|
50
|
+
while len(dils) < int(cfg.n_layers_nar):
|
|
51
|
+
dils.extend(cycle)
|
|
52
|
+
dils = dils[: int(cfg.n_layers_nar)]
|
|
53
|
+
|
|
54
|
+
self.blocks = nn.ModuleList(
|
|
55
|
+
[
|
|
56
|
+
SSMLiteBlock(
|
|
57
|
+
d_model, cfg.dropout, causal=False, kernel_size=ks, dilation=int(d)
|
|
58
|
+
)
|
|
59
|
+
for d in dils
|
|
60
|
+
]
|
|
61
|
+
)
|
|
62
|
+
self.norm = RMSNorm(d_model)
|
|
63
|
+
self.pre = nn.Linear(d_model, int(cfg.nar_head_dim))
|
|
64
|
+
|
|
65
|
+
self.stage_emb = nn.Embedding(len(self.stage_names), d_model)
|
|
66
|
+
self.adapter = NARStageAdapter(d_model, hidden=256)
|
|
67
|
+
|
|
68
|
+
self.heads = nn.ModuleDict()
|
|
69
|
+
self.head_id_emb = nn.ModuleDict()
|
|
70
|
+
for s in self.stage_names:
|
|
71
|
+
n_heads = len(self.stage_specs[s])
|
|
72
|
+
self.heads[s] = nn.ModuleList(
|
|
73
|
+
[
|
|
74
|
+
nn.Linear(int(cfg.nar_head_dim), int(cfg.codebook_size))
|
|
75
|
+
for _ in range(n_heads)
|
|
76
|
+
]
|
|
77
|
+
)
|
|
78
|
+
emb = nn.Embedding(n_heads, int(cfg.nar_head_dim))
|
|
79
|
+
nn.init.zeros_(emb.weight)
|
|
80
|
+
self.head_id_emb[s] = emb
|
|
81
|
+
|
|
82
|
+
self.mix = nn.ParameterDict(
|
|
83
|
+
{
|
|
84
|
+
s: nn.Parameter(torch.zeros(2, dtype=torch.float32))
|
|
85
|
+
for s in self.stage_names
|
|
86
|
+
}
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def forward_stage(
|
|
90
|
+
self, stage: str, cond: torch.Tensor, prev_emb: torch.Tensor
|
|
91
|
+
) -> List[torch.Tensor]:
|
|
92
|
+
if stage not in self.heads:
|
|
93
|
+
return []
|
|
94
|
+
|
|
95
|
+
w = torch.softmax(self.mix[stage], dim=0)
|
|
96
|
+
x = w[0] * cond + w[1] * prev_emb
|
|
97
|
+
|
|
98
|
+
sid = self.stage_to_id[stage]
|
|
99
|
+
stage_vec = self.stage_emb.weight[sid]
|
|
100
|
+
x = self.adapter(x, stage_vec)
|
|
101
|
+
|
|
102
|
+
for blk in self.blocks:
|
|
103
|
+
x = blk(x)
|
|
104
|
+
x = self.norm(x)
|
|
105
|
+
|
|
106
|
+
z = self.pre(x)
|
|
107
|
+
outs: List[torch.Tensor] = []
|
|
108
|
+
for i, head in enumerate(self.heads[stage]):
|
|
109
|
+
hb = (
|
|
110
|
+
self.head_id_emb[stage]
|
|
111
|
+
.weight[i]
|
|
112
|
+
.view(1, 1, -1)
|
|
113
|
+
.to(dtype=z.dtype, device=z.device)
|
|
114
|
+
)
|
|
115
|
+
outs.append(head(z + hb))
|
|
116
|
+
return outs
|
sopro/nn/ref.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
from .blocks import RMSNorm
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _rms_per_token(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
|
|
13
|
+
return torch.sqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + eps)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RefXAttnBlock(nn.Module):
|
|
17
|
+
def __init__(self, d_model: int, heads: int = 2, gmax: float = 0.35):
|
|
18
|
+
super().__init__()
|
|
19
|
+
assert d_model % heads == 0
|
|
20
|
+
|
|
21
|
+
self.d_model = int(d_model)
|
|
22
|
+
self.heads = int(heads)
|
|
23
|
+
self.head_dim = self.d_model // self.heads
|
|
24
|
+
self.gmax = float(gmax)
|
|
25
|
+
|
|
26
|
+
self.nq = RMSNorm(self.d_model)
|
|
27
|
+
self.nkv = RMSNorm(self.d_model)
|
|
28
|
+
|
|
29
|
+
self.q_proj = nn.Linear(self.d_model, self.d_model, bias=False)
|
|
30
|
+
self.k_proj = nn.Linear(self.d_model, self.d_model, bias=False)
|
|
31
|
+
self.v_proj = nn.Linear(self.d_model, self.d_model, bias=False)
|
|
32
|
+
self.out_proj = nn.Linear(self.d_model, self.d_model, bias=False)
|
|
33
|
+
|
|
34
|
+
self.gate = nn.Parameter(torch.tensor(0.0))
|
|
35
|
+
|
|
36
|
+
def _to_heads(self, t: torch.Tensor) -> torch.Tensor:
|
|
37
|
+
B, T, D = t.shape
|
|
38
|
+
return t.view(B, T, self.heads, self.head_dim).transpose(1, 2)
|
|
39
|
+
|
|
40
|
+
def _from_heads(self, t: torch.Tensor) -> torch.Tensor:
|
|
41
|
+
B, H, T, Hd = t.shape
|
|
42
|
+
return t.transpose(1, 2).contiguous().view(B, T, H * Hd)
|
|
43
|
+
|
|
44
|
+
def build_kv_cache(
|
|
45
|
+
self,
|
|
46
|
+
context: torch.Tensor,
|
|
47
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
48
|
+
) -> Dict[str, torch.Tensor]:
|
|
49
|
+
kv = self.nkv(context)
|
|
50
|
+
k = self._to_heads(self.k_proj(kv))
|
|
51
|
+
v = self._to_heads(self.v_proj(kv))
|
|
52
|
+
return {"k": k, "v": v, "key_padding_mask": key_padding_mask}
|
|
53
|
+
|
|
54
|
+
def forward(
|
|
55
|
+
self,
|
|
56
|
+
x: torch.Tensor,
|
|
57
|
+
*,
|
|
58
|
+
context: Optional[torch.Tensor] = None,
|
|
59
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
60
|
+
kv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
|
61
|
+
use_cache: bool = False,
|
|
62
|
+
):
|
|
63
|
+
q = self.nq(x)
|
|
64
|
+
q = self._to_heads(self.q_proj(q))
|
|
65
|
+
|
|
66
|
+
if kv_cache is None:
|
|
67
|
+
if context is None:
|
|
68
|
+
raise ValueError("context must be provided when kv_cache is None")
|
|
69
|
+
kv_cache = self.build_kv_cache(context, key_padding_mask=key_padding_mask)
|
|
70
|
+
|
|
71
|
+
k = kv_cache["k"]
|
|
72
|
+
v = kv_cache["v"]
|
|
73
|
+
kpm = kv_cache.get("key_padding_mask", None)
|
|
74
|
+
|
|
75
|
+
attn_bias = None
|
|
76
|
+
if kpm is not None:
|
|
77
|
+
kpm = kpm.to(torch.bool)
|
|
78
|
+
B = q.size(0)
|
|
79
|
+
S = k.size(-2)
|
|
80
|
+
attn_bias = torch.zeros((B, 1, 1, S), device=q.device, dtype=torch.float32)
|
|
81
|
+
attn_bias = attn_bias.masked_fill(kpm[:, None, None, :], float("-inf"))
|
|
82
|
+
|
|
83
|
+
bad = kpm.all(dim=1)
|
|
84
|
+
if bad.any():
|
|
85
|
+
attn_bias = attn_bias.clone()
|
|
86
|
+
attn_bias[bad, :, :, 0] = 0.0
|
|
87
|
+
|
|
88
|
+
with torch.autocast(device_type=x.device.type, enabled=False):
|
|
89
|
+
a = F.scaled_dot_product_attention(
|
|
90
|
+
q.float(),
|
|
91
|
+
k.float(),
|
|
92
|
+
v.float(),
|
|
93
|
+
attn_mask=attn_bias,
|
|
94
|
+
dropout_p=0.0,
|
|
95
|
+
is_causal=False,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
a = torch.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
|
|
99
|
+
a = self._from_heads(a)
|
|
100
|
+
|
|
101
|
+
scale = (_rms_per_token(x) / _rms_per_token(a)).clamp(0.0, 10.0)
|
|
102
|
+
a = (a * scale).to(x.dtype)
|
|
103
|
+
|
|
104
|
+
a = self.out_proj(a)
|
|
105
|
+
|
|
106
|
+
gate_eff = (self.gmax * torch.tanh(self.gate)).to(x.dtype)
|
|
107
|
+
y = x + gate_eff * a
|
|
108
|
+
return (y, kv_cache) if use_cache else y
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class RefXAttnStack(nn.Module):
|
|
112
|
+
def __init__(
|
|
113
|
+
self, d_model: int, heads: int = 2, layers: int = 3, gmax: float = 0.35
|
|
114
|
+
):
|
|
115
|
+
super().__init__()
|
|
116
|
+
self.blocks = nn.ModuleList(
|
|
117
|
+
[RefXAttnBlock(d_model, heads=heads, gmax=gmax) for _ in range(int(layers))]
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def build_kv_caches(
|
|
121
|
+
self,
|
|
122
|
+
context: torch.Tensor,
|
|
123
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
124
|
+
) -> List[Dict[str, torch.Tensor]]:
|
|
125
|
+
return [
|
|
126
|
+
blk.build_kv_cache(context, key_padding_mask=key_padding_mask)
|
|
127
|
+
for blk in self.blocks
|
|
128
|
+
]
|
|
129
|
+
|
|
130
|
+
def forward(
|
|
131
|
+
self,
|
|
132
|
+
x: torch.Tensor,
|
|
133
|
+
*,
|
|
134
|
+
context: Optional[torch.Tensor] = None,
|
|
135
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
136
|
+
kv_caches: Optional[List[Dict[str, torch.Tensor]]] = None,
|
|
137
|
+
use_cache: bool = False,
|
|
138
|
+
):
|
|
139
|
+
if use_cache:
|
|
140
|
+
if kv_caches is None:
|
|
141
|
+
if context is None:
|
|
142
|
+
raise ValueError("context must be provided when kv_caches is None")
|
|
143
|
+
kv_caches = self.build_kv_caches(
|
|
144
|
+
context, key_padding_mask=key_padding_mask
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
assert kv_caches is not None and len(kv_caches) == len(self.blocks)
|
|
148
|
+
new_caches: List[Dict[str, torch.Tensor]] = []
|
|
149
|
+
h = x
|
|
150
|
+
for blk, cache in zip(self.blocks, kv_caches):
|
|
151
|
+
h, cache2 = blk(h, kv_cache=cache, use_cache=True)
|
|
152
|
+
new_caches.append(cache2)
|
|
153
|
+
return h, new_caches
|
|
154
|
+
|
|
155
|
+
if context is None:
|
|
156
|
+
raise ValueError("context must be provided when use_cache=False")
|
|
157
|
+
h = x
|
|
158
|
+
for blk in self.blocks:
|
|
159
|
+
h = blk(h, context=context, key_padding_mask=key_padding_mask)
|
|
160
|
+
return h
|
sopro/nn/speaker.py
CHANGED
|
@@ -11,7 +11,7 @@ from .blocks import AttentiveStatsPool, DepthwiseConv1d
|
|
|
11
11
|
|
|
12
12
|
class Token2SV(nn.Module):
|
|
13
13
|
def __init__(
|
|
14
|
-
self, Q: int, V: int, d: int = 192, out_dim: int =
|
|
14
|
+
self, Q: int, V: int, d: int = 192, out_dim: int = 192, dropout: float = 0.05
|
|
15
15
|
):
|
|
16
16
|
super().__init__()
|
|
17
17
|
self.Q, self.V = int(Q), int(V)
|
|
@@ -27,7 +27,6 @@ class Token2SV(nn.Module):
|
|
|
27
27
|
DepthwiseConv1d(d, 7, causal=False),
|
|
28
28
|
nn.GELU(),
|
|
29
29
|
)
|
|
30
|
-
|
|
31
30
|
self.pool = AttentiveStatsPool(d)
|
|
32
31
|
self.proj = nn.Linear(2 * d, out_dim)
|
|
33
32
|
|
|
@@ -39,26 +38,24 @@ class Token2SV(nn.Module):
|
|
|
39
38
|
self, tokens_btq: torch.Tensor, lengths: Optional[torch.Tensor] = None
|
|
40
39
|
) -> torch.Tensor:
|
|
41
40
|
B, T, Q = tokens_btq.shape
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
41
|
+
device = tokens_btq.device
|
|
42
|
+
|
|
43
|
+
if lengths is not None:
|
|
44
|
+
valid = torch.arange(T, device=device)[None, :] < lengths[:, None]
|
|
45
|
+
else:
|
|
46
|
+
valid = torch.ones(B, T, device=device, dtype=torch.bool)
|
|
47
|
+
|
|
48
|
+
q_idx = torch.arange(Q, device=device, dtype=torch.long).view(1, 1, Q)
|
|
45
49
|
idx = q_idx * self.V + tokens_btq.long()
|
|
46
50
|
raw_emb = self.emb(idx)
|
|
47
|
-
|
|
48
|
-
if self.training:
|
|
49
|
-
keep_prob = 0.95
|
|
50
|
-
mask = torch.rand(B, T, device=tokens_btq.device) < keep_prob
|
|
51
|
-
bad = mask.sum(dim=1) == 0
|
|
52
|
-
if bad.any():
|
|
53
|
-
bad_idx = bad.nonzero(as_tuple=False).squeeze(1)
|
|
54
|
-
rand_pos = torch.randint(
|
|
55
|
-
0, T, (bad_idx.numel(),), device=tokens_btq.device
|
|
56
|
-
)
|
|
57
|
-
mask[bad_idx, rand_pos] = True
|
|
58
|
-
raw_emb = raw_emb * mask.float().unsqueeze(-1).unsqueeze(-1)
|
|
51
|
+
raw_emb = raw_emb * valid[:, :, None, None].to(raw_emb.dtype)
|
|
59
52
|
|
|
60
53
|
x = self._get_mixed_embedding(raw_emb)
|
|
54
|
+
x = x * valid[:, :, None].to(x.dtype)
|
|
55
|
+
|
|
61
56
|
h = self.enc(x)
|
|
57
|
+
h = h * valid[:, :, None].to(h.dtype)
|
|
58
|
+
|
|
62
59
|
pooled = self.pool(h, lengths=lengths)
|
|
63
60
|
e = self.proj(pooled)
|
|
64
61
|
return F.normalize(e, dim=-1, eps=1e-6)
|