livekit-plugins-google 1.0.23__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,13 +20,17 @@ import os
20
20
  from dataclasses import dataclass
21
21
  from typing import Any, cast
22
22
 
23
- from google import genai
24
23
  from google.auth._default_async import default_async
25
- from google.genai import types
24
+ from google.genai import Client, types
26
25
  from google.genai.errors import APIError, ClientError, ServerError
27
26
  from livekit.agents import APIConnectionError, APIStatusError, llm, utils
28
- from livekit.agents.llm import FunctionTool, ToolChoice, utils as llm_utils
29
- from livekit.agents.llm.tool_context import get_function_info
27
+ from livekit.agents.llm import FunctionTool, RawFunctionTool, ToolChoice, utils as llm_utils
28
+ from livekit.agents.llm.tool_context import (
29
+ get_function_info,
30
+ get_raw_function_info,
31
+ is_function_tool,
32
+ is_raw_function_tool,
33
+ )
30
34
  from livekit.agents.types import (
31
35
  DEFAULT_API_CONNECT_OPTIONS,
32
36
  NOT_GIVEN,
@@ -37,7 +41,8 @@ from livekit.agents.utils import is_given
37
41
 
38
42
  from .log import logger
39
43
  from .models import ChatModels
40
- from .utils import to_chat_ctx, to_fnc_ctx, to_response_format
44
+ from .tools import _LLMTool
45
+ from .utils import create_tools_config, to_fnc_ctx, to_response_format
41
46
 
42
47
 
43
48
  @dataclass
@@ -54,6 +59,7 @@ class _LLMOptions:
54
59
  presence_penalty: NotGivenOr[float]
55
60
  frequency_penalty: NotGivenOr[float]
56
61
  thinking_config: NotGivenOr[types.ThinkingConfigOrDict]
62
+ gemini_tools: NotGivenOr[list[_LLMTool]]
57
63
 
58
64
 
59
65
  class LLM(llm.LLM):
@@ -73,6 +79,7 @@ class LLM(llm.LLM):
73
79
  frequency_penalty: NotGivenOr[float] = NOT_GIVEN,
74
80
  tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
75
81
  thinking_config: NotGivenOr[types.ThinkingConfigOrDict] = NOT_GIVEN,
82
+ gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN,
76
83
  ) -> None:
77
84
  """
78
85
  Create a new instance of Google GenAI LLM.
@@ -98,10 +105,11 @@ class LLM(llm.LLM):
98
105
  frequency_penalty (float, optional): Penalizes the model for repeating words. Defaults to None.
99
106
  tool_choice (ToolChoice, optional): Specifies whether to use tools during response generation. Defaults to "auto".
100
107
  thinking_config (ThinkingConfigOrDict, optional): The thinking configuration for response generation. Defaults to None.
108
+ gemini_tools (list[LLMTool], optional): The Gemini-specific tools to use for the session.
101
109
  """ # noqa: E501
102
110
  super().__init__()
103
111
  gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
