cartesia 1.4.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. cartesia/__init__.py +302 -3
  2. cartesia/api_status/__init__.py +6 -0
  3. cartesia/api_status/client.py +104 -0
  4. cartesia/api_status/requests/__init__.py +5 -0
  5. cartesia/api_status/requests/api_info.py +8 -0
  6. cartesia/api_status/types/__init__.py +5 -0
  7. cartesia/api_status/types/api_info.py +20 -0
  8. cartesia/base_client.py +156 -0
  9. cartesia/client.py +163 -40
  10. cartesia/core/__init__.py +50 -0
  11. cartesia/core/api_error.py +15 -0
  12. cartesia/core/client_wrapper.py +55 -0
  13. cartesia/core/datetime_utils.py +28 -0
  14. cartesia/core/file.py +67 -0
  15. cartesia/core/http_client.py +499 -0
  16. cartesia/core/jsonable_encoder.py +101 -0
  17. cartesia/core/pagination.py +88 -0
  18. cartesia/core/pydantic_utilities.py +296 -0
  19. cartesia/core/query_encoder.py +58 -0
  20. cartesia/core/remove_none_from_dict.py +11 -0
  21. cartesia/core/request_options.py +35 -0
  22. cartesia/core/serialization.py +272 -0
  23. cartesia/datasets/__init__.py +24 -0
  24. cartesia/datasets/requests/__init__.py +15 -0
  25. cartesia/datasets/requests/create_dataset_request.py +7 -0
  26. cartesia/datasets/requests/dataset.py +9 -0
  27. cartesia/datasets/requests/dataset_file.py +9 -0
  28. cartesia/datasets/requests/paginated_dataset_files.py +10 -0
  29. cartesia/datasets/requests/paginated_datasets.py +10 -0
  30. cartesia/datasets/types/__init__.py +17 -0
  31. cartesia/datasets/types/create_dataset_request.py +19 -0
  32. cartesia/datasets/types/dataset.py +21 -0
  33. cartesia/datasets/types/dataset_file.py +21 -0
  34. cartesia/datasets/types/file_purpose.py +5 -0
  35. cartesia/datasets/types/paginated_dataset_files.py +21 -0
  36. cartesia/datasets/types/paginated_datasets.py +21 -0
  37. cartesia/embedding/__init__.py +5 -0
  38. cartesia/embedding/types/__init__.py +5 -0
  39. cartesia/embedding/types/embedding.py +201 -0
  40. cartesia/environment.py +7 -0
  41. cartesia/infill/__init__.py +2 -0
  42. cartesia/infill/client.py +318 -0
  43. cartesia/tts/__init__.py +167 -0
  44. cartesia/{_async_websocket.py → tts/_async_websocket.py} +212 -85
  45. cartesia/tts/_websocket.py +479 -0
  46. cartesia/tts/client.py +407 -0
  47. cartesia/tts/requests/__init__.py +76 -0
  48. cartesia/tts/requests/cancel_context_request.py +17 -0
  49. cartesia/tts/requests/controls.py +11 -0
  50. cartesia/tts/requests/generation_request.py +58 -0
  51. cartesia/tts/requests/mp_3_output_format.py +11 -0
  52. cartesia/tts/requests/output_format.py +30 -0
  53. cartesia/tts/requests/phoneme_timestamps.py +10 -0
  54. cartesia/tts/requests/raw_output_format.py +11 -0
  55. cartesia/tts/requests/speed.py +7 -0
  56. cartesia/tts/requests/tts_request.py +24 -0
  57. cartesia/tts/requests/tts_request_embedding_specifier.py +16 -0
  58. cartesia/tts/requests/tts_request_id_specifier.py +16 -0
  59. cartesia/tts/requests/tts_request_voice_specifier.py +7 -0
  60. cartesia/tts/requests/wav_output_format.py +7 -0
  61. cartesia/tts/requests/web_socket_base_response.py +11 -0
  62. cartesia/tts/requests/web_socket_chunk_response.py +11 -0
  63. cartesia/tts/requests/web_socket_done_response.py +7 -0
  64. cartesia/tts/requests/web_socket_error_response.py +7 -0
  65. cartesia/tts/requests/web_socket_flush_done_response.py +9 -0
  66. cartesia/tts/requests/web_socket_phoneme_timestamps_response.py +9 -0
  67. cartesia/tts/requests/web_socket_raw_output_format.py +11 -0
  68. cartesia/tts/requests/web_socket_request.py +7 -0
  69. cartesia/tts/requests/web_socket_response.py +70 -0
  70. cartesia/tts/requests/web_socket_stream_options.py +8 -0
  71. cartesia/tts/requests/web_socket_timestamps_response.py +9 -0
  72. cartesia/tts/requests/web_socket_tts_output.py +18 -0
  73. cartesia/tts/requests/web_socket_tts_request.py +25 -0
  74. cartesia/tts/requests/word_timestamps.py +10 -0
  75. cartesia/tts/socket_client.py +302 -0
  76. cartesia/tts/types/__init__.py +90 -0
  77. cartesia/tts/types/cancel_context_request.py +28 -0
  78. cartesia/tts/types/context_id.py +3 -0
  79. cartesia/tts/types/controls.py +22 -0
  80. cartesia/tts/types/emotion.py +34 -0
  81. cartesia/tts/types/flush_id.py +3 -0
  82. cartesia/tts/types/generation_request.py +71 -0
  83. cartesia/tts/types/mp_3_output_format.py +23 -0
  84. cartesia/tts/types/natural_specifier.py +5 -0
  85. cartesia/tts/types/numerical_specifier.py +3 -0
  86. cartesia/tts/types/output_format.py +58 -0
  87. cartesia/tts/types/phoneme_timestamps.py +21 -0
  88. cartesia/tts/types/raw_encoding.py +5 -0
  89. cartesia/tts/types/raw_output_format.py +22 -0
  90. cartesia/tts/types/speed.py +7 -0
  91. cartesia/tts/types/supported_language.py +7 -0
  92. cartesia/tts/types/tts_request.py +35 -0
  93. cartesia/tts/types/tts_request_embedding_specifier.py +27 -0
  94. cartesia/tts/types/tts_request_id_specifier.py +27 -0
  95. cartesia/tts/types/tts_request_voice_specifier.py +7 -0
  96. cartesia/tts/types/wav_output_format.py +17 -0
  97. cartesia/tts/types/web_socket_base_response.py +22 -0
  98. cartesia/tts/types/web_socket_chunk_response.py +22 -0
  99. cartesia/tts/types/web_socket_done_response.py +17 -0
  100. cartesia/tts/types/web_socket_error_response.py +19 -0
  101. cartesia/tts/types/web_socket_flush_done_response.py +21 -0
  102. cartesia/tts/types/web_socket_phoneme_timestamps_response.py +20 -0
  103. cartesia/tts/types/web_socket_raw_output_format.py +22 -0
  104. cartesia/tts/types/web_socket_request.py +7 -0
  105. cartesia/tts/types/web_socket_response.py +125 -0
  106. cartesia/tts/types/web_socket_stream_options.py +19 -0
  107. cartesia/tts/types/web_socket_timestamps_response.py +20 -0
  108. cartesia/tts/types/web_socket_tts_output.py +29 -0
  109. cartesia/tts/types/web_socket_tts_request.py +37 -0
  110. cartesia/tts/types/word_timestamps.py +21 -0
  111. cartesia/{_constants.py → tts/utils/constants.py} +2 -2
  112. cartesia/tts/utils/tts.py +64 -0
  113. cartesia/tts/utils/types.py +70 -0
  114. cartesia/version.py +3 -1
  115. cartesia/voice_changer/__init__.py +27 -0
  116. cartesia/voice_changer/client.py +395 -0
  117. cartesia/voice_changer/requests/__init__.py +15 -0
  118. cartesia/voice_changer/requests/streaming_response.py +38 -0
  119. cartesia/voice_changer/types/__init__.py +17 -0
  120. cartesia/voice_changer/types/output_format_container.py +5 -0
  121. cartesia/voice_changer/types/streaming_response.py +64 -0
  122. cartesia/voices/__init__.py +81 -0
  123. cartesia/voices/client.py +1218 -0
  124. cartesia/voices/requests/__init__.py +29 -0
  125. cartesia/voices/requests/create_voice_request.py +23 -0
  126. cartesia/voices/requests/embedding_response.py +8 -0
  127. cartesia/voices/requests/embedding_specifier.py +10 -0
  128. cartesia/voices/requests/get_voices_response.py +24 -0
  129. cartesia/voices/requests/id_specifier.py +10 -0
  130. cartesia/voices/requests/localize_dialect.py +11 -0
  131. cartesia/voices/requests/localize_voice_request.py +28 -0
  132. cartesia/voices/requests/mix_voice_specifier.py +7 -0
  133. cartesia/voices/requests/mix_voices_request.py +9 -0
  134. cartesia/voices/requests/update_voice_request.py +15 -0
  135. cartesia/voices/requests/voice.py +43 -0
  136. cartesia/voices/requests/voice_metadata.py +36 -0
  137. cartesia/voices/types/__init__.py +53 -0
  138. cartesia/voices/types/base_voice_id.py +5 -0
  139. cartesia/voices/types/clone_mode.py +5 -0
  140. cartesia/voices/types/create_voice_request.py +34 -0
  141. cartesia/voices/types/embedding_response.py +20 -0
  142. cartesia/voices/types/embedding_specifier.py +22 -0
  143. cartesia/voices/types/gender.py +5 -0
  144. cartesia/voices/types/gender_presentation.py +5 -0
  145. cartesia/voices/types/get_voices_response.py +34 -0
  146. cartesia/voices/types/id_specifier.py +22 -0
  147. cartesia/voices/types/localize_dialect.py +11 -0
  148. cartesia/voices/types/localize_english_dialect.py +5 -0
  149. cartesia/voices/types/localize_french_dialect.py +5 -0
  150. cartesia/voices/types/localize_portuguese_dialect.py +5 -0
  151. cartesia/voices/types/localize_spanish_dialect.py +5 -0
  152. cartesia/voices/types/localize_target_language.py +7 -0
  153. cartesia/voices/types/localize_voice_request.py +39 -0
  154. cartesia/voices/types/mix_voice_specifier.py +7 -0
  155. cartesia/voices/types/mix_voices_request.py +20 -0
  156. cartesia/voices/types/update_voice_request.py +27 -0
  157. cartesia/voices/types/voice.py +54 -0
  158. cartesia/voices/types/voice_expand_options.py +5 -0
  159. cartesia/voices/types/voice_id.py +3 -0
  160. cartesia/voices/types/voice_metadata.py +48 -0
  161. cartesia/voices/types/weight.py +3 -0
  162. cartesia-2.0.0.dist-info/METADATA +414 -0
  163. cartesia-2.0.0.dist-info/RECORD +165 -0
  164. {cartesia-1.4.0.dist-info → cartesia-2.0.0.dist-info}/WHEEL +1 -1
  165. cartesia/_async_sse.py +0 -95
  166. cartesia/_logger.py +0 -3
  167. cartesia/_sse.py +0 -143
  168. cartesia/_types.py +0 -70
  169. cartesia/_websocket.py +0 -358
  170. cartesia/async_client.py +0 -82
  171. cartesia/async_tts.py +0 -176
  172. cartesia/resource.py +0 -44
  173. cartesia/tts.py +0 -292
  174. cartesia/utils/deprecated.py +0 -55
  175. cartesia/utils/retry.py +0 -87
  176. cartesia/utils/tts.py +0 -78
  177. cartesia/voices.py +0 -204
  178. cartesia-1.4.0.dist-info/METADATA +0 -663
  179. cartesia-1.4.0.dist-info/RECORD +0 -23
  180. cartesia-1.4.0.dist-info/licenses/LICENSE.md +0 -21
  181. /cartesia/{utils/__init__.py → py.typed} +0 -0
