cartesia 2.0.4__py3-none-any.whl → 2.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cartesia/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # This file was auto-generated by Fern from our API Definition.
2
2
 
3
- from . import api_status, auth, datasets, embedding, infill, tts, voice_changer, voices
3
+ from . import api_status, auth, datasets, embedding, infill, stt, tts, voice_changer, voices
4
4
  from .api_status import ApiInfo, ApiInfoParams
5
5
  from .auth import TokenGrant, TokenGrantParams, TokenRequest, TokenRequestParams, TokenResponse, TokenResponseParams
6
6
  from .client import AsyncCartesia, Cartesia
@@ -19,6 +19,29 @@ from .datasets import (
19
19
  )
20
20
  from .embedding import Embedding
21
21
  from .environment import CartesiaEnvironment
22
+ from .stt import (
23
+ DoneMessage,
24
+ DoneMessageParams,
25
+ ErrorMessage,
26
+ ErrorMessageParams,
27
+ FlushDoneMessage,
28
+ FlushDoneMessageParams,
29
+ StreamingTranscriptionResponse,
30
+ StreamingTranscriptionResponseParams,
31
+ StreamingTranscriptionResponse_Done,
32
+ StreamingTranscriptionResponse_DoneParams,
33
+ StreamingTranscriptionResponse_Error,
34
+ StreamingTranscriptionResponse_ErrorParams,
35
+ StreamingTranscriptionResponse_FlushDone,
36
+ StreamingTranscriptionResponse_FlushDoneParams,
37
+ StreamingTranscriptionResponse_Transcript,
38
+ StreamingTranscriptionResponse_TranscriptParams,
39
+ SttEncoding,
40
+ TranscriptMessage,
41
+ TranscriptMessageParams,
42
+ TranscriptionResponse,
43
+ TranscriptionResponseParams,
44
+ )
22
45
  from .tts import (
23
46
  CancelContextRequest,
24
47
  CancelContextRequestParams,
@@ -173,13 +196,19 @@ __all__ = [
173
196
  "DatasetFile",
174
197
  "DatasetFileParams",
175
198
  "DatasetParams",
199
+ "DoneMessage",
200
+ "DoneMessageParams",
176
201
  "Embedding",
177
202
  "EmbeddingResponse",
178
203
  "EmbeddingResponseParams",
179
204
  "EmbeddingSpecifier",
180
205
  "EmbeddingSpecifierParams",
181
206
  "Emotion",
207
+ "ErrorMessage",
208
+ "ErrorMessageParams",
182
209
  "FilePurpose",
210
+ "FlushDoneMessage",
211
+ "FlushDoneMessageParams",
183
212
  "FlushId",
184
213
  "Gender",
185
214
  "GenderPresentation",
@@ -235,6 +264,17 @@ __all__ = [
235
264
  "StreamingResponse_DoneParams",
236
265
  "StreamingResponse_Error",
237
266
  "StreamingResponse_ErrorParams",
267
+ "StreamingTranscriptionResponse",
268
+ "StreamingTranscriptionResponseParams",
269
+ "StreamingTranscriptionResponse_Done",
270
+ "StreamingTranscriptionResponse_DoneParams",
271
+ "StreamingTranscriptionResponse_Error",
272
+ "StreamingTranscriptionResponse_ErrorParams",
273
+ "StreamingTranscriptionResponse_FlushDone",
274
+ "StreamingTranscriptionResponse_FlushDoneParams",
275
+ "StreamingTranscriptionResponse_Transcript",
276
+ "StreamingTranscriptionResponse_TranscriptParams",
277
+ "SttEncoding",
238
278
  "SupportedLanguage",
239
279
  "TokenGrant",
240
280
  "TokenGrantParams",
@@ -242,6 +282,10 @@ __all__ = [
242
282
  "TokenRequestParams",
243
283
  "TokenResponse",
244
284
  "TokenResponseParams",
285
+ "TranscriptMessage",
286
+ "TranscriptMessageParams",
287
+ "TranscriptionResponse",
288
+ "TranscriptionResponseParams",
245
289
  "TtsRequest",
246
290
  "TtsRequestEmbeddingSpecifier",
247
291
  "TtsRequestEmbeddingSpecifierParams",
@@ -307,6 +351,7 @@ __all__ = [
307
351
  "datasets",
308
352
  "embedding",
309
353
  "infill",
354
+ "stt",
310
355
  "tts",
311
356
  "voice_changer",
312
357
  "voices",
cartesia/base_client.py CHANGED
@@ -7,6 +7,7 @@ from .core.client_wrapper import SyncClientWrapper
7
7
  from .api_status.client import ApiStatusClient
8
8
  from .auth.client import AuthClient
9
9
  from .infill.client import InfillClient
10
+ from .stt.socket_client import AsyncSttClientWithWebsocket, SttClientWithWebsocket
10
11
  from .tts.client import TtsClient
11
12
  from .voice_changer.client import VoiceChangerClient
12
13
  from .voices.client import VoicesClient
@@ -80,6 +81,7 @@ class BaseCartesia:
80
81
  self.api_status = ApiStatusClient(client_wrapper=self._client_wrapper)
81
82
  self.auth = AuthClient(client_wrapper=self._client_wrapper)
82
83
  self.infill = InfillClient(client_wrapper=self._client_wrapper)
84
+ self.stt = SttClientWithWebsocket(client_wrapper=self._client_wrapper)
83
85
  self.tts = TtsClient(client_wrapper=self._client_wrapper)
84
86
  self.voice_changer = VoiceChangerClient(client_wrapper=self._client_wrapper)
85
87
  self.voices = VoicesClient(client_wrapper=self._client_wrapper)
cartesia/client.py CHANGED
@@ -8,6 +8,7 @@ import httpx
8
8
 
9
9
  from .base_client import AsyncBaseCartesia, BaseCartesia
10
10
  from .environment import CartesiaEnvironment
11
+ from .stt.socket_client import AsyncSttClientWithWebsocket, SttClientWithWebsocket
11
12
  from .tts.socket_client import AsyncTtsClientWithWebsocket, TtsClientWithWebsocket
12
13
 
13
14
 
@@ -66,6 +67,7 @@ class Cartesia(BaseCartesia):
66
67
  follow_redirects=follow_redirects,
67
68
  httpx_client=httpx_client,
68
69
  )
70
+ self.stt = SttClientWithWebsocket(client_wrapper=self._client_wrapper)
69
71
  self.tts = TtsClientWithWebsocket(client_wrapper=self._client_wrapper)
70
72
 
71
73
  def __enter__(self):
@@ -143,6 +145,9 @@ class AsyncCartesia(AsyncBaseCartesia):
143
145
  self._session = None
144
146
  self._loop = None
145
147
  self.max_num_connections = max_num_connections
148
+ self.stt = AsyncSttClientWithWebsocket(
149
+ client_wrapper=self._client_wrapper, get_session=self._get_session
150
+ )
146
151
  self.tts = AsyncTtsClientWithWebsocket(
147
152
  client_wrapper=self._client_wrapper, get_session=self._get_session
148
153
  )
@@ -16,7 +16,7 @@ class BaseClientWrapper:
16
16
  headers: typing.Dict[str, str] = {
17
17
  "X-Fern-Language": "Python",
18
18
  "X-Fern-SDK-Name": "cartesia",
19
- "X-Fern-SDK-Version": "2.0.4",
19
+ "X-Fern-SDK-Version": "2.0.5",
20
20
  }
21
21
  headers["X-API-Key"] = self.api_key
22
22
  headers["Cartesia-Version"] = "2024-11-13"
@@ -0,0 +1,51 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ from .types import (
4
+ DoneMessage,
5
+ ErrorMessage,
6
+ FlushDoneMessage,
7
+ StreamingTranscriptionResponse,
8
+ StreamingTranscriptionResponse_Done,
9
+ StreamingTranscriptionResponse_Error,
10
+ StreamingTranscriptionResponse_FlushDone,
11
+ StreamingTranscriptionResponse_Transcript,
12
+ SttEncoding,
13
+ TranscriptMessage,
14
+ TranscriptionResponse,
15
+ )
16
+ from .requests import (
17
+ DoneMessageParams,
18
+ ErrorMessageParams,
19
+ FlushDoneMessageParams,
20
+ StreamingTranscriptionResponseParams,
21
+ StreamingTranscriptionResponse_DoneParams,
22
+ StreamingTranscriptionResponse_ErrorParams,
23
+ StreamingTranscriptionResponse_FlushDoneParams,
24
+ StreamingTranscriptionResponse_TranscriptParams,
25
+ TranscriptMessageParams,
26
+ TranscriptionResponseParams,
27
+ )
28
+
29
+ __all__ = [
30
+ "DoneMessage",
31
+ "DoneMessageParams",
32
+ "ErrorMessage",
33
+ "ErrorMessageParams",
34
+ "FlushDoneMessage",
35
+ "FlushDoneMessageParams",
36
+ "StreamingTranscriptionResponse",
37
+ "StreamingTranscriptionResponseParams",
38
+ "StreamingTranscriptionResponse_Done",
39
+ "StreamingTranscriptionResponse_DoneParams",
40
+ "StreamingTranscriptionResponse_Error",
41
+ "StreamingTranscriptionResponse_ErrorParams",
42
+ "StreamingTranscriptionResponse_FlushDone",
43
+ "StreamingTranscriptionResponse_FlushDoneParams",
44
+ "StreamingTranscriptionResponse_Transcript",
45
+ "StreamingTranscriptionResponse_TranscriptParams",
46
+ "SttEncoding",
47
+ "TranscriptMessage",
48
+ "TranscriptMessageParams",
49
+ "TranscriptionResponse",
50
+ "TranscriptionResponseParams",
51
+ ]
@@ -0,0 +1,284 @@
1
+ import asyncio
2
+ import json
3
+ import typing
4
+ import uuid
5
+ from typing import Any, Awaitable, AsyncGenerator, Callable, Dict, Optional, Union
6
+
7
+ import aiohttp
8
+
9
+ from cartesia.stt.types import (
10
+ StreamingTranscriptionResponse,
11
+ StreamingTranscriptionResponse_Error,
12
+ StreamingTranscriptionResponse_Transcript,
13
+ )
14
+
15
+ from ..core.pydantic_utilities import parse_obj_as
16
+ from ._websocket import SttWebsocket
17
+
18
+
19
+ class AsyncSttWebsocket(SttWebsocket):
20
+ """This class contains methods to transcribe audio using WebSocket asynchronously."""
21
+
22
+ def __init__(
23
+ self,
24
+ ws_url: str,
25
+ api_key: str,
26
+ cartesia_version: str,
27
+ get_session: Callable[[], Awaitable[Optional[aiohttp.ClientSession]]],
28
+ timeout: float = 30,
29
+ ):
30
+ """
31
+ Args:
32
+ ws_url: The WebSocket URL for the Cartesia API.
33
+ api_key: The API key to use for authorization.
34
+ cartesia_version: The version of the Cartesia API to use.
35
+ timeout: The timeout for responses on the WebSocket in seconds.
36
+ get_session: A function that returns an awaitable of aiohttp.ClientSession object.
37
+ """
38
+ super().__init__(ws_url, api_key, cartesia_version)
39
+ self.timeout = timeout
40
+ self._get_session = get_session
41
+ self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None
42
+ self._default_model: str = "ink-whisper"
43
+ self._default_language: Optional[str] = "en"
44
+ self._default_encoding: Optional[str] = "pcm_s16le"
45
+ self._default_sample_rate: int = 16000
46
+
47
+ def __del__(self):
48
+ try:
49
+ loop = asyncio.get_running_loop()
50
+ except RuntimeError:
51
+ loop = None
52
+
53
+ if loop is None:
54
+ asyncio.run(self.close())
55
+ elif loop.is_running():
56
+ loop.create_task(self.close())
57
+
58
+ async def connect(
59
+ self,
60
+ *,
61
+ model: str = "ink-whisper",
62
+ language: Optional[str] = "en",
63
+ encoding: Optional[str] = "pcm_s16le",
64
+ sample_rate: int = 16000,
65
+ ):
66
+ """Connect to the STT WebSocket with the specified parameters.
67
+
68
+ Args:
69
+ model: ID of the model to use for transcription
70
+ language: The language of the input audio in ISO-639-1 format
71
+ encoding: The encoding format of the audio data
72
+ sample_rate: The sample rate of the audio in Hz
73
+
74
+ Raises:
75
+ RuntimeError: If the connection to the WebSocket fails.
76
+ """
77
+ self._default_model = model
78
+ self._default_language = language
79
+ self._default_encoding = encoding
80
+ self._default_sample_rate = sample_rate
81
+
82
+ if self.websocket is None or self._is_websocket_closed():
83
+ route = "stt/websocket"
84
+ session = await self._get_session()
85
+
86
+ params = {
87
+ "model": model,
88
+ "api_key": self.api_key,
89
+ "cartesia_version": self.cartesia_version,
90
+ }
91
+ if language is not None:
92
+ params["language"] = language
93
+ if encoding is not None:
94
+ params["encoding"] = encoding
95
+ if sample_rate is not None:
96
+ params["sample_rate"] = str(sample_rate)
97
+
98
+ query_string = "&".join([f"{k}={v}" for k, v in params.items()])
99
+ url = f"{self.ws_url}/{route}?{query_string}"
100
+
101
+ try:
102
+ if session is None:
103
+ raise RuntimeError("Session is not available")
104
+ self.websocket = await session.ws_connect(url)
105
+ except Exception as e:
106
+ status_code = None
107
+ error_message = str(e)
108
+
109
+ if hasattr(e, 'status') and e.status is not None:
110
+ status_code = e.status
111
+
112
+ if status_code == 402:
113
+ error_message = "Payment required. Your API key may have insufficient credits or permissions."
114
+ elif status_code == 401:
115
+ error_message = "Unauthorized. Please check your API key."
116
+ elif status_code == 403:
117
+ error_message = "Forbidden. You don't have permission to access this resource."
118
+ elif status_code == 404:
119
+ error_message = "Not found. The requested resource doesn't exist."
120
+
121
+ raise RuntimeError(f"Failed to connect to WebSocket.\nStatus: {status_code}. Error message: {error_message}")
122
+ else:
123
+ raise RuntimeError(f"Failed to connect to WebSocket at {url}. {e}")
124
+
125
+ def _is_websocket_closed(self):
126
+ return self.websocket is None or self.websocket.closed
127
+
128
+ async def close(self):
129
+ """This method closes the websocket connection. Highly recommended to call this method when done."""
130
+ if self.websocket is not None and not self._is_websocket_closed():
131
+ await self.websocket.close()
132
+ self.websocket = None
133
+
134
+ async def send(self, data: Union[bytes, str]):
135
+ """Send audio data or control commands to the WebSocket.
136
+
137
+ Args:
138
+ data: Binary audio data or text command ("finalize" or "done")
139
+ """
140
+ if self.websocket is None or self._is_websocket_closed():
141
+ await self.connect(
142
+ model=self._default_model,
143
+ language=self._default_language,
144
+ encoding=self._default_encoding,
145
+ sample_rate=self._default_sample_rate,
146
+ )
147
+
148
+ assert self.websocket is not None, "WebSocket should be connected after connect() call"
149
+
150
+ if isinstance(data, bytes):
151
+ await self.websocket.send_bytes(data)
152
+ elif isinstance(data, str):
153
+ await self.websocket.send_str(data)
154
+ else:
155
+ raise TypeError("Data must be bytes (audio) or str (command)")
156
+
157
+ async def receive(self) -> AsyncGenerator[Dict[str, Any], None]: # type: ignore[override]
158
+ """Receive transcription results from the WebSocket.
159
+
160
+ Yields:
161
+ Dictionary containing transcription results, flush_done, done, or error messages
162
+ """
163
+ if self.websocket is None or self._is_websocket_closed():
164
+ await self.connect(
165
+ model=self._default_model,
166
+ language=self._default_language,
167
+ encoding=self._default_encoding,
168
+ sample_rate=self._default_sample_rate,
169
+ )
170
+
171
+ assert self.websocket is not None, "WebSocket should be connected after connect() call"
172
+
173
+ try:
174
+ while True:
175
+ try:
176
+ msg = await asyncio.wait_for(self.websocket.receive(), timeout=self.timeout)
177
+
178
+ if msg.type == aiohttp.WSMsgType.TEXT:
179
+ raw_data = json.loads(msg.data)
180
+
181
+ # Handle error responses
182
+ if raw_data.get("type") == "error":
183
+ raise RuntimeError(f"Error transcribing audio: {raw_data.get('message', 'Unknown error')}")
184
+
185
+ # Handle transcript responses with flexible parsing
186
+ if raw_data.get("type") == "transcript":
187
+ # Provide defaults for missing required fields
188
+ result = {
189
+ "type": raw_data["type"],
190
+ "request_id": raw_data.get("request_id", ""),
191
+ "text": raw_data.get("text", ""), # Default to empty string if missing
192
+ "is_final": raw_data.get("is_final", False), # Default to False if missing
193
+ }
194
+
195
+ # Add optional fields if present
196
+ if "duration" in raw_data:
197
+ result["duration"] = raw_data["duration"]
198
+ if "language" in raw_data:
199
+ result["language"] = raw_data["language"]
200
+
201
+ yield result
202
+
203
+ # Handle flush_done acknowledgment
204
+ elif raw_data.get("type") == "flush_done":
205
+ result = {
206
+ "type": raw_data["type"],
207
+ "request_id": raw_data.get("request_id", ""),
208
+ }
209
+ yield result
210
+
211
+ # Handle done acknowledgment - session complete
212
+ elif raw_data.get("type") == "done":
213
+ result = {
214
+ "type": raw_data["type"],
215
+ "request_id": raw_data.get("request_id", ""),
216
+ }
217
+ yield result
218
+ # Session is complete, break out of loop
219
+ break
220
+
221
+ elif msg.type == aiohttp.WSMsgType.ERROR:
222
+ websocket_exception = self.websocket.exception() if self.websocket else None
223
+ await self.close()
224
+ raise RuntimeError(f"WebSocket error: {websocket_exception}")
225
+
226
+ elif msg.type == aiohttp.WSMsgType.CLOSE:
227
+ break
228
+
229
+ except asyncio.TimeoutError:
230
+ await self.close()
231
+ raise RuntimeError("Timeout while waiting for transcription")
232
+ except Exception as inner_e:
233
+ await self.close()
234
+ raise RuntimeError(f"Error receiving transcription: {inner_e}")
235
+
236
+ except Exception as e:
237
+ await self.close()
238
+ raise RuntimeError(f"Failed to receive transcription. {e}")
239
+
240
+ async def transcribe( # type: ignore[override]
241
+ self,
242
+ audio_chunks: typing.AsyncIterator[bytes],
243
+ *,
244
+ model: str = "ink-whisper",
245
+ language: Optional[str] = "en",
246
+ encoding: Optional[str] = "pcm_s16le",
247
+ sample_rate: int = 16000,
248
+ ) -> AsyncGenerator[Dict[str, Any], None]:
249
+ """Transcribe audio chunks using the WebSocket.
250
+
251
+ Args:
252
+ audio_chunks: Async iterator of audio chunks as bytes
253
+ model: ID of the model to use for transcription
254
+ language: The language of the input audio in ISO-639-1 format
255
+ encoding: The encoding format of the audio data
256
+ sample_rate: The sample rate of the audio in Hz
257
+
258
+ Yields:
259
+ Dictionary containing transcription results, flush_done, done, or error messages
260
+ """
261
+ await self.connect(
262
+ model=model,
263
+ language=language,
264
+ encoding=encoding,
265
+ sample_rate=sample_rate,
266
+ )
267
+
268
+ try:
269
+ # Send all audio chunks
270
+ async for chunk in audio_chunks:
271
+ await self.send(chunk)
272
+
273
+ # Send finalize command to flush remaining audio
274
+ await self.send("finalize")
275
+
276
+ # Send done command to close session cleanly
277
+ await self.send("done")
278
+
279
+ # Receive all responses until done
280
+ async for result in self.receive():
281
+ yield result
282
+
283
+ finally:
284
+ await self.close()