sopro 1.0.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sopro/__init__.py +1 -1
- sopro/cli.py +31 -46
- sopro/config.py +15 -20
- sopro/hub.py +2 -3
- sopro/model.py +265 -535
- sopro/nn/__init__.py +7 -3
- sopro/nn/blocks.py +78 -0
- sopro/nn/embeddings.py +16 -0
- sopro/nn/generator.py +130 -0
- sopro/nn/nar.py +116 -0
- sopro/nn/ref.py +160 -0
- sopro/nn/speaker.py +14 -17
- sopro/nn/text.py +132 -0
- sopro/sampling.py +3 -3
- sopro/streaming.py +25 -38
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/METADATA +30 -7
- sopro-1.5.0.dist-info/RECORD +26 -0
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/WHEEL +1 -1
- sopro/nn/xattn.py +0 -98
- sopro-1.0.1.dist-info/RECORD +0 -23
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/entry_points.txt +0 -0
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/licenses/LICENSE.txt +0 -0
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/top_level.txt +0 -0
sopro/nn/text.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Dict, Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
from sopro.config import SoproTTSConfig
|
|
10
|
+
from sopro.nn.embeddings import SinusoidalPositionalEmbedding, TextEmbedding
|
|
11
|
+
from sopro.tokenizer import TextTokenizer
|
|
12
|
+
|
|
13
|
+
from .blocks import RMSNorm, SSMLiteBlock
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TextEncoder(nn.Module):
|
|
17
|
+
def __init__(
|
|
18
|
+
self, cfg: SoproTTSConfig, d_model: int, n_layers: int, tokenizer: TextTokenizer
|
|
19
|
+
):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.tok = tokenizer
|
|
22
|
+
self.embed = TextEmbedding(self.tok.vocab_size, d_model)
|
|
23
|
+
self.layers = nn.ModuleList(
|
|
24
|
+
[SSMLiteBlock(d_model, cfg.dropout, causal=False) for _ in range(n_layers)]
|
|
25
|
+
)
|
|
26
|
+
self.pos = SinusoidalPositionalEmbedding(d_model, max_len=cfg.max_text_len + 8)
|
|
27
|
+
self.norm = RMSNorm(d_model)
|
|
28
|
+
|
|
29
|
+
def forward(
|
|
30
|
+
self, text_ids: torch.Tensor, mask: torch.Tensor
|
|
31
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
32
|
+
x = self.embed(text_ids)
|
|
33
|
+
L = x.size(1)
|
|
34
|
+
pos = self.pos(torch.arange(L, device=x.device))
|
|
35
|
+
x = x + pos.unsqueeze(0)
|
|
36
|
+
|
|
37
|
+
x = x * mask.unsqueeze(-1).float()
|
|
38
|
+
for layer in self.layers:
|
|
39
|
+
x = layer(x)
|
|
40
|
+
x = self.norm(x)
|
|
41
|
+
|
|
42
|
+
mask_f = mask.float().unsqueeze(-1)
|
|
43
|
+
pooled = (x * mask_f).sum(dim=1) / (mask_f.sum(dim=1) + 1e-6)
|
|
44
|
+
return x, pooled
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class TextXAttnBlock(nn.Module):
|
|
48
|
+
def __init__(self, d_model: int, heads: int = 4, dropout: float = 0.0):
|
|
49
|
+
super().__init__()
|
|
50
|
+
assert d_model % heads == 0
|
|
51
|
+
|
|
52
|
+
self.d_model = int(d_model)
|
|
53
|
+
self.heads = int(heads)
|
|
54
|
+
self.head_dim = self.d_model // self.heads
|
|
55
|
+
self.dropout = float(dropout)
|
|
56
|
+
|
|
57
|
+
self.nq = RMSNorm(self.d_model)
|
|
58
|
+
self.nkv = RMSNorm(self.d_model)
|
|
59
|
+
|
|
60
|
+
self.q_proj = nn.Linear(self.d_model, self.d_model, bias=False)
|
|
61
|
+
self.k_proj = nn.Linear(self.d_model, self.d_model, bias=False)
|
|
62
|
+
self.v_proj = nn.Linear(self.d_model, self.d_model, bias=False)
|
|
63
|
+
self.out_proj = nn.Linear(self.d_model, self.d_model, bias=False)
|
|
64
|
+
|
|
65
|
+
self.gate = nn.Parameter(torch.tensor(0.0))
|
|
66
|
+
|
|
67
|
+
def _to_heads(self, t: torch.Tensor) -> torch.Tensor:
|
|
68
|
+
B, T, D = t.shape
|
|
69
|
+
return t.view(B, T, self.heads, self.head_dim).transpose(1, 2)
|
|
70
|
+
|
|
71
|
+
def _from_heads(self, t: torch.Tensor) -> torch.Tensor:
|
|
72
|
+
B, H, T, Hd = t.shape
|
|
73
|
+
return t.transpose(1, 2).contiguous().view(B, T, H * Hd)
|
|
74
|
+
|
|
75
|
+
def build_kv_cache(
|
|
76
|
+
self,
|
|
77
|
+
context: torch.Tensor,
|
|
78
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
79
|
+
) -> Dict[str, torch.Tensor]:
|
|
80
|
+
kv = self.nkv(context)
|
|
81
|
+
k = self._to_heads(self.k_proj(kv))
|
|
82
|
+
v = self._to_heads(self.v_proj(kv))
|
|
83
|
+
return {"k": k, "v": v, "key_padding_mask": key_padding_mask}
|
|
84
|
+
|
|
85
|
+
def forward(
|
|
86
|
+
self,
|
|
87
|
+
x: torch.Tensor,
|
|
88
|
+
context: Optional[torch.Tensor] = None,
|
|
89
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
90
|
+
kv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
|
91
|
+
use_cache: bool = False,
|
|
92
|
+
):
|
|
93
|
+
q = self.nq(x)
|
|
94
|
+
q = self._to_heads(self.q_proj(q))
|
|
95
|
+
|
|
96
|
+
if kv_cache is None:
|
|
97
|
+
if context is None:
|
|
98
|
+
raise ValueError("context must be provided when kv_cache is None")
|
|
99
|
+
kv_cache = self.build_kv_cache(context, key_padding_mask=key_padding_mask)
|
|
100
|
+
|
|
101
|
+
k = kv_cache["k"]
|
|
102
|
+
v = kv_cache["v"]
|
|
103
|
+
kpm = kv_cache.get("key_padding_mask", None)
|
|
104
|
+
|
|
105
|
+
attn_mask = None
|
|
106
|
+
if kpm is not None:
|
|
107
|
+
kpm = kpm.to(torch.bool)
|
|
108
|
+
|
|
109
|
+
keep = ~kpm
|
|
110
|
+
|
|
111
|
+
bad = ~keep.any(dim=1)
|
|
112
|
+
if bad.any():
|
|
113
|
+
keep = keep.clone()
|
|
114
|
+
keep[bad, 0] = True
|
|
115
|
+
|
|
116
|
+
attn_mask = keep[:, None, None, :]
|
|
117
|
+
|
|
118
|
+
with torch.autocast(device_type=x.device.type, enabled=False):
|
|
119
|
+
a = F.scaled_dot_product_attention(
|
|
120
|
+
q.float(),
|
|
121
|
+
k.float(),
|
|
122
|
+
v.float(),
|
|
123
|
+
attn_mask=attn_mask,
|
|
124
|
+
dropout_p=self.dropout if self.training else 0.0,
|
|
125
|
+
is_causal=False,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
a = torch.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0).to(x.dtype)
|
|
129
|
+
a = self.out_proj(self._from_heads(a))
|
|
130
|
+
|
|
131
|
+
y = x + torch.tanh(self.gate) * a
|
|
132
|
+
return (y, kv_cache) if use_cache else y
|
sopro/sampling.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import List
|
|
3
|
+
from typing import List, Tuple
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
@@ -97,5 +97,5 @@ def rf_ar(ar_kernel: int, dilations: Tuple[int, ...]) -> int:
|
|
|
97
97
|
return 1 + (ar_kernel - 1) * int(sum(dilations))
|
|
98
98
|
|
|
99
99
|
|
|
100
|
-
def rf_nar(
|
|
101
|
-
return 1 + (kernel_size - 1) * int(
|
|
100
|
+
def rf_nar(kernel_size: int, dilations: Tuple[int, ...]) -> int:
|
|
101
|
+
return 1 + (kernel_size - 1) * int(sum(dilations))
|
sopro/streaming.py
CHANGED
|
@@ -1,20 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import time
|
|
4
3
|
from dataclasses import dataclass
|
|
5
|
-
from typing import Iterator, List, Optional
|
|
4
|
+
from typing import Iterator, List, Optional
|
|
6
5
|
|
|
7
6
|
import torch
|
|
8
7
|
|
|
9
8
|
from .codec.mimi import MimiDecodeState, MimiStreamDecoder
|
|
10
|
-
from .model import SoproTTS, SoproTTSModel
|
|
9
|
+
from .model import PreparedReference, SoproTTS, SoproTTSModel
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
@dataclass
|
|
14
13
|
class StreamConfig:
|
|
15
14
|
chunk_frames: int = 16
|
|
16
15
|
nar_context_frames: Optional[int] = None
|
|
17
|
-
cond_chunk_size: int = 32
|
|
18
16
|
|
|
19
17
|
|
|
20
18
|
class SoproTTSStreamer:
|
|
@@ -30,33 +28,30 @@ class SoproTTSStreamer:
|
|
|
30
28
|
*,
|
|
31
29
|
ref_audio_path: Optional[str] = None,
|
|
32
30
|
ref_tokens_tq: Optional[torch.Tensor] = None,
|
|
31
|
+
ref: Optional[PreparedReference] = None,
|
|
33
32
|
max_frames: int = 400,
|
|
34
33
|
top_p: float = 0.9,
|
|
35
34
|
temperature: float = 1.05,
|
|
36
35
|
anti_loop: bool = True,
|
|
37
|
-
use_prefix: bool = True,
|
|
38
|
-
prefix_sec_fixed: Optional[float] = None,
|
|
39
36
|
style_strength: Optional[float] = None,
|
|
40
37
|
ref_seconds: Optional[float] = None,
|
|
41
38
|
chunk_frames: Optional[int] = None,
|
|
42
39
|
nar_context_frames: Optional[int] = None,
|
|
43
|
-
cond_chunk_size: Optional[int] = None,
|
|
44
|
-
use_stop_head: Optional[bool] = None,
|
|
45
|
-
stop_patience: Optional[int] = None,
|
|
46
|
-
stop_threshold: Optional[float] = None,
|
|
47
40
|
min_gen_frames: Optional[int] = None,
|
|
48
41
|
) -> Iterator[torch.Tensor]:
|
|
49
42
|
model: SoproTTSModel = self.tts.model
|
|
50
43
|
device = self.tts.device
|
|
51
44
|
|
|
52
45
|
text_ids = self.tts.encode_text(text)
|
|
53
|
-
ref = self.tts.encode_reference(
|
|
54
|
-
ref_audio_path=ref_audio_path,
|
|
55
|
-
ref_tokens_tq=ref_tokens_tq,
|
|
56
|
-
ref_seconds=ref_seconds,
|
|
57
|
-
)
|
|
58
46
|
|
|
59
|
-
|
|
47
|
+
if ref is None:
|
|
48
|
+
ref = self.tts.prepare_reference(
|
|
49
|
+
ref_audio_path=ref_audio_path,
|
|
50
|
+
ref_tokens_tq=ref_tokens_tq,
|
|
51
|
+
ref_seconds=ref_seconds,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
prep = model.prepare_conditioning(
|
|
60
55
|
text_ids,
|
|
61
56
|
ref,
|
|
62
57
|
max_frames=max_frames,
|
|
@@ -69,9 +64,6 @@ class SoproTTSStreamer:
|
|
|
69
64
|
)
|
|
70
65
|
|
|
71
66
|
cf = int(chunk_frames if chunk_frames is not None else self.cfg.chunk_frames)
|
|
72
|
-
cond_cs = int(
|
|
73
|
-
cond_chunk_size if cond_chunk_size is not None else self.cfg.cond_chunk_size
|
|
74
|
-
)
|
|
75
67
|
|
|
76
68
|
nar_ctx = (
|
|
77
69
|
nar_context_frames
|
|
@@ -83,14 +75,11 @@ class SoproTTSStreamer:
|
|
|
83
75
|
nar_ctx = int(nar_ctx)
|
|
84
76
|
|
|
85
77
|
hist_A: List[int] = []
|
|
86
|
-
|
|
87
78
|
frames_emitted = 0
|
|
88
|
-
|
|
89
79
|
mimi_state = MimiDecodeState()
|
|
90
80
|
|
|
91
81
|
def refine_and_emit(end: int) -> Optional[torch.Tensor]:
|
|
92
82
|
nonlocal frames_emitted, mimi_state
|
|
93
|
-
|
|
94
83
|
new_start = frames_emitted
|
|
95
84
|
if end <= new_start:
|
|
96
85
|
return None
|
|
@@ -98,7 +87,7 @@ class SoproTTSStreamer:
|
|
|
98
87
|
win_start = max(0, new_start - nar_ctx)
|
|
99
88
|
win_end = end
|
|
100
89
|
|
|
101
|
-
cond_win = prep["
|
|
90
|
+
cond_win = prep["cond_ar"][:, win_start:win_end, :]
|
|
102
91
|
tokens_A_win = torch.as_tensor(
|
|
103
92
|
hist_A[win_start:win_end], device=device, dtype=torch.long
|
|
104
93
|
).unsqueeze(0)
|
|
@@ -114,30 +103,25 @@ class SoproTTSStreamer:
|
|
|
114
103
|
frames_emitted = end
|
|
115
104
|
return wav_chunk if wav_chunk.numel() > 0 else None
|
|
116
105
|
|
|
117
|
-
for _t,
|
|
106
|
+
for _t, tok, is_eos in model.ar_stream(
|
|
118
107
|
prep,
|
|
119
108
|
max_frames=max_frames,
|
|
120
109
|
top_p=top_p,
|
|
121
110
|
temperature=temperature,
|
|
122
111
|
anti_loop=anti_loop,
|
|
123
|
-
use_prefix=use_prefix,
|
|
124
|
-
prefix_sec_fixed=prefix_sec_fixed,
|
|
125
|
-
cond_chunk_size=cond_cs,
|
|
126
|
-
use_stop_head=use_stop_head,
|
|
127
|
-
stop_patience=stop_patience,
|
|
128
|
-
stop_threshold=stop_threshold,
|
|
129
112
|
min_gen_frames=min_gen_frames,
|
|
130
113
|
):
|
|
131
|
-
|
|
132
|
-
|
|
114
|
+
if is_eos:
|
|
115
|
+
break
|
|
133
116
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
continue
|
|
117
|
+
hist_A.append(int(tok))
|
|
118
|
+
T = len(hist_A)
|
|
137
119
|
|
|
138
|
-
|
|
139
|
-
if
|
|
140
|
-
|
|
120
|
+
boundary = (T % cf) == 0
|
|
121
|
+
if boundary:
|
|
122
|
+
wav = refine_and_emit(T)
|
|
123
|
+
if wav is not None:
|
|
124
|
+
yield wav
|
|
141
125
|
|
|
142
126
|
T_final = len(hist_A)
|
|
143
127
|
if frames_emitted < T_final:
|
|
@@ -146,12 +130,14 @@ class SoproTTSStreamer:
|
|
|
146
130
|
yield wav
|
|
147
131
|
|
|
148
132
|
|
|
133
|
+
@torch.inference_mode()
|
|
149
134
|
def stream(
|
|
150
135
|
tts: SoproTTS,
|
|
151
136
|
text: str,
|
|
152
137
|
*,
|
|
153
138
|
ref_audio_path: Optional[str] = None,
|
|
154
139
|
ref_tokens_tq: Optional[torch.Tensor] = None,
|
|
140
|
+
ref: Optional[PreparedReference] = None,
|
|
155
141
|
chunk_frames: int = 6,
|
|
156
142
|
**kwargs,
|
|
157
143
|
) -> Iterator[torch.Tensor]:
|
|
@@ -160,6 +146,7 @@ def stream(
|
|
|
160
146
|
text,
|
|
161
147
|
ref_audio_path=ref_audio_path,
|
|
162
148
|
ref_tokens_tq=ref_tokens_tq,
|
|
149
|
+
ref=ref,
|
|
163
150
|
chunk_frames=chunk_frames,
|
|
164
151
|
**kwargs,
|
|
165
152
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sopro
|
|
3
|
-
Version: 1.0
|
|
3
|
+
Version: 1.5.0
|
|
4
4
|
Summary: A lightweight text-to-speech model with zero-shot voice cloning.
|
|
5
5
|
Author-email: Samuel Vitorino <samvitorino@gmail.com>
|
|
6
6
|
License: Apache 2.0
|
|
@@ -27,14 +27,18 @@ https://github.com/user-attachments/assets/40254391-248f-45ff-b9a4-107d64fbb95f
|
|
|
27
27
|
|
|
28
28
|
[](https://huggingface.co/samuel-vitorino/sopro)
|
|
29
29
|
|
|
30
|
+
### 📰 News
|
|
31
|
+
|
|
32
|
+
**2026.02.04 – SoproTTS v1.5 is out: more stable, faster, and smaller. Trained for just $100, it reaches 250 ms TTFA streaming and 0.05 RTF (~20× realtime) on CPU.**
|
|
33
|
+
|
|
30
34
|
Sopro (from the Portuguese word for “breath/blow”) is a lightweight English text-to-speech model I trained as a side project. Sopro is composed of dilated convs (à la WaveNet) and lightweight cross-attention layers, instead of the common Transformer architecture. Even though Sopro is not SOTA across most voices and situations, I still think it’s a cool project made with a very low budget (trained on a single L40S GPU), and it can be improved with better data.
|
|
31
35
|
|
|
32
36
|
Some of the main features are:
|
|
33
37
|
|
|
34
|
-
- **
|
|
38
|
+
- **147M parameters**
|
|
35
39
|
- **Streaming**
|
|
36
40
|
- **Zero-shot voice cloning**
|
|
37
|
-
- **0.
|
|
41
|
+
- **0.05 RTF on CPU** (measured on an M3 base model), meaning it generates 32 seconds of audio in 1.77 seconds
|
|
38
42
|
- **3-12 seconds of reference audio** for voice cloning
|
|
39
43
|
|
|
40
44
|
---
|
|
@@ -53,7 +57,7 @@ conda activate soprotts
|
|
|
53
57
|
### From PyPI
|
|
54
58
|
|
|
55
59
|
```bash
|
|
56
|
-
pip install sopro
|
|
60
|
+
pip install -U sopro
|
|
57
61
|
```
|
|
58
62
|
|
|
59
63
|
### From the repo
|
|
@@ -79,9 +83,7 @@ soprotts \
|
|
|
79
83
|
|
|
80
84
|
You have the expected `temperature` and `top_p` parameters, alongside:
|
|
81
85
|
|
|
82
|
-
- `--style_strength` (controls the FiLM strength; increasing it can improve or reduce voice similarity; default `1.
|
|
83
|
-
- `--no_stop_head` to disable early stopping
|
|
84
|
-
- `--stop_threshold` and `--stop_patience` (number of consecutive frames that must be classified as final before **stopping**). For short sentences, the stop head may fail to trigger, in which case you can lower these values. Likewise, if the model stops before producing the full text, adjusting these parameters up can help.
|
|
86
|
+
- `--style_strength` (controls the FiLM strength; increasing it can improve or reduce voice similarity; default `1.2`)
|
|
85
87
|
|
|
86
88
|
### Python
|
|
87
89
|
|
|
@@ -119,6 +121,27 @@ wav = torch.cat(chunks, dim=-1)
|
|
|
119
121
|
tts.save_wav("out_stream.wav", wav)
|
|
120
122
|
```
|
|
121
123
|
|
|
124
|
+
You can also precalculate the reference to reduce TTFA:
|
|
125
|
+
|
|
126
|
+
```python
|
|
127
|
+
import torch
|
|
128
|
+
from sopro import SoproTTS
|
|
129
|
+
|
|
130
|
+
tts = SoproTTS.from_pretrained("samuel-vitorino/sopro", device="cpu")
|
|
131
|
+
|
|
132
|
+
ref = tts.prepare_reference(ref_audio_path="ref.mp3")
|
|
133
|
+
|
|
134
|
+
chunks = []
|
|
135
|
+
for chunk in tts.stream(
|
|
136
|
+
"Hello! This is a streaming Sopro TTS example.",
|
|
137
|
+
ref=ref,
|
|
138
|
+
):
|
|
139
|
+
chunks.append(chunk.cpu())
|
|
140
|
+
|
|
141
|
+
wav = torch.cat(chunks, dim=-1)
|
|
142
|
+
tts.save_wav("out_stream.wav", wav)
|
|
143
|
+
```
|
|
144
|
+
|
|
122
145
|
---
|
|
123
146
|
|
|
124
147
|
## Interactive streaming demo
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
sopro/__init__.py,sha256=OpqL73InBJ22Ja8QXeGLr09igFvKn-OXPj_smU9t98g,110
|
|
2
|
+
sopro/audio.py,sha256=xlp6aYzzGlOMcNZ-p9lDeeU0TUkSHMcvmLantwg_4-0,4162
|
|
3
|
+
sopro/cli.py,sha256=HKZ8CD7TtjdOPy7iOgilv1aplvWUb4jaTCEvBHE0Cmo,5108
|
|
4
|
+
sopro/config.py,sha256=CBTmHbsJs7hpf0mfyea5BWu-_PImL3WdSmUrzKvNC64,1052
|
|
5
|
+
sopro/constants.py,sha256=wSjFKeFIcLCxyVUVb3njxMK666IuxjlNzVT4_jfPovQ,97
|
|
6
|
+
sopro/hub.py,sha256=Axc19LlO3Vlo0sigJNDR42U6ByMtDOYvhRl_HicMMqU,1386
|
|
7
|
+
sopro/model.py,sha256=hhzbCP-PLe1NaZPC3lYcjWxHoqn7ignjfnYRuAOQl3s,18314
|
|
8
|
+
sopro/sampling.py,sha256=MXdP_oYcpW9Hf9vqaKuygOUz9VycZ7nOhIOXXfMobks,2930
|
|
9
|
+
sopro/streaming.py,sha256=iq_ukrktT6vPd1bIRhBg6yZuiXFahn2ZXJ6t1YM4lb0,4476
|
|
10
|
+
sopro/tokenizer.py,sha256=ucb86Jr-EaAyD9OHDoCmwB9Nh9AFIZK_TlZmMkv46KQ,1325
|
|
11
|
+
sopro/codec/__init__.py,sha256=6D6Q0M-SUZZnq79OT1nATenEc8zIZDrhZBpm7zdPEE4,129
|
|
12
|
+
sopro/codec/mimi.py,sha256=RNKnXfhWXUqHiU27C90wj18Rb3R2IZHpm5_cS_XAs9Y,5798
|
|
13
|
+
sopro/nn/__init__.py,sha256=48i83Bq5R2Z1q21TrxlZtyBgOBWnD2DmyU7qX-JHo9c,680
|
|
14
|
+
sopro/nn/blocks.py,sha256=QpRzwvzf4ea0JvHPlonfms2lRp93VRZI3Q9iE-ltldU,5814
|
|
15
|
+
sopro/nn/embeddings.py,sha256=UBIJiKFca3kGUBkCw3d2Iwt_zd0NgsBfZq4912KLTug,3844
|
|
16
|
+
sopro/nn/generator.py,sha256=Xnb4b9xeOYHlYWzXFjBPzxCKPdWCf0ZjWs6IJ7TkKy4,4354
|
|
17
|
+
sopro/nn/nar.py,sha256=Swz8TrnLecV-ODB1tsODJyFTqd3VbucGaAgjxrKb82I,3682
|
|
18
|
+
sopro/nn/ref.py,sha256=3QoxtY4MHAVNwofoBAty_-iuQSm9Hol03bOknsTiWl8,5385
|
|
19
|
+
sopro/nn/speaker.py,sha256=sVpVqJoIUo8Brhuk3VDSRyr7brxjpudr5aF9201kmvw,2815
|
|
20
|
+
sopro/nn/text.py,sha256=QdSXOOLOjDaRdiKoPFG7UD6t9MpqOYfLuihyrnqwgh0,4352
|
|
21
|
+
sopro-1.5.0.dist-info/licenses/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
22
|
+
sopro-1.5.0.dist-info/METADATA,sha256=LHe2O4Du_4cHRsmv9G0lWg4EfKMBFJgkr3eMkjTTh7c,6732
|
|
23
|
+
sopro-1.5.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
24
|
+
sopro-1.5.0.dist-info/entry_points.txt,sha256=OWcKgC5Syk8rzOhNzTZ3QR5GJEG88UfiShkovrwb2cI,44
|
|
25
|
+
sopro-1.5.0.dist-info/top_level.txt,sha256=Tik26_lEwzSKDuwQdqwoqA_O0b7CDATzousa0Q17PBo,6
|
|
26
|
+
sopro-1.5.0.dist-info/RECORD,,
|
sopro/nn/xattn.py
DELETED
|
@@ -1,98 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from typing import Optional
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
import torch.nn as nn
|
|
7
|
-
|
|
8
|
-
from .blocks import RMSNorm
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def rms(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
|
|
12
|
-
return torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class RefXAttnBlock(nn.Module):
|
|
16
|
-
def __init__(self, d_model: int, heads: int = 2, dropout: float = 0.0):
|
|
17
|
-
super().__init__()
|
|
18
|
-
self.nq = RMSNorm(d_model)
|
|
19
|
-
self.nkv = RMSNorm(d_model)
|
|
20
|
-
self.attn = nn.MultiheadAttention(
|
|
21
|
-
d_model, heads, batch_first=True, dropout=dropout
|
|
22
|
-
)
|
|
23
|
-
self.gate = nn.Parameter(torch.tensor(0.5))
|
|
24
|
-
self.gmax = 0.35
|
|
25
|
-
|
|
26
|
-
def forward(
|
|
27
|
-
self,
|
|
28
|
-
x: torch.Tensor,
|
|
29
|
-
ref: torch.Tensor,
|
|
30
|
-
key_padding_mask: Optional[torch.Tensor] = None,
|
|
31
|
-
) -> torch.Tensor:
|
|
32
|
-
q = self.nq(x)
|
|
33
|
-
kv = self.nkv(ref.float())
|
|
34
|
-
|
|
35
|
-
with torch.autocast(device_type=q.device.type, enabled=False):
|
|
36
|
-
a, _ = self.attn(
|
|
37
|
-
q.float(), kv, kv, key_padding_mask=key_padding_mask, need_weights=False
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
a = torch.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
|
|
41
|
-
|
|
42
|
-
rms_x = rms(x.float())
|
|
43
|
-
rms_a = rms(a)
|
|
44
|
-
scale = rms_x / rms_a
|
|
45
|
-
a = (a * scale).to(x.dtype)
|
|
46
|
-
|
|
47
|
-
gate_eff = (self.gmax * torch.tanh(self.gate)).to(x.dtype)
|
|
48
|
-
return x + gate_eff * a
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
class RefXAttn(nn.Module):
|
|
52
|
-
def __init__(
|
|
53
|
-
self, d_model: int, heads: int = 2, layers: int = 3, dropout: float = 0.0
|
|
54
|
-
):
|
|
55
|
-
super().__init__()
|
|
56
|
-
self.blocks = nn.ModuleList(
|
|
57
|
-
[RefXAttnBlock(d_model, heads, dropout) for _ in range(layers)]
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
def forward(
|
|
61
|
-
self,
|
|
62
|
-
x: torch.Tensor,
|
|
63
|
-
ref: torch.Tensor,
|
|
64
|
-
key_padding_mask: Optional[torch.Tensor] = None,
|
|
65
|
-
) -> torch.Tensor:
|
|
66
|
-
for blk in self.blocks:
|
|
67
|
-
x = blk(x, ref, key_padding_mask)
|
|
68
|
-
return x
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
class TextXAttnBlock(nn.Module):
|
|
72
|
-
def __init__(self, d_model: int, heads: int = 4, dropout: float = 0.0):
|
|
73
|
-
super().__init__()
|
|
74
|
-
self.nq = RMSNorm(d_model)
|
|
75
|
-
self.nkv = RMSNorm(d_model)
|
|
76
|
-
self.attn = nn.MultiheadAttention(
|
|
77
|
-
d_model, num_heads=heads, dropout=dropout, batch_first=True
|
|
78
|
-
)
|
|
79
|
-
self.gate = nn.Parameter(torch.tensor(0.0))
|
|
80
|
-
|
|
81
|
-
def forward(
|
|
82
|
-
self,
|
|
83
|
-
x: torch.Tensor,
|
|
84
|
-
context: torch.Tensor,
|
|
85
|
-
key_padding_mask: Optional[torch.Tensor] = None,
|
|
86
|
-
) -> torch.Tensor:
|
|
87
|
-
q = self.nq(x)
|
|
88
|
-
kv = self.nkv(context)
|
|
89
|
-
with torch.autocast(device_type=q.device.type, enabled=False):
|
|
90
|
-
out, _ = self.attn(
|
|
91
|
-
q.float(),
|
|
92
|
-
kv.float(),
|
|
93
|
-
kv.float(),
|
|
94
|
-
key_padding_mask=key_padding_mask,
|
|
95
|
-
need_weights=False,
|
|
96
|
-
)
|
|
97
|
-
out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0).to(x.dtype)
|
|
98
|
-
return x + torch.tanh(self.gate) * out
|
sopro-1.0.1.dist-info/RECORD
DELETED
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
sopro/__init__.py,sha256=NFZuESqdCL7bGXuTB8c61XxUJqhkHPUOSTqzH4pyUfU,110
|
|
2
|
-
sopro/audio.py,sha256=xlp6aYzzGlOMcNZ-p9lDeeU0TUkSHMcvmLantwg_4-0,4162
|
|
3
|
-
sopro/cli.py,sha256=YKfGalyhbRuvjVrGJuo1NlIC7h8CszlMxuTwhYgUSwQ,5751
|
|
4
|
-
sopro/config.py,sha256=OBD-k2z5GUdjFS545MyBXx-dAGhwnhRG11LW-zQt1-g,1063
|
|
5
|
-
sopro/constants.py,sha256=wSjFKeFIcLCxyVUVb3njxMK666IuxjlNzVT4_jfPovQ,97
|
|
6
|
-
sopro/hub.py,sha256=xsHfeO8X7v__FELvaQxWHYG8P39ygrgbluPs5GQjoCM,1391
|
|
7
|
-
sopro/model.py,sha256=YXwcVGN3v5T0kvKttmo9WNPpewF-b5aOZoTMVypkzO8,28624
|
|
8
|
-
sopro/sampling.py,sha256=Q5rbuef_BIuy12cv5J7v6k9ob3zQ0OFJIlMHssOkiuU,2951
|
|
9
|
-
sopro/streaming.py,sha256=O5Kkl4cUBjzgjTrEwQK2ka5h6sgcYaEZmIp66-obcPM,4975
|
|
10
|
-
sopro/tokenizer.py,sha256=ucb86Jr-EaAyD9OHDoCmwB9Nh9AFIZK_TlZmMkv46KQ,1325
|
|
11
|
-
sopro/codec/__init__.py,sha256=6D6Q0M-SUZZnq79OT1nATenEc8zIZDrhZBpm7zdPEE4,129
|
|
12
|
-
sopro/codec/mimi.py,sha256=RNKnXfhWXUqHiU27C90wj18Rb3R2IZHpm5_cS_XAs9Y,5798
|
|
13
|
-
sopro/nn/__init__.py,sha256=JewW6GvQPMBsCDkmnm9u5G3tvaAzClUVMIgcVH4N7aw,561
|
|
14
|
-
sopro/nn/blocks.py,sha256=zDEVUH2LXapXuQ4DyhplNh1I0iJYrNUL20IxHoz8ucs,3221
|
|
15
|
-
sopro/nn/embeddings.py,sha256=7YfYKj1v1oafTV4-iucJG4fmeT43fP_rQiJ6ACRKPNI,3185
|
|
16
|
-
sopro/nn/speaker.py,sha256=L2bs-bPlyxoWZyMTctBBuMTaEWm6FP7K1udrXehnTGM,2964
|
|
17
|
-
sopro/nn/xattn.py,sha256=OeRo1HbRZs0AkQ6AV6Q8cqYZP9K4vI-IwT3uVn9jOqg,2939
|
|
18
|
-
sopro-1.0.1.dist-info/licenses/LICENSE.txt,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
19
|
-
sopro-1.0.1.dist-info/METADATA,sha256=tlq9mTTsNEFgMyCtle7om5hqKRm5LwrVCFLo4olQ3_s,6470
|
|
20
|
-
sopro-1.0.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
21
|
-
sopro-1.0.1.dist-info/entry_points.txt,sha256=OWcKgC5Syk8rzOhNzTZ3QR5GJEG88UfiShkovrwb2cI,44
|
|
22
|
-
sopro-1.0.1.dist-info/top_level.txt,sha256=Tik26_lEwzSKDuwQdqwoqA_O0b7CDATzousa0Q17PBo,6
|
|
23
|
-
sopro-1.0.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|