livekit-plugins-deepgram 1.0.23__tar.gz → 1.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {livekit_plugins_deepgram-1.0.23 → livekit_plugins_deepgram-1.1.1}/.gitignore +3 -0
- {livekit_plugins_deepgram-1.0.23 → livekit_plugins_deepgram-1.1.1}/PKG-INFO +2 -2
- {livekit_plugins_deepgram-1.0.23 → livekit_plugins_deepgram-1.1.1}/livekit/plugins/deepgram/__init__.py +1 -1
- {livekit_plugins_deepgram-1.0.23 → livekit_plugins_deepgram-1.1.1}/livekit/plugins/deepgram/stt.py +12 -12
- livekit_plugins_deepgram-1.1.1/livekit/plugins/deepgram/tts.py +318 -0
- {livekit_plugins_deepgram-1.0.23 → livekit_plugins_deepgram-1.1.1}/livekit/plugins/deepgram/version.py +1 -1
- {livekit_plugins_deepgram-1.0.23 → livekit_plugins_deepgram-1.1.1}/pyproject.toml +1 -1
- livekit_plugins_deepgram-1.0.23/livekit/plugins/deepgram/tts.py +0 -438
- {livekit_plugins_deepgram-1.0.23 → livekit_plugins_deepgram-1.1.1}/README.md +0 -0
- {livekit_plugins_deepgram-1.0.23 → livekit_plugins_deepgram-1.1.1}/livekit/plugins/deepgram/_utils.py +0 -0
- {livekit_plugins_deepgram-1.0.23 → livekit_plugins_deepgram-1.1.1}/livekit/plugins/deepgram/log.py +0 -0
- {livekit_plugins_deepgram-1.0.23 → livekit_plugins_deepgram-1.1.1}/livekit/plugins/deepgram/models.py +0 -0
- {livekit_plugins_deepgram-1.0.23 → livekit_plugins_deepgram-1.1.1}/livekit/plugins/deepgram/py.typed +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: livekit-plugins-deepgram
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.1.1
|
4
4
|
Summary: Agent Framework plugin for services using Deepgram's API.
|
5
5
|
Project-URL: Documentation, https://docs.livekit.io
|
6
6
|
Project-URL: Website, https://livekit.io/
|
@@ -18,7 +18,7 @@ Classifier: Topic :: Multimedia :: Sound/Audio
|
|
18
18
|
Classifier: Topic :: Multimedia :: Video
|
19
19
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
20
20
|
Requires-Python: >=3.9.0
|
21
|
-
Requires-Dist: livekit-agents[codecs]>=1.
|
21
|
+
Requires-Dist: livekit-agents[codecs]>=1.1.1
|
22
22
|
Requires-Dist: numpy>=1.26
|
23
23
|
Description-Content-Type: text/markdown
|
24
24
|
|
{livekit_plugins_deepgram-1.0.23 → livekit_plugins_deepgram-1.1.1}/livekit/plugins/deepgram/stt.py
RENAMED
@@ -94,7 +94,7 @@ class AudioEnergyFilter:
|
|
94
94
|
|
95
95
|
@dataclass
|
96
96
|
class STTOptions:
|
97
|
-
language: DeepgramLanguages | str
|
97
|
+
language: DeepgramLanguages | str | None
|
98
98
|
detect_language: bool
|
99
99
|
interim_results: bool
|
100
100
|
punctuate: bool
|
@@ -181,9 +181,10 @@ class STT(stt.STT):
|
|
181
181
|
)
|
182
182
|
self._base_url = base_url
|
183
183
|
|
184
|
-
|
185
|
-
if not
|
184
|
+
deepgram_api_key = api_key if is_given(api_key) else os.environ.get("DEEPGRAM_API_KEY")
|
185
|
+
if not deepgram_api_key:
|
186
186
|
raise ValueError("Deepgram API key is required")
|
187
|
+
self._api_key = deepgram_api_key
|
187
188
|
|
188
189
|
model = _validate_model(model, language)
|
189
190
|
_validate_keyterms(model, language, keyterms, keywords)
|
@@ -305,7 +306,7 @@ class STT(stt.STT):
|
|
305
306
|
numerals: NotGivenOr[bool] = NOT_GIVEN,
|
306
307
|
mip_opt_out: NotGivenOr[bool] = NOT_GIVEN,
|
307
308
|
tags: NotGivenOr[list[str]] = NOT_GIVEN,
|
308
|
-
):
|
309
|
+
) -> None:
|
309
310
|
if is_given(language):
|
310
311
|
self._opts.language = language
|
311
312
|
if is_given(model):
|
@@ -383,14 +384,13 @@ class SpeechStream(stt.SpeechStream):
|
|
383
384
|
http_session: aiohttp.ClientSession,
|
384
385
|
base_url: str,
|
385
386
|
) -> None:
|
386
|
-
super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
|
387
|
-
|
388
387
|
if opts.detect_language or opts.language is None:
|
389
388
|
raise ValueError(
|
390
389
|
"language detection is not supported in streaming mode, "
|
391
390
|
"please disable it and specify a language"
|
392
391
|
)
|
393
392
|
|
393
|
+
super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
|
394
394
|
self._opts = opts
|
395
395
|
self._api_key = api_key
|
396
396
|
self._session = http_session
|
@@ -429,7 +429,7 @@ class SpeechStream(stt.SpeechStream):
|
|
429
429
|
numerals: NotGivenOr[bool] = NOT_GIVEN,
|
430
430
|
mip_opt_out: NotGivenOr[bool] = NOT_GIVEN,
|
431
431
|
tags: NotGivenOr[list[str]] = NOT_GIVEN,
|
432
|
-
):
|
432
|
+
) -> None:
|
433
433
|
if is_given(language):
|
434
434
|
self._opts.language = language
|
435
435
|
if is_given(model):
|
@@ -466,7 +466,7 @@ class SpeechStream(stt.SpeechStream):
|
|
466
466
|
async def _run(self) -> None:
|
467
467
|
closing_ws = False
|
468
468
|
|
469
|
-
async def keepalive_task(ws: aiohttp.ClientWebSocketResponse):
|
469
|
+
async def keepalive_task(ws: aiohttp.ClientWebSocketResponse) -> None:
|
470
470
|
# if we want to keep the connection alive even if no audio is sent,
|
471
471
|
# Deepgram expects a keepalive message.
|
472
472
|
# https://developers.deepgram.com/reference/listen-live#stream-keepalive
|
@@ -478,7 +478,7 @@ class SpeechStream(stt.SpeechStream):
|
|
478
478
|
return
|
479
479
|
|
480
480
|
@utils.log_exceptions(logger=logger)
|
481
|
-
async def send_task(ws: aiohttp.ClientWebSocketResponse):
|
481
|
+
async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
|
482
482
|
nonlocal closing_ws
|
483
483
|
|
484
484
|
# forward audio to deepgram in chunks of 50ms
|
@@ -529,7 +529,7 @@ class SpeechStream(stt.SpeechStream):
|
|
529
529
|
await ws.send_str(SpeechStream._CLOSE_MSG)
|
530
530
|
|
531
531
|
@utils.log_exceptions(logger=logger)
|
532
|
-
async def recv_task(ws: aiohttp.ClientWebSocketResponse):
|
532
|
+
async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
|
533
533
|
nonlocal closing_ws
|
534
534
|
while True:
|
535
535
|
msg = await ws.receive()
|
@@ -569,9 +569,9 @@ class SpeechStream(stt.SpeechStream):
|
|
569
569
|
wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
|
570
570
|
try:
|
571
571
|
done, _ = await asyncio.wait(
|
572
|
-
|
572
|
+
(tasks_group, wait_reconnect_task),
|
573
573
|
return_when=asyncio.FIRST_COMPLETED,
|
574
|
-
)
|
574
|
+
)
|
575
575
|
|
576
576
|
# propagate exceptions from completed tasks
|
577
577
|
for task in done:
|
@@ -0,0 +1,318 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import json
|
5
|
+
import os
|
6
|
+
import weakref
|
7
|
+
from dataclasses import dataclass, replace
|
8
|
+
|
9
|
+
import aiohttp
|
10
|
+
|
11
|
+
from livekit.agents import (
|
12
|
+
APIConnectionError,
|
13
|
+
APIConnectOptions,
|
14
|
+
APIStatusError,
|
15
|
+
APITimeoutError,
|
16
|
+
tokenize,
|
17
|
+
tts,
|
18
|
+
utils,
|
19
|
+
)
|
20
|
+
from livekit.agents.types import (
|
21
|
+
DEFAULT_API_CONNECT_OPTIONS,
|
22
|
+
NOT_GIVEN,
|
23
|
+
NotGivenOr,
|
24
|
+
)
|
25
|
+
from livekit.agents.utils import is_given
|
26
|
+
|
27
|
+
from ._utils import _to_deepgram_url
|
28
|
+
from .log import logger
|
29
|
+
|
30
|
+
BASE_URL = "https://api.deepgram.com/v1/speak"
|
31
|
+
NUM_CHANNELS = 1
|
32
|
+
|
33
|
+
|
34
|
+
@dataclass
|
35
|
+
class _TTSOptions:
|
36
|
+
model: str
|
37
|
+
encoding: str
|
38
|
+
sample_rate: int
|
39
|
+
word_tokenizer: tokenize.WordTokenizer
|
40
|
+
base_url: str
|
41
|
+
api_key: str
|
42
|
+
mip_opt_out: bool = False
|
43
|
+
|
44
|
+
|
45
|
+
class TTS(tts.TTS):
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
*,
|
49
|
+
model: str = "aura-2-andromeda-en",
|
50
|
+
encoding: str = "linear16",
|
51
|
+
sample_rate: int = 24000,
|
52
|
+
api_key: str | None = None,
|
53
|
+
base_url: str = BASE_URL,
|
54
|
+
word_tokenizer: NotGivenOr[tokenize.WordTokenizer] = NOT_GIVEN,
|
55
|
+
http_session: aiohttp.ClientSession | None = None,
|
56
|
+
mip_opt_out: bool = False,
|
57
|
+
) -> None:
|
58
|
+
"""
|
59
|
+
Create a new instance of Deepgram TTS.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
model (str): TTS model to use. Defaults to "aura-2-andromeda-en".
|
63
|
+
encoding (str): Audio encoding to use. Defaults to "linear16".
|
64
|
+
sample_rate (int): Sample rate of audio. Defaults to 24000.
|
65
|
+
api_key (str): Deepgram API key. If not provided, will look for DEEPGRAM_API_KEY in environment.
|
66
|
+
base_url (str): Base URL for Deepgram TTS API. Defaults to "https://api.deepgram.com/v1/speak"
|
67
|
+
word_tokenizer (tokenize.WordTokenizer): Tokenizer for processing text. Defaults to basic WordTokenizer.
|
68
|
+
http_session (aiohttp.ClientSession): Optional aiohttp session to use for requests.
|
69
|
+
|
70
|
+
""" # noqa: E501
|
71
|
+
super().__init__(
|
72
|
+
capabilities=tts.TTSCapabilities(streaming=True),
|
73
|
+
sample_rate=sample_rate,
|
74
|
+
num_channels=NUM_CHANNELS,
|
75
|
+
)
|
76
|
+
|
77
|
+
api_key = api_key or os.environ.get("DEEPGRAM_API_KEY")
|
78
|
+
if not api_key:
|
79
|
+
raise ValueError("Deepgram API key required. Set DEEPGRAM_API_KEY or provide api_key.")
|
80
|
+
|
81
|
+
if not is_given(word_tokenizer):
|
82
|
+
word_tokenizer = tokenize.basic.WordTokenizer(ignore_punctuation=False)
|
83
|
+
|
84
|
+
self._opts = _TTSOptions(
|
85
|
+
model=model,
|
86
|
+
encoding=encoding,
|
87
|
+
sample_rate=sample_rate,
|
88
|
+
word_tokenizer=word_tokenizer,
|
89
|
+
base_url=base_url,
|
90
|
+
api_key=api_key,
|
91
|
+
mip_opt_out=mip_opt_out,
|
92
|
+
)
|
93
|
+
self._session = http_session
|
94
|
+
self._streams = weakref.WeakSet[SynthesizeStream]()
|
95
|
+
|
96
|
+
self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
|
97
|
+
connect_cb=self._connect_ws,
|
98
|
+
close_cb=self._close_ws,
|
99
|
+
max_session_duration=3600, # 1 hour
|
100
|
+
mark_refreshed_on_get=False,
|
101
|
+
)
|
102
|
+
|
103
|
+
async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse:
|
104
|
+
session = self._ensure_session()
|
105
|
+
config = {
|
106
|
+
"encoding": self._opts.encoding,
|
107
|
+
"model": self._opts.model,
|
108
|
+
"sample_rate": self._opts.sample_rate,
|
109
|
+
"mip_opt_out": self._opts.mip_opt_out,
|
110
|
+
}
|
111
|
+
return await asyncio.wait_for(
|
112
|
+
session.ws_connect(
|
113
|
+
_to_deepgram_url(config, self._opts.base_url, websocket=True),
|
114
|
+
headers={"Authorization": f"Token {self._opts.api_key}"},
|
115
|
+
),
|
116
|
+
timeout,
|
117
|
+
)
|
118
|
+
|
119
|
+
async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse) -> None:
|
120
|
+
await ws.close()
|
121
|
+
|
122
|
+
def _ensure_session(self) -> aiohttp.ClientSession:
|
123
|
+
if not self._session:
|
124
|
+
self._session = utils.http_context.http_session()
|
125
|
+
return self._session
|
126
|
+
|
127
|
+
def update_options(
|
128
|
+
self,
|
129
|
+
*,
|
130
|
+
model: NotGivenOr[str] = NOT_GIVEN,
|
131
|
+
) -> None:
|
132
|
+
"""
|
133
|
+
Args:
|
134
|
+
model (str): TTS model to use.
|
135
|
+
"""
|
136
|
+
if is_given(model):
|
137
|
+
self._opts.model = model
|
138
|
+
|
139
|
+
def synthesize(
|
140
|
+
self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
|
141
|
+
) -> ChunkedStream:
|
142
|
+
return ChunkedStream(tts=self, input_text=text, conn_options=conn_options)
|
143
|
+
|
144
|
+
def stream(
|
145
|
+
self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
|
146
|
+
) -> SynthesizeStream:
|
147
|
+
stream = SynthesizeStream(tts=self, conn_options=conn_options)
|
148
|
+
self._streams.add(stream)
|
149
|
+
return stream
|
150
|
+
|
151
|
+
def prewarm(self) -> None:
|
152
|
+
self._pool.prewarm()
|
153
|
+
|
154
|
+
async def aclose(self) -> None:
|
155
|
+
for stream in list(self._streams):
|
156
|
+
await stream.aclose()
|
157
|
+
|
158
|
+
self._streams.clear()
|
159
|
+
|
160
|
+
await self._pool.aclose()
|
161
|
+
|
162
|
+
|
163
|
+
class ChunkedStream(tts.ChunkedStream):
|
164
|
+
def __init__(self, *, tts: TTS, input_text: str, conn_options: APIConnectOptions) -> None:
|
165
|
+
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
|
166
|
+
self._tts: TTS = tts
|
167
|
+
self._opts = replace(tts._opts)
|
168
|
+
|
169
|
+
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
170
|
+
try:
|
171
|
+
async with self._tts._ensure_session().post(
|
172
|
+
_to_deepgram_url(
|
173
|
+
{
|
174
|
+
"encoding": self._opts.encoding,
|
175
|
+
"container": "none",
|
176
|
+
"model": self._opts.model,
|
177
|
+
"sample_rate": self._opts.sample_rate,
|
178
|
+
"mip_opt_out": self._opts.mip_opt_out,
|
179
|
+
},
|
180
|
+
self._opts.base_url,
|
181
|
+
websocket=False,
|
182
|
+
),
|
183
|
+
headers={
|
184
|
+
"Authorization": f"Token {self._opts.api_key}",
|
185
|
+
"Content-Type": "application/json",
|
186
|
+
},
|
187
|
+
json={"text": self._input_text},
|
188
|
+
timeout=aiohttp.ClientTimeout(total=30, sock_connect=self._conn_options.timeout),
|
189
|
+
) as resp:
|
190
|
+
resp.raise_for_status()
|
191
|
+
|
192
|
+
output_emitter.initialize(
|
193
|
+
request_id=utils.shortuuid(),
|
194
|
+
sample_rate=self._opts.sample_rate,
|
195
|
+
num_channels=NUM_CHANNELS,
|
196
|
+
mime_type="audio/pcm",
|
197
|
+
)
|
198
|
+
|
199
|
+
async for data, _ in resp.content.iter_chunks():
|
200
|
+
output_emitter.push(data)
|
201
|
+
|
202
|
+
output_emitter.flush()
|
203
|
+
|
204
|
+
except asyncio.TimeoutError:
|
205
|
+
raise APITimeoutError() from None
|
206
|
+
except aiohttp.ClientResponseError as e:
|
207
|
+
raise APIStatusError(
|
208
|
+
message=e.message, status_code=e.status, request_id=None, body=None
|
209
|
+
) from None
|
210
|
+
except Exception as e:
|
211
|
+
raise APIConnectionError() from e
|
212
|
+
|
213
|
+
|
214
|
+
class SynthesizeStream(tts.SynthesizeStream):
|
215
|
+
def __init__(self, *, tts: TTS, conn_options: APIConnectOptions):
|
216
|
+
super().__init__(tts=tts, conn_options=conn_options)
|
217
|
+
self._tts: TTS = tts
|
218
|
+
self._opts = replace(tts._opts)
|
219
|
+
self._segments_ch = utils.aio.Chan[tokenize.WordStream]()
|
220
|
+
|
221
|
+
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
222
|
+
request_id = utils.shortuuid()
|
223
|
+
output_emitter.initialize(
|
224
|
+
request_id=request_id,
|
225
|
+
sample_rate=self._opts.sample_rate,
|
226
|
+
num_channels=1,
|
227
|
+
mime_type="audio/pcm",
|
228
|
+
stream=True,
|
229
|
+
)
|
230
|
+
|
231
|
+
async def _tokenize_input() -> None:
|
232
|
+
# Converts incoming text into WordStreams and sends them into _segments_ch
|
233
|
+
word_stream = None
|
234
|
+
async for input in self._input_ch:
|
235
|
+
if isinstance(input, str):
|
236
|
+
if word_stream is None:
|
237
|
+
word_stream = self._opts.word_tokenizer.stream()
|
238
|
+
self._segments_ch.send_nowait(word_stream)
|
239
|
+
word_stream.push_text(input)
|
240
|
+
elif isinstance(input, self._FlushSentinel):
|
241
|
+
if word_stream:
|
242
|
+
word_stream.end_input()
|
243
|
+
word_stream = None
|
244
|
+
|
245
|
+
self._segments_ch.close()
|
246
|
+
|
247
|
+
async def _run_segments() -> None:
|
248
|
+
async for word_stream in self._segments_ch:
|
249
|
+
await self._run_ws(word_stream, output_emitter)
|
250
|
+
|
251
|
+
tasks = [
|
252
|
+
asyncio.create_task(_tokenize_input()),
|
253
|
+
asyncio.create_task(_run_segments()),
|
254
|
+
]
|
255
|
+
try:
|
256
|
+
await asyncio.gather(*tasks)
|
257
|
+
except asyncio.TimeoutError:
|
258
|
+
raise APITimeoutError() from None
|
259
|
+
except aiohttp.ClientResponseError as e:
|
260
|
+
raise APIStatusError(
|
261
|
+
message=e.message, status_code=e.status, request_id=request_id, body=None
|
262
|
+
) from None
|
263
|
+
except Exception as e:
|
264
|
+
raise APIConnectionError() from e
|
265
|
+
finally:
|
266
|
+
await utils.aio.gracefully_cancel(*tasks)
|
267
|
+
|
268
|
+
async def _run_ws(
|
269
|
+
self, word_stream: tokenize.WordStream, output_emitter: tts.AudioEmitter
|
270
|
+
) -> None:
|
271
|
+
segment_id = utils.shortuuid()
|
272
|
+
output_emitter.start_segment(segment_id=segment_id)
|
273
|
+
|
274
|
+
async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
|
275
|
+
async for word in word_stream:
|
276
|
+
speak_msg = {"type": "Speak", "text": f"{word.token} "}
|
277
|
+
self._mark_started()
|
278
|
+
await ws.send_str(json.dumps(speak_msg))
|
279
|
+
|
280
|
+
# Always flush after a segment
|
281
|
+
flush_msg = {"type": "Flush"}
|
282
|
+
await ws.send_str(json.dumps(flush_msg))
|
283
|
+
|
284
|
+
async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
|
285
|
+
while True:
|
286
|
+
msg = await ws.receive()
|
287
|
+
if msg.type in (
|
288
|
+
aiohttp.WSMsgType.CLOSE,
|
289
|
+
aiohttp.WSMsgType.CLOSED,
|
290
|
+
aiohttp.WSMsgType.CLOSING,
|
291
|
+
):
|
292
|
+
raise APIStatusError("Deepgram websocket connection closed unexpectedly")
|
293
|
+
|
294
|
+
if msg.type == aiohttp.WSMsgType.BINARY:
|
295
|
+
output_emitter.push(msg.data)
|
296
|
+
elif msg.type == aiohttp.WSMsgType.TEXT:
|
297
|
+
resp = json.loads(msg.data)
|
298
|
+
mtype = resp.get("type")
|
299
|
+
if mtype == "Flushed":
|
300
|
+
output_emitter.end_segment()
|
301
|
+
break
|
302
|
+
elif mtype == "Warning":
|
303
|
+
logger.warning("Deepgram warning: %s", resp.get("warn_msg"))
|
304
|
+
elif mtype == "Metadata":
|
305
|
+
pass
|
306
|
+
else:
|
307
|
+
logger.debug("Unknown message type: %s", resp)
|
308
|
+
|
309
|
+
async with self._tts._pool.connection(timeout=self._conn_options.timeout) as ws:
|
310
|
+
tasks = [
|
311
|
+
asyncio.create_task(send_task(ws)),
|
312
|
+
asyncio.create_task(recv_task(ws)),
|
313
|
+
]
|
314
|
+
|
315
|
+
try:
|
316
|
+
await asyncio.gather(*tasks)
|
317
|
+
finally:
|
318
|
+
await utils.aio.gracefully_cancel(*tasks)
|
@@ -22,7 +22,7 @@ classifiers = [
|
|
22
22
|
"Programming Language :: Python :: 3.10",
|
23
23
|
"Programming Language :: Python :: 3 :: Only",
|
24
24
|
]
|
25
|
-
dependencies = ["livekit-agents[codecs]>=1.
|
25
|
+
dependencies = ["livekit-agents[codecs]>=1.1.1", "numpy>=1.26"]
|
26
26
|
|
27
27
|
[project.urls]
|
28
28
|
Documentation = "https://docs.livekit.io"
|
@@ -1,438 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import asyncio
|
4
|
-
import json
|
5
|
-
import os
|
6
|
-
import weakref
|
7
|
-
from dataclasses import dataclass
|
8
|
-
|
9
|
-
import aiohttp
|
10
|
-
|
11
|
-
from livekit.agents import (
|
12
|
-
APIConnectionError,
|
13
|
-
APIConnectOptions,
|
14
|
-
APIStatusError,
|
15
|
-
APITimeoutError,
|
16
|
-
tokenize,
|
17
|
-
tts,
|
18
|
-
utils,
|
19
|
-
)
|
20
|
-
from livekit.agents.types import (
|
21
|
-
DEFAULT_API_CONNECT_OPTIONS,
|
22
|
-
NOT_GIVEN,
|
23
|
-
NotGivenOr,
|
24
|
-
)
|
25
|
-
from livekit.agents.utils import is_given
|
26
|
-
|
27
|
-
from ._utils import _to_deepgram_url
|
28
|
-
from .log import logger
|
29
|
-
|
30
|
-
BASE_URL = "https://api.deepgram.com/v1/speak"
|
31
|
-
NUM_CHANNELS = 1
|
32
|
-
|
33
|
-
|
34
|
-
@dataclass
|
35
|
-
class _TTSOptions:
|
36
|
-
model: str
|
37
|
-
encoding: str
|
38
|
-
sample_rate: int
|
39
|
-
word_tokenizer: tokenize.WordTokenizer
|
40
|
-
mip_opt_out: bool = False
|
41
|
-
|
42
|
-
|
43
|
-
class TTS(tts.TTS):
|
44
|
-
def __init__(
|
45
|
-
self,
|
46
|
-
*,
|
47
|
-
model: str = "aura-2-andromeda-en",
|
48
|
-
encoding: str = "linear16",
|
49
|
-
sample_rate: int = 24000,
|
50
|
-
api_key: NotGivenOr[str] = NOT_GIVEN,
|
51
|
-
base_url: str = BASE_URL,
|
52
|
-
word_tokenizer: NotGivenOr[tokenize.WordTokenizer] = NOT_GIVEN,
|
53
|
-
http_session: aiohttp.ClientSession | None = None,
|
54
|
-
mip_opt_out: bool = False,
|
55
|
-
) -> None:
|
56
|
-
"""
|
57
|
-
Create a new instance of Deepgram TTS.
|
58
|
-
|
59
|
-
Args:
|
60
|
-
model (str): TTS model to use. Defaults to "aura-2-andromeda-en".
|
61
|
-
encoding (str): Audio encoding to use. Defaults to "linear16".
|
62
|
-
sample_rate (int): Sample rate of audio. Defaults to 24000.
|
63
|
-
api_key (str): Deepgram API key. If not provided, will look for DEEPGRAM_API_KEY in environment.
|
64
|
-
base_url (str): Base URL for Deepgram TTS API. Defaults to "https://api.deepgram.com/v1/speak"
|
65
|
-
word_tokenizer (tokenize.WordTokenizer): Tokenizer for processing text. Defaults to basic WordTokenizer.
|
66
|
-
http_session (aiohttp.ClientSession): Optional aiohttp session to use for requests.
|
67
|
-
|
68
|
-
""" # noqa: E501
|
69
|
-
super().__init__(
|
70
|
-
capabilities=tts.TTSCapabilities(streaming=True),
|
71
|
-
sample_rate=sample_rate,
|
72
|
-
num_channels=NUM_CHANNELS,
|
73
|
-
)
|
74
|
-
|
75
|
-
self._api_key = api_key if is_given(api_key) else os.environ.get("DEEPGRAM_API_KEY")
|
76
|
-
if not self._api_key:
|
77
|
-
raise ValueError("Deepgram API key required. Set DEEPGRAM_API_KEY or provide api_key.")
|
78
|
-
|
79
|
-
if not is_given(word_tokenizer):
|
80
|
-
word_tokenizer = tokenize.basic.WordTokenizer(ignore_punctuation=False)
|
81
|
-
|
82
|
-
self._opts = _TTSOptions(
|
83
|
-
model=model,
|
84
|
-
encoding=encoding,
|
85
|
-
sample_rate=sample_rate,
|
86
|
-
word_tokenizer=word_tokenizer,
|
87
|
-
mip_opt_out=mip_opt_out,
|
88
|
-
)
|
89
|
-
self._session = http_session
|
90
|
-
self._base_url = base_url
|
91
|
-
self._streams = weakref.WeakSet[SynthesizeStream]()
|
92
|
-
self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
|
93
|
-
connect_cb=self._connect_ws,
|
94
|
-
close_cb=self._close_ws,
|
95
|
-
max_session_duration=3600, # 1 hour
|
96
|
-
mark_refreshed_on_get=False,
|
97
|
-
)
|
98
|
-
|
99
|
-
async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
|
100
|
-
session = self._ensure_session()
|
101
|
-
config = {
|
102
|
-
"encoding": self._opts.encoding,
|
103
|
-
"model": self._opts.model,
|
104
|
-
"sample_rate": self._opts.sample_rate,
|
105
|
-
"mip_opt_out": self._opts.mip_opt_out,
|
106
|
-
}
|
107
|
-
return await asyncio.wait_for(
|
108
|
-
session.ws_connect(
|
109
|
-
_to_deepgram_url(config, self._base_url, websocket=True),
|
110
|
-
headers={"Authorization": f"Token {self._api_key}"},
|
111
|
-
),
|
112
|
-
self._conn_options.timeout,
|
113
|
-
)
|
114
|
-
|
115
|
-
async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse):
|
116
|
-
await ws.close()
|
117
|
-
|
118
|
-
def _ensure_session(self) -> aiohttp.ClientSession:
|
119
|
-
if not self._session:
|
120
|
-
self._session = utils.http_context.http_session()
|
121
|
-
return self._session
|
122
|
-
|
123
|
-
def update_options(
|
124
|
-
self,
|
125
|
-
*,
|
126
|
-
model: NotGivenOr[str] = NOT_GIVEN,
|
127
|
-
sample_rate: NotGivenOr[int] = NOT_GIVEN,
|
128
|
-
) -> None:
|
129
|
-
"""
|
130
|
-
args:
|
131
|
-
model (str): TTS model to use.
|
132
|
-
sample_rate (int): Sample rate of audio.
|
133
|
-
"""
|
134
|
-
if is_given(model):
|
135
|
-
self._opts.model = model
|
136
|
-
if is_given(sample_rate):
|
137
|
-
self._opts.sample_rate = sample_rate
|
138
|
-
for stream in self._streams:
|
139
|
-
stream.update_options(
|
140
|
-
model=model,
|
141
|
-
sample_rate=sample_rate,
|
142
|
-
)
|
143
|
-
|
144
|
-
def synthesize(
|
145
|
-
self,
|
146
|
-
text: str,
|
147
|
-
*,
|
148
|
-
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
149
|
-
) -> ChunkedStream:
|
150
|
-
return ChunkedStream(
|
151
|
-
tts=self,
|
152
|
-
input_text=text,
|
153
|
-
base_url=self._base_url,
|
154
|
-
api_key=self._api_key,
|
155
|
-
conn_options=conn_options,
|
156
|
-
opts=self._opts,
|
157
|
-
session=self._ensure_session(),
|
158
|
-
)
|
159
|
-
|
160
|
-
def stream(
|
161
|
-
self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
|
162
|
-
) -> SynthesizeStream:
|
163
|
-
stream = SynthesizeStream(
|
164
|
-
tts=self,
|
165
|
-
conn_options=conn_options,
|
166
|
-
base_url=self._base_url,
|
167
|
-
api_key=self._api_key,
|
168
|
-
opts=self._opts,
|
169
|
-
session=self._ensure_session(),
|
170
|
-
)
|
171
|
-
self._streams.add(stream)
|
172
|
-
return stream
|
173
|
-
|
174
|
-
def prewarm(self) -> None:
|
175
|
-
self._pool.prewarm()
|
176
|
-
|
177
|
-
async def aclose(self) -> None:
|
178
|
-
for stream in list(self._streams):
|
179
|
-
await stream.aclose()
|
180
|
-
self._streams.clear()
|
181
|
-
await self._pool.aclose()
|
182
|
-
await super().aclose()
|
183
|
-
|
184
|
-
|
185
|
-
class ChunkedStream(tts.ChunkedStream):
|
186
|
-
def __init__(
|
187
|
-
self,
|
188
|
-
*,
|
189
|
-
tts: TTS,
|
190
|
-
base_url: str,
|
191
|
-
api_key: str,
|
192
|
-
input_text: str,
|
193
|
-
opts: _TTSOptions,
|
194
|
-
session: aiohttp.ClientSession,
|
195
|
-
conn_options: APIConnectOptions,
|
196
|
-
) -> None:
|
197
|
-
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
|
198
|
-
self._opts = opts
|
199
|
-
self._session = session
|
200
|
-
self._base_url = base_url
|
201
|
-
self._api_key = api_key
|
202
|
-
|
203
|
-
async def _run(self) -> None:
|
204
|
-
request_id = utils.shortuuid()
|
205
|
-
audio_bstream = utils.audio.AudioByteStream(
|
206
|
-
sample_rate=self._opts.sample_rate,
|
207
|
-
num_channels=NUM_CHANNELS,
|
208
|
-
)
|
209
|
-
|
210
|
-
try:
|
211
|
-
config = {
|
212
|
-
"encoding": self._opts.encoding,
|
213
|
-
"model": self._opts.model,
|
214
|
-
"sample_rate": self._opts.sample_rate,
|
215
|
-
"mip_opt_out": self._opts.mip_opt_out,
|
216
|
-
}
|
217
|
-
async with self._session.post(
|
218
|
-
_to_deepgram_url(config, self._base_url, websocket=False),
|
219
|
-
headers={
|
220
|
-
"Authorization": f"Token {self._api_key}",
|
221
|
-
"Content-Type": "application/json",
|
222
|
-
},
|
223
|
-
json={"text": self._input_text},
|
224
|
-
timeout=aiohttp.ClientTimeout(connect=self._conn_options.timeout, total=30),
|
225
|
-
) as res:
|
226
|
-
if res.status != 200:
|
227
|
-
raise APIStatusError(
|
228
|
-
message=res.reason or "Unknown error occurred.",
|
229
|
-
status_code=res.status,
|
230
|
-
request_id=request_id,
|
231
|
-
body=await res.json(),
|
232
|
-
)
|
233
|
-
|
234
|
-
async for bytes_data, _ in res.content.iter_chunks():
|
235
|
-
for frame in audio_bstream.write(bytes_data):
|
236
|
-
self._event_ch.send_nowait(
|
237
|
-
tts.SynthesizedAudio(
|
238
|
-
request_id=request_id,
|
239
|
-
frame=frame,
|
240
|
-
)
|
241
|
-
)
|
242
|
-
|
243
|
-
for frame in audio_bstream.flush():
|
244
|
-
self._event_ch.send_nowait(
|
245
|
-
tts.SynthesizedAudio(request_id=request_id, frame=frame)
|
246
|
-
)
|
247
|
-
|
248
|
-
except asyncio.TimeoutError as e:
|
249
|
-
raise APITimeoutError() from e
|
250
|
-
except aiohttp.ClientResponseError as e:
|
251
|
-
raise APIStatusError(
|
252
|
-
message=e.message,
|
253
|
-
status_code=e.status,
|
254
|
-
request_id=request_id,
|
255
|
-
body=None,
|
256
|
-
) from e
|
257
|
-
except Exception as e:
|
258
|
-
raise APIConnectionError() from e
|
259
|
-
|
260
|
-
|
261
|
-
class SynthesizeStream(tts.SynthesizeStream):
|
262
|
-
def __init__(
|
263
|
-
self,
|
264
|
-
*,
|
265
|
-
tts: TTS,
|
266
|
-
base_url: str,
|
267
|
-
api_key: str,
|
268
|
-
opts: _TTSOptions,
|
269
|
-
session: aiohttp.ClientSession,
|
270
|
-
conn_options: APIConnectOptions,
|
271
|
-
):
|
272
|
-
super().__init__(tts=tts, conn_options=conn_options)
|
273
|
-
self._opts = opts
|
274
|
-
self._session = session
|
275
|
-
self._base_url = base_url
|
276
|
-
self._api_key = api_key
|
277
|
-
self._segments_ch = utils.aio.Chan[tokenize.WordStream]()
|
278
|
-
self._reconnect_event = asyncio.Event()
|
279
|
-
|
280
|
-
def update_options(
|
281
|
-
self,
|
282
|
-
*,
|
283
|
-
model: NotGivenOr[str] = NOT_GIVEN,
|
284
|
-
sample_rate: NotGivenOr[int] = NOT_GIVEN,
|
285
|
-
) -> None:
|
286
|
-
if is_given(model):
|
287
|
-
self._opts.model = model
|
288
|
-
if is_given(sample_rate):
|
289
|
-
self._opts.sample_rate = sample_rate
|
290
|
-
|
291
|
-
self._reconnect_event.set()
|
292
|
-
|
293
|
-
async def _run(self) -> None:
|
294
|
-
closing_ws = False
|
295
|
-
request_id = utils.shortuuid()
|
296
|
-
segment_id = utils.shortuuid()
|
297
|
-
audio_bstream = utils.audio.AudioByteStream(
|
298
|
-
sample_rate=self._opts.sample_rate,
|
299
|
-
num_channels=NUM_CHANNELS,
|
300
|
-
)
|
301
|
-
|
302
|
-
@utils.log_exceptions(logger=logger)
|
303
|
-
async def _tokenize_input():
|
304
|
-
# Converts incoming text into WordStreams and sends them into _segments_ch
|
305
|
-
word_stream = None
|
306
|
-
async for input in self._input_ch:
|
307
|
-
if isinstance(input, str):
|
308
|
-
if word_stream is None:
|
309
|
-
word_stream = self._opts.word_tokenizer.stream()
|
310
|
-
self._segments_ch.send_nowait(word_stream)
|
311
|
-
word_stream.push_text(input)
|
312
|
-
elif isinstance(input, self._FlushSentinel):
|
313
|
-
if word_stream:
|
314
|
-
word_stream.end_input()
|
315
|
-
word_stream = None
|
316
|
-
self._segments_ch.close()
|
317
|
-
|
318
|
-
@utils.log_exceptions(logger=logger)
|
319
|
-
async def _run_segments(ws: aiohttp.ClientWebSocketResponse):
|
320
|
-
nonlocal closing_ws
|
321
|
-
async for word_stream in self._segments_ch:
|
322
|
-
async for word in word_stream:
|
323
|
-
speak_msg = {"type": "Speak", "text": f"{word.token} "}
|
324
|
-
self._mark_started()
|
325
|
-
await ws.send_str(json.dumps(speak_msg))
|
326
|
-
|
327
|
-
# Always flush after a segment
|
328
|
-
flush_msg = {"type": "Flush"}
|
329
|
-
await ws.send_str(json.dumps(flush_msg))
|
330
|
-
|
331
|
-
# after all segments, close
|
332
|
-
close_msg = {"type": "Close"}
|
333
|
-
closing_ws = True
|
334
|
-
await ws.send_str(json.dumps(close_msg))
|
335
|
-
|
336
|
-
async def recv_task(ws: aiohttp.ClientWebSocketResponse):
|
337
|
-
emitter = tts.SynthesizedAudioEmitter(
|
338
|
-
event_ch=self._event_ch,
|
339
|
-
request_id=request_id,
|
340
|
-
segment_id=segment_id,
|
341
|
-
)
|
342
|
-
|
343
|
-
while True:
|
344
|
-
msg = await ws.receive()
|
345
|
-
if msg.type in (
|
346
|
-
aiohttp.WSMsgType.CLOSE,
|
347
|
-
aiohttp.WSMsgType.CLOSED,
|
348
|
-
aiohttp.WSMsgType.CLOSING,
|
349
|
-
):
|
350
|
-
if not closing_ws:
|
351
|
-
raise APIStatusError(
|
352
|
-
"Deepgram websocket connection closed unexpectedly",
|
353
|
-
request_id=request_id,
|
354
|
-
)
|
355
|
-
return
|
356
|
-
|
357
|
-
if msg.type == aiohttp.WSMsgType.BINARY:
|
358
|
-
data = msg.data
|
359
|
-
for frame in audio_bstream.write(data):
|
360
|
-
emitter.push(frame)
|
361
|
-
elif msg.type == aiohttp.WSMsgType.TEXT:
|
362
|
-
resp = json.loads(msg.data)
|
363
|
-
mtype = resp.get("type")
|
364
|
-
if mtype == "Flushed":
|
365
|
-
for frame in audio_bstream.flush():
|
366
|
-
emitter.push(frame)
|
367
|
-
emitter.flush()
|
368
|
-
break
|
369
|
-
elif mtype == "Warning":
|
370
|
-
logger.warning("Deepgram warning: %s", resp.get("warn_msg"))
|
371
|
-
elif mtype == "Metadata":
|
372
|
-
pass
|
373
|
-
else:
|
374
|
-
logger.debug("Unknown message type: %s", resp)
|
375
|
-
|
376
|
-
async def _connection_timeout():
|
377
|
-
# Deepgram has a 60-minute timeout period for websocket connections
|
378
|
-
await asyncio.sleep(3300)
|
379
|
-
logger.warning("Deepgram TTS maximum connection time reached. Reconnecting...")
|
380
|
-
self._reconnect_event.set()
|
381
|
-
|
382
|
-
ws: aiohttp.ClientWebSocketResponse | None = None
|
383
|
-
while True:
|
384
|
-
try:
|
385
|
-
config = {
|
386
|
-
"encoding": self._opts.encoding,
|
387
|
-
"model": self._opts.model,
|
388
|
-
"sample_rate": self._opts.sample_rate,
|
389
|
-
"mip_opt_out": self._opts.mip_opt_out,
|
390
|
-
}
|
391
|
-
ws = await asyncio.wait_for(
|
392
|
-
self._session.ws_connect(
|
393
|
-
_to_deepgram_url(config, self._base_url, websocket=True),
|
394
|
-
headers={"Authorization": f"Token {self._api_key}"},
|
395
|
-
),
|
396
|
-
self._conn_options.timeout,
|
397
|
-
)
|
398
|
-
closing_ws = False
|
399
|
-
|
400
|
-
tasks = [
|
401
|
-
asyncio.create_task(_tokenize_input()),
|
402
|
-
asyncio.create_task(_run_segments(ws)),
|
403
|
-
asyncio.create_task(recv_task(ws)),
|
404
|
-
]
|
405
|
-
wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
|
406
|
-
connection_timeout_task = asyncio.create_task(_connection_timeout())
|
407
|
-
|
408
|
-
try:
|
409
|
-
done, _ = await asyncio.wait(
|
410
|
-
[
|
411
|
-
asyncio.gather(*tasks),
|
412
|
-
wait_reconnect_task,
|
413
|
-
connection_timeout_task,
|
414
|
-
],
|
415
|
-
return_when=asyncio.FIRST_COMPLETED,
|
416
|
-
) # type: ignore
|
417
|
-
if wait_reconnect_task not in done:
|
418
|
-
break
|
419
|
-
self._reconnect_event.clear()
|
420
|
-
finally:
|
421
|
-
await utils.aio.gracefully_cancel(
|
422
|
-
*tasks, wait_reconnect_task, connection_timeout_task
|
423
|
-
)
|
424
|
-
|
425
|
-
except asyncio.TimeoutError as e:
|
426
|
-
raise APITimeoutError() from e
|
427
|
-
except aiohttp.ClientResponseError as e:
|
428
|
-
raise APIStatusError(
|
429
|
-
message=e.message,
|
430
|
-
status_code=e.status,
|
431
|
-
request_id=request_id,
|
432
|
-
body=None,
|
433
|
-
) from e
|
434
|
-
except Exception as e:
|
435
|
-
raise APIConnectionError() from e
|
436
|
-
finally:
|
437
|
-
if ws is not None and not ws.closed:
|
438
|
-
await ws.close()
|
File without changes
|
File without changes
|
{livekit_plugins_deepgram-1.0.23 → livekit_plugins_deepgram-1.1.1}/livekit/plugins/deepgram/log.py
RENAMED
File without changes
|
File without changes
|
{livekit_plugins_deepgram-1.0.23 → livekit_plugins_deepgram-1.1.1}/livekit/plugins/deepgram/py.typed
RENAMED
File without changes
|