104
- gcp_location = (
112
+ gcp_location: str | None = (
105
113
  location
106
114
  if is_given(location)
107
115
  else os.environ.get("GOOGLE_CLOUD_LOCATION") or "us-central1"
@@ -115,7 +123,7 @@ class LLM(llm.LLM):
115
123
 
116
124
  if use_vertexai:
117
125
  if not gcp_project:
118
- _, gcp_project = default_async(
126
+ _, gcp_project = default_async( # type: ignore
119
127
  scopes=["https://www.googleapis.com/auth/cloud-platform"]
120
128
  )
121
129
  gemini_api_key = None # VertexAI does not require an API key
@@ -157,8 +165,9 @@ class LLM(llm.LLM):
157
165
  presence_penalty=presence_penalty,
158
166
  frequency_penalty=frequency_penalty,
159
167
  thinking_config=thinking_config,
168
+ gemini_tools=gemini_tools,
160
169
  )
161
- self._client = genai.Client(
170
+ self._client = Client(
162
171
  api_key=gemini_api_key,
163
172
  vertexai=use_vertexai,
164
173
  project=gcp_project,
@@ -169,7 +178,7 @@ class LLM(llm.LLM):
169
178
  self,
170
179
  *,
171
180
  chat_ctx: llm.ChatContext,
172
- tools: list[FunctionTool] | None = None,
181
+ tools: list[FunctionTool | RawFunctionTool] | None = None,
173
182
  conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
174
183
  parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
175
184
  tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
@@ -177,50 +186,58 @@ class LLM(llm.LLM):
177
186
  types.SchemaUnion | type[llm_utils.ResponseFormatT]
178
187
  ] = NOT_GIVEN,
179
188
  extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
189
+ gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN,
180
190
  ) -> LLMStream:
181
191
  extra = {}
182
192
 
183
193
  if is_given(extra_kwargs):
184
194
  extra.update(extra_kwargs)
185
195
 
186
- tool_choice = tool_choice if is_given(tool_choice) else self._opts.tool_choice
196
+ tool_choice = (
197
+ cast(ToolChoice, tool_choice) if is_given(tool_choice) else self._opts.tool_choice
198
+ )
187
199
  if is_given(tool_choice):
188
200
  gemini_tool_choice: types.ToolConfig
189
201
  if isinstance(tool_choice, dict) and tool_choice.get("type") == "function":
190
202
  gemini_tool_choice = types.ToolConfig(
191
203
  function_calling_config=types.FunctionCallingConfig(
192
- mode="ANY",
204
+ mode=types.FunctionCallingConfigMode.ANY,
193
205
  allowed_function_names=[tool_choice["function"]["name"]],
194
206
  )
195
207
  )
196
208
  extra["tool_config"] = gemini_tool_choice
197
209
  elif tool_choice == "required":
210
+ tool_names = []
211
+ for tool in tools or []:
212
+ if is_function_tool(tool):
213
+ tool_names.append(get_function_info(tool).name)
214
+ elif is_raw_function_tool(tool):
215
+ tool_names.append(get_raw_function_info(tool).name)
216
+
198
217
  gemini_tool_choice = types.ToolConfig(
199
218
  function_calling_config=types.FunctionCallingConfig(
200
- mode="ANY",
201
- allowed_function_names=[get_function_info(fnc).name for fnc in tools]
202
- if tools
203
- else None,
219
+ mode=types.FunctionCallingConfigMode.ANY,
220
+ allowed_function_names=tool_names or None,
204
221
  )
205
222
  )
206
223
  extra["tool_config"] = gemini_tool_choice
207
224
  elif tool_choice == "auto":
208
225
  gemini_tool_choice = types.ToolConfig(
209
226
  function_calling_config=types.FunctionCallingConfig(
210
- mode="AUTO",
227
+ mode=types.FunctionCallingConfigMode.AUTO,
211
228
  )
212
229
  )
213
230
  extra["tool_config"] = gemini_tool_choice
214
231
  elif tool_choice == "none":
215
232
  gemini_tool_choice = types.ToolConfig(
216
233
  function_calling_config=types.FunctionCallingConfig(
217
- mode="NONE",
234
+ mode=types.FunctionCallingConfigMode.NONE,
218
235
  )
219
236
  )
220
237
  extra["tool_config"] = gemini_tool_choice
221
238
 
222
239
  if is_given(response_format):
223
- extra["response_schema"] = to_response_format(response_format)
240
+ extra["response_schema"] = to_response_format(response_format) # type: ignore
224
241
  extra["response_mime_type"] = "application/json"
225
242
 
226
243
  if is_given(self._opts.temperature):
@@ -240,6 +257,8 @@ class LLM(llm.LLM):
240
257
  if is_given(self._opts.thinking_config):
241
258
  extra["thinking_config"] = self._opts.thinking_config
242
259
 
260
+ gemini_tools = gemini_tools if is_given(gemini_tools) else self._opts.gemini_tools
261
+
243
262
  return LLMStream(
244
263
  self,
245
264
  client=self._client,
@@ -247,6 +266,7 @@ class LLM(llm.LLM):
247
266
  chat_ctx=chat_ctx,
248
267
  tools=tools or [],
249
268
  conn_options=conn_options,
269
+ gemini_tools=gemini_tools,
250
270
  extra_kwargs=extra,
251
271
  )
252
272
 
@@ -256,32 +276,45 @@ class LLMStream(llm.LLMStream):
256
276
  self,
257
277
  llm: LLM,
258
278
  *,
259
- client: genai.Client,
279
+ client: Client,
260
280
  model: str | ChatModels,
261
281
  chat_ctx: llm.ChatContext,
262
282
  conn_options: APIConnectOptions,
263
- tools: list[FunctionTool],
283
+ tools: list[FunctionTool | RawFunctionTool],
264
284
  extra_kwargs: dict[str, Any],
285
+ gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN,
265
286
  ) -> None:
266
287
  super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
267
288
  self._client = client
268
289
  self._model = model
269
290
  self._llm: LLM = llm
270
291
  self._extra_kwargs = extra_kwargs
292
+ self._gemini_tools = gemini_tools
271
293
 
272
294
  async def _run(self) -> None:
273
295
  retryable = True
274
296
  request_id = utils.shortuuid()
275
297
 
276
298
  try:
277
- turns, system_instruction = to_chat_ctx(self._chat_ctx, id(self._llm), generate=True)
299
+ turns_dict, extra_data = self._chat_ctx.to_provider_format(format="google")
300
+ turns = [types.Content.model_validate(turn) for turn in turns_dict]
278
301
  function_declarations = to_fnc_ctx(self._tools)
279
- if function_declarations:
280
- self._extra_kwargs["tools"] = [
281
- types.Tool(function_declarations=function_declarations)
282
- ]
302
+ tools_config = create_tools_config(
303
+ function_tools=function_declarations,
304
+ gemini_tools=self._gemini_tools if is_given(self._gemini_tools) else None,
305
+ )
306
+ if tools_config:
307
+ self._extra_kwargs["tools"] = tools_config
308
+
283
309
  config = types.GenerateContentConfig(
284
- system_instruction=system_instruction,
310
+ system_instruction=(
311
+ [types.Part(text=content) for content in extra_data.system_messages]
312
+ if extra_data.system_messages
313
+ else None
314
+ ),
315
+ http_options=types.HttpOptions(
316
+ timeout=int(self._conn_options.timeout * 1000),
317
+ ),
285
318
  **self._extra_kwargs,
286
319
  )
287
320
 
@@ -371,7 +404,7 @@ class LLMStream(llm.LLMStream):
371
404
  tool_calls=[
372
405
  llm.FunctionToolCall(
373
406
  arguments=json.dumps(part.function_call.args),
374
- name=part.function_call.name,
407
+ name=part.function_call.name, # type: ignore
375
408
  call_id=part.function_call.id or utils.shortuuid("function_call_"),
376
409
  )
377
410
  ],
@@ -18,8 +18,9 @@ import asyncio
18
18
  import dataclasses
19
19
  import time
20
20
  import weakref
21
+ from collections.abc import AsyncGenerator, AsyncIterable
21
22
  from dataclasses import dataclass
22
- from typing import Callable, Union
23
+ from typing import Callable, Union, cast
23
24
 
24
25
  from google.api_core.client_options import ClientOptions
25
26
  from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
@@ -140,7 +141,7 @@ class STT(stt.STT):
140
141
 
141
142
  if not is_given(credentials_file) and not is_given(credentials_info):
142
143
  try:
143
- gauth_default()
144
+ gauth_default() # type: ignore
144
145
  except DefaultCredentialsError:
145
146
  raise ValueError(
146
147
  "Application default credentials must be available "
@@ -168,9 +169,10 @@ class STT(stt.STT):
168
169
  connect_cb=self._create_client,
169
170
  )
170
171
 
171
- async def _create_client(self) -> SpeechAsyncClient:
172
+ async def _create_client(self, timeout: float) -> SpeechAsyncClient:
172
173
  # Add support for passing a specific location that matches recognizer
173
174
  # see: https://cloud.google.com/speech-to-text/v2/docs/speech-to-text-supported-languages
175
+ # TODO(long): how to set timeout?
174
176
  client_options = None
175
177
  client: SpeechAsyncClient | None = None
176
178
  if self._location != "global":
@@ -198,7 +200,7 @@ class STT(stt.STT):
198
200
  except AttributeError:
199
201
  from google.auth import default as ga_default
200
202
 
201
- _, project_id = ga_default()
203
+ _, project_id = ga_default() # type: ignore
202
204
  return f"projects/{project_id}/locations/{self._location}/recognizers/_"
203
205
 
204
206
  def _sanitize_options(self, *, language: NotGivenOr[str] = NOT_GIVEN) -> STTOptions:
@@ -243,7 +245,7 @@ class STT(stt.STT):
243
245
  )
