livekit-plugins-google 0.11.1__py3-none-any.whl → 1.0.0.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,45 +15,43 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import asyncio
19
18
  import json
20
19
  import os
21
20
  from dataclasses import dataclass
22
- from typing import Any, Literal, MutableSet, Union, cast
23
-
24
- from livekit.agents import (
25
- APIConnectionError,
26
- APIStatusError,
27
- llm,
28
- utils,
29
- )
30
- from livekit.agents.llm import LLMCapabilities, ToolChoice, _create_ai_function_info
31
- from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
21
+ from typing import Any, Literal, cast
32
22
 
33
23
  from google import genai
34
24
  from google.auth._default_async import default_async
35
25
  from google.genai import types
36
26
  from google.genai.errors import APIError, ClientError, ServerError
27
+ from livekit.agents import APIConnectionError, APIStatusError, llm, utils
28
+ from livekit.agents.llm import FunctionTool, ToolChoice
29
+ from livekit.agents.types import (
30
+ DEFAULT_API_CONNECT_OPTIONS,
31
+ NOT_GIVEN,
32
+ APIConnectOptions,
33
+ NotGivenOr,
34
+ )
35
+ from livekit.agents.utils import is_given
37
36
 
38
- from ._utils import _build_gemini_ctx, _build_tools
39
37
  from .log import logger
40
38
  from .models import ChatModels
39
+ from .utils import to_chat_ctx, to_fnc_ctx
41
40
 
42
41
 
43
42
  @dataclass
44
- class LLMOptions:
43
+ class _LLMOptions:
45
44
  model: ChatModels | str
46
- temperature: float | None
47
- tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto"
48
- vertexai: bool = False
49
- project: str | None = None
50
- location: str | None = None
51
- candidate_count: int = 1
52
- max_output_tokens: int | None = None
53
- top_p: float | None = None
54
- top_k: float | None = None
55
- presence_penalty: float | None = None
56
- frequency_penalty: float | None = None
45
+ temperature: NotGivenOr[float]
46
+ tool_choice: NotGivenOr[ToolChoice | Literal["auto", "required", "none"]]
47
+ vertexai: NotGivenOr[bool]
48
+ project: NotGivenOr[str]
49
+ location: NotGivenOr[str]
50
+ max_output_tokens: NotGivenOr[int]
51
+ top_p: NotGivenOr[float]
52
+ top_k: NotGivenOr[float]
53
+ presence_penalty: NotGivenOr[float]
54
+ frequency_penalty: NotGivenOr[float]
57
55
 
58
56
 
59
57
  class LLM(llm.LLM):
@@ -61,18 +59,17 @@ class LLM(llm.LLM):
61
59
  self,
62
60
  *,
63
61
  model: ChatModels | str = "gemini-2.0-flash-001",
64
- api_key: str | None = None,
65
- vertexai: bool = False,
66
- project: str | None = None,
67
- location: str | None = None,
68
- candidate_count: int = 1,
69
- temperature: float = 0.8,
70
- max_output_tokens: int | None = None,
71
- top_p: float | None = None,
72
- top_k: float | None = None,
73
- presence_penalty: float | None = None,
74
- frequency_penalty: float | None = None,
75
- tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
62
+ api_key: NotGivenOr[str] = NOT_GIVEN,
63
+ vertexai: NotGivenOr[bool] = False,
64
+ project: NotGivenOr[str] = NOT_GIVEN,
65
+ location: NotGivenOr[str] = NOT_GIVEN,
66
+ temperature: NotGivenOr[float] = NOT_GIVEN,
67
+ max_output_tokens: NotGivenOr[int] = NOT_GIVEN,
68
+ top_p: NotGivenOr[float] = NOT_GIVEN,
69
+ top_k: NotGivenOr[float] = NOT_GIVEN,
70
+ presence_penalty: NotGivenOr[float] = NOT_GIVEN,
71
+ frequency_penalty: NotGivenOr[float] = NOT_GIVEN,
72
+ tool_choice: NotGivenOr[ToolChoice | Literal["auto", "required", "none"]] = NOT_GIVEN,
76
73
  ) -> None:
