cartesia 2.0.3__py3-none-any.whl → 2.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cartesia/__init__.py +46 -1
- cartesia/base_client.py +2 -0
- cartesia/client.py +5 -0
- cartesia/core/client_wrapper.py +1 -1
- cartesia/stt/__init__.py +51 -0
- cartesia/stt/_async_websocket.py +284 -0
- cartesia/stt/_websocket.py +272 -0
- cartesia/stt/requests/__init__.py +27 -0
- cartesia/stt/requests/done_message.py +14 -0
- cartesia/stt/requests/error_message.py +16 -0
- cartesia/stt/requests/flush_done_message.py +14 -0
- cartesia/stt/requests/streaming_transcription_response.py +39 -0
- cartesia/stt/requests/transcript_message.py +33 -0
- cartesia/stt/requests/transcription_response.py +21 -0
- cartesia/stt/socket_client.py +195 -0
- cartesia/stt/types/__init__.py +29 -0
- cartesia/stt/types/done_message.py +26 -0
- cartesia/stt/types/error_message.py +27 -0
- cartesia/stt/types/flush_done_message.py +26 -0
- cartesia/stt/types/streaming_transcription_response.py +92 -0
- cartesia/stt/types/stt_encoding.py +5 -0
- cartesia/stt/types/transcript_message.py +44 -0
- cartesia/stt/types/transcription_response.py +32 -0
- cartesia/tts/_websocket.py +3 -3
- {cartesia-2.0.3.dist-info → cartesia-2.0.5.dist-info}/METADATA +159 -2
- {cartesia-2.0.3.dist-info → cartesia-2.0.5.dist-info}/RECORD +27 -8
- {cartesia-2.0.3.dist-info → cartesia-2.0.5.dist-info}/WHEEL +0 -0
cartesia/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
# This file was auto-generated by Fern from our API Definition.
|
2
2
|
|
3
|
-
from . import api_status, auth, datasets, embedding, infill, tts, voice_changer, voices
|
3
|
+
from . import api_status, auth, datasets, embedding, infill, stt, tts, voice_changer, voices
|
4
4
|
from .api_status import ApiInfo, ApiInfoParams
|
5
5
|
from .auth import TokenGrant, TokenGrantParams, TokenRequest, TokenRequestParams, TokenResponse, TokenResponseParams
|
6
6
|
from .client import AsyncCartesia, Cartesia
|
@@ -19,6 +19,29 @@ from .datasets import (
|
|
19
19
|
)
|
20
20
|
from .embedding import Embedding
|
21
21
|
from .environment import CartesiaEnvironment
|
22
|
+
from .stt import (
|
23
|
+
DoneMessage,
|
24
|
+
DoneMessageParams,
|
25
|
+
ErrorMessage,
|
26
|
+
ErrorMessageParams,
|
27
|
+
FlushDoneMessage,
|
28
|
+
FlushDoneMessageParams,
|
29
|
+
StreamingTranscriptionResponse,
|
30
|
+
StreamingTranscriptionResponseParams,
|
31
|
+
StreamingTranscriptionResponse_Done,
|
32
|
+
StreamingTranscriptionResponse_DoneParams,
|
33
|
+
StreamingTranscriptionResponse_Error,
|
34
|
+
StreamingTranscriptionResponse_ErrorParams,
|
35
|
+
StreamingTranscriptionResponse_FlushDone,
|
36
|
+
StreamingTranscriptionResponse_FlushDoneParams,
|
37
|
+
StreamingTranscriptionResponse_Transcript,
|
38
|
+
StreamingTranscriptionResponse_TranscriptParams,
|
39
|
+
SttEncoding,
|
40
|
+
TranscriptMessage,
|
41
|
+
TranscriptMessageParams,
|
42
|
+
TranscriptionResponse,
|
43
|
+
TranscriptionResponseParams,
|
44
|
+
)
|
22
45
|
from .tts import (
|
23
46
|
CancelContextRequest,
|
24
47
|
CancelContextRequestParams,
|
@@ -173,13 +196,19 @@ __all__ = [
|
|
173
196
|
"DatasetFile",
|
174
197
|
"DatasetFileParams",
|
175
198
|
"DatasetParams",
|
199
|
+
"DoneMessage",
|
200
|
+
"DoneMessageParams",
|
176
201
|
"Embedding",
|
177
202
|
"EmbeddingResponse",
|
178
203
|
"EmbeddingResponseParams",
|
179
204
|
"EmbeddingSpecifier",
|
180
205
|
"EmbeddingSpecifierParams",
|
181
206
|
"Emotion",
|
207
|
+
"ErrorMessage",
|
208
|
+
"ErrorMessageParams",
|
182
209
|
"FilePurpose",
|
210
|
+
"FlushDoneMessage",
|
211
|
+
"FlushDoneMessageParams",
|
183
212
|
"FlushId",
|
184
213
|
"Gender",
|
185
214
|
"GenderPresentation",
|
@@ -235,6 +264,17 @@ __all__ = [
|
|
235
264
|
"StreamingResponse_DoneParams",
|
236
265
|
"StreamingResponse_Error",
|
237
266
|
"StreamingResponse_ErrorParams",
|
267
|
+
"StreamingTranscriptionResponse",
|
268
|
+
"StreamingTranscriptionResponseParams",
|
269
|
+
"StreamingTranscriptionResponse_Done",
|
270
|
+
"StreamingTranscriptionResponse_DoneParams",
|
271
|
+
"StreamingTranscriptionResponse_Error",
|
272
|
+
"StreamingTranscriptionResponse_ErrorParams",
|
273
|
+
"StreamingTranscriptionResponse_FlushDone",
|
274
|
+
"StreamingTranscriptionResponse_FlushDoneParams",
|
275
|
+
"StreamingTranscriptionResponse_Transcript",
|
276
|
+
"StreamingTranscriptionResponse_TranscriptParams",
|
277
|
+
"SttEncoding",
|
238
278
|
"SupportedLanguage",
|
239
279
|
"TokenGrant",
|
240
280
|
"TokenGrantParams",
|
@@ -242,6 +282,10 @@ __all__ = [
|
|
242
282
|
"TokenRequestParams",
|
243
283
|
"TokenResponse",
|
244
284
|
"TokenResponseParams",
|
285
|
+
"TranscriptMessage",
|
286
|
+
"TranscriptMessageParams",
|
287
|
+
"TranscriptionResponse",
|
288
|
+
"TranscriptionResponseParams",
|
245
289
|
"TtsRequest",
|
246
290
|
"TtsRequestEmbeddingSpecifier",
|
247
291
|
"TtsRequestEmbeddingSpecifierParams",
|
@@ -307,6 +351,7 @@ __all__ = [
|
|
307
351
|
"datasets",
|
308
352
|
"embedding",
|
309
353
|
"infill",
|
354
|
+
"stt",
|
310
355
|
"tts",
|
311
356
|
"voice_changer",
|
312
357
|
"voices",
|
cartesia/base_client.py
CHANGED
@@ -7,6 +7,7 @@ from .core.client_wrapper import SyncClientWrapper
|
|
7
7
|
from .api_status.client import ApiStatusClient
|
8
8
|
from .auth.client import AuthClient
|
9
9
|
from .infill.client import InfillClient
|
10
|
+
from .stt.socket_client import AsyncSttClientWithWebsocket, SttClientWithWebsocket
|
10
11
|
from .tts.client import TtsClient
|
11
12
|
from .voice_changer.client import VoiceChangerClient
|
12
13
|
from .voices.client import VoicesClient
|
@@ -80,6 +81,7 @@ class BaseCartesia:
|
|
80
81
|
self.api_status = ApiStatusClient(client_wrapper=self._client_wrapper)
|
81
82
|
self.auth = AuthClient(client_wrapper=self._client_wrapper)
|
82
83
|
self.infill = InfillClient(client_wrapper=self._client_wrapper)
|
84
|
+
self.stt = SttClientWithWebsocket(client_wrapper=self._client_wrapper)
|
83
85
|
self.tts = TtsClient(client_wrapper=self._client_wrapper)
|
84
86
|
self.voice_changer = VoiceChangerClient(client_wrapper=self._client_wrapper)
|
85
87
|
self.voices = VoicesClient(client_wrapper=self._client_wrapper)
|
cartesia/client.py
CHANGED
@@ -8,6 +8,7 @@ import httpx
|
|
8
8
|
|
9
9
|
from .base_client import AsyncBaseCartesia, BaseCartesia
|
10
10
|
from .environment import CartesiaEnvironment
|
11
|
+
from .stt.socket_client import AsyncSttClientWithWebsocket, SttClientWithWebsocket
|
11
12
|
from .tts.socket_client import AsyncTtsClientWithWebsocket, TtsClientWithWebsocket
|
12
13
|
|
13
14
|
|
@@ -66,6 +67,7 @@ class Cartesia(BaseCartesia):
|
|
66
67
|
follow_redirects=follow_redirects,
|
67
68
|
httpx_client=httpx_client,
|
68
69
|
)
|
70
|
+
self.stt = SttClientWithWebsocket(client_wrapper=self._client_wrapper)
|
69
71
|
self.tts = TtsClientWithWebsocket(client_wrapper=self._client_wrapper)
|
70
72
|
|
71
73
|
def __enter__(self):
|
@@ -143,6 +145,9 @@ class AsyncCartesia(AsyncBaseCartesia):
|
|
143
145
|
self._session = None
|
144
146
|
self._loop = None
|
145
147
|
self.max_num_connections = max_num_connections
|
148
|
+
self.stt = AsyncSttClientWithWebsocket(
|
149
|
+
client_wrapper=self._client_wrapper, get_session=self._get_session
|
150
|
+
)
|
146
151
|
self.tts = AsyncTtsClientWithWebsocket(
|
147
152
|
client_wrapper=self._client_wrapper, get_session=self._get_session
|
148
153
|
)
|
cartesia/core/client_wrapper.py
CHANGED
@@ -16,7 +16,7 @@ class BaseClientWrapper:
|
|
16
16
|
headers: typing.Dict[str, str] = {
|
17
17
|
"X-Fern-Language": "Python",
|
18
18
|
"X-Fern-SDK-Name": "cartesia",
|
19
|
-
"X-Fern-SDK-Version": "2.0.
|
19
|
+
"X-Fern-SDK-Version": "2.0.5",
|
20
20
|
}
|
21
21
|
headers["X-API-Key"] = self.api_key
|
22
22
|
headers["Cartesia-Version"] = "2024-11-13"
|
cartesia/stt/__init__.py
ADDED
@@ -0,0 +1,51 @@
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
2
|
+
|
3
|
+
from .types import (
|
4
|
+
DoneMessage,
|
5
|
+
ErrorMessage,
|
6
|
+
FlushDoneMessage,
|
7
|
+
StreamingTranscriptionResponse,
|
8
|
+
StreamingTranscriptionResponse_Done,
|
9
|
+
StreamingTranscriptionResponse_Error,
|
10
|
+
StreamingTranscriptionResponse_FlushDone,
|
11
|
+
StreamingTranscriptionResponse_Transcript,
|
12
|
+
SttEncoding,
|
13
|
+
TranscriptMessage,
|
14
|
+
TranscriptionResponse,
|
15
|
+
)
|
16
|
+
from .requests import (
|
17
|
+
DoneMessageParams,
|
18
|
+
ErrorMessageParams,
|
19
|
+
FlushDoneMessageParams,
|
20
|
+
StreamingTranscriptionResponseParams,
|
21
|
+
StreamingTranscriptionResponse_DoneParams,
|
22
|
+
StreamingTranscriptionResponse_ErrorParams,
|
23
|
+
StreamingTranscriptionResponse_FlushDoneParams,
|
24
|
+
StreamingTranscriptionResponse_TranscriptParams,
|
25
|
+
TranscriptMessageParams,
|
26
|
+
TranscriptionResponseParams,
|
27
|
+
)
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
"DoneMessage",
|
31
|
+
"DoneMessageParams",
|
32
|
+
"ErrorMessage",
|
33
|
+
"ErrorMessageParams",
|
34
|
+
"FlushDoneMessage",
|
35
|
+
"FlushDoneMessageParams",
|
36
|
+
"StreamingTranscriptionResponse",
|
37
|
+
"StreamingTranscriptionResponseParams",
|
38
|
+
"StreamingTranscriptionResponse_Done",
|
39
|
+
"StreamingTranscriptionResponse_DoneParams",
|
40
|
+
"StreamingTranscriptionResponse_Error",
|
41
|
+
"StreamingTranscriptionResponse_ErrorParams",
|
42
|
+
"StreamingTranscriptionResponse_FlushDone",
|
43
|
+
"StreamingTranscriptionResponse_FlushDoneParams",
|
44
|
+
"StreamingTranscriptionResponse_Transcript",
|
45
|
+
"StreamingTranscriptionResponse_TranscriptParams",
|
46
|
+
"SttEncoding",
|
47
|
+
"TranscriptMessage",
|
48
|
+
"TranscriptMessageParams",
|
49
|
+
"TranscriptionResponse",
|
50
|
+
"TranscriptionResponseParams",
|
51
|
+
]
|
@@ -0,0 +1,284 @@
|
|
1
|
+
import asyncio
|
2
|
+
import json
|
3
|
+
import typing
|
4
|
+
import uuid
|
5
|
+
from typing import Any, Awaitable, AsyncGenerator, Callable, Dict, Optional, Union
|
6
|
+
|
7
|
+
import aiohttp
|
8
|
+
|
9
|
+
from cartesia.stt.types import (
|
10
|
+
StreamingTranscriptionResponse,
|
11
|
+
StreamingTranscriptionResponse_Error,
|
12
|
+
StreamingTranscriptionResponse_Transcript,
|
13
|
+
)
|
14
|
+
|
15
|
+
from ..core.pydantic_utilities import parse_obj_as
|
16
|
+
from ._websocket import SttWebsocket
|
17
|
+
|
18
|
+
|
19
|
+
class AsyncSttWebsocket(SttWebsocket):
|
20
|
+
"""This class contains methods to transcribe audio using WebSocket asynchronously."""
|
21
|
+
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
ws_url: str,
|
25
|
+
api_key: str,
|
26
|
+
cartesia_version: str,
|
27
|
+
get_session: Callable[[], Awaitable[Optional[aiohttp.ClientSession]]],
|
28
|
+
timeout: float = 30,
|
29
|
+
):
|
30
|
+
"""
|
31
|
+
Args:
|
32
|
+
ws_url: The WebSocket URL for the Cartesia API.
|
33
|
+
api_key: The API key to use for authorization.
|
34
|
+
cartesia_version: The version of the Cartesia API to use.
|
35
|
+
timeout: The timeout for responses on the WebSocket in seconds.
|
36
|
+
get_session: A function that returns an awaitable of aiohttp.ClientSession object.
|
37
|
+
"""
|
38
|
+
super().__init__(ws_url, api_key, cartesia_version)
|
39
|
+
self.timeout = timeout
|
40
|
+
self._get_session = get_session
|
41
|
+
self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None
|
42
|
+
self._default_model: str = "ink-whisper"
|
43
|
+
self._default_language: Optional[str] = "en"
|
44
|
+
self._default_encoding: Optional[str] = "pcm_s16le"
|
45
|
+
self._default_sample_rate: int = 16000
|
46
|
+
|
47
|
+
def __del__(self):
|
48
|
+
try:
|
49
|
+
loop = asyncio.get_running_loop()
|
50
|
+
except RuntimeError:
|
51
|
+
loop = None
|
52
|
+
|
53
|
+
if loop is None:
|
54
|
+
asyncio.run(self.close())
|
55
|
+
elif loop.is_running():
|
56
|
+
loop.create_task(self.close())
|
57
|
+
|
58
|
+
async def connect(
|
59
|
+
self,
|
60
|
+
*,
|
61
|
+
model: str = "ink-whisper",
|
62
|
+
language: Optional[str] = "en",
|
63
|
+
encoding: Optional[str] = "pcm_s16le",
|
64
|
+
sample_rate: int = 16000,
|
65
|
+
):
|
66
|
+
"""Connect to the STT WebSocket with the specified parameters.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
model: ID of the model to use for transcription
|
70
|
+
language: The language of the input audio in ISO-639-1 format
|
71
|
+
encoding: The encoding format of the audio data
|
72
|
+
sample_rate: The sample rate of the audio in Hz
|
73
|
+
|
74
|
+
Raises:
|
75
|
+
RuntimeError: If the connection to the WebSocket fails.
|
76
|
+
"""
|
77
|
+
self._default_model = model
|
78
|
+
self._default_language = language
|
79
|
+
self._default_encoding = encoding
|
80
|
+
self._default_sample_rate = sample_rate
|
81
|
+
|
82
|
+
if self.websocket is None or self._is_websocket_closed():
|
83
|
+
route = "stt/websocket"
|
84
|
+
session = await self._get_session()
|
85
|
+
|
86
|
+
params = {
|
87
|
+
"model": model,
|
88
|
+
"api_key": self.api_key,
|
89
|
+
"cartesia_version": self.cartesia_version,
|
90
|
+
}
|
91
|
+
if language is not None:
|
92
|
+
params["language"] = language
|
93
|
+
if encoding is not None:
|
94
|
+
params["encoding"] = encoding
|
95
|
+
if sample_rate is not None:
|
96
|
+
params["sample_rate"] = str(sample_rate)
|
97
|
+
|
98
|
+
query_string = "&".join([f"{k}={v}" for k, v in params.items()])
|
99
|
+
url = f"{self.ws_url}/{route}?{query_string}"
|
100
|
+
|
101
|
+
try:
|
102
|
+
if session is None:
|
103
|
+
raise RuntimeError("Session is not available")
|
104
|
+
self.websocket = await session.ws_connect(url)
|
105
|
+
except Exception as e:
|
106
|
+
status_code = None
|
107
|
+
error_message = str(e)
|
108
|
+
|
109
|
+
if hasattr(e, 'status') and e.status is not None:
|
110
|
+
status_code = e.status
|
111
|
+
|
112
|
+
if status_code == 402:
|
113
|
+
error_message = "Payment required. Your API key may have insufficient credits or permissions."
|
114
|
+
elif status_code == 401:
|
115
|
+
error_message = "Unauthorized. Please check your API key."
|
116
|
+
elif status_code == 403:
|
117
|
+
error_message = "Forbidden. You don't have permission to access this resource."
|
118
|
+
elif status_code == 404:
|
119
|
+
error_message = "Not found. The requested resource doesn't exist."
|
120
|
+
|
121
|
+
raise RuntimeError(f"Failed to connect to WebSocket.\nStatus: {status_code}. Error message: {error_message}")
|
122
|
+
else:
|
123
|
+
raise RuntimeError(f"Failed to connect to WebSocket at {url}. {e}")
|
124
|
+
|
125
|
+
def _is_websocket_closed(self):
|
126
|
+
return self.websocket is None or self.websocket.closed
|
127
|
+
|
128
|
+
async def close(self):
|
129
|
+
"""This method closes the websocket connection. Highly recommended to call this method when done."""
|
130
|
+
if self.websocket is not None and not self._is_websocket_closed():
|
131
|
+
await self.websocket.close()
|
132
|
+
self.websocket = None
|
133
|
+
|
134
|
+
async def send(self, data: Union[bytes, str]):
|
135
|
+
"""Send audio data or control commands to the WebSocket.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
data: Binary audio data or text command ("finalize" or "done")
|
139
|
+
"""
|
140
|
+
if self.websocket is None or self._is_websocket_closed():
|
141
|
+
await self.connect(
|
142
|
+
model=self._default_model,
|
143
|
+
language=self._default_language,
|
144
|
+
encoding=self._default_encoding,
|
145
|
+
sample_rate=self._default_sample_rate,
|
146
|
+
)
|
147
|
+
|
148
|
+
assert self.websocket is not None, "WebSocket should be connected after connect() call"
|
149
|
+
|
150
|
+
if isinstance(data, bytes):
|
151
|
+
await self.websocket.send_bytes(data)
|
152
|
+
elif isinstance(data, str):
|
153
|
+
await self.websocket.send_str(data)
|
154
|
+
else:
|
155
|
+
raise TypeError("Data must be bytes (audio) or str (command)")
|
156
|
+
|
157
|
+
async def receive(self) -> AsyncGenerator[Dict[str, Any], None]: # type: ignore[override]
|
158
|
+
"""Receive transcription results from the WebSocket.
|
159
|
+
|
160
|
+
Yields:
|
161
|
+
Dictionary containing transcription results, flush_done, done, or error messages
|
162
|
+
"""
|
163
|
+
if self.websocket is None or self._is_websocket_closed():
|
164
|
+
await self.connect(
|
165
|
+
model=self._default_model,
|
166
|
+
language=self._default_language,
|
167
|
+
encoding=self._default_encoding,
|
168
|
+
sample_rate=self._default_sample_rate,
|
169
|
+
)
|
170
|
+
|
171
|
+
assert self.websocket is not None, "WebSocket should be connected after connect() call"
|
172
|
+
|
173
|
+
try:
|
174
|
+
while True:
|
175
|
+
try:
|
176
|
+
msg = await asyncio.wait_for(self.websocket.receive(), timeout=self.timeout)
|
177
|
+
|
178
|
+
if msg.type == aiohttp.WSMsgType.TEXT:
|
179
|
+
raw_data = json.loads(msg.data)
|
180
|
+
|
181
|
+
# Handle error responses
|
182
|
+
if raw_data.get("type") == "error":
|
183
|
+
raise RuntimeError(f"Error transcribing audio: {raw_data.get('message', 'Unknown error')}")
|
184
|
+
|
185
|
+
# Handle transcript responses with flexible parsing
|
186
|
+
if raw_data.get("type") == "transcript":
|
187
|
+
# Provide defaults for missing required fields
|
188
|
+
result = {
|
189
|
+
"type": raw_data["type"],
|
190
|
+
"request_id": raw_data.get("request_id", ""),
|
191
|
+
"text": raw_data.get("text", ""), # Default to empty string if missing
|
192
|
+
"is_final": raw_data.get("is_final", False), # Default to False if missing
|
193
|
+
}
|
194
|
+
|
195
|
+
# Add optional fields if present
|
196
|
+
if "duration" in raw_data:
|
197
|
+
result["duration"] = raw_data["duration"]
|
198
|
+
if "language" in raw_data:
|
199
|
+
result["language"] = raw_data["language"]
|
200
|
+
|
201
|
+
yield result
|
202
|
+
|
203
|
+
# Handle flush_done acknowledgment
|
204
|
+
elif raw_data.get("type") == "flush_done":
|
205
|
+
result = {
|
206
|
+
"type": raw_data["type"],
|
207
|
+
"request_id": raw_data.get("request_id", ""),
|
208
|
+
}
|
209
|
+
yield result
|
210
|
+
|
211
|
+
# Handle done acknowledgment - session complete
|
212
|
+
elif raw_data.get("type") == "done":
|
213
|
+
result = {
|
214
|
+
"type": raw_data["type"],
|
215
|
+
"request_id": raw_data.get("request_id", ""),
|
216
|
+
}
|
217
|
+
yield result
|
218
|
+
# Session is complete, break out of loop
|
219
|
+
break
|
220
|
+
|
221
|
+
elif msg.type == aiohttp.WSMsgType.ERROR:
|
222
|
+
websocket_exception = self.websocket.exception() if self.websocket else None
|
223
|
+
await self.close()
|
224
|
+
raise RuntimeError(f"WebSocket error: {websocket_exception}")
|
225
|
+
|
226
|
+
elif msg.type == aiohttp.WSMsgType.CLOSE:
|
227
|
+
break
|
228
|
+
|
229
|
+
except asyncio.TimeoutError:
|
230
|
+
await self.close()
|
231
|
+
raise RuntimeError("Timeout while waiting for transcription")
|
232
|
+
except Exception as inner_e:
|
233
|
+
await self.close()
|
234
|
+
raise RuntimeError(f"Error receiving transcription: {inner_e}")
|
235
|
+
|
236
|
+
except Exception as e:
|
237
|
+
await self.close()
|
238
|
+
raise RuntimeError(f"Failed to receive transcription. {e}")
|
239
|
+
|
240
|
+
async def transcribe( # type: ignore[override]
|
241
|
+
self,
|
242
|
+
audio_chunks: typing.AsyncIterator[bytes],
|
243
|
+
*,
|
244
|
+
model: str = "ink-whisper",
|
245
|
+
language: Optional[str] = "en",
|
246
|
+
encoding: Optional[str] = "pcm_s16le",
|
247
|
+
sample_rate: int = 16000,
|
248
|
+
) -> AsyncGenerator[Dict[str, Any], None]:
|
249
|
+
"""Transcribe audio chunks using the WebSocket.
|
250
|
+
|
251
|
+
Args:
|
252
|
+
audio_chunks: Async iterator of audio chunks as bytes
|
253
|
+
model: ID of the model to use for transcription
|
254
|
+
language: The language of the input audio in ISO-639-1 format
|
255
|
+
encoding: The encoding format of the audio data
|
256
|
+
sample_rate: The sample rate of the audio in Hz
|
257
|
+
|
258
|
+
Yields:
|
259
|
+
Dictionary containing transcription results, flush_done, done, or error messages
|
260
|
+
"""
|
261
|
+
await self.connect(
|
262
|
+
model=model,
|
263
|
+
language=language,
|
264
|
+
encoding=encoding,
|
265
|
+
sample_rate=sample_rate,
|
266
|
+
)
|
267
|
+
|
268
|
+
try:
|
269
|
+
# Send all audio chunks
|
270
|
+
async for chunk in audio_chunks:
|
271
|
+
await self.send(chunk)
|
272
|
+
|
273
|
+
# Send finalize command to flush remaining audio
|
274
|
+
await self.send("finalize")
|
275
|
+
|
276
|
+
# Send done command to close session cleanly
|
277
|
+
await self.send("done")
|
278
|
+
|
279
|
+
# Receive all responses until done
|
280
|
+
async for result in self.receive():
|
281
|
+
yield result
|
282
|
+
|
283
|
+
finally:
|
284
|
+
await self.close()
|