ttscli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ttscli/__init__.py +3 -0
- ttscli/__main__.py +9 -0
- ttscli/audio.py +8 -0
- ttscli/backends/__init__.py +182 -0
- ttscli/backends/mlx.py +350 -0
- ttscli/backends/pytorch.py +234 -0
- ttscli/cli.py +139 -0
- ttscli/commands.py +787 -0
- ttscli/config.py +72 -0
- ttscli/output_format.py +25 -0
- ttscli/platform.py +57 -0
- ttscli/storage.py +20 -0
- ttscli/voices.py +241 -0
- ttscli-0.1.0.dist-info/METADATA +230 -0
- ttscli-0.1.0.dist-info/RECORD +18 -0
- ttscli-0.1.0.dist-info/WHEEL +4 -0
- ttscli-0.1.0.dist-info/entry_points.txt +2 -0
- ttscli-0.1.0.dist-info/licenses/LICENSE +21 -0
ttscli/__init__.py
ADDED
ttscli/__main__.py
ADDED
ttscli/audio.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Backend abstraction layer for TTS and STT.
|
|
3
|
+
|
|
4
|
+
Provides a unified interface for MLX and PyTorch backends.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Protocol, Optional, Tuple, List
|
|
8
|
+
from typing_extensions import runtime_checkable
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from ..platform import get_backend_type
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@runtime_checkable
|
|
15
|
+
class TTSBackend(Protocol):
|
|
16
|
+
"""Protocol for TTS backend implementations."""
|
|
17
|
+
|
|
18
|
+
async def load_model(self, model_size: str) -> None:
|
|
19
|
+
"""Load TTS model."""
|
|
20
|
+
...
|
|
21
|
+
|
|
22
|
+
async def create_voice_prompt(
|
|
23
|
+
self,
|
|
24
|
+
audio_path: str,
|
|
25
|
+
reference_text: str,
|
|
26
|
+
use_cache: bool = True,
|
|
27
|
+
) -> Tuple[dict, bool]:
|
|
28
|
+
"""
|
|
29
|
+
Create voice prompt from reference audio.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Tuple of (voice_prompt_dict, was_cached)
|
|
33
|
+
"""
|
|
34
|
+
...
|
|
35
|
+
|
|
36
|
+
async def combine_voice_prompts(
|
|
37
|
+
self,
|
|
38
|
+
audio_paths: List[str],
|
|
39
|
+
reference_texts: List[str],
|
|
40
|
+
) -> Tuple[np.ndarray, str]:
|
|
41
|
+
"""
|
|
42
|
+
Combine multiple voice prompts.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Tuple of (combined_audio_array, combined_text)
|
|
46
|
+
"""
|
|
47
|
+
...
|
|
48
|
+
|
|
49
|
+
async def generate(
|
|
50
|
+
self,
|
|
51
|
+
text: str,
|
|
52
|
+
voice_prompt: dict,
|
|
53
|
+
language: str = "en",
|
|
54
|
+
seed: Optional[int] = None,
|
|
55
|
+
instruct: Optional[str] = None,
|
|
56
|
+
) -> Tuple[np.ndarray, int]:
|
|
57
|
+
"""
|
|
58
|
+
Generate audio from text.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Tuple of (audio_array, sample_rate)
|
|
62
|
+
"""
|
|
63
|
+
...
|
|
64
|
+
|
|
65
|
+
async def generate_stream(
|
|
66
|
+
self,
|
|
67
|
+
text: str,
|
|
68
|
+
voice_prompt: dict,
|
|
69
|
+
language: str = "en",
|
|
70
|
+
seed: Optional[int] = None,
|
|
71
|
+
instruct: Optional[str] = None,
|
|
72
|
+
):
|
|
73
|
+
"""
|
|
74
|
+
Generate audio from text, yielding chunks as they become available.
|
|
75
|
+
|
|
76
|
+
Yields:
|
|
77
|
+
Tuple of (audio_chunk: np.ndarray, sample_rate: int, is_final: bool)
|
|
78
|
+
"""
|
|
79
|
+
...
|
|
80
|
+
|
|
81
|
+
def unload_model(self) -> None:
|
|
82
|
+
"""Unload model to free memory."""
|
|
83
|
+
...
|
|
84
|
+
|
|
85
|
+
def is_loaded(self) -> bool:
|
|
86
|
+
"""Check if model is loaded."""
|
|
87
|
+
...
|
|
88
|
+
|
|
89
|
+
def _get_model_path(self, model_size: str) -> str:
|
|
90
|
+
"""
|
|
91
|
+
Get model path for a given size.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Model path or HuggingFace Hub ID
|
|
95
|
+
"""
|
|
96
|
+
...
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@runtime_checkable
|
|
100
|
+
class STTBackend(Protocol):
|
|
101
|
+
"""Protocol for STT (Speech-to-Text) backend implementations."""
|
|
102
|
+
|
|
103
|
+
async def load_model(self, model_size: str) -> None:
|
|
104
|
+
"""Load STT model."""
|
|
105
|
+
...
|
|
106
|
+
|
|
107
|
+
async def transcribe(
|
|
108
|
+
self,
|
|
109
|
+
audio_path: str,
|
|
110
|
+
language: Optional[str] = None,
|
|
111
|
+
) -> str:
|
|
112
|
+
"""
|
|
113
|
+
Transcribe audio to text.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Transcribed text
|
|
117
|
+
"""
|
|
118
|
+
...
|
|
119
|
+
|
|
120
|
+
def unload_model(self) -> None:
|
|
121
|
+
"""Unload model to free memory."""
|
|
122
|
+
...
|
|
123
|
+
|
|
124
|
+
def is_loaded(self) -> bool:
|
|
125
|
+
"""Check if model is loaded."""
|
|
126
|
+
...
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
# Global backend instances
|
|
130
|
+
_tts_backend: Optional[TTSBackend] = None
|
|
131
|
+
_stt_backend: Optional[STTBackend] = None
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def get_tts_backend() -> TTSBackend:
|
|
135
|
+
"""
|
|
136
|
+
Get or create TTS backend instance based on platform.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
TTS backend instance (MLX or PyTorch)
|
|
140
|
+
"""
|
|
141
|
+
global _tts_backend
|
|
142
|
+
|
|
143
|
+
if _tts_backend is None:
|
|
144
|
+
backend_type = get_backend_type()
|
|
145
|
+
|
|
146
|
+
if backend_type == "mlx":
|
|
147
|
+
from .mlx import MLXTTSBackend
|
|
148
|
+
_tts_backend = MLXTTSBackend()
|
|
149
|
+
else:
|
|
150
|
+
from .pytorch import PyTorchTTSBackend
|
|
151
|
+
_tts_backend = PyTorchTTSBackend()
|
|
152
|
+
|
|
153
|
+
return _tts_backend
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def get_stt_backend() -> STTBackend:
|
|
157
|
+
"""
|
|
158
|
+
Get or create STT backend instance based on platform.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
STT backend instance (MLX or PyTorch)
|
|
162
|
+
"""
|
|
163
|
+
global _stt_backend
|
|
164
|
+
|
|
165
|
+
if _stt_backend is None:
|
|
166
|
+
backend_type = get_backend_type()
|
|
167
|
+
|
|
168
|
+
if backend_type == "mlx":
|
|
169
|
+
from .mlx import MLXSTTBackend
|
|
170
|
+
_stt_backend = MLXSTTBackend()
|
|
171
|
+
else:
|
|
172
|
+
from .pytorch import PyTorchSTTBackend
|
|
173
|
+
_stt_backend = PyTorchSTTBackend()
|
|
174
|
+
|
|
175
|
+
return _stt_backend
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def reset_backends():
|
|
179
|
+
"""Reset backend instances (useful for testing)."""
|
|
180
|
+
global _tts_backend, _stt_backend
|
|
181
|
+
_tts_backend = None
|
|
182
|
+
_stt_backend = None
|
ttscli/backends/mlx.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
"""MLX backend for TTS using mlx-audio (optimized for Apple Silicon)."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, List, Tuple
|
|
4
|
+
import asyncio
|
|
5
|
+
import hashlib
|
|
6
|
+
import io
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
import time
|
|
10
|
+
import warnings
|
|
11
|
+
import numpy as np
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Simple in-memory voice prompt cache
|
|
16
|
+
_prompt_cache: dict[str, dict] = {}
|
|
17
|
+
|
|
18
|
+
# Default streaming interval (seconds per chunk)
|
|
19
|
+
STREAMING_INTERVAL = 2.0
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _cache_key(audio_path: str, text: str) -> str:
|
|
23
|
+
with open(audio_path, "rb") as f:
|
|
24
|
+
return hashlib.md5(f.read() + text.encode()).hexdigest()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _suppress_library_noise():
|
|
28
|
+
"""Suppress noisy warnings from transformers/tokenizers/mlx_audio."""
|
|
29
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
30
|
+
warnings.filterwarnings("ignore", message=".*incorrect regex pattern.*")
|
|
31
|
+
warnings.filterwarnings("ignore", message=".*model of type.*to instantiate.*")
|
|
32
|
+
warnings.filterwarnings("ignore", message=".*not supported for all configurations.*")
|
|
33
|
+
# Suppress transformers logging
|
|
34
|
+
try:
|
|
35
|
+
import logging
|
|
36
|
+
logging.getLogger("transformers").setLevel(logging.ERROR)
|
|
37
|
+
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
|
|
38
|
+
except Exception:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class _QuietOutput:
|
|
43
|
+
"""Context manager to suppress stdout/stderr including C-level output."""
|
|
44
|
+
def __enter__(self):
|
|
45
|
+
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
|
46
|
+
# Save Python-level streams
|
|
47
|
+
self._old_out = sys.stdout
|
|
48
|
+
self._old_err = sys.stderr
|
|
49
|
+
sys.stdout = io.StringIO()
|
|
50
|
+
sys.stderr = io.StringIO()
|
|
51
|
+
# Save and redirect OS-level file descriptors
|
|
52
|
+
try:
|
|
53
|
+
self._devnull = open(os.devnull, "w")
|
|
54
|
+
self._orig_fd_out = os.dup(1)
|
|
55
|
+
self._orig_fd_err = os.dup(2)
|
|
56
|
+
os.dup2(self._devnull.fileno(), 1)
|
|
57
|
+
os.dup2(self._devnull.fileno(), 2)
|
|
58
|
+
self._fd_redirected = True
|
|
59
|
+
except Exception:
|
|
60
|
+
self._fd_redirected = False
|
|
61
|
+
return self
|
|
62
|
+
|
|
63
|
+
def __exit__(self, *args):
|
|
64
|
+
# Restore OS-level file descriptors
|
|
65
|
+
if self._fd_redirected:
|
|
66
|
+
os.dup2(self._orig_fd_out, 1)
|
|
67
|
+
os.dup2(self._orig_fd_err, 2)
|
|
68
|
+
os.close(self._orig_fd_out)
|
|
69
|
+
os.close(self._orig_fd_err)
|
|
70
|
+
self._devnull.close()
|
|
71
|
+
# Restore Python-level streams
|
|
72
|
+
sys.stdout = self._old_out
|
|
73
|
+
sys.stderr = self._old_err
|
|
74
|
+
os.environ.pop("HF_HUB_DISABLE_PROGRESS_BARS", None)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class MLXTTSBackend:
|
|
78
|
+
"""MLX-based TTS backend using mlx-audio (Apple Silicon accelerated)."""
|
|
79
|
+
|
|
80
|
+
# Set to True for verbose output (e.g. during `tts init`)
|
|
81
|
+
verbose = False
|
|
82
|
+
|
|
83
|
+
def __init__(self, model_size: str = "0.6B"):
|
|
84
|
+
self.model = None
|
|
85
|
+
self.model_size = model_size
|
|
86
|
+
self._current_model_size = None
|
|
87
|
+
self._warmed_up = False
|
|
88
|
+
|
|
89
|
+
def is_loaded(self) -> bool:
|
|
90
|
+
return self.model is not None
|
|
91
|
+
|
|
92
|
+
def _get_model_path(self, model_size: str) -> str:
|
|
93
|
+
"""Get the MLX-community model path for a given size."""
|
|
94
|
+
models = {
|
|
95
|
+
"0.6B": "mlx-community/Qwen3-TTS-12Hz-0.6B-Base-4bit",
|
|
96
|
+
"1.7B": "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-8bit",
|
|
97
|
+
}
|
|
98
|
+
if model_size not in models:
|
|
99
|
+
return model_size
|
|
100
|
+
return models[model_size]
|
|
101
|
+
|
|
102
|
+
async def load_model_async(self, model_size: Optional[str] = None):
|
|
103
|
+
if model_size is None:
|
|
104
|
+
model_size = self.model_size
|
|
105
|
+
if self.model is not None and self._current_model_size == model_size:
|
|
106
|
+
return
|
|
107
|
+
if self.model is not None:
|
|
108
|
+
self.unload_model()
|
|
109
|
+
await asyncio.to_thread(self._load_model_sync, model_size)
|
|
110
|
+
|
|
111
|
+
load_model = load_model_async
|
|
112
|
+
|
|
113
|
+
def _load_model_sync(self, model_size: str):
|
|
114
|
+
import logging
|
|
115
|
+
import warnings as _w
|
|
116
|
+
|
|
117
|
+
_suppress_library_noise()
|
|
118
|
+
model_path = self._get_model_path(model_size)
|
|
119
|
+
|
|
120
|
+
# Silence everything during model load
|
|
121
|
+
_w.filterwarnings("ignore")
|
|
122
|
+
logging.disable(logging.CRITICAL)
|
|
123
|
+
try:
|
|
124
|
+
from mlx_audio.tts.utils import load_model
|
|
125
|
+
with _QuietOutput():
|
|
126
|
+
self.model = load_model(model_path)
|
|
127
|
+
finally:
|
|
128
|
+
logging.disable(logging.NOTSET)
|
|
129
|
+
_w.resetwarnings()
|
|
130
|
+
|
|
131
|
+
self._current_model_size = model_size
|
|
132
|
+
self.model_size = model_size
|
|
133
|
+
self._warmed_up = False
|
|
134
|
+
|
|
135
|
+
def _warmup(self):
|
|
136
|
+
"""Run a short generation to trigger MLX JIT compilation."""
|
|
137
|
+
if self._warmed_up or self.model is None:
|
|
138
|
+
return
|
|
139
|
+
try:
|
|
140
|
+
with _QuietOutput():
|
|
141
|
+
for _ in self.model.generate(
|
|
142
|
+
text="Hello.",
|
|
143
|
+
stream=True,
|
|
144
|
+
streaming_interval=STREAMING_INTERVAL,
|
|
145
|
+
verbose=False,
|
|
146
|
+
max_tokens=20,
|
|
147
|
+
):
|
|
148
|
+
pass
|
|
149
|
+
except Exception:
|
|
150
|
+
pass
|
|
151
|
+
self._warmed_up = True
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def sample_rate(self) -> int:
|
|
155
|
+
return self.model.sample_rate if self.model else 24000
|
|
156
|
+
|
|
157
|
+
def unload_model(self):
|
|
158
|
+
if self.model is not None:
|
|
159
|
+
del self.model
|
|
160
|
+
self.model = None
|
|
161
|
+
self._current_model_size = None
|
|
162
|
+
self._warmed_up = False
|
|
163
|
+
try:
|
|
164
|
+
import mlx.core as mx
|
|
165
|
+
mx.clear_cache()
|
|
166
|
+
except Exception:
|
|
167
|
+
pass
|
|
168
|
+
|
|
169
|
+
async def create_voice_prompt(
|
|
170
|
+
self, audio_path: str, reference_text: str, use_cache: bool = True,
|
|
171
|
+
) -> Tuple[dict, bool]:
|
|
172
|
+
await self.load_model_async(None)
|
|
173
|
+
|
|
174
|
+
if use_cache:
|
|
175
|
+
key = _cache_key(audio_path, reference_text)
|
|
176
|
+
if key in _prompt_cache:
|
|
177
|
+
cached = _prompt_cache[key]
|
|
178
|
+
ref = cached.get("ref_audio")
|
|
179
|
+
if ref and Path(ref).exists():
|
|
180
|
+
return cached, True
|
|
181
|
+
|
|
182
|
+
prompt = {"ref_audio": str(audio_path), "ref_text": reference_text}
|
|
183
|
+
|
|
184
|
+
if use_cache:
|
|
185
|
+
_prompt_cache[_cache_key(audio_path, reference_text)] = prompt
|
|
186
|
+
|
|
187
|
+
return prompt, False
|
|
188
|
+
|
|
189
|
+
async def combine_voice_prompts(
|
|
190
|
+
self,
|
|
191
|
+
audio_paths: List[str],
|
|
192
|
+
reference_texts: List[str],
|
|
193
|
+
) -> Tuple[np.ndarray, str]:
|
|
194
|
+
if not audio_paths:
|
|
195
|
+
raise ValueError("No audio paths provided")
|
|
196
|
+
combined_text = " ".join(reference_texts)
|
|
197
|
+
return audio_paths[0], combined_text
|
|
198
|
+
|
|
199
|
+
async def generate(
|
|
200
|
+
self, text: str, voice_prompt: dict,
|
|
201
|
+
language: str = "en", seed: Optional[int] = None, instruct: Optional[str] = None,
|
|
202
|
+
) -> Tuple[np.ndarray, int]:
|
|
203
|
+
await self.load_model_async(None)
|
|
204
|
+
|
|
205
|
+
def _sync():
|
|
206
|
+
self._warmup()
|
|
207
|
+
|
|
208
|
+
audio_chunks = []
|
|
209
|
+
sr = self.sample_rate
|
|
210
|
+
|
|
211
|
+
if seed is not None:
|
|
212
|
+
import mlx.core as mx
|
|
213
|
+
np.random.seed(seed)
|
|
214
|
+
mx.random.seed(seed)
|
|
215
|
+
|
|
216
|
+
ref_audio = voice_prompt.get("ref_audio")
|
|
217
|
+
ref_text = voice_prompt.get("ref_text", "")
|
|
218
|
+
if ref_audio and not Path(ref_audio).exists():
|
|
219
|
+
ref_audio = None
|
|
220
|
+
|
|
221
|
+
gen_kwargs = dict(text=text, verbose=False, max_tokens=4096)
|
|
222
|
+
if ref_audio:
|
|
223
|
+
gen_kwargs["ref_audio"] = ref_audio
|
|
224
|
+
gen_kwargs["ref_text"] = ref_text
|
|
225
|
+
if instruct:
|
|
226
|
+
gen_kwargs["instruct"] = instruct
|
|
227
|
+
|
|
228
|
+
for result in self.model.generate(**gen_kwargs):
|
|
229
|
+
audio_chunks.append(np.array(result.audio))
|
|
230
|
+
sr = result.sample_rate
|
|
231
|
+
|
|
232
|
+
if audio_chunks:
|
|
233
|
+
return np.concatenate([c.astype(np.float32) for c in audio_chunks]), sr
|
|
234
|
+
return np.array([], dtype=np.float32), sr
|
|
235
|
+
|
|
236
|
+
return await asyncio.to_thread(_sync)
|
|
237
|
+
|
|
238
|
+
async def generate_stream(
|
|
239
|
+
self, text: str, voice_prompt: dict,
|
|
240
|
+
language: str = "en", seed: Optional[int] = None, instruct: Optional[str] = None,
|
|
241
|
+
):
|
|
242
|
+
"""Yield (chunk, sample_rate, is_final) as model generates."""
|
|
243
|
+
await self.load_model_async(None)
|
|
244
|
+
|
|
245
|
+
import queue, threading
|
|
246
|
+
|
|
247
|
+
q: queue.Queue = queue.Queue()
|
|
248
|
+
DONE = object()
|
|
249
|
+
|
|
250
|
+
def _produce():
|
|
251
|
+
self._warmup()
|
|
252
|
+
|
|
253
|
+
if seed is not None:
|
|
254
|
+
import mlx.core as mx
|
|
255
|
+
np.random.seed(seed)
|
|
256
|
+
mx.random.seed(seed)
|
|
257
|
+
|
|
258
|
+
ref_audio = voice_prompt.get("ref_audio")
|
|
259
|
+
ref_text = voice_prompt.get("ref_text", "")
|
|
260
|
+
if ref_audio and not Path(ref_audio).exists():
|
|
261
|
+
ref_audio = None
|
|
262
|
+
|
|
263
|
+
gen_kwargs = dict(
|
|
264
|
+
text=text, stream=True,
|
|
265
|
+
streaming_interval=STREAMING_INTERVAL,
|
|
266
|
+
verbose=False, max_tokens=4096,
|
|
267
|
+
)
|
|
268
|
+
if ref_audio:
|
|
269
|
+
gen_kwargs["ref_audio"] = ref_audio
|
|
270
|
+
gen_kwargs["ref_text"] = ref_text
|
|
271
|
+
if instruct:
|
|
272
|
+
gen_kwargs["instruct"] = instruct
|
|
273
|
+
|
|
274
|
+
try:
|
|
275
|
+
for result in self.model.generate(**gen_kwargs):
|
|
276
|
+
audio = np.asarray(result.audio, dtype=np.float32)
|
|
277
|
+
if len(audio) > 0:
|
|
278
|
+
q.put((audio, result.sample_rate))
|
|
279
|
+
except Exception:
|
|
280
|
+
pass
|
|
281
|
+
q.put(DONE)
|
|
282
|
+
|
|
283
|
+
t = threading.Thread(target=_produce, daemon=True)
|
|
284
|
+
t.start()
|
|
285
|
+
|
|
286
|
+
while True:
|
|
287
|
+
while q.empty():
|
|
288
|
+
await asyncio.sleep(0.01)
|
|
289
|
+
item = q.get()
|
|
290
|
+
if item is DONE:
|
|
291
|
+
break
|
|
292
|
+
chunk, sr = item
|
|
293
|
+
is_final = False
|
|
294
|
+
try:
|
|
295
|
+
is_final = not q.empty() and q.queue[0] is DONE
|
|
296
|
+
except (IndexError, AttributeError):
|
|
297
|
+
pass
|
|
298
|
+
yield chunk, sr, is_final
|
|
299
|
+
|
|
300
|
+
t.join(timeout=5.0)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class MLXSTTBackend:
|
|
304
|
+
"""MLX-based STT backend using mlx-audio Whisper."""
|
|
305
|
+
|
|
306
|
+
def __init__(self, model_size: str = "base"):
|
|
307
|
+
self.model = None
|
|
308
|
+
self.model_size = model_size
|
|
309
|
+
|
|
310
|
+
def is_loaded(self) -> bool:
|
|
311
|
+
return self.model is not None
|
|
312
|
+
|
|
313
|
+
async def load_model_async(self, model_size: Optional[str] = None):
|
|
314
|
+
if model_size is None:
|
|
315
|
+
model_size = self.model_size
|
|
316
|
+
if self.model is not None and self.model_size == model_size:
|
|
317
|
+
return
|
|
318
|
+
await asyncio.to_thread(self._load_sync, model_size)
|
|
319
|
+
|
|
320
|
+
load_model = load_model_async
|
|
321
|
+
|
|
322
|
+
def _load_sync(self, model_size: str):
|
|
323
|
+
_suppress_library_noise()
|
|
324
|
+
from mlx_audio.stt import load
|
|
325
|
+
with _QuietOutput():
|
|
326
|
+
self.model = load(f"openai/whisper-{model_size}")
|
|
327
|
+
self.model_size = model_size
|
|
328
|
+
|
|
329
|
+
def unload_model(self):
|
|
330
|
+
if self.model is not None:
|
|
331
|
+
del self.model
|
|
332
|
+
self.model = None
|
|
333
|
+
|
|
334
|
+
async def transcribe(self, audio_path: str, language: Optional[str] = None) -> str:
|
|
335
|
+
await self.load_model_async(None)
|
|
336
|
+
|
|
337
|
+
def _sync():
|
|
338
|
+
opts = {}
|
|
339
|
+
if language:
|
|
340
|
+
opts["language"] = language
|
|
341
|
+
result = self.model.generate(str(audio_path), **opts)
|
|
342
|
+
if isinstance(result, str):
|
|
343
|
+
return result.strip()
|
|
344
|
+
if isinstance(result, dict):
|
|
345
|
+
return result.get("text", "").strip()
|
|
346
|
+
if hasattr(result, "text"):
|
|
347
|
+
return result.text.strip()
|
|
348
|
+
return str(result).strip()
|
|
349
|
+
|
|
350
|
+
return await asyncio.to_thread(_sync)
|