cartesia 2.0.5__py3-none-any.whl → 2.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. cartesia/__init__.py +14 -0
  2. cartesia/auth/client.py +8 -8
  3. cartesia/auth/requests/token_grant.py +7 -1
  4. cartesia/auth/requests/token_request.py +3 -3
  5. cartesia/auth/types/token_grant.py +7 -2
  6. cartesia/auth/types/token_request.py +3 -3
  7. cartesia/core/client_wrapper.py +1 -1
  8. cartesia/stt/__init__.py +6 -0
  9. cartesia/stt/_async_websocket.py +81 -72
  10. cartesia/stt/_websocket.py +42 -20
  11. cartesia/stt/client.py +456 -0
  12. cartesia/stt/requests/__init__.py +2 -0
  13. cartesia/stt/requests/streaming_transcription_response.py +2 -0
  14. cartesia/stt/requests/transcript_message.py +8 -1
  15. cartesia/stt/requests/transcription_response.py +8 -1
  16. cartesia/stt/requests/transcription_word.py +20 -0
  17. cartesia/stt/socket_client.py +52 -109
  18. cartesia/stt/types/__init__.py +4 -0
  19. cartesia/stt/types/streaming_transcription_response.py +2 -0
  20. cartesia/stt/types/stt_encoding.py +3 -1
  21. cartesia/stt/types/timestamp_granularity.py +5 -0
  22. cartesia/stt/types/transcript_message.py +7 -1
  23. cartesia/stt/types/transcription_response.py +7 -1
  24. cartesia/stt/types/transcription_word.py +32 -0
  25. cartesia/tts/__init__.py +8 -0
  26. cartesia/tts/client.py +50 -8
  27. cartesia/tts/requests/__init__.py +4 -0
  28. cartesia/tts/requests/generation_request.py +4 -4
  29. cartesia/tts/requests/sse_output_format.py +11 -0
  30. cartesia/tts/requests/ttssse_request.py +47 -0
  31. cartesia/tts/requests/web_socket_chunk_response.py +0 -3
  32. cartesia/tts/requests/web_socket_response.py +1 -2
  33. cartesia/tts/requests/web_socket_tts_request.py +9 -1
  34. cartesia/tts/types/__init__.py +4 -0
  35. cartesia/tts/types/generation_request.py +4 -4
  36. cartesia/tts/types/sse_output_format.py +22 -0
  37. cartesia/tts/types/ttssse_request.py +58 -0
  38. cartesia/tts/types/web_socket_chunk_response.py +1 -3
  39. cartesia/tts/types/web_socket_response.py +1 -2
  40. cartesia/tts/types/web_socket_tts_request.py +11 -3
  41. cartesia/voice_changer/requests/streaming_response.py +0 -2
  42. cartesia/voice_changer/types/streaming_response.py +0 -2
  43. {cartesia-2.0.5.dist-info → cartesia-2.0.6.dist-info}/METADATA +113 -16
  44. {cartesia-2.0.5.dist-info → cartesia-2.0.6.dist-info}/RECORD +45 -37
  45. {cartesia-2.0.5.dist-info → cartesia-2.0.6.dist-info}/WHEEL +0 -0
cartesia/__init__.py CHANGED
@@ -37,10 +37,13 @@ from .stt import (
37
37
  StreamingTranscriptionResponse_Transcript,
38
38
  StreamingTranscriptionResponse_TranscriptParams,
39
39
  SttEncoding,
40
+ TimestampGranularity,
40
41
  TranscriptMessage,
41
42
  TranscriptMessageParams,
42
43
  TranscriptionResponse,
43
44
  TranscriptionResponseParams,
45
+ TranscriptionWord,
46
+ TranscriptionWordParams,
44
47
  )
