google-adk 0.5.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/base_agent.py +76 -30
- google/adk/agents/base_agent.py.orig +330 -0
- google/adk/agents/callback_context.py +0 -5
- google/adk/agents/llm_agent.py +122 -30
- google/adk/agents/loop_agent.py +1 -1
- google/adk/agents/parallel_agent.py +7 -0
- google/adk/agents/readonly_context.py +7 -1
- google/adk/agents/run_config.py +1 -1
- google/adk/agents/sequential_agent.py +31 -0
- google/adk/agents/transcription_entry.py +4 -2
- google/adk/artifacts/gcs_artifact_service.py +1 -1
- google/adk/artifacts/in_memory_artifact_service.py +1 -1
- google/adk/auth/auth_credential.py +6 -1
- google/adk/auth/auth_preprocessor.py +7 -1
- google/adk/auth/auth_tool.py +3 -4
- google/adk/cli/agent_graph.py +5 -5
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/{main-ULN5R5I5.js → main-QOEMUXM4.js} +44 -45
- google/adk/cli/cli.py +7 -7
- google/adk/cli/cli_deploy.py +7 -2
- google/adk/cli/cli_eval.py +172 -99
- google/adk/cli/cli_tools_click.py +147 -64
- google/adk/cli/fast_api.py +330 -148
- google/adk/cli/fast_api.py.orig +174 -80
- google/adk/cli/utils/common.py +23 -0
- google/adk/cli/utils/evals.py +83 -1
- google/adk/cli/utils/logs.py +13 -5
- google/adk/code_executors/__init__.py +3 -1
- google/adk/code_executors/built_in_code_executor.py +52 -0
- google/adk/evaluation/__init__.py +1 -1
- google/adk/evaluation/agent_evaluator.py +168 -128
- google/adk/evaluation/eval_case.py +102 -0
- google/adk/evaluation/eval_set.py +37 -0
- google/adk/evaluation/eval_sets_manager.py +42 -0
- google/adk/evaluation/evaluation_generator.py +88 -113
- google/adk/evaluation/evaluator.py +56 -0
- google/adk/evaluation/local_eval_sets_manager.py +264 -0
- google/adk/evaluation/response_evaluator.py +106 -2
- google/adk/evaluation/trajectory_evaluator.py +83 -2
- google/adk/events/event.py +6 -1
- google/adk/events/event_actions.py +6 -1
- google/adk/examples/example_util.py +3 -2
- google/adk/flows/llm_flows/_code_execution.py +9 -1
- google/adk/flows/llm_flows/audio_transcriber.py +4 -3
- google/adk/flows/llm_flows/base_llm_flow.py +54 -15
- google/adk/flows/llm_flows/functions.py +9 -8
- google/adk/flows/llm_flows/instructions.py +13 -5
- google/adk/flows/llm_flows/single_flow.py +1 -1
- google/adk/memory/__init__.py +1 -1
- google/adk/memory/_utils.py +23 -0
- google/adk/memory/base_memory_service.py +23 -21
- google/adk/memory/base_memory_service.py.orig +76 -0
- google/adk/memory/in_memory_memory_service.py +57 -25
- google/adk/memory/memory_entry.py +37 -0
- google/adk/memory/vertex_ai_rag_memory_service.py +38 -15
- google/adk/models/anthropic_llm.py +16 -9
- google/adk/models/gemini_llm_connection.py +11 -11
- google/adk/models/google_llm.py +9 -2
- google/adk/models/google_llm.py.orig +305 -0
- google/adk/models/lite_llm.py +77 -21
- google/adk/models/llm_response.py +14 -2
- google/adk/models/registry.py +1 -1
- google/adk/runners.py +65 -41
- google/adk/sessions/__init__.py +1 -1
- google/adk/sessions/base_session_service.py +6 -33
- google/adk/sessions/database_session_service.py +58 -65
- google/adk/sessions/in_memory_session_service.py +106 -24
- google/adk/sessions/session.py +3 -0
- google/adk/sessions/vertex_ai_session_service.py +23 -45
- google/adk/telemetry.py +3 -0
- google/adk/tools/__init__.py +4 -7
- google/adk/tools/{built_in_code_execution_tool.py → _built_in_code_execution_tool.py} +11 -0
- google/adk/tools/_memory_entry_utils.py +30 -0
- google/adk/tools/agent_tool.py +9 -9
- google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
- google/adk/tools/application_integration_tool/application_integration_toolset.py +107 -85
- google/adk/tools/application_integration_tool/clients/connections_client.py +20 -0
- google/adk/tools/application_integration_tool/clients/integration_client.py +6 -6
- google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
- google/adk/tools/base_toolset.py +58 -0
- google/adk/tools/enterprise_search_tool.py +65 -0
- google/adk/tools/function_parameter_parse_util.py +2 -2
- google/adk/tools/google_api_tool/__init__.py +18 -70
- google/adk/tools/google_api_tool/google_api_tool.py +11 -5
- google/adk/tools/google_api_tool/google_api_toolset.py +126 -0
- google/adk/tools/google_api_tool/google_api_toolsets.py +102 -0
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
- google/adk/tools/langchain_tool.py +96 -49
- google/adk/tools/load_memory_tool.py +14 -5
- google/adk/tools/mcp_tool/__init__.py +3 -2
- google/adk/tools/mcp_tool/mcp_session_manager.py +153 -16
- google/adk/tools/mcp_tool/mcp_session_manager.py.orig +322 -0
- google/adk/tools/mcp_tool/mcp_tool.py +12 -12
- google/adk/tools/mcp_tool/mcp_toolset.py +155 -195
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +32 -7
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +31 -31
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
- google/adk/tools/preload_memory_tool.py +27 -18
- google/adk/tools/retrieval/__init__.py +1 -1
- google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
- google/adk/tools/toolbox_toolset.py +79 -0
- google/adk/tools/transfer_to_agent_tool.py +0 -1
- google/adk/version.py +1 -1
- {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/METADATA +7 -5
- google_adk-1.0.0.dist-info/RECORD +195 -0
- google/adk/agents/remote_agent.py +0 -50
- google/adk/tools/google_api_tool/google_api_tool_set.py +0 -110
- google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
- google/adk/tools/toolbox_tool.py +0 -46
- google_adk-0.5.0.dist-info/RECORD +0 -180
- {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/WHEEL +0 -0
- {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/entry_points.txt +0 -0
- {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,305 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from __future__ import annotations
|
15
|
+
|
16
|
+
import contextlib
|
17
|
+
from functools import cached_property
|
18
|
+
import logging
|
19
|
+
import sys
|
20
|
+
from typing import AsyncGenerator
|
21
|
+
from typing import cast
|
22
|
+
from typing import TYPE_CHECKING
|
23
|
+
|
24
|
+
from google.genai import Client
|
25
|
+
from google.genai import types
|
26
|
+
from typing_extensions import override
|
27
|
+
|
28
|
+
from .. import version
|
29
|
+
from .base_llm import BaseLlm
|
30
|
+
from .base_llm_connection import BaseLlmConnection
|
31
|
+
from .gemini_llm_connection import GeminiLlmConnection
|
32
|
+
from .llm_response import LlmResponse
|
33
|
+
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from .llm_request import LlmRequest
|
36
|
+
|
37
|
+
logger = None
|
38
|
+
|
39
|
+
_NEW_LINE = '\n'
|
40
|
+
_EXCLUDED_PART_FIELD = {'inline_data': {'data'}}
|
41
|
+
|
42
|
+
|
43
|
+
class Gemini(BaseLlm):
|
44
|
+
"""Integration for Gemini models.
|
45
|
+
|
46
|
+
Attributes:
|
47
|
+
model: The name of the Gemini model.
|
48
|
+
"""
|
49
|
+
|
50
|
+
model: str = 'gemini-1.5-flash'
|
51
|
+
|
52
|
+
@staticmethod
|
53
|
+
@override
|
54
|
+
def supported_models() -> list[str]:
|
55
|
+
"""Provides the list of supported models.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
A list of supported models.
|
59
|
+
"""
|
60
|
+
|
61
|
+
return [
|
62
|
+
r'gemini-.*',
|
63
|
+
# fine-tuned vertex endpoint pattern
|
64
|
+
r'projects\/.+\/locations\/.+\/endpoints\/.+',
|
65
|
+
# vertex gemini long name
|
66
|
+
r'projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+',
|
67
|
+
]
|
68
|
+
|
69
|
+
async def generate_content_async(
|
70
|
+
self, llm_request: LlmRequest, stream: bool = False
|
71
|
+
) -> AsyncGenerator[LlmResponse, None]:
|
72
|
+
"""Sends a request to the Gemini model.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
llm_request: LlmRequest, the request to send to the Gemini model.
|
76
|
+
stream: bool = False, whether to do streaming call.
|
77
|
+
|
78
|
+
Yields:
|
79
|
+
LlmResponse: The model response.
|
80
|
+
"""
|
81
|
+
|
82
|
+
self._maybe_append_user_content(llm_request)
|
83
|
+
|
84
|
+
global logger
|
85
|
+
if not logger:
|
86
|
+
logger = logging.getLogger(__name__)
|
87
|
+
|
88
|
+
logger.info(
|
89
|
+
'Sending out request, model: %s, backend: %s, stream: %s',
|
90
|
+
llm_request.model,
|
91
|
+
self._api_backend,
|
92
|
+
stream,
|
93
|
+
)
|
94
|
+
logger.info(_build_request_log(llm_request))
|
95
|
+
|
96
|
+
print('********* Jack --> ')
|
97
|
+
for hh in logging.root.handlers:
|
98
|
+
print(hh, hh.level)
|
99
|
+
for hh in logger.handlers:
|
100
|
+
print(hh, hh.level)
|
101
|
+
print('********* Jack <-- ')
|
102
|
+
|
103
|
+
if stream:
|
104
|
+
responses = await self.api_client.aio.models.generate_content_stream(
|
105
|
+
model=llm_request.model,
|
106
|
+
contents=llm_request.contents,
|
107
|
+
config=llm_request.config,
|
108
|
+
)
|
109
|
+
response = None
|
110
|
+
text = ''
|
111
|
+
# for sse, similar as bidi (see receive method in gemini_llm_connecton.py),
|
112
|
+
# we need to mark those text content as partial and after all partial
|
113
|
+
# contents are sent, we send an accumulated event which contains all the
|
114
|
+
# previous partial content. The only difference is bidi rely on
|
115
|
+
# complete_turn flag to detect end while sse depends on finish_reason.
|
116
|
+
async for response in responses:
|
117
|
+
logger.info(_build_response_log(response))
|
118
|
+
llm_response = LlmResponse.create(response)
|
119
|
+
if (
|
120
|
+
llm_response.content
|
121
|
+
and llm_response.content.parts
|
122
|
+
and llm_response.content.parts[0].text
|
123
|
+
):
|
124
|
+
text += llm_response.content.parts[0].text
|
125
|
+
llm_response.partial = True
|
126
|
+
elif text and (
|
127
|
+
not llm_response.content
|
128
|
+
or not llm_response.content.parts
|
129
|
+
# don't yield the merged text event when receiving audio data
|
130
|
+
or not llm_response.content.parts[0].inline_data
|
131
|
+
):
|
132
|
+
yield LlmResponse(
|
133
|
+
content=types.ModelContent(
|
134
|
+
parts=[types.Part.from_text(text=text)],
|
135
|
+
),
|
136
|
+
usage_metadata=llm_response.usage_metadata,
|
137
|
+
)
|
138
|
+
text = ''
|
139
|
+
yield llm_response
|
140
|
+
if (
|
141
|
+
text
|
142
|
+
and response
|
143
|
+
and response.candidates
|
144
|
+
and response.candidates[0].finish_reason == types.FinishReason.STOP
|
145
|
+
):
|
146
|
+
yield LlmResponse(
|
147
|
+
content=types.ModelContent(
|
148
|
+
parts=[types.Part.from_text(text=text)],
|
149
|
+
),
|
150
|
+
)
|
151
|
+
|
152
|
+
else:
|
153
|
+
response = await self.api_client.aio.models.generate_content(
|
154
|
+
model=llm_request.model,
|
155
|
+
contents=llm_request.contents,
|
156
|
+
config=llm_request.config,
|
157
|
+
)
|
158
|
+
logger.info(_build_response_log(response))
|
159
|
+
yield LlmResponse.create(response)
|
160
|
+
|
161
|
+
@cached_property
|
162
|
+
def api_client(self) -> Client:
|
163
|
+
"""Provides the api client.
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
The api client.
|
167
|
+
"""
|
168
|
+
return Client(
|
169
|
+
http_options=types.HttpOptions(headers=self._tracking_headers)
|
170
|
+
)
|
171
|
+
|
172
|
+
@cached_property
|
173
|
+
def _api_backend(self) -> str:
|
174
|
+
return 'vertex' if self.api_client.vertexai else 'ml_dev'
|
175
|
+
|
176
|
+
@cached_property
|
177
|
+
def _tracking_headers(self) -> dict[str, str]:
|
178
|
+
framework_label = f'google-adk/{version.__version__}'
|
179
|
+
language_label = 'gl-python/' + sys.version.split()[0]
|
180
|
+
version_header_value = f'{framework_label} {language_label}'
|
181
|
+
tracking_headers = {
|
182
|
+
'x-goog-api-client': version_header_value,
|
183
|
+
'user-agent': version_header_value,
|
184
|
+
}
|
185
|
+
return tracking_headers
|
186
|
+
|
187
|
+
@cached_property
|
188
|
+
def _live_api_client(self) -> Client:
|
189
|
+
if self._api_backend == 'vertex':
|
190
|
+
# use beta version for vertex api
|
191
|
+
api_version = 'v1beta1'
|
192
|
+
# use default api version for vertex
|
193
|
+
return Client(
|
194
|
+
http_options=types.HttpOptions(
|
195
|
+
headers=self._tracking_headers, api_version=api_version
|
196
|
+
)
|
197
|
+
)
|
198
|
+
else:
|
199
|
+
# use v1alpha for ml_dev
|
200
|
+
api_version = 'v1alpha'
|
201
|
+
return Client(
|
202
|
+
http_options=types.HttpOptions(
|
203
|
+
headers=self._tracking_headers, api_version=api_version
|
204
|
+
)
|
205
|
+
)
|
206
|
+
|
207
|
+
@contextlib.asynccontextmanager
|
208
|
+
async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
|
209
|
+
"""Connects to the Gemini model and returns an llm connection.
|
210
|
+
|
211
|
+
Args:
|
212
|
+
llm_request: LlmRequest, the request to send to the Gemini model.
|
213
|
+
|
214
|
+
Yields:
|
215
|
+
BaseLlmConnection, the connection to the Gemini model.
|
216
|
+
"""
|
217
|
+
|
218
|
+
llm_request.live_connect_config.system_instruction = types.Content(
|
219
|
+
role='system',
|
220
|
+
parts=[
|
221
|
+
types.Part.from_text(text=llm_request.config.system_instruction)
|
222
|
+
],
|
223
|
+
)
|
224
|
+
llm_request.live_connect_config.tools = llm_request.config.tools
|
225
|
+
async with self._live_api_client.aio.live.connect(
|
226
|
+
model=llm_request.model, config=llm_request.live_connect_config
|
227
|
+
) as live_session:
|
228
|
+
yield GeminiLlmConnection(live_session)
|
229
|
+
|
230
|
+
|
231
|
+
def _build_function_declaration_log(
|
232
|
+
func_decl: types.FunctionDeclaration,
|
233
|
+
) -> str:
|
234
|
+
param_str = '{}'
|
235
|
+
if func_decl.parameters and func_decl.parameters.properties:
|
236
|
+
param_str = str({
|
237
|
+
k: v.model_dump(exclude_none=True)
|
238
|
+
for k, v in func_decl.parameters.properties.items()
|
239
|
+
})
|
240
|
+
return_str = 'None'
|
241
|
+
if func_decl.response:
|
242
|
+
return_str = str(func_decl.response.model_dump(exclude_none=True))
|
243
|
+
return f'{func_decl.name}: {param_str} -> {return_str}'
|
244
|
+
|
245
|
+
|
246
|
+
def _build_request_log(req: LlmRequest) -> str:
|
247
|
+
function_decls: list[types.FunctionDeclaration] = cast(
|
248
|
+
list[types.FunctionDeclaration],
|
249
|
+
req.config.tools[0].function_declarations if req.config.tools else [],
|
250
|
+
)
|
251
|
+
function_logs = (
|
252
|
+
[
|
253
|
+
_build_function_declaration_log(func_decl)
|
254
|
+
for func_decl in function_decls
|
255
|
+
]
|
256
|
+
if function_decls
|
257
|
+
else []
|
258
|
+
)
|
259
|
+
contents_logs = [
|
260
|
+
content.model_dump_json(
|
261
|
+
exclude_none=True,
|
262
|
+
exclude={
|
263
|
+
'parts': {
|
264
|
+
i: _EXCLUDED_PART_FIELD for i in range(len(content.parts))
|
265
|
+
}
|
266
|
+
},
|
267
|
+
)
|
268
|
+
for content in req.contents
|
269
|
+
]
|
270
|
+
|
271
|
+
return f"""
|
272
|
+
LLM Request:
|
273
|
+
-----------------------------------------------------------
|
274
|
+
System Instruction:
|
275
|
+
{req.config.system_instruction}
|
276
|
+
-----------------------------------------------------------
|
277
|
+
Contents:
|
278
|
+
{_NEW_LINE.join(contents_logs)}
|
279
|
+
-----------------------------------------------------------
|
280
|
+
Functions:
|
281
|
+
{_NEW_LINE.join(function_logs)}
|
282
|
+
-----------------------------------------------------------
|
283
|
+
"""
|
284
|
+
|
285
|
+
|
286
|
+
def _build_response_log(resp: types.GenerateContentResponse) -> str:
|
287
|
+
function_calls_text = []
|
288
|
+
if function_calls := resp.function_calls:
|
289
|
+
for func_call in function_calls:
|
290
|
+
function_calls_text.append(
|
291
|
+
f'name: {func_call.name}, args: {func_call.args}'
|
292
|
+
)
|
293
|
+
return f"""
|
294
|
+
LLM Response:
|
295
|
+
-----------------------------------------------------------
|
296
|
+
Text:
|
297
|
+
{resp.text}
|
298
|
+
-----------------------------------------------------------
|
299
|
+
Function calls:
|
300
|
+
{_NEW_LINE.join(function_calls_text)}
|
301
|
+
-----------------------------------------------------------
|
302
|
+
Raw response:
|
303
|
+
{resp.model_dump_json(exclude_none=True)}
|
304
|
+
-----------------------------------------------------------
|
305
|
+
"""
|
google/adk/models/lite_llm.py
CHANGED
@@ -51,7 +51,7 @@ from .base_llm import BaseLlm
|
|
51
51
|
from .llm_request import LlmRequest
|
52
52
|
from .llm_response import LlmResponse
|
53
53
|
|
54
|
-
logger = logging.getLogger(__name__)
|
54
|
+
logger = logging.getLogger("google_adk." + __name__)
|
55
55
|
|
56
56
|
_NEW_LINE = "\n"
|
57
57
|
_EXCLUDED_PART_FIELD = {"inline_data": {"data"}}
|
@@ -67,6 +67,12 @@ class TextChunk(BaseModel):
|
|
67
67
|
text: str
|
68
68
|
|
69
69
|
|
70
|
+
class UsageMetadataChunk(BaseModel):
|
71
|
+
prompt_tokens: int
|
72
|
+
completion_tokens: int
|
73
|
+
total_tokens: int
|
74
|
+
|
75
|
+
|
70
76
|
class LiteLLMClient:
|
71
77
|
"""Provides acompletion method (for better testability)."""
|
72
78
|
|
@@ -344,15 +350,20 @@ def _function_declaration_to_tool_param(
|
|
344
350
|
def _model_response_to_chunk(
|
345
351
|
response: ModelResponse,
|
346
352
|
) -> Generator[
|
347
|
-
Tuple[
|
353
|
+
Tuple[
|
354
|
+
Optional[Union[TextChunk, FunctionChunk, UsageMetadataChunk]],
|
355
|
+
Optional[str],
|
356
|
+
],
|
357
|
+
None,
|
358
|
+
None,
|
348
359
|
]:
|
349
|
-
"""Converts a litellm message to text or
|
360
|
+
"""Converts a litellm message to text, function or usage metadata chunk.
|
350
361
|
|
351
362
|
Args:
|
352
363
|
response: The response from the model.
|
353
364
|
|
354
365
|
Yields:
|
355
|
-
A tuple of text or function chunk and finish reason.
|
366
|
+
A tuple of text or function or usage metadata chunk and finish reason.
|
356
367
|
"""
|
357
368
|
|
358
369
|
message = None
|
@@ -384,11 +395,21 @@ def _model_response_to_chunk(
|
|
384
395
|
if not message:
|
385
396
|
yield None, None
|
386
397
|
|
398
|
+
# Ideally usage would be expected with the last ModelResponseStream with a
|
399
|
+
# finish_reason set. But this is not the case we are observing from litellm.
|
400
|
+
# So we are sending it as a separate chunk to be set on the llm_response.
|
401
|
+
if response.get("usage", None):
|
402
|
+
yield UsageMetadataChunk(
|
403
|
+
prompt_tokens=response["usage"].get("prompt_tokens", 0),
|
404
|
+
completion_tokens=response["usage"].get("completion_tokens", 0),
|
405
|
+
total_tokens=response["usage"].get("total_tokens", 0),
|
406
|
+
), None
|
407
|
+
|
387
408
|
|
388
409
|
def _model_response_to_generate_content_response(
|
389
410
|
response: ModelResponse,
|
390
411
|
) -> LlmResponse:
|
391
|
-
"""Converts a litellm response to LlmResponse.
|
412
|
+
"""Converts a litellm response to LlmResponse. Also adds usage metadata.
|
392
413
|
|
393
414
|
Args:
|
394
415
|
response: The model response.
|
@@ -403,7 +424,15 @@ def _model_response_to_generate_content_response(
|
|
403
424
|
|
404
425
|
if not message:
|
405
426
|
raise ValueError("No message in response")
|
406
|
-
|
427
|
+
|
428
|
+
llm_response = _message_to_generate_content_response(message)
|
429
|
+
if response.get("usage", None):
|
430
|
+
llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata(
|
431
|
+
prompt_token_count=response["usage"].get("prompt_tokens", 0),
|
432
|
+
candidates_token_count=response["usage"].get("completion_tokens", 0),
|
433
|
+
total_token_count=response["usage"].get("total_tokens", 0),
|
434
|
+
)
|
435
|
+
return llm_response
|
407
436
|
|
408
437
|
|
409
438
|
def _message_to_generate_content_response(
|
@@ -628,6 +657,10 @@ class LiteLlm(BaseLlm):
|
|
628
657
|
function_args = ""
|
629
658
|
function_id = None
|
630
659
|
completion_args["stream"] = True
|
660
|
+
aggregated_llm_response = None
|
661
|
+
aggregated_llm_response_with_tool_call = None
|
662
|
+
usage_metadata = None
|
663
|
+
|
631
664
|
for part in self.llm_client.completion(**completion_args):
|
632
665
|
for chunk, finish_reason in _model_response_to_chunk(part):
|
633
666
|
if isinstance(chunk, FunctionChunk):
|
@@ -645,32 +678,55 @@ class LiteLlm(BaseLlm):
|
|
645
678
|
),
|
646
679
|
is_partial=True,
|
647
680
|
)
|
681
|
+
elif isinstance(chunk, UsageMetadataChunk):
|
682
|
+
usage_metadata = types.GenerateContentResponseUsageMetadata(
|
683
|
+
prompt_token_count=chunk.prompt_tokens,
|
684
|
+
candidates_token_count=chunk.completion_tokens,
|
685
|
+
total_token_count=chunk.total_tokens,
|
686
|
+
)
|
687
|
+
|
648
688
|
if finish_reason == "tool_calls" and function_id:
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
689
|
+
aggregated_llm_response_with_tool_call = (
|
690
|
+
_message_to_generate_content_response(
|
691
|
+
ChatCompletionAssistantMessage(
|
692
|
+
role="assistant",
|
693
|
+
content="",
|
694
|
+
tool_calls=[
|
695
|
+
ChatCompletionMessageToolCall(
|
696
|
+
type="function",
|
697
|
+
id=function_id,
|
698
|
+
function=Function(
|
699
|
+
name=function_name,
|
700
|
+
arguments=function_args,
|
701
|
+
),
|
702
|
+
)
|
703
|
+
],
|
704
|
+
)
|
663
705
|
)
|
664
706
|
)
|
665
707
|
function_name = ""
|
666
708
|
function_args = ""
|
667
709
|
function_id = None
|
668
710
|
elif finish_reason == "stop" and text:
|
669
|
-
|
711
|
+
aggregated_llm_response = _message_to_generate_content_response(
|
670
712
|
ChatCompletionAssistantMessage(role="assistant", content=text)
|
671
713
|
)
|
672
714
|
text = ""
|
673
715
|
|
716
|
+
# waiting until streaming ends to yield the llm_response as litellm tends
|
717
|
+
# to send chunk that contains usage_metadata after the chunk with
|
718
|
+
# finish_reason set to tool_calls or stop.
|
719
|
+
if aggregated_llm_response:
|
720
|
+
if usage_metadata:
|
721
|
+
aggregated_llm_response.usage_metadata = usage_metadata
|
722
|
+
usage_metadata = None
|
723
|
+
yield aggregated_llm_response
|
724
|
+
|
725
|
+
if aggregated_llm_response_with_tool_call:
|
726
|
+
if usage_metadata:
|
727
|
+
aggregated_llm_response_with_tool_call.usage_metadata = usage_metadata
|
728
|
+
yield aggregated_llm_response_with_tool_call
|
729
|
+
|
674
730
|
else:
|
675
731
|
response = await self.llm_client.acompletion(**completion_args)
|
676
732
|
yield _model_response_to_generate_content_response(response)
|
@@ -17,6 +17,7 @@ from __future__ import annotations
|
|
17
17
|
from typing import Any, Optional
|
18
18
|
|
19
19
|
from google.genai import types
|
20
|
+
from pydantic import alias_generators
|
20
21
|
from pydantic import BaseModel
|
21
22
|
from pydantic import ConfigDict
|
22
23
|
|
@@ -40,7 +41,11 @@ class LlmResponse(BaseModel):
|
|
40
41
|
custom_metadata: The custom metadata of the LlmResponse.
|
41
42
|
"""
|
42
43
|
|
43
|
-
model_config = ConfigDict(
|
44
|
+
model_config = ConfigDict(
|
45
|
+
extra='forbid',
|
46
|
+
alias_generator=alias_generators.to_camel,
|
47
|
+
populate_by_name=True,
|
48
|
+
)
|
44
49
|
"""The pydantic model config."""
|
45
50
|
|
46
51
|
content: Optional[types.Content] = None
|
@@ -80,6 +85,9 @@ class LlmResponse(BaseModel):
|
|
80
85
|
NOTE: the entire dict must be JSON serializable.
|
81
86
|
"""
|
82
87
|
|
88
|
+
usage_metadata: Optional[types.GenerateContentResponseUsageMetadata] = None
|
89
|
+
"""The usage metadata of the LlmResponse"""
|
90
|
+
|
83
91
|
@staticmethod
|
84
92
|
def create(
|
85
93
|
generate_content_response: types.GenerateContentResponse,
|
@@ -93,18 +101,20 @@ class LlmResponse(BaseModel):
|
|
93
101
|
Returns:
|
94
102
|
The LlmResponse.
|
95
103
|
"""
|
96
|
-
|
104
|
+
usage_metadata = generate_content_response.usage_metadata
|
97
105
|
if generate_content_response.candidates:
|
98
106
|
candidate = generate_content_response.candidates[0]
|
99
107
|
if candidate.content and candidate.content.parts:
|
100
108
|
return LlmResponse(
|
101
109
|
content=candidate.content,
|
102
110
|
grounding_metadata=candidate.grounding_metadata,
|
111
|
+
usage_metadata=usage_metadata,
|
103
112
|
)
|
104
113
|
else:
|
105
114
|
return LlmResponse(
|
106
115
|
error_code=candidate.finish_reason,
|
107
116
|
error_message=candidate.finish_message,
|
117
|
+
usage_metadata=usage_metadata,
|
108
118
|
)
|
109
119
|
else:
|
110
120
|
if generate_content_response.prompt_feedback:
|
@@ -112,9 +122,11 @@ class LlmResponse(BaseModel):
|
|
112
122
|
return LlmResponse(
|
113
123
|
error_code=prompt_feedback.block_reason,
|
114
124
|
error_message=prompt_feedback.block_reason_message,
|
125
|
+
usage_metadata=usage_metadata,
|
115
126
|
)
|
116
127
|
else:
|
117
128
|
return LlmResponse(
|
118
129
|
error_code='UNKNOWN_ERROR',
|
119
130
|
error_message='Unknown error.',
|
131
|
+
usage_metadata=usage_metadata,
|
120
132
|
)
|
google/adk/models/registry.py
CHANGED