livekit-plugins-cartesia 0.4.7__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,13 +18,12 @@ import asyncio
18
18
  import base64
19
19
  import json
20
20
  import os
21
+ import weakref
21
22
  from dataclasses import dataclass
22
- from typing import Any
23
+ from typing import Any, Optional
23
24
 
24
25
  import aiohttp
25
- from livekit import rtc
26
26
  from livekit.agents import (
27
- DEFAULT_API_CONNECT_OPTIONS,
28
27
  APIConnectionError,
29
28
  APIConnectOptions,
30
29
  APIStatusError,
@@ -61,6 +60,13 @@ class _TTSOptions:
61
60
  emotion: list[TTSVoiceEmotion | str] | None
62
61
  api_key: str
63
62
  language: str
63
+ base_url: str
64
+
65
+ def get_http_url(self, path: str) -> str:
66
+ return f"{self.base_url}{path}"
67
+
68
+ def get_ws_url(self, path: str) -> str:
69
+ return f"{self.base_url.replace('http', 'ws', 1)}{path}"
64
70
 
65
71
 
66
72
  class TTS(tts.TTS):
@@ -76,6 +82,7 @@ class TTS(tts.TTS):
76
82
  sample_rate: int = 24000,
77
83
  api_key: str | None = None,
78
84
  http_session: aiohttp.ClientSession | None = None,
85
+ base_url: str = "https://api.cartesia.ai",
79
86
  ) -> None:
80
87
  """
81
88
  Create a new instance of Cartesia TTS.
@@ -92,6 +99,7 @@ class TTS(tts.TTS):
92
99
  sample_rate (int, optional): The audio sample rate in Hz. Defaults to 24000.
93
100
  api_key (str, optional): The Cartesia API key. If not provided, it will be read from the CARTESIA_API_KEY environment variable.
94
101
  http_session (aiohttp.ClientSession | None, optional): An existing aiohttp ClientSession to use. If not provided, a new session will be created.
102
+ base_url (str, optional): The base URL for the Cartesia API. Defaults to "https://api.cartesia.ai".
95
103
  """
96
104
 
97
105
  super().__init__(
@@ -113,8 +121,28 @@ class TTS(tts.TTS):
113
121
  speed=speed,
114
122
  emotion=emotion,
115
123
  api_key=api_key,
124
+ base_url=base_url,
116
125
  )
117
126
  self._session = http_session
127
+ self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
128
+ connect_cb=self._connect_ws,
129
+ close_cb=self._close_ws,
130
+ max_session_duration=300,
131
+ mark_refreshed_on_get=True,
132
+ )
133
+ self._streams = weakref.WeakSet[SynthesizeStream]()
134
+
135
+ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
136
+ session = self._ensure_session()
137
+ url = self._opts.get_ws_url(
138
+ f"/tts/websocket?api_key={self._opts.api_key}&cartesia_version={API_VERSION}"
139
+ )
140
+ return await asyncio.wait_for(
141
+ session.ws_connect(url), self._conn_options.timeout
142
+ )
143
+
144
+ async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse):
145
+ await ws.close()
118
146
 
119
147
  def _ensure_session(self) -> aiohttp.ClientSession:
120
148
  if not self._session:
@@ -122,6 +150,9 @@ class TTS(tts.TTS):
122
150
 
123
151
  return self._session
124
152
 
153
+ def prewarm(self) -> None:
154
+ self._pool.prewarm()
155
+
125
156
  def update_options(
126
157
  self,
127
158
  *,
@@ -155,7 +186,7 @@ class TTS(tts.TTS):
155
186
  self,
156
187
  text: str,
157
188
  *,
158
- conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
189
+ conn_options: Optional[APIConnectOptions] = None,
159
190
  ) -> ChunkedStream:
160
191
  return ChunkedStream(
161
192
  tts=self,
@@ -166,14 +197,22 @@ class TTS(tts.TTS):
166
197
  )
167
198
 
168
199
  def stream(
169
- self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
200
+ self, *, conn_options: Optional[APIConnectOptions] = None
170
201
  ) -> "SynthesizeStream":
171
- return SynthesizeStream(
202
+ stream = SynthesizeStream(
172
203
  tts=self,
173
- conn_options=conn_options,
204
+ pool=self._pool,
174
205
  opts=self._opts,
175
- session=self._ensure_session(),
176
206
  )
207
+ self._streams.add(stream)
208
+ return stream
209
+
210
+ async def aclose(self) -> None:
211
+ for stream in list(self._streams):
212
+ await stream.aclose()
213
+ self._streams.clear()
214
+ await self._pool.aclose()
215
+ await super().aclose()
177
216
 
178
217
 
179
218
  class ChunkedStream(tts.ChunkedStream):
@@ -184,9 +223,9 @@ class ChunkedStream(tts.ChunkedStream):
184
223
  *,
185
224
  tts: TTS,
186
225
  input_text: str,
187
- conn_options: APIConnectOptions,
188
226
  opts: _TTSOptions,
189
227
  session: aiohttp.ClientSession,
228
+ conn_options: Optional[APIConnectOptions] = None,
190
229
  ) -> None:
