ttscli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ttscli/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ """TTS CLI - Command-line interface for Voicebox TTS."""
2
+
3
+ __version__ = "0.1.0"
ttscli/__main__.py ADDED
@@ -0,0 +1,9 @@
1
+ """Entry point for python -m ttscli."""
2
+
3
+ import warnings
4
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
5
+
6
+ from .cli import main
7
+
8
+ if __name__ == "__main__":
9
+ main()
ttscli/audio.py ADDED
@@ -0,0 +1,8 @@
1
+ """Audio utilities."""
2
+
3
+ import soundfile as sf
4
+
5
+
6
+ def save_audio(audio, path: str, sample_rate: int = 24000):
7
+ """Save audio to WAV file."""
8
+ sf.write(path, audio, sample_rate)
@@ -0,0 +1,182 @@
1
+ """
2
+ Backend abstraction layer for TTS and STT.
3
+
4
+ Provides a unified interface for MLX and PyTorch backends.
5
+ """
6
+
7
+ from typing import Protocol, Optional, Tuple, List
8
+ from typing_extensions import runtime_checkable
9
+ import numpy as np
10
+
11
+ from ..platform import get_backend_type
12
+
13
+
14
+ @runtime_checkable
15
+ class TTSBackend(Protocol):
16
+ """Protocol for TTS backend implementations."""
17
+
18
+ async def load_model(self, model_size: str) -> None:
19
+ """Load TTS model."""
20
+ ...
21
+
22
+ async def create_voice_prompt(
23
+ self,
24
+ audio_path: str,
25
+ reference_text: str,
26
+ use_cache: bool = True,
27
+ ) -> Tuple[dict, bool]:
28
+ """
29
+ Create voice prompt from reference audio.
30
+
31
+ Returns:
32
+ Tuple of (voice_prompt_dict, was_cached)
33
+ """
34
+ ...
35
+
36
+ async def combine_voice_prompts(
37
+ self,
38
+ audio_paths: List[str],
39
+ reference_texts: List[str],
40
+ ) -> Tuple[np.ndarray, str]:
41
+ """
42
+ Combine multiple voice prompts.
43
+
44
+ Returns:
45
+ Tuple of (combined_audio_array, combined_text)
46
+ """
47
+ ...
48
+
49
+ async def generate(
50
+ self,
51
+ text: str,
52
+ voice_prompt: dict,
53
+ language: str = "en",
54
+ seed: Optional[int] = None,
55
+ instruct: Optional[str] = None,
56
+ ) -> Tuple[np.ndarray, int]:
57
+ """
58
+ Generate audio from text.
59
+
60
+ Returns:
61
+ Tuple of (audio_array, sample_rate)
62
+ """
63
+ ...
64
+
65
+ async def generate_stream(
66
+ self,
67
+ text: str,
68
+ voice_prompt: dict,
69
+ language: str = "en",
70
+ seed: Optional[int] = None,
71
+ instruct: Optional[str] = None,
72
+ ):
73
+ """
74
+ Generate audio from text, yielding chunks as they become available.
75
+
76
+ Yields:
77
+ Tuple of (audio_chunk: np.ndarray, sample_rate: int, is_final: bool)
78
+ """
79
+ ...
80
+
81
+ def unload_model(self) -> None:
82
+ """Unload model to free memory."""
83
+ ...
84
+
85
+ def is_loaded(self) -> bool:
86
+ """Check if model is loaded."""
87
+ ...
88
+
89
+ def _get_model_path(self, model_size: str) -> str:
90
+ """
91
+ Get model path for a given size.
92
+
93
+ Returns:
94
+ Model path or HuggingFace Hub ID
95
+ """
96
+ ...
97
+
98
+
99
+ @runtime_checkable
100
+ class STTBackend(Protocol):
101
+ """Protocol for STT (Speech-to-Text) backend implementations."""
102
+
103
+ async def load_model(self, model_size: str) -> None:
104
+ """Load STT model."""
105
+ ...
106
+
107
+ async def transcribe(
108
+ self,
109
+ audio_path: str,
110
+ language: Optional[str] = None,
111
+ ) -> str:
112
+ """
113
+ Transcribe audio to text.
114
+
115
+ Returns:
116
+ Transcribed text
117
+ """
118
+ ...
119
+
120
+ def unload_model(self) -> None:
121
+ """Unload model to free memory."""
122
+ ...
123
+
124
+ def is_loaded(self) -> bool:
125
+ """Check if model is loaded."""
126
+ ...
127
+
128
+
129
+ # Global backend instances
130
+ _tts_backend: Optional[TTSBackend] = None
131
+ _stt_backend: Optional[STTBackend] = None
132
+
133
+
134
+ def get_tts_backend() -> TTSBackend:
135
+ """
136
+ Get or create TTS backend instance based on platform.
137
+
138
+ Returns:
139
+ TTS backend instance (MLX or PyTorch)
140
+ """
141
+ global _tts_backend
142
+
143
+ if _tts_backend is None:
144
+ backend_type = get_backend_type()
145
+
146
+ if backend_type == "mlx":
147
+ from .mlx import MLXTTSBackend
148
+ _tts_backend = MLXTTSBackend()
149
+ else:
150
+ from .pytorch import PyTorchTTSBackend
151
+ _tts_backend = PyTorchTTSBackend()
152
+
153
+ return _tts_backend
154
+
155
+
156
+ def get_stt_backend() -> STTBackend:
157
+ """
158
+ Get or create STT backend instance based on platform.
159
+
160
+ Returns:
161
+ STT backend instance (MLX or PyTorch)
162
+ """
163
+ global _stt_backend
164
+
165
+ if _stt_backend is None:
166
+ backend_type = get_backend_type()
167
+
168
+ if backend_type == "mlx":
169
+ from .mlx import MLXSTTBackend
170
+ _stt_backend = MLXSTTBackend()
171
+ else:
172
+ from .pytorch import PyTorchSTTBackend
173
+ _stt_backend = PyTorchSTTBackend()
174
+
175
+ return _stt_backend
176
+
177
+
178
+ def reset_backends():
179
+ """Reset backend instances (useful for testing)."""
180
+ global _tts_backend, _stt_backend
181
+ _tts_backend = None
182
+ _stt_backend = None
ttscli/backends/mlx.py ADDED
@@ -0,0 +1,350 @@
1
+ """MLX backend for TTS using mlx-audio (optimized for Apple Silicon)."""
2
+
3
+ from typing import Optional, List, Tuple
4
+ import asyncio
5
+ import hashlib
6
+ import io
7
+ import os
8
+ import sys
9
+ import time
10
+ import warnings
11
+ import numpy as np
12
+ from pathlib import Path
13
+
14
+
15
+ # Simple in-memory voice prompt cache
16
+ _prompt_cache: dict[str, dict] = {}
17
+
18
+ # Default streaming interval (seconds per chunk)
19
+ STREAMING_INTERVAL = 2.0
20
+
21
+
22
+ def _cache_key(audio_path: str, text: str) -> str:
23
+ with open(audio_path, "rb") as f:
24
+ return hashlib.md5(f.read() + text.encode()).hexdigest()
25
+
26
+
27
+ def _suppress_library_noise():
28
+ """Suppress noisy warnings from transformers/tokenizers/mlx_audio."""
29
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
30
+ warnings.filterwarnings("ignore", message=".*incorrect regex pattern.*")
31
+ warnings.filterwarnings("ignore", message=".*model of type.*to instantiate.*")
32
+ warnings.filterwarnings("ignore", message=".*not supported for all configurations.*")
33
+ # Suppress transformers logging
34
+ try:
35
+ import logging
36
+ logging.getLogger("transformers").setLevel(logging.ERROR)
37
+ logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
38
+ except Exception:
39
+ pass
40
+
41
+
42
+ class _QuietOutput:
43
+ """Context manager to suppress stdout/stderr including C-level output."""
44
+ def __enter__(self):
45
+ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
46
+ # Save Python-level streams
47
+ self._old_out = sys.stdout
48
+ self._old_err = sys.stderr
49
+ sys.stdout = io.StringIO()
50
+ sys.stderr = io.StringIO()
51
+ # Save and redirect OS-level file descriptors
52
+ try:
53
+ self._devnull = open(os.devnull, "w")
54
+ self._orig_fd_out = os.dup(1)
55
+ self._orig_fd_err = os.dup(2)
56
+ os.dup2(self._devnull.fileno(), 1)
57
+ os.dup2(self._devnull.fileno(), 2)
58
+ self._fd_redirected = True
59
+ except Exception:
60
+ self._fd_redirected = False
61
+ return self
62
+
63
+ def __exit__(self, *args):
64
+ # Restore OS-level file descriptors
65
+ if self._fd_redirected:
66
+ os.dup2(self._orig_fd_out, 1)
67
+ os.dup2(self._orig_fd_err, 2)
68
+ os.close(self._orig_fd_out)
69
+ os.close(self._orig_fd_err)
70
+ self._devnull.close()
71
+ # Restore Python-level streams
72
+ sys.stdout = self._old_out
73
+ sys.stderr = self._old_err
74
+ os.environ.pop("HF_HUB_DISABLE_PROGRESS_BARS", None)
75
+
76
+
77
+ class MLXTTSBackend:
78
+ """MLX-based TTS backend using mlx-audio (Apple Silicon accelerated)."""
79
+
80
+ # Set to True for verbose output (e.g. during `tts init`)
81
+ verbose = False
82
+
83
+ def __init__(self, model_size: str = "0.6B"):
84
+ self.model = None
85
+ self.model_size = model_size
86
+ self._current_model_size = None
87
+ self._warmed_up = False
88
+
89
+ def is_loaded(self) -> bool:
90
+ return self.model is not None
91
+
92
+ def _get_model_path(self, model_size: str) -> str:
93
+ """Get the MLX-community model path for a given size."""
94
+ models = {
95
+ "0.6B": "mlx-community/Qwen3-TTS-12Hz-0.6B-Base-4bit",
96
+ "1.7B": "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-8bit",
97
+ }
98
+ if model_size not in models:
99
+ return model_size
100
+ return models[model_size]
101
+
102
+ async def load_model_async(self, model_size: Optional[str] = None):
103
+ if model_size is None:
104
+ model_size = self.model_size
105
+ if self.model is not None and self._current_model_size == model_size:
106
+ return
107
+ if self.model is not None:
108
+ self.unload_model()
109
+ await asyncio.to_thread(self._load_model_sync, model_size)
110
+
111
+ load_model = load_model_async
112
+
113
+ def _load_model_sync(self, model_size: str):
114
+ import logging
115
+ import warnings as _w
116
+
117
+ _suppress_library_noise()
118
+ model_path = self._get_model_path(model_size)
119
+
120
+ # Silence everything during model load
121
+ _w.filterwarnings("ignore")
122
+ logging.disable(logging.CRITICAL)
123
+ try:
124
+ from mlx_audio.tts.utils import load_model
125
+ with _QuietOutput():
126
+ self.model = load_model(model_path)
127
+ finally:
128
+ logging.disable(logging.NOTSET)
129
+ _w.resetwarnings()
130
+
131
+ self._current_model_size = model_size
132
+ self.model_size = model_size
133
+ self._warmed_up = False
134
+
135
+ def _warmup(self):
136
+ """Run a short generation to trigger MLX JIT compilation."""
137
+ if self._warmed_up or self.model is None:
138
+ return
139
+ try:
140
+ with _QuietOutput():
141
+ for _ in self.model.generate(
142
+ text="Hello.",
143
+ stream=True,
144
+ streaming_interval=STREAMING_INTERVAL,
145
+ verbose=False,
146
+ max_tokens=20,
147
+ ):
148
+ pass
149
+ except Exception:
150
+ pass
151
+ self._warmed_up = True
152
+
153
+ @property
154
+ def sample_rate(self) -> int:
155
+ return self.model.sample_rate if self.model else 24000
156
+
157
+ def unload_model(self):
158
+ if self.model is not None:
159
+ del self.model
160
+ self.model = None
161
+ self._current_model_size = None
162
+ self._warmed_up = False
163
+ try:
164
+ import mlx.core as mx
165
+ mx.clear_cache()
166
+ except Exception:
167
+ pass
168
+
169
+ async def create_voice_prompt(
170
+ self, audio_path: str, reference_text: str, use_cache: bool = True,
171
+ ) -> Tuple[dict, bool]:
172
+ await self.load_model_async(None)
173
+
174
+ if use_cache:
175
+ key = _cache_key(audio_path, reference_text)
176
+ if key in _prompt_cache:
177
+ cached = _prompt_cache[key]
178
+ ref = cached.get("ref_audio")
179
+ if ref and Path(ref).exists():
180
+ return cached, True
181
+
182
+ prompt = {"ref_audio": str(audio_path), "ref_text": reference_text}
183
+
184
+ if use_cache:
185
+ _prompt_cache[_cache_key(audio_path, reference_text)] = prompt
186
+
187
+ return prompt, False
188
+
189
+ async def combine_voice_prompts(
190
+ self,
191
+ audio_paths: List[str],
192
+ reference_texts: List[str],
193
+ ) -> Tuple[np.ndarray, str]:
194
+ if not audio_paths:
195
+ raise ValueError("No audio paths provided")
196
+ combined_text = " ".join(reference_texts)
197
+ return audio_paths[0], combined_text
198
+
199
+ async def generate(
200
+ self, text: str, voice_prompt: dict,
201
+ language: str = "en", seed: Optional[int] = None, instruct: Optional[str] = None,
202
+ ) -> Tuple[np.ndarray, int]:
203
+ await self.load_model_async(None)
204
+
205
+ def _sync():
206
+ self._warmup()
207
+
208
+ audio_chunks = []
209
+ sr = self.sample_rate
210
+
211
+ if seed is not None:
212
+ import mlx.core as mx
213
+ np.random.seed(seed)
214
+ mx.random.seed(seed)
215
+
216
+ ref_audio = voice_prompt.get("ref_audio")
217
+ ref_text = voice_prompt.get("ref_text", "")
218
+ if ref_audio and not Path(ref_audio).exists():
219
+ ref_audio = None
220
+
221
+ gen_kwargs = dict(text=text, verbose=False, max_tokens=4096)
222
+ if ref_audio:
223
+ gen_kwargs["ref_audio"] = ref_audio
224
+ gen_kwargs["ref_text"] = ref_text
225
+ if instruct:
226
+ gen_kwargs["instruct"] = instruct
227
+
228
+ for result in self.model.generate(**gen_kwargs):
229
+ audio_chunks.append(np.array(result.audio))
230
+ sr = result.sample_rate
231
+
232
+ if audio_chunks:
233
+ return np.concatenate([c.astype(np.float32) for c in audio_chunks]), sr
234
+ return np.array([], dtype=np.float32), sr
235
+
236
+ return await asyncio.to_thread(_sync)
237
+
238
+ async def generate_stream(
239
+ self, text: str, voice_prompt: dict,
240
+ language: str = "en", seed: Optional[int] = None, instruct: Optional[str] = None,
241
+ ):
242
+ """Yield (chunk, sample_rate, is_final) as model generates."""
243
+ await self.load_model_async(None)
244
+
245
+ import queue, threading
246
+
247
+ q: queue.Queue = queue.Queue()
248
+ DONE = object()
249
+
250
+ def _produce():
251
+ self._warmup()
252
+
253
+ if seed is not None:
254
+ import mlx.core as mx
255
+ np.random.seed(seed)
256
+ mx.random.seed(seed)
257
+
258
+ ref_audio = voice_prompt.get("ref_audio")
259
+ ref_text = voice_prompt.get("ref_text", "")
260
+ if ref_audio and not Path(ref_audio).exists():
261
+ ref_audio = None
262
+
263
+ gen_kwargs = dict(
264
+ text=text, stream=True,
265
+ streaming_interval=STREAMING_INTERVAL,
266
+ verbose=False, max_tokens=4096,
267
+ )
268
+ if ref_audio:
269
+ gen_kwargs["ref_audio"] = ref_audio
270
+ gen_kwargs["ref_text"] = ref_text
271
+ if instruct:
272
+ gen_kwargs["instruct"] = instruct
273
+
274
+ try:
275
+ for result in self.model.generate(**gen_kwargs):
276
+ audio = np.asarray(result.audio, dtype=np.float32)
277
+ if len(audio) > 0:
278
+ q.put((audio, result.sample_rate))
279
+ except Exception:
280
+ pass
281
+ q.put(DONE)
282
+
283
+ t = threading.Thread(target=_produce, daemon=True)
284
+ t.start()
285
+
286
+ while True:
287
+ while q.empty():
288
+ await asyncio.sleep(0.01)
289
+ item = q.get()
290
+ if item is DONE:
291
+ break
292
+ chunk, sr = item
293
+ is_final = False
294
+ try:
295
+ is_final = not q.empty() and q.queue[0] is DONE
296
+ except (IndexError, AttributeError):
297
+ pass
298
+ yield chunk, sr, is_final
299
+
300
+ t.join(timeout=5.0)
301
+
302
+
303
+ class MLXSTTBackend:
304
+ """MLX-based STT backend using mlx-audio Whisper."""
305
+
306
+ def __init__(self, model_size: str = "base"):
307
+ self.model = None
308
+ self.model_size = model_size
309
+
310
+ def is_loaded(self) -> bool:
311
+ return self.model is not None
312
+
313
+ async def load_model_async(self, model_size: Optional[str] = None):
314
+ if model_size is None:
315
+ model_size = self.model_size
316
+ if self.model is not None and self.model_size == model_size:
317
+ return
318
+ await asyncio.to_thread(self._load_sync, model_size)
319
+
320
+ load_model = load_model_async
321
+
322
+ def _load_sync(self, model_size: str):
323
+ _suppress_library_noise()
324
+ from mlx_audio.stt import load
325
+ with _QuietOutput():
326
+ self.model = load(f"openai/whisper-{model_size}")
327
+ self.model_size = model_size
328
+
329
+ def unload_model(self):
330
+ if self.model is not None:
331
+ del self.model
332
+ self.model = None
333
+
334
+ async def transcribe(self, audio_path: str, language: Optional[str] = None) -> str:
335
+ await self.load_model_async(None)
336
+
337
+ def _sync():
338
+ opts = {}
339
+ if language:
340
+ opts["language"] = language
341
+ result = self.model.generate(str(audio_path), **opts)
342
+ if isinstance(result, str):
343
+ return result.strip()
344
+ if isinstance(result, dict):
345
+ return result.get("text", "").strip()
346
+ if hasattr(result, "text"):
347
+ return result.text.strip()
348
+ return str(result).strip()
349
+
350
+ return await asyncio.to_thread(_sync)