livekit-plugins-google 0.3.0__py3-none-any.whl → 1.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,35 +15,103 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import asyncio
18
- import contextlib
19
18
  import dataclasses
20
- import logging
19
+ import time
20
+ import weakref
21
+ from collections.abc import AsyncGenerator, AsyncIterable
21
22
  from dataclasses import dataclass
22
- from typing import Any, AsyncIterable, Dict, List
23
-
24
- from livekit import agents, rtc
25
- from livekit.agents import stt
26
- from livekit.agents.utils import AudioBuffer
27
-
28
- from google.auth import credentials # type: ignore
29
- from google.cloud.speech_v2 import SpeechAsyncClient
30
- from google.cloud.speech_v2.types import cloud_speech
31
-
32
- from .models import SpeechLanguages, SpeechModels
33
-
34
- LgType = SpeechLanguages | str
35
- LanguageCode = LgType | List[LgType]
23
+ from datetime import timedelta
24
+ from typing import Callable, Union, cast, get_args
25
+
26
+ from google.api_core.client_options import ClientOptions
27
+ from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
28
+ from google.auth import default as gauth_default
29
+ from google.auth.exceptions import DefaultCredentialsError
30
+ from google.cloud.speech_v1 import SpeechAsyncClient as SpeechAsyncClientV1
31
+ from google.cloud.speech_v1.types import cloud_speech as cloud_speech_v1, resource as resource_v1
32
+ from google.cloud.speech_v2 import SpeechAsyncClient as SpeechAsyncClientV2
33
+ from google.cloud.speech_v2.types import cloud_speech as cloud_speech_v2
34
+ from google.protobuf.duration_pb2 import Duration
35
+ from livekit import rtc
36
+ from livekit.agents import (
37
+ DEFAULT_API_CONNECT_OPTIONS,
38
+ APIConnectionError,
39
+ APIConnectOptions,
40
+ APIStatusError,
41
+ APITimeoutError,
42
+ stt,
43
+ utils,
44
+ )
45
+ from livekit.agents.types import (
46
+ NOT_GIVEN,
47
+ NotGivenOr,
48
+ )
49
+ from livekit.agents.utils import is_given
50
+ from livekit.agents.voice.io import TimedString
51
+
52
+ from .log import logger
53
+ from .models import SpeechLanguages, SpeechModels, SpeechModelsV2
54
+
55
+ LgType = Union[SpeechLanguages, str]
56
+ LanguageCode = Union[LgType, list[LgType]]
57
+
58
+ # Google STT has a timeout of 5 mins, we'll attempt to restart the session
59
+ # before that timeout is reached
60
+ _max_session_duration = 240
61
+
62
+ # Google is very sensitive to background noise, so we'll ignore results with low confidence
63
+ _default_min_confidence = 0.65
36
64
 
37
65
 
38
66
  # This class is only be used internally to encapsulate the options
39
67
  @dataclass
40
68
  class STTOptions:
41
- languages: List[LgType]
69
+ languages: list[LgType]
42
70
  detect_language: bool
43
71
  interim_results: bool
44
72
  punctuate: bool
45
73
  spoken_punctuation: bool
46
- model: SpeechModels
74
+ enable_word_time_offsets: bool
75
+ enable_word_confidence: bool
76
+ enable_voice_activity_events: bool
77
+ model: SpeechModels | str
78
+ sample_rate: int
79
+ min_confidence_threshold: float
80
+ keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN
81
+
82
+ @property
83
+ def version(self) -> int:
84
+ return 2 if self.model in get_args(SpeechModelsV2) else 1
85
+
86
+ def build_adaptation(
87
+ self,
88
+ ) -> cloud_speech_v2.SpeechAdaptation | resource_v1.SpeechAdaptation | None:
89
+ if is_given(self.keywords):
90
+ if self.version == 2:
91
+ return cloud_speech_v2.SpeechAdaptation(
92
+ phrase_sets=[
93
+ cloud_speech_v2.SpeechAdaptation.AdaptationPhraseSet(
94
+ inline_phrase_set=cloud_speech_v2.PhraseSet(
95
+ phrases=[
96
+ cloud_speech_v2.PhraseSet.Phrase(value=keyword, boost=boost)
97
+ for keyword, boost in self.keywords
98
+ ]
99
+ )
100
+ )
101
+ ]
102
+ )
103
+ return resource_v1.SpeechAdaptation(
104
+ phrase_sets=[
105
+ resource_v1.PhraseSet(
106
+ name="keywords",
107
+ phrases=[
108
+ resource_v1.PhraseSet.Phrase(value=keyword, boost=boost)
109
+ for keyword, boost in self.keywords
110
+ ],
111
+ )
112
+ ]
113
+ )
114
+ return None
47
115
 
48
116
 
49
117
  class STT(stt.STT):
@@ -54,23 +122,80 @@ class STT(stt.STT):
54
122
  detect_language: bool = True,
55
123
  interim_results: bool = True,
56
124
  punctuate: bool = True,
57
- spoken_punctuation: bool = True,
58
- model: SpeechModels = "long",
59
- credentials_info: Dict[str, Any] | None = None,
60
- credentials_file: str | None = None,
125
+ spoken_punctuation: bool = False,
126
+ enable_word_time_offsets: NotGivenOr[bool] = NOT_GIVEN,
127
+ enable_word_confidence: bool = False,
128
+ enable_voice_activity_events: bool = False,
129
+ model: SpeechModels | str = "latest_long",
130
+ location: str = "global",
131
+ sample_rate: int = 16000,
132
+ min_confidence_threshold: float = _default_min_confidence,
133
+ credentials_info: NotGivenOr[dict] = NOT_GIVEN,
134
+ credentials_file: NotGivenOr[str] = NOT_GIVEN,
135
+ keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
136
+ use_streaming: NotGivenOr[bool] = NOT_GIVEN,
61
137
  ):
