cartesia 0.1.1__py2.py3-none-any.whl → 1.0.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cartesia/__init__.py +2 -2
- cartesia/_types.py +33 -39
- cartesia/client.py +874 -0
- cartesia/version.py +1 -1
- cartesia-1.0.0.dist-info/METADATA +364 -0
- cartesia-1.0.0.dist-info/RECORD +9 -0
- cartesia/tts.py +0 -702
- cartesia-0.1.1.dist-info/METADATA +0 -189
- cartesia-0.1.1.dist-info/RECORD +0 -9
- {cartesia-0.1.1.dist-info → cartesia-1.0.0.dist-info}/WHEEL +0 -0
- {cartesia-0.1.1.dist-info → cartesia-1.0.0.dist-info}/top_level.txt +0 -0
cartesia/client.py
ADDED
@@ -0,0 +1,874 @@
|
|
1
|
+
import asyncio
|
2
|
+
import base64
|
3
|
+
import json
|
4
|
+
import os
|
5
|
+
import uuid
|
6
|
+
from types import TracebackType
|
7
|
+
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union, Callable
|
8
|
+
|
9
|
+
import aiohttp
|
10
|
+
import httpx
|
11
|
+
import logging
|
12
|
+
import requests
|
13
|
+
from websockets.sync.client import connect
|
14
|
+
|
15
|
+
from cartesia.utils import retry_on_connection_error, retry_on_connection_error_async
|
16
|
+
from cartesia._types import (
|
17
|
+
OutputFormat,
|
18
|
+
OutputFormatMapping,
|
19
|
+
VoiceMetadata,
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
DEFAULT_MODEL_ID = "sonic-english" # latest default model
|
24
|
+
MULTILINGUAL_MODEL_ID = "sonic-multilingual" # latest multilingual model
|
25
|
+
DEFAULT_BASE_URL = "api.cartesia.ai"
|
26
|
+
DEFAULT_CARTESIA_VERSION = "2024-06-10" # latest version
|
27
|
+
DEFAULT_TIMEOUT = 30 # seconds
|
28
|
+
DEFAULT_NUM_CONNECTIONS = 10 # connections per client
|
29
|
+
|
30
|
+
BACKOFF_FACTOR = 1
|
31
|
+
MAX_RETRIES = 3
|
32
|
+
|
33
|
+
logger = logging.getLogger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
class BaseClient:
|
37
|
+
def __init__(self, *, api_key: Optional[str] = None, timeout: float = DEFAULT_TIMEOUT):
|
38
|
+
"""Constructor for the BaseClient. Used by the Cartesia and AsyncCartesia clients."""
|
39
|
+
self.api_key = api_key or os.environ.get("CARTESIA_API_KEY")
|
40
|
+
self.timeout = timeout
|
41
|
+
|
42
|
+
|
43
|
+
class Resource:
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
api_key: str,
|
47
|
+
timeout: float,
|
48
|
+
):
|
49
|
+
"""Constructor for the Resource class. Used by the Voices and TTS classes."""
|
50
|
+
self.api_key = api_key
|
51
|
+
self.timeout = timeout
|
52
|
+
self.base_url = os.environ.get("CARTESIA_BASE_URL", DEFAULT_BASE_URL)
|
53
|
+
self.cartesia_version = DEFAULT_CARTESIA_VERSION
|
54
|
+
self.headers = {
|
55
|
+
"X-API-Key": self.api_key,
|
56
|
+
"Cartesia-Version": self.cartesia_version,
|
57
|
+
"Content-Type": "application/json",
|
58
|
+
}
|
59
|
+
|
60
|
+
def _http_url(self):
|
61
|
+
"""Returns the HTTP URL for the Cartesia API.
|
62
|
+
If the base URL is localhost, the URL will start with 'http'. Otherwise, it will start with 'https'.
|
63
|
+
"""
|
64
|
+
if self.base_url.startswith("http://") or self.base_url.startswith("https://"):
|
65
|
+
return self.base_url
|
66
|
+
else:
|
67
|
+
prefix = "http" if "localhost" in self.base_url else "https"
|
68
|
+
return f"{prefix}://{self.base_url}"
|
69
|
+
|
70
|
+
def _ws_url(self):
|
71
|
+
"""Returns the WebSocket URL for the Cartesia API.
|
72
|
+
If the base URL is localhost, the URL will start with 'ws'. Otherwise, it will start with 'wss'.
|
73
|
+
"""
|
74
|
+
if self.base_url.startswith("ws://") or self.base_url.startswith("wss://"):
|
75
|
+
return self.base_url
|
76
|
+
else:
|
77
|
+
prefix = "ws" if "localhost" in self.base_url else "wss"
|
78
|
+
return f"{prefix}://{self.base_url}"
|
79
|
+
|
80
|
+
|
81
|
+
class Cartesia(BaseClient):
|
82
|
+
"""
|
83
|
+
The client for Cartesia's text-to-speech library.
|
84
|
+
|
85
|
+
This client contains methods to interact with the Cartesia text-to-speech API.
|
86
|
+
The client can be used to manage your voice library and generate speech from text.
|
87
|
+
|
88
|
+
The client supports generating audio using both Server-Sent Events and WebSocket for lower latency.
|
89
|
+
"""
|
90
|
+
|
91
|
+
def __init__(self, *, api_key: Optional[str] = None, timeout: float = DEFAULT_TIMEOUT):
|
92
|
+
"""Constructor for the Cartesia client.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
api_key: The API key to use for authorization.
|
96
|
+
If not specified, the API key will be read from the environment variable
|
97
|
+
`CARTESIA_API_KEY`.
|
98
|
+
timeout: The timeout for the HTTP requests in seconds. Defaults to 30 seconds.
|
99
|
+
"""
|
100
|
+
super().__init__(api_key=api_key, timeout=timeout)
|
101
|
+
self.voices = Voices(api_key=self.api_key, timeout=self.timeout)
|
102
|
+
self.tts = TTS(api_key=self.api_key, timeout=self.timeout)
|
103
|
+
|
104
|
+
def __enter__(self):
|
105
|
+
return self
|
106
|
+
|
107
|
+
def __exit__(
|
108
|
+
self,
|
109
|
+
exc_type: Union[type, None],
|
110
|
+
exc: Union[BaseException, None],
|
111
|
+
exc_tb: Union[TracebackType, None],
|
112
|
+
):
|
113
|
+
pass
|
114
|
+
|
115
|
+
|
116
|
+
class Voices(Resource):
|
117
|
+
"""This resource contains methods to list, get, clone, and create voices in your Cartesia voice library.
|
118
|
+
|
119
|
+
Usage:
|
120
|
+
>>> client = Cartesia(api_key="your_api_key")
|
121
|
+
>>> voices = client.voices.list()
|
122
|
+
>>> voice = client.voices.get(id="a0e99841-438c-4a64-b679-ae501e7d6091")
|
123
|
+
>>> print("Voice Name:", voice["name"], "Voice Description:", voice["description"])
|
124
|
+
>>> embedding = client.voices.clone(filepath="path/to/clip.wav")
|
125
|
+
>>> new_voice = client.voices.create(
|
126
|
+
... name="My Voice", description="A new voice", embedding=embedding
|
127
|
+
... )
|
128
|
+
"""
|
129
|
+
|
130
|
+
def list(self) -> List[VoiceMetadata]:
|
131
|
+
"""List all voices in your voice library.
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
This method returns a list of VoiceMetadata objects with the following keys:
|
135
|
+
- id: The ID of the voice.
|
136
|
+
- name: The name of the voice.
|
137
|
+
- description: The description of the voice.
|
138
|
+
- embedding: The embedding of the voice.
|
139
|
+
- is_public: Whether the voice is public.
|
140
|
+
- user_id: The ID of the user who created the voice.
|
141
|
+
- created_at: The timestamp (str) when the voice was created.
|
142
|
+
"""
|
143
|
+
response = httpx.get(
|
144
|
+
f"{self._http_url()}/voices",
|
145
|
+
headers=self.headers,
|
146
|
+
timeout=self.timeout,
|
147
|
+
)
|
148
|
+
|
149
|
+
if not response.is_success:
|
150
|
+
raise ValueError(f"Failed to get voices. Error: {response.text}")
|
151
|
+
|
152
|
+
voices = response.json()
|
153
|
+
return voices
|
154
|
+
|
155
|
+
def get(self, id: str) -> VoiceMetadata:
|
156
|
+
"""Get a voice by its ID.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
id: The ID of the voice.
|
160
|
+
|
161
|
+
Returns:
|
162
|
+
A dictionary containing the voice metadata with the following keys:
|
163
|
+
- id: The ID of the voice.
|
164
|
+
- name: The name of the voice.
|
165
|
+
- description: The description of the voice.
|
166
|
+
- embedding: The embedding of the voice as a list of floats.
|
167
|
+
- is_public: Whether the voice is public.
|
168
|
+
- user_id: The ID of the user who created the voice.
|
169
|
+
- created_at: The timestamp when the voice was created.
|
170
|
+
"""
|
171
|
+
url = f"{self._http_url()}/voices/{id}"
|
172
|
+
response = httpx.get(url, headers=self.headers, timeout=self.timeout)
|
173
|
+
|
174
|
+
if not response.is_success:
|
175
|
+
raise ValueError(
|
176
|
+
f"Failed to get voice. Status Code: {response.status_code}\n"
|
177
|
+
f"Error: {response.text}"
|
178
|
+
)
|
179
|
+
|
180
|
+
return response.json()
|
181
|
+
|
182
|
+
def clone(self, filepath: Optional[str] = None, link: Optional[str] = None) -> List[float]:
|
183
|
+
"""Clone a voice from a clip or a URL.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
filepath: The path to the clip file.
|
187
|
+
link: The URL to the clip
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
The embedding of the cloned voice as a list of floats.
|
191
|
+
"""
|
192
|
+
# TODO: Python has a bytes object, use that instead of a filepath
|
193
|
+
if not filepath and not link:
|
194
|
+
raise ValueError("At least one of 'filepath' or 'link' must be specified.")
|
195
|
+
if filepath and link:
|
196
|
+
raise ValueError("Only one of 'filepath' or 'link' should be specified.")
|
197
|
+
if filepath:
|
198
|
+
url = f"{self._http_url()}/voices/clone/clip"
|
199
|
+
with open(filepath, "rb") as file:
|
200
|
+
files = {"clip": file}
|
201
|
+
headers = self.headers.copy()
|
202
|
+
headers.pop("Content-Type", None)
|
203
|
+
headers["Content-Type"] = "multipart/form-data"
|
204
|
+
response = httpx.post(url, headers=headers, files=files, timeout=self.timeout)
|
205
|
+
if not response.is_success:
|
206
|
+
raise ValueError(f"Failed to clone voice from clip. Error: {response.text}")
|
207
|
+
elif link:
|
208
|
+
url = f"{self._http_url()}/voices/clone/url"
|
209
|
+
params = {"link": link}
|
210
|
+
headers = self.headers.copy()
|
211
|
+
headers.pop("Content-Type") # The content type header is not required for URLs
|
212
|
+
response = httpx.post(url, headers=self.headers, params=params, timeout=self.timeout)
|
213
|
+
if not response.is_success:
|
214
|
+
raise ValueError(f"Failed to clone voice from URL. Error: {response.text}")
|
215
|
+
|
216
|
+
return response.json()["embedding"]
|
217
|
+
|
218
|
+
def create(self, name: str, description: str, embedding: List[float]) -> VoiceMetadata:
|
219
|
+
"""Create a new voice.
|
220
|
+
|
221
|
+
Args:
|
222
|
+
name: The name of the voice.
|
223
|
+
description: The description of the voice.
|
224
|
+
embedding: The embedding of the voice. This should be generated with :meth:`clone`.
|
225
|
+
|
226
|
+
Returns:
|
227
|
+
A dictionary containing the voice metadata.
|
228
|
+
"""
|
229
|
+
response = httpx.post(
|
230
|
+
f"{self._http_url()}/voices",
|
231
|
+
headers=self.headers,
|
232
|
+
json={"name": name, "description": description, "embedding": embedding},
|
233
|
+
timeout=self.timeout,
|
234
|
+
)
|
235
|
+
|
236
|
+
if not response.is_success:
|
237
|
+
raise ValueError(f"Failed to create voice. Error: {response.text}")
|
238
|
+
|
239
|
+
return response.json()
|
240
|
+
|
241
|
+
|
242
|
+
class _WebSocket:
|
243
|
+
"""This class contains methods to generate audio using WebSocket. Ideal for low-latency audio generation.
|
244
|
+
|
245
|
+
Usage:
|
246
|
+
>>> ws = client.tts.websocket()
|
247
|
+
>>> for audio_chunk in ws.send(
|
248
|
+
... model_id="upbeat-moon", transcript="Hello world!", voice_embedding=embedding,
|
249
|
+
... output_format={"container": "raw", "encoding": "pcm_f32le", "sample_rate": 44100}, stream=True
|
250
|
+
... ):
|
251
|
+
... audio = audio_chunk["audio"]
|
252
|
+
"""
|
253
|
+
|
254
|
+
def __init__(
|
255
|
+
self,
|
256
|
+
ws_url: str,
|
257
|
+
api_key: str,
|
258
|
+
cartesia_version: str,
|
259
|
+
):
|
260
|
+
self.ws_url = ws_url
|
261
|
+
self.api_key = api_key
|
262
|
+
self.cartesia_version = cartesia_version
|
263
|
+
self.websocket = None
|
264
|
+
|
265
|
+
def connect(self):
|
266
|
+
"""This method connects to the WebSocket if it is not already connected."""
|
267
|
+
if self.websocket is None or self._is_websocket_closed():
|
268
|
+
route = "tts/websocket"
|
269
|
+
self.websocket = connect(
|
270
|
+
f"{self.ws_url}/{route}?api_key={self.api_key}&cartesia_version={self.cartesia_version}"
|
271
|
+
)
|
272
|
+
|
273
|
+
def _is_websocket_closed(self):
|
274
|
+
return self.websocket.socket.fileno() == -1
|
275
|
+
|
276
|
+
def close(self):
|
277
|
+
"""This method closes the WebSocket connection. *Highly* recommended to call this method when done using the WebSocket."""
|
278
|
+
if self.websocket is not None and not self._is_websocket_closed():
|
279
|
+
self.websocket.close()
|
280
|
+
|
281
|
+
def _convert_response(
|
282
|
+
self, response: Dict[str, any], include_context_id: bool
|
283
|
+
) -> Dict[str, Any]:
|
284
|
+
audio = base64.b64decode(response["data"])
|
285
|
+
|
286
|
+
optional_kwargs = {}
|
287
|
+
if include_context_id:
|
288
|
+
optional_kwargs["context_id"] = response["context_id"]
|
289
|
+
|
290
|
+
return {
|
291
|
+
"audio": audio,
|
292
|
+
**optional_kwargs,
|
293
|
+
}
|
294
|
+
|
295
|
+
def _validate_and_construct_voice(
|
296
|
+
self, voice_id: Optional[str] = None, voice_embedding: Optional[List[float]] = None
|
297
|
+
) -> dict:
|
298
|
+
"""Validate and construct the voice dictionary for the request.
|
299
|
+
|
300
|
+
Args:
|
301
|
+
voice_id: The ID of the voice to use for generating audio.
|
302
|
+
voice_embedding: The embedding of the voice to use for generating audio.
|
303
|
+
|
304
|
+
Returns:
|
305
|
+
A dictionary representing the voice configuration.
|
306
|
+
|
307
|
+
Raises:
|
308
|
+
ValueError: If neither or both voice_id and voice_embedding are specified.
|
309
|
+
"""
|
310
|
+
if voice_id is None and voice_embedding is None:
|
311
|
+
raise ValueError("Either voice_id or voice_embedding must be specified.")
|
312
|
+
|
313
|
+
if voice_id is not None and voice_embedding is not None:
|
314
|
+
raise ValueError("Only one of voice_id or voice_embedding should be specified.")
|
315
|
+
|
316
|
+
if voice_id:
|
317
|
+
return {"mode": "id", "id": voice_id}
|
318
|
+
|
319
|
+
return {"mode": "embedding", "embedding": voice_embedding}
|
320
|
+
|
321
|
+
def send(
|
322
|
+
self,
|
323
|
+
model_id: str,
|
324
|
+
transcript: str,
|
325
|
+
output_format: dict,
|
326
|
+
voice_id: Optional[str] = None,
|
327
|
+
voice_embedding: Optional[List[float]] = None,
|
328
|
+
context_id: Optional[str] = None,
|
329
|
+
duration: Optional[int] = None,
|
330
|
+
language: Optional[str] = None,
|
331
|
+
stream: bool = True,
|
332
|
+
) -> Union[bytes, Generator[bytes, None, None]]:
|
333
|
+
"""Send a request to the WebSocket to generate audio.
|
334
|
+
|
335
|
+
Args:
|
336
|
+
model_id: The ID of the model to use for generating audio.
|
337
|
+
transcript: The text to convert to speech.
|
338
|
+
output_format: A dictionary containing the details of the output format.
|
339
|
+
voice_id: The ID of the voice to use for generating audio.
|
340
|
+
voice_embedding: The embedding of the voice to use for generating audio.
|
341
|
+
context_id: The context ID to use for the request. If not specified, a random context ID will be generated.
|
342
|
+
duration: The duration of the audio in seconds.
|
343
|
+
language: The language code for the audio request. This can only be used with `model_id = sonic-multilingual`
|
344
|
+
stream: Whether to stream the audio or not. (Default is True)
|
345
|
+
|
346
|
+
Returns:
|
347
|
+
If `stream` is True, the method returns a generator that yields chunks of audio as bytes.
|
348
|
+
If `stream` is False, the method returns a dictionary containing the concatenated audio as bytes and the context ID.
|
349
|
+
"""
|
350
|
+
self.connect()
|
351
|
+
|
352
|
+
if context_id is None:
|
353
|
+
context_id = uuid.uuid4().hex
|
354
|
+
|
355
|
+
voice = self._validate_and_construct_voice(voice_id, voice_embedding)
|
356
|
+
|
357
|
+
request_body = {
|
358
|
+
"model_id": model_id,
|
359
|
+
"transcript": transcript,
|
360
|
+
"voice": voice,
|
361
|
+
"output_format": {
|
362
|
+
"container": output_format["container"],
|
363
|
+
"encoding": output_format["encoding"],
|
364
|
+
"sample_rate": output_format["sample_rate"],
|
365
|
+
},
|
366
|
+
"context_id": context_id,
|
367
|
+
"language": language,
|
368
|
+
}
|
369
|
+
|
370
|
+
if duration is not None:
|
371
|
+
request_body["duration"] = duration
|
372
|
+
|
373
|
+
generator = self._websocket_generator(request_body)
|
374
|
+
|
375
|
+
if stream:
|
376
|
+
return generator
|
377
|
+
|
378
|
+
chunks = []
|
379
|
+
for chunk in generator:
|
380
|
+
chunks.append(chunk["audio"])
|
381
|
+
|
382
|
+
return {"audio": b"".join(chunks), "context_id": context_id}
|
383
|
+
|
384
|
+
def _websocket_generator(self, request_body: Dict[str, Any]):
|
385
|
+
self.websocket.send(json.dumps(request_body))
|
386
|
+
|
387
|
+
try:
|
388
|
+
while True:
|
389
|
+
response = json.loads(self.websocket.recv())
|
390
|
+
if "error" in response:
|
391
|
+
raise RuntimeError(f"Error generating audio:\n{response['error']}")
|
392
|
+
if response["done"]:
|
393
|
+
break
|
394
|
+
yield self._convert_response(response=response, include_context_id=True)
|
395
|
+
except Exception as e:
|
396
|
+
# Close the websocket connection if an error occurs.
|
397
|
+
if self.websocket and not self._is_websocket_closed():
|
398
|
+
self.websocket.close()
|
399
|
+
raise RuntimeError(f"Failed to generate audio. {response}") from e
|
400
|
+
|
401
|
+
|
402
|
+
class _SSE:
|
403
|
+
"""This class contains methods to generate audio using Server-Sent Events.
|
404
|
+
|
405
|
+
Usage:
|
406
|
+
>>> for audio_chunk in client.tts.sse(
|
407
|
+
... model_id="upbeat-moon", transcript="Hello world!", voice_embedding=embedding,
|
408
|
+
... output_format={"container": "raw", "encoding": "pcm_f32le", "sample_rate": 44100}, stream=True
|
409
|
+
... ):
|
410
|
+
... audio = audio_chunk["audio"]
|
411
|
+
"""
|
412
|
+
|
413
|
+
def __init__(
|
414
|
+
self,
|
415
|
+
http_url: str,
|
416
|
+
headers: Dict[str, str],
|
417
|
+
timeout: float,
|
418
|
+
):
|
419
|
+
self.http_url = http_url
|
420
|
+
self.headers = headers
|
421
|
+
self.timeout = timeout
|
422
|
+
|
423
|
+
def _update_buffer(self, buffer: str, chunk_bytes: bytes) -> Tuple[str, List[Dict[str, Any]]]:
|
424
|
+
buffer += chunk_bytes.decode("utf-8")
|
425
|
+
outputs = []
|
426
|
+
while "{" in buffer and "}" in buffer:
|
427
|
+
start_index = buffer.find("{")
|
428
|
+
end_index = buffer.find("}", start_index)
|
429
|
+
if start_index != -1 and end_index != -1:
|
430
|
+
try:
|
431
|
+
chunk_json = json.loads(buffer[start_index : end_index + 1])
|
432
|
+
if "error" in chunk_json:
|
433
|
+
raise RuntimeError(f"Error generating audio:\n{chunk_json['error']}")
|
434
|
+
if chunk_json["done"]:
|
435
|
+
break
|
436
|
+
audio = base64.b64decode(chunk_json["data"])
|
437
|
+
outputs.append({"audio": audio})
|
438
|
+
buffer = buffer[end_index + 1 :]
|
439
|
+
except json.JSONDecodeError:
|
440
|
+
break
|
441
|
+
return buffer, outputs
|
442
|
+
|
443
|
+
def _validate_and_construct_voice(
|
444
|
+
self, voice_id: Optional[str] = None, voice_embedding: Optional[List[float]] = None
|
445
|
+
) -> dict:
|
446
|
+
"""Validate and construct the voice dictionary for the request.
|
447
|
+
|
448
|
+
Args:
|
449
|
+
voice_id: The ID of the voice to use for generating audio.
|
450
|
+
voice_embedding: The embedding of the voice to use for generating audio.
|
451
|
+
|
452
|
+
Returns:
|
453
|
+
A dictionary representing the voice configuration.
|
454
|
+
|
455
|
+
Raises:
|
456
|
+
ValueError: If neither or both voice_id and voice_embedding are specified.
|
457
|
+
"""
|
458
|
+
if voice_id is None and voice_embedding is None:
|
459
|
+
raise ValueError("Either voice_id or voice_embedding must be specified.")
|
460
|
+
|
461
|
+
if voice_id is not None and voice_embedding is not None:
|
462
|
+
raise ValueError("Only one of voice_id or voice_embedding should be specified.")
|
463
|
+
|
464
|
+
if voice_id:
|
465
|
+
return {"mode": "id", "id": voice_id}
|
466
|
+
|
467
|
+
return {"mode": "embedding", "embedding": voice_embedding}
|
468
|
+
|
469
|
+
def send(
|
470
|
+
self,
|
471
|
+
model_id: str,
|
472
|
+
transcript: str,
|
473
|
+
output_format: OutputFormat,
|
474
|
+
voice_id: Optional[str] = None,
|
475
|
+
voice_embedding: Optional[List[float]] = None,
|
476
|
+
duration: Optional[int] = None,
|
477
|
+
language: Optional[str] = None,
|
478
|
+
stream: bool = True,
|
479
|
+
) -> Union[bytes, Generator[bytes, None, None]]:
|
480
|
+
"""Send a request to the server to generate audio using Server-Sent Events.
|
481
|
+
|
482
|
+
Args:
|
483
|
+
model_id: The ID of the model to use for generating audio.
|
484
|
+
transcript: The text to convert to speech.
|
485
|
+
voice_id: The ID of the voice to use for generating audio.
|
486
|
+
voice_embedding: The embedding of the voice to use for generating audio.
|
487
|
+
output_format: A dictionary containing the details of the output format.
|
488
|
+
duration: The duration of the audio in seconds.
|
489
|
+
language: The language code for the audio request. This can only be used with `model_id = sonic-multilingual`
|
490
|
+
stream: Whether to stream the audio or not.
|
491
|
+
|
492
|
+
Returns:
|
493
|
+
If `stream` is True, the method returns a generator that yields chunks. Each chunk is a dictionary containing the audio as bytes.
|
494
|
+
If `stream` is False, the method returns a dictionary containing the audio as bytes.
|
495
|
+
"""
|
496
|
+
voice = self._validate_and_construct_voice(voice_id, voice_embedding)
|
497
|
+
|
498
|
+
request_body = {
|
499
|
+
"model_id": model_id,
|
500
|
+
"transcript": transcript,
|
501
|
+
"voice": voice,
|
502
|
+
"output_format": {
|
503
|
+
"container": output_format["container"],
|
504
|
+
"encoding": output_format["encoding"],
|
505
|
+
"sample_rate": output_format["sample_rate"],
|
506
|
+
},
|
507
|
+
"language": language,
|
508
|
+
}
|
509
|
+
|
510
|
+
if duration is not None:
|
511
|
+
request_body["duration"] = duration
|
512
|
+
|
513
|
+
generator = self._sse_generator_wrapper(request_body)
|
514
|
+
|
515
|
+
if stream:
|
516
|
+
return generator
|
517
|
+
|
518
|
+
chunks = []
|
519
|
+
for chunk in generator:
|
520
|
+
chunks.append(chunk["audio"])
|
521
|
+
|
522
|
+
return {"audio": b"".join(chunks)}
|
523
|
+
|
524
|
+
@retry_on_connection_error(
|
525
|
+
max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger
|
526
|
+
)
|
527
|
+
def _sse_generator_wrapper(self, request_body: Dict[str, Any]):
|
528
|
+
"""Need to wrap the sse generator in a function for the retry decorator to work."""
|
529
|
+
try:
|
530
|
+
for chunk in self._sse_generator(request_body):
|
531
|
+
yield chunk
|
532
|
+
except Exception as e:
|
533
|
+
logger.error(f"Failed to generate audio. {e}")
|
534
|
+
raise e
|
535
|
+
|
536
|
+
def _sse_generator(self, request_body: Dict[str, Any]):
|
537
|
+
response = requests.post(
|
538
|
+
f"{self.http_url}/tts/sse",
|
539
|
+
stream=True,
|
540
|
+
data=json.dumps(request_body),
|
541
|
+
headers=self.headers,
|
542
|
+
timeout=(self.timeout, self.timeout),
|
543
|
+
)
|
544
|
+
if not response.ok:
|
545
|
+
raise ValueError(f"Failed to generate audio. {response.text}")
|
546
|
+
|
547
|
+
buffer = ""
|
548
|
+
for chunk_bytes in response.iter_content(chunk_size=None):
|
549
|
+
buffer, outputs = self._update_buffer(buffer=buffer, chunk_bytes=chunk_bytes)
|
550
|
+
for output in outputs:
|
551
|
+
yield output
|
552
|
+
|
553
|
+
if buffer:
|
554
|
+
try:
|
555
|
+
chunk_json = json.loads(buffer)
|
556
|
+
audio = base64.b64decode(chunk_json["data"])
|
557
|
+
yield {"audio": audio}
|
558
|
+
except json.JSONDecodeError:
|
559
|
+
pass
|
560
|
+
|
561
|
+
|
562
|
+
class TTS(Resource):
|
563
|
+
"""This resource contains methods to generate audio using Cartesia's text-to-speech API."""
|
564
|
+
|
565
|
+
def __init__(self, api_key, timeout):
|
566
|
+
super().__init__(
|
567
|
+
api_key=api_key,
|
568
|
+
timeout=timeout,
|
569
|
+
)
|
570
|
+
self._sse_class = _SSE(self._http_url(), self.headers, self.timeout)
|
571
|
+
self.sse = self._sse_class.send
|
572
|
+
|
573
|
+
def websocket(self) -> _WebSocket:
|
574
|
+
"""This method returns a WebSocket object that can be used to generate audio using WebSocket.
|
575
|
+
|
576
|
+
Returns:
|
577
|
+
_WebSocket: A WebSocket object that can be used to generate audio using WebSocket.
|
578
|
+
"""
|
579
|
+
ws = _WebSocket(self._ws_url(), self.api_key, self.cartesia_version)
|
580
|
+
ws.connect()
|
581
|
+
return ws
|
582
|
+
|
583
|
+
def get_output_format(self, output_format_name: str) -> OutputFormat:
|
584
|
+
"""Convenience method to get the output_format object from a given output format name.
|
585
|
+
|
586
|
+
Args:
|
587
|
+
output_format_name (str): The name of the output format.
|
588
|
+
|
589
|
+
Returns:
|
590
|
+
OutputFormat: A dictionary containing the details of the output format to be passed into tts.sse() or tts.websocket().send()
|
591
|
+
"""
|
592
|
+
output_format_obj = OutputFormatMapping.get_format(output_format_name)
|
593
|
+
return OutputFormat(
|
594
|
+
container=output_format_obj["container"],
|
595
|
+
encoding=output_format_obj["encoding"],
|
596
|
+
sample_rate=output_format_obj["sample_rate"],
|
597
|
+
)
|
598
|
+
|
599
|
+
def get_sample_rate(self, output_format_name: str) -> int:
|
600
|
+
"""Convenience method to get the sample rate for a given output format.
|
601
|
+
|
602
|
+
Args:
|
603
|
+
output_format_name (str): The name of the output format.
|
604
|
+
|
605
|
+
Returns:
|
606
|
+
int: The sample rate for the output format.
|
607
|
+
"""
|
608
|
+
output_format_obj = OutputFormatMapping.get_format(output_format_name)
|
609
|
+
return output_format_obj["sample_rate"]
|
610
|
+
|
611
|
+
|
612
|
+
class AsyncCartesia(Cartesia):
|
613
|
+
"""The asynchronous version of the Cartesia client."""
|
614
|
+
|
615
|
+
def __init__(
|
616
|
+
self,
|
617
|
+
*,
|
618
|
+
api_key: Optional[str] = None,
|
619
|
+
timeout: float = DEFAULT_TIMEOUT,
|
620
|
+
max_num_connections: int = DEFAULT_NUM_CONNECTIONS,
|
621
|
+
):
|
622
|
+
"""
|
623
|
+
Args:
|
624
|
+
api_key: See :class:`Cartesia`.
|
625
|
+
timeout: See :class:`Cartesia`.
|
626
|
+
max_num_connections: The maximum number of concurrent connections to use for the client.
|
627
|
+
This is used to limit the number of connections that can be made to the server.
|
628
|
+
"""
|
629
|
+
self._session = None
|
630
|
+
self._loop = None
|
631
|
+
super().__init__(api_key=api_key, timeout=timeout)
|
632
|
+
self.max_num_connections = max_num_connections
|
633
|
+
self.tts = AsyncTTS(
|
634
|
+
api_key=self.api_key, timeout=self.timeout, get_session=self._get_session
|
635
|
+
)
|
636
|
+
|
637
|
+
async def _get_session(self):
|
638
|
+
current_loop = asyncio.get_event_loop()
|
639
|
+
if self._loop is not current_loop:
|
640
|
+
# If the loop has changed, close the session and create a new one.
|
641
|
+
await self.close()
|
642
|
+
if self._session is None or self._session.closed:
|
643
|
+
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
644
|
+
connector = aiohttp.TCPConnector(limit=self.max_num_connections)
|
645
|
+
self._session = aiohttp.ClientSession(timeout=timeout, connector=connector)
|
646
|
+
self._loop = current_loop
|
647
|
+
return self._session
|
648
|
+
|
649
|
+
async def close(self):
|
650
|
+
"""This method closes the session.
|
651
|
+
|
652
|
+
It is *strongly* recommended to call this method when you are done using the client.
|
653
|
+
"""
|
654
|
+
if self._session is not None and not self._session.closed:
|
655
|
+
await self._session.close()
|
656
|
+
|
657
|
+
def __del__(self):
|
658
|
+
try:
|
659
|
+
loop = asyncio.get_running_loop()
|
660
|
+
except RuntimeError:
|
661
|
+
loop = None
|
662
|
+
|
663
|
+
if loop is None:
|
664
|
+
asyncio.run(self.close())
|
665
|
+
else:
|
666
|
+
loop.create_task(self.close())
|
667
|
+
|
668
|
+
async def __aenter__(self):
|
669
|
+
return self
|
670
|
+
|
671
|
+
async def __aexit__(
|
672
|
+
self,
|
673
|
+
exc_type: Union[type, None],
|
674
|
+
exc: Union[BaseException, None],
|
675
|
+
exc_tb: Union[TracebackType, None],
|
676
|
+
):
|
677
|
+
await self.close()
|
678
|
+
|
679
|
+
|
680
|
+
class _AsyncSSE(_SSE):
|
681
|
+
"""This class contains methods to generate audio using Server-Sent Events asynchronously."""
|
682
|
+
|
683
|
+
def __init__(
|
684
|
+
self,
|
685
|
+
http_url: str,
|
686
|
+
headers: Dict[str, str],
|
687
|
+
timeout: float,
|
688
|
+
get_session: Callable[[], Optional[aiohttp.ClientSession]],
|
689
|
+
):
|
690
|
+
super().__init__(http_url, headers, timeout)
|
691
|
+
self._get_session = get_session
|
692
|
+
|
693
|
+
async def send(
|
694
|
+
self,
|
695
|
+
model_id: str,
|
696
|
+
transcript: str,
|
697
|
+
output_format: OutputFormat,
|
698
|
+
voice_id: Optional[str] = None,
|
699
|
+
voice_embedding: Optional[List[float]] = None,
|
700
|
+
duration: Optional[int] = None,
|
701
|
+
language: Optional[str] = None,
|
702
|
+
stream: bool = True,
|
703
|
+
) -> Union[bytes, AsyncGenerator[bytes, None]]:
|
704
|
+
voice = self._validate_and_construct_voice(voice_id, voice_embedding)
|
705
|
+
|
706
|
+
request_body = {
|
707
|
+
"model_id": model_id,
|
708
|
+
"transcript": transcript,
|
709
|
+
"voice": voice,
|
710
|
+
"output_format": {
|
711
|
+
"container": output_format["container"],
|
712
|
+
"encoding": output_format["encoding"],
|
713
|
+
"sample_rate": output_format["sample_rate"],
|
714
|
+
},
|
715
|
+
"language": language,
|
716
|
+
}
|
717
|
+
|
718
|
+
if duration is not None:
|
719
|
+
request_body["duration"] = duration
|
720
|
+
|
721
|
+
generator = self._sse_generator_wrapper(request_body)
|
722
|
+
|
723
|
+
if stream:
|
724
|
+
return generator
|
725
|
+
|
726
|
+
chunks = []
|
727
|
+
async for chunk in generator:
|
728
|
+
chunks.append(chunk["audio"])
|
729
|
+
|
730
|
+
return {"audio": b"".join(chunks)}
|
731
|
+
|
732
|
+
@retry_on_connection_error_async(
|
733
|
+
max_retries=MAX_RETRIES, backoff_factor=BACKOFF_FACTOR, logger=logger
|
734
|
+
)
|
735
|
+
async def _sse_generator_wrapper(self, request_body: Dict[str, Any]):
|
736
|
+
"""Need to wrap the sse generator in a function for the retry decorator to work."""
|
737
|
+
try:
|
738
|
+
async for chunk in self._sse_generator(request_body):
|
739
|
+
yield chunk
|
740
|
+
except Exception as e:
|
741
|
+
logger.error(f"Failed to generate audio. {e}")
|
742
|
+
raise e
|
743
|
+
|
744
|
+
async def _sse_generator(self, request_body: Dict[str, Any]):
|
745
|
+
session = await self._get_session()
|
746
|
+
async with session.post(
|
747
|
+
f"{self.http_url}/tts/sse", data=json.dumps(request_body), headers=self.headers
|
748
|
+
) as response:
|
749
|
+
if not response.ok:
|
750
|
+
raise ValueError(f"Failed to generate audio. {await response.text()}")
|
751
|
+
|
752
|
+
buffer = ""
|
753
|
+
async for chunk_bytes in response.content.iter_any():
|
754
|
+
buffer, outputs = self._update_buffer(buffer=buffer, chunk_bytes=chunk_bytes)
|
755
|
+
for output in outputs:
|
756
|
+
yield output
|
757
|
+
|
758
|
+
if buffer:
|
759
|
+
try:
|
760
|
+
chunk_json = json.loads(buffer)
|
761
|
+
audio = base64.b64decode(chunk_json["data"])
|
762
|
+
yield {"audio": audio}
|
763
|
+
except json.JSONDecodeError:
|
764
|
+
pass
|
765
|
+
|
766
|
+
|
767
|
+
class _AsyncWebSocket(_WebSocket):
|
768
|
+
"""This class contains methods to generate audio using WebSocket asynchronously."""
|
769
|
+
|
770
|
+
def __init__(
|
771
|
+
self,
|
772
|
+
ws_url: str,
|
773
|
+
api_key: str,
|
774
|
+
cartesia_version: str,
|
775
|
+
get_session: Callable[[], Optional[aiohttp.ClientSession]],
|
776
|
+
):
|
777
|
+
super().__init__(ws_url, api_key, cartesia_version)
|
778
|
+
self._get_session = get_session
|
779
|
+
self.websocket = None
|
780
|
+
|
781
|
+
async def connect(self):
|
782
|
+
if self.websocket is None or self._is_websocket_closed():
|
783
|
+
route = "tts/websocket"
|
784
|
+
session = await self._get_session()
|
785
|
+
self.websocket = await session.ws_connect(
|
786
|
+
f"{self.ws_url}/{route}?api_key={self.api_key}&cartesia_version={self.cartesia_version}"
|
787
|
+
)
|
788
|
+
|
789
|
+
def _is_websocket_closed(self):
|
790
|
+
return self.websocket.closed
|
791
|
+
|
792
|
+
async def close(self):
|
793
|
+
"""This method closes the websocket connection. *Highly* recommended to call this method when done."""
|
794
|
+
if self.websocket is not None and not self._is_websocket_closed():
|
795
|
+
await self.websocket.close()
|
796
|
+
|
797
|
+
async def send(
|
798
|
+
self,
|
799
|
+
model_id: str,
|
800
|
+
transcript: str,
|
801
|
+
output_format: OutputFormat,
|
802
|
+
voice_id: Optional[str] = None,
|
803
|
+
voice_embedding: Optional[List[float]] = None,
|
804
|
+
context_id: Optional[str] = None,
|
805
|
+
duration: Optional[int] = None,
|
806
|
+
language: Optional[str] = None,
|
807
|
+
stream: Optional[bool] = True,
|
808
|
+
) -> Union[bytes, AsyncGenerator[bytes, None]]:
|
809
|
+
await self.connect()
|
810
|
+
|
811
|
+
if context_id is None:
|
812
|
+
context_id = uuid.uuid4().hex
|
813
|
+
|
814
|
+
voice = self._validate_and_construct_voice(voice_id, voice_embedding)
|
815
|
+
|
816
|
+
request_body = {
|
817
|
+
"model_id": model_id,
|
818
|
+
"transcript": transcript,
|
819
|
+
"voice": voice,
|
820
|
+
"output_format": {
|
821
|
+
"container": output_format["container"],
|
822
|
+
"encoding": output_format["encoding"],
|
823
|
+
"sample_rate": output_format["sample_rate"],
|
824
|
+
},
|
825
|
+
"context_id": context_id,
|
826
|
+
"language": language,
|
827
|
+
}
|
828
|
+
|
829
|
+
if duration is not None:
|
830
|
+
request_body["duration"] = duration
|
831
|
+
|
832
|
+
generator = self._websocket_generator(request_body)
|
833
|
+
|
834
|
+
if stream:
|
835
|
+
return generator
|
836
|
+
|
837
|
+
chunks = []
|
838
|
+
async for chunk in generator:
|
839
|
+
chunks.append(chunk["audio"])
|
840
|
+
|
841
|
+
return {"audio": b"".join(chunks), "context_id": context_id}
|
842
|
+
|
843
|
+
async def _websocket_generator(self, request_body: Dict[str, Any]):
|
844
|
+
await self.websocket.send_json(request_body)
|
845
|
+
|
846
|
+
try:
|
847
|
+
response = None
|
848
|
+
while True:
|
849
|
+
response = await self.websocket.receive_json()
|
850
|
+
if "error" in response:
|
851
|
+
raise RuntimeError(f"Error generating audio:\n{response['error']}")
|
852
|
+
if response["done"]:
|
853
|
+
break
|
854
|
+
|
855
|
+
yield self._convert_response(response=response, include_context_id=True)
|
856
|
+
except Exception as e:
|
857
|
+
# Close the websocket connection if an error occurs.
|
858
|
+
if self.websocket and not self._is_websocket_closed():
|
859
|
+
await self.websocket.close()
|
860
|
+
error_msg_end = "" if response is None else f": {await response.text()}"
|
861
|
+
raise RuntimeError(f"Failed to generate audio. {error_msg_end}") from e
|
862
|
+
|
863
|
+
|
864
|
+
class AsyncTTS(TTS):
|
865
|
+
def __init__(self, api_key, timeout, get_session):
|
866
|
+
super().__init__(api_key, timeout)
|
867
|
+
self._get_session = get_session
|
868
|
+
self._sse_class = _AsyncSSE(self._http_url(), self.headers, self.timeout, get_session)
|
869
|
+
self.sse = self._sse_class.send
|
870
|
+
|
871
|
+
async def websocket(self) -> _AsyncWebSocket:
|
872
|
+
ws = _AsyncWebSocket(self._ws_url(), self.api_key, self.cartesia_version, self._get_session)
|
873
|
+
await ws.connect()
|
874
|
+
return ws
|