77
74
  """
78
75
  Create a new instance of Google GenAI LLM.
@@ -90,7 +87,6 @@ class LLM(llm.LLM):
90
87
  vertexai (bool, optional): Whether to use VertexAI. Defaults to False.
91
88
  project (str, optional): The Google Cloud project to use (only for VertexAI). Defaults to None.
92
89
  location (str, optional): The location to use for VertexAI API requests. Defaults value is "us-central1".
93
- candidate_count (int, optional): Number of candidate responses to generate. Defaults to 1.
94
90
  temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
95
91
  max_output_tokens (int, optional): Maximum number of tokens to generate in the output. Defaults to None.
96
92
  top_p (float, optional): The nucleus sampling probability for response generation. Defaults to None.
@@ -99,16 +95,9 @@ class LLM(llm.LLM):
99
95
  frequency_penalty (float, optional): Penalizes the model for repeating words. Defaults to None.
100
96
  tool_choice (ToolChoice or Literal["auto", "required", "none"], optional): Specifies whether to use tools during response generation. Defaults to "auto".
101
97
  """
102
- super().__init__(
103
- capabilities=LLMCapabilities(
104
- supports_choices_on_int=False,
105
- requires_persistent_functions=False,
106
- )
107
- )
98
+ super().__init__()
108
99
  self._project_id = project or os.environ.get("GOOGLE_CLOUD_PROJECT", None)
109
- self._location = location or os.environ.get(
110
- "GOOGLE_CLOUD_LOCATION", "us-central1"
111
- )
100
+ self._location = location or os.environ.get("GOOGLE_CLOUD_LOCATION", "us-central1")
112
101
  self._api_key = api_key or os.environ.get("GOOGLE_API_KEY", None)
113
102
  _gac = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
114
103
  if _gac is None:
@@ -131,14 +120,13 @@ class LLM(llm.LLM):
131
120
  "API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable"
132
121
  )
133
122
 
134
- self._opts = LLMOptions(
123
+ self._opts = _LLMOptions(
135
124
  model=model,
136
125
  temperature=temperature,
137
126
  tool_choice=tool_choice,
138
127
  vertexai=vertexai,
139
128
  project=project,
140
129
  location=location,
141
- candidate_count=candidate_count,
142
130
  max_output_tokens=max_output_tokens,
143
131
  top_p=top_p,
144
132
  top_k=top_k,
@@ -151,41 +139,77 @@ class LLM(llm.LLM):
151
139
  project=self._project_id,
152
140
  location=self._location,
153
141
  )
154
- self._running_fncs: MutableSet[asyncio.Task[Any]] = set()
155
142
 
156
143
  def chat(
157
144
  self,
158
145
  *,
159
146
  chat_ctx: llm.ChatContext,
147
+ tools: list[FunctionTool] | None = None,
160
148
  conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
161
- fnc_ctx: llm.FunctionContext | None = None,
162
- temperature: float | None = None,
163
- n: int | None = 1,
164
- parallel_tool_calls: bool | None = None,
165
- tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
166
- | None = None,
167
- ) -> "LLMStream":
168
- if tool_choice is None:
169
- tool_choice = self._opts.tool_choice
170
-
171
- if temperature is None:
172
- temperature = self._opts.temperature
149
+ parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
150
+ tool_choice: NotGivenOr[ToolChoice | Literal["auto", "required", "none"]] = NOT_GIVEN,
151
+ extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
152
+ ) -> LLMStream:
153
+ extra = {}
154
+
155
+ if is_given(extra_kwargs):
156
+ extra.update(extra_kwargs)
157
+
158
+ tool_choice = tool_choice if is_given(tool_choice) else self._opts.tool_choice
159
+ if is_given(tool_choice):
160
+ gemini_tool_choice: types.ToolConfig
161
+ if isinstance(tool_choice, ToolChoice):
162
+ gemini_tool_choice = types.ToolConfig(
163
+ function_calling_config=types.FunctionCallingConfig(
164
+ mode="ANY",
165
+ allowed_function_names=[tool_choice["function"]["name"]],
166
+ )
167
+ )
168
+ extra["tool_config"] = gemini_tool_choice
169
+ elif tool_choice == "required":
170
+ gemini_tool_choice = types.ToolConfig(
171
+ function_calling_config=types.FunctionCallingConfig(
172
+ mode="ANY",
173
+ allowed_function_names=[fnc.name for fnc in tools],
174
+ )
175
+ )
176
+ extra["tool_config"] = gemini_tool_choice
177
+ elif tool_choice == "auto":
178
+ gemini_tool_choice = types.ToolConfig(
179
+ function_calling_config=types.FunctionCallingConfig(
180
+ mode="AUTO",
181
+ )
182
+ )
183
+ extra["tool_config"] = gemini_tool_choice
184
+ elif tool_choice == "none":
185
+ gemini_tool_choice = types.ToolConfig(
186
+ function_calling_config=types.FunctionCallingConfig(
187
+ mode="NONE",
188
+ )
189
+ )
190
+ extra["tool_config"] = gemini_tool_choice
191
+
192
+ if is_given(self._opts.temperature):
193
+ extra["temperature"] = self._opts.temperature
194
+ if is_given(self._opts.max_output_tokens):
195
+ extra["max_output_tokens"] = self._opts.max_output_tokens
196
+ if is_given(self._opts.top_p):
197
+ extra["top_p"] = self._opts.top_p
198
+ if is_given(self._opts.top_k):
199
+ extra["top_k"] = self._opts.top_k
200
+ if is_given(self._opts.presence_penalty):
201
+ extra["presence_penalty"] = self._opts.presence_penalty
202
+ if is_given(self._opts.frequency_penalty):
203
+ extra["frequency_penalty"] = self._opts.frequency_penalty
173
204
 
