cartesia 2.0.4__py3-none-any.whl → 2.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cartesia/__init__.py +60 -1
- cartesia/auth/client.py +8 -8
- cartesia/auth/requests/token_grant.py +7 -1
- cartesia/auth/requests/token_request.py +3 -3
- cartesia/auth/types/token_grant.py +7 -2
- cartesia/auth/types/token_request.py +3 -3
- cartesia/base_client.py +2 -0
- cartesia/client.py +5 -0
- cartesia/core/client_wrapper.py +1 -1
- cartesia/stt/__init__.py +57 -0
- cartesia/stt/_async_websocket.py +293 -0
- cartesia/stt/_websocket.py +294 -0
- cartesia/stt/client.py +456 -0
- cartesia/stt/requests/__init__.py +29 -0
- cartesia/stt/requests/done_message.py +14 -0
- cartesia/stt/requests/error_message.py +16 -0
- cartesia/stt/requests/flush_done_message.py +14 -0
- cartesia/stt/requests/streaming_transcription_response.py +41 -0
- cartesia/stt/requests/transcript_message.py +40 -0
- cartesia/stt/requests/transcription_response.py +28 -0
- cartesia/stt/requests/transcription_word.py +20 -0
- cartesia/stt/socket_client.py +138 -0
- cartesia/stt/types/__init__.py +33 -0
- cartesia/stt/types/done_message.py +26 -0
- cartesia/stt/types/error_message.py +27 -0
- cartesia/stt/types/flush_done_message.py +26 -0
- cartesia/stt/types/streaming_transcription_response.py +94 -0
- cartesia/stt/types/stt_encoding.py +7 -0
- cartesia/stt/types/timestamp_granularity.py +5 -0
- cartesia/stt/types/transcript_message.py +50 -0
- cartesia/stt/types/transcription_response.py +38 -0
- cartesia/stt/types/transcription_word.py +32 -0
- cartesia/tts/__init__.py +8 -0
- cartesia/tts/client.py +50 -8
- cartesia/tts/requests/__init__.py +4 -0
- cartesia/tts/requests/generation_request.py +4 -4
- cartesia/tts/requests/sse_output_format.py +11 -0
- cartesia/tts/requests/ttssse_request.py +47 -0
- cartesia/tts/requests/web_socket_chunk_response.py +0 -3
- cartesia/tts/requests/web_socket_response.py +1 -2
- cartesia/tts/requests/web_socket_tts_request.py +9 -1
- cartesia/tts/types/__init__.py +4 -0
- cartesia/tts/types/generation_request.py +4 -4
- cartesia/tts/types/sse_output_format.py +22 -0
- cartesia/tts/types/ttssse_request.py +58 -0
- cartesia/tts/types/web_socket_chunk_response.py +1 -3
- cartesia/tts/types/web_socket_response.py +1 -2
- cartesia/tts/types/web_socket_tts_request.py +11 -3
- cartesia/voice_changer/requests/streaming_response.py +0 -2
- cartesia/voice_changer/types/streaming_response.py +0 -2
- {cartesia-2.0.4.dist-info → cartesia-2.0.6.dist-info}/METADATA +256 -2
- {cartesia-2.0.4.dist-info → cartesia-2.0.6.dist-info}/RECORD +53 -26
- {cartesia-2.0.4.dist-info → cartesia-2.0.6.dist-info}/WHEEL +0 -0
cartesia/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
# This file was auto-generated by Fern from our API Definition.
|
2
2
|
|
3
|
-
from . import api_status, auth, datasets, embedding, infill, tts, voice_changer, voices
|
3
|
+
from . import api_status, auth, datasets, embedding, infill, stt, tts, voice_changer, voices
|
4
4
|
from .api_status import ApiInfo, ApiInfoParams
|
5
5
|
from .auth import TokenGrant, TokenGrantParams, TokenRequest, TokenRequestParams, TokenResponse, TokenResponseParams
|
6
6
|
from .client import AsyncCartesia, Cartesia
|
@@ -19,6 +19,32 @@ from .datasets import (
|
|
19
19
|
)
|
20
20
|
from .embedding import Embedding
|
21
21
|
from .environment import CartesiaEnvironment
|
22
|
+
from .stt import (
|
23
|
+
DoneMessage,
|
24
|
+
DoneMessageParams,
|
25
|
+
ErrorMessage,
|
26
|
+
ErrorMessageParams,
|
27
|
+
FlushDoneMessage,
|
28
|
+
FlushDoneMessageParams,
|
29
|
+
StreamingTranscriptionResponse,
|
30
|
+
StreamingTranscriptionResponseParams,
|
31
|
+
StreamingTranscriptionResponse_Done,
|
32
|
+
StreamingTranscriptionResponse_DoneParams,
|
33
|
+
StreamingTranscriptionResponse_Error,
|
34
|
+
StreamingTranscriptionResponse_ErrorParams,
|
35
|
+
StreamingTranscriptionResponse_FlushDone,
|
36
|
+
StreamingTranscriptionResponse_FlushDoneParams,
|
37
|
+
StreamingTranscriptionResponse_Transcript,
|
38
|
+
StreamingTranscriptionResponse_TranscriptParams,
|
39
|
+
SttEncoding,
|
40
|
+
TimestampGranularity,
|
41
|
+
TranscriptMessage,
|
42
|
+
TranscriptMessageParams,
|
43
|
+
TranscriptionResponse,
|
44
|
+
TranscriptionResponseParams,
|
45
|
+
TranscriptionWord,
|
46
|
+
TranscriptionWordParams,
|
47
|
+
)
|
22
48
|
from .tts import (
|
23
49
|
CancelContextRequest,
|
24
50
|
CancelContextRequestParams,
|
@@ -49,6 +75,8 @@ from .tts import (
|
|
49
75
|
RawOutputFormatParams,
|
50
76
|
Speed,
|
51
77
|
SpeedParams,
|
78
|
+
SseOutputFormat,
|
79
|
+
SseOutputFormatParams,
|
52
80
|
SupportedLanguage,
|
53
81
|
TtsRequest,
|
54
82
|
TtsRequestEmbeddingSpecifier,
|
@@ -58,6 +86,8 @@ from .tts import (
|
|
58
86
|
TtsRequestParams,
|
59
87
|
TtsRequestVoiceSpecifier,
|
60
88
|
TtsRequestVoiceSpecifierParams,
|
89
|
+
TtssseRequest,
|
90
|
+
TtssseRequestParams,
|
61
91
|
WavOutputFormat,
|
62
92
|
WavOutputFormatParams,
|
63
93
|
WebSocketBaseResponse,
|
@@ -173,13 +203,19 @@ __all__ = [
|
|
173
203
|
"DatasetFile",
|
174
204
|
"DatasetFileParams",
|
175
205
|
"DatasetParams",
|
206
|
+
"DoneMessage",
|
207
|
+
"DoneMessageParams",
|
176
208
|
"Embedding",
|
177
209
|
"EmbeddingResponse",
|
178
210
|
"EmbeddingResponseParams",
|
179
211
|
"EmbeddingSpecifier",
|
180
212
|
"EmbeddingSpecifierParams",
|
181
213
|
"Emotion",
|
214
|
+
"ErrorMessage",
|
215
|
+
"ErrorMessageParams",
|
182
216
|
"FilePurpose",
|
217
|
+
"FlushDoneMessage",
|
218
|
+
"FlushDoneMessageParams",
|
183
219
|
"FlushId",
|
184
220
|
"Gender",
|
185
221
|
"GenderPresentation",
|
@@ -227,6 +263,8 @@ __all__ = [
|
|
227
263
|
"RawOutputFormatParams",
|
228
264
|
"Speed",
|
229
265
|
"SpeedParams",
|
266
|
+
"SseOutputFormat",
|
267
|
+
"SseOutputFormatParams",
|
230
268
|
"StreamingResponse",
|
231
269
|
"StreamingResponseParams",
|
232
270
|
"StreamingResponse_Chunk",
|
@@ -235,13 +273,31 @@ __all__ = [
|
|
235
273
|
"StreamingResponse_DoneParams",
|
236
274
|
"StreamingResponse_Error",
|
237
275
|
"StreamingResponse_ErrorParams",
|
276
|
+
"StreamingTranscriptionResponse",
|
277
|
+
"StreamingTranscriptionResponseParams",
|
278
|
+
"StreamingTranscriptionResponse_Done",
|
279
|
+
"StreamingTranscriptionResponse_DoneParams",
|
280
|
+
"StreamingTranscriptionResponse_Error",
|
281
|
+
"StreamingTranscriptionResponse_ErrorParams",
|
282
|
+
"StreamingTranscriptionResponse_FlushDone",
|
283
|
+
"StreamingTranscriptionResponse_FlushDoneParams",
|
284
|
+
"StreamingTranscriptionResponse_Transcript",
|
285
|
+
"StreamingTranscriptionResponse_TranscriptParams",
|
286
|
+
"SttEncoding",
|
238
287
|
"SupportedLanguage",
|
288
|
+
"TimestampGranularity",
|
239
289
|
"TokenGrant",
|
240
290
|
"TokenGrantParams",
|
241
291
|
"TokenRequest",
|
242
292
|
"TokenRequestParams",
|
243
293
|
"TokenResponse",
|
244
294
|
"TokenResponseParams",
|
295
|
+
"TranscriptMessage",
|
296
|
+
"TranscriptMessageParams",
|
297
|
+
"TranscriptionResponse",
|
298
|
+
"TranscriptionResponseParams",
|
299
|
+
"TranscriptionWord",
|
300
|
+
"TranscriptionWordParams",
|
245
301
|
"TtsRequest",
|
246
302
|
"TtsRequestEmbeddingSpecifier",
|
247
303
|
"TtsRequestEmbeddingSpecifierParams",
|
@@ -250,6 +306,8 @@ __all__ = [
|
|
250
306
|
"TtsRequestParams",
|
251
307
|
"TtsRequestVoiceSpecifier",
|
252
308
|
"TtsRequestVoiceSpecifierParams",
|
309
|
+
"TtssseRequest",
|
310
|
+
"TtssseRequestParams",
|
253
311
|
"UpdateVoiceRequest",
|
254
312
|
"UpdateVoiceRequestParams",
|
255
313
|
"Voice",
|
@@ -307,6 +365,7 @@ __all__ = [
|
|
307
365
|
"datasets",
|
308
366
|
"embedding",
|
309
367
|
"infill",
|
368
|
+
"stt",
|
310
369
|
"tts",
|
311
370
|
"voice_changer",
|
312
371
|
"voices",
|
cartesia/auth/client.py
CHANGED
@@ -22,7 +22,7 @@ class AuthClient:
|
|
22
22
|
def access_token(
|
23
23
|
self,
|
24
24
|
*,
|
25
|
-
grants: TokenGrantParams,
|
25
|
+
grants: typing.Optional[TokenGrantParams] = OMIT,
|
26
26
|
expires_in: typing.Optional[int] = OMIT,
|
27
27
|
request_options: typing.Optional[RequestOptions] = None,
|
28
28
|
) -> TokenResponse:
|
@@ -31,8 +31,8 @@ class AuthClient:
|
|
31
31
|
|
32
32
|
Parameters
|
33
33
|
----------
|
34
|
-
grants : TokenGrantParams
|
35
|
-
The permissions to be granted via the token.
|
34
|
+
grants : typing.Optional[TokenGrantParams]
|
35
|
+
The permissions to be granted via the token. Both TTS and STT grants are optional - specify only the capabilities you need.
|
36
36
|
|
37
37
|
expires_in : typing.Optional[int]
|
38
38
|
The number of seconds the token will be valid for since the time of generation. The maximum is 1 hour (3600 seconds).
|
@@ -52,7 +52,7 @@ class AuthClient:
|
|
52
52
|
api_key="YOUR_API_KEY",
|
53
53
|
)
|
54
54
|
client.auth.access_token(
|
55
|
-
grants={"tts": True},
|
55
|
+
grants={"tts": True, "stt": True},
|
56
56
|
expires_in=60,
|
57
57
|
)
|
58
58
|
"""
|
@@ -90,7 +90,7 @@ class AsyncAuthClient:
|
|
90
90
|
async def access_token(
|
91
91
|
self,
|
92
92
|
*,
|
93
|
-
grants: TokenGrantParams,
|
93
|
+
grants: typing.Optional[TokenGrantParams] = OMIT,
|
94
94
|
expires_in: typing.Optional[int] = OMIT,
|
95
95
|
request_options: typing.Optional[RequestOptions] = None,
|
96
96
|
) -> TokenResponse:
|
@@ -99,8 +99,8 @@ class AsyncAuthClient:
|
|
99
99
|
|
100
100
|
Parameters
|
101
101
|
----------
|
102
|
-
grants : TokenGrantParams
|
103
|
-
The permissions to be granted via the token.
|
102
|
+
grants : typing.Optional[TokenGrantParams]
|
103
|
+
The permissions to be granted via the token. Both TTS and STT grants are optional - specify only the capabilities you need.
|
104
104
|
|
105
105
|
expires_in : typing.Optional[int]
|
106
106
|
The number of seconds the token will be valid for since the time of generation. The maximum is 1 hour (3600 seconds).
|
@@ -125,7 +125,7 @@ class AsyncAuthClient:
|
|
125
125
|
|
126
126
|
async def main() -> None:
|
127
127
|
await client.auth.access_token(
|
128
|
-
grants={"tts": True},
|
128
|
+
grants={"tts": True, "stt": True},
|
129
129
|
expires_in=60,
|
130
130
|
)
|
131
131
|
|
@@ -1,10 +1,16 @@
|
|
1
1
|
# This file was auto-generated by Fern from our API Definition.
|
2
2
|
|
3
3
|
import typing_extensions
|
4
|
+
import typing_extensions
|
4
5
|
|
5
6
|
|
6
7
|
class TokenGrantParams(typing_extensions.TypedDict):
|
7
|
-
tts: bool
|
8
|
+
tts: typing_extensions.NotRequired[bool]
|
8
9
|
"""
|
9
10
|
The `tts` grant allows the token to be used to access any TTS endpoint.
|
10
11
|
"""
|
12
|
+
|
13
|
+
stt: typing_extensions.NotRequired[bool]
|
14
|
+
"""
|
15
|
+
The `stt` grant allows the token to be used to access any STT endpoint.
|
16
|
+
"""
|
@@ -1,14 +1,14 @@
|
|
1
1
|
# This file was auto-generated by Fern from our API Definition.
|
2
2
|
|
3
3
|
import typing_extensions
|
4
|
-
from .token_grant import TokenGrantParams
|
5
4
|
import typing_extensions
|
5
|
+
from .token_grant import TokenGrantParams
|
6
6
|
|
7
7
|
|
8
8
|
class TokenRequestParams(typing_extensions.TypedDict):
|
9
|
-
grants: TokenGrantParams
|
9
|
+
grants: typing_extensions.NotRequired[TokenGrantParams]
|
10
10
|
"""
|
11
|
-
The permissions to be granted via the token.
|
11
|
+
The permissions to be granted via the token. Both TTS and STT grants are optional - specify only the capabilities you need.
|
12
12
|
"""
|
13
13
|
|
14
14
|
expires_in: typing_extensions.NotRequired[int]
|
@@ -1,17 +1,22 @@
|
|
1
1
|
# This file was auto-generated by Fern from our API Definition.
|
2
2
|
|
3
3
|
from ...core.pydantic_utilities import UniversalBaseModel
|
4
|
+
import typing
|
4
5
|
import pydantic
|
5
6
|
from ...core.pydantic_utilities import IS_PYDANTIC_V2
|
6
|
-
import typing
|
7
7
|
|
8
8
|
|
9
9
|
class TokenGrant(UniversalBaseModel):
|
10
|
-
tts: bool = pydantic.Field()
|
10
|
+
tts: typing.Optional[bool] = pydantic.Field(default=None)
|
11
11
|
"""
|
12
12
|
The `tts` grant allows the token to be used to access any TTS endpoint.
|
13
13
|
"""
|
14
14
|
|
15
|
+
stt: typing.Optional[bool] = pydantic.Field(default=None)
|
16
|
+
"""
|
17
|
+
The `stt` grant allows the token to be used to access any STT endpoint.
|
18
|
+
"""
|
19
|
+
|
15
20
|
if IS_PYDANTIC_V2:
|
16
21
|
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
17
22
|
else:
|
@@ -1,16 +1,16 @@
|
|
1
1
|
# This file was auto-generated by Fern from our API Definition.
|
2
2
|
|
3
3
|
from ...core.pydantic_utilities import UniversalBaseModel
|
4
|
+
import typing
|
4
5
|
from .token_grant import TokenGrant
|
5
6
|
import pydantic
|
6
|
-
import typing
|
7
7
|
from ...core.pydantic_utilities import IS_PYDANTIC_V2
|
8
8
|
|
9
9
|
|
10
10
|
class TokenRequest(UniversalBaseModel):
|
11
|
-
grants: TokenGrant = pydantic.Field()
|
11
|
+
grants: typing.Optional[TokenGrant] = pydantic.Field(default=None)
|
12
12
|
"""
|
13
|
-
The permissions to be granted via the token.
|
13
|
+
The permissions to be granted via the token. Both TTS and STT grants are optional - specify only the capabilities you need.
|
14
14
|
"""
|
15
15
|
|
16
16
|
expires_in: typing.Optional[int] = pydantic.Field(default=None)
|
cartesia/base_client.py
CHANGED
@@ -7,6 +7,7 @@ from .core.client_wrapper import SyncClientWrapper
|
|
7
7
|
from .api_status.client import ApiStatusClient
|
8
8
|
from .auth.client import AuthClient
|
9
9
|
from .infill.client import InfillClient
|
10
|
+
from .stt.socket_client import AsyncSttClientWithWebsocket, SttClientWithWebsocket
|
10
11
|
from .tts.client import TtsClient
|
11
12
|
from .voice_changer.client import VoiceChangerClient
|
12
13
|
from .voices.client import VoicesClient
|
@@ -80,6 +81,7 @@ class BaseCartesia:
|
|
80
81
|
self.api_status = ApiStatusClient(client_wrapper=self._client_wrapper)
|
81
82
|
self.auth = AuthClient(client_wrapper=self._client_wrapper)
|
82
83
|
self.infill = InfillClient(client_wrapper=self._client_wrapper)
|
84
|
+
self.stt = SttClientWithWebsocket(client_wrapper=self._client_wrapper)
|
83
85
|
self.tts = TtsClient(client_wrapper=self._client_wrapper)
|
84
86
|
self.voice_changer = VoiceChangerClient(client_wrapper=self._client_wrapper)
|
85
87
|
self.voices = VoicesClient(client_wrapper=self._client_wrapper)
|
cartesia/client.py
CHANGED
@@ -8,6 +8,7 @@ import httpx
|
|
8
8
|
|
9
9
|
from .base_client import AsyncBaseCartesia, BaseCartesia
|
10
10
|
from .environment import CartesiaEnvironment
|
11
|
+
from .stt.socket_client import AsyncSttClientWithWebsocket, SttClientWithWebsocket
|
11
12
|
from .tts.socket_client import AsyncTtsClientWithWebsocket, TtsClientWithWebsocket
|
12
13
|
|
13
14
|
|
@@ -66,6 +67,7 @@ class Cartesia(BaseCartesia):
|
|
66
67
|
follow_redirects=follow_redirects,
|
67
68
|
httpx_client=httpx_client,
|
68
69
|
)
|
70
|
+
self.stt = SttClientWithWebsocket(client_wrapper=self._client_wrapper)
|
69
71
|
self.tts = TtsClientWithWebsocket(client_wrapper=self._client_wrapper)
|
70
72
|
|
71
73
|
def __enter__(self):
|
@@ -143,6 +145,9 @@ class AsyncCartesia(AsyncBaseCartesia):
|
|
143
145
|
self._session = None
|
144
146
|
self._loop = None
|
145
147
|
self.max_num_connections = max_num_connections
|
148
|
+
self.stt = AsyncSttClientWithWebsocket(
|
149
|
+
client_wrapper=self._client_wrapper, get_session=self._get_session
|
150
|
+
)
|
146
151
|
self.tts = AsyncTtsClientWithWebsocket(
|
147
152
|
client_wrapper=self._client_wrapper, get_session=self._get_session
|
148
153
|
)
|
cartesia/core/client_wrapper.py
CHANGED
@@ -16,7 +16,7 @@ class BaseClientWrapper:
|
|
16
16
|
headers: typing.Dict[str, str] = {
|
17
17
|
"X-Fern-Language": "Python",
|
18
18
|
"X-Fern-SDK-Name": "cartesia",
|
19
|
-
"X-Fern-SDK-Version": "2.0.
|
19
|
+
"X-Fern-SDK-Version": "2.0.6",
|
20
20
|
}
|
21
21
|
headers["X-API-Key"] = self.api_key
|
22
22
|
headers["Cartesia-Version"] = "2024-11-13"
|
cartesia/stt/__init__.py
ADDED
@@ -0,0 +1,57 @@
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
2
|
+
|
3
|
+
from .types import (
|
4
|
+
DoneMessage,
|
5
|
+
ErrorMessage,
|
6
|
+
FlushDoneMessage,
|
7
|
+
StreamingTranscriptionResponse,
|
8
|
+
StreamingTranscriptionResponse_Done,
|
9
|
+
StreamingTranscriptionResponse_Error,
|
10
|
+
StreamingTranscriptionResponse_FlushDone,
|
11
|
+
StreamingTranscriptionResponse_Transcript,
|
12
|
+
SttEncoding,
|
13
|
+
TimestampGranularity,
|
14
|
+
TranscriptMessage,
|
15
|
+
TranscriptionResponse,
|
16
|
+
TranscriptionWord,
|
17
|
+
)
|
18
|
+
from .requests import (
|
19
|
+
DoneMessageParams,
|
20
|
+
ErrorMessageParams,
|
21
|
+
FlushDoneMessageParams,
|
22
|
+
StreamingTranscriptionResponseParams,
|
23
|
+
StreamingTranscriptionResponse_DoneParams,
|
24
|
+
StreamingTranscriptionResponse_ErrorParams,
|
25
|
+
StreamingTranscriptionResponse_FlushDoneParams,
|
26
|
+
StreamingTranscriptionResponse_TranscriptParams,
|
27
|
+
TranscriptMessageParams,
|
28
|
+
TranscriptionResponseParams,
|
29
|
+
TranscriptionWordParams,
|
30
|
+
)
|
31
|
+
|
32
|
+
__all__ = [
|
33
|
+
"DoneMessage",
|
34
|
+
"DoneMessageParams",
|
35
|
+
"ErrorMessage",
|
36
|
+
"ErrorMessageParams",
|
37
|
+
"FlushDoneMessage",
|
38
|
+
"FlushDoneMessageParams",
|
39
|
+
"StreamingTranscriptionResponse",
|
40
|
+
"StreamingTranscriptionResponseParams",
|
41
|
+
"StreamingTranscriptionResponse_Done",
|
42
|
+
"StreamingTranscriptionResponse_DoneParams",
|
43
|
+
"StreamingTranscriptionResponse_Error",
|
44
|
+
"StreamingTranscriptionResponse_ErrorParams",
|
45
|
+
"StreamingTranscriptionResponse_FlushDone",
|
46
|
+
"StreamingTranscriptionResponse_FlushDoneParams",
|
47
|
+
"StreamingTranscriptionResponse_Transcript",
|
48
|
+
"StreamingTranscriptionResponse_TranscriptParams",
|
49
|
+
"SttEncoding",
|
50
|
+
"TimestampGranularity",
|
51
|
+
"TranscriptMessage",
|
52
|
+
"TranscriptMessageParams",
|
53
|
+
"TranscriptionResponse",
|
54
|
+
"TranscriptionResponseParams",
|
55
|
+
"TranscriptionWord",
|
56
|
+
"TranscriptionWordParams",
|
57
|
+
]
|
@@ -0,0 +1,293 @@
|
|
1
|
+
import asyncio
|
2
|
+
import json
|
3
|
+
import typing
|
4
|
+
import uuid
|
5
|
+
from typing import Any, Awaitable, AsyncGenerator, Callable, Dict, Optional, Union
|
6
|
+
|
7
|
+
import aiohttp
|
8
|
+
|
9
|
+
from cartesia.stt.types import (
|
10
|
+
StreamingTranscriptionResponse,
|
11
|
+
StreamingTranscriptionResponse_Error,
|
12
|
+
StreamingTranscriptionResponse_Transcript,
|
13
|
+
)
|
14
|
+
from cartesia.stt.types.stt_encoding import SttEncoding
|
15
|
+
|
16
|
+
from ..core.pydantic_utilities import parse_obj_as
|
17
|
+
from ._websocket import SttWebsocket
|
18
|
+
|
19
|
+
|
20
|
+
class AsyncSttWebsocket(SttWebsocket):
|
21
|
+
"""This class contains methods to transcribe audio using WebSocket asynchronously."""
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
ws_url: str,
|
26
|
+
api_key: str,
|
27
|
+
cartesia_version: str,
|
28
|
+
get_session: Callable[[], Awaitable[Optional[aiohttp.ClientSession]]],
|
29
|
+
timeout: float = 30,
|
30
|
+
):
|
31
|
+
"""
|
32
|
+
Args:
|
33
|
+
ws_url: The WebSocket URL for the Cartesia API.
|
34
|
+
api_key: The API key to use for authorization.
|
35
|
+
cartesia_version: The version of the Cartesia API to use.
|
36
|
+
timeout: The timeout for responses on the WebSocket in seconds.
|
37
|
+
get_session: A function that returns an awaitable of aiohttp.ClientSession object.
|
38
|
+
"""
|
39
|
+
super().__init__(ws_url, api_key, cartesia_version)
|
40
|
+
self.timeout = timeout
|
41
|
+
self._get_session = get_session
|
42
|
+
self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None
|
43
|
+
self._default_model: str = "ink-whisper"
|
44
|
+
self._default_language: Optional[str] = "en"
|
45
|
+
self._default_encoding: SttEncoding = "pcm_s16le"
|
46
|
+
self._default_sample_rate: int = 16000
|
47
|
+
self._default_min_volume: Optional[float] = None
|
48
|
+
self._default_max_silence_duration_secs: Optional[float] = None
|
49
|
+
|
50
|
+
def __del__(self):
|
51
|
+
try:
|
52
|
+
loop = asyncio.get_running_loop()
|
53
|
+
except RuntimeError:
|
54
|
+
loop = None
|
55
|
+
|
56
|
+
if loop is None:
|
57
|
+
asyncio.run(self.close())
|
58
|
+
elif loop.is_running():
|
59
|
+
loop.create_task(self.close())
|
60
|
+
|
61
|
+
async def connect(
|
62
|
+
self,
|
63
|
+
*,
|
64
|
+
model: str = "ink-whisper",
|
65
|
+
language: Optional[str] = "en",
|
66
|
+
encoding: SttEncoding = "pcm_s16le",
|
67
|
+
sample_rate: int = 16000,
|
68
|
+
min_volume: Optional[float] = None,
|
69
|
+
max_silence_duration_secs: Optional[float] = None,
|
70
|
+
):
|
71
|
+
"""Connect to the STT WebSocket with the specified parameters.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
model: ID of the model to use for transcription (required)
|
75
|
+
language: The language of the input audio in ISO-639-1 format (defaults to "en")
|
76
|
+
encoding: The encoding format of the audio data (required)
|
77
|
+
sample_rate: The sample rate of the audio in Hz (required)
|
78
|
+
min_volume: Volume threshold for voice activity detection (0.0-1.0)
|
79
|
+
max_silence_duration_secs: Maximum duration of silence before endpointing
|
80
|
+
|
81
|
+
Raises:
|
82
|
+
RuntimeError: If the connection to the WebSocket fails.
|
83
|
+
"""
|
84
|
+
self._default_model = model
|
85
|
+
self._default_language = language
|
86
|
+
self._default_encoding = encoding
|
87
|
+
self._default_sample_rate = sample_rate
|
88
|
+
self._default_min_volume = min_volume
|
89
|
+
self._default_max_silence_duration_secs = max_silence_duration_secs
|
90
|
+
|
91
|
+
if self.websocket is None or self._is_websocket_closed():
|
92
|
+
route = "stt/websocket"
|
93
|
+
session = await self._get_session()
|
94
|
+
|
95
|
+
params = {
|
96
|
+
"model": model,
|
97
|
+
"api_key": self.api_key,
|
98
|
+
"cartesia_version": self.cartesia_version,
|
99
|
+
"encoding": encoding,
|
100
|
+
"sample_rate": str(sample_rate),
|
101
|
+
}
|
102
|
+
if language is not None:
|
103
|
+
params["language"] = language
|
104
|
+
if min_volume is not None:
|
105
|
+
params["min_volume"] = str(min_volume)
|
106
|
+
if max_silence_duration_secs is not None:
|
107
|
+
params["max_silence_duration_secs"] = str(max_silence_duration_secs)
|
108
|
+
|
109
|
+
query_string = "&".join([f"{k}={v}" for k, v in params.items()])
|
110
|
+
url = f"{self.ws_url}/{route}?{query_string}"
|
111
|
+
|
112
|
+
try:
|
113
|
+
if session is None:
|
114
|
+
raise RuntimeError("Session is not available")
|
115
|
+
self.websocket = await session.ws_connect(url)
|
116
|
+
except Exception as e:
|
117
|
+
status_code = None
|
118
|
+
error_message = str(e)
|
119
|
+
|
120
|
+
if hasattr(e, 'status') and e.status is not None:
|
121
|
+
status_code = e.status
|
122
|
+
|
123
|
+
if status_code == 402:
|
124
|
+
error_message = "Payment required. Your API key may have insufficient credits or permissions."
|
125
|
+
elif status_code == 401:
|
126
|
+
error_message = "Unauthorized. Please check your API key."
|
127
|
+
elif status_code == 403:
|
128
|
+
error_message = "Forbidden. You don't have permission to access this resource."
|
129
|
+
elif status_code == 404:
|
130
|
+
error_message = "Not found. The requested resource doesn't exist."
|
131
|
+
|
132
|
+
raise RuntimeError(f"Failed to connect to WebSocket.\nStatus: {status_code}. Error message: {error_message}")
|
133
|
+
else:
|
134
|
+
raise RuntimeError(f"Failed to connect to WebSocket at {url}. {e}")
|
135
|
+
|
136
|
+
def _is_websocket_closed(self):
|
137
|
+
return self.websocket is None or self.websocket.closed
|
138
|
+
|
139
|
+
async def close(self):
|
140
|
+
"""This method closes the websocket connection. Highly recommended to call this method when done."""
|
141
|
+
if self.websocket is not None and not self._is_websocket_closed():
|
142
|
+
await self.websocket.close()
|
143
|
+
self.websocket = None
|
144
|
+
|
145
|
+
async def send(self, data: Union[bytes, str]):
|
146
|
+
"""Send audio data or control commands to the WebSocket.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
data: Binary audio data or text command ("finalize" or "done")
|
150
|
+
"""
|
151
|
+
if self.websocket is None or self._is_websocket_closed():
|
152
|
+
await self.connect(
|
153
|
+
model=self._default_model,
|
154
|
+
language=self._default_language,
|
155
|
+
encoding=self._default_encoding,
|
156
|
+
sample_rate=self._default_sample_rate,
|
157
|
+
min_volume=self._default_min_volume,
|
158
|
+
max_silence_duration_secs=self._default_max_silence_duration_secs,
|
159
|
+
)
|
160
|
+
|
161
|
+
assert self.websocket is not None, "WebSocket should be connected after connect() call"
|
162
|
+
|
163
|
+
if isinstance(data, bytes):
|
164
|
+
await self.websocket.send_bytes(data)
|
165
|
+
elif isinstance(data, str):
|
166
|
+
await self.websocket.send_str(data)
|
167
|
+
else:
|
168
|
+
raise TypeError("Data must be bytes (audio) or str (command)")
|
169
|
+
|
170
|
+
async def receive(self) -> AsyncGenerator[Dict[str, Any], None]: # type: ignore[override]
|
171
|
+
"""Receive transcription results from the WebSocket.
|
172
|
+
|
173
|
+
Yields:
|
174
|
+
Dictionary containing transcription results, flush_done, done, or error messages
|
175
|
+
"""
|
176
|
+
if self.websocket is None or self._is_websocket_closed():
|
177
|
+
await self.connect(
|
178
|
+
model=self._default_model,
|
179
|
+
language=self._default_language,
|
180
|
+
encoding=self._default_encoding,
|
181
|
+
sample_rate=self._default_sample_rate,
|
182
|
+
min_volume=self._default_min_volume,
|
183
|
+
max_silence_duration_secs=self._default_max_silence_duration_secs,
|
184
|
+
)
|
185
|
+
|
186
|
+
assert self.websocket is not None, "WebSocket should be connected after connect() call"
|
187
|
+
|
188
|
+
try:
|
189
|
+
async for message in self.websocket:
|
190
|
+
if message.type == aiohttp.WSMsgType.TEXT:
|
191
|
+
raw_data = json.loads(message.data)
|
192
|
+
|
193
|
+
# Handle error responses
|
194
|
+
if raw_data.get("type") == "error":
|
195
|
+
raise RuntimeError(f"Error transcribing audio: {raw_data.get('message', 'Unknown error')}")
|
196
|
+
|
197
|
+
# Handle transcript responses with flexible parsing
|
198
|
+
if raw_data.get("type") == "transcript":
|
199
|
+
# Provide defaults for missing required fields
|
200
|
+
result = {
|
201
|
+
"type": raw_data["type"],
|
202
|
+
"request_id": raw_data.get("request_id", ""),
|
203
|
+
"text": raw_data.get("text", ""), # Default to empty string if missing
|
204
|
+
"is_final": raw_data.get("is_final", False), # Default to False if missing
|
205
|
+
}
|
206
|
+
|
207
|
+
# Add optional fields if present
|
208
|
+
if "duration" in raw_data:
|
209
|
+
result["duration"] = raw_data["duration"]
|
210
|
+
if "language" in raw_data:
|
211
|
+
result["language"] = raw_data["language"]
|
212
|
+
if "words" in raw_data:
|
213
|
+
result["words"] = raw_data["words"]
|
214
|
+
|
215
|
+
yield result
|
216
|
+
|
217
|
+
# Handle flush_done acknowledgment
|
218
|
+
elif raw_data.get("type") == "flush_done":
|
219
|
+
result = {
|
220
|
+
"type": raw_data["type"],
|
221
|
+
"request_id": raw_data.get("request_id", ""),
|
222
|
+
}
|
223
|
+
yield result
|
224
|
+
|
225
|
+
# Handle done acknowledgment
|
226
|
+
elif raw_data.get("type") == "done":
|
227
|
+
result = {
|
228
|
+
"type": raw_data["type"],
|
229
|
+
"request_id": raw_data.get("request_id", ""),
|
230
|
+
}
|
231
|
+
yield result
|
232
|
+
break # Exit the loop when done
|
233
|
+
|
234
|
+
elif message.type == aiohttp.WSMsgType.ERROR:
|
235
|
+
error_message = f"WebSocket error: {self.websocket.exception()}"
|
236
|
+
raise RuntimeError(error_message)
|
237
|
+
elif message.type == aiohttp.WSMsgType.CLOSE:
|
238
|
+
break # WebSocket was closed
|
239
|
+
except Exception as e:
|
240
|
+
await self.close()
|
241
|
+
raise e
|
242
|
+
|
243
|
+
async def transcribe( # type: ignore[override]
|
244
|
+
self,
|
245
|
+
audio_chunks: typing.AsyncIterator[bytes],
|
246
|
+
*,
|
247
|
+
model: str = "ink-whisper",
|
248
|
+
language: Optional[str] = "en",
|
249
|
+
encoding: SttEncoding = "pcm_s16le",
|
250
|
+
sample_rate: int = 16000,
|
251
|
+
min_volume: Optional[float] = None,
|
252
|
+
max_silence_duration_secs: Optional[float] = None,
|
253
|
+
) -> AsyncGenerator[Dict[str, Any], None]:
|
254
|
+
"""Transcribe audio chunks using the WebSocket.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
audio_chunks: Async iterator of audio chunks as bytes
|
258
|
+
model: ID of the model to use for transcription (required)
|
259
|
+
language: The language of the input audio in ISO-639-1 format (defaults to "en")
|
260
|
+
encoding: The encoding format of the audio data (required)
|
261
|
+
sample_rate: The sample rate of the audio in Hz (required)
|
262
|
+
min_volume: Volume threshold for voice activity detection (0.0-1.0)
|
263
|
+
max_silence_duration_secs: Maximum duration of silence before endpointing
|
264
|
+
|
265
|
+
Yields:
|
266
|
+
Dictionary containing transcription results, flush_done, done, or error messages
|
267
|
+
"""
|
268
|
+
await self.connect(
|
269
|
+
model=model,
|
270
|
+
language=language,
|
271
|
+
encoding=encoding,
|
272
|
+
sample_rate=sample_rate,
|
273
|
+
min_volume=min_volume,
|
274
|
+
max_silence_duration_secs=max_silence_duration_secs,
|
275
|
+
)
|
276
|
+
|
277
|
+
try:
|
278
|
+
# Send all audio chunks
|
279
|
+
async for chunk in audio_chunks:
|
280
|
+
await self.send(chunk)
|
281
|
+
|
282
|
+
# Send finalize command to flush remaining audio
|
283
|
+
await self.send("finalize")
|
284
|
+
|
285
|
+
# Send done command to close session cleanly
|
286
|
+
await self.send("done")
|
287
|
+
|
288
|
+
# Receive all responses until done
|
289
|
+
async for result in self.receive():
|
290
|
+
yield result
|
291
|
+
|
292
|
+
finally:
|
293
|
+
await self.close()
|