livekit-plugins-google 0.11.3__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,45 +15,43 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import asyncio
19
18
  import json
20
19
  import os
21
20
  from dataclasses import dataclass
22
- from typing import Any, Literal, MutableSet, Union, cast
23
-
24
- from livekit.agents import (
25
- APIConnectionError,
26
- APIStatusError,
27
- llm,
28
- utils,
29
- )
30
- from livekit.agents.llm import LLMCapabilities, ToolChoice, _create_ai_function_info
31
- from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
21
+ from typing import Any, cast
32
22
 
33
23
  from google import genai
34
24
  from google.auth._default_async import default_async
35
25
  from google.genai import types
36
26
  from google.genai.errors import APIError, ClientError, ServerError
27
+ from livekit.agents import APIConnectionError, APIStatusError, llm, utils
28
+ from livekit.agents.llm import FunctionTool, ToolChoice, utils as llm_utils
29
+ from livekit.agents.types import (
30
+ DEFAULT_API_CONNECT_OPTIONS,
31
+ NOT_GIVEN,
32
+ APIConnectOptions,
33
+ NotGivenOr,
34
+ )
35
+ from livekit.agents.utils import is_given
37
36
 
38
- from ._utils import _build_gemini_ctx, _build_tools
39
37
  from .log import logger
40
38
  from .models import ChatModels
39
+ from .utils import to_chat_ctx, to_fnc_ctx, to_response_format
41
40
 
42
41
 
43
42
  @dataclass
44
- class LLMOptions:
43
+ class _LLMOptions:
45
44
  model: ChatModels | str
46
- temperature: float | None
47
- tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto"
48
- vertexai: bool = False
49
- project: str | None = None
50
- location: str | None = None
51
- candidate_count: int = 1
52
- max_output_tokens: int | None = None
53
- top_p: float | None = None
54
- top_k: float | None = None
55
- presence_penalty: float | None = None
56
- frequency_penalty: float | None = None
45
+ temperature: NotGivenOr[float]
46
+ tool_choice: NotGivenOr[ToolChoice]
47
+ vertexai: NotGivenOr[bool]
48
+ project: NotGivenOr[str]
49
+ location: NotGivenOr[str]
50
+ max_output_tokens: NotGivenOr[int]
51
+ top_p: NotGivenOr[float]
52
+ top_k: NotGivenOr[float]
53
+ presence_penalty: NotGivenOr[float]
54
+ frequency_penalty: NotGivenOr[float]
57
55
 
58
56
 
59
57
  class LLM(llm.LLM):
@@ -61,18 +59,17 @@ class LLM(llm.LLM):
61
59
  self,
62
60
  *,
63
61
  model: ChatModels | str = "gemini-2.0-flash-001",
64
- api_key: str | None = None,
65
- vertexai: bool = False,
66
- project: str | None = None,
67
- location: str | None = None,
68
- candidate_count: int = 1,
69
- temperature: float = 0.8,
70
- max_output_tokens: int | None = None,
71
- top_p: float | None = None,
72
- top_k: float | None = None,
73
- presence_penalty: float | None = None,
74
- frequency_penalty: float | None = None,
75
- tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
62
+ api_key: NotGivenOr[str] = NOT_GIVEN,
63
+ vertexai: NotGivenOr[bool] = False,
64
+ project: NotGivenOr[str] = NOT_GIVEN,
65
+ location: NotGivenOr[str] = NOT_GIVEN,
66
+ temperature: NotGivenOr[float] = NOT_GIVEN,
67
+ max_output_tokens: NotGivenOr[int] = NOT_GIVEN,
68
+ top_p: NotGivenOr[float] = NOT_GIVEN,
69
+ top_k: NotGivenOr[float] = NOT_GIVEN,
70
+ presence_penalty: NotGivenOr[float] = NOT_GIVEN,
71
+ frequency_penalty: NotGivenOr[float] = NOT_GIVEN,
72
+ tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
76
73
  ) -> None:
77
74
  """
78
75
  Create a new instance of Google GenAI LLM.
@@ -90,55 +87,46 @@ class LLM(llm.LLM):
90
87
  vertexai (bool, optional): Whether to use VertexAI. Defaults to False.
91
88
  project (str, optional): The Google Cloud project to use (only for VertexAI). Defaults to None.
92
89
  location (str, optional): The location to use for VertexAI API requests. Defaults value is "us-central1".
93
- candidate_count (int, optional): Number of candidate responses to generate. Defaults to 1.
94
90
  temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
95
91
  max_output_tokens (int, optional): Maximum number of tokens to generate in the output. Defaults to None.
96
92
  top_p (float, optional): The nucleus sampling probability for response generation. Defaults to None.
97
93
  top_k (int, optional): The top-k sampling value for response generation. Defaults to None.
98
94
  presence_penalty (float, optional): Penalizes the model for generating previously mentioned concepts. Defaults to None.
99
95
  frequency_penalty (float, optional): Penalizes the model for repeating words. Defaults to None.
100
- tool_choice (ToolChoice or Literal["auto", "required", "none"], optional): Specifies whether to use tools during response generation. Defaults to "auto".
101
- """
102
- super().__init__(
103
- capabilities=LLMCapabilities(
104
- supports_choices_on_int=False,
105
- requires_persistent_functions=False,
106
- )
107
- )
108
- self._project_id = project or os.environ.get("GOOGLE_CLOUD_PROJECT", None)
109
- self._location = location or os.environ.get(
110
- "GOOGLE_CLOUD_LOCATION", "us-central1"
111
- )
112
- self._api_key = api_key or os.environ.get("GOOGLE_API_KEY", None)
96
+ tool_choice (ToolChoice, optional): Specifies whether to use tools during response generation. Defaults to "auto".
97
+ """ # noqa: E501
98
+ super().__init__()
99
+ gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
100
+ gcp_location = location if is_given(location) else os.environ.get("GOOGLE_CLOUD_LOCATION")
101
+ gemini_api_key = api_key if is_given(api_key) else os.environ.get("GOOGLE_API_KEY")
113
102
  _gac = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
114
103
  if _gac is None:
115
104
  logger.warning(
116
- "`GOOGLE_APPLICATION_CREDENTIALS` environment variable is not set. please set it to the path of the service account key file. Otherwise, use any of the other Google Cloud auth methods."
105
+ "`GOOGLE_APPLICATION_CREDENTIALS` environment variable is not set. please set it to the path of the service account key file. Otherwise, use any of the other Google Cloud auth methods." # noqa: E501
117
106
  )
118
107
 
119
- if vertexai:
120
- if not self._project_id:
121
- _, self._project_id = default_async(
108
+ if is_given(vertexai) and vertexai:
109
+ if not gcp_project:
110
+ _, gcp_project = default_async(
122
111
  scopes=["https://www.googleapis.com/auth/cloud-platform"]
123
112
  )
124
- self._api_key = None # VertexAI does not require an API key
113
+ gemini_api_key = None # VertexAI does not require an API key
125
114
 
126
115
  else:
127
- self._project_id = None
128
- self._location = None
129
- if not self._api_key:
116
+ gcp_project = None
117
+ gcp_location = None
118
+ if not gemini_api_key:
130
119
  raise ValueError(
131
- "API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable"
120
+ "API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable" # noqa: E501
132
121
  )
133
122
 
134
- self._opts = LLMOptions(
123
+ self._opts = _LLMOptions(
135
124
  model=model,
136
125
  temperature=temperature,
137
126
  tool_choice=tool_choice,
138
127
  vertexai=vertexai,
139
128
  project=project,
140
129
  location=location,
141
- candidate_count=candidate_count,
142
130
  max_output_tokens=max_output_tokens,
143
131
  top_p=top_p,
144
132
  top_k=top_k,
@@ -146,46 +134,89 @@ class LLM(llm.LLM):
146
134
  frequency_penalty=frequency_penalty,
147
135
  )
148
136
  self._client = genai.Client(
149
- api_key=self._api_key,
150
- vertexai=vertexai,
151
- project=self._project_id,
152
- location=self._location,
137
+ api_key=gemini_api_key,
138
+ vertexai=is_given(vertexai) and vertexai,
139
+ project=gcp_project,
140
+ location=gcp_location,
153
141
  )
154
- self._running_fncs: MutableSet[asyncio.Task[Any]] = set()
155
142
 
156
143
  def chat(
157
144
  self,
158
145
  *,
159
146
  chat_ctx: llm.ChatContext,
147
+ tools: list[FunctionTool] | None = None,
160
148
  conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
161
- fnc_ctx: llm.FunctionContext | None = None,
162
- temperature: float | None = None,
163
- n: int | None = 1,
164
- parallel_tool_calls: bool | None = None,
165
- tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
166
- | None = None,
167
- ) -> "LLMStream":
168
- if tool_choice is None:
169
- tool_choice = self._opts.tool_choice
170
-
171
- if temperature is None:
172
- temperature = self._opts.temperature
149
+ parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
150
+ tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
151
+ response_format: NotGivenOr[
152
+ types.SchemaUnion | type[llm_utils.ResponseFormatT]
153
+ ] = NOT_GIVEN,
154
+ extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
155
+ ) -> LLMStream:
156
+ extra = {}
157
+
158
+ if is_given(extra_kwargs):
159
+ extra.update(extra_kwargs)
160
+
161
+ tool_choice = tool_choice if is_given(tool_choice) else self._opts.tool_choice
162
+ if is_given(tool_choice):
163
+ gemini_tool_choice: types.ToolConfig
164
+ if isinstance(tool_choice, dict) and tool_choice.get("type") == "function":
165
+ gemini_tool_choice = types.ToolConfig(
166
+ function_calling_config=types.FunctionCallingConfig(
167
+ mode="ANY",
168
+ allowed_function_names=[tool_choice["function"]["name"]],
169
+ )
170
+ )
171
+ extra["tool_config"] = gemini_tool_choice
172
+ elif tool_choice == "required":
173
+ gemini_tool_choice = types.ToolConfig(
174
+ function_calling_config=types.FunctionCallingConfig(
175
+ mode="ANY",
176
+ allowed_function_names=[fnc.name for fnc in tools],
177
+ )
178
+ )
179
+ extra["tool_config"] = gemini_tool_choice
180
+ elif tool_choice == "auto":
181
+ gemini_tool_choice = types.ToolConfig(
182
+ function_calling_config=types.FunctionCallingConfig(
183
+ mode="AUTO",
184
+ )
185
+ )
186
+ extra["tool_config"] = gemini_tool_choice
187
+ elif tool_choice == "none":
188
+ gemini_tool_choice = types.ToolConfig(
189
+ function_calling_config=types.FunctionCallingConfig(
190
+ mode="NONE",
191
+ )
192
+ )
193
+ extra["tool_config"] = gemini_tool_choice
194
+
195
+ if is_given(response_format):
196
+ extra["response_schema"] = to_response_format(response_format)
197
+ extra["response_mime_type"] = "application/json"
198
+
199
+ if is_given(self._opts.temperature):
200
+ extra["temperature"] = self._opts.temperature
201
+ if is_given(self._opts.max_output_tokens):
202
+ extra["max_output_tokens"] = self._opts.max_output_tokens
203
+ if is_given(self._opts.top_p):
204
+ extra["top_p"] = self._opts.top_p
205
+ if is_given(self._opts.top_k):
206
+ extra["top_k"] = self._opts.top_k
207
+ if is_given(self._opts.presence_penalty):
208
+ extra["presence_penalty"] = self._opts.presence_penalty
209
+ if is_given(self._opts.frequency_penalty):
210
+ extra["frequency_penalty"] = self._opts.frequency_penalty
173
211
 
174
212
  return LLMStream(
175
213
  self,
176
214
  client=self._client,
177
215
  model=self._opts.model,
178
- max_output_tokens=self._opts.max_output_tokens,
179
- top_p=self._opts.top_p,
180
- top_k=self._opts.top_k,
181
- presence_penalty=self._opts.presence_penalty,
182
- frequency_penalty=self._opts.frequency_penalty,
183
216
  chat_ctx=chat_ctx,
184
- fnc_ctx=fnc_ctx,
217
+ tools=tools,
185
218
  conn_options=conn_options,
186
- n=n,
187
- temperature=temperature,
188
- tool_choice=tool_choice,
219
+ extra_kwargs=extra,
189
220
  )
190
221
 
191
222
 
@@ -198,96 +229,38 @@ class LLMStream(llm.LLMStream):
198
229
  model: str | ChatModels,
199
230
  chat_ctx: llm.ChatContext,
200
231
  conn_options: APIConnectOptions,
201
- fnc_ctx: llm.FunctionContext | None,
202
- temperature: float | None,
203
- n: int | None,
204
- max_output_tokens: int | None,
205
- top_p: float | None,
206
- top_k: float | None,
207
- presence_penalty: float | None,
208
- frequency_penalty: float | None,
209
- tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]],
232
+ tools: list[FunctionTool] | None,
233
+ extra_kwargs: dict[str, Any],
210
234
  ) -> None:
211
- super().__init__(
212
- llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
213
- )
235
+ super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
214
236
  self._client = client
215
237
  self._model = model
216
238
  self._llm: LLM = llm
217
- self._max_output_tokens = max_output_tokens
218
- self._top_p = top_p
219
- self._top_k = top_k
220
- self._presence_penalty = presence_penalty
221
- self._frequency_penalty = frequency_penalty
222
- self._temperature = temperature
223
- self._n = n
224
- self._tool_choice = tool_choice
239
+ self._extra_kwargs = extra_kwargs
225
240
 
226
241
  async def _run(self) -> None:
227
242
  retryable = True
228
243
  request_id = utils.shortuuid()
229
244
 
230
245
  try:
231
- opts: dict[str, Any] = dict()
232
- turns, system_instruction = _build_gemini_ctx(self._chat_ctx, id(self))
233
-
234
- if self._fnc_ctx and len(self._fnc_ctx.ai_functions) > 0:
235
- functions = _build_tools(self._fnc_ctx)
236
- opts["tools"] = [types.Tool(function_declarations=functions)]
237
-
238
- if self._tool_choice is not None:
239
- if isinstance(self._tool_choice, ToolChoice):
240
- # specific function
241
- tool_config = types.ToolConfig(
242
- function_calling_config=types.FunctionCallingConfig(
243
- mode=types.FunctionCallingConfigMode.ANY,
244
- allowed_function_names=[self._tool_choice.name],
245
- )
246
- )
247
- elif self._tool_choice == "required":
248
- # model must call any function
249
- tool_config = types.ToolConfig(
250
- function_calling_config=types.FunctionCallingConfig(
251
- mode=types.FunctionCallingConfigMode.ANY,
252
- allowed_function_names=[
253
- fnc.name
254
- for fnc in self._fnc_ctx.ai_functions.values()
255
- ],
256
- )
257
- )
258
- elif self._tool_choice == "auto":
259
- # model can call any function
260
- tool_config = types.ToolConfig(
261
- function_calling_config=types.FunctionCallingConfig(
262
- mode=types.FunctionCallingConfigMode.AUTO
263
- )
264
- )
265
- elif self._tool_choice == "none":
266
- # model cannot call any function
267
- tool_config = types.ToolConfig(
268
- function_calling_config=types.FunctionCallingConfig(
269
- mode=types.FunctionCallingConfigMode.NONE,
270
- )
271
- )
272
- opts["tool_config"] = tool_config
273
-
246
+ turns, system_instruction = to_chat_ctx(self._chat_ctx, id(self._llm))
247
+ function_declarations = to_fnc_ctx(self._tools)
248
+ if function_declarations:
249
+ self._extra_kwargs["tools"] = [
250
+ types.Tool(function_declarations=function_declarations)
251
+ ]
274
252
  config = types.GenerateContentConfig(
275
- candidate_count=self._n,
276
- temperature=self._temperature,
277
- max_output_tokens=self._max_output_tokens,
278
- top_p=self._top_p,
279
- top_k=self._top_k,
280
- presence_penalty=self._presence_penalty,
281
- frequency_penalty=self._frequency_penalty,
282
253
  system_instruction=system_instruction,
283
- **opts,
254
+ **self._extra_kwargs,
284
255
  )
256
+
285
257
  stream = await self._client.aio.models.generate_content_stream(
286
258
  model=self._model,
287
259
  contents=cast(types.ContentListUnion, turns),
288
260
  config=config,
289
261
  )
290
- async for response in stream: # type: ignore
262
+
263
+ async for response in stream:
291
264
  if response.prompt_feedback:
292
265
  raise APIStatusError(
293
266
  response.prompt_feedback.json(),
@@ -308,11 +281,11 @@ class LLMStream(llm.LLMStream):
308
281
 
309
282
  if len(response.candidates) > 1:
310
283
  logger.warning(
311
- "gemini llm: there are multiple candidates in the response, returning response from the first one."
284
+ "gemini llm: there are multiple candidates in the response, returning response from the first one." # noqa: E501
312
285
  )
313
286
 
314
- for index, part in enumerate(response.candidates[0].content.parts):
315
- chat_chunk = self._parse_part(request_id, index, part)
287
+ for part in response.candidates[0].content.parts:
288
+ chat_chunk = self._parse_part(request_id, part)
316
289
  if chat_chunk is not None:
317
290
  retryable = False
318
291
  self._event_ch.send_nowait(chat_chunk)
@@ -321,7 +294,7 @@ class LLMStream(llm.LLMStream):
321
294
  usage = response.usage_metadata
322
295
  self._event_ch.send_nowait(
323
296
  llm.ChatChunk(
324
- request_id=request_id,
297
+ id=request_id,
325
298
  usage=llm.CompletionUsage(
326
299
  completion_tokens=usage.candidates_token_count or 0,
327
300
  prompt_tokens=usage.prompt_token_count or 0,
@@ -329,11 +302,12 @@ class LLMStream(llm.LLMStream):
329
302
  ),
330
303
  )
331
304
  )
305
+
332
306
  except ClientError as e:
333
307
  raise APIStatusError(
334
308
  "gemini llm: client error",
335
309
  status_code=e.code,
336
- body=e.message,
310
+ body=e.message + e.status,
337
311
  request_id=request_id,
338
312
  retryable=False if e.code != 429 else True,
339
313
  ) from e
@@ -341,7 +315,7 @@ class LLMStream(llm.LLMStream):
341
315
  raise APIStatusError(
342
316
  "gemini llm: server error",
343
317
  status_code=e.code,
344
- body=e.message,
318
+ body=e.message + e.status,
345
319
  request_id=request_id,
346
320
  retryable=retryable,
347
321
  ) from e
@@ -349,71 +323,35 @@ class LLMStream(llm.LLMStream):
349
323
  raise APIStatusError(
350
324
  "gemini llm: api error",
351
325
  status_code=e.code,
352
- body=e.message,
326
+ body=e.message + e.status,
353
327
  request_id=request_id,
354
328
  retryable=retryable,
355
329
  ) from e
356
330
  except Exception as e:
357
331
  raise APIConnectionError(
358
- "gemini llm: error generating content",
332
+ f"gemini llm: error generating content {str(e)}",
359
333
  retryable=retryable,
360
334
  ) from e
361
335
 
362
- def _parse_part(
363
- self, id: str, index: int, part: types.Part
364
- ) -> llm.ChatChunk | None:
336
+ def _parse_part(self, id: str, part: types.Part) -> llm.ChatChunk | None:
365
337
  if part.function_call:
366
- return self._try_build_function(id, index, part)
367
-
368
- return llm.ChatChunk(
369
- request_id=id,
370
- choices=[
371
- llm.Choice(
372
- delta=llm.ChoiceDelta(content=part.text, role="assistant"),
373
- index=index,
374
- )
375
- ],
376
- )
377
-
378
- def _try_build_function(
379
- self, id: str, index: int, part: types.Part
380
- ) -> llm.ChatChunk | None:
381
- if part.function_call is None:
382
- logger.warning("gemini llm: no function call in the response")
383
- return None
384
-
385
- if part.function_call.name is None:
386
- logger.warning("gemini llm: no function name in the response")
387
- return None
388
-
389
- if part.function_call.id is None:
390
- part.function_call.id = utils.shortuuid()
391
-
392
- if self._fnc_ctx is None:
393
- logger.warning(
394
- "google stream tried to run function without function context"
338
+ chat_chunk = llm.ChatChunk(
339
+ id=id,
340
+ delta=llm.ChoiceDelta(
341
+ role="assistant",
342
+ tool_calls=[
343
+ llm.FunctionToolCall(
344
+ arguments=json.dumps(part.function_call.args),
345
+ name=part.function_call.name,
346
+ call_id=part.function_call.id or utils.shortuuid("function_call_"),
347
+ )
348
+ ],
349
+ content=part.text,
350
+ ),
395
351
  )
396
- return None
397
-
398
- fnc_info = _create_ai_function_info(
399
- self._fnc_ctx,
400
- part.function_call.id,
401
- part.function_call.name,
402
- json.dumps(part.function_call.args),
403
- )
404
-
405
- self._function_calls_info.append(fnc_info)
352
+ return chat_chunk
406
353
 
407
354
  return llm.ChatChunk(
408
- request_id=id,
409
- choices=[
410
- llm.Choice(
411
- delta=llm.ChoiceDelta(
412
- role="assistant",
413
- tool_calls=[fnc_info],
414
- content=part.text,
415
- ),
416
- index=index,
417
- )
418
- ],
355
+ id=id,
356
+ delta=llm.ChoiceDelta(content=part.text, role="assistant"),
419
357
  )