livekit-plugins-google 0.3.0__py3-none-any.whl → 1.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,35 +15,82 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import asyncio
18
- import contextlib
19
18
  import dataclasses
20
- import logging
19
+ import time
20
+ import weakref
21
+ from collections.abc import AsyncGenerator, AsyncIterable
21
22
  from dataclasses import dataclass
22
- from typing import Any, AsyncIterable, Dict, List
23
+ from datetime import timedelta
24
+ from typing import Callable, Union, cast
23
25
 
24
- from livekit import agents, rtc
25
- from livekit.agents import stt
26
- from livekit.agents.utils import AudioBuffer
27
-
28
- from google.auth import credentials # type: ignore
26
+ from google.api_core.client_options import ClientOptions
27
+ from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
28
+ from google.auth import default as gauth_default
29
+ from google.auth.exceptions import DefaultCredentialsError
29
30
  from google.cloud.speech_v2 import SpeechAsyncClient
30
31
  from google.cloud.speech_v2.types import cloud_speech
31
-
32
+ from google.protobuf.duration_pb2 import Duration
33
+ from livekit import rtc
34
+ from livekit.agents import (
35
+ DEFAULT_API_CONNECT_OPTIONS,
36
+ APIConnectionError,
37
+ APIConnectOptions,
38
+ APIStatusError,
39
+ APITimeoutError,
40
+ stt,
41
+ utils,
42
+ )
43
+ from livekit.agents.types import (
44
+ NOT_GIVEN,
45
+ NotGivenOr,
46
+ )
47
+ from livekit.agents.utils import is_given
48
+
49
+ from .log import logger
32
50
  from .models import SpeechLanguages, SpeechModels
33
51
 
34
- LgType = SpeechLanguages | str
35
- LanguageCode = LgType | List[LgType]
52
+ LgType = Union[SpeechLanguages, str]
53
+ LanguageCode = Union[LgType, list[LgType]]
54
+
55
+ # Google STT has a timeout of 5 mins, we'll attempt to restart the session
56
+ # before that timeout is reached
57
+ _max_session_duration = 240
58
+
59
+ # Google is very sensitive to background noise, so we'll ignore results with low confidence
60
+ _default_min_confidence = 0.65
36
61
 
37
62
 
38
63
  # This class is only be used internally to encapsulate the options
39
64
  @dataclass
40
65
  class STTOptions:
41
- languages: List[LgType]
66
+ languages: list[LgType]
42
67
  detect_language: bool
43
68
  interim_results: bool
44
69
  punctuate: bool
45
70
  spoken_punctuation: bool
46
- model: SpeechModels
71
+ enable_word_time_offsets: bool
72
+ enable_word_confidence: bool
73
+ enable_voice_activity_events: bool
74
+ model: SpeechModels | str
75
+ sample_rate: int
76
+ min_confidence_threshold: float
77
+ keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN
78
+
79
+ def build_adaptation(self) -> cloud_speech.SpeechAdaptation | None:
80
+ if is_given(self.keywords):
81
+ return cloud_speech.SpeechAdaptation(
82
+ phrase_sets=[
83
+ cloud_speech.SpeechAdaptation.AdaptationPhraseSet(
84
+ inline_phrase_set=cloud_speech.PhraseSet(
85
+ phrases=[
86
+ cloud_speech.PhraseSet.Phrase(value=keyword, boost=boost)
87
+ for keyword, boost in self.keywords
88
+ ]
89
+ )
90
+ )
91
+ ]
92
+ )
93
+ return None
47
94
 
48
95
 
49
96
  class STT(stt.STT):
@@ -54,23 +101,64 @@ class STT(stt.STT):
54
101
  detect_language: bool = True,
55
102
  interim_results: bool = True,
56
103
  punctuate: bool = True,
57
- spoken_punctuation: bool = True,
58
- model: SpeechModels = "long",
59
- credentials_info: Dict[str, Any] | None = None,
60
- credentials_file: str | None = None,
104
+ spoken_punctuation: bool = False,
105
+ enable_word_time_offsets: bool = True,
106
+ enable_word_confidence: bool = False,
107
+ enable_voice_activity_events: bool = False,
108
+ model: SpeechModels | str = "latest_long",
109
+ location: str = "global",
110
+ sample_rate: int = 16000,
111
+ min_confidence_threshold: float = _default_min_confidence,
112
+ credentials_info: NotGivenOr[dict] = NOT_GIVEN,
113
+ credentials_file: NotGivenOr[str] = NOT_GIVEN,
114
+ keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
115
+ use_streaming: NotGivenOr[bool] = NOT_GIVEN,
61
116
  ):
62
117
  """
63
- if no credentials is provided, it will use the credentials on the environment
64
- GOOGLE_APPLICATION_CREDENTIALS (Default behavior of Google SpeechAsyncClient)
118
+ Create a new instance of Google STT.
119
+
120
+ Credentials must be provided, either by using the ``credentials_info`` dict, or reading
121
+ from the file specified in ``credentials_file`` or via Application Default Credentials as
122
+ described in https://cloud.google.com/docs/authentication/application-default-credentials
123
+
124
+ args:
125
+ languages(LanguageCode): list of language codes to recognize (default: "en-US")
126
+ detect_language(bool): whether to detect the language of the audio (default: True)
127
+ interim_results(bool): whether to return interim results (default: True)
128
+ punctuate(bool): whether to punctuate the audio (default: True)
129
+ spoken_punctuation(bool): whether to use spoken punctuation (default: False)
130
+ enable_word_time_offsets(bool): whether to enable word time offsets (default: True)
131
+ enable_word_confidence(bool): whether to enable word confidence (default: False)
132
+ enable_voice_activity_events(bool): whether to enable voice activity events (default: False)
133
+ model(SpeechModels): the model to use for recognition default: "latest_long"
134
+ location(str): the location to use for recognition default: "global"
135
+ sample_rate(int): the sample rate of the audio default: 16000
136
+ min_confidence_threshold(float): minimum confidence threshold for recognition
137
+ (default: 0.65)
138
+ credentials_info(dict): the credentials info to use for recognition (default: None)
139
+ credentials_file(str): the credentials file to use for recognition (default: None)
140
+ keywords(List[tuple[str, float]]): list of keywords to recognize (default: None)
141
+ use_streaming(bool): whether to use streaming for recognition (default: True)
65
142
  """
66
- super().__init__(streaming_supported=True)
143
+ if not is_given(use_streaming):
144
+ use_streaming = True
145
+ super().__init__(
146
+ capabilities=stt.STTCapabilities(streaming=use_streaming, interim_results=True)
147
+ )
67
148
 
68
- if credentials_info:
69
- self._client = SpeechAsyncClient.from_service_account_info(credentials_info)
70
- elif credentials_file:
71
- self._client = SpeechAsyncClient.from_service_account_file(credentials_file)
72
- else:
73
- self._client = SpeechAsyncClient()
149
+ self._location = location
150
+ self._credentials_info = credentials_info
151
+ self._credentials_file = credentials_file
152
+
153
+ if not is_given(credentials_file) and not is_given(credentials_info):
154
+ try:
155
+ gauth_default() # type: ignore
156
+ except DefaultCredentialsError:
157
+ raise ValueError(
158
+ "Application default credentials must be available "
159
+ "when using Google STT without explicitly passing "
160
+ "credentials through credentials_info or credentials_file."
161
+ ) from None
74
162
 
75
163
  if isinstance(languages, str):
76
164
  languages = [languages]
@@ -81,322 +169,480 @@ class STT(stt.STT):
81
169
  interim_results=interim_results,
82
170
  punctuate=punctuate,
83
171
  spoken_punctuation=spoken_punctuation,
172
+ enable_word_time_offsets=enable_word_time_offsets,
173
+ enable_word_confidence=enable_word_confidence,
174
+ enable_voice_activity_events=enable_voice_activity_events,
84
175
  model=model,
176
+ sample_rate=sample_rate,
177
+ min_confidence_threshold=min_confidence_threshold,
178
+ keywords=keywords,
85
179
  )
86
- self._creds = self._client.transport._credentials
180
+ self._streams = weakref.WeakSet[SpeechStream]()
181
+ self._pool = utils.ConnectionPool[SpeechAsyncClient](
182
+ max_session_duration=_max_session_duration,
183
+ connect_cb=self._create_client,
184
+ )
185
+
186
+ @property
187
+ def model(self) -> str:
188
+ return self._config.model
87
189
 
88
190
  @property
