livekit-plugins-deepgram 1.0.23__tar.gz → 1.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -146,6 +146,9 @@ venv.bak/
146
146
  .dmypy.json
147
147
  dmypy.json
148
148
 
149
+ # trunk
150
+ .trunk/
151
+
149
152
  # Pyre type checker
150
153
  .pyre/
151
154
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: livekit-plugins-deepgram
3
- Version: 1.0.23
3
+ Version: 1.1.0
4
4
  Summary: Agent Framework plugin for services using Deepgram's API.
5
5
  Project-URL: Documentation, https://docs.livekit.io
6
6
  Project-URL: Website, https://livekit.io/
@@ -18,7 +18,7 @@ Classifier: Topic :: Multimedia :: Sound/Audio
18
18
  Classifier: Topic :: Multimedia :: Video
19
19
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Requires-Python: >=3.9.0
21
- Requires-Dist: livekit-agents[codecs]>=1.0.23
21
+ Requires-Dist: livekit-agents[codecs]>=1.1.0
22
22
  Requires-Dist: numpy>=1.26
23
23
  Description-Content-Type: text/markdown
24
24
 
@@ -32,7 +32,7 @@ from .log import logger
32
32
 
33
33
 
34
34
  class DeepgramPlugin(Plugin):
35
- def __init__(self):
35
+ def __init__(self) -> None:
36
36
  super().__init__(__name__, __version__, __package__, logger)
37
37
 
38
38
 
@@ -94,7 +94,7 @@ class AudioEnergyFilter:
94
94
 
95
95
  @dataclass
96
96
  class STTOptions:
97
- language: DeepgramLanguages | str
97
+ language: DeepgramLanguages | str | None
98
98
  detect_language: bool
99
99
  interim_results: bool
100
100
  punctuate: bool
@@ -181,9 +181,10 @@ class STT(stt.STT):
181
181
  )
182
182
  self._base_url = base_url
183
183
 
184
- self._api_key = api_key if is_given(api_key) else os.environ.get("DEEPGRAM_API_KEY")
185
- if not self._api_key:
184
+ deepgram_api_key = api_key if is_given(api_key) else os.environ.get("DEEPGRAM_API_KEY")
185
+ if not deepgram_api_key:
186
186
  raise ValueError("Deepgram API key is required")
187
+ self._api_key = deepgram_api_key
187
188
 
188
189
  model = _validate_model(model, language)
189
190
  _validate_keyterms(model, language, keyterms, keywords)
@@ -305,7 +306,7 @@ class STT(stt.STT):
305
306
  numerals: NotGivenOr[bool] = NOT_GIVEN,
306
307
  mip_opt_out: NotGivenOr[bool] = NOT_GIVEN,
307
308
  tags: NotGivenOr[list[str]] = NOT_GIVEN,
308
- ):
309
+ ) -> None:
309
310
  if is_given(language):
310
311
  self._opts.language = language
311
312
  if is_given(model):
@@ -383,14 +384,13 @@ class SpeechStream(stt.SpeechStream):
383
384
  http_session: aiohttp.ClientSession,
384
385
  base_url: str,
385
386
  ) -> None:
386
- super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
387
-
388
387
  if opts.detect_language or opts.language is None:
389
388
  raise ValueError(
390
389
  "language detection is not supported in streaming mode, "
391
390
  "please disable it and specify a language"
392
391
  )
393
392
 
393
+ super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
394
394
  self._opts = opts
395
395
  self._api_key = api_key
396
396
  self._session = http_session
@@ -429,7 +429,7 @@ class SpeechStream(stt.SpeechStream):
429
429
  numerals: NotGivenOr[bool] = NOT_GIVEN,
430
430
  mip_opt_out: NotGivenOr[bool] = NOT_GIVEN,
431
431
  tags: NotGivenOr[list[str]] = NOT_GIVEN,
432
- ):
432
+ ) -> None:
433
433
  if is_given(language):
434
434
  self._opts.language = language
435
435
  if is_given(model):
@@ -466,7 +466,7 @@ class SpeechStream(stt.SpeechStream):
466
466
  async def _run(self) -> None:
467
467
  closing_ws = False
468
468
 
