sopro 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sopro/__init__.py +6 -0
- sopro/audio.py +155 -0
- sopro/cli.py +185 -0
- sopro/codec/__init__.py +3 -0
- sopro/codec/mimi.py +181 -0
- sopro/config.py +48 -0
- sopro/constants.py +5 -0
- sopro/hub.py +53 -0
- sopro/model.py +853 -0
- sopro/nn/__init__.py +20 -0
- sopro/nn/blocks.py +110 -0
- sopro/nn/embeddings.py +96 -0
- sopro/nn/speaker.py +88 -0
- sopro/nn/xattn.py +98 -0
- sopro/sampling.py +101 -0
- sopro/streaming.py +165 -0
- sopro/tokenizer.py +38 -0
- sopro-1.0.0.dist-info/METADATA +182 -0
- sopro-1.0.0.dist-info/RECORD +23 -0
- sopro-1.0.0.dist-info/WHEEL +5 -0
- sopro-1.0.0.dist-info/entry_points.txt +2 -0
- sopro-1.0.0.dist-info/licenses/LICENSE.txt +201 -0
- sopro-1.0.0.dist-info/top_level.txt +1 -0
sopro/__init__.py
ADDED
sopro/audio.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
from .constants import TARGET_SR
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
import torchaudio
|
|
13
|
+
import torchaudio.functional as AF
|
|
14
|
+
except Exception:
|
|
15
|
+
torchaudio = None
|
|
16
|
+
AF = None
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import soundfile as sf
|
|
20
|
+
except Exception:
|
|
21
|
+
sf = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def device_str() -> str:
|
|
25
|
+
if torch.cuda.is_available():
|
|
26
|
+
return "cuda"
|
|
27
|
+
return "cpu"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def trim_silence_energy(
|
|
31
|
+
wav: torch.Tensor,
|
|
32
|
+
sr: int,
|
|
33
|
+
frame_ms: float = 25.0,
|
|
34
|
+
hop_ms: float = 10.0,
|
|
35
|
+
thresh_db_floor: float = -40.0,
|
|
36
|
+
prepad_ms: float = 30.0,
|
|
37
|
+
postpad_ms: float = 30.0,
|
|
38
|
+
min_keep_sec: float = 0.5,
|
|
39
|
+
) -> torch.Tensor:
|
|
40
|
+
orig_1d = wav.ndim == 1
|
|
41
|
+
if orig_1d:
|
|
42
|
+
wav = wav.unsqueeze(0)
|
|
43
|
+
|
|
44
|
+
C, T = wav.shape
|
|
45
|
+
if T == 0:
|
|
46
|
+
return wav.squeeze(0) if orig_1d else wav
|
|
47
|
+
if T < int(sr * 0.1):
|
|
48
|
+
return wav.squeeze(0) if orig_1d else wav
|
|
49
|
+
|
|
50
|
+
frame_len = max(1, int(sr * frame_ms / 1000.0))
|
|
51
|
+
hop = max(1, int(sr * hop_ms / 1000.0))
|
|
52
|
+
if T < frame_len:
|
|
53
|
+
return wav.squeeze(0) if orig_1d else wav
|
|
54
|
+
|
|
55
|
+
mono = wav.mean(dim=0, keepdim=True)
|
|
56
|
+
frames = mono.unfold(-1, frame_len, hop)
|
|
57
|
+
energy = frames.pow(2).mean(dim=-1).squeeze(0)
|
|
58
|
+
|
|
59
|
+
eps = 1e-10
|
|
60
|
+
energy_db = 10.0 * torch.log10(energy + eps)
|
|
61
|
+
max_db = float(energy_db.max().item())
|
|
62
|
+
|
|
63
|
+
rel_thresh = max_db + thresh_db_floor
|
|
64
|
+
thresh_db = max(rel_thresh, thresh_db_floor)
|
|
65
|
+
|
|
66
|
+
voiced = energy_db > thresh_db
|
|
67
|
+
idx = torch.nonzero(voiced, as_tuple=False)
|
|
68
|
+
if idx.numel() == 0:
|
|
69
|
+
return wav.squeeze(0) if orig_1d else wav
|
|
70
|
+
|
|
71
|
+
first_frame = int(idx[0, 0].item())
|
|
72
|
+
last_frame = int(idx[-1, 0].item())
|
|
73
|
+
|
|
74
|
+
prepad_samples = int(sr * prepad_ms / 1000.0)
|
|
75
|
+
postpad_samples = int(sr * postpad_ms / 1000.0)
|
|
76
|
+
|
|
77
|
+
start = max(0, first_frame * hop - prepad_samples)
|
|
78
|
+
|
|
79
|
+
end = min(T, last_frame * hop + frame_len + postpad_samples)
|
|
80
|
+
|
|
81
|
+
min_keep = int(min_keep_sec * sr)
|
|
82
|
+
if end <= start or (end - start) < min_keep:
|
|
83
|
+
return wav.squeeze(0) if orig_1d else wav
|
|
84
|
+
|
|
85
|
+
out = wav[:, start:end]
|
|
86
|
+
return out.squeeze(0) if orig_1d else out
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def load_audio_file(path: str) -> Tuple[torch.Tensor, int]:
|
|
90
|
+
if sf is not None:
|
|
91
|
+
wav_np, sr = sf.read(path, dtype="float32", always_2d=True)
|
|
92
|
+
wav = torch.from_numpy(wav_np).transpose(0, 1)
|
|
93
|
+
elif torchaudio is not None:
|
|
94
|
+
wav, sr = torchaudio.load(path)
|
|
95
|
+
if wav.dtype != torch.float32:
|
|
96
|
+
if wav.dtype == torch.int16:
|
|
97
|
+
wav = wav.float() / (2**15)
|
|
98
|
+
else:
|
|
99
|
+
wav = wav.float()
|
|
100
|
+
else:
|
|
101
|
+
raise RuntimeError("Install 'soundfile' or 'torchaudio' to read audio.")
|
|
102
|
+
|
|
103
|
+
if wav.size(0) > 1:
|
|
104
|
+
wav = wav.mean(dim=0, keepdim=True)
|
|
105
|
+
return wav, sr
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def resample(
|
|
109
|
+
wav: torch.Tensor, sr_in: int, sr_out: int, device: Optional[str] = None
|
|
110
|
+
) -> torch.Tensor:
|
|
111
|
+
device = device or device_str()
|
|
112
|
+
wav = wav.to(device)
|
|
113
|
+
if sr_in == sr_out:
|
|
114
|
+
return wav
|
|
115
|
+
if sr_in != sr_out and AF is None:
|
|
116
|
+
raise RuntimeError("Resampling requires torchaudio. pip install torchaudio")
|
|
117
|
+
return AF.resample(wav, sr_in, sr_out)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def save_audio(path: str, wav: torch.Tensor, sr: int = TARGET_SR) -> None:
|
|
121
|
+
os.makedirs(os.path.dirname(os.path.abspath(path)) or ".", exist_ok=True)
|
|
122
|
+
|
|
123
|
+
wav = wav.detach().cpu()
|
|
124
|
+
|
|
125
|
+
if wav.ndim == 1:
|
|
126
|
+
wav = wav.unsqueeze(0)
|
|
127
|
+
elif wav.ndim == 2:
|
|
128
|
+
pass
|
|
129
|
+
elif wav.ndim == 3:
|
|
130
|
+
wav = wav[0]
|
|
131
|
+
else:
|
|
132
|
+
raise ValueError(f"Expected wav with 1-3 dims, got shape {tuple(wav.shape)}")
|
|
133
|
+
|
|
134
|
+
if wav.size(0) > 1:
|
|
135
|
+
wav = wav.mean(dim=0, keepdim=True)
|
|
136
|
+
|
|
137
|
+
if sf is not None:
|
|
138
|
+
sf.write(path, wav[0].numpy(), sr)
|
|
139
|
+
return
|
|
140
|
+
|
|
141
|
+
if torchaudio is not None:
|
|
142
|
+
torchaudio.save(path, wav, sample_rate=sr)
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
raise RuntimeError("Install 'soundfile' or 'torchaudio' to write audio.")
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def center_crop_audio(wav: torch.Tensor, win_samples: int) -> torch.Tensor:
|
|
149
|
+
if win_samples <= 0:
|
|
150
|
+
return wav
|
|
151
|
+
T = int(wav.shape[-1])
|
|
152
|
+
if T <= win_samples:
|
|
153
|
+
return wav
|
|
154
|
+
s = (T - win_samples) // 2
|
|
155
|
+
return wav[..., s : s + win_samples]
|
sopro/cli.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import time
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from tqdm.auto import tqdm
|
|
8
|
+
|
|
9
|
+
from .audio import save_audio
|
|
10
|
+
from .constants import TARGET_SR
|
|
11
|
+
from .model import SoproTTS
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def main() -> None:
|
|
15
|
+
ap = argparse.ArgumentParser(description="SoproTTS cli inference")
|
|
16
|
+
|
|
17
|
+
ap.add_argument("--repo_id", type=str, default="samuel-vitorino/sopro")
|
|
18
|
+
ap.add_argument(
|
|
19
|
+
"--revision", type=str, default=None, help="Optional git revision/branch/tag"
|
|
20
|
+
)
|
|
21
|
+
ap.add_argument("--cache_dir", type=str, default=None, help="Optional HF cache dir")
|
|
22
|
+
ap.add_argument("--hf_token", type=str, default=None, help="HF token")
|
|
23
|
+
|
|
24
|
+
ap.add_argument("--text", type=str, required=True)
|
|
25
|
+
ap.add_argument("--ref_audio", type=str, default=None)
|
|
26
|
+
ap.add_argument("--ref_tokens", type=str, default=None, help="Path to .npy [T,Q]")
|
|
27
|
+
ap.add_argument("--out", type=str, required=True)
|
|
28
|
+
|
|
29
|
+
ap.add_argument("--max_frames", type=int, default=400)
|
|
30
|
+
|
|
31
|
+
ap.add_argument("--top_p", type=float, default=0.9)
|
|
32
|
+
ap.add_argument("--temperature", type=float, default=1.05)
|
|
33
|
+
ap.add_argument("--no_anti_loop", action="store_true")
|
|
34
|
+
|
|
35
|
+
ap.add_argument("--no_prefix", action="store_true")
|
|
36
|
+
ap.add_argument("--prefix_sec", type=float, default=None)
|
|
37
|
+
ap.add_argument("--style_strength", type=float, default=None)
|
|
38
|
+
ap.add_argument("--ref_seconds", type=float, default=None)
|
|
39
|
+
|
|
40
|
+
ap.add_argument("--seed", type=int, default=None, help="Random seed for sampling")
|
|
41
|
+
|
|
42
|
+
ap.add_argument(
|
|
43
|
+
"--no_stop_head", action="store_true", help="Disable stop head early stopping"
|
|
44
|
+
)
|
|
45
|
+
ap.add_argument(
|
|
46
|
+
"--stop_patience", type=int, default=None, help="Override cfg.stop_patience"
|
|
47
|
+
)
|
|
48
|
+
ap.add_argument(
|
|
49
|
+
"--stop_threshold", type=float, default=None, help="Override cfg.stop_threshold"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
ap.add_argument(
|
|
53
|
+
"--device",
|
|
54
|
+
type=str,
|
|
55
|
+
default=None,
|
|
56
|
+
choices=["cpu", "cuda", "mps"],
|
|
57
|
+
help="Device to run on (default: cuda if available else cpu)",
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
ap.add_argument("--quiet", action="store_true")
|
|
61
|
+
|
|
62
|
+
args = ap.parse_args()
|
|
63
|
+
|
|
64
|
+
default_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
65
|
+
device = args.device or default_device
|
|
66
|
+
|
|
67
|
+
if device == "cuda" and not torch.cuda.is_available():
|
|
68
|
+
raise SystemExit("Error: --device cuda requested but CUDA is not available.")
|
|
69
|
+
if device == "mps" and not (
|
|
70
|
+
hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
|
71
|
+
):
|
|
72
|
+
raise SystemExit("Error: --device mps requested but MPS is not available.")
|
|
73
|
+
|
|
74
|
+
if args.seed is not None:
|
|
75
|
+
torch.manual_seed(args.seed)
|
|
76
|
+
if torch.cuda.is_available():
|
|
77
|
+
torch.cuda.manual_seed_all(args.seed)
|
|
78
|
+
|
|
79
|
+
t0 = time.perf_counter()
|
|
80
|
+
tts = SoproTTS.from_pretrained(
|
|
81
|
+
args.repo_id,
|
|
82
|
+
revision=args.revision,
|
|
83
|
+
cache_dir=args.cache_dir,
|
|
84
|
+
token=args.hf_token,
|
|
85
|
+
device=device,
|
|
86
|
+
)
|
|
87
|
+
t1 = time.perf_counter()
|
|
88
|
+
if not args.quiet:
|
|
89
|
+
print(f"[Load] {t1 - t0:.2f}s")
|
|
90
|
+
|
|
91
|
+
cfg = tts.cfg
|
|
92
|
+
|
|
93
|
+
ref_tokens_tq = None
|
|
94
|
+
if args.ref_tokens is not None:
|
|
95
|
+
import numpy as np
|
|
96
|
+
|
|
97
|
+
arr = np.load(args.ref_tokens)
|
|
98
|
+
ref_tokens_tq = torch.from_numpy(arr).long()
|
|
99
|
+
|
|
100
|
+
text_ids = tts.encode_text(args.text)
|
|
101
|
+
ref = tts.encode_reference(
|
|
102
|
+
ref_audio_path=args.ref_audio,
|
|
103
|
+
ref_tokens_tq=ref_tokens_tq,
|
|
104
|
+
ref_seconds=args.ref_seconds,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
prep = tts.model.prepare_conditioning(
|
|
108
|
+
text_ids,
|
|
109
|
+
ref,
|
|
110
|
+
max_frames=args.max_frames,
|
|
111
|
+
device=tts.device,
|
|
112
|
+
style_strength=float(
|
|
113
|
+
args.style_strength
|
|
114
|
+
if args.style_strength is not None
|
|
115
|
+
else cfg.style_strength
|
|
116
|
+
),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
t_start = time.perf_counter()
|
|
120
|
+
|
|
121
|
+
hist_A: list[int] = []
|
|
122
|
+
pbar = tqdm(
|
|
123
|
+
total=args.max_frames,
|
|
124
|
+
desc="AR sampling",
|
|
125
|
+
unit="frame",
|
|
126
|
+
disable=args.quiet,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
for _t, rvq1, p_stop in tts.model.ar_stream(
|
|
130
|
+
prep,
|
|
131
|
+
max_frames=args.max_frames,
|
|
132
|
+
top_p=args.top_p,
|
|
133
|
+
temperature=args.temperature,
|
|
134
|
+
anti_loop=(not args.no_anti_loop),
|
|
135
|
+
use_prefix=(not args.no_prefix),
|
|
136
|
+
prefix_sec_fixed=args.prefix_sec,
|
|
137
|
+
use_stop_head=(False if args.no_stop_head else None),
|
|
138
|
+
stop_patience=args.stop_patience,
|
|
139
|
+
stop_threshold=args.stop_threshold,
|
|
140
|
+
):
|
|
141
|
+
hist_A.append(int(rvq1))
|
|
142
|
+
pbar.update(1)
|
|
143
|
+
if p_stop is None:
|
|
144
|
+
pbar.set_postfix(p_stop="off")
|
|
145
|
+
else:
|
|
146
|
+
pbar.set_postfix(p_stop=f"{float(p_stop):.2f}")
|
|
147
|
+
|
|
148
|
+
pbar.n = len(hist_A)
|
|
149
|
+
pbar.close()
|
|
150
|
+
|
|
151
|
+
t_after_sampling = time.perf_counter()
|
|
152
|
+
|
|
153
|
+
T = len(hist_A)
|
|
154
|
+
if T == 0:
|
|
155
|
+
save_audio(args.out, torch.zeros(1, 0), sr=TARGET_SR)
|
|
156
|
+
t_end = time.perf_counter()
|
|
157
|
+
if not args.quiet:
|
|
158
|
+
print(
|
|
159
|
+
f"[Timing] sampling={t_after_sampling - t_start:.2f}s, "
|
|
160
|
+
f"postproc+decode+save={t_end - t_after_sampling:.2f}s, "
|
|
161
|
+
f"total={t_end - t_start:.2f}s"
|
|
162
|
+
)
|
|
163
|
+
print(f"[Done] Wrote {args.out}")
|
|
164
|
+
return
|
|
165
|
+
|
|
166
|
+
tokens_A = torch.tensor(hist_A, device=tts.device, dtype=torch.long).unsqueeze(0)
|
|
167
|
+
cond_seq = prep["cond_all"][:, :T, :]
|
|
168
|
+
tokens_1xTQ = tts.model.nar_refine(cond_seq, tokens_A)
|
|
169
|
+
tokens_tq = tokens_1xTQ.squeeze(0)
|
|
170
|
+
|
|
171
|
+
wav = tts.codec.decode_full(tokens_tq)
|
|
172
|
+
save_audio(args.out, wav, sr=TARGET_SR)
|
|
173
|
+
|
|
174
|
+
t_end = time.perf_counter()
|
|
175
|
+
if not args.quiet:
|
|
176
|
+
print(
|
|
177
|
+
f"[Timing] sampling={t_after_sampling - t_start:.2f}s, "
|
|
178
|
+
f"postproc+decode+save={t_end - t_after_sampling:.2f}s, "
|
|
179
|
+
f"total={t_end - t_start:.2f}s"
|
|
180
|
+
)
|
|
181
|
+
print(f"[Done] Wrote {args.out}")
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
if __name__ == "__main__":
|
|
185
|
+
main()
|
sopro/codec/__init__.py
ADDED
sopro/codec/mimi.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ..audio import center_crop_audio, load_audio_file, resample, trim_silence_energy
|
|
9
|
+
from ..constants import DEFAULT_MIMI_ID, TARGET_SR
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from transformers import MimiConfig, MimiModel
|
|
13
|
+
except Exception:
|
|
14
|
+
MimiModel = None
|
|
15
|
+
MimiConfig = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MimiCodec:
|
|
19
|
+
def __init__(
|
|
20
|
+
self, num_quantizers: int, device: str = "cuda", model_id: str = DEFAULT_MIMI_ID
|
|
21
|
+
):
|
|
22
|
+
if MimiModel is None or MimiConfig is None:
|
|
23
|
+
raise RuntimeError(
|
|
24
|
+
"MimiModel missing. Install a recent transformers version."
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
self.device = torch.device(device)
|
|
28
|
+
cfg = MimiConfig.from_pretrained(model_id, num_quantizers=int(num_quantizers))
|
|
29
|
+
self.model = (
|
|
30
|
+
MimiModel.from_pretrained(model_id, config=cfg).to(self.device).eval()
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def codebook_size(self) -> int:
|
|
35
|
+
return int(getattr(self.model.config, "codebook_size", 2048))
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def num_quantizers(self) -> int:
|
|
39
|
+
return int(getattr(self.model.config, "num_quantizers", 32))
|
|
40
|
+
|
|
41
|
+
@torch.no_grad()
|
|
42
|
+
def encode_file(
|
|
43
|
+
self, wav_path: str, *, crop_seconds: Optional[float] = None
|
|
44
|
+
) -> torch.Tensor:
|
|
45
|
+
wav, sr = load_audio_file(wav_path)
|
|
46
|
+
|
|
47
|
+
wav = trim_silence_energy(wav, sr)
|
|
48
|
+
|
|
49
|
+
cfg = self.model.config
|
|
50
|
+
sr_target = int(getattr(cfg, "sampling_rate", TARGET_SR))
|
|
51
|
+
wav = resample(wav, sr, sr_target, device=str(self.device))
|
|
52
|
+
|
|
53
|
+
if crop_seconds is not None and crop_seconds > 0:
|
|
54
|
+
fps = float(getattr(cfg, "frame_rate", 12.5))
|
|
55
|
+
hop = int(round(sr_target / fps))
|
|
56
|
+
win_frames = max(1, int(round(crop_seconds * fps)))
|
|
57
|
+
win_samples = win_frames * hop
|
|
58
|
+
wav = center_crop_audio(wav, win_samples)
|
|
59
|
+
|
|
60
|
+
wav = wav.unsqueeze(0)
|
|
61
|
+
out = self.model.encode(wav, return_dict=True)
|
|
62
|
+
codes_bqt = out.audio_codes
|
|
63
|
+
return codes_bqt[0].permute(1, 0).contiguous()
|
|
64
|
+
|
|
65
|
+
@torch.no_grad()
|
|
66
|
+
def decode_full(self, codes_tq: torch.Tensor) -> torch.Tensor:
|
|
67
|
+
audio_codes = codes_tq.to(self.device).permute(1, 0).unsqueeze(0).contiguous()
|
|
68
|
+
out = self.model.decode(audio_codes=audio_codes, return_dict=True)
|
|
69
|
+
wav = out.audio_values
|
|
70
|
+
if wav.ndim == 1:
|
|
71
|
+
wav = wav.unsqueeze(0)
|
|
72
|
+
return wav
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class MimiDecodeState:
|
|
77
|
+
decoder_past_key_values: Optional[object] = None
|
|
78
|
+
frames_seen: int = 0
|
|
79
|
+
samples_emitted: int = 0
|
|
80
|
+
tail_codes_tq: Optional[torch.Tensor] = None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class MimiStreamDecoder:
|
|
84
|
+
def __init__(self, codec: MimiCodec):
|
|
85
|
+
self.codec = codec
|
|
86
|
+
self.codec.model.config.use_cache = True
|
|
87
|
+
|
|
88
|
+
def drop_cache_tail(self, pkv: Any, n: int):
|
|
89
|
+
if pkv is None or n <= 0:
|
|
90
|
+
return pkv
|
|
91
|
+
|
|
92
|
+
if hasattr(pkv, "to_legacy_cache") and hasattr(type(pkv), "from_legacy_cache"):
|
|
93
|
+
legacy = pkv.to_legacy_cache()
|
|
94
|
+
|
|
95
|
+
trimmed = tuple(
|
|
96
|
+
(
|
|
97
|
+
k[..., : max(0, k.shape[-2] - n), :].contiguous(),
|
|
98
|
+
v[..., : max(0, v.shape[-2] - n), :].contiguous(),
|
|
99
|
+
)
|
|
100
|
+
for (k, v) in legacy
|
|
101
|
+
)
|
|
102
|
+
return type(pkv).from_legacy_cache(trimmed)
|
|
103
|
+
|
|
104
|
+
if isinstance(pkv, tuple):
|
|
105
|
+
return tuple(
|
|
106
|
+
(
|
|
107
|
+
k[..., : max(0, k.shape[-2] - n), :].contiguous(),
|
|
108
|
+
v[..., : max(0, v.shape[-2] - n), :].contiguous(),
|
|
109
|
+
)
|
|
110
|
+
for (k, v) in pkv
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return pkv
|
|
114
|
+
|
|
115
|
+
@torch.inference_mode()
|
|
116
|
+
def decode_step(
|
|
117
|
+
self,
|
|
118
|
+
codes_chunk_tq: torch.Tensor,
|
|
119
|
+
state: Optional[MimiDecodeState] = None,
|
|
120
|
+
*,
|
|
121
|
+
overlap_frames: int = 2,
|
|
122
|
+
) -> Tuple[torch.Tensor, MimiDecodeState]:
|
|
123
|
+
if state is None:
|
|
124
|
+
state = MimiDecodeState()
|
|
125
|
+
|
|
126
|
+
cfg = self.codec.model.config
|
|
127
|
+
sr = int(getattr(cfg, "sampling_rate", 24000))
|
|
128
|
+
fps = float(getattr(cfg, "frame_rate", 12.5))
|
|
129
|
+
hop = int(round(sr / fps))
|
|
130
|
+
|
|
131
|
+
n_new = int(codes_chunk_tq.size(0))
|
|
132
|
+
if n_new == 0:
|
|
133
|
+
return torch.zeros(1, 0, device=self.codec.device), state
|
|
134
|
+
|
|
135
|
+
tail = state.tail_codes_tq
|
|
136
|
+
ov = 0
|
|
137
|
+
if overlap_frames > 0 and tail is not None and tail.numel() > 0:
|
|
138
|
+
ov = min(int(overlap_frames), int(tail.size(0)))
|
|
139
|
+
tail = tail[-ov:]
|
|
140
|
+
codes_in_tq = torch.cat([tail, codes_chunk_tq], dim=0)
|
|
141
|
+
else:
|
|
142
|
+
codes_in_tq = codes_chunk_tq
|
|
143
|
+
|
|
144
|
+
pkv = state.decoder_past_key_values
|
|
145
|
+
if ov > 0 and pkv is not None:
|
|
146
|
+
pkv = self.drop_cache_tail(pkv, ov)
|
|
147
|
+
|
|
148
|
+
audio_codes = (
|
|
149
|
+
codes_in_tq.to(self.codec.device).permute(1, 0).unsqueeze(0).contiguous()
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
out = self.codec.model.decode(
|
|
153
|
+
audio_codes=audio_codes,
|
|
154
|
+
decoder_past_key_values=pkv,
|
|
155
|
+
return_dict=True,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
wav = out.audio_values
|
|
159
|
+
if wav.ndim == 1:
|
|
160
|
+
wav = wav.unsqueeze(0)
|
|
161
|
+
else:
|
|
162
|
+
wav = wav.reshape(1, -1)
|
|
163
|
+
|
|
164
|
+
expected_total = int((ov + n_new) * hop)
|
|
165
|
+
if wav.size(1) >= expected_total:
|
|
166
|
+
wav = wav[:, :expected_total]
|
|
167
|
+
|
|
168
|
+
ov_samp = min(int(ov * hop), int(wav.size(1)))
|
|
169
|
+
wav_new = wav[:, ov_samp:]
|
|
170
|
+
|
|
171
|
+
state.decoder_past_key_values = getattr(out, "decoder_past_key_values", None)
|
|
172
|
+
state.frames_seen += n_new
|
|
173
|
+
state.samples_emitted += int(wav_new.size(1))
|
|
174
|
+
|
|
175
|
+
if overlap_frames > 0:
|
|
176
|
+
keep = min(int(overlap_frames), int(codes_in_tq.size(0)))
|
|
177
|
+
state.tail_codes_tq = codes_in_tq[-keep:].detach()
|
|
178
|
+
else:
|
|
179
|
+
state.tail_codes_tq = None
|
|
180
|
+
|
|
181
|
+
return wav_new, state
|
sopro/config.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
from sopro.constants import TARGET_SR
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class SoproTTSConfig:
|
|
9
|
+
num_codebooks: int = 32
|
|
10
|
+
codebook_size: int = 2048
|
|
11
|
+
mimi_fps: float = 12.5
|
|
12
|
+
max_frames: int = 400
|
|
13
|
+
audio_sr: int = TARGET_SR
|
|
14
|
+
|
|
15
|
+
d_model: int = 384
|
|
16
|
+
n_layers_text: int = 4
|
|
17
|
+
n_layers_ar: int = 6
|
|
18
|
+
n_layers_nar: int = 6
|
|
19
|
+
dropout: float = 0.05
|
|
20
|
+
|
|
21
|
+
pos_emb_max: int = 4096
|
|
22
|
+
max_text_len: int = 2048
|
|
23
|
+
|
|
24
|
+
nar_head_dim: int = 256
|
|
25
|
+
|
|
26
|
+
use_stop_head: bool = True
|
|
27
|
+
stop_threshold: float = 0.8
|
|
28
|
+
stop_patience: int = 5
|
|
29
|
+
min_gen_frames: int = 12
|
|
30
|
+
|
|
31
|
+
stage_B: Tuple[int, int] = (2, 4)
|
|
32
|
+
stage_C: Tuple[int, int] = (5, 8)
|
|
33
|
+
stage_D: Tuple[int, int] = (9, 16)
|
|
34
|
+
stage_E: Tuple[int, int] = (17, 32)
|
|
35
|
+
|
|
36
|
+
ar_lookback: int = 4
|
|
37
|
+
ar_kernel: int = 13
|
|
38
|
+
ar_dilation_cycle: Tuple[int, ...] = (1, 2, 4, 1)
|
|
39
|
+
|
|
40
|
+
ar_text_attn_freq: int = 2
|
|
41
|
+
|
|
42
|
+
ref_attn_heads: int = 2
|
|
43
|
+
ref_seconds_max: float = 12.0
|
|
44
|
+
|
|
45
|
+
preprompt_sec_max: float = 4.0
|
|
46
|
+
|
|
47
|
+
sv_student_dim: int = 192
|
|
48
|
+
style_strength: float = 1.0
|
sopro/constants.py
ADDED
sopro/hub.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import struct
|
|
5
|
+
from typing import Any, Dict, Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from huggingface_hub import snapshot_download
|
|
9
|
+
from safetensors.torch import load_file
|
|
10
|
+
|
|
11
|
+
from sopro.config import SoproTTSConfig
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def download_repo(
|
|
15
|
+
repo_id: str,
|
|
16
|
+
*,
|
|
17
|
+
revision: Optional[str] = None,
|
|
18
|
+
cache_dir: Optional[str] = None,
|
|
19
|
+
token: Optional[str] = None,
|
|
20
|
+
) -> str:
|
|
21
|
+
return snapshot_download(
|
|
22
|
+
repo_id=repo_id,
|
|
23
|
+
revision=revision,
|
|
24
|
+
cache_dir=cache_dir,
|
|
25
|
+
token=token,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _read_safetensors_metadata(path: str) -> Dict[str, str]:
|
|
30
|
+
with open(path, "rb") as f:
|
|
31
|
+
header_len = struct.unpack("<Q", f.read(8))[0]
|
|
32
|
+
header = json.loads(f.read(header_len).decode("utf-8"))
|
|
33
|
+
meta = header.get("__metadata__", {}) or {}
|
|
34
|
+
return {str(k): str(v) for k, v in meta.items()}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def load_cfg_from_safetensors(path: str) -> SoproTTSConfig:
|
|
38
|
+
meta = _read_safetensors_metadata(path)
|
|
39
|
+
if "cfg" not in meta:
|
|
40
|
+
raise RuntimeError(f"No 'cfg' metadata found in {path}.")
|
|
41
|
+
|
|
42
|
+
cfg_dict = json.loads(meta["cfg"])
|
|
43
|
+
init: Dict[str, Any] = {}
|
|
44
|
+
for k in SoproTTSConfig.__annotations__.keys():
|
|
45
|
+
if k in cfg_dict:
|
|
46
|
+
init[k] = cfg_dict[k]
|
|
47
|
+
|
|
48
|
+
cfg = SoproTTSConfig(**init)
|
|
49
|
+
return cfg
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def load_state_dict_from_safetensors(path: str) -> Dict[str, torch.Tensor]:
|
|
53
|
+
return load_file(path)
|