cartesia 0.0.5rc1__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cartesia/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from cartesia.tts import CartesiaTTS
1
+ from cartesia.tts import AsyncCartesiaTTS, CartesiaTTS
2
2
 
3
- __all__ = ["CartesiaTTS"]
3
+ __all__ = ["CartesiaTTS", "AsyncCartesiaTTS"]
cartesia/_types.py ADDED
@@ -0,0 +1,42 @@
1
+ from enum import Enum
2
+ from typing import List, Optional, TypedDict, Union
3
+
4
+ try:
5
+ import numpy as np
6
+
7
+ _NUMPY_AVAILABLE = True
8
+ except ImportError:
9
+ _NUMPY_AVAILABLE = False
10
+
11
+
12
+ class AudioDataReturnType(Enum):
13
+ BYTES = "bytes"
14
+ ARRAY = "array"
15
+
16
+
17
+ class AudioOutputFormat(Enum):
18
+ """Supported output formats for the audio."""
19
+
20
+ FP32 = "fp32" # float32
21
+ PCM = "pcm" # 16-bit signed integer PCM
22
+ FP32_16000 = "fp32_16000" # float32, 16 kHz
23
+ FP32_22050 = "fp32_22050" # float32, 22.05 kHz
24
+ FP32_44100 = "fp32_44100" # float32, 44.1 kHz
25
+ PCM_16000 = "pcm_16000" # 16-bit signed integer PCM, 16 kHz
26
+ PCM_22050 = "pcm_22050" # 16-bit signed integer PCM, 22.05 kHz
27
+ PCM_44100 = "pcm_44100" # 16-bit signed integer PCM, 44.1 kHz
28
+
29
+
30
+ class AudioOutput(TypedDict):
31
+ audio: Union[bytes, "np.ndarray"]
32
+ sampling_rate: int
33
+
34
+
35
+ Embedding = List[float]
36
+
37
+
38
+ class VoiceMetadata(TypedDict):
39
+ id: str
40
+ name: str
41
+ description: str
42
+ embedding: Optional[Embedding]
cartesia/tts.py CHANGED
@@ -3,33 +3,52 @@ import base64
3
3
  import json
4
4
  import os
5
5
  import uuid
6
- from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, TypedDict, Union
6
+ from types import TracebackType
7
+ from typing import (
8
+ Any,
9
+ AsyncGenerator,
10
+ Dict,
11
+ Generator,
12
+ List,
13
+ Optional,
14
+ Tuple,
15
+ TypedDict,
16
+ Union,
17
+ )
7
18
 
8
19
  import aiohttp
9
20
  import httpx
21
+ import logging
10
22
  import requests
11
23
  from websockets.sync.client import connect
12
24
 
13
- DEFAULT_MODEL_ID = "genial-planet-1346"
14
- DEFAULT_BASE_URL = "api.cartesia.ai"
15
- DEFAULT_API_VERSION = "v0"
16
- DEFAULT_TIMEOUT = 60 # seconds
17
- DEFAULT_NUM_CONNECTIONS = 10 # connections per client
25
+ from cartesia.utils import retry_on_connection_error, retry_on_connection_error_async
26
+ from cartesia._types import (
27
+ AudioDataReturnType,
28
+ AudioOutputFormat,
29
+ AudioOutput,
30
+ Embedding,
31
+ VoiceMetadata,
32
+ )
18
33
 
34
+ try:
35
+ import numpy as np
19
36
 
20
- class AudioOutput(TypedDict):
21
- audio: bytes
22
- sampling_rate: int
37
+ _NUMPY_AVAILABLE = True
38
+ except ImportError:
39
+ _NUMPY_AVAILABLE = False
23
40
 
24
41
 
25
- Embedding = List[float]
42
+ DEFAULT_MODEL_ID = ""
43
+ DEFAULT_BASE_URL = "api.cartesia.ai"
44
+ DEFAULT_API_VERSION = "v0"
45
+ DEFAULT_TIMEOUT = 30 # seconds
46
+ DEFAULT_NUM_CONNECTIONS = 10 # connections per client
26
47
 
48
+ BACKOFF_FACTOR = 1
49
+ MAX_RETRIES = 3
27
50
 
28
- class VoiceMetadata(TypedDict):
29
- id: str
30
- name: str
31
- description: str
32
- embedding: Optional[Embedding]
51
+ logger = logging.getLogger(__name__)
33
52
 
34
53
 
35
54
  def update_buffer(buffer: str, chunk_bytes: bytes) -> Tuple[str, List[Dict[str, Any]]]:
@@ -71,10 +90,8 @@ class CartesiaTTS:
71
90
  and generate speech from text.
72
91
 
73
92
  The client also supports generating audio using a websocket for lower latency.
74
- To enable interrupt handling along the websocket, set `experimental_ws_handle_interrupts=True`.
75
93
 
76
94
  Examples:
77
-
78
95
  >>> client = CartesiaTTS()
79
96
 
80
97
  # Load available voices and their metadata (excluding the embeddings).
@@ -95,23 +112,17 @@ class CartesiaTTS:
95
112
  ... audio, sr = audio_chunk["audio"], audio_chunk["sampling_rate"]
96
113
  """
97
114
 
98
- def __init__(self, *, api_key: str = None, experimental_ws_handle_interrupts: bool = False):
99
- """
100
- Args:
101
- api_key: The API key to use for authorization.
102
- If not specified, the API key will be read from the environment variable
103
- `CARTESIA_API_KEY`.
104
- experimental_ws_handle_interrupts: Whether to handle interrupts when generating
105
- audio using the websocket. This is an experimental feature and may have bugs
106
- or be deprecated in the future.
115
+ def __init__(self, *, api_key: str = None):
116
+ """Args:
117
+ api_key: The API key to use for authorization.
118
+ If not specified, the API key will be read from the environment variable
119
+ `CARTESIA_API_KEY`.
107
120
  """
