sopro 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sopro/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ from __future__ import annotations
2
+
3
+ from .model import SoproTTS
4
+
5
+ __all__ = ["SoproTTS"]
6
+ __version__ = "1.0.0"
sopro/audio.py ADDED
@@ -0,0 +1,155 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from .constants import TARGET_SR
10
+
11
+ try:
12
+ import torchaudio
13
+ import torchaudio.functional as AF
14
+ except Exception:
15
+ torchaudio = None
16
+ AF = None
17
+
18
+ try:
19
+ import soundfile as sf
20
+ except Exception:
21
+ sf = None
22
+
23
+
24
+ def device_str() -> str:
25
+ if torch.cuda.is_available():
26
+ return "cuda"
27
+ return "cpu"
28
+
29
+
30
+ def trim_silence_energy(
31
+ wav: torch.Tensor,
32
+ sr: int,
33
+ frame_ms: float = 25.0,
34
+ hop_ms: float = 10.0,
35
+ thresh_db_floor: float = -40.0,
36
+ prepad_ms: float = 30.0,
37
+ postpad_ms: float = 30.0,
38
+ min_keep_sec: float = 0.5,
39
+ ) -> torch.Tensor:
40
+ orig_1d = wav.ndim == 1
41
+ if orig_1d:
42
+ wav = wav.unsqueeze(0)
43
+
44
+ C, T = wav.shape
45
+ if T == 0:
46
+ return wav.squeeze(0) if orig_1d else wav
47
+ if T < int(sr * 0.1):
48
+ return wav.squeeze(0) if orig_1d else wav
49
+
50
+ frame_len = max(1, int(sr * frame_ms / 1000.0))
51
+ hop = max(1, int(sr * hop_ms / 1000.0))
52
+ if T < frame_len:
53
+ return wav.squeeze(0) if orig_1d else wav
54
+
55
+ mono = wav.mean(dim=0, keepdim=True)
56
+ frames = mono.unfold(-1, frame_len, hop)
57
+ energy = frames.pow(2).mean(dim=-1).squeeze(0)
58
+
59
+ eps = 1e-10
60
+ energy_db = 10.0 * torch.log10(energy + eps)
61
+ max_db = float(energy_db.max().item())
62
+
63
+ rel_thresh = max_db + thresh_db_floor
64
+ thresh_db = max(rel_thresh, thresh_db_floor)
65
+
66
+ voiced = energy_db > thresh_db
67
+ idx = torch.nonzero(voiced, as_tuple=False)
68
+ if idx.numel() == 0:
69
+ return wav.squeeze(0) if orig_1d else wav
70
+
71
+ first_frame = int(idx[0, 0].item())
72
+ last_frame = int(idx[-1, 0].item())
73
+
74
+ prepad_samples = int(sr * prepad_ms / 1000.0)
75
+ postpad_samples = int(sr * postpad_ms / 1000.0)
76
+
77
+ start = max(0, first_frame * hop - prepad_samples)
78
+
79
+ end = min(T, last_frame * hop + frame_len + postpad_samples)
80
+
81
+ min_keep = int(min_keep_sec * sr)
82
+ if end <= start or (end - start) < min_keep:
83
+ return wav.squeeze(0) if orig_1d else wav
84
+
85
+ out = wav[:, start:end]
86
+ return out.squeeze(0) if orig_1d else out
87
+
88
+
89
+ def load_audio_file(path: str) -> Tuple[torch.Tensor, int]:
90
+ if sf is not None:
91
+ wav_np, sr = sf.read(path, dtype="float32", always_2d=True)
92
+ wav = torch.from_numpy(wav_np).transpose(0, 1)
93
+ elif torchaudio is not None:
94
+ wav, sr = torchaudio.load(path)
95
+ if wav.dtype != torch.float32:
96
+ if wav.dtype == torch.int16:
97
+ wav = wav.float() / (2**15)
98
+ else:
99
+ wav = wav.float()
100
+ else:
101
+ raise RuntimeError("Install 'soundfile' or 'torchaudio' to read audio.")
102
+
103
+ if wav.size(0) > 1:
104
+ wav = wav.mean(dim=0, keepdim=True)
105
+ return wav, sr
106
+
107
+
108
+ def resample(
109
+ wav: torch.Tensor, sr_in: int, sr_out: int, device: Optional[str] = None
110
+ ) -> torch.Tensor:
111
+ device = device or device_str()
112
+ wav = wav.to(device)
113
+ if sr_in == sr_out:
114
+ return wav
115
+ if sr_in != sr_out and AF is None:
116
+ raise RuntimeError("Resampling requires torchaudio. pip install torchaudio")
117
+ return AF.resample(wav, sr_in, sr_out)
118
+
119
+
120
+ def save_audio(path: str, wav: torch.Tensor, sr: int = TARGET_SR) -> None:
121
+ os.makedirs(os.path.dirname(os.path.abspath(path)) or ".", exist_ok=True)
122
+
123
+ wav = wav.detach().cpu()
124
+
125
+ if wav.ndim == 1:
126
+ wav = wav.unsqueeze(0)
127
+ elif wav.ndim == 2:
128
+ pass
129
+ elif wav.ndim == 3:
130
+ wav = wav[0]
131
+ else:
132
+ raise ValueError(f"Expected wav with 1-3 dims, got shape {tuple(wav.shape)}")
133
+
134
+ if wav.size(0) > 1:
135
+ wav = wav.mean(dim=0, keepdim=True)
136
+
137
+ if sf is not None:
138
+ sf.write(path, wav[0].numpy(), sr)
139
+ return
140
+
141
+ if torchaudio is not None:
142
+ torchaudio.save(path, wav, sample_rate=sr)
143
+ return
144
+
145
+ raise RuntimeError("Install 'soundfile' or 'torchaudio' to write audio.")
146
+
147
+
148
+ def center_crop_audio(wav: torch.Tensor, win_samples: int) -> torch.Tensor:
149
+ if win_samples <= 0:
150
+ return wav
151
+ T = int(wav.shape[-1])
152
+ if T <= win_samples:
153
+ return wav
154
+ s = (T - win_samples) // 2
155
+ return wav[..., s : s + win_samples]
sopro/cli.py ADDED
@@ -0,0 +1,185 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import time
5
+
6
+ import torch
7
+ from tqdm.auto import tqdm
8
+
9
+ from .audio import save_audio
10
+ from .constants import TARGET_SR
11
+ from .model import SoproTTS
12
+
13
+
14
+ def main() -> None:
15
+ ap = argparse.ArgumentParser(description="SoproTTS cli inference")
16
+
17
+ ap.add_argument("--repo_id", type=str, default="samuel-vitorino/sopro")
18
+ ap.add_argument(
19
+ "--revision", type=str, default=None, help="Optional git revision/branch/tag"
20
+ )
21
+ ap.add_argument("--cache_dir", type=str, default=None, help="Optional HF cache dir")
22
+ ap.add_argument("--hf_token", type=str, default=None, help="HF token")
23
+
24
+ ap.add_argument("--text", type=str, required=True)
25
+ ap.add_argument("--ref_audio", type=str, default=None)
26
+ ap.add_argument("--ref_tokens", type=str, default=None, help="Path to .npy [T,Q]")
27
+ ap.add_argument("--out", type=str, required=True)
28
+
29
+ ap.add_argument("--max_frames", type=int, default=400)
30
+
31
+ ap.add_argument("--top_p", type=float, default=0.9)
32
+ ap.add_argument("--temperature", type=float, default=1.05)
33
+ ap.add_argument("--no_anti_loop", action="store_true")
34
+
35
+ ap.add_argument("--no_prefix", action="store_true")
36
+ ap.add_argument("--prefix_sec", type=float, default=None)
37
+ ap.add_argument("--style_strength", type=float, default=None)
38
+ ap.add_argument("--ref_seconds", type=float, default=None)
39
+
40
+ ap.add_argument("--seed", type=int, default=None, help="Random seed for sampling")
41
+
42
+ ap.add_argument(
43
+ "--no_stop_head", action="store_true", help="Disable stop head early stopping"
44
+ )
45
+ ap.add_argument(
46
+ "--stop_patience", type=int, default=None, help="Override cfg.stop_patience"
47
+ )
48
+ ap.add_argument(
49
+ "--stop_threshold", type=float, default=None, help="Override cfg.stop_threshold"
50
+ )
51
+
52
+ ap.add_argument(
53
+ "--device",
54
+ type=str,
55
+ default=None,
56
+ choices=["cpu", "cuda", "mps"],
57
+ help="Device to run on (default: cuda if available else cpu)",
58
+ )
59
+
60
+ ap.add_argument("--quiet", action="store_true")
61
+
62
+ args = ap.parse_args()
63
+
64
+ default_device = "cuda" if torch.cuda.is_available() else "cpu"
65
+ device = args.device or default_device
66
+
67
+ if device == "cuda" and not torch.cuda.is_available():
68
+ raise SystemExit("Error: --device cuda requested but CUDA is not available.")
69
+ if device == "mps" and not (
70
+ hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
71
+ ):
72
+ raise SystemExit("Error: --device mps requested but MPS is not available.")
73
+
74
+ if args.seed is not None:
75
+ torch.manual_seed(args.seed)
76
+ if torch.cuda.is_available():
77
+ torch.cuda.manual_seed_all(args.seed)
78
+
79
+ t0 = time.perf_counter()
80
+ tts = SoproTTS.from_pretrained(
81
+ args.repo_id,
82
+ revision=args.revision,
83
+ cache_dir=args.cache_dir,
84
+ token=args.hf_token,
85
+ device=device,
86
+ )
87
+ t1 = time.perf_counter()
88
+ if not args.quiet:
89
+ print(f"[Load] {t1 - t0:.2f}s")
90
+
91
+ cfg = tts.cfg
92
+
93
+ ref_tokens_tq = None
94
+ if args.ref_tokens is not None:
95
+ import numpy as np
96
+
97
+ arr = np.load(args.ref_tokens)
98
+ ref_tokens_tq = torch.from_numpy(arr).long()
99
+
100
+ text_ids = tts.encode_text(args.text)
101
+ ref = tts.encode_reference(
102
+ ref_audio_path=args.ref_audio,
103
+ ref_tokens_tq=ref_tokens_tq,
104
+ ref_seconds=args.ref_seconds,
105
+ )
106
+
107
+ prep = tts.model.prepare_conditioning(
108
+ text_ids,
109
+ ref,
110
+ max_frames=args.max_frames,
111
+ device=tts.device,
112
+ style_strength=float(
113
+ args.style_strength
114
+ if args.style_strength is not None
115
+ else cfg.style_strength
116
+ ),
117
+ )
118
+
119
+ t_start = time.perf_counter()
120
+
121
+ hist_A: list[int] = []
122
+ pbar = tqdm(
123
+ total=args.max_frames,
124
+ desc="AR sampling",
125
+ unit="frame",
126
+ disable=args.quiet,
127
+ )
128
+
129
+ for _t, rvq1, p_stop in tts.model.ar_stream(
130
+ prep,
131
+ max_frames=args.max_frames,
132
+ top_p=args.top_p,
133
+ temperature=args.temperature,
134
+ anti_loop=(not args.no_anti_loop),
135
+ use_prefix=(not args.no_prefix),
136
+ prefix_sec_fixed=args.prefix_sec,
137
+ use_stop_head=(False if args.no_stop_head else None),
138
+ stop_patience=args.stop_patience,
139
+ stop_threshold=args.stop_threshold,
140
+ ):
141
+ hist_A.append(int(rvq1))
142
+ pbar.update(1)
143
+ if p_stop is None:
144
+ pbar.set_postfix(p_stop="off")
145
+ else:
146
+ pbar.set_postfix(p_stop=f"{float(p_stop):.2f}")
147
+
148
+ pbar.n = len(hist_A)
149
+ pbar.close()
150
+
151
+ t_after_sampling = time.perf_counter()
152
+
153
+ T = len(hist_A)
154
+ if T == 0:
155
+ save_audio(args.out, torch.zeros(1, 0), sr=TARGET_SR)
156
+ t_end = time.perf_counter()
157
+ if not args.quiet:
158
+ print(
159
+ f"[Timing] sampling={t_after_sampling - t_start:.2f}s, "
160
+ f"postproc+decode+save={t_end - t_after_sampling:.2f}s, "
161
+ f"total={t_end - t_start:.2f}s"
162
+ )
163
+ print(f"[Done] Wrote {args.out}")
164
+ return
165
+
166
+ tokens_A = torch.tensor(hist_A, device=tts.device, dtype=torch.long).unsqueeze(0)
167
+ cond_seq = prep["cond_all"][:, :T, :]
168
+ tokens_1xTQ = tts.model.nar_refine(cond_seq, tokens_A)
169
+ tokens_tq = tokens_1xTQ.squeeze(0)
170
+
171
+ wav = tts.codec.decode_full(tokens_tq)
172
+ save_audio(args.out, wav, sr=TARGET_SR)
173
+
174
+ t_end = time.perf_counter()
175
+ if not args.quiet:
176
+ print(
177
+ f"[Timing] sampling={t_after_sampling - t_start:.2f}s, "
178
+ f"postproc+decode+save={t_end - t_after_sampling:.2f}s, "
179
+ f"total={t_end - t_start:.2f}s"
180
+ )
181
+ print(f"[Done] Wrote {args.out}")
182
+
183
+
184
+ if __name__ == "__main__":
185
+ main()
@@ -0,0 +1,3 @@
1
+ from .mimi import MimiCodec, MimiDecodeState, MimiStreamDecoder
2
+
3
+ __all__ = ["MimiCodec", "MimiStreamDecoder", "MimiDecodeState"]
sopro/codec/mimi.py ADDED
@@ -0,0 +1,181 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from ..audio import center_crop_audio, load_audio_file, resample, trim_silence_energy
9
+ from ..constants import DEFAULT_MIMI_ID, TARGET_SR
10
+
11
+ try:
12
+ from transformers import MimiConfig, MimiModel
13
+ except Exception:
14
+ MimiModel = None
15
+ MimiConfig = None
16
+
17
+
18
+ class MimiCodec:
19
+ def __init__(
20
+ self, num_quantizers: int, device: str = "cuda", model_id: str = DEFAULT_MIMI_ID
21
+ ):
22
+ if MimiModel is None or MimiConfig is None:
23
+ raise RuntimeError(
24
+ "MimiModel missing. Install a recent transformers version."
25
+ )
26
+
27
+ self.device = torch.device(device)
28
+ cfg = MimiConfig.from_pretrained(model_id, num_quantizers=int(num_quantizers))
29
+ self.model = (
30
+ MimiModel.from_pretrained(model_id, config=cfg).to(self.device).eval()
31
+ )
32
+
33
+ @property
34
+ def codebook_size(self) -> int:
35
+ return int(getattr(self.model.config, "codebook_size", 2048))
36
+
37
+ @property
38
+ def num_quantizers(self) -> int:
39
+ return int(getattr(self.model.config, "num_quantizers", 32))
40
+
41
+ @torch.no_grad()
42
+ def encode_file(
43
+ self, wav_path: str, *, crop_seconds: Optional[float] = None
44
+ ) -> torch.Tensor:
45
+ wav, sr = load_audio_file(wav_path)
46
+
47
+ wav = trim_silence_energy(wav, sr)
48
+
49
+ cfg = self.model.config
50
+ sr_target = int(getattr(cfg, "sampling_rate", TARGET_SR))
51
+ wav = resample(wav, sr, sr_target, device=str(self.device))
52
+
53
+ if crop_seconds is not None and crop_seconds > 0:
54
+ fps = float(getattr(cfg, "frame_rate", 12.5))
55
+ hop = int(round(sr_target / fps))
56
+ win_frames = max(1, int(round(crop_seconds * fps)))
57
+ win_samples = win_frames * hop
58
+ wav = center_crop_audio(wav, win_samples)
59
+
60
+ wav = wav.unsqueeze(0)
61
+ out = self.model.encode(wav, return_dict=True)
62
+ codes_bqt = out.audio_codes
63
+ return codes_bqt[0].permute(1, 0).contiguous()
64
+
65
+ @torch.no_grad()
66
+ def decode_full(self, codes_tq: torch.Tensor) -> torch.Tensor:
67
+ audio_codes = codes_tq.to(self.device).permute(1, 0).unsqueeze(0).contiguous()
68
+ out = self.model.decode(audio_codes=audio_codes, return_dict=True)
69
+ wav = out.audio_values
70
+ if wav.ndim == 1:
71
+ wav = wav.unsqueeze(0)
72
+ return wav
73
+
74
+
75
+ @dataclass
76
+ class MimiDecodeState:
77
+ decoder_past_key_values: Optional[object] = None
78
+ frames_seen: int = 0
79
+ samples_emitted: int = 0
80
+ tail_codes_tq: Optional[torch.Tensor] = None
81
+
82
+
83
+ class MimiStreamDecoder:
84
+ def __init__(self, codec: MimiCodec):
85
+ self.codec = codec
86
+ self.codec.model.config.use_cache = True
87
+
88
+ def drop_cache_tail(self, pkv: Any, n: int):
89
+ if pkv is None or n <= 0:
90
+ return pkv
91
+
92
+ if hasattr(pkv, "to_legacy_cache") and hasattr(type(pkv), "from_legacy_cache"):
93
+ legacy = pkv.to_legacy_cache()
94
+
95
+ trimmed = tuple(
96
+ (
97
+ k[..., : max(0, k.shape[-2] - n), :].contiguous(),
98
+ v[..., : max(0, v.shape[-2] - n), :].contiguous(),
99
+ )
100
+ for (k, v) in legacy
101
+ )
102
+ return type(pkv).from_legacy_cache(trimmed)
103
+
104
+ if isinstance(pkv, tuple):
105
+ return tuple(
106
+ (
107
+ k[..., : max(0, k.shape[-2] - n), :].contiguous(),
108
+ v[..., : max(0, v.shape[-2] - n), :].contiguous(),
109
+ )
110
+ for (k, v) in pkv
111
+ )
112
+
113
+ return pkv
114
+
115
+ @torch.inference_mode()
116
+ def decode_step(
117
+ self,
118
+ codes_chunk_tq: torch.Tensor,
119
+ state: Optional[MimiDecodeState] = None,
120
+ *,
121
+ overlap_frames: int = 2,
122
+ ) -> Tuple[torch.Tensor, MimiDecodeState]:
123
+ if state is None:
124
+ state = MimiDecodeState()
125
+
126
+ cfg = self.codec.model.config
127
+ sr = int(getattr(cfg, "sampling_rate", 24000))
128
+ fps = float(getattr(cfg, "frame_rate", 12.5))
129
+ hop = int(round(sr / fps))
130
+
131
+ n_new = int(codes_chunk_tq.size(0))
132
+ if n_new == 0:
133
+ return torch.zeros(1, 0, device=self.codec.device), state
134
+
135
+ tail = state.tail_codes_tq
136
+ ov = 0
137
+ if overlap_frames > 0 and tail is not None and tail.numel() > 0:
138
+ ov = min(int(overlap_frames), int(tail.size(0)))
139
+ tail = tail[-ov:]
140
+ codes_in_tq = torch.cat([tail, codes_chunk_tq], dim=0)
141
+ else:
142
+ codes_in_tq = codes_chunk_tq
143
+
144
+ pkv = state.decoder_past_key_values
145
+ if ov > 0 and pkv is not None:
146
+ pkv = self.drop_cache_tail(pkv, ov)
147
+
148
+ audio_codes = (
149
+ codes_in_tq.to(self.codec.device).permute(1, 0).unsqueeze(0).contiguous()
150
+ )
151
+
152
+ out = self.codec.model.decode(
153
+ audio_codes=audio_codes,
154
+ decoder_past_key_values=pkv,
155
+ return_dict=True,
156
+ )
157
+
158
+ wav = out.audio_values
159
+ if wav.ndim == 1:
160
+ wav = wav.unsqueeze(0)
161
+ else:
162
+ wav = wav.reshape(1, -1)
163
+
164
+ expected_total = int((ov + n_new) * hop)
165
+ if wav.size(1) >= expected_total:
166
+ wav = wav[:, :expected_total]
167
+
168
+ ov_samp = min(int(ov * hop), int(wav.size(1)))
169
+ wav_new = wav[:, ov_samp:]
170
+
171
+ state.decoder_past_key_values = getattr(out, "decoder_past_key_values", None)
172
+ state.frames_seen += n_new
173
+ state.samples_emitted += int(wav_new.size(1))
174
+
175
+ if overlap_frames > 0:
176
+ keep = min(int(overlap_frames), int(codes_in_tq.size(0)))
177
+ state.tail_codes_tq = codes_in_tq[-keep:].detach()
178
+ else:
179
+ state.tail_codes_tq = None
180
+
181
+ return wav_new, state
sopro/config.py ADDED
@@ -0,0 +1,48 @@
1
+ from dataclasses import dataclass
2
+ from typing import Tuple
3
+
4
+ from sopro.constants import TARGET_SR
5
+
6
+
7
+ @dataclass
8
+ class SoproTTSConfig:
9
+ num_codebooks: int = 32
10
+ codebook_size: int = 2048
11
+ mimi_fps: float = 12.5
12
+ max_frames: int = 400
13
+ audio_sr: int = TARGET_SR
14
+
15
+ d_model: int = 384
16
+ n_layers_text: int = 4
17
+ n_layers_ar: int = 6
18
+ n_layers_nar: int = 6
19
+ dropout: float = 0.05
20
+
21
+ pos_emb_max: int = 4096
22
+ max_text_len: int = 2048
23
+
24
+ nar_head_dim: int = 256
25
+
26
+ use_stop_head: bool = True
27
+ stop_threshold: float = 0.8
28
+ stop_patience: int = 5
29
+ min_gen_frames: int = 12
30
+
31
+ stage_B: Tuple[int, int] = (2, 4)
32
+ stage_C: Tuple[int, int] = (5, 8)
33
+ stage_D: Tuple[int, int] = (9, 16)
34
+ stage_E: Tuple[int, int] = (17, 32)
35
+
36
+ ar_lookback: int = 4
37
+ ar_kernel: int = 13
38
+ ar_dilation_cycle: Tuple[int, ...] = (1, 2, 4, 1)
39
+
40
+ ar_text_attn_freq: int = 2
41
+
42
+ ref_attn_heads: int = 2
43
+ ref_seconds_max: float = 12.0
44
+
45
+ preprompt_sec_max: float = 4.0
46
+
47
+ sv_student_dim: int = 192
48
+ style_strength: float = 1.0
sopro/constants.py ADDED
@@ -0,0 +1,5 @@
1
+ from __future__ import annotations
2
+
3
+ TARGET_SR: int = 24000
4
+
5
+ DEFAULT_MIMI_ID: str = "kyutai/mimi"
sopro/hub.py ADDED
@@ -0,0 +1,53 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import struct
5
+ from typing import Any, Dict, Optional
6
+
7
+ import torch
8
+ from huggingface_hub import snapshot_download
9
+ from safetensors.torch import load_file
10
+
11
+ from sopro.config import SoproTTSConfig
12
+
13
+
14
+ def download_repo(
15
+ repo_id: str,
16
+ *,
17
+ revision: Optional[str] = None,
18
+ cache_dir: Optional[str] = None,
19
+ token: Optional[str] = None,
20
+ ) -> str:
21
+ return snapshot_download(
22
+ repo_id=repo_id,
23
+ revision=revision,
24
+ cache_dir=cache_dir,
25
+ token=token,
26
+ )
27
+
28
+
29
+ def _read_safetensors_metadata(path: str) -> Dict[str, str]:
30
+ with open(path, "rb") as f:
31
+ header_len = struct.unpack("<Q", f.read(8))[0]
32
+ header = json.loads(f.read(header_len).decode("utf-8"))
33
+ meta = header.get("__metadata__", {}) or {}
34
+ return {str(k): str(v) for k, v in meta.items()}
35
+
36
+
37
+ def load_cfg_from_safetensors(path: str) -> SoproTTSConfig:
38
+ meta = _read_safetensors_metadata(path)
39
+ if "cfg" not in meta:
40
+ raise RuntimeError(f"No 'cfg' metadata found in {path}.")
41
+
42
+ cfg_dict = json.loads(meta["cfg"])
43
+ init: Dict[str, Any] = {}
44
+ for k in SoproTTSConfig.__annotations__.keys():
45
+ if k in cfg_dict:
46
+ init[k] = cfg_dict[k]
47
+
48
+ cfg = SoproTTSConfig(**init)
49
+ return cfg
50
+
51
+
52
+ def load_state_dict_from_safetensors(path: str) -> Dict[str, torch.Tensor]:
53
+ return load_file(path)