cartesia 1.4.0__py3-none-any.whl → 2.0.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. cartesia/__init__.py +292 -3
  2. cartesia/api_status/__init__.py +6 -0
  3. cartesia/api_status/client.py +104 -0
  4. cartesia/api_status/requests/__init__.py +5 -0
  5. cartesia/api_status/requests/api_info.py +8 -0
  6. cartesia/api_status/types/__init__.py +5 -0
  7. cartesia/api_status/types/api_info.py +20 -0
  8. cartesia/base_client.py +160 -0
  9. cartesia/client.py +163 -40
  10. cartesia/core/__init__.py +47 -0
  11. cartesia/core/api_error.py +15 -0
  12. cartesia/core/client_wrapper.py +55 -0
  13. cartesia/core/datetime_utils.py +28 -0
  14. cartesia/core/file.py +67 -0
  15. cartesia/core/http_client.py +499 -0
  16. cartesia/core/jsonable_encoder.py +101 -0
  17. cartesia/core/pydantic_utilities.py +296 -0
  18. cartesia/core/query_encoder.py +58 -0
  19. cartesia/core/remove_none_from_dict.py +11 -0
  20. cartesia/core/request_options.py +35 -0
  21. cartesia/core/serialization.py +272 -0
  22. cartesia/datasets/__init__.py +24 -0
  23. cartesia/datasets/client.py +392 -0
  24. cartesia/datasets/requests/__init__.py +15 -0
  25. cartesia/datasets/requests/create_dataset_request.py +7 -0
  26. cartesia/datasets/requests/dataset.py +9 -0
  27. cartesia/datasets/requests/dataset_file.py +9 -0
  28. cartesia/datasets/requests/paginated_dataset_files.py +10 -0
  29. cartesia/datasets/requests/paginated_datasets.py +10 -0
  30. cartesia/datasets/types/__init__.py +17 -0
  31. cartesia/datasets/types/create_dataset_request.py +19 -0
  32. cartesia/datasets/types/dataset.py +21 -0
  33. cartesia/datasets/types/dataset_file.py +21 -0
  34. cartesia/datasets/types/file_purpose.py +5 -0
  35. cartesia/datasets/types/paginated_dataset_files.py +21 -0
  36. cartesia/datasets/types/paginated_datasets.py +21 -0
  37. cartesia/embedding/__init__.py +5 -0
  38. cartesia/embedding/types/__init__.py +5 -0
  39. cartesia/embedding/types/embedding.py +201 -0
  40. cartesia/environment.py +7 -0
  41. cartesia/infill/__init__.py +2 -0
  42. cartesia/infill/client.py +318 -0
  43. cartesia/tts/__init__.py +167 -0
  44. cartesia/{_async_websocket.py → tts/_async_websocket.py} +159 -84
  45. cartesia/tts/_websocket.py +430 -0
  46. cartesia/tts/client.py +407 -0
  47. cartesia/tts/requests/__init__.py +76 -0
  48. cartesia/tts/requests/cancel_context_request.py +17 -0
  49. cartesia/tts/requests/controls.py +11 -0
  50. cartesia/tts/requests/generation_request.py +53 -0
  51. cartesia/tts/requests/mp_3_output_format.py +11 -0
  52. cartesia/tts/requests/output_format.py +30 -0
  53. cartesia/tts/requests/phoneme_timestamps.py +10 -0
  54. cartesia/tts/requests/raw_output_format.py +11 -0
  55. cartesia/tts/requests/speed.py +7 -0
  56. cartesia/tts/requests/tts_request.py +24 -0
  57. cartesia/tts/requests/tts_request_embedding_specifier.py +16 -0
  58. cartesia/tts/requests/tts_request_id_specifier.py +16 -0
  59. cartesia/tts/requests/tts_request_voice_specifier.py +7 -0
  60. cartesia/tts/requests/wav_output_format.py +7 -0
  61. cartesia/tts/requests/web_socket_base_response.py +11 -0
  62. cartesia/tts/requests/web_socket_chunk_response.py +8 -0
  63. cartesia/tts/requests/web_socket_done_response.py +7 -0
  64. cartesia/tts/requests/web_socket_error_response.py +7 -0
  65. cartesia/tts/requests/web_socket_flush_done_response.py +9 -0
  66. cartesia/tts/requests/web_socket_phoneme_timestamps_response.py +9 -0
  67. cartesia/tts/requests/web_socket_raw_output_format.py +11 -0
  68. cartesia/tts/requests/web_socket_request.py +7 -0
  69. cartesia/tts/requests/web_socket_response.py +69 -0
  70. cartesia/tts/requests/web_socket_stream_options.py +8 -0
  71. cartesia/tts/requests/web_socket_timestamps_response.py +9 -0
  72. cartesia/tts/requests/web_socket_tts_output.py +18 -0
  73. cartesia/tts/requests/web_socket_tts_request.py +24 -0
  74. cartesia/tts/requests/word_timestamps.py +10 -0
  75. cartesia/tts/socket_client.py +302 -0
  76. cartesia/tts/types/__init__.py +90 -0
  77. cartesia/tts/types/cancel_context_request.py +28 -0
  78. cartesia/tts/types/context_id.py +3 -0
  79. cartesia/tts/types/controls.py +22 -0
  80. cartesia/tts/types/emotion.py +29 -0
  81. cartesia/tts/types/flush_id.py +3 -0
  82. cartesia/tts/types/generation_request.py +66 -0
  83. cartesia/tts/types/mp_3_output_format.py +23 -0
  84. cartesia/tts/types/natural_specifier.py +5 -0
  85. cartesia/tts/types/numerical_specifier.py +3 -0
  86. cartesia/tts/types/output_format.py +58 -0
  87. cartesia/tts/types/phoneme_timestamps.py +21 -0
  88. cartesia/tts/types/raw_encoding.py +5 -0
  89. cartesia/tts/types/raw_output_format.py +22 -0
  90. cartesia/tts/types/speed.py +7 -0
  91. cartesia/tts/types/supported_language.py +7 -0
  92. cartesia/tts/types/tts_request.py +35 -0
  93. cartesia/tts/types/tts_request_embedding_specifier.py +27 -0
  94. cartesia/tts/types/tts_request_id_specifier.py +27 -0
  95. cartesia/tts/types/tts_request_voice_specifier.py +7 -0
  96. cartesia/tts/types/wav_output_format.py +17 -0
  97. cartesia/tts/types/web_socket_base_response.py +22 -0
  98. cartesia/tts/types/web_socket_chunk_response.py +20 -0
  99. cartesia/tts/types/web_socket_done_response.py +17 -0
  100. cartesia/tts/types/web_socket_error_response.py +19 -0
  101. cartesia/tts/types/web_socket_flush_done_response.py +21 -0
  102. cartesia/tts/types/web_socket_phoneme_timestamps_response.py +20 -0
  103. cartesia/tts/types/web_socket_raw_output_format.py +22 -0
  104. cartesia/tts/types/web_socket_request.py +7 -0
  105. cartesia/tts/types/web_socket_response.py +124 -0
  106. cartesia/tts/types/web_socket_stream_options.py +19 -0
  107. cartesia/tts/types/web_socket_timestamps_response.py +20 -0
  108. cartesia/tts/types/web_socket_tts_output.py +27 -0
  109. cartesia/tts/types/web_socket_tts_request.py +36 -0
  110. cartesia/tts/types/word_timestamps.py +21 -0
  111. cartesia/tts/utils/tts.py +64 -0
  112. cartesia/tts/utils/types.py +70 -0
  113. cartesia/version.py +3 -1
  114. cartesia/voice_changer/__init__.py +27 -0
  115. cartesia/voice_changer/client.py +395 -0
  116. cartesia/voice_changer/requests/__init__.py +15 -0
  117. cartesia/voice_changer/requests/streaming_response.py +36 -0
  118. cartesia/voice_changer/types/__init__.py +17 -0
  119. cartesia/voice_changer/types/output_format_container.py +5 -0
  120. cartesia/voice_changer/types/streaming_response.py +62 -0
  121. cartesia/voices/__init__.py +71 -0
  122. cartesia/voices/client.py +1053 -0
  123. cartesia/voices/requests/__init__.py +27 -0
  124. cartesia/voices/requests/create_voice_request.py +23 -0
  125. cartesia/voices/requests/embedding_response.py +8 -0
  126. cartesia/voices/requests/embedding_specifier.py +10 -0
  127. cartesia/voices/requests/id_specifier.py +10 -0
  128. cartesia/voices/requests/localize_dialect.py +8 -0
  129. cartesia/voices/requests/localize_voice_request.py +15 -0
  130. cartesia/voices/requests/mix_voice_specifier.py +7 -0
  131. cartesia/voices/requests/mix_voices_request.py +9 -0
  132. cartesia/voices/requests/update_voice_request.py +15 -0
  133. cartesia/voices/requests/voice.py +39 -0
  134. cartesia/voices/requests/voice_metadata.py +36 -0
  135. cartesia/voices/types/__init__.py +45 -0
  136. cartesia/voices/types/base_voice_id.py +5 -0
  137. cartesia/voices/types/clone_mode.py +5 -0
  138. cartesia/voices/types/create_voice_request.py +34 -0
  139. cartesia/voices/types/embedding_response.py +20 -0
  140. cartesia/voices/types/embedding_specifier.py +22 -0
  141. cartesia/voices/types/gender.py +5 -0
  142. cartesia/voices/types/id_specifier.py +22 -0
  143. cartesia/voices/types/localize_dialect.py +8 -0
  144. cartesia/voices/types/localize_english_dialect.py +5 -0
  145. cartesia/voices/types/localize_portuguese_dialect.py +5 -0
  146. cartesia/voices/types/localize_spanish_dialect.py +5 -0
  147. cartesia/voices/types/localize_target_language.py +7 -0
  148. cartesia/voices/types/localize_voice_request.py +26 -0
  149. cartesia/voices/types/mix_voice_specifier.py +7 -0
  150. cartesia/voices/types/mix_voices_request.py +20 -0
  151. cartesia/voices/types/update_voice_request.py +27 -0
  152. cartesia/voices/types/voice.py +50 -0
  153. cartesia/voices/types/voice_id.py +3 -0
  154. cartesia/voices/types/voice_metadata.py +48 -0
  155. cartesia/voices/types/weight.py +3 -0
  156. cartesia-2.0.0a2.dist-info/METADATA +307 -0
  157. cartesia-2.0.0a2.dist-info/RECORD +160 -0
  158. {cartesia-1.4.0.dist-info → cartesia-2.0.0a2.dist-info}/WHEEL +1 -1
  159. cartesia/_async_sse.py +0 -95
  160. cartesia/_logger.py +0 -3
  161. cartesia/_sse.py +0 -143
  162. cartesia/_types.py +0 -70
  163. cartesia/_websocket.py +0 -358
  164. cartesia/async_client.py +0 -82
  165. cartesia/async_tts.py +0 -176
  166. cartesia/resource.py +0 -44
  167. cartesia/tts.py +0 -292
  168. cartesia/utils/deprecated.py +0 -55
  169. cartesia/utils/retry.py +0 -87
  170. cartesia/utils/tts.py +0 -78
  171. cartesia/voices.py +0 -204
  172. cartesia-1.4.0.dist-info/METADATA +0 -663
  173. cartesia-1.4.0.dist-info/RECORD +0 -23
  174. cartesia-1.4.0.dist-info/licenses/LICENSE.md +0 -21
  175. /cartesia/{utils/__init__.py → py.typed} +0 -0
  176. /cartesia/{_constants.py → tts/utils/constants.py} +0 -0
