openspeechapi 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openspeech/__init__.py +75 -0
- openspeech/__main__.py +5 -0
- openspeech/cli.py +413 -0
- openspeech/client/__init__.py +4 -0
- openspeech/client/client.py +145 -0
- openspeech/config.py +212 -0
- openspeech/core/__init__.py +0 -0
- openspeech/core/base.py +75 -0
- openspeech/core/enums.py +39 -0
- openspeech/core/models.py +61 -0
- openspeech/core/registry.py +37 -0
- openspeech/core/settings.py +8 -0
- openspeech/demo.py +675 -0
- openspeech/dispatch/__init__.py +0 -0
- openspeech/dispatch/context.py +34 -0
- openspeech/dispatch/dispatcher.py +661 -0
- openspeech/dispatch/executors/__init__.py +0 -0
- openspeech/dispatch/executors/base.py +34 -0
- openspeech/dispatch/executors/in_process.py +66 -0
- openspeech/dispatch/executors/remote.py +64 -0
- openspeech/dispatch/executors/subprocess_exec.py +446 -0
- openspeech/dispatch/fanout.py +95 -0
- openspeech/dispatch/filters.py +73 -0
- openspeech/dispatch/lifecycle.py +178 -0
- openspeech/dispatch/watcher.py +82 -0
- openspeech/engine_catalog.py +236 -0
- openspeech/engine_registry.yaml +347 -0
- openspeech/exceptions.py +51 -0
- openspeech/factory.py +325 -0
- openspeech/local_engines/__init__.py +12 -0
- openspeech/local_engines/aim_resolver.py +91 -0
- openspeech/local_engines/backends/__init__.py +1 -0
- openspeech/local_engines/backends/docker_backend.py +490 -0
- openspeech/local_engines/backends/native_backend.py +902 -0
- openspeech/local_engines/base.py +30 -0
- openspeech/local_engines/engines/__init__.py +1 -0
- openspeech/local_engines/engines/faster_whisper.py +36 -0
- openspeech/local_engines/engines/fish_speech.py +33 -0
- openspeech/local_engines/engines/sherpa_onnx.py +56 -0
- openspeech/local_engines/engines/whisper.py +41 -0
- openspeech/local_engines/engines/whisperlivekit.py +60 -0
- openspeech/local_engines/manager.py +208 -0
- openspeech/local_engines/models.py +50 -0
- openspeech/local_engines/progress.py +69 -0
- openspeech/local_engines/registry.py +19 -0
- openspeech/local_engines/task_store.py +52 -0
- openspeech/local_engines/tasks.py +71 -0
- openspeech/logging_config.py +607 -0
- openspeech/observe/__init__.py +0 -0
- openspeech/observe/base.py +79 -0
- openspeech/observe/debug.py +44 -0
- openspeech/observe/latency.py +19 -0
- openspeech/observe/metrics.py +47 -0
- openspeech/observe/tracing.py +44 -0
- openspeech/observe/usage.py +27 -0
- openspeech/providers/__init__.py +0 -0
- openspeech/providers/_template.py +101 -0
- openspeech/providers/stt/__init__.py +0 -0
- openspeech/providers/stt/alibaba.py +86 -0
- openspeech/providers/stt/assemblyai.py +135 -0
- openspeech/providers/stt/azure_speech.py +99 -0
- openspeech/providers/stt/baidu.py +135 -0
- openspeech/providers/stt/deepgram.py +311 -0
- openspeech/providers/stt/elevenlabs.py +385 -0
- openspeech/providers/stt/faster_whisper.py +211 -0
- openspeech/providers/stt/google_cloud.py +106 -0
- openspeech/providers/stt/iflytek.py +427 -0
- openspeech/providers/stt/macos_speech.py +226 -0
- openspeech/providers/stt/openai.py +84 -0
- openspeech/providers/stt/sherpa_onnx.py +353 -0
- openspeech/providers/stt/tencent.py +212 -0
- openspeech/providers/stt/volcengine.py +107 -0
- openspeech/providers/stt/whisper.py +153 -0
- openspeech/providers/stt/whisperlivekit.py +530 -0
- openspeech/providers/stt/windows_speech.py +249 -0
- openspeech/providers/tts/__init__.py +0 -0
- openspeech/providers/tts/alibaba.py +95 -0
- openspeech/providers/tts/azure_speech.py +123 -0
- openspeech/providers/tts/baidu.py +143 -0
- openspeech/providers/tts/coqui.py +64 -0
- openspeech/providers/tts/cosyvoice.py +90 -0
- openspeech/providers/tts/deepgram.py +174 -0
- openspeech/providers/tts/elevenlabs.py +311 -0
- openspeech/providers/tts/fish_speech.py +158 -0
- openspeech/providers/tts/google_cloud.py +107 -0
- openspeech/providers/tts/iflytek.py +209 -0
- openspeech/providers/tts/macos_say.py +251 -0
- openspeech/providers/tts/minimax.py +122 -0
- openspeech/providers/tts/openai.py +104 -0
- openspeech/providers/tts/piper.py +104 -0
- openspeech/providers/tts/tencent.py +189 -0
- openspeech/providers/tts/volcengine.py +117 -0
- openspeech/providers/tts/windows_sapi.py +234 -0
- openspeech/server/__init__.py +1 -0
- openspeech/server/app.py +72 -0
- openspeech/server/auth.py +42 -0
- openspeech/server/middleware.py +75 -0
- openspeech/server/routes/__init__.py +1 -0
- openspeech/server/routes/management.py +848 -0
- openspeech/server/routes/stt.py +121 -0
- openspeech/server/routes/tts.py +159 -0
- openspeech/server/routes/webui.py +29 -0
- openspeech/server/webui/app.js +2649 -0
- openspeech/server/webui/index.html +216 -0
- openspeech/server/webui/styles.css +617 -0
- openspeech/server/ws/__init__.py +1 -0
- openspeech/server/ws/stt_stream.py +263 -0
- openspeech/server/ws/tts_stream.py +207 -0
- openspeech/telemetry/__init__.py +21 -0
- openspeech/telemetry/perf.py +307 -0
- openspeech/utils/__init__.py +5 -0
- openspeech/utils/audio_converter.py +406 -0
- openspeech/utils/audio_playback.py +156 -0
- openspeech/vendor_registry.yaml +74 -0
- openspeechapi-0.1.0.dist-info/METADATA +101 -0
- openspeechapi-0.1.0.dist-info/RECORD +118 -0
- openspeechapi-0.1.0.dist-info/WHEEL +4 -0
- openspeechapi-0.1.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
"""ElevenLabs STT provider adapter (batch + realtime WS)."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import asyncio
|
|
5
|
+
import base64
|
|
6
|
+
import inspect
|
|
7
|
+
import json
|
|
8
|
+
from urllib.parse import urlencode
|
|
9
|
+
from collections.abc import AsyncIterator
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import httpx
|
|
14
|
+
|
|
15
|
+
from openspeech.core.base import STTProvider
|
|
16
|
+
from openspeech.core.enums import AudioFormat, Capability, ExecMode, ProviderType
|
|
17
|
+
from openspeech.core.models import AudioData, STTOptions, Transcription, Word
|
|
18
|
+
from openspeech.core.settings import BaseSettings
|
|
19
|
+
from openspeech.logging_config import logger
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _ws_connect_with_headers(websockets_mod, url: str, headers: dict[str, str]):
|
|
23
|
+
"""Compatible connect kwargs across websockets versions.
|
|
24
|
+
|
|
25
|
+
websockets>=15 uses ``additional_headers``; older versions use ``extra_headers``.
|
|
26
|
+
"""
|
|
27
|
+
try:
|
|
28
|
+
sig = inspect.signature(websockets_mod.connect)
|
|
29
|
+
if "additional_headers" in sig.parameters:
|
|
30
|
+
return websockets_mod.connect(url, additional_headers=headers)
|
|
31
|
+
except Exception:
|
|
32
|
+
pass
|
|
33
|
+
return websockets_mod.connect(url, extra_headers=headers)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class ElevenLabsSTTSettings(BaseSettings):
|
|
38
|
+
api_key: str = ""
|
|
39
|
+
model_id: str = "scribe_v2"
|
|
40
|
+
realtime_model_id: str = "scribe_v2_realtime"
|
|
41
|
+
language_code: str = ""
|
|
42
|
+
diarize: bool = False
|
|
43
|
+
tag_audio_events: bool = True
|
|
44
|
+
num_speakers: int | None = None
|
|
45
|
+
timestamps_granularity: str = "word"
|
|
46
|
+
diarization_threshold: float | None = None
|
|
47
|
+
use_multi_channel: bool = False
|
|
48
|
+
no_verbatim: bool = False
|
|
49
|
+
detect_speaker_roles: bool = False
|
|
50
|
+
temperature: float | None = None
|
|
51
|
+
entity_detection: str = ""
|
|
52
|
+
realtime_audio_format: str = "pcm_16000"
|
|
53
|
+
realtime_commit_strategy: str = "vad" # manual | vad
|
|
54
|
+
realtime_include_timestamps: bool = True
|
|
55
|
+
realtime_include_language_detection: bool = True
|
|
56
|
+
realtime_vad_silence_threshold_secs: float = 1.0
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ElevenLabsSTT(STTProvider):
|
|
60
|
+
name = "elevenlabs-stt"
|
|
61
|
+
provider_type = ProviderType.STT
|
|
62
|
+
execution_mode = ExecMode.REMOTE
|
|
63
|
+
settings_cls = ElevenLabsSTTSettings
|
|
64
|
+
capabilities = {Capability.BATCH, Capability.STREAMING, Capability.MULTILINGUAL}
|
|
65
|
+
field_options = {
|
|
66
|
+
"model_id": ["scribe_v2", "scribe_v1"],
|
|
67
|
+
"realtime_model_id": ["scribe_v2_realtime"],
|
|
68
|
+
"realtime_audio_format": ["pcm_16000", "pcm_22050", "pcm_24000", "pcm_44100"],
|
|
69
|
+
"realtime_commit_strategy": ["vad", "manual"],
|
|
70
|
+
"timestamps_granularity": ["none", "word", "character"],
|
|
71
|
+
"language_code": [
|
|
72
|
+
"",
|
|
73
|
+
"en", "zh", "ja", "ko", "es", "fr", "de", "pt", "it", "hi",
|
|
74
|
+
"id", "nl", "tr", "pl", "sv", "ar", "ru", "uk", "vi", "hu",
|
|
75
|
+
"no", "da", "fi", "cs", "el", "ro", "bg", "hr", "sk", "ms",
|
|
76
|
+
"ta", "fil",
|
|
77
|
+
],
|
|
78
|
+
"diarize": [False, True],
|
|
79
|
+
"tag_audio_events": [True, False],
|
|
80
|
+
"no_verbatim": [False, True],
|
|
81
|
+
"detect_speaker_roles": [False, True],
|
|
82
|
+
"use_multi_channel": [False, True],
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
def __init__(self, settings: ElevenLabsSTTSettings | None = None) -> None:
|
|
86
|
+
self.settings = settings or ElevenLabsSTTSettings()
|
|
87
|
+
self._client: httpx.AsyncClient | None = None
|
|
88
|
+
self._owns_client: bool = True
|
|
89
|
+
|
|
90
|
+
def set_http_client(self, client) -> None:
|
|
91
|
+
self._client = client
|
|
92
|
+
self._owns_client = False
|
|
93
|
+
|
|
94
|
+
async def start(self) -> None:
|
|
95
|
+
if self._client is None:
|
|
96
|
+
self._client = httpx.AsyncClient(timeout=60.0, trust_env=False)
|
|
97
|
+
self._owns_client = True
|
|
98
|
+
logger.info("{} provider started", self.name)
|
|
99
|
+
|
|
100
|
+
async def stop(self) -> None:
|
|
101
|
+
if self._client and self._owns_client:
|
|
102
|
+
await self._client.aclose()
|
|
103
|
+
self._client = None
|
|
104
|
+
logger.info("{} provider stopped", self.name)
|
|
105
|
+
|
|
106
|
+
async def health_check(self) -> bool:
|
|
107
|
+
return bool(self.settings.api_key)
|
|
108
|
+
|
|
109
|
+
async def transcribe(
|
|
110
|
+
self, audio: AudioData, opts: STTOptions | None = None
|
|
111
|
+
) -> Transcription:
|
|
112
|
+
if self._client is None:
|
|
113
|
+
raise RuntimeError("Provider not started — call start() first")
|
|
114
|
+
|
|
115
|
+
opts = opts or STTOptions()
|
|
116
|
+
language_code = (opts.language or self.settings.language_code or "").strip()
|
|
117
|
+
|
|
118
|
+
data: dict[str, str] = {
|
|
119
|
+
"model_id": self.settings.model_id,
|
|
120
|
+
"diarize": "true" if self.settings.diarize else "false",
|
|
121
|
+
"tag_audio_events": "true" if self.settings.tag_audio_events else "false",
|
|
122
|
+
"timestamps_granularity": self.settings.timestamps_granularity,
|
|
123
|
+
"no_verbatim": "true" if self.settings.no_verbatim else "false",
|
|
124
|
+
"detect_speaker_roles": "true" if self.settings.detect_speaker_roles else "false",
|
|
125
|
+
"use_multi_channel": "true" if self.settings.use_multi_channel else "false",
|
|
126
|
+
}
|
|
127
|
+
if language_code:
|
|
128
|
+
data["language_code"] = language_code
|
|
129
|
+
if self.settings.num_speakers is not None:
|
|
130
|
+
data["num_speakers"] = str(self.settings.num_speakers)
|
|
131
|
+
if self.settings.diarization_threshold is not None:
|
|
132
|
+
data["diarization_threshold"] = str(self.settings.diarization_threshold)
|
|
133
|
+
if self.settings.temperature is not None:
|
|
134
|
+
data["temperature"] = str(self.settings.temperature)
|
|
135
|
+
if self.settings.entity_detection:
|
|
136
|
+
data["entity_detection"] = self.settings.entity_detection
|
|
137
|
+
|
|
138
|
+
files = {
|
|
139
|
+
"file": (
|
|
140
|
+
"audio",
|
|
141
|
+
audio.data,
|
|
142
|
+
self._mime_for_format(audio.format),
|
|
143
|
+
)
|
|
144
|
+
}
|
|
145
|
+
headers = {"xi-api-key": self.settings.api_key}
|
|
146
|
+
|
|
147
|
+
resp = await self._client.post(
|
|
148
|
+
"https://api.elevenlabs.io/v1/speech-to-text",
|
|
149
|
+
headers=headers,
|
|
150
|
+
data=data,
|
|
151
|
+
files=files,
|
|
152
|
+
)
|
|
153
|
+
if resp.status_code != 200:
|
|
154
|
+
raise RuntimeError(
|
|
155
|
+
f"ElevenLabs STT API error {resp.status_code}: {resp.text}"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
payload = resp.json()
|
|
159
|
+
text = str(payload.get("text", "") or "")
|
|
160
|
+
if not text and isinstance(payload.get("transcripts"), list):
|
|
161
|
+
parts: list[str] = []
|
|
162
|
+
for item in payload.get("transcripts", []):
|
|
163
|
+
channel_text = str((item or {}).get("text", "") or "").strip()
|
|
164
|
+
if channel_text:
|
|
165
|
+
parts.append(channel_text)
|
|
166
|
+
text = "\n".join(parts)
|
|
167
|
+
response_lang = payload.get("language_code")
|
|
168
|
+
|
|
169
|
+
words: list[Word] = []
|
|
170
|
+
source_word_lists = []
|
|
171
|
+
if isinstance(payload.get("words"), list):
|
|
172
|
+
source_word_lists.append(payload.get("words", []))
|
|
173
|
+
if isinstance(payload.get("transcripts"), list):
|
|
174
|
+
for item in payload.get("transcripts", []):
|
|
175
|
+
if isinstance((item or {}).get("words"), list):
|
|
176
|
+
source_word_lists.append((item or {}).get("words", []))
|
|
177
|
+
for word_items in source_word_lists:
|
|
178
|
+
for item in word_items:
|
|
179
|
+
wtxt = str(item.get("text", "") or "").strip()
|
|
180
|
+
if not wtxt:
|
|
181
|
+
continue
|
|
182
|
+
start_ms = int(float(item.get("start", 0.0) or 0.0) * 1000)
|
|
183
|
+
end_ms = int(float(item.get("end", 0.0) or 0.0) * 1000)
|
|
184
|
+
words.append(
|
|
185
|
+
Word(
|
|
186
|
+
text=wtxt,
|
|
187
|
+
start_ms=start_ms,
|
|
188
|
+
end_ms=end_ms,
|
|
189
|
+
)
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
return Transcription(
|
|
193
|
+
text=text,
|
|
194
|
+
language=response_lang or (language_code or None),
|
|
195
|
+
confidence=float(payload.get("language_probability", 0.0) or 0.0)
|
|
196
|
+
if payload.get("language_probability") is not None
|
|
197
|
+
else None,
|
|
198
|
+
words=words if words else None,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
async def transcribe_stream(
|
|
202
|
+
self, stream: AsyncIterator[bytes]
|
|
203
|
+
) -> AsyncIterator[Any]:
|
|
204
|
+
"""Realtime STT over ElevenLabs WebSocket.
|
|
205
|
+
|
|
206
|
+
Protocol summary (official docs):
|
|
207
|
+
- Connect: ``wss://api.elevenlabs.io/v1/speech-to-text/realtime``
|
|
208
|
+
- Send audio frames as ``input_audio_chunk`` with base64 payload
|
|
209
|
+
- Receive ``partial_transcript`` and ``committed_transcript*`` events
|
|
210
|
+
"""
|
|
211
|
+
if self._client is None:
|
|
212
|
+
raise RuntimeError("Provider not started — call start() first")
|
|
213
|
+
|
|
214
|
+
import websockets
|
|
215
|
+
|
|
216
|
+
params = {
|
|
217
|
+
"model_id": self.settings.realtime_model_id,
|
|
218
|
+
"audio_format": self.settings.realtime_audio_format,
|
|
219
|
+
"commit_strategy": self.settings.realtime_commit_strategy,
|
|
220
|
+
"include_timestamps": "true" if self.settings.realtime_include_timestamps else "false",
|
|
221
|
+
"include_language_detection": "true" if self.settings.realtime_include_language_detection else "false",
|
|
222
|
+
"vad_silence_threshold_secs": str(self.settings.realtime_vad_silence_threshold_secs),
|
|
223
|
+
}
|
|
224
|
+
language_code = (self.settings.language_code or "").strip()
|
|
225
|
+
if language_code:
|
|
226
|
+
params["language_code"] = language_code
|
|
227
|
+
|
|
228
|
+
url = "wss://api.elevenlabs.io/v1/speech-to-text/realtime?" + urlencode(params)
|
|
229
|
+
headers = {"xi-api-key": self.settings.api_key}
|
|
230
|
+
|
|
231
|
+
logger.debug("{}: connecting realtime WS...", self.name)
|
|
232
|
+
_frames_sent = 0
|
|
233
|
+
_sender_done = asyncio.Event()
|
|
234
|
+
_confirmed_parts: list[str] = []
|
|
235
|
+
_last_partial: str = ""
|
|
236
|
+
results: asyncio.Queue[Transcription | None] = asyncio.Queue()
|
|
237
|
+
|
|
238
|
+
async with _ws_connect_with_headers(websockets, url, headers) as ws:
|
|
239
|
+
async def _sender() -> None:
|
|
240
|
+
nonlocal _frames_sent
|
|
241
|
+
try:
|
|
242
|
+
async for chunk in stream:
|
|
243
|
+
if not chunk:
|
|
244
|
+
continue
|
|
245
|
+
payload = {
|
|
246
|
+
"message_type": "input_audio_chunk",
|
|
247
|
+
"audio_base_64": base64.b64encode(chunk).decode("ascii"),
|
|
248
|
+
"sample_rate": 16000,
|
|
249
|
+
}
|
|
250
|
+
# In manual mode, we explicitly commit each frame.
|
|
251
|
+
if self.settings.realtime_commit_strategy == "manual":
|
|
252
|
+
payload["commit"] = True
|
|
253
|
+
await ws.send(json.dumps(payload))
|
|
254
|
+
_frames_sent += 1
|
|
255
|
+
finally:
|
|
256
|
+
# Trigger a final commit after stream ends to flush any buffered text.
|
|
257
|
+
if self.settings.realtime_commit_strategy != "manual":
|
|
258
|
+
try:
|
|
259
|
+
await ws.send(json.dumps({
|
|
260
|
+
"message_type": "input_audio_chunk",
|
|
261
|
+
"audio_base_64": "",
|
|
262
|
+
"sample_rate": 16000,
|
|
263
|
+
"commit": True,
|
|
264
|
+
}))
|
|
265
|
+
except Exception:
|
|
266
|
+
pass
|
|
267
|
+
_sender_done.set()
|
|
268
|
+
|
|
269
|
+
async def _receiver() -> None:
|
|
270
|
+
nonlocal _last_partial
|
|
271
|
+
try:
|
|
272
|
+
async for raw in ws:
|
|
273
|
+
data = json.loads(raw)
|
|
274
|
+
mtype = str(data.get("message_type", "") or "")
|
|
275
|
+
if mtype == "session_started":
|
|
276
|
+
continue
|
|
277
|
+
|
|
278
|
+
if mtype.startswith("scribe") or mtype == "error":
|
|
279
|
+
detail = data.get("message") or data.get("detail") or data
|
|
280
|
+
raise RuntimeError(f"ElevenLabs realtime STT error: {detail}")
|
|
281
|
+
|
|
282
|
+
if mtype == "partial_transcript":
|
|
283
|
+
partial = str(data.get("text", "") or "").strip()
|
|
284
|
+
if not partial:
|
|
285
|
+
continue
|
|
286
|
+
_last_partial = partial
|
|
287
|
+
snapshot = self._snapshot_text(_confirmed_parts, _last_partial)
|
|
288
|
+
if not snapshot:
|
|
289
|
+
continue
|
|
290
|
+
await results.put(Transcription(
|
|
291
|
+
text=snapshot,
|
|
292
|
+
language=data.get("language_code") or language_code or None,
|
|
293
|
+
confidence=data.get("confidence"),
|
|
294
|
+
is_partial=True,
|
|
295
|
+
))
|
|
296
|
+
continue
|
|
297
|
+
|
|
298
|
+
if mtype in {"committed_transcript", "committed_transcript_with_timestamps"}:
|
|
299
|
+
committed = str(data.get("text", "") or "").strip()
|
|
300
|
+
if committed:
|
|
301
|
+
if not _confirmed_parts or _confirmed_parts[-1] != committed:
|
|
302
|
+
_confirmed_parts.append(committed)
|
|
303
|
+
_last_partial = ""
|
|
304
|
+
snapshot = self._snapshot_text(_confirmed_parts, _last_partial)
|
|
305
|
+
if not snapshot:
|
|
306
|
+
continue
|
|
307
|
+
words = None
|
|
308
|
+
if mtype == "committed_transcript_with_timestamps":
|
|
309
|
+
words = self._parse_stream_words(data.get("words") or [])
|
|
310
|
+
await results.put(Transcription(
|
|
311
|
+
text=snapshot,
|
|
312
|
+
language=data.get("language_code") or language_code or None,
|
|
313
|
+
confidence=data.get("confidence"),
|
|
314
|
+
words=words,
|
|
315
|
+
is_partial=False,
|
|
316
|
+
))
|
|
317
|
+
finally:
|
|
318
|
+
await results.put(None)
|
|
319
|
+
|
|
320
|
+
sender_task = asyncio.create_task(_sender())
|
|
321
|
+
receiver_task = asyncio.create_task(_receiver())
|
|
322
|
+
|
|
323
|
+
while True:
|
|
324
|
+
try:
|
|
325
|
+
timeout = 1.2 if _sender_done.is_set() else None
|
|
326
|
+
item = await asyncio.wait_for(results.get(), timeout=timeout)
|
|
327
|
+
except asyncio.TimeoutError:
|
|
328
|
+
break
|
|
329
|
+
if item is None:
|
|
330
|
+
break
|
|
331
|
+
yield item
|
|
332
|
+
|
|
333
|
+
if not sender_task.done():
|
|
334
|
+
sender_task.cancel()
|
|
335
|
+
if not receiver_task.done():
|
|
336
|
+
receiver_task.cancel()
|
|
337
|
+
for task in (sender_task, receiver_task):
|
|
338
|
+
try:
|
|
339
|
+
await task
|
|
340
|
+
except asyncio.CancelledError:
|
|
341
|
+
pass
|
|
342
|
+
|
|
343
|
+
logger.debug("{}: realtime WS finished, sent {} frames", self.name, _frames_sent)
|
|
344
|
+
|
|
345
|
+
@staticmethod
|
|
346
|
+
def _mime_for_format(fmt: AudioFormat) -> str:
|
|
347
|
+
if fmt == AudioFormat.WAV:
|
|
348
|
+
return "audio/wav"
|
|
349
|
+
if fmt == AudioFormat.MP3:
|
|
350
|
+
return "audio/mpeg"
|
|
351
|
+
if fmt == AudioFormat.FLAC:
|
|
352
|
+
return "audio/flac"
|
|
353
|
+
if fmt == AudioFormat.OGG:
|
|
354
|
+
return "audio/ogg"
|
|
355
|
+
if fmt == AudioFormat.OPUS:
|
|
356
|
+
return "audio/opus"
|
|
357
|
+
if fmt in (AudioFormat.PCM_16K, AudioFormat.PCM_44K):
|
|
358
|
+
return "audio/wav"
|
|
359
|
+
return "application/octet-stream"
|
|
360
|
+
|
|
361
|
+
@staticmethod
|
|
362
|
+
def _snapshot_text(confirmed_parts: list[str], partial: str) -> str:
|
|
363
|
+
base = " ".join([p for p in confirmed_parts if p]).strip()
|
|
364
|
+
p = (partial or "").strip()
|
|
365
|
+
if not base:
|
|
366
|
+
return p
|
|
367
|
+
if not p:
|
|
368
|
+
return base
|
|
369
|
+
if p.startswith(base):
|
|
370
|
+
return p
|
|
371
|
+
return f"{base} {p}".strip()
|
|
372
|
+
|
|
373
|
+
@staticmethod
|
|
374
|
+
def _parse_stream_words(words_payload: list[dict]) -> list[Word] | None:
|
|
375
|
+
words: list[Word] = []
|
|
376
|
+
for item in words_payload:
|
|
377
|
+
if str(item.get("type", "word")) != "word":
|
|
378
|
+
continue
|
|
379
|
+
txt = str(item.get("text", "") or "").strip()
|
|
380
|
+
if not txt:
|
|
381
|
+
continue
|
|
382
|
+
start_ms = int(float(item.get("start", 0.0) or 0.0) * 1000)
|
|
383
|
+
end_ms = int(float(item.get("end", 0.0) or 0.0) * 1000)
|
|
384
|
+
words.append(Word(text=txt, start_ms=start_ms, end_ms=end_ms))
|
|
385
|
+
return words or None
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
"""Faster-Whisper STT provider adapter (local, subprocess mode)."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from openspeech.logging_config import logger
|
|
5
|
+
import os
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
import tempfile
|
|
8
|
+
import time
|
|
9
|
+
from collections.abc import AsyncIterator
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from openspeech.core.base import STTProvider
|
|
14
|
+
|
|
15
|
+
from openspeech.core.enums import Capability, ExecMode, ProviderType
|
|
16
|
+
from openspeech.core.models import AudioData, STTOptions, Transcription, Word
|
|
17
|
+
from openspeech.core.settings import BaseSettings
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class FasterWhisperSTTSettings(BaseSettings):
|
|
21
|
+
model_size: str = "base"
|
|
22
|
+
device: str = "auto"
|
|
23
|
+
compute_type: str = "default"
|
|
24
|
+
download_root: str | None = None
|
|
25
|
+
beam_size: int = 5
|
|
26
|
+
|
|
27
|
+
class FasterWhisperSTT(STTProvider):
|
|
28
|
+
name = "faster-whisper"
|
|
29
|
+
provider_type = ProviderType.STT
|
|
30
|
+
execution_mode = ExecMode.SUBPROCESS
|
|
31
|
+
settings_cls = FasterWhisperSTTSettings
|
|
32
|
+
capabilities = {
|
|
33
|
+
Capability.BATCH,
|
|
34
|
+
Capability.MULTILINGUAL,
|
|
35
|
+
Capability.WORD_TIMESTAMPS,
|
|
36
|
+
}
|
|
37
|
+
field_options = {"device": ["auto", "cpu", "cuda", "mps"], "compute_type": ["default", "auto", "float16", "float32", "int8", "int8_float16"], "model_size": ["tiny", "base", "small", "medium", "large-v2", "large-v3", "large-v3-turbo"]}
|
|
38
|
+
|
|
39
|
+
def __init__(self, settings: FasterWhisperSTTSettings | None = None) -> None:
|
|
40
|
+
self.settings = settings or FasterWhisperSTTSettings()
|
|
41
|
+
self._client: Any = None
|
|
42
|
+
self._model: Any = None
|
|
43
|
+
self._loaded_model_size = self.settings.model_size
|
|
44
|
+
self._loaded_device = self.settings.device
|
|
45
|
+
self._loaded_compute_type = self.settings.compute_type
|
|
46
|
+
|
|
47
|
+
async def start(self) -> None:
|
|
48
|
+
await self._ensure_model(
|
|
49
|
+
model_size=self.settings.model_size,
|
|
50
|
+
device=self.settings.device,
|
|
51
|
+
compute_type=self.settings.compute_type,
|
|
52
|
+
)
|
|
53
|
+
logger.info("{} provider started", self.name)
|
|
54
|
+
|
|
55
|
+
async def _ensure_model(
|
|
56
|
+
self,
|
|
57
|
+
*,
|
|
58
|
+
model_size: str,
|
|
59
|
+
device: str,
|
|
60
|
+
compute_type: str,
|
|
61
|
+
) -> None:
|
|
62
|
+
try:
|
|
63
|
+
from faster_whisper import WhisperModel
|
|
64
|
+
except ImportError:
|
|
65
|
+
raise ImportError(
|
|
66
|
+
"Install faster-whisper: pip install openspeech[faster-whisper]"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
normalized_model_size = self._resolve_model_dir(model_size)
|
|
70
|
+
|
|
71
|
+
if (
|
|
72
|
+
self._model is not None
|
|
73
|
+
and self._loaded_model_size == normalized_model_size
|
|
74
|
+
and self._loaded_device == device
|
|
75
|
+
and self._loaded_compute_type == compute_type
|
|
76
|
+
):
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
self._model = WhisperModel(
|
|
80
|
+
normalized_model_size,
|
|
81
|
+
device=device,
|
|
82
|
+
compute_type=compute_type,
|
|
83
|
+
download_root=self.settings.download_root,
|
|
84
|
+
)
|
|
85
|
+
self._client = self._model # for health_check
|
|
86
|
+
self._loaded_model_size = normalized_model_size
|
|
87
|
+
self._loaded_device = device
|
|
88
|
+
self._loaded_compute_type = compute_type
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def _resolve_model_dir(model_size: str) -> str:
|
|
92
|
+
raw = (model_size or "").strip()
|
|
93
|
+
p = Path(raw).expanduser()
|
|
94
|
+
if not p.is_absolute() or not p.exists() or not p.is_dir():
|
|
95
|
+
return raw
|
|
96
|
+
if (p / "model.bin").exists():
|
|
97
|
+
return str(p)
|
|
98
|
+
snapshots = p / "snapshots"
|
|
99
|
+
if snapshots.is_dir():
|
|
100
|
+
for d in sorted(snapshots.iterdir()):
|
|
101
|
+
if d.is_dir() and (d / "model.bin").exists():
|
|
102
|
+
return str(d)
|
|
103
|
+
return str(p)
|
|
104
|
+
|
|
105
|
+
async def stop(self) -> None:
|
|
106
|
+
self._client = None
|
|
107
|
+
self._model = None
|
|
108
|
+
logger.info("{} provider stopped", self.name)
|
|
109
|
+
|
|
110
|
+
async def health_check(self) -> bool:
|
|
111
|
+
return self._client is not None
|
|
112
|
+
|
|
113
|
+
async def transcribe(
|
|
114
|
+
self, audio: AudioData, opts: STTOptions | None = None
|
|
115
|
+
) -> Transcription:
|
|
116
|
+
if self._client is None:
|
|
117
|
+
raise RuntimeError("Provider not started — call start() first")
|
|
118
|
+
logger.info("{}: request received, audio={} bytes", self.name, len(audio.data))
|
|
119
|
+
_t0 = time.perf_counter()
|
|
120
|
+
|
|
121
|
+
opts = opts or STTOptions()
|
|
122
|
+
requested_model = (opts.model or self.settings.model_size).strip()
|
|
123
|
+
requested_device = (opts.device or self.settings.device).strip()
|
|
124
|
+
requested_compute_type = (opts.compute_type or self.settings.compute_type).strip()
|
|
125
|
+
await self._ensure_model(
|
|
126
|
+
model_size=requested_model,
|
|
127
|
+
device=requested_device,
|
|
128
|
+
compute_type=requested_compute_type,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Ensure proper WAV headers — raw PCM from TTS providers needs wrapping
|
|
132
|
+
import io
|
|
133
|
+
import wave
|
|
134
|
+
|
|
135
|
+
is_wav = len(audio.data) > 4 and audio.data[:4] == b"RIFF"
|
|
136
|
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
|
137
|
+
if is_wav:
|
|
138
|
+
f.write(audio.data)
|
|
139
|
+
else:
|
|
140
|
+
buf = io.BytesIO()
|
|
141
|
+
with wave.open(buf, "wb") as wf:
|
|
142
|
+
wf.setnchannels(audio.channels)
|
|
143
|
+
wf.setsampwidth(2)
|
|
144
|
+
wf.setframerate(audio.sample_rate)
|
|
145
|
+
wf.writeframes(audio.data)
|
|
146
|
+
f.write(buf.getvalue())
|
|
147
|
+
tmp_path = f.name
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
beam_size = int(opts.beam_size) if opts.beam_size is not None else int(self.settings.beam_size)
|
|
151
|
+
kwargs: dict[str, Any] = {"beam_size": beam_size}
|
|
152
|
+
if opts.language:
|
|
153
|
+
kwargs["language"] = opts.language
|
|
154
|
+
|
|
155
|
+
segments, info = self._model.transcribe(tmp_path, **kwargs)
|
|
156
|
+
|
|
157
|
+
text_parts = []
|
|
158
|
+
words_list: list[Word] = []
|
|
159
|
+
max_end_s = 0.0
|
|
160
|
+
for segment in segments:
|
|
161
|
+
text_parts.append(segment.text)
|
|
162
|
+
seg_end = getattr(segment, "end", None)
|
|
163
|
+
if seg_end is not None:
|
|
164
|
+
try:
|
|
165
|
+
max_end_s = max(max_end_s, float(seg_end))
|
|
166
|
+
except Exception:
|
|
167
|
+
pass
|
|
168
|
+
if hasattr(segment, "words") and segment.words:
|
|
169
|
+
for w in segment.words:
|
|
170
|
+
words_list.append(
|
|
171
|
+
Word(
|
|
172
|
+
text=w.word,
|
|
173
|
+
start_ms=int(w.start * 1000),
|
|
174
|
+
end_ms=int(w.end * 1000),
|
|
175
|
+
confidence=getattr(w, "probability", None),
|
|
176
|
+
)
|
|
177
|
+
)
|
|
178
|
+
try:
|
|
179
|
+
max_end_s = max(max_end_s, float(w.end))
|
|
180
|
+
except Exception:
|
|
181
|
+
pass
|
|
182
|
+
|
|
183
|
+
duration_ms: int | None = None
|
|
184
|
+
info_duration = getattr(info, "duration", None)
|
|
185
|
+
if info_duration is not None:
|
|
186
|
+
try:
|
|
187
|
+
duration_ms = int(float(info_duration) * 1000)
|
|
188
|
+
except Exception:
|
|
189
|
+
duration_ms = None
|
|
190
|
+
if duration_ms is None and max_end_s > 0:
|
|
191
|
+
duration_ms = int(max_end_s * 1000)
|
|
192
|
+
|
|
193
|
+
result = Transcription(
|
|
194
|
+
text=" ".join(text_parts).strip(),
|
|
195
|
+
language=info.language if hasattr(info, "language") else None,
|
|
196
|
+
confidence=getattr(info, "language_probability", None),
|
|
197
|
+
words=words_list if words_list else None,
|
|
198
|
+
duration_ms=duration_ms,
|
|
199
|
+
)
|
|
200
|
+
logger.info("{}: completed in {:.0f}ms, result={} chars", self.name, (time.perf_counter() - _t0) * 1000, len(result.text))
|
|
201
|
+
return result
|
|
202
|
+
finally:
|
|
203
|
+
os.unlink(tmp_path)
|
|
204
|
+
|
|
205
|
+
async def transcribe_stream(
|
|
206
|
+
self, stream: AsyncIterator[bytes]
|
|
207
|
+
) -> AsyncIterator[Any]:
|
|
208
|
+
raise NotImplementedError(
|
|
209
|
+
"FasterWhisperSTT does not support streaming transcription"
|
|
210
|
+
)
|
|
211
|
+
yield # pragma: no cover
|