sopro 1.0.2__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sopro/__init__.py CHANGED
@@ -3,4 +3,4 @@ from __future__ import annotations
3
3
  from .model import SoproTTS
4
4
 
5
5
  __all__ = ["SoproTTS"]
6
- __version__ = "1.0.2"
6
+ __version__ = "1.5.0"
sopro/cli.py CHANGED
@@ -32,8 +32,6 @@ def main() -> None:
32
32
  ap.add_argument("--temperature", type=float, default=1.05)
33
33
  ap.add_argument("--no_anti_loop", action="store_true")
34
34
 
35
- ap.add_argument("--no_prefix", action="store_true")
36
- ap.add_argument("--prefix_sec", type=float, default=None)
37
35
  ap.add_argument("--style_strength", type=float, default=None)
38
36
  ap.add_argument("--ref_seconds", type=float, default=None)
39
37
 
@@ -77,6 +75,7 @@ def main() -> None:
77
75
  torch.cuda.manual_seed_all(args.seed)
78
76
 
79
77
  t0 = time.perf_counter()
78
+
80
79
  tts = SoproTTS.from_pretrained(
81
80
  args.repo_id,
82
81
  revision=args.revision,
@@ -84,6 +83,7 @@ def main() -> None:
84
83
  token=args.hf_token,
85
84
  device=device,
86
85
  )
86
+
87
87
  t1 = time.perf_counter()
88
88
  if not args.quiet:
89
89
  print(f"[Load] {t1 - t0:.2f}s")
@@ -99,7 +99,7 @@ def main() -> None:
99
99
 
100
100
  with torch.inference_mode():
101
101
  text_ids = tts.encode_text(args.text)