108
121
  self.base_url = os.environ.get("CARTESIA_BASE_URL", DEFAULT_BASE_URL)
109
122
  self.api_key = api_key or os.environ.get("CARTESIA_API_KEY")
110
123
  self.api_version = os.environ.get("CARTESIA_API_VERSION", DEFAULT_API_VERSION)
111
124
  self.headers = {"X-API-Key": self.api_key, "Content-Type": "application/json"}
112
125
  self.websocket = None
113
- self.experimental_ws_handle_interrupts = experimental_ws_handle_interrupts
114
- self.refresh_websocket()
115
126
 
116
127
  def get_voices(self, skip_embeddings: bool = True) -> Dict[str, VoiceMetadata]:
117
128
  """Returns a mapping from voice name -> voice metadata.
@@ -144,18 +155,25 @@ class CartesiaTTS:
144
155
  >>> audio = client.generate(transcript="Hello world!", voice=embedding)
145
156
  """
146
157
  params = {"select": "id, name, description"} if skip_embeddings else None
147
- response = httpx.get(f"{self._http_url()}/voices", headers=self.headers, params=params)
158
+ response = httpx.get(
159
+ f"{self._http_url()}/voices",
160
+ headers=self.headers,
161
+ params=params,
162
+ timeout=DEFAULT_TIMEOUT,
163
+ )
148
164
 
149
165
  if not response.is_success:
150
166
  raise ValueError(f"Failed to get voices. Error: {response.text}")
151
167
 
152
168
  voices = response.json()
153
- # TODO: Update the API to return the embedding as a list of floats rather than string.
154
- if not skip_embeddings:
155
- for voice in voices:
169
+ for voice in voices:
170
+ if "embedding" in voice and isinstance(voice["embedding"], str):
156
171
  voice["embedding"] = json.loads(voice["embedding"])
157
172
  return {voice["name"]: voice for voice in voices}
158
173
 
174
+ @retry_on_connection_error(
175
+ max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger
176
+ )
159
177
  def get_voice_embedding(
160
178
  self, *, voice_id: str = None, filepath: str = None, link: str = None
161
179
  ) -> Embedding:
@@ -178,18 +196,18 @@ class CartesiaTTS:
178
196
 
179
197
  if voice_id:
180
198
  url = f"{self._http_url()}/voices/embedding/{voice_id}"
181
- response = httpx.get(url, headers=self.headers)
199
+ response = httpx.get(url, headers=self.headers, timeout=DEFAULT_TIMEOUT)
182
200
  elif filepath:
183
201
  url = f"{self._http_url()}/voices/clone/clip"
184
202
  files = {"clip": open(filepath, "rb")}
185
203
  headers = self.headers.copy()
186
204
  # The default content type of JSON is incorrect for file uploads
187
205
  headers.pop("Content-Type")
188
- response = httpx.post(url, headers=headers, files=files)
206
+ response = httpx.post(url, headers=headers, files=files, timeout=DEFAULT_TIMEOUT)
189
207
  elif link:
190
208
  url = f"{self._http_url()}/voices/clone/url"
191
209
  params = {"link": link}
192
- response = httpx.post(url, headers=self.headers, params=params)
210
+ response = httpx.post(url, headers=self.headers, params=params, timeout=DEFAULT_TIMEOUT)
193
211
 
194
212
  if not response.is_success:
195
213
  raise ValueError(
@@ -199,9 +217,10 @@ class CartesiaTTS:
199
217
 
200
218
  # Handle successful response
201
219
  out = response.json()
202
- if isinstance(out["embedding"], str):
203
- out["embedding"] = json.loads(out["embedding"])
204
- return out["embedding"]
220
+ embedding = out["embedding"]
221
+ if isinstance(embedding, str):
222
+ embedding = json.loads(embedding)
223
+ return embedding
205
224
 
206
225
  def refresh_websocket(self):
207
226
  """Refresh the websocket connection.
@@ -209,22 +228,30 @@ class CartesiaTTS:
209
228
  Note:
210
229
  The connection is synchronous.
211
230
  """
212
- if self.websocket and not self._is_websocket_closed():
213
- self.websocket.close()
214
- route = "audio/websocket"
215
- if self.experimental_ws_handle_interrupts:
216
- route = f"experimental/{route}"
217
- self.websocket = connect(
218
- f"{self._ws_url()}/{route}?api_key={self.api_key}",
219
- close_timeout=None,
220
- )
231
+ if self.websocket is None or self._is_websocket_closed():
232
+ route = "audio/websocket"
233
+ self.websocket = connect(f"{self._ws_url()}/{route}?api_key={self.api_key}")
221
234
 
222
235
  def _is_websocket_closed(self):
223
236
  return self.websocket.socket.fileno() == -1
224
237
 
225
238
  def _check_inputs(
226
- self, transcript: str, duration: Optional[float], chunk_time: Optional[float]
239
+ self,
240
+ transcript: str,
241
+ duration: Optional[float],
242
+ chunk_time: Optional[float],
243
+ output_format: Union[str, AudioOutputFormat],
244
+ data_rtype: Union[str, AudioDataReturnType],
227
245
  ):
246
+ # This will try the casting and raise an error.
247
+ _ = AudioOutputFormat(output_format)
248
+
249
+ if AudioDataReturnType(data_rtype) == AudioDataReturnType.ARRAY and not _NUMPY_AVAILABLE:
250
+ raise ImportError(
251
+ "The 'numpy' package is required to use the 'array' return type. "
252
+ "Please install 'numpy' or use 'bytes' as the return type."
253
+ )
254
+
228
255
  if chunk_time is not None:
229
256
  if chunk_time < 0.1 or chunk_time > 0.5:
230
257
  raise ValueError("`chunk_time` must be between 0.1 and 0.5")
@@ -240,20 +267,24 @@ class CartesiaTTS:
240
267
  self,
241
268
  *,
242
269
  transcript: str,
270
+ voice: Embedding,
271
+ model_id: str,
272
+ output_format: AudioOutputFormat,
243
273
  duration: int = None,
244
274
  chunk_time: float = None,
245
- voice: Embedding = None,
246
275
  ) -> Dict[str, Any]:
276
+ """Create the request body for a stream request.
277
+
278
+ Note that anything that's not provided will use a default if available or be
279
+ filtered out otherwise.
247
280
  """
248
- Create the request body for a stream request.
249
- Note that anything that's not provided will use a default if available or be filtered out otherwise.
250
- """
251
- body = dict(transcript=transcript, model_id=DEFAULT_MODEL_ID, voice=voice)
281
+ body = dict(transcript=transcript, model_id=model_id, voice=voice)
282
+ output_format = output_format.value
252
283
 
253
284
  optional_body = dict(
254
285
  duration=duration,
255
286
  chunk_time=chunk_time,
256
- voice=voice,
287
+ output_format=output_format,
257
288
  )
258
289
  body.update({k: v for k, v in optional_body.items() if v is not None})
259
290
 
@@ -263,25 +294,30 @@ class CartesiaTTS:
263
294
  self,
264
295
  *,
265
296
  transcript: str,
297
+ voice: Embedding,
298
+ model_id: str = DEFAULT_MODEL_ID,
266
299
  duration: int = None,
267
300
  chunk_time: float = None,
268
- voice: Embedding = None,
269
301
  stream: bool = False,
270
302
  websocket: bool = True,
303
+ output_format: Union[str, AudioOutputFormat] = "fp32",
304
+ data_rtype: str = "bytes",
271
305
  ) -> Union[AudioOutput, Generator[AudioOutput, None, None]]:
272
306
  """Generate audio from a transcript.
273
307
 
274
308
  Args:
275
- transcript: The text to generate audio for.
276
- duration: The maximum duration of the audio in seconds.
277
- chunk_time: How long each audio segment should be in seconds.
309
+ transcript (str): The text to generate audio for.
310
+ voice (Embedding (List[float])): The voice to use for generating audio.
311
+ duration (int, optional): The maximum duration of the audio in seconds.
312
+ chunk_time (float, optional): How long each audio segment should be in seconds.
278
313
  This should not need to be adjusted.
279
- voice: The voice to use for generating audio.
280
- This can either be a voice id (string) or an embedding vector (List[float]).
281
- stream: Whether to stream the audio or not.
282
- If ``True`` this function returns a generator.
283
- websocket: Whether to use a websocket for streaming audio.
284
- Using the websocket reduces latency by pre-poning the handshake.
314
+ stream (bool, optional): Whether to stream the audio or not.
315
+ If True this function returns a generator. False by default.
316
+ websocket (bool, optional): Whether to use a websocket for streaming audio.
317
+ Using the websocket reduces latency by pre-poning the handshake. True by default.
318
+ data_rtype: The return type for the 'data' key in the dictionary.
319
+ One of `'byte' | 'array'`.
320
+ Note this field is experimental and may be deprecated in the future.
285
321
 
286
322
  Returns:
287
323
  A generator if `stream` is True, otherwise a dictionary.
@@ -289,17 +325,28 @@ class CartesiaTTS:
289
325
  * "audio": The audio as a bytes buffer.
290
326
  * "sampling_rate": The sampling rate of the audio.
