abstractvoice 0.5.1__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- abstractvoice/__init__.py +2 -5
- abstractvoice/__main__.py +82 -3
- abstractvoice/adapters/__init__.py +12 -0
- abstractvoice/adapters/base.py +207 -0
- abstractvoice/adapters/stt_faster_whisper.py +401 -0
- abstractvoice/adapters/tts_piper.py +480 -0
- abstractvoice/aec/__init__.py +10 -0
- abstractvoice/aec/webrtc_apm.py +56 -0
- abstractvoice/artifacts.py +173 -0
- abstractvoice/audio/__init__.py +7 -0
- abstractvoice/audio/recorder.py +46 -0
- abstractvoice/audio/resample.py +25 -0
- abstractvoice/cloning/__init__.py +7 -0
- abstractvoice/cloning/engine_chroma.py +738 -0
- abstractvoice/cloning/engine_f5.py +546 -0
- abstractvoice/cloning/manager.py +349 -0
- abstractvoice/cloning/store.py +362 -0
- abstractvoice/compute/__init__.py +6 -0
- abstractvoice/compute/device.py +73 -0
- abstractvoice/config/__init__.py +2 -0
- abstractvoice/config/voice_catalog.py +19 -0
- abstractvoice/dependency_check.py +0 -1
- abstractvoice/examples/cli_repl.py +2403 -243
- abstractvoice/examples/voice_cli.py +64 -63
- abstractvoice/integrations/__init__.py +2 -0
- abstractvoice/integrations/abstractcore.py +116 -0
- abstractvoice/integrations/abstractcore_plugin.py +253 -0
- abstractvoice/prefetch.py +82 -0
- abstractvoice/recognition.py +424 -42
- abstractvoice/stop_phrase.py +103 -0
- abstractvoice/tts/__init__.py +3 -3
- abstractvoice/tts/adapter_tts_engine.py +210 -0
- abstractvoice/tts/tts_engine.py +257 -1208
- abstractvoice/vm/__init__.py +2 -0
- abstractvoice/vm/common.py +21 -0
- abstractvoice/vm/core.py +139 -0
- abstractvoice/vm/manager.py +108 -0
- abstractvoice/vm/stt_mixin.py +158 -0
- abstractvoice/vm/tts_mixin.py +550 -0
- abstractvoice/voice_manager.py +6 -1061
- abstractvoice-0.6.1.dist-info/METADATA +213 -0
- abstractvoice-0.6.1.dist-info/RECORD +52 -0
- {abstractvoice-0.5.1.dist-info → abstractvoice-0.6.1.dist-info}/WHEEL +1 -1
- abstractvoice-0.6.1.dist-info/entry_points.txt +6 -0
- abstractvoice/instant_setup.py +0 -83
- abstractvoice/simple_model_manager.py +0 -539
- abstractvoice-0.5.1.dist-info/METADATA +0 -1458
- abstractvoice-0.5.1.dist-info/RECORD +0 -23
- abstractvoice-0.5.1.dist-info/entry_points.txt +0 -2
- {abstractvoice-0.5.1.dist-info → abstractvoice-0.6.1.dist-info}/licenses/LICENSE +0 -0
- {abstractvoice-0.5.1.dist-info → abstractvoice-0.6.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,738 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import gc
|
|
5
|
+
import hashlib
|
|
6
|
+
import io
|
|
7
|
+
import os
|
|
8
|
+
import warnings
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Dict, Iterable, List, Optional, Tuple
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import soundfile as sf
|
|
15
|
+
|
|
16
|
+
from ..tts.tts_engine import _SilenceStderrFD
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class ChromaArtifacts:
|
|
21
|
+
root: Path
|
|
22
|
+
model_id: str
|
|
23
|
+
revision: Optional[str]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ChromaVoiceCloningEngine:
|
|
27
|
+
"""In-process Chroma voice cloning engine (optional; GPU-heavy).
|
|
28
|
+
|
|
29
|
+
Design principles
|
|
30
|
+
-----------------
|
|
31
|
+
- Never download weights implicitly from speak()/speak_to_bytes().
|
|
32
|
+
- Provide explicit prefetch hooks.
|
|
33
|
+
- Keep imports lightweight until the engine is actually used.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
DEFAULT_MODEL_ID = "FlashLabs/Chroma-4B"
|
|
37
|
+
# Pin a known revision by default (safer than floating remote code).
|
|
38
|
+
DEFAULT_REVISION = "864b4aea0c1359f91af62f1367df64657dc5e90f"
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
*,
|
|
43
|
+
debug: bool = False,
|
|
44
|
+
device: str = "auto",
|
|
45
|
+
model_id: str = DEFAULT_MODEL_ID,
|
|
46
|
+
revision: str | None = DEFAULT_REVISION,
|
|
47
|
+
temperature: float = 0.7,
|
|
48
|
+
top_k: int = 50,
|
|
49
|
+
do_sample: bool = True,
|
|
50
|
+
max_new_tokens_per_chunk: int = 512,
|
|
51
|
+
):
|
|
52
|
+
self.debug = bool(debug)
|
|
53
|
+
self._device_pref = str(device or "auto")
|
|
54
|
+
self._model_id = str(model_id or self.DEFAULT_MODEL_ID)
|
|
55
|
+
self._revision = revision if revision else None
|
|
56
|
+
|
|
57
|
+
self._temperature = float(temperature)
|
|
58
|
+
self._top_k = int(top_k)
|
|
59
|
+
self._do_sample = bool(do_sample)
|
|
60
|
+
self._max_new_tokens_per_chunk = int(max_new_tokens_per_chunk)
|
|
61
|
+
|
|
62
|
+
self._model = None
|
|
63
|
+
self._processor = None
|
|
64
|
+
self._resolved_device = None
|
|
65
|
+
|
|
66
|
+
def unload(self) -> None:
|
|
67
|
+
"""Best-effort release of loaded model/processor to free memory."""
|
|
68
|
+
self._model = None
|
|
69
|
+
self._processor = None
|
|
70
|
+
self._resolved_device = None
|
|
71
|
+
try:
|
|
72
|
+
gc.collect()
|
|
73
|
+
except Exception:
|
|
74
|
+
pass
|
|
75
|
+
try:
|
|
76
|
+
import torch
|
|
77
|
+
|
|
78
|
+
# Best-effort GPU/MPS cache release.
|
|
79
|
+
try:
|
|
80
|
+
if torch.cuda.is_available():
|
|
81
|
+
torch.cuda.empty_cache()
|
|
82
|
+
except Exception:
|
|
83
|
+
pass
|
|
84
|
+
try:
|
|
85
|
+
if hasattr(torch, "mps") and torch.backends.mps.is_available():
|
|
86
|
+
torch.mps.empty_cache()
|
|
87
|
+
except Exception:
|
|
88
|
+
pass
|
|
89
|
+
except Exception:
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
def runtime_info(self) -> Dict[str, object]:
|
|
93
|
+
info: Dict[str, object] = {
|
|
94
|
+
"model_id": self._model_id,
|
|
95
|
+
"revision": self._revision,
|
|
96
|
+
"requested_device": self._device_pref,
|
|
97
|
+
"resolved_device": self._resolved_device,
|
|
98
|
+
"temperature": self._temperature,
|
|
99
|
+
"top_k": self._top_k,
|
|
100
|
+
"do_sample": self._do_sample,
|
|
101
|
+
"max_new_tokens_per_chunk": self._max_new_tokens_per_chunk,
|
|
102
|
+
}
|
|
103
|
+
try:
|
|
104
|
+
import torch
|
|
105
|
+
|
|
106
|
+
info["torch_version"] = getattr(torch, "__version__", "?")
|
|
107
|
+
info["cuda_available"] = bool(torch.cuda.is_available())
|
|
108
|
+
try:
|
|
109
|
+
info["mps_available"] = bool(torch.backends.mps.is_available())
|
|
110
|
+
except Exception:
|
|
111
|
+
info["mps_available"] = False
|
|
112
|
+
except Exception:
|
|
113
|
+
pass
|
|
114
|
+
return info
|
|
115
|
+
|
|
116
|
+
def set_quality_preset(self, preset: str) -> None:
|
|
117
|
+
p = (preset or "").strip().lower()
|
|
118
|
+
if p not in ("fast", "balanced", "high"):
|
|
119
|
+
raise ValueError("preset must be one of: fast|balanced|high")
|
|
120
|
+
if p == "fast":
|
|
121
|
+
self._temperature = 0.6
|
|
122
|
+
self._top_k = 30
|
|
123
|
+
self._max_new_tokens_per_chunk = 384
|
|
124
|
+
elif p == "balanced":
|
|
125
|
+
self._temperature = 0.7
|
|
126
|
+
self._top_k = 50
|
|
127
|
+
self._max_new_tokens_per_chunk = 512
|
|
128
|
+
else:
|
|
129
|
+
self._temperature = 0.75
|
|
130
|
+
self._top_k = 80
|
|
131
|
+
self._max_new_tokens_per_chunk = 768
|
|
132
|
+
|
|
133
|
+
def _artifact_root(self) -> Path:
|
|
134
|
+
cache_dir = Path(os.path.expanduser("~/.cache/abstractvoice/chroma"))
|
|
135
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
136
|
+
return cache_dir
|
|
137
|
+
|
|
138
|
+
def _resolve_chroma_artifacts_local(self) -> ChromaArtifacts:
|
|
139
|
+
root = self._artifact_root()
|
|
140
|
+
required = [
|
|
141
|
+
"config.json",
|
|
142
|
+
"processor_config.json",
|
|
143
|
+
"model.safetensors.index.json",
|
|
144
|
+
"modeling_chroma.py",
|
|
145
|
+
"processing_chroma.py",
|
|
146
|
+
"configuration_chroma.py",
|
|
147
|
+
]
|
|
148
|
+
missing = [name for name in required if not (root / name).exists()]
|
|
149
|
+
if missing:
|
|
150
|
+
raise RuntimeError(
|
|
151
|
+
"Chroma artifacts are not present locally.\n"
|
|
152
|
+
"Prefetch explicitly (outside the REPL):\n"
|
|
153
|
+
" abstractvoice-prefetch --chroma\n"
|
|
154
|
+
"or:\n"
|
|
155
|
+
" python -m abstractvoice download --chroma\n"
|
|
156
|
+
f"Looked under: {root}\n"
|
|
157
|
+
f"Missing: {', '.join(missing)}"
|
|
158
|
+
)
|
|
159
|
+
return ChromaArtifacts(root=root, model_id=self._model_id, revision=self._revision)
|
|
160
|
+
|
|
161
|
+
def ensure_chroma_artifacts_downloaded(self) -> ChromaArtifacts:
|
|
162
|
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
|
163
|
+
try:
|
|
164
|
+
from huggingface_hub import snapshot_download
|
|
165
|
+
except Exception as e:
|
|
166
|
+
raise RuntimeError(
|
|
167
|
+
"huggingface_hub is required to download Chroma artifacts.\n"
|
|
168
|
+
"Install with: pip install huggingface_hub"
|
|
169
|
+
) from e
|
|
170
|
+
|
|
171
|
+
import warnings
|
|
172
|
+
|
|
173
|
+
root = self._artifact_root()
|
|
174
|
+
with warnings.catch_warnings():
|
|
175
|
+
warnings.filterwarnings(
|
|
176
|
+
"ignore",
|
|
177
|
+
category=UserWarning,
|
|
178
|
+
message=r".*local_dir_use_symlinks.*deprecated.*",
|
|
179
|
+
)
|
|
180
|
+
snapshot_download(
|
|
181
|
+
repo_id=self._model_id,
|
|
182
|
+
revision=self._revision,
|
|
183
|
+
local_dir=str(root),
|
|
184
|
+
)
|
|
185
|
+
return self._resolve_chroma_artifacts_local()
|
|
186
|
+
|
|
187
|
+
def are_chroma_artifacts_available(self) -> bool:
|
|
188
|
+
root = self._artifact_root()
|
|
189
|
+
return bool((root / "config.json").exists() and (root / "model.safetensors.index.json").exists())
|
|
190
|
+
|
|
191
|
+
def _resolve_device(self) -> str:
|
|
192
|
+
if self._device_pref and self._device_pref != "auto":
|
|
193
|
+
return str(self._device_pref)
|
|
194
|
+
try:
|
|
195
|
+
from ..compute import best_torch_device
|
|
196
|
+
|
|
197
|
+
return best_torch_device()
|
|
198
|
+
except Exception:
|
|
199
|
+
return "cpu"
|
|
200
|
+
|
|
201
|
+
def _ensure_chroma_runtime(self) -> None:
|
|
202
|
+
try:
|
|
203
|
+
import importlib.util
|
|
204
|
+
|
|
205
|
+
if importlib.util.find_spec("torch") is None:
|
|
206
|
+
raise ImportError("torch not installed")
|
|
207
|
+
if importlib.util.find_spec("transformers") is None:
|
|
208
|
+
raise ImportError("transformers not installed")
|
|
209
|
+
except Exception as e:
|
|
210
|
+
raise RuntimeError(
|
|
211
|
+
"Chroma requires the optional dependency group.\n"
|
|
212
|
+
"Install with:\n"
|
|
213
|
+
" pip install \"abstractvoice[chroma]\"\n"
|
|
214
|
+
f"Original error: {e}"
|
|
215
|
+
) from e
|
|
216
|
+
|
|
217
|
+
def _ensure_model_loaded(self) -> None:
|
|
218
|
+
if self._model is not None and self._processor is not None and self._resolved_device is not None:
|
|
219
|
+
return
|
|
220
|
+
|
|
221
|
+
self._ensure_chroma_runtime()
|
|
222
|
+
artifacts = self._resolve_chroma_artifacts_local()
|
|
223
|
+
|
|
224
|
+
# Keep interactive UX quiet by default (the REPL provides its own spinner).
|
|
225
|
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
|
226
|
+
os.environ.setdefault("TRANSFORMERS_NO_TQDM", "1")
|
|
227
|
+
|
|
228
|
+
# Keep interactive UX quiet by default (the REPL provides its own spinner).
|
|
229
|
+
try:
|
|
230
|
+
from transformers.utils import logging as hf_logging
|
|
231
|
+
|
|
232
|
+
hf_logging.disable_progress_bar()
|
|
233
|
+
if not self.debug:
|
|
234
|
+
hf_logging.set_verbosity_error()
|
|
235
|
+
except Exception:
|
|
236
|
+
pass
|
|
237
|
+
|
|
238
|
+
import torch
|
|
239
|
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
|
240
|
+
|
|
241
|
+
device = self._resolve_device()
|
|
242
|
+
self._resolved_device = device
|
|
243
|
+
|
|
244
|
+
torch_dtype = None
|
|
245
|
+
if device == "cuda":
|
|
246
|
+
torch_dtype = torch.bfloat16
|
|
247
|
+
elif device == "mps":
|
|
248
|
+
torch_dtype = torch.float16
|
|
249
|
+
|
|
250
|
+
device_map = None
|
|
251
|
+
# `device_map="auto"` is most reliable on CUDA; MPS/CPU should load normally then move.
|
|
252
|
+
if self._device_pref == "auto" and device == "cuda":
|
|
253
|
+
device_map = "auto"
|
|
254
|
+
|
|
255
|
+
# transformers 5.0.0 deprecates `torch_dtype` in favor of `dtype`.
|
|
256
|
+
with warnings.catch_warnings():
|
|
257
|
+
# Torch/Torchaudio/Transformers often emit noisy warnings that are not actionable
|
|
258
|
+
# for interactive users. Keep them hidden unless debug is enabled.
|
|
259
|
+
if not self.debug:
|
|
260
|
+
warnings.filterwarnings("ignore", category=UserWarning, module=r"torchaudio\..*")
|
|
261
|
+
warnings.filterwarnings("ignore", category=UserWarning, message=r".*TorchCodec.*")
|
|
262
|
+
warnings.filterwarnings("ignore", category=UserWarning, message=r".*output_attentions.*")
|
|
263
|
+
stdout_ctx = contextlib.redirect_stdout(io.StringIO()) if not self.debug else contextlib.nullcontext()
|
|
264
|
+
with stdout_ctx:
|
|
265
|
+
with _SilenceStderrFD(enabled=not self.debug):
|
|
266
|
+
try:
|
|
267
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
268
|
+
str(artifacts.root),
|
|
269
|
+
trust_remote_code=True,
|
|
270
|
+
device_map=device_map,
|
|
271
|
+
dtype=torch_dtype,
|
|
272
|
+
)
|
|
273
|
+
except TypeError:
|
|
274
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
275
|
+
str(artifacts.root),
|
|
276
|
+
trust_remote_code=True,
|
|
277
|
+
device_map=device_map,
|
|
278
|
+
torch_dtype=torch_dtype,
|
|
279
|
+
)
|
|
280
|
+
model.eval()
|
|
281
|
+
|
|
282
|
+
stdout_ctx = contextlib.redirect_stdout(io.StringIO()) if not self.debug else contextlib.nullcontext()
|
|
283
|
+
with stdout_ctx:
|
|
284
|
+
with _SilenceStderrFD(enabled=not self.debug):
|
|
285
|
+
processor = AutoProcessor.from_pretrained(str(artifacts.root), trust_remote_code=True)
|
|
286
|
+
|
|
287
|
+
if device != "cuda" and device_map is None:
|
|
288
|
+
try:
|
|
289
|
+
model.to(device)
|
|
290
|
+
except Exception:
|
|
291
|
+
pass
|
|
292
|
+
|
|
293
|
+
# Ensure tied weights are applied (Chroma relies on tied audio embeddings).
|
|
294
|
+
try:
|
|
295
|
+
if hasattr(model, "tie_weights"):
|
|
296
|
+
model.tie_weights()
|
|
297
|
+
except Exception:
|
|
298
|
+
pass
|
|
299
|
+
# Some releases log this key as MISSING even when it is meant to be tied.
|
|
300
|
+
# Ensure backbone audio embeddings match decoder embeddings (best-effort).
|
|
301
|
+
try:
|
|
302
|
+
if hasattr(model, "backbone") and hasattr(model, "decoder"):
|
|
303
|
+
b = getattr(model.backbone, "audio_embedding", None)
|
|
304
|
+
d = getattr(model.decoder, "audio_embedding", None)
|
|
305
|
+
if b is not None and d is not None:
|
|
306
|
+
bw = getattr(b, "embed_audio_tokens", None)
|
|
307
|
+
dw = getattr(d, "embed_audio_tokens", None)
|
|
308
|
+
if bw is not None and dw is not None and hasattr(bw, "weight") and hasattr(dw, "weight"):
|
|
309
|
+
if getattr(bw.weight, "shape", None) == getattr(dw.weight, "shape", None):
|
|
310
|
+
bw.weight.data.copy_(dw.weight.data)
|
|
311
|
+
# Prefer a hard tie (shared module) to avoid one side staying randomly
|
|
312
|
+
# initialized when the checkpoint omits a tied key.
|
|
313
|
+
try:
|
|
314
|
+
if bw is not None and dw is not None and bw is not dw:
|
|
315
|
+
b.embed_audio_tokens = dw
|
|
316
|
+
except Exception:
|
|
317
|
+
pass
|
|
318
|
+
except Exception:
|
|
319
|
+
pass
|
|
320
|
+
|
|
321
|
+
self._patch_generation_compat(model)
|
|
322
|
+
|
|
323
|
+
self._model = model
|
|
324
|
+
self._processor = processor
|
|
325
|
+
|
|
326
|
+
def _estimate_max_new_tokens(self, text: str, model) -> int:
|
|
327
|
+
"""Estimate a safe `max_new_tokens` (audio frames) for a text batch.
|
|
328
|
+
|
|
329
|
+
Chroma audio frames decode at roughly:
|
|
330
|
+
frames_per_second = sampling_rate / audio_frame_freq
|
|
331
|
+
For the default config this is ~12.5 fps, so `512` frames ~= 41s.
|
|
332
|
+
Without a good stopping signal, the model may generate until `max_new_tokens`,
|
|
333
|
+
so we cap it based on expected speech duration to avoid long noisy tails.
|
|
334
|
+
"""
|
|
335
|
+
s = " ".join(str(text or "").split()).strip()
|
|
336
|
+
words = len(s.split()) if s else 0
|
|
337
|
+
if words <= 0:
|
|
338
|
+
# Fallback: roughly 5 chars/word in English.
|
|
339
|
+
words = max(1, int(round(len(s) / 5.0))) if s else 1
|
|
340
|
+
|
|
341
|
+
# Conservative speech-rate assumption (words/sec).
|
|
342
|
+
wps = 3.2
|
|
343
|
+
est_s = float(words) / float(wps)
|
|
344
|
+
# Small pad to avoid truncation on short utterances.
|
|
345
|
+
est_s += 0.4
|
|
346
|
+
|
|
347
|
+
sr = 24000
|
|
348
|
+
frame_freq = 1920
|
|
349
|
+
try:
|
|
350
|
+
cfg = getattr(model, "config", None)
|
|
351
|
+
if cfg is not None:
|
|
352
|
+
frame_freq = int(getattr(cfg, "audio_frame_freq", frame_freq) or frame_freq)
|
|
353
|
+
codec_cfg = getattr(cfg, "codec_config", None)
|
|
354
|
+
if codec_cfg is not None:
|
|
355
|
+
if isinstance(codec_cfg, dict):
|
|
356
|
+
sr = int(codec_cfg.get("sampling_rate", sr) or sr)
|
|
357
|
+
else:
|
|
358
|
+
sr = int(getattr(codec_cfg, "sampling_rate", sr) or sr)
|
|
359
|
+
except Exception:
|
|
360
|
+
pass
|
|
361
|
+
|
|
362
|
+
fps = float(sr) / float(frame_freq) if frame_freq else 12.5
|
|
363
|
+
# Add slack so we still have room for slower speech.
|
|
364
|
+
frames = int(round(est_s * fps * 1.25))
|
|
365
|
+
frames = max(64, frames)
|
|
366
|
+
cap = int(self._max_new_tokens_per_chunk) if self._max_new_tokens_per_chunk else frames
|
|
367
|
+
return int(min(frames, cap))
|
|
368
|
+
|
|
369
|
+
def _patch_generation_compat(self, model) -> None:
|
|
370
|
+
"""Patch known transformers incompatibilities for Chroma remote code.
|
|
371
|
+
|
|
372
|
+
Upstream Chroma code (pinned revision) expects a `use_model_defaults` positional arg
|
|
373
|
+
in `GenerationMixin._prepare_generation_config`. transformers 5.0.0 removed it.
|
|
374
|
+
"""
|
|
375
|
+
try:
|
|
376
|
+
import inspect
|
|
377
|
+
|
|
378
|
+
from transformers.generation import GenerationMode
|
|
379
|
+
from transformers.generation.utils import GenerationMixin
|
|
380
|
+
|
|
381
|
+
base_sig = inspect.signature(GenerationMixin._prepare_generation_config)
|
|
382
|
+
if "use_model_defaults" in base_sig.parameters:
|
|
383
|
+
return
|
|
384
|
+
|
|
385
|
+
chroma_gen_cls = None
|
|
386
|
+
for cls in model.__class__.mro():
|
|
387
|
+
if cls.__name__ == "ChromaGenerationMixin":
|
|
388
|
+
chroma_gen_cls = cls
|
|
389
|
+
break
|
|
390
|
+
if chroma_gen_cls is None:
|
|
391
|
+
return
|
|
392
|
+
|
|
393
|
+
current = getattr(chroma_gen_cls, "_prepare_generation_config", None)
|
|
394
|
+
if current is None or getattr(current, "_abstractvoice_patched", False):
|
|
395
|
+
return
|
|
396
|
+
|
|
397
|
+
def _patched_prepare_generation_config(
|
|
398
|
+
self, generation_config=None, use_model_defaults=None, **kwargs
|
|
399
|
+
):
|
|
400
|
+
depth_decoder_kwargs = {k[len("decoder_") :]: v for k, v in kwargs.items() if k.startswith("decoder_")}
|
|
401
|
+
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("decoder_")}
|
|
402
|
+
|
|
403
|
+
# transformers 5.0.0 removed the `use_model_defaults` positional arg; keep compatibility.
|
|
404
|
+
try:
|
|
405
|
+
generation_config_out, model_kwargs = super(chroma_gen_cls, self)._prepare_generation_config(
|
|
406
|
+
generation_config, use_model_defaults, **kwargs
|
|
407
|
+
)
|
|
408
|
+
except TypeError:
|
|
409
|
+
generation_config_out, model_kwargs = super(chroma_gen_cls, self)._prepare_generation_config(
|
|
410
|
+
generation_config, **kwargs
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
try:
|
|
414
|
+
self.decoder.generation_config.update(**depth_decoder_kwargs)
|
|
415
|
+
except Exception:
|
|
416
|
+
pass
|
|
417
|
+
|
|
418
|
+
try:
|
|
419
|
+
decoder_min_new_tokens = getattr(self.decoder.generation_config, "min_new_tokens") or (
|
|
420
|
+
self.decoder.config.audio_num_codebooks - 1
|
|
421
|
+
)
|
|
422
|
+
decoder_max_new_tokens = getattr(self.decoder.generation_config, "max_new_tokens") or (
|
|
423
|
+
self.decoder.config.audio_num_codebooks - 1
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
if {decoder_min_new_tokens, decoder_max_new_tokens} != {self.decoder.config.audio_num_codebooks - 1}:
|
|
427
|
+
raise ValueError(
|
|
428
|
+
"decoder_generation_config's min_new_tokens "
|
|
429
|
+
f"({decoder_min_new_tokens}) and max_new_tokens ({decoder_max_new_tokens}) "
|
|
430
|
+
f"must be equal to self.config.num_codebooks - 1 ({self.decoder.config.audio_num_codebooks - 1})"
|
|
431
|
+
)
|
|
432
|
+
elif getattr(self.decoder.generation_config, "return_dict_in_generate", False):
|
|
433
|
+
self.decoder.generation_config.return_dict_in_generate = False
|
|
434
|
+
except Exception:
|
|
435
|
+
pass
|
|
436
|
+
|
|
437
|
+
# Monkey patch the get_generation_mode method to support Chroma model.
|
|
438
|
+
try:
|
|
439
|
+
original_get_generation_mode = generation_config_out.get_generation_mode
|
|
440
|
+
|
|
441
|
+
def patched_get_generation_mode(assistant_model=None):
|
|
442
|
+
generation_mode = original_get_generation_mode(assistant_model)
|
|
443
|
+
if generation_mode not in (GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE):
|
|
444
|
+
raise ValueError(
|
|
445
|
+
f"Generation mode {generation_mode} is not supported for Chroma model. "
|
|
446
|
+
"Please set generation parameters to use greedy or sampling generation."
|
|
447
|
+
)
|
|
448
|
+
return generation_mode
|
|
449
|
+
|
|
450
|
+
generation_config_out.get_generation_mode = patched_get_generation_mode
|
|
451
|
+
except Exception:
|
|
452
|
+
pass
|
|
453
|
+
|
|
454
|
+
return generation_config_out, model_kwargs
|
|
455
|
+
|
|
456
|
+
_patched_prepare_generation_config._abstractvoice_patched = True # type: ignore[attr-defined]
|
|
457
|
+
chroma_gen_cls._prepare_generation_config = _patched_prepare_generation_config # type: ignore[assignment]
|
|
458
|
+
except Exception:
|
|
459
|
+
# Best-effort only: do not break core flows if patching fails.
|
|
460
|
+
return
|
|
461
|
+
|
|
462
|
+
def _select_prompt_audio(self, reference_paths: Iterable[str | Path]) -> Path:
|
|
463
|
+
paths = [Path(p) for p in reference_paths]
|
|
464
|
+
if not paths:
|
|
465
|
+
raise ValueError("reference_paths must contain at least one path")
|
|
466
|
+
for p in paths:
|
|
467
|
+
if not p.exists():
|
|
468
|
+
raise FileNotFoundError(str(p))
|
|
469
|
+
# Chroma prompt_audio currently expects a single audio file.
|
|
470
|
+
return paths[0]
|
|
471
|
+
|
|
472
|
+
def _prepare_prompt_audio_for_processor(self, prompt_audio: Path) -> Path:
|
|
473
|
+
"""Best-effort prompt-audio normalization for Chroma.
|
|
474
|
+
|
|
475
|
+
Chroma's processor loads audio via torchaudio and resamples internally.
|
|
476
|
+
We keep this preprocessing minimal (mono + 24kHz + PCM16 WAV) to reduce
|
|
477
|
+
backend variability without altering voice characteristics.
|
|
478
|
+
"""
|
|
479
|
+
try:
|
|
480
|
+
if not prompt_audio.exists():
|
|
481
|
+
return prompt_audio
|
|
482
|
+
|
|
483
|
+
# Cache by (path, size, mtime) to avoid recomputing every turn.
|
|
484
|
+
try:
|
|
485
|
+
st = prompt_audio.stat()
|
|
486
|
+
key = f"{prompt_audio.resolve()}|{st.st_size}|{st.st_mtime_ns}"
|
|
487
|
+
except Exception:
|
|
488
|
+
key = str(prompt_audio)
|
|
489
|
+
h = hashlib.sha1(key.encode("utf-8", errors="ignore")).hexdigest()[:16]
|
|
490
|
+
out_dir = self._artifact_root() / "prompt_cache"
|
|
491
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
492
|
+
out_path = out_dir / f"prompt_{h}.wav"
|
|
493
|
+
if out_path.exists():
|
|
494
|
+
return out_path
|
|
495
|
+
|
|
496
|
+
audio, sr = sf.read(str(prompt_audio), always_2d=True, dtype="float32")
|
|
497
|
+
mono = np.mean(audio, axis=1).astype(np.float32).reshape(-1)
|
|
498
|
+
|
|
499
|
+
target_sr = 24000
|
|
500
|
+
if int(sr) != int(target_sr):
|
|
501
|
+
# Prefer torchaudio's sinc-based resampling when available (better fidelity than linear).
|
|
502
|
+
try:
|
|
503
|
+
import torch
|
|
504
|
+
import torchaudio
|
|
505
|
+
|
|
506
|
+
t = torch.from_numpy(mono).unsqueeze(0)
|
|
507
|
+
t = torchaudio.functional.resample(t, orig_freq=int(sr), new_freq=int(target_sr))
|
|
508
|
+
mono = t.squeeze(0).detach().cpu().numpy().astype(np.float32)
|
|
509
|
+
except Exception:
|
|
510
|
+
from ..audio.resample import linear_resample_mono
|
|
511
|
+
|
|
512
|
+
mono = linear_resample_mono(mono, int(sr), int(target_sr))
|
|
513
|
+
|
|
514
|
+
# Cap extreme prompts for runtime stability; upstream examples go up to ~26s.
|
|
515
|
+
max_seconds = 30.0
|
|
516
|
+
max_samples = int(round(float(max_seconds) * float(target_sr)))
|
|
517
|
+
if mono.size > max_samples:
|
|
518
|
+
mono = mono[:max_samples]
|
|
519
|
+
|
|
520
|
+
try:
|
|
521
|
+
mono = np.nan_to_num(mono, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
|
|
522
|
+
mono = np.clip(mono, -1.0, 1.0).astype(np.float32)
|
|
523
|
+
except Exception:
|
|
524
|
+
pass
|
|
525
|
+
|
|
526
|
+
sf.write(str(out_path), mono, int(target_sr), format="WAV", subtype="PCM_16")
|
|
527
|
+
return out_path
|
|
528
|
+
except Exception:
|
|
529
|
+
# If anything goes wrong, fall back to the original path.
|
|
530
|
+
return prompt_audio
|
|
531
|
+
|
|
532
|
+
def _build_conversation(self, text: str) -> List[List[dict]]:
|
|
533
|
+
# Match upstream framing ("Chroma" virtual human) while keeping strict TTS behavior.
|
|
534
|
+
system_prompt = (
|
|
535
|
+
"You are Chroma, an advanced virtual human created by the FlashLabs. "
|
|
536
|
+
"You possess the ability to understand auditory inputs and generate both text and speech. "
|
|
537
|
+
"For this request, do not answer or paraphrase. Read the user's text aloud exactly as written."
|
|
538
|
+
)
|
|
539
|
+
return [[
|
|
540
|
+
{"role": "system", "content": [{"type": "text", "text": system_prompt}]},
|
|
541
|
+
{"role": "user", "content": [{"type": "text", "text": str(text)}]},
|
|
542
|
+
]]
|
|
543
|
+
|
|
544
|
+
def infer_to_wav_bytes(
|
|
545
|
+
self,
|
|
546
|
+
*,
|
|
547
|
+
text: str,
|
|
548
|
+
reference_paths: Iterable[str | Path],
|
|
549
|
+
reference_text: Optional[str] = None,
|
|
550
|
+
speed: Optional[float] = None,
|
|
551
|
+
) -> bytes:
|
|
552
|
+
chunks = []
|
|
553
|
+
sr_out = 24000
|
|
554
|
+
for chunk, sr in self.infer_to_audio_chunks(
|
|
555
|
+
text=text,
|
|
556
|
+
reference_paths=reference_paths,
|
|
557
|
+
reference_text=reference_text,
|
|
558
|
+
speed=speed,
|
|
559
|
+
):
|
|
560
|
+
chunks.append(np.asarray(chunk, dtype=np.float32).reshape(-1))
|
|
561
|
+
sr_out = int(sr)
|
|
562
|
+
audio = np.concatenate(chunks) if chunks else np.zeros((0,), dtype=np.float32)
|
|
563
|
+
buf = io.BytesIO()
|
|
564
|
+
sf.write(buf, audio, int(sr_out), format="WAV", subtype="PCM_16")
|
|
565
|
+
return buf.getvalue()
|
|
566
|
+
|
|
567
|
+
def infer_to_audio_chunks(
|
|
568
|
+
self,
|
|
569
|
+
*,
|
|
570
|
+
text: str,
|
|
571
|
+
reference_paths: Iterable[str | Path],
|
|
572
|
+
reference_text: Optional[str] = None,
|
|
573
|
+
speed: Optional[float] = None,
|
|
574
|
+
max_chars: int = 240,
|
|
575
|
+
):
|
|
576
|
+
self._ensure_model_loaded()
|
|
577
|
+
|
|
578
|
+
if not reference_text or not str(reference_text).strip():
|
|
579
|
+
raise RuntimeError(
|
|
580
|
+
"Missing reference_text for Chroma cloning.\n"
|
|
581
|
+
"If you're using VoiceCloner/VoiceManager, reference_text should be auto-generated and cached.\n"
|
|
582
|
+
"If you're calling this engine directly, provide reference_text or set it via the voice store."
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
prompt_audio = self._select_prompt_audio(reference_paths)
|
|
586
|
+
prompt_audio = self._prepare_prompt_audio_for_processor(prompt_audio)
|
|
587
|
+
prompt_text = str(reference_text).strip()
|
|
588
|
+
|
|
589
|
+
import re
|
|
590
|
+
|
|
591
|
+
def _split_batches(s: str, limit: int) -> List[str]:
|
|
592
|
+
s = " ".join(str(s).replace("\n", " ").split()).strip()
|
|
593
|
+
if not s:
|
|
594
|
+
return []
|
|
595
|
+
sentences = re.split(r"(?<=[\\.!\\?\\。])\\s+", s)
|
|
596
|
+
out: List[str] = []
|
|
597
|
+
cur_s = ""
|
|
598
|
+
for sent in sentences:
|
|
599
|
+
sent = sent.strip()
|
|
600
|
+
if not sent:
|
|
601
|
+
continue
|
|
602
|
+
if len(sent) > limit:
|
|
603
|
+
words = sent.split(" ")
|
|
604
|
+
tmp = ""
|
|
605
|
+
for w in words:
|
|
606
|
+
cand = (tmp + " " + w).strip()
|
|
607
|
+
if len(cand) <= limit:
|
|
608
|
+
tmp = cand
|
|
609
|
+
else:
|
|
610
|
+
if tmp:
|
|
611
|
+
out.append(tmp)
|
|
612
|
+
tmp = w
|
|
613
|
+
if tmp:
|
|
614
|
+
out.append(tmp)
|
|
615
|
+
continue
|
|
616
|
+
cand = (cur_s + " " + sent).strip()
|
|
617
|
+
if len(cand) <= limit:
|
|
618
|
+
cur_s = cand
|
|
619
|
+
else:
|
|
620
|
+
if cur_s:
|
|
621
|
+
out.append(cur_s)
|
|
622
|
+
cur_s = sent
|
|
623
|
+
if cur_s:
|
|
624
|
+
out.append(cur_s)
|
|
625
|
+
return out
|
|
626
|
+
|
|
627
|
+
batches = _split_batches(text, int(max_chars)) or [" "]
|
|
628
|
+
|
|
629
|
+
model = self._model
|
|
630
|
+
processor = self._processor
|
|
631
|
+
if model is None or processor is None:
|
|
632
|
+
raise RuntimeError("Chroma model not loaded")
|
|
633
|
+
|
|
634
|
+
import torch
|
|
635
|
+
|
|
636
|
+
for batch_text in batches:
|
|
637
|
+
conversation = self._build_conversation(batch_text)
|
|
638
|
+
with warnings.catch_warnings():
|
|
639
|
+
if not self.debug:
|
|
640
|
+
warnings.filterwarnings("ignore", category=UserWarning, module=r"torchaudio\..*")
|
|
641
|
+
warnings.filterwarnings("ignore", category=UserWarning, message=r".*TorchCodec.*")
|
|
642
|
+
stdout_ctx = contextlib.redirect_stdout(io.StringIO()) if not self.debug else contextlib.nullcontext()
|
|
643
|
+
with stdout_ctx:
|
|
644
|
+
with _SilenceStderrFD(enabled=not self.debug):
|
|
645
|
+
inputs = processor(
|
|
646
|
+
conversation,
|
|
647
|
+
add_generation_prompt=True,
|
|
648
|
+
tokenize=False,
|
|
649
|
+
prompt_audio=[str(prompt_audio)],
|
|
650
|
+
prompt_text=[prompt_text],
|
|
651
|
+
return_tensors="pt",
|
|
652
|
+
)
|
|
653
|
+
device = getattr(model, "device", None) or torch.device("cpu")
|
|
654
|
+
# Match upstream usage: `processor(...).to(device)` (device move only).
|
|
655
|
+
# We'll cast floating-point inputs to the model's dtype right before
|
|
656
|
+
# generation to avoid fp16/bf16 dtype mismatches on GPU/MPS.
|
|
657
|
+
try:
|
|
658
|
+
inputs = inputs.to(device)
|
|
659
|
+
except Exception:
|
|
660
|
+
pass
|
|
661
|
+
moved = dict(inputs)
|
|
662
|
+
# Chroma weights often run in fp16/bf16 on GPU/MPS; ensure floating-point
|
|
663
|
+
# inputs match model weight dtype to avoid runtime dtype mismatches like:
|
|
664
|
+
# "Input type (float) and bias type (c10::Half) should be the same".
|
|
665
|
+
try:
|
|
666
|
+
target_dtype = getattr(model, "dtype", None)
|
|
667
|
+
if target_dtype is None:
|
|
668
|
+
target_dtype = next(iter(model.parameters())).dtype # type: ignore[assignment]
|
|
669
|
+
if target_dtype is not None:
|
|
670
|
+
for k, v in list(moved.items()):
|
|
671
|
+
if isinstance(v, torch.Tensor) and v.is_floating_point() and v.dtype != target_dtype:
|
|
672
|
+
moved[k] = v.to(dtype=target_dtype)
|
|
673
|
+
except Exception:
|
|
674
|
+
pass
|
|
675
|
+
|
|
676
|
+
max_new = self._estimate_max_new_tokens(batch_text, model)
|
|
677
|
+
with warnings.catch_warnings():
|
|
678
|
+
if not self.debug:
|
|
679
|
+
warnings.filterwarnings("ignore", category=UserWarning, message=r".*output_attentions.*")
|
|
680
|
+
warnings.filterwarnings("ignore", category=UserWarning, message=r".*sdp attention.*")
|
|
681
|
+
stdout_ctx = contextlib.redirect_stdout(io.StringIO()) if not self.debug else contextlib.nullcontext()
|
|
682
|
+
with stdout_ctx:
|
|
683
|
+
with _SilenceStderrFD(enabled=not self.debug):
|
|
684
|
+
with torch.inference_mode():
|
|
685
|
+
out = model.generate(
|
|
686
|
+
**moved,
|
|
687
|
+
max_new_tokens=int(max_new),
|
|
688
|
+
do_sample=bool(self._do_sample),
|
|
689
|
+
temperature=float(self._temperature),
|
|
690
|
+
top_k=int(self._top_k),
|
|
691
|
+
use_cache=True,
|
|
692
|
+
output_attentions=False,
|
|
693
|
+
output_audio=True,
|
|
694
|
+
)
|
|
695
|
+
audio_list = out.audio if hasattr(out, "audio") else out
|
|
696
|
+
if not audio_list:
|
|
697
|
+
continue
|
|
698
|
+
a = audio_list[0]
|
|
699
|
+
if isinstance(a, (list, tuple)) and a:
|
|
700
|
+
a = a[0]
|
|
701
|
+
if isinstance(a, torch.Tensor):
|
|
702
|
+
arr = a.detach().cpu().numpy()
|
|
703
|
+
else:
|
|
704
|
+
arr = np.asarray(a)
|
|
705
|
+
|
|
706
|
+
mono = np.asarray(arr, dtype=np.float32).reshape(-1)
|
|
707
|
+
sr = 24000
|
|
708
|
+
|
|
709
|
+
if speed and speed != 1.0:
|
|
710
|
+
try:
|
|
711
|
+
from ..tts.tts_engine import apply_speed_without_pitch_change
|
|
712
|
+
|
|
713
|
+
mono = apply_speed_without_pitch_change(mono, float(speed), sr=sr)
|
|
714
|
+
except Exception:
|
|
715
|
+
pass
|
|
716
|
+
|
|
717
|
+
# Sanitize before any normalization (NaNs/Infs produce loud artifacts).
|
|
718
|
+
try:
|
|
719
|
+
mono = np.nan_to_num(mono, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
|
|
720
|
+
except Exception:
|
|
721
|
+
pass
|
|
722
|
+
|
|
723
|
+
# Chroma's codec output can exceed [-1, 1] slightly; normalize to avoid
|
|
724
|
+
# hard clipping/distortion in playback and PCM encoders.
|
|
725
|
+
try:
|
|
726
|
+
peak = float(np.max(np.abs(mono))) if mono.size else 0.0
|
|
727
|
+
if peak > 1.0:
|
|
728
|
+
mono = mono / peak
|
|
729
|
+
except Exception:
|
|
730
|
+
pass
|
|
731
|
+
|
|
732
|
+
# Final clamp for safety (avoids stray overs on some backends).
|
|
733
|
+
try:
|
|
734
|
+
mono = np.clip(mono, -1.0, 1.0).astype(np.float32)
|
|
735
|
+
except Exception:
|
|
736
|
+
pass
|
|
737
|
+
|
|
738
|
+
yield mono, sr
|