@@ -1,4 +1,6 @@
1
1
  import asyncio
2
+ import json
3
+ import typing
2
4
  import uuid
3
5
  from collections import defaultdict
4
6
  from types import TracebackType
@@ -6,11 +8,26 @@ from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
6
8
 
7
9
  import aiohttp
8
10
 
9
- from cartesia._constants import DEFAULT_MODEL_ID, DEFAULT_OUTPUT_FORMAT, DEFAULT_VOICE_EMBEDDING
10
- from cartesia._types import OutputFormat, VoiceControls
11
- from cartesia._websocket import _WebSocket
12
- from cartesia.tts import TTS
13
- from cartesia.utils.tts import _construct_tts_request
11
+ from cartesia.tts.requests import TtsRequestVoiceSpecifierParams
12
+ from cartesia.tts.requests.output_format import OutputFormatParams
13
+ from cartesia.tts.types import (
14
+ WebSocketResponse,
15
+ WebSocketResponse_Done,
16
+ WebSocketResponse_Error,
17
+ WebSocketResponse_FlushDone,
18
+ WebSocketTtsOutput,
19
+ WordTimestamps,
20
+ )
21
+
22
+ from ..core.pydantic_utilities import parse_obj_as
23
+ from ._websocket import TtsWebsocket
24
+ from .types.generation_request import GenerationRequest
25
+ from .utils.constants import (
26
+ DEFAULT_MODEL_ID,
27
+ DEFAULT_OUTPUT_FORMAT,
28
+ DEFAULT_VOICE_EMBEDDING,
29
+ )
30
+ from .utils.tts import get_output_format
14
31
 
