cartesia 1.0.3__py2.py3-none-any.whl → 1.0.4__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cartesia/_types.py CHANGED
@@ -45,15 +45,16 @@ class DeprecatedOutputFormatMapping:
45
45
  "mulaw_8000": {"container": "raw", "encoding": "pcm_mulaw", "sample_rate": 8000},
46
46
  "alaw_8000": {"container": "raw", "encoding": "pcm_alaw", "sample_rate": 8000},
47
47
  }
48
-
48
+
49
+ @classmethod
49
50
  @deprecated(
50
51
  vdeprecated="1.0.1",
51
52
  vremove="1.2.0",
52
53
  reason="Old output format names are being deprecated in favor of names aligned with the Cartesia API. Use names from `OutputFormatMapping` instead.",
53
54
  )
54
- def get_format_deprecated(self, format_name):
55
- if format_name in self._format_mapping:
56
- return self._format_mapping[format_name]
55
+ def get_format_deprecated(cls, format_name):
56
+ if format_name in cls._format_mapping:
57
+ return cls._format_mapping[format_name]
57
58
  else:
58
59
  raise ValueError(f"Unsupported format: {format_name}")
59
60
 
cartesia/client.py CHANGED
@@ -4,7 +4,17 @@ import json
4
4
  import os
5
5
  import uuid
6
6
  from types import TracebackType
7
- from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union, Callable
7
+ from typing import (
8
+ Any,
9
+ AsyncGenerator,
10
+ Dict,
11
+ Generator,
12
+ List,
13
+ Optional,
14
+ Tuple,
15
+ Union,
16
+ Callable,
17
+ )
8
18
 
9
19
  import aiohttp
10
20
  import httpx
@@ -13,7 +23,6 @@ import requests
13
23
  from websockets.sync.client import connect
14
24
 
15
25
  from cartesia.utils.retry import retry_on_connection_error, retry_on_connection_error_async