@@ -0,0 +1,479 @@
1
+ import base64
2
+ import json
3
+ import typing
4
+ import uuid
5
+ from collections import defaultdict
6
+ from typing import Any, Dict, Generator, Optional, Set, Union
7
+
8
+ try:
9
+ from websockets.sync.client import connect
10
+
11
+ IS_WEBSOCKET_SYNC_AVAILABLE = True
12
+ except ImportError:
13
+ IS_WEBSOCKET_SYNC_AVAILABLE = False
14
+
15
+ from iterators import TimeoutIterator # type: ignore
16
+
17
+ from cartesia.tts.requests import TtsRequestVoiceSpecifierParams
18
+ from cartesia.tts.requests.output_format import OutputFormatParams
19
+ from cartesia.tts.types import (
20
+ WebSocketResponse,
21
+ WebSocketResponse_Chunk,
22
+ WebSocketResponse_Done,
23
+ WebSocketResponse_Error,
24
+ WebSocketResponse_FlushDone,
25
+ WebSocketResponse_PhonemeTimestamps,
26
+ WebSocketResponse_Timestamps,
27
+ WebSocketTtsOutput,
28
+ WordTimestamps,
29
+ PhonemeTimestamps,
30
+ )
31
+
32
+ from ..core.pydantic_utilities import parse_obj_as
33
+ from .types.generation_request import GenerationRequest
34
+
35
+
36
+ class _TTSContext:
37
+ """Manage a single context over a WebSocket.
38
+
39
+ This class can be used to stream inputs, as they become available, to a specific `context_id`. See README for usage.
40
+
41
+ See :class:`_AsyncTTSContext` for asynchronous use cases.
42
+
43
+ Each TTSContext will close automatically when a done message is received for that context. It also closes if there is an error.
44
+ """
45
+
46
+ def __init__(self, context_id: str, websocket: "TtsWebsocket"):
47
+ self._context_id = context_id
48
+ self._websocket = websocket
49
+ self._error = None
50
+
51
+ def __del__(self):
52
+ self._close()
53
+
54
+ @property
55
+ def context_id(self) -> str:
56
+ return self._context_id
57
+
58
+ def send(
59
+ self,
60
+ *,
61
+ model_id: str,
62
+ transcript: typing.Generator[str, None, None],
63
+ output_format: OutputFormatParams,
64
+ voice: TtsRequestVoiceSpecifierParams,
65
+ context_id: Optional[str] = None,
66
+ duration: Optional[int] = None,
67
+ language: Optional[str] = None,
68
+ stream: bool = True,
69
+ add_timestamps: bool = False,
70
+ add_phoneme_timestamps: bool = False,
71
+ use_original_timestamps: bool = False,
72
+ ) -> Generator[bytes, None, None]:
73
+ """Send audio generation requests to the WebSocket and yield responses.
74
+
75
+ Args:
76
+ request: The request to generate audio.
77
+
78
+ Yields:
79
+ Dictionary containing the following key(s):
80
+ - audio: The audio as bytes.
81
+ - context_id: The context ID for the request.
82
+
83
+ Raises:
84
+ ValueError: If provided context_id doesn't match the current context.
85
+ RuntimeError: If there's an error generating audio.
86
+ """
87
+ self._websocket.connect()
88
+ assert self._websocket.websocket is not None, "WebSocket is not connected"
89
+
90
+ request_body = {
91
+ "model_id": model_id,
92
+ "transcript": transcript,
93
+ "output_format": output_format,
94
+ "voice": voice,
95
+ "context_id": self._context_id,
96
+ }
97
+ if context_id is not None:
98
+ request_body["context_id"] = context_id
99
+ if duration is not None:
100
+ request_body["duration"] = duration
101
+ if language is not None:
102
+ request_body["language"] = language
103
+ if stream:
104
+ request_body["stream"] = stream
105
+ if add_timestamps:
106
+ request_body["add_timestamps"] = add_timestamps
107
+ if add_phoneme_timestamps:
108
+ request_body["add_phoneme_timestamps"] = add_phoneme_timestamps
109
+ if use_original_timestamps:
110
+ request_body["use_original_timestamps"] = use_original_timestamps
111
+
112
+ if (
113
+ "context_id" in request_body
114
+ and request_body["context_id"] is not None
115
+ and request_body["context_id"] != self._context_id
116
+ ):
117
+ raise ValueError(
118
+ "Context ID does not match the context ID of the current context."
119
+ )
120
+
121
+ try:
122
+ text_iterator = TimeoutIterator(request_body["transcript"], timeout=0.001)
123
+ next_chunk = next(text_iterator, None)
124
+
125
+ while True:
126
+ # Send the next text chunk to the WebSocket if available
127
+ if (
128
+ next_chunk is not None
129
+ and next_chunk != text_iterator.get_sentinel()
130
+ ):
131
+ request_body["transcript"] = next_chunk
132
+ request_body["continue"] = True
133
+ self._websocket.websocket.send(json.dumps(request_body))
134
+ next_chunk = next(text_iterator, None)
135
+
136
+ try:
137
+ # Receive responses from the WebSocket with a small timeout
138
+ response_obj = typing.cast(
139
+ WebSocketResponse,
140
+ parse_obj_as(
141
+ type_=WebSocketResponse, # type: ignore
142
+ object_=json.loads(
143
+ self._websocket.websocket.recv(timeout=0.001)
144
+ ),
145
+ ),
146
+ )
147
+ if response_obj.context_id != self._context_id:
148
+ pass
149
+ if isinstance(response_obj, WebSocketResponse_Error):
150
+ raise RuntimeError(
151
+ f"Error generating audio:\n{response_obj.error}"
152
+ )
153
+ if isinstance(response_obj, WebSocketResponse_Done):
154
+ break
155
+ if (
156
+ isinstance(response_obj, WebSocketResponse_Chunk)
157
+ or isinstance(response_obj, WebSocketResponse_Timestamps)
158
+ or isinstance(response_obj, WebSocketResponse_PhonemeTimestamps)
159
+ ):
160
+ yield self._websocket._convert_response(
161
+ response_obj, include_context_id=True
162
+ )
163
+ except TimeoutError:
164
+ pass
165
+
166
+ # Continuously receive from WebSocket until the next text chunk is available
167
+ while next_chunk == text_iterator.get_sentinel():
168
+ try:
169
+ response_obj = typing.cast(
170
+ WebSocketResponse,
171
+ parse_obj_as(
172
+ type_=WebSocketResponse, # type: ignore
173
+ object_=json.loads(
174
+ self._websocket.websocket.recv(timeout=0.001)
175
+ ),
176
+ ),
177
+ )
178
+ if response_obj.context_id != self._context_id:
179
+ continue
180
+ if isinstance(response_obj, WebSocketResponse_Error):
181
+ raise RuntimeError(
182
+ f"Error generating audio:\n{response_obj.error}"
183
+ )
184
+ if isinstance(response_obj, WebSocketResponse_Done):
185
+ break
186
+ if (
187
+ isinstance(response_obj, WebSocketResponse_Chunk)
188
+ or isinstance(response_obj, WebSocketResponse_Timestamps)
189
+ or isinstance(
190
+ response_obj, WebSocketResponse_PhonemeTimestamps
191
+ )
192
+ ):
193
+ yield self._websocket._convert_response(
194
+ response_obj, include_context_id=True
195
+ )
196
+ except TimeoutError:
197
+ pass
198
+ next_chunk = next(text_iterator, None)
199
+
200
+ # Send final message if all input text chunks are exhausted
201
+ if next_chunk is None:
202
+ request_body["transcript"] = ""
203
+ request_body["continue"] = False
204
+ self._websocket.websocket.send(json.dumps(request_body))
205
+ break
206
+
207
+ # Receive remaining messages from the WebSocket until "done" is received
208
+ while True:
209
+ response_obj = typing.cast(
210
+ WebSocketResponse,
211
+ parse_obj_as(
212
+ type_=WebSocketResponse, # type: ignore
213
+ object_=json.loads(self._websocket.websocket.recv()),
214
+ ),
215
+ )
216
+ if response_obj.context_id != self._context_id:
217
+ continue
218
+ if isinstance(response_obj, WebSocketResponse_Error):
219
+ raise RuntimeError(f"Error generating audio:\n{response_obj.error}")
220
+ if isinstance(response_obj, WebSocketResponse_Done):
221
+ break
222
+ yield self._websocket._convert_response(
223
+ response_obj, include_context_id=True
224
+ )
225
+
226
+ except Exception as e:
227
+ self._websocket.close()
228
+ raise RuntimeError(f"Failed to generate audio. {e}")
229
+
230
+ def _close(self):
231
+ """Closes the context. Automatically called when a done message is received for this context."""
232
+ self._websocket._remove_context(self._context_id)
233
+
234
+ def is_closed(self):
235
+ """Check if the context is closed or not. Returns True if closed."""
236
+ return self._context_id not in self._websocket._contexts
237
+
238
+
239
+ class TtsWebsocket:
240
+ """This class contains methods to generate audio using WebSocket. Ideal for low-latency audio generation.
241
+
242
+ Usage:
243
+ >>> ws = client.tts.websocket()
244
+ >>> generation_request = GenerationRequest(
245
+ ... model_id="sonic-2",
246
+ ... transcript="Hello world!",
247
+ ... voice_embedding=embedding
248
+ ... output_format={"container": "raw", "encoding": "pcm_f32le", "sample_rate": 44100}
249
+ ... context_id=context_id,
250
+ ... stream=True
251
+ ... )
252
+ >>> for audio_chunk in ws.send(generation_request):
253
+ ... audio = audio_chunk["audio"]
254
+ """
255
+
256
+ def __init__(
257
+ self,
258
+ ws_url: str,
259
+ api_key: str,
260
+ cartesia_version: str,
261
+ ):
262
+ self.ws_url = ws_url
263
+ self.api_key = api_key
264
+ self.cartesia_version = cartesia_version
265
+ self.websocket = None
266
+ self._contexts: Set[str] = set()
267
+
268
+ def __del__(self):
269
+ try:
270
+ self.close()
271
+ except Exception as e:
272
+ raise RuntimeError("Failed to close WebSocket: ", e)
273
+
274
+ def connect(self):
275
+ """This method connects to the WebSocket if it is not already connected.
276
+
277
+ Raises:
278
+ RuntimeError: If the connection to the WebSocket fails.
279
+ """
280
+ if not IS_WEBSOCKET_SYNC_AVAILABLE:
281
+ raise ImportError(
282
+ "The synchronous WebSocket client is not available. Please ensure that you have 'websockets>=12.0' or compatible version installed."
283
+ )
284
+ if self.websocket is None or self._is_websocket_closed():
285
+ route = "tts/websocket"
286
+ try:
287
+ self.websocket = connect(
288
+ f"{self.ws_url}/{route}?api_key={self.api_key}&cartesia_version={self.cartesia_version}"
289
+ )
290
+ except Exception as e:
291
+ # Extract status code if available
292
+ status_code = None
293
+ error_message = str(e)
294
+
295
+ if hasattr(e, 'status') and e.status is not None:
296
+ status_code = e.status
297
+
298
+ # Create a meaningful error message based on status code
299
+ if status_code == 402:
300
+ error_message = "Payment required. Your API key may have insufficient credits or permissions."
301
+ elif status_code == 401:
302
+ error_message = "Unauthorized. Please check your API key."
303
+ elif status_code == 403:
304
+ error_message = "Forbidden. You don't have permission to access this resource."
305
+ elif status_code == 404:
306
+ error_message = "Not found. The requested resource doesn't exist."
307
+
308
+ raise RuntimeError(f"Failed to connect to WebSocket.\nStatus: {status_code}. Error message: {error_message}")
309
+ else:
310
+ raise RuntimeError(f"Failed to connect to WebSocket. {e}")
311
+
312
+ def _is_websocket_closed(self):
313
+ return self.websocket.socket.fileno() == -1
314
+
315
+ def close(self):
316
+ """This method closes the WebSocket connection. *Highly* recommended to call this method when done using the WebSocket."""
317
+ if self.websocket and not self._is_websocket_closed():
318
+ self.websocket.close()
319
+
320
+ if self._contexts:
321
+ self._contexts.clear()
322
+
323
+ def _convert_response(
324
+ self,
325
+ response: typing.Union[
326
+ WebSocketResponse_Chunk,
327
+ WebSocketResponse_Timestamps,
328
+ WebSocketResponse_PhonemeTimestamps,
329
+ WebSocketResponse_FlushDone,
330
+ ],
331
+ include_context_id: bool,
332
+ include_flush_id: bool = False,
333
+ ) -> WebSocketTtsOutput:
334
+ out = {}
335
+ if isinstance(response, WebSocketResponse_Chunk):
336
+ out["audio"] = base64.b64decode(response.data)
337
+ elif isinstance(response, WebSocketResponse_Timestamps):
338
+ out["word_timestamps"] = response.word_timestamps # type: ignore
339
+ elif isinstance(response, WebSocketResponse_PhonemeTimestamps):
340
+ out["phoneme_timestamps"] = response.phoneme_timestamps # type: ignore
341
+ elif include_flush_id and isinstance(response, WebSocketResponse_FlushDone):
342
+ out["flush_done"] = response.flush_done # type: ignore
343
+ out["flush_id"] = response.flush_id # type: ignore
344
+
345
+ if include_context_id and response.context_id:
346
+ out["context_id"] = response.context_id # type: ignore
347
+
348
+ return WebSocketTtsOutput(**out) # type: ignore
349
+
350
+ def send(
351
+ self,
352
+ *,
353
+ model_id: str,
354
+ transcript: str,
355
+ output_format: OutputFormatParams,
356
+ voice: TtsRequestVoiceSpecifierParams,
357
+ context_id: Optional[str] = None,
358
+ duration: Optional[int] = None,
359
+ language: Optional[str] = None,
360
+ stream: bool = True,
361
+ add_timestamps: bool = False,
362
+ add_phoneme_timestamps: bool = False,
363
+ use_original_timestamps: bool = False,
364
+ ):
365
+ """Send a request to the WebSocket to generate audio.
366
+
367
+ Args:
368
+ request: The request to generate audio.
369
+ stream: Whether to stream the audio or not.
370
+
371
+ Returns:
372
+ If `stream` is True, the method returns a generator that yields chunks. Each chunk is a dictionary.
373
+ If `stream` is False, the method returns a dictionary.
374
+ Both the generator and the dictionary contain the following key(s):
375
+ - audio: The audio as bytes.
376
+ - context_id: The context ID for the request.
377
+ """
378
+ self.connect()
379
+
380
+ if context_id is None:
381
+ context_id = str(uuid.uuid4())
382
+
383
+ request_body = {
384
+ "model_id": model_id,
385
+ "transcript": transcript,
386
+ "output_format": output_format,
387
+ "voice": voice,
388
+ "context_id": context_id,
389
+ "duration": duration,
390
+ "language": language,
391
+ "stream": stream,
392
+ "add_timestamps": add_timestamps,
393
+ "add_phoneme_timestamps": add_phoneme_timestamps,
394
+ "use_original_timestamps": use_original_timestamps,
395
+ }
396
+ generator = self._websocket_generator(request_body)
397
+
398
+ if stream:
399
+ return generator
400
+
401
+ chunks: typing.List[str] = []
402
+ words: typing.List[str] = []
403
+ start: typing.List[float] = []
404
+ end: typing.List[float] = []
405
+ phonemes: typing.List[str] = []
406
+ phoneme_start: typing.List[float] = []
407
+ phoneme_end: typing.List[float] = []
408
+ for chunk in generator:
409
+ if chunk.audio is not None:
410
+ chunks.append(chunk.audio)
411
+ if add_timestamps and chunk.word_timestamps is not None:
412
+ if chunk.word_timestamps is not None:
413
+ words.extend(chunk.word_timestamps.words)
414
+ start.extend(chunk.word_timestamps.start)
415
+ end.extend(chunk.word_timestamps.end)
416
+ if add_phoneme_timestamps and chunk.phoneme_timestamps is not None:
417
+ if chunk.phoneme_timestamps is not None:
418
+ phonemes.extend(chunk.phoneme_timestamps.phonemes)
419
+ phoneme_start.extend(chunk.phoneme_timestamps.start)
420
+ phoneme_end.extend(chunk.phoneme_timestamps.end)
421
+
422
+ return WebSocketTtsOutput(
423
+ audio=b"".join(chunks), # type: ignore
424
+ context_id=context_id,
425
+ word_timestamps=(
426
+ WordTimestamps(
427
+ words=words,
428
+ start=start,
429
+ end=end,
430
+ )
431
+ if add_timestamps
432
+ else None
433
+ ),
434
+ phoneme_timestamps=(
435
+ PhonemeTimestamps(
436
+ phonemes=phonemes,
437
+ start=phoneme_start,
438
+ end=phoneme_end,
439
+ )
440
+ if add_phoneme_timestamps
441
+ else None
442
+ ),
443
+ )
444
+
445
+ def _websocket_generator(self, request_body: Dict[str, Any]):
446
+ assert self.websocket is not None, "WebSocket is not connected"
447
+ self.websocket.send(json.dumps(request_body))
448
+
449
+ try:
450
+ while True:
451
+ response_obj = typing.cast(
452
+ WebSocketResponse,
453
+ parse_obj_as(
454
+ type_=WebSocketResponse, # type: ignore
455
+ object_=json.loads(self.websocket.recv()),
456
+ ),
457
+ )
458
+ if isinstance(response_obj, WebSocketResponse_Error):
459
+ raise RuntimeError(f"Error generating audio:\n{response_obj.error}")
460
+ if isinstance(response_obj, WebSocketResponse_Done):
461
+ break
462
+ yield self._convert_response(response_obj, include_context_id=True)
463
+ except Exception as e:
464
+ # Close the websocket connection if an error occurs.
465
+ self.close()
466
+ raise RuntimeError(f"Failed to generate audio. {e}") from e
467
+
468
+ def _remove_context(self, context_id: str):
469
+ if context_id in self._contexts:
470
+ self._contexts.remove(context_id)
471
+
472
+ def context(self, context_id: Optional[str] = None):
473
+ if context_id in self._contexts:
474
+ raise ValueError(f"Context for context ID {context_id} already exists.")
475
+ if context_id is None:
476
+ context_id = str(uuid.uuid4())
477
+ if context_id not in self._contexts:
478
+ self._contexts.add(context_id)
479
+ return _TTSContext(context_id, self)