102
- ref = tts.encode_reference(
102
+ ref = tts.prepare_reference(
103
103
  ref_audio_path=args.ref_audio,
104
104
  ref_tokens_tq=ref_tokens_tq,
105
105
  ref_seconds=args.ref_seconds,
@@ -119,58 +119,42 @@ def main() -> None:
119
119
 
120
120
  t_start = time.perf_counter()
121
121
 
122
- hist_A: list[int] = []
123
- pbar = tqdm(
124
- total=args.max_frames,
125
- desc="AR sampling",
126
- unit="frame",
127
- disable=args.quiet,
128
- )
122
+ hist_A: list[int] = []
123
+ pbar = tqdm(
124
+ total=args.max_frames + 1, desc="AR sampling", unit="step", disable=args.quiet
125
+ )
129
126
 
130
- for _t, rvq1, p_stop in tts.model.ar_stream(
131
- prep,
132
- max_frames=args.max_frames,
133
- top_p=args.top_p,
134
- temperature=args.temperature,
135
- anti_loop=(not args.no_anti_loop),
136
- use_prefix=(not args.no_prefix),
137
- prefix_sec_fixed=args.prefix_sec,
138
- use_stop_head=(False if args.no_stop_head else None),
139
- stop_patience=args.stop_patience,
140
- stop_threshold=args.stop_threshold,
141
- ):
142
- hist_A.append(int(rvq1))
127
+ for _t, tok, is_eos in tts.model.ar_stream(
128
+ prep,
129
+ max_frames=args.max_frames,
130
+ top_p=args.top_p,
131
+ temperature=args.temperature,
132
+ anti_loop=(not args.no_anti_loop),
133
+ ):
134
+ if is_eos:
135
+ pbar.set_postfix(eos="yes")
143
136
  pbar.update(1)
144
- if p_stop is None:
145
- pbar.set_postfix(p_stop="off")
146
- else:
147
- pbar.set_postfix(p_stop=f"{float(p_stop):.2f}")
148
-
149
- pbar.n = len(hist_A)
150
- pbar.close()
151
-
152
- t_after_sampling = time.perf_counter()
153
-
154
- T = len(hist_A)
155
- if T == 0:
156
- save_audio(args.out, torch.zeros(1, 0), sr=TARGET_SR)
157
- t_end = time.perf_counter()
158
- if not args.quiet:
159
- print(
160
- f"[Timing] sampling={t_after_sampling - t_start:.2f}s, "
161
- f"postproc+decode+save={t_end - t_after_sampling:.2f}s, "
162
- f"total={t_end - t_start:.2f}s"
163
- )
164
- print(f"[Done] Wrote {args.out}")
165
- return
166
-
167
- tokens_A = torch.tensor(hist_A, device=tts.device, dtype=torch.long).unsqueeze(0)
168
- cond_seq = prep["cond_all"][:, :T, :]
169
- tokens_1xTQ = tts.model.nar_refine(cond_seq, tokens_A)
170
- tokens_tq = tokens_1xTQ.squeeze(0)
171
-
172
- wav = tts.codec.decode_full(tokens_tq)
173
- save_audio(args.out, wav, sr=TARGET_SR)
137
+ break
138
+ hist_A.append(int(tok))
139
+ pbar.update(1)
140
+
141
+ t_after_sampling = time.perf_counter()
142
+
143
+ pbar.n = len(hist_A)
144
+ pbar.close()
145
+
146
+ T = len(hist_A)
147
+ if T == 0:
148
+ save_audio(args.out, torch.zeros(1, 0), sr=TARGET_SR)
149
+ return
150
+
151
+ tokens_A = torch.tensor(hist_A, device=tts.device, dtype=torch.long).unsqueeze(0)
152
+ cond_seq = prep["cond_ar"][:, :T, :]
153
+ tokens_1xTQ = tts.model.nar_refine(cond_seq, tokens_A)
154
+ tokens_tq = tokens_1xTQ.squeeze(0)
155
+
156
+ wav = tts.codec.decode_full(tokens_tq)
157
+ save_audio(args.out, wav, sr=TARGET_SR)
174
158
 
175
159
  t_end = time.perf_counter()
176
160
  if not args.quiet:
sopro/config.py CHANGED
@@ -13,36 +13,31 @@ class SoproTTSConfig:
13
13
  audio_sr: int = TARGET_SR
14
14
 
15
15
  d_model: int = 384
16
- n_layers_text: int = 4
17
- n_layers_ar: int = 6
18
- n_layers_nar: int = 6
16
+ n_layers_text: int = 2
19
17
  dropout: float = 0.05
20
-
21
18
  pos_emb_max: int = 4096
22
19
  max_text_len: int = 2048
23
20
 
24
- nar_head_dim: int = 256
25
-
26
- use_stop_head: bool = True
27
- stop_threshold: float = 0.8
28
- stop_patience: int = 5
21
+ n_layers_ar: int = 6
22
+ ar_kernel: int = 13
23
+ ar_dilation_cycle: Tuple[int, ...] = (1, 2, 4, 1)
24
+ ar_text_attn_freq: int = 2
29
25
  min_gen_frames: int = 12
30
26
 
27
+ n_layers_nar: int = 6
28
+ nar_head_dim: int = 256
29
+ nar_kernel_size: int = 11
30
+ nar_dilation_cycle: Tuple[int, ...] = (1, 2, 4, 8)
31
+
31
32
  stage_B: Tuple[int, int] = (2, 4)
32
33
  stage_C: Tuple[int, int] = (5, 8)
33
34
  stage_D: Tuple[int, int] = (9, 16)
34
35
  stage_E: Tuple[int, int] = (17, 32)
35
36
 
36
- ar_lookback: int = 4
37
- ar_kernel: int = 13
38
- ar_dilation_cycle: Tuple[int, ...] = (1, 2, 4, 1)
39
-
40
- ar_text_attn_freq: int = 2
41
-
42
- ref_attn_heads: int = 2
43
- ref_seconds_max: float = 12.0
44
-
45
- preprompt_sec_max: float = 4.0
46
-
47
37
  sv_student_dim: int = 192
48
38
  style_strength: float = 1.0
39
+
40
+ ref_enc_layers: int = 2
41
+ ref_xattn_heads: int = 2
42
+ ref_xattn_layers: int = 3
43
+ ref_xattn_gmax: float = 0.35
sopro/hub.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import json
4
+ import os
4
5
  import struct
5
6
  from typing import Any, Dict, Optional
6
7
 
@@ -44,9 +45,7 @@ def load_cfg_from_safetensors(path: str) -> SoproTTSConfig:
44
45
  for k in SoproTTSConfig.__annotations__.keys():
45
46
  if k in cfg_dict:
46
47
  init[k] = cfg_dict[k]
47
-
48
- cfg = SoproTTSConfig(**init)
49
- return cfg
48
+ return SoproTTSConfig(**init)
50
49
 
51
50
 
52
51
  def load_state_dict_from_safetensors(path: str) -> Dict[str, torch.Tensor]: