sopro 1.0.2__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sopro/__init__.py +1 -1
- sopro/cli.py +37 -53
- sopro/config.py +15 -20
- sopro/hub.py +2 -3
- sopro/model.py +264 -534
- sopro/nn/__init__.py +7 -3
- sopro/nn/blocks.py +78 -0
- sopro/nn/embeddings.py +16 -0
- sopro/nn/generator.py +130 -0
- sopro/nn/nar.py +116 -0
- sopro/nn/ref.py +160 -0
- sopro/nn/speaker.py +14 -17
- sopro/nn/text.py +132 -0
- sopro/sampling.py +3 -3
- sopro/streaming.py +25 -38
- {sopro-1.0.2.dist-info → sopro-1.5.0.dist-info}/METADATA +30 -7
- sopro-1.5.0.dist-info/RECORD +26 -0
- {sopro-1.0.2.dist-info → sopro-1.5.0.dist-info}/WHEEL +1 -1
- sopro/nn/xattn.py +0 -98
- sopro-1.0.2.dist-info/RECORD +0 -23
- {sopro-1.0.2.dist-info → sopro-1.5.0.dist-info}/entry_points.txt +0 -0
- {sopro-1.0.2.dist-info → sopro-1.5.0.dist-info}/licenses/LICENSE.txt +0 -0
- {sopro-1.0.2.dist-info → sopro-1.5.0.dist-info}/top_level.txt +0 -0
sopro/__init__.py
CHANGED
sopro/cli.py
CHANGED
|
@@ -32,8 +32,6 @@ def main() -> None:
|
|
|
32
32
|
ap.add_argument("--temperature", type=float, default=1.05)
|
|
33
33
|
ap.add_argument("--no_anti_loop", action="store_true")
|
|
34
34
|
|
|
35
|
-
ap.add_argument("--no_prefix", action="store_true")
|
|
36
|
-
ap.add_argument("--prefix_sec", type=float, default=None)
|
|
37
35
|
ap.add_argument("--style_strength", type=float, default=None)
|
|
38
36
|
ap.add_argument("--ref_seconds", type=float, default=None)
|
|
39
37
|
|
|
@@ -77,6 +75,7 @@ def main() -> None:
|
|
|
77
75
|
torch.cuda.manual_seed_all(args.seed)
|
|
78
76
|
|
|
79
77
|
t0 = time.perf_counter()
|
|
78
|
+
|
|
80
79
|
tts = SoproTTS.from_pretrained(
|
|
81
80
|
args.repo_id,
|
|
82
81
|
revision=args.revision,
|
|
@@ -84,6 +83,7 @@ def main() -> None:
|
|
|
84
83
|
token=args.hf_token,
|
|
85
84
|
device=device,
|
|
86
85
|
)
|
|
86
|
+
|
|
87
87
|
t1 = time.perf_counter()
|
|
88
88
|
if not args.quiet:
|
|
89
89
|
print(f"[Load] {t1 - t0:.2f}s")
|
|
@@ -99,7 +99,7 @@ def main() -> None:
|
|
|
99
99
|
|
|
100
100
|
with torch.inference_mode():
|
|
101
101
|
text_ids = tts.encode_text(args.text)
|
|
102
|
-
ref = tts.
|
|
102
|
+
ref = tts.prepare_reference(
|
|
103
103
|
ref_audio_path=args.ref_audio,
|
|
104
104
|
ref_tokens_tq=ref_tokens_tq,
|
|
105
105
|
ref_seconds=args.ref_seconds,
|
|
@@ -119,58 +119,42 @@ def main() -> None:
|
|
|
119
119
|
|
|
120
120
|
t_start = time.perf_counter()
|
|
121
121
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
unit="frame",
|
|
127
|
-
disable=args.quiet,
|
|
128
|
-
)
|
|
122
|
+
hist_A: list[int] = []
|
|
123
|
+
pbar = tqdm(
|
|
124
|
+
total=args.max_frames + 1, desc="AR sampling", unit="step", disable=args.quiet
|
|
125
|
+
)
|
|
129
126
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
stop_patience=args.stop_patience,
|
|
140
|
-
stop_threshold=args.stop_threshold,
|
|
141
|
-
):
|
|
142
|
-
hist_A.append(int(rvq1))
|
|
127
|
+
for _t, tok, is_eos in tts.model.ar_stream(
|
|
128
|
+
prep,
|
|
129
|
+
max_frames=args.max_frames,
|
|
130
|
+
top_p=args.top_p,
|
|
131
|
+
temperature=args.temperature,
|
|
132
|
+
anti_loop=(not args.no_anti_loop),
|
|
133
|
+
):
|
|
134
|
+
if is_eos:
|
|
135
|
+
pbar.set_postfix(eos="yes")
|
|
143
136
|
pbar.update(1)
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
return
|
|
166
|
-
|
|
167
|
-
tokens_A = torch.tensor(hist_A, device=tts.device, dtype=torch.long).unsqueeze(0)
|
|
168
|
-
cond_seq = prep["cond_all"][:, :T, :]
|
|
169
|
-
tokens_1xTQ = tts.model.nar_refine(cond_seq, tokens_A)
|
|
170
|
-
tokens_tq = tokens_1xTQ.squeeze(0)
|
|
171
|
-
|
|
172
|
-
wav = tts.codec.decode_full(tokens_tq)
|
|
173
|
-
save_audio(args.out, wav, sr=TARGET_SR)
|
|
137
|
+
break
|
|
138
|
+
hist_A.append(int(tok))
|
|
139
|
+
pbar.update(1)
|
|
140
|
+
|
|
141
|
+
t_after_sampling = time.perf_counter()
|
|
142
|
+
|
|
143
|
+
pbar.n = len(hist_A)
|
|
144
|
+
pbar.close()
|
|
145
|
+
|
|
146
|
+
T = len(hist_A)
|
|
147
|
+
if T == 0:
|
|
148
|
+
save_audio(args.out, torch.zeros(1, 0), sr=TARGET_SR)
|
|
149
|
+
return
|
|
150
|
+
|
|
151
|
+
tokens_A = torch.tensor(hist_A, device=tts.device, dtype=torch.long).unsqueeze(0)
|
|
152
|
+
cond_seq = prep["cond_ar"][:, :T, :]
|
|
153
|
+
tokens_1xTQ = tts.model.nar_refine(cond_seq, tokens_A)
|
|
154
|
+
tokens_tq = tokens_1xTQ.squeeze(0)
|
|
155
|
+
|
|
156
|
+
wav = tts.codec.decode_full(tokens_tq)
|
|
157
|
+
save_audio(args.out, wav, sr=TARGET_SR)
|
|
174
158
|
|
|
175
159
|
t_end = time.perf_counter()
|
|
176
160
|
if not args.quiet:
|
sopro/config.py
CHANGED
|
@@ -13,36 +13,31 @@ class SoproTTSConfig:
|
|
|
13
13
|
audio_sr: int = TARGET_SR
|
|
14
14
|
|
|
15
15
|
d_model: int = 384
|
|
16
|
-
n_layers_text: int =
|
|
17
|
-
n_layers_ar: int = 6
|
|
18
|
-
n_layers_nar: int = 6
|
|
16
|
+
n_layers_text: int = 2
|
|
19
17
|
dropout: float = 0.05
|
|
20
|
-
|
|
21
18
|
pos_emb_max: int = 4096
|
|
22
19
|
max_text_len: int = 2048
|
|
23
20
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
stop_patience: int = 5
|
|
21
|
+
n_layers_ar: int = 6
|
|
22
|
+
ar_kernel: int = 13
|
|
23
|
+
ar_dilation_cycle: Tuple[int, ...] = (1, 2, 4, 1)
|
|
24
|
+
ar_text_attn_freq: int = 2
|
|
29
25
|
min_gen_frames: int = 12
|
|
30
26
|
|
|
27
|
+
n_layers_nar: int = 6
|
|
28
|
+
nar_head_dim: int = 256
|
|
29
|
+
nar_kernel_size: int = 11
|
|
30
|
+
nar_dilation_cycle: Tuple[int, ...] = (1, 2, 4, 8)
|
|
31
|
+
|
|
31
32
|
stage_B: Tuple[int, int] = (2, 4)
|
|
32
33
|
stage_C: Tuple[int, int] = (5, 8)
|
|
33
34
|
stage_D: Tuple[int, int] = (9, 16)
|
|
34
35
|
stage_E: Tuple[int, int] = (17, 32)
|
|
35
36
|
|
|
36
|
-
ar_lookback: int = 4
|
|
37
|
-
ar_kernel: int = 13
|
|
38
|
-
ar_dilation_cycle: Tuple[int, ...] = (1, 2, 4, 1)
|
|
39
|
-
|
|
40
|
-
ar_text_attn_freq: int = 2
|
|
41
|
-
|
|
42
|
-
ref_attn_heads: int = 2
|
|
43
|
-
ref_seconds_max: float = 12.0
|
|
44
|
-
|
|
45
|
-
preprompt_sec_max: float = 4.0
|
|
46
|
-
|
|
47
37
|
sv_student_dim: int = 192
|
|
48
38
|
style_strength: float = 1.0
|
|
39
|
+
|
|
40
|
+
ref_enc_layers: int = 2
|
|
41
|
+
ref_xattn_heads: int = 2
|
|
42
|
+
ref_xattn_layers: int = 3
|
|
43
|
+
ref_xattn_gmax: float = 0.35
|
sopro/hub.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import os
|
|
4
5
|
import struct
|
|
5
6
|
from typing import Any, Dict, Optional
|
|
6
7
|
|
|
@@ -44,9 +45,7 @@ def load_cfg_from_safetensors(path: str) -> SoproTTSConfig:
|
|
|
44
45
|
for k in SoproTTSConfig.__annotations__.keys():
|
|
45
46
|
if k in cfg_dict:
|
|
46
47
|
init[k] = cfg_dict[k]
|
|
47
|
-
|
|
48
|
-
cfg = SoproTTSConfig(**init)
|
|
49
|
-
return cfg
|
|
48
|
+
return SoproTTSConfig(**init)
|
|
50
49
|
|
|
51
50
|
|
|
52
51
|
def load_state_dict_from_safetensors(path: str) -> Dict[str, torch.Tensor]:
|