15
32
 
16
33
  class _AsyncTTSContext:
@@ -26,7 +43,9 @@ class _AsyncTTSContext:
26
43
 
27
44
  """
28
45
 
29
- def __init__(self, context_id: str, websocket: "_AsyncWebSocket", timeout: float):
46
+ def __init__(
47
+ self, context_id: str, websocket: "AsyncTtsWebsocket", timeout: float = 30
48
+ ):
30
49
  self._context_id = context_id
31
50
  self._websocket = websocket
32
51
  self.timeout = timeout
@@ -38,60 +57,73 @@ class _AsyncTTSContext:
38
57
 
39
58
  async def send(
40
59
  self,
60
+ *,
41
61
  model_id: str,
42
62
  transcript: str,
43
- output_format: OutputFormat,
44
- voice_id: Optional[str] = None,
45
- voice_embedding: Optional[List[float]] = None,
63
+ output_format: OutputFormatParams,
64
+ voice: TtsRequestVoiceSpecifierParams,
46
65
  context_id: Optional[str] = None,
47
- continue_: bool = False,
48
- flush: bool = False,
49
66
  duration: Optional[int] = None,
50
67
  language: Optional[str] = None,
68
+ stream: bool = True,
51
69
  add_timestamps: bool = False,
52
- _experimental_voice_controls: Optional[VoiceControls] = None,
70
+ continue_: bool = False,
71
+ flush: bool = False,
53
72
  ) -> None:
54
73
  """Send audio generation requests to the WebSocket. The response can be received using the `receive` method.
