cartesia 2.0.4__py3-none-any.whl → 2.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. cartesia/__init__.py +60 -1
  2. cartesia/auth/client.py +8 -8
  3. cartesia/auth/requests/token_grant.py +7 -1
  4. cartesia/auth/requests/token_request.py +3 -3
  5. cartesia/auth/types/token_grant.py +7 -2
  6. cartesia/auth/types/token_request.py +3 -3
  7. cartesia/base_client.py +2 -0
  8. cartesia/client.py +5 -0
  9. cartesia/core/client_wrapper.py +1 -1
  10. cartesia/stt/__init__.py +57 -0
  11. cartesia/stt/_async_websocket.py +293 -0
  12. cartesia/stt/_websocket.py +294 -0
  13. cartesia/stt/client.py +456 -0
  14. cartesia/stt/requests/__init__.py +29 -0
  15. cartesia/stt/requests/done_message.py +14 -0
  16. cartesia/stt/requests/error_message.py +16 -0
  17. cartesia/stt/requests/flush_done_message.py +14 -0
  18. cartesia/stt/requests/streaming_transcription_response.py +41 -0
  19. cartesia/stt/requests/transcript_message.py +40 -0
  20. cartesia/stt/requests/transcription_response.py +28 -0
  21. cartesia/stt/requests/transcription_word.py +20 -0
  22. cartesia/stt/socket_client.py +138 -0
  23. cartesia/stt/types/__init__.py +33 -0
  24. cartesia/stt/types/done_message.py +26 -0
  25. cartesia/stt/types/error_message.py +27 -0
  26. cartesia/stt/types/flush_done_message.py +26 -0
  27. cartesia/stt/types/streaming_transcription_response.py +94 -0
  28. cartesia/stt/types/stt_encoding.py +7 -0
  29. cartesia/stt/types/timestamp_granularity.py +5 -0
  30. cartesia/stt/types/transcript_message.py +50 -0
  31. cartesia/stt/types/transcription_response.py +38 -0
  32. cartesia/stt/types/transcription_word.py +32 -0
  33. cartesia/tts/__init__.py +8 -0
  34. cartesia/tts/client.py +50 -8
  35. cartesia/tts/requests/__init__.py +4 -0
  36. cartesia/tts/requests/generation_request.py +4 -4
  37. cartesia/tts/requests/sse_output_format.py +11 -0
  38. cartesia/tts/requests/ttssse_request.py +47 -0
  39. cartesia/tts/requests/web_socket_chunk_response.py +0 -3
  40. cartesia/tts/requests/web_socket_response.py +1 -2
  41. cartesia/tts/requests/web_socket_tts_request.py +9 -1
  42. cartesia/tts/types/__init__.py +4 -0
  43. cartesia/tts/types/generation_request.py +4 -4
  44. cartesia/tts/types/sse_output_format.py +22 -0
  45. cartesia/tts/types/ttssse_request.py +58 -0
  46. cartesia/tts/types/web_socket_chunk_response.py +1 -3
  47. cartesia/tts/types/web_socket_response.py +1 -2
  48. cartesia/tts/types/web_socket_tts_request.py +11 -3
  49. cartesia/voice_changer/requests/streaming_response.py +0 -2
  50. cartesia/voice_changer/types/streaming_response.py +0 -2
  51. {cartesia-2.0.4.dist-info → cartesia-2.0.6.dist-info}/METADATA +256 -2
  52. {cartesia-2.0.4.dist-info → cartesia-2.0.6.dist-info}/RECORD +53 -26
  53. {cartesia-2.0.4.dist-info → cartesia-2.0.6.dist-info}/WHEEL +0 -0
cartesia/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # This file was auto-generated by Fern from our API Definition.
2
2
 
3
- from . import api_status, auth, datasets, embedding, infill, tts, voice_changer, voices
3
+ from . import api_status, auth, datasets, embedding, infill, stt, tts, voice_changer, voices
4
4
  from .api_status import ApiInfo, ApiInfoParams
5
5
  from .auth import TokenGrant, TokenGrantParams, TokenRequest, TokenRequestParams, TokenResponse, TokenResponseParams
