livekit-plugins-cartesia 1.0.22__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,10 +17,11 @@
17
17
  See https://docs.livekit.io/agents/integrations/tts/cartesia/ for more information.
18
18
  """
19
19
 
20
+ from .stt import STT
20
21
  from .tts import TTS, ChunkedStream
21
22
  from .version import __version__
22
23
 
23
- __all__ = ["TTS", "ChunkedStream", "__version__"]
24
+ __all__ = ["STT", "TTS", "ChunkedStream", "__version__"]
24
25
 
25
26
  from livekit.agents import Plugin
26
27
 
@@ -28,7 +29,7 @@ from .log import logger
28
29
 
29
30
 
30
31
  class CartesiaPlugin(Plugin):
31
- def __init__(self):
32
+ def __init__(self) -> None:
32
33
  super().__init__(__name__, __version__, __package__, logger)
33
34
 
34
35
 
@@ -39,3 +39,53 @@ TTSVoiceEmotion = Literal[
39
39
  "curiosity:high",
40
40
  "curiosity:highest",
41
41
  ]
42
+
43
+ # STT model definitions
44
+ STTEncoding = Literal["pcm_s16le",]
45
+
46
+ STTModels = Literal["ink-whisper"]
47
+ STTLanguages = Literal[
48
+ "en",
49
+ "de",
50
+ "es",
51
+ "fr",
52
+ "ja",
53
+ "pt",
54
+ "zh",
55
+ "hi",
56
+ "ko",
57
+ "it",
58
+ "nl",
59
+ "pl",
60
+ "ru",
61
+ "sv",
62
+ "tr",
63
+ "tl",
64
+ "bg",
65
+ "ro",
66
+ "ar",
67
+ "cs",
68
+ "el",
69
+ "fi",
70
+ "hr",
71
+ "ms",
72
+ "sk",
73
+ "da",
74
+ "ta",
75
+ "uk",
76
+ "hu",
77
+ "no",
78
+ "vi",
79
+ "bn",
80
+ "th",
81
+ "he",
82
+ "ka",
83
+ "id",
84
+ "te",
85
+ "gu",
86
+ "kn",
87
+ "ml",
88
+ "mr",
89
+ "or",
90
+ "pa",
91
+ ]
@@ -0,0 +1,474 @@
1
+ # Copyright 2023 LiveKit, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import asyncio
18
+ import json
19
+ import os
20
+ import uuid
21
+ import weakref
22
+ from dataclasses import dataclass
23
+ from enum import Enum
24
+
25
+ import aiohttp
26
+ import numpy as np
27
+
28
+ from livekit import rtc
29
+ from livekit.agents import (
30
+ DEFAULT_API_CONNECT_OPTIONS,
31
+ APIConnectOptions,
32
+ APIStatusError,
33
+ stt,
34
+ utils,
35
+ )
36
+ from livekit.agents.types import NOT_GIVEN, NotGivenOr
37
+ from livekit.agents.utils import is_given
38
+
39
+ from .log import logger
40
+ from .models import STTEncoding, STTLanguages, STTModels
41
+
42
+ API_AUTH_HEADER = "X-API-Key"
43
+ API_VERSION_HEADER = "Cartesia-Version"
44
+ API_VERSION = "2025-04-16"
45
+
46
+ # Audio energy threshold for speech detection
47
+ MAGIC_NUMBER_THRESHOLD = 0.004**2
48
+
49
+
50
+ class AudioEnergyFilter:
51
+ """Local voice activity detection based on audio energy levels."""
52
+
53
+ class State(Enum):
54
+ START = 0
55
+ SPEAKING = 1
56
+ SILENCE = 2
57
+ END = 3
58
+
59
+ def __init__(self, *, min_silence: float = 1.5, rms_threshold: float = MAGIC_NUMBER_THRESHOLD):
60
+ self._cooldown_seconds = min_silence
61
+ self._cooldown = min_silence
62
+ self._state = self.State.SILENCE
63
+ self._rms_threshold = rms_threshold
64
+
65
+ def update(self, frame: rtc.AudioFrame) -> State:
66
+ arr = np.frombuffer(frame.data, dtype=np.int16)
67
+ float_arr = arr.astype(np.float32) / 32768.0
68
+ rms = np.mean(np.square(float_arr))
69
+
70
+ if rms > self._rms_threshold:
71
+ self._cooldown = self._cooldown_seconds
72
+ if self._state in (self.State.SILENCE, self.State.END):
73
+ self._state = self.State.START
74
+ else:
75
+ self._state = self.State.SPEAKING
76
+ else:
77
+ if self._cooldown <= 0:
78
+ if self._state in (self.State.SPEAKING, self.State.START):
79
+ self._state = self.State.END
80
+ elif self._state == self.State.END:
81
+ self._state = self.State.SILENCE
82
+ else:
83
+ # keep speaking during cooldown
84
+ self._cooldown -= frame.duration
85
+ self._state = self.State.SPEAKING
86
+
87
+ return self._state
88
+
89
+
90
+ @dataclass
91
+ class STTOptions:
92
+ model: STTModels | str
93
+ language: STTLanguages | str | None
94
+ encoding: STTEncoding
95
+ sample_rate: int
96
+ api_key: str
97
+ base_url: str
98
+ energy_filter: AudioEnergyFilter | bool
99
+
100
+ def get_http_url(self, path: str) -> str:
101
+ return f"{self.base_url}{path}"
102
+
103
+ def get_ws_url(self, path: str) -> str:
104
+ # If base_url already has a protocol, replace it, otherwise add wss://
105
+ if self.base_url.startswith(("http://", "https://")):
106
+ return f"{self.base_url.replace('http', 'ws', 1)}{path}"
107
+ else:
108
+ return f"wss://{self.base_url}{path}"
109
+
110
+
111
+ class STT(stt.STT):
112
+ def __init__(
113
+ self,
114
+ *,
115
+ model: STTModels | str = "ink-whisper",
116
+ language: STTLanguages | str = "en",
117
+ encoding: STTEncoding = "pcm_s16le",
118
+ sample_rate: int = 16000,
119
+ api_key: str | None = None,
120
+ http_session: aiohttp.ClientSession | None = None,
121
+ base_url: str = "https://api.cartesia.ai",
122
+ energy_filter: AudioEnergyFilter | bool = False,
123
+ ) -> None:
124
+ """
125
+ Create a new instance of Cartesia STT.
126
+
127
+ Args:
128
+ model: The Cartesia STT model to use. Defaults to "ink-whisper".
129
+ language: The language code for recognition. Defaults to "en".
130
+ encoding: The audio encoding format. Defaults to "pcm_s16le".
131
+ sample_rate: The sample rate of the audio in Hz. Defaults to 16000.
132
+ api_key: The Cartesia API key. If not provided, it will be read from
133
+ the CARTESIA_API_KEY environment variable.
134
+ http_session: Optional aiohttp ClientSession to use for requests.
135
+ base_url: The base URL for the Cartesia API.
136
+ Defaults to "https://api.cartesia.ai".
137
+ energy_filter: The energy filter to use for local voice activity
138
+ detection. Defaults to False.
139
+
140
+ Raises:
141
+ ValueError: If no API key is provided or found in environment variables.
142
+ """
143
+ super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=False))
144
+
145
+ cartesia_api_key = api_key or os.environ.get("CARTESIA_API_KEY")
146
+ if not cartesia_api_key:
147
+ raise ValueError("CARTESIA_API_KEY must be set")
148
+
149
+ self._opts = STTOptions(
150
+ model=model,
151
+ language=language,
152
+ encoding=encoding,
153
+ sample_rate=sample_rate,
154
+ api_key=cartesia_api_key,
155
+ base_url=base_url,
156
+ energy_filter=AudioEnergyFilter() if energy_filter is True else energy_filter,
157
+ )
158
+ self._session = http_session
159
+ self._streams = weakref.WeakSet[SpeechStream]()
160
+
161
+ def _ensure_session(self) -> aiohttp.ClientSession:
162
+ if not self._session:
163
+ self._session = utils.http_context.http_session()
164
+ return self._session
165
+
166
+ async def _recognize_impl(
167
+ self,
168
+ buffer: utils.AudioBuffer,
169
+ *,
170
+ language: NotGivenOr[str] = NOT_GIVEN,
171
+ conn_options: APIConnectOptions,
172
+ ) -> stt.SpeechEvent:
173
+ raise NotImplementedError(
174
+ "Cartesia STT does not support batch recognition, use stream() instead"
175
+ )
176
+
177
+ def stream(
178
+ self,
179
+ *,
180
+ language: NotGivenOr[STTLanguages | str] = NOT_GIVEN,
181
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
182
+ ) -> SpeechStream:
183
+ """Create a streaming transcription session."""
184
+ config = self._sanitize_options(language=language)
185
+ stream = SpeechStream(
186
+ stt=self,
187
+ opts=config,
188
+ conn_options=conn_options,
189
+ )
190
+ self._streams.add(stream)
191
+ return stream
192
+
193
+ def update_options(
194
+ self,
195
+ *,
196
+ model: NotGivenOr[STTModels | str] = NOT_GIVEN,
197
+ language: NotGivenOr[STTLanguages | str] = NOT_GIVEN,
198
+ ) -> None:
199
+ """Update STT configuration options."""
200
+ if is_given(model):
201
+ self._opts.model = model
202
+ if is_given(language):
203
+ self._opts.language = language
204
+
205
+ # Update all active streams
206
+ for stream in self._streams:
207
+ stream.update_options(
208
+ model=model,
209
+ language=language,
210
+ )
211
+
212
+ def _sanitize_options(
213
+ self, *, language: NotGivenOr[STTLanguages | str] = NOT_GIVEN
214
+ ) -> STTOptions:
215
+ """Create a sanitized copy of options with language override if provided."""
216
+ config = STTOptions(
217
+ model=self._opts.model,
218
+ language=self._opts.language,
219
+ encoding=self._opts.encoding,
220
+ sample_rate=self._opts.sample_rate,
221
+ api_key=self._opts.api_key,
222
+ base_url=self._opts.base_url,
223
+ energy_filter=self._opts.energy_filter,
224
+ )
225
+
226
+ if is_given(language):
227
+ config.language = language
228
+
229
+ return config
230
+
231
+
232
+ class SpeechStream(stt.SpeechStream):
233
+ def __init__(
234
+ self,
235
+ *,
236
+ stt: STT,
237
+ opts: STTOptions,
238
+ conn_options: APIConnectOptions,
239
+ ) -> None:
240
+ super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
241
+ self._opts = opts
242
+ self._session = stt._ensure_session()
243
+ self._request_id = str(uuid.uuid4())
244
+ self._reconnect_event = asyncio.Event()
245
+ self._speaking = False
246
+
247
+ # Set up audio energy filter for local VAD
248
+ self._audio_energy_filter: AudioEnergyFilter | None = None
249
+ if opts.energy_filter:
250
+ if isinstance(opts.energy_filter, AudioEnergyFilter):
251
+ self._audio_energy_filter = opts.energy_filter
252
+ else:
253
+ self._audio_energy_filter = AudioEnergyFilter()
254
+
255
+ def update_options(
256
+ self,
257
+ *,
258
+ model: NotGivenOr[STTModels | str] = NOT_GIVEN,
259
+ language: NotGivenOr[STTLanguages | str] = NOT_GIVEN,
260
+ ) -> None:
261
+ """Update streaming transcription options."""
262
+ if is_given(model):
263
+ self._opts.model = model
264
+ if is_given(language):
265
+ self._opts.language = language
266
+
267
+ self._reconnect_event.set()
268
+
269
+ def _check_energy_state(self, frame: rtc.AudioFrame) -> AudioEnergyFilter.State:
270
+ """Check the energy state of an audio frame for voice activity detection."""
271
+ if self._audio_energy_filter:
272
+ return self._audio_energy_filter.update(frame)
273
+ return AudioEnergyFilter.State.SPEAKING
274
+
275
+ async def _run(self) -> None:
276
+ """Main loop for streaming transcription."""
277
+ closing_ws = False
278
+
279
+ async def keepalive_task(ws: aiohttp.ClientWebSocketResponse) -> None:
280
+ try:
281
+ while True:
282
+ await ws.ping()
283
+ await asyncio.sleep(30)
284
+ except Exception:
285
+ return
286
+
287
+ @utils.log_exceptions(logger=logger)
288
+ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
289
+ nonlocal closing_ws
290
+
291
+ # Forward audio to Cartesia in chunks
292
+ samples_50ms = self._opts.sample_rate // 20
293
+ audio_bstream = utils.audio.AudioByteStream(
294
+ sample_rate=self._opts.sample_rate,
295
+ num_channels=1,
296
+ samples_per_channel=samples_50ms,
297
+ )
298
+
299
+ has_ended = False
300
+ last_frame: rtc.AudioFrame | None = None
301
+ async for data in self._input_ch:
302
+ frames: list[rtc.AudioFrame] = []
303
+ if isinstance(data, rtc.AudioFrame):
304
+ state = self._check_energy_state(data)
305
+ if state in (
306
+ AudioEnergyFilter.State.START,
307
+ AudioEnergyFilter.State.SPEAKING,
308
+ ):
309
+ # Send buffered silence frame if we have one
310
+ if last_frame:
311
+ frames.extend(audio_bstream.write(last_frame.data.tobytes()))
312
+ last_frame = None
313
+ frames.extend(audio_bstream.write(data.data.tobytes()))
314
+
315
+ # Emit START_OF_SPEECH event if we just started speaking
316
+ if state == AudioEnergyFilter.State.START and not self._speaking:
317
+ self._speaking = True
318
+ start_event = stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
319
+ self._event_ch.send_nowait(start_event)
320
+
321
+ elif state == AudioEnergyFilter.State.END:
322
+ # Flush remaining audio and mark as ended
323
+ frames.extend(audio_bstream.flush())
324
+ has_ended = True
325
+ elif state == AudioEnergyFilter.State.SILENCE:
326
+ # Buffer the last silence frame in case it contains speech beginning
327
+ last_frame = data
328
+ elif isinstance(data, self._FlushSentinel):
329
+ frames.extend(audio_bstream.flush())
330
+ has_ended = True
331
+
332
+ for frame in frames:
333
+ await ws.send_bytes(frame.data.tobytes())
334
+
335
+ if has_ended:
336
+ has_ended = False
337
+
338
+ closing_ws = True
339
+ await ws.send_str("finalize")
340
+
341
+ @utils.log_exceptions(logger=logger)
342
+ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
343
+ nonlocal closing_ws
344
+ while True:
345
+ msg = await ws.receive()
346
+ if msg.type in (
347
+ aiohttp.WSMsgType.CLOSED,
348
+ aiohttp.WSMsgType.CLOSE,
349
+ aiohttp.WSMsgType.CLOSING,
350
+ ):
351
+ if closing_ws or self._session.closed:
352
+ return
353
+ raise APIStatusError(message="Cartesia STT connection closed unexpectedly")
354
+
355
+ if msg.type != aiohttp.WSMsgType.TEXT:
356
+ logger.warning("unexpected Cartesia STT message type %s", msg.type)
357
+ continue
358
+
359
+ try:
360
+ self._process_stream_event(json.loads(msg.data))
361
+ except Exception:
362
+ logger.exception("failed to process Cartesia STT message")
363
+
364
+ ws: aiohttp.ClientWebSocketResponse | None = None
365
+
366
+ while True:
367
+ try:
368
+ ws = await self._connect_ws()
369
+ tasks = [
370
+ asyncio.create_task(send_task(ws)),
371
+ asyncio.create_task(recv_task(ws)),
372
+ asyncio.create_task(keepalive_task(ws)),
373
+ ]
374
+ tasks_group = asyncio.gather(*tasks)
375
+ wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
376
+
377
+ try:
378
+ done, _ = await asyncio.wait(
379
+ (tasks_group, wait_reconnect_task),
380
+ return_when=asyncio.FIRST_COMPLETED,
381
+ )
382
+
383
+ for task in done:
384
+ if task != wait_reconnect_task:
385
+ task.result()
386
+
387
+ if wait_reconnect_task not in done:
388
+ break
389
+
390
+ self._reconnect_event.clear()
391
+ finally:
392
+ await utils.aio.gracefully_cancel(*tasks, wait_reconnect_task)
393
+ await tasks_group
394
+ finally:
395
+ if ws is not None:
396
+ await ws.close()
397
+
398
+ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
399
+ """Connect to the Cartesia STT WebSocket."""
400
+ params = {
401
+ "model": self._opts.model,
402
+ "sample_rate": str(self._opts.sample_rate),
403
+ "encoding": self._opts.encoding,
404
+ "cartesia_version": API_VERSION,
405
+ "api_key": self._opts.api_key,
406
+ }
407
+
408
+ if self._opts.language:
409
+ params["language"] = self._opts.language
410
+
411
+ # Build URL
412
+ url = self._opts.get_ws_url("/stt/websocket")
413
+ query_string = "&".join(f"{k}={v}" for k, v in params.items())
414
+ ws_url = f"{url}?{query_string}"
415
+
416
+ ws = await asyncio.wait_for(
417
+ self._session.ws_connect(ws_url),
418
+ self._conn_options.timeout,
419
+ )
420
+ return ws
421
+
422
+ def _process_stream_event(self, data: dict) -> None:
423
+ """Process incoming WebSocket messages."""
424
+ message_type = data.get("type")
425
+
426
+ if message_type == "transcript":
427
+ request_id = data.get("request_id", self._request_id)
428
+ text = data.get("text", "")
429
+ is_final = data.get("is_final", False)
430
+ language = data.get("language", self._opts.language or "en")
431
+
432
+ if not text and not is_final:
433
+ return
434
+
435
+ speech_data = stt.SpeechData(
436
+ language=language,
437
+ start_time=0, # Cartesia doesn't provide word-level timestamps in this version
438
+ end_time=data.get("duration", 0),
439
+ confidence=data.get("probability", 1.0),
440
+ text=text,
441
+ )
442
+
443
+ if is_final:
444
+ event = stt.SpeechEvent(
445
+ type=stt.SpeechEventType.FINAL_TRANSCRIPT,
446
+ request_id=request_id,
447
+ alternatives=[speech_data],
448
+ )
449
+ self._event_ch.send_nowait(event)
450
+
451
+ if self._speaking:
452
+ self._speaking = False
453
+ end_event = stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
454
+ self._event_ch.send_nowait(end_event)
455
+ else:
456
+ event = stt.SpeechEvent(
457
+ type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
458
+ request_id=request_id,
459
+ alternatives=[speech_data],
460
+ )
461
+ self._event_ch.send_nowait(event)
462
+
463
+ elif message_type == "flush_done":
464
+ logger.debug("Received flush_done acknowledgment from Cartesia STT")
465
+
466
+ elif message_type == "done":
467
+ logger.debug("Received done acknowledgment from Cartesia STT - session closing")
468
+
469
+ elif message_type == "error":
470
+ error_msg = data.get("message", "Unknown error")
471
+ logger.error("Cartesia STT error: %s", error_msg)
472
+ # We could emit an error event here if needed
473
+ else:
474
+ logger.warning("received unexpected message from Cartesia STT: %s", data)
@@ -19,8 +19,8 @@ import base64
19
19
  import json
20
20
  import os
21
21
  import weakref
22
- from dataclasses import dataclass
23
- from typing import Any
22
+ from dataclasses import dataclass, replace
23
+ from typing import Any, Optional, Union, cast
24
24
 
25
25
  import aiohttp
26
26
 
@@ -33,11 +33,7 @@ from livekit.agents import (
33
33
  tts,
34
34
  utils,
35
35
  )
36
- from livekit.agents.types import (
37
- DEFAULT_API_CONNECT_OPTIONS,
38
- NOT_GIVEN,
39
- NotGivenOr,
40
- )
36
+ from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, NotGivenOr
41
37
  from livekit.agents.utils import is_given
42
38
 
43
39
  from .log import logger
@@ -53,7 +49,6 @@ API_AUTH_HEADER = "X-API-Key"
53
49
  API_VERSION_HEADER = "Cartesia-Version"
54
50
  API_VERSION = "2024-06-10"
55
51
 
56
- NUM_CHANNELS = 1
57
52
  BUFFERED_WORDS_COUNT = 10
58
53
 
59
54
 
@@ -63,8 +58,8 @@ class _TTSOptions:
63
58
  encoding: TTSEncoding
64
59
  sample_rate: int
65
60
  voice: str | list[float]
66
- speed: NotGivenOr[TTSVoiceSpeed | float]
67
- emotion: NotGivenOr[list[TTSVoiceEmotion | str]]
61
+ speed: TTSVoiceSpeed | float | None
62
+ emotion: list[TTSVoiceEmotion | str] | None
68
63
  api_key: str
69
64
  language: str
70
65
  base_url: str
@@ -80,14 +75,14 @@ class TTS(tts.TTS):
80
75
  def __init__(
81
76
  self,
82
77
  *,
78
+ api_key: str | None = None,
83
79
  model: TTSModels | str = "sonic-2",
84
80
  language: str = "en",
85
81
  encoding: TTSEncoding = "pcm_s16le",
86
82
  voice: str | list[float] = TTSDefaultVoiceId,
87
- speed: NotGivenOr[TTSVoiceSpeed | float] = NOT_GIVEN,
88
- emotion: NotGivenOr[list[TTSVoiceEmotion | str]] = NOT_GIVEN,
83
+ speed: TTSVoiceSpeed | float | None = None,
84
+ emotion: list[TTSVoiceEmotion | str] | None = None,
89
85
  sample_rate: int = 24000,
90
- api_key: NotGivenOr[str] = NOT_GIVEN,
91
86
  http_session: aiohttp.ClientSession | None = None,
92
87
  base_url: str = "https://api.cartesia.ai",
93
88
  ) -> None:
@@ -112,9 +107,9 @@ class TTS(tts.TTS):
112
107
  super().__init__(
113
108
  capabilities=tts.TTSCapabilities(streaming=True),
114
109
  sample_rate=sample_rate,
115
- num_channels=NUM_CHANNELS,
110
+ num_channels=1,
116
111
  )
117
- cartesia_api_key = api_key if is_given(api_key) else os.environ.get("CARTESIA_API_KEY")
112
+ cartesia_api_key = api_key or os.environ.get("CARTESIA_API_KEY")
118
113
  if not cartesia_api_key:
119
114
  raise ValueError("CARTESIA_API_KEY must be set")
120
115
 
@@ -138,14 +133,14 @@ class TTS(tts.TTS):
138
133
  )
139
134
  self._streams = weakref.WeakSet[SynthesizeStream]()
140
135
 
141
- async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
136
+ async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse:
142
137
  session = self._ensure_session()
143
138
  url = self._opts.get_ws_url(
144
139
  f"/tts/websocket?api_key={self._opts.api_key}&cartesia_version={API_VERSION}"
145
140
  )
146
- return await asyncio.wait_for(session.ws_connect(url), self._conn_options.timeout)
141
+ return await asyncio.wait_for(session.ws_connect(url), timeout)
147
142
 
148
- async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse):
143
+ async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse) -> None:
149
144
  await ws.close()
150
145
 
151
146
  def _ensure_session(self) -> aiohttp.ClientSession:
@@ -163,8 +158,8 @@ class TTS(tts.TTS):
163
158
  model: NotGivenOr[TTSModels | str] = NOT_GIVEN,
164
159
  language: NotGivenOr[str] = NOT_GIVEN,
165
160
  voice: NotGivenOr[str | list[float]] = NOT_GIVEN,
166
- speed: NotGivenOr[TTSVoiceSpeed | float] = NOT_GIVEN,
167
- emotion: NotGivenOr[list[TTSVoiceEmotion | str]] = NOT_GIVEN,
161
+ speed: NotGivenOr[TTSVoiceSpeed | float | None] = NOT_GIVEN,
162
+ emotion: NotGivenOr[list[TTSVoiceEmotion | str] | None] = NOT_GIVEN,
168
163
  ) -> None:
169
164
  """
170
165
  Update the Text-to-Speech (TTS) configuration options.
@@ -184,158 +179,123 @@ class TTS(tts.TTS):
184
179
  if is_given(language):
185
180
  self._opts.language = language
186
181
  if is_given(voice):
187
- self._opts.voice = voice
182
+ self._opts.voice = cast(Union[str, list[float]], voice)
188
183
  if is_given(speed):
189
- self._opts.speed = speed
184
+ self._opts.speed = cast(Optional[Union[TTSVoiceSpeed, float]], speed)
190
185
  if is_given(emotion):
191
186
  self._opts.emotion = emotion
192
187
 
193
188
  def synthesize(
194
- self,
195
- text: str,
196
- *,
197
- conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
189
+ self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
198
190
  ) -> ChunkedStream:
199
- return ChunkedStream(
200
- tts=self,
201
- input_text=text,
202
- conn_options=conn_options,
203
- opts=self._opts,
204
- session=self._ensure_session(),
205
- )
191
+ return ChunkedStream(tts=self, input_text=text, conn_options=conn_options)
206
192
 
207
193
  def stream(
208
194
  self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
209
195
  ) -> SynthesizeStream:
210
- return SynthesizeStream(
211
- tts=self,
212
- pool=self._pool,
213
- opts=self._opts,
214
- )
196
+ return SynthesizeStream(tts=self, conn_options=conn_options)
215
197
 
