sopro 1.0.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sopro/nn/__init__.py CHANGED
@@ -1,7 +1,9 @@
1
1
  from .blocks import GLU, AttentiveStatsPool, DepthwiseConv1d, RMSNorm, SSMLiteBlock
2
2
  from .embeddings import CodebookEmbedding, SinusoidalPositionalEmbedding, TextEmbedding
3
+ from .generator import ARRVQ1Generator
4
+ from .ref import RefXAttnBlock, RefXAttnStack
3
5
  from .speaker import SpeakerFiLM, Token2SV
4
- from .xattn import RefXAttn, RefXAttnBlock, TextXAttnBlock
6
+ from .text import TextEncoder, TextXAttnBlock
5
7
 
6
8
  __all__ = [
7
9
  "GLU",
@@ -14,7 +16,9 @@ __all__ = [
14
16
  "CodebookEmbedding",
15
17
  "Token2SV",
16
18
  "SpeakerFiLM",
17
- "RefXAttn",
18
- "RefXAttnBlock",
19
+ "TextEncoder",
19
20
  "TextXAttnBlock",
21
+ "RefXAttnBlock",
22
+ "RefXAttnStack",
23
+ "ARRVQ1Generator",
20
24
  ]
sopro/nn/blocks.py CHANGED
@@ -1,10 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple
5
+
3
6
  import torch
4
7
  import torch.nn as nn
5
8
  import torch.nn.functional as F
6
9
 
7
10
 
11
+ @dataclass
12
+ class DepthwiseConv1dState:
13
+ buf: torch.Tensor
14
+
15
+
8
16
  class GLU(nn.Module):
9
17
  def __init__(self, d: int):
10
18
  super().__init__()
@@ -39,6 +47,19 @@ class DepthwiseConv1d(nn.Module):
39
47
  self.kernel_size = kernel_size
40
48
  self.dw = nn.Conv1d(d, d, kernel_size, groups=d, padding=0, dilation=dilation)
41
49
 
50
+ def _ctx_len(self) -> int:
51
+ return (self.kernel_size - 1) * self.dilation + 1
52
+
53
+ def init_state(
54
+ self, batch_size: int, device: torch.device, dtype: torch.dtype
55
+ ) -> DepthwiseConv1dState:
56
+ if not self.causal:
57
+ raise ValueError("init_state is only valid for causal convs")
58
+ L = self._ctx_len()
59
+ D = int(self.dw.in_channels)
60
+ buf = torch.zeros((batch_size, L, D), device=device, dtype=dtype)
61
+ return DepthwiseConv1dState(buf=buf)
62
+
42
63
  def forward(self, x: torch.Tensor) -> torch.Tensor:
43
64
  xt = x.transpose(1, 2)
44
65
  if self.causal:
@@ -52,6 +73,42 @@ class DepthwiseConv1d(nn.Module):
52
73
  y = self.dw(xt)
53
74
  return y.transpose(1, 2)
54
75
 
76
+ def forward_step(
77
+ self, x_bt_d: torch.Tensor, state: Optional[DepthwiseConv1dState]
78
+ ) -> Tuple[torch.Tensor, DepthwiseConv1dState]:
79
+ if not self.causal:
80
+ raise ValueError("forward_step is only valid for causal convs")
81
+
82
+ if x_bt_d.dim() == 2:
83
+ x_bt_d = x_bt_d.unsqueeze(1)
84
+
85
+ B, T, D = x_bt_d.shape
86
+ if T != 1:
87
+ raise ValueError("forward_step expects a single timestep [B,1,D]")
88
+
89
+ if state is None:
90
+ state = self.init_state(B, x_bt_d.device, x_bt_d.dtype)
91
+
92
+ buf = state.buf
93
+ if buf.size(1) > 1:
94
+ buf = torch.cat([buf[:, 1:, :], x_bt_d], dim=1)
95
+ else:
96
+ buf = x_bt_d
97
+
98
+ k = int(self.kernel_size)
99
+ d = int(self.dilation)
100
+ idx = torch.arange(0, k * d, d, device=x_bt_d.device)
101
+ x_bkd = buf.index_select(1, idx) # [B,k,D]
102
+
103
+ w_dk = self.dw.weight.squeeze(1).to(dtype=x_bt_d.dtype)
104
+ y_bd = (x_bkd.transpose(1, 2) * w_dk.unsqueeze(0)).sum(dim=-1)
105
+ if self.dw.bias is not None:
106
+ y_bd = y_bd + self.dw.bias.to(dtype=y_bd.dtype).unsqueeze(0)
107
+
108
+ y_bt_d = y_bd.unsqueeze(1)
109
+ state.buf = buf
110
+ return y_bt_d, state
111
+
55
112
 
56
113
  class SSMLiteBlock(nn.Module):
57
114
  def __init__(
@@ -76,6 +133,13 @@ class SSMLiteBlock(nn.Module):
76
133
  )
77
134
  self.drop = nn.Dropout(dropout)
78
135
 
136
+ def init_state(
137
+ self, batch_size: int, device: torch.device, dtype: torch.dtype
138
+ ) -> dict:
139
+ if not self.dw.causal:
140
+ raise ValueError("SSMLiteBlock.init_state only valid for causal blocks")
141
+ return {"dw": self.dw.init_state(batch_size, device, dtype)}
142
+
79
143
  def forward(self, x: torch.Tensor) -> torch.Tensor:
80
144
  h = self.glu(self.norm(x))
81
145
  h = self.dw(h)
@@ -83,6 +147,20 @@ class SSMLiteBlock(nn.Module):
83
147
  x = x + self.drop(self.ff(x))
84
148
  return x
85
149
 
150
+ def forward_step(
151
+ self, x_bt_d: torch.Tensor, state: dict
152
+ ) -> Tuple[torch.Tensor, dict]:
153
+ if not self.dw.causal:
154
+ raise ValueError("forward_step only valid for causal blocks")
155
+
156
+ h = self.glu(self.norm(x_bt_d))
157
+ y, dw_state = self.dw.forward_step(h, state.get("dw", None))
158
+ state["dw"] = dw_state
159
+
160
+ x = x_bt_d + self.drop(y)
161
+ x = x + self.drop(self.ff(x))
162
+ return x, state
163
+
86
164
 
87
165
  class AttentiveStatsPool(nn.Module):
88
166
  def __init__(self, d: int):
sopro/nn/embeddings.py CHANGED
@@ -79,6 +79,7 @@ class CodebookEmbedding(nn.Module):
79
79
  tokens_subset: Optional[torch.Tensor],
80
80
  cb_indices: Optional[List[int]],
81
81
  keep_mask: Optional[torch.Tensor] = None,
82
+ cb_weights: Optional[torch.Tensor] = None,
82
83
  ) -> torch.Tensor:
83
84
  if tokens_subset is None or cb_indices is None or len(cb_indices) == 0:
84
85
  return 0.0
@@ -90,6 +91,21 @@ class CodebookEmbedding(nn.Module):
90
91
  idx = torch.cat(idx_list, dim=2)
91
92
  emb = self.emb(idx)
92
93
 
94
+ if cb_weights is not None:
95
+ w = cb_weights
96
+ if w.dim() != 1:
97
+ raise ValueError("cb_weights must be 1D")
98
+ if w.numel() == self.Q:
99
+ cb_t = torch.tensor(cb_indices, device=emb.device, dtype=torch.long)
100
+ w = w.to(emb.device).index_select(0, cb_t)
101
+ elif w.numel() != K:
102
+ raise ValueError(
103
+ f"cb_weights must have len Q={self.Q} or K={K}, got {w.numel()}"
104
+ )
105
+
106
+ w = F.softmax(w.float(), dim=0).to(dtype=emb.dtype)
107
+ emb = emb * w.view(1, 1, K, 1)
108
+
93
109
  if keep_mask is not None:
94
110
  emb = emb * keep_mask.unsqueeze(-1).to(emb.dtype)
95
111
 
sopro/nn/generator.py ADDED
@@ -0,0 +1,130 @@
1
+ from typing import Dict, List, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from sopro.config import SoproTTSConfig
7
+ from sopro.nn.blocks import RMSNorm, SSMLiteBlock
8
+ from sopro.nn.text import TextXAttnBlock
9
+
10
+
11
+ class ARRVQ1Generator(nn.Module):
12
+ def __init__(self, cfg: SoproTTSConfig, d_model: int, vocab: int):
13
+ super().__init__()
14
+ ks = int(cfg.ar_kernel)
15
+
16
+ dils: List[int] = []
17
+ while len(dils) < int(cfg.n_layers_ar):
18
+ dils.extend(list(cfg.ar_dilation_cycle))
19
+ dils = dils[: int(cfg.n_layers_ar)]
20
+ self.dils = tuple(int(d) for d in dils)
21
+
22
+ self.blocks = nn.ModuleList(
23
+ [
24
+ SSMLiteBlock(
25
+ d_model, cfg.dropout, causal=True, kernel_size=ks, dilation=d
26
+ )
27
+ for d in self.dils
28
+ ]
29
+ )
30
+
31
+ self.attn_freq = int(cfg.ar_text_attn_freq)
32
+ self.x_attns = nn.ModuleList()
33
+ for i in range(len(self.blocks)):
34
+ if (i + 1) % self.attn_freq == 0:
35
+ self.x_attns.append(
36
+ TextXAttnBlock(d_model, heads=4, dropout=cfg.dropout)
37
+ )
38
+ else:
39
+ self.x_attns.append(nn.Identity())
40
+
41
+ self.norm = RMSNorm(d_model)
42
+ self.head = nn.Linear(d_model, vocab)
43
+
44
+ @torch.no_grad()
45
+ def init_stream_state(
46
+ self,
47
+ batch_size: int,
48
+ device: torch.device,
49
+ dtype: torch.dtype,
50
+ *,
51
+ text_emb: Optional[torch.Tensor] = None,
52
+ text_mask: Optional[torch.Tensor] = None,
53
+ ) -> Dict[str, object]:
54
+ layer_states = [
55
+ blk.init_state(batch_size, device, dtype) for blk in self.blocks
56
+ ]
57
+
58
+ kv_caches: List[Optional[Dict[str, torch.Tensor]]] = []
59
+ key_padding_mask = (~text_mask) if text_mask is not None else None
60
+ for xa in self.x_attns:
61
+ if isinstance(xa, nn.Identity) or (text_emb is None):
62
+ kv_caches.append(None)
63
+ else:
64
+ kv_caches.append(
65
+ xa.build_kv_cache(text_emb, key_padding_mask=key_padding_mask)
66
+ )
67
+
68
+ return {"layer_states": layer_states, "kv_caches": kv_caches}
69
+
70
+ def forward(
71
+ self,
72
+ x: torch.Tensor,
73
+ text_emb: Optional[torch.Tensor] = None,
74
+ text_mask: Optional[torch.Tensor] = None,
75
+ ) -> torch.Tensor:
76
+ key_padding_mask = ~text_mask if text_mask is not None else None
77
+
78
+ if key_padding_mask is not None:
79
+ bad_rows = key_padding_mask.all(dim=1)
80
+ if bad_rows.any():
81
+ key_padding_mask = key_padding_mask.clone()
82
+ idx = torch.nonzero(bad_rows, as_tuple=False).squeeze(1)
83
+ key_padding_mask[idx, 0] = False
84
+ if text_emb is not None:
85
+ text_emb = text_emb.clone()
86
+ text_emb[idx, 0, :] = 0
87
+
88
+ h = x
89
+ for i, lyr in enumerate(self.blocks):
90
+ h = lyr(h)
91
+ if not isinstance(self.x_attns[i], nn.Identity) and text_emb is not None:
92
+ h = self.x_attns[i](h, text_emb, key_padding_mask=key_padding_mask)
93
+
94
+ h = self.norm(h)
95
+
96
+ return self.head(h)
97
+
98
+ @torch.no_grad()
99
+ def step(
100
+ self,
101
+ x_bt_d: torch.Tensor,
102
+ state: Dict[str, object],
103
+ *,
104
+ text_emb: Optional[torch.Tensor] = None,
105
+ text_mask: Optional[torch.Tensor] = None,
106
+ ) -> Tuple[torch.Tensor, Dict[str, object]]:
107
+ h = x_bt_d
108
+ key_padding_mask = (~text_mask) if text_mask is not None else None
109
+
110
+ layer_states: List[dict] = state["layer_states"]
111
+ kv_caches: List[Optional[Dict[str, torch.Tensor]]] = state["kv_caches"]
112
+
113
+ for i, blk in enumerate(self.blocks):
114
+ h, layer_states[i] = blk.forward_step(h, layer_states[i])
115
+
116
+ xa = self.x_attns[i]
117
+ if (not isinstance(xa, nn.Identity)) and (text_emb is not None):
118
+ kv = kv_caches[i]
119
+ if kv is None:
120
+ kv = xa.build_kv_cache(text_emb, key_padding_mask=key_padding_mask)
121
+ h, kv = xa(h, kv_cache=kv, use_cache=True)
122
+ kv_caches[i] = kv
123
+
124
+ state["layer_states"] = layer_states
125
+ state["kv_caches"] = kv_caches
126
+
127
+ h = self.norm(h)
128
+ logits = self.head(h)
129
+
130
+ return logits, state
sopro/nn/nar.py ADDED
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, List
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from sopro.config import SoproTTSConfig
9
+
10
+ from .blocks import RMSNorm, SSMLiteBlock
11
+
12
+
13
+ class NARStageAdapter(nn.Module):
14
+ def __init__(self, d_model: int, hidden: int = 256):
15
+ super().__init__()
16
+ self.norm = RMSNorm(d_model)
17
+ self.mlp = nn.Sequential(
18
+ nn.Linear(d_model, hidden),
19
+ nn.GELU(),
20
+ nn.Linear(hidden, 2 * d_model),
21
+ )
22
+ nn.init.zeros_(self.mlp[-1].weight)
23
+ nn.init.zeros_(self.mlp[-1].bias)
24
+
25
+ def forward(self, x: torch.Tensor, stage_vec: torch.Tensor) -> torch.Tensor:
26
+ if stage_vec.dim() == 1:
27
+ stage_vec = stage_vec.unsqueeze(0).expand(x.size(0), -1)
28
+ g, b = self.mlp(stage_vec).chunk(2, dim=-1)
29
+ g = g.unsqueeze(1)
30
+ b = b.unsqueeze(1)
31
+ x = self.norm(x)
32
+ return x * (1 + torch.tanh(g)) + torch.tanh(b)
33
+
34
+
35
+ class NARSinglePass(nn.Module):
36
+ def __init__(
37
+ self, cfg: SoproTTSConfig, d_model: int, stage_specs: Dict[str, List[int]]
38
+ ):
39
+ super().__init__()
40
+ self.cfg = cfg
41
+ self.stage_names = [
42
+ s for s in ["B", "C", "D", "E"] if len(stage_specs.get(s, [])) > 0
43
+ ]
44
+ self.stage_to_id = {s: i for i, s in enumerate(self.stage_names)}
45
+ self.stage_specs = {s: stage_specs[s] for s in self.stage_names}
46
+
47
+ ks = int(cfg.nar_kernel_size)
48
+ cycle = tuple(int(x) for x in cfg.nar_dilation_cycle) or (1,)
49
+ dils: List[int] = []
50
+ while len(dils) < int(cfg.n_layers_nar):
51
+ dils.extend(cycle)
52
+ dils = dils[: int(cfg.n_layers_nar)]
53
+
54
+ self.blocks = nn.ModuleList(
55
+ [
56
+ SSMLiteBlock(
57
+ d_model, cfg.dropout, causal=False, kernel_size=ks, dilation=int(d)
58
+ )
59
+ for d in dils
60
+ ]
61
+ )
62
+ self.norm = RMSNorm(d_model)
63
+ self.pre = nn.Linear(d_model, int(cfg.nar_head_dim))
64
+
65
+ self.stage_emb = nn.Embedding(len(self.stage_names), d_model)
66
+ self.adapter = NARStageAdapter(d_model, hidden=256)
67
+
68
+ self.heads = nn.ModuleDict()
69
+ self.head_id_emb = nn.ModuleDict()
70
+ for s in self.stage_names:
71
+ n_heads = len(self.stage_specs[s])
72
+ self.heads[s] = nn.ModuleList(
73
+ [
74
+ nn.Linear(int(cfg.nar_head_dim), int(cfg.codebook_size))
75
+ for _ in range(n_heads)
76
+ ]
77
+ )
78
+ emb = nn.Embedding(n_heads, int(cfg.nar_head_dim))
79
+ nn.init.zeros_(emb.weight)
80
+ self.head_id_emb[s] = emb
81
+
82
+ self.mix = nn.ParameterDict(
83
+ {
84
+ s: nn.Parameter(torch.zeros(2, dtype=torch.float32))
85
+ for s in self.stage_names
86
+ }
87
+ )
88
+
89
+ def forward_stage(
90
+ self, stage: str, cond: torch.Tensor, prev_emb: torch.Tensor
91
+ ) -> List[torch.Tensor]:
92
+ if stage not in self.heads:
93
+ return []
94
+
95
+ w = torch.softmax(self.mix[stage], dim=0)
96
+ x = w[0] * cond + w[1] * prev_emb
97
+
98
+ sid = self.stage_to_id[stage]
99
+ stage_vec = self.stage_emb.weight[sid]
100
+ x = self.adapter(x, stage_vec)
101
+
102
+ for blk in self.blocks:
103
+ x = blk(x)
104
+ x = self.norm(x)
105
+
106
+ z = self.pre(x)
107
+ outs: List[torch.Tensor] = []
108
+ for i, head in enumerate(self.heads[stage]):
109
+ hb = (
110
+ self.head_id_emb[stage]
111
+ .weight[i]
112
+ .view(1, 1, -1)
113
+ .to(dtype=z.dtype, device=z.device)
114
+ )
115
+ outs.append(head(z + hb))
116
+ return outs
sopro/nn/ref.py ADDED
@@ -0,0 +1,160 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from .blocks import RMSNorm
10
+
11
+
12
+ def _rms_per_token(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
13
+ return torch.sqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + eps)
14
+
15
+
16
+ class RefXAttnBlock(nn.Module):
17
+ def __init__(self, d_model: int, heads: int = 2, gmax: float = 0.35):
18
+ super().__init__()
19
+ assert d_model % heads == 0
20
+
21
+ self.d_model = int(d_model)
22
+ self.heads = int(heads)
23
+ self.head_dim = self.d_model // self.heads
24
+ self.gmax = float(gmax)
25
+
26
+ self.nq = RMSNorm(self.d_model)
27
+ self.nkv = RMSNorm(self.d_model)
28
+
29
+ self.q_proj = nn.Linear(self.d_model, self.d_model, bias=False)
30
+ self.k_proj = nn.Linear(self.d_model, self.d_model, bias=False)
31
+ self.v_proj = nn.Linear(self.d_model, self.d_model, bias=False)
32
+ self.out_proj = nn.Linear(self.d_model, self.d_model, bias=False)
33
+
34
+ self.gate = nn.Parameter(torch.tensor(0.0))
35
+
36
+ def _to_heads(self, t: torch.Tensor) -> torch.Tensor:
37
+ B, T, D = t.shape
38
+ return t.view(B, T, self.heads, self.head_dim).transpose(1, 2)
39
+
40
+ def _from_heads(self, t: torch.Tensor) -> torch.Tensor:
41
+ B, H, T, Hd = t.shape
42
+ return t.transpose(1, 2).contiguous().view(B, T, H * Hd)
43
+
44
+ def build_kv_cache(
45
+ self,
46
+ context: torch.Tensor,
47
+ key_padding_mask: Optional[torch.Tensor] = None,
48
+ ) -> Dict[str, torch.Tensor]:
49
+ kv = self.nkv(context)
50
+ k = self._to_heads(self.k_proj(kv))
51
+ v = self._to_heads(self.v_proj(kv))
52
+ return {"k": k, "v": v, "key_padding_mask": key_padding_mask}
53
+
54
+ def forward(
55
+ self,
56
+ x: torch.Tensor,
57
+ *,
58
+ context: Optional[torch.Tensor] = None,
59
+ key_padding_mask: Optional[torch.Tensor] = None,
60
+ kv_cache: Optional[Dict[str, torch.Tensor]] = None,
61
+ use_cache: bool = False,
62
+ ):
63
+ q = self.nq(x)
64
+ q = self._to_heads(self.q_proj(q))
65
+
66
+ if kv_cache is None:
67
+ if context is None:
68
+ raise ValueError("context must be provided when kv_cache is None")
69
+ kv_cache = self.build_kv_cache(context, key_padding_mask=key_padding_mask)
70
+
71
+ k = kv_cache["k"]
72
+ v = kv_cache["v"]
73
+ kpm = kv_cache.get("key_padding_mask", None)
74
+
75
+ attn_bias = None
76
+ if kpm is not None:
77
+ kpm = kpm.to(torch.bool)
78
+ B = q.size(0)
79
+ S = k.size(-2)
80
+ attn_bias = torch.zeros((B, 1, 1, S), device=q.device, dtype=torch.float32)
81
+ attn_bias = attn_bias.masked_fill(kpm[:, None, None, :], float("-inf"))
82
+
83
+ bad = kpm.all(dim=1)
84
+ if bad.any():
85
+ attn_bias = attn_bias.clone()
86
+ attn_bias[bad, :, :, 0] = 0.0
87
+
88
+ with torch.autocast(device_type=x.device.type, enabled=False):
89
+ a = F.scaled_dot_product_attention(
90
+ q.float(),
91
+ k.float(),
92
+ v.float(),
93
+ attn_mask=attn_bias,
94
+ dropout_p=0.0,
95
+ is_causal=False,
96
+ )
97
+
98
+ a = torch.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
99
+ a = self._from_heads(a)
100
+
101
+ scale = (_rms_per_token(x) / _rms_per_token(a)).clamp(0.0, 10.0)
102
+ a = (a * scale).to(x.dtype)
103
+
104
+ a = self.out_proj(a)
105
+
106
+ gate_eff = (self.gmax * torch.tanh(self.gate)).to(x.dtype)
107
+ y = x + gate_eff * a
108
+ return (y, kv_cache) if use_cache else y
109
+
110
+
111
+ class RefXAttnStack(nn.Module):
112
+ def __init__(
113
+ self, d_model: int, heads: int = 2, layers: int = 3, gmax: float = 0.35
114
+ ):
115
+ super().__init__()
116
+ self.blocks = nn.ModuleList(
117
+ [RefXAttnBlock(d_model, heads=heads, gmax=gmax) for _ in range(int(layers))]
118
+ )
119
+
120
+ def build_kv_caches(
121
+ self,
122
+ context: torch.Tensor,
123
+ key_padding_mask: Optional[torch.Tensor] = None,
124
+ ) -> List[Dict[str, torch.Tensor]]:
125
+ return [
126
+ blk.build_kv_cache(context, key_padding_mask=key_padding_mask)
127
+ for blk in self.blocks
128
+ ]
129
+
130
+ def forward(
131
+ self,
132
+ x: torch.Tensor,
133
+ *,
134
+ context: Optional[torch.Tensor] = None,
135
+ key_padding_mask: Optional[torch.Tensor] = None,
136
+ kv_caches: Optional[List[Dict[str, torch.Tensor]]] = None,
137
+ use_cache: bool = False,
138
+ ):
139
+ if use_cache:
140
+ if kv_caches is None:
141
+ if context is None:
142
+ raise ValueError("context must be provided when kv_caches is None")
143
+ kv_caches = self.build_kv_caches(
144
+ context, key_padding_mask=key_padding_mask
145
+ )
146
+
147
+ assert kv_caches is not None and len(kv_caches) == len(self.blocks)
148
+ new_caches: List[Dict[str, torch.Tensor]] = []
149
+ h = x
150
+ for blk, cache in zip(self.blocks, kv_caches):
151
+ h, cache2 = blk(h, kv_cache=cache, use_cache=True)
152
+ new_caches.append(cache2)
153
+ return h, new_caches
154
+
155
+ if context is None:
156
+ raise ValueError("context must be provided when use_cache=False")
157
+ h = x
158
+ for blk in self.blocks:
159
+ h = blk(h, context=context, key_padding_mask=key_padding_mask)
160
+ return h
sopro/nn/speaker.py CHANGED
@@ -11,7 +11,7 @@ from .blocks import AttentiveStatsPool, DepthwiseConv1d
11
11
 
