livekit-plugins-google 0.3.0__py3-none-any.whl → 1.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,501 @@
1
+ # Copyright 2023 LiveKit, Inc.
2
+ #
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import os
20
+ from dataclasses import dataclass
21
+ from typing import Any, cast
22
+
23
+ from google.auth._default_async import default_async
24
+ from google.genai import Client, types
25
+ from google.genai.errors import APIError, ClientError, ServerError
26
+ from livekit.agents import APIConnectionError, APIStatusError, llm, utils
27
+ from livekit.agents.llm import FunctionTool, RawFunctionTool, ToolChoice, utils as llm_utils
28
+ from livekit.agents.llm.tool_context import (
29
+ get_function_info,
30
+ get_raw_function_info,
31
+ is_function_tool,
32
+ is_raw_function_tool,
33
+ )
34
+ from livekit.agents.types import (
35
+ DEFAULT_API_CONNECT_OPTIONS,
36
+ NOT_GIVEN,
37
+ APIConnectOptions,
38
+ NotGivenOr,
39
+ )
40
+ from livekit.agents.utils import is_given
41
+
42
+ from .log import logger
43
+ from .models import ChatModels
44
+ from .tools import _LLMTool
45
+ from .utils import create_tools_config, to_fnc_ctx, to_response_format
46
+ from .version import __version__
47
+
48
+
49
+ @dataclass
50
+ class _LLMOptions:
51
+ model: ChatModels | str
52
+ temperature: NotGivenOr[float]
53
+ tool_choice: NotGivenOr[ToolChoice]
54
+ vertexai: NotGivenOr[bool]
55
+ project: NotGivenOr[str]
56
+ location: NotGivenOr[str]
57
+ max_output_tokens: NotGivenOr[int]
58
+ top_p: NotGivenOr[float]
59
+ top_k: NotGivenOr[float]
60
+ presence_penalty: NotGivenOr[float]
61
+ frequency_penalty: NotGivenOr[float]
62
+ thinking_config: NotGivenOr[types.ThinkingConfigOrDict]
63
+ automatic_function_calling_config: NotGivenOr[types.AutomaticFunctionCallingConfigOrDict]
64
+ gemini_tools: NotGivenOr[list[_LLMTool]]
65
+ http_options: NotGivenOr[types.HttpOptions]
66
+ seed: NotGivenOr[int]
67
+ safety_settings: NotGivenOr[list[types.SafetySettingOrDict]]
68
+
69
+
70
+ BLOCKED_REASONS = [
71
+ types.FinishReason.SAFETY,
72
+ types.FinishReason.SPII,
73
+ types.FinishReason.PROHIBITED_CONTENT,
74
+ types.FinishReason.BLOCKLIST,
75
+ types.FinishReason.LANGUAGE,
76
+ types.FinishReason.RECITATION,
77
+ ]
78
+
79
+
80
+ class LLM(llm.LLM):
81
+ def __init__(
82
+ self,
83
+ *,
84
+ model: ChatModels | str = "gemini-2.0-flash-001",
85
+ api_key: NotGivenOr[str] = NOT_GIVEN,
86
+ vertexai: NotGivenOr[bool] = NOT_GIVEN,
87
+ project: NotGivenOr[str] = NOT_GIVEN,
88
+ location: NotGivenOr[str] = NOT_GIVEN,
89
+ temperature: NotGivenOr[float] = NOT_GIVEN,
90
+ max_output_tokens: NotGivenOr[int] = NOT_GIVEN,
91
+ top_p: NotGivenOr[float] = NOT_GIVEN,
92
+ top_k: NotGivenOr[float] = NOT_GIVEN,
93
+ presence_penalty: NotGivenOr[float] = NOT_GIVEN,
94
+ frequency_penalty: NotGivenOr[float] = NOT_GIVEN,
95
+ tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
96
+ thinking_config: NotGivenOr[types.ThinkingConfigOrDict] = NOT_GIVEN,
97
+ automatic_function_calling_config: NotGivenOr[
98
+ types.AutomaticFunctionCallingConfigOrDict
99
+ ] = NOT_GIVEN,
100
+ gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN,
101
+ http_options: NotGivenOr[types.HttpOptions] = NOT_GIVEN,
102
+ seed: NotGivenOr[int] = NOT_GIVEN,
103
+ safety_settings: NotGivenOr[list[types.SafetySettingOrDict]] = NOT_GIVEN,
104
+ ) -> None:
105
+ """
106
+ Create a new instance of Google GenAI LLM.
107
+
108
+ Environment Requirements:
109
+ - For VertexAI: Set the `GOOGLE_APPLICATION_CREDENTIALS` environment variable to the path of the service account key file or use any of the other Google Cloud auth methods.
110
+ The Google Cloud project and location can be set via `project` and `location` arguments or the environment variables
111
+ `GOOGLE_CLOUD_PROJECT` and `GOOGLE_CLOUD_LOCATION`. By default, the project is inferred from the service account key file,
112
+ and the location defaults to "us-central1".
113
+ - For Google Gemini API: Set the `api_key` argument or the `GOOGLE_API_KEY` environment variable.
114
+
115
+ Args:
116
+ model (ChatModels | str, optional): The model name to use. Defaults to "gemini-2.0-flash-001".
117
+ api_key (str, optional): The API key for Google Gemini. If not provided, it attempts to read from the `GOOGLE_API_KEY` environment variable.
118
+ vertexai (bool, optional): Whether to use VertexAI. If not provided, it attempts to read from the `GOOGLE_GENAI_USE_VERTEXAI` environment variable. Defaults to False.
119
+ project (str, optional): The Google Cloud project to use (only for VertexAI). Defaults to None.
120
+ location (str, optional): The location to use for VertexAI API requests. Defaults value is "us-central1".
121
+ temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
122
+ max_output_tokens (int, optional): Maximum number of tokens to generate in the output. Defaults to None.
123
+ top_p (float, optional): The nucleus sampling probability for response generation. Defaults to None.
124
+ top_k (int, optional): The top-k sampling value for response generation. Defaults to None.
125
+ presence_penalty (float, optional): Penalizes the model for generating previously mentioned concepts. Defaults to None.
126
+ frequency_penalty (float, optional): Penalizes the model for repeating words. Defaults to None.
127
+ tool_choice (ToolChoice, optional): Specifies whether to use tools during response generation. Defaults to "auto".
128
+ thinking_config (ThinkingConfigOrDict, optional): The thinking configuration for response generation. Defaults to None.
129
+ automatic_function_calling_config (AutomaticFunctionCallingConfigOrDict, optional): The automatic function calling configuration for response generation. Defaults to None.
130
+ gemini_tools (list[LLMTool], optional): The Gemini-specific tools to use for the session.
131
+ http_options (HttpOptions, optional): The HTTP options to use for the session.
132
+ seed (int, optional): Random seed for reproducible generation. Defaults to None.
133
+ safety_settings (list[SafetySettingOrDict], optional): Safety settings for content filtering. Defaults to None.
134
+ """ # noqa: E501
135
+ super().__init__()
136
+ gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
137
+ gcp_location: str | None = (
138
+ location
139
+ if is_given(location)
140
+ else os.environ.get("GOOGLE_CLOUD_LOCATION") or "us-central1"
141
+ )
142
+ use_vertexai = (
143
+ vertexai
144
+ if is_given(vertexai)
145
+ else os.environ.get("GOOGLE_GENAI_USE_VERTEXAI", "0").lower() in ["true", "1"]
146
+ )
147
+ gemini_api_key = api_key if is_given(api_key) else os.environ.get("GOOGLE_API_KEY")
148
+
149
+ if use_vertexai:
150
+ if not gcp_project:
151
+ _, gcp_project = default_async( # type: ignore
152
+ scopes=["https://www.googleapis.com/auth/cloud-platform"]
153
+ )
154
+ if not gcp_project or not gcp_location:
155
+ raise ValueError(
156
+ "Project is required for VertexAI via project kwarg or GOOGLE_CLOUD_PROJECT environment variable" # noqa: E501
157
+ )
158
+ gemini_api_key = None # VertexAI does not require an API key
159
+
160
+ else:
161
+ gcp_project = None
162
+ gcp_location = None
163
+ if not gemini_api_key:
164
+ raise ValueError(
165
+ "API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable" # noqa: E501
166
+ )
167
+
168
+ # Validate thinking_config
169
+ if is_given(thinking_config):
170
+ _thinking_budget = None
171
+ if isinstance(thinking_config, dict):
172
+ _thinking_budget = thinking_config.get("thinking_budget")
173
+ elif isinstance(thinking_config, types.ThinkingConfig):
174
+ _thinking_budget = thinking_config.thinking_budget
175
+
176
+ if _thinking_budget is not None:
177
+ if not isinstance(_thinking_budget, int):
178
+ raise ValueError("thinking_budget inside thinking_config must be an integer")
179
+
180
+ self._opts = _LLMOptions(
181
+ model=model,
182
+ temperature=temperature,
183
+ tool_choice=tool_choice,
184
+ vertexai=use_vertexai,
185
+ project=project,
186
+ location=location,
187
+ max_output_tokens=max_output_tokens,
188
+ top_p=top_p,
189
+ top_k=top_k,
190
+ presence_penalty=presence_penalty,
191
+ frequency_penalty=frequency_penalty,
192
+ thinking_config=thinking_config,
193
+ automatic_function_calling_config=automatic_function_calling_config,
194
+ gemini_tools=gemini_tools,
195
+ http_options=http_options,
196
+ seed=seed,
197
+ safety_settings=safety_settings,
198
+ )
199
+ self._client = Client(
200
+ api_key=gemini_api_key,
201
+ vertexai=use_vertexai,
202
+ project=gcp_project,
203
+ location=gcp_location,
204
+ )
205
+
206
+ @property
207
+ def model(self) -> str:
208
+ return self._opts.model
209
+
210
+ @property
211
+ def provider(self) -> str:
212
+ if self._client.vertexai:
213
+ return "Vertex AI"
214
+ else:
215
+ return "Gemini"
216
+
217
+ def chat(
218
+ self,
219
+ *,
220
+ chat_ctx: llm.ChatContext,
221
+ tools: list[FunctionTool | RawFunctionTool] | None = None,
222
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
223
+ parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
224
+ tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
225
+ response_format: NotGivenOr[
226
+ types.SchemaUnion | type[llm_utils.ResponseFormatT]
227
+ ] = NOT_GIVEN,
228
+ extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
229
+ gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN,
230
+ ) -> LLMStream:
231
+ extra = {}
232
+
233
+ if is_given(extra_kwargs):
234
+ extra.update(extra_kwargs)
235
+
236
+ tool_choice = (
237
+ cast(ToolChoice, tool_choice) if is_given(tool_choice) else self._opts.tool_choice
238
+ )
239
+ if is_given(tool_choice):
240
+ gemini_tool_choice: types.ToolConfig
241
+ if isinstance(tool_choice, dict) and tool_choice.get("type") == "function":
242
+ gemini_tool_choice = types.ToolConfig(
243
+ function_calling_config=types.FunctionCallingConfig(
244
+ mode=types.FunctionCallingConfigMode.ANY,
245
+ allowed_function_names=[tool_choice["function"]["name"]],
246
+ )
247
+ )
248
+ extra["tool_config"] = gemini_tool_choice
249
+ elif tool_choice == "required":
250
+ tool_names = []
251
+ for tool in tools or []:
252
+ if is_function_tool(tool):
253
+ tool_names.append(get_function_info(tool).name)
254
+ elif is_raw_function_tool(tool):
255
+ tool_names.append(get_raw_function_info(tool).name)
256
+
257
+ gemini_tool_choice = types.ToolConfig(
258
+ function_calling_config=types.FunctionCallingConfig(
259
+ mode=types.FunctionCallingConfigMode.ANY,
260
+ allowed_function_names=tool_names or None,
261
+ )
262
+ )
263
+ extra["tool_config"] = gemini_tool_choice
264
+ elif tool_choice == "auto":
265
+ gemini_tool_choice = types.ToolConfig(
266
+ function_calling_config=types.FunctionCallingConfig(
267
+ mode=types.FunctionCallingConfigMode.AUTO,
268
+ )
269
+ )
270
+ extra["tool_config"] = gemini_tool_choice
271
+ elif tool_choice == "none":
272
+ gemini_tool_choice = types.ToolConfig(
273
+ function_calling_config=types.FunctionCallingConfig(
274
+ mode=types.FunctionCallingConfigMode.NONE,
275
+ )
276
+ )
277
+ extra["tool_config"] = gemini_tool_choice
278
+
279
+ if is_given(response_format):
280
+ extra["response_schema"] = to_response_format(response_format) # type: ignore
281
+ extra["response_mime_type"] = "application/json"
282
+
283
+ if is_given(self._opts.temperature):
284
+ extra["temperature"] = self._opts.temperature
285
+ if is_given(self._opts.max_output_tokens):
286
+ extra["max_output_tokens"] = self._opts.max_output_tokens
287
+ if is_given(self._opts.top_p):
288
+ extra["top_p"] = self._opts.top_p
289
+ if is_given(self._opts.top_k):
290
+ extra["top_k"] = self._opts.top_k
291
+ if is_given(self._opts.presence_penalty):
292
+ extra["presence_penalty"] = self._opts.presence_penalty
293
+ if is_given(self._opts.frequency_penalty):
294
+ extra["frequency_penalty"] = self._opts.frequency_penalty
295
+ if is_given(self._opts.seed):
296
+ extra["seed"] = self._opts.seed
297
+
298
+ # Add thinking config if thinking_budget is provided
299
+ if is_given(self._opts.thinking_config):
300
+ extra["thinking_config"] = self._opts.thinking_config
301
+
302
+ if is_given(self._opts.automatic_function_calling_config):
303
+ extra["automatic_function_calling"] = self._opts.automatic_function_calling_config
304
+
305
+ if is_given(self._opts.safety_settings):
306
+ extra["safety_settings"] = self._opts.safety_settings
307
+
308
+ gemini_tools = gemini_tools if is_given(gemini_tools) else self._opts.gemini_tools
309
+
310
+ return LLMStream(
311
+ self,
312
+ client=self._client,
313
+ model=self._opts.model,
314
+ chat_ctx=chat_ctx,
315
+ tools=tools or [],
316
+ conn_options=conn_options,
317
+ gemini_tools=gemini_tools,
318
+ extra_kwargs=extra,
319
+ )
320
+
321
+
322
+ class LLMStream(llm.LLMStream):
323
+ def __init__(
324
+ self,
325
+ llm: LLM,
326
+ *,
327
+ client: Client,
328
+ model: str | ChatModels,
329
+ chat_ctx: llm.ChatContext,
330
+ conn_options: APIConnectOptions,
331
+ tools: list[FunctionTool | RawFunctionTool],
332
+ extra_kwargs: dict[str, Any],
333
+ gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN,
334
+ ) -> None:
335
+ super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
336
+ self._client = client
337
+ self._model = model
338
+ self._llm: LLM = llm
339
+ self._extra_kwargs = extra_kwargs
340
+ self._gemini_tools = gemini_tools
341
+
342
+ async def _run(self) -> None:
343
+ retryable = True
344
+ request_id = utils.shortuuid()
345
+
346
+ try:
347
+ turns_dict, extra_data = self._chat_ctx.to_provider_format(format="google")
348
+ turns = [types.Content.model_validate(turn) for turn in turns_dict]
349
+ function_declarations = to_fnc_ctx(self._tools)
350
+ tools_config = create_tools_config(
351
+ function_tools=function_declarations,
352
+ gemini_tools=self._gemini_tools if is_given(self._gemini_tools) else None,
353
+ )
354
+ if tools_config:
355
+ self._extra_kwargs["tools"] = tools_config
356
+ http_options = self._llm._opts.http_options or types.HttpOptions(
357
+ timeout=int(self._conn_options.timeout * 1000)
358
+ )
359
+ if not http_options.headers:
360
+ http_options.headers = {}
361
+ http_options.headers["x-goog-api-client"] = f"livekit-agents/{__version__}"
362
+ config = types.GenerateContentConfig(
363
+ system_instruction=(
364
+ [types.Part(text=content) for content in extra_data.system_messages]
365
+ if extra_data.system_messages
366
+ else None
367
+ ),
368
+ http_options=http_options,
369
+ **self._extra_kwargs,
370
+ )
371
+ stream = await self._client.aio.models.generate_content_stream(
372
+ model=self._model,
373
+ contents=cast(types.ContentListUnion, turns),
374
+ config=config,
375
+ )
376
+
377
+ async for response in stream:
378
+ if response.prompt_feedback:
379
+ raise APIStatusError(
380
+ response.prompt_feedback.json(),
381
+ retryable=False,
382
+ request_id=request_id,
383
+ )
384
+
385
+ if (
386
+ not response.candidates
387
+ or not response.candidates[0].content
388
+ or not response.candidates[0].content.parts
389
+ ):
390
+ logger.warning(f"no content in the response: {response}")
391
+ raise APIStatusError(
392
+ "no content in the response",
393
+ retryable=True,
394
+ request_id=request_id,
395
+ )
396
+
397
+ if len(response.candidates) > 1:
398
+ logger.warning(
399
+ "gemini llm: there are multiple candidates in the response, returning response from the first one." # noqa: E501
400
+ )
401
+
402
+ candidate = response.candidates[0]
403
+
404
+ if candidate.finish_reason in BLOCKED_REASONS:
405
+ raise APIStatusError(
406
+ f"generation blocked by gemini: {candidate.finish_reason}",
407
+ retryable=False,
408
+ request_id=request_id,
409
+ )
410
+
411
+ if not candidate.content or not candidate.content.parts:
412
+ raise APIStatusError(
413
+ "no content in the response",
414
+ retryable=retryable,
415
+ request_id=request_id,
416
+ )
417
+
418
+ chunks_yielded = False
419
+ for part in candidate.content.parts:
420
+ chat_chunk = self._parse_part(request_id, part)
421
+ if chat_chunk is not None:
422
+ chunks_yielded = True
423
+ retryable = False
424
+ self._event_ch.send_nowait(chat_chunk)
425
+
426
+ if candidate.finish_reason == types.FinishReason.STOP and not chunks_yielded:
427
+ raise APIStatusError(
428
+ "no response generated",
429
+ retryable=retryable,
430
+ request_id=request_id,
431
+ )
432
+
433
+ if response.usage_metadata is not None:
434
+ usage = response.usage_metadata
435
+ self._event_ch.send_nowait(
436
+ llm.ChatChunk(
437
+ id=request_id,
438
+ usage=llm.CompletionUsage(
439
+ completion_tokens=usage.candidates_token_count or 0,
440
+ prompt_tokens=usage.prompt_token_count or 0,
441
+ prompt_cached_tokens=usage.cached_content_token_count or 0,
442
+ total_tokens=usage.total_token_count or 0,
443
+ ),
444
+ )
445
+ )
446
+
447
+ except ClientError as e:
448
+ raise APIStatusError(
449
+ "gemini llm: client error",
450
+ status_code=e.code,
451
+ body=f"{e.message} {e.status}",
452
+ request_id=request_id,
453
+ retryable=False if e.code != 429 else True,
454
+ ) from e
455
+ except ServerError as e:
456
+ raise APIStatusError(
457
+ "gemini llm: server error",
458
+ status_code=e.code,
459
+ body=f"{e.message} {e.status}",
460
+ request_id=request_id,
461
+ retryable=retryable,
462
+ ) from e
463
+ except APIError as e:
464
+ raise APIStatusError(
465
+ "gemini llm: api error",
466
+ status_code=e.code,
467
+ body=f"{e.message} {e.status}",
468
+ request_id=request_id,
469
+ retryable=retryable,
470
+ ) from e
471
+ except Exception as e:
472
+ raise APIConnectionError(
473
+ f"gemini llm: error generating content {str(e)}",
474
+ retryable=retryable,
475
+ ) from e
476
+
477
+ def _parse_part(self, id: str, part: types.Part) -> llm.ChatChunk | None:
478
+ if part.function_call:
479
+ chat_chunk = llm.ChatChunk(
480
+ id=id,
481
+ delta=llm.ChoiceDelta(
482
+ role="assistant",
483
+ tool_calls=[
484
+ llm.FunctionToolCall(
485
+ arguments=json.dumps(part.function_call.args),
486
+ name=part.function_call.name,
487
+ call_id=part.function_call.id or utils.shortuuid("function_call_"),
488
+ )
489
+ ],
490
+ content=part.text,
491
+ ),
492
+ )
493
+ return chat_chunk
494
+
495
+ if not part.text:
496
+ return None
497
+
498
+ return llm.ChatChunk(
499
+ id=id,
500
+ delta=llm.ChoiceDelta(content=part.text, role="assistant"),
501
+ )
@@ -0,0 +1,3 @@
1
+ import logging
2
+
3
+ logger = logging.getLogger("livekit.plugins.google")