openspeechapi 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openspeech/__init__.py +75 -0
- openspeech/__main__.py +5 -0
- openspeech/cli.py +413 -0
- openspeech/client/__init__.py +4 -0
- openspeech/client/client.py +145 -0
- openspeech/config.py +212 -0
- openspeech/core/__init__.py +0 -0
- openspeech/core/base.py +75 -0
- openspeech/core/enums.py +39 -0
- openspeech/core/models.py +61 -0
- openspeech/core/registry.py +37 -0
- openspeech/core/settings.py +8 -0
- openspeech/demo.py +675 -0
- openspeech/dispatch/__init__.py +0 -0
- openspeech/dispatch/context.py +34 -0
- openspeech/dispatch/dispatcher.py +661 -0
- openspeech/dispatch/executors/__init__.py +0 -0
- openspeech/dispatch/executors/base.py +34 -0
- openspeech/dispatch/executors/in_process.py +66 -0
- openspeech/dispatch/executors/remote.py +64 -0
- openspeech/dispatch/executors/subprocess_exec.py +446 -0
- openspeech/dispatch/fanout.py +95 -0
- openspeech/dispatch/filters.py +73 -0
- openspeech/dispatch/lifecycle.py +178 -0
- openspeech/dispatch/watcher.py +82 -0
- openspeech/engine_catalog.py +236 -0
- openspeech/engine_registry.yaml +347 -0
- openspeech/exceptions.py +51 -0
- openspeech/factory.py +325 -0
- openspeech/local_engines/__init__.py +12 -0
- openspeech/local_engines/aim_resolver.py +91 -0
- openspeech/local_engines/backends/__init__.py +1 -0
- openspeech/local_engines/backends/docker_backend.py +490 -0
- openspeech/local_engines/backends/native_backend.py +902 -0
- openspeech/local_engines/base.py +30 -0
- openspeech/local_engines/engines/__init__.py +1 -0
- openspeech/local_engines/engines/faster_whisper.py +36 -0
- openspeech/local_engines/engines/fish_speech.py +33 -0
- openspeech/local_engines/engines/sherpa_onnx.py +56 -0
- openspeech/local_engines/engines/whisper.py +41 -0
- openspeech/local_engines/engines/whisperlivekit.py +60 -0
- openspeech/local_engines/manager.py +208 -0
- openspeech/local_engines/models.py +50 -0
- openspeech/local_engines/progress.py +69 -0
- openspeech/local_engines/registry.py +19 -0
- openspeech/local_engines/task_store.py +52 -0
- openspeech/local_engines/tasks.py +71 -0
- openspeech/logging_config.py +607 -0
- openspeech/observe/__init__.py +0 -0
- openspeech/observe/base.py +79 -0
- openspeech/observe/debug.py +44 -0
- openspeech/observe/latency.py +19 -0
- openspeech/observe/metrics.py +47 -0
- openspeech/observe/tracing.py +44 -0
- openspeech/observe/usage.py +27 -0
- openspeech/providers/__init__.py +0 -0
- openspeech/providers/_template.py +101 -0
- openspeech/providers/stt/__init__.py +0 -0
- openspeech/providers/stt/alibaba.py +86 -0
- openspeech/providers/stt/assemblyai.py +135 -0
- openspeech/providers/stt/azure_speech.py +99 -0
- openspeech/providers/stt/baidu.py +135 -0
- openspeech/providers/stt/deepgram.py +311 -0
- openspeech/providers/stt/elevenlabs.py +385 -0
- openspeech/providers/stt/faster_whisper.py +211 -0
- openspeech/providers/stt/google_cloud.py +106 -0
- openspeech/providers/stt/iflytek.py +427 -0
- openspeech/providers/stt/macos_speech.py +226 -0
- openspeech/providers/stt/openai.py +84 -0
- openspeech/providers/stt/sherpa_onnx.py +353 -0
- openspeech/providers/stt/tencent.py +212 -0
- openspeech/providers/stt/volcengine.py +107 -0
- openspeech/providers/stt/whisper.py +153 -0
- openspeech/providers/stt/whisperlivekit.py +530 -0
- openspeech/providers/stt/windows_speech.py +249 -0
- openspeech/providers/tts/__init__.py +0 -0
- openspeech/providers/tts/alibaba.py +95 -0
- openspeech/providers/tts/azure_speech.py +123 -0
- openspeech/providers/tts/baidu.py +143 -0
- openspeech/providers/tts/coqui.py +64 -0
- openspeech/providers/tts/cosyvoice.py +90 -0
- openspeech/providers/tts/deepgram.py +174 -0
- openspeech/providers/tts/elevenlabs.py +311 -0
- openspeech/providers/tts/fish_speech.py +158 -0
- openspeech/providers/tts/google_cloud.py +107 -0
- openspeech/providers/tts/iflytek.py +209 -0
- openspeech/providers/tts/macos_say.py +251 -0
- openspeech/providers/tts/minimax.py +122 -0
- openspeech/providers/tts/openai.py +104 -0
- openspeech/providers/tts/piper.py +104 -0
- openspeech/providers/tts/tencent.py +189 -0
- openspeech/providers/tts/volcengine.py +117 -0
- openspeech/providers/tts/windows_sapi.py +234 -0
- openspeech/server/__init__.py +1 -0
- openspeech/server/app.py +72 -0
- openspeech/server/auth.py +42 -0
- openspeech/server/middleware.py +75 -0
- openspeech/server/routes/__init__.py +1 -0
- openspeech/server/routes/management.py +848 -0
- openspeech/server/routes/stt.py +121 -0
- openspeech/server/routes/tts.py +159 -0
- openspeech/server/routes/webui.py +29 -0
- openspeech/server/webui/app.js +2649 -0
- openspeech/server/webui/index.html +216 -0
- openspeech/server/webui/styles.css +617 -0
- openspeech/server/ws/__init__.py +1 -0
- openspeech/server/ws/stt_stream.py +263 -0
- openspeech/server/ws/tts_stream.py +207 -0
- openspeech/telemetry/__init__.py +21 -0
- openspeech/telemetry/perf.py +307 -0
- openspeech/utils/__init__.py +5 -0
- openspeech/utils/audio_converter.py +406 -0
- openspeech/utils/audio_playback.py +156 -0
- openspeech/vendor_registry.yaml +74 -0
- openspeechapi-0.1.0.dist-info/METADATA +101 -0
- openspeechapi-0.1.0.dist-info/RECORD +118 -0
- openspeechapi-0.1.0.dist-info/WHEEL +4 -0
- openspeechapi-0.1.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""Tencent Cloud ASR STT provider adapter (TC3-HMAC-SHA256 signed)."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import base64
|
|
5
|
+
import hashlib
|
|
6
|
+
import hmac
|
|
7
|
+
import json
|
|
8
|
+
from openspeech.logging_config import logger
|
|
9
|
+
import time
|
|
10
|
+
from collections.abc import AsyncIterator
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from datetime import datetime, timezone
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import httpx
|
|
16
|
+
|
|
17
|
+
from openspeech.core.base import STTProvider
|
|
18
|
+
|
|
19
|
+
from openspeech.core.enums import Capability, ExecMode, ProviderType
|
|
20
|
+
from openspeech.core.models import AudioData, STTOptions, Transcription
|
|
21
|
+
from openspeech.core.settings import BaseSettings
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class TencentSTTSettings(BaseSettings):
|
|
25
|
+
secret_id: str = ""
|
|
26
|
+
secret_key: str = ""
|
|
27
|
+
engine_type: str = "16k_zh"
|
|
28
|
+
region: str = "ap-guangzhou"
|
|
29
|
+
|
|
30
|
+
class TencentSTT(STTProvider):
|
|
31
|
+
name = "tencent-stt"
|
|
32
|
+
provider_type = ProviderType.STT
|
|
33
|
+
execution_mode = ExecMode.IN_PROCESS
|
|
34
|
+
settings_cls = TencentSTTSettings
|
|
35
|
+
capabilities = {Capability.BATCH, Capability.MULTILINGUAL}
|
|
36
|
+
field_options = {"engine_type": ["16k_zh", "16k_en", "16k_ja", "16k_ko", "16k_zh_dialect", "16k_yue", "16k_zh_medical"], "region": ["ap-guangzhou", "ap-shanghai", "ap-beijing", "ap-chengdu"]}
|
|
37
|
+
|
|
38
|
+
_SERVICE = "asr"
|
|
39
|
+
_HOST = "asr.tencentcloudapi.com"
|
|
40
|
+
|
|
41
|
+
def __init__(self, settings: TencentSTTSettings | None = None) -> None:
|
|
42
|
+
self.settings = settings or TencentSTTSettings()
|
|
43
|
+
self._client: httpx.AsyncClient | None = None
|
|
44
|
+
self._owns_client: bool = True
|
|
45
|
+
|
|
46
|
+
def set_http_client(self, client) -> None:
|
|
47
|
+
self._client = client
|
|
48
|
+
self._owns_client = False
|
|
49
|
+
|
|
50
|
+
async def start(self) -> None:
|
|
51
|
+
if self._client is None:
|
|
52
|
+
self._client = httpx.AsyncClient(timeout=60.0)
|
|
53
|
+
self._owns_client = True
|
|
54
|
+
logger.info("{} provider started", self.name)
|
|
55
|
+
|
|
56
|
+
async def stop(self) -> None:
|
|
57
|
+
if self._client and self._owns_client:
|
|
58
|
+
await self._client.aclose()
|
|
59
|
+
self._client = None
|
|
60
|
+
logger.info("{} provider stopped", self.name)
|
|
61
|
+
|
|
62
|
+
async def health_check(self) -> bool:
|
|
63
|
+
return bool(self.settings.secret_id) and bool(self.settings.secret_key)
|
|
64
|
+
|
|
65
|
+
# ---- TC3-HMAC-SHA256 signing ------------------------------------------------
|
|
66
|
+
|
|
67
|
+
def _sign_request(
|
|
68
|
+
self, action: str, version: str, payload_json: str
|
|
69
|
+
) -> dict[str, str]:
|
|
70
|
+
"""Build Tencent Cloud API v3 signed headers."""
|
|
71
|
+
timestamp = int(time.time())
|
|
72
|
+
date = datetime.fromtimestamp(timestamp, tz=timezone.utc).strftime(
|
|
73
|
+
"%Y-%m-%d"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# 1. Canonical request
|
|
77
|
+
http_method = "POST"
|
|
78
|
+
canonical_uri = "/"
|
|
79
|
+
canonical_querystring = ""
|
|
80
|
+
ct = "application/json; charset=utf-8"
|
|
81
|
+
canonical_headers = (
|
|
82
|
+
f"content-type:{ct}\nhost:{self._HOST}\nx-tc-action:{action.lower()}\n"
|
|
83
|
+
)
|
|
84
|
+
signed_headers = "content-type;host;x-tc-action"
|
|
85
|
+
hashed_payload = hashlib.sha256(payload_json.encode("utf-8")).hexdigest()
|
|
86
|
+
canonical_request = (
|
|
87
|
+
f"{http_method}\n{canonical_uri}\n{canonical_querystring}\n"
|
|
88
|
+
f"{canonical_headers}\n{signed_headers}\n{hashed_payload}"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# 2. String to sign
|
|
92
|
+
algorithm = "TC3-HMAC-SHA256"
|
|
93
|
+
credential_scope = f"{date}/{self._SERVICE}/tc3_request"
|
|
94
|
+
hashed_canonical = hashlib.sha256(
|
|
95
|
+
canonical_request.encode("utf-8")
|
|
96
|
+
).hexdigest()
|
|
97
|
+
string_to_sign = (
|
|
98
|
+
f"{algorithm}\n{timestamp}\n{credential_scope}\n{hashed_canonical}"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# 3. Signing key
|
|
102
|
+
secret_date = hmac.new(
|
|
103
|
+
("TC3" + self.settings.secret_key).encode("utf-8"),
|
|
104
|
+
date.encode("utf-8"),
|
|
105
|
+
hashlib.sha256,
|
|
106
|
+
).digest()
|
|
107
|
+
secret_service = hmac.new(
|
|
108
|
+
secret_date, self._SERVICE.encode("utf-8"), hashlib.sha256
|
|
109
|
+
).digest()
|
|
110
|
+
secret_signing = hmac.new(
|
|
111
|
+
secret_service, b"tc3_request", hashlib.sha256
|
|
112
|
+
).digest()
|
|
113
|
+
|
|
114
|
+
# 4. Signature
|
|
115
|
+
signature = hmac.new(
|
|
116
|
+
secret_signing, string_to_sign.encode("utf-8"), hashlib.sha256
|
|
117
|
+
).hexdigest()
|
|
118
|
+
|
|
119
|
+
authorization = (
|
|
120
|
+
f"{algorithm} Credential={self.settings.secret_id}/{credential_scope}, "
|
|
121
|
+
f"SignedHeaders={signed_headers}, Signature={signature}"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return {
|
|
125
|
+
"Authorization": authorization,
|
|
126
|
+
"Content-Type": ct,
|
|
127
|
+
"Host": self._HOST,
|
|
128
|
+
"X-TC-Action": action,
|
|
129
|
+
"X-TC-Version": version,
|
|
130
|
+
"X-TC-Timestamp": str(timestamp),
|
|
131
|
+
"X-TC-Region": self.settings.region,
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
# ---- API calls ---------------------------------------------------------------
|
|
135
|
+
|
|
136
|
+
async def _post(
|
|
137
|
+
self, action: str, version: str, payload: dict[str, Any]
|
|
138
|
+
) -> dict[str, Any]:
|
|
139
|
+
payload_json = json.dumps(payload)
|
|
140
|
+
headers = self._sign_request(action, version, payload_json)
|
|
141
|
+
resp = await self._client.post( # type: ignore[union-attr]
|
|
142
|
+
f"https://{self._HOST}",
|
|
143
|
+
content=payload_json,
|
|
144
|
+
headers=headers,
|
|
145
|
+
)
|
|
146
|
+
resp.raise_for_status()
|
|
147
|
+
result = resp.json()
|
|
148
|
+
response = result.get("Response", {})
|
|
149
|
+
if "Error" in response:
|
|
150
|
+
err = response["Error"]
|
|
151
|
+
raise RuntimeError(
|
|
152
|
+
f"Tencent ASR error [{err.get('Code')}]: {err.get('Message')}"
|
|
153
|
+
)
|
|
154
|
+
return response
|
|
155
|
+
|
|
156
|
+
async def transcribe(
|
|
157
|
+
self, audio: AudioData, opts: STTOptions | None = None
|
|
158
|
+
) -> Transcription:
|
|
159
|
+
if self._client is None:
|
|
160
|
+
raise RuntimeError("Provider not started — call start() first")
|
|
161
|
+
logger.info("{}: request received, audio={} bytes", self.name, len(audio.data))
|
|
162
|
+
_t0 = time.perf_counter()
|
|
163
|
+
|
|
164
|
+
base64_audio = base64.b64encode(audio.data).decode("utf-8")
|
|
165
|
+
|
|
166
|
+
# Step 1: Create recognition task
|
|
167
|
+
create_resp = await self._post(
|
|
168
|
+
action="CreateRecTask",
|
|
169
|
+
version="2019-06-14",
|
|
170
|
+
payload={
|
|
171
|
+
"EngineModelType": self.settings.engine_type,
|
|
172
|
+
"ChannelNum": 1,
|
|
173
|
+
"ResTextFormat": 0,
|
|
174
|
+
"SourceType": 1,
|
|
175
|
+
"Data": base64_audio,
|
|
176
|
+
"DataLen": len(audio.data),
|
|
177
|
+
},
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
task_id = create_resp.get("Data", {}).get("TaskId")
|
|
181
|
+
if not task_id:
|
|
182
|
+
raise RuntimeError("Tencent ASR: no TaskId in CreateRecTask response")
|
|
183
|
+
|
|
184
|
+
# Step 2: Poll for result
|
|
185
|
+
import asyncio
|
|
186
|
+
|
|
187
|
+
for _ in range(60):
|
|
188
|
+
await asyncio.sleep(2)
|
|
189
|
+
status_resp = await self._post(
|
|
190
|
+
action="DescribeTaskStatus",
|
|
191
|
+
version="2019-06-14",
|
|
192
|
+
payload={"TaskId": task_id},
|
|
193
|
+
)
|
|
194
|
+
task_data = status_resp.get("Data", {})
|
|
195
|
+
status = task_data.get("StatusStr", "")
|
|
196
|
+
if status == "success":
|
|
197
|
+
text = task_data.get("Result", "")
|
|
198
|
+
result = Transcription(text=text)
|
|
199
|
+
logger.info("{}: completed in {:.0f}ms, result={} chars", self.name, (time.perf_counter() - _t0) * 1000, len(result.text))
|
|
200
|
+
return result
|
|
201
|
+
if status == "failed":
|
|
202
|
+
raise RuntimeError(
|
|
203
|
+
f"Tencent ASR task failed: {task_data.get('ErrorMsg', 'unknown')}"
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
raise RuntimeError("Tencent ASR: task polling timed out")
|
|
207
|
+
|
|
208
|
+
async def transcribe_stream(
|
|
209
|
+
self, stream: AsyncIterator[bytes]
|
|
210
|
+
) -> AsyncIterator[Any]:
|
|
211
|
+
raise NotImplementedError("Tencent STT streaming not implemented")
|
|
212
|
+
yield # noqa: unreachable — makes this an async generator
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Volcengine (ByteDance) STT provider adapter."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import base64
|
|
5
|
+
from openspeech.logging_config import logger
|
|
6
|
+
import time
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
|
|
13
|
+
from openspeech.core.base import STTProvider
|
|
14
|
+
|
|
15
|
+
from openspeech.core.enums import Capability, ExecMode, ProviderType
|
|
16
|
+
from openspeech.core.models import AudioData, STTOptions, Transcription
|
|
17
|
+
from openspeech.core.settings import BaseSettings
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class VolcengineSTTSettings(BaseSettings):
|
|
21
|
+
access_token: str = ""
|
|
22
|
+
app_id: str = ""
|
|
23
|
+
cluster: str = "volcengine_streaming_common"
|
|
24
|
+
|
|
25
|
+
class VolcengineSTT(STTProvider):
|
|
26
|
+
name = "volcengine-stt"
|
|
27
|
+
provider_type = ProviderType.STT
|
|
28
|
+
execution_mode = ExecMode.IN_PROCESS
|
|
29
|
+
settings_cls = VolcengineSTTSettings
|
|
30
|
+
capabilities = {Capability.BATCH, Capability.MULTILINGUAL}
|
|
31
|
+
field_options = {"cluster": ["volcengine_streaming_common", "volcengine_input_common"]}
|
|
32
|
+
|
|
33
|
+
def __init__(self, settings: VolcengineSTTSettings | None = None) -> None:
|
|
34
|
+
self.settings = settings or VolcengineSTTSettings()
|
|
35
|
+
self._client: httpx.AsyncClient | None = None
|
|
36
|
+
self._owns_client: bool = True
|
|
37
|
+
|
|
38
|
+
def set_http_client(self, client) -> None:
|
|
39
|
+
self._client = client
|
|
40
|
+
self._owns_client = False
|
|
41
|
+
|
|
42
|
+
async def start(self) -> None:
|
|
43
|
+
if self._client is None:
|
|
44
|
+
self._client = httpx.AsyncClient(timeout=60.0)
|
|
45
|
+
self._owns_client = True
|
|
46
|
+
logger.info("{} provider started", self.name)
|
|
47
|
+
|
|
48
|
+
async def stop(self) -> None:
|
|
49
|
+
if self._client and self._owns_client:
|
|
50
|
+
await self._client.aclose()
|
|
51
|
+
self._client = None
|
|
52
|
+
logger.info("{} provider stopped", self.name)
|
|
53
|
+
|
|
54
|
+
async def health_check(self) -> bool:
|
|
55
|
+
return bool(self.settings.access_token) and bool(self.settings.app_id)
|
|
56
|
+
|
|
57
|
+
async def transcribe(
|
|
58
|
+
self, audio: AudioData, opts: STTOptions | None = None
|
|
59
|
+
) -> Transcription:
|
|
60
|
+
if self._client is None:
|
|
61
|
+
raise RuntimeError("Provider not started — call start() first")
|
|
62
|
+
logger.info("{}: request received, audio={} bytes", self.name, len(audio.data))
|
|
63
|
+
_t0 = time.perf_counter()
|
|
64
|
+
|
|
65
|
+
base64_audio = base64.b64encode(audio.data).decode("utf-8")
|
|
66
|
+
payload = {
|
|
67
|
+
"app": {
|
|
68
|
+
"appid": self.settings.app_id,
|
|
69
|
+
"cluster": self.settings.cluster,
|
|
70
|
+
},
|
|
71
|
+
"audio": {
|
|
72
|
+
"format": "wav",
|
|
73
|
+
"data": base64_audio,
|
|
74
|
+
},
|
|
75
|
+
"user": {"uid": "openspeech"},
|
|
76
|
+
}
|
|
77
|
+
headers = {
|
|
78
|
+
"Authorization": f"Bearer;{self.settings.access_token}",
|
|
79
|
+
"Content-Type": "application/json",
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
resp = await self._client.post(
|
|
83
|
+
"https://openspeech.bytedance.com/api/v1/auc/submit",
|
|
84
|
+
json=payload,
|
|
85
|
+
headers=headers,
|
|
86
|
+
)
|
|
87
|
+
resp.raise_for_status()
|
|
88
|
+
data = resp.json()
|
|
89
|
+
|
|
90
|
+
if data.get("code") != 0 and data.get("code") is not None:
|
|
91
|
+
raise RuntimeError(
|
|
92
|
+
f"Volcengine STT error: {data.get('message', 'unknown error')}"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
text = data.get("result", "")
|
|
96
|
+
if isinstance(text, dict):
|
|
97
|
+
text = text.get("text", "")
|
|
98
|
+
|
|
99
|
+
result = Transcription(text=text)
|
|
100
|
+
logger.info("{}: completed in {:.0f}ms, result={} chars", self.name, (time.perf_counter() - _t0) * 1000, len(result.text))
|
|
101
|
+
return result
|
|
102
|
+
|
|
103
|
+
async def transcribe_stream(
|
|
104
|
+
self, stream: AsyncIterator[bytes]
|
|
105
|
+
) -> AsyncIterator[Any]:
|
|
106
|
+
raise NotImplementedError("Volcengine STT streaming not implemented")
|
|
107
|
+
yield # noqa: unreachable — makes this an async generator
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""OpenAI Whisper (local) STT provider adapter (subprocess mode)."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from collections.abc import AsyncIterator
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
import io
|
|
7
|
+
from openspeech.logging_config import logger
|
|
8
|
+
import os
|
|
9
|
+
import tempfile
|
|
10
|
+
import time
|
|
11
|
+
from typing import Any
|
|
12
|
+
import wave
|
|
13
|
+
|
|
14
|
+
from openspeech.core.base import STTProvider
|
|
15
|
+
|
|
16
|
+
from openspeech.core.enums import Capability, ExecMode, ProviderType
|
|
17
|
+
from openspeech.core.models import AudioData, STTOptions, Transcription, Word
|
|
18
|
+
from openspeech.core.settings import BaseSettings
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class WhisperSTTSettings(BaseSettings):
|
|
22
|
+
model_name: str = "base"
|
|
23
|
+
device: str = "cpu"
|
|
24
|
+
fp16: bool = False
|
|
25
|
+
download_root: str | None = None
|
|
26
|
+
beam_size: int = 5
|
|
27
|
+
|
|
28
|
+
class WhisperSTT(STTProvider):
|
|
29
|
+
name = "whisper"
|
|
30
|
+
provider_type = ProviderType.STT
|
|
31
|
+
execution_mode = ExecMode.SUBPROCESS
|
|
32
|
+
settings_cls = WhisperSTTSettings
|
|
33
|
+
capabilities = {Capability.BATCH, Capability.MULTILINGUAL}
|
|
34
|
+
field_options = {"device": ["cpu", "cuda", "mps"], "model_name": ["tiny", "base", "small", "medium", "large-v2", "large-v3", "large-v3-turbo"]}
|
|
35
|
+
|
|
36
|
+
def __init__(self, settings: WhisperSTTSettings | None = None) -> None:
|
|
37
|
+
self.settings = settings or WhisperSTTSettings()
|
|
38
|
+
self._client: Any = None
|
|
39
|
+
self._model: Any = None
|
|
40
|
+
self._loaded_model_name = self.settings.model_name
|
|
41
|
+
self._loaded_device = self.settings.device
|
|
42
|
+
|
|
43
|
+
async def start(self) -> None:
|
|
44
|
+
await self._ensure_model(
|
|
45
|
+
model_name=self.settings.model_name,
|
|
46
|
+
device=self.settings.device,
|
|
47
|
+
)
|
|
48
|
+
logger.info("{} provider started", self.name)
|
|
49
|
+
|
|
50
|
+
async def _ensure_model(self, *, model_name: str, device: str) -> None:
|
|
51
|
+
try:
|
|
52
|
+
import whisper
|
|
53
|
+
except ImportError:
|
|
54
|
+
raise ImportError(
|
|
55
|
+
"Install whisper: pip install openspeech[whisper]"
|
|
56
|
+
)
|
|
57
|
+
if (
|
|
58
|
+
self._model is not None
|
|
59
|
+
and self._loaded_model_name == model_name
|
|
60
|
+
and self._loaded_device == device
|
|
61
|
+
):
|
|
62
|
+
return
|
|
63
|
+
self._model = whisper.load_model(
|
|
64
|
+
model_name,
|
|
65
|
+
device=device,
|
|
66
|
+
download_root=self.settings.download_root,
|
|
67
|
+
)
|
|
68
|
+
self._client = self._model
|
|
69
|
+
self._loaded_model_name = model_name
|
|
70
|
+
self._loaded_device = device
|
|
71
|
+
|
|
72
|
+
async def stop(self) -> None:
|
|
73
|
+
self._client = None
|
|
74
|
+
self._model = None
|
|
75
|
+
logger.info("{} provider stopped", self.name)
|
|
76
|
+
|
|
77
|
+
async def health_check(self) -> bool:
|
|
78
|
+
return self._client is not None
|
|
79
|
+
|
|
80
|
+
async def transcribe(
|
|
81
|
+
self, audio: AudioData, opts: STTOptions | None = None
|
|
82
|
+
) -> Transcription:
|
|
83
|
+
if self._client is None:
|
|
84
|
+
raise RuntimeError("Provider not started — call start() first")
|
|
85
|
+
logger.info("{}: request received, audio={} bytes", self.name, len(audio.data))
|
|
86
|
+
_t0 = time.perf_counter()
|
|
87
|
+
opts = opts or STTOptions()
|
|
88
|
+
requested_model = (opts.model or self.settings.model_name).strip()
|
|
89
|
+
requested_device = (opts.device or self.settings.device).strip()
|
|
90
|
+
await self._ensure_model(model_name=requested_model, device=requested_device)
|
|
91
|
+
|
|
92
|
+
is_wav = len(audio.data) > 4 and audio.data[:4] == b"RIFF"
|
|
93
|
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
|
94
|
+
if is_wav:
|
|
95
|
+
f.write(audio.data)
|
|
96
|
+
else:
|
|
97
|
+
buf = io.BytesIO()
|
|
98
|
+
with wave.open(buf, "wb") as wf:
|
|
99
|
+
wf.setnchannels(audio.channels)
|
|
100
|
+
wf.setsampwidth(2)
|
|
101
|
+
wf.setframerate(audio.sample_rate)
|
|
102
|
+
wf.writeframes(audio.data)
|
|
103
|
+
f.write(buf.getvalue())
|
|
104
|
+
tmp_path = f.name
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
kwargs: dict[str, Any] = {
|
|
108
|
+
"fp16": bool(self.settings.fp16 if opts.fp16 is None else opts.fp16),
|
|
109
|
+
"beam_size": int(self.settings.beam_size if opts.beam_size is None else opts.beam_size),
|
|
110
|
+
"word_timestamps": True,
|
|
111
|
+
}
|
|
112
|
+
if opts.language:
|
|
113
|
+
kwargs["language"] = opts.language
|
|
114
|
+
result = self._model.transcribe(tmp_path, **kwargs)
|
|
115
|
+
|
|
116
|
+
segments = result.get("segments") or []
|
|
117
|
+
words: list[Word] = []
|
|
118
|
+
for seg in segments:
|
|
119
|
+
for item in (seg.get("words") or []):
|
|
120
|
+
word_text = str(item.get("word", "")).strip()
|
|
121
|
+
if not word_text:
|
|
122
|
+
continue
|
|
123
|
+
words.append(
|
|
124
|
+
Word(
|
|
125
|
+
text=word_text,
|
|
126
|
+
start_ms=int(float(item.get("start", 0.0)) * 1000),
|
|
127
|
+
end_ms=int(float(item.get("end", 0.0)) * 1000),
|
|
128
|
+
confidence=float(item["probability"]) if item.get("probability") is not None else None,
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
duration_ms: int | None = None
|
|
132
|
+
if segments:
|
|
133
|
+
last_end = float(segments[-1].get("end", 0.0))
|
|
134
|
+
if last_end > 0:
|
|
135
|
+
duration_ms = int(last_end * 1000)
|
|
136
|
+
transcription = Transcription(
|
|
137
|
+
text=str(result.get("text", "")).strip(),
|
|
138
|
+
language=result.get("language"),
|
|
139
|
+
words=words or None,
|
|
140
|
+
duration_ms=duration_ms,
|
|
141
|
+
)
|
|
142
|
+
logger.info("{}: completed in {:.0f}ms, result={} chars", self.name, (time.perf_counter() - _t0) * 1000, len(transcription.text))
|
|
143
|
+
return transcription
|
|
144
|
+
finally:
|
|
145
|
+
os.unlink(tmp_path)
|
|
146
|
+
|
|
147
|
+
async def transcribe_stream(
|
|
148
|
+
self, stream: AsyncIterator[bytes]
|
|
149
|
+
) -> AsyncIterator[Any]:
|
|
150
|
+
raise NotImplementedError(
|
|
151
|
+
"WhisperSTT does not support streaming transcription"
|
|
152
|
+
)
|
|
153
|
+
yield # pragma: no cover
|