62
138
  """
63
- if no credentials is provided, it will use the credentials on the environment
64
- GOOGLE_APPLICATION_CREDENTIALS (Default behavior of Google SpeechAsyncClient)
139
+ Create a new instance of Google STT.
140
+
141
+ Credentials must be provided, either by using the ``credentials_info`` dict, or reading
142
+ from the file specified in ``credentials_file`` or via Application Default Credentials as
143
+ described in https://cloud.google.com/docs/authentication/application-default-credentials
144
+
145
+ args:
146
+ languages(LanguageCode): list of language codes to recognize (default: "en-US")
147
+ detect_language(bool): whether to detect the language of the audio (default: True)
148
+ interim_results(bool): whether to return interim results (default: True)
149
+ punctuate(bool): whether to punctuate the audio (default: True)
150
+ spoken_punctuation(bool): whether to use spoken punctuation (default: False)
151
+ enable_word_time_offsets(bool): whether to enable word time offsets (default: None)
152
+ enable_word_confidence(bool): whether to enable word confidence (default: False)
153
+ enable_voice_activity_events(bool): whether to enable voice activity events (default: False)
154
+ model(SpeechModels): the model to use for recognition default: "latest_long"
155
+ location(str): the location to use for recognition default: "global"
156
+ sample_rate(int): the sample rate of the audio default: 16000
157
+ min_confidence_threshold(float): minimum confidence threshold for recognition
158
+ (default: 0.65)
159
+ credentials_info(dict): the credentials info to use for recognition (default: None)
160
+ credentials_file(str): the credentials file to use for recognition (default: None)
161
+ keywords(List[tuple[str, float]]): list of keywords to recognize (default: None)
162
+ use_streaming(bool): whether to use streaming for recognition (default: True)
65
163
  """
66
- super().__init__(streaming_supported=True)
164
+ if not is_given(use_streaming):
165
+ use_streaming = True
67
166
 
68
- if credentials_info:
69
- self._client = SpeechAsyncClient.from_service_account_info(credentials_info)
70
- elif credentials_file:
71
- self._client = SpeechAsyncClient.from_service_account_file(credentials_file)
167
+ if model == "chirp_3":
168
+ if is_given(enable_word_time_offsets) and enable_word_time_offsets:
169
+ logger.warning(
170
+ "Chirp 3 does not support word timestamps, setting 'enable_word_time_offsets' to False."
171
+ )
172
+ enable_word_time_offsets = False
173
+ elif is_given(enable_word_time_offsets):
174
+ enable_word_time_offsets = enable_word_time_offsets
72
175
  else:
73
- self._client = SpeechAsyncClient()
176
+ enable_word_time_offsets = True
177
+
178
+ super().__init__(
179
+ capabilities=stt.STTCapabilities(
180
+ streaming=use_streaming,
181
+ interim_results=True,
182
+ aligned_transcript="word" if enable_word_time_offsets and use_streaming else False,
183
+ )
184
+ )
185
+
186
+ self._location = location
187
+ self._credentials_info = credentials_info
188
+ self._credentials_file = credentials_file
189
+
190
+ if not is_given(credentials_file) and not is_given(credentials_info):
191
+ try:
192
+ gauth_default() # type: ignore
193
+ except DefaultCredentialsError:
194
+ raise ValueError(
195
+ "Application default credentials must be available "
196
+ "when using Google STT without explicitly passing "
197
+ "credentials through credentials_info or credentials_file."
198
+ ) from None
74
199
 
75
200
  if isinstance(languages, str):
76
201
  languages = [languages]
@@ -81,322 +206,631 @@ class STT(stt.STT):
81
206
  interim_results=interim_results,
82
207
  punctuate=punctuate,
83
208
  spoken_punctuation=spoken_punctuation,
209
+ enable_word_time_offsets=enable_word_time_offsets,
210
+ enable_word_confidence=enable_word_confidence,
211
+ enable_voice_activity_events=enable_voice_activity_events,
84
212
  model=model,
213
+ sample_rate=sample_rate,
214
+ min_confidence_threshold=min_confidence_threshold,
215
+ keywords=keywords,
216
+ )
217
+ self._streams = weakref.WeakSet[SpeechStream]()
218
+ self._pool = utils.ConnectionPool[SpeechAsyncClientV2 | SpeechAsyncClientV1](
219
+ max_session_duration=_max_session_duration,
220
+ connect_cb=self._create_client,
85
221
  )
86
- self._creds = self._client.transport._credentials
87
222
 
88
223
  @property
89
- def _recognizer(self) -> str:
224
+ def model(self) -> str:
225
+ return self._config.model
226
+
227
+ @property
228
+ def provider(self) -> str:
229
+ return "Google Cloud Platform"
230
+
231
+ async def _create_client(self, timeout: float) -> SpeechAsyncClientV2 | SpeechAsyncClientV1:
232
+ # Add support for passing a specific location that matches recognizer
233
+ # see: https://cloud.google.com/speech-to-text/v2/docs/speech-to-text-supported-languages
234
+ # TODO(long): how to set timeout?
235
+ client_options = None
236
+ client: SpeechAsyncClientV2 | SpeechAsyncClientV1 | None = None
237
+ client_cls = SpeechAsyncClientV2 if self._config.version == 2 else SpeechAsyncClientV1
238
+ if self._location != "global":
239
+ client_options = ClientOptions(api_endpoint=f"{self._location}-speech.googleapis.com")
240
+ if is_given(self._credentials_info):
241
+ client = client_cls.from_service_account_info(
242
+ self._credentials_info, client_options=client_options
243
+ )
244
+ elif is_given(self._credentials_file):
245
+ client = client_cls.from_service_account_file(
246
+ self._credentials_file, client_options=client_options
247
+ )
248
+ else:
249
+ client = client_cls(client_options=client_options)
250
+ assert client is not None
251
+ return client
252
+
253
+ def _get_recognizer(self, client: SpeechAsyncClientV2) -> str:
90
254
  # TODO(theomonnom): should we use recognizers?
