sopro 1.0.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sopro/nn/text.py ADDED
@@ -0,0 +1,132 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from sopro.config import SoproTTSConfig
10
+ from sopro.nn.embeddings import SinusoidalPositionalEmbedding, TextEmbedding
11
+ from sopro.tokenizer import TextTokenizer
12
+
13
+ from .blocks import RMSNorm, SSMLiteBlock
14
+
15
+
16
+ class TextEncoder(nn.Module):
17
+ def __init__(
18
+ self, cfg: SoproTTSConfig, d_model: int, n_layers: int, tokenizer: TextTokenizer
19
+ ):
20
+ super().__init__()
21
+ self.tok = tokenizer
22
+ self.embed = TextEmbedding(self.tok.vocab_size, d_model)
23
+ self.layers = nn.ModuleList(
24
+ [SSMLiteBlock(d_model, cfg.dropout, causal=False) for _ in range(n_layers)]
25
+ )
26
+ self.pos = SinusoidalPositionalEmbedding(d_model, max_len=cfg.max_text_len + 8)
27
+ self.norm = RMSNorm(d_model)
28
+
29
+ def forward(
30
+ self, text_ids: torch.Tensor, mask: torch.Tensor
31
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
32
+ x = self.embed(text_ids)
33
+ L = x.size(1)
34
+ pos = self.pos(torch.arange(L, device=x.device))
35
+ x = x + pos.unsqueeze(0)
36
+
37
+ x = x * mask.unsqueeze(-1).float()
38
+ for layer in self.layers:
39
+ x = layer(x)
40
+ x = self.norm(x)
41
+
42
+ mask_f = mask.float().unsqueeze(-1)
43
+ pooled = (x * mask_f).sum(dim=1) / (mask_f.sum(dim=1) + 1e-6)
44
+ return x, pooled
45
+
46
+
47
+ class TextXAttnBlock(nn.Module):
48
+ def __init__(self, d_model: int, heads: int = 4, dropout: float = 0.0):
49
+ super().__init__()
50
+ assert d_model % heads == 0
51
+
52
+ self.d_model = int(d_model)
53
+ self.heads = int(heads)
54
+ self.head_dim = self.d_model // self.heads
55
+ self.dropout = float(dropout)
56
+
57
+ self.nq = RMSNorm(self.d_model)
58
+ self.nkv = RMSNorm(self.d_model)
59
+
60
+ self.q_proj = nn.Linear(self.d_model, self.d_model, bias=False)
61
+ self.k_proj = nn.Linear(self.d_model, self.d_model, bias=False)
62
+ self.v_proj = nn.Linear(self.d_model, self.d_model, bias=False)
63
+ self.out_proj = nn.Linear(self.d_model, self.d_model, bias=False)
64
+
65
+ self.gate = nn.Parameter(torch.tensor(0.0))
66
+
67
+ def _to_heads(self, t: torch.Tensor) -> torch.Tensor:
68
+ B, T, D = t.shape
69
+ return t.view(B, T, self.heads, self.head_dim).transpose(1, 2)
70
+
71
+ def _from_heads(self, t: torch.Tensor) -> torch.Tensor:
72
+ B, H, T, Hd = t.shape
73
+ return t.transpose(1, 2).contiguous().view(B, T, H * Hd)
74
+
75
+ def build_kv_cache(
76
+ self,
77
+ context: torch.Tensor,
78
+ key_padding_mask: Optional[torch.Tensor] = None,
79
+ ) -> Dict[str, torch.Tensor]:
80
+ kv = self.nkv(context)
81
+ k = self._to_heads(self.k_proj(kv))
82
+ v = self._to_heads(self.v_proj(kv))
83
+ return {"k": k, "v": v, "key_padding_mask": key_padding_mask}
84
+
85
+ def forward(
86
+ self,
87
+ x: torch.Tensor,
88
+ context: Optional[torch.Tensor] = None,
89
+ key_padding_mask: Optional[torch.Tensor] = None,
90
+ kv_cache: Optional[Dict[str, torch.Tensor]] = None,
91
+ use_cache: bool = False,
92
+ ):
93
+ q = self.nq(x)
94
+ q = self._to_heads(self.q_proj(q))
95
+
96
+ if kv_cache is None:
97
+ if context is None:
98
+ raise ValueError("context must be provided when kv_cache is None")
99
+ kv_cache = self.build_kv_cache(context, key_padding_mask=key_padding_mask)
100
+
101
+ k = kv_cache["k"]
102
+ v = kv_cache["v"]
103
+ kpm = kv_cache.get("key_padding_mask", None)
104
+
105
+ attn_mask = None
106
+ if kpm is not None:
107
+ kpm = kpm.to(torch.bool)
108
+
109
+ keep = ~kpm
110
+
111
+ bad = ~keep.any(dim=1)
112
+ if bad.any():
113
+ keep = keep.clone()
114
+ keep[bad, 0] = True
115
+
116
+ attn_mask = keep[:, None, None, :]
117
+
118
+ with torch.autocast(device_type=x.device.type, enabled=False):
119
+ a = F.scaled_dot_product_attention(
120
+ q.float(),
121
+ k.float(),
122
+ v.float(),
123
+ attn_mask=attn_mask,
124
+ dropout_p=self.dropout if self.training else 0.0,
125
+ is_causal=False,
126
+ )
127
+
128
+ a = torch.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0).to(x.dtype)
129
+ a = self.out_proj(self._from_heads(a))
130
+
131
+ y = x + torch.tanh(self.gate) * a
132
+ return (y, kv_cache) if use_cache else y
sopro/sampling.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import List
3
+ from typing import List, Tuple
4
4
 