216
198
  async def aclose(self) -> None:
217
199
  for stream in list(self._streams):
218
200
  await stream.aclose()
201
+
219
202
  self._streams.clear()
220
203
  await self._pool.aclose()
221
- await super().aclose()
222
204
 
223
205
 
224
206
  class ChunkedStream(tts.ChunkedStream):
225
207
  """Synthesize chunked text using the bytes endpoint"""
226
208
 
227
- def __init__(
228
- self,
229
- *,
230
- tts: TTS,
231
- input_text: str,
232
- opts: _TTSOptions,
233
- session: aiohttp.ClientSession,
234
- conn_options: APIConnectOptions,
235
- ) -> None:
209
+ def __init__(self, *, tts: TTS, input_text: str, conn_options: APIConnectOptions) -> None:
236
210
  super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
237
- self._opts, self._session = opts, session
238
-
239
- async def _run(self) -> None:
240
- request_id = utils.shortuuid()
241
- bstream = utils.audio.AudioByteStream(
242
- sample_rate=self._opts.sample_rate, num_channels=NUM_CHANNELS
243
- )
211
+ self._tts: TTS = tts
212
+ self._opts = replace(tts._opts)
244
213
 
214
+ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
245
215
  json = _to_cartesia_options(self._opts)
246
216
  json["transcript"] = self._input_text
