sopro 1.0.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sopro/__init__.py CHANGED
@@ -3,4 +3,4 @@ from __future__ import annotations
3
3
  from .model import SoproTTS
4
4
 
5
5
  __all__ = ["SoproTTS"]
6
- __version__ = "1.0.1"
6
+ __version__ = "1.5.0"
sopro/cli.py CHANGED
@@ -32,8 +32,6 @@ def main() -> None:
32
32
  ap.add_argument("--temperature", type=float, default=1.05)
33
33
  ap.add_argument("--no_anti_loop", action="store_true")
34
34
 
35
- ap.add_argument("--no_prefix", action="store_true")
36
- ap.add_argument("--prefix_sec", type=float, default=None)
37
35
  ap.add_argument("--style_strength", type=float, default=None)
38
36
  ap.add_argument("--ref_seconds", type=float, default=None)
39
37
 
@@ -77,6 +75,7 @@ def main() -> None:
77
75
  torch.cuda.manual_seed_all(args.seed)
78
76
 
79
77
  t0 = time.perf_counter()
78
+
80
79
  tts = SoproTTS.from_pretrained(
81
80
  args.repo_id,
82
81
  revision=args.revision,
@@ -84,6 +83,7 @@ def main() -> None:
84
83
  token=args.hf_token,
85
84
  device=device,
86
85
  )
86
+
87
87
  t1 = time.perf_counter()
88
88
  if not args.quiet:
89
89
  print(f"[Load] {t1 - t0:.2f}s")
@@ -97,74 +97,59 @@ def main() -> None:
97
97
  arr = np.load(args.ref_tokens)
98
98
  ref_tokens_tq = torch.from_numpy(arr).long()
99
99
 
100
- text_ids = tts.encode_text(args.text)
101
- ref = tts.encode_reference(
102
- ref_audio_path=args.ref_audio,
103
- ref_tokens_tq=ref_tokens_tq,
104
- ref_seconds=args.ref_seconds,
105
- )
100
+ with torch.inference_mode():
101
+ text_ids = tts.encode_text(args.text)
102
+ ref = tts.prepare_reference(
103
+ ref_audio_path=args.ref_audio,
104
+ ref_tokens_tq=ref_tokens_tq,
105
+ ref_seconds=args.ref_seconds,
106
+ )
106
107
 
107
- prep = tts.model.prepare_conditioning(
108
- text_ids,
109
- ref,
110
- max_frames=args.max_frames,
111
- device=tts.device,
112
- style_strength=float(
113
- args.style_strength
114
- if args.style_strength is not None
115
- else cfg.style_strength
116
- ),
117
- )
108
+ prep = tts.model.prepare_conditioning(
109
+ text_ids,
110
+ ref,
111
+ max_frames=args.max_frames,
112
+ device=tts.device,
113
+ style_strength=float(
114
+ args.style_strength
115
+ if args.style_strength is not None
116
+ else cfg.style_strength
117
+ ),
118
+ )
118
119
 
119
- t_start = time.perf_counter()
120
+ t_start = time.perf_counter()
120
121
 
121
122
  hist_A: list[int] = []
122
123
  pbar = tqdm(
123
- total=args.max_frames,
124
- desc="AR sampling",
125
- unit="frame",
126
- disable=args.quiet,
124
+ total=args.max_frames + 1, desc="AR sampling", unit="step", disable=args.quiet
127
125
  )
128
126
 
