livekit-plugins-google 0.3.0__py3-none-any.whl → 1.3.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/google/__init__.py +25 -7
- livekit/plugins/google/beta/__init__.py +13 -0
- livekit/plugins/google/beta/gemini_tts.py +258 -0
- livekit/plugins/google/llm.py +501 -0
- livekit/plugins/google/log.py +3 -0
- livekit/plugins/google/models.py +145 -31
- livekit/plugins/google/realtime/__init__.py +9 -0
- livekit/plugins/google/realtime/api_proto.py +66 -0
- livekit/plugins/google/realtime/realtime_api.py +1252 -0
- livekit/plugins/google/stt.py +518 -272
- livekit/plugins/google/tools.py +11 -0
- livekit/plugins/google/tts.py +447 -0
- livekit/plugins/google/utils.py +286 -0
- livekit/plugins/google/version.py +1 -1
- livekit_plugins_google-1.3.8.dist-info/METADATA +63 -0
- livekit_plugins_google-1.3.8.dist-info/RECORD +18 -0
- {livekit_plugins_google-0.3.0.dist-info → livekit_plugins_google-1.3.8.dist-info}/WHEEL +1 -2
- livekit_plugins_google-0.3.0.dist-info/METADATA +0 -47
- livekit_plugins_google-0.3.0.dist-info/RECORD +0 -9
- livekit_plugins_google-0.3.0.dist-info/top_level.txt +0 -1
livekit/plugins/google/stt.py
CHANGED
|
@@ -15,35 +15,82 @@
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
17
|
import asyncio
|
|
18
|
-
import contextlib
|
|
19
18
|
import dataclasses
|
|
20
|
-
import
|
|
19
|
+
import time
|
|
20
|
+
import weakref
|
|
21
|
+
from collections.abc import AsyncGenerator, AsyncIterable
|
|
21
22
|
from dataclasses import dataclass
|
|
22
|
-
from
|
|
23
|
+
from datetime import timedelta
|
|
24
|
+
from typing import Callable, Union, cast
|
|
23
25
|
|
|
24
|
-
from
|
|
25
|
-
from
|
|
26
|
-
from
|
|
27
|
-
|
|
28
|
-
from google.auth import credentials # type: ignore
|
|
26
|
+
from google.api_core.client_options import ClientOptions
|
|
27
|
+
from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
|
|
28
|
+
from google.auth import default as gauth_default
|
|
29
|
+
from google.auth.exceptions import DefaultCredentialsError
|
|
29
30
|
from google.cloud.speech_v2 import SpeechAsyncClient
|
|
30
31
|
from google.cloud.speech_v2.types import cloud_speech
|
|
31
|
-
|
|
32
|
+
from google.protobuf.duration_pb2 import Duration
|
|
33
|
+
from livekit import rtc
|
|
34
|
+
from livekit.agents import (
|
|
35
|
+
DEFAULT_API_CONNECT_OPTIONS,
|
|
36
|
+
APIConnectionError,
|
|
37
|
+
APIConnectOptions,
|
|
38
|
+
APIStatusError,
|
|
39
|
+
APITimeoutError,
|
|
40
|
+
stt,
|
|
41
|
+
utils,
|
|
42
|
+
)
|
|
43
|
+
from livekit.agents.types import (
|
|
44
|
+
NOT_GIVEN,
|
|
45
|
+
NotGivenOr,
|
|
46
|
+
)
|
|
47
|
+
from livekit.agents.utils import is_given
|
|
48
|
+
|
|
49
|
+
from .log import logger
|
|
32
50
|
from .models import SpeechLanguages, SpeechModels
|
|
33
51
|
|
|
34
|
-
LgType = SpeechLanguages
|
|
35
|
-
LanguageCode = LgType
|
|
52
|
+
LgType = Union[SpeechLanguages, str]
|
|
53
|
+
LanguageCode = Union[LgType, list[LgType]]
|
|
54
|
+
|
|
55
|
+
# Google STT has a timeout of 5 mins, we'll attempt to restart the session
|
|
56
|
+
# before that timeout is reached
|
|
57
|
+
_max_session_duration = 240
|
|
58
|
+
|
|
59
|
+
# Google is very sensitive to background noise, so we'll ignore results with low confidence
|
|
60
|
+
_default_min_confidence = 0.65
|
|
36
61
|
|
|
37
62
|
|
|
38
63
|
# This class is only be used internally to encapsulate the options
|
|
39
64
|
@dataclass
|
|
40
65
|
class STTOptions:
|
|
41
|
-
languages:
|
|
66
|
+
languages: list[LgType]
|
|
42
67
|
detect_language: bool
|
|
43
68
|
interim_results: bool
|
|
44
69
|
punctuate: bool
|
|
45
70
|
spoken_punctuation: bool
|
|
46
|
-
|
|
71
|
+
enable_word_time_offsets: bool
|
|
72
|
+
enable_word_confidence: bool
|
|
73
|
+
enable_voice_activity_events: bool
|
|
74
|
+
model: SpeechModels | str
|
|
75
|
+
sample_rate: int
|
|
76
|
+
min_confidence_threshold: float
|
|
77
|
+
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN
|
|
78
|
+
|
|
79
|
+
def build_adaptation(self) -> cloud_speech.SpeechAdaptation | None:
|
|
80
|
+
if is_given(self.keywords):
|
|
81
|
+
return cloud_speech.SpeechAdaptation(
|
|
82
|
+
phrase_sets=[
|
|
83
|
+
cloud_speech.SpeechAdaptation.AdaptationPhraseSet(
|
|
84
|
+
inline_phrase_set=cloud_speech.PhraseSet(
|
|
85
|
+
phrases=[
|
|
86
|
+
cloud_speech.PhraseSet.Phrase(value=keyword, boost=boost)
|
|
87
|
+
for keyword, boost in self.keywords
|
|
88
|
+
]
|
|
89
|
+
)
|
|
90
|
+
)
|
|
91
|
+
]
|
|
92
|
+
)
|
|
93
|
+
return None
|
|
47
94
|
|
|
48
95
|
|
|
49
96
|
class STT(stt.STT):
|
|
@@ -54,23 +101,64 @@ class STT(stt.STT):
|
|
|
54
101
|
detect_language: bool = True,
|
|
55
102
|
interim_results: bool = True,
|
|
56
103
|
punctuate: bool = True,
|
|
57
|
-
spoken_punctuation: bool =
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
104
|
+
spoken_punctuation: bool = False,
|
|
105
|
+
enable_word_time_offsets: bool = True,
|
|
106
|
+
enable_word_confidence: bool = False,
|
|
107
|
+
enable_voice_activity_events: bool = False,
|
|
108
|
+
model: SpeechModels | str = "latest_long",
|
|
109
|
+
location: str = "global",
|
|
110
|
+
sample_rate: int = 16000,
|
|
111
|
+
min_confidence_threshold: float = _default_min_confidence,
|
|
112
|
+
credentials_info: NotGivenOr[dict] = NOT_GIVEN,
|
|
113
|
+
credentials_file: NotGivenOr[str] = NOT_GIVEN,
|
|
114
|
+
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
|
|
115
|
+
use_streaming: NotGivenOr[bool] = NOT_GIVEN,
|
|
61
116
|
):
|
|
62
117
|
"""
|
|
63
|
-
|
|
64
|
-
|
|
118
|
+
Create a new instance of Google STT.
|
|
119
|
+
|
|
120
|
+
Credentials must be provided, either by using the ``credentials_info`` dict, or reading
|
|
121
|
+
from the file specified in ``credentials_file`` or via Application Default Credentials as
|
|
122
|
+
described in https://cloud.google.com/docs/authentication/application-default-credentials
|
|
123
|
+
|
|
124
|
+
args:
|
|
125
|
+
languages(LanguageCode): list of language codes to recognize (default: "en-US")
|
|
126
|
+
detect_language(bool): whether to detect the language of the audio (default: True)
|
|
127
|
+
interim_results(bool): whether to return interim results (default: True)
|
|
128
|
+
punctuate(bool): whether to punctuate the audio (default: True)
|
|
129
|
+
spoken_punctuation(bool): whether to use spoken punctuation (default: False)
|
|
130
|
+
enable_word_time_offsets(bool): whether to enable word time offsets (default: True)
|
|
131
|
+
enable_word_confidence(bool): whether to enable word confidence (default: False)
|
|
132
|
+
enable_voice_activity_events(bool): whether to enable voice activity events (default: False)
|
|
133
|
+
model(SpeechModels): the model to use for recognition default: "latest_long"
|
|
134
|
+
location(str): the location to use for recognition default: "global"
|
|
135
|
+
sample_rate(int): the sample rate of the audio default: 16000
|
|
136
|
+
min_confidence_threshold(float): minimum confidence threshold for recognition
|
|
137
|
+
(default: 0.65)
|
|
138
|
+
credentials_info(dict): the credentials info to use for recognition (default: None)
|
|
139
|
+
credentials_file(str): the credentials file to use for recognition (default: None)
|
|
140
|
+
keywords(List[tuple[str, float]]): list of keywords to recognize (default: None)
|
|
141
|
+
use_streaming(bool): whether to use streaming for recognition (default: True)
|
|
65
142
|
"""
|
|
66
|
-
|
|
143
|
+
if not is_given(use_streaming):
|
|
144
|
+
use_streaming = True
|
|
145
|
+
super().__init__(
|
|
146
|
+
capabilities=stt.STTCapabilities(streaming=use_streaming, interim_results=True)
|
|
147
|
+
)
|
|
67
148
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
149
|
+
self._location = location
|
|
150
|
+
self._credentials_info = credentials_info
|
|
151
|
+
self._credentials_file = credentials_file
|
|
152
|
+
|
|
153
|
+
if not is_given(credentials_file) and not is_given(credentials_info):
|
|
154
|
+
try:
|
|
155
|
+
gauth_default() # type: ignore
|
|
156
|
+
except DefaultCredentialsError:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
"Application default credentials must be available "
|
|
159
|
+
"when using Google STT without explicitly passing "
|
|
160
|
+
"credentials through credentials_info or credentials_file."
|
|
161
|
+
) from None
|
|
74
162
|
|
|
75
163
|
if isinstance(languages, str):
|
|
76
164
|
languages = [languages]
|
|
@@ -81,322 +169,480 @@ class STT(stt.STT):
|
|
|
81
169
|
interim_results=interim_results,
|
|
82
170
|
punctuate=punctuate,
|
|
83
171
|
spoken_punctuation=spoken_punctuation,
|
|
172
|
+
enable_word_time_offsets=enable_word_time_offsets,
|
|
173
|
+
enable_word_confidence=enable_word_confidence,
|
|
174
|
+
enable_voice_activity_events=enable_voice_activity_events,
|
|
84
175
|
model=model,
|
|
176
|
+
sample_rate=sample_rate,
|
|
177
|
+
min_confidence_threshold=min_confidence_threshold,
|
|
178
|
+
keywords=keywords,
|
|
85
179
|
)
|
|
86
|
-
self.
|
|
180
|
+
self._streams = weakref.WeakSet[SpeechStream]()
|
|
181
|
+
self._pool = utils.ConnectionPool[SpeechAsyncClient](
|
|
182
|
+
max_session_duration=_max_session_duration,
|
|
183
|
+
connect_cb=self._create_client,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
@property
|
|
187
|
+
def model(self) -> str:
|
|
188
|
+
return self._config.model
|
|
87
189
|
|
|
88
190
|
@property
|
|
89
|
-
def
|
|
191
|
+
def provider(self) -> str:
|
|
192
|
+
return "Google Cloud Platform"
|
|
193
|
+
|
|
194
|
+
async def _create_client(self, timeout: float) -> SpeechAsyncClient:
|
|
195
|
+
# Add support for passing a specific location that matches recognizer
|
|
196
|
+
# see: https://cloud.google.com/speech-to-text/v2/docs/speech-to-text-supported-languages
|
|
197
|
+
# TODO(long): how to set timeout?
|
|
198
|
+
client_options = None
|
|
199
|
+
client: SpeechAsyncClient | None = None
|
|
200
|
+
if self._location != "global":
|
|
201
|
+
client_options = ClientOptions(api_endpoint=f"{self._location}-speech.googleapis.com")
|
|
202
|
+
if is_given(self._credentials_info):
|
|
203
|
+
client = SpeechAsyncClient.from_service_account_info(
|
|
204
|
+
self._credentials_info, client_options=client_options
|
|
205
|
+
)
|
|
206
|
+
elif is_given(self._credentials_file):
|
|
207
|
+
client = SpeechAsyncClient.from_service_account_file(
|
|
208
|
+
self._credentials_file, client_options=client_options
|
|
209
|
+
)
|
|
210
|
+
else:
|
|
211
|
+
client = SpeechAsyncClient(client_options=client_options)
|
|
212
|
+
assert client is not None
|
|
213
|
+
return client
|
|
214
|
+
|
|
215
|
+
def _get_recognizer(self, client: SpeechAsyncClient) -> str:
|
|
90
216
|
# TODO(theomonnom): should we use recognizers?
|
|
91
|
-
#
|
|
92
|
-
return f"projects/{self._creds.project_id}/locations/global/recognizers/_" # type: ignore
|
|
217
|
+
# recognizers may improve latency https://cloud.google.com/speech-to-text/v2/docs/recognizers#understand_recognizers
|
|
93
218
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
219
|
+
# TODO(theomonnom): find a better way to access the project_id
|
|
220
|
+
try:
|
|
221
|
+
project_id = client.transport._credentials.project_id # type: ignore
|
|
222
|
+
except AttributeError:
|
|
223
|
+
from google.auth import default as ga_default
|
|
224
|
+
|
|
225
|
+
_, project_id = ga_default() # type: ignore
|
|
226
|
+
return f"projects/{project_id}/locations/{self._location}/recognizers/_"
|
|
227
|
+
|
|
228
|
+
def _sanitize_options(self, *, language: NotGivenOr[str] = NOT_GIVEN) -> STTOptions:
|
|
99
229
|
config = dataclasses.replace(self._config)
|
|
100
230
|
|
|
101
|
-
if language:
|
|
231
|
+
if is_given(language):
|
|
102
232
|
config.languages = [language]
|
|
103
233
|
|
|
104
234
|
if not isinstance(config.languages, list):
|
|
105
235
|
config.languages = [config.languages]
|
|
106
236
|
elif not config.detect_language:
|
|
107
237
|
if len(config.languages) > 1:
|
|
108
|
-
|
|
109
|
-
"multiple languages provided, but language detection is disabled"
|
|
110
|
-
)
|
|
238
|
+
logger.warning("multiple languages provided, but language detection is disabled")
|
|
111
239
|
config.languages = [config.languages[0]]
|
|
112
240
|
|
|
113
241
|
return config
|
|
114
242
|
|
|
115
|
-
async def
|
|
243
|
+
async def _recognize_impl(
|
|
116
244
|
self,
|
|
245
|
+
buffer: utils.AudioBuffer,
|
|
117
246
|
*,
|
|
118
|
-
|
|
119
|
-
|
|
247
|
+
language: NotGivenOr[SpeechLanguages | str] = NOT_GIVEN,
|
|
248
|
+
conn_options: APIConnectOptions,
|
|
120
249
|
) -> stt.SpeechEvent:
|
|
121
250
|
config = self._sanitize_options(language=language)
|
|
122
|
-
|
|
251
|
+
frame = rtc.combine_audio_frames(buffer)
|
|
123
252
|
|
|
124
253
|
config = cloud_speech.RecognitionConfig(
|
|
125
254
|
explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
|
|
126
255
|
encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
|
|
127
|
-
sample_rate_hertz=
|
|
128
|
-
audio_channel_count=
|
|
256
|
+
sample_rate_hertz=frame.sample_rate,
|
|
257
|
+
audio_channel_count=frame.num_channels,
|
|
129
258
|
),
|
|
259
|
+
adaptation=config.build_adaptation(),
|
|
130
260
|
features=cloud_speech.RecognitionFeatures(
|
|
131
261
|
enable_automatic_punctuation=config.punctuate,
|
|
132
262
|
enable_spoken_punctuation=config.spoken_punctuation,
|
|
263
|
+
enable_word_time_offsets=config.enable_word_time_offsets,
|
|
264
|
+
enable_word_confidence=config.enable_word_confidence,
|
|
133
265
|
),
|
|
134
266
|
model=config.model,
|
|
135
267
|
language_codes=config.languages,
|
|
136
268
|
)
|
|
137
269
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
270
|
+
try:
|
|
271
|
+
async with self._pool.connection(timeout=conn_options.timeout) as client:
|
|
272
|
+
raw = await client.recognize(
|
|
273
|
+
cloud_speech.RecognizeRequest(
|
|
274
|
+
recognizer=self._get_recognizer(client),
|
|
275
|
+
config=config,
|
|
276
|
+
content=frame.data.tobytes(),
|
|
277
|
+
),
|
|
278
|
+
timeout=conn_options.timeout,
|
|
144
279
|
)
|
|
145
|
-
|
|
146
|
-
|
|
280
|
+
|
|
281
|
+
return _recognize_response_to_speech_event(raw)
|
|
282
|
+
except DeadlineExceeded:
|
|
283
|
+
raise APITimeoutError() from None
|
|
284
|
+
except GoogleAPICallError as e:
|
|
285
|
+
raise APIStatusError(f"{e.message} {e.details}", status_code=e.code or -1) from e
|
|
286
|
+
except Exception as e:
|
|
287
|
+
raise APIConnectionError() from e
|
|
147
288
|
|
|
148
289
|
def stream(
|
|
149
290
|
self,
|
|
150
291
|
*,
|
|
151
|
-
language: SpeechLanguages | str
|
|
152
|
-
|
|
292
|
+
language: NotGivenOr[SpeechLanguages | str] = NOT_GIVEN,
|
|
293
|
+
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
294
|
+
) -> SpeechStream:
|
|
153
295
|
config = self._sanitize_options(language=language)
|
|
154
|
-
|
|
155
|
-
self
|
|
156
|
-
self.
|
|
157
|
-
self.
|
|
158
|
-
config,
|
|
296
|
+
stream = SpeechStream(
|
|
297
|
+
stt=self,
|
|
298
|
+
pool=self._pool,
|
|
299
|
+
recognizer_cb=self._get_recognizer,
|
|
300
|
+
config=config,
|
|
301
|
+
conn_options=conn_options,
|
|
159
302
|
)
|
|
303
|
+
self._streams.add(stream)
|
|
304
|
+
return stream
|
|
305
|
+
|
|
306
|
+
def update_options(
|
|
307
|
+
self,
|
|
308
|
+
*,
|
|
309
|
+
languages: NotGivenOr[LanguageCode] = NOT_GIVEN,
|
|
310
|
+
detect_language: NotGivenOr[bool] = NOT_GIVEN,
|
|
311
|
+
interim_results: NotGivenOr[bool] = NOT_GIVEN,
|
|
312
|
+
punctuate: NotGivenOr[bool] = NOT_GIVEN,
|
|
313
|
+
spoken_punctuation: NotGivenOr[bool] = NOT_GIVEN,
|
|
314
|
+
model: NotGivenOr[SpeechModels] = NOT_GIVEN,
|
|
315
|
+
location: NotGivenOr[str] = NOT_GIVEN,
|
|
316
|
+
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
|
|
317
|
+
) -> None:
|
|
318
|
+
if is_given(languages):
|
|
319
|
+
if isinstance(languages, str):
|
|
320
|
+
languages = [languages]
|
|
321
|
+
self._config.languages = cast(list[LgType], languages)
|
|
322
|
+
if is_given(detect_language):
|
|
323
|
+
self._config.detect_language = detect_language
|
|
324
|
+
if is_given(interim_results):
|
|
325
|
+
self._config.interim_results = interim_results
|
|
326
|
+
if is_given(punctuate):
|
|
327
|
+
self._config.punctuate = punctuate
|
|
328
|
+
if is_given(spoken_punctuation):
|
|
329
|
+
self._config.spoken_punctuation = spoken_punctuation
|
|
330
|
+
if is_given(model):
|
|
331
|
+
self._config.model = model
|
|
332
|
+
if is_given(location):
|
|
333
|
+
self._location = location
|
|
334
|
+
# if location is changed, fetch a new client and recognizer as per the new location
|
|
335
|
+
self._pool.invalidate()
|
|
336
|
+
if is_given(keywords):
|
|
337
|
+
self._config.keywords = keywords
|
|
338
|
+
|
|
339
|
+
for stream in self._streams:
|
|
340
|
+
stream.update_options(
|
|
341
|
+
languages=languages,
|
|
342
|
+
detect_language=detect_language,
|
|
343
|
+
interim_results=interim_results,
|
|
344
|
+
punctuate=punctuate,
|
|
345
|
+
spoken_punctuation=spoken_punctuation,
|
|
346
|
+
model=model,
|
|
347
|
+
keywords=keywords,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
async def aclose(self) -> None:
|
|
351
|
+
await self._pool.aclose()
|
|
352
|
+
await super().aclose()
|
|
160
353
|
|
|
161
354
|
|
|
162
355
|
class SpeechStream(stt.SpeechStream):
|
|
163
356
|
def __init__(
|
|
164
357
|
self,
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
358
|
+
*,
|
|
359
|
+
stt: STT,
|
|
360
|
+
conn_options: APIConnectOptions,
|
|
361
|
+
pool: utils.ConnectionPool[SpeechAsyncClient],
|
|
362
|
+
recognizer_cb: Callable[[SpeechAsyncClient], str],
|
|
168
363
|
config: STTOptions,
|
|
169
|
-
sample_rate: int = 24000,
|
|
170
|
-
num_channels: int = 1,
|
|
171
|
-
max_retry: int = 32,
|
|
172
364
|
) -> None:
|
|
173
|
-
super().__init__()
|
|
365
|
+
super().__init__(stt=stt, conn_options=conn_options, sample_rate=config.sample_rate)
|
|
174
366
|
|
|
175
|
-
self.
|
|
176
|
-
self.
|
|
177
|
-
self._recognizer = recognizer
|
|
367
|
+
self._pool = pool
|
|
368
|
+
self._recognizer_cb = recognizer_cb
|
|
178
369
|
self._config = config
|
|
179
|
-
self.
|
|
180
|
-
self.
|
|
181
|
-
|
|
182
|
-
self._queue = asyncio.Queue[rtc.AudioFrame | None]()
|
|
183
|
-
self._event_queue = asyncio.Queue[stt.SpeechEvent | None]()
|
|
184
|
-
self._closed = False
|
|
185
|
-
self._main_task = asyncio.create_task(self._run(max_retry=max_retry))
|
|
186
|
-
|
|
187
|
-
self._final_events: List[stt.SpeechEvent] = []
|
|
188
|
-
self._speaking = False
|
|
189
|
-
|
|
190
|
-
self._streaming_config = cloud_speech.StreamingRecognitionConfig(
|
|
191
|
-
config=cloud_speech.RecognitionConfig(
|
|
192
|
-
explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
|
|
193
|
-
encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
|
|
194
|
-
sample_rate_hertz=self._sample_rate,
|
|
195
|
-
audio_channel_count=self._num_channels,
|
|
196
|
-
),
|
|
197
|
-
language_codes=self._config.languages,
|
|
198
|
-
model=self._config.model,
|
|
199
|
-
features=cloud_speech.RecognitionFeatures(
|
|
200
|
-
enable_automatic_punctuation=self._config.punctuate,
|
|
201
|
-
),
|
|
202
|
-
),
|
|
203
|
-
streaming_features=cloud_speech.StreamingRecognitionFeatures(
|
|
204
|
-
enable_voice_activity_events=True,
|
|
205
|
-
interim_results=self._config.interim_results,
|
|
206
|
-
),
|
|
207
|
-
)
|
|
370
|
+
self._reconnect_event = asyncio.Event()
|
|
371
|
+
self._session_connected_at: float = 0
|
|
208
372
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
self.
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
373
|
+
def update_options(
|
|
374
|
+
self,
|
|
375
|
+
*,
|
|
376
|
+
languages: NotGivenOr[LanguageCode] = NOT_GIVEN,
|
|
377
|
+
detect_language: NotGivenOr[bool] = NOT_GIVEN,
|
|
378
|
+
interim_results: NotGivenOr[bool] = NOT_GIVEN,
|
|
379
|
+
punctuate: NotGivenOr[bool] = NOT_GIVEN,
|
|
380
|
+
spoken_punctuation: NotGivenOr[bool] = NOT_GIVEN,
|
|
381
|
+
model: NotGivenOr[SpeechModels] = NOT_GIVEN,
|
|
382
|
+
min_confidence_threshold: NotGivenOr[float] = NOT_GIVEN,
|
|
383
|
+
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
|
|
384
|
+
) -> None:
|
|
385
|
+
if is_given(languages):
|
|
386
|
+
if isinstance(languages, str):
|
|
387
|
+
languages = [languages]
|
|
388
|
+
self._config.languages = cast(list[LgType], languages)
|
|
389
|
+
if is_given(detect_language):
|
|
390
|
+
self._config.detect_language = detect_language
|
|
391
|
+
if is_given(interim_results):
|
|
392
|
+
self._config.interim_results = interim_results
|
|
393
|
+
if is_given(punctuate):
|
|
394
|
+
self._config.punctuate = punctuate
|
|
395
|
+
if is_given(spoken_punctuation):
|
|
396
|
+
self._config.spoken_punctuation = spoken_punctuation
|
|
397
|
+
if is_given(model):
|
|
398
|
+
self._config.model = model
|
|
399
|
+
if is_given(min_confidence_threshold):
|
|
400
|
+
self._config.min_confidence_threshold = min_confidence_threshold
|
|
401
|
+
if is_given(keywords):
|
|
402
|
+
self._config.keywords = keywords
|
|
403
|
+
|
|
404
|
+
self._reconnect_event.set()
|
|
405
|
+
|
|
406
|
+
async def _run(self) -> None:
|
|
407
|
+
audio_pushed = False
|
|
408
|
+
|
|
409
|
+
# google requires a async generator when calling streaming_recognize
|
|
410
|
+
# this function basically convert the queue into a async generator
|
|
411
|
+
async def input_generator(
|
|
412
|
+
client: SpeechAsyncClient, should_stop: asyncio.Event
|
|
413
|
+
) -> AsyncGenerator[cloud_speech.StreamingRecognizeRequest, None]:
|
|
414
|
+
nonlocal audio_pushed
|
|
415
|
+
try:
|
|
416
|
+
# first request should contain the config
|
|
417
|
+
yield cloud_speech.StreamingRecognizeRequest(
|
|
418
|
+
recognizer=self._recognizer_cb(client),
|
|
419
|
+
streaming_config=self._streaming_config,
|
|
420
|
+
)
|
|
229
421
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
422
|
+
async for frame in self._input_ch:
|
|
423
|
+
# when the stream is aborted due to reconnect, this input_generator
|
|
424
|
+
# needs to stop consuming frames
|
|
425
|
+
# when the generator stops, the previous gRPC stream will close
|
|
426
|
+
if should_stop.is_set():
|
|
427
|
+
return
|
|
428
|
+
|
|
429
|
+
if isinstance(frame, rtc.AudioFrame):
|
|
430
|
+
yield cloud_speech.StreamingRecognizeRequest(audio=frame.data.tobytes())
|
|
431
|
+
if not audio_pushed:
|
|
432
|
+
audio_pushed = True
|
|
433
|
+
|
|
434
|
+
except Exception:
|
|
435
|
+
logger.exception("an error occurred while streaming input to google STT")
|
|
436
|
+
|
|
437
|
+
async def process_stream(
|
|
438
|
+
client: SpeechAsyncClient,
|
|
439
|
+
stream: AsyncIterable[cloud_speech.StreamingRecognizeResponse],
|
|
440
|
+
) -> None:
|
|
441
|
+
has_started = False
|
|
442
|
+
async for resp in stream:
|
|
443
|
+
if (
|
|
444
|
+
resp.speech_event_type
|
|
445
|
+
== cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN
|
|
446
|
+
):
|
|
447
|
+
self._event_ch.send_nowait(
|
|
448
|
+
stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
|
|
449
|
+
)
|
|
450
|
+
has_started = True
|
|
451
|
+
|
|
452
|
+
if (
|
|
453
|
+
resp.speech_event_type
|
|
454
|
+
== cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_TYPE_UNSPECIFIED # noqa: E501
|
|
455
|
+
):
|
|
456
|
+
result = resp.results[0]
|
|
457
|
+
speech_data = _streaming_recognize_response_to_speech_data(
|
|
458
|
+
resp,
|
|
459
|
+
min_confidence_threshold=self._config.min_confidence_threshold,
|
|
460
|
+
)
|
|
461
|
+
if speech_data is None:
|
|
462
|
+
continue
|
|
463
|
+
|
|
464
|
+
if not result.is_final:
|
|
465
|
+
self._event_ch.send_nowait(
|
|
466
|
+
stt.SpeechEvent(
|
|
467
|
+
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
|
|
468
|
+
alternatives=[speech_data],
|
|
243
469
|
)
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
self._queue.task_done()
|
|
252
|
-
frame = frame.remix_and_resample(
|
|
253
|
-
self._sample_rate, self._num_channels
|
|
254
|
-
)
|
|
255
|
-
yield cloud_speech.StreamingRecognizeRequest(
|
|
256
|
-
audio=frame.data.tobytes(),
|
|
257
|
-
)
|
|
258
|
-
except Exception as e:
|
|
259
|
-
logging.error(
|
|
260
|
-
f"an error occurred while streaming inputs: {e}"
|
|
470
|
+
)
|
|
471
|
+
else:
|
|
472
|
+
self._event_ch.send_nowait(
|
|
473
|
+
stt.SpeechEvent(
|
|
474
|
+
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
|
475
|
+
alternatives=[speech_data],
|
|
261
476
|
)
|
|
262
|
-
|
|
263
|
-
# try to connect
|
|
264
|
-
stream = await self._client.streaming_recognize(
|
|
265
|
-
requests=input_generator()
|
|
266
|
-
)
|
|
267
|
-
retry_count = 0 # connection successful, reset retry count
|
|
268
|
-
|
|
269
|
-
await self._run_stream(stream)
|
|
270
|
-
except Exception as e:
|
|
271
|
-
if retry_count >= max_retry:
|
|
272
|
-
logging.error(
|
|
273
|
-
f"failed to connect to google stt after {max_retry} tries",
|
|
274
|
-
exc_info=e,
|
|
275
477
|
)
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
478
|
+
if time.time() - self._session_connected_at > _max_session_duration:
|
|
479
|
+
logger.debug(
|
|
480
|
+
"Google STT maximum connection time reached. Reconnecting..."
|
|
481
|
+
)
|
|
482
|
+
self._pool.remove(client)
|
|
483
|
+
if has_started:
|
|
484
|
+
self._event_ch.send_nowait(
|
|
485
|
+
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
|
|
486
|
+
)
|
|
487
|
+
has_started = False
|
|
488
|
+
self._reconnect_event.set()
|
|
489
|
+
return
|
|
490
|
+
|
|
491
|
+
if (
|
|
492
|
+
resp.speech_event_type
|
|
493
|
+
== cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END
|
|
494
|
+
):
|
|
495
|
+
self._event_ch.send_nowait(
|
|
496
|
+
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
|
|
283
497
|
)
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
|
|
311
|
-
alternatives=streaming_recognize_response_to_speech_data(resp),
|
|
498
|
+
has_started = False
|
|
499
|
+
|
|
500
|
+
while True:
|
|
501
|
+
audio_pushed = False
|
|
502
|
+
try:
|
|
503
|
+
async with self._pool.connection(timeout=self._conn_options.timeout) as client:
|
|
504
|
+
self._streaming_config = cloud_speech.StreamingRecognitionConfig(
|
|
505
|
+
config=cloud_speech.RecognitionConfig(
|
|
506
|
+
explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
|
|
507
|
+
encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
|
|
508
|
+
sample_rate_hertz=self._config.sample_rate,
|
|
509
|
+
audio_channel_count=1,
|
|
510
|
+
),
|
|
511
|
+
adaptation=self._config.build_adaptation(),
|
|
512
|
+
language_codes=self._config.languages,
|
|
513
|
+
model=self._config.model,
|
|
514
|
+
features=cloud_speech.RecognitionFeatures(
|
|
515
|
+
enable_automatic_punctuation=self._config.punctuate,
|
|
516
|
+
enable_word_time_offsets=self._config.enable_word_time_offsets,
|
|
517
|
+
enable_spoken_punctuation=self._config.spoken_punctuation,
|
|
518
|
+
),
|
|
519
|
+
),
|
|
520
|
+
streaming_features=cloud_speech.StreamingRecognitionFeatures(
|
|
521
|
+
interim_results=self._config.interim_results,
|
|
522
|
+
enable_voice_activity_events=self._config.enable_voice_activity_events,
|
|
523
|
+
),
|
|
312
524
|
)
|
|
313
|
-
self._event_queue.put_nowait(iterim_event)
|
|
314
525
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
alternatives=streaming_recognize_response_to_speech_data(resp),
|
|
526
|
+
should_stop = asyncio.Event()
|
|
527
|
+
stream = await client.streaming_recognize(
|
|
528
|
+
requests=input_generator(client, should_stop),
|
|
319
529
|
)
|
|
320
|
-
self.
|
|
321
|
-
self._event_queue.put_nowait(final_event)
|
|
322
|
-
|
|
323
|
-
if not self._speaking:
|
|
324
|
-
# With Google STT, we receive the final event after the END_OF_SPEECH event
|
|
325
|
-
sentence = ""
|
|
326
|
-
confidence = 0.0
|
|
327
|
-
for alt in self._final_events:
|
|
328
|
-
sentence += f"{alt.alternatives[0].text.strip()} "
|
|
329
|
-
confidence += alt.alternatives[0].confidence
|
|
330
|
-
|
|
331
|
-
sentence = sentence.rstrip()
|
|
332
|
-
confidence /= len(self._final_events) # avg. of confidence
|
|
333
|
-
|
|
334
|
-
end_event = stt.SpeechEvent(
|
|
335
|
-
type=stt.SpeechEventType.END_OF_SPEECH,
|
|
336
|
-
alternatives=[
|
|
337
|
-
stt.SpeechData(
|
|
338
|
-
language=result.language_code,
|
|
339
|
-
start_time=self._final_events[0]
|
|
340
|
-
.alternatives[0]
|
|
341
|
-
.start_time,
|
|
342
|
-
end_time=self._final_events[-1]
|
|
343
|
-
.alternatives[0]
|
|
344
|
-
.end_time,
|
|
345
|
-
confidence=confidence,
|
|
346
|
-
text=sentence,
|
|
347
|
-
)
|
|
348
|
-
],
|
|
349
|
-
)
|
|
530
|
+
self._session_connected_at = time.time()
|
|
350
531
|
|
|
351
|
-
|
|
352
|
-
|
|
532
|
+
process_stream_task = asyncio.create_task(process_stream(client, stream))
|
|
533
|
+
wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
|
|
353
534
|
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
535
|
+
try:
|
|
536
|
+
done, _ = await asyncio.wait(
|
|
537
|
+
[process_stream_task, wait_reconnect_task],
|
|
538
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
539
|
+
)
|
|
540
|
+
for task in done:
|
|
541
|
+
if task != wait_reconnect_task:
|
|
542
|
+
task.result()
|
|
543
|
+
if wait_reconnect_task not in done:
|
|
544
|
+
break
|
|
545
|
+
self._reconnect_event.clear()
|
|
546
|
+
finally:
|
|
547
|
+
should_stop.set()
|
|
548
|
+
if not process_stream_task.done() and not wait_reconnect_task.done():
|
|
549
|
+
# try to gracefully stop the process_stream_task
|
|
550
|
+
try:
|
|
551
|
+
await asyncio.wait_for(process_stream_task, timeout=1.0)
|
|
552
|
+
except asyncio.TimeoutError:
|
|
553
|
+
pass
|
|
554
|
+
|
|
555
|
+
await utils.aio.gracefully_cancel(process_stream_task, wait_reconnect_task)
|
|
556
|
+
except DeadlineExceeded:
|
|
557
|
+
raise APITimeoutError() from None
|
|
558
|
+
except GoogleAPICallError as e:
|
|
559
|
+
if e.code == 409:
|
|
560
|
+
if audio_pushed:
|
|
561
|
+
logger.debug("stream timed out, restarting.")
|
|
562
|
+
else:
|
|
563
|
+
raise APIStatusError(
|
|
564
|
+
f"{e.message} {e.details}", status_code=e.code or -1
|
|
565
|
+
) from e
|
|
566
|
+
except Exception as e:
|
|
567
|
+
raise APIConnectionError() from e
|
|
359
568
|
|
|
360
|
-
async def __anext__(self) -> stt.SpeechEvent:
|
|
361
|
-
evt = await self._event_queue.get()
|
|
362
|
-
if evt is None:
|
|
363
|
-
raise StopAsyncIteration
|
|
364
569
|
|
|
365
|
-
|
|
570
|
+
def _duration_to_seconds(duration: Duration | timedelta) -> float:
|
|
571
|
+
# Proto Plus may auto-convert Duration to timedelta; handle both.
|
|
572
|
+
# https://proto-plus-python.readthedocs.io/en/latest/marshal.html
|
|
573
|
+
if isinstance(duration, timedelta):
|
|
574
|
+
return duration.total_seconds()
|
|
575
|
+
return duration.seconds + duration.nanos / 1e9
|
|
366
576
|
|
|
367
577
|
|
|
368
|
-
def
|
|
578
|
+
def _recognize_response_to_speech_event(
|
|
369
579
|
resp: cloud_speech.RecognizeResponse,
|
|
370
580
|
) -> stt.SpeechEvent:
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
alternatives
|
|
581
|
+
text = ""
|
|
582
|
+
confidence = 0.0
|
|
583
|
+
for result in resp.results:
|
|
584
|
+
text += result.alternatives[0].transcript
|
|
585
|
+
confidence += result.alternatives[0].confidence
|
|
586
|
+
|
|
587
|
+
alternatives = []
|
|
588
|
+
|
|
589
|
+
# Google STT may return empty results when spoken_lang != stt_lang
|
|
590
|
+
if resp.results:
|
|
591
|
+
try:
|
|
592
|
+
start_time = _duration_to_seconds(resp.results[0].alternatives[0].words[0].start_offset)
|
|
593
|
+
end_time = _duration_to_seconds(resp.results[-1].alternatives[0].words[-1].end_offset)
|
|
594
|
+
except IndexError:
|
|
595
|
+
# When enable_word_time_offsets=False, there are no "words" to access
|
|
596
|
+
start_time = end_time = 0
|
|
597
|
+
|
|
598
|
+
confidence /= len(resp.results)
|
|
599
|
+
lg = resp.results[0].language_code
|
|
600
|
+
|
|
601
|
+
alternatives = [
|
|
376
602
|
stt.SpeechData(
|
|
377
|
-
language=
|
|
378
|
-
start_time=
|
|
379
|
-
end_time=
|
|
380
|
-
confidence=
|
|
381
|
-
text=
|
|
603
|
+
language=lg,
|
|
604
|
+
start_time=start_time,
|
|
605
|
+
end_time=end_time,
|
|
606
|
+
confidence=confidence,
|
|
607
|
+
text=text,
|
|
382
608
|
)
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
)
|
|
609
|
+
]
|
|
610
|
+
|
|
611
|
+
return stt.SpeechEvent(type=stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives=alternatives)
|
|
386
612
|
|
|
387
613
|
|
|
388
|
-
def
|
|
614
|
+
def _streaming_recognize_response_to_speech_data(
|
|
389
615
|
resp: cloud_speech.StreamingRecognizeResponse,
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
616
|
+
*,
|
|
617
|
+
min_confidence_threshold: float,
|
|
618
|
+
) -> stt.SpeechData | None:
|
|
619
|
+
text = ""
|
|
620
|
+
confidence = 0.0
|
|
621
|
+
final_result = None
|
|
622
|
+
for result in resp.results:
|
|
623
|
+
if len(result.alternatives) == 0:
|
|
624
|
+
continue
|
|
625
|
+
else:
|
|
626
|
+
if result.is_final:
|
|
627
|
+
final_result = result
|
|
628
|
+
break
|
|
629
|
+
else:
|
|
630
|
+
text += result.alternatives[0].transcript
|
|
631
|
+
confidence += result.alternatives[0].confidence
|
|
632
|
+
|
|
633
|
+
if final_result is not None:
|
|
634
|
+
text = final_result.alternatives[0].transcript
|
|
635
|
+
confidence = final_result.alternatives[0].confidence
|
|
636
|
+
lg = final_result.language_code
|
|
637
|
+
else:
|
|
638
|
+
confidence /= len(resp.results)
|
|
639
|
+
if confidence < min_confidence_threshold:
|
|
640
|
+
return None
|
|
641
|
+
lg = resp.results[0].language_code
|
|
642
|
+
|
|
643
|
+
if text == "":
|
|
644
|
+
return None
|
|
645
|
+
|
|
646
|
+
data = stt.SpeechData(language=lg, start_time=0, end_time=0, confidence=confidence, text=text)
|
|
647
|
+
|
|
648
|
+
return data
|