174
205
  return LLMStream(
175
206
  self,
176
207
  client=self._client,
177
208
  model=self._opts.model,
178
- max_output_tokens=self._opts.max_output_tokens,
179
- top_p=self._opts.top_p,
180
- top_k=self._opts.top_k,
181
- presence_penalty=self._opts.presence_penalty,
182
- frequency_penalty=self._opts.frequency_penalty,
183
209
  chat_ctx=chat_ctx,
184
- fnc_ctx=fnc_ctx,
210
+ tools=tools,
185
211
  conn_options=conn_options,
186
- n=n,
187
- temperature=temperature,
188
- tool_choice=tool_choice,
212
+ extra_kwargs=extra,
189
213
  )
190
214
 
191
215
 
@@ -198,96 +222,37 @@ class LLMStream(llm.LLMStream):
198
222
  model: str | ChatModels,
199
223
  chat_ctx: llm.ChatContext,
200
224
  conn_options: APIConnectOptions,
201
- fnc_ctx: llm.FunctionContext | None,
202
- temperature: float | None,
203
- n: int | None,
204
- max_output_tokens: int | None,
205
- top_p: float | None,
206
- top_k: float | None,
207
- presence_penalty: float | None,
208
- frequency_penalty: float | None,
209
- tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]],
225
+ tools: list[FunctionTool] | None,
226
+ extra_kwargs: dict[str, Any],
210
227
  ) -> None:
211
- super().__init__(
212
- llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
213
- )
228
+ super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
214
229
  self._client = client
215
230
  self._model = model
216
231
  self._llm: LLM = llm
217
- self._max_output_tokens = max_output_tokens
218
- self._top_p = top_p
219
- self._top_k = top_k
220
- self._presence_penalty = presence_penalty
221
- self._frequency_penalty = frequency_penalty
222
- self._temperature = temperature
223
- self._n = n
224
- self._tool_choice = tool_choice
232
+ self._extra_kwargs = extra_kwargs
225
233
 
226
234
  async def _run(self) -> None:
227
235
  retryable = True
228
236
  request_id = utils.shortuuid()
229
237
 
230
238
  try:
231
- opts: dict[str, Any] = dict()
232
- turns, system_instruction = _build_gemini_ctx(self._chat_ctx, id(self))
233
-
234
- if self._fnc_ctx and len(self._fnc_ctx.ai_functions) > 0:
235
- functions = _build_tools(self._fnc_ctx)
236
- opts["tools"] = [types.Tool(function_declarations=functions)]
237
-
238
- if self._tool_choice is not None:
239
- if isinstance(self._tool_choice, ToolChoice):
240
- # specific function
241
- tool_config = types.ToolConfig(
242
- function_calling_config=types.FunctionCallingConfig(
243
- mode=types.FunctionCallingConfigMode.ANY,
244
- allowed_function_names=[self._tool_choice.name],
245
- )
246
- )
247
- elif self._tool_choice == "required":
248
- # model must call any function
249
- tool_config = types.ToolConfig(
250
- function_calling_config=types.FunctionCallingConfig(
251
- mode=types.FunctionCallingConfigMode.ANY,
252
- allowed_function_names=[
253
- fnc.name
254
- for fnc in self._fnc_ctx.ai_functions.values()
255
- ],
256
- )
257
- )
258
- elif self._tool_choice == "auto":
259
- # model can call any function
260
- tool_config = types.ToolConfig(
261
- function_calling_config=types.FunctionCallingConfig(
262
- mode=types.FunctionCallingConfigMode.AUTO
263
- )
264
- )
265
- elif self._tool_choice == "none":
266
- # model cannot call any function
267
- tool_config = types.ToolConfig(
268
- function_calling_config=types.FunctionCallingConfig(
269
- mode=types.FunctionCallingConfigMode.NONE,
270
- )
271
- )
272
- opts["tool_config"] = tool_config
239
+ turns, system_instruction = to_chat_ctx(self._chat_ctx, id(self._llm))
273
240
 
241
+ self._extra_kwargs["tools"] = [
242
+ types.Tool(function_declarations=to_fnc_ctx(self._tools))
243
+ ]
274
244
  config = types.GenerateContentConfig(
275
- candidate_count=self._n,
276
- temperature=self._temperature,
277
- max_output_tokens=self._max_output_tokens,
278
- top_p=self._top_p,
279
- top_k=self._top_k,
280
- presence_penalty=self._presence_penalty,
281
- frequency_penalty=self._frequency_penalty,
282
245
  system_instruction=system_instruction,
283
- **opts,
246
+ **self._extra_kwargs,
284
247
  )
248
+
285
249
  stream = await self._client.aio.models.generate_content_stream(
286
250
  model=self._model,
287
251
  contents=cast(types.ContentListUnion, turns),
288
252
  config=config,
289
253
  )
290
- async for response in stream: # type: ignore
254
+
255
+ async for response in stream:
291
256
  if response.prompt_feedback:
292
257
  raise APIStatusError(
293
258
  response.prompt_feedback.json(),
@@ -311,8 +276,8 @@ class LLMStream(llm.LLMStream):
311
276
  "gemini llm: there are multiple candidates in the response, returning response from the first one."
312
277
  )
313
278
 
314
- for index, part in enumerate(response.candidates[0].content.parts):
315
- chat_chunk = self._parse_part(request_id, index, part)
279
+ for part in response.candidates[0].content.parts:
280
+ chat_chunk = self._parse_part(request_id, part)
316
281
  if chat_chunk is not None:
317
282
  retryable = False
318
283
  self._event_ch.send_nowait(chat_chunk)
@@ -321,7 +286,7 @@ class LLMStream(llm.LLMStream):
321
286
  usage = response.usage_metadata
322
287
  self._event_ch.send_nowait(
323
288
  llm.ChatChunk(
324
- request_id=request_id,
289
+ id=request_id,
325
290
  usage=llm.CompletionUsage(
326
291
  completion_tokens=usage.candidates_token_count or 0,
327
292
  prompt_tokens=usage.prompt_token_count or 0,
@@ -329,6 +294,7 @@ class LLMStream(llm.LLMStream):
329
294
  ),
330
295
  )
331
296
  )
297
+
332
298
  except ClientError as e:
333
299
  raise APIStatusError(
334
300
  "gemini llm: client error",
@@ -359,61 +325,25 @@ class LLMStream(llm.LLMStream):
359
325
  retryable=retryable,
360
326
  ) from e
361
327
 
362
- def _parse_part(
363
- self, id: str, index: int, part: types.Part
364
- ) -> llm.ChatChunk | None:
328
+ def _parse_part(self, id: str, part: types.Part) -> llm.ChatChunk | None:
365
329
  if part.function_call:
366
- return self._try_build_function(id, index, part)
367
-
368
- return llm.ChatChunk(
369
- request_id=id,
370
- choices=[
371
- llm.Choice(
372
- delta=llm.ChoiceDelta(content=part.text, role="assistant"),
373
- index=index,
374
- )
375
- ],
376
- )
377
-
378
- def _try_build_function(
379
- self, id: str, index: int, part: types.Part
380
- ) -> llm.ChatChunk | None:
381
- if part.function_call is None:
382
- logger.warning("gemini llm: no function call in the response")
383
- return None
384
-
385
- if part.function_call.name is None:
386
- logger.warning("gemini llm: no function name in the response")
387
- return None
388
-
389
- if part.function_call.id is None:
390
- part.function_call.id = utils.shortuuid()
391
-
392
- if self._fnc_ctx is None:
393
- logger.warning(
394
- "google stream tried to run function without function context"
330
+ chat_chunk = llm.ChatChunk(
331
+ id=id,
332
+ delta=llm.ChoiceDelta(
333
+ role="assistant",
334
+ tool_calls=[
335
+ llm.FunctionToolCall(
336
+ arguments=json.dumps(part.function_call.args),
337
+ name=part.function_call.name,
338
+ call_id=part.function_call.id or utils.shortuuid("function_call_"),
339
+ )
340
+ ],
341
+ content=part.text,
342
+ ),
395
343
  )
396
- return None
397
-
398
- fnc_info = _create_ai_function_info(
399
- self._fnc_ctx,
400
- part.function_call.id,
401
- part.function_call.name,
402
- json.dumps(part.function_call.args),
403
- )
404
-
405
- self._function_calls_info.append(fnc_info)
344
+ return chat_chunk
406
345
 
407
346
  return llm.ChatChunk(
408
- request_id=id,
409
- choices=[
410
- llm.Choice(
411
- delta=llm.ChoiceDelta(
412
- role="assistant",
413
- tool_calls=[fnc_info],
414
- content=part.text,
415
- ),
416
- index=index,
417
- )
418
- ],
347
+ id=id,
348
+ delta=llm.ChoiceDelta(content=part.text, role="assistant"),
419
349
  )
@@ -19,8 +19,14 @@ import dataclasses
19
19
  import time
20
20
  import weakref
21
21
  from dataclasses import dataclass
22
- from typing import Callable, List, Union
22
+ from typing import Callable, Union
23
23
 
24
+ from google.api_core.client_options import ClientOptions
25
+ from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
26
+ from google.auth import default as gauth_default
27
+ from google.auth.exceptions import DefaultCredentialsError
28
+ from google.cloud.speech_v2 import SpeechAsyncClient
29
+ from google.cloud.speech_v2.types import cloud_speech
24
30
  from livekit import rtc
25
31
  from livekit.agents import (
26
32
  DEFAULT_API_CONNECT_OPTIONS,
@@ -32,18 +38,11 @@ from livekit.agents import (
32
38
  utils,
33
39
  )
34
40
 
35
- from google.api_core.client_options import ClientOptions
36
- from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
37
- from google.auth import default as gauth_default
38
- from google.auth.exceptions import DefaultCredentialsError
39
- from google.cloud.speech_v2 import SpeechAsyncClient
40
- from google.cloud.speech_v2.types import cloud_speech
41
-
42
41
  from .log import logger
43
42
  from .models import SpeechLanguages, SpeechModels
44
43
 
45
44
  LgType = Union[SpeechLanguages, str]
46
- LanguageCode = Union[LgType, List[LgType]]
45
+ LanguageCode = Union[LgType, list[LgType]]
47
46
 
48
47
  # Google STT has a timeout of 5 mins, we'll attempt to restart the session
49
48
  # before that timeout is reached
@@ -56,14 +55,14 @@ _min_confidence = 0.65
56
55
  # This class is only be used internally to encapsulate the options
57
56
  @dataclass
58
57
  class STTOptions:
59
- languages: List[LgType]
58
+ languages: list[LgType]
60
59
  detect_language: bool
61
60
  interim_results: bool
62
61
  punctuate: bool
63
62
  spoken_punctuation: bool
64
63
  model: SpeechModels | str
65
64
  sample_rate: int
66
- keywords: List[tuple[str, float]] | None
65
+ keywords: list[tuple[str, float]] | None
67
66
 
68
67
  def build_adaptation(self) -> cloud_speech.SpeechAdaptation | None:
69
68
  if self.keywords:
@@ -72,9 +71,7 @@ class STTOptions:
72
71
  cloud_speech.SpeechAdaptation.AdaptationPhraseSet(
73
72
  inline_phrase_set=cloud_speech.PhraseSet(
74
73
  phrases=[
75
- cloud_speech.PhraseSet.Phrase(
76
- value=keyword, boost=boost
77
- )
74
+ cloud_speech.PhraseSet.Phrase(value=keyword, boost=boost)
78
75
  for keyword, boost in self.keywords
79
76
  ]
80
77
  )
@@ -98,7 +95,7 @@ class STT(stt.STT):
98
95
  sample_rate: int = 16000,
99
96
  credentials_info: dict | None = None,
100
97
  credentials_file: str | None = None,
101
- keywords: List[tuple[str, float]] | None = None,
98
+ keywords: list[tuple[str, float]] | None = None,
102
99
  ):
103
100
  """
104
101
  Create a new instance of Google STT.
@@ -120,9 +117,7 @@ class STT(stt.STT):
120
117
  credentials_file(str): the credentials file to use for recognition (default: None)
121
118
  keywords(List[tuple[str, float]]): list of keywords to recognize (default: None)
122
119
  """
123
- super().__init__(
124
- capabilities=stt.STTCapabilities(streaming=True, interim_results=True)
125
- )
120
+ super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True))
126
121
 