191
230
  super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
192
231
  self._opts, self._session = opts, session
@@ -207,7 +246,7 @@ class ChunkedStream(tts.ChunkedStream):
207
246
 
208
247
  try:
209
248
  async with self._session.post(
210
- "https://api.cartesia.ai/tts/bytes",
249
+ self._opts.get_http_url("/tts/bytes"),
211
250
  headers=headers,
212
251
  json=json,
213
252
  timeout=aiohttp.ClientTimeout(
@@ -216,19 +255,17 @@ class ChunkedStream(tts.ChunkedStream):
216
255
  ),
217
256
  ) as resp:
218
257
  resp.raise_for_status()
258
+ emitter = tts.SynthesizedAudioEmitter(
259
+ event_ch=self._event_ch,
260
+ request_id=request_id,
261
+ )
219
262
  async for data, _ in resp.content.iter_chunks():
220
263
  for frame in bstream.write(data):
221
- self._event_ch.send_nowait(
222
- tts.SynthesizedAudio(
223
- request_id=request_id,
224
- frame=frame,
225
- )
226
- )
264
+ emitter.push(frame)
227
265
 
228
266
  for frame in bstream.flush():
229
- self._event_ch.send_nowait(
230
- tts.SynthesizedAudio(request_id=request_id, frame=frame)
231
- )
267
+ emitter.push(frame)
268
+ emitter.flush()
232
269
  except asyncio.TimeoutError as e:
233
270
  raise APITimeoutError() from e
234
271
  except aiohttp.ClientResponseError as e:
@@ -247,12 +284,11 @@ class SynthesizeStream(tts.SynthesizeStream):
247
284
  self,
248
285
  *,
249
286
  tts: TTS,
250
- conn_options: APIConnectOptions,
251
287
  opts: _TTSOptions,
252
- session: aiohttp.ClientSession,
288
+ pool: utils.ConnectionPool[aiohttp.ClientWebSocketResponse],
253
289
  ):
254
- super().__init__(tts=tts, conn_options=conn_options)
255
- self._opts, self._session = opts, session
290
+ super().__init__(tts=tts)
291
+ self._opts, self._pool = opts, pool
256
292
  self._sent_tokenizer_stream = tokenize.basic.SentenceTokenizer(
257
293
  min_sentence_len=BUFFERED_WORDS_COUNT
258
294
  ).stream()
@@ -289,22 +325,10 @@ class SynthesizeStream(tts.SynthesizeStream):
289
325
  sample_rate=self._opts.sample_rate,
290
326
  num_channels=NUM_CHANNELS,
291
327
  )
292
-
293
- last_frame: rtc.AudioFrame | None = None
294
-
295
- def _send_last_frame(*, segment_id: str, is_final: bool) -> None:
296
- nonlocal last_frame
297
- if last_frame is not None:
298
- self._event_ch.send_nowait(
299
- tts.SynthesizedAudio(
300
- request_id=request_id,
301
- segment_id=segment_id,
302
- frame=last_frame,
303
- is_final=is_final,
304
- )
305
- )
306
-
307
- last_frame = None
328
+ emitter = tts.SynthesizedAudioEmitter(
329
+ event_ch=self._event_ch,
330
+ request_id=request_id,
331
+ )
308
332
 
309
333
  while True:
310
334
  msg = await ws.receive()
@@ -324,35 +348,23 @@ class SynthesizeStream(tts.SynthesizeStream):
324
348
 
325
349
  data = json.loads(msg.data)
326
350
  segment_id = data.get("context_id")
351
+ emitter._segment_id = segment_id
327
352
 
328
353
  if data.get("data"):
329
354
  b64data = base64.b64decode(data["data"])
330
355
  for frame in audio_bstream.write(b64data):
331
- _send_last_frame(segment_id=segment_id, is_final=False)
332
- last_frame = frame
356
+ emitter.push(frame)
333
357
  elif data.get("done"):
334
358
  for frame in audio_bstream.flush():
335
- _send_last_frame(segment_id=segment_id, is_final=False)
336
- last_frame = frame
337
-
338
- _send_last_frame(segment_id=segment_id, is_final=True)
339
-
359
+ emitter.push(frame)
360
+ emitter.flush()
340
361
  if segment_id == request_id:
341
- # we're not going to receive more frames, close the connection
342
- await ws.close()
362
+ # we're not going to receive more frames, end stream
343
363
  break
344
364
  else:
345
365
  logger.error("unexpected Cartesia message %s", data)
346
366
 