89
- def _recognizer(self) -> str:
191
+ def provider(self) -> str:
192
+ return "Google Cloud Platform"
193
+
194
+ async def _create_client(self, timeout: float) -> SpeechAsyncClient:
195
+ # Add support for passing a specific location that matches recognizer
196
+ # see: https://cloud.google.com/speech-to-text/v2/docs/speech-to-text-supported-languages
197
+ # TODO(long): how to set timeout?
198
+ client_options = None
199
+ client: SpeechAsyncClient | None = None
200
+ if self._location != "global":
201
+ client_options = ClientOptions(api_endpoint=f"{self._location}-speech.googleapis.com")
202
+ if is_given(self._credentials_info):
203
+ client = SpeechAsyncClient.from_service_account_info(
204
+ self._credentials_info, client_options=client_options
205
+ )
206
+ elif is_given(self._credentials_file):
207
+ client = SpeechAsyncClient.from_service_account_file(
208
+ self._credentials_file, client_options=client_options
209
+ )
210
+ else:
211
+ client = SpeechAsyncClient(client_options=client_options)
212
+ assert client is not None
213
+ return client
214
+
215
+ def _get_recognizer(self, client: SpeechAsyncClient) -> str:
90
216
  # TODO(theomonnom): should we use recognizers?
91
- # Recognizers may improve latency https://cloud.google.com/speech-to-text/v2/docs/recognizers#understand_recognizers
92
- return f"projects/{self._creds.project_id}/locations/global/recognizers/_" # type: ignore
217
+ # recognizers may improve latency https://cloud.google.com/speech-to-text/v2/docs/recognizers#understand_recognizers
93
218
 
94
- def _sanitize_options(
95
- self,
96
- *,
97
- language: str | None = None,
98
- ) -> STTOptions:
219
+ # TODO(theomonnom): find a better way to access the project_id
220
+ try:
221
+ project_id = client.transport._credentials.project_id # type: ignore
222
+ except AttributeError:
223
+ from google.auth import default as ga_default
224
+
225
+ _, project_id = ga_default() # type: ignore
226
+ return f"projects/{project_id}/locations/{self._location}/recognizers/_"
227
+
228
+ def _sanitize_options(self, *, language: NotGivenOr[str] = NOT_GIVEN) -> STTOptions:
99
229
  config = dataclasses.replace(self._config)
100
230
 
101
- if language:
231
+ if is_given(language):
102
232
  config.languages = [language]
103
233
 
104
234
  if not isinstance(config.languages, list):
105
235
  config.languages = [config.languages]
106
236
  elif not config.detect_language:
107
237
  if len(config.languages) > 1:
108
- logging.warning(
109
- "multiple languages provided, but language detection is disabled"
110
- )
238
+ logger.warning("multiple languages provided, but language detection is disabled")
111
239
  config.languages = [config.languages[0]]
112
240
 
113
241
  return config
114
242
 
115
- async def recognize(
243
+ async def _recognize_impl(
116
244
  self,
245
+ buffer: utils.AudioBuffer,
117
246
  *,
118
- buffer: AudioBuffer,
119
- language: SpeechLanguages | str | None = None,
247
+ language: NotGivenOr[SpeechLanguages | str] = NOT_GIVEN,
248
+ conn_options: APIConnectOptions,
120
249
  ) -> stt.SpeechEvent:
121
250
  config = self._sanitize_options(language=language)
122
- buffer = agents.utils.merge_frames(buffer)
251
+ frame = rtc.combine_audio_frames(buffer)
123
252
 
124
253
  config = cloud_speech.RecognitionConfig(
125
254
  explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
126
255
  encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
127
- sample_rate_hertz=buffer.sample_rate,
128
- audio_channel_count=buffer.num_channels,
256
+ sample_rate_hertz=frame.sample_rate,
257
+ audio_channel_count=frame.num_channels,
129
258
  ),
259
+ adaptation=config.build_adaptation(),
130
260
  features=cloud_speech.RecognitionFeatures(
131
261
  enable_automatic_punctuation=config.punctuate,
132
262
  enable_spoken_punctuation=config.spoken_punctuation,
263
+ enable_word_time_offsets=config.enable_word_time_offsets,
264
+ enable_word_confidence=config.enable_word_confidence,
133
265
  ),
134
266
  model=config.model,
135
267
  language_codes=config.languages,
136
268
  )
137
269
 