469
- async def keepalive_task(ws: aiohttp.ClientWebSocketResponse):
469
+ async def keepalive_task(ws: aiohttp.ClientWebSocketResponse) -> None:
470
470
  # if we want to keep the connection alive even if no audio is sent,
471
471
  # Deepgram expects a keepalive message.
472
472
  # https://developers.deepgram.com/reference/listen-live#stream-keepalive
@@ -478,7 +478,7 @@ class SpeechStream(stt.SpeechStream):
478
478
  return
479
479
 
480
480
  @utils.log_exceptions(logger=logger)
481
- async def send_task(ws: aiohttp.ClientWebSocketResponse):
481
+ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
482
482
  nonlocal closing_ws
483
483
 
484
484
  # forward audio to deepgram in chunks of 50ms
@@ -529,7 +529,7 @@ class SpeechStream(stt.SpeechStream):
529
529
  await ws.send_str(SpeechStream._CLOSE_MSG)
530
530
 
531
531
  @utils.log_exceptions(logger=logger)
532
- async def recv_task(ws: aiohttp.ClientWebSocketResponse):
532
+ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
533
533
  nonlocal closing_ws
534
534
  while True:
535
535
  msg = await ws.receive()
@@ -569,9 +569,9 @@ class SpeechStream(stt.SpeechStream):
569
569
  wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
570
570
  try:
571
571
  done, _ = await asyncio.wait(
572
- [tasks_group, wait_reconnect_task],
572
+ (tasks_group, wait_reconnect_task),
573
573
  return_when=asyncio.FIRST_COMPLETED,
574
- ) # type: ignore
574
+ )
575
575
 
576
576
  # propagate exceptions from completed tasks
577
577
  for task in done:
@@ -0,0 +1,318 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import os
6
+ import weakref
7
+ from dataclasses import dataclass, replace
8
+
9
+ import aiohttp
10
+
11
+ from livekit.agents import (
12
+ APIConnectionError,
13
+ APIConnectOptions,
14
+ APIStatusError,
15
+ APITimeoutError,
16
+ tokenize,
17
+ tts,
18
+ utils,
19
+ )
20
+ from livekit.agents.types import (
21
+ DEFAULT_API_CONNECT_OPTIONS,
22
+ NOT_GIVEN,
23
+ NotGivenOr,
24
+ )
25
+ from livekit.agents.utils import is_given
26
+
27
+ from ._utils import _to_deepgram_url
28
+ from .log import logger
29
+
30
+ BASE_URL = "https://api.deepgram.com/v1/speak"
31
+ NUM_CHANNELS = 1
32
+
33
+
34
+ @dataclass
35
+ class _TTSOptions:
36
+ model: str
37
+ encoding: str
38
+ sample_rate: int
39
+ word_tokenizer: tokenize.WordTokenizer
40
+ base_url: str
41
+ api_key: str
42
+ mip_opt_out: bool = False
43
+
44
+
45
+ class TTS(tts.TTS):
46
+ def __init__(
47
+ self,
48
+ *,
49
+ model: str = "aura-2-andromeda-en",
50
+ encoding: str = "linear16",
51
+ sample_rate: int = 24000,
52
+ api_key: str | None = None,
53
+ base_url: str = BASE_URL,
54
+ word_tokenizer: NotGivenOr[tokenize.WordTokenizer] = NOT_GIVEN,
55
+ http_session: aiohttp.ClientSession | None = None,
56
+ mip_opt_out: bool = False,
57
+ ) -> None:
58
+ """
59
+ Create a new instance of Deepgram TTS.
60
+
61
+ Args:
62
+ model (str): TTS model to use. Defaults to "aura-2-andromeda-en".
63
+ encoding (str): Audio encoding to use. Defaults to "linear16".
64
+ sample_rate (int): Sample rate of audio. Defaults to 24000.
65
+ api_key (str): Deepgram API key. If not provided, will look for DEEPGRAM_API_KEY in environment.
66
+ base_url (str): Base URL for Deepgram TTS API. Defaults to "https://api.deepgram.com/v1/speak"
67
+ word_tokenizer (tokenize.WordTokenizer): Tokenizer for processing text. Defaults to basic WordTokenizer.
68
+ http_session (aiohttp.ClientSession): Optional aiohttp session to use for requests.
69
+
70
+ """ # noqa: E501
71
+ super().__init__(
72
+ capabilities=tts.TTSCapabilities(streaming=True),
73
+ sample_rate=sample_rate,
74
+ num_channels=NUM_CHANNELS,
75
+ )
76
+
77
+ api_key = api_key or os.environ.get("DEEPGRAM_API_KEY")
78
+ if not api_key:
79
+ raise ValueError("Deepgram API key required. Set DEEPGRAM_API_KEY or provide api_key.")
80
+
81
+ if not is_given(word_tokenizer):
82
+ word_tokenizer = tokenize.basic.WordTokenizer(ignore_punctuation=False)
83
+
84
+ self._opts = _TTSOptions(
85
+ model=model,
86
+ encoding=encoding,
87
+ sample_rate=sample_rate,
88
+ word_tokenizer=word_tokenizer,
89
+ base_url=base_url,
90
+ api_key=api_key,
91
+ mip_opt_out=mip_opt_out,
92
+ )
93
+ self._session = http_session
94
+ self._streams = weakref.WeakSet[SynthesizeStream]()
95
+
96
+ self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
97
+ connect_cb=self._connect_ws,
98
+ close_cb=self._close_ws,
99
+ max_session_duration=3600, # 1 hour
100
+ mark_refreshed_on_get=False,
101
+ )
102
+
103
+ async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse:
104
+ session = self._ensure_session()
105
+ config = {
106
+ "encoding": self._opts.encoding,
107
+ "model": self._opts.model,
108
+ "sample_rate": self._opts.sample_rate,
109
+ "mip_opt_out": self._opts.mip_opt_out,
110
+ }
111
+ return await asyncio.wait_for(
112
+ session.ws_connect(
113
+ _to_deepgram_url(config, self._opts.base_url, websocket=True),
114
+ headers={"Authorization": f"Token {self._opts.api_key}"},
115
+ ),
116
+ timeout,
117
+ )
118
+
119
+ async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse) -> None:
120
+ await ws.close()
121
+
122
+ def _ensure_session(self) -> aiohttp.ClientSession:
123
+ if not self._session:
124
+ self._session = utils.http_context.http_session()
125
+ return self._session
126
+
127
+ def update_options(
128
+ self,
129
+ *,
130
+ model: NotGivenOr[str] = NOT_GIVEN,
131
+ ) -> None:
132
+ """
133
+ Args:
134
+ model (str): TTS model to use.
135
+ """
136
+ if is_given(model):
137
+ self._opts.model = model
138
+
139
+ def synthesize(
140
+ self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
141
+ ) -> ChunkedStream:
142
+ return ChunkedStream(tts=self, input_text=text, conn_options=conn_options)
143
+
144
+ def stream(
145
+ self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
146
+ ) -> SynthesizeStream:
147
+ stream = SynthesizeStream(tts=self, conn_options=conn_options)
148
+ self._streams.add(stream)
149
+ return stream
150
+
151
+ def prewarm(self) -> None:
152
+ self._pool.prewarm()
153
+
154
+ async def aclose(self) -> None:
155
+ for stream in list(self._streams):
156
+ await stream.aclose()
157
+
158
+ self._streams.clear()
159
+
160
+ await self._pool.aclose()
161
+
162
+
163
+ class ChunkedStream(tts.ChunkedStream):
164
+ def __init__(self, *, tts: TTS, input_text: str, conn_options: APIConnectOptions) -> None:
165
+ super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
166
+ self._tts: TTS = tts
167
+ self._opts = replace(tts._opts)
168
+
169
+ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
170
+ try:
171
+ async with self._tts._ensure_session().post(
172
+ _to_deepgram_url(
173
+ {
174
+ "encoding": self._opts.encoding,
175
+ "container": "none",
176
+ "model": self._opts.model,
177
+ "sample_rate": self._opts.sample_rate,
178
+ "mip_opt_out": self._opts.mip_opt_out,
179
+ },
180
+ self._opts.base_url,
181
+ websocket=False,
182
+ ),
183
+ headers={
184
+ "Authorization": f"Token {self._opts.api_key}",
185
+ "Content-Type": "application/json",
186
+ },
187
+ json={"text": self._input_text},
188
+ timeout=aiohttp.ClientTimeout(total=30, sock_connect=self._conn_options.timeout),
189
+ ) as resp:
190
+ resp.raise_for_status()
191
+
192
+ output_emitter.initialize(
193
+ request_id=utils.shortuuid(),
194
+ sample_rate=self._opts.sample_rate,
195
+ num_channels=NUM_CHANNELS,
196
+ mime_type="audio/pcm",
197
+ )
198
+
199
+ async for data, _ in resp.content.iter_chunks():
200
+ output_emitter.push(data)
201
+
202
+ output_emitter.flush()
203
+
204
+ except asyncio.TimeoutError:
205
+ raise APITimeoutError() from None
206
+ except aiohttp.ClientResponseError as e:
207
+ raise APIStatusError(
208
+ message=e.message, status_code=e.status, request_id=None, body=None
209
+ ) from None
210
+ except Exception as e:
211
+ raise APIConnectionError() from e
212
+
213
+
214
+ class SynthesizeStream(tts.SynthesizeStream):
215
+ def __init__(self, *, tts: TTS, conn_options: APIConnectOptions):
216
+ super().__init__(tts=tts, conn_options=conn_options)
217
+ self._tts: TTS = tts
218
+ self._opts = replace(tts._opts)
219
+ self._segments_ch = utils.aio.Chan[tokenize.WordStream]()
220
+
221
+ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
222
+ request_id = utils.shortuuid()
223
+ output_emitter.initialize(
224
+ request_id=request_id,
225
+ sample_rate=self._opts.sample_rate,
226
+ num_channels=1,
227
+ mime_type="audio/pcm",
228
+ stream=True,
229
+ )
230
+
231
+ async def _tokenize_input() -> None:
232
+ # Converts incoming text into WordStreams and sends them into _segments_ch
233
+ word_stream = None
234
+ async for input in self._input_ch:
235
+ if isinstance(input, str):
236
+ if word_stream is None:
237
+ word_stream = self._opts.word_tokenizer.stream()
238
+ self._segments_ch.send_nowait(word_stream)
239
+ word_stream.push_text(input)
240
+ elif isinstance(input, self._FlushSentinel):
241
+ if word_stream:
242
+ word_stream.end_input()
243
+ word_stream = None
244
+
245
+ self._segments_ch.close()
246
+
247
+ async def _run_segments() -> None:
248
+ async for word_stream in self._segments_ch:
249
+ await self._run_ws(word_stream, output_emitter)
250
+
251
+ tasks = [
252
+ asyncio.create_task(_tokenize_input()),
253
+ asyncio.create_task(_run_segments()),
254
+ ]
255
+ try:
256
+ await asyncio.gather(*tasks)
257
+ except asyncio.TimeoutError:
258
+ raise APITimeoutError() from None
259
+ except aiohttp.ClientResponseError as e:
260
+ raise APIStatusError(
261
+ message=e.message, status_code=e.status, request_id=request_id, body=None
262
+ ) from None
263
+ except Exception as e:
264
+ raise APIConnectionError() from e
265
+ finally:
266
+ await utils.aio.gracefully_cancel(*tasks)
267
+
268
+ async def _run_ws(
269
+ self, word_stream: tokenize.WordStream, output_emitter: tts.AudioEmitter
270
+ ) -> None:
271
+ segment_id = utils.shortuuid()
272
+ output_emitter.start_segment(segment_id=segment_id)
273
+
274
+ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
275
+ async for word in word_stream:
276
+ speak_msg = {"type": "Speak", "text": f"{word.token} "}
277
+ self._mark_started()
278
+ await ws.send_str(json.dumps(speak_msg))
279
+
280
+ # Always flush after a segment
281
+ flush_msg = {"type": "Flush"}
282
+ await ws.send_str(json.dumps(flush_msg))
283
+
284
+ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
285
+ while True:
286
+ msg = await ws.receive()
287
+ if msg.type in (
288
+ aiohttp.WSMsgType.CLOSE,
289
+ aiohttp.WSMsgType.CLOSED,
290
+ aiohttp.WSMsgType.CLOSING,
291
+ ):
292
+ raise APIStatusError("Deepgram websocket connection closed unexpectedly")
293
+
294
+ if msg.type == aiohttp.WSMsgType.BINARY:
295
+ output_emitter.push(msg.data)
296
+ elif msg.type == aiohttp.WSMsgType.TEXT:
297
+ resp = json.loads(msg.data)
298
+ mtype = resp.get("type")
299
+ if mtype == "Flushed":
300
+ output_emitter.end_segment()
301
+ break
302
+ elif mtype == "Warning":
303
+ logger.warning("Deepgram warning: %s", resp.get("warn_msg"))
304
+ elif mtype == "Metadata":
305
+ pass
306
+ else:
307
+ logger.debug("Unknown message type: %s", resp)
308
+
309
+ async with self._tts._pool.connection(timeout=self._conn_options.timeout) as ws:
310
+ tasks = [
311
+ asyncio.create_task(send_task(ws)),
312
+ asyncio.create_task(recv_task(ws)),
313
+ ]
314
+
315
+ try:
316
+ await asyncio.gather(*tasks)
317
+ finally:
318
+ await utils.aio.gracefully_cancel(*tasks)
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __version__ = "1.0.23"
15
+ __version__ = "1.1.0"
@@ -22,7 +22,7 @@ classifiers = [
22
22
  "Programming Language :: Python :: 3.10",
23
23
  "Programming Language :: Python :: 3 :: Only",
24
24
  ]
25
- dependencies = ["livekit-agents[codecs]>=1.0.23", "numpy>=1.26"]
25
+ dependencies = ["livekit-agents[codecs]>=1.1.0", "numpy>=1.26"]
26
26
 
27
27
  [project.urls]
28
28
  Documentation = "https://docs.livekit.io"
@@ -1,438 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import asyncio
4
- import json
5
- import os
6
- import weakref
7
- from dataclasses import dataclass
8
-
9
- import aiohttp
10
-
11
- from livekit.agents import (
12
- APIConnectionError,
13
- APIConnectOptions,
14
- APIStatusError,
15
- APITimeoutError,
16
- tokenize,
17
- tts,
18
- utils,
19
- )
20
- from livekit.agents.types import (
21
- DEFAULT_API_CONNECT_OPTIONS,
22
- NOT_GIVEN,
23
- NotGivenOr,
24
- )
25
- from livekit.agents.utils import is_given
26
-
27
- from ._utils import _to_deepgram_url
28
- from .log import logger
29
-
30
- BASE_URL = "https://api.deepgram.com/v1/speak"
31
- NUM_CHANNELS = 1
32
-
33
-
34
- @dataclass
35
- class _TTSOptions:
36
- model: str
37
- encoding: str
38
- sample_rate: int
39
- word_tokenizer: tokenize.WordTokenizer
40
- mip_opt_out: bool = False
41
-
42
-
43
- class TTS(tts.TTS):
44
- def __init__(
45
- self,
46
- *,
47
- model: str = "aura-2-andromeda-en",
48
- encoding: str = "linear16",
49
- sample_rate: int = 24000,
50
- api_key: NotGivenOr[str] = NOT_GIVEN,
51
- base_url: str = BASE_URL,
52
- word_tokenizer: NotGivenOr[tokenize.WordTokenizer] = NOT_GIVEN,
53
- http_session: aiohttp.ClientSession | None = None,
54
- mip_opt_out: bool = False,
55
- ) -> None:
56
- """
57
- Create a new instance of Deepgram TTS.
58
-
59
- Args:
60
- model (str): TTS model to use. Defaults to "aura-2-andromeda-en".
61
- encoding (str): Audio encoding to use. Defaults to "linear16".
62
- sample_rate (int): Sample rate of audio. Defaults to 24000.
63
- api_key (str): Deepgram API key. If not provided, will look for DEEPGRAM_API_KEY in environment.
64
- base_url (str): Base URL for Deepgram TTS API. Defaults to "https://api.deepgram.com/v1/speak"
65
- word_tokenizer (tokenize.WordTokenizer): Tokenizer for processing text. Defaults to basic WordTokenizer.
66
- http_session (aiohttp.ClientSession): Optional aiohttp session to use for requests.
67
-
68
- """ # noqa: E501
69
- super().__init__(
70
- capabilities=tts.TTSCapabilities(streaming=True),
71
- sample_rate=sample_rate,
72
- num_channels=NUM_CHANNELS,
73
- )
74
-
75
- self._api_key = api_key if is_given(api_key) else os.environ.get("DEEPGRAM_API_KEY")
76
- if not self._api_key:
77
- raise ValueError("Deepgram API key required. Set DEEPGRAM_API_KEY or provide api_key.")
78
-
79
- if not is_given(word_tokenizer):
80
- word_tokenizer = tokenize.basic.WordTokenizer(ignore_punctuation=False)
81
-
82
- self._opts = _TTSOptions(
83
- model=model,
84
- encoding=encoding,
85
- sample_rate=sample_rate,
86
- word_tokenizer=word_tokenizer,
87
- mip_opt_out=mip_opt_out,
88
- )
89
- self._session = http_session
90
- self._base_url = base_url
91
- self._streams = weakref.WeakSet[SynthesizeStream]()
92
- self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
93
- connect_cb=self._connect_ws,
94
- close_cb=self._close_ws,
95
- max_session_duration=3600, # 1 hour
96
- mark_refreshed_on_get=False,
97
- )
98
-
99
- async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
100
- session = self._ensure_session()
101
- config = {
102
- "encoding": self._opts.encoding,
103
- "model": self._opts.model,
104
- "sample_rate": self._opts.sample_rate,
105
- "mip_opt_out": self._opts.mip_opt_out,
106
- }
107
- return await asyncio.wait_for(
108
- session.ws_connect(
109
- _to_deepgram_url(config, self._base_url, websocket=True),
110
- headers={"Authorization": f"Token {self._api_key}"},
111
- ),
112
- self._conn_options.timeout,
113
- )
114
-
115
- async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse):
116
- await ws.close()
117
-
118
- def _ensure_session(self) -> aiohttp.ClientSession:
119
- if not self._session:
120
- self._session = utils.http_context.http_session()
121
- return self._session
122
-
123
- def update_options(
124
- self,
125
- *,
126
- model: NotGivenOr[str] = NOT_GIVEN,
127
- sample_rate: NotGivenOr[int] = NOT_GIVEN,
128
- ) -> None:
129
- """
130
- args:
131
- model (str): TTS model to use.
132
- sample_rate (int): Sample rate of audio.
133
- """
134
- if is_given(model):
135
- self._opts.model = model
136
- if is_given(sample_rate):
137
- self._opts.sample_rate = sample_rate
138
- for stream in self._streams:
139
- stream.update_options(
140
- model=model,
141
- sample_rate=sample_rate,
142
- )
143
-
144
- def synthesize(
145
- self,
146
- text: str,
147
- *,
148
- conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
149
- ) -> ChunkedStream:
150
- return ChunkedStream(
151
- tts=self,
152
- input_text=text,
153
- base_url=self._base_url,
154
- api_key=self._api_key,
155
- conn_options=conn_options,
156
- opts=self._opts,
157
- session=self._ensure_session(),
158
- )
159
-
160
- def stream(
161
- self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
162
- ) -> SynthesizeStream:
163
- stream = SynthesizeStream(
164
- tts=self,
165
- conn_options=conn_options,
166
- base_url=self._base_url,
167
- api_key=self._api_key,
168
- opts=self._opts,
169
- session=self._ensure_session(),
170
- )
171
- self._streams.add(stream)
172
- return stream
173
-
174
- def prewarm(self) -> None:
175
- self._pool.prewarm()
176
-
177
- async def aclose(self) -> None:
178
- for stream in list(self._streams):
179
- await stream.aclose()
180
- self._streams.clear()
181
- await self._pool.aclose()
182
- await super().aclose()
183
-
184
-
185
- class ChunkedStream(tts.ChunkedStream):
186
- def __init__(
187
- self,
188
- *,
189
- tts: TTS,
190
- base_url: str,
191
- api_key: str,
192
- input_text: str,
193
- opts: _TTSOptions,
194
- session: aiohttp.ClientSession,
195
- conn_options: APIConnectOptions,
196
- ) -> None:
197
- super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
198
- self._opts = opts
199
- self._session = session
200
- self._base_url = base_url
201
- self._api_key = api_key
202
-
203
- async def _run(self) -> None:
204
- request_id = utils.shortuuid()
205
- audio_bstream = utils.audio.AudioByteStream(
206
- sample_rate=self._opts.sample_rate,
207
- num_channels=NUM_CHANNELS,
208
- )
209
-
210
- try:
211
- config = {
212
- "encoding": self._opts.encoding,
213
- "model": self._opts.model,
214
- "sample_rate": self._opts.sample_rate,
215
- "mip_opt_out": self._opts.mip_opt_out,
216
- }
217
- async with self._session.post(
218
- _to_deepgram_url(config, self._base_url, websocket=False),
219
- headers={
220
- "Authorization": f"Token {self._api_key}",
221
- "Content-Type": "application/json",
222
- },
223
- json={"text": self._input_text},
224
- timeout=aiohttp.ClientTimeout(connect=self._conn_options.timeout, total=30),
225
- ) as res:
226
- if res.status != 200:
227
- raise APIStatusError(
228
- message=res.reason or "Unknown error occurred.",
229
- status_code=res.status,
230
- request_id=request_id,
231
- body=await res.json(),
232
- )
233
-
234
- async for bytes_data, _ in res.content.iter_chunks():
235
- for frame in audio_bstream.write(bytes_data):
236
- self._event_ch.send_nowait(
237
- tts.SynthesizedAudio(
238
- request_id=request_id,
239
- frame=frame,
240
- )
241
- )
242
-
243
- for frame in audio_bstream.flush():
244
- self._event_ch.send_nowait(
245
- tts.SynthesizedAudio(request_id=request_id, frame=frame)
246
- )
247
-
248
- except asyncio.TimeoutError as e:
249
- raise APITimeoutError() from e
250
- except aiohttp.ClientResponseError as e:
251
- raise APIStatusError(
252
- message=e.message,
253
- status_code=e.status,
254
- request_id=request_id,
255
- body=None,
256
- ) from e
257
- except Exception as e:
258
- raise APIConnectionError() from e
259
-
260
-
261
- class SynthesizeStream(tts.SynthesizeStream):
262
- def __init__(
263
- self,
264
- *,
265
- tts: TTS,
266
- base_url: str,
267
- api_key: str,
268
- opts: _TTSOptions,
269
- session: aiohttp.ClientSession,
270
- conn_options: APIConnectOptions,
271
- ):
272
- super().__init__(tts=tts, conn_options=conn_options)
273
- self._opts = opts
274
- self._session = session
275
- self._base_url = base_url
276
- self._api_key = api_key
277
- self._segments_ch = utils.aio.Chan[tokenize.WordStream]()
278
- self._reconnect_event = asyncio.Event()
279
-
280
- def update_options(
281
- self,
282
- *,
283
- model: NotGivenOr[str] = NOT_GIVEN,
284
- sample_rate: NotGivenOr[int] = NOT_GIVEN,
285
- ) -> None:
286
- if is_given(model):
287
- self._opts.model = model
288
- if is_given(sample_rate):
289
- self._opts.sample_rate = sample_rate
290
-
291
- self._reconnect_event.set()
292
-
293
- async def _run(self) -> None:
294
- closing_ws = False
295
- request_id = utils.shortuuid()
296
- segment_id = utils.shortuuid()
297
- audio_bstream = utils.audio.AudioByteStream(
298
- sample_rate=self._opts.sample_rate,
299
- num_channels=NUM_CHANNELS,
300
- )
301
-
302
- @utils.log_exceptions(logger=logger)
303
- async def _tokenize_input():
304
- # Converts incoming text into WordStreams and sends them into _segments_ch
305
- word_stream = None
306
- async for input in self._input_ch:
307
- if isinstance(input, str):
308
- if word_stream is None:
309
- word_stream = self._opts.word_tokenizer.stream()
310
- self._segments_ch.send_nowait(word_stream)
311
- word_stream.push_text(input)
312
- elif isinstance(input, self._FlushSentinel):
313
- if word_stream:
314
- word_stream.end_input()
315
- word_stream = None
316
- self._segments_ch.close()
317
-
318
- @utils.log_exceptions(logger=logger)
319
- async def _run_segments(ws: aiohttp.ClientWebSocketResponse):
320
- nonlocal closing_ws
321
- async for word_stream in self._segments_ch:
322
- async for word in word_stream:
323
- speak_msg = {"type": "Speak", "text": f"{word.token} "}
324
- self._mark_started()
325
- await ws.send_str(json.dumps(speak_msg))
326
-
327
- # Always flush after a segment
328
- flush_msg = {"type": "Flush"}
329
- await ws.send_str(json.dumps(flush_msg))
330
-
331
- # after all segments, close
332
- close_msg = {"type": "Close"}
333
- closing_ws = True
334
- await ws.send_str(json.dumps(close_msg))
335
-
336
- async def recv_task(ws: aiohttp.ClientWebSocketResponse):
337
- emitter = tts.SynthesizedAudioEmitter(
338
- event_ch=self._event_ch,
339
- request_id=request_id,
340
- segment_id=segment_id,
341
- )
342
-
343
- while True:
344
- msg = await ws.receive()
345
- if msg.type in (
346
- aiohttp.WSMsgType.CLOSE,
347
- aiohttp.WSMsgType.CLOSED,
348
- aiohttp.WSMsgType.CLOSING,
349
- ):
350
- if not closing_ws:
351
- raise APIStatusError(
352
- "Deepgram websocket connection closed unexpectedly",
353
- request_id=request_id,
354
- )
355
- return
356
-
357
- if msg.type == aiohttp.WSMsgType.BINARY:
358
- data = msg.data
359
- for frame in audio_bstream.write(data):
360
- emitter.push(frame)
361
- elif msg.type == aiohttp.WSMsgType.TEXT:
362
- resp = json.loads(msg.data)
363
- mtype = resp.get("type")
364
- if mtype == "Flushed":
365
- for frame in audio_bstream.flush():
366
- emitter.push(frame)
367
- emitter.flush()
368
- break
369
- elif mtype == "Warning":
370
- logger.warning("Deepgram warning: %s", resp.get("warn_msg"))
371
- elif mtype == "Metadata":
372
- pass
373
- else:
374
- logger.debug("Unknown message type: %s", resp)
375
-
376
- async def _connection_timeout():
377
- # Deepgram has a 60-minute timeout period for websocket connections
378
- await asyncio.sleep(3300)
379
- logger.warning("Deepgram TTS maximum connection time reached. Reconnecting...")
380
- self._reconnect_event.set()
381
-
382
- ws: aiohttp.ClientWebSocketResponse | None = None
383
- while True:
384
- try:
385
- config = {
386
- "encoding": self._opts.encoding,
387
- "model": self._opts.model,
388
- "sample_rate": self._opts.sample_rate,
389
- "mip_opt_out": self._opts.mip_opt_out,
390
- }
391
- ws = await asyncio.wait_for(
392
- self._session.ws_connect(
393
- _to_deepgram_url(config, self._base_url, websocket=True),
394
- headers={"Authorization": f"Token {self._api_key}"},
395
- ),
396
- self._conn_options.timeout,
397
- )
398
- closing_ws = False
399
-
400
- tasks = [
401
- asyncio.create_task(_tokenize_input()),
402
- asyncio.create_task(_run_segments(ws)),
403
- asyncio.create_task(recv_task(ws)),
404
- ]
405
- wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
406
- connection_timeout_task = asyncio.create_task(_connection_timeout())
407
-
408
- try:
409
- done, _ = await asyncio.wait(
410
- [
411
- asyncio.gather(*tasks),
412
- wait_reconnect_task,
413
- connection_timeout_task,
414
- ],
415
- return_when=asyncio.FIRST_COMPLETED,
416
- ) # type: ignore
417
- if wait_reconnect_task not in done:
418
- break
419
- self._reconnect_event.clear()
420
- finally:
421
- await utils.aio.gracefully_cancel(
422
- *tasks, wait_reconnect_task, connection_timeout_task
423
- )
424
-
425
- except asyncio.TimeoutError as e:
426
- raise APITimeoutError() from e
427
- except aiohttp.ClientResponseError as e:
428
- raise APIStatusError(
429
- message=e.message,
430
- status_code=e.status,
431
- request_id=request_id,
432
- body=None,
433
- ) from e
434
- except Exception as e:
435
- raise APIConnectionError() from e
436
- finally:
437
- if ws is not None and not ws.closed:
438
- await ws.close()