abstractvoice 0.5.1__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. abstractvoice/__init__.py +2 -5
  2. abstractvoice/__main__.py +82 -3
  3. abstractvoice/adapters/__init__.py +12 -0
  4. abstractvoice/adapters/base.py +207 -0
  5. abstractvoice/adapters/stt_faster_whisper.py +401 -0
  6. abstractvoice/adapters/tts_piper.py +480 -0
  7. abstractvoice/aec/__init__.py +10 -0
  8. abstractvoice/aec/webrtc_apm.py +56 -0
  9. abstractvoice/artifacts.py +173 -0
  10. abstractvoice/audio/__init__.py +7 -0
  11. abstractvoice/audio/recorder.py +46 -0
  12. abstractvoice/audio/resample.py +25 -0
  13. abstractvoice/cloning/__init__.py +7 -0
  14. abstractvoice/cloning/engine_chroma.py +738 -0
  15. abstractvoice/cloning/engine_f5.py +546 -0
  16. abstractvoice/cloning/manager.py +349 -0
  17. abstractvoice/cloning/store.py +362 -0
  18. abstractvoice/compute/__init__.py +6 -0
  19. abstractvoice/compute/device.py +73 -0
  20. abstractvoice/config/__init__.py +2 -0
  21. abstractvoice/config/voice_catalog.py +19 -0
  22. abstractvoice/dependency_check.py +0 -1
  23. abstractvoice/examples/cli_repl.py +2403 -243
  24. abstractvoice/examples/voice_cli.py +64 -63
  25. abstractvoice/integrations/__init__.py +2 -0
  26. abstractvoice/integrations/abstractcore.py +116 -0
  27. abstractvoice/integrations/abstractcore_plugin.py +253 -0
  28. abstractvoice/prefetch.py +82 -0
  29. abstractvoice/recognition.py +424 -42
  30. abstractvoice/stop_phrase.py +103 -0
  31. abstractvoice/tts/__init__.py +3 -3
  32. abstractvoice/tts/adapter_tts_engine.py +210 -0
  33. abstractvoice/tts/tts_engine.py +257 -1208
  34. abstractvoice/vm/__init__.py +2 -0
  35. abstractvoice/vm/common.py +21 -0
  36. abstractvoice/vm/core.py +139 -0
  37. abstractvoice/vm/manager.py +108 -0
  38. abstractvoice/vm/stt_mixin.py +158 -0
  39. abstractvoice/vm/tts_mixin.py +550 -0
  40. abstractvoice/voice_manager.py +6 -1061
  41. abstractvoice-0.6.1.dist-info/METADATA +213 -0
  42. abstractvoice-0.6.1.dist-info/RECORD +52 -0
  43. {abstractvoice-0.5.1.dist-info → abstractvoice-0.6.1.dist-info}/WHEEL +1 -1
  44. abstractvoice-0.6.1.dist-info/entry_points.txt +6 -0
  45. abstractvoice/instant_setup.py +0 -83
  46. abstractvoice/simple_model_manager.py +0 -539
  47. abstractvoice-0.5.1.dist-info/METADATA +0 -1458
  48. abstractvoice-0.5.1.dist-info/RECORD +0 -23
  49. abstractvoice-0.5.1.dist-info/entry_points.txt +0 -2
  50. {abstractvoice-0.5.1.dist-info → abstractvoice-0.6.1.dist-info}/licenses/LICENSE +0 -0
  51. {abstractvoice-0.5.1.dist-info → abstractvoice-0.6.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,738 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import gc
5
+ import hashlib
6
+ import io
7
+ import os
8
+ import warnings
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Dict, Iterable, List, Optional, Tuple
12
+
13
+ import numpy as np
14
+ import soundfile as sf
15
+
16
+ from ..tts.tts_engine import _SilenceStderrFD
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class ChromaArtifacts:
21
+ root: Path
22
+ model_id: str
23
+ revision: Optional[str]
24
+
25
+
26
+ class ChromaVoiceCloningEngine:
27
+ """In-process Chroma voice cloning engine (optional; GPU-heavy).
28
+
29
+ Design principles
30
+ -----------------
31
+ - Never download weights implicitly from speak()/speak_to_bytes().
32
+ - Provide explicit prefetch hooks.
33
+ - Keep imports lightweight until the engine is actually used.
34
+ """
35
+
36
+ DEFAULT_MODEL_ID = "FlashLabs/Chroma-4B"
37
+ # Pin a known revision by default (safer than floating remote code).
38
+ DEFAULT_REVISION = "864b4aea0c1359f91af62f1367df64657dc5e90f"
39
+
40
+ def __init__(
41
+ self,
42
+ *,
43
+ debug: bool = False,
44
+ device: str = "auto",
45
+ model_id: str = DEFAULT_MODEL_ID,
46
+ revision: str | None = DEFAULT_REVISION,
47
+ temperature: float = 0.7,
48
+ top_k: int = 50,
49
+ do_sample: bool = True,
50
+ max_new_tokens_per_chunk: int = 512,
51
+ ):
52
+ self.debug = bool(debug)
53
+ self._device_pref = str(device or "auto")
54
+ self._model_id = str(model_id or self.DEFAULT_MODEL_ID)
55
+ self._revision = revision if revision else None
56
+
57
+ self._temperature = float(temperature)
58
+ self._top_k = int(top_k)
59
+ self._do_sample = bool(do_sample)
60
+ self._max_new_tokens_per_chunk = int(max_new_tokens_per_chunk)
61
+
62
+ self._model = None
63
+ self._processor = None
64
+ self._resolved_device = None
65
+
66
+ def unload(self) -> None:
67
+ """Best-effort release of loaded model/processor to free memory."""
68
+ self._model = None
69
+ self._processor = None
70
+ self._resolved_device = None
71
+ try:
72
+ gc.collect()
73
+ except Exception:
74
+ pass
75
+ try:
76
+ import torch
77
+
78
+ # Best-effort GPU/MPS cache release.
79
+ try:
80
+ if torch.cuda.is_available():
81
+ torch.cuda.empty_cache()
82
+ except Exception:
83
+ pass
84
+ try:
85
+ if hasattr(torch, "mps") and torch.backends.mps.is_available():
86
+ torch.mps.empty_cache()
87
+ except Exception:
88
+ pass
89
+ except Exception:
90
+ pass
91
+
92
+ def runtime_info(self) -> Dict[str, object]:
93
+ info: Dict[str, object] = {
94
+ "model_id": self._model_id,
95
+ "revision": self._revision,
96
+ "requested_device": self._device_pref,
97
+ "resolved_device": self._resolved_device,
98
+ "temperature": self._temperature,
99
+ "top_k": self._top_k,
100
+ "do_sample": self._do_sample,
101
+ "max_new_tokens_per_chunk": self._max_new_tokens_per_chunk,
102
+ }
103
+ try:
104
+ import torch
105
+
106
+ info["torch_version"] = getattr(torch, "__version__", "?")
107
+ info["cuda_available"] = bool(torch.cuda.is_available())
108
+ try:
109
+ info["mps_available"] = bool(torch.backends.mps.is_available())
110
+ except Exception:
111
+ info["mps_available"] = False
112
+ except Exception:
113
+ pass
114
+ return info
115
+
116
+ def set_quality_preset(self, preset: str) -> None:
117
+ p = (preset or "").strip().lower()
118
+ if p not in ("fast", "balanced", "high"):
119
+ raise ValueError("preset must be one of: fast|balanced|high")
120
+ if p == "fast":
121
+ self._temperature = 0.6
122
+ self._top_k = 30
123
+ self._max_new_tokens_per_chunk = 384
124
+ elif p == "balanced":
125
+ self._temperature = 0.7
126
+ self._top_k = 50
127
+ self._max_new_tokens_per_chunk = 512
128
+ else:
129
+ self._temperature = 0.75
130
+ self._top_k = 80
131
+ self._max_new_tokens_per_chunk = 768
132
+
133
+ def _artifact_root(self) -> Path:
134
+ cache_dir = Path(os.path.expanduser("~/.cache/abstractvoice/chroma"))
135
+ cache_dir.mkdir(parents=True, exist_ok=True)
136
+ return cache_dir
137
+
138
+ def _resolve_chroma_artifacts_local(self) -> ChromaArtifacts:
139
+ root = self._artifact_root()
140
+ required = [
141
+ "config.json",
142
+ "processor_config.json",
143
+ "model.safetensors.index.json",
144
+ "modeling_chroma.py",
145
+ "processing_chroma.py",
146
+ "configuration_chroma.py",
147
+ ]
148
+ missing = [name for name in required if not (root / name).exists()]
149
+ if missing:
150
+ raise RuntimeError(
151
+ "Chroma artifacts are not present locally.\n"
152
+ "Prefetch explicitly (outside the REPL):\n"
153
+ " abstractvoice-prefetch --chroma\n"
154
+ "or:\n"
155
+ " python -m abstractvoice download --chroma\n"
156
+ f"Looked under: {root}\n"
157
+ f"Missing: {', '.join(missing)}"
158
+ )
159
+ return ChromaArtifacts(root=root, model_id=self._model_id, revision=self._revision)
160
+
161
+ def ensure_chroma_artifacts_downloaded(self) -> ChromaArtifacts:
162
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
163
+ try:
164
+ from huggingface_hub import snapshot_download
165
+ except Exception as e:
166
+ raise RuntimeError(
167
+ "huggingface_hub is required to download Chroma artifacts.\n"
168
+ "Install with: pip install huggingface_hub"
169
+ ) from e
170
+
171
+ import warnings
172
+
173
+ root = self._artifact_root()
174
+ with warnings.catch_warnings():
175
+ warnings.filterwarnings(
176
+ "ignore",
177
+ category=UserWarning,
178
+ message=r".*local_dir_use_symlinks.*deprecated.*",
179
+ )
180
+ snapshot_download(
181
+ repo_id=self._model_id,
182
+ revision=self._revision,
183
+ local_dir=str(root),
184
+ )
185
+ return self._resolve_chroma_artifacts_local()
186
+
187
+ def are_chroma_artifacts_available(self) -> bool:
188
+ root = self._artifact_root()
189
+ return bool((root / "config.json").exists() and (root / "model.safetensors.index.json").exists())
190
+
191
+ def _resolve_device(self) -> str:
192
+ if self._device_pref and self._device_pref != "auto":
193
+ return str(self._device_pref)
194
+ try:
195
+ from ..compute import best_torch_device
196
+
197
+ return best_torch_device()
198
+ except Exception:
199
+ return "cpu"
200
+
201
+ def _ensure_chroma_runtime(self) -> None:
202
+ try:
203
+ import importlib.util
204
+
205
+ if importlib.util.find_spec("torch") is None:
206
+ raise ImportError("torch not installed")
207
+ if importlib.util.find_spec("transformers") is None:
208
+ raise ImportError("transformers not installed")
209
+ except Exception as e:
210
+ raise RuntimeError(
211
+ "Chroma requires the optional dependency group.\n"
212
+ "Install with:\n"
213
+ " pip install \"abstractvoice[chroma]\"\n"
214
+ f"Original error: {e}"
215
+ ) from e
216
+
217
+ def _ensure_model_loaded(self) -> None:
218
+ if self._model is not None and self._processor is not None and self._resolved_device is not None:
219
+ return
220
+
221
+ self._ensure_chroma_runtime()
222
+ artifacts = self._resolve_chroma_artifacts_local()
223
+
224
+ # Keep interactive UX quiet by default (the REPL provides its own spinner).
225
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
226
+ os.environ.setdefault("TRANSFORMERS_NO_TQDM", "1")
227
+
228
+ # Keep interactive UX quiet by default (the REPL provides its own spinner).
229
+ try:
230
+ from transformers.utils import logging as hf_logging
231
+
232
+ hf_logging.disable_progress_bar()
233
+ if not self.debug:
234
+ hf_logging.set_verbosity_error()
235
+ except Exception:
236
+ pass
237
+
238
+ import torch
239
+ from transformers import AutoModelForCausalLM, AutoProcessor
240
+
241
+ device = self._resolve_device()
242
+ self._resolved_device = device
243
+
244
+ torch_dtype = None
245
+ if device == "cuda":
246
+ torch_dtype = torch.bfloat16
247
+ elif device == "mps":
248
+ torch_dtype = torch.float16
249
+
250
+ device_map = None
251
+ # `device_map="auto"` is most reliable on CUDA; MPS/CPU should load normally then move.
252
+ if self._device_pref == "auto" and device == "cuda":
253
+ device_map = "auto"
254
+
255
+ # transformers 5.0.0 deprecates `torch_dtype` in favor of `dtype`.
256
+ with warnings.catch_warnings():
257
+ # Torch/Torchaudio/Transformers often emit noisy warnings that are not actionable
258
+ # for interactive users. Keep them hidden unless debug is enabled.
259
+ if not self.debug:
260
+ warnings.filterwarnings("ignore", category=UserWarning, module=r"torchaudio\..*")
261
+ warnings.filterwarnings("ignore", category=UserWarning, message=r".*TorchCodec.*")
262
+ warnings.filterwarnings("ignore", category=UserWarning, message=r".*output_attentions.*")
263
+ stdout_ctx = contextlib.redirect_stdout(io.StringIO()) if not self.debug else contextlib.nullcontext()
264
+ with stdout_ctx:
265
+ with _SilenceStderrFD(enabled=not self.debug):
266
+ try:
267
+ model = AutoModelForCausalLM.from_pretrained(
268
+ str(artifacts.root),
269
+ trust_remote_code=True,
270
+ device_map=device_map,
271
+ dtype=torch_dtype,
272
+ )
273
+ except TypeError:
274
+ model = AutoModelForCausalLM.from_pretrained(
275
+ str(artifacts.root),
276
+ trust_remote_code=True,
277
+ device_map=device_map,
278
+ torch_dtype=torch_dtype,
279
+ )
280
+ model.eval()
281
+
282
+ stdout_ctx = contextlib.redirect_stdout(io.StringIO()) if not self.debug else contextlib.nullcontext()
283
+ with stdout_ctx:
284
+ with _SilenceStderrFD(enabled=not self.debug):
285
+ processor = AutoProcessor.from_pretrained(str(artifacts.root), trust_remote_code=True)
286
+
287
+ if device != "cuda" and device_map is None:
288
+ try:
289
+ model.to(device)
290
+ except Exception:
291
+ pass
292
+
293
+ # Ensure tied weights are applied (Chroma relies on tied audio embeddings).
294
+ try:
295
+ if hasattr(model, "tie_weights"):
296
+ model.tie_weights()
297
+ except Exception:
298
+ pass
299
+ # Some releases log this key as MISSING even when it is meant to be tied.
300
+ # Ensure backbone audio embeddings match decoder embeddings (best-effort).
301
+ try:
302
+ if hasattr(model, "backbone") and hasattr(model, "decoder"):
303
+ b = getattr(model.backbone, "audio_embedding", None)
304
+ d = getattr(model.decoder, "audio_embedding", None)
305
+ if b is not None and d is not None:
306
+ bw = getattr(b, "embed_audio_tokens", None)
307
+ dw = getattr(d, "embed_audio_tokens", None)
308
+ if bw is not None and dw is not None and hasattr(bw, "weight") and hasattr(dw, "weight"):
309
+ if getattr(bw.weight, "shape", None) == getattr(dw.weight, "shape", None):
310
+ bw.weight.data.copy_(dw.weight.data)
311
+ # Prefer a hard tie (shared module) to avoid one side staying randomly
312
+ # initialized when the checkpoint omits a tied key.
313
+ try:
314
+ if bw is not None and dw is not None and bw is not dw:
315
+ b.embed_audio_tokens = dw
316
+ except Exception:
317
+ pass
318
+ except Exception:
319
+ pass
320
+
321
+ self._patch_generation_compat(model)
322
+
323
+ self._model = model
324
+ self._processor = processor
325
+
326
+ def _estimate_max_new_tokens(self, text: str, model) -> int:
327
+ """Estimate a safe `max_new_tokens` (audio frames) for a text batch.
328
+
329
+ Chroma audio frames decode at roughly:
330
+ frames_per_second = sampling_rate / audio_frame_freq
331
+ For the default config this is ~12.5 fps, so `512` frames ~= 41s.
332
+ Without a good stopping signal, the model may generate until `max_new_tokens`,
333
+ so we cap it based on expected speech duration to avoid long noisy tails.
334
+ """
335
+ s = " ".join(str(text or "").split()).strip()
336
+ words = len(s.split()) if s else 0
337
+ if words <= 0:
338
+ # Fallback: roughly 5 chars/word in English.
339
+ words = max(1, int(round(len(s) / 5.0))) if s else 1
340
+
341
+ # Conservative speech-rate assumption (words/sec).
342
+ wps = 3.2
343
+ est_s = float(words) / float(wps)
344
+ # Small pad to avoid truncation on short utterances.
345
+ est_s += 0.4
346
+
347
+ sr = 24000
348
+ frame_freq = 1920
349
+ try:
350
+ cfg = getattr(model, "config", None)
351
+ if cfg is not None:
352
+ frame_freq = int(getattr(cfg, "audio_frame_freq", frame_freq) or frame_freq)
353
+ codec_cfg = getattr(cfg, "codec_config", None)
354
+ if codec_cfg is not None:
355
+ if isinstance(codec_cfg, dict):
356
+ sr = int(codec_cfg.get("sampling_rate", sr) or sr)
357
+ else:
358
+ sr = int(getattr(codec_cfg, "sampling_rate", sr) or sr)
359
+ except Exception:
360
+ pass
361
+
362
+ fps = float(sr) / float(frame_freq) if frame_freq else 12.5
363
+ # Add slack so we still have room for slower speech.
364
+ frames = int(round(est_s * fps * 1.25))
365
+ frames = max(64, frames)
366
+ cap = int(self._max_new_tokens_per_chunk) if self._max_new_tokens_per_chunk else frames
367
+ return int(min(frames, cap))
368
+
369
+ def _patch_generation_compat(self, model) -> None:
370
+ """Patch known transformers incompatibilities for Chroma remote code.
371
+
372
+ Upstream Chroma code (pinned revision) expects a `use_model_defaults` positional arg
373
+ in `GenerationMixin._prepare_generation_config`. transformers 5.0.0 removed it.
374
+ """
375
+ try:
376
+ import inspect
377
+
378
+ from transformers.generation import GenerationMode
379
+ from transformers.generation.utils import GenerationMixin
380
+
381
+ base_sig = inspect.signature(GenerationMixin._prepare_generation_config)
382
+ if "use_model_defaults" in base_sig.parameters:
383
+ return
384
+
385
+ chroma_gen_cls = None
386
+ for cls in model.__class__.mro():
387
+ if cls.__name__ == "ChromaGenerationMixin":
388
+ chroma_gen_cls = cls
389
+ break
390
+ if chroma_gen_cls is None:
391
+ return
392
+
393
+ current = getattr(chroma_gen_cls, "_prepare_generation_config", None)
394
+ if current is None or getattr(current, "_abstractvoice_patched", False):
395
+ return
396
+
397
+ def _patched_prepare_generation_config(
398
+ self, generation_config=None, use_model_defaults=None, **kwargs
399
+ ):
400
+ depth_decoder_kwargs = {k[len("decoder_") :]: v for k, v in kwargs.items() if k.startswith("decoder_")}
401
+ kwargs = {k: v for k, v in kwargs.items() if not k.startswith("decoder_")}
402
+
403
+ # transformers 5.0.0 removed the `use_model_defaults` positional arg; keep compatibility.
404
+ try:
405
+ generation_config_out, model_kwargs = super(chroma_gen_cls, self)._prepare_generation_config(
406
+ generation_config, use_model_defaults, **kwargs
407
+ )
408
+ except TypeError:
409
+ generation_config_out, model_kwargs = super(chroma_gen_cls, self)._prepare_generation_config(
410
+ generation_config, **kwargs
411
+ )
412
+
413
+ try:
414
+ self.decoder.generation_config.update(**depth_decoder_kwargs)
415
+ except Exception:
416
+ pass
417
+
418
+ try:
419
+ decoder_min_new_tokens = getattr(self.decoder.generation_config, "min_new_tokens") or (
420
+ self.decoder.config.audio_num_codebooks - 1
421
+ )
422
+ decoder_max_new_tokens = getattr(self.decoder.generation_config, "max_new_tokens") or (
423
+ self.decoder.config.audio_num_codebooks - 1
424
+ )
425
+
426
+ if {decoder_min_new_tokens, decoder_max_new_tokens} != {self.decoder.config.audio_num_codebooks - 1}:
427
+ raise ValueError(
428
+ "decoder_generation_config's min_new_tokens "
429
+ f"({decoder_min_new_tokens}) and max_new_tokens ({decoder_max_new_tokens}) "
430
+ f"must be equal to self.config.num_codebooks - 1 ({self.decoder.config.audio_num_codebooks - 1})"
431
+ )
432
+ elif getattr(self.decoder.generation_config, "return_dict_in_generate", False):
433
+ self.decoder.generation_config.return_dict_in_generate = False
434
+ except Exception:
435
+ pass
436
+
437
+ # Monkey patch the get_generation_mode method to support Chroma model.
438
+ try:
439
+ original_get_generation_mode = generation_config_out.get_generation_mode
440
+
441
+ def patched_get_generation_mode(assistant_model=None):
442
+ generation_mode = original_get_generation_mode(assistant_model)
443
+ if generation_mode not in (GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE):
444
+ raise ValueError(
445
+ f"Generation mode {generation_mode} is not supported for Chroma model. "
446
+ "Please set generation parameters to use greedy or sampling generation."
447
+ )
448
+ return generation_mode
449
+
450
+ generation_config_out.get_generation_mode = patched_get_generation_mode
451
+ except Exception:
452
+ pass
453
+
454
+ return generation_config_out, model_kwargs
455
+
456
+ _patched_prepare_generation_config._abstractvoice_patched = True # type: ignore[attr-defined]
457
+ chroma_gen_cls._prepare_generation_config = _patched_prepare_generation_config # type: ignore[assignment]
458
+ except Exception:
459
+ # Best-effort only: do not break core flows if patching fails.
460
+ return
461
+
462
+ def _select_prompt_audio(self, reference_paths: Iterable[str | Path]) -> Path:
463
+ paths = [Path(p) for p in reference_paths]
464
+ if not paths:
465
+ raise ValueError("reference_paths must contain at least one path")
466
+ for p in paths:
467
+ if not p.exists():
468
+ raise FileNotFoundError(str(p))
469
+ # Chroma prompt_audio currently expects a single audio file.
470
+ return paths[0]
471
+
472
+ def _prepare_prompt_audio_for_processor(self, prompt_audio: Path) -> Path:
473
+ """Best-effort prompt-audio normalization for Chroma.
474
+
475
+ Chroma's processor loads audio via torchaudio and resamples internally.
476
+ We keep this preprocessing minimal (mono + 24kHz + PCM16 WAV) to reduce
477
+ backend variability without altering voice characteristics.
478
+ """
479
+ try:
480
+ if not prompt_audio.exists():
481
+ return prompt_audio
482
+
483
+ # Cache by (path, size, mtime) to avoid recomputing every turn.
484
+ try:
485
+ st = prompt_audio.stat()
486
+ key = f"{prompt_audio.resolve()}|{st.st_size}|{st.st_mtime_ns}"
487
+ except Exception:
488
+ key = str(prompt_audio)
489
+ h = hashlib.sha1(key.encode("utf-8", errors="ignore")).hexdigest()[:16]
490
+ out_dir = self._artifact_root() / "prompt_cache"
491
+ out_dir.mkdir(parents=True, exist_ok=True)
492
+ out_path = out_dir / f"prompt_{h}.wav"
493
+ if out_path.exists():
494
+ return out_path
495
+
496
+ audio, sr = sf.read(str(prompt_audio), always_2d=True, dtype="float32")
497
+ mono = np.mean(audio, axis=1).astype(np.float32).reshape(-1)
498
+
499
+ target_sr = 24000
500
+ if int(sr) != int(target_sr):
501
+ # Prefer torchaudio's sinc-based resampling when available (better fidelity than linear).
502
+ try:
503
+ import torch
504
+ import torchaudio
505
+
506
+ t = torch.from_numpy(mono).unsqueeze(0)
507
+ t = torchaudio.functional.resample(t, orig_freq=int(sr), new_freq=int(target_sr))
508
+ mono = t.squeeze(0).detach().cpu().numpy().astype(np.float32)
509
+ except Exception:
510
+ from ..audio.resample import linear_resample_mono
511
+
512
+ mono = linear_resample_mono(mono, int(sr), int(target_sr))
513
+
514
+ # Cap extreme prompts for runtime stability; upstream examples go up to ~26s.
515
+ max_seconds = 30.0
516
+ max_samples = int(round(float(max_seconds) * float(target_sr)))
517
+ if mono.size > max_samples:
518
+ mono = mono[:max_samples]
519
+
520
+ try:
521
+ mono = np.nan_to_num(mono, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
522
+ mono = np.clip(mono, -1.0, 1.0).astype(np.float32)
523
+ except Exception:
524
+ pass
525
+
526
+ sf.write(str(out_path), mono, int(target_sr), format="WAV", subtype="PCM_16")
527
+ return out_path
528
+ except Exception:
529
+ # If anything goes wrong, fall back to the original path.
530
+ return prompt_audio
531
+
532
+ def _build_conversation(self, text: str) -> List[List[dict]]:
533
+ # Match upstream framing ("Chroma" virtual human) while keeping strict TTS behavior.
534
+ system_prompt = (
535
+ "You are Chroma, an advanced virtual human created by the FlashLabs. "
536
+ "You possess the ability to understand auditory inputs and generate both text and speech. "
537
+ "For this request, do not answer or paraphrase. Read the user's text aloud exactly as written."
538
+ )
539
+ return [[
540
+ {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
541
+ {"role": "user", "content": [{"type": "text", "text": str(text)}]},
542
+ ]]
543
+
544
+ def infer_to_wav_bytes(
545
+ self,
546
+ *,
547
+ text: str,
548
+ reference_paths: Iterable[str | Path],
549
+ reference_text: Optional[str] = None,
550
+ speed: Optional[float] = None,
551
+ ) -> bytes:
552
+ chunks = []
553
+ sr_out = 24000
554
+ for chunk, sr in self.infer_to_audio_chunks(
555
+ text=text,
556
+ reference_paths=reference_paths,
557
+ reference_text=reference_text,
558
+ speed=speed,
559
+ ):
560
+ chunks.append(np.asarray(chunk, dtype=np.float32).reshape(-1))
561
+ sr_out = int(sr)
562
+ audio = np.concatenate(chunks) if chunks else np.zeros((0,), dtype=np.float32)
563
+ buf = io.BytesIO()
564
+ sf.write(buf, audio, int(sr_out), format="WAV", subtype="PCM_16")
565
+ return buf.getvalue()
566
+
567
+ def infer_to_audio_chunks(
568
+ self,
569
+ *,
570
+ text: str,
571
+ reference_paths: Iterable[str | Path],
572
+ reference_text: Optional[str] = None,
573
+ speed: Optional[float] = None,
574
+ max_chars: int = 240,
575
+ ):
576
+ self._ensure_model_loaded()
577
+
578
+ if not reference_text or not str(reference_text).strip():
579
+ raise RuntimeError(
580
+ "Missing reference_text for Chroma cloning.\n"
581
+ "If you're using VoiceCloner/VoiceManager, reference_text should be auto-generated and cached.\n"
582
+ "If you're calling this engine directly, provide reference_text or set it via the voice store."
583
+ )
584
+
585
+ prompt_audio = self._select_prompt_audio(reference_paths)
586
+ prompt_audio = self._prepare_prompt_audio_for_processor(prompt_audio)
587
+ prompt_text = str(reference_text).strip()
588
+
589
+ import re
590
+
591
+ def _split_batches(s: str, limit: int) -> List[str]:
592
+ s = " ".join(str(s).replace("\n", " ").split()).strip()
593
+ if not s:
594
+ return []
595
+ sentences = re.split(r"(?<=[\\.!\\?\\。])\\s+", s)
596
+ out: List[str] = []
597
+ cur_s = ""
598
+ for sent in sentences:
599
+ sent = sent.strip()
600
+ if not sent:
601
+ continue
602
+ if len(sent) > limit:
603
+ words = sent.split(" ")
604
+ tmp = ""
605
+ for w in words:
606
+ cand = (tmp + " " + w).strip()
607
+ if len(cand) <= limit:
608
+ tmp = cand
609
+ else:
610
+ if tmp:
611
+ out.append(tmp)
612
+ tmp = w
613
+ if tmp:
614
+ out.append(tmp)
615
+ continue
616
+ cand = (cur_s + " " + sent).strip()
617
+ if len(cand) <= limit:
618
+ cur_s = cand
619
+ else:
620
+ if cur_s:
621
+ out.append(cur_s)
622
+ cur_s = sent
623
+ if cur_s:
624
+ out.append(cur_s)
625
+ return out
626
+
627
+ batches = _split_batches(text, int(max_chars)) or [" "]
628
+
629
+ model = self._model
630
+ processor = self._processor
631
+ if model is None or processor is None:
632
+ raise RuntimeError("Chroma model not loaded")
633
+
634
+ import torch
635
+
636
+ for batch_text in batches:
637
+ conversation = self._build_conversation(batch_text)
638
+ with warnings.catch_warnings():
639
+ if not self.debug:
640
+ warnings.filterwarnings("ignore", category=UserWarning, module=r"torchaudio\..*")
641
+ warnings.filterwarnings("ignore", category=UserWarning, message=r".*TorchCodec.*")
642
+ stdout_ctx = contextlib.redirect_stdout(io.StringIO()) if not self.debug else contextlib.nullcontext()
643
+ with stdout_ctx:
644
+ with _SilenceStderrFD(enabled=not self.debug):
645
+ inputs = processor(
646
+ conversation,
647
+ add_generation_prompt=True,
648
+ tokenize=False,
649
+ prompt_audio=[str(prompt_audio)],
650
+ prompt_text=[prompt_text],
651
+ return_tensors="pt",
652
+ )
653
+ device = getattr(model, "device", None) or torch.device("cpu")
654
+ # Match upstream usage: `processor(...).to(device)` (device move only).
655
+ # We'll cast floating-point inputs to the model's dtype right before
656
+ # generation to avoid fp16/bf16 dtype mismatches on GPU/MPS.
657
+ try:
658
+ inputs = inputs.to(device)
659
+ except Exception:
660
+ pass
661
+ moved = dict(inputs)
662
+ # Chroma weights often run in fp16/bf16 on GPU/MPS; ensure floating-point
663
+ # inputs match model weight dtype to avoid runtime dtype mismatches like:
664
+ # "Input type (float) and bias type (c10::Half) should be the same".
665
+ try:
666
+ target_dtype = getattr(model, "dtype", None)
667
+ if target_dtype is None:
668
+ target_dtype = next(iter(model.parameters())).dtype # type: ignore[assignment]
669
+ if target_dtype is not None:
670
+ for k, v in list(moved.items()):
671
+ if isinstance(v, torch.Tensor) and v.is_floating_point() and v.dtype != target_dtype:
672
+ moved[k] = v.to(dtype=target_dtype)
673
+ except Exception:
674
+ pass
675
+
676
+ max_new = self._estimate_max_new_tokens(batch_text, model)
677
+ with warnings.catch_warnings():
678
+ if not self.debug:
679
+ warnings.filterwarnings("ignore", category=UserWarning, message=r".*output_attentions.*")
680
+ warnings.filterwarnings("ignore", category=UserWarning, message=r".*sdp attention.*")
681
+ stdout_ctx = contextlib.redirect_stdout(io.StringIO()) if not self.debug else contextlib.nullcontext()
682
+ with stdout_ctx:
683
+ with _SilenceStderrFD(enabled=not self.debug):
684
+ with torch.inference_mode():
685
+ out = model.generate(
686
+ **moved,
687
+ max_new_tokens=int(max_new),
688
+ do_sample=bool(self._do_sample),
689
+ temperature=float(self._temperature),
690
+ top_k=int(self._top_k),
691
+ use_cache=True,
692
+ output_attentions=False,
693
+ output_audio=True,
694
+ )
695
+ audio_list = out.audio if hasattr(out, "audio") else out
696
+ if not audio_list:
697
+ continue
698
+ a = audio_list[0]
699
+ if isinstance(a, (list, tuple)) and a:
700
+ a = a[0]
701
+ if isinstance(a, torch.Tensor):
702
+ arr = a.detach().cpu().numpy()
703
+ else:
704
+ arr = np.asarray(a)
705
+
706
+ mono = np.asarray(arr, dtype=np.float32).reshape(-1)
707
+ sr = 24000
708
+
709
+ if speed and speed != 1.0:
710
+ try:
711
+ from ..tts.tts_engine import apply_speed_without_pitch_change
712
+
713
+ mono = apply_speed_without_pitch_change(mono, float(speed), sr=sr)
714
+ except Exception:
715
+ pass
716
+
717
+ # Sanitize before any normalization (NaNs/Infs produce loud artifacts).
718
+ try:
719
+ mono = np.nan_to_num(mono, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
720
+ except Exception:
721
+ pass
722
+
723
+ # Chroma's codec output can exceed [-1, 1] slightly; normalize to avoid
724
+ # hard clipping/distortion in playback and PCM encoders.
725
+ try:
726
+ peak = float(np.max(np.abs(mono))) if mono.size else 0.0
727
+ if peak > 1.0:
728
+ mono = mono / peak
729
+ except Exception:
730
+ pass
731
+
732
+ # Final clamp for safety (avoids stray overs on some backends).
733
+ try:
734
+ mono = np.clip(mono, -1.0, 1.0).astype(np.float32)
735
+ except Exception:
736
+ pass
737
+
738
+ yield mono, sr