138
- return recognize_response_to_speech_event(
139
- await self._client.recognize(
140
- cloud_speech.RecognizeRequest(
141
- recognizer=self._recognizer,
142
- config=config,
143
- content=buffer.data.tobytes(),
270
+ try:
271
+ async with self._pool.connection(timeout=conn_options.timeout) as client:
272
+ raw = await client.recognize(
273
+ cloud_speech.RecognizeRequest(
274
+ recognizer=self._get_recognizer(client),
275
+ config=config,
276
+ content=frame.data.tobytes(),
277
+ ),
278
+ timeout=conn_options.timeout,
144
279
  )
145
- )
146
- )
280
+
281
+ return _recognize_response_to_speech_event(raw)
282
+ except DeadlineExceeded:
283
+ raise APITimeoutError() from None
284
+ except GoogleAPICallError as e:
285
+ raise APIStatusError(f"{e.message} {e.details}", status_code=e.code or -1) from e
286
+ except Exception as e:
287
+ raise APIConnectionError() from e
147
288
 
148
289
  def stream(
149
290
  self,
150
291
  *,
151
- language: SpeechLanguages | str | None = None,
152
- ) -> "SpeechStream":
292
+ language: NotGivenOr[SpeechLanguages | str] = NOT_GIVEN,
293
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
294
+ ) -> SpeechStream:
153
295
  config = self._sanitize_options(language=language)
154
- return SpeechStream(
155
- self._client,
156
- self._creds,
157
- self._recognizer,
158
- config,
296
+ stream = SpeechStream(
297
+ stt=self,
298
+ pool=self._pool,
299
+ recognizer_cb=self._get_recognizer,
300
+ config=config,
301
+ conn_options=conn_options,
159
302
  )
303
+ self._streams.add(stream)
304
+ return stream
305
+
306
+ def update_options(
307
+ self,
308
+ *,
309
+ languages: NotGivenOr[LanguageCode] = NOT_GIVEN,
310
+ detect_language: NotGivenOr[bool] = NOT_GIVEN,
311
+ interim_results: NotGivenOr[bool] = NOT_GIVEN,
312
+ punctuate: NotGivenOr[bool] = NOT_GIVEN,
313
+ spoken_punctuation: NotGivenOr[bool] = NOT_GIVEN,
314
+ model: NotGivenOr[SpeechModels] = NOT_GIVEN,
315
+ location: NotGivenOr[str] = NOT_GIVEN,
316
+ keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
317
+ ) -> None:
318
+ if is_given(languages):
319
+ if isinstance(languages, str):
320
+ languages = [languages]
321
+ self._config.languages = cast(list[LgType], languages)
322
+ if is_given(detect_language):
323
+ self._config.detect_language = detect_language
324
+ if is_given(interim_results):
325
+ self._config.interim_results = interim_results
326
+ if is_given(punctuate):
327
+ self._config.punctuate = punctuate
328
+ if is_given(spoken_punctuation):
329
+ self._config.spoken_punctuation = spoken_punctuation
330
+ if is_given(model):
331
+ self._config.model = model
332
+ if is_given(location):
333
+ self._location = location
334
+ # if location is changed, fetch a new client and recognizer as per the new location
335
+ self._pool.invalidate()
336
+ if is_given(keywords):
337
+ self._config.keywords = keywords
338
+
339
+ for stream in self._streams:
340
+ stream.update_options(
341
+ languages=languages,
342
+ detect_language=detect_language,
343
+ interim_results=interim_results,
344
+ punctuate=punctuate,
345
+ spoken_punctuation=spoken_punctuation,
346
+ model=model,
347
+ keywords=keywords,
348
+ )
349
+
350
+ async def aclose(self) -> None:
351
+ await self._pool.aclose()
352
+ await super().aclose()
160
353
 
161
354
 
162
355
  class SpeechStream(stt.SpeechStream):
163
356
  def __init__(
164
357
  self,
165
- client: SpeechAsyncClient,
166
- creds: credentials.Credentials,
167
- recognizer: str,
358
+ *,
359
+ stt: STT,
360
+ conn_options: APIConnectOptions,
361
+ pool: utils.ConnectionPool[SpeechAsyncClient],
362
+ recognizer_cb: Callable[[SpeechAsyncClient], str],
168
363
  config: STTOptions,
169
- sample_rate: int = 24000,
170
- num_channels: int = 1,
171
- max_retry: int = 32,
172
364
  ) -> None:
173
- super().__init__()
365
+ super().__init__(stt=stt, conn_options=conn_options, sample_rate=config.sample_rate)
174
366
 
175
- self._client = client
176
- self._creds = creds
177
- self._recognizer = recognizer
367
+ self._pool = pool
368
+ self._recognizer_cb = recognizer_cb
178
369
  self._config = config
179
- self._sample_rate = sample_rate
180
- self._num_channels = num_channels
181
-
182
- self._queue = asyncio.Queue[rtc.AudioFrame | None]()
183
- self._event_queue = asyncio.Queue[stt.SpeechEvent | None]()
184
- self._closed = False
185
- self._main_task = asyncio.create_task(self._run(max_retry=max_retry))
186
-
187
- self._final_events: List[stt.SpeechEvent] = []
188
- self._speaking = False
189
-
190
- self._streaming_config = cloud_speech.StreamingRecognitionConfig(
191
- config=cloud_speech.RecognitionConfig(
192
- explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
193
- encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
194
- sample_rate_hertz=self._sample_rate,
195
- audio_channel_count=self._num_channels,
196
- ),
197
- language_codes=self._config.languages,
198
- model=self._config.model,
199
- features=cloud_speech.RecognitionFeatures(
200
- enable_automatic_punctuation=self._config.punctuate,
201
- ),
202
- ),
203
- streaming_features=cloud_speech.StreamingRecognitionFeatures(
204
- enable_voice_activity_events=True,
205
- interim_results=self._config.interim_results,
206
- ),
207
- )
370
+ self._reconnect_event = asyncio.Event()
371
+ self._session_connected_at: float = 0
208
372
 
209
- def log_exception(task: asyncio.Task) -> None:
210
- if not task.cancelled() and task.exception():
211
- logging.error(f"google stt task failed: {task.exception()}")
212
-
213
- self._main_task.add_done_callback(log_exception)
214
-
215
- def push_frame(self, frame: rtc.AudioFrame) -> None:
216
- if self._closed:
217
- raise ValueError("cannot push frame to closed stream")
218
-
219
- self._queue.put_nowait(frame)
220
-
221
- async def aclose(self, wait: bool = True) -> None:
222
- self._closed = True
223
- if not wait:
224
- self._main_task.cancel()
225
-
226
- self._queue.put_nowait(None)
227
- with contextlib.suppress(asyncio.CancelledError):
228
- await self._main_task
373
+ def update_options(
374
+ self,
375
+ *,
376
+ languages: NotGivenOr[LanguageCode] = NOT_GIVEN,
377
+ detect_language: NotGivenOr[bool] = NOT_GIVEN,
378
+ interim_results: NotGivenOr[bool] = NOT_GIVEN,
379
+ punctuate: NotGivenOr[bool] = NOT_GIVEN,
380
+ spoken_punctuation: NotGivenOr[bool] = NOT_GIVEN,
381
+ model: NotGivenOr[SpeechModels] = NOT_GIVEN,
382
+ min_confidence_threshold: NotGivenOr[float] = NOT_GIVEN,
383
+ keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
384
+ ) -> None:
385
+ if is_given(languages):
386
+ if isinstance(languages, str):
387
+ languages = [languages]
388
+ self._config.languages = cast(list[LgType], languages)
389
+ if is_given(detect_language):
390
+ self._config.detect_language = detect_language
391
+ if is_given(interim_results):
392
+ self._config.interim_results = interim_results
393
+ if is_given(punctuate):
394
+ self._config.punctuate = punctuate
395
+ if is_given(spoken_punctuation):
396
+ self._config.spoken_punctuation = spoken_punctuation
397
+ if is_given(model):
398
+ self._config.model = model
399
+ if is_given(min_confidence_threshold):
400
+ self._config.min_confidence_threshold = min_confidence_threshold
401
+ if is_given(keywords):
402
+ self._config.keywords = keywords
403
+
404
+ self._reconnect_event.set()
405
+
406
+ async def _run(self) -> None:
407
+ audio_pushed = False
408
+
409
+ # google requires a async generator when calling streaming_recognize
410
+ # this function basically convert the queue into a async generator
411
+ async def input_generator(
412
+ client: SpeechAsyncClient, should_stop: asyncio.Event
413
+ ) -> AsyncGenerator[cloud_speech.StreamingRecognizeRequest, None]:
414
+ nonlocal audio_pushed
415
+ try:
416
+ # first request should contain the config
417
+ yield cloud_speech.StreamingRecognizeRequest(
418
+ recognizer=self._recognizer_cb(client),
419
+ streaming_config=self._streaming_config,
420
+ )
229
421
 
230
- async def _run(self, max_retry: int) -> None:
231
- retry_count = 0
232
- try:
233
- while not self._closed:
234
- try:
235
- # google requires a async generator when calling streaming_recognize
236
- # this function basically convert the queue into a async generator
237
- async def input_generator():
238
- try:
239
- # first request should contain the config
240
- yield cloud_speech.StreamingRecognizeRequest(
241
- recognizer=self._recognizer,
242
- streaming_config=self._streaming_config,
422
+ async for frame in self._input_ch:
423
+ # when the stream is aborted due to reconnect, this input_generator
424
+ # needs to stop consuming frames
425
+ # when the generator stops, the previous gRPC stream will close
426
+ if should_stop.is_set():
427
+ return
428
+
429
+ if isinstance(frame, rtc.AudioFrame):
430
+ yield cloud_speech.StreamingRecognizeRequest(audio=frame.data.tobytes())
431
+ if not audio_pushed:
432
+ audio_pushed = True
433
+
434
+ except Exception:
435
+ logger.exception("an error occurred while streaming input to google STT")
436
+
437
+ async def process_stream(
438
+ client: SpeechAsyncClient,
439
+ stream: AsyncIterable[cloud_speech.StreamingRecognizeResponse],
440
+ ) -> None:
441
+ has_started = False
442
+ async for resp in stream:
443
+ if (
444
+ resp.speech_event_type
445
+ == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN
446
+ ):
447
+ self._event_ch.send_nowait(
448
+ stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
449
+ )
450
+ has_started = True
451
+
452
+ if (
453
+ resp.speech_event_type
454
+ == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_TYPE_UNSPECIFIED # noqa: E501
455
+ ):
456
+ result = resp.results[0]
457
+ speech_data = _streaming_recognize_response_to_speech_data(
458
+ resp,
459
+ min_confidence_threshold=self._config.min_confidence_threshold,
460
+ )
461
+ if speech_data is None:
462
+ continue
463
+
464
+ if not result.is_final:
465
+ self._event_ch.send_nowait(
466
+ stt.SpeechEvent(
467
+ type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
468
+ alternatives=[speech_data],
243
469
  )
244
- while True:
245
- frame = (
246
- await self._queue.get()
247
- ) # wait for a new rtc.AudioFrame
248
- if frame is None:
249
- break # None is sent inside aclose
250
-
251
- self._queue.task_done()
252
- frame = frame.remix_and_resample(
253
- self._sample_rate, self._num_channels
254
- )
255
- yield cloud_speech.StreamingRecognizeRequest(
256
- audio=frame.data.tobytes(),
257
- )
258
- except Exception as e:
259
- logging.error(
260
- f"an error occurred while streaming inputs: {e}"
470
+ )
471
+ else:
472
+ self._event_ch.send_nowait(
473
+ stt.SpeechEvent(
474
+ type=stt.SpeechEventType.FINAL_TRANSCRIPT,
475
+ alternatives=[speech_data],
261
476
  )
262
-
263
- # try to connect
264
- stream = await self._client.streaming_recognize(
265
- requests=input_generator()
266
- )
267
- retry_count = 0 # connection successful, reset retry count
268
-
269
- await self._run_stream(stream)
270
- except Exception as e:
271
- if retry_count >= max_retry:
272
- logging.error(
273
- f"failed to connect to google stt after {max_retry} tries",
274
- exc_info=e,
275
477
  )
276
- break
277
-
278
- retry_delay = min(retry_count * 2, 10) # max 10s
279
- retry_count += 1
280
- logging.warning(
281
- f"google stt connection failed, retrying in {retry_delay}s",
282
- exc_info=e,
478
+ if time.time() - self._session_connected_at > _max_session_duration:
479
+ logger.debug(
480
+ "Google STT maximum connection time reached. Reconnecting..."
481
+ )
482
+ self._pool.remove(client)
483
+ if has_started:
484
+ self._event_ch.send_nowait(
485
+ stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
486
+ )
487
+ has_started = False
488
+ self._reconnect_event.set()
489
+ return
490
+
491
+ if (
492
+ resp.speech_event_type
493
+ == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END
494
+ ):
495
+ self._event_ch.send_nowait(
496
+ stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
283
497
  )
284
- await asyncio.sleep(retry_delay)
285
- finally:
286
- self._event_queue.put_nowait(None)
287
-
288
- async def _run_stream(
289
- self, stream: AsyncIterable[cloud_speech.StreamingRecognizeResponse]
290
- ):
291
- async for resp in stream:
292
- if (
293
- resp.speech_event_type
294
- == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN
295
- ):
296
- self._speaking = True
297
- start_event = stt.SpeechEvent(
298
- type=stt.SpeechEventType.START_OF_SPEECH,
299
- )
300
- self._event_queue.put_nowait(start_event)
301
-
302
- if (
303
- resp.speech_event_type
304
- == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_TYPE_UNSPECIFIED
305
- ):
306
- result = resp.results[0]
307
- if not result.is_final:
308
- # interim results
309
- iterim_event = stt.SpeechEvent(
310
- type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
311
- alternatives=streaming_recognize_response_to_speech_data(resp),
498
+ has_started = False
499
+
500
+ while True:
501
+ audio_pushed = False
502
+ try:
503
+ async with self._pool.connection(timeout=self._conn_options.timeout) as client:
504
+ self._streaming_config = cloud_speech.StreamingRecognitionConfig(
505
+ config=cloud_speech.RecognitionConfig(
506
+ explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
507
+ encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
508
+ sample_rate_hertz=self._config.sample_rate,
509
+ audio_channel_count=1,
510
+ ),
511
+ adaptation=self._config.build_adaptation(),
512
+ language_codes=self._config.languages,
513
+ model=self._config.model,
514
+ features=cloud_speech.RecognitionFeatures(
515
+ enable_automatic_punctuation=self._config.punctuate,
516
+ enable_word_time_offsets=self._config.enable_word_time_offsets,
517
+ enable_spoken_punctuation=self._config.spoken_punctuation,
518
+ ),
519
+ ),
520
+ streaming_features=cloud_speech.StreamingRecognitionFeatures(
521
+ interim_results=self._config.interim_results,
522
+ enable_voice_activity_events=self._config.enable_voice_activity_events,
523
+ ),
312
524
  )
313
- self._event_queue.put_nowait(iterim_event)
314
525
 
315
- else:
316
- final_event = stt.SpeechEvent(
317
- type=stt.SpeechEventType.FINAL_TRANSCRIPT,
318
- alternatives=streaming_recognize_response_to_speech_data(resp),
526
+ should_stop = asyncio.Event()
527
+ stream = await client.streaming_recognize(
528
+ requests=input_generator(client, should_stop),
319
529
  )
320
- self._final_events.append(final_event)
321
- self._event_queue.put_nowait(final_event)
322
-
323
- if not self._speaking:
324
- # With Google STT, we receive the final event after the END_OF_SPEECH event
325
- sentence = ""
326
- confidence = 0.0
327
- for alt in self._final_events:
328
- sentence += f"{alt.alternatives[0].text.strip()} "
329
- confidence += alt.alternatives[0].confidence
330
-
331
- sentence = sentence.rstrip()
332
- confidence /= len(self._final_events) # avg. of confidence
333
-
334
- end_event = stt.SpeechEvent(
335
- type=stt.SpeechEventType.END_OF_SPEECH,
336
- alternatives=[
337
- stt.SpeechData(
338
- language=result.language_code,
339
- start_time=self._final_events[0]
340
- .alternatives[0]
341
- .start_time,
342
- end_time=self._final_events[-1]
343
- .alternatives[0]
344
- .end_time,
345
- confidence=confidence,
346
- text=sentence,
347
- )
348
- ],
349
- )
530
+ self._session_connected_at = time.time()
350
531
 
351
- self._final_events = []
352
- self._event_queue.put_nowait(end_event)
532
+ process_stream_task = asyncio.create_task(process_stream(client, stream))
533
+ wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
353
534
 
354
- if (
355
- resp.speech_event_type
356
- == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END
357
- ):
358
- self._speaking = False
535
+ try:
536
+ done, _ = await asyncio.wait(
537
+ [process_stream_task, wait_reconnect_task],
538
+ return_when=asyncio.FIRST_COMPLETED,
539
+ )
540
+ for task in done:
541
+ if task != wait_reconnect_task:
542
+ task.result()
543
+ if wait_reconnect_task not in done:
544
+ break
545
+ self._reconnect_event.clear()
546
+ finally:
547
+ should_stop.set()
548
+ if not process_stream_task.done() and not wait_reconnect_task.done():
549
+ # try to gracefully stop the process_stream_task
550
+ try:
551
+ await asyncio.wait_for(process_stream_task, timeout=1.0)
552
+ except asyncio.TimeoutError:
553
+ pass
554
+
555
+ await utils.aio.gracefully_cancel(process_stream_task, wait_reconnect_task)
556
+ except DeadlineExceeded:
557
+ raise APITimeoutError() from None
558
+ except GoogleAPICallError as e:
559
+ if e.code == 409:
560
+ if audio_pushed:
561
+ logger.debug("stream timed out, restarting.")
562
+ else:
563
+ raise APIStatusError(
564
+ f"{e.message} {e.details}", status_code=e.code or -1
565
+ ) from e
566
+ except Exception as e:
567
+ raise APIConnectionError() from e
359
568
 
360
- async def __anext__(self) -> stt.SpeechEvent:
361
- evt = await self._event_queue.get()
362
- if evt is None:
363
- raise StopAsyncIteration
364
569
 
365
- return evt
570
+ def _duration_to_seconds(duration: Duration | timedelta) -> float:
571
+ # Proto Plus may auto-convert Duration to timedelta; handle both.
572
+ # https://proto-plus-python.readthedocs.io/en/latest/marshal.html
573
+ if isinstance(duration, timedelta):
574
+ return duration.total_seconds()
575
+ return duration.seconds + duration.nanos / 1e9
366
576
 
367
577
 
368
- def recognize_response_to_speech_event(
578
+ def _recognize_response_to_speech_event(
369
579
  resp: cloud_speech.RecognizeResponse,
370
580
  ) -> stt.SpeechEvent:
371
- result = resp.results[0]
372
- gg_alts = result.alternatives
373
- return stt.SpeechEvent(
374
- type=stt.SpeechEventType.FINAL_TRANSCRIPT,
375
- alternatives=[
581
+ text = ""
582
+ confidence = 0.0
583
+ for result in resp.results:
584
+ text += result.alternatives[0].transcript
585
+ confidence += result.alternatives[0].confidence
586
+
587
+ alternatives = []
588
+
589
+ # Google STT may return empty results when spoken_lang != stt_lang
590
+ if resp.results:
591
+ try:
592
+ start_time = _duration_to_seconds(resp.results[0].alternatives[0].words[0].start_offset)
593
+ end_time = _duration_to_seconds(resp.results[-1].alternatives[0].words[-1].end_offset)
594
+ except IndexError:
595
+ # When enable_word_time_offsets=False, there are no "words" to access
596
+ start_time = end_time = 0
597
+
598
+ confidence /= len(resp.results)
599
+ lg = resp.results[0].language_code
600
+
601
+ alternatives = [
376
602
  stt.SpeechData(
377
- language=result.language_code,
378
- start_time=alt.words[0].start_offset.seconds if alt.words else 0,
379
- end_time=alt.words[-1].end_offset.seconds if alt.words else 0,
380
- confidence=alt.confidence,
381
- text=alt.transcript,
603
+ language=lg,
604
+ start_time=start_time,
605
+ end_time=end_time,
606
+ confidence=confidence,
607
+ text=text,
382
608
  )
383
- for alt in gg_alts
384
- ],
385
- )
609
+ ]
610
+
611
+ return stt.SpeechEvent(type=stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives=alternatives)
386
612
 
387
613
 
388
- def streaming_recognize_response_to_speech_data(
614
+ def _streaming_recognize_response_to_speech_data(
389
615
  resp: cloud_speech.StreamingRecognizeResponse,
390
- ) -> List[stt.SpeechData]:
391
- result = resp.results[0]
392
- gg_alts = result.alternatives
393
- return [
394
- stt.SpeechData(
395
- language=result.language_code,
396
- start_time=alt.words[0].start_offset.seconds if alt.words else 0,
397
- end_time=alt.words[-1].end_offset.seconds if alt.words else 0,
398
- confidence=alt.confidence,
399
- text=alt.transcript,
400
- )
401
- for alt in gg_alts
402
- ]
616
+ *,
617
+ min_confidence_threshold: float,
618
+ ) -> stt.SpeechData | None:
619
+ text = ""
620
+ confidence = 0.0
621
+ final_result = None
622
+ for result in resp.results:
623
+ if len(result.alternatives) == 0:
624
+ continue
625
+ else:
626
+ if result.is_final:
627
+ final_result = result
628
+ break
629
+ else:
630
+ text += result.alternatives[0].transcript
631
+ confidence += result.alternatives[0].confidence
632
+
633
+ if final_result is not None:
634
+ text = final_result.alternatives[0].transcript
635
+ confidence = final_result.alternatives[0].confidence
636
+ lg = final_result.language_code
637
+ else:
638
+ confidence /= len(resp.results)
639
+ if confidence < min_confidence_threshold:
640
+ return None
641
+ lg = resp.results[0].language_code
642
+
643
+ if text == "":
644
+ return None
645
+
646
+ data = stt.SpeechData(language=lg, start_time=0, end_time=0, confidence=confidence, text=text)
647
+
648
+ return data