55
74
 
56
75
  Args:
57
- model_id: The ID of the model to use for generating audio.
58
- transcript: The text to convert to speech.
59
- output_format: A dictionary containing the details of the output format.
60
- voice_id: The ID of the voice to use for generating audio.
61
- voice_embedding: The embedding of the voice to use for generating audio.
62
- context_id: The context ID to use for the request. If not specified, a random context ID will be generated.
63
- continue_: Whether to continue the audio generation from the previous transcript or not.
64
- flush: Whether to trigger a manual flush for the current context's generation.
65
- duration: The duration of the audio in seconds.
66
- language: The language code for the audio request. This can only be used with `model_id = sonic-multilingual`.
67
- add_timestamps: Whether to return word-level timestamps.
68
- _experimental_voice_controls: Experimental voice controls for controlling speed and emotion.
69
- Note: This is an experimental feature and may change rapidly in future releases.
76
+ request: The request to generate audio.
70
77
 
71
78
  Returns:
72
79
  None.
73
80
  """
74
- if context_id is not None and context_id != self._context_id:
75
- raise ValueError("Context ID does not match the context ID of the current context.")
76
- if continue_ and transcript == "" and not flush:
77
- raise ValueError("Transcript cannot be empty when continue_ is True.")
78
-
79
81
  await self._websocket.connect()
80
-
81
- request_body = _construct_tts_request(
82
- model_id=model_id,
83
- transcript=transcript,
84
- output_format=output_format,
85
- voice_id=voice_id,
86
- voice_embedding=voice_embedding,
87
- duration=duration,
88
- language=language,
89
- context_id=self._context_id,
90
- add_timestamps=add_timestamps,
91
- continue_=continue_,
92
- flush=flush,
93
- _experimental_voice_controls=_experimental_voice_controls,
94
- )
82
+ assert self._websocket.websocket is not None, "WebSocket is not connected"
83
+
84
+ request_body = {
85
+ "model_id": model_id,
86
+ "transcript": transcript,
87
+ "output_format": (
88
+ output_format
89
+ if isinstance(output_format, dict)
90
+ else output_format.dict()
91
+ ),
92
+ "voice": (voice if isinstance(voice, dict) else voice.dict()),
93
+ "context_id": self._context_id,
94
+ }
95
+ if context_id is not None:
96
+ request_body["context_id"] = context_id
97
+ if duration is not None:
98
+ request_body["duration"] = duration
99
+ if language is not None:
100
+ request_body["language"] = language
101
+ if stream:
102
+ request_body["stream"] = stream
103
+ if add_timestamps:
104
+ request_body["add_timestamps"] = add_timestamps
105
+ if continue_:
106
+ request_body["continue"] = continue_
107
+ if flush:
108
+ request_body["flush"] = flush
109
+
110
+ if (
111
+ "context_id" in request_body
112
+ and request_body["context_id"] is not None
113
+ and request_body["context_id"] != self._context_id
114
+ ):
115
+ raise ValueError(
116
+ "Context ID does not match the context ID of the current context."
117
+ )
118
+ request_body["context_id"] = self._context_id
119
+
120
+ if (
121
+ "continue" in request_body
122
+ and request_body["continue"]
123
+ and request_body["transcript"] == ""
124
+ and ("flush" in request_body and not request_body["flush"])
125
+ ):
126
+ raise ValueError("Transcript cannot be empty when continue_ is True.")
95
127
 
96
128
  await self._websocket.websocket.send_json(request_body)
97
129
 
@@ -103,8 +135,11 @@ class _AsyncTTSContext:
103
135
  await self.send(
104
136
  model_id=DEFAULT_MODEL_ID,
105
137
  transcript="",
106
- output_format=TTS.get_output_format(DEFAULT_OUTPUT_FORMAT),
107
- voice_embedding=DEFAULT_VOICE_EMBEDDING, # Default voice embedding since it's a required input for now.
138
+ output_format=get_output_format(DEFAULT_OUTPUT_FORMAT),
139
+ voice={
140
+ "mode": "embedding",
141
+ "embedding": DEFAULT_VOICE_EMBEDDING,
142
+ },
108
143
  context_id=self._context_id,
109
144
  continue_=False,
110
145
  )
@@ -114,8 +149,11 @@ class _AsyncTTSContext:
114
149
  await self.send(
115
150
  model_id=DEFAULT_MODEL_ID,
116
151
  transcript="",
117
- output_format=TTS.get_output_format(DEFAULT_OUTPUT_FORMAT),
118
- voice_embedding=DEFAULT_VOICE_EMBEDDING, # Default voice embedding since it's a required input for now.
152
+ output_format=get_output_format(DEFAULT_OUTPUT_FORMAT),
153
+ voice={
154
+ "mode": "embedding",
155
+ "embedding": DEFAULT_VOICE_EMBEDDING,
156
+ },
119
157
  context_id=self._context_id,
120
158
  continue_=True,
121
159
  flush=True,
@@ -134,11 +172,23 @@ class _AsyncTTSContext:
134
172
  response = await self._websocket._get_message(
135
173
  self._context_id, timeout=self.timeout, flush_id=flush_id
136
174
  )
137
- if "error" in response:
138
- raise RuntimeError(f"Error generating audio:\n{response['error']}")
139
- if response.get("flush_done") or response["done"]:
175
+ response_obj = typing.cast(
176
+ WebSocketResponse,
177
+ parse_obj_as(
178
+ type_=WebSocketResponse, object_=response # type: ignore
179
+ ),
180
+ )
181
+ if isinstance(response_obj, WebSocketResponse_Error):
182
+ raise RuntimeError(
183
+ f"Error generating audio:\n{response_obj.error}"
184
+ )
185
+ if isinstance(response_obj, WebSocketResponse_Done) or isinstance(
186
+ response_obj, WebSocketResponse_FlushDone
187
+ ):
140
188
  break
141
- yield self._websocket._convert_response(response, include_context_id=True)
189
+ yield self._websocket._convert_response(
190
+ response_obj, include_context_id=True
191
+ )
142
192
  except Exception as e:
143
193
  if isinstance(e, asyncio.TimeoutError):
144
194
  raise RuntimeError("Timeout while waiting for audio chunk")
@@ -146,7 +196,7 @@ class _AsyncTTSContext:
146
196
 
147
197
  return generator
148
198
 
149
- async def receive(self) -> AsyncGenerator[Dict[str, Any], None]:
199
+ async def receive(self) -> AsyncGenerator[WebSocketTtsOutput, None]:
150
200
  """Receive the audio chunks from the WebSocket. This method is a generator that yields audio chunks.