347
- url = f"wss://api.cartesia.ai/tts/websocket?api_key={self._opts.api_key}&cartesia_version={API_VERSION}"
348
-
349
- ws: aiohttp.ClientWebSocketResponse | None = None
350
-
351
- try:
352
- ws = await asyncio.wait_for(
353
- self._session.ws_connect(url), self._conn_options.timeout
354
- )
355
-
367
+ async with self._pool.connection() as ws:
356
368
  tasks = [
357
369
  asyncio.create_task(_input_task()),
358
370
  asyncio.create_task(_sentence_stream_task(ws)),
@@ -363,9 +375,6 @@ class SynthesizeStream(tts.SynthesizeStream):
363
375
  await asyncio.gather(*tasks)
364
376
  finally:
365
377
  await utils.aio.gracefully_cancel(*tasks)
366
- finally:
367
- if ws is not None:
368
- await ws.close()
369
378
 
370
379
 
371
380
  def _to_cartesia_options(opts: _TTSOptions) -> dict[str, Any]:
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __version__ = "0.4.7"
15
+ __version__ = "0.4.9"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: livekit-plugins-cartesia
3
- Version: 0.4.7
3
+ Version: 0.4.9
4
4
  Summary: LiveKit Agents Plugin for Cartesia
5
5
  Home-page: https://github.com/livekit/agents
6
6
  License: Apache-2.0
@@ -19,7 +19,7 @@ Classifier: Programming Language :: Python :: 3.10
19
19
  Classifier: Programming Language :: Python :: 3 :: Only
20
20
  Requires-Python: >=3.9.0
21
21
  Description-Content-Type: text/markdown
22
- Requires-Dist: livekit-agents>=0.12.11
22
+ Requires-Dist: livekit-agents<1.0.0,>=0.12.16
23
23
  Dynamic: classifier
24
24
  Dynamic: description
25
25
  Dynamic: description-content-type
@@ -0,0 +1,10 @@
1
+ livekit/plugins/cartesia/__init__.py,sha256=UTa6Q7IxhRBCwPftowHEUDvmBg99J_UjGS_yxTzKD7g,1095
2
+ livekit/plugins/cartesia/log.py,sha256=4Mnhjng_DU1dIWP9IWjIQGZ67EV3LnQhWMWCHVudJbo,71
3
+ livekit/plugins/cartesia/models.py,sha256=56CJgo7my-w-vpedir_ImV_aqKASeLihE5DbcCCgGJI,950
4
+ livekit/plugins/cartesia/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ livekit/plugins/cartesia/tts.py,sha256=C7R_Z2nmWK3V0bQ30Nl5lgRgIZTfCzb7hiYWfBohGQs,14138
6
+ livekit/plugins/cartesia/version.py,sha256=IgraxfFuS50xVM2gugcS0QqKE5YXL-FzFLco3Ffq6BY,600
7
+ livekit_plugins_cartesia-0.4.9.dist-info/METADATA,sha256=P-63JHYrrl1DtEUBVVCnSPeS2vnws2dIdTPclQNG4ow,1470
8
+ livekit_plugins_cartesia-0.4.9.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
9
+ livekit_plugins_cartesia-0.4.9.dist-info/top_level.txt,sha256=OoDok3xUmXbZRvOrfvvXB-Juu4DX79dlq188E19YHoo,8
10
+ livekit_plugins_cartesia-0.4.9.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (75.8.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,10 +0,0 @@
1
- livekit/plugins/cartesia/__init__.py,sha256=UTa6Q7IxhRBCwPftowHEUDvmBg99J_UjGS_yxTzKD7g,1095
2
- livekit/plugins/cartesia/log.py,sha256=4Mnhjng_DU1dIWP9IWjIQGZ67EV3LnQhWMWCHVudJbo,71
3
- livekit/plugins/cartesia/models.py,sha256=56CJgo7my-w-vpedir_ImV_aqKASeLihE5DbcCCgGJI,950
4
- livekit/plugins/cartesia/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- livekit/plugins/cartesia/tts.py,sha256=MKMLQBLc9A9XBkfwbFFt7PogUiWUHl-HozUR9mBmtMI,13906
6
- livekit/plugins/cartesia/version.py,sha256=85hVEOZ--XQ1Y7ngd1qGTZPpeywK2do8-2uhP_kdeyA,600
7
- livekit_plugins_cartesia-0.4.7.dist-info/METADATA,sha256=gZFU-QvodV6JmJiLDZ_OluuXviq8Zef-pYxhNEvxHho,1463
8
- livekit_plugins_cartesia-0.4.7.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
9
- livekit_plugins_cartesia-0.4.7.dist-info/top_level.txt,sha256=OoDok3xUmXbZRvOrfvvXB-Juu4DX79dlq188E19YHoo,8
10
- livekit_plugins_cartesia-0.4.7.dist-info/RECORD,,