livekit-plugins-google 0.3.0__py3-none-any.whl → 1.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/google/__init__.py +33 -7
- livekit/plugins/google/beta/__init__.py +13 -0
- livekit/plugins/google/beta/gemini_tts.py +258 -0
- livekit/plugins/google/llm.py +562 -0
- livekit/plugins/google/log.py +3 -0
- livekit/plugins/google/models.py +160 -32
- livekit/plugins/google/realtime/__init__.py +9 -0
- livekit/plugins/google/realtime/api_proto.py +68 -0
- livekit/plugins/google/realtime/realtime_api.py +1249 -0
- livekit/plugins/google/stt.py +717 -283
- livekit/plugins/google/tools.py +71 -0
- livekit/plugins/google/tts.py +455 -0
- livekit/plugins/google/utils.py +220 -0
- livekit/plugins/google/version.py +1 -1
- livekit_plugins_google-1.3.11.dist-info/METADATA +63 -0
- livekit_plugins_google-1.3.11.dist-info/RECORD +18 -0
- {livekit_plugins_google-0.3.0.dist-info → livekit_plugins_google-1.3.11.dist-info}/WHEEL +1 -2
- livekit_plugins_google-0.3.0.dist-info/METADATA +0 -47
- livekit_plugins_google-0.3.0.dist-info/RECORD +0 -9
- livekit_plugins_google-0.3.0.dist-info/top_level.txt +0 -1
livekit/plugins/google/stt.py
CHANGED
|
@@ -15,35 +15,103 @@
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
17
|
import asyncio
|
|
18
|
-
import contextlib
|
|
19
18
|
import dataclasses
|
|
20
|
-
import
|
|
19
|
+
import time
|
|
20
|
+
import weakref
|
|
21
|
+
from collections.abc import AsyncGenerator, AsyncIterable
|
|
21
22
|
from dataclasses import dataclass
|
|
22
|
-
from
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
from
|
|
26
|
-
from
|
|
27
|
-
|
|
28
|
-
from google.auth import
|
|
29
|
-
from google.cloud.
|
|
30
|
-
from google.cloud.
|
|
31
|
-
|
|
32
|
-
from .
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
23
|
+
from datetime import timedelta
|
|
24
|
+
from typing import Callable, Union, cast, get_args
|
|
25
|
+
|
|
26
|
+
from google.api_core.client_options import ClientOptions
|
|
27
|
+
from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
|
|
28
|
+
from google.auth import default as gauth_default
|
|
29
|
+
from google.auth.exceptions import DefaultCredentialsError
|
|
30
|
+
from google.cloud.speech_v1 import SpeechAsyncClient as SpeechAsyncClientV1
|
|
31
|
+
from google.cloud.speech_v1.types import cloud_speech as cloud_speech_v1, resource as resource_v1
|
|
32
|
+
from google.cloud.speech_v2 import SpeechAsyncClient as SpeechAsyncClientV2
|
|
33
|
+
from google.cloud.speech_v2.types import cloud_speech as cloud_speech_v2
|
|
34
|
+
from google.protobuf.duration_pb2 import Duration
|
|
35
|
+
from livekit import rtc
|
|
36
|
+
from livekit.agents import (
|
|
37
|
+
DEFAULT_API_CONNECT_OPTIONS,
|
|
38
|
+
APIConnectionError,
|
|
39
|
+
APIConnectOptions,
|
|
40
|
+
APIStatusError,
|
|
41
|
+
APITimeoutError,
|
|
42
|
+
stt,
|
|
43
|
+
utils,
|
|
44
|
+
)
|
|
45
|
+
from livekit.agents.types import (
|
|
46
|
+
NOT_GIVEN,
|
|
47
|
+
NotGivenOr,
|
|
48
|
+
)
|
|
49
|
+
from livekit.agents.utils import is_given
|
|
50
|
+
from livekit.agents.voice.io import TimedString
|
|
51
|
+
|
|
52
|
+
from .log import logger
|
|
53
|
+
from .models import SpeechLanguages, SpeechModels, SpeechModelsV2
|
|
54
|
+
|
|
55
|
+
LgType = Union[SpeechLanguages, str]
|
|
56
|
+
LanguageCode = Union[LgType, list[LgType]]
|
|
57
|
+
|
|
58
|
+
# Google STT has a timeout of 5 mins, we'll attempt to restart the session
|
|
59
|
+
# before that timeout is reached
|
|
60
|
+
_max_session_duration = 240
|
|
61
|
+
|
|
62
|
+
# Google is very sensitive to background noise, so we'll ignore results with low confidence
|
|
63
|
+
_default_min_confidence = 0.65
|
|
36
64
|
|
|
37
65
|
|
|
38
66
|
# This class is only be used internally to encapsulate the options
|
|
39
67
|
@dataclass
|
|
40
68
|
class STTOptions:
|
|
41
|
-
languages:
|
|
69
|
+
languages: list[LgType]
|
|
42
70
|
detect_language: bool
|
|
43
71
|
interim_results: bool
|
|
44
72
|
punctuate: bool
|
|
45
73
|
spoken_punctuation: bool
|
|
46
|
-
|
|
74
|
+
enable_word_time_offsets: bool
|
|
75
|
+
enable_word_confidence: bool
|
|
76
|
+
enable_voice_activity_events: bool
|
|
77
|
+
model: SpeechModels | str
|
|
78
|
+
sample_rate: int
|
|
79
|
+
min_confidence_threshold: float
|
|
80
|
+
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def version(self) -> int:
|
|
84
|
+
return 2 if self.model in get_args(SpeechModelsV2) else 1
|
|
85
|
+
|
|
86
|
+
def build_adaptation(
|
|
87
|
+
self,
|
|
88
|
+
) -> cloud_speech_v2.SpeechAdaptation | resource_v1.SpeechAdaptation | None:
|
|
89
|
+
if is_given(self.keywords):
|
|
90
|
+
if self.version == 2:
|
|
91
|
+
return cloud_speech_v2.SpeechAdaptation(
|
|
92
|
+
phrase_sets=[
|
|
93
|
+
cloud_speech_v2.SpeechAdaptation.AdaptationPhraseSet(
|
|
94
|
+
inline_phrase_set=cloud_speech_v2.PhraseSet(
|
|
95
|
+
phrases=[
|
|
96
|
+
cloud_speech_v2.PhraseSet.Phrase(value=keyword, boost=boost)
|
|
97
|
+
for keyword, boost in self.keywords
|
|
98
|
+
]
|
|
99
|
+
)
|
|
100
|
+
)
|
|
101
|
+
]
|
|
102
|
+
)
|
|
103
|
+
return resource_v1.SpeechAdaptation(
|
|
104
|
+
phrase_sets=[
|
|
105
|
+
resource_v1.PhraseSet(
|
|
106
|
+
name="keywords",
|
|
107
|
+
phrases=[
|
|
108
|
+
resource_v1.PhraseSet.Phrase(value=keyword, boost=boost)
|
|
109
|
+
for keyword, boost in self.keywords
|
|
110
|
+
],
|
|
111
|
+
)
|
|
112
|
+
]
|
|
113
|
+
)
|
|
114
|
+
return None
|
|
47
115
|
|
|
48
116
|
|
|
49
117
|
class STT(stt.STT):
|
|
@@ -54,23 +122,80 @@ class STT(stt.STT):
|
|
|
54
122
|
detect_language: bool = True,
|
|
55
123
|
interim_results: bool = True,
|
|
56
124
|
punctuate: bool = True,
|
|
57
|
-
spoken_punctuation: bool =
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
125
|
+
spoken_punctuation: bool = False,
|
|
126
|
+
enable_word_time_offsets: NotGivenOr[bool] = NOT_GIVEN,
|
|
127
|
+
enable_word_confidence: bool = False,
|
|
128
|
+
enable_voice_activity_events: bool = False,
|
|
129
|
+
model: SpeechModels | str = "latest_long",
|
|
130
|
+
location: str = "global",
|
|
131
|
+
sample_rate: int = 16000,
|
|
132
|
+
min_confidence_threshold: float = _default_min_confidence,
|
|
133
|
+
credentials_info: NotGivenOr[dict] = NOT_GIVEN,
|
|
134
|
+
credentials_file: NotGivenOr[str] = NOT_GIVEN,
|
|
135
|
+
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
|
|
136
|
+
use_streaming: NotGivenOr[bool] = NOT_GIVEN,
|
|
61
137
|
):
|
|
62
138
|
"""
|
|
63
|
-
|
|
64
|
-
|
|
139
|
+
Create a new instance of Google STT.
|
|
140
|
+
|
|
141
|
+
Credentials must be provided, either by using the ``credentials_info`` dict, or reading
|
|
142
|
+
from the file specified in ``credentials_file`` or via Application Default Credentials as
|
|
143
|
+
described in https://cloud.google.com/docs/authentication/application-default-credentials
|
|
144
|
+
|
|
145
|
+
args:
|
|
146
|
+
languages(LanguageCode): list of language codes to recognize (default: "en-US")
|
|
147
|
+
detect_language(bool): whether to detect the language of the audio (default: True)
|
|
148
|
+
interim_results(bool): whether to return interim results (default: True)
|
|
149
|
+
punctuate(bool): whether to punctuate the audio (default: True)
|
|
150
|
+
spoken_punctuation(bool): whether to use spoken punctuation (default: False)
|
|
151
|
+
enable_word_time_offsets(bool): whether to enable word time offsets (default: None)
|
|
152
|
+
enable_word_confidence(bool): whether to enable word confidence (default: False)
|
|
153
|
+
enable_voice_activity_events(bool): whether to enable voice activity events (default: False)
|
|
154
|
+
model(SpeechModels): the model to use for recognition default: "latest_long"
|
|
155
|
+
location(str): the location to use for recognition default: "global"
|
|
156
|
+
sample_rate(int): the sample rate of the audio default: 16000
|
|
157
|
+
min_confidence_threshold(float): minimum confidence threshold for recognition
|
|
158
|
+
(default: 0.65)
|
|
159
|
+
credentials_info(dict): the credentials info to use for recognition (default: None)
|
|
160
|
+
credentials_file(str): the credentials file to use for recognition (default: None)
|
|
161
|
+
keywords(List[tuple[str, float]]): list of keywords to recognize (default: None)
|
|
162
|
+
use_streaming(bool): whether to use streaming for recognition (default: True)
|
|
65
163
|
"""
|
|
66
|
-
|
|
164
|
+
if not is_given(use_streaming):
|
|
165
|
+
use_streaming = True
|
|
67
166
|
|
|
68
|
-
if
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
167
|
+
if model == "chirp_3":
|
|
168
|
+
if is_given(enable_word_time_offsets) and enable_word_time_offsets:
|
|
169
|
+
logger.warning(
|
|
170
|
+
"Chirp 3 does not support word timestamps, setting 'enable_word_time_offsets' to False."
|
|
171
|
+
)
|
|
172
|
+
enable_word_time_offsets = False
|
|
173
|
+
elif is_given(enable_word_time_offsets):
|
|
174
|
+
enable_word_time_offsets = enable_word_time_offsets
|
|
72
175
|
else:
|
|
73
|
-
|
|
176
|
+
enable_word_time_offsets = True
|
|
177
|
+
|
|
178
|
+
super().__init__(
|
|
179
|
+
capabilities=stt.STTCapabilities(
|
|
180
|
+
streaming=use_streaming,
|
|
181
|
+
interim_results=True,
|
|
182
|
+
aligned_transcript="word" if enable_word_time_offsets and use_streaming else False,
|
|
183
|
+
)
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
self._location = location
|
|
187
|
+
self._credentials_info = credentials_info
|
|
188
|
+
self._credentials_file = credentials_file
|
|
189
|
+
|
|
190
|
+
if not is_given(credentials_file) and not is_given(credentials_info):
|
|
191
|
+
try:
|
|
192
|
+
gauth_default() # type: ignore
|
|
193
|
+
except DefaultCredentialsError:
|
|
194
|
+
raise ValueError(
|
|
195
|
+
"Application default credentials must be available "
|
|
196
|
+
"when using Google STT without explicitly passing "
|
|
197
|
+
"credentials through credentials_info or credentials_file."
|
|
198
|
+
) from None
|
|
74
199
|
|
|
75
200
|
if isinstance(languages, str):
|
|
76
201
|
languages = [languages]
|
|
@@ -81,322 +206,631 @@ class STT(stt.STT):
|
|
|
81
206
|
interim_results=interim_results,
|
|
82
207
|
punctuate=punctuate,
|
|
83
208
|
spoken_punctuation=spoken_punctuation,
|
|
209
|
+
enable_word_time_offsets=enable_word_time_offsets,
|
|
210
|
+
enable_word_confidence=enable_word_confidence,
|
|
211
|
+
enable_voice_activity_events=enable_voice_activity_events,
|
|
84
212
|
model=model,
|
|
213
|
+
sample_rate=sample_rate,
|
|
214
|
+
min_confidence_threshold=min_confidence_threshold,
|
|
215
|
+
keywords=keywords,
|
|
216
|
+
)
|
|
217
|
+
self._streams = weakref.WeakSet[SpeechStream]()
|
|
218
|
+
self._pool = utils.ConnectionPool[SpeechAsyncClientV2 | SpeechAsyncClientV1](
|
|
219
|
+
max_session_duration=_max_session_duration,
|
|
220
|
+
connect_cb=self._create_client,
|
|
85
221
|
)
|
|
86
|
-
self._creds = self._client.transport._credentials
|
|
87
222
|
|
|
88
223
|
@property
|
|
89
|
-
def
|
|
224
|
+
def model(self) -> str:
|
|
225
|
+
return self._config.model
|
|
226
|
+
|
|
227
|
+
@property
|
|
228
|
+
def provider(self) -> str:
|
|
229
|
+
return "Google Cloud Platform"
|
|
230
|
+
|
|
231
|
+
async def _create_client(self, timeout: float) -> SpeechAsyncClientV2 | SpeechAsyncClientV1:
|
|
232
|
+
# Add support for passing a specific location that matches recognizer
|
|
233
|
+
# see: https://cloud.google.com/speech-to-text/v2/docs/speech-to-text-supported-languages
|
|
234
|
+
# TODO(long): how to set timeout?
|
|
235
|
+
client_options = None
|
|
236
|
+
client: SpeechAsyncClientV2 | SpeechAsyncClientV1 | None = None
|
|
237
|
+
client_cls = SpeechAsyncClientV2 if self._config.version == 2 else SpeechAsyncClientV1
|
|
238
|
+
if self._location != "global":
|
|
239
|
+
client_options = ClientOptions(api_endpoint=f"{self._location}-speech.googleapis.com")
|
|
240
|
+
if is_given(self._credentials_info):
|
|
241
|
+
client = client_cls.from_service_account_info(
|
|
242
|
+
self._credentials_info, client_options=client_options
|
|
243
|
+
)
|
|
244
|
+
elif is_given(self._credentials_file):
|
|
245
|
+
client = client_cls.from_service_account_file(
|
|
246
|
+
self._credentials_file, client_options=client_options
|
|
247
|
+
)
|
|
248
|
+
else:
|
|
249
|
+
client = client_cls(client_options=client_options)
|
|
250
|
+
assert client is not None
|
|
251
|
+
return client
|
|
252
|
+
|
|
253
|
+
def _get_recognizer(self, client: SpeechAsyncClientV2) -> str:
|
|
90
254
|
# TODO(theomonnom): should we use recognizers?
|
|
91
|
-
#
|
|
92
|
-
return f"projects/{self._creds.project_id}/locations/global/recognizers/_" # type: ignore
|
|
255
|
+
# recognizers may improve latency https://cloud.google.com/speech-to-text/v2/docs/recognizers#understand_recognizers
|
|
93
256
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
257
|
+
# TODO(theomonnom): find a better way to access the project_id
|
|
258
|
+
try:
|
|
259
|
+
project_id = client.transport._credentials.project_id # type: ignore
|
|
260
|
+
except AttributeError:
|
|
261
|
+
from google.auth import default as ga_default
|
|
262
|
+
|
|
263
|
+
_, project_id = ga_default() # type: ignore
|
|
264
|
+
return f"projects/{project_id}/locations/{self._location}/recognizers/_"
|
|
265
|
+
|
|
266
|
+
def _sanitize_options(self, *, language: NotGivenOr[str] = NOT_GIVEN) -> STTOptions:
|
|
99
267
|
config = dataclasses.replace(self._config)
|
|
100
268
|
|
|
101
|
-
if language:
|
|
269
|
+
if is_given(language):
|
|
102
270
|
config.languages = [language]
|
|
103
271
|
|
|
104
272
|
if not isinstance(config.languages, list):
|
|
105
273
|
config.languages = [config.languages]
|
|
106
274
|
elif not config.detect_language:
|
|
107
275
|
if len(config.languages) > 1:
|
|
108
|
-
|
|
109
|
-
"multiple languages provided, but language detection is disabled"
|
|
110
|
-
)
|
|
276
|
+
logger.warning("multiple languages provided, but language detection is disabled")
|
|
111
277
|
config.languages = [config.languages[0]]
|
|
112
278
|
|
|
113
279
|
return config
|
|
114
280
|
|
|
115
|
-
|
|
281
|
+
def _build_recognition_config(
|
|
116
282
|
self,
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
language: SpeechLanguages | str
|
|
120
|
-
) ->
|
|
283
|
+
sample_rate: int,
|
|
284
|
+
num_channels: int,
|
|
285
|
+
language: NotGivenOr[SpeechLanguages | str] = NOT_GIVEN,
|
|
286
|
+
) -> cloud_speech_v2.RecognitionConfig | cloud_speech_v1.RecognitionConfig:
|
|
121
287
|
config = self._sanitize_options(language=language)
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
288
|
+
if self._config.version == 2:
|
|
289
|
+
return cloud_speech_v2.RecognitionConfig(
|
|
290
|
+
explicit_decoding_config=cloud_speech_v2.ExplicitDecodingConfig(
|
|
291
|
+
encoding=cloud_speech_v2.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
|
|
292
|
+
sample_rate_hertz=sample_rate,
|
|
293
|
+
audio_channel_count=num_channels,
|
|
294
|
+
),
|
|
295
|
+
adaptation=config.build_adaptation(),
|
|
296
|
+
features=cloud_speech_v2.RecognitionFeatures(
|
|
297
|
+
enable_automatic_punctuation=config.punctuate,
|
|
298
|
+
enable_spoken_punctuation=config.spoken_punctuation,
|
|
299
|
+
enable_word_time_offsets=config.enable_word_time_offsets,
|
|
300
|
+
enable_word_confidence=config.enable_word_confidence,
|
|
301
|
+
),
|
|
302
|
+
model=config.model,
|
|
303
|
+
language_codes=config.languages,
|
|
304
|
+
)
|
|
305
|
+
return cloud_speech_v1.RecognitionConfig(
|
|
306
|
+
encoding=cloud_speech_v1.RecognitionConfig.AudioEncoding.LINEAR16,
|
|
307
|
+
sample_rate_hertz=sample_rate,
|
|
308
|
+
audio_channel_count=num_channels,
|
|
309
|
+
adaptation=config.build_adaptation(),
|
|
310
|
+
language_code=config.languages[0],
|
|
311
|
+
alternative_language_codes=config.languages[1:],
|
|
312
|
+
enable_word_time_offsets=config.enable_word_time_offsets,
|
|
313
|
+
enable_word_confidence=config.enable_word_confidence,
|
|
314
|
+
enable_automatic_punctuation=config.punctuate,
|
|
315
|
+
enable_spoken_punctuation=config.spoken_punctuation,
|
|
134
316
|
model=config.model,
|
|
135
|
-
language_codes=config.languages,
|
|
136
317
|
)
|
|
137
318
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
319
|
+
def _build_recognition_request(
|
|
320
|
+
self,
|
|
321
|
+
client: SpeechAsyncClientV2 | SpeechAsyncClientV1,
|
|
322
|
+
config: cloud_speech_v2.RecognitionConfig | cloud_speech_v1.RecognitionConfig,
|
|
323
|
+
content: bytes,
|
|
324
|
+
) -> cloud_speech_v2.RecognizeRequest | cloud_speech_v1.RecognizeRequest:
|
|
325
|
+
if self._config.version == 2:
|
|
326
|
+
return cloud_speech_v2.RecognizeRequest(
|
|
327
|
+
recognizer=self._get_recognizer(cast(SpeechAsyncClientV2, client)),
|
|
328
|
+
config=config,
|
|
329
|
+
content=content,
|
|
145
330
|
)
|
|
331
|
+
|
|
332
|
+
return cloud_speech_v1.RecognizeRequest(
|
|
333
|
+
config=config,
|
|
334
|
+
audio=cloud_speech_v1.RecognitionAudio(content=content),
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
async def _recognize_impl(
|
|
338
|
+
self,
|
|
339
|
+
buffer: utils.AudioBuffer,
|
|
340
|
+
*,
|
|
341
|
+
language: NotGivenOr[SpeechLanguages | str] = NOT_GIVEN,
|
|
342
|
+
conn_options: APIConnectOptions,
|
|
343
|
+
) -> stt.SpeechEvent:
|
|
344
|
+
frame = rtc.combine_audio_frames(buffer)
|
|
345
|
+
|
|
346
|
+
config = self._build_recognition_config(
|
|
347
|
+
sample_rate=frame.sample_rate,
|
|
348
|
+
num_channels=frame.num_channels,
|
|
349
|
+
language=language,
|
|
146
350
|
)
|
|
147
351
|
|
|
352
|
+
try:
|
|
353
|
+
async with self._pool.connection(timeout=conn_options.timeout) as client:
|
|
354
|
+
raw = await client.recognize(
|
|
355
|
+
self._build_recognition_request(client, config, frame.data.tobytes()),
|
|
356
|
+
timeout=conn_options.timeout,
|
|
357
|
+
)
|
|
358
|
+
return _recognize_response_to_speech_event(raw)
|
|
359
|
+
except DeadlineExceeded:
|
|
360
|
+
raise APITimeoutError() from None
|
|
361
|
+
except GoogleAPICallError as e:
|
|
362
|
+
raise APIStatusError(f"{e.message} {e.details}", status_code=e.code or -1) from e
|
|
363
|
+
except Exception as e:
|
|
364
|
+
raise APIConnectionError() from e
|
|
365
|
+
|
|
148
366
|
def stream(
|
|
149
367
|
self,
|
|
150
368
|
*,
|
|
151
|
-
language: SpeechLanguages | str
|
|
152
|
-
|
|
369
|
+
language: NotGivenOr[SpeechLanguages | str] = NOT_GIVEN,
|
|
370
|
+
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
371
|
+
) -> SpeechStream:
|
|
153
372
|
config = self._sanitize_options(language=language)
|
|
154
|
-
|
|
155
|
-
self
|
|
156
|
-
self.
|
|
157
|
-
self.
|
|
158
|
-
config,
|
|
373
|
+
stream = SpeechStream(
|
|
374
|
+
stt=self,
|
|
375
|
+
pool=self._pool,
|
|
376
|
+
recognizer_cb=self._get_recognizer,
|
|
377
|
+
config=config,
|
|
378
|
+
conn_options=conn_options,
|
|
159
379
|
)
|
|
380
|
+
self._streams.add(stream)
|
|
381
|
+
return stream
|
|
382
|
+
|
|
383
|
+
def update_options(
|
|
384
|
+
self,
|
|
385
|
+
*,
|
|
386
|
+
languages: NotGivenOr[LanguageCode] = NOT_GIVEN,
|
|
387
|
+
detect_language: NotGivenOr[bool] = NOT_GIVEN,
|
|
388
|
+
interim_results: NotGivenOr[bool] = NOT_GIVEN,
|
|
389
|
+
punctuate: NotGivenOr[bool] = NOT_GIVEN,
|
|
390
|
+
spoken_punctuation: NotGivenOr[bool] = NOT_GIVEN,
|
|
391
|
+
model: NotGivenOr[SpeechModels] = NOT_GIVEN,
|
|
392
|
+
location: NotGivenOr[str] = NOT_GIVEN,
|
|
393
|
+
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
|
|
394
|
+
) -> None:
|
|
395
|
+
if is_given(languages):
|
|
396
|
+
if isinstance(languages, str):
|
|
397
|
+
languages = [languages]
|
|
398
|
+
self._config.languages = cast(list[LgType], languages)
|
|
399
|
+
if is_given(detect_language):
|
|
400
|
+
self._config.detect_language = detect_language
|
|
401
|
+
if is_given(interim_results):
|
|
402
|
+
self._config.interim_results = interim_results
|
|
403
|
+
if is_given(punctuate):
|
|
404
|
+
self._config.punctuate = punctuate
|
|
405
|
+
if is_given(spoken_punctuation):
|
|
406
|
+
self._config.spoken_punctuation = spoken_punctuation
|
|
407
|
+
if is_given(model):
|
|
408
|
+
old_version = self._config.version
|
|
409
|
+
self._config.model = model
|
|
410
|
+
if self._config.version != old_version:
|
|
411
|
+
self._pool.invalidate()
|
|
412
|
+
|
|
413
|
+
if is_given(location):
|
|
414
|
+
self._location = location
|
|
415
|
+
# if location is changed, fetch a new client and recognizer as per the new location
|
|
416
|
+
self._pool.invalidate()
|
|
417
|
+
if is_given(keywords):
|
|
418
|
+
self._config.keywords = keywords
|
|
419
|
+
|
|
420
|
+
for stream in self._streams:
|
|
421
|
+
stream.update_options(
|
|
422
|
+
languages=languages,
|
|
423
|
+
detect_language=detect_language,
|
|
424
|
+
interim_results=interim_results,
|
|
425
|
+
punctuate=punctuate,
|
|
426
|
+
spoken_punctuation=spoken_punctuation,
|
|
427
|
+
model=model,
|
|
428
|
+
keywords=keywords,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
async def aclose(self) -> None:
|
|
432
|
+
await self._pool.aclose()
|
|
433
|
+
await super().aclose()
|
|
160
434
|
|
|
161
435
|
|
|
162
436
|
class SpeechStream(stt.SpeechStream):
|
|
163
437
|
def __init__(
|
|
164
438
|
self,
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
439
|
+
*,
|
|
440
|
+
stt: STT,
|
|
441
|
+
conn_options: APIConnectOptions,
|
|
442
|
+
pool: utils.ConnectionPool[SpeechAsyncClientV2 | SpeechAsyncClientV1],
|
|
443
|
+
recognizer_cb: Callable[[SpeechAsyncClientV2], str],
|
|
168
444
|
config: STTOptions,
|
|
169
|
-
sample_rate: int = 24000,
|
|
170
|
-
num_channels: int = 1,
|
|
171
|
-
max_retry: int = 32,
|
|
172
445
|
) -> None:
|
|
173
|
-
super().__init__()
|
|
446
|
+
super().__init__(stt=stt, conn_options=conn_options, sample_rate=config.sample_rate)
|
|
174
447
|
|
|
175
|
-
self.
|
|
176
|
-
self.
|
|
177
|
-
self._recognizer = recognizer
|
|
448
|
+
self._pool = pool
|
|
449
|
+
self._recognizer_cb = recognizer_cb
|
|
178
450
|
self._config = config
|
|
179
|
-
self.
|
|
180
|
-
self.
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
self
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
451
|
+
self._reconnect_event = asyncio.Event()
|
|
452
|
+
self._session_connected_at: float = 0
|
|
453
|
+
|
|
454
|
+
def update_options(
|
|
455
|
+
self,
|
|
456
|
+
*,
|
|
457
|
+
languages: NotGivenOr[LanguageCode] = NOT_GIVEN,
|
|
458
|
+
detect_language: NotGivenOr[bool] = NOT_GIVEN,
|
|
459
|
+
interim_results: NotGivenOr[bool] = NOT_GIVEN,
|
|
460
|
+
punctuate: NotGivenOr[bool] = NOT_GIVEN,
|
|
461
|
+
spoken_punctuation: NotGivenOr[bool] = NOT_GIVEN,
|
|
462
|
+
model: NotGivenOr[SpeechModels] = NOT_GIVEN,
|
|
463
|
+
min_confidence_threshold: NotGivenOr[float] = NOT_GIVEN,
|
|
464
|
+
keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
|
|
465
|
+
) -> None:
|
|
466
|
+
if is_given(languages):
|
|
467
|
+
if isinstance(languages, str):
|
|
468
|
+
languages = [languages]
|
|
469
|
+
self._config.languages = cast(list[LgType], languages)
|
|
470
|
+
if is_given(detect_language):
|
|
471
|
+
self._config.detect_language = detect_language
|
|
472
|
+
if is_given(interim_results):
|
|
473
|
+
self._config.interim_results = interim_results
|
|
474
|
+
if is_given(punctuate):
|
|
475
|
+
self._config.punctuate = punctuate
|
|
476
|
+
if is_given(spoken_punctuation):
|
|
477
|
+
self._config.spoken_punctuation = spoken_punctuation
|
|
478
|
+
if is_given(model):
|
|
479
|
+
old_version = self._config.version
|
|
480
|
+
self._config.model = model
|
|
481
|
+
if self._config.version != old_version:
|
|
482
|
+
self._pool.invalidate()
|
|
483
|
+
if is_given(min_confidence_threshold):
|
|
484
|
+
self._config.min_confidence_threshold = min_confidence_threshold
|
|
485
|
+
if is_given(keywords):
|
|
486
|
+
self._config.keywords = keywords
|
|
487
|
+
|
|
488
|
+
self._reconnect_event.set()
|
|
489
|
+
|
|
490
|
+
def _build_streaming_config(
|
|
491
|
+
self,
|
|
492
|
+
) -> cloud_speech_v2.StreamingRecognitionConfig | cloud_speech_v1.StreamingRecognitionConfig:
|
|
493
|
+
if self._config.version == 2:
|
|
494
|
+
return cloud_speech_v2.StreamingRecognitionConfig(
|
|
495
|
+
config=cloud_speech_v2.RecognitionConfig(
|
|
496
|
+
explicit_decoding_config=cloud_speech_v2.ExplicitDecodingConfig(
|
|
497
|
+
encoding=cloud_speech_v2.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
|
|
498
|
+
sample_rate_hertz=self._config.sample_rate,
|
|
499
|
+
audio_channel_count=1,
|
|
500
|
+
),
|
|
501
|
+
adaptation=self._config.build_adaptation(),
|
|
502
|
+
language_codes=self._config.languages,
|
|
503
|
+
model=self._config.model,
|
|
504
|
+
features=cloud_speech_v2.RecognitionFeatures(
|
|
505
|
+
enable_automatic_punctuation=self._config.punctuate,
|
|
506
|
+
enable_word_time_offsets=self._config.enable_word_time_offsets,
|
|
507
|
+
enable_spoken_punctuation=self._config.spoken_punctuation,
|
|
508
|
+
enable_word_confidence=self._config.enable_word_confidence,
|
|
509
|
+
),
|
|
196
510
|
),
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
enable_automatic_punctuation=self._config.punctuate,
|
|
511
|
+
streaming_features=cloud_speech_v2.StreamingRecognitionFeatures(
|
|
512
|
+
interim_results=self._config.interim_results,
|
|
513
|
+
enable_voice_activity_events=self._config.enable_voice_activity_events,
|
|
201
514
|
),
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
return cloud_speech_v1.StreamingRecognitionConfig(
|
|
518
|
+
config=cloud_speech_v1.RecognitionConfig(
|
|
519
|
+
encoding=cloud_speech_v1.RecognitionConfig.AudioEncoding.LINEAR16,
|
|
520
|
+
sample_rate_hertz=self._config.sample_rate,
|
|
521
|
+
audio_channel_count=1,
|
|
522
|
+
adaptation=self._config.build_adaptation(),
|
|
523
|
+
language_code=self._config.languages[0],
|
|
524
|
+
alternative_language_codes=self._config.languages[1:],
|
|
525
|
+
enable_word_time_offsets=self._config.enable_word_time_offsets,
|
|
526
|
+
enable_word_confidence=self._config.enable_word_confidence,
|
|
527
|
+
enable_automatic_punctuation=self._config.punctuate,
|
|
528
|
+
enable_spoken_punctuation=self._config.spoken_punctuation,
|
|
529
|
+
model=self._config.model,
|
|
202
530
|
),
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
interim_results=self._config.interim_results,
|
|
206
|
-
),
|
|
531
|
+
interim_results=self._config.interim_results,
|
|
532
|
+
enable_voice_activity_events=self._config.enable_voice_activity_events,
|
|
207
533
|
)
|
|
208
534
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
self.
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
async def aclose(self, wait: bool = True) -> None:
|
|
222
|
-
self._closed = True
|
|
223
|
-
if not wait:
|
|
224
|
-
self._main_task.cancel()
|
|
225
|
-
|
|
226
|
-
self._queue.put_nowait(None)
|
|
227
|
-
with contextlib.suppress(asyncio.CancelledError):
|
|
228
|
-
await self._main_task
|
|
535
|
+
def _build_init_request(
|
|
536
|
+
self,
|
|
537
|
+
client: SpeechAsyncClientV2 | SpeechAsyncClientV1,
|
|
538
|
+
) -> cloud_speech_v2.StreamingRecognizeRequest | cloud_speech_v1.StreamingRecognizeRequest:
|
|
539
|
+
if self._config.version == 2:
|
|
540
|
+
return cloud_speech_v2.StreamingRecognizeRequest(
|
|
541
|
+
recognizer=self._recognizer_cb(cast(SpeechAsyncClientV2, client)),
|
|
542
|
+
streaming_config=self._streaming_config,
|
|
543
|
+
)
|
|
544
|
+
return cloud_speech_v1.StreamingRecognizeRequest(
|
|
545
|
+
streaming_config=self._streaming_config,
|
|
546
|
+
)
|
|
229
547
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
548
|
+
def _build_audio_request(
|
|
549
|
+
self,
|
|
550
|
+
frame: rtc.AudioFrame,
|
|
551
|
+
) -> cloud_speech_v2.StreamingRecognizeRequest | cloud_speech_v1.StreamingRecognizeRequest:
|
|
552
|
+
if self._config.version == 2:
|
|
553
|
+
return cloud_speech_v2.StreamingRecognizeRequest(audio=frame.data.tobytes())
|
|
554
|
+
return cloud_speech_v1.StreamingRecognizeRequest(audio_content=frame.data.tobytes())
|
|
555
|
+
|
|
556
|
+
async def _run(self) -> None:
|
|
557
|
+
audio_pushed = False
|
|
558
|
+
|
|
559
|
+
# google requires a async generator when calling streaming_recognize
|
|
560
|
+
# this function basically convert the queue into a async generator
|
|
561
|
+
async def input_generator(
|
|
562
|
+
client: SpeechAsyncClientV2 | SpeechAsyncClientV1, should_stop: asyncio.Event
|
|
563
|
+
) -> AsyncGenerator[
|
|
564
|
+
cloud_speech_v2.StreamingRecognizeRequest | cloud_speech_v1.StreamingRecognizeRequest,
|
|
565
|
+
None,
|
|
566
|
+
]:
|
|
567
|
+
nonlocal audio_pushed
|
|
568
|
+
try:
|
|
569
|
+
yield self._build_init_request(client)
|
|
570
|
+
|
|
571
|
+
async for frame in self._input_ch:
|
|
572
|
+
# when the stream is aborted due to reconnect, this input_generator
|
|
573
|
+
# needs to stop consuming frames
|
|
574
|
+
# when the generator stops, the previous gRPC stream will close
|
|
575
|
+
if should_stop.is_set():
|
|
576
|
+
return
|
|
577
|
+
|
|
578
|
+
if isinstance(frame, rtc.AudioFrame):
|
|
579
|
+
yield self._build_audio_request(frame)
|
|
580
|
+
if not audio_pushed:
|
|
581
|
+
audio_pushed = True
|
|
582
|
+
|
|
583
|
+
except Exception:
|
|
584
|
+
logger.exception("an error occurred while streaming input to google STT")
|
|
585
|
+
|
|
586
|
+
async def process_stream(
|
|
587
|
+
client: SpeechAsyncClientV2 | SpeechAsyncClientV1,
|
|
588
|
+
stream: AsyncIterable[
|
|
589
|
+
cloud_speech_v2.StreamingRecognizeResponse
|
|
590
|
+
| cloud_speech_v1.StreamingRecognizeResponse
|
|
591
|
+
],
|
|
592
|
+
) -> None:
|
|
593
|
+
has_started = False
|
|
594
|
+
async for resp in stream:
|
|
595
|
+
if resp.speech_event_type == (
|
|
596
|
+
cloud_speech_v2.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN
|
|
597
|
+
if self._config.version == 2
|
|
598
|
+
else cloud_speech_v1.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN
|
|
599
|
+
):
|
|
600
|
+
self._event_ch.send_nowait(
|
|
601
|
+
stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
|
|
602
|
+
)
|
|
603
|
+
has_started = True
|
|
604
|
+
|
|
605
|
+
if resp.speech_event_type == (
|
|
606
|
+
cloud_speech_v2.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_TYPE_UNSPECIFIED
|
|
607
|
+
if self._config.version == 2
|
|
608
|
+
else cloud_speech_v1.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_UNSPECIFIED
|
|
609
|
+
):
|
|
610
|
+
result = resp.results[0]
|
|
611
|
+
speech_data = _streaming_recognize_response_to_speech_data(
|
|
612
|
+
resp,
|
|
613
|
+
min_confidence_threshold=self._config.min_confidence_threshold,
|
|
614
|
+
start_time_offset=self.start_time_offset,
|
|
615
|
+
)
|
|
616
|
+
if speech_data is None:
|
|
617
|
+
continue
|
|
618
|
+
|
|
619
|
+
if not result.is_final:
|
|
620
|
+
self._event_ch.send_nowait(
|
|
621
|
+
stt.SpeechEvent(
|
|
622
|
+
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
|
|
623
|
+
alternatives=[speech_data],
|
|
243
624
|
)
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
self._queue.task_done()
|
|
252
|
-
frame = frame.remix_and_resample(
|
|
253
|
-
self._sample_rate, self._num_channels
|
|
254
|
-
)
|
|
255
|
-
yield cloud_speech.StreamingRecognizeRequest(
|
|
256
|
-
audio=frame.data.tobytes(),
|
|
257
|
-
)
|
|
258
|
-
except Exception as e:
|
|
259
|
-
logging.error(
|
|
260
|
-
f"an error occurred while streaming inputs: {e}"
|
|
625
|
+
)
|
|
626
|
+
else:
|
|
627
|
+
self._event_ch.send_nowait(
|
|
628
|
+
stt.SpeechEvent(
|
|
629
|
+
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
|
630
|
+
alternatives=[speech_data],
|
|
261
631
|
)
|
|
262
|
-
|
|
263
|
-
# try to connect
|
|
264
|
-
stream = await self._client.streaming_recognize(
|
|
265
|
-
requests=input_generator()
|
|
266
|
-
)
|
|
267
|
-
retry_count = 0 # connection successful, reset retry count
|
|
268
|
-
|
|
269
|
-
await self._run_stream(stream)
|
|
270
|
-
except Exception as e:
|
|
271
|
-
if retry_count >= max_retry:
|
|
272
|
-
logging.error(
|
|
273
|
-
f"failed to connect to google stt after {max_retry} tries",
|
|
274
|
-
exc_info=e,
|
|
275
632
|
)
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
633
|
+
if time.time() - self._session_connected_at > _max_session_duration:
|
|
634
|
+
logger.debug(
|
|
635
|
+
"Google STT maximum connection time reached. Reconnecting..."
|
|
636
|
+
)
|
|
637
|
+
self._pool.remove(client)
|
|
638
|
+
if has_started:
|
|
639
|
+
self._event_ch.send_nowait(
|
|
640
|
+
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
|
|
641
|
+
)
|
|
642
|
+
has_started = False
|
|
643
|
+
self._reconnect_event.set()
|
|
644
|
+
return
|
|
645
|
+
|
|
646
|
+
if resp.speech_event_type == (
|
|
647
|
+
cloud_speech_v2.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END
|
|
648
|
+
if self._config.version == 2
|
|
649
|
+
else cloud_speech_v1.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END
|
|
650
|
+
):
|
|
651
|
+
self._event_ch.send_nowait(
|
|
652
|
+
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
|
|
283
653
|
)
|
|
284
|
-
|
|
285
|
-
finally:
|
|
286
|
-
self._event_queue.put_nowait(None)
|
|
654
|
+
has_started = False
|
|
287
655
|
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
resp.speech_event_type
|
|
294
|
-
== cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN
|
|
295
|
-
):
|
|
296
|
-
self._speaking = True
|
|
297
|
-
start_event = stt.SpeechEvent(
|
|
298
|
-
type=stt.SpeechEventType.START_OF_SPEECH,
|
|
299
|
-
)
|
|
300
|
-
self._event_queue.put_nowait(start_event)
|
|
301
|
-
|
|
302
|
-
if (
|
|
303
|
-
resp.speech_event_type
|
|
304
|
-
== cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_TYPE_UNSPECIFIED
|
|
305
|
-
):
|
|
306
|
-
result = resp.results[0]
|
|
307
|
-
if not result.is_final:
|
|
308
|
-
# interim results
|
|
309
|
-
iterim_event = stt.SpeechEvent(
|
|
310
|
-
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
|
|
311
|
-
alternatives=streaming_recognize_response_to_speech_data(resp),
|
|
312
|
-
)
|
|
313
|
-
self._event_queue.put_nowait(iterim_event)
|
|
656
|
+
while True:
|
|
657
|
+
audio_pushed = False
|
|
658
|
+
try:
|
|
659
|
+
async with self._pool.connection(timeout=self._conn_options.timeout) as client:
|
|
660
|
+
self._streaming_config = self._build_streaming_config()
|
|
314
661
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
alternatives=streaming_recognize_response_to_speech_data(resp),
|
|
662
|
+
should_stop = asyncio.Event()
|
|
663
|
+
stream = await client.streaming_recognize(
|
|
664
|
+
requests=input_generator(client, should_stop),
|
|
319
665
|
)
|
|
320
|
-
self.
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
confidence += alt.alternatives[0].confidence
|
|
330
|
-
|
|
331
|
-
sentence = sentence.rstrip()
|
|
332
|
-
confidence /= len(self._final_events) # avg. of confidence
|
|
333
|
-
|
|
334
|
-
end_event = stt.SpeechEvent(
|
|
335
|
-
type=stt.SpeechEventType.END_OF_SPEECH,
|
|
336
|
-
alternatives=[
|
|
337
|
-
stt.SpeechData(
|
|
338
|
-
language=result.language_code,
|
|
339
|
-
start_time=self._final_events[0]
|
|
340
|
-
.alternatives[0]
|
|
341
|
-
.start_time,
|
|
342
|
-
end_time=self._final_events[-1]
|
|
343
|
-
.alternatives[0]
|
|
344
|
-
.end_time,
|
|
345
|
-
confidence=confidence,
|
|
346
|
-
text=sentence,
|
|
347
|
-
)
|
|
348
|
-
],
|
|
666
|
+
self._session_connected_at = time.time()
|
|
667
|
+
|
|
668
|
+
process_stream_task = asyncio.create_task(process_stream(client, stream))
|
|
669
|
+
wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
|
|
670
|
+
|
|
671
|
+
try:
|
|
672
|
+
done, _ = await asyncio.wait(
|
|
673
|
+
[process_stream_task, wait_reconnect_task],
|
|
674
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
349
675
|
)
|
|
676
|
+
for task in done:
|
|
677
|
+
if task != wait_reconnect_task:
|
|
678
|
+
task.result()
|
|
679
|
+
if wait_reconnect_task not in done:
|
|
680
|
+
break
|
|
681
|
+
self._reconnect_event.clear()
|
|
682
|
+
finally:
|
|
683
|
+
should_stop.set()
|
|
684
|
+
if not process_stream_task.done() and not wait_reconnect_task.done():
|
|
685
|
+
# try to gracefully stop the process_stream_task
|
|
686
|
+
try:
|
|
687
|
+
await asyncio.wait_for(process_stream_task, timeout=1.0)
|
|
688
|
+
except asyncio.TimeoutError:
|
|
689
|
+
pass
|
|
690
|
+
|
|
691
|
+
await utils.aio.gracefully_cancel(process_stream_task, wait_reconnect_task)
|
|
692
|
+
except DeadlineExceeded:
|
|
693
|
+
raise APITimeoutError() from None
|
|
694
|
+
except GoogleAPICallError as e:
|
|
695
|
+
if e.code == 409:
|
|
696
|
+
if audio_pushed:
|
|
697
|
+
logger.debug("stream timed out, restarting.")
|
|
698
|
+
else:
|
|
699
|
+
raise APIStatusError(
|
|
700
|
+
f"{e.message} {e.details}", status_code=e.code or -1
|
|
701
|
+
) from e
|
|
702
|
+
except Exception as e:
|
|
703
|
+
raise APIConnectionError() from e
|
|
350
704
|
|
|
351
|
-
self._final_events = []
|
|
352
|
-
self._event_queue.put_nowait(end_event)
|
|
353
705
|
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
706
|
+
def _duration_to_seconds(duration: Duration | timedelta) -> float:
|
|
707
|
+
# Proto Plus may auto-convert Duration to timedelta; handle both.
|
|
708
|
+
# https://proto-plus-python.readthedocs.io/en/latest/marshal.html
|
|
709
|
+
if isinstance(duration, timedelta):
|
|
710
|
+
return duration.total_seconds()
|
|
711
|
+
return duration.seconds + duration.nanos / 1e9
|
|
359
712
|
|
|
360
|
-
async def __anext__(self) -> stt.SpeechEvent:
|
|
361
|
-
evt = await self._event_queue.get()
|
|
362
|
-
if evt is None:
|
|
363
|
-
raise StopAsyncIteration
|
|
364
713
|
|
|
365
|
-
|
|
714
|
+
def _get_start_time(word: cloud_speech_v2.WordInfo | cloud_speech_v1.WordInfo) -> float:
|
|
715
|
+
if hasattr(word, "start_offset"):
|
|
716
|
+
return _duration_to_seconds(word.start_offset)
|
|
717
|
+
return _duration_to_seconds(word.start_time)
|
|
366
718
|
|
|
367
719
|
|
|
368
|
-
def
|
|
369
|
-
|
|
720
|
+
def _get_end_time(word: cloud_speech_v2.WordInfo | cloud_speech_v1.WordInfo) -> float:
|
|
721
|
+
if hasattr(word, "end_offset"):
|
|
722
|
+
return _duration_to_seconds(word.end_offset)
|
|
723
|
+
return _duration_to_seconds(word.end_time)
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
def _recognize_response_to_speech_event(
|
|
727
|
+
resp: cloud_speech_v2.RecognizeResponse | cloud_speech_v1.RecognizeResponse,
|
|
370
728
|
) -> stt.SpeechEvent:
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
alternatives
|
|
729
|
+
text = ""
|
|
730
|
+
confidence = 0.0
|
|
731
|
+
for result in resp.results:
|
|
732
|
+
text += result.alternatives[0].transcript
|
|
733
|
+
confidence += result.alternatives[0].confidence
|
|
734
|
+
|
|
735
|
+
alternatives = []
|
|
736
|
+
|
|
737
|
+
# Google STT may return empty results when spoken_lang != stt_lang
|
|
738
|
+
if resp.results:
|
|
739
|
+
try:
|
|
740
|
+
start_time = _get_start_time(resp.results[0].alternatives[0].words[0])
|
|
741
|
+
end_time = _get_end_time(resp.results[-1].alternatives[0].words[-1])
|
|
742
|
+
except IndexError:
|
|
743
|
+
# When enable_word_time_offsets=False, there are no "words" to access
|
|
744
|
+
start_time = end_time = 0
|
|
745
|
+
|
|
746
|
+
confidence /= len(resp.results)
|
|
747
|
+
lg = resp.results[0].language_code
|
|
748
|
+
|
|
749
|
+
alternatives = [
|
|
376
750
|
stt.SpeechData(
|
|
377
|
-
language=
|
|
378
|
-
start_time=
|
|
379
|
-
end_time=
|
|
380
|
-
confidence=
|
|
381
|
-
text=
|
|
751
|
+
language=lg,
|
|
752
|
+
start_time=start_time,
|
|
753
|
+
end_time=end_time,
|
|
754
|
+
confidence=confidence,
|
|
755
|
+
text=text,
|
|
756
|
+
words=[
|
|
757
|
+
TimedString(
|
|
758
|
+
text=word.word,
|
|
759
|
+
start_time=_get_start_time(word),
|
|
760
|
+
end_time=_get_end_time(word),
|
|
761
|
+
)
|
|
762
|
+
for word in resp.results[0].alternatives[0].words
|
|
763
|
+
]
|
|
764
|
+
if resp.results[0].alternatives[0].words
|
|
765
|
+
else None,
|
|
766
|
+
)
|
|
767
|
+
]
|
|
768
|
+
|
|
769
|
+
return stt.SpeechEvent(type=stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives=alternatives)
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
@utils.log_exceptions(logger=logger)
|
|
773
|
+
def _streaming_recognize_response_to_speech_data(
|
|
774
|
+
resp: cloud_speech_v2.StreamingRecognizeResponse | cloud_speech_v1.StreamingRecognizeResponse,
|
|
775
|
+
*,
|
|
776
|
+
min_confidence_threshold: float,
|
|
777
|
+
start_time_offset: float,
|
|
778
|
+
) -> stt.SpeechData | None:
|
|
779
|
+
text = ""
|
|
780
|
+
confidence = 0.0
|
|
781
|
+
final_result = None
|
|
782
|
+
words: list[cloud_speech_v2.WordInfo | cloud_speech_v1.WordInfo] = []
|
|
783
|
+
for result in resp.results:
|
|
784
|
+
if len(result.alternatives) == 0:
|
|
785
|
+
continue
|
|
786
|
+
else:
|
|
787
|
+
if result.is_final:
|
|
788
|
+
final_result = result
|
|
789
|
+
break
|
|
790
|
+
else:
|
|
791
|
+
text += result.alternatives[0].transcript
|
|
792
|
+
confidence += result.alternatives[0].confidence
|
|
793
|
+
words.extend(result.alternatives[0].words)
|
|
794
|
+
|
|
795
|
+
if final_result is not None:
|
|
796
|
+
text = final_result.alternatives[0].transcript
|
|
797
|
+
confidence = final_result.alternatives[0].confidence
|
|
798
|
+
words = list(final_result.alternatives[0].words)
|
|
799
|
+
lg = final_result.language_code
|
|
800
|
+
else:
|
|
801
|
+
confidence /= len(resp.results)
|
|
802
|
+
if confidence < min_confidence_threshold:
|
|
803
|
+
return None
|
|
804
|
+
lg = resp.results[0].language_code
|
|
805
|
+
|
|
806
|
+
if text == "" or not words:
|
|
807
|
+
if text and not words:
|
|
808
|
+
data = stt.SpeechData(
|
|
809
|
+
language=lg,
|
|
810
|
+
start_time=start_time_offset,
|
|
811
|
+
end_time=start_time_offset,
|
|
812
|
+
confidence=confidence,
|
|
813
|
+
text=text,
|
|
382
814
|
)
|
|
383
|
-
|
|
815
|
+
return data
|
|
816
|
+
return None
|
|
817
|
+
|
|
818
|
+
data = stt.SpeechData(
|
|
819
|
+
language=lg,
|
|
820
|
+
start_time=_get_start_time(words[0]) + start_time_offset,
|
|
821
|
+
end_time=_get_end_time(words[-1]) + start_time_offset,
|
|
822
|
+
confidence=confidence,
|
|
823
|
+
text=text,
|
|
824
|
+
words=[
|
|
825
|
+
TimedString(
|
|
826
|
+
text=word.word,
|
|
827
|
+
start_time=_get_start_time(word) + start_time_offset,
|
|
828
|
+
end_time=_get_end_time(word) + start_time_offset,
|
|
829
|
+
start_time_offset=start_time_offset,
|
|
830
|
+
confidence=word.confidence,
|
|
831
|
+
)
|
|
832
|
+
for word in words
|
|
384
833
|
],
|
|
385
834
|
)
|
|
386
835
|
|
|
387
|
-
|
|
388
|
-
def streaming_recognize_response_to_speech_data(
|
|
389
|
-
resp: cloud_speech.StreamingRecognizeResponse,
|
|
390
|
-
) -> List[stt.SpeechData]:
|
|
391
|
-
result = resp.results[0]
|
|
392
|
-
gg_alts = result.alternatives
|
|
393
|
-
return [
|
|
394
|
-
stt.SpeechData(
|
|
395
|
-
language=result.language_code,
|
|
396
|
-
start_time=alt.words[0].start_offset.seconds if alt.words else 0,
|
|
397
|
-
end_time=alt.words[-1].end_offset.seconds if alt.words else 0,
|
|
398
|
-
confidence=alt.confidence,
|
|
399
|
-
text=alt.transcript,
|
|
400
|
-
)
|
|
401
|
-
for alt in gg_alts
|
|
402
|
-
]
|
|
836
|
+
return data
|