151
201
 
152
202
  Returns:
@@ -157,11 +207,21 @@ class _AsyncTTSContext:
157
207
  response = await self._websocket._get_message(
158
208
  self._context_id, timeout=self.timeout
159
209
  )
160
- if "error" in response:
161
- raise RuntimeError(f"Error generating audio:\n{response['error']}")
162
- if response["done"]:
210
+ response_obj = typing.cast(
211
+ WebSocketResponse,
212
+ parse_obj_as(
213
+ type_=WebSocketResponse, # type: ignore
214
+ object_=response,
215
+ ),
216
+ )
217
+
218
+ if isinstance(response_obj, WebSocketResponse_Error):
219
+ raise RuntimeError(f"Error generating audio:\n{response_obj.error}")
220
+ if isinstance(response_obj, WebSocketResponse_Done):
163
221
  break
164
- yield self._websocket._convert_response(response, include_context_id=True)
222
+ yield self._websocket._convert_response(
223
+ response_obj, include_context_id=True
224
+ )
165
225
  except Exception as e:
166
226
  if isinstance(e, asyncio.TimeoutError):
167
227
  raise RuntimeError("Timeout while waiting for audio chunk")
@@ -192,7 +252,7 @@ class _AsyncTTSContext:
192
252
  self._close()
193
253
 
