cartesia 0.1.1__py2.py3-none-any.whl → 1.0.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cartesia/__init__.py +2 -2
- cartesia/_types.py +33 -39
- cartesia/client.py +874 -0
- cartesia/version.py +1 -1
- cartesia-1.0.0.dist-info/METADATA +364 -0
- cartesia-1.0.0.dist-info/RECORD +9 -0
- cartesia/tts.py +0 -702
- cartesia-0.1.1.dist-info/METADATA +0 -189
- cartesia-0.1.1.dist-info/RECORD +0 -9
- {cartesia-0.1.1.dist-info → cartesia-1.0.0.dist-info}/WHEEL +0 -0
- {cartesia-0.1.1.dist-info → cartesia-1.0.0.dist-info}/top_level.txt +0 -0
cartesia/tts.py
DELETED
@@ -1,702 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
import base64
|
3
|
-
import json
|
4
|
-
import os
|
5
|
-
import uuid
|
6
|
-
from types import TracebackType
|
7
|
-
from typing import (
|
8
|
-
Any,
|
9
|
-
AsyncGenerator,
|
10
|
-
Dict,
|
11
|
-
Generator,
|
12
|
-
List,
|
13
|
-
Optional,
|
14
|
-
Tuple,
|
15
|
-
TypedDict,
|
16
|
-
Union,
|
17
|
-
)
|
18
|
-
|
19
|
-
import aiohttp
|
20
|
-
import httpx
|
21
|
-
import logging
|
22
|
-
import requests
|
23
|
-
from websockets.sync.client import connect
|
24
|
-
|
25
|
-
from cartesia.utils import retry_on_connection_error, retry_on_connection_error_async
|
26
|
-
from cartesia._types import (
|
27
|
-
AudioDataReturnType,
|
28
|
-
AudioOutputFormat,
|
29
|
-
AudioOutput,
|
30
|
-
Embedding,
|
31
|
-
VoiceMetadata,
|
32
|
-
)
|
33
|
-
|
34
|
-
try:
|
35
|
-
import numpy as np
|
36
|
-
|
37
|
-
_NUMPY_AVAILABLE = True
|
38
|
-
except ImportError:
|
39
|
-
_NUMPY_AVAILABLE = False
|
40
|
-
|
41
|
-
|
42
|
-
DEFAULT_MODEL_ID = ""
|
43
|
-
DEFAULT_BASE_URL = "api.cartesia.ai"
|
44
|
-
DEFAULT_API_VERSION = "v0"
|
45
|
-
DEFAULT_TIMEOUT = 30 # seconds
|
46
|
-
DEFAULT_NUM_CONNECTIONS = 10 # connections per client
|
47
|
-
|
48
|
-
BACKOFF_FACTOR = 1
|
49
|
-
MAX_RETRIES = 3
|
50
|
-
|
51
|
-
logger = logging.getLogger(__name__)
|
52
|
-
|
53
|
-
|
54
|
-
def update_buffer(buffer: str, chunk_bytes: bytes) -> Tuple[str, List[Dict[str, Any]]]:
|
55
|
-
buffer += chunk_bytes.decode("utf-8")
|
56
|
-
outputs = []
|
57
|
-
while "{" in buffer and "}" in buffer:
|
58
|
-
start_index = buffer.find("{")
|
59
|
-
end_index = buffer.find("}", start_index)
|
60
|
-
if start_index != -1 and end_index != -1:
|
61
|
-
try:
|
62
|
-
chunk_json = json.loads(buffer[start_index : end_index + 1])
|
63
|
-
audio = base64.b64decode(chunk_json["data"])
|
64
|
-
outputs.append({"audio": audio, "sampling_rate": chunk_json["sampling_rate"]})
|
65
|
-
buffer = buffer[end_index + 1 :]
|
66
|
-
except json.JSONDecodeError:
|
67
|
-
break
|
68
|
-
return buffer, outputs
|
69
|
-
|
70
|
-
|
71
|
-
def convert_response(response: Dict[str, any], include_context_id: bool) -> Dict[str, Any]:
|
72
|
-
audio = base64.b64decode(response["data"])
|
73
|
-
|
74
|
-
optional_kwargs = {}
|
75
|
-
if include_context_id:
|
76
|
-
optional_kwargs["context_id"] = response["context_id"]
|
77
|
-
|
78
|
-
return {
|
79
|
-
"audio": audio,
|
80
|
-
"sampling_rate": response["sampling_rate"],
|
81
|
-
**optional_kwargs,
|
82
|
-
}
|
83
|
-
|
84
|
-
|
85
|
-
class CartesiaTTS:
|
86
|
-
"""The client for Cartesia's text-to-speech library.
|
87
|
-
|
88
|
-
This client contains methods to interact with the Cartesia text-to-speech API.
|
89
|
-
The client can be used to retrieve available voices, compute new voice embeddings,
|
90
|
-
and generate speech from text.
|
91
|
-
|
92
|
-
The client also supports generating audio using a websocket for lower latency.
|
93
|
-
|
94
|
-
Examples:
|
95
|
-
>>> client = CartesiaTTS()
|
96
|
-
|
97
|
-
# Load available voices and their metadata (excluding the embeddings).
|
98
|
-
# Embeddings are fetched with `get_voice_embedding`. This avoids preloading
|
99
|
-
# all of the embeddings, which can be expensive if there are a lot of voices.
|
100
|
-
>>> voices = client.get_voices()
|
101
|
-
>>> embedding = client.get_voice_embedding(voice_id=voices["Milo"]["id"])
|
102
|
-
>>> audio = client.generate(transcript="Hello world!", voice=embedding)
|
103
|
-
|
104
|
-
# Preload all available voices and their embeddings if you plan on reusing
|
105
|
-
# all of the embeddings often.
|
106
|
-
>>> voices = client.get_voices(skip_embeddings=False)
|
107
|
-
>>> embedding = voices["Milo"]["embedding"]
|
108
|
-
>>> audio = client.generate(transcript="Hello world!", voice=embedding)
|
109
|
-
|
110
|
-
# Generate audio stream
|
111
|
-
>>> for audio_chunk in client.generate(transcript="Hello world!", voice=embedding, stream=True):
|
112
|
-
... audio, sr = audio_chunk["audio"], audio_chunk["sampling_rate"]
|
113
|
-
"""
|
114
|
-
|
115
|
-
def __init__(self, *, api_key: str = None):
|
116
|
-
"""Args:
|
117
|
-
api_key: The API key to use for authorization.
|
118
|
-
If not specified, the API key will be read from the environment variable
|
119
|
-
`CARTESIA_API_KEY`.
|
120
|
-
"""
|
121
|
-
self.base_url = os.environ.get("CARTESIA_BASE_URL", DEFAULT_BASE_URL)
|
122
|
-
self.api_key = api_key or os.environ.get("CARTESIA_API_KEY")
|
123
|
-
self.api_version = os.environ.get("CARTESIA_API_VERSION", DEFAULT_API_VERSION)
|
124
|
-
self.headers = {"X-API-Key": self.api_key, "Content-Type": "application/json"}
|
125
|
-
self.websocket = None
|
126
|
-
|
127
|
-
def get_voices(self, skip_embeddings: bool = True) -> Dict[str, VoiceMetadata]:
|
128
|
-
"""Returns a mapping from voice name -> voice metadata.
|
129
|
-
|
130
|
-
Args:
|
131
|
-
skip_embeddings: Whether to skip returning the embeddings.
|
132
|
-
It is recommended to skip if you only want to see what
|
133
|
-
voices are available, since loading embeddings for all your voices can be expensive.
|
134
|
-
You can then use ``get_voice_embedding`` to get the embeddings for the voices you are
|
135
|
-
interested in.
|
136
|
-
|
137
|
-
Returns:
|
138
|
-
A mapping from voice name -> voice metadata.
|
139
|
-
|
140
|
-
Note:
|
141
|
-
If the voice name is not unique, there is undefined behavior as to which
|
142
|
-
voice will correspond to the name. To be more thorough, look at the web
|
143
|
-
client to find the `voice_id` for the voice you are looking for.
|
144
|
-
|
145
|
-
Usage:
|
146
|
-
>>> client = CartesiaTTS()
|
147
|
-
>>> voices = client.get_voices()
|
148
|
-
>>> voices
|
149
|
-
{
|
150
|
-
"Jane": {
|
151
|
-
"id": "c1d1d3a8-6f4e-4b3f-8b3e-2e1b3e1b3e1b",
|
152
|
-
"name": "Jane",
|
153
|
-
}
|
154
|
-
>>> embedding = client.get_voice_embedding(voice_id=voices["Jane"]["id"])
|
155
|
-
>>> audio = client.generate(transcript="Hello world!", voice=embedding)
|
156
|
-
"""
|
157
|
-
params = {"select": "id, name, description"} if skip_embeddings else None
|
158
|
-
response = httpx.get(
|
159
|
-
f"{self._http_url()}/voices",
|
160
|
-
headers=self.headers,
|
161
|
-
params=params,
|
162
|
-
timeout=DEFAULT_TIMEOUT,
|
163
|
-
)
|
164
|
-
|
165
|
-
if not response.is_success:
|
166
|
-
raise ValueError(f"Failed to get voices. Error: {response.text}")
|
167
|
-
|
168
|
-
voices = response.json()
|
169
|
-
for voice in voices:
|
170
|
-
if "embedding" in voice and isinstance(voice["embedding"], str):
|
171
|
-
voice["embedding"] = json.loads(voice["embedding"])
|
172
|
-
return {voice["name"]: voice for voice in voices}
|
173
|
-
|
174
|
-
@retry_on_connection_error(
|
175
|
-
max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger
|
176
|
-
)
|
177
|
-
def get_voice_embedding(
|
178
|
-
self, *, voice_id: str = None, filepath: str = None, link: str = None
|
179
|
-
) -> Embedding:
|
180
|
-
"""Get a voice embedding from voice_id, a filepath or YouTube url.
|
181
|
-
|
182
|
-
Args:
|
183
|
-
voice_id: The voice id.
|
184
|
-
filepath: Path to audio file from which to get the audio.
|
185
|
-
link: The url to get the audio from. Currently only supports youtube shared urls.
|
186
|
-
|
187
|
-
Returns:
|
188
|
-
The voice embedding.
|
189
|
-
|
190
|
-
Raises:
|
191
|
-
ValueError: If more than one of `voice_id`, `filepath` or `link` is specified.
|
192
|
-
Only one should be specified.
|
193
|
-
"""
|
194
|
-
if sum(bool(x) for x in (voice_id, filepath, link)) != 1:
|
195
|
-
raise ValueError("Exactly one of `voice_id`, `filepath` or `url` should be specified.")
|
196
|
-
|
197
|
-
if voice_id:
|
198
|
-
url = f"{self._http_url()}/voices/embedding/{voice_id}"
|
199
|
-
response = httpx.get(url, headers=self.headers, timeout=DEFAULT_TIMEOUT)
|
200
|
-
elif filepath:
|
201
|
-
url = f"{self._http_url()}/voices/clone/clip"
|
202
|
-
files = {"clip": open(filepath, "rb")}
|
203
|
-
headers = self.headers.copy()
|
204
|
-
# The default content type of JSON is incorrect for file uploads
|
205
|
-
headers.pop("Content-Type")
|
206
|
-
response = httpx.post(url, headers=headers, files=files, timeout=DEFAULT_TIMEOUT)
|
207
|
-
elif link:
|
208
|
-
url = f"{self._http_url()}/voices/clone/url"
|
209
|
-
params = {"link": link}
|
210
|
-
response = httpx.post(url, headers=self.headers, params=params, timeout=DEFAULT_TIMEOUT)
|
211
|
-
|
212
|
-
if not response.is_success:
|
213
|
-
raise ValueError(
|
214
|
-
f"Failed to clone voice. Status Code: {response.status_code}\n"
|
215
|
-
f"Error: {response.text}"
|
216
|
-
)
|
217
|
-
|
218
|
-
# Handle successful response
|
219
|
-
out = response.json()
|
220
|
-
embedding = out["embedding"]
|
221
|
-
if isinstance(embedding, str):
|
222
|
-
embedding = json.loads(embedding)
|
223
|
-
return embedding
|
224
|
-
|
225
|
-
def refresh_websocket(self):
|
226
|
-
"""Refresh the websocket connection.
|
227
|
-
|
228
|
-
Note:
|
229
|
-
The connection is synchronous.
|
230
|
-
"""
|
231
|
-
if self.websocket is None or self._is_websocket_closed():
|
232
|
-
route = "audio/websocket"
|
233
|
-
self.websocket = connect(f"{self._ws_url()}/{route}?api_key={self.api_key}")
|
234
|
-
|
235
|
-
def _is_websocket_closed(self):
|
236
|
-
return self.websocket.socket.fileno() == -1
|
237
|
-
|
238
|
-
def _check_inputs(
|
239
|
-
self,
|
240
|
-
transcript: str,
|
241
|
-
duration: Optional[float],
|
242
|
-
chunk_time: Optional[float],
|
243
|
-
output_format: Union[str, AudioOutputFormat],
|
244
|
-
data_rtype: Union[str, AudioDataReturnType],
|
245
|
-
):
|
246
|
-
# This will try the casting and raise an error.
|
247
|
-
_ = AudioOutputFormat(output_format)
|
248
|
-
|
249
|
-
if AudioDataReturnType(data_rtype) == AudioDataReturnType.ARRAY and not _NUMPY_AVAILABLE:
|
250
|
-
raise ImportError(
|
251
|
-
"The 'numpy' package is required to use the 'array' return type. "
|
252
|
-
"Please install 'numpy' or use 'bytes' as the return type."
|
253
|
-
)
|
254
|
-
|
255
|
-
if chunk_time is not None:
|
256
|
-
if chunk_time < 0.1 or chunk_time > 0.5:
|
257
|
-
raise ValueError("`chunk_time` must be between 0.1 and 0.5")
|
258
|
-
|
259
|
-
if chunk_time is not None and duration is not None:
|
260
|
-
if duration < chunk_time:
|
261
|
-
raise ValueError("`duration` must be greater than chunk_time")
|
262
|
-
|
263
|
-
if transcript.strip() == "":
|
264
|
-
raise ValueError("`transcript` must be non empty")
|
265
|
-
|
266
|
-
def _generate_request_body(
|
267
|
-
self,
|
268
|
-
*,
|
269
|
-
transcript: str,
|
270
|
-
voice: Embedding,
|
271
|
-
model_id: str,
|
272
|
-
output_format: AudioOutputFormat,
|
273
|
-
duration: int = None,
|
274
|
-
chunk_time: float = None,
|
275
|
-
) -> Dict[str, Any]:
|
276
|
-
"""Create the request body for a stream request.
|
277
|
-
|
278
|
-
Note that anything that's not provided will use a default if available or be
|
279
|
-
filtered out otherwise.
|
280
|
-
"""
|
281
|
-
body = dict(transcript=transcript, model_id=model_id, voice=voice)
|
282
|
-
output_format = output_format.value
|
283
|
-
|
284
|
-
optional_body = dict(
|
285
|
-
duration=duration,
|
286
|
-
chunk_time=chunk_time,
|
287
|
-
output_format=output_format,
|
288
|
-
)
|
289
|
-
body.update({k: v for k, v in optional_body.items() if v is not None})
|
290
|
-
|
291
|
-
return body
|
292
|
-
|
293
|
-
def generate(
|
294
|
-
self,
|
295
|
-
*,
|
296
|
-
transcript: str,
|
297
|
-
voice: Embedding,
|
298
|
-
model_id: str = DEFAULT_MODEL_ID,
|
299
|
-
duration: int = None,
|
300
|
-
chunk_time: float = None,
|
301
|
-
stream: bool = False,
|
302
|
-
websocket: bool = True,
|
303
|
-
output_format: Union[str, AudioOutputFormat] = "fp32",
|
304
|
-
data_rtype: str = "bytes",
|
305
|
-
) -> Union[AudioOutput, Generator[AudioOutput, None, None]]:
|
306
|
-
"""Generate audio from a transcript.
|
307
|
-
|
308
|
-
Args:
|
309
|
-
transcript (str): The text to generate audio for.
|
310
|
-
voice (Embedding (List[float])): The voice to use for generating audio.
|
311
|
-
duration (int, optional): The maximum duration of the audio in seconds.
|
312
|
-
chunk_time (float, optional): How long each audio segment should be in seconds.
|
313
|
-
This should not need to be adjusted.
|
314
|
-
stream (bool, optional): Whether to stream the audio or not.
|
315
|
-
If True this function returns a generator. False by default.
|
316
|
-
websocket (bool, optional): Whether to use a websocket for streaming audio.
|
317
|
-
Using the websocket reduces latency by pre-poning the handshake. True by default.
|
318
|
-
data_rtype: The return type for the 'data' key in the dictionary.
|
319
|
-
One of `'byte' | 'array'`.
|
320
|
-
Note this field is experimental and may be deprecated in the future.
|
321
|
-
|
322
|
-
Returns:
|
323
|
-
A generator if `stream` is True, otherwise a dictionary.
|
324
|
-
Dictionary from both generator and non-generator return types have the following keys:
|
325
|
-
* "audio": The audio as a bytes buffer.
|
326
|
-
* "sampling_rate": The sampling rate of the audio.
|
327
|
-
"""
|
328
|
-
self._check_inputs(transcript, duration, chunk_time, output_format, data_rtype)
|
329
|
-
|
330
|
-
data_rtype = AudioDataReturnType(data_rtype)
|
331
|
-
output_format = AudioOutputFormat(output_format)
|
332
|
-
|
333
|
-
body = self._generate_request_body(
|
334
|
-
transcript=transcript,
|
335
|
-
voice=voice,
|
336
|
-
model_id=model_id,
|
337
|
-
duration=duration,
|
338
|
-
chunk_time=chunk_time,
|
339
|
-
output_format=output_format,
|
340
|
-
)
|
341
|
-
|
342
|
-
if websocket:
|
343
|
-
generator = self._generate_ws(body)
|
344
|
-
else:
|
345
|
-
generator = self._generate_http_wrapper(body)
|
346
|
-
|
347
|
-
generator = self._postprocess_audio(
|
348
|
-
generator, data_rtype=data_rtype, output_format=output_format
|
349
|
-
)
|
350
|
-
if stream:
|
351
|
-
return generator
|
352
|
-
|
353
|
-
chunks = []
|
354
|
-
sampling_rate = None
|
355
|
-
for chunk in generator:
|
356
|
-
if sampling_rate is None:
|
357
|
-
sampling_rate = chunk["sampling_rate"]
|
358
|
-
chunks.append(chunk["audio"])
|
359
|
-
|
360
|
-
if data_rtype == AudioDataReturnType.ARRAY:
|
361
|
-
cat = np.concatenate
|
362
|
-
else:
|
363
|
-
cat = b"".join
|
364
|
-
|
365
|
-
return {"audio": cat(chunks), "sampling_rate": sampling_rate}
|
366
|
-
|
367
|
-
def _postprocess_audio(
|
368
|
-
self,
|
369
|
-
generator: Generator[AudioOutput, None, None],
|
370
|
-
*,
|
371
|
-
data_rtype: AudioDataReturnType,
|
372
|
-
output_format: AudioOutputFormat,
|
373
|
-
) -> Generator[AudioOutput, None, None]:
|
374
|
-
"""Perform postprocessing on the generator outputs.
|
375
|
-
|
376
|
-
The postprocessing should be minimal (e.g. converting to array, casting dtype).
|
377
|
-
This code should not perform heavy operations like changing the sampling rate.
|
378
|
-
|
379
|
-
Args:
|
380
|
-
generator: A generator that yields audio chunks.
|
381
|
-
data_rtype: The data return type.
|
382
|
-
output_format: The output format for the audio.
|
383
|
-
|
384
|
-
Returns:
|
385
|
-
A generator that yields audio chunks.
|
386
|
-
"""
|
387
|
-
dtype = None
|
388
|
-
if data_rtype == AudioDataReturnType.ARRAY:
|
389
|
-
dtype = np.float32 if "fp32" in output_format.value else np.int16
|
390
|
-
|
391
|
-
for chunk in generator:
|
392
|
-
if dtype is not None:
|
393
|
-
chunk["audio"] = np.frombuffer(chunk["audio"], dtype=dtype)
|
394
|
-
yield chunk
|
395
|
-
|
396
|
-
@retry_on_connection_error(
|
397
|
-
max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger
|
398
|
-
)
|
399
|
-
def _generate_http_wrapper(self, body: Dict[str, Any]):
|
400
|
-
"""Need to wrap the http generator in a function for the retry decorator to work."""
|
401
|
-
try:
|
402
|
-
for chunk in self._generate_http(body):
|
403
|
-
yield chunk
|
404
|
-
except Exception as e:
|
405
|
-
logger.error(f"Failed to generate audio. {e}")
|
406
|
-
raise e
|
407
|
-
|
408
|
-
def _generate_http(self, body: Dict[str, Any]):
|
409
|
-
response = requests.post(
|
410
|
-
f"{self._http_url()}/audio/sse",
|
411
|
-
stream=True,
|
412
|
-
data=json.dumps(body),
|
413
|
-
headers=self.headers,
|
414
|
-
timeout=(DEFAULT_TIMEOUT, DEFAULT_TIMEOUT),
|
415
|
-
)
|
416
|
-
if not response.ok:
|
417
|
-
raise ValueError(f"Failed to generate audio. {response.text}")
|
418
|
-
|
419
|
-
buffer = ""
|
420
|
-
for chunk_bytes in response.iter_content(chunk_size=None):
|
421
|
-
buffer, outputs = update_buffer(buffer, chunk_bytes)
|
422
|
-
for output in outputs:
|
423
|
-
yield output
|
424
|
-
|
425
|
-
if buffer:
|
426
|
-
try:
|
427
|
-
chunk_json = json.loads(buffer)
|
428
|
-
audio = base64.b64decode(chunk_json["data"])
|
429
|
-
yield {"audio": audio, "sampling_rate": chunk_json["sampling_rate"]}
|
430
|
-
except json.JSONDecodeError:
|
431
|
-
pass
|
432
|
-
|
433
|
-
def _generate_ws(self, body: Dict[str, Any], *, context_id: str = None):
|
434
|
-
"""Generate audio using the websocket connection.
|
435
|
-
|
436
|
-
Args:
|
437
|
-
body: The request body.
|
438
|
-
context_id: The context id for the request.
|
439
|
-
The context id must be globally unique for the duration this client exists.
|
440
|
-
If this is provided, the context id that is in the response will
|
441
|
-
also be returned as part of the dict. This is helpful for testing.
|
442
|
-
"""
|
443
|
-
if not self.websocket or self._is_websocket_closed():
|
444
|
-
self.refresh_websocket()
|
445
|
-
|
446
|
-
include_context_id = bool(context_id)
|
447
|
-
if context_id is None:
|
448
|
-
context_id = uuid.uuid4().hex
|
449
|
-
self.websocket.send(json.dumps({"data": body, "context_id": context_id}))
|
450
|
-
try:
|
451
|
-
while True:
|
452
|
-
response = json.loads(self.websocket.recv())
|
453
|
-
if "error" in response:
|
454
|
-
raise RuntimeError(f"Error generating audio:\n{response['error']}")
|
455
|
-
if response["done"]:
|
456
|
-
break
|
457
|
-
|
458
|
-
yield convert_response(response, include_context_id)
|
459
|
-
except Exception as e:
|
460
|
-
# Close the websocket connection if an error occurs.
|
461
|
-
if self.websocket and not self._is_websocket_closed():
|
462
|
-
self.websocket.close()
|
463
|
-
raise RuntimeError(f"Failed to generate audio. {response}") from e
|
464
|
-
finally:
|
465
|
-
# Ensure the websocket is ultimately closed.
|
466
|
-
if self.websocket and not self._is_websocket_closed():
|
467
|
-
self.websocket.close()
|
468
|
-
|
469
|
-
def prepare_audio_and_headers(
|
470
|
-
self, raw_audio: Union[bytes, str]
|
471
|
-
) -> Tuple[bytes, Dict[str, Any]]:
|
472
|
-
if isinstance(raw_audio, str):
|
473
|
-
with open(raw_audio, "rb") as f:
|
474
|
-
raw_audio_bytes = f.read()
|
475
|
-
else:
|
476
|
-
raw_audio_bytes = raw_audio
|
477
|
-
# application/json is not the right content type for this request
|
478
|
-
headers = {k: v for k, v in self.headers.items() if k != "Content-Type"}
|
479
|
-
return raw_audio_bytes, headers
|
480
|
-
|
481
|
-
def _http_url(self):
|
482
|
-
prefix = "http" if "localhost" in self.base_url else "https"
|
483
|
-
return f"{prefix}://{self.base_url}/{self.api_version}"
|
484
|
-
|
485
|
-
def _ws_url(self):
|
486
|
-
prefix = "ws" if "localhost" in self.base_url else "wss"
|
487
|
-
return f"{prefix}://{self.base_url}/{self.api_version}"
|
488
|
-
|
489
|
-
def close(self):
|
490
|
-
if self.websocket and not self._is_websocket_closed():
|
491
|
-
self.websocket.close()
|
492
|
-
|
493
|
-
def __del__(self):
|
494
|
-
self.close()
|
495
|
-
|
496
|
-
def __enter__(self):
|
497
|
-
self.refresh_websocket()
|
498
|
-
return self
|
499
|
-
|
500
|
-
def __exit__(
|
501
|
-
self,
|
502
|
-
exc_type: Union[type, None],
|
503
|
-
exc: Union[BaseException, None],
|
504
|
-
exc_tb: Union[TracebackType, None],
|
505
|
-
):
|
506
|
-
self.close()
|
507
|
-
|
508
|
-
|
509
|
-
class AsyncCartesiaTTS(CartesiaTTS):
|
510
|
-
def __init__(self, *, api_key: str = None):
|
511
|
-
self._session = None
|
512
|
-
self._loop = None
|
513
|
-
super().__init__(api_key=api_key)
|
514
|
-
|
515
|
-
async def _get_session(self):
|
516
|
-
current_loop = asyncio.get_event_loop()
|
517
|
-
if self._loop is not current_loop:
|
518
|
-
# If the loop has changed, close the session and create a new one.
|
519
|
-
await self.close()
|
520
|
-
if self._session is None or self._session.closed:
|
521
|
-
timeout = aiohttp.ClientTimeout(total=DEFAULT_TIMEOUT)
|
522
|
-
connector = aiohttp.TCPConnector(limit=DEFAULT_NUM_CONNECTIONS)
|
523
|
-
self._session = aiohttp.ClientSession(timeout=timeout, connector=connector)
|
524
|
-
self._loop = current_loop
|
525
|
-
return self._session
|
526
|
-
|
527
|
-
async def refresh_websocket(self):
|
528
|
-
"""Refresh the websocket connection."""
|
529
|
-
if self.websocket is None or self._is_websocket_closed():
|
530
|
-
route = "audio/websocket"
|
531
|
-
session = await self._get_session()
|
532
|
-
self.websocket = await session.ws_connect(
|
533
|
-
f"{self._ws_url()}/{route}?api_key={self.api_key}"
|
534
|
-
)
|
535
|
-
|
536
|
-
def _is_websocket_closed(self):
|
537
|
-
return self.websocket.closed
|
538
|
-
|
539
|
-
async def close(self):
|
540
|
-
"""This method closes the websocket and the session.
|
541
|
-
|
542
|
-
It is *strongly* recommended to call this method when you are done using the client.
|
543
|
-
"""
|
544
|
-
if self.websocket is not None and not self._is_websocket_closed():
|
545
|
-
await self.websocket.close()
|
546
|
-
if self._session is not None and not self._session.closed:
|
547
|
-
await self._session.close()
|
548
|
-
|
549
|
-
async def generate(
|
550
|
-
self,
|
551
|
-
*,
|
552
|
-
transcript: str,
|
553
|
-
voice: Embedding,
|
554
|
-
model_id: str = DEFAULT_MODEL_ID,
|
555
|
-
duration: int = None,
|
556
|
-
chunk_time: float = None,
|
557
|
-
stream: bool = False,
|
558
|
-
websocket: bool = True,
|
559
|
-
output_format: Union[str, AudioOutputFormat] = "fp32",
|
560
|
-
data_rtype: Union[str, AudioDataReturnType] = "bytes",
|
561
|
-
) -> Union[AudioOutput, AsyncGenerator[AudioOutput, None]]:
|
562
|
-
"""Asynchronously generate audio from a transcript.
|
563
|
-
|
564
|
-
For more information on the arguments, see the synchronous :meth:`CartesiaTTS.generate`.
|
565
|
-
"""
|
566
|
-
self._check_inputs(transcript, duration, chunk_time, output_format, data_rtype)
|
567
|
-
data_rtype = AudioDataReturnType(data_rtype)
|
568
|
-
output_format = AudioOutputFormat(output_format)
|
569
|
-
|
570
|
-
body = self._generate_request_body(
|
571
|
-
transcript=transcript,
|
572
|
-
voice=voice,
|
573
|
-
model_id=model_id,
|
574
|
-
duration=duration,
|
575
|
-
chunk_time=chunk_time,
|
576
|
-
output_format=output_format,
|
577
|
-
)
|
578
|
-
|
579
|
-
if websocket:
|
580
|
-
generator = self._generate_ws(body)
|
581
|
-
else:
|
582
|
-
generator = self._generate_http_wrapper(body)
|
583
|
-
generator = self._postprocess_audio(
|
584
|
-
generator, data_rtype=data_rtype, output_format=output_format
|
585
|
-
)
|
586
|
-
if stream:
|
587
|
-
return generator
|
588
|
-
|
589
|
-
chunks = []
|
590
|
-
sampling_rate = None
|
591
|
-
async for chunk in generator:
|
592
|
-
if sampling_rate is None:
|
593
|
-
sampling_rate = chunk["sampling_rate"]
|
594
|
-
chunks.append(chunk["audio"])
|
595
|
-
|
596
|
-
if data_rtype == AudioDataReturnType.ARRAY:
|
597
|
-
cat = np.concatenate
|
598
|
-
else:
|
599
|
-
cat = b"".join
|
600
|
-
|
601
|
-
return {"audio": cat(chunks), "sampling_rate": sampling_rate}
|
602
|
-
|
603
|
-
async def _postprocess_audio(
|
604
|
-
self,
|
605
|
-
generator: AsyncGenerator[AudioOutput, None],
|
606
|
-
*,
|
607
|
-
data_rtype: AudioDataReturnType,
|
608
|
-
output_format: AudioOutputFormat,
|
609
|
-
) -> AsyncGenerator[AudioOutput, None]:
|
610
|
-
"""See :meth:`CartesiaTTS._postprocess_audio`."""
|
611
|
-
dtype = None
|
612
|
-
if data_rtype == AudioDataReturnType.ARRAY:
|
613
|
-
dtype = np.float32 if "fp32" in output_format.value else np.int16
|
614
|
-
|
615
|
-
async for chunk in generator:
|
616
|
-
if dtype is not None:
|
617
|
-
chunk["audio"] = np.frombuffer(chunk["audio"], dtype=dtype)
|
618
|
-
yield chunk
|
619
|
-
|
620
|
-
@retry_on_connection_error_async(
|
621
|
-
max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger
|
622
|
-
)
|
623
|
-
async def _generate_http_wrapper(self, body: Dict[str, Any]):
|
624
|
-
"""Need to wrap the http generator in a function for the retry decorator to work."""
|
625
|
-
try:
|
626
|
-
async for chunk in self._generate_http(body):
|
627
|
-
yield chunk
|
628
|
-
except Exception as e:
|
629
|
-
logger.error(f"Failed to generate audio. {e}")
|
630
|
-
raise e
|
631
|
-
|
632
|
-
async def _generate_http(self, body: Dict[str, Any]):
|
633
|
-
session = await self._get_session()
|
634
|
-
async with session.post(
|
635
|
-
f"{self._http_url()}/audio/sse", data=json.dumps(body), headers=self.headers
|
636
|
-
) as response:
|
637
|
-
if not response.ok:
|
638
|
-
raise ValueError(f"Failed to generate audio. {await response.text()}")
|
639
|
-
|
640
|
-
buffer = ""
|
641
|
-
async for chunk_bytes in response.content.iter_any():
|
642
|
-
buffer, outputs = update_buffer(buffer, chunk_bytes)
|
643
|
-
for output in outputs:
|
644
|
-
yield output
|
645
|
-
|
646
|
-
if buffer:
|
647
|
-
try:
|
648
|
-
chunk_json = json.loads(buffer)
|
649
|
-
audio = base64.b64decode(chunk_json["data"])
|
650
|
-
yield {"audio": audio, "sampling_rate": chunk_json["sampling_rate"]}
|
651
|
-
except json.JSONDecodeError:
|
652
|
-
pass
|
653
|
-
|
654
|
-
async def _generate_ws(self, body: Dict[str, Any], *, context_id: str = None):
|
655
|
-
include_context_id = bool(context_id)
|
656
|
-
if not self.websocket or self._is_websocket_closed():
|
657
|
-
await self.refresh_websocket()
|
658
|
-
|
659
|
-
ws = self.websocket
|
660
|
-
if context_id is None:
|
661
|
-
context_id = uuid.uuid4().hex
|
662
|
-
await ws.send_json({"data": body, "context_id": context_id})
|
663
|
-
try:
|
664
|
-
response = None
|
665
|
-
while True:
|
666
|
-
response = await ws.receive_json()
|
667
|
-
if response["done"]:
|
668
|
-
break
|
669
|
-
|
670
|
-
yield convert_response(response, include_context_id)
|
671
|
-
except Exception as e:
|
672
|
-
if self.websocket and not self._is_websocket_closed():
|
673
|
-
await self.websocket.close()
|
674
|
-
error_msg_end = "" if response is None else f": {await response.text()}"
|
675
|
-
raise RuntimeError(f"Failed to generate audio{error_msg_end}") from e
|
676
|
-
finally:
|
677
|
-
# Ensure the websocket is ultimately closed.
|
678
|
-
if self.websocket and not self._is_websocket_closed():
|
679
|
-
await self.websocket.close()
|
680
|
-
|
681
|
-
def __del__(self):
|
682
|
-
try:
|
683
|
-
loop = asyncio.get_running_loop()
|
684
|
-
except RuntimeError:
|
685
|
-
loop = None
|
686
|
-
|
687
|
-
if loop is None:
|
688
|
-
asyncio.run(self.close())
|
689
|
-
else:
|
690
|
-
loop.create_task(self.close())
|
691
|
-
|
692
|
-
async def __aenter__(self):
|
693
|
-
await self.refresh_websocket()
|
694
|
-
return self
|
695
|
-
|
696
|
-
async def __aexit__(
|
697
|
-
self,
|
698
|
-
exc_type: Union[type, None],
|
699
|
-
exc: Union[BaseException, None],
|
700
|
-
exc_tb: Union[TracebackType, None],
|
701
|
-
):
|
702
|
-
await self.close()
|