291
327
  """
292
- self._check_inputs(transcript, duration, chunk_time)
328
+ self._check_inputs(transcript, duration, chunk_time, output_format, data_rtype)
329
+
330
+ data_rtype = AudioDataReturnType(data_rtype)
331
+ output_format = AudioOutputFormat(output_format)
293
332
 
294
333
  body = self._generate_request_body(
295
- transcript=transcript, duration=duration, chunk_time=chunk_time, voice=voice
334
+ transcript=transcript,
335
+ voice=voice,
336
+ model_id=model_id,
337
+ duration=duration,
338
+ chunk_time=chunk_time,
339
+ output_format=output_format,
296
340
  )
297
341
 
298
342
  if websocket:
299
343
  generator = self._generate_ws(body)
300
344
  else:
301
- generator = self._generate_http(body)
345
+ generator = self._generate_http_wrapper(body)
302
346
 
347
+ generator = self._postprocess_audio(
348
+ generator, data_rtype=data_rtype, output_format=output_format
349
+ )
303
350
  if stream:
304
351
  return generator
305
352
 
@@ -310,14 +357,61 @@ class CartesiaTTS:
310
357
  sampling_rate = chunk["sampling_rate"]
311
358
  chunks.append(chunk["audio"])
312
359
 
313
- return {"audio": b"".join(chunks), "sampling_rate": sampling_rate}
360
+ if data_rtype == AudioDataReturnType.ARRAY:
361
+ cat = np.concatenate
362
+ else:
363
+ cat = b"".join
364
+
365
+ return {"audio": cat(chunks), "sampling_rate": sampling_rate}
366
+
367
+ def _postprocess_audio(
368
+ self,
369
+ generator: Generator[AudioOutput, None, None],
370
+ *,
371
+ data_rtype: AudioDataReturnType,
372
+ output_format: AudioOutputFormat,
373
+ ) -> Generator[AudioOutput, None, None]:
374
+ """Perform postprocessing on the generator outputs.
375
+
376
+ The postprocessing should be minimal (e.g. converting to array, casting dtype).
377
+ This code should not perform heavy operations like changing the sampling rate.
378
+
379
+ Args:
380
+ generator: A generator that yields audio chunks.
381
+ data_rtype: The data return type.
382
+ output_format: The output format for the audio.
383
+
384
+ Returns:
385
+ A generator that yields audio chunks.
386
+ """
387
+ dtype = None
388
+ if data_rtype == AudioDataReturnType.ARRAY:
389
+ dtype = np.float32 if "fp32" in output_format.value else np.int16
390
+
391
+ for chunk in generator:
392
+ if dtype is not None:
393
+ chunk["audio"] = np.frombuffer(chunk["audio"], dtype=dtype)
394
+ yield chunk
395
+
396
+ @retry_on_connection_error(
397
+ max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger
398
+ )
399
+ def _generate_http_wrapper(self, body: Dict[str, Any]):
400
+ """Need to wrap the http generator in a function for the retry decorator to work."""
401
+ try:
402
+ for chunk in self._generate_http(body):
403
+ yield chunk
404
+ except Exception as e:
405
+ logger.error(f"Failed to generate audio. {e}")
406
+ raise e
314
407
 
315
408
  def _generate_http(self, body: Dict[str, Any]):
316
409
  response = requests.post(
317
- f"{self._http_url()}/audio/stream",
410
+ f"{self._http_url()}/audio/sse",
318
411
  stream=True,
319
412
  data=json.dumps(body),
320
413
  headers=self.headers,
414
+ timeout=(DEFAULT_TIMEOUT, DEFAULT_TIMEOUT),
321
415
  )
322
416
  if not response.ok:
323
417
  raise ValueError(f"Failed to generate audio. {response.text}")
@@ -356,21 +450,33 @@ class CartesiaTTS:
356
450
  try:
357
451
  while True:
358
452
  response = json.loads(self.websocket.recv())
453
+ if "error" in response:
454
+ raise RuntimeError(f"Error generating audio:\n{response['error']}")
359
455
  if response["done"]:
360
456
  break
361
457
 
362
458
  yield convert_response(response, include_context_id)
363
-
364
- if self.experimental_ws_handle_interrupts:
365
- self.websocket.send(json.dumps({"context_id": context_id}))
366
- except GeneratorExit:
367
- # The exit is only called when the generator is garbage collected.
368
- # It may not be called directly after a break statement.
369
- # However, the generator will be automatically cancelled on the next request.
370
- if self.experimental_ws_handle_interrupts:
371
- self.websocket.send(json.dumps({"context_id": context_id, "action": "cancel"}))
372
459
  except Exception as e:
460
+ # Close the websocket connection if an error occurs.
461
+ if self.websocket and not self._is_websocket_closed():
462
+ self.websocket.close()
373
463
  raise RuntimeError(f"Failed to generate audio. {response}") from e
464
+ finally:
465
+ # Ensure the websocket is ultimately closed.
466
+ if self.websocket and not self._is_websocket_closed():
467
+ self.websocket.close()
468
+
469
+ def prepare_audio_and_headers(
470
+ self, raw_audio: Union[bytes, str]
471
+ ) -> Tuple[bytes, Dict[str, Any]]:
472
+ if isinstance(raw_audio, str):
473
+ with open(raw_audio, "rb") as f:
474
+ raw_audio_bytes = f.read()
475
+ else:
476
+ raw_audio_bytes = raw_audio
477
+ # application/json is not the right content type for this request
478
+ headers = {k: v for k, v in self.headers.items() if k != "Content-Type"}
479
+ return raw_audio_bytes, headers
374
480
 
375
481
  def _http_url(self):
376
482
  prefix = "http" if "localhost" in self.base_url else "https"
@@ -380,64 +486,103 @@ class CartesiaTTS:
380
486
  prefix = "ws" if "localhost" in self.base_url else "wss"
381
487
  return f"{prefix}://{self.base_url}/{self.api_version}"
382
488
 
383
- def __del__(self):
384
- if self.websocket.socket.fileno() > -1:
489
+ def close(self):
490
+ if self.websocket and not self._is_websocket_closed():
385
491
  self.websocket.close()
386
492
 
493
+ def __del__(self):
494
+ self.close()
387
495
 
388
- class AsyncCartesiaTTS(CartesiaTTS):
389
- def __init__(self, *, api_key: str = None, experimental_ws_handle_interrupts: bool = False):
390
- self.timeout = aiohttp.ClientTimeout(total=DEFAULT_TIMEOUT)
391
- self.connector = aiohttp.TCPConnector(limit=DEFAULT_NUM_CONNECTIONS)
392
- self._session = aiohttp.ClientSession(timeout=self.timeout, connector=self.connector)
393
- super().__init__(
394
- api_key=api_key, experimental_ws_handle_interrupts=experimental_ws_handle_interrupts
395
- )
496
+ def __enter__(self):
497
+ self.refresh_websocket()
498
+ return self
499
+
500
+ def __exit__(
501
+ self,
502
+ exc_type: Union[type, None],
503
+ exc: Union[BaseException, None],
504
+ exc_tb: Union[TracebackType, None],
505
+ ):
506
+ self.close()
396
507
 
397
- def refresh_websocket(self):
398
- pass # do not load the websocket for the client until asynchronously when it is needed
399
508
 
400
- async def _async_refresh_websocket(self):
509
+ class AsyncCartesiaTTS(CartesiaTTS):
510
+ def __init__(self, *, api_key: str = None):
511
+ self._session = None
512
+ self._loop = None
513
+ super().__init__(api_key=api_key)
514
+
515
+ async def _get_session(self):
516
+ current_loop = asyncio.get_event_loop()
517
+ if self._loop is not current_loop:
518
+ # If the loop has changed, close the session and create a new one.
519
+ await self.close()
520
+ if self._session is None or self._session.closed:
521
+ timeout = aiohttp.ClientTimeout(total=DEFAULT_TIMEOUT)
522
+ connector = aiohttp.TCPConnector(limit=DEFAULT_NUM_CONNECTIONS)
523
+ self._session = aiohttp.ClientSession(timeout=timeout, connector=connector)
524
+ self._loop = current_loop
525
+ return self._session
526
+
527
+ async def refresh_websocket(self):
401
528
  """Refresh the websocket connection."""
402
- if self.websocket and not self._is_websocket_closed():
403
- self.websocket.close()
404
- route = "audio/websocket"
405
- if self.experimental_ws_handle_interrupts:
406
- route = f"experimental/{route}"
407
- self.websocket = await self._session.ws_connect(
408
- f"{self._ws_url()}/{route}?api_key={self.api_key}"
409
- )
529
+ if self.websocket is None or self._is_websocket_closed():
530
+ route = "audio/websocket"
531
+ session = await self._get_session()
532
+ self.websocket = await session.ws_connect(
533
+ f"{self._ws_url()}/{route}?api_key={self.api_key}"
534
+ )
535
+
536
+ def _is_websocket_closed(self):
537
+ return self.websocket.closed
538
+
539
+ async def close(self):
540
+ """This method closes the websocket and the session.
541
+
542
+ It is *strongly* recommended to call this method when you are done using the client.
543
+ """
544
+ if self.websocket is not None and not self._is_websocket_closed():
545
+ await self.websocket.close()
546
+ if self._session is not None and not self._session.closed:
547
+ await self._session.close()
410
548
 
411
549
  async def generate(
412
550
  self,
413
551
  *,
414
552
  transcript: str,
553
+ voice: Embedding,
554
+ model_id: str = DEFAULT_MODEL_ID,
415
555
  duration: int = None,
416
556
  chunk_time: float = None,
417
- voice: Embedding = None,
418
557
  stream: bool = False,
419
558
  websocket: bool = True,
559
+ output_format: Union[str, AudioOutputFormat] = "fp32",
560
+ data_rtype: Union[str, AudioDataReturnType] = "bytes",
420
561
  ) -> Union[AudioOutput, AsyncGenerator[AudioOutput, None]]:
421
562
  """Asynchronously generate audio from a transcript.
422
- NOTE: This overrides the non-asynchronous generate method from the base class.
423
- Args:
424
- transcript: The text to generate audio for.
425
- voice: The embedding to use for generating audio.
426
- options: The options to use for generating audio. See :class:`GenerateOptions`.
427
- Returns:
428
- A dictionary containing the following:
429
- * "audio": The audio as a 1D numpy array.
430
- * "sampling_rate": The sampling rate of the audio.
563
+
564
+ For more information on the arguments, see the synchronous :meth:`CartesiaTTS.generate`.
431
565
  """
566
+ self._check_inputs(transcript, duration, chunk_time, output_format, data_rtype)
567
+ data_rtype = AudioDataReturnType(data_rtype)
568
+ output_format = AudioOutputFormat(output_format)
569
+
432
570
  body = self._generate_request_body(
433
- transcript=transcript, duration=duration, chunk_time=chunk_time, voice=voice
571
+ transcript=transcript,
572
+ voice=voice,
573
+ model_id=model_id,
574
+ duration=duration,
575
+ chunk_time=chunk_time,
576
+ output_format=output_format,
434
577
  )
435
578
 
436
579
  if websocket:
437
580
  generator = self._generate_ws(body)
438
581
  else:
439
- generator = self._generate_http(body)
440
-
582
+ generator = self._generate_http_wrapper(body)
583
+ generator = self._postprocess_audio(
584
+ generator, data_rtype=data_rtype, output_format=output_format
585
+ )
441
586
  if stream:
442
587
  return generator
443
588
 
@@ -448,14 +593,49 @@ class AsyncCartesiaTTS(CartesiaTTS):
448
593
  sampling_rate = chunk["sampling_rate"]
449
594
  chunks.append(chunk["audio"])
450
595
 
451
- return {"audio": b"".join(chunks), "sampling_rate": sampling_rate}
596
+ if data_rtype == AudioDataReturnType.ARRAY:
597
+ cat = np.concatenate
598
+ else:
599
+ cat = b"".join
600
+
601
+ return {"audio": cat(chunks), "sampling_rate": sampling_rate}
602
+
603
+ async def _postprocess_audio(
604
+ self,
605
+ generator: AsyncGenerator[AudioOutput, None],
606
+ *,
607
+ data_rtype: AudioDataReturnType,
608
+ output_format: AudioOutputFormat,
609
+ ) -> AsyncGenerator[AudioOutput, None]:
610
+ """See :meth:`CartesiaTTS._postprocess_audio`."""
611
+ dtype = None
612
+ if data_rtype == AudioDataReturnType.ARRAY:
613
+ dtype = np.float32 if "fp32" in output_format.value else np.int16
614
+
615
+ async for chunk in generator:
616
+ if dtype is not None:
617
+ chunk["audio"] = np.frombuffer(chunk["audio"], dtype=dtype)
618
+ yield chunk
619
+
620
+ @retry_on_connection_error_async(
621
+ max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger
622
+ )
623
+ async def _generate_http_wrapper(self, body: Dict[str, Any]):
624
+ """Need to wrap the http generator in a function for the retry decorator to work."""
625
+ try:
626
+ async for chunk in self._generate_http(body):
627
+ yield chunk
628
+ except Exception as e:
629
+ logger.error(f"Failed to generate audio. {e}")
630
+ raise e
452
631
 
453
632
  async def _generate_http(self, body: Dict[str, Any]):
454
- async with self._session.post(
455
- f"{self._http_url()}/audio/stream", data=json.dumps(body), headers=self.headers
633
+ session = await self._get_session()
634
+ async with session.post(
635
+ f"{self._http_url()}/audio/sse", data=json.dumps(body), headers=self.headers
456
636
  ) as response:
457
- if response.status < 200 or response.status >= 300:
458
- raise ValueError(f"Failed to generate audio. {response.text}")
637
+ if not response.ok:
638
+ raise ValueError(f"Failed to generate audio. {await response.text()}")
459
639
 
460
640
  buffer = ""
461
641
  async for chunk_bytes in response.content.iter_any():
@@ -473,12 +653,8 @@ class AsyncCartesiaTTS(CartesiaTTS):
473
653
 
474
654
  async def _generate_ws(self, body: Dict[str, Any], *, context_id: str = None):
475
655
  include_context_id = bool(context_id)
476
- route = "audio/websocket"
477
- if self.experimental_ws_handle_interrupts:
478
- route = f"experimental/{route}"
479
-
480
656
  if not self.websocket or self._is_websocket_closed():
481
- await self._async_refresh_websocket()
657
+ await self.refresh_websocket()
482
658
 
483
659
  ws = self.websocket
484
660
  if context_id is None:
@@ -492,26 +668,14 @@ class AsyncCartesiaTTS(CartesiaTTS):
492
668
  break
493
669
 
494
670
  yield convert_response(response, include_context_id)
495
-
496
- if self.experimental_ws_handle_interrupts:
497
- await ws.send_json({"context_id": context_id})
498
- except GeneratorExit:
499
- # The exit is only called when the generator is garbage collected.
500
- # It may not be called directly after a break statement.
501
- # However, the generator will be automatically cancelled on the next request.
502
- if self.experimental_ws_handle_interrupts:
503
- await ws.send_json({"context_id": context_id, "action": "cancel"})
504
671
  except Exception as e:
505
- raise RuntimeError(f"Failed to generate audio. {response}") from e
506
-
507
- def _is_websocket_closed(self):
508
- return self.websocket.closed
509
-
510
- async def cleanup(self):
511
- if self.websocket is not None and not self._is_websocket_closed():
512
- await self.websocket.close()
513
- if not self._session.closed:
514
- await self._session.close()
672
+ if self.websocket and not self._is_websocket_closed():
673
+ await self.websocket.close()
674
+ raise RuntimeError(f"Failed to generate audio. {await response.text()}") from e
675
+ finally:
676
+ # Ensure the websocket is ultimately closed.
677
+ if self.websocket and not self._is_websocket_closed():
678
+ await self.websocket.close()
515
679
 
516
680
  def __del__(self):
517
681
  try:
@@ -520,6 +684,18 @@ class AsyncCartesiaTTS(CartesiaTTS):
520
684
  loop = None
521
685
 
522
686
  if loop is None:
523
- asyncio.run(self.cleanup())
687
+ asyncio.run(self.close())
524
688
  else:
525
- loop.create_task(self.cleanup())
689
+ loop.create_task(self.close())
690
+
691
+ async def __aenter__(self):
692
+ await self.refresh_websocket()
693
+ return self
694
+
695
+ async def __aexit__(
696
+ self,
697
+ exc_type: Union[type, None],
698
+ exc: Union[BaseException, None],
699
+ exc_tb: Union[TracebackType, None],
700
+ ):
701
+ await self.close()
cartesia/utils.py ADDED
@@ -0,0 +1,87 @@
1
+ import time
2
+
3
+ from aiohttp.client_exceptions import ServerDisconnectedError
4
+ import asyncio
5
+ from functools import wraps
6
+ from http.client import RemoteDisconnected
7
+ from httpx import TimeoutException
8
+ from requests.exceptions import ConnectionError
9
+
10
+
11
+ def retry_on_connection_error(max_retries=3, backoff_factor=1, logger=None):
12
+ """Retry a function if a ConnectionError, RemoteDisconnected, ServerDisconnectedError, or TimeoutException occurs.
13
+
14
+ Args:
15
+ max_retries (int): The maximum number of retries.
16
+ backoff_factor (int): The factor to increase the delay between retries.
17
+ logger (logging.Logger): The logger to use for logging.
18
+ """
19
+
20
+ def decorator(func):
21
+ @wraps(func)
22
+ def wrapper(*args, **kwargs):
23
+ retry_count = 0
24
+ while retry_count < max_retries:
25
+ try:
26
+ return func(*args, **kwargs)
27
+ except (
28
+ ConnectionError,
29
+ RemoteDisconnected,
30
+ ServerDisconnectedError,
31
+ TimeoutException,
32
+ ) as e:
33
+ logger.info(f"Retrying after exception: {e}")
34
+ retry_count += 1
35
+ if retry_count < max_retries:
36
+ delay = backoff_factor * (2 ** (retry_count - 1))
37
+ logger.warn(
38
+ f"Attempt {retry_count + 1}/{max_retries} in {delay} seconds..."
39
+ )
40
+ time.sleep(delay)
41
+ else:
42
+ raise Exception(f"Exception occurred after {max_retries} tries.") from e
43
+
44
+ return wrapper
45
+
46
+ return decorator
47
+
48
+
49
+ def retry_on_connection_error_async(max_retries=3, backoff_factor=1, logger=None):
50
+ """Retry an asynchronous function if a ConnectionError, RemoteDisconnected, ServerDisconnectedError, or TimeoutException occurs.
51
+
52
+ Args:
53
+ max_retries (int): The maximum number of retries.
54
+ backoff_factor (int): The factor to increase the delay between retries.
55
+ logger (logging.Logger): The logger to use for logging.
56
+ """
57
+
58
+ def decorator(func):
59
+ @wraps(func)
60
+ async def wrapper(*args, **kwargs):
61
+ retry_count = 0
62
+ while retry_count < max_retries:
63
+ try:
64
+ async for chunk in func(*args, **kwargs):
65
+ yield chunk
66
+ # If the function completes without raising an exception return
67
+ return
68
+ except (
69
+ ConnectionError,
70
+ RemoteDisconnected,
71
+ ServerDisconnectedError,
72
+ TimeoutException,
73
+ ) as e:
74
+ logger.info(f"Retrying after exception: {e}")
75
+ retry_count += 1
76
+ if retry_count < max_retries:
77
+ delay = backoff_factor * (2 ** (retry_count - 1))
78
+ logger.warn(
79
+ f"Attempt {retry_count + 1}/{max_retries} in {delay} seconds..."
80
+ )
81
+ await asyncio.sleep(delay)
82
+ else:
83
+ raise Exception(f"Exception occurred after {max_retries} tries.") from e
84
+
85
+ return wrapper
86
+
87
+ return decorator
cartesia/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.0.5rc1"
1
+ __version__ = "0.1.0"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cartesia
3
- Version: 0.0.5rc1
3
+ Version: 0.1.0
4
4
  Summary: The official Python library for the Cartesia API.
5
5
  Home-page:
6
6
  Author: Cartesia, Inc.
@@ -16,25 +16,19 @@ Requires-Dist: pytest-asyncio
16
16
  Requires-Dist: requests
17
17
  Requires-Dist: websockets
18
18
  Provides-Extra: all
19
- Requires-Dist: pre-commit ; extra == 'all'
20
- Requires-Dist: docformatter ; extra == 'all'
21
- Requires-Dist: black ==24.1.1 ; extra == 'all'
22
- Requires-Dist: isort ==5.13.2 ; extra == 'all'
23
- Requires-Dist: flake8 ==7.0.0 ; extra == 'all'
24
- Requires-Dist: flake8-bugbear ==24.2.6 ; extra == 'all'
25
19
  Requires-Dist: pytest >=8.0.2 ; extra == 'all'
26
20
  Requires-Dist: pytest-cov >=4.1.0 ; extra == 'all'
27
21
  Requires-Dist: twine ; extra == 'all'
22
+ Requires-Dist: setuptools ; extra == 'all'
23
+ Requires-Dist: wheel ; extra == 'all'
24
+ Requires-Dist: numpy ; extra == 'all'
28
25
  Provides-Extra: dev
29
- Requires-Dist: pre-commit ; extra == 'dev'
30
- Requires-Dist: docformatter ; extra == 'dev'
31
- Requires-Dist: black ==24.1.1 ; extra == 'dev'
32
- Requires-Dist: isort ==5.13.2 ; extra == 'dev'
33
- Requires-Dist: flake8 ==7.0.0 ; extra == 'dev'
34
- Requires-Dist: flake8-bugbear ==24.2.6 ; extra == 'dev'
35
26
  Requires-Dist: pytest >=8.0.2 ; extra == 'dev'
36
27
  Requires-Dist: pytest-cov >=4.1.0 ; extra == 'dev'
37
28
  Requires-Dist: twine ; extra == 'dev'
29
+ Requires-Dist: setuptools ; extra == 'dev'
30
+ Requires-Dist: wheel ; extra == 'dev'
31
+ Requires-Dist: numpy ; extra == 'dev'
38
32
 
39
33
 
40
34
  # Cartesia Python API Library
@@ -60,13 +54,14 @@ client = CartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
60
54
  voices = client.get_voices()
61
55
  voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
62
56
  transcript = "Hello! Welcome to Cartesia"
57
+ model_id = "genial-planet-1346" # (Optional) We'll specify a default if you don't have a specific model in mind
63
58
 
64
59
  p = pyaudio.PyAudio()
65
60
 
66
61
  stream = None
67
62
 
68
63
  # Generate and stream audio
69
- for output in client.generate(transcript=transcript, voice=voice, stream=True):
64
+ for output in client.generate(transcript=transcript, voice=voice, model_id=model_id, stream=True):
70
65
  buffer = output["audio"]
71
66
  rate = output["sampling_rate"]
72
67
 
@@ -84,26 +79,68 @@ stream.close()
84
79
  p.terminate()
85
80
  ```
86
81
 
87
- If you are using Jupyter Notebook or JupyterLab, you can use IPython.display.Audio to play the generated audio directly in the notebook. Here's an example:
82
+ You can also use the async client if you want to make asynchronous API calls:
83
+ ```python
84
+ from cartesia.tts import AsyncCartesiaTTS
85
+ import asyncio
86
+ import pyaudio
87
+ import os
88
+
89
+ async def write_stream():
90
+ client = AsyncCartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
91
+ voices = client.get_voices()
92
+ voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
93
+ transcript = "Hello! Welcome to Cartesia"
94
+ model_id = "genial-planet-1346" # (Optional) We'll specify a default if you don't have a specific model in mind
95
+
96
+ p = pyaudio.PyAudio()
97
+
98
+ stream = None
99
+
100
+ # Generate and stream audio
101
+ async for output in await client.generate(transcript=transcript, voice=voice, model_id=model_id, stream=True):
102
+ buffer = output["audio"]
103
+ rate = output["sampling_rate"]
104
+
105
+ if not stream:
106
+ stream = p.open(format=pyaudio.paFloat32,
107
+ channels=1,
108
+ rate=rate,
109
+ output=True)
110
+
111
+ # Write the audio data to the stream
112
+ stream.write(buffer)
113
+
114
+ stream.stop_stream()
115
+ stream.close()
116
+ p.terminate()
117
+
118
+ asyncio.run(write_stream())
119
+ ```
120
+
121
+ If you are using Jupyter Notebook or JupyterLab, you can use IPython.display.Audio to play the generated audio directly in the notebook.
122
+ Additionally, in these notebook examples we show how to use the client as a context manager (though this is not required).
88
123
 
89
124
  ```python
90
- from cartesia.tts import CartesiaTTS
91
125
  from IPython.display import Audio
92
126
  import io
93
127
  import os
128
+ import numpy as np
94
129
 
95
- client = CartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
96
- voices = client.get_voices()
97
- voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
98
- transcript = "Hello! Welcome to Cartesia"
130
+ from cartesia.tts import CartesiaTTS
99
131
 
100
- # Create a BytesIO object to store the audio data
101
- audio_data = io.BytesIO()
132
+ with CartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY")) as client:
133
+ voices = client.get_voices()
134
+ voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
135
+ transcript = "Hello! Welcome to Cartesia"
102
136
 
103
- # Generate and stream audio
104
- for output in client.generate(transcript=transcript, voice=voice, stream=True):
105
- buffer = output["audio"]
106
- audio_data.write(buffer)
137
+ # Create a BytesIO object to store the audio data
138
+ audio_data = io.BytesIO()
139
+
140
+ # Generate and stream audio
141
+ for output in client.generate(transcript=transcript, voice=voice, stream=True):
142
+ buffer = output["audio"]
143
+ audio_data.write(buffer)
107
144
 
108
145
  # Set the cursor position to the beginning of the BytesIO object
109
146
  audio_data.seek(0)
@@ -115,25 +152,27 @@ audio = Audio(np.frombuffer(audio_data.read(), dtype=np.float32), rate=output["s
115
152
  display(audio)
116
153
  ```
117
154
 
118
- You can also use the async client if you want to make asynchronous API calls. The usage is very similar:
155
+ Below is the same example using the async client:
119
156
  ```python
120
- from cartesia.tts import AsyncCartesiaTTS
121
157
  from IPython.display import Audio
122
158
  import io
123
159
  import os
160
+ import numpy as np
124
161
 
125
- client = AsyncCartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
126
- voices = client.get_voices()
127
- voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
128
- transcript = "Hello! Welcome to Cartesia"
162
+ from cartesia.tts import AsyncCartesiaTTS
129
163
 
130
- # Create a BytesIO object to store the audio data
131
- audio_data = io.BytesIO()
164
+ async with AsyncCartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY")) as client:
165
+ voices = client.get_voices()
166
+ voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
167
+ transcript = "Hello! Welcome to Cartesia"
132
168
 
133
- # Generate and stream audio
134
- async for output in client.generate(transcript=transcript, voice=voice, stream=True):
135
- buffer = output["audio"]
136
- audio_data.write(buffer)
169
+ # Create a BytesIO object to store the audio data
170
+ audio_data = io.BytesIO()
171
+
172
+ # Generate and stream audio
173
+ async for output in await client.generate(transcript=transcript, voice=voice, stream=True):
174
+ buffer = output["audio"]
175
+ audio_data.write(buffer)
137
176
 
138
177
  # Set the cursor position to the beginning of the BytesIO object
139
178
  audio_data.seek(0)
@@ -0,0 +1,9 @@
1
+ cartesia/__init__.py,sha256=uIc9xGNPs8_A6eAvbTUY1geazunYoEZVWFKhCwC9TRA,102
2
+ cartesia/_types.py,sha256=uf2Pe-9g7nU-RNUxNAFN3j5Cwy0WyLP1oZf6VV5rGgw,1001
3
+ cartesia/tts.py,sha256=hAADPdTYu7yGsY7yIQIf1hjKKJLUk9pm5LU0cEIB8gA,25806
4
+ cartesia/utils.py,sha256=nuwWRfu3MOVTxIQMLjYf6WLaxSlnu_GdE3QjTV0zisQ,3339
5
+ cartesia/version.py,sha256=kUR5RAFc7HCeiqdlX36dZOHkUI5wI6V_43RpEcD8b-0,22
6
+ cartesia-0.1.0.dist-info/METADATA,sha256=H7spLdviK35R839_OAB47JL2FAaGw6AZ7CnNs_xy87Q,6050
7
+ cartesia-0.1.0.dist-info/WHEEL,sha256=DZajD4pwLWue70CAfc7YaxT1wLUciNBvN_TTcvXpltE,110
8
+ cartesia-0.1.0.dist-info/top_level.txt,sha256=rTX4HnnCegMxl1FK9czpVC7GAvf3SwDzPG65qP-BS4w,9
9
+ cartesia-0.1.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.41.2)
2
+ Generator: bdist_wheel (0.43.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py2-none-any
5
5
  Tag: py3-none-any
@@ -1,7 +0,0 @@
1
- cartesia/__init__.py,sha256=m8BX-qLjsMoI_JZtgf3jNi8R3cBZqYy-z4oEhYeJLdI,64
2
- cartesia/tts.py,sha256=yPLz41AR0oAYPUNW48mqmwEEbLBHCnbaK_wPT0iFBVk,20543
3
- cartesia/version.py,sha256=VkI5lk2CFatZR200RqGd8cBjTnMDmhtZW7DI6mPe6n4,25
4
- cartesia-0.0.5rc1.dist-info/METADATA,sha256=632D6iZ2IU3MLySAnMtwV2zQA38XkQv1rfFF4iRdAco,4893
5
- cartesia-0.0.5rc1.dist-info/WHEEL,sha256=iYlv5fX357PQyRT2o6tw1bN-YcKFFHKqB_LwHO5wP-g,110
6
- cartesia-0.0.5rc1.dist-info/top_level.txt,sha256=rTX4HnnCegMxl1FK9czpVC7GAvf3SwDzPG65qP-BS4w,9
7
- cartesia-0.0.5rc1.dist-info/RECORD,,