194
254
 
195
- class _AsyncWebSocket(_WebSocket):
255
+ class AsyncTtsWebsocket(TtsWebsocket):
196
256
  """This class contains methods to generate audio using WebSocket asynchronously."""
197
257
 
198
258
  def __init__(
@@ -200,8 +260,8 @@ class _AsyncWebSocket(_WebSocket):
200
260
  ws_url: str,
201
261
  api_key: str,
202
262
  cartesia_version: str,
203
- timeout: float,
204
263
  get_session: Callable[[], Optional[aiohttp.ClientSession]],
264
+ timeout: float = 30,
205
265
  ):
206
266
  """
207
267
  Args:
@@ -216,7 +276,7 @@ class _AsyncWebSocket(_WebSocket):
216
276
  self._get_session = get_session
217
277
  self.websocket = None
218
278
  self._context_queues: Dict[str, List[asyncio.Queue]] = {}
219
- self._processing_task: asyncio.Task = None
279
+ self._processing_task: Optional[asyncio.Task] = None
220
280
 
221
281
  def __del__(self):
222
282
  try:
@@ -268,18 +328,17 @@ class _AsyncWebSocket(_WebSocket):
268
328
 
269
329
  async def send(
270
330
  self,
331
+ *,
271
332
  model_id: str,
272
333
  transcript: str,
273
- output_format: OutputFormat,
274
- voice_id: Optional[str] = None,
275
- voice_embedding: Optional[List[float]] = None,
334
+ output_format: OutputFormatParams,
335
+ voice: TtsRequestVoiceSpecifierParams,
276
336
  context_id: Optional[str] = None,
277
337
  duration: Optional[int] = None,
278
338
  language: Optional[str] = None,
279
339
  stream: bool = True,
280
340
  add_timestamps: bool = False,
281
- _experimental_voice_controls: Optional[VoiceControls] = None,
282
- ) -> Union[bytes, AsyncGenerator[bytes, None]]:
341
+ ):
283
342
  """See :meth:`_WebSocket.send` for details."""
284
343
  if context_id is None:
285
344
  context_id = str(uuid.uuid4())
@@ -290,14 +349,12 @@ class _AsyncWebSocket(_WebSocket):
290
349
  model_id=model_id,
291
350
  transcript=transcript,
292
351
  output_format=output_format,
293
- voice_id=voice_id,
294
- voice_embedding=voice_embedding,
352
+ voice=voice,
295
353
  context_id=context_id,
296
354
  duration=duration,
297
355
  language=language,
298
356
  continue_=False,
299
357
  add_timestamps=add_timestamps,
300
- _experimental_voice_controls=_experimental_voice_controls,
301
358
  )
302
359
 
303
360
  generator = ctx.receive()
@@ -305,18 +362,32 @@ class _AsyncWebSocket(_WebSocket):
305
362
  if stream:
306
363
  return generator
307
364
 
308
- chunks = []
309
- word_timestamps = defaultdict(list)
365
+ chunks: typing.List[str] = []
366
+ words: typing.List[str] = []
367
+ start: typing.List[float] = []
368
+ end: typing.List[float] = []
310
369
  async for chunk in generator:
311
- if "audio" in chunk:
312
- chunks.append(chunk["audio"])
313
- if add_timestamps and "word_timestamps" in chunk:
314
- for k, v in chunk["word_timestamps"].items():
315
- word_timestamps[k].extend(v)
316
- out = {"audio": b"".join(chunks), "context_id": context_id}
317
- if add_timestamps:
318
- out["word_timestamps"] = word_timestamps
319
- return out
370
+ if chunk.audio is not None:
371
+ chunks.append(chunk.audio)
372
+ if add_timestamps and chunk.word_timestamps is not None:
373
+ if chunk.word_timestamps is not None:
374
+ words.extend(chunk.word_timestamps.words)
375
+ start.extend(chunk.word_timestamps.start)
376
+ end.extend(chunk.word_timestamps.end)
377
+
378
+ return WebSocketTtsOutput(
379
+ audio=b"".join(chunks), # type: ignore
380
+ context_id=context_id,
381
+ word_timestamps=(
382
+ WordTimestamps(
383
+ words=words,
384
+ start=start,
385
+ end=end,
386
+ )
387
+ if add_timestamps
388
+ else None
389
+ ),
390
+ )
320
391
 
321
392
  async def _process_responses(self):
322
393
  try:
@@ -332,12 +403,14 @@ class _AsyncWebSocket(_WebSocket):
332
403
  raise e
333
404
 
334
405
  async def _get_message(
335
- self, context_id: str, timeout: float, flush_id: Optional[int] = -1
406
+ self, context_id: str, timeout: float, flush_id: int = -1
336
407
  ) -> Dict[str, Any]:
337
408
  if context_id not in self._context_queues:
338
409
  raise ValueError(f"Context ID {context_id} not found.")
339
410
  if len(self._context_queues[context_id]) <= flush_id:
340
- raise ValueError(f"Flush ID {flush_id} not found for context ID {context_id}.")
411
+ raise ValueError(
412
+ f"Flush ID {flush_id} not found for context ID {context_id}."
413
+ )
341
414
  return await asyncio.wait_for(
342
415
  self._context_queues[context_id][flush_id].get(), timeout=timeout
343
416
  )
@@ -350,9 +423,11 @@ class _AsyncWebSocket(_WebSocket):
350
423
  if self._processing_task is None or self._processing_task.done():
351
424
  self._processing_task = asyncio.create_task(self._process_responses())
352
425
 
353
- def context(self, context_id: Optional[str] = None) -> _AsyncTTSContext:
426
+ def context(self, context_id: Optional[str] = None):
354
427
  if context_id in self._context_queues:
355
- raise ValueError(f"AsyncContext for context ID {context_id} already exists.")
428
+ raise ValueError(
429
+ f"AsyncContext for context ID {context_id} already exists."
430
+ )
356
431
  if context_id is None:
357
432
  context_id = str(uuid.uuid4())
358
433
  if context_id not in self._context_queues: