livekit-plugins-cartesia 0.2.0.dev2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,13 +19,12 @@ __all__ = ["TTS", "ChunkedStream", "__version__"]
19
19
 
20
20
  from livekit.agents import Plugin
21
21
 
22
+ from .log import logger
23
+
22
24
 
23
25
  class CartesiaPlugin(Plugin):
24
26
  def __init__(self):
25
- super().__init__(__name__, __version__, __package__)
26
-
27
- def download_files(self):
28
- pass
27
+ super().__init__(__name__, __version__, __package__, logger)
29
28
 
30
29
 
31
30
  Plugin.register_plugin(CartesiaPlugin())
@@ -11,4 +11,4 @@ TTSEncoding = Literal[
11
11
 
12
12
  TTSModels = Literal["sonic-english", "sonic-multilingual"]
13
13
  TTSLanguages = Literal["en", "es", "fr", "de", "pt", "zh", "ja"]
14
- TTSDefaultVoiceId = "248be419-c632-4f23-adf1-5324ed7dbf1d"
14
+ TTSDefaultVoiceId = "b7d50908-b17c-442d-ad8d-810c63997ed9"
@@ -14,11 +14,14 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ import asyncio
18
+ import base64
19
+ import json
17
20
  import os
18
21
  from dataclasses import dataclass
19
22
 
20
23
  import aiohttp
21
- from livekit.agents import tts, utils
24
+ from livekit.agents import tokenize, tts, utils
22
25
 
23
26
  from .log import logger
24
27
  from .models import TTSDefaultVoiceId, TTSEncoding, TTSModels
@@ -27,6 +30,9 @@ API_AUTH_HEADER = "X-API-Key"
27
30
  API_VERSION_HEADER = "Cartesia-Version"
28
31
  API_VERSION = "2024-06-10"
29
32
 
33
+ NUM_CHANNELS = 1
34
+ BUFFERED_WORDS_COUNT = 8
35
+
30
36
 
31
37
  @dataclass
32
38
  class _TTSOptions:
@@ -36,6 +42,7 @@ class _TTSOptions:
36
42
  voice: str | list[float]
37
43
  api_key: str
38
44
  language: str
45
+ word_tokenizer: tokenize.WordTokenizer
39
46
 
40
47
 
41
48
  class TTS(tts.TTS):
@@ -49,11 +56,14 @@ class TTS(tts.TTS):
49
56
  sample_rate: int = 24000,
50
57
  api_key: str | None = None,
51
58
  http_session: aiohttp.ClientSession | None = None,
59
+ word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer(
60
+ ignore_punctuation=False
61
+ ),
52
62
  ) -> None:
53
63
  super().__init__(
54
- capabilities=tts.TTSCapabilities(streaming=False),
64
+ capabilities=tts.TTSCapabilities(streaming=True),
55
65
  sample_rate=sample_rate,
56
- num_channels=1,
66
+ num_channels=NUM_CHANNELS,
57
67
  )
58
68
 
59
69
  api_key = api_key or os.environ.get("CARTESIA_API_KEY")
@@ -67,6 +77,7 @@ class TTS(tts.TTS):
67
77
  sample_rate=sample_rate,
68
78
  voice=voice,
69
79
  api_key=api_key,
80
+ word_tokenizer=word_tokenizer,
70
81
  )
71
82
  self._session = http_session
72
83
 
@@ -79,8 +90,13 @@ class TTS(tts.TTS):
79
90
  def synthesize(self, text: str) -> "ChunkedStream":
80
91
  return ChunkedStream(text, self._opts, self._ensure_session())
81
92
 
93
+ def stream(self) -> "SynthesizeStream":
94
+ return SynthesizeStream(self._opts, self._ensure_session())
95
+
82
96
 
83
97
  class ChunkedStream(tts.ChunkedStream):
98
+ """Synthesize chunked text using the bytes endpoint"""
99
+
84
100
  def __init__(
85
101
  self, text: str, opts: _TTSOptions, session: aiohttp.ClientSession
86
102
  ) -> None:
@@ -90,35 +106,17 @@ class ChunkedStream(tts.ChunkedStream):
90
106
  @utils.log_exceptions(logger=logger)
91
107
  async def _main_task(self):
92
108
  bstream = utils.audio.AudioByteStream(
93
- sample_rate=self._opts.sample_rate, num_channels=1
109
+ sample_rate=self._opts.sample_rate, num_channels=NUM_CHANNELS
94
110
  )
95
- request_id = utils.shortuuid()
96
- segment_id = utils.shortuuid()
97
-
98
- voice = {}
99
- if isinstance(self._opts.voice, str):
100
- voice["mode"] = "id"
101
- voice["id"] = self._opts.voice
102
- else:
103
- voice["mode"] = "embedding"
104
- voice["embedding"] = self._opts.voice
105
-
106
- data = {
107
- "model_id": self._opts.model,
108
- "transcript": self._text,
109
- "voice": voice,
110
- "output_format": {
111
- "container": "raw",
112
- "encoding": self._opts.encoding,
113
- "sample_rate": self._opts.sample_rate,
114
- },
115
- "language": self._opts.language,
116
- }
111
+ request_id, segment_id = utils.shortuuid(), utils.shortuuid()
112
+
113
+ data = _to_cartesia_options(self._opts)
114
+ data["transcript"] = self._text
117
115
 
118
116
  async with self._session.post(
119
117
  "https://api.cartesia.ai/tts/bytes",
120
118
  headers={
121
- API_AUTH_HEADER: f"{self._opts.api_key}",
119
+ API_AUTH_HEADER: self._opts.api_key,
122
120
  API_VERSION_HEADER: API_VERSION,
123
121
  },
124
122
  json=data,
@@ -137,3 +135,169 @@ class ChunkedStream(tts.ChunkedStream):
137
135
  request_id=request_id, segment_id=segment_id, frame=frame
138
136
  )
139
137
  )
138
+
139
+
140
+ class SynthesizeStream(tts.SynthesizeStream):
141
+ def __init__(
142
+ self,
143
+ opts: _TTSOptions,
144
+ session: aiohttp.ClientSession,
145
+ ):
146
+ super().__init__()
147
+ self._opts, self._session = opts, session
148
+ self._buf = ""
149
+
150
+ @utils.log_exceptions(logger=logger)
151
+ async def _main_task(self) -> None:
152
+ retry_count = 0
153
+ max_retry = 3
154
+ while self._input_ch.qsize() or not self._input_ch.closed:
155
+ try:
156
+ url = f"wss://api.cartesia.ai/tts/websocket?api_key={self._opts.api_key}&cartesia_version={API_VERSION}"
157
+ ws = await self._session.ws_connect(url)
158
+ retry_count = 0 # connected successfully, reset the retry_count
159
+
160
+ await self._run_ws(ws)
161
+ except Exception as e:
162
+ if retry_count >= max_retry:
163
+ logger.exception(
164
+ f"failed to connect to Cartesia after {max_retry} tries"
165
+ )
166
+ break
167
+
168
+ retry_delay = min(retry_count * 2, 10) # max 10s
169
+ retry_count += 1
170
+
171
+ logger.warning(
172
+ f"Cartesia connection failed, retrying in {retry_delay}s",
173
+ exc_info=e,
174
+ )
175
+ await asyncio.sleep(retry_delay)
176
+
177
+ async def _run_ws(self, ws: aiohttp.ClientWebSocketResponse) -> None:
178
+ request_id = utils.shortuuid()
179
+ pending_segments = []
180
+
181
+ async def send_task():
182
+ base_pkt = _to_cartesia_options(self._opts)
183
+
184
+ def _new_segment():
185
+ segment_id = utils.shortuuid()
186
+ pending_segments.append(segment_id)
187
+ return segment_id
188
+
189
+ current_segment_id: str | None = _new_segment()
190
+
191
+ async for data in self._input_ch:
192
+ if isinstance(data, self._FlushSentinel):
193
+ if current_segment_id is None:
194
+ continue
195
+
196
+ end_pkt = base_pkt.copy()
197
+ end_pkt["context_id"] = current_segment_id
198
+ end_pkt["transcript"] = self._buf + " "
199
+ end_pkt["continue"] = False
200
+ await ws.send_str(json.dumps(end_pkt))
201
+
202
+ current_segment_id = None
203
+ self._buf = ""
204
+ elif data:
205
+ if current_segment_id is None:
206
+ current_segment_id = _new_segment()
207
+
208
+ self._buf += data
209
+ words = self._opts.word_tokenizer.tokenize(text=self._buf)
210
+ if len(words) < BUFFERED_WORDS_COUNT + 1:
211
+ continue
212
+
213
+ data = self._opts.word_tokenizer.format_words(words[:-1]) + " "
214
+ self._buf = words[-1]
215
+
216
+ token_pkt = base_pkt.copy()
217
+ token_pkt["context_id"] = current_segment_id
218
+ token_pkt["transcript"] = data
219
+ token_pkt["continue"] = True
220
+ await ws.send_str(json.dumps(token_pkt))
221
+
222
+ if len(pending_segments) == 0:
223
+ await ws.close()
224
+
225
+ async def recv_task():
226
+ audio_bstream = utils.audio.AudioByteStream(
227
+ sample_rate=self._opts.sample_rate,
228
+ num_channels=NUM_CHANNELS,
229
+ )
230
+
231
+ while True:
232
+ msg = await ws.receive()
233
+ if msg.type in (
234
+ aiohttp.WSMsgType.CLOSED,
235
+ aiohttp.WSMsgType.CLOSE,
236
+ aiohttp.WSMsgType.CLOSING,
237
+ ):
238
+ raise Exception("Cartesia connection closed unexpectedly")
239
+
240
+ if msg.type != aiohttp.WSMsgType.TEXT:
241
+ logger.warning("unexpected Cartesia message type %s", msg.type)
242
+ continue
243
+
244
+ data = json.loads(msg.data)
245
+ segment_id = data.get("context_id")
246
+ if data.get("data"):
247
+ b64data = base64.b64decode(data["data"])
248
+ for frame in audio_bstream.write(b64data):
249
+ self._event_ch.send_nowait(
250
+ tts.SynthesizedAudio(
251
+ request_id=request_id,
252
+ segment_id=segment_id,
253
+ frame=frame,
254
+ )
255
+ )
256
+ elif data.get("done"):
257
+ for frame in audio_bstream.flush():
258
+ self._event_ch.send_nowait(
259
+ tts.SynthesizedAudio(
260
+ request_id=request_id,
261
+ segment_id=segment_id,
262
+ frame=frame,
263
+ )
264
+ )
265
+
266
+ pending_segments.remove(segment_id)
267
+ if len(pending_segments) == 0 and self._input_ch.closed:
268
+ # we're not going to receive more frames, close the connection
269
+ await ws.close()
270
+ break
271
+ else:
272
+ logger.error("unexpected Cartesia message %s", data)
273
+
274
+ tasks = [
275
+ asyncio.create_task(send_task()),
276
+ asyncio.create_task(recv_task()),
277
+ ]
278
+
279
+ try:
280
+ await asyncio.gather(*tasks)
281
+ finally:
282
+ await utils.aio.gracefully_cancel(*tasks)
283
+
284
+
285
+ def _to_cartesia_options(opts: _TTSOptions) -> dict:
286
+ voice: dict = {}
287
+ if isinstance(opts.voice, str):
288
+ voice["mode"] = "id"
289
+ voice["id"] = opts.voice
290
+ else:
291
+ voice["mode"] = "embedding"
292
+ voice["embedding"] = opts.voice
293
+
294
+ return {
295
+ "model_id": opts.model,
296
+ "voice": voice,
297
+ "output_format": {
298
+ "container": "raw",
299
+ "encoding": opts.encoding,
300
+ "sample_rate": opts.sample_rate,
301
+ },
302
+ "language": opts.language,
303
+ }
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __version__ = "0.2.0-dev.2"
15
+ __version__ = "0.3.0"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: livekit-plugins-cartesia
3
- Version: 0.2.0.dev2
3
+ Version: 0.3.0
4
4
  Summary: LiveKit Agents Plugin for Cartesia
5
5
  Home-page: https://github.com/livekit/agents
6
6
  License: Apache-2.0
@@ -19,7 +19,7 @@ Classifier: Programming Language :: Python :: 3.10
19
19
  Classifier: Programming Language :: Python :: 3 :: Only
20
20
  Requires-Python: >=3.9.0
21
21
  Description-Content-Type: text/markdown
22
- Requires-Dist: livekit-agents ~=0.7
22
+ Requires-Dist: livekit-agents >=0.8.0.dev0
23
23
 
24
24
  # LiveKit Plugins Cartesia
25
25
 
@@ -0,0 +1,10 @@
1
+ livekit/plugins/cartesia/__init__.py,sha256=BUfWY_evL5dUHn9hBDQVor6ssctDKQfbQfZy5SWndN8,926
2
+ livekit/plugins/cartesia/log.py,sha256=4Mnhjng_DU1dIWP9IWjIQGZ67EV3LnQhWMWCHVudJbo,71
3
+ livekit/plugins/cartesia/models.py,sha256=T1iPQ18h4-o5rgSW236PDc73qp5zR9k4r_qNCl3XPWc,335
4
+ livekit/plugins/cartesia/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ livekit/plugins/cartesia/tts.py,sha256=uklD9fIYL8QWUSiyypFDgflkie9VhTu1C-x4YwJcDCU,10283
6
+ livekit/plugins/cartesia/version.py,sha256=G5iYozum4q7UpHwW43F7QfhzUfwcncPxBZ0gmUGsd5I,600
7
+ livekit_plugins_cartesia-0.3.0.dist-info/METADATA,sha256=iJcOyrkQ-0yPK_lYtR-eEbIDav84xlN7DUvwncx7OpQ,1252
8
+ livekit_plugins_cartesia-0.3.0.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
9
+ livekit_plugins_cartesia-0.3.0.dist-info/top_level.txt,sha256=OoDok3xUmXbZRvOrfvvXB-Juu4DX79dlq188E19YHoo,8
10
+ livekit_plugins_cartesia-0.3.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (71.1.0)
2
+ Generator: setuptools (72.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,10 +0,0 @@
1
- livekit/plugins/cartesia/__init__.py,sha256=_a8u7qqya1pjZTV19gNOpMKTO7ccAVZAeCukiDKAG-U,937
2
- livekit/plugins/cartesia/log.py,sha256=4Mnhjng_DU1dIWP9IWjIQGZ67EV3LnQhWMWCHVudJbo,71
3
- livekit/plugins/cartesia/models.py,sha256=06S-Z-M90kB-kEOQsQk70xfQUD-TztU4ZIU_AfAyUMc,335
4
- livekit/plugins/cartesia/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- livekit/plugins/cartesia/tts.py,sha256=S5BMSVtsbNI_c2PpgyFK6wvleudmJZLTUt3ZmGNKlRI,4319
6
- livekit/plugins/cartesia/version.py,sha256=j435tq2PUytJ_b4XnAN8SFy_5GzOk4LHUz-zUdDCYIE,606
7
- livekit_plugins_cartesia-0.2.0.dev2.dist-info/METADATA,sha256=Y63CUmdnZq9MrEXokSN6OkhFvnwGX5FAmPLhwhV9920,1250
8
- livekit_plugins_cartesia-0.2.0.dev2.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
9
- livekit_plugins_cartesia-0.2.0.dev2.dist-info/top_level.txt,sha256=OoDok3xUmXbZRvOrfvvXB-Juu4DX79dlq188E19YHoo,8
10
- livekit_plugins_cartesia-0.2.0.dev2.dist-info/RECORD,,