sopro 1.0.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sopro/__init__.py +1 -1
- sopro/cli.py +31 -46
- sopro/config.py +15 -20
- sopro/hub.py +2 -3
- sopro/model.py +265 -535
- sopro/nn/__init__.py +7 -3
- sopro/nn/blocks.py +78 -0
- sopro/nn/embeddings.py +16 -0
- sopro/nn/generator.py +130 -0
- sopro/nn/nar.py +116 -0
- sopro/nn/ref.py +160 -0
- sopro/nn/speaker.py +14 -17
- sopro/nn/text.py +132 -0
- sopro/sampling.py +3 -3
- sopro/streaming.py +25 -38
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/METADATA +30 -7
- sopro-1.5.0.dist-info/RECORD +26 -0
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/WHEEL +1 -1
- sopro/nn/xattn.py +0 -98
- sopro-1.0.1.dist-info/RECORD +0 -23
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/entry_points.txt +0 -0
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/licenses/LICENSE.txt +0 -0
- {sopro-1.0.1.dist-info → sopro-1.5.0.dist-info}/top_level.txt +0 -0
sopro/__init__.py
CHANGED
sopro/cli.py
CHANGED
|
@@ -32,8 +32,6 @@ def main() -> None:
|
|
|
32
32
|
ap.add_argument("--temperature", type=float, default=1.05)
|
|
33
33
|
ap.add_argument("--no_anti_loop", action="store_true")
|
|
34
34
|
|
|
35
|
-
ap.add_argument("--no_prefix", action="store_true")
|
|
36
|
-
ap.add_argument("--prefix_sec", type=float, default=None)
|
|
37
35
|
ap.add_argument("--style_strength", type=float, default=None)
|
|
38
36
|
ap.add_argument("--ref_seconds", type=float, default=None)
|
|
39
37
|
|
|
@@ -77,6 +75,7 @@ def main() -> None:
|
|
|
77
75
|
torch.cuda.manual_seed_all(args.seed)
|
|
78
76
|
|
|
79
77
|
t0 = time.perf_counter()
|
|
78
|
+
|
|
80
79
|
tts = SoproTTS.from_pretrained(
|
|
81
80
|
args.repo_id,
|
|
82
81
|
revision=args.revision,
|
|
@@ -84,6 +83,7 @@ def main() -> None:
|
|
|
84
83
|
token=args.hf_token,
|
|
85
84
|
device=device,
|
|
86
85
|
)
|
|
86
|
+
|
|
87
87
|
t1 = time.perf_counter()
|
|
88
88
|
if not args.quiet:
|
|
89
89
|
print(f"[Load] {t1 - t0:.2f}s")
|
|
@@ -97,74 +97,59 @@ def main() -> None:
|
|
|
97
97
|
arr = np.load(args.ref_tokens)
|
|
98
98
|
ref_tokens_tq = torch.from_numpy(arr).long()
|
|
99
99
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
100
|
+
with torch.inference_mode():
|
|
101
|
+
text_ids = tts.encode_text(args.text)
|
|
102
|
+
ref = tts.prepare_reference(
|
|
103
|
+
ref_audio_path=args.ref_audio,
|
|
104
|
+
ref_tokens_tq=ref_tokens_tq,
|
|
105
|
+
ref_seconds=args.ref_seconds,
|
|
106
|
+
)
|
|
106
107
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
108
|
+
prep = tts.model.prepare_conditioning(
|
|
109
|
+
text_ids,
|
|
110
|
+
ref,
|
|
111
|
+
max_frames=args.max_frames,
|
|
112
|
+
device=tts.device,
|
|
113
|
+
style_strength=float(
|
|
114
|
+
args.style_strength
|
|
115
|
+
if args.style_strength is not None
|
|
116
|
+
else cfg.style_strength
|
|
117
|
+
),
|
|
118
|
+
)
|
|
118
119
|
|
|
119
|
-
|
|
120
|
+
t_start = time.perf_counter()
|
|
120
121
|
|
|
121
122
|
hist_A: list[int] = []
|
|
122
123
|
pbar = tqdm(
|
|
123
|
-
total=args.max_frames,
|
|
124
|
-
desc="AR sampling",
|
|
125
|
-
unit="frame",
|
|
126
|
-
disable=args.quiet,
|
|
124
|
+
total=args.max_frames + 1, desc="AR sampling", unit="step", disable=args.quiet
|
|
127
125
|
)
|
|
128
126
|
|
|
129
|
-
for _t,
|
|
127
|
+
for _t, tok, is_eos in tts.model.ar_stream(
|
|
130
128
|
prep,
|
|
131
129
|
max_frames=args.max_frames,
|
|
132
130
|
top_p=args.top_p,
|
|
133
131
|
temperature=args.temperature,
|
|
134
132
|
anti_loop=(not args.no_anti_loop),
|
|
135
|
-
use_prefix=(not args.no_prefix),
|
|
136
|
-
prefix_sec_fixed=args.prefix_sec,
|
|
137
|
-
use_stop_head=(False if args.no_stop_head else None),
|
|
138
|
-
stop_patience=args.stop_patience,
|
|
139
|
-
stop_threshold=args.stop_threshold,
|
|
140
133
|
):
|
|
141
|
-
|
|
134
|
+
if is_eos:
|
|
135
|
+
pbar.set_postfix(eos="yes")
|
|
136
|
+
pbar.update(1)
|
|
137
|
+
break
|
|
138
|
+
hist_A.append(int(tok))
|
|
142
139
|
pbar.update(1)
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
else:
|
|
146
|
-
pbar.set_postfix(p_stop=f"{float(p_stop):.2f}")
|
|
140
|
+
|
|
141
|
+
t_after_sampling = time.perf_counter()
|
|
147
142
|
|
|
148
143
|
pbar.n = len(hist_A)
|
|
149
144
|
pbar.close()
|
|
150
145
|
|
|
151
|
-
t_after_sampling = time.perf_counter()
|
|
152
|
-
|
|
153
146
|
T = len(hist_A)
|
|
154
147
|
if T == 0:
|
|
155
148
|
save_audio(args.out, torch.zeros(1, 0), sr=TARGET_SR)
|
|
156
|
-
t_end = time.perf_counter()
|
|
157
|
-
if not args.quiet:
|
|
158
|
-
print(
|
|
159
|
-
f"[Timing] sampling={t_after_sampling - t_start:.2f}s, "
|
|
160
|
-
f"postproc+decode+save={t_end - t_after_sampling:.2f}s, "
|
|
161
|
-
f"total={t_end - t_start:.2f}s"
|
|
162
|
-
)
|
|
163
|
-
print(f"[Done] Wrote {args.out}")
|
|
164
149
|
return
|
|
165
150
|
|
|
166
151
|
tokens_A = torch.tensor(hist_A, device=tts.device, dtype=torch.long).unsqueeze(0)
|
|
167
|
-
cond_seq = prep["
|
|
152
|
+
cond_seq = prep["cond_ar"][:, :T, :]
|
|
168
153
|
tokens_1xTQ = tts.model.nar_refine(cond_seq, tokens_A)
|
|
169
154
|
tokens_tq = tokens_1xTQ.squeeze(0)
|
|
170
155
|
|
sopro/config.py
CHANGED
|
@@ -13,36 +13,31 @@ class SoproTTSConfig:
|
|
|
13
13
|
audio_sr: int = TARGET_SR
|
|
14
14
|
|
|
15
15
|
d_model: int = 384
|
|
16
|
-
n_layers_text: int =
|
|
17
|
-
n_layers_ar: int = 6
|
|
18
|
-
n_layers_nar: int = 6
|
|
16
|
+
n_layers_text: int = 2
|
|
19
17
|
dropout: float = 0.05
|
|
20
|
-
|
|
21
18
|
pos_emb_max: int = 4096
|
|
22
19
|
max_text_len: int = 2048
|
|
23
20
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
stop_patience: int = 5
|
|
21
|
+
n_layers_ar: int = 6
|
|
22
|
+
ar_kernel: int = 13
|
|
23
|
+
ar_dilation_cycle: Tuple[int, ...] = (1, 2, 4, 1)
|
|
24
|
+
ar_text_attn_freq: int = 2
|
|
29
25
|
min_gen_frames: int = 12
|
|
30
26
|
|
|
27
|
+
n_layers_nar: int = 6
|
|
28
|
+
nar_head_dim: int = 256
|
|
29
|
+
nar_kernel_size: int = 11
|
|
30
|
+
nar_dilation_cycle: Tuple[int, ...] = (1, 2, 4, 8)
|
|
31
|
+
|
|
31
32
|
stage_B: Tuple[int, int] = (2, 4)
|
|
32
33
|
stage_C: Tuple[int, int] = (5, 8)
|
|
33
34
|
stage_D: Tuple[int, int] = (9, 16)
|
|
34
35
|
stage_E: Tuple[int, int] = (17, 32)
|
|
35
36
|
|
|
36
|
-
ar_lookback: int = 4
|
|
37
|
-
ar_kernel: int = 13
|
|
38
|
-
ar_dilation_cycle: Tuple[int, ...] = (1, 2, 4, 1)
|
|
39
|
-
|
|
40
|
-
ar_text_attn_freq: int = 2
|
|
41
|
-
|
|
42
|
-
ref_attn_heads: int = 2
|
|
43
|
-
ref_seconds_max: float = 12.0
|
|
44
|
-
|
|
45
|
-
preprompt_sec_max: float = 4.0
|
|
46
|
-
|
|
47
37
|
sv_student_dim: int = 192
|
|
48
38
|
style_strength: float = 1.0
|
|
39
|
+
|
|
40
|
+
ref_enc_layers: int = 2
|
|
41
|
+
ref_xattn_heads: int = 2
|
|
42
|
+
ref_xattn_layers: int = 3
|
|
43
|
+
ref_xattn_gmax: float = 0.35
|
sopro/hub.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import os
|
|
4
5
|
import struct
|
|
5
6
|
from typing import Any, Dict, Optional
|
|
6
7
|
|
|
@@ -44,9 +45,7 @@ def load_cfg_from_safetensors(path: str) -> SoproTTSConfig:
|
|
|
44
45
|
for k in SoproTTSConfig.__annotations__.keys():
|
|
45
46
|
if k in cfg_dict:
|
|
46
47
|
init[k] = cfg_dict[k]
|
|
47
|
-
|
|
48
|
-
cfg = SoproTTSConfig(**init)
|
|
49
|
-
return cfg
|
|
48
|
+
return SoproTTSConfig(**init)
|
|
50
49
|
|
|
51
50
|
|
|
52
51
|
def load_state_dict_from_safetensors(path: str) -> Dict[str, torch.Tensor]:
|