12
12
  class Token2SV(nn.Module):
13
13
  def __init__(
14
- self, Q: int, V: int, d: int = 192, out_dim: int = 256, dropout: float = 0.05
14
+ self, Q: int, V: int, d: int = 192, out_dim: int = 192, dropout: float = 0.05
15
15
  ):
16
16
  super().__init__()
17
17
  self.Q, self.V = int(Q), int(V)
@@ -27,7 +27,6 @@ class Token2SV(nn.Module):
27
27
  DepthwiseConv1d(d, 7, causal=False),
28
28
  nn.GELU(),
29
29
  )
30
-
31
30
  self.pool = AttentiveStatsPool(d)
32
31
  self.proj = nn.Linear(2 * d, out_dim)
33
32
 
@@ -39,26 +38,24 @@ class Token2SV(nn.Module):
39
38
  self, tokens_btq: torch.Tensor, lengths: Optional[torch.Tensor] = None
40
39
  ) -> torch.Tensor:
41
40
  B, T, Q = tokens_btq.shape
42
- q_idx = torch.arange(Q, device=tokens_btq.device, dtype=torch.long).view(
43
- 1, 1, Q
44
- )
41
+ device = tokens_btq.device
42
+
43
+ if lengths is not None:
44
+ valid = torch.arange(T, device=device)[None, :] < lengths[:, None]
45
+ else:
46
+ valid = torch.ones(B, T, device=device, dtype=torch.bool)
47
+
48
+ q_idx = torch.arange(Q, device=device, dtype=torch.long).view(1, 1, Q)
45
49
  idx = q_idx * self.V + tokens_btq.long()
46
50
  raw_emb = self.emb(idx)
47
-
48
- if self.training:
49
- keep_prob = 0.95
50
- mask = torch.rand(B, T, device=tokens_btq.device) < keep_prob
51
- bad = mask.sum(dim=1) == 0
52
- if bad.any():
53
- bad_idx = bad.nonzero(as_tuple=False).squeeze(1)
54
- rand_pos = torch.randint(
55
- 0, T, (bad_idx.numel(),), device=tokens_btq.device
56
- )
57
- mask[bad_idx, rand_pos] = True
58
- raw_emb = raw_emb * mask.float().unsqueeze(-1).unsqueeze(-1)
51
+ raw_emb = raw_emb * valid[:, :, None, None].to(raw_emb.dtype)
59
52
 
60
53
  x = self._get_mixed_embedding(raw_emb)
54
+ x = x * valid[:, :, None].to(x.dtype)
55
+
61
56
  h = self.enc(x)
57
+ h = h * valid[:, :, None].to(h.dtype)
58
+
62
59
  pooled = self.pool(h, lengths=lengths)
63
60
  e = self.proj(pooled)
64
61
  return F.normalize(e, dim=-1, eps=1e-6)