abstractvoice 0.5.1__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- abstractvoice/__init__.py +2 -5
- abstractvoice/__main__.py +82 -3
- abstractvoice/adapters/__init__.py +12 -0
- abstractvoice/adapters/base.py +207 -0
- abstractvoice/adapters/stt_faster_whisper.py +401 -0
- abstractvoice/adapters/tts_piper.py +480 -0
- abstractvoice/aec/__init__.py +10 -0
- abstractvoice/aec/webrtc_apm.py +56 -0
- abstractvoice/artifacts.py +173 -0
- abstractvoice/audio/__init__.py +7 -0
- abstractvoice/audio/recorder.py +46 -0
- abstractvoice/audio/resample.py +25 -0
- abstractvoice/cloning/__init__.py +7 -0
- abstractvoice/cloning/engine_chroma.py +738 -0
- abstractvoice/cloning/engine_f5.py +546 -0
- abstractvoice/cloning/manager.py +349 -0
- abstractvoice/cloning/store.py +362 -0
- abstractvoice/compute/__init__.py +6 -0
- abstractvoice/compute/device.py +73 -0
- abstractvoice/config/__init__.py +2 -0
- abstractvoice/config/voice_catalog.py +19 -0
- abstractvoice/dependency_check.py +0 -1
- abstractvoice/examples/cli_repl.py +2403 -243
- abstractvoice/examples/voice_cli.py +64 -63
- abstractvoice/integrations/__init__.py +2 -0
- abstractvoice/integrations/abstractcore.py +116 -0
- abstractvoice/integrations/abstractcore_plugin.py +253 -0
- abstractvoice/prefetch.py +82 -0
- abstractvoice/recognition.py +424 -42
- abstractvoice/stop_phrase.py +103 -0
- abstractvoice/tts/__init__.py +3 -3
- abstractvoice/tts/adapter_tts_engine.py +210 -0
- abstractvoice/tts/tts_engine.py +257 -1208
- abstractvoice/vm/__init__.py +2 -0
- abstractvoice/vm/common.py +21 -0
- abstractvoice/vm/core.py +139 -0
- abstractvoice/vm/manager.py +108 -0
- abstractvoice/vm/stt_mixin.py +158 -0
- abstractvoice/vm/tts_mixin.py +550 -0
- abstractvoice/voice_manager.py +6 -1061
- abstractvoice-0.6.1.dist-info/METADATA +213 -0
- abstractvoice-0.6.1.dist-info/RECORD +52 -0
- {abstractvoice-0.5.1.dist-info → abstractvoice-0.6.1.dist-info}/WHEEL +1 -1
- abstractvoice-0.6.1.dist-info/entry_points.txt +6 -0
- abstractvoice/instant_setup.py +0 -83
- abstractvoice/simple_model_manager.py +0 -539
- abstractvoice-0.5.1.dist-info/METADATA +0 -1458
- abstractvoice-0.5.1.dist-info/RECORD +0 -23
- abstractvoice-0.5.1.dist-info/entry_points.txt +0 -2
- {abstractvoice-0.5.1.dist-info → abstractvoice-0.6.1.dist-info}/licenses/LICENSE +0 -0
- {abstractvoice-0.5.1.dist-info → abstractvoice-0.6.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,546 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import gc
|
|
4
|
+
import os
|
|
5
|
+
import tempfile
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Iterable, List, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import soundfile as sf
|
|
12
|
+
|
|
13
|
+
from ..audio.resample import linear_resample_mono
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _load_as_mono_float(path: Path) -> Tuple[np.ndarray, int]:
|
|
17
|
+
audio, sr = sf.read(str(path), always_2d=True, dtype="float32")
|
|
18
|
+
# downmix
|
|
19
|
+
mono = np.mean(audio, axis=1).astype(np.float32)
|
|
20
|
+
return mono, int(sr)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _load_as_torch_channels_first(path: Path):
|
|
24
|
+
"""Load audio as a float32 torch Tensor shaped (channels, frames).
|
|
25
|
+
|
|
26
|
+
We prefer `soundfile` over `torchaudio.load()` because torchaudio's I/O backend
|
|
27
|
+
can vary by version (e.g. TorchCodec requirements) and may emit noisy stderr
|
|
28
|
+
logs during decode that corrupt interactive CLI output.
|
|
29
|
+
"""
|
|
30
|
+
try:
|
|
31
|
+
import torch
|
|
32
|
+
except Exception as e: # pragma: no cover - torch is required by f5_tts runtime anyway
|
|
33
|
+
raise RuntimeError("torch is required for F5 cloning inference") from e
|
|
34
|
+
|
|
35
|
+
audio, sr = sf.read(str(path), always_2d=True, dtype="float32")
|
|
36
|
+
# soundfile: (frames, channels) -> torch: (channels, frames)
|
|
37
|
+
arr = np.ascontiguousarray(audio.T, dtype=np.float32)
|
|
38
|
+
return torch.from_numpy(arr), int(sr)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass(frozen=True)
|
|
42
|
+
class OpenF5Artifacts:
|
|
43
|
+
model_cfg: Path
|
|
44
|
+
ckpt_file: Path
|
|
45
|
+
vocab_file: Path
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class F5TTSVoiceCloningEngine:
|
|
49
|
+
"""In-process F5-TTS voice cloning engine (optional extra).
|
|
50
|
+
|
|
51
|
+
Why in-process
|
|
52
|
+
--------------
|
|
53
|
+
The CLI approach re-loads a multi-GB model on every utterance (very slow).
|
|
54
|
+
Running in-process allows:
|
|
55
|
+
- one-time model/vocoder load
|
|
56
|
+
- per-voice reference preprocessing cache
|
|
57
|
+
- much lower latency per utterance
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
*,
|
|
63
|
+
whisper_model: str = "tiny",
|
|
64
|
+
debug: bool = False,
|
|
65
|
+
device: str = "auto",
|
|
66
|
+
nfe_step: int = 16,
|
|
67
|
+
cfg_strength: float = 2.0,
|
|
68
|
+
sway_sampling_coef: float = -1.0,
|
|
69
|
+
vocoder_name: str = "vocos",
|
|
70
|
+
target_rms: float = 0.1,
|
|
71
|
+
cross_fade_duration: float = 0.15,
|
|
72
|
+
):
|
|
73
|
+
self.debug = debug
|
|
74
|
+
self._whisper_model = whisper_model
|
|
75
|
+
self._stt = None
|
|
76
|
+
self._device_pref = device
|
|
77
|
+
|
|
78
|
+
# Speed/quality knobs (lower nfe_step = faster, usually lower quality).
|
|
79
|
+
self._nfe_step = int(nfe_step)
|
|
80
|
+
self._cfg_strength = float(cfg_strength)
|
|
81
|
+
self._sway_sampling_coef = float(sway_sampling_coef)
|
|
82
|
+
self._vocoder_name = str(vocoder_name)
|
|
83
|
+
self._target_rms = float(target_rms)
|
|
84
|
+
self._cross_fade_duration = float(cross_fade_duration)
|
|
85
|
+
self._quality_preset = "balanced"
|
|
86
|
+
|
|
87
|
+
# Lazy heavy objects (loaded on first inference).
|
|
88
|
+
self._f5_model = None
|
|
89
|
+
self._f5_vocoder = None
|
|
90
|
+
self._f5_device = None
|
|
91
|
+
|
|
92
|
+
def unload(self) -> None:
|
|
93
|
+
"""Best-effort release of loaded model/vocoder to free memory."""
|
|
94
|
+
self._f5_model = None
|
|
95
|
+
self._f5_vocoder = None
|
|
96
|
+
self._f5_device = None
|
|
97
|
+
# STT adapter can also hold memory; drop references.
|
|
98
|
+
self._stt = None
|
|
99
|
+
try:
|
|
100
|
+
gc.collect()
|
|
101
|
+
except Exception:
|
|
102
|
+
pass
|
|
103
|
+
try:
|
|
104
|
+
import torch
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
if torch.cuda.is_available():
|
|
108
|
+
torch.cuda.empty_cache()
|
|
109
|
+
except Exception:
|
|
110
|
+
pass
|
|
111
|
+
try:
|
|
112
|
+
if hasattr(torch, "mps") and torch.backends.mps.is_available():
|
|
113
|
+
torch.mps.empty_cache()
|
|
114
|
+
except Exception:
|
|
115
|
+
pass
|
|
116
|
+
except Exception:
|
|
117
|
+
pass
|
|
118
|
+
|
|
119
|
+
def runtime_info(self) -> dict:
|
|
120
|
+
"""Return best-effort runtime info for debugging/perf validation."""
|
|
121
|
+
info = {"requested_device": self._device_pref, "resolved_device": self._f5_device, "quality_preset": self._quality_preset}
|
|
122
|
+
try:
|
|
123
|
+
m = self._f5_model
|
|
124
|
+
if m is not None and hasattr(m, "parameters"):
|
|
125
|
+
p = next(iter(m.parameters()), None)
|
|
126
|
+
if p is not None and hasattr(p, "device"):
|
|
127
|
+
info["model_param_device"] = str(p.device)
|
|
128
|
+
except Exception:
|
|
129
|
+
pass
|
|
130
|
+
try:
|
|
131
|
+
import torch
|
|
132
|
+
|
|
133
|
+
info["torch_version"] = getattr(torch, "__version__", "?")
|
|
134
|
+
info["cuda_available"] = bool(torch.cuda.is_available())
|
|
135
|
+
try:
|
|
136
|
+
info["mps_available"] = bool(torch.backends.mps.is_available())
|
|
137
|
+
except Exception:
|
|
138
|
+
info["mps_available"] = False
|
|
139
|
+
except Exception:
|
|
140
|
+
pass
|
|
141
|
+
return info
|
|
142
|
+
|
|
143
|
+
def set_quality_preset(self, preset: str) -> None:
|
|
144
|
+
"""Set speed/quality preset.
|
|
145
|
+
|
|
146
|
+
Presets tune diffusion steps; lower steps are faster but can reduce quality.
|
|
147
|
+
"""
|
|
148
|
+
p = (preset or "").strip().lower()
|
|
149
|
+
if p not in ("fast", "balanced", "high"):
|
|
150
|
+
raise ValueError("preset must be one of: fast|balanced|high")
|
|
151
|
+
self._quality_preset = p
|
|
152
|
+
if p == "fast":
|
|
153
|
+
self._nfe_step = 8
|
|
154
|
+
self._cfg_strength = 1.8
|
|
155
|
+
elif p == "balanced":
|
|
156
|
+
self._nfe_step = 16
|
|
157
|
+
self._cfg_strength = 2.0
|
|
158
|
+
else:
|
|
159
|
+
self._nfe_step = 24
|
|
160
|
+
self._cfg_strength = 2.2
|
|
161
|
+
|
|
162
|
+
def _get_stt(self):
|
|
163
|
+
"""Lazy-load STT to avoid surprise model downloads."""
|
|
164
|
+
if self._stt is None:
|
|
165
|
+
from ..adapters.stt_faster_whisper import FasterWhisperAdapter
|
|
166
|
+
|
|
167
|
+
self._stt = FasterWhisperAdapter(
|
|
168
|
+
model_size=self._whisper_model,
|
|
169
|
+
device="cpu",
|
|
170
|
+
compute_type="int8",
|
|
171
|
+
)
|
|
172
|
+
return self._stt
|
|
173
|
+
|
|
174
|
+
def _ensure_f5_runtime(self) -> None:
|
|
175
|
+
try:
|
|
176
|
+
import importlib.util
|
|
177
|
+
|
|
178
|
+
if importlib.util.find_spec("f5_tts") is None:
|
|
179
|
+
raise ImportError("f5_tts not installed")
|
|
180
|
+
except Exception as e:
|
|
181
|
+
raise RuntimeError(
|
|
182
|
+
"Voice cloning requires the optional dependency group.\n"
|
|
183
|
+
"Install with:\n"
|
|
184
|
+
" pip install \"abstractvoice[cloning]\"\n"
|
|
185
|
+
f"Original error: {e}"
|
|
186
|
+
) from e
|
|
187
|
+
|
|
188
|
+
def _artifact_root(self) -> Path:
|
|
189
|
+
cache_dir = Path(os.path.expanduser("~/.cache/abstractvoice/openf5"))
|
|
190
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
191
|
+
return cache_dir
|
|
192
|
+
|
|
193
|
+
def _resolve_openf5_artifacts_local(self) -> OpenF5Artifacts:
|
|
194
|
+
"""Resolve artifacts from the local cache directory without any network calls."""
|
|
195
|
+
root = self._artifact_root()
|
|
196
|
+
cfg = next(iter(root.rglob("*.yaml")), None) or next(iter(root.rglob("*.yml")), None)
|
|
197
|
+
ckpt = next(iter(root.rglob("*.pt")), None)
|
|
198
|
+
vocab = next(iter(root.rglob("vocab*.txt")), None) or next(iter(root.rglob("*.txt")), None)
|
|
199
|
+
if not (cfg and ckpt and vocab):
|
|
200
|
+
raise RuntimeError(
|
|
201
|
+
"OpenF5 artifacts are not present locally.\n"
|
|
202
|
+
"In the REPL run: /cloning_download\n"
|
|
203
|
+
f"Looked under: {root}"
|
|
204
|
+
)
|
|
205
|
+
return OpenF5Artifacts(model_cfg=cfg, ckpt_file=ckpt, vocab_file=vocab)
|
|
206
|
+
|
|
207
|
+
def ensure_openf5_artifacts_downloaded(self) -> OpenF5Artifacts:
|
|
208
|
+
"""Explicit prefetch entry point (REPL should call this, not speak())."""
|
|
209
|
+
# Keep output quiet in interactive contexts.
|
|
210
|
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
|
211
|
+
# Lazy import: keep core install light.
|
|
212
|
+
try:
|
|
213
|
+
from huggingface_hub import snapshot_download
|
|
214
|
+
except Exception as e:
|
|
215
|
+
raise RuntimeError(
|
|
216
|
+
"huggingface_hub is required to download OpenF5 artifacts.\n"
|
|
217
|
+
"Install with: pip install huggingface_hub"
|
|
218
|
+
) from e
|
|
219
|
+
|
|
220
|
+
import warnings
|
|
221
|
+
|
|
222
|
+
with warnings.catch_warnings():
|
|
223
|
+
# huggingface_hub deprecated `local_dir_use_symlinks`; keep prefetch UX clean.
|
|
224
|
+
warnings.filterwarnings(
|
|
225
|
+
"ignore",
|
|
226
|
+
category=UserWarning,
|
|
227
|
+
message=r".*local_dir_use_symlinks.*deprecated.*",
|
|
228
|
+
)
|
|
229
|
+
snapshot_download(
|
|
230
|
+
repo_id="mrfakename/OpenF5-TTS-Base",
|
|
231
|
+
local_dir=str(self._artifact_root()),
|
|
232
|
+
)
|
|
233
|
+
return self._resolve_openf5_artifacts_local()
|
|
234
|
+
|
|
235
|
+
def are_openf5_artifacts_available(self) -> bool:
|
|
236
|
+
"""Return True if artifacts are already present locally (no downloads)."""
|
|
237
|
+
root = self._artifact_root()
|
|
238
|
+
cfg = next(iter(root.rglob("*.yaml")), None) or next(iter(root.rglob("*.yml")), None)
|
|
239
|
+
ckpt = next(iter(root.rglob("*.pt")), None)
|
|
240
|
+
vocab = next(iter(root.rglob("vocab*.txt")), None) or next(iter(root.rglob("*.txt")), None)
|
|
241
|
+
return bool(cfg and ckpt and vocab)
|
|
242
|
+
|
|
243
|
+
def _resolve_device(self) -> str:
|
|
244
|
+
if self._device_pref and self._device_pref != "auto":
|
|
245
|
+
return str(self._device_pref)
|
|
246
|
+
from ..compute import best_torch_device
|
|
247
|
+
|
|
248
|
+
return best_torch_device()
|
|
249
|
+
|
|
250
|
+
def _ensure_model_loaded(self) -> None:
|
|
251
|
+
"""Load vocoder + model once (expensive)."""
|
|
252
|
+
if self._f5_model is not None and self._f5_vocoder is not None and self._f5_device is not None:
|
|
253
|
+
return
|
|
254
|
+
|
|
255
|
+
self._ensure_f5_runtime()
|
|
256
|
+
artifacts = self._resolve_openf5_artifacts_local()
|
|
257
|
+
|
|
258
|
+
# Silence HF progress bars during internal downloads (REPL UX).
|
|
259
|
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
|
260
|
+
|
|
261
|
+
# Some f5_tts utilities print; keep it quiet unless debug.
|
|
262
|
+
import contextlib
|
|
263
|
+
import io
|
|
264
|
+
|
|
265
|
+
from omegaconf import OmegaConf
|
|
266
|
+
from hydra.utils import get_class
|
|
267
|
+
|
|
268
|
+
from f5_tts.infer.utils_infer import load_model, load_vocoder
|
|
269
|
+
|
|
270
|
+
device = self._resolve_device()
|
|
271
|
+
|
|
272
|
+
model_cfg = OmegaConf.load(str(artifacts.model_cfg))
|
|
273
|
+
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
|
274
|
+
model_arc = model_cfg.model.arch
|
|
275
|
+
|
|
276
|
+
# load vocoder + model
|
|
277
|
+
if self.debug:
|
|
278
|
+
self._f5_vocoder = load_vocoder(vocoder_name=self._vocoder_name, device=device)
|
|
279
|
+
self._f5_model = load_model(
|
|
280
|
+
model_cls,
|
|
281
|
+
model_arc,
|
|
282
|
+
str(artifacts.ckpt_file),
|
|
283
|
+
mel_spec_type=self._vocoder_name,
|
|
284
|
+
vocab_file=str(artifacts.vocab_file),
|
|
285
|
+
device=device,
|
|
286
|
+
)
|
|
287
|
+
else:
|
|
288
|
+
buf_out = io.StringIO()
|
|
289
|
+
buf_err = io.StringIO()
|
|
290
|
+
with contextlib.redirect_stdout(buf_out), contextlib.redirect_stderr(buf_err):
|
|
291
|
+
self._f5_vocoder = load_vocoder(vocoder_name=self._vocoder_name, device=device)
|
|
292
|
+
self._f5_model = load_model(
|
|
293
|
+
model_cls,
|
|
294
|
+
model_arc,
|
|
295
|
+
str(artifacts.ckpt_file),
|
|
296
|
+
mel_spec_type=self._vocoder_name,
|
|
297
|
+
vocab_file=str(artifacts.vocab_file),
|
|
298
|
+
device=device,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
self._f5_device = device
|
|
302
|
+
|
|
303
|
+
def _prepare_reference_wav(
|
|
304
|
+
self, reference_paths: Iterable[str | Path], *, target_sr: int = 24000, max_seconds: float = 15.0
|
|
305
|
+
) -> Path:
|
|
306
|
+
paths = [Path(p) for p in reference_paths]
|
|
307
|
+
if not paths:
|
|
308
|
+
raise ValueError("reference_paths must contain at least one path")
|
|
309
|
+
|
|
310
|
+
# Only support WAV/FLAC/OGG that soundfile can read reliably without extra system deps.
|
|
311
|
+
supported = {".wav", ".flac", ".ogg"}
|
|
312
|
+
for p in paths:
|
|
313
|
+
if p.suffix.lower() not in supported:
|
|
314
|
+
raise ValueError(
|
|
315
|
+
f"Unsupported reference audio format: {p.suffix}. "
|
|
316
|
+
f"Provide WAV/FLAC/OGG (got: {p})."
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
merged: List[np.ndarray] = []
|
|
320
|
+
for p in paths:
|
|
321
|
+
mono, sr = _load_as_mono_float(p)
|
|
322
|
+
mono = linear_resample_mono(mono, sr, target_sr)
|
|
323
|
+
merged.append(mono)
|
|
324
|
+
|
|
325
|
+
audio = np.concatenate(merged) if merged else np.zeros((0,), dtype=np.float32)
|
|
326
|
+
max_len = int(target_sr * max_seconds)
|
|
327
|
+
if len(audio) > max_len:
|
|
328
|
+
audio = audio[:max_len]
|
|
329
|
+
|
|
330
|
+
# Write PCM16 WAV for maximum compatibility with downstream tools.
|
|
331
|
+
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
|
332
|
+
tmp.close()
|
|
333
|
+
sf.write(tmp.name, audio, target_sr, subtype="PCM_16")
|
|
334
|
+
return Path(tmp.name)
|
|
335
|
+
|
|
336
|
+
def infer_to_wav_bytes(
|
|
337
|
+
self,
|
|
338
|
+
*,
|
|
339
|
+
text: str,
|
|
340
|
+
reference_paths: Iterable[str | Path],
|
|
341
|
+
reference_text: Optional[str] = None,
|
|
342
|
+
speed: Optional[float] = None,
|
|
343
|
+
) -> bytes:
|
|
344
|
+
self._ensure_model_loaded()
|
|
345
|
+
|
|
346
|
+
ref_wav = self._prepare_reference_wav(reference_paths)
|
|
347
|
+
try:
|
|
348
|
+
if not reference_text:
|
|
349
|
+
# Deliberately do NOT auto-transcribe in the engine layer:
|
|
350
|
+
# it can implicitly download STT weights and pollute interactive UX.
|
|
351
|
+
raise RuntimeError(
|
|
352
|
+
"Missing reference_text for cloning.\n"
|
|
353
|
+
"Provide reference_text when cloning, or set it via the voice store."
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
from f5_tts.infer.utils_infer import infer_process, preprocess_ref_audio_text
|
|
357
|
+
import contextlib
|
|
358
|
+
import io
|
|
359
|
+
import warnings
|
|
360
|
+
|
|
361
|
+
# f5_tts prints a lot (progress bars, ref_text, batching info).
|
|
362
|
+
# Keep default UX clean unless debug is enabled.
|
|
363
|
+
out_buf = io.StringIO()
|
|
364
|
+
err_buf = io.StringIO()
|
|
365
|
+
stdout_cm = contextlib.nullcontext() if self.debug else contextlib.redirect_stdout(out_buf)
|
|
366
|
+
stderr_cm = contextlib.nullcontext() if self.debug else contextlib.redirect_stderr(err_buf)
|
|
367
|
+
|
|
368
|
+
with warnings.catch_warnings():
|
|
369
|
+
# Torchaudio emits noisy deprecation warnings; they don't help users here.
|
|
370
|
+
warnings.filterwarnings("ignore", category=UserWarning, module=r"torchaudio\..*")
|
|
371
|
+
# NOTE: redirecting sys.stdout affects the REPL spinner thread (global),
|
|
372
|
+
# so we avoid redirects here. Instead, we call lower-level primitives
|
|
373
|
+
# that don't print when properly configured.
|
|
374
|
+
|
|
375
|
+
# Minimal normalization: f5_tts expects sentence-ending punctuation.
|
|
376
|
+
ref_text = str(reference_text or "").strip()
|
|
377
|
+
if ref_text and not (ref_text.endswith(". ") or ref_text.endswith("。") or ref_text.endswith(".")):
|
|
378
|
+
ref_text = ref_text + ". "
|
|
379
|
+
elif ref_text.endswith("."):
|
|
380
|
+
ref_text = ref_text + " "
|
|
381
|
+
|
|
382
|
+
# Avoid f5_tts preprocess_ref_audio_text() because it prints loudly.
|
|
383
|
+
# We already clipped/resampled reference audio in _prepare_reference_wav().
|
|
384
|
+
ref_audio_path = str(ref_wav)
|
|
385
|
+
|
|
386
|
+
# Build gen_text batches with a simple chunker (no prints).
|
|
387
|
+
gen_text = str(text)
|
|
388
|
+
batches: List[str] = []
|
|
389
|
+
max_chars = 160
|
|
390
|
+
cur = ""
|
|
391
|
+
for part in gen_text.replace("\n", " ").split(" "):
|
|
392
|
+
if not part:
|
|
393
|
+
continue
|
|
394
|
+
if len((cur + " " + part).strip()) <= max_chars:
|
|
395
|
+
cur = (cur + " " + part).strip()
|
|
396
|
+
else:
|
|
397
|
+
if cur:
|
|
398
|
+
batches.append(cur)
|
|
399
|
+
cur = part.strip()
|
|
400
|
+
if cur:
|
|
401
|
+
batches.append(cur)
|
|
402
|
+
|
|
403
|
+
from f5_tts.infer.utils_infer import infer_batch_process
|
|
404
|
+
import numpy as _np
|
|
405
|
+
audio, sr = _load_as_torch_channels_first(Path(ref_audio_path))
|
|
406
|
+
# infer_batch_process returns a generator yielding final_wave at the end.
|
|
407
|
+
final_wave, final_sr, _spec = next(
|
|
408
|
+
infer_batch_process(
|
|
409
|
+
(audio, sr),
|
|
410
|
+
ref_text if ref_text else " ", # must not be empty
|
|
411
|
+
batches or [" "],
|
|
412
|
+
self._f5_model,
|
|
413
|
+
self._f5_vocoder,
|
|
414
|
+
mel_spec_type=self._vocoder_name,
|
|
415
|
+
progress=None,
|
|
416
|
+
target_rms=self._target_rms,
|
|
417
|
+
cross_fade_duration=self._cross_fade_duration,
|
|
418
|
+
nfe_step=self._nfe_step,
|
|
419
|
+
cfg_strength=self._cfg_strength,
|
|
420
|
+
sway_sampling_coef=self._sway_sampling_coef,
|
|
421
|
+
speed=float(speed) if speed is not None else 1.0,
|
|
422
|
+
fix_duration=None,
|
|
423
|
+
device=self._f5_device,
|
|
424
|
+
streaming=False,
|
|
425
|
+
)
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
audio_segment = _np.asarray(final_wave, dtype=_np.float32)
|
|
429
|
+
final_sr = int(final_sr)
|
|
430
|
+
|
|
431
|
+
buf = io.BytesIO()
|
|
432
|
+
sf.write(buf, audio_segment, int(final_sr), format="WAV", subtype="PCM_16")
|
|
433
|
+
return buf.getvalue()
|
|
434
|
+
finally:
|
|
435
|
+
try:
|
|
436
|
+
Path(ref_wav).unlink(missing_ok=True) # type: ignore[arg-type]
|
|
437
|
+
except Exception:
|
|
438
|
+
pass
|
|
439
|
+
|
|
440
|
+
def infer_to_audio_chunks(
|
|
441
|
+
self,
|
|
442
|
+
*,
|
|
443
|
+
text: str,
|
|
444
|
+
reference_paths: Iterable[str | Path],
|
|
445
|
+
reference_text: Optional[str] = None,
|
|
446
|
+
speed: Optional[float] = None,
|
|
447
|
+
max_chars: int = 120,
|
|
448
|
+
chunk_size: int = 2048,
|
|
449
|
+
):
|
|
450
|
+
"""Yield (audio_chunk, sample_rate) for progressive playback.
|
|
451
|
+
|
|
452
|
+
Note: F5 sampling itself is not truly streaming mid-step. This yields chunks
|
|
453
|
+
after each batch completes, which is still valuable for perceived latency.
|
|
454
|
+
"""
|
|
455
|
+
self._ensure_model_loaded()
|
|
456
|
+
|
|
457
|
+
ref_wav = self._prepare_reference_wav(reference_paths)
|
|
458
|
+
try:
|
|
459
|
+
if not reference_text:
|
|
460
|
+
raise RuntimeError(
|
|
461
|
+
"Missing reference_text for cloning.\n"
|
|
462
|
+
"Provide reference_text when cloning, or set it via the voice store."
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
import warnings
|
|
466
|
+
with warnings.catch_warnings():
|
|
467
|
+
# Keep REPL and API logs clean.
|
|
468
|
+
warnings.filterwarnings("ignore", category=UserWarning, module=r"torchaudio\..*")
|
|
469
|
+
warnings.filterwarnings("ignore", category=UserWarning, message=r".*TorchCodec.*")
|
|
470
|
+
|
|
471
|
+
ref_text = str(reference_text or "").strip()
|
|
472
|
+
if ref_text and not (ref_text.endswith(". ") or ref_text.endswith("。") or ref_text.endswith(".")):
|
|
473
|
+
ref_text = ref_text + ". "
|
|
474
|
+
elif ref_text.endswith("."):
|
|
475
|
+
ref_text = ref_text + " "
|
|
476
|
+
|
|
477
|
+
# Prefer sentence boundaries to reduce audible "cuts".
|
|
478
|
+
import re
|
|
479
|
+
|
|
480
|
+
def _split_batches(s: str, limit: int) -> List[str]:
|
|
481
|
+
s = " ".join(str(s).replace("\n", " ").split()).strip()
|
|
482
|
+
if not s:
|
|
483
|
+
return []
|
|
484
|
+
sentences = re.split(r"(?<=[\.\!\?\。])\s+", s)
|
|
485
|
+
out: List[str] = []
|
|
486
|
+
cur_s = ""
|
|
487
|
+
for sent in sentences:
|
|
488
|
+
sent = sent.strip()
|
|
489
|
+
if not sent:
|
|
490
|
+
continue
|
|
491
|
+
if len(sent) > limit:
|
|
492
|
+
# Fallback: word-based chunking for very long sentences.
|
|
493
|
+
words = sent.split(" ")
|
|
494
|
+
tmp = ""
|
|
495
|
+
for w in words:
|
|
496
|
+
cand = (tmp + " " + w).strip()
|
|
497
|
+
if len(cand) <= limit:
|
|
498
|
+
tmp = cand
|
|
499
|
+
else:
|
|
500
|
+
if tmp:
|
|
501
|
+
out.append(tmp)
|
|
502
|
+
tmp = w
|
|
503
|
+
if tmp:
|
|
504
|
+
out.append(tmp)
|
|
505
|
+
continue
|
|
506
|
+
cand = (cur_s + " " + sent).strip()
|
|
507
|
+
if len(cand) <= limit:
|
|
508
|
+
cur_s = cand
|
|
509
|
+
else:
|
|
510
|
+
if cur_s:
|
|
511
|
+
out.append(cur_s)
|
|
512
|
+
cur_s = sent
|
|
513
|
+
if cur_s:
|
|
514
|
+
out.append(cur_s)
|
|
515
|
+
return out
|
|
516
|
+
|
|
517
|
+
batches = _split_batches(text, int(max_chars)) or [" "]
|
|
518
|
+
|
|
519
|
+
from f5_tts.infer.utils_infer import infer_batch_process
|
|
520
|
+
audio, sr = _load_as_torch_channels_first(Path(ref_wav))
|
|
521
|
+
|
|
522
|
+
for chunk, sr_out in infer_batch_process(
|
|
523
|
+
(audio, sr),
|
|
524
|
+
ref_text if ref_text else " ",
|
|
525
|
+
batches,
|
|
526
|
+
self._f5_model,
|
|
527
|
+
self._f5_vocoder,
|
|
528
|
+
mel_spec_type=self._vocoder_name,
|
|
529
|
+
progress=None,
|
|
530
|
+
target_rms=self._target_rms,
|
|
531
|
+
cross_fade_duration=self._cross_fade_duration,
|
|
532
|
+
nfe_step=self._nfe_step,
|
|
533
|
+
cfg_strength=self._cfg_strength,
|
|
534
|
+
sway_sampling_coef=self._sway_sampling_coef,
|
|
535
|
+
speed=float(speed) if speed is not None else 1.0,
|
|
536
|
+
fix_duration=None,
|
|
537
|
+
device=self._f5_device,
|
|
538
|
+
streaming=True,
|
|
539
|
+
chunk_size=int(chunk_size),
|
|
540
|
+
):
|
|
541
|
+
yield np.asarray(chunk, dtype=np.float32), int(sr_out)
|
|
542
|
+
finally:
|
|
543
|
+
try:
|
|
544
|
+
Path(ref_wav).unlink(missing_ok=True) # type: ignore[arg-type]
|
|
545
|
+
except Exception:
|
|
546
|
+
pass
|