livekit-plugins-google 0.7.2__tar.gz → 0.8.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (18) hide show
  1. {livekit_plugins_google-0.7.2 → livekit_plugins_google-0.8.1}/PKG-INFO +1 -1
  2. {livekit_plugins_google-0.7.2 → livekit_plugins_google-0.8.1}/livekit/plugins/google/models.py +7 -1
  3. livekit_plugins_google-0.8.1/livekit/plugins/google/stt.py +505 -0
  4. {livekit_plugins_google-0.7.2 → livekit_plugins_google-0.8.1}/livekit/plugins/google/tts.py +24 -9
  5. {livekit_plugins_google-0.7.2 → livekit_plugins_google-0.8.1}/livekit/plugins/google/version.py +1 -1
  6. {livekit_plugins_google-0.7.2 → livekit_plugins_google-0.8.1}/livekit_plugins_google.egg-info/PKG-INFO +1 -1
  7. livekit_plugins_google-0.7.2/livekit/plugins/google/stt.py +0 -394
  8. {livekit_plugins_google-0.7.2 → livekit_plugins_google-0.8.1}/README.md +0 -0
  9. {livekit_plugins_google-0.7.2 → livekit_plugins_google-0.8.1}/livekit/plugins/google/__init__.py +0 -0
  10. {livekit_plugins_google-0.7.2 → livekit_plugins_google-0.8.1}/livekit/plugins/google/log.py +0 -0
  11. {livekit_plugins_google-0.7.2 → livekit_plugins_google-0.8.1}/livekit/plugins/google/py.typed +0 -0
  12. {livekit_plugins_google-0.7.2 → livekit_plugins_google-0.8.1}/livekit_plugins_google.egg-info/SOURCES.txt +0 -0
  13. {livekit_plugins_google-0.7.2 → livekit_plugins_google-0.8.1}/livekit_plugins_google.egg-info/dependency_links.txt +0 -0
  14. {livekit_plugins_google-0.7.2 → livekit_plugins_google-0.8.1}/livekit_plugins_google.egg-info/requires.txt +0 -0
  15. {livekit_plugins_google-0.7.2 → livekit_plugins_google-0.8.1}/livekit_plugins_google.egg-info/top_level.txt +0 -0
  16. {livekit_plugins_google-0.7.2 → livekit_plugins_google-0.8.1}/pyproject.toml +0 -0
  17. {livekit_plugins_google-0.7.2 → livekit_plugins_google-0.8.1}/setup.cfg +0 -0
  18. {livekit_plugins_google-0.7.2 → livekit_plugins_google-0.8.1}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: livekit-plugins-google
3
- Version: 0.7.2
3
+ Version: 0.8.1
4
4
  Summary: Agent Framework plugin for services from Google Cloud
5
5
  Home-page: https://github.com/livekit/agents
6
6
  License: Apache-2.0
@@ -3,7 +3,13 @@ from typing import Literal
3
3
  # Speech to Text v2
4
4
 
5
5
  SpeechModels = Literal[
6
- "long", "short", "telephony", "medical_dictation", "medical_conversation", "chirp"
6
+ "long",
7
+ "short",
8
+ "telephony",
9
+ "medical_dictation",
10
+ "medical_conversation",
11
+ "chirp",
12
+ "chirp_2",
7
13
  ]
8
14
 