5
5
  import torch
6
6
 
@@ -97,5 +97,5 @@ def rf_ar(ar_kernel: int, dilations: Tuple[int, ...]) -> int:
97
97
  return 1 + (ar_kernel - 1) * int(sum(dilations))
98
98
 
99
99
 
100
- def rf_nar(n_layers_nar: int, kernel_size: int = 7, dilation: int = 1) -> int:
101
- return 1 + (kernel_size - 1) * int(n_layers_nar) * int(dilation)
100
+ def rf_nar(kernel_size: int, dilations: Tuple[int, ...]) -> int:
101
+ return 1 + (kernel_size - 1) * int(sum(dilations))
sopro/streaming.py CHANGED
@@ -1,20 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
- import time
4
3
  from dataclasses import dataclass
5
- from typing import Iterator, List, Optional, Tuple
4
+ from typing import Iterator, List, Optional
6
5
 
7
6
  import torch
8
7
 
9
8
  from .codec.mimi import MimiDecodeState, MimiStreamDecoder
10
- from .model import SoproTTS, SoproTTSModel
9
+ from .model import PreparedReference, SoproTTS, SoproTTSModel
11
10
 
12
11
 
13
12
  @dataclass
14
13
  class StreamConfig:
15
14
  chunk_frames: int = 16
16
15
  nar_context_frames: Optional[int] = None
17
- cond_chunk_size: int = 32
18
16
 
19
17
 
20
18
  class SoproTTSStreamer:
@@ -30,33 +28,30 @@ class SoproTTSStreamer:
30
28
  *,
31
29
  ref_audio_path: Optional[str] = None,
32
30
  ref_tokens_tq: Optional[torch.Tensor] = None,
31
+ ref: Optional[PreparedReference] = None,
33
32
  max_frames: int = 400,
34
33
  top_p: float = 0.9,
35
34
  temperature: float = 1.05,
36
35
  anti_loop: bool = True,
37
- use_prefix: bool = True,
38
- prefix_sec_fixed: Optional[float] = None,
39
36
  style_strength: Optional[float] = None,
40
37
  ref_seconds: Optional[float] = None,
41
38
  chunk_frames: Optional[int] = None,
42
39
  nar_context_frames: Optional[int] = None,