247
217
 
248
- headers = {
249
- API_AUTH_HEADER: self._opts.api_key,
250
- API_VERSION_HEADER: API_VERSION,
251
- }
252
-
253
218
  try:
254
- async with self._session.post(
219
+ async with self._tts._ensure_session().post(
255
220
  self._opts.get_http_url("/tts/bytes"),
256
- headers=headers,
221
+ headers={
222
+ API_AUTH_HEADER: self._opts.api_key,
223
+ API_VERSION_HEADER: API_VERSION,
224
+ },
257
225
  json=json,
258
- timeout=aiohttp.ClientTimeout(
259
- total=30,
260
- sock_connect=self._conn_options.timeout,
261
- ),
226
+ timeout=aiohttp.ClientTimeout(total=30, sock_connect=self._conn_options.timeout),
262
227
  ) as resp:
263
228
  resp.raise_for_status()
264
- emitter = tts.SynthesizedAudioEmitter(
265
- event_ch=self._event_ch,
266
- request_id=request_id,
229
+
230
+ output_emitter.initialize(
231
+ request_id=utils.shortuuid(),
232
+ sample_rate=self._opts.sample_rate,
233
+ num_channels=1,
234
+ mime_type="audio/pcm",
267
235
  )
236
+
268
237
  async for data, _ in resp.content.iter_chunks():
269
- for frame in bstream.write(data):
270
- emitter.push(frame)
238
+ output_emitter.push(data)
271
239
 
272
- for frame in bstream.flush():
273
- emitter.push(frame)
274
- emitter.flush()
240
+ output_emitter.flush()
275
241
  except asyncio.TimeoutError:
276
242
  raise APITimeoutError() from None
277
243
  except aiohttp.ClientResponseError as e:
278
244
  raise APIStatusError(
279
- message=e.message,
280
- status_code=e.status,
281
- request_id=None,
282
- body=None,
245
+ message=e.message, status_code=e.status, request_id=None, body=None
283
246
  ) from None
284
247
  except Exception as e:
285
248
  raise APIConnectionError() from e
286
249
 
287
250
 
288
251
  class SynthesizeStream(tts.SynthesizeStream):
289
- def __init__(
290
- self,
291
- *,
292
- tts: TTS,
293
- opts: _TTSOptions,
294
- pool: utils.ConnectionPool[aiohttp.ClientWebSocketResponse],
295
- ):
296
- super().__init__(tts=tts)
297
- self._opts, self._pool = opts, pool
252
+ def __init__(self, *, tts: TTS, conn_options: APIConnectOptions):
253
+ super().__init__(tts=tts, conn_options=conn_options)
254
+ self._tts: TTS = tts
298
255
  self._sent_tokenizer_stream = tokenize.basic.SentenceTokenizer(
299
256
  min_sentence_len=BUFFERED_WORDS_COUNT
300
257
  ).stream()
258
+ self._opts = replace(tts._opts)
301
259
 
302
- async def _run(self) -> None:
260
+ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
303
261
  request_id = utils.shortuuid()
262
+ output_emitter.initialize(
263
+ request_id=request_id,
264
+ sample_rate=self._opts.sample_rate,
265
+ num_channels=1,
266
+ mime_type="audio/pcm",
267
+ stream=True,
268
+ )
304
269
 
305
- async def _sentence_stream_task(ws: aiohttp.ClientWebSocketResponse):
270
+ async def _sentence_stream_task(ws: aiohttp.ClientWebSocketResponse) -> None:
271
+ context_id = utils.shortuuid()
306
272
  base_pkt = _to_cartesia_options(self._opts)
307
273
  async for ev in self._sent_tokenizer_stream:
308
274
  token_pkt = base_pkt.copy()
309
- token_pkt["context_id"] = request_id
275
+ token_pkt["context_id"] = context_id
310
276
  token_pkt["transcript"] = ev.token + " "
311
277
  token_pkt["continue"] = True
312
278
  self._mark_started()
313
279
  await ws.send_str(json.dumps(token_pkt))
314
280
 
315
281
  end_pkt = base_pkt.copy()
316
- end_pkt["context_id"] = request_id
282
+ end_pkt["context_id"] = context_id
317
283
  end_pkt["transcript"] = " "
318
284
  end_pkt["continue"] = False
319
285
  await ws.send_str(json.dumps(end_pkt))
320
286
 
321
- async def _input_task():
287
+ async def _input_task() -> None:
322
288
  async for data in self._input_ch:
323
289
  if isinstance(data, self._FlushSentinel):
324
290
  self._sent_tokenizer_stream.flush()
325
291
  continue
292
+
326
293
  self._sent_tokenizer_stream.push_text(data)
327
- self._sent_tokenizer_stream.end_input()
328
294
 
329
- async def _recv_task(ws: aiohttp.ClientWebSocketResponse):
330
- audio_bstream = utils.audio.AudioByteStream(
331
- sample_rate=self._opts.sample_rate,
332
- num_channels=NUM_CHANNELS,
333
- )
334
- emitter = tts.SynthesizedAudioEmitter(
335
- event_ch=self._event_ch,
336
- request_id=request_id,
337
- )
295
+ self._sent_tokenizer_stream.end_input()
338
296
 
297
+ async def _recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
298
+ current_segment_id: str | None = None
339
299
  while True:
340
300
  msg = await ws.receive()
341
301
  if msg.type in (
@@ -344,8 +304,7 @@ class SynthesizeStream(tts.SynthesizeStream):
344
304
  aiohttp.WSMsgType.CLOSING,
345
305
  ):
346
306
  raise APIStatusError(
347
- "Cartesia connection closed unexpectedly",
348
- request_id=request_id,
307
+ "Cartesia connection closed unexpectedly", request_id=request_id
349
308
  )
350
309
 
351
310
  if msg.type != aiohttp.WSMsgType.TEXT:
@@ -354,49 +313,54 @@ class SynthesizeStream(tts.SynthesizeStream):
354
313
 
355
314
  data = json.loads(msg.data)
356
315
  segment_id = data.get("context_id")
357
- emitter._segment_id = segment_id
358
-
316
+ if current_segment_id is None:
317
+ current_segment_id = segment_id
318
+ output_emitter.start_segment(segment_id=segment_id)
359
319
  if data.get("data"):
360
320
  b64data = base64.b64decode(data["data"])
361
- for frame in audio_bstream.write(b64data):
362
- emitter.push(frame)
321
+ output_emitter.push(b64data)
363
322
  elif data.get("done"):
364
- for frame in audio_bstream.flush():
365
- emitter.push(frame)
366
- emitter.flush()
367
- if segment_id == request_id:
368
- # we're not going to receive more frames, end stream
369
- break
323
+ output_emitter.end_input()
324
+ break
370
325
  else:
371
- logger.error("unexpected Cartesia message %s", data)
326
+ logger.warning("unexpected message %s", data)
372
327
 
373
- async with self._pool.connection() as ws:
374
- tasks = [
375
- asyncio.create_task(_input_task()),
376
- asyncio.create_task(_sentence_stream_task(ws)),
377
- asyncio.create_task(_recv_task(ws)),
378
- ]
379
-
380
- try:
381
- await asyncio.gather(*tasks)
382
- finally:
383
- await utils.aio.gracefully_cancel(*tasks)
328
+ try:
329
+ async with self._tts._pool.connection(timeout=self._conn_options.timeout) as ws:
330
+ tasks = [
331
+ asyncio.create_task(_input_task()),
332
+ asyncio.create_task(_sentence_stream_task(ws)),
333
+ asyncio.create_task(_recv_task(ws)),
334
+ ]
335
+
336
+ try:
337
+ await asyncio.gather(*tasks)
338
+ finally:
339
+ await utils.aio.gracefully_cancel(*tasks)
340
+ except asyncio.TimeoutError:
341
+ raise APITimeoutError() from None
342
+ except aiohttp.ClientResponseError as e:
343
+ raise APIStatusError(
344
+ message=e.message, status_code=e.status, request_id=None, body=None
345
+ ) from None
346
+ except Exception as e:
347
+ raise APIConnectionError() from e
384
348
 
385
349
 
386
350
  def _to_cartesia_options(opts: _TTSOptions) -> dict[str, Any]:
387
351
  voice: dict[str, Any] = {}
388
- if is_given(opts.voice):
389
- if isinstance(opts.voice, str):
390
- voice["mode"] = "id"
391
- voice["id"] = opts.voice
392
- else:
393
- voice["mode"] = "embedding"
394
- voice["embedding"] = opts.voice
352
+ if isinstance(opts.voice, str):
353
+ voice["mode"] = "id"
354
+ voice["id"] = opts.voice
355
+ else:
356
+ voice["mode"] = "embedding"
357
+ voice["embedding"] = opts.voice
395
358
 
396
359
  voice_controls: dict = {}
397
- if is_given(opts.speed):
360
+ if opts.speed:
398
361
  voice_controls["speed"] = opts.speed
399
- if is_given(opts.emotion):
362
+
363
+ if opts.emotion:
400
364
  voice_controls["emotion"] = opts.emotion
401
365
 
402
366
  if voice_controls:
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __version__ = "1.0.22"
15
+ __version__ = "1.1.0"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: livekit-plugins-cartesia
3
- Version: 1.0.22
3
+ Version: 1.1.0
4
4
  Summary: LiveKit Agents Plugin for Cartesia
5
5
  Project-URL: Documentation, https://docs.livekit.io
6
6
  Project-URL: Website, https://livekit.io/
@@ -18,7 +18,7 @@ Classifier: Topic :: Multimedia :: Sound/Audio
18
18
  Classifier: Topic :: Multimedia :: Video
19
19
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Requires-Python: >=3.9.0
21
- Requires-Dist: livekit-agents>=1.0.22
21
+ Requires-Dist: livekit-agents>=1.1.0
22
22
  Description-Content-Type: text/markdown
23
23
 
24
24
  # Cartesia plugin for LiveKit Agents
@@ -0,0 +1,10 @@
1
+ livekit/plugins/cartesia/__init__.py,sha256=n8BvjZSpYiYFxOg3Hyh-UuyG7XeQw9uP48_OPDSBWdE,1259
2
+ livekit/plugins/cartesia/log.py,sha256=4Mnhjng_DU1dIWP9IWjIQGZ67EV3LnQhWMWCHVudJbo,71
3
+ livekit/plugins/cartesia/models.py,sha256=TIJQa9gNKj_1t09XUjXN5hIrp6_xG1O7YZfVrr0KG4M,1530
4
+ livekit/plugins/cartesia/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ livekit/plugins/cartesia/stt.py,sha256=2GY2o90s-Vp0E8UX89maJsY6r0D-I225L8Etv714OJs,17211
6
+ livekit/plugins/cartesia/tts.py,sha256=gyTJIVmlA8HsWe51LCvSTLVKyO66eQZRGDZjQOOlU1E,14060
7
+ livekit/plugins/cartesia/version.py,sha256=7SjyflIFTjH0djSotKGIRoRykPCqMpVYetIlvHMFuh0,600
8
+ livekit_plugins_cartesia-1.1.0.dist-info/METADATA,sha256=FxSF1dGRP7fLTEOT27IXgY3Eu-3nbpTdt8JCoGdFsPg,1329
9
+ livekit_plugins_cartesia-1.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ livekit_plugins_cartesia-1.1.0.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- livekit/plugins/cartesia/__init__.py,sha256=DFnl1khtyLstonZ6-FzIItl6ob9132SbZDLFRfremVs,1223
2
- livekit/plugins/cartesia/log.py,sha256=4Mnhjng_DU1dIWP9IWjIQGZ67EV3LnQhWMWCHVudJbo,71
3
- livekit/plugins/cartesia/models.py,sha256=KGY-r2luJuUNY6a3nnB0Rx-5Td12hikk-GtYLnqvysE,977
4
- livekit/plugins/cartesia/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- livekit/plugins/cartesia/tts.py,sha256=g3RPmTGyMjL0sG6lS-1zaq4Pa1DO2DmKfAnFeJwnHtY,14445
6
- livekit/plugins/cartesia/version.py,sha256=-8dkOE2vDSF9WN8VoBrSwU2sb5YBGFuwPnSQXQ-uaYM,601
7
- livekit_plugins_cartesia-1.0.22.dist-info/METADATA,sha256=9qFxQqS_sHBnR1i30Qx17_Ura2azcO6W8RaWKSqTaIU,1331
8
- livekit_plugins_cartesia-1.0.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
- livekit_plugins_cartesia-1.0.22.dist-info/RECORD,,