6
6
  from .client import AsyncCartesia, Cartesia
@@ -19,6 +19,32 @@ from .datasets import (
19
19
  )
20
20
  from .embedding import Embedding
21
21
  from .environment import CartesiaEnvironment
22
+ from .stt import (
23
+ DoneMessage,
24
+ DoneMessageParams,
25
+ ErrorMessage,
26
+ ErrorMessageParams,
27
+ FlushDoneMessage,
28
+ FlushDoneMessageParams,
29
+ StreamingTranscriptionResponse,
30
+ StreamingTranscriptionResponseParams,
31
+ StreamingTranscriptionResponse_Done,
32
+ StreamingTranscriptionResponse_DoneParams,
33
+ StreamingTranscriptionResponse_Error,
34
+ StreamingTranscriptionResponse_ErrorParams,
35
+ StreamingTranscriptionResponse_FlushDone,
36
+ StreamingTranscriptionResponse_FlushDoneParams,
37
+ StreamingTranscriptionResponse_Transcript,
38
+ StreamingTranscriptionResponse_TranscriptParams,
39
+ SttEncoding,
40
+ TimestampGranularity,
41
+ TranscriptMessage,
42
+ TranscriptMessageParams,
43
+ TranscriptionResponse,
44
+ TranscriptionResponseParams,
45
+ TranscriptionWord,
46
+ TranscriptionWordParams,
47
+ )
22
48
  from .tts import (
23
49
  CancelContextRequest,
24
50
  CancelContextRequestParams,
@@ -49,6 +75,8 @@ from .tts import (
49
75
  RawOutputFormatParams,
50
76
  Speed,
51
77
  SpeedParams,
78
+ SseOutputFormat,
79
+ SseOutputFormatParams,
52
80
  SupportedLanguage,
53
81
  TtsRequest,
54
82
  TtsRequestEmbeddingSpecifier,
@@ -58,6 +86,8 @@ from .tts import (
58
86
  TtsRequestParams,
59
87
  TtsRequestVoiceSpecifier,
60
88
  TtsRequestVoiceSpecifierParams,
89
+ TtssseRequest,
90
+ TtssseRequestParams,
61
91
  WavOutputFormat,
62
92
  WavOutputFormatParams,
63
93
  WebSocketBaseResponse,
@@ -173,13 +203,19 @@ __all__ = [
173
203
  "DatasetFile",
174
204
  "DatasetFileParams",
175
205
  "DatasetParams",
206
+ "DoneMessage",
207
+ "DoneMessageParams",
176
208
  "Embedding",
177
209
  "EmbeddingResponse",
178
210
  "EmbeddingResponseParams",
179
211
  "EmbeddingSpecifier",
180
212
  "EmbeddingSpecifierParams",
181
213
  "Emotion",
214
+ "ErrorMessage",
215
+ "ErrorMessageParams",
182
216
  "FilePurpose",
217
+ "FlushDoneMessage",
218
+ "FlushDoneMessageParams",
183
219
  "FlushId",
184
220
  "Gender",
185
221
  "GenderPresentation",
@@ -227,6 +263,8 @@ __all__ = [
227
263
  "RawOutputFormatParams",
228
264
  "Speed",
229
265
  "SpeedParams",
266
+ "SseOutputFormat",
267
+ "SseOutputFormatParams",
230
268
  "StreamingResponse",
231
269
  "StreamingResponseParams",
232
270
  "StreamingResponse_Chunk",
@@ -235,13 +273,31 @@ __all__ = [
235
273
  "StreamingResponse_DoneParams",
236
274
  "StreamingResponse_Error",
237
275
  "StreamingResponse_ErrorParams",
276
+ "StreamingTranscriptionResponse",
277
+ "StreamingTranscriptionResponseParams",
278
+ "StreamingTranscriptionResponse_Done",
279
+ "StreamingTranscriptionResponse_DoneParams",
280
+ "StreamingTranscriptionResponse_Error",
281
+ "StreamingTranscriptionResponse_ErrorParams",
282
+ "StreamingTranscriptionResponse_FlushDone",
283
+ "StreamingTranscriptionResponse_FlushDoneParams",
284
+ "StreamingTranscriptionResponse_Transcript",
285
+ "StreamingTranscriptionResponse_TranscriptParams",
286
+ "SttEncoding",
238
287
  "SupportedLanguage",
288
+ "TimestampGranularity",
239
289
  "TokenGrant",
240
290
  "TokenGrantParams",
241
291
  "TokenRequest",
242
292
  "TokenRequestParams",
243
293
  "TokenResponse",
244
294
  "TokenResponseParams",
295
+ "TranscriptMessage",
296
+ "TranscriptMessageParams",
297
+ "TranscriptionResponse",
298
+ "TranscriptionResponseParams",
299
+ "TranscriptionWord",
300
+ "TranscriptionWordParams",
245
301
  "TtsRequest",
246
302
  "TtsRequestEmbeddingSpecifier",
247
303
  "TtsRequestEmbeddingSpecifierParams",
@@ -250,6 +306,8 @@ __all__ = [
250
306
  "TtsRequestParams",
251
307
  "TtsRequestVoiceSpecifier",
252
308
  "TtsRequestVoiceSpecifierParams",
309
+ "TtssseRequest",
310
+ "TtssseRequestParams",
253
311
  "UpdateVoiceRequest",
254
312
  "UpdateVoiceRequestParams",
255
313
  "Voice",
@@ -307,6 +365,7 @@ __all__ = [
307
365
  "datasets",
308
366
  "embedding",
309
367
  "infill",
368
+ "stt",
310
369
  "tts",
311
370
  "voice_changer",
312
371
  "voices",
cartesia/auth/client.py CHANGED
@@ -22,7 +22,7 @@ class AuthClient:
22
22
  def access_token(
23
23
  self,
24
24
  *,
25
- grants: TokenGrantParams,
25
+ grants: typing.Optional[TokenGrantParams] = OMIT,
26
26
  expires_in: typing.Optional[int] = OMIT,
27
27
  request_options: typing.Optional[RequestOptions] = None,
28
28
  ) -> TokenResponse:
@@ -31,8 +31,8 @@ class AuthClient:
31
31
 
32
32
  Parameters
33
33
  ----------
34
- grants : TokenGrantParams
35
- The permissions to be granted via the token.
34
+ grants : typing.Optional[TokenGrantParams]
35
+ The permissions to be granted via the token. Both TTS and STT grants are optional - specify only the capabilities you need.
36
36
 
37
37
  expires_in : typing.Optional[int]
38
38
  The number of seconds the token will be valid for since the time of generation. The maximum is 1 hour (3600 seconds).
@@ -52,7 +52,7 @@ class AuthClient:
52
52
  api_key="YOUR_API_KEY",
53
53
  )
54
54
  client.auth.access_token(
55
- grants={"tts": True},
55
+ grants={"tts": True, "stt": True},
56
56
  expires_in=60,
57
57
  )
58
58
  """
@@ -90,7 +90,7 @@ class AsyncAuthClient:
90
90
  async def access_token(
91
91
  self,
92
92
  *,
93
- grants: TokenGrantParams,
93
+ grants: typing.Optional[TokenGrantParams] = OMIT,
94
94
  expires_in: typing.Optional[int] = OMIT,
95
95
  request_options: typing.Optional[RequestOptions] = None,
96
96
  ) -> TokenResponse:
@@ -99,8 +99,8 @@ class AsyncAuthClient:
99
99
 
100
100
  Parameters
101
101
  ----------
102
- grants : TokenGrantParams
103
- The permissions to be granted via the token.
102
+ grants : typing.Optional[TokenGrantParams]
103
+ The permissions to be granted via the token. Both TTS and STT grants are optional - specify only the capabilities you need.
104
104
 
105
105
  expires_in : typing.Optional[int]
106
106
  The number of seconds the token will be valid for since the time of generation. The maximum is 1 hour (3600 seconds).
@@ -125,7 +125,7 @@ class AsyncAuthClient:
125
125
 
126
126
  async def main() -> None:
127
127
  await client.auth.access_token(
128
- grants={"tts": True},
128
+ grants={"tts": True, "stt": True},
129
129
  expires_in=60,
130
130
  )
131
131
 
@@ -1,10 +1,16 @@
1
1
  # This file was auto-generated by Fern from our API Definition.
2
2
 
3
3
  import typing_extensions
4
+ import typing_extensions
4
5
 
5
6
 
6
7
  class TokenGrantParams(typing_extensions.TypedDict):
7
- tts: bool
8
+ tts: typing_extensions.NotRequired[bool]
8
9
  """
9
10
  The `tts` grant allows the token to be used to access any TTS endpoint.
10
11
  """
12
+
13
+ stt: typing_extensions.NotRequired[bool]
14
+ """
15
+ The `stt` grant allows the token to be used to access any STT endpoint.
16
+ """
@@ -1,14 +1,14 @@
1
1
  # This file was auto-generated by Fern from our API Definition.
2
2
 
3
3
  import typing_extensions
4
- from .token_grant import TokenGrantParams
5
4
  import typing_extensions
5
+ from .token_grant import TokenGrantParams
6
6
 
7
7
 
8
8
  class TokenRequestParams(typing_extensions.TypedDict):
9
- grants: TokenGrantParams
9
+ grants: typing_extensions.NotRequired[TokenGrantParams]
10
10
  """
11
- The permissions to be granted via the token.
11
+ The permissions to be granted via the token. Both TTS and STT grants are optional - specify only the capabilities you need.
12
12
  """
13
13
 
14
14
  expires_in: typing_extensions.NotRequired[int]
@@ -1,17 +1,22 @@
1
1
  # This file was auto-generated by Fern from our API Definition.
2
2
 
3
3
  from ...core.pydantic_utilities import UniversalBaseModel
4
+ import typing
4
5
  import pydantic
5
6
  from ...core.pydantic_utilities import IS_PYDANTIC_V2
6
- import typing
7
7
 
8
8
 
9
9
  class TokenGrant(UniversalBaseModel):
10
- tts: bool = pydantic.Field()
10
+ tts: typing.Optional[bool] = pydantic.Field(default=None)
11
11
  """
12
12
  The `tts` grant allows the token to be used to access any TTS endpoint.
13
13
  """
14
14
 
15
+ stt: typing.Optional[bool] = pydantic.Field(default=None)
16
+ """
17
+ The `stt` grant allows the token to be used to access any STT endpoint.
18
+ """
19
+
15
20
  if IS_PYDANTIC_V2:
16
21
  model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
17
22
  else:
@@ -1,16 +1,16 @@
1
1
  # This file was auto-generated by Fern from our API Definition.
2
2
 
3
3
  from ...core.pydantic_utilities import UniversalBaseModel
4
+ import typing
4
5
  from .token_grant import TokenGrant
5
6
  import pydantic
6
- import typing
7
7
  from ...core.pydantic_utilities import IS_PYDANTIC_V2
8
8
 
9
9
 
10
10
  class TokenRequest(UniversalBaseModel):
11
- grants: TokenGrant = pydantic.Field()
11
+ grants: typing.Optional[TokenGrant] = pydantic.Field(default=None)
12
12
  """
13
- The permissions to be granted via the token.
13
+ The permissions to be granted via the token. Both TTS and STT grants are optional - specify only the capabilities you need.
14
14
  """
15
15
 
16
16
  expires_in: typing.Optional[int] = pydantic.Field(default=None)
cartesia/base_client.py CHANGED
@@ -7,6 +7,7 @@ from .core.client_wrapper import SyncClientWrapper
7
7
  from .api_status.client import ApiStatusClient
8
8
  from .auth.client import AuthClient
9
9
  from .infill.client import InfillClient
10
+ from .stt.socket_client import AsyncSttClientWithWebsocket, SttClientWithWebsocket
10
11
  from .tts.client import TtsClient
11
12
  from .voice_changer.client import VoiceChangerClient
12
13
  from .voices.client import VoicesClient
@@ -80,6 +81,7 @@ class BaseCartesia:
80
81
  self.api_status = ApiStatusClient(client_wrapper=self._client_wrapper)
81
82
  self.auth = AuthClient(client_wrapper=self._client_wrapper)
82
83
  self.infill = InfillClient(client_wrapper=self._client_wrapper)
84
+ self.stt = SttClientWithWebsocket(client_wrapper=self._client_wrapper)
83
85
  self.tts = TtsClient(client_wrapper=self._client_wrapper)
84
86
  self.voice_changer = VoiceChangerClient(client_wrapper=self._client_wrapper)
85
87
  self.voices = VoicesClient(client_wrapper=self._client_wrapper)
cartesia/client.py CHANGED
@@ -8,6 +8,7 @@ import httpx
8
8
 
9
9
  from .base_client import AsyncBaseCartesia, BaseCartesia
10
10
  from .environment import CartesiaEnvironment
11
+ from .stt.socket_client import AsyncSttClientWithWebsocket, SttClientWithWebsocket
11
12
  from .tts.socket_client import AsyncTtsClientWithWebsocket, TtsClientWithWebsocket
12
13
 
13
14
 
@@ -66,6 +67,7 @@ class Cartesia(BaseCartesia):
66
67
  follow_redirects=follow_redirects,
67
68
  httpx_client=httpx_client,
68
69
  )
70
+ self.stt = SttClientWithWebsocket(client_wrapper=self._client_wrapper)
69
71
  self.tts = TtsClientWithWebsocket(client_wrapper=self._client_wrapper)
70
72
 
71
73
  def __enter__(self):
@@ -143,6 +145,9 @@ class AsyncCartesia(AsyncBaseCartesia):
143
145
  self._session = None
144
146
  self._loop = None
145
147
  self.max_num_connections = max_num_connections
148
+ self.stt = AsyncSttClientWithWebsocket(
149
+ client_wrapper=self._client_wrapper, get_session=self._get_session
150
+ )
146
151
  self.tts = AsyncTtsClientWithWebsocket(
147
152
  client_wrapper=self._client_wrapper, get_session=self._get_session
148
153
  )
@@ -16,7 +16,7 @@ class BaseClientWrapper:
16
16
  headers: typing.Dict[str, str] = {
17
17
  "X-Fern-Language": "Python",
18
18
  "X-Fern-SDK-Name": "cartesia",
19
- "X-Fern-SDK-Version": "2.0.4",
19
+ "X-Fern-SDK-Version": "2.0.6",
20
20
  }
21
21
  headers["X-API-Key"] = self.api_key
22
22
  headers["Cartesia-Version"] = "2024-11-13"
@@ -0,0 +1,57 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from .types import (
4
+ DoneMessage,
5
+ ErrorMessage,
6
+ FlushDoneMessage,
7
+ StreamingTranscriptionResponse,
8
+ StreamingTranscriptionResponse_Done,
9
+ StreamingTranscriptionResponse_Error,
10
+ StreamingTranscriptionResponse_FlushDone,
11
+ StreamingTranscriptionResponse_Transcript,
12
+ SttEncoding,
13
+ TimestampGranularity,
14
+ TranscriptMessage,
15
+ TranscriptionResponse,
16
+ TranscriptionWord,
17
+ )
18
+ from .requests import (
19
+ DoneMessageParams,
20
+ ErrorMessageParams,
21
+ FlushDoneMessageParams,
22
+ StreamingTranscriptionResponseParams,
23
+ StreamingTranscriptionResponse_DoneParams,
24
+ StreamingTranscriptionResponse_ErrorParams,
25
+ StreamingTranscriptionResponse_FlushDoneParams,
26
+ StreamingTranscriptionResponse_TranscriptParams,
27
+ TranscriptMessageParams,
28
+ TranscriptionResponseParams,
29
+ TranscriptionWordParams,
30
+ )
31
+
32
+ __all__ = [
33
+ "DoneMessage",
34
+ "DoneMessageParams",
35
+ "ErrorMessage",
36
+ "ErrorMessageParams",
37
+ "FlushDoneMessage",
38
+ "FlushDoneMessageParams",
39
+ "StreamingTranscriptionResponse",
40
+ "StreamingTranscriptionResponseParams",
41
+ "StreamingTranscriptionResponse_Done",
42
+ "StreamingTranscriptionResponse_DoneParams",
43
+ "StreamingTranscriptionResponse_Error",
44
+ "StreamingTranscriptionResponse_ErrorParams",
45
+ "StreamingTranscriptionResponse_FlushDone",
46
+ "StreamingTranscriptionResponse_FlushDoneParams",
47
+ "StreamingTranscriptionResponse_Transcript",
48
+ "StreamingTranscriptionResponse_TranscriptParams",
49
+ "SttEncoding",
50
+ "TimestampGranularity",
51
+ "TranscriptMessage",
52
+ "TranscriptMessageParams",
53
+ "TranscriptionResponse",
54
+ "TranscriptionResponseParams",
55
+ "TranscriptionWord",
56
+ "TranscriptionWordParams",
57
+ ]
@@ -0,0 +1,293 @@
1
+ import asyncio
2
+ import json
3
+ import typing
4
+ import uuid
5
+ from typing import Any, Awaitable, AsyncGenerator, Callable, Dict, Optional, Union
6
+
7
+ import aiohttp
8
+
9
+ from cartesia.stt.types import (
10
+ StreamingTranscriptionResponse,
11
+ StreamingTranscriptionResponse_Error,
12
+ StreamingTranscriptionResponse_Transcript,
13
+ )
14
+ from cartesia.stt.types.stt_encoding import SttEncoding
15
+
16
+ from ..core.pydantic_utilities import parse_obj_as
17
+ from ._websocket import SttWebsocket
18
+
19
+
20
+ class AsyncSttWebsocket(SttWebsocket):
21
+ """This class contains methods to transcribe audio using WebSocket asynchronously."""
22
+
23
+ def __init__(
24
+ self,
25
+ ws_url: str,
26
+ api_key: str,
27
+ cartesia_version: str,
28
+ get_session: Callable[[], Awaitable[Optional[aiohttp.ClientSession]]],
29
+ timeout: float = 30,
30
+ ):
31
+ """
32
+ Args:
33
+ ws_url: The WebSocket URL for the Cartesia API.
34
+ api_key: The API key to use for authorization.
35
+ cartesia_version: The version of the Cartesia API to use.
36
+ timeout: The timeout for responses on the WebSocket in seconds.
37
+ get_session: A function that returns an awaitable of aiohttp.ClientSession object.
38
+ """
39
+ super().__init__(ws_url, api_key, cartesia_version)
40
+ self.timeout = timeout
41
+ self._get_session = get_session
42
+ self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None
43
+ self._default_model: str = "ink-whisper"
44
+ self._default_language: Optional[str] = "en"
45
+ self._default_encoding: SttEncoding = "pcm_s16le"
46
+ self._default_sample_rate: int = 16000
47
+ self._default_min_volume: Optional[float] = None
48
+ self._default_max_silence_duration_secs: Optional[float] = None
49
+
50
+ def __del__(self):
51
+ try:
52
+ loop = asyncio.get_running_loop()
53
+ except RuntimeError:
54
+ loop = None
55
+
56
+ if loop is None:
57
+ asyncio.run(self.close())
58
+ elif loop.is_running():
59
+ loop.create_task(self.close())
60
+
61
+ async def connect(
62
+ self,
63
+ *,
64
+ model: str = "ink-whisper",
65
+ language: Optional[str] = "en",
66
+ encoding: SttEncoding = "pcm_s16le",
67
+ sample_rate: int = 16000,
68
+ min_volume: Optional[float] = None,
69
+ max_silence_duration_secs: Optional[float] = None,
70
+ ):
71
+ """Connect to the STT WebSocket with the specified parameters.
72
+
73
+ Args:
74
+ model: ID of the model to use for transcription (required)
75
+ language: The language of the input audio in ISO-639-1 format (defaults to "en")
76
+ encoding: The encoding format of the audio data (required)
77
+ sample_rate: The sample rate of the audio in Hz (required)
78
+ min_volume: Volume threshold for voice activity detection (0.0-1.0)
79
+ max_silence_duration_secs: Maximum duration of silence before endpointing
80
+
81
+ Raises:
82
+ RuntimeError: If the connection to the WebSocket fails.
83
+ """
84
+ self._default_model = model
85
+ self._default_language = language
86
+ self._default_encoding = encoding
87
+ self._default_sample_rate = sample_rate
88
+ self._default_min_volume = min_volume
89
+ self._default_max_silence_duration_secs = max_silence_duration_secs
90
+
91
+ if self.websocket is None or self._is_websocket_closed():
92
+ route = "stt/websocket"
93
+ session = await self._get_session()
94
+
95
+ params = {
96
+ "model": model,
97
+ "api_key": self.api_key,
98
+ "cartesia_version": self.cartesia_version,
99
+ "encoding": encoding,
100
+ "sample_rate": str(sample_rate),
101
+ }
102
+ if language is not None:
103
+ params["language"] = language
104
+ if min_volume is not None:
105
+ params["min_volume"] = str(min_volume)
106
+ if max_silence_duration_secs is not None:
107
+ params["max_silence_duration_secs"] = str(max_silence_duration_secs)
108
+
109
+ query_string = "&".join([f"{k}={v}" for k, v in params.items()])
110
+ url = f"{self.ws_url}/{route}?{query_string}"
111
+
112
+ try:
113
+ if session is None:
114
+ raise RuntimeError("Session is not available")
115
+ self.websocket = await session.ws_connect(url)
116
+ except Exception as e:
117
+ status_code = None
118
+ error_message = str(e)
119
+
120
+ if hasattr(e, 'status') and e.status is not None:
121
+ status_code = e.status
122
+
123
+ if status_code == 402:
124
+ error_message = "Payment required. Your API key may have insufficient credits or permissions."
125
+ elif status_code == 401:
126
+ error_message = "Unauthorized. Please check your API key."
127
+ elif status_code == 403:
128
+ error_message = "Forbidden. You don't have permission to access this resource."
129
+ elif status_code == 404:
130
+ error_message = "Not found. The requested resource doesn't exist."
131
+
132
+ raise RuntimeError(f"Failed to connect to WebSocket.\nStatus: {status_code}. Error message: {error_message}")
133
+ else:
134
+ raise RuntimeError(f"Failed to connect to WebSocket at {url}. {e}")
135
+
136
+ def _is_websocket_closed(self):
137
+ return self.websocket is None or self.websocket.closed
138
+
139
+ async def close(self):
140
+ """This method closes the websocket connection. Highly recommended to call this method when done."""
141
+ if self.websocket is not None and not self._is_websocket_closed():
142
+ await self.websocket.close()
143
+ self.websocket = None
144
+
145
+ async def send(self, data: Union[bytes, str]):
146
+ """Send audio data or control commands to the WebSocket.
147
+
148
+ Args:
149
+ data: Binary audio data or text command ("finalize" or "done")
150
+ """
151
+ if self.websocket is None or self._is_websocket_closed():
152
+ await self.connect(
153
+ model=self._default_model,
154
+ language=self._default_language,
155
+ encoding=self._default_encoding,
156
+ sample_rate=self._default_sample_rate,
157
+ min_volume=self._default_min_volume,
158
+ max_silence_duration_secs=self._default_max_silence_duration_secs,
159
+ )
160
+
161
+ assert self.websocket is not None, "WebSocket should be connected after connect() call"
162
+
163
+ if isinstance(data, bytes):
164
+ await self.websocket.send_bytes(data)
165
+ elif isinstance(data, str):
166
+ await self.websocket.send_str(data)
167
+ else:
168
+ raise TypeError("Data must be bytes (audio) or str (command)")
169
+
170
+ async def receive(self) -> AsyncGenerator[Dict[str, Any], None]: # type: ignore[override]
171
+ """Receive transcription results from the WebSocket.
172
+
173
+ Yields:
174
+ Dictionary containing transcription results, flush_done, done, or error messages
175
+ """
176
+ if self.websocket is None or self._is_websocket_closed():
177
+ await self.connect(
178
+ model=self._default_model,
179
+ language=self._default_language,
180
+ encoding=self._default_encoding,
181
+ sample_rate=self._default_sample_rate,
182
+ min_volume=self._default_min_volume,
183
+ max_silence_duration_secs=self._default_max_silence_duration_secs,
184
+ )
185
+
186
+ assert self.websocket is not None, "WebSocket should be connected after connect() call"
187
+
188
+ try:
189
+ async for message in self.websocket:
190
+ if message.type == aiohttp.WSMsgType.TEXT:
191
+ raw_data = json.loads(message.data)
192
+
193
+ # Handle error responses
194
+ if raw_data.get("type") == "error":
195
+ raise RuntimeError(f"Error transcribing audio: {raw_data.get('message', 'Unknown error')}")
196
+
197
+ # Handle transcript responses with flexible parsing
198
+ if raw_data.get("type") == "transcript":
199
+ # Provide defaults for missing required fields
200
+ result = {
201
+ "type": raw_data["type"],
202
+ "request_id": raw_data.get("request_id", ""),
203
+ "text": raw_data.get("text", ""), # Default to empty string if missing
204
+ "is_final": raw_data.get("is_final", False), # Default to False if missing
205
+ }
206
+
207
+ # Add optional fields if present
208
+ if "duration" in raw_data:
209
+ result["duration"] = raw_data["duration"]
210
+ if "language" in raw_data:
211
+ result["language"] = raw_data["language"]
212
+ if "words" in raw_data:
213
+ result["words"] = raw_data["words"]
214
+
215
+ yield result
216
+
217
+ # Handle flush_done acknowledgment
218
+ elif raw_data.get("type") == "flush_done":
219
+ result = {
220
+ "type": raw_data["type"],
221
+ "request_id": raw_data.get("request_id", ""),
222
+ }
223
+ yield result
224
+
225
+ # Handle done acknowledgment
226
+ elif raw_data.get("type") == "done":
227
+ result = {
228
+ "type": raw_data["type"],
229
+ "request_id": raw_data.get("request_id", ""),
230
+ }
231
+ yield result
232
+ break # Exit the loop when done
233
+
234
+ elif message.type == aiohttp.WSMsgType.ERROR:
235
+ error_message = f"WebSocket error: {self.websocket.exception()}"
236
+ raise RuntimeError(error_message)
237
+ elif message.type == aiohttp.WSMsgType.CLOSE:
238
+ break # WebSocket was closed
239
+ except Exception as e:
240
+ await self.close()
241
+ raise e
242
+
243
+ async def transcribe( # type: ignore[override]
244
+ self,
245
+ audio_chunks: typing.AsyncIterator[bytes],
246
+ *,
247
+ model: str = "ink-whisper",
248
+ language: Optional[str] = "en",
249
+ encoding: SttEncoding = "pcm_s16le",
250
+ sample_rate: int = 16000,
251
+ min_volume: Optional[float] = None,
252
+ max_silence_duration_secs: Optional[float] = None,
253
+ ) -> AsyncGenerator[Dict[str, Any], None]:
254
+ """Transcribe audio chunks using the WebSocket.
255
+
256
+ Args:
257
+ audio_chunks: Async iterator of audio chunks as bytes
258
+ model: ID of the model to use for transcription (required)
259
+ language: The language of the input audio in ISO-639-1 format (defaults to "en")
260
+ encoding: The encoding format of the audio data (required)
261
+ sample_rate: The sample rate of the audio in Hz (required)
262
+ min_volume: Volume threshold for voice activity detection (0.0-1.0)
263
+ max_silence_duration_secs: Maximum duration of silence before endpointing
264
+
265
+ Yields:
266
+ Dictionary containing transcription results, flush_done, done, or error messages
267
+ """
268
+ await self.connect(
269
+ model=model,
270
+ language=language,
271
+ encoding=encoding,
272
+ sample_rate=sample_rate,
273
+ min_volume=min_volume,
274
+ max_silence_duration_secs=max_silence_duration_secs,
275
+ )
276
+
277
+ try:
278
+ # Send all audio chunks
279
+ async for chunk in audio_chunks:
280
+ await self.send(chunk)
281
+
282
+ # Send finalize command to flush remaining audio
283
+ await self.send("finalize")
284
+
285
+ # Send done command to close session cleanly
286
+ await self.send("done")
287
+
288
+ # Receive all responses until done
289
+ async for result in self.receive():
290
+ yield result
291
+
292
+ finally:
293
+ await self.close()