cartesia 0.0.5rc1__py2.py3-none-any.whl → 0.0.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cartesia/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from cartesia.tts import CartesiaTTS
1
+ from cartesia.tts import AsyncCartesiaTTS, CartesiaTTS
2
2
 
3
- __all__ = ["CartesiaTTS"]
3
+ __all__ = ["CartesiaTTS", "AsyncCartesiaTTS"]
cartesia/tts.py CHANGED
@@ -3,19 +3,27 @@ import base64
3
3
  import json
4
4
  import os
5
5
  import uuid
6
+ from types import TracebackType
6
7
  from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, TypedDict, Union
7
8
 
8
9
  import aiohttp
9
10
  import httpx
11
+ import logging
10
12
  import requests
11
13
  from websockets.sync.client import connect
12
14
 
13
- DEFAULT_MODEL_ID = "genial-planet-1346"
15
+ from cartesia.utils import retry_on_connection_error, retry_on_connection_error_async
16
+
17
+ DEFAULT_MODEL_ID = ""
14
18
  DEFAULT_BASE_URL = "api.cartesia.ai"
15
19
  DEFAULT_API_VERSION = "v0"
16
- DEFAULT_TIMEOUT = 60 # seconds
20
+ DEFAULT_TIMEOUT = 30 # seconds
17
21
  DEFAULT_NUM_CONNECTIONS = 10 # connections per client
18
22
 
23
+ BACKOFF_FACTOR = 1
24
+ MAX_RETRIES = 3
25
+
26
+ logger = logging.getLogger(__name__)
19
27
 
20
28
  class AudioOutput(TypedDict):
21
29
  audio: bytes
@@ -74,7 +82,6 @@ class CartesiaTTS:
74
82
  To enable interrupt handling along the websocket, set `experimental_ws_handle_interrupts=True`.
75
83
 
76
84
  Examples:
77
-
78
85
  >>> client = CartesiaTTS()
79
86
 
80
87
  # Load available voices and their metadata (excluding the embeddings).
@@ -96,14 +103,13 @@ class CartesiaTTS:
96
103
  """
97
104
 
98
105
  def __init__(self, *, api_key: str = None, experimental_ws_handle_interrupts: bool = False):
99
- """
100
- Args:
101
- api_key: The API key to use for authorization.
102
- If not specified, the API key will be read from the environment variable
103
- `CARTESIA_API_KEY`.
104
- experimental_ws_handle_interrupts: Whether to handle interrupts when generating
105
- audio using the websocket. This is an experimental feature and may have bugs
106
- or be deprecated in the future.
106
+ """Args:
107
+ api_key: The API key to use for authorization.
108
+ If not specified, the API key will be read from the environment variable
109
+ `CARTESIA_API_KEY`.
110
+ experimental_ws_handle_interrupts: Whether to handle interrupts when generating
111
+ audio using the websocket. This is an experimental feature and may have bugs
112
+ or be deprecated in the future.
107
113
  """
108
114
  self.base_url = os.environ.get("CARTESIA_BASE_URL", DEFAULT_BASE_URL)
109
115
  self.api_key = api_key or os.environ.get("CARTESIA_API_KEY")
@@ -111,7 +117,6 @@ class CartesiaTTS:
111
117
  self.headers = {"X-API-Key": self.api_key, "Content-Type": "application/json"}
112
118
  self.websocket = None
113
119
  self.experimental_ws_handle_interrupts = experimental_ws_handle_interrupts
114
- self.refresh_websocket()
115
120
 
116
121
  def get_voices(self, skip_embeddings: bool = True) -> Dict[str, VoiceMetadata]:
117
122
  """Returns a mapping from voice name -> voice metadata.
@@ -144,18 +149,23 @@ class CartesiaTTS:
144
149
  >>> audio = client.generate(transcript="Hello world!", voice=embedding)
145
150
  """
146
151
  params = {"select": "id, name, description"} if skip_embeddings else None
147
- response = httpx.get(f"{self._http_url()}/voices", headers=self.headers, params=params)
152
+ response = httpx.get(
153
+ f"{self._http_url()}/voices",
154
+ headers=self.headers,
155
+ params=params,
156
+ timeout=DEFAULT_TIMEOUT,
157
+ )
148
158
 
149
159
  if not response.is_success:
150
160
  raise ValueError(f"Failed to get voices. Error: {response.text}")
151
161
 
152
162
  voices = response.json()
153
- # TODO: Update the API to return the embedding as a list of floats rather than string.
154
- if not skip_embeddings:
155
- for voice in voices:
163
+ for voice in voices:
164
+ if "embedding" in voice and isinstance(voice["embedding"], str):
156
165
  voice["embedding"] = json.loads(voice["embedding"])
157
166
  return {voice["name"]: voice for voice in voices}
158
167
 
168
+ @retry_on_connection_error(max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger)
159
169
  def get_voice_embedding(
160
170
  self, *, voice_id: str = None, filepath: str = None, link: str = None
161
171
  ) -> Embedding:
@@ -178,18 +188,18 @@ class CartesiaTTS:
178
188
 
179
189
  if voice_id:
180
190
  url = f"{self._http_url()}/voices/embedding/{voice_id}"
181
- response = httpx.get(url, headers=self.headers)
191
+ response = httpx.get(url, headers=self.headers, timeout=DEFAULT_TIMEOUT)
182
192
  elif filepath:
183
193
  url = f"{self._http_url()}/voices/clone/clip"
184
194
  files = {"clip": open(filepath, "rb")}
185
195
  headers = self.headers.copy()
186
196
  # The default content type of JSON is incorrect for file uploads
187
197
  headers.pop("Content-Type")
188
- response = httpx.post(url, headers=headers, files=files)
198
+ response = httpx.post(url, headers=headers, files=files, timeout=DEFAULT_TIMEOUT)
189
199
  elif link:
190
200
  url = f"{self._http_url()}/voices/clone/url"
191
201
  params = {"link": link}
192
- response = httpx.post(url, headers=self.headers, params=params)
202
+ response = httpx.post(url, headers=self.headers, params=params, timeout=DEFAULT_TIMEOUT)
193
203
 
194
204
  if not response.is_success:
195
205
  raise ValueError(
@@ -199,9 +209,10 @@ class CartesiaTTS:
199
209
 
200
210
  # Handle successful response
201
211
  out = response.json()
202
- if isinstance(out["embedding"], str):
203
- out["embedding"] = json.loads(out["embedding"])
204
- return out["embedding"]
212
+ embedding = out["embedding"]
213
+ if isinstance(embedding, str):
214
+ embedding = json.loads(embedding)
215
+ return embedding
205
216
 
206
217
  def refresh_websocket(self):
207
218
  """Refresh the websocket connection.
@@ -209,15 +220,11 @@ class CartesiaTTS:
209
220
  Note:
210
221
  The connection is synchronous.
211
222
  """
212
- if self.websocket and not self._is_websocket_closed():
213
- self.websocket.close()
214
- route = "audio/websocket"
215
- if self.experimental_ws_handle_interrupts:
216
- route = f"experimental/{route}"
217
- self.websocket = connect(
218
- f"{self._ws_url()}/{route}?api_key={self.api_key}",
219
- close_timeout=None,
220
- )
223
+ if self.websocket is None or self._is_websocket_closed():
224
+ route = "audio/websocket"
225
+ if self.experimental_ws_handle_interrupts:
226
+ route = f"experimental/{route}"
227
+ self.websocket = connect(f"{self._ws_url()}/{route}?api_key={self.api_key}")
221
228
 
222
229
  def _is_websocket_closed(self):
223
230
  return self.websocket.socket.fileno() == -1
@@ -240,20 +247,23 @@ class CartesiaTTS:
240
247
  self,
241
248
  *,
242
249
  transcript: str,
250
+ voice: Embedding,
251
+ model_id: str,
252
+ output_format: str,
243
253
  duration: int = None,
244
254
  chunk_time: float = None,
245
- voice: Embedding = None,
246
255
  ) -> Dict[str, Any]:
256
+ """Create the request body for a stream request.
257
+
258
+ Note that anything that's not provided will use a default if available or be
259
+ filtered out otherwise.
247
260
  """
248
- Create the request body for a stream request.
249
- Note that anything that's not provided will use a default if available or be filtered out otherwise.
250
- """
251
- body = dict(transcript=transcript, model_id=DEFAULT_MODEL_ID, voice=voice)
261
+ body = dict(transcript=transcript, model_id=model_id, voice=voice)
252
262
 
253
263
  optional_body = dict(
254
264
  duration=duration,
255
265
  chunk_time=chunk_time,
256
- voice=voice,
266
+ output_format=output_format,
257
267
  )
258
268
  body.update({k: v for k, v in optional_body.items() if v is not None})
259
269
 
@@ -263,25 +273,26 @@ class CartesiaTTS:
263
273
  self,
264
274
  *,
265
275
  transcript: str,
276
+ voice: Embedding,
277
+ model_id: str = DEFAULT_MODEL_ID,
266
278
  duration: int = None,
267
279
  chunk_time: float = None,
268
- voice: Embedding = None,
269
280
  stream: bool = False,
270
281
  websocket: bool = True,
282
+ output_format: str = "fp32",
271
283
  ) -> Union[AudioOutput, Generator[AudioOutput, None, None]]:
272
284
  """Generate audio from a transcript.
273
285
 
274
286
  Args:
275
- transcript: The text to generate audio for.
276
- duration: The maximum duration of the audio in seconds.
277
- chunk_time: How long each audio segment should be in seconds.
287
+ transcript (str): The text to generate audio for.
288
+ voice (Embedding (List[float])): The voice to use for generating audio.
289
+ duration (int, optional): The maximum duration of the audio in seconds.
290
+ chunk_time (float, optional): How long each audio segment should be in seconds.
278
291
  This should not need to be adjusted.
279
- voice: The voice to use for generating audio.
280
- This can either be a voice id (string) or an embedding vector (List[float]).
281
- stream: Whether to stream the audio or not.
282
- If ``True`` this function returns a generator.
283
- websocket: Whether to use a websocket for streaming audio.
284
- Using the websocket reduces latency by pre-poning the handshake.
292
+ stream (bool, optional): Whether to stream the audio or not.
293
+ If True this function returns a generator. False by default.
294
+ websocket (bool, optional): Whether to use a websocket for streaming audio.
295
+ Using the websocket reduces latency by pre-poning the handshake. True by default.
285
296
 
286
297
  Returns:
287
298
  A generator if `stream` is True, otherwise a dictionary.
@@ -292,13 +303,18 @@ class CartesiaTTS:
292
303
  self._check_inputs(transcript, duration, chunk_time)
293
304
 
294
305
  body = self._generate_request_body(
295
- transcript=transcript, duration=duration, chunk_time=chunk_time, voice=voice
306
+ transcript=transcript,
307
+ voice=voice,
308
+ model_id=model_id,
309
+ duration=duration,
310
+ chunk_time=chunk_time,
311
+ output_format=output_format,
296
312
  )
297
313
 
298
314
  if websocket:
299
315
  generator = self._generate_ws(body)
300
316
  else:
301
- generator = self._generate_http(body)
317
+ generator = self._generate_http_wrapper(body)
302
318
 
303
319
  if stream:
304
320
  return generator
@@ -312,12 +328,23 @@ class CartesiaTTS:
312
328
 
313
329
  return {"audio": b"".join(chunks), "sampling_rate": sampling_rate}
314
330
 
331
+ @retry_on_connection_error(max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger)
332
+ def _generate_http_wrapper(self, body: Dict[str, Any]):
333
+ """Need to wrap the http generator in a function for the retry decorator to work."""
334
+ try:
335
+ for chunk in self._generate_http(body):
336
+ yield chunk
337
+ except Exception as e:
338
+ logger.error(f"Failed to generate audio. {e}")
339
+ raise e
340
+
315
341
  def _generate_http(self, body: Dict[str, Any]):
316
342
  response = requests.post(
317
- f"{self._http_url()}/audio/stream",
343
+ f"{self._http_url()}/audio/sse",
318
344
  stream=True,
319
345
  data=json.dumps(body),
320
346
  headers=self.headers,
347
+ timeout=(DEFAULT_TIMEOUT, DEFAULT_TIMEOUT),
321
348
  )
322
349
  if not response.ok:
323
350
  raise ValueError(f"Failed to generate audio. {response.text}")
@@ -356,6 +383,8 @@ class CartesiaTTS:
356
383
  try:
357
384
  while True:
358
385
  response = json.loads(self.websocket.recv())
386
+ if "error" in response:
387
+ raise RuntimeError(f"Error generating audio:\n{response['error']}")
359
388
  if response["done"]:
360
389
  break
361
390
 
@@ -370,7 +399,43 @@ class CartesiaTTS:
370
399
  if self.experimental_ws_handle_interrupts:
371
400
  self.websocket.send(json.dumps({"context_id": context_id, "action": "cancel"}))
372
401
  except Exception as e:
402
+ # Close the websocket connection if an error occurs.
403
+ if self.websocket and not self._is_websocket_closed():
404
+ self.websocket.close()
373
405
  raise RuntimeError(f"Failed to generate audio. {response}") from e
406
+ finally:
407
+ # Ensure the websocket is ultimately closed.
408
+ if self.websocket and not self._is_websocket_closed():
409
+ self.websocket.close()
410
+
411
+ @retry_on_connection_error(max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger)
412
+ def transcribe(self, raw_audio: Union[bytes, str]) -> str:
413
+ raw_audio_bytes, headers = self.prepare_audio_and_headers(raw_audio)
414
+ response = httpx.post(
415
+ f"{self._http_url()}/audio/transcriptions",
416
+ headers=headers,
417
+ files={"clip": ("input.wav", raw_audio_bytes)},
418
+ timeout=DEFAULT_TIMEOUT,
419
+ )
420
+
421
+ if not response.is_success:
422
+ raise ValueError(f"Failed to transcribe audio. Error: {response.text()}")
423
+
424
+ transcript = response.json()
425
+ return transcript["text"]
426
+
427
+
428
+ def prepare_audio_and_headers(
429
+ self, raw_audio: Union[bytes, str]
430
+ ) -> Tuple[bytes, Dict[str, Any]]:
431
+ if isinstance(raw_audio, str):
432
+ with open(raw_audio, "rb") as f:
433
+ raw_audio_bytes = f.read()
434
+ else:
435
+ raw_audio_bytes = raw_audio
436
+ # application/json is not the right content type for this request
437
+ headers = {k: v for k, v in self.headers.items() if k != "Content-Type"}
438
+ return raw_audio_bytes, headers
374
439
 
375
440
  def _http_url(self):
376
441
  prefix = "http" if "localhost" in self.base_url else "https"
@@ -380,63 +445,119 @@ class CartesiaTTS:
380
445
  prefix = "ws" if "localhost" in self.base_url else "wss"
381
446
  return f"{prefix}://{self.base_url}/{self.api_version}"
382
447
 
383
- def __del__(self):
384
- if self.websocket.socket.fileno() > -1:
448
+ def close(self):
449
+ if self.websocket and not self._is_websocket_closed():
385
450
  self.websocket.close()
386
451
 
452
+ def __del__(self):
453
+ self.close()
454
+
455
+ def __enter__(self):
456
+ self.refresh_websocket()
457
+ return self
458
+
459
+ def __exit__(
460
+ self,
461
+ exc_type: Union[type, None],
462
+ exc: Union[BaseException, None],
463
+ exc_tb: Union[TracebackType, None],
464
+ ):
465
+ self.close()
466
+
387
467
 
388
468
  class AsyncCartesiaTTS(CartesiaTTS):
389
469
  def __init__(self, *, api_key: str = None, experimental_ws_handle_interrupts: bool = False):
390
- self.timeout = aiohttp.ClientTimeout(total=DEFAULT_TIMEOUT)
391
- self.connector = aiohttp.TCPConnector(limit=DEFAULT_NUM_CONNECTIONS)
392
- self._session = aiohttp.ClientSession(timeout=self.timeout, connector=self.connector)
470
+ self._session = None
471
+ self._loop = None
393
472
  super().__init__(
394
473
  api_key=api_key, experimental_ws_handle_interrupts=experimental_ws_handle_interrupts
395
474
  )
396
-
397
- def refresh_websocket(self):
398
- pass # do not load the websocket for the client until asynchronously when it is needed
399
-
400
- async def _async_refresh_websocket(self):
475
+
476
+ async def _get_session(self):
477
+ current_loop = asyncio.get_event_loop()
478
+ if self._loop is not current_loop:
479
+ # If the loop has changed, close the session and create a new one.
480
+ await self.close()
481
+ if self._session is None or self._session.closed:
482
+ timeout = aiohttp.ClientTimeout(total=DEFAULT_TIMEOUT)
483
+ connector = aiohttp.TCPConnector(limit=DEFAULT_NUM_CONNECTIONS)
484
+ self._session = aiohttp.ClientSession(
485
+ timeout=timeout, connector=connector
486
+ )
487
+ self._loop = current_loop
488
+ return self._session
489
+
490
+ async def refresh_websocket(self):
401
491
  """Refresh the websocket connection."""
402
- if self.websocket and not self._is_websocket_closed():
403
- self.websocket.close()
404
- route = "audio/websocket"
405
- if self.experimental_ws_handle_interrupts:
406
- route = f"experimental/{route}"
407
- self.websocket = await self._session.ws_connect(
408
- f"{self._ws_url()}/{route}?api_key={self.api_key}"
409
- )
492
+ if self.websocket is None or self._is_websocket_closed():
493
+ route = "audio/websocket"
494
+ if self.experimental_ws_handle_interrupts:
495
+ route = f"experimental/{route}"
496
+ session = await self._get_session()
497
+ self.websocket = await session.ws_connect(
498
+ f"{self._ws_url()}/{route}?api_key={self.api_key}"
499
+ )
500
+
501
+ def _is_websocket_closed(self):
502
+ return self.websocket.closed
503
+
504
+ async def close(self):
505
+ """This method closes the websocket and the session.
506
+
507
+ It is *strongly* recommended to call this method when you are done using the client.
508
+ """
509
+ if self.websocket is not None and not self._is_websocket_closed():
510
+ await self.websocket.close()
511
+ if self._session is not None and not self._session.closed:
512
+ await self._session.close()
410
513
 
411
514
  async def generate(
412
515
  self,
413
516
  *,
414
517
  transcript: str,
518
+ voice: Embedding,
519
+ model_id: str = DEFAULT_MODEL_ID,
415
520
  duration: int = None,
416
521
  chunk_time: float = None,
417
- voice: Embedding = None,
418
522
  stream: bool = False,
419
523
  websocket: bool = True,
524
+ output_format: str = "fp32"
420
525
  ) -> Union[AudioOutput, AsyncGenerator[AudioOutput, None]]:
421
526
  """Asynchronously generate audio from a transcript.
422
527
  NOTE: This overrides the non-asynchronous generate method from the base class.
528
+
423
529
  Args:
424
- transcript: The text to generate audio for.
425
- voice: The embedding to use for generating audio.
426
- options: The options to use for generating audio. See :class:`GenerateOptions`.
530
+ transcript (str): The text to generate audio for.
531
+ voice (Embedding (List[float])): The voice to use for generating audio.
532
+ duration (int, optional): The maximum duration of the audio in seconds.
533
+ chunk_time (float, optional): How long each audio segment should be in seconds.
534
+ This should not need to be adjusted.
535
+ stream (bool, optional): Whether to stream the audio or not.
536
+ If True this function returns a generator. False by default.
537
+ websocket (bool, optional): Whether to use a websocket for streaming audio.
538
+ Using the websocket reduces latency by pre-poning the handshake. True by default.
539
+
427
540
  Returns:
428
- A dictionary containing the following:
429
- * "audio": The audio as a 1D numpy array.
541
+ A generator if `stream` is True, otherwise a dictionary.
542
+ Dictionary from both generator and non-generator return types have the following keys:
543
+ * "audio": The audio as a bytes buffer.
430
544
  * "sampling_rate": The sampling rate of the audio.
431
545
  """
546
+ self._check_inputs(transcript, duration, chunk_time)
547
+
432
548
  body = self._generate_request_body(
433
- transcript=transcript, duration=duration, chunk_time=chunk_time, voice=voice
549
+ transcript=transcript,
550
+ voice=voice,
551
+ model_id=model_id,
552
+ duration=duration,
553
+ chunk_time=chunk_time,
554
+ output_format=output_format,
434
555
  )
435
556
 
436
557
  if websocket:
437
558
  generator = self._generate_ws(body)
438
559
  else:
439
- generator = self._generate_http(body)
560
+ generator = self._generate_http_wrapper(body)
440
561
 
441
562
  if stream:
442
563
  return generator
@@ -450,12 +571,23 @@ class AsyncCartesiaTTS(CartesiaTTS):
450
571
 
451
572
  return {"audio": b"".join(chunks), "sampling_rate": sampling_rate}
452
573
 
574
+ @retry_on_connection_error_async(max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger)
575
+ async def _generate_http_wrapper(self, body: Dict[str, Any]):
576
+ """Need to wrap the http generator in a function for the retry decorator to work."""
577
+ try:
578
+ async for chunk in self._generate_http(body):
579
+ yield chunk
580
+ except Exception as e:
581
+ logger.error(f"Failed to generate audio. {e}")
582
+ raise e
583
+
453
584
  async def _generate_http(self, body: Dict[str, Any]):
454
- async with self._session.post(
455
- f"{self._http_url()}/audio/stream", data=json.dumps(body), headers=self.headers
585
+ session = await self._get_session()
586
+ async with session.post(
587
+ f"{self._http_url()}/audio/sse", data=json.dumps(body), headers=self.headers
456
588
  ) as response:
457
- if response.status < 200 or response.status >= 300:
458
- raise ValueError(f"Failed to generate audio. {response.text}")
589
+ if not response.ok:
590
+ raise ValueError(f"Failed to generate audio. {await response.text()}")
459
591
 
460
592
  buffer = ""
461
593
  async for chunk_bytes in response.content.iter_any():
@@ -478,7 +610,7 @@ class AsyncCartesiaTTS(CartesiaTTS):
478
610
  route = f"experimental/{route}"
479
611
 
480
612
  if not self.websocket or self._is_websocket_closed():
481
- await self._async_refresh_websocket()
613
+ await self.refresh_websocket()
482
614
 
483
615
  ws = self.websocket
484
616
  if context_id is None:
@@ -502,17 +634,29 @@ class AsyncCartesiaTTS(CartesiaTTS):
502
634
  if self.experimental_ws_handle_interrupts:
503
635
  await ws.send_json({"context_id": context_id, "action": "cancel"})
504
636
  except Exception as e:
505
- raise RuntimeError(f"Failed to generate audio. {response}") from e
506
-
507
- def _is_websocket_closed(self):
508
- return self.websocket.closed
509
-
510
- async def cleanup(self):
511
- if self.websocket is not None and not self._is_websocket_closed():
512
- await self.websocket.close()
513
- if not self._session.closed:
514
- await self._session.close()
637
+ if self.websocket and not self._is_websocket_closed():
638
+ await self.websocket.close()
639
+ raise RuntimeError(f"Failed to generate audio. {await response.text()}") from e
640
+ finally:
641
+ # Ensure the websocket is ultimately closed.
642
+ if self.websocket and not self._is_websocket_closed():
643
+ await self.websocket.close()
644
+
645
+ async def transcribe(self, raw_audio: Union[bytes, str]) -> str:
646
+ raw_audio_bytes, headers = self.prepare_audio_and_headers(raw_audio)
647
+ data = aiohttp.FormData()
648
+ data.add_field("clip", raw_audio_bytes, filename="input.wav", content_type="audio/wav")
649
+ session = await self._get_session()
650
+
651
+ async with session.post(
652
+ f"{self._http_url()}/audio/transcriptions", headers=headers, data=data
653
+ ) as response:
654
+ if not response.ok:
655
+ raise ValueError(f"Failed to transcribe audio. Error: {await response.text()}")
515
656
 
657
+ transcript = await response.json()
658
+ return transcript["text"]
659
+
516
660
  def __del__(self):
517
661
  try:
518
662
  loop = asyncio.get_running_loop()
@@ -520,6 +664,18 @@ class AsyncCartesiaTTS(CartesiaTTS):
520
664
  loop = None
521
665
 
522
666
  if loop is None:
523
- asyncio.run(self.cleanup())
667
+ asyncio.run(self.close())
524
668
  else:
525
- loop.create_task(self.cleanup())
669
+ loop.create_task(self.close())
670
+
671
+ async def __aenter__(self):
672
+ await self.refresh_websocket()
673
+ return self
674
+
675
+ async def __aexit__(
676
+ self,
677
+ exc_type: Union[type, None],
678
+ exc: Union[BaseException, None],
679
+ exc_tb: Union[TracebackType, None],
680
+ ):
681
+ await self.close()
cartesia/utils.py ADDED
@@ -0,0 +1,65 @@
1
+ import time
2
+
3
+ from aiohttp.client_exceptions import ServerDisconnectedError
4
+ import asyncio
5
+ from functools import wraps
6
+ from http.client import RemoteDisconnected
7
+ from httpx import TimeoutException
8
+ from requests.exceptions import ConnectionError
9
+
10
+ def retry_on_connection_error(max_retries=3, backoff_factor=1, logger=None):
11
+ """Retry a function if a ConnectionError, RemoteDisconnected, ServerDisconnectedError, or TimeoutException occurs.
12
+
13
+ Args:
14
+ max_retries (int): The maximum number of retries.
15
+ backoff_factor (int): The factor to increase the delay between retries.
16
+ logger (logging.Logger): The logger to use for logging.
17
+ """
18
+ def decorator(func):
19
+ @wraps(func)
20
+ def wrapper(*args, **kwargs):
21
+ retry_count = 0
22
+ while retry_count < max_retries:
23
+ try:
24
+ return func(*args, **kwargs)
25
+ except (ConnectionError, RemoteDisconnected, ServerDisconnectedError, TimeoutException) as e:
26
+ logger.info(f"Retrying after exception: {e}")
27
+ retry_count += 1
28
+ if retry_count < max_retries:
29
+ delay = backoff_factor * (2 ** (retry_count - 1))
30
+ logger.warn(f"Attempt {retry_count + 1}/{max_retries} in {delay} seconds...")
31
+ time.sleep(delay)
32
+ else:
33
+ raise Exception(f"Exception occurred after {max_retries} tries.") from e
34
+ return wrapper
35
+ return decorator
36
+
37
+ def retry_on_connection_error_async(max_retries=3, backoff_factor=1, logger=None):
38
+ """Retry an asynchronous function if a ConnectionError, RemoteDisconnected, ServerDisconnectedError, or TimeoutException occurs.
39
+
40
+ Args:
41
+ max_retries (int): The maximum number of retries.
42
+ backoff_factor (int): The factor to increase the delay between retries.
43
+ logger (logging.Logger): The logger to use for logging.
44
+ """
45
+ def decorator(func):
46
+ @wraps(func)
47
+ async def wrapper(*args, **kwargs):
48
+ retry_count = 0
49
+ while retry_count < max_retries:
50
+ try:
51
+ async for chunk in func(*args, **kwargs):
52
+ yield chunk
53
+ # If the function completes without raising an exception return
54
+ return
55
+ except (ConnectionError, RemoteDisconnected, ServerDisconnectedError, TimeoutException) as e:
56
+ logger.info(f"Retrying after exception: {e}")
57
+ retry_count += 1
58
+ if retry_count < max_retries:
59
+ delay = backoff_factor * (2 ** (retry_count - 1))
60
+ logger.warn(f"Attempt {retry_count + 1}/{max_retries} in {delay} seconds...")
61
+ await asyncio.sleep(delay)
62
+ else:
63
+ raise Exception(f"Exception occurred after {max_retries} tries.") from e
64
+ return wrapper
65
+ return decorator
cartesia/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.0.5rc1"
1
+ __version__ = "0.0.6"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cartesia
3
- Version: 0.0.5rc1
3
+ Version: 0.0.6
4
4
  Summary: The official Python library for the Cartesia API.
5
5
  Home-page:
6
6
  Author: Cartesia, Inc.
@@ -16,25 +16,17 @@ Requires-Dist: pytest-asyncio
16
16
  Requires-Dist: requests
17
17
  Requires-Dist: websockets
18
18
  Provides-Extra: all
19
- Requires-Dist: pre-commit ; extra == 'all'
20
- Requires-Dist: docformatter ; extra == 'all'
21
- Requires-Dist: black ==24.1.1 ; extra == 'all'
22
- Requires-Dist: isort ==5.13.2 ; extra == 'all'
23
- Requires-Dist: flake8 ==7.0.0 ; extra == 'all'
24
- Requires-Dist: flake8-bugbear ==24.2.6 ; extra == 'all'
25
19
  Requires-Dist: pytest >=8.0.2 ; extra == 'all'
26
20
  Requires-Dist: pytest-cov >=4.1.0 ; extra == 'all'
27
21
  Requires-Dist: twine ; extra == 'all'
22
+ Requires-Dist: setuptools ; extra == 'all'
23
+ Requires-Dist: wheel ; extra == 'all'
28
24
  Provides-Extra: dev
29
- Requires-Dist: pre-commit ; extra == 'dev'
30
- Requires-Dist: docformatter ; extra == 'dev'
31
- Requires-Dist: black ==24.1.1 ; extra == 'dev'
32
- Requires-Dist: isort ==5.13.2 ; extra == 'dev'
33
- Requires-Dist: flake8 ==7.0.0 ; extra == 'dev'
34
- Requires-Dist: flake8-bugbear ==24.2.6 ; extra == 'dev'
35
25
  Requires-Dist: pytest >=8.0.2 ; extra == 'dev'
36
26
  Requires-Dist: pytest-cov >=4.1.0 ; extra == 'dev'
37
27
  Requires-Dist: twine ; extra == 'dev'
28
+ Requires-Dist: setuptools ; extra == 'dev'
29
+ Requires-Dist: wheel ; extra == 'dev'
38
30
 
39
31
 
40
32
  # Cartesia Python API Library
@@ -60,13 +52,14 @@ client = CartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
60
52
  voices = client.get_voices()
61
53
  voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
62
54
  transcript = "Hello! Welcome to Cartesia"
55
+ model_id = "genial-planet-1346" # (Optional) We'll specify a default if you don't have a specific model in mind
63
56
 
64
57
  p = pyaudio.PyAudio()
65
58
 
66
59
  stream = None
67
60
 
68
61
  # Generate and stream audio
69
- for output in client.generate(transcript=transcript, voice=voice, stream=True):
62
+ for output in client.generate(transcript=transcript, voice=voice, model_id=model_id, stream=True):
70
63
  buffer = output["audio"]
71
64
  rate = output["sampling_rate"]
72
65
 
@@ -84,26 +77,68 @@ stream.close()
84
77
  p.terminate()
85
78
  ```
86
79
 
87
- If you are using Jupyter Notebook or JupyterLab, you can use IPython.display.Audio to play the generated audio directly in the notebook. Here's an example:
80
+ You can also use the async client if you want to make asynchronous API calls:
81
+ ```python
82
+ from cartesia.tts import AsyncCartesiaTTS
83
+ import asyncio
84
+ import pyaudio
85
+ import os
86
+
87
+ async def write_stream():
88
+ client = AsyncCartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
89
+ voices = client.get_voices()
90
+ voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
91
+ transcript = "Hello! Welcome to Cartesia"
92
+ model_id = "genial-planet-1346" # (Optional) We'll specify a default if you don't have a specific model in mind
93
+
94
+ p = pyaudio.PyAudio()
95
+
96
+ stream = None
97
+
98
+ # Generate and stream audio
99
+ async for output in await client.generate(transcript=transcript, voice=voice, model_id=model_id, stream=True):
100
+ buffer = output["audio"]
101
+ rate = output["sampling_rate"]
102
+
103
+ if not stream:
104
+ stream = p.open(format=pyaudio.paFloat32,
105
+ channels=1,
106
+ rate=rate,
107
+ output=True)
108
+
109
+ # Write the audio data to the stream
110
+ stream.write(buffer)
111
+
112
+ stream.stop_stream()
113
+ stream.close()
114
+ p.terminate()
115
+
116
+ asyncio.run(write_stream())
117
+ ```
118
+
119
+ If you are using Jupyter Notebook or JupyterLab, you can use IPython.display.Audio to play the generated audio directly in the notebook.
120
+ Additionally, in these notebook examples we show how to use the client as a context manager (though this is not required).
88
121
 
89
122
  ```python
90
- from cartesia.tts import CartesiaTTS
91
123
  from IPython.display import Audio
92
124
  import io
93
125
  import os
126
+ import numpy as np
94
127
 
95
- client = CartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
96
- voices = client.get_voices()
97
- voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
98
- transcript = "Hello! Welcome to Cartesia"
128
+ from cartesia.tts import CartesiaTTS
99
129
 
100
- # Create a BytesIO object to store the audio data
101
- audio_data = io.BytesIO()
130
+ with CartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY")) as client:
131
+ voices = client.get_voices()
132
+ voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
133
+ transcript = "Hello! Welcome to Cartesia"
102
134
 
103
- # Generate and stream audio
104
- for output in client.generate(transcript=transcript, voice=voice, stream=True):
105
- buffer = output["audio"]
106
- audio_data.write(buffer)
135
+ # Create a BytesIO object to store the audio data
136
+ audio_data = io.BytesIO()
137
+
138
+ # Generate and stream audio
139
+ for output in client.generate(transcript=transcript, voice=voice, stream=True):
140
+ buffer = output["audio"]
141
+ audio_data.write(buffer)
107
142
 
108
143
  # Set the cursor position to the beginning of the BytesIO object
109
144
  audio_data.seek(0)
@@ -115,25 +150,27 @@ audio = Audio(np.frombuffer(audio_data.read(), dtype=np.float32), rate=output["s
115
150
  display(audio)
116
151
  ```
117
152
 
118
- You can also use the async client if you want to make asynchronous API calls. The usage is very similar:
153
+ Below is the same example using the async client:
119
154
  ```python
120
- from cartesia.tts import AsyncCartesiaTTS
121
155
  from IPython.display import Audio
122
156
  import io
123
157
  import os
158
+ import numpy as np
124
159
 
125
- client = AsyncCartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
126
- voices = client.get_voices()
127
- voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
128
- transcript = "Hello! Welcome to Cartesia"
160
+ from cartesia.tts import AsyncCartesiaTTS
129
161
 
130
- # Create a BytesIO object to store the audio data
131
- audio_data = io.BytesIO()
162
+ async with AsyncCartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY")) as client:
163
+ voices = client.get_voices()
164
+ voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
165
+ transcript = "Hello! Welcome to Cartesia"
132
166
 
133
- # Generate and stream audio
134
- async for output in client.generate(transcript=transcript, voice=voice, stream=True):
135
- buffer = output["audio"]
136
- audio_data.write(buffer)
167
+ # Create a BytesIO object to store the audio data
168
+ audio_data = io.BytesIO()
169
+
170
+ # Generate and stream audio
171
+ async for output in await client.generate(transcript=transcript, voice=voice, stream=True):
172
+ buffer = output["audio"]
173
+ audio_data.write(buffer)
137
174
 
138
175
  # Set the cursor position to the beginning of the BytesIO object
139
176
  audio_data.seek(0)
@@ -0,0 +1,8 @@
1
+ cartesia/__init__.py,sha256=uIc9xGNPs8_A6eAvbTUY1geazunYoEZVWFKhCwC9TRA,102
2
+ cartesia/tts.py,sha256=YjOW8mlvvPbHblhcMUY71RsKn77K_WQi8ySok3ifeJg,26734
3
+ cartesia/utils.py,sha256=GoTJe8LZ3WpS4hXkwoZauPYjo7Mbx7BvbBjAX5vEbwg,3024
4
+ cartesia/version.py,sha256=QiiYsv0kcJaB8wCWyT-FnI2b6be87HA-CrrIUn8LQhg,22
5
+ cartesia-0.0.6.dist-info/METADATA,sha256=yhq7LSvLrboBPI3IOcLTvaneisqhq-v1VMQ0sKBq8kk,5974
6
+ cartesia-0.0.6.dist-info/WHEEL,sha256=DZajD4pwLWue70CAfc7YaxT1wLUciNBvN_TTcvXpltE,110
7
+ cartesia-0.0.6.dist-info/top_level.txt,sha256=rTX4HnnCegMxl1FK9czpVC7GAvf3SwDzPG65qP-BS4w,9
8
+ cartesia-0.0.6.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.41.2)
2
+ Generator: bdist_wheel (0.43.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py2-none-any
5
5
  Tag: py3-none-any
@@ -1,7 +0,0 @@
1
- cartesia/__init__.py,sha256=m8BX-qLjsMoI_JZtgf3jNi8R3cBZqYy-z4oEhYeJLdI,64
2
- cartesia/tts.py,sha256=yPLz41AR0oAYPUNW48mqmwEEbLBHCnbaK_wPT0iFBVk,20543
3
- cartesia/version.py,sha256=VkI5lk2CFatZR200RqGd8cBjTnMDmhtZW7DI6mPe6n4,25
4
- cartesia-0.0.5rc1.dist-info/METADATA,sha256=632D6iZ2IU3MLySAnMtwV2zQA38XkQv1rfFF4iRdAco,4893
5
- cartesia-0.0.5rc1.dist-info/WHEEL,sha256=iYlv5fX357PQyRT2o6tw1bN-YcKFFHKqB_LwHO5wP-g,110
6
- cartesia-0.0.5rc1.dist-info/top_level.txt,sha256=rTX4HnnCegMxl1FK9czpVC7GAvf3SwDzPG65qP-BS4w,9
7
- cartesia-0.0.5rc1.dist-info/RECORD,,