45
48
  from .tts import (
46
49
  CancelContextRequest,
@@ -72,6 +75,8 @@ from .tts import (
72
75
  RawOutputFormatParams,
73
76
  Speed,
74
77
  SpeedParams,
78
+ SseOutputFormat,
79
+ SseOutputFormatParams,
75
80
  SupportedLanguage,
76
81
  TtsRequest,
77
82
  TtsRequestEmbeddingSpecifier,
@@ -81,6 +86,8 @@ from .tts import (
81
86
  TtsRequestParams,
82
87
  TtsRequestVoiceSpecifier,
83
88
  TtsRequestVoiceSpecifierParams,
89
+ TtssseRequest,
90
+ TtssseRequestParams,
84
91
  WavOutputFormat,
85
92
  WavOutputFormatParams,
86
93
  WebSocketBaseResponse,
@@ -256,6 +263,8 @@ __all__ = [
256
263
  "RawOutputFormatParams",
257
264
  "Speed",
258
265
  "SpeedParams",
266
+ "SseOutputFormat",
267
+ "SseOutputFormatParams",
259
268
  "StreamingResponse",
260
269
  "StreamingResponseParams",
261
270
  "StreamingResponse_Chunk",
@@ -276,6 +285,7 @@ __all__ = [
276
285
  "StreamingTranscriptionResponse_TranscriptParams",
277
286
  "SttEncoding",
278
287
  "SupportedLanguage",
288
+ "TimestampGranularity",
279
289
  "TokenGrant",
280
290
  "TokenGrantParams",
281
291
  "TokenRequest",
@@ -286,6 +296,8 @@ __all__ = [
286
296
  "TranscriptMessageParams",
287
297
  "TranscriptionResponse",
288
298
  "TranscriptionResponseParams",
299
+ "TranscriptionWord",
300
+ "TranscriptionWordParams",
289
301
  "TtsRequest",
290
302
  "TtsRequestEmbeddingSpecifier",
291
303
  "TtsRequestEmbeddingSpecifierParams",
@@ -294,6 +306,8 @@ __all__ = [
294
306
  "TtsRequestParams",
295
307
  "TtsRequestVoiceSpecifier",
296
308
  "TtsRequestVoiceSpecifierParams",
309
+ "TtssseRequest",
310
+ "TtssseRequestParams",
297
311
  "UpdateVoiceRequest",
298
312
  "UpdateVoiceRequestParams",
299
313
  "Voice",
cartesia/auth/client.py CHANGED
@@ -22,7 +22,7 @@ class AuthClient:
22
22
  def access_token(
23
23
  self,
24
24
  *,
25
- grants: TokenGrantParams,
25
+ grants: typing.Optional[TokenGrantParams] = OMIT,
26
26
  expires_in: typing.Optional[int] = OMIT,
27
27
  request_options: typing.Optional[RequestOptions] = None,
28
28
  ) -> TokenResponse:
@@ -31,8 +31,8 @@ class AuthClient:
31
31
 
32
32
  Parameters
33
33
  ----------
34
- grants : TokenGrantParams
35
- The permissions to be granted via the token.
34
+ grants : typing.Optional[TokenGrantParams]
35
+ The permissions to be granted via the token. Both TTS and STT grants are optional - specify only the capabilities you need.
36
36
 
37
37
  expires_in : typing.Optional[int]
38
38
  The number of seconds the token will be valid for since the time of generation. The maximum is 1 hour (3600 seconds).
@@ -52,7 +52,7 @@ class AuthClient:
52
52
  api_key="YOUR_API_KEY",
53
53
  )
54
54
  client.auth.access_token(
55
- grants={"tts": True},
55
+ grants={"tts": True, "stt": True},
56
56
  expires_in=60,
57
57
  )
58
58
  """
@@ -90,7 +90,7 @@ class AsyncAuthClient:
90
90
  async def access_token(
91
91
  self,
92
92
  *,
93
- grants: TokenGrantParams,
93
+ grants: typing.Optional[TokenGrantParams] = OMIT,
94
94
  expires_in: typing.Optional[int] = OMIT,
95
95
  request_options: typing.Optional[RequestOptions] = None,
96
96
  ) -> TokenResponse:
@@ -99,8 +99,8 @@ class AsyncAuthClient:
99
99
 
100
100
  Parameters
101
101
  ----------
102
- grants : TokenGrantParams
103
- The permissions to be granted via the token.
102
+ grants : typing.Optional[TokenGrantParams]
103
+ The permissions to be granted via the token. Both TTS and STT grants are optional - specify only the capabilities you need.
104
104
 
105
105
  expires_in : typing.Optional[int]
106
106
  The number of seconds the token will be valid for since the time of generation. The maximum is 1 hour (3600 seconds).
@@ -125,7 +125,7 @@ class AsyncAuthClient:
125
125
 
126
126
  async def main() -> None:
127
127
  await client.auth.access_token(
128
- grants={"tts": True},
128
+ grants={"tts": True, "stt": True},
129
129
  expires_in=60,
130
130
  )
131
131
 
@@ -1,10 +1,16 @@
1
1
  # This file was auto-generated by Fern from our API Definition.
2
2
 
3
3
  import typing_extensions
4
+ import typing_extensions
4
5
 
5
6
 
6
7
  class TokenGrantParams(typing_extensions.TypedDict):
7
- tts: bool
8
+ tts: typing_extensions.NotRequired[bool]
8
9
  """
9
10
  The `tts` grant allows the token to be used to access any TTS endpoint.
10
11
  """
12
+
13
+ stt: typing_extensions.NotRequired[bool]
14
+ """
15
+ The `stt` grant allows the token to be used to access any STT endpoint.
16
+ """
@@ -1,14 +1,14 @@
1
1
  # This file was auto-generated by Fern from our API Definition.
2
2
 
3
3
  import typing_extensions
4
- from .token_grant import TokenGrantParams
5
4
  import typing_extensions
5
+ from .token_grant import TokenGrantParams
6
6
 
7
7
 
8
8
  class TokenRequestParams(typing_extensions.TypedDict):
9
- grants: TokenGrantParams
9
+ grants: typing_extensions.NotRequired[TokenGrantParams]
10
10
  """
11
- The permissions to be granted via the token.
11
+ The permissions to be granted via the token. Both TTS and STT grants are optional - specify only the capabilities you need.
12
12
  """
13
13
 
14
14
  expires_in: typing_extensions.NotRequired[int]
@@ -1,17 +1,22 @@
1
1
  # This file was auto-generated by Fern from our API Definition.
2
2
 
3
3
  from ...core.pydantic_utilities import UniversalBaseModel
4
+ import typing
4
5
  import pydantic
5
6
  from ...core.pydantic_utilities import IS_PYDANTIC_V2
6
- import typing
7
7
 
8
8
 
9
9
  class TokenGrant(UniversalBaseModel):
10
- tts: bool = pydantic.Field()
10
+ tts: typing.Optional[bool] = pydantic.Field(default=None)
11
11
  """
12
12
  The `tts` grant allows the token to be used to access any TTS endpoint.
13
13
  """
14
14
 
15
+ stt: typing.Optional[bool] = pydantic.Field(default=None)
16
+ """
17
+ The `stt` grant allows the token to be used to access any STT endpoint.
18
+ """
19
+
15
20
  if IS_PYDANTIC_V2:
16
21
  model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
17
22
  else:
@@ -1,16 +1,16 @@
1
1
  # This file was auto-generated by Fern from our API Definition.
2
2
 
3
3
  from ...core.pydantic_utilities import UniversalBaseModel
4
+ import typing
4
5
  from .token_grant import TokenGrant
5
6
  import pydantic
6
- import typing
7
7
  from ...core.pydantic_utilities import IS_PYDANTIC_V2
8
8
 
9
9
 
10
10
  class TokenRequest(UniversalBaseModel):
11
- grants: TokenGrant = pydantic.Field()
11
+ grants: typing.Optional[TokenGrant] = pydantic.Field(default=None)
12
12
  """
13
- The permissions to be granted via the token.
13
+ The permissions to be granted via the token. Both TTS and STT grants are optional - specify only the capabilities you need.
14
14
  """
15
15
 
16
16
  expires_in: typing.Optional[int] = pydantic.Field(default=None)
@@ -16,7 +16,7 @@ class BaseClientWrapper:
16
16
  headers: typing.Dict[str, str] = {
17
17
  "X-Fern-Language": "Python",
18
18
  "X-Fern-SDK-Name": "cartesia",
19
- "X-Fern-SDK-Version": "2.0.5",
19
+ "X-Fern-SDK-Version": "2.0.6",
20
20
  }
21
21
  headers["X-API-Key"] = self.api_key
22
22
  headers["Cartesia-Version"] = "2024-11-13"
cartesia/stt/__init__.py CHANGED
@@ -10,8 +10,10 @@ from .types import (
10
10
  StreamingTranscriptionResponse_FlushDone,
11
11
  StreamingTranscriptionResponse_Transcript,
12
12
  SttEncoding,
13
+ TimestampGranularity,
13
14
  TranscriptMessage,
14
15
  TranscriptionResponse,
16
+ TranscriptionWord,
15
17
  )
16
18
  from .requests import (
17
19
  DoneMessageParams,
@@ -24,6 +26,7 @@ from .requests import (
24
26
  StreamingTranscriptionResponse_TranscriptParams,
25
27
  TranscriptMessageParams,
26
28
  TranscriptionResponseParams,
29
+ TranscriptionWordParams,
27
30
  )
28
31
 
29
32
  __all__ = [
@@ -44,8 +47,11 @@ __all__ = [
44
47
  "StreamingTranscriptionResponse_Transcript",
45
48
  "StreamingTranscriptionResponse_TranscriptParams",
46
49
  "SttEncoding",
50
+ "TimestampGranularity",
47
51
  "TranscriptMessage",
48
52
  "TranscriptMessageParams",
49
53
  "TranscriptionResponse",
50
54
  "TranscriptionResponseParams",
55
+ "TranscriptionWord",
56
+ "TranscriptionWordParams",
51
57
  ]
@@ -11,6 +11,7 @@ from cartesia.stt.types import (
11
11
  StreamingTranscriptionResponse_Error,
12
12
  StreamingTranscriptionResponse_Transcript,
13
13
  )
14
+ from cartesia.stt.types.stt_encoding import SttEncoding
14
15
 
15
16
  from ..core.pydantic_utilities import parse_obj_as
16
17
  from ._websocket import SttWebsocket
@@ -41,8 +42,10 @@ class AsyncSttWebsocket(SttWebsocket):
41
42
  self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None
42
43
  self._default_model: str = "ink-whisper"
43
44
  self._default_language: Optional[str] = "en"
44
- self._default_encoding: Optional[str] = "pcm_s16le"
45
+ self._default_encoding: SttEncoding = "pcm_s16le"
45
46
  self._default_sample_rate: int = 16000
47
+ self._default_min_volume: Optional[float] = None
48
+ self._default_max_silence_duration_secs: Optional[float] = None
46
49
 
47
50
  def __del__(self):
48
51
  try:
@@ -60,16 +63,20 @@ class AsyncSttWebsocket(SttWebsocket):
60
63
  *,
61
64
  model: str = "ink-whisper",
62
65
  language: Optional[str] = "en",
63
- encoding: Optional[str] = "pcm_s16le",
66
+ encoding: SttEncoding = "pcm_s16le",
64
67
  sample_rate: int = 16000,
68
+ min_volume: Optional[float] = None,
69
+ max_silence_duration_secs: Optional[float] = None,
65
70
  ):
66
71
  """Connect to the STT WebSocket with the specified parameters.
67
72
 
68
73
  Args:
69
- model: ID of the model to use for transcription
70
- language: The language of the input audio in ISO-639-1 format
71
- encoding: The encoding format of the audio data
72
- sample_rate: The sample rate of the audio in Hz
74
+ model: ID of the model to use for transcription (required)
75
+ language: The language of the input audio in ISO-639-1 format (defaults to "en")
76
+ encoding: The encoding format of the audio data (required)
77
+ sample_rate: The sample rate of the audio in Hz (required)
78
+ min_volume: Volume threshold for voice activity detection (0.0-1.0)
79
+ max_silence_duration_secs: Maximum duration of silence before endpointing
73
80
 
74
81
  Raises:
75
82
  RuntimeError: If the connection to the WebSocket fails.
@@ -78,6 +85,8 @@ class AsyncSttWebsocket(SttWebsocket):
78
85
  self._default_language = language
79
86
  self._default_encoding = encoding
80
87
  self._default_sample_rate = sample_rate
88
+ self._default_min_volume = min_volume
89
+ self._default_max_silence_duration_secs = max_silence_duration_secs
81
90
 
82
91
  if self.websocket is None or self._is_websocket_closed():
83
92
  route = "stt/websocket"
@@ -87,13 +96,15 @@ class AsyncSttWebsocket(SttWebsocket):
87
96
  "model": model,
88
97
  "api_key": self.api_key,
89
98
  "cartesia_version": self.cartesia_version,
99
+ "encoding": encoding,
100
+ "sample_rate": str(sample_rate),
90
101
  }
91
102
  if language is not None:
92
103
  params["language"] = language
93
- if encoding is not None:
94
- params["encoding"] = encoding
95
- if sample_rate is not None:
96
- params["sample_rate"] = str(sample_rate)
104
+ if min_volume is not None:
105
+ params["min_volume"] = str(min_volume)
106
+ if max_silence_duration_secs is not None:
107
+ params["max_silence_duration_secs"] = str(max_silence_duration_secs)
97
108
 
98
109
  query_string = "&".join([f"{k}={v}" for k, v in params.items()])
99
110
  url = f"{self.ws_url}/{route}?{query_string}"
@@ -143,6 +154,8 @@ class AsyncSttWebsocket(SttWebsocket):
143
154
  language=self._default_language,
144
155
  encoding=self._default_encoding,
145
156
  sample_rate=self._default_sample_rate,
157
+ min_volume=self._default_min_volume,
158
+ max_silence_duration_secs=self._default_max_silence_duration_secs,
146
159
  )
147
160
 
148
161
  assert self.websocket is not None, "WebSocket should be connected after connect() call"
@@ -166,76 +179,66 @@ class AsyncSttWebsocket(SttWebsocket):
166
179
  language=self._default_language,
167
180
  encoding=self._default_encoding,
168
181
  sample_rate=self._default_sample_rate,
182
+ min_volume=self._default_min_volume,
183
+ max_silence_duration_secs=self._default_max_silence_duration_secs,
169
184
  )
170
185
 
171
186
  assert self.websocket is not None, "WebSocket should be connected after connect() call"
172
187
 
173
188
  try:
174
- while True:
175
- try:
176
- msg = await asyncio.wait_for(self.websocket.receive(), timeout=self.timeout)
189
+ async for message in self.websocket:
190
+ if message.type == aiohttp.WSMsgType.TEXT:
191
+ raw_data = json.loads(message.data)
177
192
 
178
- if msg.type == aiohttp.WSMsgType.TEXT:
179
- raw_data = json.loads(msg.data)
180
-
181
- # Handle error responses
182
- if raw_data.get("type") == "error":
183
- raise RuntimeError(f"Error transcribing audio: {raw_data.get('message', 'Unknown error')}")
184
-
185
- # Handle transcript responses with flexible parsing
186
- if raw_data.get("type") == "transcript":
187
- # Provide defaults for missing required fields
188
- result = {
189
- "type": raw_data["type"],
190
- "request_id": raw_data.get("request_id", ""),
191
- "text": raw_data.get("text", ""), # Default to empty string if missing
192
- "is_final": raw_data.get("is_final", False), # Default to False if missing
193
- }
194
-
195
- # Add optional fields if present
196
- if "duration" in raw_data:
197
- result["duration"] = raw_data["duration"]
198
- if "language" in raw_data:
199
- result["language"] = raw_data["language"]
200
-
201
- yield result
193
+ # Handle error responses
194
+ if raw_data.get("type") == "error":
195
+ raise RuntimeError(f"Error transcribing audio: {raw_data.get('message', 'Unknown error')}")
196
+
197
+ # Handle transcript responses with flexible parsing
198
+ if raw_data.get("type") == "transcript":
199
+ # Provide defaults for missing required fields
200
+ result = {
201
+ "type": raw_data["type"],
202
+ "request_id": raw_data.get("request_id", ""),
203
+ "text": raw_data.get("text", ""), # Default to empty string if missing
204
+ "is_final": raw_data.get("is_final", False), # Default to False if missing
205
+ }
202
206
 
203
- # Handle flush_done acknowledgment
204
- elif raw_data.get("type") == "flush_done":
205
- result = {
206
- "type": raw_data["type"],
207
- "request_id": raw_data.get("request_id", ""),
208
- }
209
- yield result
207
+ # Add optional fields if present
208
+ if "duration" in raw_data:
209
+ result["duration"] = raw_data["duration"]
210
+ if "language" in raw_data:
211
+ result["language"] = raw_data["language"]
212
+ if "words" in raw_data:
213
+ result["words"] = raw_data["words"]
210
214
 
211
- # Handle done acknowledgment - session complete
212
- elif raw_data.get("type") == "done":
213
- result = {
214
- "type": raw_data["type"],
215
- "request_id": raw_data.get("request_id", ""),
216
- }
217
- yield result
218
- # Session is complete, break out of loop
219
- break
215
+ yield result
220
216
 
221
- elif msg.type == aiohttp.WSMsgType.ERROR:
222
- websocket_exception = self.websocket.exception() if self.websocket else None
223
- await self.close()
224
- raise RuntimeError(f"WebSocket error: {websocket_exception}")
217
+ # Handle flush_done acknowledgment
218
+ elif raw_data.get("type") == "flush_done":
219
+ result = {
220
+ "type": raw_data["type"],
221
+ "request_id": raw_data.get("request_id", ""),
222
+ }
223
+ yield result
225
224
 
226
- elif msg.type == aiohttp.WSMsgType.CLOSE:
227
- break
225
+ # Handle done acknowledgment
226
+ elif raw_data.get("type") == "done":
227
+ result = {
228
+ "type": raw_data["type"],
229
+ "request_id": raw_data.get("request_id", ""),
230
+ }
231
+ yield result
232
+ break # Exit the loop when done
228
233
 
229
- except asyncio.TimeoutError:
230
- await self.close()
231
- raise RuntimeError("Timeout while waiting for transcription")
232
- except Exception as inner_e:
233
- await self.close()
234
- raise RuntimeError(f"Error receiving transcription: {inner_e}")
235
-
234
+ elif message.type == aiohttp.WSMsgType.ERROR:
235
+ error_message = f"WebSocket error: {self.websocket.exception()}"
236
+ raise RuntimeError(error_message)
237
+ elif message.type == aiohttp.WSMsgType.CLOSE:
238
+ break # WebSocket was closed
236
239
  except Exception as e:
237
240
  await self.close()
238
- raise RuntimeError(f"Failed to receive transcription. {e}")
241
+ raise e
239
242
 
240
243
  async def transcribe( # type: ignore[override]
241
244
  self,
@@ -243,17 +246,21 @@ class AsyncSttWebsocket(SttWebsocket):
243
246
  *,
244
247
  model: str = "ink-whisper",
245
248
  language: Optional[str] = "en",
246
- encoding: Optional[str] = "pcm_s16le",
249
+ encoding: SttEncoding = "pcm_s16le",
247
250
  sample_rate: int = 16000,
251
+ min_volume: Optional[float] = None,
252
+ max_silence_duration_secs: Optional[float] = None,
248
253
  ) -> AsyncGenerator[Dict[str, Any], None]:
249
254
  """Transcribe audio chunks using the WebSocket.
250
255
 
251
256
  Args:
252
257
  audio_chunks: Async iterator of audio chunks as bytes
253
- model: ID of the model to use for transcription
254
- language: The language of the input audio in ISO-639-1 format
255
- encoding: The encoding format of the audio data
256
- sample_rate: The sample rate of the audio in Hz
258
+ model: ID of the model to use for transcription (required)
259
+ language: The language of the input audio in ISO-639-1 format (defaults to "en")
260
+ encoding: The encoding format of the audio data (required)
261
+ sample_rate: The sample rate of the audio in Hz (required)
262
+ min_volume: Volume threshold for voice activity detection (0.0-1.0)
263
+ max_silence_duration_secs: Maximum duration of silence before endpointing
257
264
 
258
265
  Yields:
259
266
  Dictionary containing transcription results, flush_done, done, or error messages
@@ -263,6 +270,8 @@ class AsyncSttWebsocket(SttWebsocket):
263
270
  language=language,
264
271
  encoding=encoding,
265
272
  sample_rate=sample_rate,
273
+ min_volume=min_volume,
274
+ max_silence_duration_secs=max_silence_duration_secs,
266
275
  )
267
276
 
268
277
  try:
@@ -14,6 +14,7 @@ from cartesia.stt.types import (
14
14
  StreamingTranscriptionResponse_Error,
15
15
  StreamingTranscriptionResponse_Transcript,
16
16
  )
17
+ from cartesia.stt.types.stt_encoding import SttEncoding
17
18
 
18
19
  from ..core.pydantic_utilities import parse_obj_as
19
20
 
@@ -45,8 +46,10 @@ class SttWebsocket:
45
46
  # Store default connection parameters for auto-connect with proper typing
46
47
  self._default_model: str = "ink-whisper"
47
48
  self._default_language: Optional[str] = "en"
48
- self._default_encoding: Optional[str] = "pcm_s16le"
49
+ self._default_encoding: SttEncoding = "pcm_s16le"
49
50
  self._default_sample_rate: int = 16000
51
+ self._default_min_volume: Optional[float] = None
52
+ self._default_max_silence_duration_secs: Optional[float] = None
50
53
 
51
54
  def __del__(self):
52
55
  try:
@@ -59,16 +62,20 @@ class SttWebsocket:
59
62
  *,
60
63
  model: str = "ink-whisper",
61
64
  language: Optional[str] = "en",
62
- encoding: Optional[str] = "pcm_s16le",
65
+ encoding: SttEncoding = "pcm_s16le",
63
66
  sample_rate: int = 16000,
67
+ min_volume: Optional[float] = None,
68
+ max_silence_duration_secs: Optional[float] = None,
64
69
  ):
65
70
  """Connect to the STT WebSocket with the specified parameters.
66
71
 
67
72
  Args:
68
73
  model: ID of the model to use for transcription
69
74
  language: The language of the input audio in ISO-639-1 format
70
- encoding: The encoding format of the audio data
71
- sample_rate: The sample rate of the audio in Hz
75
+ encoding: The encoding format of the audio data (required)
76
+ sample_rate: The sample rate of the audio in Hz (required)
77
+ min_volume: Volume threshold for voice activity detection (0.0-1.0)
78
+ max_silence_duration_secs: Maximum duration of silence before endpointing
72
79
 
73
80
  Raises:
74
81
  RuntimeError: If the connection to the WebSocket fails.
@@ -78,6 +85,8 @@ class SttWebsocket:
78
85
  self._default_language = language
79
86
  self._default_encoding = encoding
80
87
  self._default_sample_rate = sample_rate
88
+ self._default_min_volume = min_volume
89
+ self._default_max_silence_duration_secs = max_silence_duration_secs
81
90
 
82
91
  if not IS_WEBSOCKET_SYNC_AVAILABLE:
83
92
  raise ImportError(
@@ -89,13 +98,15 @@ class SttWebsocket:
89
98
  "model": model,
90
99
  "api_key": self.api_key,
91
100
  "cartesia_version": self.cartesia_version,
101
+ "encoding": encoding,
102
+ "sample_rate": str(sample_rate),
92
103
  }
93
104
  if language is not None:
94
105
  params["language"] = language
95
- if encoding is not None:
96
- params["encoding"] = encoding
97
- if sample_rate is not None:
98
- params["sample_rate"] = str(sample_rate)
106
+ if min_volume is not None:
107
+ params["min_volume"] = str(min_volume)
108
+ if max_silence_duration_secs is not None:
109
+ params["max_silence_duration_secs"] = str(max_silence_duration_secs)
99
110
 
100
111
  query_string = "&".join([f"{k}={v}" for k, v in params.items()])
101
112
  url = f"{self.ws_url}/{route}?{query_string}"
@@ -143,6 +154,8 @@ class SttWebsocket:
143
154
  language=self._default_language,
144
155
  encoding=self._default_encoding,
145
156
  sample_rate=self._default_sample_rate,
157
+ min_volume=self._default_min_volume,
158
+ max_silence_duration_secs=self._default_max_silence_duration_secs,
146
159
  )
147
160
 
148
161
  assert self.websocket is not None, "WebSocket should be connected after connect() call"
@@ -167,6 +180,8 @@ class SttWebsocket:
167
180
  language=self._default_language,
168
181
  encoding=self._default_encoding,
169
182
  sample_rate=self._default_sample_rate,
183
+ min_volume=self._default_min_volume,
184
+ max_silence_duration_secs=self._default_max_silence_duration_secs,
170
185
  )
171
186
 
172
187
  assert self.websocket is not None, "WebSocket should be connected after connect() call"
@@ -197,6 +212,8 @@ class SttWebsocket:
197
212
  result["duration"] = raw_data["duration"]
198
213
  if "language" in raw_data:
199
214
  result["language"] = raw_data["language"]
215
+ if "words" in raw_data:
216
+ result["words"] = raw_data["words"]
200
217
 
201
218
  yield result
202
219
 
@@ -208,23 +225,22 @@ class SttWebsocket:
208
225
  }