16
- from cartesia.utils.deprecated import deprecated
17
26
  from cartesia._types import (
18
27
  OutputFormat,
19
28
  OutputFormatMapping,
@@ -36,22 +45,34 @@ logger = logging.getLogger(__name__)
36
45
 
37
46
 
38
47
  class BaseClient:
39
- def __init__(self, *, api_key: Optional[str] = None, timeout: float = DEFAULT_TIMEOUT):
48
+ def __init__(
49
+ self,
50
+ *,
51
+ api_key: Optional[str] = None,
52
+ base_url: Optional[str] = None,
53
+ timeout: float = DEFAULT_TIMEOUT,
54
+ ):
40
55
  """Constructor for the BaseClient. Used by the Cartesia and AsyncCartesia clients."""
41
56
  self.api_key = api_key or os.environ.get("CARTESIA_API_KEY")
57
+ self._base_url = base_url or os.environ.get("CARTESIA_BASE_URL", DEFAULT_BASE_URL)
42
58
  self.timeout = timeout
43
59
 
60
+ @property
61
+ def base_url(self):
62
+ return self._base_url
63
+
44
64
 
45
65
  class Resource:
46
66
  def __init__(
47
67
  self,
48
68
  api_key: str,
69
+ base_url: str,
49
70
  timeout: float,
50
71
  ):
51
72
  """Constructor for the Resource class. Used by the Voices and TTS classes."""
52
73
  self.api_key = api_key
53
74
  self.timeout = timeout
54
- self.base_url = os.environ.get("CARTESIA_BASE_URL", DEFAULT_BASE_URL)
75
+ self._base_url = base_url
55
76
  self.cartesia_version = DEFAULT_CARTESIA_VERSION
56
77
  self.headers = {
57
78
  "X-API-Key": self.api_key,
@@ -59,25 +80,29 @@ class Resource:
59
80
  "Content-Type": "application/json",
60
81
  }
61
82
 
83
+ @property
84
+ def base_url(self):
85
+ return self._base_url
86
+
62
87
  def _http_url(self):
63
88
  """Returns the HTTP URL for the Cartesia API.
64
89
  If the base URL is localhost, the URL will start with 'http'. Otherwise, it will start with 'https'.
65
90
  """
66
- if self.base_url.startswith("http://") or self.base_url.startswith("https://"):
67
- return self.base_url
91
+ if self._base_url.startswith("http://") or self._base_url.startswith("https://"):
92
+ return self._base_url
68
93
  else:
69
- prefix = "http" if "localhost" in self.base_url else "https"
70
- return f"{prefix}://{self.base_url}"
94
+ prefix = "http" if "localhost" in self._base_url else "https"
95
+ return f"{prefix}://{self._base_url}"
71
96
 
72
97
  def _ws_url(self):
73
98
  """Returns the WebSocket URL for the Cartesia API.
74
99
  If the base URL is localhost, the URL will start with 'ws'. Otherwise, it will start with 'wss'.
75
100
  """
76
- if self.base_url.startswith("ws://") or self.base_url.startswith("wss://"):
77
- return self.base_url
101
+ if self._base_url.startswith("ws://") or self._base_url.startswith("wss://"):
102
+ return self._base_url
78
103
  else:
79
- prefix = "ws" if "localhost" in self.base_url else "wss"
80
- return f"{prefix}://{self.base_url}"
104
+ prefix = "ws" if "localhost" in self._base_url else "wss"
105
+ return f"{prefix}://{self._base_url}"
81
106
 
82
107
 
83
108
  class Cartesia(BaseClient):
@@ -90,18 +115,27 @@ class Cartesia(BaseClient):
90
115
  The client supports generating audio using both Server-Sent Events and WebSocket for lower latency.
91
116
  """
92
117
 
93
- def __init__(self, *, api_key: Optional[str] = None, timeout: float = DEFAULT_TIMEOUT):
118
+ def __init__(
119
+ self,
120
+ *,
121
+ api_key: Optional[str] = None,
122
+ base_url: Optional[str] = None,
123
+ timeout: float = DEFAULT_TIMEOUT,
124
+ ):
94
125
  """Constructor for the Cartesia client.
95
126
 
96
127
  Args:
97
128
  api_key: The API key to use for authorization.
98
129
  If not specified, the API key will be read from the environment variable
99
130
  `CARTESIA_API_KEY`.
100
- timeout: The timeout for the HTTP requests in seconds. Defaults to 30 seconds.
131
+ base_url: The base URL for the Cartesia API.
132
+ If not specified, the base URL will be read from the enviroment variable
133
+ `CARTESIA_BASE_URL`. Defaults to `api.cartesia.ai`.
134
+ timeout: The timeout for HTTP and WebSocket requests in seconds. Defaults to 30 seconds.
101
135
  """
102
- super().__init__(api_key=api_key, timeout=timeout)
103
- self.voices = Voices(api_key=self.api_key, timeout=self.timeout)
104
- self.tts = TTS(api_key=self.api_key, timeout=self.timeout)
136
+ super().__init__(api_key=api_key, base_url=base_url, timeout=timeout)
137
+ self.voices = Voices(api_key=self.api_key, base_url=self._base_url, timeout=self.timeout)
138
+ self.tts = TTS(api_key=self.api_key, base_url=self._base_url, timeout=self.timeout)
105
139
 
106
140
  def __enter__(self):
107
141
  return self
@@ -188,7 +222,6 @@ class Voices(Resource):
188
222
  files = {"clip": file}
189
223
  headers = self.headers.copy()
190
224
  headers.pop("Content-Type", None)
191
- headers["Content-Type"] = "multipart/form-data"
192
225
  response = httpx.post(url, headers=headers, files=files, timeout=self.timeout)
193
226
  if not response.is_success:
194
227
  raise ValueError(f"Failed to clone voice from clip. Error: {response.text}")
@@ -233,8 +266,9 @@ class _WebSocket:
233
266
  Usage:
234
267
  >>> ws = client.tts.websocket()
235
268
  >>> for audio_chunk in ws.send(
236
- ... model_id="upbeat-moon", transcript="Hello world!", voice_embedding=embedding,
237
- ... output_format={"container": "raw", "encoding": "pcm_f32le", "sample_rate": 44100}, stream=True
269
+ ... model_id="sonic-english", transcript="Hello world!", voice_embedding=embedding,
270
+ ... output_format={"container": "raw", "encoding": "pcm_f32le", "sample_rate": 44100},
271
+ ... context_id=context_id, stream=True
238
272
  ... ):
239
273
  ... audio = audio_chunk["audio"]
240
274
  """
@@ -251,12 +285,19 @@ class _WebSocket:
251
285
  self.websocket = None
252
286
 
253
287
  def connect(self):
254
- """This method connects to the WebSocket if it is not already connected."""
288
+ """This method connects to the WebSocket if it is not already connected.
289
+
290
+ Raises:
291
+ RuntimeError: If the connection to the WebSocket fails.
292
+ """
255
293
  if self.websocket is None or self._is_websocket_closed():
256
294
  route = "tts/websocket"
257
- self.websocket = connect(
258
- f"{self.ws_url}/{route}?api_key={self.api_key}&cartesia_version={self.cartesia_version}"
259
- )
295
+ try:
296
+ self.websocket = connect(
297
+ f"{self.ws_url}/{route}?api_key={self.api_key}&cartesia_version={self.cartesia_version}"
298
+ )
299
+ except Exception as e:
300
+ raise RuntimeError(f"Failed to connect to WebSocket. {e}")
260
301
 
261
302
  def _is_websocket_closed(self):
262
303
  return self.websocket.socket.fileno() == -1
@@ -329,7 +370,7 @@ class _WebSocket:
329
370
  context_id: The context ID to use for the request. If not specified, a random context ID will be generated.
330
371
  duration: The duration of the audio in seconds.
331
372
  language: The language code for the audio request. This can only be used with `model_id = sonic-multilingual`
332
- stream: Whether to stream the audio or not. (Default is True)
373
+ stream: Whether to stream the audio or not.
333
374
 
334
375
  Returns:
335
376
  If `stream` is True, the method returns a generator that yields chunks. Each chunk is a dictionary.
@@ -341,7 +382,7 @@ class _WebSocket:
341
382
  self.connect()
342
383
 
343
384
  if context_id is None:
344
- context_id = uuid.uuid4().hex
385
+ context_id = str(uuid.uuid4())
345
386
 
346
387
  voice = self._validate_and_construct_voice(voice_id, voice_embedding)
347
388
 
@@ -395,7 +436,7 @@ class _SSE:
395
436
 
396
437
  Usage:
397
438
  >>> for audio_chunk in client.tts.sse(
398
- ... model_id="upbeat-moon", transcript="Hello world!", voice_embedding=embedding,
439
+ ... model_id="sonic-english", transcript="Hello world!", voice_embedding=embedding,
399
440
  ... output_format={"container": "raw", "encoding": "pcm_f32le", "sample_rate": 44100}, stream=True
400
441
  ... ):
401
442
  ... audio = audio_chunk["audio"]
@@ -523,8 +564,7 @@ class _SSE:
523
564
  for chunk in self._sse_generator(request_body):
524
565
  yield chunk
525
566
  except Exception as e:
526
- logger.error(f"Failed to generate audio. {e}")
527
- raise e
567
+ raise RuntimeError(f"Error generating audio. {e}")
528
568
 
529
569
  def _sse_generator(self, request_body: Dict[str, Any]):
530
570
  response = requests.post(
@@ -555,9 +595,10 @@ class _SSE:
555
595
  class TTS(Resource):
556
596
  """This resource contains methods to generate audio using Cartesia's text-to-speech API."""
557
597
 
558
- def __init__(self, api_key, timeout):
598
+ def __init__(self, api_key: str, base_url: str, timeout: float):
559
599
  super().__init__(
560
600
  api_key=api_key,
601
+ base_url=base_url,
561
602
  timeout=timeout,
562
603
  )
563
604
  self._sse_class = _SSE(self._http_url(), self.headers, self.timeout)
@@ -573,7 +614,8 @@ class TTS(Resource):
573
614
  ws.connect()
574
615
  return ws
575
616
 
576
- def get_output_format(self, output_format_name: str) -> OutputFormat:
617
+ @staticmethod
618
+ def get_output_format(output_format_name: str) -> OutputFormat:
577
619
  """Convenience method to get the output_format dictionary from a given output format name.
578
620
 
579
621
  Args:
@@ -631,22 +673,27 @@ class AsyncCartesia(Cartesia):
631
673
  self,
632
674
  *,
633
675
  api_key: Optional[str] = None,
676
+ base_url: Optional[str] = None,
634
677
  timeout: float = DEFAULT_TIMEOUT,
635
678
  max_num_connections: int = DEFAULT_NUM_CONNECTIONS,
636
679
  ):
637
680
  """
638
681
  Args:
639
682
  api_key: See :class:`Cartesia`.
683
+ base_url: See :class:`Cartesia`.
640
684
  timeout: See :class:`Cartesia`.
641
685
  max_num_connections: The maximum number of concurrent connections to use for the client.
642
686
  This is used to limit the number of connections that can be made to the server.
643
687
  """
644
688
  self._session = None
645
689
  self._loop = None
646
- super().__init__(api_key=api_key, timeout=timeout)
690
+ super().__init__(api_key=api_key, base_url=base_url, timeout=timeout)
647
691
  self.max_num_connections = max_num_connections
648
692
  self.tts = AsyncTTS(
649
- api_key=self.api_key, timeout=self.timeout, get_session=self._get_session
693
+ api_key=self.api_key,
694
+ base_url=self._base_url,
695
+ timeout=self.timeout,
696
+ get_session=self._get_session,
650
697
  )
651
698
 
652
699
  async def _get_session(self):
@@ -677,7 +724,7 @@ class AsyncCartesia(Cartesia):
677
724
 
678
725
  if loop is None:
679
726
  asyncio.run(self.close())
680
- else:
727
+ elif loop.is_running():
681
728
  loop.create_task(self.close())
682
729
 
683
730
  async def __aenter__(self):
@@ -753,8 +800,7 @@ class _AsyncSSE(_SSE):
753
800
  async for chunk in self._sse_generator(request_body):
754
801
  yield chunk
755
802
  except Exception as e:
756
- logger.error(f"Failed to generate audio. {e}")
757
- raise e
803
+ raise RuntimeError(f"Error generating audio. {e}")
758
804
 
759
805
  async def _sse_generator(self, request_body: Dict[str, Any]):
760
806
  session = await self._get_session()
@@ -779,6 +825,141 @@ class _AsyncSSE(_SSE):
779
825
  pass
780
826
 
781
827
 
828
+ class _AsyncTTSContext:
829
+ """Manage a single context over a WebSocket.
830
+
831
+ This class separates sending requests and receiving responses into two separate methods.
832
+ This can be used for sending multiple requests without awaiting the response.
833
+ Then you can listen to the responses in the order they were sent. See README for usage.
834
+
835
+ Each AsyncTTSContext will close automatically when a done message is received for that context.
836
+ This happens when the no_more_inputs method is called (equivalent to sending a request with `continue_ = False`),
837
+ or if no requests have been sent for 5 seconds on the same context. It also closes if there is an error.
838
+
839
+ """
840
+
841
+ def __init__(self, context_id: str, websocket: "_AsyncWebSocket", timeout: float):
842
+ self._context_id = context_id
843
+ self._websocket = websocket
844
+ self.timeout = timeout
845
+ self._error = None
846
+
847
+ @property
848
+ def context_id(self) -> str:
849
+ return self._context_id
850
+
851
+ async def send(
852
+ self,
853
+ model_id: str,
854
+ transcript: str,
855
+ output_format: OutputFormat,
856
+ voice_id: Optional[str] = None,
857
+ voice_embedding: Optional[List[float]] = None,
858
+ context_id: Optional[str] = None,
859
+ continue_: bool = False,
860
+ duration: Optional[int] = None,
861
+ language: Optional[str] = None,
862
+ ) -> None:
863
+ """Send audio generation requests to the WebSocket. The response can be received using the `receive` method.
864
+
865
+ Args:
866
+ model_id: The ID of the model to use for generating audio.
867
+ transcript: The text to convert to speech.
868
+ output_format: A dictionary containing the details of the output format.
869
+ voice_id: The ID of the voice to use for generating audio.
870
+ voice_embedding: The embedding of the voice to use for generating audio.
871
+ context_id: The context ID to use for the request. If not specified, a random context ID will be generated.
872
+ continue_: Whether to continue the audio generation from the previous transcript or not.
873
+ duration: The duration of the audio in seconds.
874
+ language: The language code for the audio request. This can only be used with `model_id = sonic-multilingual`
875
+
876
+ Returns:
877
+ None.
878
+ """
879
+ if context_id is not None and context_id != self._context_id:
880
+ raise ValueError("Context ID does not match the context ID of the current context.")
881
+ if continue_ and transcript == "":
882
+ raise ValueError("Transcript cannot be empty when continue_ is True.")
883
+
884
+ await self._websocket.connect()
885
+
886
+ voice = self._websocket._validate_and_construct_voice(voice_id, voice_embedding)
887
+
888
+ request_body = {
889
+ "model_id": model_id,
890
+ "transcript": transcript,
891
+ "voice": voice,
892
+ "output_format": {
893
+ "container": output_format["container"],
894
+ "encoding": output_format["encoding"],
895
+ "sample_rate": output_format["sample_rate"],
896
+ },
897
+ "context_id": self._context_id,
898
+ "continue": continue_,
899
+ "language": language,
900
+ }
901
+
902
+ if duration is not None:
903
+ request_body["duration"] = duration
904
+
905
+ await self._websocket.websocket.send_json(request_body)
906
+
907
+ # Start listening for responses on the WebSocket
908
+ self._websocket._dispatch_listener()
909
+
910
+ async def no_more_inputs(self) -> None:
911
+ """Send a request to the WebSocket to indicate that no more requests will be sent."""
912
+ await self.send(
913
+ model_id=DEFAULT_MODEL_ID,
914
+ transcript="",
915
+ output_format=TTS.get_output_format("raw_pcm_f32le_44100"),
916
+ voice_id="a0e99841-438c-4a64-b679-ae501e7d6091", # Default voice ID since it's a required input for now
917
+ context_id=self._context_id,
918
+ continue_=False,
919
+ )
920
+
921
+ async def receive(self) -> AsyncGenerator[Dict[str, Any], None]:
922
+ """Receive the audio chunks from the WebSocket. This method is a generator that yields audio chunks.
923
+
924
+ Returns:
925
+ An async generator that yields audio chunks. Each chunk is a dictionary containing the audio as bytes.
926
+ """
927
+ try:
928
+ while True:
929
+ response = await self._websocket._get_message(
930
+ self._context_id, timeout=self.timeout
931
+ )
932
+ if "error" in response:
933
+ raise RuntimeError(f"Error generating audio:\n{response['error']}")
934
+ if response["done"]:
935
+ break
936
+ yield self._websocket._convert_response(response, include_context_id=True)
937
+ except Exception as e:
938
+ if isinstance(e, asyncio.TimeoutError):
939
+ raise RuntimeError("Timeout while waiting for audio chunk")
940
+ raise RuntimeError(f"Failed to generate audio:\n{e}")
941
+ finally:
942
+ self._close()
943
+
944
+ def _close(self) -> None:
945
+ """Closes the context. Automatically called when a done message is received for this context."""
946
+ self._websocket._remove_context(self._context_id)
947
+
948
+ async def __aenter__(self):
949
+ return self
950
+
951
+ async def __aexit__(
952
+ self,
953
+ exc_type: Union[type, None],
954
+ exc: Union[BaseException, None],
955
+ exc_tb: Union[TracebackType, None],
956
+ ):
957
+ self._close()
958
+
959
+ def __del__(self):
960
+ self._close()
961
+
962
+
782
963
  class _AsyncWebSocket(_WebSocket):
783
964
  """This class contains methods to generate audio using WebSocket asynchronously."""
784
965
 
@@ -787,19 +968,45 @@ class _AsyncWebSocket(_WebSocket):
787
968
  ws_url: str,
788
969
  api_key: str,
789
970
  cartesia_version: str,
971
+ timeout: float,
790
972
  get_session: Callable[[], Optional[aiohttp.ClientSession]],
791
973
  ):
974
+ """
975
+ Args:
976
+ ws_url: The WebSocket URL for the Cartesia API.
977
+ api_key: The API key to use for authorization.
978
+ cartesia_version: The version of the Cartesia API to use.
979
+ timeout: The timeout for responses on the WebSocket in seconds.
980
+ get_session: A function that returns an aiohttp.ClientSession object.
981
+ """
792
982
  super().__init__(ws_url, api_key, cartesia_version)
983
+ self.timeout = timeout
793
984
  self._get_session = get_session
794
985
  self.websocket = None
986
+ self._context_queues: Dict[str, asyncio.Queue] = {}
987
+ self._processing_task: asyncio.Task = None
988
+
989
+ def __del__(self):
990
+ try:
991
+ loop = asyncio.get_running_loop()
992
+ except RuntimeError:
993
+ loop = None
994
+
995
+ if loop is None:
996
+ asyncio.run(self.close())
997
+ elif loop.is_running():
998
+ loop.create_task(self.close())
795
999
 
796
1000
  async def connect(self):
797
1001
  if self.websocket is None or self._is_websocket_closed():
798
1002
  route = "tts/websocket"
799
1003
  session = await self._get_session()
800
- self.websocket = await session.ws_connect(
801
- f"{self.ws_url}/{route}?api_key={self.api_key}&cartesia_version={self.cartesia_version}"
802
- )
1004
+ try:
1005
+ self.websocket = await session.ws_connect(
1006
+ f"{self.ws_url}/{route}?api_key={self.api_key}&cartesia_version={self.cartesia_version}"
1007
+ )
1008
+ except Exception as e:
1009
+ raise RuntimeError(f"Failed to connect to WebSocket. {e}")
803
1010
 
804
1011
  def _is_websocket_closed(self):
805
1012
  return self.websocket.closed
@@ -808,6 +1015,25 @@ class _AsyncWebSocket(_WebSocket):
808
1015
  """This method closes the websocket connection. *Highly* recommended to call this method when done."""
809
1016
  if self.websocket is not None and not self._is_websocket_closed():
810
1017
  await self.websocket.close()
1018
+ if self._processing_task:
1019
+ self._processing_task.cancel()
1020
+ try:
1021
+ self._processing_task = None
1022
+ except asyncio.CancelledError:
1023
+ pass
1024
+ except TypeError as e:
1025
+ # Ignore the error if the task is already cancelled
1026
+ # For some reason we are getting None responses
1027
+ # TODO: This needs to be fixed - we need to think about why we are getting None responses.
1028
+ if "Received message 256:None" not in str(e):
1029
+ raise e
1030
+
1031
+ for context_id in list(self._context_queues.keys()):
1032
+ self._remove_context(context_id)
1033
+
1034
+ self._context_queues.clear()
1035
+ self._processing_task = None
1036
+ self.websocket = None
811
1037
 
812
1038
  async def send(
813
1039
  self,
@@ -819,32 +1045,26 @@ class _AsyncWebSocket(_WebSocket):
819
1045
  context_id: Optional[str] = None,
820
1046
  duration: Optional[int] = None,
821
1047
  language: Optional[str] = None,
822
- stream: Optional[bool] = True,
1048
+ stream: bool = True,
823
1049
  ) -> Union[bytes, AsyncGenerator[bytes, None]]:
824
- await self.connect()
825
-
826
1050
  if context_id is None:
827
- context_id = uuid.uuid4().hex
828
-
829
- voice = self._validate_and_construct_voice(voice_id, voice_embedding)
830
-
831
- request_body = {
832
- "model_id": model_id,
833
- "transcript": transcript,
834
- "voice": voice,
835
- "output_format": {
836
- "container": output_format["container"],
837
- "encoding": output_format["encoding"],
838
- "sample_rate": output_format["sample_rate"],
839
- },
840
- "context_id": context_id,
841
- "language": language,
842
- }
843
-
844
- if duration is not None:
845
- request_body["duration"] = duration
1051
+ context_id = str(uuid.uuid4())
1052
+
1053
+ ctx = self.context(context_id)
1054
+
1055
+ await ctx.send(
1056
+ model_id=model_id,
1057
+ transcript=transcript,
1058
+ output_format=output_format,
1059
+ voice_id=voice_id,
1060
+ voice_embedding=voice_embedding,
1061
+ context_id=context_id,
1062
+ duration=duration,
1063
+ language=language,
1064
+ continue_=False,
1065
+ )
846
1066
 
847
- generator = self._websocket_generator(request_body)
1067
+ generator = ctx.receive()
848
1068
 
849
1069
  if stream:
850
1070
  return generator
@@ -855,35 +1075,51 @@ class _AsyncWebSocket(_WebSocket):
855
1075
 
856
1076
  return {"audio": b"".join(chunks), "context_id": context_id}
857
1077
 
858
- async def _websocket_generator(self, request_body: Dict[str, Any]):
859
- await self.websocket.send_json(request_body)
860
-
1078
+ async def _process_responses(self):
861
1079
  try:
862
- response = None
863
1080
  while True:
864
1081
  response = await self.websocket.receive_json()
865
- if "error" in response:
866
- raise RuntimeError(f"Error generating audio:\n{response['error']}")
867
- if response["done"]:
868
- break
869
-
870
- yield self._convert_response(response=response, include_context_id=True)
1082
+ if response["context_id"]:
1083
+ context_id = response["context_id"]
1084
+ if context_id in self._context_queues:
1085
+ await self._context_queues[context_id].put(response)
871
1086
  except Exception as e:
872
- # Close the websocket connection if an error occurs.
873
- if self.websocket and not self._is_websocket_closed():
874
- await self.websocket.close()
875
- error_msg_end = "" if response is None else f": {await response.text()}"
876
- raise RuntimeError(f"Failed to generate audio. {error_msg_end}") from e
1087
+ self._error = e
1088
+ raise e
1089
+
1090
+ async def _get_message(self, context_id: str, timeout: float) -> Dict[str, Any]:
1091
+ if context_id not in self._context_queues:
1092
+ raise ValueError(f"Context ID {context_id} not found.")
1093
+ return await asyncio.wait_for(self._context_queues[context_id].get(), timeout=timeout)
1094
+
1095
+ def _remove_context(self, context_id: str):
1096
+ if context_id in self._context_queues:
1097
+ del self._context_queues[context_id]
1098
+
1099
+ def _dispatch_listener(self):
1100
+ if self._processing_task is None or self._processing_task.done():
1101
+ self._processing_task = asyncio.create_task(self._process_responses())
1102
+
1103
+ def context(self, context_id: Optional[str] = None) -> _AsyncTTSContext:
1104
+ if context_id in self._context_queues:
1105
+ raise ValueError(f"AsyncContext for context ID {context_id} already exists.")
1106
+ if context_id is None:
1107
+ context_id = str(uuid.uuid4())
1108
+ if context_id not in self._context_queues:
1109
+ self._context_queues[context_id] = asyncio.Queue()
1110
+ return _AsyncTTSContext(context_id, self, self.timeout)
877
1111
 
878
1112
 
879
1113
  class AsyncTTS(TTS):
880
- def __init__(self, api_key, timeout, get_session):
881
- super().__init__(api_key, timeout)
1114
+ def __init__(self, api_key, base_url, timeout, get_session):
1115
+ super().__init__(api_key, base_url, timeout)
882
1116
  self._get_session = get_session
883
1117
  self._sse_class = _AsyncSSE(self._http_url(), self.headers, self.timeout, get_session)
884
1118
  self.sse = self._sse_class.send
885
1119
 
886
1120
  async def websocket(self) -> _AsyncWebSocket:
887
- ws = _AsyncWebSocket(self._ws_url(), self.api_key, self.cartesia_version, self._get_session)
1121
+ ws = _AsyncWebSocket(
1122
+ self._ws_url(), self.api_key, self.cartesia_version, self.timeout, self._get_session
1123
+ )
888
1124
  await ws.connect()
889
1125
  return ws
@@ -17,6 +17,8 @@ def deprecated(
17
17
  local_vars = locals()
18
18
 
19
19
  def fn(func: TCallable) -> TCallable:
20
+ if isinstance(func, classmethod):
21
+ func = func.__func__
20
22
  msg = _get_deprecated_msg(func, reason, vdeprecated, vremove, replacement)
21
23
  warnings.warn(msg, DeprecationWarning)
22
24
  return func
cartesia/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "1.0.3"
1
+ __version__ = "1.0.4"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cartesia
3
- Version: 1.0.3
3
+ Version: 1.0.4
4
4
  Summary: The official Python library for the Cartesia API.
5
5
  Home-page:
6
6
  Author: Cartesia, Inc.
@@ -244,6 +244,101 @@ p.terminate()
244
244
  ws.close() # Close the websocket connection
245
245
  ```
246
246
 
247
+ #### Conditioning speech on previous generations using WebSocket
248
+
249
+ In some cases, input text may need to be streamed in. In these cases, it would be slow to wait for all the text to buffer before sending it to Cartesia's TTS service.
250
+
251
+ To mitigate this, Cartesia offers audio continuations. In this setting, users can send input text, as it becomes available, over a websocket connection.
252
+
253
+ To do this, we will create a `context` and sending multiple requests without awaiting the response. Then you can listen to the responses in the order they were sent.
254
+
255
+ Each `context` will be closed automatically after 5 seconds of inactivity or when the `no_more_inputs` method is called. `no_more_inputs` sends a request with the `continue_=False`, which indicates no more inputs will be sent over this context
256
+
257
+ ```python
258
+ import asyncio
259
+ import os
260
+ import pyaudio
261
+ from cartesia import AsyncCartesia
262
+
263
+ async def send_transcripts(ctx):
264
+ # Check out voice IDs by calling `client.voices.list()` or on [play.cartesia.ai](https://play.cartesia.ai/)
265
+ voice_id = "87748186-23bb-4158-a1eb-332911b0b708"
266
+
267
+ # You can check out our models at [docs.cartesia.ai](https://docs.cartesia.ai/getting-started/available-models).
268
+ model_id = "sonic-english"
269
+
270
+ # You can find the supported `output_format`s in our [API Reference](https://docs.cartesia.ai/api-reference/endpoints/stream-speech-server-sent-events).
271
+ output_format = {
272
+ "container": "raw",
273
+ "encoding": "pcm_f32le",
274
+ "sample_rate": 44100,
275
+ }
276
+
277
+ transcripts = [
278
+ "Sonic and Yoshi team up in a dimension-hopping adventure! ",
279
+ "Racing through twisting zones, they dodge Eggman's badniks and solve ancient puzzles. ",
280
+ "In the Echoing Caverns, they find the Harmonic Crystal, unlocking new powers. ",
281
+ "Sonic's speed creates sound waves, while Yoshi's eggs become sonic bolts. ",
282
+ "As they near Eggman's lair, our heroes charge their abilities for an epic boss battle. ",
283
+ "Get ready to spin, jump, and sound-blast your way to victory in this high-octane crossover!"
284
+ ]
285
+
286
+ for transcript in transcripts:
287
+ # Send text inputs as they become available
288
+ await ctx.send(
289
+ model_id=model_id,
290
+ transcript=transcript,
291
+ voice_id=voice_id,
292
+ continue_=True,
293
+ output_format=output_format,
294
+ )
295
+
296
+ # Indicate that no more inputs will be sent. Otherwise, the context will close after 5 seconds of inactivity.
297
+ await ctx.no_more_inputs()
298
+
299
+ async def receive_and_play_audio(ctx):
300
+ p = pyaudio.PyAudio()
301
+ stream = None
302
+ rate = 44100
303
+
304
+ async for output in ctx.receive():
305
+ buffer = output["audio"]
306
+
307
+ if not stream:
308
+ stream = p.open(
309
+ format=pyaudio.paFloat32,
310
+ channels=1,
311
+ rate=rate,
312
+ output=True
313
+ )
314
+
315
+ stream.write(buffer)
316
+
317
+ stream.stop_stream()
318
+ stream.close()
319
+ p.terminate()
320
+
321
+ async def stream_and_listen():
322
+ client = AsyncCartesia(api_key=os.environ.get("CARTESIA_API_KEY"))
323
+
324
+ # Set up the websocket connection
325
+ ws = await client.tts.websocket()
326
+
327
+ # Create a context to send and receive audio
328
+ ctx = ws.context() # Generates a random context ID if not provided
329
+
330
+ send_task = asyncio.create_task(send_transcripts(ctx))
331
+ listen_task = asyncio.create_task(receive_and_play_audio(ctx))
332
+
333
+ # Call the two coroutine tasks concurrently
334
+ await asyncio.gather(send_task, listen_task)
335
+
336
+ await ws.close()
337
+ await client.close()
338
+
339
+ asyncio.run(stream_and_listen())
340
+ ```
341
+
247
342
  ### Multilingual Text-to-Speech [Alpha]
248
343
 
249
344
  You can use our `sonic-multilingual` model to generate audio in multiple languages. The languages supported are available at [docs.cartesia.ai](https://docs.cartesia.ai/getting-started/available-models).
@@ -0,0 +1,11 @@
1
+ cartesia/__init__.py,sha256=jMIf2O7dTGxvTA5AfXtmh1H_EGfMtQseR5wXrjNRbLs,93
2
+ cartesia/_types.py,sha256=tO3Nef_V78TDMKDuIv_wsQLkxoSvYG4bdzFkMGXUFho,3765
3
+ cartesia/client.py,sha256=UCNTAU8eVzb-o-bygxfQQXWTDov_FX8dbAQdn7a8Hr0,41458
4
+ cartesia/version.py,sha256=acuR_XSJzp4OrQ5T8-Ac5gYe48mUwObuwjRmisFmZ7k,22
5
+ cartesia/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ cartesia/utils/deprecated.py,sha256=2cXvGtrxhPeUZA5LWy2n_U5OFLDv7SHeFtzqhjSJGyk,1674
7
+ cartesia/utils/retry.py,sha256=nuwWRfu3MOVTxIQMLjYf6WLaxSlnu_GdE3QjTV0zisQ,3339
8
+ cartesia-1.0.4.dist-info/METADATA,sha256=N7NoGr6XBtmLI6EHsG3efw0QNJ7uhV_E9HV8uqTYfQM,15991
9
+ cartesia-1.0.4.dist-info/WHEEL,sha256=DZajD4pwLWue70CAfc7YaxT1wLUciNBvN_TTcvXpltE,110
10
+ cartesia-1.0.4.dist-info/top_level.txt,sha256=rTX4HnnCegMxl1FK9czpVC7GAvf3SwDzPG65qP-BS4w,9
11
+ cartesia-1.0.4.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- cartesia/__init__.py,sha256=jMIf2O7dTGxvTA5AfXtmh1H_EGfMtQseR5wXrjNRbLs,93
2
- cartesia/_types.py,sha256=msXRqNwVx_mbcLIQgRJYEl8U-hO9LRPWmscnX89cBCY,3747
3
- cartesia/client.py,sha256=jMlFDPRtKVDelqevHlv7YZJgOES3ws9BFN_6uUyN0W8,32720
4
- cartesia/version.py,sha256=2plzdEEb24FLjE2I2XyBBcJEPYWHccNL4SgtLC_6erg,22
5
- cartesia/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- cartesia/utils/deprecated.py,sha256=QSM-ld_g1r-JlT571SSt6-w650jVm9mJmmI2MSScLPw,1599
7
- cartesia/utils/retry.py,sha256=nuwWRfu3MOVTxIQMLjYf6WLaxSlnu_GdE3QjTV0zisQ,3339
8
- cartesia-1.0.3.dist-info/METADATA,sha256=y5_HREGB417qL69qMFYtKrIRQAQJ1WDxqObaAg6-V6U,12394
9
- cartesia-1.0.3.dist-info/WHEEL,sha256=DZajD4pwLWue70CAfc7YaxT1wLUciNBvN_TTcvXpltE,110
10
- cartesia-1.0.3.dist-info/top_level.txt,sha256=rTX4HnnCegMxl1FK9czpVC7GAvf3SwDzPG65qP-BS4w,9
11
- cartesia-1.0.3.dist-info/RECORD,,