127
122
  self._location = location
128
123
  self._credentials_info = credentials_info
@@ -163,9 +158,7 @@ class STT(stt.STT):
163
158
  client_options = None
164
159
  client: SpeechAsyncClient | None = None
165
160
  if self._location != "global":
166
- client_options = ClientOptions(
167
- api_endpoint=f"{self._location}-speech.googleapis.com"
168
- )
161
+ client_options = ClientOptions(api_endpoint=f"{self._location}-speech.googleapis.com")
169
162
  if self._credentials_info:
170
163
  client = SpeechAsyncClient.from_service_account_info(
171
164
  self._credentials_info,
@@ -206,9 +199,7 @@ class STT(stt.STT):
206
199
  config.languages = [config.languages]
207
200
  elif not config.detect_language:
208
201
  if len(config.languages) > 1:
209
- logger.warning(
210
- "multiple languages provided, but language detection is disabled"
211
- )
202
+ logger.warning("multiple languages provided, but language detection is disabled")
212
203
  config.languages = [config.languages[0]]
213
204
 
214
205
  return config
@@ -266,7 +257,7 @@ class STT(stt.STT):
266
257
  *,
267
258
  language: SpeechLanguages | str | None = None,
268
259
  conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
269
- ) -> "SpeechStream":
260
+ ) -> SpeechStream:
270
261
  config = self._sanitize_options(language=language)
271
262
  stream = SpeechStream(
272
263
  stt=self,
@@ -288,7 +279,7 @@ class STT(stt.STT):
288
279
  spoken_punctuation: bool | None = None,
289
280
  model: SpeechModels | None = None,
290
281
  location: str | None = None,
291
- keywords: List[tuple[str, float]] | None = None,
282
+ keywords: list[tuple[str, float]] | None = None,
292
283
  ):
293
284
  if languages is not None:
294
285
  if isinstance(languages, str):
@@ -337,9 +328,7 @@ class SpeechStream(stt.SpeechStream):
337
328
  recognizer_cb: Callable[[SpeechAsyncClient], str],
338
329
  config: STTOptions,
339
330
  ) -> None:
340
- super().__init__(
341
- stt=stt, conn_options=conn_options, sample_rate=config.sample_rate
342
- )
331
+ super().__init__(stt=stt, conn_options=conn_options, sample_rate=config.sample_rate)
343
332
 
