cartesia 1.0.2__py2.py3-none-any.whl → 1.0.4__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cartesia/_types.py +5 -4
- cartesia/client.py +316 -80
- cartesia/utils/__init__.py +0 -0
- cartesia/utils/deprecated.py +55 -0
- cartesia/utils/retry.py +87 -0
- cartesia/version.py +1 -1
- {cartesia-1.0.2.dist-info → cartesia-1.0.4.dist-info}/METADATA +96 -1
- cartesia-1.0.4.dist-info/RECORD +11 -0
- cartesia-1.0.2.dist-info/RECORD +0 -8
- {cartesia-1.0.2.dist-info → cartesia-1.0.4.dist-info}/WHEEL +0 -0
- {cartesia-1.0.2.dist-info → cartesia-1.0.4.dist-info}/top_level.txt +0 -0
cartesia/_types.py
CHANGED
@@ -45,15 +45,16 @@ class DeprecatedOutputFormatMapping:
|
|
45
45
|
"mulaw_8000": {"container": "raw", "encoding": "pcm_mulaw", "sample_rate": 8000},
|
46
46
|
"alaw_8000": {"container": "raw", "encoding": "pcm_alaw", "sample_rate": 8000},
|
47
47
|
}
|
48
|
-
|
48
|
+
|
49
|
+
@classmethod
|
49
50
|
@deprecated(
|
50
51
|
vdeprecated="1.0.1",
|
51
52
|
vremove="1.2.0",
|
52
53
|
reason="Old output format names are being deprecated in favor of names aligned with the Cartesia API. Use names from `OutputFormatMapping` instead.",
|
53
54
|
)
|
54
|
-
def get_format_deprecated(
|
55
|
-
if format_name in
|
56
|
-
return
|
55
|
+
def get_format_deprecated(cls, format_name):
|
56
|
+
if format_name in cls._format_mapping:
|
57
|
+
return cls._format_mapping[format_name]
|
57
58
|
else:
|
58
59
|
raise ValueError(f"Unsupported format: {format_name}")
|
59
60
|
|
cartesia/client.py
CHANGED
@@ -4,7 +4,17 @@ import json
|
|
4
4
|
import os
|
5
5
|
import uuid
|
6
6
|
from types import TracebackType
|
7
|
-
from typing import
|
7
|
+
from typing import (
|
8
|
+
Any,
|
9
|
+
AsyncGenerator,
|
10
|
+
Dict,
|
11
|
+
Generator,
|
12
|
+
List,
|
13
|
+
Optional,
|
14
|
+
Tuple,
|
15
|
+
Union,
|
16
|
+
Callable,
|
17
|
+
)
|
8
18
|
|
9
19
|
import aiohttp
|
10
20
|
import httpx
|
@@ -13,7 +23,6 @@ import requests
|
|
13
23
|
from websockets.sync.client import connect
|
14
24
|
|
15
25
|
from cartesia.utils.retry import retry_on_connection_error, retry_on_connection_error_async
|
16
|
-
from cartesia.utils.deprecated import deprecated
|
17
26
|
from cartesia._types import (
|
18
27
|
OutputFormat,
|
19
28
|
OutputFormatMapping,
|
@@ -36,22 +45,34 @@ logger = logging.getLogger(__name__)
|
|
36
45
|
|
37
46
|
|
38
47
|
class BaseClient:
|
39
|
-
def __init__(
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
*,
|
51
|
+
api_key: Optional[str] = None,
|
52
|
+
base_url: Optional[str] = None,
|
53
|
+
timeout: float = DEFAULT_TIMEOUT,
|
54
|
+
):
|
40
55
|
"""Constructor for the BaseClient. Used by the Cartesia and AsyncCartesia clients."""
|
41
56
|
self.api_key = api_key or os.environ.get("CARTESIA_API_KEY")
|
57
|
+
self._base_url = base_url or os.environ.get("CARTESIA_BASE_URL", DEFAULT_BASE_URL)
|
42
58
|
self.timeout = timeout
|
43
59
|
|
60
|
+
@property
|
61
|
+
def base_url(self):
|
62
|
+
return self._base_url
|
63
|
+
|
44
64
|
|
45
65
|
class Resource:
|
46
66
|
def __init__(
|
47
67
|
self,
|
48
68
|
api_key: str,
|
69
|
+
base_url: str,
|
49
70
|
timeout: float,
|
50
71
|
):
|
51
72
|
"""Constructor for the Resource class. Used by the Voices and TTS classes."""
|
52
73
|
self.api_key = api_key
|
53
74
|
self.timeout = timeout
|
54
|
-
self.
|
75
|
+
self._base_url = base_url
|
55
76
|
self.cartesia_version = DEFAULT_CARTESIA_VERSION
|
56
77
|
self.headers = {
|
57
78
|
"X-API-Key": self.api_key,
|
@@ -59,25 +80,29 @@ class Resource:
|
|
59
80
|
"Content-Type": "application/json",
|
60
81
|
}
|
61
82
|
|
83
|
+
@property
|
84
|
+
def base_url(self):
|
85
|
+
return self._base_url
|
86
|
+
|
62
87
|
def _http_url(self):
|
63
88
|
"""Returns the HTTP URL for the Cartesia API.
|
64
89
|
If the base URL is localhost, the URL will start with 'http'. Otherwise, it will start with 'https'.
|
65
90
|
"""
|
66
|
-
if self.
|
67
|
-
return self.
|
91
|
+
if self._base_url.startswith("http://") or self._base_url.startswith("https://"):
|
92
|
+
return self._base_url
|
68
93
|
else:
|
69
|
-
prefix = "http" if "localhost" in self.
|
70
|
-
return f"{prefix}://{self.
|
94
|
+
prefix = "http" if "localhost" in self._base_url else "https"
|
95
|
+
return f"{prefix}://{self._base_url}"
|
71
96
|
|
72
97
|
def _ws_url(self):
|
73
98
|
"""Returns the WebSocket URL for the Cartesia API.
|
74
99
|
If the base URL is localhost, the URL will start with 'ws'. Otherwise, it will start with 'wss'.
|
75
100
|
"""
|
76
|
-
if self.
|
77
|
-
return self.
|
101
|
+
if self._base_url.startswith("ws://") or self._base_url.startswith("wss://"):
|
102
|
+
return self._base_url
|
78
103
|
else:
|
79
|
-
prefix = "ws" if "localhost" in self.
|
80
|
-
return f"{prefix}://{self.
|
104
|
+
prefix = "ws" if "localhost" in self._base_url else "wss"
|
105
|
+
return f"{prefix}://{self._base_url}"
|
81
106
|
|
82
107
|
|
83
108
|
class Cartesia(BaseClient):
|
@@ -90,18 +115,27 @@ class Cartesia(BaseClient):
|
|
90
115
|
The client supports generating audio using both Server-Sent Events and WebSocket for lower latency.
|
91
116
|
"""
|
92
117
|
|
93
|
-
def __init__(
|
118
|
+
def __init__(
|
119
|
+
self,
|
120
|
+
*,
|
121
|
+
api_key: Optional[str] = None,
|
122
|
+
base_url: Optional[str] = None,
|
123
|
+
timeout: float = DEFAULT_TIMEOUT,
|
124
|
+
):
|
94
125
|
"""Constructor for the Cartesia client.
|
95
126
|
|
96
127
|
Args:
|
97
128
|
api_key: The API key to use for authorization.
|
98
129
|
If not specified, the API key will be read from the environment variable
|
99
130
|
`CARTESIA_API_KEY`.
|
100
|
-
|
131
|
+
base_url: The base URL for the Cartesia API.
|
132
|
+
If not specified, the base URL will be read from the enviroment variable
|
133
|
+
`CARTESIA_BASE_URL`. Defaults to `api.cartesia.ai`.
|
134
|
+
timeout: The timeout for HTTP and WebSocket requests in seconds. Defaults to 30 seconds.
|
101
135
|
"""
|
102
|
-
super().__init__(api_key=api_key, timeout=timeout)
|
103
|
-
self.voices = Voices(api_key=self.api_key, timeout=self.timeout)
|
104
|
-
self.tts = TTS(api_key=self.api_key, timeout=self.timeout)
|
136
|
+
super().__init__(api_key=api_key, base_url=base_url, timeout=timeout)
|
137
|
+
self.voices = Voices(api_key=self.api_key, base_url=self._base_url, timeout=self.timeout)
|
138
|
+
self.tts = TTS(api_key=self.api_key, base_url=self._base_url, timeout=self.timeout)
|
105
139
|
|
106
140
|
def __enter__(self):
|
107
141
|
return self
|
@@ -188,7 +222,6 @@ class Voices(Resource):
|
|
188
222
|
files = {"clip": file}
|
189
223
|
headers = self.headers.copy()
|
190
224
|
headers.pop("Content-Type", None)
|
191
|
-
headers["Content-Type"] = "multipart/form-data"
|
192
225
|
response = httpx.post(url, headers=headers, files=files, timeout=self.timeout)
|
193
226
|
if not response.is_success:
|
194
227
|
raise ValueError(f"Failed to clone voice from clip. Error: {response.text}")
|
@@ -233,8 +266,9 @@ class _WebSocket:
|
|
233
266
|
Usage:
|
234
267
|
>>> ws = client.tts.websocket()
|
235
268
|
>>> for audio_chunk in ws.send(
|
236
|
-
... model_id="
|
237
|
-
... output_format={"container": "raw", "encoding": "pcm_f32le", "sample_rate": 44100},
|
269
|
+
... model_id="sonic-english", transcript="Hello world!", voice_embedding=embedding,
|
270
|
+
... output_format={"container": "raw", "encoding": "pcm_f32le", "sample_rate": 44100},
|
271
|
+
... context_id=context_id, stream=True
|
238
272
|
... ):
|
239
273
|
... audio = audio_chunk["audio"]
|
240
274
|
"""
|
@@ -251,12 +285,19 @@ class _WebSocket:
|
|
251
285
|
self.websocket = None
|
252
286
|
|
253
287
|
def connect(self):
|
254
|
-
"""This method connects to the WebSocket if it is not already connected.
|
288
|
+
"""This method connects to the WebSocket if it is not already connected.
|
289
|
+
|
290
|
+
Raises:
|
291
|
+
RuntimeError: If the connection to the WebSocket fails.
|
292
|
+
"""
|
255
293
|
if self.websocket is None or self._is_websocket_closed():
|
256
294
|
route = "tts/websocket"
|
257
|
-
|
258
|
-
|
259
|
-
|
295
|
+
try:
|
296
|
+
self.websocket = connect(
|
297
|
+
f"{self.ws_url}/{route}?api_key={self.api_key}&cartesia_version={self.cartesia_version}"
|
298
|
+
)
|
299
|
+
except Exception as e:
|
300
|
+
raise RuntimeError(f"Failed to connect to WebSocket. {e}")
|
260
301
|
|
261
302
|
def _is_websocket_closed(self):
|
262
303
|
return self.websocket.socket.fileno() == -1
|
@@ -329,7 +370,7 @@ class _WebSocket:
|
|
329
370
|
context_id: The context ID to use for the request. If not specified, a random context ID will be generated.
|
330
371
|
duration: The duration of the audio in seconds.
|
331
372
|
language: The language code for the audio request. This can only be used with `model_id = sonic-multilingual`
|
332
|
-
stream: Whether to stream the audio or not.
|
373
|
+
stream: Whether to stream the audio or not.
|
333
374
|
|
334
375
|
Returns:
|
335
376
|
If `stream` is True, the method returns a generator that yields chunks. Each chunk is a dictionary.
|
@@ -341,7 +382,7 @@ class _WebSocket:
|
|
341
382
|
self.connect()
|
342
383
|
|
343
384
|
if context_id is None:
|
344
|
-
context_id = uuid.uuid4()
|
385
|
+
context_id = str(uuid.uuid4())
|
345
386
|
|
346
387
|
voice = self._validate_and_construct_voice(voice_id, voice_embedding)
|
347
388
|
|
@@ -395,7 +436,7 @@ class _SSE:
|
|
395
436
|
|
396
437
|
Usage:
|
397
438
|
>>> for audio_chunk in client.tts.sse(
|
398
|
-
... model_id="
|
439
|
+
... model_id="sonic-english", transcript="Hello world!", voice_embedding=embedding,
|
399
440
|
... output_format={"container": "raw", "encoding": "pcm_f32le", "sample_rate": 44100}, stream=True
|
400
441
|
... ):
|
401
442
|
... audio = audio_chunk["audio"]
|
@@ -523,8 +564,7 @@ class _SSE:
|
|
523
564
|
for chunk in self._sse_generator(request_body):
|
524
565
|
yield chunk
|
525
566
|
except Exception as e:
|
526
|
-
|
527
|
-
raise e
|
567
|
+
raise RuntimeError(f"Error generating audio. {e}")
|
528
568
|
|
529
569
|
def _sse_generator(self, request_body: Dict[str, Any]):
|
530
570
|
response = requests.post(
|
@@ -555,9 +595,10 @@ class _SSE:
|
|
555
595
|
class TTS(Resource):
|
556
596
|
"""This resource contains methods to generate audio using Cartesia's text-to-speech API."""
|
557
597
|
|
558
|
-
def __init__(self, api_key, timeout):
|
598
|
+
def __init__(self, api_key: str, base_url: str, timeout: float):
|
559
599
|
super().__init__(
|
560
600
|
api_key=api_key,
|
601
|
+
base_url=base_url,
|
561
602
|
timeout=timeout,
|
562
603
|
)
|
563
604
|
self._sse_class = _SSE(self._http_url(), self.headers, self.timeout)
|
@@ -573,7 +614,8 @@ class TTS(Resource):
|
|
573
614
|
ws.connect()
|
574
615
|
return ws
|
575
616
|
|
576
|
-
|
617
|
+
@staticmethod
|
618
|
+
def get_output_format(output_format_name: str) -> OutputFormat:
|
577
619
|
"""Convenience method to get the output_format dictionary from a given output format name.
|
578
620
|
|
579
621
|
Args:
|
@@ -631,22 +673,27 @@ class AsyncCartesia(Cartesia):
|
|
631
673
|
self,
|
632
674
|
*,
|
633
675
|
api_key: Optional[str] = None,
|
676
|
+
base_url: Optional[str] = None,
|
634
677
|
timeout: float = DEFAULT_TIMEOUT,
|
635
678
|
max_num_connections: int = DEFAULT_NUM_CONNECTIONS,
|
636
679
|
):
|
637
680
|
"""
|
638
681
|
Args:
|
639
682
|
api_key: See :class:`Cartesia`.
|
683
|
+
base_url: See :class:`Cartesia`.
|
640
684
|
timeout: See :class:`Cartesia`.
|
641
685
|
max_num_connections: The maximum number of concurrent connections to use for the client.
|
642
686
|
This is used to limit the number of connections that can be made to the server.
|
643
687
|
"""
|
644
688
|
self._session = None
|
645
689
|
self._loop = None
|
646
|
-
super().__init__(api_key=api_key, timeout=timeout)
|
690
|
+
super().__init__(api_key=api_key, base_url=base_url, timeout=timeout)
|
647
691
|
self.max_num_connections = max_num_connections
|
648
692
|
self.tts = AsyncTTS(
|
649
|
-
api_key=self.api_key,
|
693
|
+
api_key=self.api_key,
|
694
|
+
base_url=self._base_url,
|
695
|
+
timeout=self.timeout,
|
696
|
+
get_session=self._get_session,
|
650
697
|
)
|
651
698
|
|
652
699
|
async def _get_session(self):
|
@@ -677,7 +724,7 @@ class AsyncCartesia(Cartesia):
|
|
677
724
|
|
678
725
|
if loop is None:
|
679
726
|
asyncio.run(self.close())
|
680
|
-
|
727
|
+
elif loop.is_running():
|
681
728
|
loop.create_task(self.close())
|
682
729
|
|
683
730
|
async def __aenter__(self):
|
@@ -753,8 +800,7 @@ class _AsyncSSE(_SSE):
|
|
753
800
|
async for chunk in self._sse_generator(request_body):
|
754
801
|
yield chunk
|
755
802
|
except Exception as e:
|
756
|
-
|
757
|
-
raise e
|
803
|
+
raise RuntimeError(f"Error generating audio. {e}")
|
758
804
|
|
759
805
|
async def _sse_generator(self, request_body: Dict[str, Any]):
|
760
806
|
session = await self._get_session()
|
@@ -779,6 +825,141 @@ class _AsyncSSE(_SSE):
|
|
779
825
|
pass
|
780
826
|
|
781
827
|
|
828
|
+
class _AsyncTTSContext:
|
829
|
+
"""Manage a single context over a WebSocket.
|
830
|
+
|
831
|
+
This class separates sending requests and receiving responses into two separate methods.
|
832
|
+
This can be used for sending multiple requests without awaiting the response.
|
833
|
+
Then you can listen to the responses in the order they were sent. See README for usage.
|
834
|
+
|
835
|
+
Each AsyncTTSContext will close automatically when a done message is received for that context.
|
836
|
+
This happens when the no_more_inputs method is called (equivalent to sending a request with `continue_ = False`),
|
837
|
+
or if no requests have been sent for 5 seconds on the same context. It also closes if there is an error.
|
838
|
+
|
839
|
+
"""
|
840
|
+
|
841
|
+
def __init__(self, context_id: str, websocket: "_AsyncWebSocket", timeout: float):
|
842
|
+
self._context_id = context_id
|
843
|
+
self._websocket = websocket
|
844
|
+
self.timeout = timeout
|
845
|
+
self._error = None
|
846
|
+
|
847
|
+
@property
|
848
|
+
def context_id(self) -> str:
|
849
|
+
return self._context_id
|
850
|
+
|
851
|
+
async def send(
|
852
|
+
self,
|
853
|
+
model_id: str,
|
854
|
+
transcript: str,
|
855
|
+
output_format: OutputFormat,
|
856
|
+
voice_id: Optional[str] = None,
|
857
|
+
voice_embedding: Optional[List[float]] = None,
|
858
|
+
context_id: Optional[str] = None,
|
859
|
+
continue_: bool = False,
|
860
|
+
duration: Optional[int] = None,
|
861
|
+
language: Optional[str] = None,
|
862
|
+
) -> None:
|
863
|
+
"""Send audio generation requests to the WebSocket. The response can be received using the `receive` method.
|
864
|
+
|
865
|
+
Args:
|
866
|
+
model_id: The ID of the model to use for generating audio.
|
867
|
+
transcript: The text to convert to speech.
|
868
|
+
output_format: A dictionary containing the details of the output format.
|
869
|
+
voice_id: The ID of the voice to use for generating audio.
|
870
|
+
voice_embedding: The embedding of the voice to use for generating audio.
|
871
|
+
context_id: The context ID to use for the request. If not specified, a random context ID will be generated.
|
872
|
+
continue_: Whether to continue the audio generation from the previous transcript or not.
|
873
|
+
duration: The duration of the audio in seconds.
|
874
|
+
language: The language code for the audio request. This can only be used with `model_id = sonic-multilingual`
|
875
|
+
|
876
|
+
Returns:
|
877
|
+
None.
|
878
|
+
"""
|
879
|
+
if context_id is not None and context_id != self._context_id:
|
880
|
+
raise ValueError("Context ID does not match the context ID of the current context.")
|
881
|
+
if continue_ and transcript == "":
|
882
|
+
raise ValueError("Transcript cannot be empty when continue_ is True.")
|
883
|
+
|
884
|
+
await self._websocket.connect()
|
885
|
+
|
886
|
+
voice = self._websocket._validate_and_construct_voice(voice_id, voice_embedding)
|
887
|
+
|
888
|
+
request_body = {
|
889
|
+
"model_id": model_id,
|
890
|
+
"transcript": transcript,
|
891
|
+
"voice": voice,
|
892
|
+
"output_format": {
|
893
|
+
"container": output_format["container"],
|
894
|
+
"encoding": output_format["encoding"],
|
895
|
+
"sample_rate": output_format["sample_rate"],
|
896
|
+
},
|
897
|
+
"context_id": self._context_id,
|
898
|
+
"continue": continue_,
|
899
|
+
"language": language,
|
900
|
+
}
|
901
|
+
|
902
|
+
if duration is not None:
|
903
|
+
request_body["duration"] = duration
|
904
|
+
|
905
|
+
await self._websocket.websocket.send_json(request_body)
|
906
|
+
|
907
|
+
# Start listening for responses on the WebSocket
|
908
|
+
self._websocket._dispatch_listener()
|
909
|
+
|
910
|
+
async def no_more_inputs(self) -> None:
|
911
|
+
"""Send a request to the WebSocket to indicate that no more requests will be sent."""
|
912
|
+
await self.send(
|
913
|
+
model_id=DEFAULT_MODEL_ID,
|
914
|
+
transcript="",
|
915
|
+
output_format=TTS.get_output_format("raw_pcm_f32le_44100"),
|
916
|
+
voice_id="a0e99841-438c-4a64-b679-ae501e7d6091", # Default voice ID since it's a required input for now
|
917
|
+
context_id=self._context_id,
|
918
|
+
continue_=False,
|
919
|
+
)
|
920
|
+
|
921
|
+
async def receive(self) -> AsyncGenerator[Dict[str, Any], None]:
|
922
|
+
"""Receive the audio chunks from the WebSocket. This method is a generator that yields audio chunks.
|
923
|
+
|
924
|
+
Returns:
|
925
|
+
An async generator that yields audio chunks. Each chunk is a dictionary containing the audio as bytes.
|
926
|
+
"""
|
927
|
+
try:
|
928
|
+
while True:
|
929
|
+
response = await self._websocket._get_message(
|
930
|
+
self._context_id, timeout=self.timeout
|
931
|
+
)
|
932
|
+
if "error" in response:
|
933
|
+
raise RuntimeError(f"Error generating audio:\n{response['error']}")
|
934
|
+
if response["done"]:
|
935
|
+
break
|
936
|
+
yield self._websocket._convert_response(response, include_context_id=True)
|
937
|
+
except Exception as e:
|
938
|
+
if isinstance(e, asyncio.TimeoutError):
|
939
|
+
raise RuntimeError("Timeout while waiting for audio chunk")
|
940
|
+
raise RuntimeError(f"Failed to generate audio:\n{e}")
|
941
|
+
finally:
|
942
|
+
self._close()
|
943
|
+
|
944
|
+
def _close(self) -> None:
|
945
|
+
"""Closes the context. Automatically called when a done message is received for this context."""
|
946
|
+
self._websocket._remove_context(self._context_id)
|
947
|
+
|
948
|
+
async def __aenter__(self):
|
949
|
+
return self
|
950
|
+
|
951
|
+
async def __aexit__(
|
952
|
+
self,
|
953
|
+
exc_type: Union[type, None],
|
954
|
+
exc: Union[BaseException, None],
|
955
|
+
exc_tb: Union[TracebackType, None],
|
956
|
+
):
|
957
|
+
self._close()
|
958
|
+
|
959
|
+
def __del__(self):
|
960
|
+
self._close()
|
961
|
+
|
962
|
+
|
782
963
|
class _AsyncWebSocket(_WebSocket):
|
783
964
|
"""This class contains methods to generate audio using WebSocket asynchronously."""
|
784
965
|
|
@@ -787,19 +968,45 @@ class _AsyncWebSocket(_WebSocket):
|
|
787
968
|
ws_url: str,
|
788
969
|
api_key: str,
|
789
970
|
cartesia_version: str,
|
971
|
+
timeout: float,
|
790
972
|
get_session: Callable[[], Optional[aiohttp.ClientSession]],
|
791
973
|
):
|
974
|
+
"""
|
975
|
+
Args:
|
976
|
+
ws_url: The WebSocket URL for the Cartesia API.
|
977
|
+
api_key: The API key to use for authorization.
|
978
|
+
cartesia_version: The version of the Cartesia API to use.
|
979
|
+
timeout: The timeout for responses on the WebSocket in seconds.
|
980
|
+
get_session: A function that returns an aiohttp.ClientSession object.
|
981
|
+
"""
|
792
982
|
super().__init__(ws_url, api_key, cartesia_version)
|
983
|
+
self.timeout = timeout
|
793
984
|
self._get_session = get_session
|
794
985
|
self.websocket = None
|
986
|
+
self._context_queues: Dict[str, asyncio.Queue] = {}
|
987
|
+
self._processing_task: asyncio.Task = None
|
988
|
+
|
989
|
+
def __del__(self):
|
990
|
+
try:
|
991
|
+
loop = asyncio.get_running_loop()
|
992
|
+
except RuntimeError:
|
993
|
+
loop = None
|
994
|
+
|
995
|
+
if loop is None:
|
996
|
+
asyncio.run(self.close())
|
997
|
+
elif loop.is_running():
|
998
|
+
loop.create_task(self.close())
|
795
999
|
|
796
1000
|
async def connect(self):
|
797
1001
|
if self.websocket is None or self._is_websocket_closed():
|
798
1002
|
route = "tts/websocket"
|
799
1003
|
session = await self._get_session()
|
800
|
-
|
801
|
-
|
802
|
-
|
1004
|
+
try:
|
1005
|
+
self.websocket = await session.ws_connect(
|
1006
|
+
f"{self.ws_url}/{route}?api_key={self.api_key}&cartesia_version={self.cartesia_version}"
|
1007
|
+
)
|
1008
|
+
except Exception as e:
|
1009
|
+
raise RuntimeError(f"Failed to connect to WebSocket. {e}")
|
803
1010
|
|
804
1011
|
def _is_websocket_closed(self):
|
805
1012
|
return self.websocket.closed
|
@@ -808,6 +1015,25 @@ class _AsyncWebSocket(_WebSocket):
|
|
808
1015
|
"""This method closes the websocket connection. *Highly* recommended to call this method when done."""
|
809
1016
|
if self.websocket is not None and not self._is_websocket_closed():
|
810
1017
|
await self.websocket.close()
|
1018
|
+
if self._processing_task:
|
1019
|
+
self._processing_task.cancel()
|
1020
|
+
try:
|
1021
|
+
self._processing_task = None
|
1022
|
+
except asyncio.CancelledError:
|
1023
|
+
pass
|
1024
|
+
except TypeError as e:
|
1025
|
+
# Ignore the error if the task is already cancelled
|
1026
|
+
# For some reason we are getting None responses
|
1027
|
+
# TODO: This needs to be fixed - we need to think about why we are getting None responses.
|
1028
|
+
if "Received message 256:None" not in str(e):
|
1029
|
+
raise e
|
1030
|
+
|
1031
|
+
for context_id in list(self._context_queues.keys()):
|
1032
|
+
self._remove_context(context_id)
|
1033
|
+
|
1034
|
+
self._context_queues.clear()
|
1035
|
+
self._processing_task = None
|
1036
|
+
self.websocket = None
|
811
1037
|
|
812
1038
|
async def send(
|
813
1039
|
self,
|
@@ -819,32 +1045,26 @@ class _AsyncWebSocket(_WebSocket):
|
|
819
1045
|
context_id: Optional[str] = None,
|
820
1046
|
duration: Optional[int] = None,
|
821
1047
|
language: Optional[str] = None,
|
822
|
-
stream:
|
1048
|
+
stream: bool = True,
|
823
1049
|
) -> Union[bytes, AsyncGenerator[bytes, None]]:
|
824
|
-
await self.connect()
|
825
|
-
|
826
1050
|
if context_id is None:
|
827
|
-
context_id = uuid.uuid4()
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
}
|
843
|
-
|
844
|
-
if duration is not None:
|
845
|
-
request_body["duration"] = duration
|
1051
|
+
context_id = str(uuid.uuid4())
|
1052
|
+
|
1053
|
+
ctx = self.context(context_id)
|
1054
|
+
|
1055
|
+
await ctx.send(
|
1056
|
+
model_id=model_id,
|
1057
|
+
transcript=transcript,
|
1058
|
+
output_format=output_format,
|
1059
|
+
voice_id=voice_id,
|
1060
|
+
voice_embedding=voice_embedding,
|
1061
|
+
context_id=context_id,
|
1062
|
+
duration=duration,
|
1063
|
+
language=language,
|
1064
|
+
continue_=False,
|
1065
|
+
)
|
846
1066
|
|
847
|
-
generator =
|
1067
|
+
generator = ctx.receive()
|
848
1068
|
|
849
1069
|
if stream:
|
850
1070
|
return generator
|
@@ -855,35 +1075,51 @@ class _AsyncWebSocket(_WebSocket):
|
|
855
1075
|
|
856
1076
|
return {"audio": b"".join(chunks), "context_id": context_id}
|
857
1077
|
|
858
|
-
async def
|
859
|
-
await self.websocket.send_json(request_body)
|
860
|
-
|
1078
|
+
async def _process_responses(self):
|
861
1079
|
try:
|
862
|
-
response = None
|
863
1080
|
while True:
|
864
1081
|
response = await self.websocket.receive_json()
|
865
|
-
if "
|
866
|
-
|
867
|
-
if
|
868
|
-
|
869
|
-
|
870
|
-
yield self._convert_response(response=response, include_context_id=True)
|
1082
|
+
if response["context_id"]:
|
1083
|
+
context_id = response["context_id"]
|
1084
|
+
if context_id in self._context_queues:
|
1085
|
+
await self._context_queues[context_id].put(response)
|
871
1086
|
except Exception as e:
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
1087
|
+
self._error = e
|
1088
|
+
raise e
|
1089
|
+
|
1090
|
+
async def _get_message(self, context_id: str, timeout: float) -> Dict[str, Any]:
|
1091
|
+
if context_id not in self._context_queues:
|
1092
|
+
raise ValueError(f"Context ID {context_id} not found.")
|
1093
|
+
return await asyncio.wait_for(self._context_queues[context_id].get(), timeout=timeout)
|
1094
|
+
|
1095
|
+
def _remove_context(self, context_id: str):
|
1096
|
+
if context_id in self._context_queues:
|
1097
|
+
del self._context_queues[context_id]
|
1098
|
+
|
1099
|
+
def _dispatch_listener(self):
|
1100
|
+
if self._processing_task is None or self._processing_task.done():
|
1101
|
+
self._processing_task = asyncio.create_task(self._process_responses())
|
1102
|
+
|
1103
|
+
def context(self, context_id: Optional[str] = None) -> _AsyncTTSContext:
|
1104
|
+
if context_id in self._context_queues:
|
1105
|
+
raise ValueError(f"AsyncContext for context ID {context_id} already exists.")
|
1106
|
+
if context_id is None:
|
1107
|
+
context_id = str(uuid.uuid4())
|
1108
|
+
if context_id not in self._context_queues:
|
1109
|
+
self._context_queues[context_id] = asyncio.Queue()
|
1110
|
+
return _AsyncTTSContext(context_id, self, self.timeout)
|
877
1111
|
|
878
1112
|
|
879
1113
|
class AsyncTTS(TTS):
|
880
|
-
def __init__(self, api_key, timeout, get_session):
|
881
|
-
super().__init__(api_key, timeout)
|
1114
|
+
def __init__(self, api_key, base_url, timeout, get_session):
|
1115
|
+
super().__init__(api_key, base_url, timeout)
|
882
1116
|
self._get_session = get_session
|
883
1117
|
self._sse_class = _AsyncSSE(self._http_url(), self.headers, self.timeout, get_session)
|
884
1118
|
self.sse = self._sse_class.send
|
885
1119
|
|
886
1120
|
async def websocket(self) -> _AsyncWebSocket:
|
887
|
-
ws = _AsyncWebSocket(
|
1121
|
+
ws = _AsyncWebSocket(
|
1122
|
+
self._ws_url(), self.api_key, self.cartesia_version, self.timeout, self._get_session
|
1123
|
+
)
|
888
1124
|
await ws.connect()
|
889
1125
|
return ws
|
File without changes
|
@@ -0,0 +1,55 @@
|
|
1
|
+
import os
|
2
|
+
import warnings
|
3
|
+
from typing import Any, Callable, TypeVar
|
4
|
+
|
5
|
+
TCallable = TypeVar("TCallable", bound=Callable[..., Any])
|
6
|
+
|
7
|
+
# List of statistics of deprecated functions.
|
8
|
+
# This should only be used by the test suite to find any deprecated functions
|
9
|
+
# that should be removed for this version.
|
10
|
+
_TRACK_DEPRECATED_FUNCTION_STATS = os.environ.get("CARTESIA_TEST_DEPRECATED", "").lower() == "true"
|
11
|
+
_DEPRECATED_FUNCTION_STATS = []
|
12
|
+
|
13
|
+
|
14
|
+
def deprecated(
|
15
|
+
reason=None, vdeprecated=None, vremove=None, replacement=None
|
16
|
+
) -> Callable[[TCallable], TCallable]:
|
17
|
+
local_vars = locals()
|
18
|
+
|
19
|
+
def fn(func: TCallable) -> TCallable:
|
20
|
+
if isinstance(func, classmethod):
|
21
|
+
func = func.__func__
|
22
|
+
msg = _get_deprecated_msg(func, reason, vdeprecated, vremove, replacement)
|
23
|
+
warnings.warn(msg, DeprecationWarning)
|
24
|
+
return func
|
25
|
+
|
26
|
+
if _TRACK_DEPRECATED_FUNCTION_STATS: # pragma: no cover
|
27
|
+
_DEPRECATED_FUNCTION_STATS.append(local_vars)
|
28
|
+
|
29
|
+
return fn
|
30
|
+
|
31
|
+
|
32
|
+
def _get_deprecated_msg(wrapped, reason, vdeprecated, vremoved, replacement=None):
|
33
|
+
fmt = "{name} is deprecated"
|
34
|
+
if vdeprecated:
|
35
|
+
fmt += " since v{vdeprecated}"
|
36
|
+
if vremoved:
|
37
|
+
fmt += " and will be removed in v{vremoved}"
|
38
|
+
fmt += "."
|
39
|
+
|
40
|
+
if reason:
|
41
|
+
fmt += " ({reason})"
|
42
|
+
if replacement:
|
43
|
+
fmt += " -- Use {replacement} instead."
|
44
|
+
|
45
|
+
return fmt.format(
|
46
|
+
name=wrapped.__name__,
|
47
|
+
reason=reason or "",
|
48
|
+
vdeprecated=vdeprecated or "",
|
49
|
+
vremoved=vremoved or "",
|
50
|
+
replacement=replacement or "",
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
# This method is taken from the following source:
|
55
|
+
# https://github.com/ad12/meddlr/blob/main/meddlr/utils/deprecated.py
|
cartesia/utils/retry.py
ADDED
@@ -0,0 +1,87 @@
|
|
1
|
+
import time
|
2
|
+
|
3
|
+
from aiohttp.client_exceptions import ServerDisconnectedError
|
4
|
+
import asyncio
|
5
|
+
from functools import wraps
|
6
|
+
from http.client import RemoteDisconnected
|
7
|
+
from httpx import TimeoutException
|
8
|
+
from requests.exceptions import ConnectionError
|
9
|
+
|
10
|
+
|
11
|
+
def retry_on_connection_error(max_retries=3, backoff_factor=1, logger=None):
|
12
|
+
"""Retry a function if a ConnectionError, RemoteDisconnected, ServerDisconnectedError, or TimeoutException occurs.
|
13
|
+
|
14
|
+
Args:
|
15
|
+
max_retries (int): The maximum number of retries.
|
16
|
+
backoff_factor (int): The factor to increase the delay between retries.
|
17
|
+
logger (logging.Logger): The logger to use for logging.
|
18
|
+
"""
|
19
|
+
|
20
|
+
def decorator(func):
|
21
|
+
@wraps(func)
|
22
|
+
def wrapper(*args, **kwargs):
|
23
|
+
retry_count = 0
|
24
|
+
while retry_count < max_retries:
|
25
|
+
try:
|
26
|
+
return func(*args, **kwargs)
|
27
|
+
except (
|
28
|
+
ConnectionError,
|
29
|
+
RemoteDisconnected,
|
30
|
+
ServerDisconnectedError,
|
31
|
+
TimeoutException,
|
32
|
+
) as e:
|
33
|
+
logger.info(f"Retrying after exception: {e}")
|
34
|
+
retry_count += 1
|
35
|
+
if retry_count < max_retries:
|
36
|
+
delay = backoff_factor * (2 ** (retry_count - 1))
|
37
|
+
logger.warn(
|
38
|
+
f"Attempt {retry_count + 1}/{max_retries} in {delay} seconds..."
|
39
|
+
)
|
40
|
+
time.sleep(delay)
|
41
|
+
else:
|
42
|
+
raise Exception(f"Exception occurred after {max_retries} tries.") from e
|
43
|
+
|
44
|
+
return wrapper
|
45
|
+
|
46
|
+
return decorator
|
47
|
+
|
48
|
+
|
49
|
+
def retry_on_connection_error_async(max_retries=3, backoff_factor=1, logger=None):
|
50
|
+
"""Retry an asynchronous function if a ConnectionError, RemoteDisconnected, ServerDisconnectedError, or TimeoutException occurs.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
max_retries (int): The maximum number of retries.
|
54
|
+
backoff_factor (int): The factor to increase the delay between retries.
|
55
|
+
logger (logging.Logger): The logger to use for logging.
|
56
|
+
"""
|
57
|
+
|
58
|
+
def decorator(func):
|
59
|
+
@wraps(func)
|
60
|
+
async def wrapper(*args, **kwargs):
|
61
|
+
retry_count = 0
|
62
|
+
while retry_count < max_retries:
|
63
|
+
try:
|
64
|
+
async for chunk in func(*args, **kwargs):
|
65
|
+
yield chunk
|
66
|
+
# If the function completes without raising an exception return
|
67
|
+
return
|
68
|
+
except (
|
69
|
+
ConnectionError,
|
70
|
+
RemoteDisconnected,
|
71
|
+
ServerDisconnectedError,
|
72
|
+
TimeoutException,
|
73
|
+
) as e:
|
74
|
+
logger.info(f"Retrying after exception: {e}")
|
75
|
+
retry_count += 1
|
76
|
+
if retry_count < max_retries:
|
77
|
+
delay = backoff_factor * (2 ** (retry_count - 1))
|
78
|
+
logger.warn(
|
79
|
+
f"Attempt {retry_count + 1}/{max_retries} in {delay} seconds..."
|
80
|
+
)
|
81
|
+
await asyncio.sleep(delay)
|
82
|
+
else:
|
83
|
+
raise Exception(f"Exception occurred after {max_retries} tries.") from e
|
84
|
+
|
85
|
+
return wrapper
|
86
|
+
|
87
|
+
return decorator
|
cartesia/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "1.0.
|
1
|
+
__version__ = "1.0.4"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: cartesia
|
3
|
-
Version: 1.0.
|
3
|
+
Version: 1.0.4
|
4
4
|
Summary: The official Python library for the Cartesia API.
|
5
5
|
Home-page:
|
6
6
|
Author: Cartesia, Inc.
|
@@ -244,6 +244,101 @@ p.terminate()
|
|
244
244
|
ws.close() # Close the websocket connection
|
245
245
|
```
|
246
246
|
|
247
|
+
#### Conditioning speech on previous generations using WebSocket
|
248
|
+
|
249
|
+
In some cases, input text may need to be streamed in. In these cases, it would be slow to wait for all the text to buffer before sending it to Cartesia's TTS service.
|
250
|
+
|
251
|
+
To mitigate this, Cartesia offers audio continuations. In this setting, users can send input text, as it becomes available, over a websocket connection.
|
252
|
+
|
253
|
+
To do this, we will create a `context` and sending multiple requests without awaiting the response. Then you can listen to the responses in the order they were sent.
|
254
|
+
|
255
|
+
Each `context` will be closed automatically after 5 seconds of inactivity or when the `no_more_inputs` method is called. `no_more_inputs` sends a request with the `continue_=False`, which indicates no more inputs will be sent over this context
|
256
|
+
|
257
|
+
```python
|
258
|
+
import asyncio
|
259
|
+
import os
|
260
|
+
import pyaudio
|
261
|
+
from cartesia import AsyncCartesia
|
262
|
+
|
263
|
+
async def send_transcripts(ctx):
|
264
|
+
# Check out voice IDs by calling `client.voices.list()` or on [play.cartesia.ai](https://play.cartesia.ai/)
|
265
|
+
voice_id = "87748186-23bb-4158-a1eb-332911b0b708"
|
266
|
+
|
267
|
+
# You can check out our models at [docs.cartesia.ai](https://docs.cartesia.ai/getting-started/available-models).
|
268
|
+
model_id = "sonic-english"
|
269
|
+
|
270
|
+
# You can find the supported `output_format`s in our [API Reference](https://docs.cartesia.ai/api-reference/endpoints/stream-speech-server-sent-events).
|
271
|
+
output_format = {
|
272
|
+
"container": "raw",
|
273
|
+
"encoding": "pcm_f32le",
|
274
|
+
"sample_rate": 44100,
|
275
|
+
}
|
276
|
+
|
277
|
+
transcripts = [
|
278
|
+
"Sonic and Yoshi team up in a dimension-hopping adventure! ",
|
279
|
+
"Racing through twisting zones, they dodge Eggman's badniks and solve ancient puzzles. ",
|
280
|
+
"In the Echoing Caverns, they find the Harmonic Crystal, unlocking new powers. ",
|
281
|
+
"Sonic's speed creates sound waves, while Yoshi's eggs become sonic bolts. ",
|
282
|
+
"As they near Eggman's lair, our heroes charge their abilities for an epic boss battle. ",
|
283
|
+
"Get ready to spin, jump, and sound-blast your way to victory in this high-octane crossover!"
|
284
|
+
]
|
285
|
+
|
286
|
+
for transcript in transcripts:
|
287
|
+
# Send text inputs as they become available
|
288
|
+
await ctx.send(
|
289
|
+
model_id=model_id,
|
290
|
+
transcript=transcript,
|
291
|
+
voice_id=voice_id,
|
292
|
+
continue_=True,
|
293
|
+
output_format=output_format,
|
294
|
+
)
|
295
|
+
|
296
|
+
# Indicate that no more inputs will be sent. Otherwise, the context will close after 5 seconds of inactivity.
|
297
|
+
await ctx.no_more_inputs()
|
298
|
+
|
299
|
+
async def receive_and_play_audio(ctx):
|
300
|
+
p = pyaudio.PyAudio()
|
301
|
+
stream = None
|
302
|
+
rate = 44100
|
303
|
+
|
304
|
+
async for output in ctx.receive():
|
305
|
+
buffer = output["audio"]
|
306
|
+
|
307
|
+
if not stream:
|
308
|
+
stream = p.open(
|
309
|
+
format=pyaudio.paFloat32,
|
310
|
+
channels=1,
|
311
|
+
rate=rate,
|
312
|
+
output=True
|
313
|
+
)
|
314
|
+
|
315
|
+
stream.write(buffer)
|
316
|
+
|
317
|
+
stream.stop_stream()
|
318
|
+
stream.close()
|
319
|
+
p.terminate()
|
320
|
+
|
321
|
+
async def stream_and_listen():
|
322
|
+
client = AsyncCartesia(api_key=os.environ.get("CARTESIA_API_KEY"))
|
323
|
+
|
324
|
+
# Set up the websocket connection
|
325
|
+
ws = await client.tts.websocket()
|
326
|
+
|
327
|
+
# Create a context to send and receive audio
|
328
|
+
ctx = ws.context() # Generates a random context ID if not provided
|
329
|
+
|
330
|
+
send_task = asyncio.create_task(send_transcripts(ctx))
|
331
|
+
listen_task = asyncio.create_task(receive_and_play_audio(ctx))
|
332
|
+
|
333
|
+
# Call the two coroutine tasks concurrently
|
334
|
+
await asyncio.gather(send_task, listen_task)
|
335
|
+
|
336
|
+
await ws.close()
|
337
|
+
await client.close()
|
338
|
+
|
339
|
+
asyncio.run(stream_and_listen())
|
340
|
+
```
|
341
|
+
|
247
342
|
### Multilingual Text-to-Speech [Alpha]
|
248
343
|
|
249
344
|
You can use our `sonic-multilingual` model to generate audio in multiple languages. The languages supported are available at [docs.cartesia.ai](https://docs.cartesia.ai/getting-started/available-models).
|
@@ -0,0 +1,11 @@
|
|
1
|
+
cartesia/__init__.py,sha256=jMIf2O7dTGxvTA5AfXtmh1H_EGfMtQseR5wXrjNRbLs,93
|
2
|
+
cartesia/_types.py,sha256=tO3Nef_V78TDMKDuIv_wsQLkxoSvYG4bdzFkMGXUFho,3765
|
3
|
+
cartesia/client.py,sha256=UCNTAU8eVzb-o-bygxfQQXWTDov_FX8dbAQdn7a8Hr0,41458
|
4
|
+
cartesia/version.py,sha256=acuR_XSJzp4OrQ5T8-Ac5gYe48mUwObuwjRmisFmZ7k,22
|
5
|
+
cartesia/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
+
cartesia/utils/deprecated.py,sha256=2cXvGtrxhPeUZA5LWy2n_U5OFLDv7SHeFtzqhjSJGyk,1674
|
7
|
+
cartesia/utils/retry.py,sha256=nuwWRfu3MOVTxIQMLjYf6WLaxSlnu_GdE3QjTV0zisQ,3339
|
8
|
+
cartesia-1.0.4.dist-info/METADATA,sha256=N7NoGr6XBtmLI6EHsG3efw0QNJ7uhV_E9HV8uqTYfQM,15991
|
9
|
+
cartesia-1.0.4.dist-info/WHEEL,sha256=DZajD4pwLWue70CAfc7YaxT1wLUciNBvN_TTcvXpltE,110
|
10
|
+
cartesia-1.0.4.dist-info/top_level.txt,sha256=rTX4HnnCegMxl1FK9czpVC7GAvf3SwDzPG65qP-BS4w,9
|
11
|
+
cartesia-1.0.4.dist-info/RECORD,,
|
cartesia-1.0.2.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
cartesia/__init__.py,sha256=jMIf2O7dTGxvTA5AfXtmh1H_EGfMtQseR5wXrjNRbLs,93
|
2
|
-
cartesia/_types.py,sha256=msXRqNwVx_mbcLIQgRJYEl8U-hO9LRPWmscnX89cBCY,3747
|
3
|
-
cartesia/client.py,sha256=jMlFDPRtKVDelqevHlv7YZJgOES3ws9BFN_6uUyN0W8,32720
|
4
|
-
cartesia/version.py,sha256=Y3LSfRioSl2xch70pq_ULlvyECXyEtN3krVaWeGyaxk,22
|
5
|
-
cartesia-1.0.2.dist-info/METADATA,sha256=LPW7f4297S2DQ_uRtc3A-7-5WUCXsIgGcr1M3XelXys,12394
|
6
|
-
cartesia-1.0.2.dist-info/WHEEL,sha256=DZajD4pwLWue70CAfc7YaxT1wLUciNBvN_TTcvXpltE,110
|
7
|
-
cartesia-1.0.2.dist-info/top_level.txt,sha256=rTX4HnnCegMxl1FK9czpVC7GAvf3SwDzPG65qP-BS4w,9
|
8
|
-
cartesia-1.0.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|