244
246
 
245
247
  try:
246
- async with self._pool.connection() as client:
248
+ async with self._pool.connection(timeout=conn_options.timeout) as client:
247
249
  raw = await client.recognize(
248
250
  cloud_speech.RecognizeRequest(
249
251
  recognizer=self._get_recognizer(client),
@@ -289,11 +291,11 @@ class STT(stt.STT):
289
291
  model: NotGivenOr[SpeechModels] = NOT_GIVEN,
290
292
  location: NotGivenOr[str] = NOT_GIVEN,
291
293
  keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
292
- ):
294
+ ) -> None:
293
295
  if is_given(languages):
294
296
  if isinstance(languages, str):
295
297
  languages = [languages]
296
- self._config.languages = languages
298
+ self._config.languages = cast(list[LgType], languages)
297
299
  if is_given(detect_language):
298
300
  self._config.detect_language = detect_language
299
301
  if is_given(interim_results):
@@ -356,11 +358,11 @@ class SpeechStream(stt.SpeechStream):
356
358
  model: NotGivenOr[SpeechModels] = NOT_GIVEN,
357
359
  min_confidence_threshold: NotGivenOr[float] = NOT_GIVEN,
358
360
  keywords: NotGivenOr[list[tuple[str, float]]] = NOT_GIVEN,