344
333
  self._pool = pool
345
334
  self._recognizer_cb = recognizer_cb
@@ -356,7 +345,7 @@ class SpeechStream(stt.SpeechStream):
356
345
  punctuate: bool | None = None,
357
346
  spoken_punctuation: bool | None = None,
358
347
  model: SpeechModels | None = None,
359
- keywords: List[tuple[str, float]] | None = None,
348
+ keywords: list[tuple[str, float]] | None = None,
360
349
  ):
361
350
  if languages is not None:
362
351
  if isinstance(languages, str):
@@ -380,9 +369,7 @@ class SpeechStream(stt.SpeechStream):
380
369
  async def _run(self) -> None:
381
370
  # google requires a async generator when calling streaming_recognize
382
371
  # this function basically convert the queue into a async generator
383
- async def input_generator(
384
- client: SpeechAsyncClient, should_stop: asyncio.Event
385
- ):
372
+ async def input_generator(client: SpeechAsyncClient, should_stop: asyncio.Event):
386
373
  try:
387
374
  # first request should contain the config
388
375
  yield cloud_speech.StreamingRecognizeRequest(
@@ -398,14 +385,10 @@ class SpeechStream(stt.SpeechStream):
398
385
  return
399
386
 
400
387
  if isinstance(frame, rtc.AudioFrame):
401
- yield cloud_speech.StreamingRecognizeRequest(
402
- audio=frame.data.tobytes()
403
- )
388
+ yield cloud_speech.StreamingRecognizeRequest(audio=frame.data.tobytes())
404
389
 
405
390
  except Exception:
406
- logger.exception(
407
- "an error occurred while streaming input to google STT"
408
- )
391
+ logger.exception("an error occurred while streaming input to google STT")
409
392
 
410
393
  async def process_stream(client: SpeechAsyncClient, stream):
411
394
  has_started = False
@@ -442,19 +425,14 @@ class SpeechStream(stt.SpeechStream):
442
425
  alternatives=[speech_data],
443
426
  )
444
427
  )
445
- if (
446
- time.time() - self._session_connected_at
447
- > _max_session_duration
448
- ):
428
+ if time.time() - self._session_connected_at > _max_session_duration:
449
429
  logger.debug(
450
430
  "Google STT maximum connection time reached. Reconnecting..."
451
431
  )
452
432
  self._pool.remove(client)
453
433
  if has_started:
454
434
  self._event_ch.send_nowait(
455
- stt.SpeechEvent(
456
- type=stt.SpeechEventType.END_OF_SPEECH
457
- )
435
+ stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
458
436
  )
459
437
  has_started = False
460
438
  self._reconnect_event.set()
@@ -499,12 +477,8 @@ class SpeechStream(stt.SpeechStream):
499
477
  )
500
478
  self._session_connected_at = time.time()
501
479
 
502
- process_stream_task = asyncio.create_task(
503
- process_stream(client, stream)
504
- )
505
- wait_reconnect_task = asyncio.create_task(
506
- self._reconnect_event.wait()
507
- )
480
+ process_stream_task = asyncio.create_task(process_stream(client, stream))
481
+ wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
508
482
 
509
483
  try:
510
484
  done, _ = await asyncio.wait(
@@ -518,9 +492,7 @@ class SpeechStream(stt.SpeechStream):
518
492
  break
519
493
  self._reconnect_event.clear()
520
494
  finally:
521
- await utils.aio.gracefully_cancel(
522
- process_stream_task, wait_reconnect_task
523
- )
495
+ await utils.aio.gracefully_cancel(process_stream_task, wait_reconnect_task)
524
496
  should_stop.set()
525
497
  except DeadlineExceeded:
526
498
  raise APITimeoutError()
@@ -581,8 +553,6 @@ def _streaming_recognize_response_to_speech_data(
581
553
  if text == "":
582
554
  return None
583
555
 
584
- data = stt.SpeechData(
585
- language=lg, start_time=0, end_time=0, confidence=confidence, text=text
586
- )
556
+ data = stt.SpeechData(language=lg, start_time=0, end_time=0, confidence=confidence, text=text)
587
557
 
588
558
  return data