livekit-plugins-google 0.3.0__py3-none-any.whl → 1.3.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/google/__init__.py +25 -7
- livekit/plugins/google/beta/__init__.py +13 -0
- livekit/plugins/google/beta/gemini_tts.py +258 -0
- livekit/plugins/google/llm.py +501 -0
- livekit/plugins/google/log.py +3 -0
- livekit/plugins/google/models.py +145 -31
- livekit/plugins/google/realtime/__init__.py +9 -0
- livekit/plugins/google/realtime/api_proto.py +66 -0
- livekit/plugins/google/realtime/realtime_api.py +1252 -0
- livekit/plugins/google/stt.py +518 -272
- livekit/plugins/google/tools.py +11 -0
- livekit/plugins/google/tts.py +447 -0
- livekit/plugins/google/utils.py +286 -0
- livekit/plugins/google/version.py +1 -1
- livekit_plugins_google-1.3.8.dist-info/METADATA +63 -0
- livekit_plugins_google-1.3.8.dist-info/RECORD +18 -0
- {livekit_plugins_google-0.3.0.dist-info → livekit_plugins_google-1.3.8.dist-info}/WHEEL +1 -2
- livekit_plugins_google-0.3.0.dist-info/METADATA +0 -47
- livekit_plugins_google-0.3.0.dist-info/RECORD +0 -9
- livekit_plugins_google-0.3.0.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,501 @@
|
|
|
1
|
+
# Copyright 2023 LiveKit, Inc.
|
|
2
|
+
#
|
|
3
|
+
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import json
|
|
19
|
+
import os
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
from typing import Any, cast
|
|
22
|
+
|
|
23
|
+
from google.auth._default_async import default_async
|
|
24
|
+
from google.genai import Client, types
|
|
25
|
+
from google.genai.errors import APIError, ClientError, ServerError
|
|
26
|
+
from livekit.agents import APIConnectionError, APIStatusError, llm, utils
|
|
27
|
+
from livekit.agents.llm import FunctionTool, RawFunctionTool, ToolChoice, utils as llm_utils
|
|
28
|
+
from livekit.agents.llm.tool_context import (
|
|
29
|
+
get_function_info,
|
|
30
|
+
get_raw_function_info,
|
|
31
|
+
is_function_tool,
|
|
32
|
+
is_raw_function_tool,
|
|
33
|
+
)
|
|
34
|
+
from livekit.agents.types import (
|
|
35
|
+
DEFAULT_API_CONNECT_OPTIONS,
|
|
36
|
+
NOT_GIVEN,
|
|
37
|
+
APIConnectOptions,
|
|
38
|
+
NotGivenOr,
|
|
39
|
+
)
|
|
40
|
+
from livekit.agents.utils import is_given
|
|
41
|
+
|
|
42
|
+
from .log import logger
|
|
43
|
+
from .models import ChatModels
|
|
44
|
+
from .tools import _LLMTool
|
|
45
|
+
from .utils import create_tools_config, to_fnc_ctx, to_response_format
|
|
46
|
+
from .version import __version__
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class _LLMOptions:
|
|
51
|
+
model: ChatModels | str
|
|
52
|
+
temperature: NotGivenOr[float]
|
|
53
|
+
tool_choice: NotGivenOr[ToolChoice]
|
|
54
|
+
vertexai: NotGivenOr[bool]
|
|
55
|
+
project: NotGivenOr[str]
|
|
56
|
+
location: NotGivenOr[str]
|
|
57
|
+
max_output_tokens: NotGivenOr[int]
|
|
58
|
+
top_p: NotGivenOr[float]
|
|
59
|
+
top_k: NotGivenOr[float]
|
|
60
|
+
presence_penalty: NotGivenOr[float]
|
|
61
|
+
frequency_penalty: NotGivenOr[float]
|
|
62
|
+
thinking_config: NotGivenOr[types.ThinkingConfigOrDict]
|
|
63
|
+
automatic_function_calling_config: NotGivenOr[types.AutomaticFunctionCallingConfigOrDict]
|
|
64
|
+
gemini_tools: NotGivenOr[list[_LLMTool]]
|
|
65
|
+
http_options: NotGivenOr[types.HttpOptions]
|
|
66
|
+
seed: NotGivenOr[int]
|
|
67
|
+
safety_settings: NotGivenOr[list[types.SafetySettingOrDict]]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
BLOCKED_REASONS = [
|
|
71
|
+
types.FinishReason.SAFETY,
|
|
72
|
+
types.FinishReason.SPII,
|
|
73
|
+
types.FinishReason.PROHIBITED_CONTENT,
|
|
74
|
+
types.FinishReason.BLOCKLIST,
|
|
75
|
+
types.FinishReason.LANGUAGE,
|
|
76
|
+
types.FinishReason.RECITATION,
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class LLM(llm.LLM):
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
*,
|
|
84
|
+
model: ChatModels | str = "gemini-2.0-flash-001",
|
|
85
|
+
api_key: NotGivenOr[str] = NOT_GIVEN,
|
|
86
|
+
vertexai: NotGivenOr[bool] = NOT_GIVEN,
|
|
87
|
+
project: NotGivenOr[str] = NOT_GIVEN,
|
|
88
|
+
location: NotGivenOr[str] = NOT_GIVEN,
|
|
89
|
+
temperature: NotGivenOr[float] = NOT_GIVEN,
|
|
90
|
+
max_output_tokens: NotGivenOr[int] = NOT_GIVEN,
|
|
91
|
+
top_p: NotGivenOr[float] = NOT_GIVEN,
|
|
92
|
+
top_k: NotGivenOr[float] = NOT_GIVEN,
|
|
93
|
+
presence_penalty: NotGivenOr[float] = NOT_GIVEN,
|
|
94
|
+
frequency_penalty: NotGivenOr[float] = NOT_GIVEN,
|
|
95
|
+
tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
|
|
96
|
+
thinking_config: NotGivenOr[types.ThinkingConfigOrDict] = NOT_GIVEN,
|
|
97
|
+
automatic_function_calling_config: NotGivenOr[
|
|
98
|
+
types.AutomaticFunctionCallingConfigOrDict
|
|
99
|
+
] = NOT_GIVEN,
|
|
100
|
+
gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN,
|
|
101
|
+
http_options: NotGivenOr[types.HttpOptions] = NOT_GIVEN,
|
|
102
|
+
seed: NotGivenOr[int] = NOT_GIVEN,
|
|
103
|
+
safety_settings: NotGivenOr[list[types.SafetySettingOrDict]] = NOT_GIVEN,
|
|
104
|
+
) -> None:
|
|
105
|
+
"""
|
|
106
|
+
Create a new instance of Google GenAI LLM.
|
|
107
|
+
|
|
108
|
+
Environment Requirements:
|
|
109
|
+
- For VertexAI: Set the `GOOGLE_APPLICATION_CREDENTIALS` environment variable to the path of the service account key file or use any of the other Google Cloud auth methods.
|
|
110
|
+
The Google Cloud project and location can be set via `project` and `location` arguments or the environment variables
|
|
111
|
+
`GOOGLE_CLOUD_PROJECT` and `GOOGLE_CLOUD_LOCATION`. By default, the project is inferred from the service account key file,
|
|
112
|
+
and the location defaults to "us-central1".
|
|
113
|
+
- For Google Gemini API: Set the `api_key` argument or the `GOOGLE_API_KEY` environment variable.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
model (ChatModels | str, optional): The model name to use. Defaults to "gemini-2.0-flash-001".
|
|
117
|
+
api_key (str, optional): The API key for Google Gemini. If not provided, it attempts to read from the `GOOGLE_API_KEY` environment variable.
|
|
118
|
+
vertexai (bool, optional): Whether to use VertexAI. If not provided, it attempts to read from the `GOOGLE_GENAI_USE_VERTEXAI` environment variable. Defaults to False.
|
|
119
|
+
project (str, optional): The Google Cloud project to use (only for VertexAI). Defaults to None.
|
|
120
|
+
location (str, optional): The location to use for VertexAI API requests. Defaults value is "us-central1".
|
|
121
|
+
temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
|
|
122
|
+
max_output_tokens (int, optional): Maximum number of tokens to generate in the output. Defaults to None.
|
|
123
|
+
top_p (float, optional): The nucleus sampling probability for response generation. Defaults to None.
|
|
124
|
+
top_k (int, optional): The top-k sampling value for response generation. Defaults to None.
|
|
125
|
+
presence_penalty (float, optional): Penalizes the model for generating previously mentioned concepts. Defaults to None.
|
|
126
|
+
frequency_penalty (float, optional): Penalizes the model for repeating words. Defaults to None.
|
|
127
|
+
tool_choice (ToolChoice, optional): Specifies whether to use tools during response generation. Defaults to "auto".
|
|
128
|
+
thinking_config (ThinkingConfigOrDict, optional): The thinking configuration for response generation. Defaults to None.
|
|
129
|
+
automatic_function_calling_config (AutomaticFunctionCallingConfigOrDict, optional): The automatic function calling configuration for response generation. Defaults to None.
|
|
130
|
+
gemini_tools (list[LLMTool], optional): The Gemini-specific tools to use for the session.
|
|
131
|
+
http_options (HttpOptions, optional): The HTTP options to use for the session.
|
|
132
|
+
seed (int, optional): Random seed for reproducible generation. Defaults to None.
|
|
133
|
+
safety_settings (list[SafetySettingOrDict], optional): Safety settings for content filtering. Defaults to None.
|
|
134
|
+
""" # noqa: E501
|
|
135
|
+
super().__init__()
|
|
136
|
+
gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
|
|
137
|
+
gcp_location: str | None = (
|
|
138
|
+
location
|
|
139
|
+
if is_given(location)
|
|
140
|
+
else os.environ.get("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
|
141
|
+
)
|
|
142
|
+
use_vertexai = (
|
|
143
|
+
vertexai
|
|
144
|
+
if is_given(vertexai)
|
|
145
|
+
else os.environ.get("GOOGLE_GENAI_USE_VERTEXAI", "0").lower() in ["true", "1"]
|
|
146
|
+
)
|
|
147
|
+
gemini_api_key = api_key if is_given(api_key) else os.environ.get("GOOGLE_API_KEY")
|
|
148
|
+
|
|
149
|
+
if use_vertexai:
|
|
150
|
+
if not gcp_project:
|
|
151
|
+
_, gcp_project = default_async( # type: ignore
|
|
152
|
+
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
|
153
|
+
)
|
|
154
|
+
if not gcp_project or not gcp_location:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
"Project is required for VertexAI via project kwarg or GOOGLE_CLOUD_PROJECT environment variable" # noqa: E501
|
|
157
|
+
)
|
|
158
|
+
gemini_api_key = None # VertexAI does not require an API key
|
|
159
|
+
|
|
160
|
+
else:
|
|
161
|
+
gcp_project = None
|
|
162
|
+
gcp_location = None
|
|
163
|
+
if not gemini_api_key:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
"API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable" # noqa: E501
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Validate thinking_config
|
|
169
|
+
if is_given(thinking_config):
|
|
170
|
+
_thinking_budget = None
|
|
171
|
+
if isinstance(thinking_config, dict):
|
|
172
|
+
_thinking_budget = thinking_config.get("thinking_budget")
|
|
173
|
+
elif isinstance(thinking_config, types.ThinkingConfig):
|
|
174
|
+
_thinking_budget = thinking_config.thinking_budget
|
|
175
|
+
|
|
176
|
+
if _thinking_budget is not None:
|
|
177
|
+
if not isinstance(_thinking_budget, int):
|
|
178
|
+
raise ValueError("thinking_budget inside thinking_config must be an integer")
|
|
179
|
+
|
|
180
|
+
self._opts = _LLMOptions(
|
|
181
|
+
model=model,
|
|
182
|
+
temperature=temperature,
|
|
183
|
+
tool_choice=tool_choice,
|
|
184
|
+
vertexai=use_vertexai,
|
|
185
|
+
project=project,
|
|
186
|
+
location=location,
|
|
187
|
+
max_output_tokens=max_output_tokens,
|
|
188
|
+
top_p=top_p,
|
|
189
|
+
top_k=top_k,
|
|
190
|
+
presence_penalty=presence_penalty,
|
|
191
|
+
frequency_penalty=frequency_penalty,
|
|
192
|
+
thinking_config=thinking_config,
|
|
193
|
+
automatic_function_calling_config=automatic_function_calling_config,
|
|
194
|
+
gemini_tools=gemini_tools,
|
|
195
|
+
http_options=http_options,
|
|
196
|
+
seed=seed,
|
|
197
|
+
safety_settings=safety_settings,
|
|
198
|
+
)
|
|
199
|
+
self._client = Client(
|
|
200
|
+
api_key=gemini_api_key,
|
|
201
|
+
vertexai=use_vertexai,
|
|
202
|
+
project=gcp_project,
|
|
203
|
+
location=gcp_location,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
@property
|
|
207
|
+
def model(self) -> str:
|
|
208
|
+
return self._opts.model
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def provider(self) -> str:
|
|
212
|
+
if self._client.vertexai:
|
|
213
|
+
return "Vertex AI"
|
|
214
|
+
else:
|
|
215
|
+
return "Gemini"
|
|
216
|
+
|
|
217
|
+
def chat(
|
|
218
|
+
self,
|
|
219
|
+
*,
|
|
220
|
+
chat_ctx: llm.ChatContext,
|
|
221
|
+
tools: list[FunctionTool | RawFunctionTool] | None = None,
|
|
222
|
+
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
223
|
+
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
|
|
224
|
+
tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
|
|
225
|
+
response_format: NotGivenOr[
|
|
226
|
+
types.SchemaUnion | type[llm_utils.ResponseFormatT]
|
|
227
|
+
] = NOT_GIVEN,
|
|
228
|
+
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
|
|
229
|
+
gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN,
|
|
230
|
+
) -> LLMStream:
|
|
231
|
+
extra = {}
|
|
232
|
+
|
|
233
|
+
if is_given(extra_kwargs):
|
|
234
|
+
extra.update(extra_kwargs)
|
|
235
|
+
|
|
236
|
+
tool_choice = (
|
|
237
|
+
cast(ToolChoice, tool_choice) if is_given(tool_choice) else self._opts.tool_choice
|
|
238
|
+
)
|
|
239
|
+
if is_given(tool_choice):
|
|
240
|
+
gemini_tool_choice: types.ToolConfig
|
|
241
|
+
if isinstance(tool_choice, dict) and tool_choice.get("type") == "function":
|
|
242
|
+
gemini_tool_choice = types.ToolConfig(
|
|
243
|
+
function_calling_config=types.FunctionCallingConfig(
|
|
244
|
+
mode=types.FunctionCallingConfigMode.ANY,
|
|
245
|
+
allowed_function_names=[tool_choice["function"]["name"]],
|
|
246
|
+
)
|
|
247
|
+
)
|
|
248
|
+
extra["tool_config"] = gemini_tool_choice
|
|
249
|
+
elif tool_choice == "required":
|
|
250
|
+
tool_names = []
|
|
251
|
+
for tool in tools or []:
|
|
252
|
+
if is_function_tool(tool):
|
|
253
|
+
tool_names.append(get_function_info(tool).name)
|
|
254
|
+
elif is_raw_function_tool(tool):
|
|
255
|
+
tool_names.append(get_raw_function_info(tool).name)
|
|
256
|
+
|
|
257
|
+
gemini_tool_choice = types.ToolConfig(
|
|
258
|
+
function_calling_config=types.FunctionCallingConfig(
|
|
259
|
+
mode=types.FunctionCallingConfigMode.ANY,
|
|
260
|
+
allowed_function_names=tool_names or None,
|
|
261
|
+
)
|
|
262
|
+
)
|
|
263
|
+
extra["tool_config"] = gemini_tool_choice
|
|
264
|
+
elif tool_choice == "auto":
|
|
265
|
+
gemini_tool_choice = types.ToolConfig(
|
|
266
|
+
function_calling_config=types.FunctionCallingConfig(
|
|
267
|
+
mode=types.FunctionCallingConfigMode.AUTO,
|
|
268
|
+
)
|
|
269
|
+
)
|
|
270
|
+
extra["tool_config"] = gemini_tool_choice
|
|
271
|
+
elif tool_choice == "none":
|
|
272
|
+
gemini_tool_choice = types.ToolConfig(
|
|
273
|
+
function_calling_config=types.FunctionCallingConfig(
|
|
274
|
+
mode=types.FunctionCallingConfigMode.NONE,
|
|
275
|
+
)
|
|
276
|
+
)
|
|
277
|
+
extra["tool_config"] = gemini_tool_choice
|
|
278
|
+
|
|
279
|
+
if is_given(response_format):
|
|
280
|
+
extra["response_schema"] = to_response_format(response_format) # type: ignore
|
|
281
|
+
extra["response_mime_type"] = "application/json"
|
|
282
|
+
|
|
283
|
+
if is_given(self._opts.temperature):
|
|
284
|
+
extra["temperature"] = self._opts.temperature
|
|
285
|
+
if is_given(self._opts.max_output_tokens):
|
|
286
|
+
extra["max_output_tokens"] = self._opts.max_output_tokens
|
|
287
|
+
if is_given(self._opts.top_p):
|
|
288
|
+
extra["top_p"] = self._opts.top_p
|
|
289
|
+
if is_given(self._opts.top_k):
|
|
290
|
+
extra["top_k"] = self._opts.top_k
|
|
291
|
+
if is_given(self._opts.presence_penalty):
|
|
292
|
+
extra["presence_penalty"] = self._opts.presence_penalty
|
|
293
|
+
if is_given(self._opts.frequency_penalty):
|
|
294
|
+
extra["frequency_penalty"] = self._opts.frequency_penalty
|
|
295
|
+
if is_given(self._opts.seed):
|
|
296
|
+
extra["seed"] = self._opts.seed
|
|
297
|
+
|
|
298
|
+
# Add thinking config if thinking_budget is provided
|
|
299
|
+
if is_given(self._opts.thinking_config):
|
|
300
|
+
extra["thinking_config"] = self._opts.thinking_config
|
|
301
|
+
|
|
302
|
+
if is_given(self._opts.automatic_function_calling_config):
|
|
303
|
+
extra["automatic_function_calling"] = self._opts.automatic_function_calling_config
|
|
304
|
+
|
|
305
|
+
if is_given(self._opts.safety_settings):
|
|
306
|
+
extra["safety_settings"] = self._opts.safety_settings
|
|
307
|
+
|
|
308
|
+
gemini_tools = gemini_tools if is_given(gemini_tools) else self._opts.gemini_tools
|
|
309
|
+
|
|
310
|
+
return LLMStream(
|
|
311
|
+
self,
|
|
312
|
+
client=self._client,
|
|
313
|
+
model=self._opts.model,
|
|
314
|
+
chat_ctx=chat_ctx,
|
|
315
|
+
tools=tools or [],
|
|
316
|
+
conn_options=conn_options,
|
|
317
|
+
gemini_tools=gemini_tools,
|
|
318
|
+
extra_kwargs=extra,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class LLMStream(llm.LLMStream):
|
|
323
|
+
def __init__(
|
|
324
|
+
self,
|
|
325
|
+
llm: LLM,
|
|
326
|
+
*,
|
|
327
|
+
client: Client,
|
|
328
|
+
model: str | ChatModels,
|
|
329
|
+
chat_ctx: llm.ChatContext,
|
|
330
|
+
conn_options: APIConnectOptions,
|
|
331
|
+
tools: list[FunctionTool | RawFunctionTool],
|
|
332
|
+
extra_kwargs: dict[str, Any],
|
|
333
|
+
gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN,
|
|
334
|
+
) -> None:
|
|
335
|
+
super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
|
|
336
|
+
self._client = client
|
|
337
|
+
self._model = model
|
|
338
|
+
self._llm: LLM = llm
|
|
339
|
+
self._extra_kwargs = extra_kwargs
|
|
340
|
+
self._gemini_tools = gemini_tools
|
|
341
|
+
|
|
342
|
+
async def _run(self) -> None:
|
|
343
|
+
retryable = True
|
|
344
|
+
request_id = utils.shortuuid()
|
|
345
|
+
|
|
346
|
+
try:
|
|
347
|
+
turns_dict, extra_data = self._chat_ctx.to_provider_format(format="google")
|
|
348
|
+
turns = [types.Content.model_validate(turn) for turn in turns_dict]
|
|
349
|
+
function_declarations = to_fnc_ctx(self._tools)
|
|
350
|
+
tools_config = create_tools_config(
|
|
351
|
+
function_tools=function_declarations,
|
|
352
|
+
gemini_tools=self._gemini_tools if is_given(self._gemini_tools) else None,
|
|
353
|
+
)
|
|
354
|
+
if tools_config:
|
|
355
|
+
self._extra_kwargs["tools"] = tools_config
|
|
356
|
+
http_options = self._llm._opts.http_options or types.HttpOptions(
|
|
357
|
+
timeout=int(self._conn_options.timeout * 1000)
|
|
358
|
+
)
|
|
359
|
+
if not http_options.headers:
|
|
360
|
+
http_options.headers = {}
|
|
361
|
+
http_options.headers["x-goog-api-client"] = f"livekit-agents/{__version__}"
|
|
362
|
+
config = types.GenerateContentConfig(
|
|
363
|
+
system_instruction=(
|
|
364
|
+
[types.Part(text=content) for content in extra_data.system_messages]
|
|
365
|
+
if extra_data.system_messages
|
|
366
|
+
else None
|
|
367
|
+
),
|
|
368
|
+
http_options=http_options,
|
|
369
|
+
**self._extra_kwargs,
|
|
370
|
+
)
|
|
371
|
+
stream = await self._client.aio.models.generate_content_stream(
|
|
372
|
+
model=self._model,
|
|
373
|
+
contents=cast(types.ContentListUnion, turns),
|
|
374
|
+
config=config,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
async for response in stream:
|
|
378
|
+
if response.prompt_feedback:
|
|
379
|
+
raise APIStatusError(
|
|
380
|
+
response.prompt_feedback.json(),
|
|
381
|
+
retryable=False,
|
|
382
|
+
request_id=request_id,
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
if (
|
|
386
|
+
not response.candidates
|
|
387
|
+
or not response.candidates[0].content
|
|
388
|
+
or not response.candidates[0].content.parts
|
|
389
|
+
):
|
|
390
|
+
logger.warning(f"no content in the response: {response}")
|
|
391
|
+
raise APIStatusError(
|
|
392
|
+
"no content in the response",
|
|
393
|
+
retryable=True,
|
|
394
|
+
request_id=request_id,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
if len(response.candidates) > 1:
|
|
398
|
+
logger.warning(
|
|
399
|
+
"gemini llm: there are multiple candidates in the response, returning response from the first one." # noqa: E501
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
candidate = response.candidates[0]
|
|
403
|
+
|
|
404
|
+
if candidate.finish_reason in BLOCKED_REASONS:
|
|
405
|
+
raise APIStatusError(
|
|
406
|
+
f"generation blocked by gemini: {candidate.finish_reason}",
|
|
407
|
+
retryable=False,
|
|
408
|
+
request_id=request_id,
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
if not candidate.content or not candidate.content.parts:
|
|
412
|
+
raise APIStatusError(
|
|
413
|
+
"no content in the response",
|
|
414
|
+
retryable=retryable,
|
|
415
|
+
request_id=request_id,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
chunks_yielded = False
|
|
419
|
+
for part in candidate.content.parts:
|
|
420
|
+
chat_chunk = self._parse_part(request_id, part)
|
|
421
|
+
if chat_chunk is not None:
|
|
422
|
+
chunks_yielded = True
|
|
423
|
+
retryable = False
|
|
424
|
+
self._event_ch.send_nowait(chat_chunk)
|
|
425
|
+
|
|
426
|
+
if candidate.finish_reason == types.FinishReason.STOP and not chunks_yielded:
|
|
427
|
+
raise APIStatusError(
|
|
428
|
+
"no response generated",
|
|
429
|
+
retryable=retryable,
|
|
430
|
+
request_id=request_id,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
if response.usage_metadata is not None:
|
|
434
|
+
usage = response.usage_metadata
|
|
435
|
+
self._event_ch.send_nowait(
|
|
436
|
+
llm.ChatChunk(
|
|
437
|
+
id=request_id,
|
|
438
|
+
usage=llm.CompletionUsage(
|
|
439
|
+
completion_tokens=usage.candidates_token_count or 0,
|
|
440
|
+
prompt_tokens=usage.prompt_token_count or 0,
|
|
441
|
+
prompt_cached_tokens=usage.cached_content_token_count or 0,
|
|
442
|
+
total_tokens=usage.total_token_count or 0,
|
|
443
|
+
),
|
|
444
|
+
)
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
except ClientError as e:
|
|
448
|
+
raise APIStatusError(
|
|
449
|
+
"gemini llm: client error",
|
|
450
|
+
status_code=e.code,
|
|
451
|
+
body=f"{e.message} {e.status}",
|
|
452
|
+
request_id=request_id,
|
|
453
|
+
retryable=False if e.code != 429 else True,
|
|
454
|
+
) from e
|
|
455
|
+
except ServerError as e:
|
|
456
|
+
raise APIStatusError(
|
|
457
|
+
"gemini llm: server error",
|
|
458
|
+
status_code=e.code,
|
|
459
|
+
body=f"{e.message} {e.status}",
|
|
460
|
+
request_id=request_id,
|
|
461
|
+
retryable=retryable,
|
|
462
|
+
) from e
|
|
463
|
+
except APIError as e:
|
|
464
|
+
raise APIStatusError(
|
|
465
|
+
"gemini llm: api error",
|
|
466
|
+
status_code=e.code,
|
|
467
|
+
body=f"{e.message} {e.status}",
|
|
468
|
+
request_id=request_id,
|
|
469
|
+
retryable=retryable,
|
|
470
|
+
) from e
|
|
471
|
+
except Exception as e:
|
|
472
|
+
raise APIConnectionError(
|
|
473
|
+
f"gemini llm: error generating content {str(e)}",
|
|
474
|
+
retryable=retryable,
|
|
475
|
+
) from e
|
|
476
|
+
|
|
477
|
+
def _parse_part(self, id: str, part: types.Part) -> llm.ChatChunk | None:
|
|
478
|
+
if part.function_call:
|
|
479
|
+
chat_chunk = llm.ChatChunk(
|
|
480
|
+
id=id,
|
|
481
|
+
delta=llm.ChoiceDelta(
|
|
482
|
+
role="assistant",
|
|
483
|
+
tool_calls=[
|
|
484
|
+
llm.FunctionToolCall(
|
|
485
|
+
arguments=json.dumps(part.function_call.args),
|
|
486
|
+
name=part.function_call.name,
|
|
487
|
+
call_id=part.function_call.id or utils.shortuuid("function_call_"),
|
|
488
|
+
)
|
|
489
|
+
],
|
|
490
|
+
content=part.text,
|
|
491
|
+
),
|
|
492
|
+
)
|
|
493
|
+
return chat_chunk
|
|
494
|
+
|
|
495
|
+
if not part.text:
|
|
496
|
+
return None
|
|
497
|
+
|
|
498
|
+
return llm.ChatChunk(
|
|
499
|
+
id=id,
|
|
500
|
+
delta=llm.ChoiceDelta(content=part.text, role="assistant"),
|
|
501
|
+
)
|