91
- # Recognizers may improve latency https://cloud.google.com/speech-to-text/v2/docs/recognizers#understand_recognizers
92
- return f"projects/{self._creds.project_id}/locations/global/recognizers/_" # type: ignore
255
+ # recognizers may improve latency https://cloud.google.com/speech-to-text/v2/docs/recognizers#understand_recognizers
93
256
 
94
- def _sanitize_options(
95
- self,
96
- *,
97
- language: str | None = None,
98
- ) -> STTOptions:
257
+ # TODO(theomonnom): find a better way to access the project_id
258
+ try:
259
+ project_id = client.transport._credentials.project_id # type: ignore
260
+ except AttributeError:
261
+ from google.auth import default as ga_default
262
+
263
+ _, project_id = ga_default() # type: ignore
264
+ return f"projects/{project_id}/locations/{self._location}/recognizers/_"
265
+
266
+ def _sanitize_options(self, *, language: NotGivenOr[str] = NOT_GIVEN) -> STTOptions:
99
267
  config = dataclasses.replace(self._config)
100
268
 
101
- if language:
269
+ if is_given(language):
102
270
  config.languages = [language]
103
271
 
104
272
  if not isinstance(config.languages, list):
105
273
  config.languages = [config.languages]
106
274
  elif not config.detect_language:
107
275
  if len(config.languages) > 1:
108
- logging.warning(
109
- "multiple languages provided, but language detection is disabled"
110
- )
276
+ logger.warning("multiple languages provided, but language detection is disabled")
111
277
  config.languages = [config.languages[0]]
112
278
 
113
279
  return config
114
280
 
115
- async def recognize(
281
+ def _build_recognition_config(
116
282
  self,
117
- *,
118
- buffer: AudioBuffer,
119
- language: SpeechLanguages | str | None = None,
120
- ) -> stt.SpeechEvent:
283
+ sample_rate: int,
284
+ num_channels: int,
285
+ language: NotGivenOr[SpeechLanguages | str] = NOT_GIVEN,
286
+ ) -> cloud_speech_v2.RecognitionConfig | cloud_speech_v1.RecognitionConfig:
121
287
  config = self._sanitize_options(language=language)
122
- buffer = agents.utils.merge_frames(buffer)
123
-
124
- config = cloud_speech.RecognitionConfig(
125
- explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
126
- encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
127
- sample_rate_hertz=buffer.sample_rate,
128
- audio_channel_count=buffer.num_channels,
129
- ),
130
- features=cloud_speech.RecognitionFeatures(
131
- enable_automatic_punctuation=config.punctuate,
132
- enable_spoken_punctuation=config.spoken_punctuation,
133
- ),
288
+ if self._config.version == 2:
289
+ return cloud_speech_v2.RecognitionConfig(
290
+ explicit_decoding_config=cloud_speech_v2.ExplicitDecodingConfig(
291
+ encoding=cloud_speech_v2.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
292
+ sample_rate_hertz=sample_rate,
293
+ audio_channel_count=num_channels,
294
+ ),
295
+ adaptation=config.build_adaptation(),
296
+ features=cloud_speech_v2.RecognitionFeatures(
297
+ enable_automatic_punctuation=config.punctuate,
298
+ enable_spoken_punctuation=config.spoken_punctuation,
299
+ enable_word_time_offsets=config.enable_word_time_offsets,
300
+ enable_word_confidence=config.enable_word_confidence,
301
+ ),
302
+ model=config.model,
303
+ language_codes=config.languages,
304
+ )
305
+ return cloud_speech_v1.RecognitionConfig(
306
+ encoding=cloud_speech_v1.RecognitionConfig.AudioEncoding.LINEAR16,
307
+ sample_rate_hertz=sample_rate,
308
+ audio_channel_count=num_channels,
309
+ adaptation=config.build_adaptation(),
310
+ language_code=config.languages[0],
311
+ alternative_language_codes=config.languages[1:],
312
+ enable_word_time_offsets=config.enable_word_time_offsets,
313
+ enable_word_confidence=config.enable_word_confidence,
314
+ enable_automatic_punctuation=config.punctuate,
315
+ enable_spoken_punctuation=config.spoken_punctuation,
134
316
  model=config.model,
135
- language_codes=config.languages,
136
317
  )
137
318
 
