cartesia 0.0.6__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cartesia
3
- Version: 0.0.6
3
+ Version: 0.1.1
4
4
  Summary: The official Python library for the Cartesia API.
5
5
  Home-page:
6
6
  Author: Cartesia, Inc.
@@ -35,7 +35,7 @@ import os
35
35
 
36
36
  client = CartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
37
37
  voices = client.get_voices()
38
- voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
38
+ voice = client.get_voice_embedding(voice_id=voices["Ted"]["id"])
39
39
  transcript = "Hello! Welcome to Cartesia"
40
40
  model_id = "genial-planet-1346" # (Optional) We'll specify a default if you don't have a specific model in mind
41
41
 
@@ -72,7 +72,7 @@ import os
72
72
  async def write_stream():
73
73
  client = AsyncCartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
74
74
  voices = client.get_voices()
75
- voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
75
+ voice = client.get_voice_embedding(voice_id=voices["Ted"]["id"])
76
76
  transcript = "Hello! Welcome to Cartesia"
77
77
  model_id = "genial-planet-1346" # (Optional) We'll specify a default if you don't have a specific model in mind
78
78
 
@@ -114,7 +114,7 @@ from cartesia.tts import CartesiaTTS
114
114
 
115
115
  with CartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY")) as client:
116
116
  voices = client.get_voices()
117
- voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
117
+ voice = client.get_voice_embedding(voice_id=voices["Ted"]["id"])
118
118
  transcript = "Hello! Welcome to Cartesia"
119
119
 
120
120
  # Create a BytesIO object to store the audio data
@@ -146,7 +146,7 @@ from cartesia.tts import AsyncCartesiaTTS
146
146
 
147
147
  async with AsyncCartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY")) as client:
148
148
  voices = client.get_voices()
149
- voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
149
+ voice = client.get_voice_embedding(voice_id=voices["Ted"]["id"])
150
150
  transcript = "Hello! Welcome to Cartesia"
151
151
 
152
152
  # Create a BytesIO object to store the audio data
@@ -19,7 +19,7 @@ import os
19
19
 
20
20
  client = CartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
21
21
  voices = client.get_voices()
22
- voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
22
+ voice = client.get_voice_embedding(voice_id=voices["Ted"]["id"])
23
23
  transcript = "Hello! Welcome to Cartesia"
24
24
  model_id = "genial-planet-1346" # (Optional) We'll specify a default if you don't have a specific model in mind
25
25
 
@@ -56,7 +56,7 @@ import os
56
56
  async def write_stream():
57
57
  client = AsyncCartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
58
58
  voices = client.get_voices()
59
- voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
59
+ voice = client.get_voice_embedding(voice_id=voices["Ted"]["id"])
60
60
  transcript = "Hello! Welcome to Cartesia"
61
61
  model_id = "genial-planet-1346" # (Optional) We'll specify a default if you don't have a specific model in mind
62
62
 
@@ -98,7 +98,7 @@ from cartesia.tts import CartesiaTTS
98
98
 
99
99
  with CartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY")) as client:
100
100
  voices = client.get_voices()
101
- voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
101
+ voice = client.get_voice_embedding(voice_id=voices["Ted"]["id"])
102
102
  transcript = "Hello! Welcome to Cartesia"
103
103
 
104
104
  # Create a BytesIO object to store the audio data
@@ -130,7 +130,7 @@ from cartesia.tts import AsyncCartesiaTTS
130
130
 
131
131
  async with AsyncCartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY")) as client:
132
132
  voices = client.get_voices()
133
- voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
133
+ voice = client.get_voice_embedding(voice_id=voices["Ted"]["id"])
134
134
  transcript = "Hello! Welcome to Cartesia"
135
135
 
136
136
  # Create a BytesIO object to store the audio data
@@ -0,0 +1,43 @@
1
+ from enum import Enum
2
+ from typing import List, Optional, TypedDict, Union
3
+
4
+ try:
5
+ import numpy as np
6
+
7
+ _NUMPY_AVAILABLE = True
8
+ except ImportError:
9
+ _NUMPY_AVAILABLE = False
10
+
11
+
12
+ class AudioDataReturnType(Enum):
13
+ BYTES = "bytes"
14
+ ARRAY = "array"
15
+
16
+
17
+ class AudioOutputFormat(Enum):
18
+ """Supported output formats for the audio."""
19
+
20
+ FP32 = "fp32" # float32
21
+ PCM = "pcm" # 16-bit signed integer PCM
22
+ FP32_16000 = "fp32_16000" # float32, 16 kHz
23
+ FP32_22050 = "fp32_22050" # float32, 22.05 kHz
24
+ FP32_44100 = "fp32_44100" # float32, 44.1 kHz
25
+ PCM_16000 = "pcm_16000" # 16-bit signed integer PCM, 16 kHz
26
+ PCM_22050 = "pcm_22050" # 16-bit signed integer PCM, 22.05 kHz
27
+ PCM_44100 = "pcm_44100" # 16-bit signed integer PCM, 44.1 kHz
28
+ MULAW_8000 = "mulaw_8000" # 8-bit mu-law, 8 kHz
29
+
30
+
31
+ class AudioOutput(TypedDict):
32
+ audio: Union[bytes, "np.ndarray"]
33
+ sampling_rate: int
34
+
35
+
36
+ Embedding = List[float]
37
+
38
+
39
+ class VoiceMetadata(TypedDict):
40
+ id: str
41
+ name: str
42
+ description: str
43
+ embedding: Optional[Embedding]
@@ -4,7 +4,17 @@ import json
4
4
  import os
5
5
  import uuid
6
6
  from types import TracebackType
7
- from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, TypedDict, Union
7
+ from typing import (
8
+ Any,
9
+ AsyncGenerator,
10
+ Dict,
11
+ Generator,
12
+ List,
13
+ Optional,
14
+ Tuple,
15
+ TypedDict,
16
+ Union,
17
+ )
8
18
 
9
19
  import aiohttp
10
20
  import httpx
@@ -13,6 +23,21 @@ import requests
13
23
  from websockets.sync.client import connect
14
24
 
15
25
  from cartesia.utils import retry_on_connection_error, retry_on_connection_error_async
26
+ from cartesia._types import (
27
+ AudioDataReturnType,
28
+ AudioOutputFormat,
29
+ AudioOutput,
30
+ Embedding,
31
+ VoiceMetadata,
32
+ )
33
+
34
+ try:
35
+ import numpy as np
36
+
37
+ _NUMPY_AVAILABLE = True
38
+ except ImportError:
39
+ _NUMPY_AVAILABLE = False
40
+
16
41
 
17
42
  DEFAULT_MODEL_ID = ""
18
43
  DEFAULT_BASE_URL = "api.cartesia.ai"
@@ -25,20 +50,6 @@ MAX_RETRIES = 3
25
50
 
26
51
  logger = logging.getLogger(__name__)
27
52
 
28
- class AudioOutput(TypedDict):
29
- audio: bytes
30
- sampling_rate: int
31
-
32
-
33
- Embedding = List[float]
34
-
35
-
36
- class VoiceMetadata(TypedDict):
37
- id: str
38
- name: str
39
- description: str
40
- embedding: Optional[Embedding]
41
-
42
53
 
43
54
  def update_buffer(buffer: str, chunk_bytes: bytes) -> Tuple[str, List[Dict[str, Any]]]:
44
55
  buffer += chunk_bytes.decode("utf-8")
@@ -79,7 +90,6 @@ class CartesiaTTS:
79
90
  and generate speech from text.
80
91
 
81
92
  The client also supports generating audio using a websocket for lower latency.
82
- To enable interrupt handling along the websocket, set `experimental_ws_handle_interrupts=True`.
83
93
 
84
94
  Examples:
85
95
  >>> client = CartesiaTTS()
@@ -102,21 +112,17 @@ class CartesiaTTS:
102
112
  ... audio, sr = audio_chunk["audio"], audio_chunk["sampling_rate"]
103
113
  """
104
114
 
105
- def __init__(self, *, api_key: str = None, experimental_ws_handle_interrupts: bool = False):
115
+ def __init__(self, *, api_key: str = None):
106
116
  """Args:
107
117
  api_key: The API key to use for authorization.
108
118
  If not specified, the API key will be read from the environment variable
109
119
  `CARTESIA_API_KEY`.
110
- experimental_ws_handle_interrupts: Whether to handle interrupts when generating
111
- audio using the websocket. This is an experimental feature and may have bugs
112
- or be deprecated in the future.
113
120
  """
114
121
  self.base_url = os.environ.get("CARTESIA_BASE_URL", DEFAULT_BASE_URL)
115
122
  self.api_key = api_key or os.environ.get("CARTESIA_API_KEY")
116
123
  self.api_version = os.environ.get("CARTESIA_API_VERSION", DEFAULT_API_VERSION)
117
124
  self.headers = {"X-API-Key": self.api_key, "Content-Type": "application/json"}
118
125
  self.websocket = None
119
- self.experimental_ws_handle_interrupts = experimental_ws_handle_interrupts
120
126
 
121
127
  def get_voices(self, skip_embeddings: bool = True) -> Dict[str, VoiceMetadata]:
122
128
  """Returns a mapping from voice name -> voice metadata.
@@ -165,7 +171,9 @@ class CartesiaTTS:
165
171
  voice["embedding"] = json.loads(voice["embedding"])
166
172
  return {voice["name"]: voice for voice in voices}
167
173
 
168
- @retry_on_connection_error(max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger)
174
+ @retry_on_connection_error(
175
+ max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger
176
+ )
169
177
  def get_voice_embedding(
170
178
  self, *, voice_id: str = None, filepath: str = None, link: str = None
171
179
  ) -> Embedding:
@@ -222,16 +230,28 @@ class CartesiaTTS:
222
230
  """
223
231
  if self.websocket is None or self._is_websocket_closed():
224
232
  route = "audio/websocket"
225
- if self.experimental_ws_handle_interrupts:
226
- route = f"experimental/{route}"
227
233
  self.websocket = connect(f"{self._ws_url()}/{route}?api_key={self.api_key}")
228
234
 
229
235
  def _is_websocket_closed(self):
230
236
  return self.websocket.socket.fileno() == -1
231
237
 
232
238
  def _check_inputs(
233
- self, transcript: str, duration: Optional[float], chunk_time: Optional[float]
239
+ self,
240
+ transcript: str,
241
+ duration: Optional[float],
242
+ chunk_time: Optional[float],
243
+ output_format: Union[str, AudioOutputFormat],
244
+ data_rtype: Union[str, AudioDataReturnType],
234
245
  ):
246
+ # This will try the casting and raise an error.
247
+ _ = AudioOutputFormat(output_format)
248
+
249
+ if AudioDataReturnType(data_rtype) == AudioDataReturnType.ARRAY and not _NUMPY_AVAILABLE:
250
+ raise ImportError(
251
+ "The 'numpy' package is required to use the 'array' return type. "
252
+ "Please install 'numpy' or use 'bytes' as the return type."
253
+ )
254
+
235
255
  if chunk_time is not None:
236
256
  if chunk_time < 0.1 or chunk_time > 0.5:
237
257
  raise ValueError("`chunk_time` must be between 0.1 and 0.5")
@@ -249,7 +269,7 @@ class CartesiaTTS:
249
269
  transcript: str,
250
270
  voice: Embedding,
251
271
  model_id: str,
252
- output_format: str,
272
+ output_format: AudioOutputFormat,
253
273
  duration: int = None,
254
274
  chunk_time: float = None,
255
275
  ) -> Dict[str, Any]:
@@ -259,6 +279,7 @@ class CartesiaTTS:
259
279
  filtered out otherwise.
260
280
  """
261
281
  body = dict(transcript=transcript, model_id=model_id, voice=voice)
282
+ output_format = output_format.value
262
283
 
263
284
  optional_body = dict(
264
285
  duration=duration,
@@ -279,7 +300,8 @@ class CartesiaTTS:
279
300
  chunk_time: float = None,
280
301
  stream: bool = False,
281
302
  websocket: bool = True,
282
- output_format: str = "fp32",
303
+ output_format: Union[str, AudioOutputFormat] = "fp32",
304
+ data_rtype: str = "bytes",
283
305
  ) -> Union[AudioOutput, Generator[AudioOutput, None, None]]:
284
306
  """Generate audio from a transcript.
285
307
 
@@ -293,6 +315,9 @@ class CartesiaTTS:
293
315
  If True this function returns a generator. False by default.
294
316
  websocket (bool, optional): Whether to use a websocket for streaming audio.
295
317
  Using the websocket reduces latency by pre-poning the handshake. True by default.
318
+ data_rtype: The return type for the 'data' key in the dictionary.
319
+ One of `'byte' | 'array'`.
320
+ Note this field is experimental and may be deprecated in the future.
296
321
 
297
322
  Returns:
298
323
  A generator if `stream` is True, otherwise a dictionary.
@@ -300,13 +325,16 @@ class CartesiaTTS:
300
325
  * "audio": The audio as a bytes buffer.
301
326
  * "sampling_rate": The sampling rate of the audio.
302
327
  """
303
- self._check_inputs(transcript, duration, chunk_time)
328
+ self._check_inputs(transcript, duration, chunk_time, output_format, data_rtype)
329
+
330
+ data_rtype = AudioDataReturnType(data_rtype)
331
+ output_format = AudioOutputFormat(output_format)
304
332
 
305
333
  body = self._generate_request_body(
306
- transcript=transcript,
307
- voice=voice,
334
+ transcript=transcript,
335
+ voice=voice,
308
336
  model_id=model_id,
309
- duration=duration,
337
+ duration=duration,
310
338
  chunk_time=chunk_time,
311
339
  output_format=output_format,
312
340
  )
@@ -316,6 +344,9 @@ class CartesiaTTS:
316
344
  else:
317
345
  generator = self._generate_http_wrapper(body)
318
346
 
347
+ generator = self._postprocess_audio(
348
+ generator, data_rtype=data_rtype, output_format=output_format
349
+ )
319
350
  if stream:
320
351
  return generator
321
352
 
@@ -326,9 +357,45 @@ class CartesiaTTS:
326
357
  sampling_rate = chunk["sampling_rate"]
327
358
  chunks.append(chunk["audio"])
328
359
 
329
- return {"audio": b"".join(chunks), "sampling_rate": sampling_rate}
360
+ if data_rtype == AudioDataReturnType.ARRAY:
361
+ cat = np.concatenate
362
+ else:
363
+ cat = b"".join
364
+
365
+ return {"audio": cat(chunks), "sampling_rate": sampling_rate}
366
+
367
+ def _postprocess_audio(
368
+ self,
369
+ generator: Generator[AudioOutput, None, None],
370
+ *,
371
+ data_rtype: AudioDataReturnType,
372
+ output_format: AudioOutputFormat,
373
+ ) -> Generator[AudioOutput, None, None]:
374
+ """Perform postprocessing on the generator outputs.
375
+
376
+ The postprocessing should be minimal (e.g. converting to array, casting dtype).
377
+ This code should not perform heavy operations like changing the sampling rate.
378
+
379
+ Args:
380
+ generator: A generator that yields audio chunks.
381
+ data_rtype: The data return type.
382
+ output_format: The output format for the audio.
383
+
384
+ Returns:
385
+ A generator that yields audio chunks.
386
+ """
387
+ dtype = None
388
+ if data_rtype == AudioDataReturnType.ARRAY:
389
+ dtype = np.float32 if "fp32" in output_format.value else np.int16
390
+
391
+ for chunk in generator:
392
+ if dtype is not None:
393
+ chunk["audio"] = np.frombuffer(chunk["audio"], dtype=dtype)
394
+ yield chunk
330
395
 
331
- @retry_on_connection_error(max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger)
396
+ @retry_on_connection_error(
397
+ max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger
398
+ )
332
399
  def _generate_http_wrapper(self, body: Dict[str, Any]):
333
400
  """Need to wrap the http generator in a function for the retry decorator to work."""
334
401
  try:
@@ -389,15 +456,6 @@ class CartesiaTTS:
389
456
  break
390
457
 
391
458
  yield convert_response(response, include_context_id)
392
-
393
- if self.experimental_ws_handle_interrupts:
394
- self.websocket.send(json.dumps({"context_id": context_id}))
395
- except GeneratorExit:
396
- # The exit is only called when the generator is garbage collected.
397
- # It may not be called directly after a break statement.
398
- # However, the generator will be automatically cancelled on the next request.
399
- if self.experimental_ws_handle_interrupts:
400
- self.websocket.send(json.dumps({"context_id": context_id, "action": "cancel"}))
401
459
  except Exception as e:
402
460
  # Close the websocket connection if an error occurs.
403
461
  if self.websocket and not self._is_websocket_closed():
@@ -408,23 +466,6 @@ class CartesiaTTS:
408
466
  if self.websocket and not self._is_websocket_closed():
409
467
  self.websocket.close()
410
468
 
411
- @retry_on_connection_error(max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger)
412
- def transcribe(self, raw_audio: Union[bytes, str]) -> str:
413
- raw_audio_bytes, headers = self.prepare_audio_and_headers(raw_audio)
414
- response = httpx.post(
415
- f"{self._http_url()}/audio/transcriptions",
416
- headers=headers,
417
- files={"clip": ("input.wav", raw_audio_bytes)},
418
- timeout=DEFAULT_TIMEOUT,
419
- )
420
-
421
- if not response.is_success:
422
- raise ValueError(f"Failed to transcribe audio. Error: {response.text()}")
423
-
424
- transcript = response.json()
425
- return transcript["text"]
426
-
427
-
428
469
  def prepare_audio_and_headers(
429
470
  self, raw_audio: Union[bytes, str]
430
471
  ) -> Tuple[bytes, Dict[str, Any]]:
@@ -466,13 +507,11 @@ class CartesiaTTS:
466
507
 
467
508
 
468
509
  class AsyncCartesiaTTS(CartesiaTTS):
469
- def __init__(self, *, api_key: str = None, experimental_ws_handle_interrupts: bool = False):
510
+ def __init__(self, *, api_key: str = None):
470
511
  self._session = None
471
512
  self._loop = None
472
- super().__init__(
473
- api_key=api_key, experimental_ws_handle_interrupts=experimental_ws_handle_interrupts
474
- )
475
-
513
+ super().__init__(api_key=api_key)
514
+
476
515
  async def _get_session(self):
477
516
  current_loop = asyncio.get_event_loop()
478
517
  if self._loop is not current_loop:
@@ -481,29 +520,25 @@ class AsyncCartesiaTTS(CartesiaTTS):
481
520
  if self._session is None or self._session.closed:
482
521
  timeout = aiohttp.ClientTimeout(total=DEFAULT_TIMEOUT)
483
522
  connector = aiohttp.TCPConnector(limit=DEFAULT_NUM_CONNECTIONS)
484
- self._session = aiohttp.ClientSession(
485
- timeout=timeout, connector=connector
486
- )
523
+ self._session = aiohttp.ClientSession(timeout=timeout, connector=connector)
487
524
  self._loop = current_loop
488
525
  return self._session
489
-
526
+
490
527
  async def refresh_websocket(self):
491
528
  """Refresh the websocket connection."""
492
529
  if self.websocket is None or self._is_websocket_closed():
493
530
  route = "audio/websocket"
494
- if self.experimental_ws_handle_interrupts:
495
- route = f"experimental/{route}"
496
531
  session = await self._get_session()
497
532
  self.websocket = await session.ws_connect(
498
533
  f"{self._ws_url()}/{route}?api_key={self.api_key}"
499
534
  )
500
-
535
+
501
536
  def _is_websocket_closed(self):
502
537
  return self.websocket.closed
503
538
 
504
539
  async def close(self):
505
540
  """This method closes the websocket and the session.
506
-
541
+
507
542
  It is *strongly* recommended to call this method when you are done using the client.
508
543
  """
509
544
  if self.websocket is not None and not self._is_websocket_closed():
@@ -521,35 +556,22 @@ class AsyncCartesiaTTS(CartesiaTTS):
521
556
  chunk_time: float = None,
522
557
  stream: bool = False,
523
558
  websocket: bool = True,
524
- output_format: str = "fp32"
559
+ output_format: Union[str, AudioOutputFormat] = "fp32",
560
+ data_rtype: Union[str, AudioDataReturnType] = "bytes",
525
561
  ) -> Union[AudioOutput, AsyncGenerator[AudioOutput, None]]:
526
562
  """Asynchronously generate audio from a transcript.
527
- NOTE: This overrides the non-asynchronous generate method from the base class.
528
-
529
- Args:
530
- transcript (str): The text to generate audio for.
531
- voice (Embedding (List[float])): The voice to use for generating audio.
532
- duration (int, optional): The maximum duration of the audio in seconds.
533
- chunk_time (float, optional): How long each audio segment should be in seconds.
534
- This should not need to be adjusted.
535
- stream (bool, optional): Whether to stream the audio or not.
536
- If True this function returns a generator. False by default.
537
- websocket (bool, optional): Whether to use a websocket for streaming audio.
538
- Using the websocket reduces latency by pre-poning the handshake. True by default.
539
563
 
540
- Returns:
541
- A generator if `stream` is True, otherwise a dictionary.
542
- Dictionary from both generator and non-generator return types have the following keys:
543
- * "audio": The audio as a bytes buffer.
544
- * "sampling_rate": The sampling rate of the audio.
564
+ For more information on the arguments, see the synchronous :meth:`CartesiaTTS.generate`.
545
565
  """
546
- self._check_inputs(transcript, duration, chunk_time)
566
+ self._check_inputs(transcript, duration, chunk_time, output_format, data_rtype)
567
+ data_rtype = AudioDataReturnType(data_rtype)
568
+ output_format = AudioOutputFormat(output_format)
547
569
 
548
570
  body = self._generate_request_body(
549
- transcript=transcript,
571
+ transcript=transcript,
550
572
  voice=voice,
551
573
  model_id=model_id,
552
- duration=duration,
574
+ duration=duration,
553
575
  chunk_time=chunk_time,
554
576
  output_format=output_format,
555
577
  )
@@ -558,7 +580,9 @@ class AsyncCartesiaTTS(CartesiaTTS):
558
580
  generator = self._generate_ws(body)
559
581
  else:
560
582
  generator = self._generate_http_wrapper(body)
561
-
583
+ generator = self._postprocess_audio(
584
+ generator, data_rtype=data_rtype, output_format=output_format
585
+ )
562
586
  if stream:
563
587
  return generator
564
588
 
@@ -569,14 +593,38 @@ class AsyncCartesiaTTS(CartesiaTTS):
569
593
  sampling_rate = chunk["sampling_rate"]
570
594
  chunks.append(chunk["audio"])
571
595
 
572
- return {"audio": b"".join(chunks), "sampling_rate": sampling_rate}
596
+ if data_rtype == AudioDataReturnType.ARRAY:
597
+ cat = np.concatenate
598
+ else:
599
+ cat = b"".join
600
+
601
+ return {"audio": cat(chunks), "sampling_rate": sampling_rate}
602
+
603
+ async def _postprocess_audio(
604
+ self,
605
+ generator: AsyncGenerator[AudioOutput, None],
606
+ *,
607
+ data_rtype: AudioDataReturnType,
608
+ output_format: AudioOutputFormat,
609
+ ) -> AsyncGenerator[AudioOutput, None]:
610
+ """See :meth:`CartesiaTTS._postprocess_audio`."""
611
+ dtype = None
612
+ if data_rtype == AudioDataReturnType.ARRAY:
613
+ dtype = np.float32 if "fp32" in output_format.value else np.int16
614
+
615
+ async for chunk in generator:
616
+ if dtype is not None:
617
+ chunk["audio"] = np.frombuffer(chunk["audio"], dtype=dtype)
618
+ yield chunk
573
619
 
574
- @retry_on_connection_error_async(max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger)
620
+ @retry_on_connection_error_async(
621
+ max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger
622
+ )
575
623
  async def _generate_http_wrapper(self, body: Dict[str, Any]):
576
624
  """Need to wrap the http generator in a function for the retry decorator to work."""
577
625
  try:
578
- async for chunk in self._generate_http(body):
579
- yield chunk
626
+ async for chunk in self._generate_http(body):
627
+ yield chunk
580
628
  except Exception as e:
581
629
  logger.error(f"Failed to generate audio. {e}")
582
630
  raise e
@@ -605,10 +653,6 @@ class AsyncCartesiaTTS(CartesiaTTS):
605
653
 
606
654
  async def _generate_ws(self, body: Dict[str, Any], *, context_id: str = None):
607
655
  include_context_id = bool(context_id)
608
- route = "audio/websocket"
609
- if self.experimental_ws_handle_interrupts:
610
- route = f"experimental/{route}"
611
-
612
656
  if not self.websocket or self._is_websocket_closed():
613
657
  await self.refresh_websocket()
614
658
 
@@ -624,39 +668,16 @@ class AsyncCartesiaTTS(CartesiaTTS):
624
668
  break
625
669
 
626
670
  yield convert_response(response, include_context_id)
627
-
628
- if self.experimental_ws_handle_interrupts:
629
- await ws.send_json({"context_id": context_id})
630
- except GeneratorExit:
631
- # The exit is only called when the generator is garbage collected.
632
- # It may not be called directly after a break statement.
633
- # However, the generator will be automatically cancelled on the next request.
634
- if self.experimental_ws_handle_interrupts:
635
- await ws.send_json({"context_id": context_id, "action": "cancel"})
636
671
  except Exception as e:
637
672
  if self.websocket and not self._is_websocket_closed():
638
673
  await self.websocket.close()
639
- raise RuntimeError(f"Failed to generate audio. {await response.text()}") from e
674
+ error_msg_end = "" if response is None else f": {await response.text()}"
675
+ raise RuntimeError(f"Failed to generate audio{error_msg_end}") from e
640
676
  finally:
641
677
  # Ensure the websocket is ultimately closed.
642
678
  if self.websocket and not self._is_websocket_closed():
643
679
  await self.websocket.close()
644
680
 
645
- async def transcribe(self, raw_audio: Union[bytes, str]) -> str:
646
- raw_audio_bytes, headers = self.prepare_audio_and_headers(raw_audio)
647
- data = aiohttp.FormData()
648
- data.add_field("clip", raw_audio_bytes, filename="input.wav", content_type="audio/wav")
649
- session = await self._get_session()
650
-
651
- async with session.post(
652
- f"{self._http_url()}/audio/transcriptions", headers=headers, data=data
653
- ) as response:
654
- if not response.ok:
655
- raise ValueError(f"Failed to transcribe audio. Error: {await response.text()}")
656
-
657
- transcript = await response.json()
658
- return transcript["text"]
659
-
660
681
  def __del__(self):
661
682
  try:
662
683
  loop = asyncio.get_running_loop()
@@ -7,6 +7,7 @@ from http.client import RemoteDisconnected
7
7
  from httpx import TimeoutException
8
8
  from requests.exceptions import ConnectionError
9
9
 
10
+
10
11
  def retry_on_connection_error(max_retries=3, backoff_factor=1, logger=None):
11
12
  """Retry a function if a ConnectionError, RemoteDisconnected, ServerDisconnectedError, or TimeoutException occurs.
12
13
 
@@ -15,6 +16,7 @@ def retry_on_connection_error(max_retries=3, backoff_factor=1, logger=None):
15
16
  backoff_factor (int): The factor to increase the delay between retries.
16
17
  logger (logging.Logger): The logger to use for logging.
17
18
  """
19
+
18
20
  def decorator(func):
19
21
  @wraps(func)
20
22
  def wrapper(*args, **kwargs):
@@ -22,18 +24,28 @@ def retry_on_connection_error(max_retries=3, backoff_factor=1, logger=None):
22
24
  while retry_count < max_retries:
23
25
  try:
24
26
  return func(*args, **kwargs)
25
- except (ConnectionError, RemoteDisconnected, ServerDisconnectedError, TimeoutException) as e:
27
+ except (
28
+ ConnectionError,
29
+ RemoteDisconnected,
30
+ ServerDisconnectedError,
31
+ TimeoutException,
32
+ ) as e:
26
33
  logger.info(f"Retrying after exception: {e}")
27
34
  retry_count += 1
28
35
  if retry_count < max_retries:
29
36
  delay = backoff_factor * (2 ** (retry_count - 1))
30
- logger.warn(f"Attempt {retry_count + 1}/{max_retries} in {delay} seconds...")
37
+ logger.warn(
38
+ f"Attempt {retry_count + 1}/{max_retries} in {delay} seconds..."
39
+ )
31
40
  time.sleep(delay)
32
41
  else:
33
42
  raise Exception(f"Exception occurred after {max_retries} tries.") from e
43
+
34
44
  return wrapper
45
+
35
46
  return decorator
36
47
 
48
+
37
49
  def retry_on_connection_error_async(max_retries=3, backoff_factor=1, logger=None):
38
50
  """Retry an asynchronous function if a ConnectionError, RemoteDisconnected, ServerDisconnectedError, or TimeoutException occurs.
39
51
 
@@ -42,6 +54,7 @@ def retry_on_connection_error_async(max_retries=3, backoff_factor=1, logger=None
42
54
  backoff_factor (int): The factor to increase the delay between retries.
43
55
  logger (logging.Logger): The logger to use for logging.
44
56
  """
57
+
45
58
  def decorator(func):
46
59
  @wraps(func)
47
60
  async def wrapper(*args, **kwargs):
@@ -52,14 +65,23 @@ def retry_on_connection_error_async(max_retries=3, backoff_factor=1, logger=None
52
65
  yield chunk
53
66
  # If the function completes without raising an exception return
54
67
  return
55
- except (ConnectionError, RemoteDisconnected, ServerDisconnectedError, TimeoutException) as e:
68
+ except (
69
+ ConnectionError,
70
+ RemoteDisconnected,
71
+ ServerDisconnectedError,
72
+ TimeoutException,
73
+ ) as e:
56
74
  logger.info(f"Retrying after exception: {e}")
57
75
  retry_count += 1
58
76
  if retry_count < max_retries:
59
77
  delay = backoff_factor * (2 ** (retry_count - 1))
60
- logger.warn(f"Attempt {retry_count + 1}/{max_retries} in {delay} seconds...")
78
+ logger.warn(
79
+ f"Attempt {retry_count + 1}/{max_retries} in {delay} seconds..."
80
+ )
61
81
  await asyncio.sleep(delay)
62
82
  else:
63
83
  raise Exception(f"Exception occurred after {max_retries} tries.") from e
84
+
64
85
  return wrapper
65
- return decorator
86
+
87
+ return decorator
@@ -0,0 +1 @@
1
+ __version__ = "0.1.1"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cartesia
3
- Version: 0.0.6
3
+ Version: 0.1.1
4
4
  Summary: The official Python library for the Cartesia API.
5
5
  Home-page:
6
6
  Author: Cartesia, Inc.
@@ -35,7 +35,7 @@ import os
35
35
 
36
36
  client = CartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
37
37
  voices = client.get_voices()
38
- voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
38
+ voice = client.get_voice_embedding(voice_id=voices["Ted"]["id"])
39
39
  transcript = "Hello! Welcome to Cartesia"
40
40
  model_id = "genial-planet-1346" # (Optional) We'll specify a default if you don't have a specific model in mind
41
41
 
@@ -72,7 +72,7 @@ import os
72
72
  async def write_stream():
73
73
  client = AsyncCartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
74
74
  voices = client.get_voices()
75
- voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
75
+ voice = client.get_voice_embedding(voice_id=voices["Ted"]["id"])
76
76
  transcript = "Hello! Welcome to Cartesia"
77
77
  model_id = "genial-planet-1346" # (Optional) We'll specify a default if you don't have a specific model in mind
78
78
 
@@ -114,7 +114,7 @@ from cartesia.tts import CartesiaTTS
114
114
 
115
115
  with CartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY")) as client:
116
116
  voices = client.get_voices()
117
- voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
117
+ voice = client.get_voice_embedding(voice_id=voices["Ted"]["id"])
118
118
  transcript = "Hello! Welcome to Cartesia"
119
119
 
120
120
  # Create a BytesIO object to store the audio data
@@ -146,7 +146,7 @@ from cartesia.tts import AsyncCartesiaTTS
146
146
 
147
147
  async with AsyncCartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY")) as client:
148
148
  voices = client.get_voices()
149
- voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
149
+ voice = client.get_voice_embedding(voice_id=voices["Ted"]["id"])
150
150
  transcript = "Hello! Welcome to Cartesia"
151
151
 
152
152
  # Create a BytesIO object to store the audio data
@@ -2,6 +2,7 @@ README.md
2
2
  pyproject.toml
3
3
  setup.py
4
4
  cartesia/__init__.py
5
+ cartesia/_types.py
5
6
  cartesia/tts.py
6
7
  cartesia/utils.py
7
8
  cartesia/version.py
@@ -10,6 +10,7 @@ pytest-cov>=4.1.0
10
10
  twine
11
11
  setuptools
12
12
  wheel
13
+ numpy
13
14
 
14
15
  [dev]
15
16
  pytest>=8.0.2
@@ -17,3 +18,4 @@ pytest-cov>=4.1.0
17
18
  twine
18
19
  setuptools
19
20
  wheel
21
+ numpy
@@ -8,15 +8,16 @@ general correctness.
8
8
  import logging
9
9
  import os
10
10
  import sys
11
- from cartesia.tts import DEFAULT_MODEL_ID, AsyncCartesiaTTS, CartesiaTTS, VoiceMetadata
12
- from typing import AsyncGenerator, Dict, Generator, List
13
-
11
+ from cartesia.tts import DEFAULT_MODEL_ID, AsyncCartesiaTTS, CartesiaTTS
12
+ from cartesia._types import AudioDataReturnType, AudioOutputFormat, VoiceMetadata
13
+ from typing import AsyncGenerator, Dict, Generator, List, Optional, Union
14
+ import numpy as np
14
15
  import pytest
15
16
 
16
17
  THISDIR = os.path.dirname(__file__)
17
18
  sys.path.insert(0, os.path.dirname(THISDIR))
18
19
 
19
- SAMPLE_VOICE = "Samantha"
20
+ SAMPLE_VOICE = "Newsman"
20
21
 
21
22
  logger = logging.getLogger(__name__)
22
23
 
@@ -166,6 +167,78 @@ def test_generate_context_manager_with_err():
166
167
  assert websocket.socket.fileno() == -1 # check socket is now closed
167
168
 
168
169
 
170
+ @pytest.mark.parametrize("output_format", [_fmt for _fmt in AudioOutputFormat])
171
+ @pytest.mark.parametrize("as_str", [True, False])
172
+ @pytest.mark.parametrize("stream", [True, False])
173
+ @pytest.mark.parametrize("websocket", [True, False])
174
+ def test_generate_with_output_format(
175
+ resources: _Resources,
176
+ output_format: AudioOutputFormat,
177
+ as_str: bool,
178
+ stream: bool,
179
+ websocket: bool,
180
+ ):
181
+ value = output_format.value
182
+
183
+ client = resources.client
184
+ voices = resources.voices
185
+ embedding = voices[SAMPLE_VOICE]["embedding"]
186
+ transcript = "Hello, world!"
187
+
188
+ split = value.split("_")
189
+ expected_sampling_rate = int(split[1]) if len(split) == 2 else 44_100
190
+
191
+ # Easy way to get around iterating over stream=True / False.
192
+ output_generate = client.generate(
193
+ transcript=transcript,
194
+ voice=embedding,
195
+ websocket=websocket,
196
+ stream=stream,
197
+ output_format=output_format.value if as_str else output_format,
198
+ )
199
+ if not stream:
200
+ output_generate = [output_generate]
201
+
202
+ for out in output_generate:
203
+ assert isinstance(out["audio"], bytes)
204
+ assert out["sampling_rate"] == expected_sampling_rate
205
+
206
+
207
+ @pytest.mark.parametrize("data_rtype", [_fmt for _fmt in AudioDataReturnType])
208
+ @pytest.mark.parametrize("as_str", [True, False])
209
+ @pytest.mark.parametrize("stream", [True, False])
210
+ @pytest.mark.parametrize("websocket", [True, False])
211
+ def test_generate_with_data_rtype(
212
+ resources: _Resources,
213
+ data_rtype: AudioDataReturnType,
214
+ as_str: bool,
215
+ stream: bool,
216
+ websocket: bool,
217
+ ):
218
+ client = resources.client
219
+ voices = resources.voices
220
+ embedding = voices[SAMPLE_VOICE]["embedding"]
221
+ transcript = "Hello, world!"
222
+
223
+ # Easy way to get around iterating over stream=True / False.
224
+ output_generate = client.generate(
225
+ transcript=transcript,
226
+ voice=embedding,
227
+ websocket=websocket,
228
+ stream=stream,
229
+ data_rtype=data_rtype.value if as_str else data_rtype,
230
+ )
231
+ if not stream:
232
+ output_generate = [output_generate]
233
+
234
+ for out in output_generate:
235
+ if data_rtype == AudioDataReturnType.BYTES:
236
+ assert isinstance(out["audio"], bytes)
237
+ elif data_rtype == AudioDataReturnType.ARRAY:
238
+ assert isinstance(out["audio"], np.ndarray)
239
+ assert out["audio"].dtype == np.float32
240
+
241
+
169
242
  @pytest.mark.parametrize("websocket", [True, False])
170
243
  @pytest.mark.asyncio
171
244
  async def test_async_generate(resources: _Resources, websocket: bool):
@@ -199,7 +272,9 @@ async def test_async_generate_stream(resources: _Resources, websocket: bool):
199
272
  async_client = create_async_client()
200
273
 
201
274
  try:
202
- generator = await async_client.generate(transcript=transcript, voice=embedding, websocket=websocket, stream=True)
275
+ generator = await async_client.generate(
276
+ transcript=transcript, voice=embedding, websocket=websocket, stream=True
277
+ )
203
278
  assert isinstance(generator, AsyncGenerator)
204
279
  async for output in generator:
205
280
  assert output.keys() == {"audio", "sampling_rate"}
@@ -248,18 +323,111 @@ async def test_generate_async_context_manager_with_err():
248
323
  assert websocket.closed # check websocket is now closed
249
324
 
250
325
 
326
+ @pytest.mark.parametrize("output_format", [_fmt for _fmt in AudioOutputFormat])
327
+ @pytest.mark.parametrize("as_str", [True, False])
328
+ @pytest.mark.parametrize("stream", [True, False])
329
+ @pytest.mark.asyncio
330
+ async def test_generate_async_with_output_format(
331
+ resources: _Resources, output_format: AudioOutputFormat, as_str: bool, stream: bool
332
+ ):
333
+ logger.info(
334
+ f"Testing async generate stream with output_format={output_format}, as_str={as_str}, stream={stream}"
335
+ )
336
+ voices = resources.voices
337
+ embedding = voices[SAMPLE_VOICE]["embedding"]
338
+ transcript = "Hello, world!"
339
+
340
+ split = output_format.value.split("_")
341
+ expected_sampling_rate = int(split[1]) if len(split) == 2 else 44_100
342
+
343
+ def _validate(output):
344
+ assert isinstance(output["audio"], bytes)
345
+ assert output["sampling_rate"] == expected_sampling_rate
346
+
347
+ async_client = create_async_client()
348
+
349
+ try:
350
+ output_generate = await async_client.generate(
351
+ transcript=transcript,
352
+ voice=embedding,
353
+ websocket=False,
354
+ stream=stream,
355
+ output_format=output_format.value if as_str else output_format,
356
+ )
357
+ if stream:
358
+ generator = output_generate
359
+ assert isinstance(generator, AsyncGenerator)
360
+ async for output in generator:
361
+ _validate(output)
362
+ else:
363
+ _validate(output_generate)
364
+ finally:
365
+ # Close the websocket
366
+ await async_client.close()
367
+
368
+
369
+ @pytest.mark.parametrize("data_rtype", [_fmt for _fmt in AudioDataReturnType])
370
+ @pytest.mark.parametrize("as_str", [True, False])
371
+ @pytest.mark.parametrize("stream", [True, False])
372
+ @pytest.mark.parametrize("websocket", [True, False])
373
+ @pytest.mark.asyncio
374
+ async def test_generate_async_with_data_rtype(resources: _Resources, data_rtype: AudioDataReturnType, as_str: bool, stream: bool, websocket: bool):
375
+ voices = resources.voices
376
+ embedding = voices[SAMPLE_VOICE]["embedding"]
377
+ transcript = "Hello, world!"
378
+
379
+ async_client = create_async_client()
380
+
381
+ def _validate(output):
382
+ if data_rtype == AudioDataReturnType.BYTES:
383
+ assert isinstance(output["audio"], bytes)
384
+ elif data_rtype == AudioDataReturnType.ARRAY:
385
+ assert isinstance(output["audio"], np.ndarray)
386
+ assert output["audio"].dtype == np.float32
387
+
388
+ try:
389
+ output_generate = await async_client.generate(
390
+ transcript=transcript,
391
+ voice=embedding,
392
+ websocket=False,
393
+ stream=stream,
394
+ data_rtype=data_rtype.value if as_str else data_rtype,
395
+ )
396
+ if stream:
397
+ assert isinstance(output_generate, AsyncGenerator)
398
+ async for output in output_generate:
399
+ _validate(output)
400
+ else:
401
+ _validate(output_generate)
402
+ finally:
403
+ # Close the websocket
404
+ await async_client.close()
405
+
406
+
251
407
  @pytest.mark.parametrize("chunk_time", [0.05, 0.6])
252
408
  def test_check_inputs_invalid_chunk_time(client: CartesiaTTS, chunk_time):
253
409
  logger.info(f"Testing invalid chunk_time: {chunk_time}")
254
410
  with pytest.raises(ValueError, match="`chunk_time` must be between 0.1 and 0.5"):
255
- client._check_inputs("Test", None, chunk_time)
411
+ client._check_inputs(
412
+ "Test",
413
+ None,
414
+ chunk_time,
415
+ output_format=AudioOutputFormat.FP32,
416
+ data_rtype=AudioDataReturnType.BYTES,
417
+ )
256
418
 
257
419
 
258
420
  @pytest.mark.parametrize("chunk_time", [0.1, 0.3, 0.5])
259
421
  def test_check_inputs_valid_chunk_time(client, chunk_time):
260
422
  logger.info("Testing valid chunk_time: {chunk_time}")
261
423
  try:
262
- client._check_inputs("Test", None, chunk_time)
424
+ client._check_inputs(
425
+ "Test",
426
+ None,
427
+ chunk_time,
428
+ output_format=AudioOutputFormat.FP32,
429
+ data_rtype=AudioDataReturnType.BYTES,
430
+ )
263
431
  except ValueError:
264
432
  pytest.fail("Unexpected ValueError raised")
265
433
 
@@ -267,14 +435,26 @@ def test_check_inputs_valid_chunk_time(client, chunk_time):
267
435
  def test_check_inputs_duration_less_than_chunk_time(client: CartesiaTTS):
268
436
  logger.info("Testing duration less than chunk_time")
269
437
  with pytest.raises(ValueError, match="`duration` must be greater than chunk_time"):
270
- client._check_inputs("Test", 0.2, 0.3)
438
+ client._check_inputs(
439
+ "Test",
440
+ 0.2,
441
+ 0.3,
442
+ output_format=AudioOutputFormat.FP32,
443
+ data_rtype=AudioDataReturnType.BYTES,
444
+ )
271
445
 
272
446
 
273
447
  @pytest.mark.parametrize("duration,chunk_time", [(0.5, 0.2), (1.0, 0.5), (2.0, 0.1)])
274
448
  def test_check_inputs_valid_duration_and_chunk_time(client: CartesiaTTS, duration, chunk_time):
275
449
  logger.info(f"Testing valid duration: {duration} and chunk_time: {chunk_time}")
276
450
  try:
277
- client._check_inputs("Test", duration, chunk_time)
451
+ client._check_inputs(
452
+ "Test",
453
+ duration,
454
+ chunk_time,
455
+ output_format=AudioOutputFormat.FP32,
456
+ data_rtype=AudioDataReturnType.BYTES,
457
+ )
278
458
  except ValueError:
279
459
  pytest.fail("Unexpected ValueError raised")
280
460
 
@@ -282,13 +462,99 @@ def test_check_inputs_valid_duration_and_chunk_time(client: CartesiaTTS, duratio
282
462
  def test_check_inputs_empty_transcript(client: CartesiaTTS):
283
463
  logger.info("Testing empty transcript")
284
464
  with pytest.raises(ValueError, match="`transcript` must be non empty"):
285
- client._check_inputs("", None, None)
465
+ client._check_inputs(
466
+ "",
467
+ None,
468
+ None,
469
+ output_format=AudioOutputFormat.FP32,
470
+ data_rtype=AudioDataReturnType.BYTES,
471
+ )
286
472
 
287
473
 
288
474
  @pytest.mark.parametrize("transcript", ["Hello", "Test transcript", "Lorem ipsum dolor sit amet"])
289
475
  def test_check_inputs_valid_transcript(client: CartesiaTTS, transcript):
290
476
  logger.info(f"Testing valid transcript: {transcript}")
291
477
  try:
292
- client._check_inputs(transcript, None, None)
478
+ client._check_inputs(
479
+ transcript,
480
+ None,
481
+ None,
482
+ output_format=AudioOutputFormat.FP32,
483
+ data_rtype=AudioDataReturnType.BYTES,
484
+ )
293
485
  except ValueError:
294
486
  pytest.fail("Unexpected ValueError raised")
487
+
488
+
489
+ @pytest.mark.parametrize(
490
+ "output_format,error",
491
+ [
492
+ # Valid output formats.
493
+ ("fp32", None),
494
+ ("pcm", None),
495
+ ("fp32_16000", None),
496
+ ("fp32_22050", None),
497
+ ("fp32_44100", None),
498
+ ("pcm_16000", None),
499
+ ("pcm_22050", None),
500
+ ("pcm_44100", None),
501
+ # Invalid output formats.
502
+ ("invalid", ValueError),
503
+ ("pcm_1234", ValueError), # cannot specify arbitrary sampling rate
504
+ ("fp32_1234", ValueError), # cannot specify arbitrary sampling rate
505
+ ("fp16_44100", ValueError), # fp16 not supported.
506
+ ],
507
+ )
508
+ def test_check_inputs_output_format(
509
+ client: CartesiaTTS, output_format: Union[str, AudioOutputFormat], error: Optional[Exception]
510
+ ):
511
+ if error:
512
+ with pytest.raises(error):
513
+ client._check_inputs(
514
+ "Test",
515
+ None,
516
+ None,
517
+ output_format=output_format,
518
+ data_rtype=AudioDataReturnType.BYTES,
519
+ )
520
+ else:
521
+ client._check_inputs(
522
+ "Test",
523
+ None,
524
+ None,
525
+ output_format=output_format,
526
+ data_rtype=AudioDataReturnType.BYTES,
527
+ )
528
+
529
+
530
+ @pytest.mark.parametrize(
531
+ "data_rtype,error",
532
+ [
533
+ # Valid data return types.
534
+ ("bytes", None),
535
+ ("array", None),
536
+ # Invalid data return types.
537
+ ("invalid", ValueError),
538
+ ("tensor", ValueError),
539
+ ],
540
+ )
541
+ def test_check_inputs_data_rtype(
542
+ client: CartesiaTTS, data_rtype: Union[str, AudioDataReturnType], error: Optional[Exception]
543
+ ):
544
+ if error:
545
+ with pytest.raises(error):
546
+ client._check_inputs(
547
+ "Test",
548
+ None,
549
+ None,
550
+ output_format=AudioOutputFormat.FP32,
551
+ data_rtype=data_rtype,
552
+ )
553
+ else:
554
+ client._check_inputs(
555
+ "Test",
556
+ None,
557
+ None,
558
+ output_format=AudioOutputFormat.FP32,
559
+ data_rtype=data_rtype,
560
+ )
@@ -1 +0,0 @@
1
- __version__ = "0.0.6"
File without changes
File without changes
File without changes
File without changes