209
226
  yield result
210
227
 
211
- # Handle done acknowledgment - session complete
228
+ # Handle done acknowledgment
212
229
  elif raw_data.get("type") == "done":
213
230
  result = {
214
231
  "type": raw_data["type"],
215
232
  "request_id": raw_data.get("request_id", ""),
216
233
  }
217
234
  yield result
218
- # Session is complete, break out of loop
219
- break
220
-
221
- except Exception as inner_e:
222
- self.close()
223
- raise RuntimeError(f"Error receiving transcription: {inner_e}")
235
+ break # Exit the loop when done
224
236
 
225
- except Exception as e:
237
+ except Exception as e:
238
+ if "Connection closed" in str(e) or "no active connection" in str(e):
239
+ break # WebSocket was closed
240
+ raise e # Re-raise other exceptions
241
+ except KeyboardInterrupt:
226
242
  self.close()
227
- raise RuntimeError(f"Failed to receive transcription. {e}")
243
+ raise
228
244
 
229
245
  def transcribe(
230
246
  self,
@@ -232,8 +248,10 @@ class SttWebsocket:
232
248
  *,
233
249
  model: str = "ink-whisper",
234
250
  language: Optional[str] = "en",
235
- encoding: Optional[str] = "pcm_s16le",
251
+ encoding: SttEncoding = "pcm_s16le",
236
252
  sample_rate: int = 16000,
253
+ min_volume: Optional[float] = None,
254
+ max_silence_duration_secs: Optional[float] = None,
237
255
  ) -> Generator[Dict[str, Any], None, None]:
238
256
  """Transcribe audio chunks using the WebSocket.
239
257
 
@@ -241,8 +259,10 @@ class SttWebsocket:
241
259
  audio_chunks: Iterator of audio chunks as bytes
242
260
  model: ID of the model to use for transcription
243
261
  language: The language of the input audio in ISO-639-1 format
244
- encoding: The encoding format of the audio data
245
- sample_rate: The sample rate of the audio in Hz
262
+ encoding: The encoding format of the audio data (required)
263
+ sample_rate: The sample rate of the audio in Hz (required)
264
+ min_volume: Volume threshold for voice activity detection (0.0-1.0)
265
+ max_silence_duration_secs: Maximum duration of silence before endpointing
246
266
 
247
267
  Yields:
248
268
  Dictionary containing transcription results, flush_done, done, or error messages
@@ -252,6 +272,8 @@ class SttWebsocket:
252
272
  language=language,
253
273
  encoding=encoding,
254
274
  sample_rate=sample_rate,
275
+ min_volume=min_volume,
276
+ max_silence_duration_secs=max_silence_duration_secs,
255
277
  )
256
278
 
257
279
  try: