abstractvoice 0.5.1__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. abstractvoice/__init__.py +2 -5
  2. abstractvoice/__main__.py +82 -3
  3. abstractvoice/adapters/__init__.py +12 -0
  4. abstractvoice/adapters/base.py +207 -0
  5. abstractvoice/adapters/stt_faster_whisper.py +401 -0
  6. abstractvoice/adapters/tts_piper.py +480 -0
  7. abstractvoice/aec/__init__.py +10 -0
  8. abstractvoice/aec/webrtc_apm.py +56 -0
  9. abstractvoice/artifacts.py +173 -0
  10. abstractvoice/audio/__init__.py +7 -0
  11. abstractvoice/audio/recorder.py +46 -0
  12. abstractvoice/audio/resample.py +25 -0
  13. abstractvoice/cloning/__init__.py +7 -0
  14. abstractvoice/cloning/engine_chroma.py +738 -0
  15. abstractvoice/cloning/engine_f5.py +546 -0
  16. abstractvoice/cloning/manager.py +349 -0
  17. abstractvoice/cloning/store.py +362 -0
  18. abstractvoice/compute/__init__.py +6 -0
  19. abstractvoice/compute/device.py +73 -0
  20. abstractvoice/config/__init__.py +2 -0
  21. abstractvoice/config/voice_catalog.py +19 -0
  22. abstractvoice/dependency_check.py +0 -1
  23. abstractvoice/examples/cli_repl.py +2403 -243
  24. abstractvoice/examples/voice_cli.py +64 -63
  25. abstractvoice/integrations/__init__.py +2 -0
  26. abstractvoice/integrations/abstractcore.py +116 -0
  27. abstractvoice/integrations/abstractcore_plugin.py +253 -0
  28. abstractvoice/prefetch.py +82 -0
  29. abstractvoice/recognition.py +424 -42
  30. abstractvoice/stop_phrase.py +103 -0
  31. abstractvoice/tts/__init__.py +3 -3
  32. abstractvoice/tts/adapter_tts_engine.py +210 -0
  33. abstractvoice/tts/tts_engine.py +257 -1208
  34. abstractvoice/vm/__init__.py +2 -0
  35. abstractvoice/vm/common.py +21 -0
  36. abstractvoice/vm/core.py +139 -0
  37. abstractvoice/vm/manager.py +108 -0
  38. abstractvoice/vm/stt_mixin.py +158 -0
  39. abstractvoice/vm/tts_mixin.py +550 -0
  40. abstractvoice/voice_manager.py +6 -1061
  41. abstractvoice-0.6.1.dist-info/METADATA +213 -0
  42. abstractvoice-0.6.1.dist-info/RECORD +52 -0
  43. {abstractvoice-0.5.1.dist-info → abstractvoice-0.6.1.dist-info}/WHEEL +1 -1
  44. abstractvoice-0.6.1.dist-info/entry_points.txt +6 -0
  45. abstractvoice/instant_setup.py +0 -83
  46. abstractvoice/simple_model_manager.py +0 -539
  47. abstractvoice-0.5.1.dist-info/METADATA +0 -1458
  48. abstractvoice-0.5.1.dist-info/RECORD +0 -23
  49. abstractvoice-0.5.1.dist-info/entry_points.txt +0 -2
  50. {abstractvoice-0.5.1.dist-info → abstractvoice-0.6.1.dist-info}/licenses/LICENSE +0 -0
  51. {abstractvoice-0.5.1.dist-info → abstractvoice-0.6.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,546 @@
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import os
5
+ import tempfile
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Iterable, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import soundfile as sf
12
+
13
+ from ..audio.resample import linear_resample_mono
14
+
15
+
16
+ def _load_as_mono_float(path: Path) -> Tuple[np.ndarray, int]:
17
+ audio, sr = sf.read(str(path), always_2d=True, dtype="float32")
18
+ # downmix
19
+ mono = np.mean(audio, axis=1).astype(np.float32)
20
+ return mono, int(sr)
21
+
22
+
23
+ def _load_as_torch_channels_first(path: Path):
24
+ """Load audio as a float32 torch Tensor shaped (channels, frames).
25
+
26
+ We prefer `soundfile` over `torchaudio.load()` because torchaudio's I/O backend
27
+ can vary by version (e.g. TorchCodec requirements) and may emit noisy stderr
28
+ logs during decode that corrupt interactive CLI output.
29
+ """
30
+ try:
31
+ import torch
32
+ except Exception as e: # pragma: no cover - torch is required by f5_tts runtime anyway
33
+ raise RuntimeError("torch is required for F5 cloning inference") from e
34
+
35
+ audio, sr = sf.read(str(path), always_2d=True, dtype="float32")
36
+ # soundfile: (frames, channels) -> torch: (channels, frames)
37
+ arr = np.ascontiguousarray(audio.T, dtype=np.float32)
38
+ return torch.from_numpy(arr), int(sr)
39
+
40
+
41
+ @dataclass(frozen=True)
42
+ class OpenF5Artifacts:
43
+ model_cfg: Path
44
+ ckpt_file: Path
45
+ vocab_file: Path
46
+
47
+
48
+ class F5TTSVoiceCloningEngine:
49
+ """In-process F5-TTS voice cloning engine (optional extra).
50
+
51
+ Why in-process
52
+ --------------
53
+ The CLI approach re-loads a multi-GB model on every utterance (very slow).
54
+ Running in-process allows:
55
+ - one-time model/vocoder load
56
+ - per-voice reference preprocessing cache
57
+ - much lower latency per utterance
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ *,
63
+ whisper_model: str = "tiny",
64
+ debug: bool = False,
65
+ device: str = "auto",
66
+ nfe_step: int = 16,
67
+ cfg_strength: float = 2.0,
68
+ sway_sampling_coef: float = -1.0,
69
+ vocoder_name: str = "vocos",
70
+ target_rms: float = 0.1,
71
+ cross_fade_duration: float = 0.15,
72
+ ):
73
+ self.debug = debug
74
+ self._whisper_model = whisper_model
75
+ self._stt = None
76
+ self._device_pref = device
77
+
78
+ # Speed/quality knobs (lower nfe_step = faster, usually lower quality).
79
+ self._nfe_step = int(nfe_step)
80
+ self._cfg_strength = float(cfg_strength)
81
+ self._sway_sampling_coef = float(sway_sampling_coef)
82
+ self._vocoder_name = str(vocoder_name)
83
+ self._target_rms = float(target_rms)
84
+ self._cross_fade_duration = float(cross_fade_duration)
85
+ self._quality_preset = "balanced"
86
+
87
+ # Lazy heavy objects (loaded on first inference).
88
+ self._f5_model = None
89
+ self._f5_vocoder = None
90
+ self._f5_device = None
91
+
92
+ def unload(self) -> None:
93
+ """Best-effort release of loaded model/vocoder to free memory."""
94
+ self._f5_model = None
95
+ self._f5_vocoder = None
96
+ self._f5_device = None
97
+ # STT adapter can also hold memory; drop references.
98
+ self._stt = None
99
+ try:
100
+ gc.collect()
101
+ except Exception:
102
+ pass
103
+ try:
104
+ import torch
105
+
106
+ try:
107
+ if torch.cuda.is_available():
108
+ torch.cuda.empty_cache()
109
+ except Exception:
110
+ pass
111
+ try:
112
+ if hasattr(torch, "mps") and torch.backends.mps.is_available():
113
+ torch.mps.empty_cache()
114
+ except Exception:
115
+ pass
116
+ except Exception:
117
+ pass
118
+
119
+ def runtime_info(self) -> dict:
120
+ """Return best-effort runtime info for debugging/perf validation."""
121
+ info = {"requested_device": self._device_pref, "resolved_device": self._f5_device, "quality_preset": self._quality_preset}
122
+ try:
123
+ m = self._f5_model
124
+ if m is not None and hasattr(m, "parameters"):
125
+ p = next(iter(m.parameters()), None)
126
+ if p is not None and hasattr(p, "device"):
127
+ info["model_param_device"] = str(p.device)
128
+ except Exception:
129
+ pass
130
+ try:
131
+ import torch
132
+
133
+ info["torch_version"] = getattr(torch, "__version__", "?")
134
+ info["cuda_available"] = bool(torch.cuda.is_available())
135
+ try:
136
+ info["mps_available"] = bool(torch.backends.mps.is_available())
137
+ except Exception:
138
+ info["mps_available"] = False
139
+ except Exception:
140
+ pass
141
+ return info
142
+
143
+ def set_quality_preset(self, preset: str) -> None:
144
+ """Set speed/quality preset.
145
+
146
+ Presets tune diffusion steps; lower steps are faster but can reduce quality.
147
+ """
148
+ p = (preset or "").strip().lower()
149
+ if p not in ("fast", "balanced", "high"):
150
+ raise ValueError("preset must be one of: fast|balanced|high")
151
+ self._quality_preset = p
152
+ if p == "fast":
153
+ self._nfe_step = 8
154
+ self._cfg_strength = 1.8
155
+ elif p == "balanced":
156
+ self._nfe_step = 16
157
+ self._cfg_strength = 2.0
158
+ else:
159
+ self._nfe_step = 24
160
+ self._cfg_strength = 2.2
161
+
162
+ def _get_stt(self):
163
+ """Lazy-load STT to avoid surprise model downloads."""
164
+ if self._stt is None:
165
+ from ..adapters.stt_faster_whisper import FasterWhisperAdapter
166
+
167
+ self._stt = FasterWhisperAdapter(
168
+ model_size=self._whisper_model,
169
+ device="cpu",
170
+ compute_type="int8",
171
+ )
172
+ return self._stt
173
+
174
+ def _ensure_f5_runtime(self) -> None:
175
+ try:
176
+ import importlib.util
177
+
178
+ if importlib.util.find_spec("f5_tts") is None:
179
+ raise ImportError("f5_tts not installed")
180
+ except Exception as e:
181
+ raise RuntimeError(
182
+ "Voice cloning requires the optional dependency group.\n"
183
+ "Install with:\n"
184
+ " pip install \"abstractvoice[cloning]\"\n"
185
+ f"Original error: {e}"
186
+ ) from e
187
+
188
+ def _artifact_root(self) -> Path:
189
+ cache_dir = Path(os.path.expanduser("~/.cache/abstractvoice/openf5"))
190
+ cache_dir.mkdir(parents=True, exist_ok=True)
191
+ return cache_dir
192
+
193
+ def _resolve_openf5_artifacts_local(self) -> OpenF5Artifacts:
194
+ """Resolve artifacts from the local cache directory without any network calls."""
195
+ root = self._artifact_root()
196
+ cfg = next(iter(root.rglob("*.yaml")), None) or next(iter(root.rglob("*.yml")), None)
197
+ ckpt = next(iter(root.rglob("*.pt")), None)
198
+ vocab = next(iter(root.rglob("vocab*.txt")), None) or next(iter(root.rglob("*.txt")), None)
199
+ if not (cfg and ckpt and vocab):
200
+ raise RuntimeError(
201
+ "OpenF5 artifacts are not present locally.\n"
202
+ "In the REPL run: /cloning_download\n"
203
+ f"Looked under: {root}"
204
+ )
205
+ return OpenF5Artifacts(model_cfg=cfg, ckpt_file=ckpt, vocab_file=vocab)
206
+
207
+ def ensure_openf5_artifacts_downloaded(self) -> OpenF5Artifacts:
208
+ """Explicit prefetch entry point (REPL should call this, not speak())."""
209
+ # Keep output quiet in interactive contexts.
210
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
211
+ # Lazy import: keep core install light.
212
+ try:
213
+ from huggingface_hub import snapshot_download
214
+ except Exception as e:
215
+ raise RuntimeError(
216
+ "huggingface_hub is required to download OpenF5 artifacts.\n"
217
+ "Install with: pip install huggingface_hub"
218
+ ) from e
219
+
220
+ import warnings
221
+
222
+ with warnings.catch_warnings():
223
+ # huggingface_hub deprecated `local_dir_use_symlinks`; keep prefetch UX clean.
224
+ warnings.filterwarnings(
225
+ "ignore",
226
+ category=UserWarning,
227
+ message=r".*local_dir_use_symlinks.*deprecated.*",
228
+ )
229
+ snapshot_download(
230
+ repo_id="mrfakename/OpenF5-TTS-Base",
231
+ local_dir=str(self._artifact_root()),
232
+ )
233
+ return self._resolve_openf5_artifacts_local()
234
+
235
+ def are_openf5_artifacts_available(self) -> bool:
236
+ """Return True if artifacts are already present locally (no downloads)."""
237
+ root = self._artifact_root()
238
+ cfg = next(iter(root.rglob("*.yaml")), None) or next(iter(root.rglob("*.yml")), None)
239
+ ckpt = next(iter(root.rglob("*.pt")), None)
240
+ vocab = next(iter(root.rglob("vocab*.txt")), None) or next(iter(root.rglob("*.txt")), None)
241
+ return bool(cfg and ckpt and vocab)
242
+
243
+ def _resolve_device(self) -> str:
244
+ if self._device_pref and self._device_pref != "auto":
245
+ return str(self._device_pref)
246
+ from ..compute import best_torch_device
247
+
248
+ return best_torch_device()
249
+
250
+ def _ensure_model_loaded(self) -> None:
251
+ """Load vocoder + model once (expensive)."""
252
+ if self._f5_model is not None and self._f5_vocoder is not None and self._f5_device is not None:
253
+ return
254
+
255
+ self._ensure_f5_runtime()
256
+ artifacts = self._resolve_openf5_artifacts_local()
257
+
258
+ # Silence HF progress bars during internal downloads (REPL UX).
259
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
260
+
261
+ # Some f5_tts utilities print; keep it quiet unless debug.
262
+ import contextlib
263
+ import io
264
+
265
+ from omegaconf import OmegaConf
266
+ from hydra.utils import get_class
267
+
268
+ from f5_tts.infer.utils_infer import load_model, load_vocoder
269
+
270
+ device = self._resolve_device()
271
+
272
+ model_cfg = OmegaConf.load(str(artifacts.model_cfg))
273
+ model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
274
+ model_arc = model_cfg.model.arch
275
+
276
+ # load vocoder + model
277
+ if self.debug:
278
+ self._f5_vocoder = load_vocoder(vocoder_name=self._vocoder_name, device=device)
279
+ self._f5_model = load_model(
280
+ model_cls,
281
+ model_arc,
282
+ str(artifacts.ckpt_file),
283
+ mel_spec_type=self._vocoder_name,
284
+ vocab_file=str(artifacts.vocab_file),
285
+ device=device,
286
+ )
287
+ else:
288
+ buf_out = io.StringIO()
289
+ buf_err = io.StringIO()
290
+ with contextlib.redirect_stdout(buf_out), contextlib.redirect_stderr(buf_err):
291
+ self._f5_vocoder = load_vocoder(vocoder_name=self._vocoder_name, device=device)
292
+ self._f5_model = load_model(
293
+ model_cls,
294
+ model_arc,
295
+ str(artifacts.ckpt_file),
296
+ mel_spec_type=self._vocoder_name,
297
+ vocab_file=str(artifacts.vocab_file),
298
+ device=device,
299
+ )
300
+
301
+ self._f5_device = device
302
+
303
+ def _prepare_reference_wav(
304
+ self, reference_paths: Iterable[str | Path], *, target_sr: int = 24000, max_seconds: float = 15.0
305
+ ) -> Path:
306
+ paths = [Path(p) for p in reference_paths]
307
+ if not paths:
308
+ raise ValueError("reference_paths must contain at least one path")
309
+
310
+ # Only support WAV/FLAC/OGG that soundfile can read reliably without extra system deps.
311
+ supported = {".wav", ".flac", ".ogg"}
312
+ for p in paths:
313
+ if p.suffix.lower() not in supported:
314
+ raise ValueError(
315
+ f"Unsupported reference audio format: {p.suffix}. "
316
+ f"Provide WAV/FLAC/OGG (got: {p})."
317
+ )
318
+
319
+ merged: List[np.ndarray] = []
320
+ for p in paths:
321
+ mono, sr = _load_as_mono_float(p)
322
+ mono = linear_resample_mono(mono, sr, target_sr)
323
+ merged.append(mono)
324
+
325
+ audio = np.concatenate(merged) if merged else np.zeros((0,), dtype=np.float32)
326
+ max_len = int(target_sr * max_seconds)
327
+ if len(audio) > max_len:
328
+ audio = audio[:max_len]
329
+
330
+ # Write PCM16 WAV for maximum compatibility with downstream tools.
331
+ tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
332
+ tmp.close()
333
+ sf.write(tmp.name, audio, target_sr, subtype="PCM_16")
334
+ return Path(tmp.name)
335
+
336
+ def infer_to_wav_bytes(
337
+ self,
338
+ *,
339
+ text: str,
340
+ reference_paths: Iterable[str | Path],
341
+ reference_text: Optional[str] = None,
342
+ speed: Optional[float] = None,
343
+ ) -> bytes:
344
+ self._ensure_model_loaded()
345
+
346
+ ref_wav = self._prepare_reference_wav(reference_paths)
347
+ try:
348
+ if not reference_text:
349
+ # Deliberately do NOT auto-transcribe in the engine layer:
350
+ # it can implicitly download STT weights and pollute interactive UX.
351
+ raise RuntimeError(
352
+ "Missing reference_text for cloning.\n"
353
+ "Provide reference_text when cloning, or set it via the voice store."
354
+ )
355
+
356
+ from f5_tts.infer.utils_infer import infer_process, preprocess_ref_audio_text
357
+ import contextlib
358
+ import io
359
+ import warnings
360
+
361
+ # f5_tts prints a lot (progress bars, ref_text, batching info).
362
+ # Keep default UX clean unless debug is enabled.
363
+ out_buf = io.StringIO()
364
+ err_buf = io.StringIO()
365
+ stdout_cm = contextlib.nullcontext() if self.debug else contextlib.redirect_stdout(out_buf)
366
+ stderr_cm = contextlib.nullcontext() if self.debug else contextlib.redirect_stderr(err_buf)
367
+
368
+ with warnings.catch_warnings():
369
+ # Torchaudio emits noisy deprecation warnings; they don't help users here.
370
+ warnings.filterwarnings("ignore", category=UserWarning, module=r"torchaudio\..*")
371
+ # NOTE: redirecting sys.stdout affects the REPL spinner thread (global),
372
+ # so we avoid redirects here. Instead, we call lower-level primitives
373
+ # that don't print when properly configured.
374
+
375
+ # Minimal normalization: f5_tts expects sentence-ending punctuation.
376
+ ref_text = str(reference_text or "").strip()
377
+ if ref_text and not (ref_text.endswith(". ") or ref_text.endswith("。") or ref_text.endswith(".")):
378
+ ref_text = ref_text + ". "
379
+ elif ref_text.endswith("."):
380
+ ref_text = ref_text + " "
381
+
382
+ # Avoid f5_tts preprocess_ref_audio_text() because it prints loudly.
383
+ # We already clipped/resampled reference audio in _prepare_reference_wav().
384
+ ref_audio_path = str(ref_wav)
385
+
386
+ # Build gen_text batches with a simple chunker (no prints).
387
+ gen_text = str(text)
388
+ batches: List[str] = []
389
+ max_chars = 160
390
+ cur = ""
391
+ for part in gen_text.replace("\n", " ").split(" "):
392
+ if not part:
393
+ continue
394
+ if len((cur + " " + part).strip()) <= max_chars:
395
+ cur = (cur + " " + part).strip()
396
+ else:
397
+ if cur:
398
+ batches.append(cur)
399
+ cur = part.strip()
400
+ if cur:
401
+ batches.append(cur)
402
+
403
+ from f5_tts.infer.utils_infer import infer_batch_process
404
+ import numpy as _np
405
+ audio, sr = _load_as_torch_channels_first(Path(ref_audio_path))
406
+ # infer_batch_process returns a generator yielding final_wave at the end.
407
+ final_wave, final_sr, _spec = next(
408
+ infer_batch_process(
409
+ (audio, sr),
410
+ ref_text if ref_text else " ", # must not be empty
411
+ batches or [" "],
412
+ self._f5_model,
413
+ self._f5_vocoder,
414
+ mel_spec_type=self._vocoder_name,
415
+ progress=None,
416
+ target_rms=self._target_rms,
417
+ cross_fade_duration=self._cross_fade_duration,
418
+ nfe_step=self._nfe_step,
419
+ cfg_strength=self._cfg_strength,
420
+ sway_sampling_coef=self._sway_sampling_coef,
421
+ speed=float(speed) if speed is not None else 1.0,
422
+ fix_duration=None,
423
+ device=self._f5_device,
424
+ streaming=False,
425
+ )
426
+ )
427
+
428
+ audio_segment = _np.asarray(final_wave, dtype=_np.float32)
429
+ final_sr = int(final_sr)
430
+
431
+ buf = io.BytesIO()
432
+ sf.write(buf, audio_segment, int(final_sr), format="WAV", subtype="PCM_16")
433
+ return buf.getvalue()
434
+ finally:
435
+ try:
436
+ Path(ref_wav).unlink(missing_ok=True) # type: ignore[arg-type]
437
+ except Exception:
438
+ pass
439
+
440
+ def infer_to_audio_chunks(
441
+ self,
442
+ *,
443
+ text: str,
444
+ reference_paths: Iterable[str | Path],
445
+ reference_text: Optional[str] = None,
446
+ speed: Optional[float] = None,
447
+ max_chars: int = 120,
448
+ chunk_size: int = 2048,
449
+ ):
450
+ """Yield (audio_chunk, sample_rate) for progressive playback.
451
+
452
+ Note: F5 sampling itself is not truly streaming mid-step. This yields chunks
453
+ after each batch completes, which is still valuable for perceived latency.
454
+ """
455
+ self._ensure_model_loaded()
456
+
457
+ ref_wav = self._prepare_reference_wav(reference_paths)
458
+ try:
459
+ if not reference_text:
460
+ raise RuntimeError(
461
+ "Missing reference_text for cloning.\n"
462
+ "Provide reference_text when cloning, or set it via the voice store."
463
+ )
464
+
465
+ import warnings
466
+ with warnings.catch_warnings():
467
+ # Keep REPL and API logs clean.
468
+ warnings.filterwarnings("ignore", category=UserWarning, module=r"torchaudio\..*")
469
+ warnings.filterwarnings("ignore", category=UserWarning, message=r".*TorchCodec.*")
470
+
471
+ ref_text = str(reference_text or "").strip()
472
+ if ref_text and not (ref_text.endswith(". ") or ref_text.endswith("。") or ref_text.endswith(".")):
473
+ ref_text = ref_text + ". "
474
+ elif ref_text.endswith("."):
475
+ ref_text = ref_text + " "
476
+
477
+ # Prefer sentence boundaries to reduce audible "cuts".
478
+ import re
479
+
480
+ def _split_batches(s: str, limit: int) -> List[str]:
481
+ s = " ".join(str(s).replace("\n", " ").split()).strip()
482
+ if not s:
483
+ return []
484
+ sentences = re.split(r"(?<=[\.\!\?\。])\s+", s)
485
+ out: List[str] = []
486
+ cur_s = ""
487
+ for sent in sentences:
488
+ sent = sent.strip()
489
+ if not sent:
490
+ continue
491
+ if len(sent) > limit:
492
+ # Fallback: word-based chunking for very long sentences.
493
+ words = sent.split(" ")
494
+ tmp = ""
495
+ for w in words:
496
+ cand = (tmp + " " + w).strip()
497
+ if len(cand) <= limit:
498
+ tmp = cand
499
+ else:
500
+ if tmp:
501
+ out.append(tmp)
502
+ tmp = w
503
+ if tmp:
504
+ out.append(tmp)
505
+ continue
506
+ cand = (cur_s + " " + sent).strip()
507
+ if len(cand) <= limit:
508
+ cur_s = cand
509
+ else:
510
+ if cur_s:
511
+ out.append(cur_s)
512
+ cur_s = sent
513
+ if cur_s:
514
+ out.append(cur_s)
515
+ return out
516
+
517
+ batches = _split_batches(text, int(max_chars)) or [" "]
518
+
519
+ from f5_tts.infer.utils_infer import infer_batch_process
520
+ audio, sr = _load_as_torch_channels_first(Path(ref_wav))
521
+
522
+ for chunk, sr_out in infer_batch_process(
523
+ (audio, sr),
524
+ ref_text if ref_text else " ",
525
+ batches,
526
+ self._f5_model,
527
+ self._f5_vocoder,
528
+ mel_spec_type=self._vocoder_name,
529
+ progress=None,
530
+ target_rms=self._target_rms,
531
+ cross_fade_duration=self._cross_fade_duration,
532
+ nfe_step=self._nfe_step,
533
+ cfg_strength=self._cfg_strength,
534
+ sway_sampling_coef=self._sway_sampling_coef,
535
+ speed=float(speed) if speed is not None else 1.0,
536
+ fix_duration=None,
537
+ device=self._f5_device,
538
+ streaming=True,
539
+ chunk_size=int(chunk_size),
540
+ ):
541
+ yield np.asarray(chunk, dtype=np.float32), int(sr_out)
542
+ finally:
543
+ try:
544
+ Path(ref_wav).unlink(missing_ok=True) # type: ignore[arg-type]
545
+ except Exception:
546
+ pass