129
- for _t, rvq1, p_stop in tts.model.ar_stream(
127
+ for _t, tok, is_eos in tts.model.ar_stream(
130
128
  prep,
131
129
  max_frames=args.max_frames,
132
130
  top_p=args.top_p,
133
131
  temperature=args.temperature,
134
132
  anti_loop=(not args.no_anti_loop),
135
- use_prefix=(not args.no_prefix),
136
- prefix_sec_fixed=args.prefix_sec,
137
- use_stop_head=(False if args.no_stop_head else None),
138
- stop_patience=args.stop_patience,
139
- stop_threshold=args.stop_threshold,
140
133
  ):
141
- hist_A.append(int(rvq1))
134
+ if is_eos:
135
+ pbar.set_postfix(eos="yes")
136
+ pbar.update(1)
137
+ break
138
+ hist_A.append(int(tok))
142
139
  pbar.update(1)
143
- if p_stop is None:
144
- pbar.set_postfix(p_stop="off")
145
- else:
146
- pbar.set_postfix(p_stop=f"{float(p_stop):.2f}")
140
+
141
+ t_after_sampling = time.perf_counter()
147
142
 
148
143
  pbar.n = len(hist_A)
149
144
  pbar.close()
150
145
 
151
- t_after_sampling = time.perf_counter()
152
-
153
146
  T = len(hist_A)
154
147
  if T == 0:
155
148
  save_audio(args.out, torch.zeros(1, 0), sr=TARGET_SR)
156
- t_end = time.perf_counter()
157
- if not args.quiet:
158
- print(
159
- f"[Timing] sampling={t_after_sampling - t_start:.2f}s, "
160
- f"postproc+decode+save={t_end - t_after_sampling:.2f}s, "
161
- f"total={t_end - t_start:.2f}s"
162
- )
163
- print(f"[Done] Wrote {args.out}")
164
149
  return
165
150
 
166
151
  tokens_A = torch.tensor(hist_A, device=tts.device, dtype=torch.long).unsqueeze(0)
167
- cond_seq = prep["cond_all"][:, :T, :]
152
+ cond_seq = prep["cond_ar"][:, :T, :]
168
153
  tokens_1xTQ = tts.model.nar_refine(cond_seq, tokens_A)
169
154
  tokens_tq = tokens_1xTQ.squeeze(0)
170
155
 
sopro/config.py CHANGED
@@ -13,36 +13,31 @@ class SoproTTSConfig:
13
13
  audio_sr: int = TARGET_SR
14
14
 
15
15
  d_model: int = 384
16
- n_layers_text: int = 4
17
- n_layers_ar: int = 6
18
- n_layers_nar: int = 6
16
+ n_layers_text: int = 2
19
17
  dropout: float = 0.05
20
-
21
18
  pos_emb_max: int = 4096
22
19
  max_text_len: int = 2048
23
20
 
24
- nar_head_dim: int = 256
25
-
26
- use_stop_head: bool = True
27
- stop_threshold: float = 0.8
28
- stop_patience: int = 5
21
+ n_layers_ar: int = 6
22
+ ar_kernel: int = 13
23
+ ar_dilation_cycle: Tuple[int, ...] = (1, 2, 4, 1)
24
+ ar_text_attn_freq: int = 2
29
25
  min_gen_frames: int = 12
30
26
 
27
+ n_layers_nar: int = 6
28
+ nar_head_dim: int = 256
29
+ nar_kernel_size: int = 11
30
+ nar_dilation_cycle: Tuple[int, ...] = (1, 2, 4, 8)
31
+
31
32
  stage_B: Tuple[int, int] = (2, 4)
32
33
  stage_C: Tuple[int, int] = (5, 8)
33
34
  stage_D: Tuple[int, int] = (9, 16)
34
35
  stage_E: Tuple[int, int] = (17, 32)
35
36
 
36
- ar_lookback: int = 4
37
- ar_kernel: int = 13
38
- ar_dilation_cycle: Tuple[int, ...] = (1, 2, 4, 1)
39
-
40
- ar_text_attn_freq: int = 2
41
-
42
- ref_attn_heads: int = 2
43
- ref_seconds_max: float = 12.0
44
-
45
- preprompt_sec_max: float = 4.0
46
-
47
37
  sv_student_dim: int = 192
48
38
  style_strength: float = 1.0
39
+
40
+ ref_enc_layers: int = 2
41
+ ref_xattn_heads: int = 2
42
+ ref_xattn_layers: int = 3
43
+ ref_xattn_gmax: float = 0.35
sopro/hub.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import json
4
+ import os
4
5
  import struct
5
6
  from typing import Any, Dict, Optional
6
7
 
@@ -44,9 +45,7 @@ def load_cfg_from_safetensors(path: str) -> SoproTTSConfig:
44
45
  for k in SoproTTSConfig.__annotations__.keys():
45
46
  if k in cfg_dict:
46
47
  init[k] = cfg_dict[k]
47
-
48
- cfg = SoproTTSConfig(**init)
49
- return cfg
48
+ return SoproTTSConfig(**init)
50
49
 
51
50
 
52
51
  def load_state_dict_from_safetensors(path: str) -> Dict[str, torch.Tensor]: