cartesia 0.0.6__tar.gz → 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cartesia-0.0.6 → cartesia-0.1.0}/PKG-INFO +1 -1
- cartesia-0.1.0/cartesia/_types.py +42 -0
- {cartesia-0.0.6 → cartesia-0.1.0}/cartesia/tts.py +146 -126
- {cartesia-0.0.6 → cartesia-0.1.0}/cartesia/utils.py +27 -5
- cartesia-0.1.0/cartesia/version.py +1 -0
- {cartesia-0.0.6 → cartesia-0.1.0}/cartesia.egg-info/PKG-INFO +1 -1
- {cartesia-0.0.6 → cartesia-0.1.0}/cartesia.egg-info/SOURCES.txt +1 -0
- {cartesia-0.0.6 → cartesia-0.1.0}/cartesia.egg-info/requires.txt +2 -0
- {cartesia-0.0.6 → cartesia-0.1.0}/tests/test_tts.py +277 -11
- cartesia-0.0.6/cartesia/version.py +0 -1
- {cartesia-0.0.6 → cartesia-0.1.0}/README.md +0 -0
- {cartesia-0.0.6 → cartesia-0.1.0}/cartesia/__init__.py +0 -0
- {cartesia-0.0.6 → cartesia-0.1.0}/cartesia.egg-info/dependency_links.txt +0 -0
- {cartesia-0.0.6 → cartesia-0.1.0}/cartesia.egg-info/top_level.txt +0 -0
- {cartesia-0.0.6 → cartesia-0.1.0}/pyproject.toml +0 -0
- {cartesia-0.0.6 → cartesia-0.1.0}/setup.cfg +0 -0
- {cartesia-0.0.6 → cartesia-0.1.0}/setup.py +0 -0
@@ -0,0 +1,42 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import List, Optional, TypedDict, Union
|
3
|
+
|
4
|
+
try:
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
_NUMPY_AVAILABLE = True
|
8
|
+
except ImportError:
|
9
|
+
_NUMPY_AVAILABLE = False
|
10
|
+
|
11
|
+
|
12
|
+
class AudioDataReturnType(Enum):
|
13
|
+
BYTES = "bytes"
|
14
|
+
ARRAY = "array"
|
15
|
+
|
16
|
+
|
17
|
+
class AudioOutputFormat(Enum):
|
18
|
+
"""Supported output formats for the audio."""
|
19
|
+
|
20
|
+
FP32 = "fp32" # float32
|
21
|
+
PCM = "pcm" # 16-bit signed integer PCM
|
22
|
+
FP32_16000 = "fp32_16000" # float32, 16 kHz
|
23
|
+
FP32_22050 = "fp32_22050" # float32, 22.05 kHz
|
24
|
+
FP32_44100 = "fp32_44100" # float32, 44.1 kHz
|
25
|
+
PCM_16000 = "pcm_16000" # 16-bit signed integer PCM, 16 kHz
|
26
|
+
PCM_22050 = "pcm_22050" # 16-bit signed integer PCM, 22.05 kHz
|
27
|
+
PCM_44100 = "pcm_44100" # 16-bit signed integer PCM, 44.1 kHz
|
28
|
+
|
29
|
+
|
30
|
+
class AudioOutput(TypedDict):
|
31
|
+
audio: Union[bytes, "np.ndarray"]
|
32
|
+
sampling_rate: int
|
33
|
+
|
34
|
+
|
35
|
+
Embedding = List[float]
|
36
|
+
|
37
|
+
|
38
|
+
class VoiceMetadata(TypedDict):
|
39
|
+
id: str
|
40
|
+
name: str
|
41
|
+
description: str
|
42
|
+
embedding: Optional[Embedding]
|
@@ -4,7 +4,17 @@ import json
|
|
4
4
|
import os
|
5
5
|
import uuid
|
6
6
|
from types import TracebackType
|
7
|
-
from typing import
|
7
|
+
from typing import (
|
8
|
+
Any,
|
9
|
+
AsyncGenerator,
|
10
|
+
Dict,
|
11
|
+
Generator,
|
12
|
+
List,
|
13
|
+
Optional,
|
14
|
+
Tuple,
|
15
|
+
TypedDict,
|
16
|
+
Union,
|
17
|
+
)
|
8
18
|
|
9
19
|
import aiohttp
|
10
20
|
import httpx
|
@@ -13,6 +23,21 @@ import requests
|
|
13
23
|
from websockets.sync.client import connect
|
14
24
|
|
15
25
|
from cartesia.utils import retry_on_connection_error, retry_on_connection_error_async
|
26
|
+
from cartesia._types import (
|
27
|
+
AudioDataReturnType,
|
28
|
+
AudioOutputFormat,
|
29
|
+
AudioOutput,
|
30
|
+
Embedding,
|
31
|
+
VoiceMetadata,
|
32
|
+
)
|
33
|
+
|
34
|
+
try:
|
35
|
+
import numpy as np
|
36
|
+
|
37
|
+
_NUMPY_AVAILABLE = True
|
38
|
+
except ImportError:
|
39
|
+
_NUMPY_AVAILABLE = False
|
40
|
+
|
16
41
|
|
17
42
|
DEFAULT_MODEL_ID = ""
|
18
43
|
DEFAULT_BASE_URL = "api.cartesia.ai"
|
@@ -25,20 +50,6 @@ MAX_RETRIES = 3
|
|
25
50
|
|
26
51
|
logger = logging.getLogger(__name__)
|
27
52
|
|
28
|
-
class AudioOutput(TypedDict):
|
29
|
-
audio: bytes
|
30
|
-
sampling_rate: int
|
31
|
-
|
32
|
-
|
33
|
-
Embedding = List[float]
|
34
|
-
|
35
|
-
|
36
|
-
class VoiceMetadata(TypedDict):
|
37
|
-
id: str
|
38
|
-
name: str
|
39
|
-
description: str
|
40
|
-
embedding: Optional[Embedding]
|
41
|
-
|
42
53
|
|
43
54
|
def update_buffer(buffer: str, chunk_bytes: bytes) -> Tuple[str, List[Dict[str, Any]]]:
|
44
55
|
buffer += chunk_bytes.decode("utf-8")
|
@@ -79,7 +90,6 @@ class CartesiaTTS:
|
|
79
90
|
and generate speech from text.
|
80
91
|
|
81
92
|
The client also supports generating audio using a websocket for lower latency.
|
82
|
-
To enable interrupt handling along the websocket, set `experimental_ws_handle_interrupts=True`.
|
83
93
|
|
84
94
|
Examples:
|
85
95
|
>>> client = CartesiaTTS()
|
@@ -102,21 +112,17 @@ class CartesiaTTS:
|
|
102
112
|
... audio, sr = audio_chunk["audio"], audio_chunk["sampling_rate"]
|
103
113
|
"""
|
104
114
|
|
105
|
-
def __init__(self, *, api_key: str = None
|
115
|
+
def __init__(self, *, api_key: str = None):
|
106
116
|
"""Args:
|
107
117
|
api_key: The API key to use for authorization.
|
108
118
|
If not specified, the API key will be read from the environment variable
|
109
119
|
`CARTESIA_API_KEY`.
|
110
|
-
experimental_ws_handle_interrupts: Whether to handle interrupts when generating
|
111
|
-
audio using the websocket. This is an experimental feature and may have bugs
|
112
|
-
or be deprecated in the future.
|
113
120
|
"""
|
114
121
|
self.base_url = os.environ.get("CARTESIA_BASE_URL", DEFAULT_BASE_URL)
|
115
122
|
self.api_key = api_key or os.environ.get("CARTESIA_API_KEY")
|
116
123
|
self.api_version = os.environ.get("CARTESIA_API_VERSION", DEFAULT_API_VERSION)
|
117
124
|
self.headers = {"X-API-Key": self.api_key, "Content-Type": "application/json"}
|
118
125
|
self.websocket = None
|
119
|
-
self.experimental_ws_handle_interrupts = experimental_ws_handle_interrupts
|
120
126
|
|
121
127
|
def get_voices(self, skip_embeddings: bool = True) -> Dict[str, VoiceMetadata]:
|
122
128
|
"""Returns a mapping from voice name -> voice metadata.
|
@@ -165,7 +171,9 @@ class CartesiaTTS:
|
|
165
171
|
voice["embedding"] = json.loads(voice["embedding"])
|
166
172
|
return {voice["name"]: voice for voice in voices}
|
167
173
|
|
168
|
-
@retry_on_connection_error(
|
174
|
+
@retry_on_connection_error(
|
175
|
+
max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger
|
176
|
+
)
|
169
177
|
def get_voice_embedding(
|
170
178
|
self, *, voice_id: str = None, filepath: str = None, link: str = None
|
171
179
|
) -> Embedding:
|
@@ -222,16 +230,28 @@ class CartesiaTTS:
|
|
222
230
|
"""
|
223
231
|
if self.websocket is None or self._is_websocket_closed():
|
224
232
|
route = "audio/websocket"
|
225
|
-
if self.experimental_ws_handle_interrupts:
|
226
|
-
route = f"experimental/{route}"
|
227
233
|
self.websocket = connect(f"{self._ws_url()}/{route}?api_key={self.api_key}")
|
228
234
|
|
229
235
|
def _is_websocket_closed(self):
|
230
236
|
return self.websocket.socket.fileno() == -1
|
231
237
|
|
232
238
|
def _check_inputs(
|
233
|
-
self,
|
239
|
+
self,
|
240
|
+
transcript: str,
|
241
|
+
duration: Optional[float],
|
242
|
+
chunk_time: Optional[float],
|
243
|
+
output_format: Union[str, AudioOutputFormat],
|
244
|
+
data_rtype: Union[str, AudioDataReturnType],
|
234
245
|
):
|
246
|
+
# This will try the casting and raise an error.
|
247
|
+
_ = AudioOutputFormat(output_format)
|
248
|
+
|
249
|
+
if AudioDataReturnType(data_rtype) == AudioDataReturnType.ARRAY and not _NUMPY_AVAILABLE:
|
250
|
+
raise ImportError(
|
251
|
+
"The 'numpy' package is required to use the 'array' return type. "
|
252
|
+
"Please install 'numpy' or use 'bytes' as the return type."
|
253
|
+
)
|
254
|
+
|
235
255
|
if chunk_time is not None:
|
236
256
|
if chunk_time < 0.1 or chunk_time > 0.5:
|
237
257
|
raise ValueError("`chunk_time` must be between 0.1 and 0.5")
|
@@ -249,7 +269,7 @@ class CartesiaTTS:
|
|
249
269
|
transcript: str,
|
250
270
|
voice: Embedding,
|
251
271
|
model_id: str,
|
252
|
-
output_format:
|
272
|
+
output_format: AudioOutputFormat,
|
253
273
|
duration: int = None,
|
254
274
|
chunk_time: float = None,
|
255
275
|
) -> Dict[str, Any]:
|
@@ -259,6 +279,7 @@ class CartesiaTTS:
|
|
259
279
|
filtered out otherwise.
|
260
280
|
"""
|
261
281
|
body = dict(transcript=transcript, model_id=model_id, voice=voice)
|
282
|
+
output_format = output_format.value
|
262
283
|
|
263
284
|
optional_body = dict(
|
264
285
|
duration=duration,
|
@@ -279,7 +300,8 @@ class CartesiaTTS:
|
|
279
300
|
chunk_time: float = None,
|
280
301
|
stream: bool = False,
|
281
302
|
websocket: bool = True,
|
282
|
-
output_format: str = "fp32",
|
303
|
+
output_format: Union[str, AudioOutputFormat] = "fp32",
|
304
|
+
data_rtype: str = "bytes",
|
283
305
|
) -> Union[AudioOutput, Generator[AudioOutput, None, None]]:
|
284
306
|
"""Generate audio from a transcript.
|
285
307
|
|
@@ -293,6 +315,9 @@ class CartesiaTTS:
|
|
293
315
|
If True this function returns a generator. False by default.
|
294
316
|
websocket (bool, optional): Whether to use a websocket for streaming audio.
|
295
317
|
Using the websocket reduces latency by pre-poning the handshake. True by default.
|
318
|
+
data_rtype: The return type for the 'data' key in the dictionary.
|
319
|
+
One of `'byte' | 'array'`.
|
320
|
+
Note this field is experimental and may be deprecated in the future.
|
296
321
|
|
297
322
|
Returns:
|
298
323
|
A generator if `stream` is True, otherwise a dictionary.
|
@@ -300,13 +325,16 @@ class CartesiaTTS:
|
|
300
325
|
* "audio": The audio as a bytes buffer.
|
301
326
|
* "sampling_rate": The sampling rate of the audio.
|
302
327
|
"""
|
303
|
-
self._check_inputs(transcript, duration, chunk_time)
|
328
|
+
self._check_inputs(transcript, duration, chunk_time, output_format, data_rtype)
|
329
|
+
|
330
|
+
data_rtype = AudioDataReturnType(data_rtype)
|
331
|
+
output_format = AudioOutputFormat(output_format)
|
304
332
|
|
305
333
|
body = self._generate_request_body(
|
306
|
-
transcript=transcript,
|
307
|
-
voice=voice,
|
334
|
+
transcript=transcript,
|
335
|
+
voice=voice,
|
308
336
|
model_id=model_id,
|
309
|
-
duration=duration,
|
337
|
+
duration=duration,
|
310
338
|
chunk_time=chunk_time,
|
311
339
|
output_format=output_format,
|
312
340
|
)
|
@@ -316,6 +344,9 @@ class CartesiaTTS:
|
|
316
344
|
else:
|
317
345
|
generator = self._generate_http_wrapper(body)
|
318
346
|
|
347
|
+
generator = self._postprocess_audio(
|
348
|
+
generator, data_rtype=data_rtype, output_format=output_format
|
349
|
+
)
|
319
350
|
if stream:
|
320
351
|
return generator
|
321
352
|
|
@@ -326,9 +357,45 @@ class CartesiaTTS:
|
|
326
357
|
sampling_rate = chunk["sampling_rate"]
|
327
358
|
chunks.append(chunk["audio"])
|
328
359
|
|
329
|
-
|
360
|
+
if data_rtype == AudioDataReturnType.ARRAY:
|
361
|
+
cat = np.concatenate
|
362
|
+
else:
|
363
|
+
cat = b"".join
|
364
|
+
|
365
|
+
return {"audio": cat(chunks), "sampling_rate": sampling_rate}
|
366
|
+
|
367
|
+
def _postprocess_audio(
|
368
|
+
self,
|
369
|
+
generator: Generator[AudioOutput, None, None],
|
370
|
+
*,
|
371
|
+
data_rtype: AudioDataReturnType,
|
372
|
+
output_format: AudioOutputFormat,
|
373
|
+
) -> Generator[AudioOutput, None, None]:
|
374
|
+
"""Perform postprocessing on the generator outputs.
|
375
|
+
|
376
|
+
The postprocessing should be minimal (e.g. converting to array, casting dtype).
|
377
|
+
This code should not perform heavy operations like changing the sampling rate.
|
378
|
+
|
379
|
+
Args:
|
380
|
+
generator: A generator that yields audio chunks.
|
381
|
+
data_rtype: The data return type.
|
382
|
+
output_format: The output format for the audio.
|
383
|
+
|
384
|
+
Returns:
|
385
|
+
A generator that yields audio chunks.
|
386
|
+
"""
|
387
|
+
dtype = None
|
388
|
+
if data_rtype == AudioDataReturnType.ARRAY:
|
389
|
+
dtype = np.float32 if "fp32" in output_format.value else np.int16
|
390
|
+
|
391
|
+
for chunk in generator:
|
392
|
+
if dtype is not None:
|
393
|
+
chunk["audio"] = np.frombuffer(chunk["audio"], dtype=dtype)
|
394
|
+
yield chunk
|
330
395
|
|
331
|
-
@retry_on_connection_error(
|
396
|
+
@retry_on_connection_error(
|
397
|
+
max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger
|
398
|
+
)
|
332
399
|
def _generate_http_wrapper(self, body: Dict[str, Any]):
|
333
400
|
"""Need to wrap the http generator in a function for the retry decorator to work."""
|
334
401
|
try:
|
@@ -389,15 +456,6 @@ class CartesiaTTS:
|
|
389
456
|
break
|
390
457
|
|
391
458
|
yield convert_response(response, include_context_id)
|
392
|
-
|
393
|
-
if self.experimental_ws_handle_interrupts:
|
394
|
-
self.websocket.send(json.dumps({"context_id": context_id}))
|
395
|
-
except GeneratorExit:
|
396
|
-
# The exit is only called when the generator is garbage collected.
|
397
|
-
# It may not be called directly after a break statement.
|
398
|
-
# However, the generator will be automatically cancelled on the next request.
|
399
|
-
if self.experimental_ws_handle_interrupts:
|
400
|
-
self.websocket.send(json.dumps({"context_id": context_id, "action": "cancel"}))
|
401
459
|
except Exception as e:
|
402
460
|
# Close the websocket connection if an error occurs.
|
403
461
|
if self.websocket and not self._is_websocket_closed():
|
@@ -408,23 +466,6 @@ class CartesiaTTS:
|
|
408
466
|
if self.websocket and not self._is_websocket_closed():
|
409
467
|
self.websocket.close()
|
410
468
|
|
411
|
-
@retry_on_connection_error(max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger)
|
412
|
-
def transcribe(self, raw_audio: Union[bytes, str]) -> str:
|
413
|
-
raw_audio_bytes, headers = self.prepare_audio_and_headers(raw_audio)
|
414
|
-
response = httpx.post(
|
415
|
-
f"{self._http_url()}/audio/transcriptions",
|
416
|
-
headers=headers,
|
417
|
-
files={"clip": ("input.wav", raw_audio_bytes)},
|
418
|
-
timeout=DEFAULT_TIMEOUT,
|
419
|
-
)
|
420
|
-
|
421
|
-
if not response.is_success:
|
422
|
-
raise ValueError(f"Failed to transcribe audio. Error: {response.text()}")
|
423
|
-
|
424
|
-
transcript = response.json()
|
425
|
-
return transcript["text"]
|
426
|
-
|
427
|
-
|
428
469
|
def prepare_audio_and_headers(
|
429
470
|
self, raw_audio: Union[bytes, str]
|
430
471
|
) -> Tuple[bytes, Dict[str, Any]]:
|
@@ -466,13 +507,11 @@ class CartesiaTTS:
|
|
466
507
|
|
467
508
|
|
468
509
|
class AsyncCartesiaTTS(CartesiaTTS):
|
469
|
-
def __init__(self, *, api_key: str = None
|
510
|
+
def __init__(self, *, api_key: str = None):
|
470
511
|
self._session = None
|
471
512
|
self._loop = None
|
472
|
-
super().__init__(
|
473
|
-
|
474
|
-
)
|
475
|
-
|
513
|
+
super().__init__(api_key=api_key)
|
514
|
+
|
476
515
|
async def _get_session(self):
|
477
516
|
current_loop = asyncio.get_event_loop()
|
478
517
|
if self._loop is not current_loop:
|
@@ -481,29 +520,25 @@ class AsyncCartesiaTTS(CartesiaTTS):
|
|
481
520
|
if self._session is None or self._session.closed:
|
482
521
|
timeout = aiohttp.ClientTimeout(total=DEFAULT_TIMEOUT)
|
483
522
|
connector = aiohttp.TCPConnector(limit=DEFAULT_NUM_CONNECTIONS)
|
484
|
-
self._session = aiohttp.ClientSession(
|
485
|
-
timeout=timeout, connector=connector
|
486
|
-
)
|
523
|
+
self._session = aiohttp.ClientSession(timeout=timeout, connector=connector)
|
487
524
|
self._loop = current_loop
|
488
525
|
return self._session
|
489
|
-
|
526
|
+
|
490
527
|
async def refresh_websocket(self):
|
491
528
|
"""Refresh the websocket connection."""
|
492
529
|
if self.websocket is None or self._is_websocket_closed():
|
493
530
|
route = "audio/websocket"
|
494
|
-
if self.experimental_ws_handle_interrupts:
|
495
|
-
route = f"experimental/{route}"
|
496
531
|
session = await self._get_session()
|
497
532
|
self.websocket = await session.ws_connect(
|
498
533
|
f"{self._ws_url()}/{route}?api_key={self.api_key}"
|
499
534
|
)
|
500
|
-
|
535
|
+
|
501
536
|
def _is_websocket_closed(self):
|
502
537
|
return self.websocket.closed
|
503
538
|
|
504
539
|
async def close(self):
|
505
540
|
"""This method closes the websocket and the session.
|
506
|
-
|
541
|
+
|
507
542
|
It is *strongly* recommended to call this method when you are done using the client.
|
508
543
|
"""
|
509
544
|
if self.websocket is not None and not self._is_websocket_closed():
|
@@ -521,35 +556,22 @@ class AsyncCartesiaTTS(CartesiaTTS):
|
|
521
556
|
chunk_time: float = None,
|
522
557
|
stream: bool = False,
|
523
558
|
websocket: bool = True,
|
524
|
-
output_format: str = "fp32"
|
559
|
+
output_format: Union[str, AudioOutputFormat] = "fp32",
|
560
|
+
data_rtype: Union[str, AudioDataReturnType] = "bytes",
|
525
561
|
) -> Union[AudioOutput, AsyncGenerator[AudioOutput, None]]:
|
526
562
|
"""Asynchronously generate audio from a transcript.
|
527
|
-
NOTE: This overrides the non-asynchronous generate method from the base class.
|
528
|
-
|
529
|
-
Args:
|
530
|
-
transcript (str): The text to generate audio for.
|
531
|
-
voice (Embedding (List[float])): The voice to use for generating audio.
|
532
|
-
duration (int, optional): The maximum duration of the audio in seconds.
|
533
|
-
chunk_time (float, optional): How long each audio segment should be in seconds.
|
534
|
-
This should not need to be adjusted.
|
535
|
-
stream (bool, optional): Whether to stream the audio or not.
|
536
|
-
If True this function returns a generator. False by default.
|
537
|
-
websocket (bool, optional): Whether to use a websocket for streaming audio.
|
538
|
-
Using the websocket reduces latency by pre-poning the handshake. True by default.
|
539
563
|
|
540
|
-
|
541
|
-
A generator if `stream` is True, otherwise a dictionary.
|
542
|
-
Dictionary from both generator and non-generator return types have the following keys:
|
543
|
-
* "audio": The audio as a bytes buffer.
|
544
|
-
* "sampling_rate": The sampling rate of the audio.
|
564
|
+
For more information on the arguments, see the synchronous :meth:`CartesiaTTS.generate`.
|
545
565
|
"""
|
546
|
-
self._check_inputs(transcript, duration, chunk_time)
|
566
|
+
self._check_inputs(transcript, duration, chunk_time, output_format, data_rtype)
|
567
|
+
data_rtype = AudioDataReturnType(data_rtype)
|
568
|
+
output_format = AudioOutputFormat(output_format)
|
547
569
|
|
548
570
|
body = self._generate_request_body(
|
549
|
-
transcript=transcript,
|
571
|
+
transcript=transcript,
|
550
572
|
voice=voice,
|
551
573
|
model_id=model_id,
|
552
|
-
duration=duration,
|
574
|
+
duration=duration,
|
553
575
|
chunk_time=chunk_time,
|
554
576
|
output_format=output_format,
|
555
577
|
)
|
@@ -558,7 +580,9 @@ class AsyncCartesiaTTS(CartesiaTTS):
|
|
558
580
|
generator = self._generate_ws(body)
|
559
581
|
else:
|
560
582
|
generator = self._generate_http_wrapper(body)
|
561
|
-
|
583
|
+
generator = self._postprocess_audio(
|
584
|
+
generator, data_rtype=data_rtype, output_format=output_format
|
585
|
+
)
|
562
586
|
if stream:
|
563
587
|
return generator
|
564
588
|
|
@@ -569,14 +593,38 @@ class AsyncCartesiaTTS(CartesiaTTS):
|
|
569
593
|
sampling_rate = chunk["sampling_rate"]
|
570
594
|
chunks.append(chunk["audio"])
|
571
595
|
|
572
|
-
|
596
|
+
if data_rtype == AudioDataReturnType.ARRAY:
|
597
|
+
cat = np.concatenate
|
598
|
+
else:
|
599
|
+
cat = b"".join
|
600
|
+
|
601
|
+
return {"audio": cat(chunks), "sampling_rate": sampling_rate}
|
602
|
+
|
603
|
+
async def _postprocess_audio(
|
604
|
+
self,
|
605
|
+
generator: AsyncGenerator[AudioOutput, None],
|
606
|
+
*,
|
607
|
+
data_rtype: AudioDataReturnType,
|
608
|
+
output_format: AudioOutputFormat,
|
609
|
+
) -> AsyncGenerator[AudioOutput, None]:
|
610
|
+
"""See :meth:`CartesiaTTS._postprocess_audio`."""
|
611
|
+
dtype = None
|
612
|
+
if data_rtype == AudioDataReturnType.ARRAY:
|
613
|
+
dtype = np.float32 if "fp32" in output_format.value else np.int16
|
614
|
+
|
615
|
+
async for chunk in generator:
|
616
|
+
if dtype is not None:
|
617
|
+
chunk["audio"] = np.frombuffer(chunk["audio"], dtype=dtype)
|
618
|
+
yield chunk
|
573
619
|
|
574
|
-
@retry_on_connection_error_async(
|
620
|
+
@retry_on_connection_error_async(
|
621
|
+
max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger
|
622
|
+
)
|
575
623
|
async def _generate_http_wrapper(self, body: Dict[str, Any]):
|
576
624
|
"""Need to wrap the http generator in a function for the retry decorator to work."""
|
577
625
|
try:
|
578
|
-
|
579
|
-
|
626
|
+
async for chunk in self._generate_http(body):
|
627
|
+
yield chunk
|
580
628
|
except Exception as e:
|
581
629
|
logger.error(f"Failed to generate audio. {e}")
|
582
630
|
raise e
|
@@ -605,10 +653,6 @@ class AsyncCartesiaTTS(CartesiaTTS):
|
|
605
653
|
|
606
654
|
async def _generate_ws(self, body: Dict[str, Any], *, context_id: str = None):
|
607
655
|
include_context_id = bool(context_id)
|
608
|
-
route = "audio/websocket"
|
609
|
-
if self.experimental_ws_handle_interrupts:
|
610
|
-
route = f"experimental/{route}"
|
611
|
-
|
612
656
|
if not self.websocket or self._is_websocket_closed():
|
613
657
|
await self.refresh_websocket()
|
614
658
|
|
@@ -624,15 +668,6 @@ class AsyncCartesiaTTS(CartesiaTTS):
|
|
624
668
|
break
|
625
669
|
|
626
670
|
yield convert_response(response, include_context_id)
|
627
|
-
|
628
|
-
if self.experimental_ws_handle_interrupts:
|
629
|
-
await ws.send_json({"context_id": context_id})
|
630
|
-
except GeneratorExit:
|
631
|
-
# The exit is only called when the generator is garbage collected.
|
632
|
-
# It may not be called directly after a break statement.
|
633
|
-
# However, the generator will be automatically cancelled on the next request.
|
634
|
-
if self.experimental_ws_handle_interrupts:
|
635
|
-
await ws.send_json({"context_id": context_id, "action": "cancel"})
|
636
671
|
except Exception as e:
|
637
672
|
if self.websocket and not self._is_websocket_closed():
|
638
673
|
await self.websocket.close()
|
@@ -642,21 +677,6 @@ class AsyncCartesiaTTS(CartesiaTTS):
|
|
642
677
|
if self.websocket and not self._is_websocket_closed():
|
643
678
|
await self.websocket.close()
|
644
679
|
|
645
|
-
async def transcribe(self, raw_audio: Union[bytes, str]) -> str:
|
646
|
-
raw_audio_bytes, headers = self.prepare_audio_and_headers(raw_audio)
|
647
|
-
data = aiohttp.FormData()
|
648
|
-
data.add_field("clip", raw_audio_bytes, filename="input.wav", content_type="audio/wav")
|
649
|
-
session = await self._get_session()
|
650
|
-
|
651
|
-
async with session.post(
|
652
|
-
f"{self._http_url()}/audio/transcriptions", headers=headers, data=data
|
653
|
-
) as response:
|
654
|
-
if not response.ok:
|
655
|
-
raise ValueError(f"Failed to transcribe audio. Error: {await response.text()}")
|
656
|
-
|
657
|
-
transcript = await response.json()
|
658
|
-
return transcript["text"]
|
659
|
-
|
660
680
|
def __del__(self):
|
661
681
|
try:
|
662
682
|
loop = asyncio.get_running_loop()
|
@@ -7,6 +7,7 @@ from http.client import RemoteDisconnected
|
|
7
7
|
from httpx import TimeoutException
|
8
8
|
from requests.exceptions import ConnectionError
|
9
9
|
|
10
|
+
|
10
11
|
def retry_on_connection_error(max_retries=3, backoff_factor=1, logger=None):
|
11
12
|
"""Retry a function if a ConnectionError, RemoteDisconnected, ServerDisconnectedError, or TimeoutException occurs.
|
12
13
|
|
@@ -15,6 +16,7 @@ def retry_on_connection_error(max_retries=3, backoff_factor=1, logger=None):
|
|
15
16
|
backoff_factor (int): The factor to increase the delay between retries.
|
16
17
|
logger (logging.Logger): The logger to use for logging.
|
17
18
|
"""
|
19
|
+
|
18
20
|
def decorator(func):
|
19
21
|
@wraps(func)
|
20
22
|
def wrapper(*args, **kwargs):
|
@@ -22,18 +24,28 @@ def retry_on_connection_error(max_retries=3, backoff_factor=1, logger=None):
|
|
22
24
|
while retry_count < max_retries:
|
23
25
|
try:
|
24
26
|
return func(*args, **kwargs)
|
25
|
-
except (
|
27
|
+
except (
|
28
|
+
ConnectionError,
|
29
|
+
RemoteDisconnected,
|
30
|
+
ServerDisconnectedError,
|
31
|
+
TimeoutException,
|
32
|
+
) as e:
|
26
33
|
logger.info(f"Retrying after exception: {e}")
|
27
34
|
retry_count += 1
|
28
35
|
if retry_count < max_retries:
|
29
36
|
delay = backoff_factor * (2 ** (retry_count - 1))
|
30
|
-
logger.warn(
|
37
|
+
logger.warn(
|
38
|
+
f"Attempt {retry_count + 1}/{max_retries} in {delay} seconds..."
|
39
|
+
)
|
31
40
|
time.sleep(delay)
|
32
41
|
else:
|
33
42
|
raise Exception(f"Exception occurred after {max_retries} tries.") from e
|
43
|
+
|
34
44
|
return wrapper
|
45
|
+
|
35
46
|
return decorator
|
36
47
|
|
48
|
+
|
37
49
|
def retry_on_connection_error_async(max_retries=3, backoff_factor=1, logger=None):
|
38
50
|
"""Retry an asynchronous function if a ConnectionError, RemoteDisconnected, ServerDisconnectedError, or TimeoutException occurs.
|
39
51
|
|
@@ -42,6 +54,7 @@ def retry_on_connection_error_async(max_retries=3, backoff_factor=1, logger=None
|
|
42
54
|
backoff_factor (int): The factor to increase the delay between retries.
|
43
55
|
logger (logging.Logger): The logger to use for logging.
|
44
56
|
"""
|
57
|
+
|
45
58
|
def decorator(func):
|
46
59
|
@wraps(func)
|
47
60
|
async def wrapper(*args, **kwargs):
|
@@ -52,14 +65,23 @@ def retry_on_connection_error_async(max_retries=3, backoff_factor=1, logger=None
|
|
52
65
|
yield chunk
|
53
66
|
# If the function completes without raising an exception return
|
54
67
|
return
|
55
|
-
except (
|
68
|
+
except (
|
69
|
+
ConnectionError,
|
70
|
+
RemoteDisconnected,
|
71
|
+
ServerDisconnectedError,
|
72
|
+
TimeoutException,
|
73
|
+
) as e:
|
56
74
|
logger.info(f"Retrying after exception: {e}")
|
57
75
|
retry_count += 1
|
58
76
|
if retry_count < max_retries:
|
59
77
|
delay = backoff_factor * (2 ** (retry_count - 1))
|
60
|
-
logger.warn(
|
78
|
+
logger.warn(
|
79
|
+
f"Attempt {retry_count + 1}/{max_retries} in {delay} seconds..."
|
80
|
+
)
|
61
81
|
await asyncio.sleep(delay)
|
62
82
|
else:
|
63
83
|
raise Exception(f"Exception occurred after {max_retries} tries.") from e
|
84
|
+
|
64
85
|
return wrapper
|
65
|
-
|
86
|
+
|
87
|
+
return decorator
|
@@ -0,0 +1 @@
|
|
1
|
+
__version__ = "0.1.0"
|
@@ -8,15 +8,16 @@ general correctness.
|
|
8
8
|
import logging
|
9
9
|
import os
|
10
10
|
import sys
|
11
|
-
from cartesia.tts import DEFAULT_MODEL_ID, AsyncCartesiaTTS, CartesiaTTS
|
12
|
-
from
|
13
|
-
|
11
|
+
from cartesia.tts import DEFAULT_MODEL_ID, AsyncCartesiaTTS, CartesiaTTS
|
12
|
+
from cartesia._types import AudioDataReturnType, AudioOutputFormat, VoiceMetadata
|
13
|
+
from typing import AsyncGenerator, Dict, Generator, List, Optional, Union
|
14
|
+
import numpy as np
|
14
15
|
import pytest
|
15
16
|
|
16
17
|
THISDIR = os.path.dirname(__file__)
|
17
18
|
sys.path.insert(0, os.path.dirname(THISDIR))
|
18
19
|
|
19
|
-
SAMPLE_VOICE = "
|
20
|
+
SAMPLE_VOICE = "Newsman"
|
20
21
|
|
21
22
|
logger = logging.getLogger(__name__)
|
22
23
|
|
@@ -166,6 +167,78 @@ def test_generate_context_manager_with_err():
|
|
166
167
|
assert websocket.socket.fileno() == -1 # check socket is now closed
|
167
168
|
|
168
169
|
|
170
|
+
@pytest.mark.parametrize("output_format", [_fmt for _fmt in AudioOutputFormat])
|
171
|
+
@pytest.mark.parametrize("as_str", [True, False])
|
172
|
+
@pytest.mark.parametrize("stream", [True, False])
|
173
|
+
@pytest.mark.parametrize("websocket", [True, False])
|
174
|
+
def test_generate_with_output_format(
|
175
|
+
resources: _Resources,
|
176
|
+
output_format: AudioOutputFormat,
|
177
|
+
as_str: bool,
|
178
|
+
stream: bool,
|
179
|
+
websocket: bool,
|
180
|
+
):
|
181
|
+
value = output_format.value
|
182
|
+
|
183
|
+
client = resources.client
|
184
|
+
voices = resources.voices
|
185
|
+
embedding = voices[SAMPLE_VOICE]["embedding"]
|
186
|
+
transcript = "Hello, world!"
|
187
|
+
|
188
|
+
split = value.split("_")
|
189
|
+
expected_sampling_rate = int(split[1]) if len(split) == 2 else 44_100
|
190
|
+
|
191
|
+
# Easy way to get around iterating over stream=True / False.
|
192
|
+
output_generate = client.generate(
|
193
|
+
transcript=transcript,
|
194
|
+
voice=embedding,
|
195
|
+
websocket=websocket,
|
196
|
+
stream=stream,
|
197
|
+
output_format=output_format.value if as_str else output_format,
|
198
|
+
)
|
199
|
+
if not stream:
|
200
|
+
output_generate = [output_generate]
|
201
|
+
|
202
|
+
for out in output_generate:
|
203
|
+
assert isinstance(out["audio"], bytes)
|
204
|
+
assert out["sampling_rate"] == expected_sampling_rate
|
205
|
+
|
206
|
+
|
207
|
+
@pytest.mark.parametrize("data_rtype", [_fmt for _fmt in AudioDataReturnType])
|
208
|
+
@pytest.mark.parametrize("as_str", [True, False])
|
209
|
+
@pytest.mark.parametrize("stream", [True, False])
|
210
|
+
@pytest.mark.parametrize("websocket", [True, False])
|
211
|
+
def test_generate_with_data_rtype(
|
212
|
+
resources: _Resources,
|
213
|
+
data_rtype: AudioDataReturnType,
|
214
|
+
as_str: bool,
|
215
|
+
stream: bool,
|
216
|
+
websocket: bool,
|
217
|
+
):
|
218
|
+
client = resources.client
|
219
|
+
voices = resources.voices
|
220
|
+
embedding = voices[SAMPLE_VOICE]["embedding"]
|
221
|
+
transcript = "Hello, world!"
|
222
|
+
|
223
|
+
# Easy way to get around iterating over stream=True / False.
|
224
|
+
output_generate = client.generate(
|
225
|
+
transcript=transcript,
|
226
|
+
voice=embedding,
|
227
|
+
websocket=websocket,
|
228
|
+
stream=stream,
|
229
|
+
data_rtype=data_rtype.value if as_str else data_rtype,
|
230
|
+
)
|
231
|
+
if not stream:
|
232
|
+
output_generate = [output_generate]
|
233
|
+
|
234
|
+
for out in output_generate:
|
235
|
+
if data_rtype == AudioDataReturnType.BYTES:
|
236
|
+
assert isinstance(out["audio"], bytes)
|
237
|
+
elif data_rtype == AudioDataReturnType.ARRAY:
|
238
|
+
assert isinstance(out["audio"], np.ndarray)
|
239
|
+
assert out["audio"].dtype == np.float32
|
240
|
+
|
241
|
+
|
169
242
|
@pytest.mark.parametrize("websocket", [True, False])
|
170
243
|
@pytest.mark.asyncio
|
171
244
|
async def test_async_generate(resources: _Resources, websocket: bool):
|
@@ -199,7 +272,9 @@ async def test_async_generate_stream(resources: _Resources, websocket: bool):
|
|
199
272
|
async_client = create_async_client()
|
200
273
|
|
201
274
|
try:
|
202
|
-
generator = await async_client.generate(
|
275
|
+
generator = await async_client.generate(
|
276
|
+
transcript=transcript, voice=embedding, websocket=websocket, stream=True
|
277
|
+
)
|
203
278
|
assert isinstance(generator, AsyncGenerator)
|
204
279
|
async for output in generator:
|
205
280
|
assert output.keys() == {"audio", "sampling_rate"}
|
@@ -248,18 +323,111 @@ async def test_generate_async_context_manager_with_err():
|
|
248
323
|
assert websocket.closed # check websocket is now closed
|
249
324
|
|
250
325
|
|
326
|
+
@pytest.mark.parametrize("output_format", [_fmt for _fmt in AudioOutputFormat])
|
327
|
+
@pytest.mark.parametrize("as_str", [True, False])
|
328
|
+
@pytest.mark.parametrize("stream", [True, False])
|
329
|
+
@pytest.mark.asyncio
|
330
|
+
async def test_generate_async_with_output_format(
|
331
|
+
resources: _Resources, output_format: AudioOutputFormat, as_str: bool, stream: bool
|
332
|
+
):
|
333
|
+
logger.info(
|
334
|
+
f"Testing async generate stream with output_format={output_format}, as_str={as_str}, stream={stream}"
|
335
|
+
)
|
336
|
+
voices = resources.voices
|
337
|
+
embedding = voices[SAMPLE_VOICE]["embedding"]
|
338
|
+
transcript = "Hello, world!"
|
339
|
+
|
340
|
+
split = output_format.value.split("_")
|
341
|
+
expected_sampling_rate = int(split[1]) if len(split) == 2 else 44_100
|
342
|
+
|
343
|
+
def _validate(output):
|
344
|
+
assert isinstance(output["audio"], bytes)
|
345
|
+
assert output["sampling_rate"] == expected_sampling_rate
|
346
|
+
|
347
|
+
async_client = create_async_client()
|
348
|
+
|
349
|
+
try:
|
350
|
+
output_generate = await async_client.generate(
|
351
|
+
transcript=transcript,
|
352
|
+
voice=embedding,
|
353
|
+
websocket=False,
|
354
|
+
stream=stream,
|
355
|
+
output_format=output_format.value if as_str else output_format,
|
356
|
+
)
|
357
|
+
if stream:
|
358
|
+
generator = output_generate
|
359
|
+
assert isinstance(generator, AsyncGenerator)
|
360
|
+
async for output in generator:
|
361
|
+
_validate(output)
|
362
|
+
else:
|
363
|
+
_validate(output_generate)
|
364
|
+
finally:
|
365
|
+
# Close the websocket
|
366
|
+
await async_client.close()
|
367
|
+
|
368
|
+
|
369
|
+
@pytest.mark.parametrize("data_rtype", [_fmt for _fmt in AudioDataReturnType])
|
370
|
+
@pytest.mark.parametrize("as_str", [True, False])
|
371
|
+
@pytest.mark.parametrize("stream", [True, False])
|
372
|
+
@pytest.mark.parametrize("websocket", [True, False])
|
373
|
+
@pytest.mark.asyncio
|
374
|
+
async def test_generate_async_with_data_rtype(resources: _Resources, data_rtype: AudioDataReturnType, as_str: bool, stream: bool, websocket: bool):
|
375
|
+
voices = resources.voices
|
376
|
+
embedding = voices[SAMPLE_VOICE]["embedding"]
|
377
|
+
transcript = "Hello, world!"
|
378
|
+
|
379
|
+
async_client = create_async_client()
|
380
|
+
|
381
|
+
def _validate(output):
|
382
|
+
if data_rtype == AudioDataReturnType.BYTES:
|
383
|
+
assert isinstance(output["audio"], bytes)
|
384
|
+
elif data_rtype == AudioDataReturnType.ARRAY:
|
385
|
+
assert isinstance(output["audio"], np.ndarray)
|
386
|
+
assert output["audio"].dtype == np.float32
|
387
|
+
|
388
|
+
try:
|
389
|
+
output_generate = await async_client.generate(
|
390
|
+
transcript=transcript,
|
391
|
+
voice=embedding,
|
392
|
+
websocket=False,
|
393
|
+
stream=stream,
|
394
|
+
data_rtype=data_rtype.value if as_str else data_rtype,
|
395
|
+
)
|
396
|
+
if stream:
|
397
|
+
assert isinstance(output_generate, AsyncGenerator)
|
398
|
+
async for output in output_generate:
|
399
|
+
_validate(output)
|
400
|
+
else:
|
401
|
+
_validate(output_generate)
|
402
|
+
finally:
|
403
|
+
# Close the websocket
|
404
|
+
await async_client.close()
|
405
|
+
|
406
|
+
|
251
407
|
@pytest.mark.parametrize("chunk_time", [0.05, 0.6])
|
252
408
|
def test_check_inputs_invalid_chunk_time(client: CartesiaTTS, chunk_time):
|
253
409
|
logger.info(f"Testing invalid chunk_time: {chunk_time}")
|
254
410
|
with pytest.raises(ValueError, match="`chunk_time` must be between 0.1 and 0.5"):
|
255
|
-
client._check_inputs(
|
411
|
+
client._check_inputs(
|
412
|
+
"Test",
|
413
|
+
None,
|
414
|
+
chunk_time,
|
415
|
+
output_format=AudioOutputFormat.FP32,
|
416
|
+
data_rtype=AudioDataReturnType.BYTES,
|
417
|
+
)
|
256
418
|
|
257
419
|
|
258
420
|
@pytest.mark.parametrize("chunk_time", [0.1, 0.3, 0.5])
|
259
421
|
def test_check_inputs_valid_chunk_time(client, chunk_time):
|
260
422
|
logger.info("Testing valid chunk_time: {chunk_time}")
|
261
423
|
try:
|
262
|
-
client._check_inputs(
|
424
|
+
client._check_inputs(
|
425
|
+
"Test",
|
426
|
+
None,
|
427
|
+
chunk_time,
|
428
|
+
output_format=AudioOutputFormat.FP32,
|
429
|
+
data_rtype=AudioDataReturnType.BYTES,
|
430
|
+
)
|
263
431
|
except ValueError:
|
264
432
|
pytest.fail("Unexpected ValueError raised")
|
265
433
|
|
@@ -267,14 +435,26 @@ def test_check_inputs_valid_chunk_time(client, chunk_time):
|
|
267
435
|
def test_check_inputs_duration_less_than_chunk_time(client: CartesiaTTS):
|
268
436
|
logger.info("Testing duration less than chunk_time")
|
269
437
|
with pytest.raises(ValueError, match="`duration` must be greater than chunk_time"):
|
270
|
-
client._check_inputs(
|
438
|
+
client._check_inputs(
|
439
|
+
"Test",
|
440
|
+
0.2,
|
441
|
+
0.3,
|
442
|
+
output_format=AudioOutputFormat.FP32,
|
443
|
+
data_rtype=AudioDataReturnType.BYTES,
|
444
|
+
)
|
271
445
|
|
272
446
|
|
273
447
|
@pytest.mark.parametrize("duration,chunk_time", [(0.5, 0.2), (1.0, 0.5), (2.0, 0.1)])
|
274
448
|
def test_check_inputs_valid_duration_and_chunk_time(client: CartesiaTTS, duration, chunk_time):
|
275
449
|
logger.info(f"Testing valid duration: {duration} and chunk_time: {chunk_time}")
|
276
450
|
try:
|
277
|
-
client._check_inputs(
|
451
|
+
client._check_inputs(
|
452
|
+
"Test",
|
453
|
+
duration,
|
454
|
+
chunk_time,
|
455
|
+
output_format=AudioOutputFormat.FP32,
|
456
|
+
data_rtype=AudioDataReturnType.BYTES,
|
457
|
+
)
|
278
458
|
except ValueError:
|
279
459
|
pytest.fail("Unexpected ValueError raised")
|
280
460
|
|
@@ -282,13 +462,99 @@ def test_check_inputs_valid_duration_and_chunk_time(client: CartesiaTTS, duratio
|
|
282
462
|
def test_check_inputs_empty_transcript(client: CartesiaTTS):
|
283
463
|
logger.info("Testing empty transcript")
|
284
464
|
with pytest.raises(ValueError, match="`transcript` must be non empty"):
|
285
|
-
client._check_inputs(
|
465
|
+
client._check_inputs(
|
466
|
+
"",
|
467
|
+
None,
|
468
|
+
None,
|
469
|
+
output_format=AudioOutputFormat.FP32,
|
470
|
+
data_rtype=AudioDataReturnType.BYTES,
|
471
|
+
)
|
286
472
|
|
287
473
|
|
288
474
|
@pytest.mark.parametrize("transcript", ["Hello", "Test transcript", "Lorem ipsum dolor sit amet"])
|
289
475
|
def test_check_inputs_valid_transcript(client: CartesiaTTS, transcript):
|
290
476
|
logger.info(f"Testing valid transcript: {transcript}")
|
291
477
|
try:
|
292
|
-
client._check_inputs(
|
478
|
+
client._check_inputs(
|
479
|
+
transcript,
|
480
|
+
None,
|
481
|
+
None,
|
482
|
+
output_format=AudioOutputFormat.FP32,
|
483
|
+
data_rtype=AudioDataReturnType.BYTES,
|
484
|
+
)
|
293
485
|
except ValueError:
|
294
486
|
pytest.fail("Unexpected ValueError raised")
|
487
|
+
|
488
|
+
|
489
|
+
@pytest.mark.parametrize(
|
490
|
+
"output_format,error",
|
491
|
+
[
|
492
|
+
# Valid output formats.
|
493
|
+
("fp32", None),
|
494
|
+
("pcm", None),
|
495
|
+
("fp32_16000", None),
|
496
|
+
("fp32_22050", None),
|
497
|
+
("fp32_44100", None),
|
498
|
+
("pcm_16000", None),
|
499
|
+
("pcm_22050", None),
|
500
|
+
("pcm_44100", None),
|
501
|
+
# Invalid output formats.
|
502
|
+
("invalid", ValueError),
|
503
|
+
("pcm_1234", ValueError), # cannot specify arbitrary sampling rate
|
504
|
+
("fp32_1234", ValueError), # cannot specify arbitrary sampling rate
|
505
|
+
("fp16_44100", ValueError), # fp16 not supported.
|
506
|
+
],
|
507
|
+
)
|
508
|
+
def test_check_inputs_output_format(
|
509
|
+
client: CartesiaTTS, output_format: Union[str, AudioOutputFormat], error: Optional[Exception]
|
510
|
+
):
|
511
|
+
if error:
|
512
|
+
with pytest.raises(error):
|
513
|
+
client._check_inputs(
|
514
|
+
"Test",
|
515
|
+
None,
|
516
|
+
None,
|
517
|
+
output_format=output_format,
|
518
|
+
data_rtype=AudioDataReturnType.BYTES,
|
519
|
+
)
|
520
|
+
else:
|
521
|
+
client._check_inputs(
|
522
|
+
"Test",
|
523
|
+
None,
|
524
|
+
None,
|
525
|
+
output_format=output_format,
|
526
|
+
data_rtype=AudioDataReturnType.BYTES,
|
527
|
+
)
|
528
|
+
|
529
|
+
|
530
|
+
@pytest.mark.parametrize(
|
531
|
+
"data_rtype,error",
|
532
|
+
[
|
533
|
+
# Valid data return types.
|
534
|
+
("bytes", None),
|
535
|
+
("array", None),
|
536
|
+
# Invalid data return types.
|
537
|
+
("invalid", ValueError),
|
538
|
+
("tensor", ValueError),
|
539
|
+
],
|
540
|
+
)
|
541
|
+
def test_check_inputs_data_rtype(
|
542
|
+
client: CartesiaTTS, data_rtype: Union[str, AudioDataReturnType], error: Optional[Exception]
|
543
|
+
):
|
544
|
+
if error:
|
545
|
+
with pytest.raises(error):
|
546
|
+
client._check_inputs(
|
547
|
+
"Test",
|
548
|
+
None,
|
549
|
+
None,
|
550
|
+
output_format=AudioOutputFormat.FP32,
|
551
|
+
data_rtype=data_rtype,
|
552
|
+
)
|
553
|
+
else:
|
554
|
+
client._check_inputs(
|
555
|
+
"Test",
|
556
|
+
None,
|
557
|
+
None,
|
558
|
+
output_format=AudioOutputFormat.FP32,
|
559
|
+
data_rtype=data_rtype,
|
560
|
+
)
|
@@ -1 +0,0 @@
|
|
1
|
-
__version__ = "0.0.6"
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|