9
15
  SpeechLanguages = Literal[
@@ -0,0 +1,505 @@
1
+ # Copyright 2023 LiveKit, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import asyncio
18
+ import dataclasses
19
+ import weakref
20
+ from dataclasses import dataclass
21
+ from typing import List, Union
22
+
23
+ from livekit import rtc
24
+ from livekit.agents import (
25
+ DEFAULT_API_CONNECT_OPTIONS,
26
+ APIConnectionError,
27
+ APIConnectOptions,
28
+ APIStatusError,
29
+ APITimeoutError,
30
+ stt,
31
+ utils,
32
+ )
33
+
34
+ from google.api_core.client_options import ClientOptions
35
+ from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
36
+ from google.auth import default as gauth_default
37
+ from google.auth.exceptions import DefaultCredentialsError
38
+ from google.cloud.speech_v2 import SpeechAsyncClient
39
+ from google.cloud.speech_v2.types import cloud_speech
40
+
41
+ from .log import logger
42
+ from .models import SpeechLanguages, SpeechModels
43
+
44
+ LgType = Union[SpeechLanguages, str]
45
+ LanguageCode = Union[LgType, List[LgType]]
46
+
47
+
48
+ # This class is only be used internally to encapsulate the options
49
+ @dataclass
50
+ class STTOptions:
51
+ languages: List[LgType]
52
+ detect_language: bool
53
+ interim_results: bool
54
+ punctuate: bool
55
+ spoken_punctuation: bool
56
+ model: SpeechModels
57
+ sample_rate: int
58
+ keywords: List[tuple[str, float]] | None
59
+
60
+ def build_adaptation(self) -> cloud_speech.SpeechAdaptation | None:
61
+ if self.keywords:
62
+ return cloud_speech.SpeechAdaptation(
63
+ phrase_sets=[
64
+ cloud_speech.SpeechAdaptation.AdaptationPhraseSet(
65
+ inline_phrase_set=cloud_speech.PhraseSet(
66
+ phrases=[
67
+ cloud_speech.PhraseSet.Phrase(
68
+ value=keyword, boost=boost
69
+ )
70
+ for keyword, boost in self.keywords
71
+ ]
72
+ )
73
+ )
74
+ ]
75
+ )
76
+ return None
77
+
78
+
79
+ class STT(stt.STT):
80
+ def __init__(
81
+ self,
82
+ *,
83
+ languages: LanguageCode = "en-US", # Google STT can accept multiple languages
84
+ detect_language: bool = True,
85
+ interim_results: bool = True,
86
+ punctuate: bool = True,
87
+ spoken_punctuation: bool = True,
88
+ model: SpeechModels = "long",
89
+ location: str = "global",
90
+ sample_rate: int = 16000,
91
+ credentials_info: dict | None = None,
92
+ credentials_file: str | None = None,
93
+ keywords: List[tuple[str, float]] | None = None,
94
+ ):
95
+ """
96
+ Create a new instance of Google STT.
97
+
98
+ Credentials must be provided, either by using the ``credentials_info`` dict, or reading
99
+ from the file specified in ``credentials_file`` or via Application Default Credentials as
100
+ described in https://cloud.google.com/docs/authentication/application-default-credentials
101
+ """
102
+ super().__init__(
103
+ capabilities=stt.STTCapabilities(streaming=True, interim_results=True)
104
+ )
105
+
106
+ self._client: SpeechAsyncClient | None = None
107
+ self._location = location
108
+ self._credentials_info = credentials_info
109
+ self._credentials_file = credentials_file
110
+
111
+ if credentials_file is None and credentials_info is None:
112
+ try:
113
+ gauth_default()
114
+ except DefaultCredentialsError:
115
+ raise ValueError(
116
+ "Application default credentials must be available "
117
+ "when using Google STT without explicitly passing "
118
+ "credentials through credentials_info or credentials_file."
119
+ )
120
+
121
+ if isinstance(languages, str):
122
+ languages = [languages]
123
+
124
+ self._config = STTOptions(
125
+ languages=languages,
126
+ detect_language=detect_language,
127
+ interim_results=interim_results,
128
+ punctuate=punctuate,
129
+ spoken_punctuation=spoken_punctuation,
130
+ model=model,
131
+ sample_rate=sample_rate,
132
+ keywords=keywords,
133
+ )
134
+ self._streams = weakref.WeakSet[SpeechStream]()
135
+
136
+ def _ensure_client(self) -> SpeechAsyncClient:
137
+ if self._credentials_info:
138
+ self._client = SpeechAsyncClient.from_service_account_info(
139
+ self._credentials_info
140
+ )
141
+ elif self._credentials_file:
142
+ self._client = SpeechAsyncClient.from_service_account_file(
143
+ self._credentials_file
144
+ )
145
+ elif self._location == "global":
146
+ self._client = SpeechAsyncClient()
147
+ else:
148
+ # Add support for passing a specific location that matches recognizer
149
+ # see: https://cloud.google.com/speech-to-text/v2/docs/speech-to-text-supported-languages
150
+ self._client = SpeechAsyncClient(
151
+ client_options=ClientOptions(
152
+ api_endpoint=f"{self._location}-speech.googleapis.com"
153
+ )
154
+ )
155
+ assert self._client is not None
156
+ return self._client
157
+
158
+ @property
159
+ def _recognizer(self) -> str:
160
+ # TODO(theomonnom): should we use recognizers?
161
+ # recognizers may improve latency https://cloud.google.com/speech-to-text/v2/docs/recognizers#understand_recognizers
162
+
163
+ # TODO(theomonnom): find a better way to access the project_id
164
+ try:
165
+ project_id = self._ensure_client().transport._credentials.project_id # type: ignore
166
+ except AttributeError:
167
+ from google.auth import default as ga_default
168
+
169
+ _, project_id = ga_default()
170
+ return f"projects/{project_id}/locations/{self._location}/recognizers/_"
171
+
172
+ def _sanitize_options(self, *, language: str | None = None) -> STTOptions:
173
+ config = dataclasses.replace(self._config)
174
+
175
+ if language:
176
+ config.languages = [language]
177
+
178
+ if not isinstance(config.languages, list):
179
+ config.languages = [config.languages]
180
+ elif not config.detect_language:
181
+ if len(config.languages) > 1:
182
+ logger.warning(
183
+ "multiple languages provided, but language detection is disabled"
184
+ )
185
+ config.languages = [config.languages[0]]
186
+
187
+ return config
188
+
189
+ async def _recognize_impl(
190
+ self,
191
+ buffer: utils.AudioBuffer,
192
+ *,
193
+ language: SpeechLanguages | str | None,
194
+ conn_options: APIConnectOptions,
195
+ ) -> stt.SpeechEvent:
196
+ config = self._sanitize_options(language=language)
197
+ frame = rtc.combine_audio_frames(buffer)
198
+
199
+ config = cloud_speech.RecognitionConfig(
200
+ explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
201
+ encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
202
+ sample_rate_hertz=frame.sample_rate,
203
+ audio_channel_count=frame.num_channels,
204
+ ),
205
+ adaptation=config.build_adaptation(),
206
+ features=cloud_speech.RecognitionFeatures(
207
+ enable_automatic_punctuation=config.punctuate,
208
+ enable_spoken_punctuation=config.spoken_punctuation,
209
+ enable_word_time_offsets=True,
210
+ ),
211
+ model=config.model,
212
+ language_codes=config.languages,
213
+ )
214
+
215
+ try:
216
+ raw = await self._ensure_client().recognize(
217
+ cloud_speech.RecognizeRequest(
218
+ recognizer=self._recognizer,
219
+ config=config,
220
+ content=frame.data.tobytes(),
221
+ ),
222
+ timeout=conn_options.timeout,
223
+ )
224
+
225
+ return _recognize_response_to_speech_event(raw)
226
+ except DeadlineExceeded:
227
+ raise APITimeoutError()
228
+ except GoogleAPICallError as e:
229
+ raise APIStatusError(
230
+ e.message,
231
+ status_code=e.code or -1,
232
+ request_id=None,
233
+ body=None,
234
+ )
235
+ except Exception as e:
236
+ raise APIConnectionError() from e
237
+
238
+ def stream(
239
+ self,
240
+ *,
241
+ language: SpeechLanguages | str | None = None,
242
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
243
+ ) -> "SpeechStream":
244
+ config = self._sanitize_options(language=language)
245
+ stream = SpeechStream(
246
+ stt=self,
247
+ client=self._ensure_client(),
248
+ recognizer=self._recognizer,
249
+ config=config,
250
+ conn_options=conn_options,
251
+ )
252
+ self._streams.add(stream)
253
+ return stream
254
+
255
+ def update_options(
256
+ self,
257
+ *,
258
+ languages: LanguageCode | None = None,
259
+ detect_language: bool | None = None,
260
+ interim_results: bool | None = None,
261
+ punctuate: bool | None = None,
262
+ spoken_punctuation: bool | None = None,
263
+ model: SpeechModels | None = None,
264
+ location: str | None = None,
265
+ keywords: List[tuple[str, float]] | None = None,
266
+ ):
267
+ if languages is not None:
268
+ if isinstance(languages, str):
269
+ languages = [languages]
270
+ self._config.languages = languages
271
+ if detect_language is not None:
272
+ self._config.detect_language = detect_language
273
+ if interim_results is not None:
274
+ self._config.interim_results = interim_results
275
+ if punctuate is not None:
276
+ self._config.punctuate = punctuate
277
+ if spoken_punctuation is not None:
278
+ self._config.spoken_punctuation = spoken_punctuation
279
+ if model is not None:
280
+ self._config.model = model
281
+ if keywords is not None:
282
+ self._config.keywords = keywords
283
+
284
+ for stream in self._streams:
285
+ stream.update_options(
286
+ languages=languages,
287
+ detect_language=detect_language,
288
+ interim_results=interim_results,
289
+ punctuate=punctuate,
290
+ spoken_punctuation=spoken_punctuation,
291
+ model=model,
292
+ location=location,
293
+ keywords=keywords,
294
+ )
295
+
296
+
297
+ class SpeechStream(stt.SpeechStream):
298
+ def __init__(
299
+ self,
300
+ *,
301
+ stt: STT,
302
+ conn_options: APIConnectOptions,
303
+ client: SpeechAsyncClient,
304
+ recognizer: str,
305
+ config: STTOptions,
306
+ ) -> None:
307
+ super().__init__(
308
+ stt=stt, conn_options=conn_options, sample_rate=config.sample_rate
309
+ )
310
+
311
+ self._client = client
312
+ self._recognizer = recognizer
313
+ self._config = config
314
+ self._reconnect_event = asyncio.Event()
315
+
316
+ def update_options(
317
+ self,
318
+ *,
319
+ languages: LanguageCode | None = None,
320
+ detect_language: bool | None = None,
321
+ interim_results: bool | None = None,
322
+ punctuate: bool | None = None,
323
+ spoken_punctuation: bool | None = None,
324
+ model: SpeechModels | None = None,
325
+ location: str | None = None,
326
+ keywords: List[tuple[str, float]] | None = None,
327
+ ):
328
+ if languages is not None:
329
+ if isinstance(languages, str):
330
+ languages = [languages]
331
+ self._config.languages = languages
332
+ if detect_language is not None:
333
+ self._config.detect_language = detect_language
334
+ if interim_results is not None:
335
+ self._config.interim_results = interim_results
336
+ if punctuate is not None:
337
+ self._config.punctuate = punctuate
338
+ if spoken_punctuation is not None:
339
+ self._config.spoken_punctuation = spoken_punctuation
340
+ if model is not None:
341
+ self._config.model = model
342
+ if keywords is not None:
343
+ self._config.keywords = keywords
344
+
345
+ self._reconnect_event.set()
346
+
347
+ async def _run(self) -> None:
348
+ # google requires a async generator when calling streaming_recognize
349
+ # this function basically convert the queue into a async generator
350
+ async def input_generator():
351
+ try:
352
+ # first request should contain the config
353
+ yield cloud_speech.StreamingRecognizeRequest(
354
+ recognizer=self._recognizer,
355
+ streaming_config=self._streaming_config,
356
+ )
357
+
358
+ async for frame in self._input_ch:
359
+ if isinstance(frame, rtc.AudioFrame):
360
+ yield cloud_speech.StreamingRecognizeRequest(
361
+ audio=frame.data.tobytes()
362
+ )
363
+
364
+ except Exception:
365
+ logger.exception(
366
+ "an error occurred while streaming input to google STT"
367
+ )
368
+
369
+ async def process_stream(stream):
370
+ async for resp in stream:
371
+ if (
372
+ resp.speech_event_type
373
+ == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN
374
+ ):
375
+ self._event_ch.send_nowait(
376
+ stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
377
+ )
378
+
379
+ if (
380
+ resp.speech_event_type
381
+ == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_TYPE_UNSPECIFIED
382
+ ):
383
+ result = resp.results[0]
384
+ speech_data = _streaming_recognize_response_to_speech_data(resp)
385
+ if speech_data is None:
386
+ continue
387
+
388
+ if not result.is_final:
389
+ self._event_ch.send_nowait(
390
+ stt.SpeechEvent(
391
+ type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
392
+ alternatives=[speech_data],
393
+ )
394
+ )
395
+ else:
396
+ self._event_ch.send_nowait(
397
+ stt.SpeechEvent(
398
+ type=stt.SpeechEventType.FINAL_TRANSCRIPT,
399
+ alternatives=[speech_data],
400
+ )
401
+ )
402
+
403
+ if (
404
+ resp.speech_event_type
405
+ == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END
406
+ ):
407
+ self._event_ch.send_nowait(
408
+ stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
409
+ )
410
+
411
+ while True:
412
+ try:
413
+ self._streaming_config = cloud_speech.StreamingRecognitionConfig(
414
+ config=cloud_speech.RecognitionConfig(
415
+ explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
416
+ encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
417
+ sample_rate_hertz=self._config.sample_rate,
418
+ audio_channel_count=1,
419
+ ),
420
+ adaptation=self._config.build_adaptation(),
421
+ language_codes=self._config.languages,
422
+ model=self._config.model,
423
+ features=cloud_speech.RecognitionFeatures(
424
+ enable_automatic_punctuation=self._config.punctuate,
425
+ enable_word_time_offsets=True,
426
+ ),
427
+ ),
428
+ streaming_features=cloud_speech.StreamingRecognitionFeatures(
429
+ enable_voice_activity_events=True,
430
+ interim_results=self._config.interim_results,
431
+ ),
432
+ )
433
+
434
+ stream = await self._client.streaming_recognize(
435
+ requests=input_generator(),
436
+ )
437
+
438
+ process_stream_task = asyncio.create_task(process_stream(stream))
439
+ wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
440
+ try:
441
+ await asyncio.wait(
442
+ [process_stream_task, wait_reconnect_task],
443
+ return_when=asyncio.FIRST_COMPLETED,
444
+ )
445
+ finally:
446
+ await utils.aio.gracefully_cancel(
447
+ process_stream_task, wait_reconnect_task
448
+ )
449
+ finally:
450
+ if not self._reconnect_event.is_set():
451
+ break
452
+ self._reconnect_event.clear()
453
+
454
+
455
+ def _recognize_response_to_speech_event(
456
+ resp: cloud_speech.RecognizeResponse,
457
+ ) -> stt.SpeechEvent:
458
+ text = ""
459
+ confidence = 0.0
460
+ for result in resp.results:
461
+ text += result.alternatives[0].transcript
462
+ confidence += result.alternatives[0].confidence
463
+
464
+ # not sure why start_offset and end_offset returns a timedelta
465
+ start_offset = resp.results[0].alternatives[0].words[0].start_offset
466
+ end_offset = resp.results[-1].alternatives[0].words[-1].end_offset
467
+
468
+ confidence /= len(resp.results)
469
+ lg = resp.results[0].language_code
470
+ return stt.SpeechEvent(
471
+ type=stt.SpeechEventType.FINAL_TRANSCRIPT,
472
+ alternatives=[
473
+ stt.SpeechData(
474
+ language=lg,
475
+ start_time=start_offset.total_seconds(), # type: ignore
476
+ end_time=end_offset.total_seconds(), # type: ignore
477
+ confidence=confidence,
478
+ text=text,
479
+ )
480
+ ],
481
+ )
482
+
483
+
484
+ def _streaming_recognize_response_to_speech_data(
485
+ resp: cloud_speech.StreamingRecognizeResponse,
486
+ ) -> stt.SpeechData | None:
487
+ text = ""
488
+ confidence = 0.0
489
+ for result in resp.results:
490
+ if len(result.alternatives) == 0:
491
+ continue
492
+ text += result.alternatives[0].transcript
493
+ confidence += result.alternatives[0].confidence
494
+
495
+ confidence /= len(resp.results)
496
+ lg = resp.results[0].language_code
497
+
498
+ if text == "":
499
+ return None
500
+
501
+ data = stt.SpeechData(
502
+ language=lg, start_time=0, end_time=0, confidence=confidence, text=text
503
+ )
504
+
505
+ return data
@@ -18,7 +18,9 @@ from dataclasses import dataclass
18
18
 
19
19
  from livekit import rtc
20
20
  from livekit.agents import (
21
+ DEFAULT_API_CONNECT_OPTIONS,
21
22
  APIConnectionError,
23
+ APIConnectOptions,
22
24
  APIStatusError,
23
25
  APITimeoutError,
24
26
  tts,
@@ -134,7 +136,7 @@ class TTS(tts.TTS):
134
136
  self._opts.audio_config.speaking_rate = speaking_rate
135
137
 
136
138
  def _ensure_client(self) -> texttospeech.TextToSpeechAsyncClient:
137
- if not self._client:
139
+ if self._client is None:
138
140
  if self._credentials_info:
139
141
  self._client = (
140
142
  texttospeech.TextToSpeechAsyncClient.from_service_account_info(
@@ -154,22 +156,35 @@ class TTS(tts.TTS):
154
156
  assert self._client is not None
155
157
  return self._client
156
158
 
157
- def synthesize(self, text: str) -> "ChunkedStream":
158
- return ChunkedStream(self, text, self._opts, self._ensure_client())
159
+ def synthesize(
160
+ self,
161
+ text: str,
162
+ *,
163
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
164
+ ) -> "ChunkedStream":
165
+ return ChunkedStream(
166
+ tts=self,
167
+ input_text=text,
168
+ conn_options=conn_options,
169
+ opts=self._opts,
170
+ client=self._ensure_client(),
171
+ )
159
172
 
160
173
 
161
174
  class ChunkedStream(tts.ChunkedStream):
162
175
  def __init__(
163
176
  self,
177
+ *,
164
178
  tts: TTS,
165
- text: str,
179
+ input_text: str,
180
+ conn_options: APIConnectOptions,
166
181
  opts: _TTSOptions,
167
182
  client: texttospeech.TextToSpeechAsyncClient,
168
183
  ) -> None:
169
- super().__init__(tts, text)
184
+ super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
170
185
  self._opts, self._client = opts, client
171
186
 
172
- async def _main_task(self) -> None:
187
+ async def _run(self) -> None:
173
188
  request_id = utils.shortuuid()
174
189
 
175
190
  try:
@@ -177,16 +192,16 @@ class ChunkedStream(tts.ChunkedStream):
177
192
  input=texttospeech.SynthesisInput(text=self._input_text),
178
193
  voice=self._opts.voice,
179
194
  audio_config=self._opts.audio_config,
195
+ timeout=self._conn_options.timeout,
180
196
  )
181
197
 
182
- data = response.audio_content
183
198
  if self._opts.audio_config.audio_encoding == "mp3":
184
199
  decoder = utils.codecs.Mp3StreamDecoder()
185
200
  bstream = utils.audio.AudioByteStream(
186
201
  sample_rate=self._opts.audio_config.sample_rate_hertz,
187
202
  num_channels=1,
188
203
  )
189
- for frame in decoder.decode_chunk(data):
204
+ for frame in decoder.decode_chunk(response.audio_content):
190
205
  for frame in bstream.write(frame.data.tobytes()):
191
206
  self._event_ch.send_nowait(
192
207
  tts.SynthesizedAudio(request_id=request_id, frame=frame)
@@ -197,7 +212,7 @@ class ChunkedStream(tts.ChunkedStream):
197
212
  tts.SynthesizedAudio(request_id=request_id, frame=frame)
198
213
  )
199
214
  else:
200
- data = data[44:] # skip WAV header
215
+ data = response.audio_content[44:] # skip WAV header
201
216
  self._event_ch.send_nowait(
202
217
  tts.SynthesizedAudio(
203
218
  request_id=request_id,
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __version__ = "0.7.2"
15
+ __version__ = "0.8.1"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: livekit-plugins-google
3
- Version: 0.7.2
3
+ Version: 0.8.1
4
4
  Summary: Agent Framework plugin for services from Google Cloud
5
5
  Home-page: https://github.com/livekit/agents
6
6
  License: Apache-2.0
@@ -1,394 +0,0 @@
1
- # Copyright 2023 LiveKit, Inc.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import asyncio
18
- import dataclasses
19
- from dataclasses import dataclass
20
- from typing import AsyncIterable, List, Union
21
-
22
- from livekit import agents, rtc
23
- from livekit.agents import (
24
- APIConnectionError,
25
- APIStatusError,
26
- APITimeoutError,
27
- stt,
28
- utils,
29
- )
30
-
31
- from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
32
- from google.auth import default as gauth_default
33
- from google.auth.exceptions import DefaultCredentialsError
34
- from google.cloud.speech_v2 import SpeechAsyncClient
35
- from google.cloud.speech_v2.types import cloud_speech
36
-
37
- from .log import logger
38
- from .models import SpeechLanguages, SpeechModels
39
-
40
- LgType = Union[SpeechLanguages, str]
41
- LanguageCode = Union[LgType, List[LgType]]
42
-
43
-
44
- # This class is only be used internally to encapsulate the options
45
- @dataclass
46
- class STTOptions:
47
- languages: List[LgType]
48
- detect_language: bool
49
- interim_results: bool
50
- punctuate: bool
51
- spoken_punctuation: bool
52
- model: SpeechModels
53
-
54
-
55
- class STT(stt.STT):
56
- def __init__(
57
- self,
58
- *,
59
- languages: LanguageCode = "en-US", # Google STT can accept multiple languages
60
- detect_language: bool = True,
61
- interim_results: bool = True,
62
- punctuate: bool = True,
63
- spoken_punctuation: bool = True,
64
- model: SpeechModels = "long",
65
- credentials_info: dict | None = None,
66
- credentials_file: str | None = None,
67
- ):
68
- """
69
- Create a new instance of Google STT.
70
-
71
- Credentials must be provided, either by using the ``credentials_info`` dict, or reading
72
- from the file specified in ``credentials_file`` or via Application Default Credentials as
73
- described in https://cloud.google.com/docs/authentication/application-default-credentials
74
- """
75
- super().__init__(
76
- capabilities=stt.STTCapabilities(streaming=True, interim_results=True)
77
- )
78
-
79
- self._client: SpeechAsyncClient | None = None
80
- self._credentials_info = credentials_info
81
- self._credentials_file = credentials_file
82
-
83
- if credentials_file is None and credentials_info is None:
84
- try:
85
- gauth_default()
86
- except DefaultCredentialsError:
87
- raise ValueError(
88
- "Application default credentials must be available "
89
- "when using Google STT without explicitly passing "
90
- "credentials through credentials_info or credentials_file."
91
- )
92
-
93
- if isinstance(languages, str):
94
- languages = [languages]
95
-
96
- self._config = STTOptions(
97
- languages=languages,
98
- detect_language=detect_language,
99
- interim_results=interim_results,
100
- punctuate=punctuate,
101
- spoken_punctuation=spoken_punctuation,
102
- model=model,
103
- )
104
-
105
- def _ensure_client(self) -> SpeechAsyncClient:
106
- if self._credentials_info:
107
- self._client = SpeechAsyncClient.from_service_account_info(
108
- self._credentials_info
109
- )
110
- elif self._credentials_file:
111
- self._client = SpeechAsyncClient.from_service_account_file(
112
- self._credentials_file
113
- )
114
- else:
115
- self._client = SpeechAsyncClient()
116
-
117
- assert self._client is not None
118
- return self._client
119
-
120
- @property
121
- def _recognizer(self) -> str:
122
- # TODO(theomonnom): should we use recognizers?
123
- # recognizers may improve latency https://cloud.google.com/speech-to-text/v2/docs/recognizers#understand_recognizers
124
-
125
- # TODO(theomonnom): find a better way to access the project_id
126
- try:
127
- project_id = self._ensure_client().transport._credentials.project_id # type: ignore
128
- except AttributeError:
129
- from google.auth import default as ga_default
130
-
131
- _, project_id = ga_default()
132
- return f"projects/{project_id}/locations/global/recognizers/_"
133
-
134
- def _sanitize_options(self, *, language: str | None = None) -> STTOptions:
135
- config = dataclasses.replace(self._config)
136
-
137
- if language:
138
- config.languages = [language]
139
-
140
- if not isinstance(config.languages, list):
141
- config.languages = [config.languages]
142
- elif not config.detect_language:
143
- if len(config.languages) > 1:
144
- logger.warning(
145
- "multiple languages provided, but language detection is disabled"
146
- )
147
- config.languages = [config.languages[0]]
148
-
149
- return config
150
-
151
- async def _recognize_impl(
152
- self,
153
- buffer: utils.AudioBuffer,
154
- *,
155
- language: SpeechLanguages | str | None = None,
156
- ) -> stt.SpeechEvent:
157
- config = self._sanitize_options(language=language)
158
- frame = agents.utils.merge_frames(buffer)
159
-
160
- config = cloud_speech.RecognitionConfig(
161
- explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
162
- encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
163
- sample_rate_hertz=frame.sample_rate,
164
- audio_channel_count=frame.num_channels,
165
- ),
166
- features=cloud_speech.RecognitionFeatures(
167
- enable_automatic_punctuation=config.punctuate,
168
- enable_spoken_punctuation=config.spoken_punctuation,
169
- enable_word_time_offsets=True,
170
- ),
171
- model=config.model,
172
- language_codes=config.languages,
173
- )
174
-
175
- try:
176
- raw = await self._ensure_client().recognize(
177
- cloud_speech.RecognizeRequest(
178
- recognizer=self._recognizer,
179
- config=config,
180
- content=frame.data.tobytes(),
181
- )
182
- )
183
-
184
- return _recognize_response_to_speech_event(raw)
185
- except DeadlineExceeded:
186
- raise APITimeoutError()
187
- except GoogleAPICallError as e:
188
- raise APIStatusError(
189
- e.message,
190
- status_code=e.code or -1,
191
- request_id=None,
192
- body=None,
193
- )
194
- except Exception as e:
195
- raise APIConnectionError() from e
196
-
197
- def stream(
198
- self, *, language: SpeechLanguages | str | None = None
199
- ) -> "SpeechStream":
200
- config = self._sanitize_options(language=language)
201
- return SpeechStream(self, self._ensure_client(), self._recognizer, config)
202
-
203
-
204
- class SpeechStream(stt.SpeechStream):
205
- def __init__(
206
- self,
207
- stt: STT,
208
- client: SpeechAsyncClient,
209
- recognizer: str,
210
- config: STTOptions,
211
- sample_rate: int = 48000,
212
- num_channels: int = 1,
213
- max_retry: int = 32,
214
- ) -> None:
215
- super().__init__(stt)
216
-
217
- self._client = client
218
- self._recognizer = recognizer
219
- self._config = config
220
- self._sample_rate = sample_rate
221
- self._num_channels = num_channels
222
- self._max_retry = max_retry
223
-
224
- self._streaming_config = cloud_speech.StreamingRecognitionConfig(
225
- config=cloud_speech.RecognitionConfig(
226
- explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
227
- encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
228
- sample_rate_hertz=self._sample_rate,
229
- audio_channel_count=self._num_channels,
230
- ),
231
- language_codes=self._config.languages,
232
- model=self._config.model,
233
- features=cloud_speech.RecognitionFeatures(
234
- enable_automatic_punctuation=self._config.punctuate,
235
- enable_word_time_offsets=True,
236
- ),
237
- ),
238
- streaming_features=cloud_speech.StreamingRecognitionFeatures(
239
- enable_voice_activity_events=True,
240
- interim_results=self._config.interim_results,
241
- ),
242
- )
243
-
244
- @utils.log_exceptions(logger=logger)
245
- async def _main_task(self) -> None:
246
- await self._run(self._max_retry)
247
-
248
- async def _run(self, max_retry: int) -> None:
249
- retry_count = 0
250
- while self._input_ch.qsize() or not self._input_ch.closed:
251
- try:
252
- # google requires a async generator when calling streaming_recognize
253
- # this function basically convert the queue into a async generator
254
- async def input_generator():
255
- try:
256
- # first request should contain the config
257
- yield cloud_speech.StreamingRecognizeRequest(
258
- recognizer=self._recognizer,
259
- streaming_config=self._streaming_config,
260
- )
261
-
262
- async for frame in self._input_ch:
263
- if isinstance(frame, rtc.AudioFrame):
264
- frame = frame.remix_and_resample(
265
- self._sample_rate, self._num_channels
266
- )
267
- yield cloud_speech.StreamingRecognizeRequest(
268
- audio=frame.data.tobytes()
269
- )
270
-
271
- except Exception:
272
- logger.exception(
273
- "an error occurred while streaming input to google STT"
274
- )
275
-
276
- # try to connect
277
- stream = await self._client.streaming_recognize(
278
- requests=input_generator()
279
- )
280
- retry_count = 0 # connection successful, reset retry count
281
-
282
- await self._run_stream(stream)
283
- except Exception as e:
284
- if retry_count >= max_retry:
285
- logger.error(
286
- f"failed to connect to google stt after {max_retry} tries",
287
- exc_info=e,
288
- )
289
- break
290
-
291
- retry_delay = min(retry_count * 2, 5) # max 5s
292
- retry_count += 1
293
- logger.warning(
294
- f"google stt connection failed, retrying in {retry_delay}s",
295
- exc_info=e,
296
- )
297
- await asyncio.sleep(retry_delay)
298
-
299
- async def _run_stream(
300
- self, stream: AsyncIterable[cloud_speech.StreamingRecognizeResponse]
301
- ):
302
- async for resp in stream:
303
- if (
304
- resp.speech_event_type
305
- == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN
306
- ):
307
- self._event_ch.send_nowait(
308
- stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
309
- )
310
-
311
- if (
312
- resp.speech_event_type
313
- == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_TYPE_UNSPECIFIED
314
- ):
315
- result = resp.results[0]
316
- speech_data = _streaming_recognize_response_to_speech_data(resp)
317
- if speech_data is None:
318
- continue
319
-
320
- if not result.is_final:
321
- self._event_ch.send_nowait(
322
- stt.SpeechEvent(
323
- type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
324
- alternatives=[speech_data],
325
- )
326
- )
327
- else:
328
- self._event_ch.send_nowait(
329
- stt.SpeechEvent(
330
- type=stt.SpeechEventType.FINAL_TRANSCRIPT,
331
- alternatives=[speech_data],
332
- )
333
- )
334
-
335
- if (
336
- resp.speech_event_type
337
- == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END
338
- ):
339
- self._event_ch.send_nowait(
340
- stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
341
- )
342
-
343
-
344
- def _recognize_response_to_speech_event(
345
- resp: cloud_speech.RecognizeResponse,
346
- ) -> stt.SpeechEvent:
347
- text = ""
348
- confidence = 0.0
349
- for result in resp.results:
350
- text += result.alternatives[0].transcript
351
- confidence += result.alternatives[0].confidence
352
-
353
- # not sure why start_offset and end_offset returns a timedelta
354
- start_offset = resp.results[0].alternatives[0].words[0].start_offset
355
- end_offset = resp.results[-1].alternatives[0].words[-1].end_offset
356
-
357
- confidence /= len(resp.results)
358
- lg = resp.results[0].language_code
359
- return stt.SpeechEvent(
360
- type=stt.SpeechEventType.FINAL_TRANSCRIPT,
361
- alternatives=[
362
- stt.SpeechData(
363
- language=lg,
364
- start_time=start_offset.total_seconds(), # type: ignore
365
- end_time=end_offset.total_seconds(), # type: ignore
366
- confidence=confidence,
367
- text=text,
368
- )
369
- ],
370
- )
371
-
372
-
373
- def _streaming_recognize_response_to_speech_data(
374
- resp: cloud_speech.StreamingRecognizeResponse,
375
- ) -> stt.SpeechData | None:
376
- text = ""
377
- confidence = 0.0
378
- for result in resp.results:
379
- if len(result.alternatives) == 0:
380
- continue
381
- text += result.alternatives[0].transcript
382
- confidence += result.alternatives[0].confidence
383
-
384
- confidence /= len(resp.results)
385
- lg = resp.results[0].language_code
386
-
387
- if text == "":
388
- return None
389
-
390
- data = stt.SpeechData(
391
- language=lg, start_time=0, end_time=0, confidence=confidence, text=text
392
- )
393
-
394
- return data