359
- ):
361
+ ) -> None:
360
362
  if is_given(languages):
361
363
  if isinstance(languages, str):
362
364
  languages = [languages]
363
- self._config.languages = languages
365
+ self._config.languages = cast(list[LgType], languages)
364
366
  if is_given(detect_language):
365
367
  self._config.detect_language = detect_language
366
368
  if is_given(interim_results):
@@ -381,7 +383,9 @@ class SpeechStream(stt.SpeechStream):
381
383
  async def _run(self) -> None:
382
384
  # google requires a async generator when calling streaming_recognize
383
385
  # this function basically convert the queue into a async generator
384
- async def input_generator(client: SpeechAsyncClient, should_stop: asyncio.Event):
386
+ async def input_generator(
387
+ client: SpeechAsyncClient, should_stop: asyncio.Event
388
+ ) -> AsyncGenerator[cloud_speech.StreamingRecognizeRequest, None]:
385
389
  try:
386
390
  # first request should contain the config
387
391
  yield cloud_speech.StreamingRecognizeRequest(
@@ -402,7 +406,10 @@ class SpeechStream(stt.SpeechStream):
402
406
  except Exception:
403
407
  logger.exception("an error occurred while streaming input to google STT")
404
408
 
405
- async def process_stream(client: SpeechAsyncClient, stream):
409
+ async def process_stream(
410
+ client: SpeechAsyncClient,
411
+ stream: AsyncIterable[cloud_speech.StreamingRecognizeResponse],
412
+ ) -> None:
406
413
  has_started = False
407
414
  async for resp in stream:
408
415
  if (
@@ -464,7 +471,7 @@ class SpeechStream(stt.SpeechStream):
464
471
 
465
472
  while True:
466
473
  try:
467
- async with self._pool.connection() as client:
474
+ async with self._pool.connection(timeout=self._conn_options.timeout) as client:
468
475
  self._streaming_config = cloud_speech.StreamingRecognitionConfig(
469
476
  config=cloud_speech.RecognitionConfig(
470
477
  explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
@@ -0,0 +1,11 @@
1
+ from typing import Union
2
+
3
+ from google.genai.types import (
4
+ GoogleMaps,
5
+ GoogleSearch,
6
+ GoogleSearchRetrieval,
7
+ ToolCodeExecution,
8
+ UrlContext,
9
+ )
10
+
11
+ _LLMTool = Union[GoogleSearchRetrieval, ToolCodeExecution, GoogleSearch, UrlContext, GoogleMaps]