43
- cond_chunk_size: Optional[int] = None,
44
- use_stop_head: Optional[bool] = None,
45
- stop_patience: Optional[int] = None,
46
- stop_threshold: Optional[float] = None,
47
40
  min_gen_frames: Optional[int] = None,
48
41
  ) -> Iterator[torch.Tensor]:
49
42
  model: SoproTTSModel = self.tts.model
50
43
  device = self.tts.device
51
44
 
52
45
  text_ids = self.tts.encode_text(text)
53
- ref = self.tts.encode_reference(
54
- ref_audio_path=ref_audio_path,
55
- ref_tokens_tq=ref_tokens_tq,
56
- ref_seconds=ref_seconds,
57
- )
58
46
 
59
- prep = model.prepare_conditioning_lazy(
47
+ if ref is None:
48
+ ref = self.tts.prepare_reference(
49
+ ref_audio_path=ref_audio_path,
50
+ ref_tokens_tq=ref_tokens_tq,
51
+ ref_seconds=ref_seconds,
52
+ )
53
+
54
+ prep = model.prepare_conditioning(
60
55
  text_ids,
61
56
  ref,
62
57
  max_frames=max_frames,
@@ -69,9 +64,6 @@ class SoproTTSStreamer:
69
64
  )
70
65
 
71
66
  cf = int(chunk_frames if chunk_frames is not None else self.cfg.chunk_frames)
72
- cond_cs = int(
73
- cond_chunk_size if cond_chunk_size is not None else self.cfg.cond_chunk_size
74
- )
75
67
 
76
68
  nar_ctx = (
77
69
  nar_context_frames
@@ -83,14 +75,11 @@ class SoproTTSStreamer:
83
75
  nar_ctx = int(nar_ctx)
84
76
 
85
77
  hist_A: List[int] = []
86
-
87
78
  frames_emitted = 0
88
-
89
79
  mimi_state = MimiDecodeState()
90
80
 
91
81
  def refine_and_emit(end: int) -> Optional[torch.Tensor]:
92
82
  nonlocal frames_emitted, mimi_state
93
-
94
83
  new_start = frames_emitted
95
84
  if end <= new_start:
96
85
  return None
@@ -98,7 +87,7 @@ class SoproTTSStreamer:
98
87
  win_start = max(0, new_start - nar_ctx)
99
88
  win_end = end
100
89
 
101
- cond_win = prep["cond_all"][:, win_start:win_end, :]
90
+ cond_win = prep["cond_ar"][:, win_start:win_end, :]
102
91
  tokens_A_win = torch.as_tensor(
103
92
  hist_A[win_start:win_end], device=device, dtype=torch.long
104
93
  ).unsqueeze(0)
@@ -114,30 +103,25 @@ class SoproTTSStreamer:
114
103
  frames_emitted = end
115
104
  return wav_chunk if wav_chunk.numel() > 0 else None
116
105
 
117
- for _t, rvq1_id, _p_stop in model.ar_stream(
106
+ for _t, tok, is_eos in model.ar_stream(
118
107
  prep,
119
108
  max_frames=max_frames,
120
109
  top_p=top_p,
121
110
  temperature=temperature,
122
111
  anti_loop=anti_loop,
123
- use_prefix=use_prefix,
124
- prefix_sec_fixed=prefix_sec_fixed,
125
- cond_chunk_size=cond_cs,
126
- use_stop_head=use_stop_head,
127
- stop_patience=stop_patience,
128
- stop_threshold=stop_threshold,
129
112
  min_gen_frames=min_gen_frames,
130
113
  ):
131
- hist_A.append(int(rvq1_id))
132
- T = len(hist_A)
114
+ if is_eos:
115
+ break
133
116
 
134
- is_boundary = (T % cf) == 0
135
- if not is_boundary and T < max_frames:
136
- continue
117
+ hist_A.append(int(tok))
118
+ T = len(hist_A)
137
119
 
138
- wav = refine_and_emit(T)
139
- if wav is not None:
140
- yield wav
120
+ boundary = (T % cf) == 0
121
+ if boundary:
122
+ wav = refine_and_emit(T)
123
+ if wav is not None:
124
+ yield wav
141
125
 
142
126
  T_final = len(hist_A)
143
127
  if frames_emitted < T_final:
@@ -146,12 +130,14 @@ class SoproTTSStreamer:
146
130
  yield wav
147
131
 
148
132
 
133
+ @torch.inference_mode()
149
134
  def stream(
150
135
  tts: SoproTTS,
151
136
  text: str,
152
137
  *,
153
138
  ref_audio_path: Optional[str] = None,
154
139
  ref_tokens_tq: Optional[torch.Tensor] = None,
140
+ ref: Optional[PreparedReference] = None,
155
141
  chunk_frames: int = 6,
156
142
  **kwargs,
157
143
  ) -> Iterator[torch.Tensor]:
@@ -160,6 +146,7 @@ def stream(
160
146
  text,
161
147
  ref_audio_path=ref_audio_path,
162
148
  ref_tokens_tq=ref_tokens_tq,
149
+ ref=ref,
163
150
  chunk_frames=chunk_frames,
164
151
  **kwargs,
165
152
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sopro
3
- Version: 1.0.1
3
+ Version: 1.5.0
4
4
  Summary: A lightweight text-to-speech model with zero-shot voice cloning.
5
5
  Author-email: Samuel Vitorino <samvitorino@gmail.com>
6
6
  License: Apache 2.0
@@ -27,14 +27,18 @@ https://github.com/user-attachments/assets/40254391-248f-45ff-b9a4-107d64fbb95f
27
27
 
28
28
  [![Alt Text](https://img.shields.io/badge/HuggingFace-Model-orange?logo=huggingface)](https://huggingface.co/samuel-vitorino/sopro)
29
29
 
30
+ ### 📰 News
31
+
32
+ **2026.02.04 – SoproTTS v1.5 is out: more stable, faster, and smaller. Trained for just $100, it reaches 250 ms TTFA streaming and 0.05 RTF (~20× realtime) on CPU.**
33
+
30
34
  Sopro (from the Portuguese word for “breath/blow”) is a lightweight English text-to-speech model I trained as a side project. Sopro is composed of dilated convs (à la WaveNet) and lightweight cross-attention layers, instead of the common Transformer architecture. Even though Sopro is not SOTA across most voices and situations, I still think it’s a cool project made with a very low budget (trained on a single L40S GPU), and it can be improved with better data.
31
35
 
32
36
  Some of the main features are:
33
37
 
34
- - **169M parameters**
38
+ - **147M parameters**
35
39
  - **Streaming**
36
40
  - **Zero-shot voice cloning**
37
- - **0.25 RTF on CPU** (measured on an M3 base model), meaning it generates 30 seconds of audio in 7.5 seconds
41
+ - **0.05 RTF on CPU** (measured on an M3 base model), meaning it generates 32 seconds of audio in 1.77 seconds
38
42
  - **3-12 seconds of reference audio** for voice cloning
39
43
 
40
44
  ---
@@ -53,7 +57,7 @@ conda activate soprotts
53
57
  ### From PyPI
54
58
 
55
59
  ```bash
56
- pip install sopro
60
+ pip install -U sopro
57
61
  ```
58
62
 
59
63
  ### From the repo
@@ -79,9 +83,7 @@ soprotts \
79
83
 
80
84
  You have the expected `temperature` and `top_p` parameters, alongside:
81
85
 
82
- - `--style_strength` (controls the FiLM strength; increasing it can improve or reduce voice similarity; default `1.0`)
83
- - `--no_stop_head` to disable early stopping
84
- - `--stop_threshold` and `--stop_patience` (number of consecutive frames that must be classified as final before **stopping**). For short sentences, the stop head may fail to trigger, in which case you can lower these values. Likewise, if the model stops before producing the full text, adjusting these parameters up can help.
86
+ - `--style_strength` (controls the FiLM strength; increasing it can improve or reduce voice similarity; default `1.2`)
85
87
 
86
88
  ### Python
87
89
 
@@ -119,6 +121,27 @@ wav = torch.cat(chunks, dim=-1)
119
121
  tts.save_wav("out_stream.wav", wav)
120
122
  ```
121
123
 
124
+ You can also precalculate the reference to reduce TTFA:
125
+
126
+ ```python
127
+ import torch
128
+ from sopro import SoproTTS
129
+
130
+ tts = SoproTTS.from_pretrained("samuel-vitorino/sopro", device="cpu")
131
+
132
+ ref = tts.prepare_reference(ref_audio_path="ref.mp3")
133
+
134
+ chunks = []
135
+ for chunk in tts.stream(
136
+ "Hello! This is a streaming Sopro TTS example.",
137
+ ref=ref,
138
+ ):
139
+ chunks.append(chunk.cpu())
140
+
141
+ wav = torch.cat(chunks, dim=-1)
142
+ tts.save_wav("out_stream.wav", wav)
143
+ ```
144
+
122
145
  ---
123
146
 
124
147
  ## Interactive streaming demo
@@ -0,0 +1,26 @@
1
+ sopro/__init__.py,sha256=OpqL73InBJ22Ja8QXeGLr09igFvKn-OXPj_smU9t98g,110
2
+ sopro/audio.py,sha256=xlp6aYzzGlOMcNZ-p9lDeeU0TUkSHMcvmLantwg_4-0,4162
3
+ sopro/cli.py,sha256=HKZ8CD7TtjdOPy7iOgilv1aplvWUb4jaTCEvBHE0Cmo,5108
4
+ sopro/config.py,sha256=CBTmHbsJs7hpf0mfyea5BWu-_PImL3WdSmUrzKvNC64,1052
5
+ sopro/constants.py,sha256=wSjFKeFIcLCxyVUVb3njxMK666IuxjlNzVT4_jfPovQ,97
6
+ sopro/hub.py,sha256=Axc19LlO3Vlo0sigJNDR42U6ByMtDOYvhRl_HicMMqU,1386
7
+ sopro/model.py,sha256=hhzbCP-PLe1NaZPC3lYcjWxHoqn7ignjfnYRuAOQl3s,18314
8
+ sopro/sampling.py,sha256=MXdP_oYcpW9Hf9vqaKuygOUz9VycZ7nOhIOXXfMobks,2930
9
+ sopro/streaming.py,sha256=iq_ukrktT6vPd1bIRhBg6yZuiXFahn2ZXJ6t1YM4lb0,4476
10
+ sopro/tokenizer.py,sha256=ucb86Jr-EaAyD9OHDoCmwB9Nh9AFIZK_TlZmMkv46KQ,1325
11
+ sopro/codec/__init__.py,sha256=6D6Q0M-SUZZnq79OT1nATenEc8zIZDrhZBpm7zdPEE4,129
12
+ sopro/codec/mimi.py,sha256=RNKnXfhWXUqHiU27C90wj18Rb3R2IZHpm5_cS_XAs9Y,5798
13
+ sopro/nn/__init__.py,sha256=48i83Bq5R2Z1q21TrxlZtyBgOBWnD2DmyU7qX-JHo9c,680
14
+ sopro/nn/blocks.py,sha256=QpRzwvzf4ea0JvHPlonfms2lRp93VRZI3Q9iE-ltldU,5814
15
+ sopro/nn/embeddings.py,sha256=UBIJiKFca3kGUBkCw3d2Iwt_zd0NgsBfZq4912KLTug,3844
16
+ sopro/nn/generator.py,sha256=Xnb4b9xeOYHlYWzXFjBPzxCKPdWCf0ZjWs6IJ7TkKy4,4354
17
+ sopro/nn/nar.py,sha256=Swz8TrnLecV-ODB1tsODJyFTqd3VbucGaAgjxrKb82I,3682
18
+ sopro/nn/ref.py,sha256=3QoxtY4MHAVNwofoBAty_-iuQSm9Hol03bOknsTiWl8,5385
19
+ sopro/nn/speaker.py,sha256=sVpVqJoIUo8Brhuk3VDSRyr7brxjpudr5aF9201kmvw,2815
20
+ sopro/nn/text.py,sha256=QdSXOOLOjDaRdiKoPFG7UD6t9MpqOYfLuihyrnqwgh0,4352
21
+ sopro-1.5.0.dist-info/licenses/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
22
+ sopro-1.5.0.dist-info/METADATA,sha256=LHe2O4Du_4cHRsmv9G0lWg4EfKMBFJgkr3eMkjTTh7c,6732
23
+ sopro-1.5.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
24
+ sopro-1.5.0.dist-info/entry_points.txt,sha256=OWcKgC5Syk8rzOhNzTZ3QR5GJEG88UfiShkovrwb2cI,44
25
+ sopro-1.5.0.dist-info/top_level.txt,sha256=Tik26_lEwzSKDuwQdqwoqA_O0b7CDATzousa0Q17PBo,6
26
+ sopro-1.5.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
sopro/nn/xattn.py DELETED
@@ -1,98 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Optional
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- from .blocks import RMSNorm
9
-
10
-
11
- def rms(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
12
- return torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps)
13
-
14
-
15
- class RefXAttnBlock(nn.Module):
16
- def __init__(self, d_model: int, heads: int = 2, dropout: float = 0.0):
17
- super().__init__()
18
- self.nq = RMSNorm(d_model)
19
- self.nkv = RMSNorm(d_model)
20
- self.attn = nn.MultiheadAttention(
21
- d_model, heads, batch_first=True, dropout=dropout
22
- )
23
- self.gate = nn.Parameter(torch.tensor(0.5))
24
- self.gmax = 0.35
25
-
26
- def forward(
27
- self,
28
- x: torch.Tensor,
29
- ref: torch.Tensor,
30
- key_padding_mask: Optional[torch.Tensor] = None,
31
- ) -> torch.Tensor:
32
- q = self.nq(x)
33
- kv = self.nkv(ref.float())
34
-
35
- with torch.autocast(device_type=q.device.type, enabled=False):
36
- a, _ = self.attn(
37
- q.float(), kv, kv, key_padding_mask=key_padding_mask, need_weights=False
38
- )
39
-
40
- a = torch.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
41
-
42
- rms_x = rms(x.float())
43
- rms_a = rms(a)
44
- scale = rms_x / rms_a
45
- a = (a * scale).to(x.dtype)
46
-
47
- gate_eff = (self.gmax * torch.tanh(self.gate)).to(x.dtype)
48
- return x + gate_eff * a
49
-
50
-
51
- class RefXAttn(nn.Module):
52
- def __init__(
53
- self, d_model: int, heads: int = 2, layers: int = 3, dropout: float = 0.0
54
- ):
55
- super().__init__()
56
- self.blocks = nn.ModuleList(
57
- [RefXAttnBlock(d_model, heads, dropout) for _ in range(layers)]
58
- )
59
-
60
- def forward(
61
- self,
62
- x: torch.Tensor,
63
- ref: torch.Tensor,
64
- key_padding_mask: Optional[torch.Tensor] = None,
65
- ) -> torch.Tensor:
66
- for blk in self.blocks:
67
- x = blk(x, ref, key_padding_mask)
68
- return x
69
-
70
-
71
- class TextXAttnBlock(nn.Module):
72
- def __init__(self, d_model: int, heads: int = 4, dropout: float = 0.0):
73
- super().__init__()
74
- self.nq = RMSNorm(d_model)
75
- self.nkv = RMSNorm(d_model)
76
- self.attn = nn.MultiheadAttention(
77
- d_model, num_heads=heads, dropout=dropout, batch_first=True
78
- )
79
- self.gate = nn.Parameter(torch.tensor(0.0))
80
-
81
- def forward(
82
- self,
83
- x: torch.Tensor,
84
- context: torch.Tensor,
85
- key_padding_mask: Optional[torch.Tensor] = None,
86
- ) -> torch.Tensor:
87
- q = self.nq(x)
88
- kv = self.nkv(context)
89
- with torch.autocast(device_type=q.device.type, enabled=False):
90
- out, _ = self.attn(
91
- q.float(),
92
- kv.float(),
93
- kv.float(),
94
- key_padding_mask=key_padding_mask,
95
- need_weights=False,
96
- )
97
- out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0).to(x.dtype)
98
- return x + torch.tanh(self.gate) * out
@@ -1,23 +0,0 @@
1
- sopro/__init__.py,sha256=NFZuESqdCL7bGXuTB8c61XxUJqhkHPUOSTqzH4pyUfU,110
2
- sopro/audio.py,sha256=xlp6aYzzGlOMcNZ-p9lDeeU0TUkSHMcvmLantwg_4-0,4162
3
- sopro/cli.py,sha256=YKfGalyhbRuvjVrGJuo1NlIC7h8CszlMxuTwhYgUSwQ,5751
4
- sopro/config.py,sha256=OBD-k2z5GUdjFS545MyBXx-dAGhwnhRG11LW-zQt1-g,1063
5
- sopro/constants.py,sha256=wSjFKeFIcLCxyVUVb3njxMK666IuxjlNzVT4_jfPovQ,97
6
- sopro/hub.py,sha256=xsHfeO8X7v__FELvaQxWHYG8P39ygrgbluPs5GQjoCM,1391
7
- sopro/model.py,sha256=YXwcVGN3v5T0kvKttmo9WNPpewF-b5aOZoTMVypkzO8,28624
8
- sopro/sampling.py,sha256=Q5rbuef_BIuy12cv5J7v6k9ob3zQ0OFJIlMHssOkiuU,2951
9
- sopro/streaming.py,sha256=O5Kkl4cUBjzgjTrEwQK2ka5h6sgcYaEZmIp66-obcPM,4975
10
- sopro/tokenizer.py,sha256=ucb86Jr-EaAyD9OHDoCmwB9Nh9AFIZK_TlZmMkv46KQ,1325
11
- sopro/codec/__init__.py,sha256=6D6Q0M-SUZZnq79OT1nATenEc8zIZDrhZBpm7zdPEE4,129
12
- sopro/codec/mimi.py,sha256=RNKnXfhWXUqHiU27C90wj18Rb3R2IZHpm5_cS_XAs9Y,5798
13
- sopro/nn/__init__.py,sha256=JewW6GvQPMBsCDkmnm9u5G3tvaAzClUVMIgcVH4N7aw,561
14
- sopro/nn/blocks.py,sha256=zDEVUH2LXapXuQ4DyhplNh1I0iJYrNUL20IxHoz8ucs,3221
15
- sopro/nn/embeddings.py,sha256=7YfYKj1v1oafTV4-iucJG4fmeT43fP_rQiJ6ACRKPNI,3185
16
- sopro/nn/speaker.py,sha256=L2bs-bPlyxoWZyMTctBBuMTaEWm6FP7K1udrXehnTGM,2964
17
- sopro/nn/xattn.py,sha256=OeRo1HbRZs0AkQ6AV6Q8cqYZP9K4vI-IwT3uVn9jOqg,2939
18
- sopro-1.0.1.dist-info/licenses/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
19
- sopro-1.0.1.dist-info/METADATA,sha256=tlq9mTTsNEFgMyCtle7om5hqKRm5LwrVCFLo4olQ3_s,6470
20
- sopro-1.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
21
- sopro-1.0.1.dist-info/entry_points.txt,sha256=OWcKgC5Syk8rzOhNzTZ3QR5GJEG88UfiShkovrwb2cI,44
22
- sopro-1.0.1.dist-info/top_level.txt,sha256=Tik26_lEwzSKDuwQdqwoqA_O0b7CDATzousa0Q17PBo,6
23
- sopro-1.0.1.dist-info/RECORD,,