cartesia 1.3.1__py3-none-any.whl → 2.0.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cartesia/__init__.py +288 -3
- cartesia/api_status/__init__.py +6 -0
- cartesia/api_status/client.py +104 -0
- cartesia/api_status/requests/__init__.py +5 -0
- cartesia/api_status/requests/api_info.py +8 -0
- cartesia/api_status/types/__init__.py +5 -0
- cartesia/api_status/types/api_info.py +20 -0
- cartesia/base_client.py +160 -0
- cartesia/client.py +163 -40
- cartesia/core/__init__.py +47 -0
- cartesia/core/api_error.py +15 -0
- cartesia/core/client_wrapper.py +55 -0
- cartesia/core/datetime_utils.py +28 -0
- cartesia/core/file.py +67 -0
- cartesia/core/http_client.py +499 -0
- cartesia/core/jsonable_encoder.py +101 -0
- cartesia/core/pydantic_utilities.py +296 -0
- cartesia/core/query_encoder.py +58 -0
- cartesia/core/remove_none_from_dict.py +11 -0
- cartesia/core/request_options.py +35 -0
- cartesia/core/serialization.py +272 -0
- cartesia/datasets/__init__.py +24 -0
- cartesia/datasets/client.py +422 -0
- cartesia/datasets/requests/__init__.py +15 -0
- cartesia/datasets/requests/create_dataset_request.py +7 -0
- cartesia/datasets/requests/dataset.py +9 -0
- cartesia/datasets/requests/dataset_file.py +9 -0
- cartesia/datasets/requests/paginated_dataset_files.py +10 -0
- cartesia/datasets/requests/paginated_datasets.py +10 -0
- cartesia/datasets/types/__init__.py +17 -0
- cartesia/datasets/types/create_dataset_request.py +19 -0
- cartesia/datasets/types/dataset.py +21 -0
- cartesia/datasets/types/dataset_file.py +21 -0
- cartesia/datasets/types/file_purpose.py +5 -0
- cartesia/datasets/types/paginated_dataset_files.py +21 -0
- cartesia/datasets/types/paginated_datasets.py +21 -0
- cartesia/embedding/__init__.py +5 -0
- cartesia/embedding/types/__init__.py +5 -0
- cartesia/embedding/types/embedding.py +201 -0
- cartesia/environment.py +7 -0
- cartesia/infill/__init__.py +2 -0
- cartesia/infill/client.py +294 -0
- cartesia/tts/__init__.py +167 -0
- cartesia/{_async_websocket.py → tts/_async_websocket.py} +159 -84
- cartesia/tts/_websocket.py +430 -0
- cartesia/tts/client.py +407 -0
- cartesia/tts/requests/__init__.py +76 -0
- cartesia/tts/requests/cancel_context_request.py +17 -0
- cartesia/tts/requests/controls.py +11 -0
- cartesia/tts/requests/generation_request.py +53 -0
- cartesia/tts/requests/mp_3_output_format.py +11 -0
- cartesia/tts/requests/output_format.py +30 -0
- cartesia/tts/requests/phoneme_timestamps.py +10 -0
- cartesia/tts/requests/raw_output_format.py +11 -0
- cartesia/tts/requests/speed.py +7 -0
- cartesia/tts/requests/tts_request.py +24 -0
- cartesia/tts/requests/tts_request_embedding_specifier.py +16 -0
- cartesia/tts/requests/tts_request_id_specifier.py +16 -0
- cartesia/tts/requests/tts_request_voice_specifier.py +7 -0
- cartesia/tts/requests/wav_output_format.py +7 -0
- cartesia/tts/requests/web_socket_base_response.py +11 -0
- cartesia/tts/requests/web_socket_chunk_response.py +8 -0
- cartesia/tts/requests/web_socket_done_response.py +7 -0
- cartesia/tts/requests/web_socket_error_response.py +7 -0
- cartesia/tts/requests/web_socket_flush_done_response.py +9 -0
- cartesia/tts/requests/web_socket_phoneme_timestamps_response.py +9 -0
- cartesia/tts/requests/web_socket_raw_output_format.py +11 -0
- cartesia/tts/requests/web_socket_request.py +7 -0
- cartesia/tts/requests/web_socket_response.py +69 -0
- cartesia/tts/requests/web_socket_stream_options.py +8 -0
- cartesia/tts/requests/web_socket_timestamps_response.py +9 -0
- cartesia/tts/requests/web_socket_tts_output.py +18 -0
- cartesia/tts/requests/web_socket_tts_request.py +24 -0
- cartesia/tts/requests/word_timestamps.py +10 -0
- cartesia/tts/socket_client.py +302 -0
- cartesia/tts/types/__init__.py +90 -0
- cartesia/tts/types/cancel_context_request.py +28 -0
- cartesia/tts/types/context_id.py +3 -0
- cartesia/tts/types/controls.py +22 -0
- cartesia/tts/types/emotion.py +29 -0
- cartesia/tts/types/flush_id.py +3 -0
- cartesia/tts/types/generation_request.py +66 -0
- cartesia/tts/types/mp_3_output_format.py +23 -0
- cartesia/tts/types/natural_specifier.py +5 -0
- cartesia/tts/types/numerical_specifier.py +3 -0
- cartesia/tts/types/output_format.py +58 -0
- cartesia/tts/types/phoneme_timestamps.py +21 -0
- cartesia/tts/types/raw_encoding.py +5 -0
- cartesia/tts/types/raw_output_format.py +22 -0
- cartesia/tts/types/speed.py +7 -0
- cartesia/tts/types/supported_language.py +7 -0
- cartesia/tts/types/tts_request.py +35 -0
- cartesia/tts/types/tts_request_embedding_specifier.py +27 -0
- cartesia/tts/types/tts_request_id_specifier.py +27 -0
- cartesia/tts/types/tts_request_voice_specifier.py +7 -0
- cartesia/tts/types/wav_output_format.py +17 -0
- cartesia/tts/types/web_socket_base_response.py +22 -0
- cartesia/tts/types/web_socket_chunk_response.py +20 -0
- cartesia/tts/types/web_socket_done_response.py +17 -0
- cartesia/tts/types/web_socket_error_response.py +19 -0
- cartesia/tts/types/web_socket_flush_done_response.py +21 -0
- cartesia/tts/types/web_socket_phoneme_timestamps_response.py +20 -0
- cartesia/tts/types/web_socket_raw_output_format.py +22 -0
- cartesia/tts/types/web_socket_request.py +7 -0
- cartesia/tts/types/web_socket_response.py +124 -0
- cartesia/tts/types/web_socket_stream_options.py +19 -0
- cartesia/tts/types/web_socket_timestamps_response.py +20 -0
- cartesia/tts/types/web_socket_tts_output.py +27 -0
- cartesia/tts/types/web_socket_tts_request.py +36 -0
- cartesia/tts/types/word_timestamps.py +21 -0
- cartesia/tts/utils/tts.py +64 -0
- cartesia/tts/utils/types.py +70 -0
- cartesia/version.py +3 -1
- cartesia/voice_changer/__init__.py +27 -0
- cartesia/voice_changer/client.py +395 -0
- cartesia/voice_changer/requests/__init__.py +15 -0
- cartesia/voice_changer/requests/streaming_response.py +36 -0
- cartesia/voice_changer/types/__init__.py +17 -0
- cartesia/voice_changer/types/output_format_container.py +5 -0
- cartesia/voice_changer/types/streaming_response.py +62 -0
- cartesia/voices/__init__.py +67 -0
- cartesia/voices/client.py +1812 -0
- cartesia/voices/requests/__init__.py +27 -0
- cartesia/voices/requests/create_voice_request.py +21 -0
- cartesia/voices/requests/embedding_response.py +8 -0
- cartesia/voices/requests/embedding_specifier.py +10 -0
- cartesia/voices/requests/id_specifier.py +10 -0
- cartesia/voices/requests/localize_dialect.py +6 -0
- cartesia/voices/requests/localize_voice_request.py +15 -0
- cartesia/voices/requests/mix_voice_specifier.py +7 -0
- cartesia/voices/requests/mix_voices_request.py +9 -0
- cartesia/voices/requests/update_voice_request.py +15 -0
- cartesia/voices/requests/voice.py +39 -0
- cartesia/voices/requests/voice_metadata.py +36 -0
- cartesia/voices/types/__init__.py +41 -0
- cartesia/voices/types/base_voice_id.py +5 -0
- cartesia/voices/types/clone_mode.py +5 -0
- cartesia/voices/types/create_voice_request.py +32 -0
- cartesia/voices/types/embedding_response.py +20 -0
- cartesia/voices/types/embedding_specifier.py +22 -0
- cartesia/voices/types/gender.py +5 -0
- cartesia/voices/types/id_specifier.py +22 -0
- cartesia/voices/types/localize_dialect.py +6 -0
- cartesia/voices/types/localize_english_dialect.py +5 -0
- cartesia/voices/types/localize_target_language.py +7 -0
- cartesia/voices/types/localize_voice_request.py +26 -0
- cartesia/voices/types/mix_voice_specifier.py +7 -0
- cartesia/voices/types/mix_voices_request.py +20 -0
- cartesia/voices/types/update_voice_request.py +27 -0
- cartesia/voices/types/voice.py +50 -0
- cartesia/voices/types/voice_id.py +3 -0
- cartesia/voices/types/voice_metadata.py +48 -0
- cartesia/voices/types/weight.py +3 -0
- cartesia-2.0.0a0.dist-info/METADATA +306 -0
- cartesia-2.0.0a0.dist-info/RECORD +158 -0
- {cartesia-1.3.1.dist-info → cartesia-2.0.0a0.dist-info}/WHEEL +1 -1
- cartesia/_async_sse.py +0 -95
- cartesia/_logger.py +0 -3
- cartesia/_sse.py +0 -143
- cartesia/_types.py +0 -70
- cartesia/_websocket.py +0 -358
- cartesia/async_client.py +0 -82
- cartesia/async_tts.py +0 -63
- cartesia/resource.py +0 -44
- cartesia/tts.py +0 -137
- cartesia/utils/deprecated.py +0 -55
- cartesia/utils/retry.py +0 -87
- cartesia/utils/tts.py +0 -78
- cartesia/voices.py +0 -208
- cartesia-1.3.1.dist-info/METADATA +0 -661
- cartesia-1.3.1.dist-info/RECORD +0 -23
- cartesia-1.3.1.dist-info/licenses/LICENSE.md +0 -21
- /cartesia/{utils/__init__.py → py.typed} +0 -0
- /cartesia/{_constants.py → tts/utils/constants.py} +0 -0
@@ -0,0 +1,430 @@
|
|
1
|
+
import base64
|
2
|
+
import json
|
3
|
+
import typing
|
4
|
+
import uuid
|
5
|
+
from collections import defaultdict
|
6
|
+
from typing import Any, Dict, Generator, Optional, Set, Union
|
7
|
+
|
8
|
+
try:
|
9
|
+
from websockets.sync.client import connect
|
10
|
+
|
11
|
+
IS_WEBSOCKET_SYNC_AVAILABLE = True
|
12
|
+
except ImportError:
|
13
|
+
IS_WEBSOCKET_SYNC_AVAILABLE = False
|
14
|
+
|
15
|
+
from iterators import TimeoutIterator # type: ignore
|
16
|
+
|
17
|
+
from cartesia.tts.requests import TtsRequestVoiceSpecifierParams
|
18
|
+
from cartesia.tts.requests.output_format import OutputFormatParams
|
19
|
+
from cartesia.tts.types import (
|
20
|
+
WebSocketResponse,
|
21
|
+
WebSocketResponse_Chunk,
|
22
|
+
WebSocketResponse_Done,
|
23
|
+
WebSocketResponse_Error,
|
24
|
+
WebSocketResponse_FlushDone,
|
25
|
+
WebSocketResponse_PhonemeTimestamps,
|
26
|
+
WebSocketResponse_Timestamps,
|
27
|
+
WebSocketTtsOutput,
|
28
|
+
WordTimestamps,
|
29
|
+
)
|
30
|
+
|
31
|
+
from ..core.pydantic_utilities import parse_obj_as
|
32
|
+
from .types.generation_request import GenerationRequest
|
33
|
+
|
34
|
+
|
35
|
+
class _TTSContext:
|
36
|
+
"""Manage a single context over a WebSocket.
|
37
|
+
|
38
|
+
This class can be used to stream inputs, as they become available, to a specific `context_id`. See README for usage.
|
39
|
+
|
40
|
+
See :class:`_AsyncTTSContext` for asynchronous use cases.
|
41
|
+
|
42
|
+
Each TTSContext will close automatically when a done message is received for that context. It also closes if there is an error.
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(self, context_id: str, websocket: "TtsWebsocket"):
|
46
|
+
self._context_id = context_id
|
47
|
+
self._websocket = websocket
|
48
|
+
self._error = None
|
49
|
+
|
50
|
+
def __del__(self):
|
51
|
+
self._close()
|
52
|
+
|
53
|
+
@property
|
54
|
+
def context_id(self) -> str:
|
55
|
+
return self._context_id
|
56
|
+
|
57
|
+
def send(
|
58
|
+
self,
|
59
|
+
*,
|
60
|
+
model_id: str,
|
61
|
+
transcript: str,
|
62
|
+
output_format: OutputFormatParams,
|
63
|
+
voice: TtsRequestVoiceSpecifierParams,
|
64
|
+
context_id: Optional[str] = None,
|
65
|
+
duration: Optional[int] = None,
|
66
|
+
language: Optional[str] = None,
|
67
|
+
stream: bool = True,
|
68
|
+
add_timestamps: bool = False,
|
69
|
+
) -> Generator[bytes, None, None]:
|
70
|
+
"""Send audio generation requests to the WebSocket and yield responses.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
request: The request to generate audio.
|
74
|
+
|
75
|
+
Yields:
|
76
|
+
Dictionary containing the following key(s):
|
77
|
+
- audio: The audio as bytes.
|
78
|
+
- context_id: The context ID for the request.
|
79
|
+
|
80
|
+
Raises:
|
81
|
+
ValueError: If provided context_id doesn't match the current context.
|
82
|
+
RuntimeError: If there's an error generating audio.
|
83
|
+
"""
|
84
|
+
self._websocket.connect()
|
85
|
+
assert self._websocket.websocket is not None, "WebSocket is not connected"
|
86
|
+
|
87
|
+
request_body = {
|
88
|
+
"model_id": model_id,
|
89
|
+
"transcript": transcript,
|
90
|
+
"output_format": output_format,
|
91
|
+
"voice": voice,
|
92
|
+
"context_id": self._context_id,
|
93
|
+
}
|
94
|
+
if context_id is not None:
|
95
|
+
request_body["context_id"] = context_id
|
96
|
+
if duration is not None:
|
97
|
+
request_body["duration"] = duration
|
98
|
+
if language is not None:
|
99
|
+
request_body["language"] = language
|
100
|
+
if stream:
|
101
|
+
request_body["stream"] = stream
|
102
|
+
if add_timestamps:
|
103
|
+
request_body["add_timestamps"] = add_timestamps
|
104
|
+
|
105
|
+
if (
|
106
|
+
"context_id" in request_body
|
107
|
+
and request_body["context_id"] is not None
|
108
|
+
and request_body["context_id"] != self._context_id
|
109
|
+
):
|
110
|
+
raise ValueError(
|
111
|
+
"Context ID does not match the context ID of the current context."
|
112
|
+
)
|
113
|
+
|
114
|
+
try:
|
115
|
+
text_iterator = TimeoutIterator(request_body["transcript"], timeout=0.001)
|
116
|
+
next_chunk = next(text_iterator, None)
|
117
|
+
|
118
|
+
while True:
|
119
|
+
# Send the next text chunk to the WebSocket if available
|
120
|
+
if (
|
121
|
+
next_chunk is not None
|
122
|
+
and next_chunk != text_iterator.get_sentinel()
|
123
|
+
):
|
124
|
+
request_body["transcript"] = next_chunk
|
125
|
+
request_body["continue"] = True
|
126
|
+
self._websocket.websocket.send(json.dumps(request_body))
|
127
|
+
next_chunk = next(text_iterator, None)
|
128
|
+
|
129
|
+
try:
|
130
|
+
# Receive responses from the WebSocket with a small timeout
|
131
|
+
response_obj = typing.cast(
|
132
|
+
WebSocketResponse,
|
133
|
+
parse_obj_as(
|
134
|
+
type_=WebSocketResponse, # type: ignore
|
135
|
+
object_=json.loads(
|
136
|
+
self._websocket.websocket.recv(timeout=0.001)
|
137
|
+
),
|
138
|
+
),
|
139
|
+
)
|
140
|
+
if response_obj.context_id != self._context_id:
|
141
|
+
pass
|
142
|
+
if isinstance(response_obj, WebSocketResponse_Error):
|
143
|
+
raise RuntimeError(
|
144
|
+
f"Error generating audio:\n{response_obj.error}"
|
145
|
+
)
|
146
|
+
if isinstance(response_obj, WebSocketResponse_Done):
|
147
|
+
break
|
148
|
+
if (
|
149
|
+
isinstance(response_obj, WebSocketResponse_Chunk)
|
150
|
+
or isinstance(response_obj, WebSocketResponse_Timestamps)
|
151
|
+
or isinstance(response_obj, WebSocketResponse_PhonemeTimestamps)
|
152
|
+
):
|
153
|
+
yield self._websocket._convert_response(
|
154
|
+
response_obj, include_context_id=True
|
155
|
+
)
|
156
|
+
except TimeoutError:
|
157
|
+
pass
|
158
|
+
|
159
|
+
# Continuously receive from WebSocket until the next text chunk is available
|
160
|
+
while next_chunk == text_iterator.get_sentinel():
|
161
|
+
try:
|
162
|
+
response_obj = typing.cast(
|
163
|
+
WebSocketResponse,
|
164
|
+
parse_obj_as(
|
165
|
+
type_=WebSocketResponse, # type: ignore
|
166
|
+
object_=json.loads(
|
167
|
+
self._websocket.websocket.recv(timeout=0.001)
|
168
|
+
),
|
169
|
+
),
|
170
|
+
)
|
171
|
+
if response_obj.context_id != self._context_id:
|
172
|
+
continue
|
173
|
+
if isinstance(response_obj, WebSocketResponse_Error):
|
174
|
+
raise RuntimeError(
|
175
|
+
f"Error generating audio:\n{response_obj.error}"
|
176
|
+
)
|
177
|
+
if isinstance(response_obj, WebSocketResponse_Done):
|
178
|
+
break
|
179
|
+
if (
|
180
|
+
isinstance(response_obj, WebSocketResponse_Chunk)
|
181
|
+
or isinstance(response_obj, WebSocketResponse_Timestamps)
|
182
|
+
or isinstance(
|
183
|
+
response_obj, WebSocketResponse_PhonemeTimestamps
|
184
|
+
)
|
185
|
+
):
|
186
|
+
yield self._websocket._convert_response(
|
187
|
+
response_obj, include_context_id=True
|
188
|
+
)
|
189
|
+
except TimeoutError:
|
190
|
+
pass
|
191
|
+
next_chunk = next(text_iterator, None)
|
192
|
+
|
193
|
+
# Send final message if all input text chunks are exhausted
|
194
|
+
if next_chunk is None:
|
195
|
+
request_body["transcript"] = ""
|
196
|
+
request_body["continue"] = False
|
197
|
+
self._websocket.websocket.send(json.dumps(request_body))
|
198
|
+
break
|
199
|
+
|
200
|
+
# Receive remaining messages from the WebSocket until "done" is received
|
201
|
+
while True:
|
202
|
+
response_obj = typing.cast(
|
203
|
+
WebSocketResponse,
|
204
|
+
parse_obj_as(
|
205
|
+
type_=WebSocketResponse, # type: ignore
|
206
|
+
object_=json.loads(self._websocket.websocket.recv()),
|
207
|
+
),
|
208
|
+
)
|
209
|
+
if response_obj.context_id != self._context_id:
|
210
|
+
continue
|
211
|
+
if isinstance(response_obj, WebSocketResponse_Error):
|
212
|
+
raise RuntimeError(f"Error generating audio:\n{response_obj.error}")
|
213
|
+
if isinstance(response_obj, WebSocketResponse_Done):
|
214
|
+
break
|
215
|
+
yield self._websocket._convert_response(
|
216
|
+
response_obj, include_context_id=True
|
217
|
+
)
|
218
|
+
|
219
|
+
except Exception as e:
|
220
|
+
self._websocket.close()
|
221
|
+
raise RuntimeError(f"Failed to generate audio. {e}")
|
222
|
+
|
223
|
+
def _close(self):
|
224
|
+
"""Closes the context. Automatically called when a done message is received for this context."""
|
225
|
+
self._websocket._remove_context(self._context_id)
|
226
|
+
|
227
|
+
def is_closed(self):
|
228
|
+
"""Check if the context is closed or not. Returns True if closed."""
|
229
|
+
return self._context_id not in self._websocket._contexts
|
230
|
+
|
231
|
+
|
232
|
+
class TtsWebsocket:
|
233
|
+
"""This class contains methods to generate audio using WebSocket. Ideal for low-latency audio generation.
|
234
|
+
|
235
|
+
Usage:
|
236
|
+
>>> ws = client.tts.websocket()
|
237
|
+
>>> generation_request = GenerationRequest(
|
238
|
+
... model_id="sonic-english",
|
239
|
+
... transcript="Hello world!",
|
240
|
+
... voice_embedding=embedding
|
241
|
+
... output_format={"container": "raw", "encoding": "pcm_f32le", "sample_rate": 44100}
|
242
|
+
... context_id=context_id,
|
243
|
+
... stream=True
|
244
|
+
... )
|
245
|
+
>>> for audio_chunk in ws.send(generation_request):
|
246
|
+
... audio = audio_chunk["audio"]
|
247
|
+
"""
|
248
|
+
|
249
|
+
def __init__(
|
250
|
+
self,
|
251
|
+
ws_url: str,
|
252
|
+
api_key: str,
|
253
|
+
cartesia_version: str,
|
254
|
+
):
|
255
|
+
self.ws_url = ws_url
|
256
|
+
self.api_key = api_key
|
257
|
+
self.cartesia_version = cartesia_version
|
258
|
+
self.websocket = None
|
259
|
+
self._contexts: Set[str] = set()
|
260
|
+
|
261
|
+
def __del__(self):
|
262
|
+
try:
|
263
|
+
self.close()
|
264
|
+
except Exception as e:
|
265
|
+
raise RuntimeError("Failed to close WebSocket: ", e)
|
266
|
+
|
267
|
+
def connect(self):
|
268
|
+
"""This method connects to the WebSocket if it is not already connected.
|
269
|
+
|
270
|
+
Raises:
|
271
|
+
RuntimeError: If the connection to the WebSocket fails.
|
272
|
+
"""
|
273
|
+
if not IS_WEBSOCKET_SYNC_AVAILABLE:
|
274
|
+
raise ImportError(
|
275
|
+
"The synchronous WebSocket client is not available. Please ensure that you have 'websockets>=12.0' or compatible version installed."
|
276
|
+
)
|
277
|
+
if self.websocket is None or self._is_websocket_closed():
|
278
|
+
route = "tts/websocket"
|
279
|
+
try:
|
280
|
+
self.websocket = connect(
|
281
|
+
f"{self.ws_url}/{route}?api_key={self.api_key}&cartesia_version={self.cartesia_version}"
|
282
|
+
)
|
283
|
+
except Exception as e:
|
284
|
+
raise RuntimeError(f"Failed to connect to WebSocket. {e}")
|
285
|
+
|
286
|
+
def _is_websocket_closed(self):
|
287
|
+
return self.websocket.socket.fileno() == -1
|
288
|
+
|
289
|
+
def close(self):
|
290
|
+
"""This method closes the WebSocket connection. *Highly* recommended to call this method when done using the WebSocket."""
|
291
|
+
if self.websocket and not self._is_websocket_closed():
|
292
|
+
self.websocket.close()
|
293
|
+
|
294
|
+
if self._contexts:
|
295
|
+
self._contexts.clear()
|
296
|
+
|
297
|
+
def _convert_response(
|
298
|
+
self,
|
299
|
+
response: typing.Union[
|
300
|
+
WebSocketResponse_Chunk,
|
301
|
+
WebSocketResponse_Timestamps,
|
302
|
+
WebSocketResponse_PhonemeTimestamps,
|
303
|
+
WebSocketResponse_FlushDone,
|
304
|
+
],
|
305
|
+
include_context_id: bool,
|
306
|
+
include_flush_id: bool = False,
|
307
|
+
) -> WebSocketTtsOutput:
|
308
|
+
out = {}
|
309
|
+
if isinstance(response, WebSocketResponse_Chunk):
|
310
|
+
out["audio"] = base64.b64decode(response.data)
|
311
|
+
elif isinstance(response, WebSocketResponse_Timestamps):
|
312
|
+
out["word_timestamps"] = response.word_timestamps # type: ignore
|
313
|
+
elif include_flush_id and isinstance(response, WebSocketResponse_FlushDone):
|
314
|
+
out["flush_done"] = response.flush_done # type: ignore
|
315
|
+
out["flush_id"] = response.flush_id # type: ignore
|
316
|
+
|
317
|
+
if include_context_id and response.context_id:
|
318
|
+
out["context_id"] = response.context_id # type: ignore
|
319
|
+
|
320
|
+
return WebSocketTtsOutput(**out) # type: ignore
|
321
|
+
|
322
|
+
def send(
|
323
|
+
self,
|
324
|
+
*,
|
325
|
+
model_id: str,
|
326
|
+
transcript: str,
|
327
|
+
output_format: OutputFormatParams,
|
328
|
+
voice: TtsRequestVoiceSpecifierParams,
|
329
|
+
context_id: Optional[str] = None,
|
330
|
+
duration: Optional[int] = None,
|
331
|
+
language: Optional[str] = None,
|
332
|
+
stream: bool = True,
|
333
|
+
add_timestamps: bool = False,
|
334
|
+
):
|
335
|
+
"""Send a request to the WebSocket to generate audio.
|
336
|
+
|
337
|
+
Args:
|
338
|
+
request: The request to generate audio.
|
339
|
+
stream: Whether to stream the audio or not.
|
340
|
+
|
341
|
+
Returns:
|
342
|
+
If `stream` is True, the method returns a generator that yields chunks. Each chunk is a dictionary.
|
343
|
+
If `stream` is False, the method returns a dictionary.
|
344
|
+
Both the generator and the dictionary contain the following key(s):
|
345
|
+
- audio: The audio as bytes.
|
346
|
+
- context_id: The context ID for the request.
|
347
|
+
"""
|
348
|
+
self.connect()
|
349
|
+
|
350
|
+
if context_id is None:
|
351
|
+
context_id = str(uuid.uuid4())
|
352
|
+
|
353
|
+
request_body = {
|
354
|
+
"model_id": model_id,
|
355
|
+
"transcript": transcript,
|
356
|
+
"output_format": output_format,
|
357
|
+
"voice": voice,
|
358
|
+
"context_id": context_id,
|
359
|
+
"duration": duration,
|
360
|
+
"language": language,
|
361
|
+
"stream": stream,
|
362
|
+
"add_timestamps": add_timestamps,
|
363
|
+
}
|
364
|
+
generator = self._websocket_generator(request_body)
|
365
|
+
|
366
|
+
if stream:
|
367
|
+
return generator
|
368
|
+
|
369
|
+
chunks: typing.List[str] = []
|
370
|
+
words: typing.List[str] = []
|
371
|
+
start: typing.List[float] = []
|
372
|
+
end: typing.List[float] = []
|
373
|
+
for chunk in generator:
|
374
|
+
if chunk.audio is not None:
|
375
|
+
chunks.append(chunk.audio)
|
376
|
+
if add_timestamps and chunk.word_timestamps is not None:
|
377
|
+
if chunk.word_timestamps is not None:
|
378
|
+
words.extend(chunk.word_timestamps.words)
|
379
|
+
start.extend(chunk.word_timestamps.start)
|
380
|
+
end.extend(chunk.word_timestamps.end)
|
381
|
+
|
382
|
+
return WebSocketTtsOutput(
|
383
|
+
audio=b"".join(chunks), # type: ignore
|
384
|
+
context_id=context_id,
|
385
|
+
word_timestamps=(
|
386
|
+
WordTimestamps(
|
387
|
+
words=words,
|
388
|
+
start=start,
|
389
|
+
end=end,
|
390
|
+
)
|
391
|
+
if add_timestamps
|
392
|
+
else None
|
393
|
+
),
|
394
|
+
)
|
395
|
+
|
396
|
+
def _websocket_generator(self, request_body: Dict[str, Any]):
|
397
|
+
assert self.websocket is not None, "WebSocket is not connected"
|
398
|
+
self.websocket.send(json.dumps(request_body))
|
399
|
+
|
400
|
+
try:
|
401
|
+
while True:
|
402
|
+
response_obj = typing.cast(
|
403
|
+
WebSocketResponse,
|
404
|
+
parse_obj_as(
|
405
|
+
type_=WebSocketResponse, # type: ignore
|
406
|
+
object_=json.loads(self.websocket.recv()),
|
407
|
+
),
|
408
|
+
)
|
409
|
+
if isinstance(response_obj, WebSocketResponse_Error):
|
410
|
+
raise RuntimeError(f"Error generating audio:\n{response_obj.error}")
|
411
|
+
if isinstance(response_obj, WebSocketResponse_Done):
|
412
|
+
break
|
413
|
+
yield self._convert_response(response_obj, include_context_id=True)
|
414
|
+
except Exception as e:
|
415
|
+
# Close the websocket connection if an error occurs.
|
416
|
+
self.close()
|
417
|
+
raise RuntimeError(f"Failed to generate audio. {e}") from e
|
418
|
+
|
419
|
+
def _remove_context(self, context_id: str):
|
420
|
+
if context_id in self._contexts:
|
421
|
+
self._contexts.remove(context_id)
|
422
|
+
|
423
|
+
def context(self, context_id: Optional[str] = None):
|
424
|
+
if context_id in self._contexts:
|
425
|
+
raise ValueError(f"Context for context ID {context_id} already exists.")
|
426
|
+
if context_id is None:
|
427
|
+
context_id = str(uuid.uuid4())
|
428
|
+
if context_id not in self._contexts:
|
429
|
+
self._contexts.add(context_id)
|
430
|
+
return _TTSContext(context_id, self)
|