138
- return recognize_response_to_speech_event(
139
- await self._client.recognize(
140
- cloud_speech.RecognizeRequest(
141
- recognizer=self._recognizer,
142
- config=config,
143
- content=buffer.data.tobytes(),
144
- )
319
+ def _build_recognition_request(
320
+ self,
321
+ client: SpeechAsyncClientV2 | SpeechAsyncClientV1,
322
+ config: cloud_speech_v2.RecognitionConfig | cloud_speech_v1.RecognitionConfig,
323
+ content: bytes,
324
+ ) -> cloud_speech_v2.RecognizeRequest | cloud_speech_v1.RecognizeRequest:
325
+ if self._config.version == 2:
326
+ return cloud_speech_v2.RecognizeRequest(
327
+ recognizer=self._get_recognizer(cast(SpeechAsyncClientV2, client)),
328
+ config=config,
329
+ content=content,
145
330
  )
331
+
332
+ return cloud_speech_v1.RecognizeRequest(
333
+ config=config,
334
+ audio=cloud_speech_v1.RecognitionAudio(content=content),
335
+ )
336
+
337
+ async def _recognize_impl(
338
+ self,
339
+ buffer: utils.AudioBuffer,
340
+ *,
341
+ language: NotGivenOr[SpeechLanguages | str] = NOT_GIVEN,
342
+ conn_options: APIConnectOptions,
343
+ ) -> stt.SpeechEvent:
344
+ frame = rtc.combine_audio_frames(buffer)
345
+
346
+ config = self._build_recognition_config(
347
+ sample_rate=frame.sample_rate,
348
+ num_channels=frame.num_channels,
349
+ language=language,
146
350
  )
147
351
 
352
+ try:
353
+ async with self._pool.connection(timeout=conn_options.timeout) as client:
354
+ raw = await client.recognize(
355
+ self._build_recognition_request(client, config, frame.data.tobytes()),
356
+ timeout=conn_options.timeout,
357
+ )
358
+ return _recognize_response_to_speech_event(raw)
359
+ except DeadlineExceeded:
360
+ raise APITimeoutError() from None
361
+ except GoogleAPICallError as e:
362
+ raise APIStatusError(f"{e.message} {e.details}", status_code=e.code or -1) from e
363
+ except Exception as e:
364
+ raise APIConnectionError() from e
365
+
148
366
  def stream(
149
367
  self,
150
368
  *,
151
- language: SpeechLanguages | str | None = None,
152
- ) -> "SpeechStream":
369
+ language: NotGivenOr[SpeechLanguages | str] = NOT_GIVEN,
370
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
371
+ ) -> SpeechStream:
153
372
  config = self._sanitize_options(language=language)
154
- return SpeechStream(
155
- self._client,
156
- self._creds,
157
- self._recognizer,
158
- config,
373
+ stream = SpeechStream(
374
+ stt=self,
375
+ pool=self._pool,
376
+ recognizer_cb=self._get_recognizer,
377
+ config=config,
378
+ conn_options=conn_options,
159
379
  )
380
+ self._streams.add(stream)
381
+ return stream
382
+
383
+ def update_options(
384
+ self,
385
+ *,
386
+ languages: NotGivenOr[LanguageCode] = NOT_GIVEN,
387
+ detect_language: NotGivenOr[bool] = NOT_GIVEN,
388
+ interim_results: NotGivenOr[bool] = NOT_GIVEN,
389
+ punctuate: NotGivenOr[bool] = NOT_GIVEN,
390
+ spoken_punctuation: NotGivenOr[bool] = NOT_GIVEN,
391
+ model: NotGivenOr[SpeechModels] = NOT_GIVEN,
392
+ location: NotGivenOr[str] = NOT_GIVEN,
393
+ keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
394
+ ) -> None:
395
+ if is_given(languages):
396
+ if isinstance(languages, str):
397
+ languages = [languages]
398
+ self._config.languages = cast(list[LgType], languages)
399
+ if is_given(detect_language):
400
+ self._config.detect_language = detect_language
401
+ if is_given(interim_results):
402
+ self._config.interim_results = interim_results
403
+ if is_given(punctuate):
404
+ self._config.punctuate = punctuate
405
+ if is_given(spoken_punctuation):
406
+ self._config.spoken_punctuation = spoken_punctuation
407
+ if is_given(model):
408
+ old_version = self._config.version
409
+ self._config.model = model
410
+ if self._config.version != old_version:
411
+ self._pool.invalidate()
412
+
413
+ if is_given(location):
414
+ self._location = location
415
+ # if location is changed, fetch a new client and recognizer as per the new location
416
+ self._pool.invalidate()
417
+ if is_given(keywords):
418
+ self._config.keywords = keywords
419
+
420
+ for stream in self._streams:
421
+ stream.update_options(
422
+ languages=languages,
423
+ detect_language=detect_language,
424
+ interim_results=interim_results,
425
+ punctuate=punctuate,
426
+ spoken_punctuation=spoken_punctuation,
427
+ model=model,
428
+ keywords=keywords,
429
+ )
430
+
431
+ async def aclose(self) -> None:
432
+ await self._pool.aclose()
433
+ await super().aclose()
160
434
 
161
435
 
162
436
  class SpeechStream(stt.SpeechStream):
163
437
  def __init__(
164
438
  self,
165
- client: SpeechAsyncClient,
166
- creds: credentials.Credentials,
167
- recognizer: str,
439
+ *,
440
+ stt: STT,
441
+ conn_options: APIConnectOptions,
442
+ pool: utils.ConnectionPool[SpeechAsyncClientV2 | SpeechAsyncClientV1],
443
+ recognizer_cb: Callable[[SpeechAsyncClientV2], str],
168
444
  config: STTOptions,
169
- sample_rate: int = 24000,
170
- num_channels: int = 1,
171
- max_retry: int = 32,
172
445
  ) -> None:
173
- super().__init__()
446
+ super().__init__(stt=stt, conn_options=conn_options, sample_rate=config.sample_rate)
174
447
 
175
- self._client = client
176
- self._creds = creds
177
- self._recognizer = recognizer
448
+ self._pool = pool
449
+ self._recognizer_cb = recognizer_cb
178
450
  self._config = config
179
- self._sample_rate = sample_rate
180
- self._num_channels = num_channels
181
-
182
- self._queue = asyncio.Queue[rtc.AudioFrame | None]()
183
- self._event_queue = asyncio.Queue[stt.SpeechEvent | None]()
184
- self._closed = False
185
- self._main_task = asyncio.create_task(self._run(max_retry=max_retry))
186
-
187
- self._final_events: List[stt.SpeechEvent] = []
188
- self._speaking = False
189
-
190
- self._streaming_config = cloud_speech.StreamingRecognitionConfig(
191
- config=cloud_speech.RecognitionConfig(
192
- explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
193
- encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
194
- sample_rate_hertz=self._sample_rate,
195
- audio_channel_count=self._num_channels,
451
+ self._reconnect_event = asyncio.Event()
452
+ self._session_connected_at: float = 0
453
+
454
+ def update_options(
455
+ self,
456
+ *,
457
+ languages: NotGivenOr[LanguageCode] = NOT_GIVEN,
458
+ detect_language: NotGivenOr[bool] = NOT_GIVEN,
459
+ interim_results: NotGivenOr[bool] = NOT_GIVEN,
460
+ punctuate: NotGivenOr[bool] = NOT_GIVEN,
461
+ spoken_punctuation: NotGivenOr[bool] = NOT_GIVEN,
462
+ model: NotGivenOr[SpeechModels] = NOT_GIVEN,
463
+ min_confidence_threshold: NotGivenOr[float] = NOT_GIVEN,
464
+ keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
465
+ ) -> None:
466
+ if is_given(languages):
467
+ if isinstance(languages, str):
468
+ languages = [languages]
469
+ self._config.languages = cast(list[LgType], languages)
470
+ if is_given(detect_language):
471
+ self._config.detect_language = detect_language
472
+ if is_given(interim_results):
473
+ self._config.interim_results = interim_results
474
+ if is_given(punctuate):
475
+ self._config.punctuate = punctuate
476
+ if is_given(spoken_punctuation):
477
+ self._config.spoken_punctuation = spoken_punctuation
478
+ if is_given(model):
479
+ old_version = self._config.version
480
+ self._config.model = model
481
+ if self._config.version != old_version:
482
+ self._pool.invalidate()
483
+ if is_given(min_confidence_threshold):
484
+ self._config.min_confidence_threshold = min_confidence_threshold
485
+ if is_given(keywords):
486
+ self._config.keywords = keywords
487
+
488
+ self._reconnect_event.set()
489
+
490
+ def _build_streaming_config(
491
+ self,
492
+ ) -> cloud_speech_v2.StreamingRecognitionConfig | cloud_speech_v1.StreamingRecognitionConfig:
493
+ if self._config.version == 2:
494
+ return cloud_speech_v2.StreamingRecognitionConfig(
495
+ config=cloud_speech_v2.RecognitionConfig(
496
+ explicit_decoding_config=cloud_speech_v2.ExplicitDecodingConfig(
497
+ encoding=cloud_speech_v2.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
498
+ sample_rate_hertz=self._config.sample_rate,
499
+ audio_channel_count=1,
500
+ ),
501
+ adaptation=self._config.build_adaptation(),
502
+ language_codes=self._config.languages,
503
+ model=self._config.model,
504
+ features=cloud_speech_v2.RecognitionFeatures(
505
+ enable_automatic_punctuation=self._config.punctuate,
506
+ enable_word_time_offsets=self._config.enable_word_time_offsets,
507
+ enable_spoken_punctuation=self._config.spoken_punctuation,
508
+ enable_word_confidence=self._config.enable_word_confidence,
509
+ ),
196
510
  ),
197
- language_codes=self._config.languages,
198
- model=self._config.model,
199
- features=cloud_speech.RecognitionFeatures(
200
- enable_automatic_punctuation=self._config.punctuate,
511
+ streaming_features=cloud_speech_v2.StreamingRecognitionFeatures(
512
+ interim_results=self._config.interim_results,
513
+ enable_voice_activity_events=self._config.enable_voice_activity_events,
201
514
  ),
515
+ )
516
+
517
+ return cloud_speech_v1.StreamingRecognitionConfig(
518
+ config=cloud_speech_v1.RecognitionConfig(
519
+ encoding=cloud_speech_v1.RecognitionConfig.AudioEncoding.LINEAR16,
520
+ sample_rate_hertz=self._config.sample_rate,
521
+ audio_channel_count=1,
522
+ adaptation=self._config.build_adaptation(),
523
+ language_code=self._config.languages[0],
524
+ alternative_language_codes=self._config.languages[1:],
525
+ enable_word_time_offsets=self._config.enable_word_time_offsets,
526
+ enable_word_confidence=self._config.enable_word_confidence,
527
+ enable_automatic_punctuation=self._config.punctuate,
528
+ enable_spoken_punctuation=self._config.spoken_punctuation,
529
+ model=self._config.model,
202
530
  ),
203
- streaming_features=cloud_speech.StreamingRecognitionFeatures(
204
- enable_voice_activity_events=True,
205
- interim_results=self._config.interim_results,
206
- ),
531
+ interim_results=self._config.interim_results,
532
+ enable_voice_activity_events=self._config.enable_voice_activity_events,
207
533
  )
208
534
 
209
- def log_exception(task: asyncio.Task) -> None:
210
- if not task.cancelled() and task.exception():
211
- logging.error(f"google stt task failed: {task.exception()}")
212
-
213
- self._main_task.add_done_callback(log_exception)
214
-
215
- def push_frame(self, frame: rtc.AudioFrame) -> None:
216
- if self._closed:
217
- raise ValueError("cannot push frame to closed stream")
218
-
219
- self._queue.put_nowait(frame)
220
-
221
- async def aclose(self, wait: bool = True) -> None:
222
- self._closed = True
223
- if not wait:
224
- self._main_task.cancel()
225
-
226
- self._queue.put_nowait(None)
227
- with contextlib.suppress(asyncio.CancelledError):
228
- await self._main_task
535
+ def _build_init_request(
536
+ self,
537
+ client: SpeechAsyncClientV2 | SpeechAsyncClientV1,
538
+ ) -> cloud_speech_v2.StreamingRecognizeRequest | cloud_speech_v1.StreamingRecognizeRequest:
539
+ if self._config.version == 2:
540
+ return cloud_speech_v2.StreamingRecognizeRequest(
541
+ recognizer=self._recognizer_cb(cast(SpeechAsyncClientV2, client)),
542
+ streaming_config=self._streaming_config,
543
+ )
544
+ return cloud_speech_v1.StreamingRecognizeRequest(
545
+ streaming_config=self._streaming_config,
546
+ )
229
547
 
230
- async def _run(self, max_retry: int) -> None:
231
- retry_count = 0
232
- try:
233
- while not self._closed:
234
- try:
235
- # google requires a async generator when calling streaming_recognize
236
- # this function basically convert the queue into a async generator
237
- async def input_generator():
238
- try:
239
- # first request should contain the config
240
- yield cloud_speech.StreamingRecognizeRequest(
241
- recognizer=self._recognizer,
242
- streaming_config=self._streaming_config,
548
+ def _build_audio_request(
549
+ self,
550
+ frame: rtc.AudioFrame,
551
+ ) -> cloud_speech_v2.StreamingRecognizeRequest | cloud_speech_v1.StreamingRecognizeRequest:
552
+ if self._config.version == 2:
553
+ return cloud_speech_v2.StreamingRecognizeRequest(audio=frame.data.tobytes())
554
+ return cloud_speech_v1.StreamingRecognizeRequest(audio_content=frame.data.tobytes())
555
+
556
+ async def _run(self) -> None:
557
+ audio_pushed = False
558
+
559
+ # google requires a async generator when calling streaming_recognize
560
+ # this function basically convert the queue into a async generator
561
+ async def input_generator(
562
+ client: SpeechAsyncClientV2 | SpeechAsyncClientV1, should_stop: asyncio.Event
563
+ ) -> AsyncGenerator[
564
+ cloud_speech_v2.StreamingRecognizeRequest | cloud_speech_v1.StreamingRecognizeRequest,
565
+ None,
566
+ ]:
567
+ nonlocal audio_pushed
568
+ try:
569
+ yield self._build_init_request(client)
570
+
571
+ async for frame in self._input_ch:
572
+ # when the stream is aborted due to reconnect, this input_generator
573
+ # needs to stop consuming frames
574
+ # when the generator stops, the previous gRPC stream will close
575
+ if should_stop.is_set():
576
+ return
577
+
578
+ if isinstance(frame, rtc.AudioFrame):
579
+ yield self._build_audio_request(frame)
580
+ if not audio_pushed:
581
+ audio_pushed = True
582
+
583
+ except Exception:
584
+ logger.exception("an error occurred while streaming input to google STT")
585
+
586
+ async def process_stream(
587
+ client: SpeechAsyncClientV2 | SpeechAsyncClientV1,
588
+ stream: AsyncIterable[
589
+ cloud_speech_v2.StreamingRecognizeResponse
590
+ | cloud_speech_v1.StreamingRecognizeResponse
591
+ ],
592
+ ) -> None:
593
+ has_started = False
594
+ async for resp in stream:
595
+ if resp.speech_event_type == (
596
+ cloud_speech_v2.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN
597
+ if self._config.version == 2
598
+ else cloud_speech_v1.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN
599
+ ):
600
+ self._event_ch.send_nowait(
601
+ stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
602
+ )
603
+ has_started = True
604
+
605
+ if resp.speech_event_type == (
606
+ cloud_speech_v2.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_TYPE_UNSPECIFIED
607
+ if self._config.version == 2
608
+ else cloud_speech_v1.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_UNSPECIFIED
609
+ ):
610
+ result = resp.results[0]
611
+ speech_data = _streaming_recognize_response_to_speech_data(
612
+ resp,
613
+ min_confidence_threshold=self._config.min_confidence_threshold,
614
+ start_time_offset=self.start_time_offset,
615
+ )
616
+ if speech_data is None:
617
+ continue
618
+
619
+ if not result.is_final:
620
+ self._event_ch.send_nowait(
621
+ stt.SpeechEvent(
622
+ type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
623
+ alternatives=[speech_data],
243
624
  )
244
- while True:
245
- frame = (
246
- await self._queue.get()
247
- ) # wait for a new rtc.AudioFrame
248
- if frame is None:
249
- break # None is sent inside aclose
250
-
251
- self._queue.task_done()
252
- frame = frame.remix_and_resample(
253
- self._sample_rate, self._num_channels
254
- )
255
- yield cloud_speech.StreamingRecognizeRequest(
256
- audio=frame.data.tobytes(),
257
- )
258
- except Exception as e:
259
- logging.error(
260
- f"an error occurred while streaming inputs: {e}"
625
+ )
626
+ else:
627
+ self._event_ch.send_nowait(
628
+ stt.SpeechEvent(
629
+ type=stt.SpeechEventType.FINAL_TRANSCRIPT,
630
+ alternatives=[speech_data],
261
631
  )
262
-
263
- # try to connect
264
- stream = await self._client.streaming_recognize(
265
- requests=input_generator()
266
- )
267
- retry_count = 0 # connection successful, reset retry count
268
-
269
- await self._run_stream(stream)
270
- except Exception as e:
271
- if retry_count >= max_retry:
272
- logging.error(
273
- f"failed to connect to google stt after {max_retry} tries",
274
- exc_info=e,
275
632
  )
276
- break
277
-
278
- retry_delay = min(retry_count * 2, 10) # max 10s
279
- retry_count += 1
280
- logging.warning(
281
- f"google stt connection failed, retrying in {retry_delay}s",
282
- exc_info=e,
633
+ if time.time() - self._session_connected_at > _max_session_duration:
634
+ logger.debug(
635
+ "Google STT maximum connection time reached. Reconnecting..."
636
+ )
637
+ self._pool.remove(client)
638
+ if has_started:
639
+ self._event_ch.send_nowait(
640
+ stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
641
+ )
642
+ has_started = False
643
+ self._reconnect_event.set()
644
+ return
645
+
646
+ if resp.speech_event_type == (
647
+ cloud_speech_v2.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END
648
+ if self._config.version == 2
649
+ else cloud_speech_v1.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END
650
+ ):
651
+ self._event_ch.send_nowait(
652
+ stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
283
653
  )
284
- await asyncio.sleep(retry_delay)
285
- finally:
286
- self._event_queue.put_nowait(None)
654
+ has_started = False
287
655
 
288
- async def _run_stream(
289
- self, stream: AsyncIterable[cloud_speech.StreamingRecognizeResponse]
290
- ):
291
- async for resp in stream:
292
- if (
293
- resp.speech_event_type
294
- == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN
295
- ):
296
- self._speaking = True
297
- start_event = stt.SpeechEvent(
298
- type=stt.SpeechEventType.START_OF_SPEECH,
299
- )
300
- self._event_queue.put_nowait(start_event)
301
-
302
- if (
303
- resp.speech_event_type
304
- == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_TYPE_UNSPECIFIED
305
- ):
306
- result = resp.results[0]
307
- if not result.is_final:
308
- # interim results
309
- iterim_event = stt.SpeechEvent(
310
- type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
311
- alternatives=streaming_recognize_response_to_speech_data(resp),
312
- )
313
- self._event_queue.put_nowait(iterim_event)
656
+ while True:
657
+ audio_pushed = False
658
+ try:
659
+ async with self._pool.connection(timeout=self._conn_options.timeout) as client:
660
+ self._streaming_config = self._build_streaming_config()
314
661
 
315
- else:
316
- final_event = stt.SpeechEvent(
317
- type=stt.SpeechEventType.FINAL_TRANSCRIPT,
318
- alternatives=streaming_recognize_response_to_speech_data(resp),
662
+ should_stop = asyncio.Event()
663
+ stream = await client.streaming_recognize(
664
+ requests=input_generator(client, should_stop),
319
665
  )
320
- self._final_events.append(final_event)
321
- self._event_queue.put_nowait(final_event)
322
-
323
- if not self._speaking:
324
- # With Google STT, we receive the final event after the END_OF_SPEECH event
325
- sentence = ""
326
- confidence = 0.0
327
- for alt in self._final_events:
328
- sentence += f"{alt.alternatives[0].text.strip()} "
329
- confidence += alt.alternatives[0].confidence
330
-
331
- sentence = sentence.rstrip()
332
- confidence /= len(self._final_events) # avg. of confidence
333
-
334
- end_event = stt.SpeechEvent(
335
- type=stt.SpeechEventType.END_OF_SPEECH,
336
- alternatives=[
337
- stt.SpeechData(
338
- language=result.language_code,
339
- start_time=self._final_events[0]
340
- .alternatives[0]
341
- .start_time,
342
- end_time=self._final_events[-1]
343
- .alternatives[0]
344
- .end_time,
345
- confidence=confidence,
346
- text=sentence,
347
- )
348
- ],
666
+ self._session_connected_at = time.time()
667
+
668
+ process_stream_task = asyncio.create_task(process_stream(client, stream))
669
+ wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
670
+
671
+ try:
672
+ done, _ = await asyncio.wait(
673
+ [process_stream_task, wait_reconnect_task],
674
+ return_when=asyncio.FIRST_COMPLETED,
349
675
  )
676
+ for task in done:
677
+ if task != wait_reconnect_task:
678
+ task.result()
679
+ if wait_reconnect_task not in done:
680
+ break
681
+ self._reconnect_event.clear()
682
+ finally:
683
+ should_stop.set()
684
+ if not process_stream_task.done() and not wait_reconnect_task.done():
685
+ # try to gracefully stop the process_stream_task
686
+ try:
687
+ await asyncio.wait_for(process_stream_task, timeout=1.0)
688
+ except asyncio.TimeoutError:
689
+ pass
690
+
691
+ await utils.aio.gracefully_cancel(process_stream_task, wait_reconnect_task)
692
+ except DeadlineExceeded:
693
+ raise APITimeoutError() from None
694
+ except GoogleAPICallError as e:
695
+ if e.code == 409:
696
+ if audio_pushed:
697
+ logger.debug("stream timed out, restarting.")
698
+ else:
699
+ raise APIStatusError(
700
+ f"{e.message} {e.details}", status_code=e.code or -1
701
+ ) from e
702
+ except Exception as e:
703
+ raise APIConnectionError() from e
350
704
 
351
- self._final_events = []
352
- self._event_queue.put_nowait(end_event)
353
705
 
354
- if (
355
- resp.speech_event_type
356
- == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END
357
- ):
358
- self._speaking = False
706
+ def _duration_to_seconds(duration: Duration | timedelta) -> float:
707
+ # Proto Plus may auto-convert Duration to timedelta; handle both.
708
+ # https://proto-plus-python.readthedocs.io/en/latest/marshal.html
709
+ if isinstance(duration, timedelta):
710
+ return duration.total_seconds()
711
+ return duration.seconds + duration.nanos / 1e9
359
712
 
360
- async def __anext__(self) -> stt.SpeechEvent:
361
- evt = await self._event_queue.get()
362
- if evt is None:
363
- raise StopAsyncIteration
364
713
 
365
- return evt
714
+ def _get_start_time(word: cloud_speech_v2.WordInfo | cloud_speech_v1.WordInfo) -> float:
715
+ if hasattr(word, "start_offset"):
716
+ return _duration_to_seconds(word.start_offset)
717
+ return _duration_to_seconds(word.start_time)
366
718
 
367
719
 
368
- def recognize_response_to_speech_event(
369
- resp: cloud_speech.RecognizeResponse,
720
+ def _get_end_time(word: cloud_speech_v2.WordInfo | cloud_speech_v1.WordInfo) -> float:
721
+ if hasattr(word, "end_offset"):
722
+ return _duration_to_seconds(word.end_offset)
723
+ return _duration_to_seconds(word.end_time)
724
+
725
+
726
+ def _recognize_response_to_speech_event(
727
+ resp: cloud_speech_v2.RecognizeResponse | cloud_speech_v1.RecognizeResponse,
370
728
  ) -> stt.SpeechEvent:
371
- result = resp.results[0]
372
- gg_alts = result.alternatives
373
- return stt.SpeechEvent(
374
- type=stt.SpeechEventType.FINAL_TRANSCRIPT,
375
- alternatives=[
729
+ text = ""
730
+ confidence = 0.0
731
+ for result in resp.results:
732
+ text += result.alternatives[0].transcript
733
+ confidence += result.alternatives[0].confidence
734
+
735
+ alternatives = []
736
+
737
+ # Google STT may return empty results when spoken_lang != stt_lang
738
+ if resp.results:
739
+ try:
740
+ start_time = _get_start_time(resp.results[0].alternatives[0].words[0])
741
+ end_time = _get_end_time(resp.results[-1].alternatives[0].words[-1])
742
+ except IndexError:
743
+ # When enable_word_time_offsets=False, there are no "words" to access
744
+ start_time = end_time = 0
745
+
746
+ confidence /= len(resp.results)
747
+ lg = resp.results[0].language_code
748
+
749
+ alternatives = [
376
750
  stt.SpeechData(
377
- language=result.language_code,
378
- start_time=alt.words[0].start_offset.seconds if alt.words else 0,
379
- end_time=alt.words[-1].end_offset.seconds if alt.words else 0,
380
- confidence=alt.confidence,
381
- text=alt.transcript,
751
+ language=lg,
752
+ start_time=start_time,
753
+ end_time=end_time,
754
+ confidence=confidence,
755
+ text=text,
756
+ words=[
757
+ TimedString(
758
+ text=word.word,
759
+ start_time=_get_start_time(word),
760
+ end_time=_get_end_time(word),
761
+ )
762
+ for word in resp.results[0].alternatives[0].words
763
+ ]
764
+ if resp.results[0].alternatives[0].words
765
+ else None,
766
+ )
767
+ ]
768
+
769
+ return stt.SpeechEvent(type=stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives=alternatives)
770
+
771
+
772
+ @utils.log_exceptions(logger=logger)
773
+ def _streaming_recognize_response_to_speech_data(
774
+ resp: cloud_speech_v2.StreamingRecognizeResponse | cloud_speech_v1.StreamingRecognizeResponse,
775
+ *,
776
+ min_confidence_threshold: float,
777
+ start_time_offset: float,
778
+ ) -> stt.SpeechData | None:
779
+ text = ""
780
+ confidence = 0.0
781
+ final_result = None
782
+ words: list[cloud_speech_v2.WordInfo | cloud_speech_v1.WordInfo] = []
783
+ for result in resp.results:
784
+ if len(result.alternatives) == 0:
785
+ continue
786
+ else:
787
+ if result.is_final:
788
+ final_result = result
789
+ break
790
+ else:
791
+ text += result.alternatives[0].transcript
792
+ confidence += result.alternatives[0].confidence
793
+ words.extend(result.alternatives[0].words)
794
+
795
+ if final_result is not None:
796
+ text = final_result.alternatives[0].transcript
797
+ confidence = final_result.alternatives[0].confidence
798
+ words = list(final_result.alternatives[0].words)
799
+ lg = final_result.language_code
800
+ else:
801
+ confidence /= len(resp.results)
802
+ if confidence < min_confidence_threshold:
803
+ return None
804
+ lg = resp.results[0].language_code
805
+
806
+ if text == "" or not words:
807
+ if text and not words:
808
+ data = stt.SpeechData(
809
+ language=lg,
810
+ start_time=start_time_offset,
811
+ end_time=start_time_offset,
812
+ confidence=confidence,
813
+ text=text,
382
814
  )
383
- for alt in gg_alts
815
+ return data
816
+ return None
817
+
818
+ data = stt.SpeechData(
819
+ language=lg,
820
+ start_time=_get_start_time(words[0]) + start_time_offset,
821
+ end_time=_get_end_time(words[-1]) + start_time_offset,
822
+ confidence=confidence,
823
+ text=text,
824
+ words=[
825
+ TimedString(
826
+ text=word.word,
827
+ start_time=_get_start_time(word) + start_time_offset,
828
+ end_time=_get_end_time(word) + start_time_offset,
829
+ start_time_offset=start_time_offset,
830
+ confidence=word.confidence,
831
+ )
832
+ for word in words
384
833
  ],
385
834
  )
386
835
 
387
-
388
- def streaming_recognize_response_to_speech_data(
389
- resp: cloud_speech.StreamingRecognizeResponse,
390
- ) -> List[stt.SpeechData]:
391
- result = resp.results[0]
392
- gg_alts = result.alternatives
393
- return [
394
- stt.SpeechData(
395
- language=result.language_code,
396
- start_time=alt.words[0].start_offset.seconds if alt.words else 0,
397
- end_time=alt.words[-1].end_offset.seconds if alt.words else 0,
398
- confidence=alt.confidence,
399
- text=alt.transcript,
400
- )
401
- for alt in gg_alts
402
- ]
836
+ return data