google-adk 0.5.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/base_agent.py +76 -30
- google/adk/agents/callback_context.py +2 -6
- google/adk/agents/llm_agent.py +122 -30
- google/adk/agents/loop_agent.py +1 -1
- google/adk/agents/parallel_agent.py +7 -0
- google/adk/agents/readonly_context.py +8 -0
- google/adk/agents/run_config.py +1 -1
- google/adk/agents/sequential_agent.py +31 -0
- google/adk/agents/transcription_entry.py +4 -2
- google/adk/artifacts/gcs_artifact_service.py +1 -1
- google/adk/artifacts/in_memory_artifact_service.py +1 -1
- google/adk/auth/auth_credential.py +10 -2
- google/adk/auth/auth_preprocessor.py +7 -1
- google/adk/auth/auth_tool.py +3 -4
- google/adk/cli/agent_graph.py +5 -5
- google/adk/cli/browser/index.html +4 -4
- google/adk/cli/browser/{main-ULN5R5I5.js → main-PKDNKWJE.js} +59 -60
- google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
- google/adk/cli/cli.py +10 -9
- google/adk/cli/cli_deploy.py +7 -2
- google/adk/cli/cli_eval.py +109 -115
- google/adk/cli/cli_tools_click.py +179 -67
- google/adk/cli/fast_api.py +248 -197
- google/adk/cli/utils/agent_loader.py +137 -0
- google/adk/cli/utils/cleanup.py +40 -0
- google/adk/cli/utils/common.py +23 -0
- google/adk/cli/utils/evals.py +83 -0
- google/adk/cli/utils/logs.py +8 -5
- google/adk/code_executors/__init__.py +3 -1
- google/adk/code_executors/built_in_code_executor.py +52 -0
- google/adk/code_executors/code_execution_utils.py +2 -1
- google/adk/code_executors/container_code_executor.py +0 -1
- google/adk/code_executors/vertex_ai_code_executor.py +6 -8
- google/adk/evaluation/__init__.py +1 -1
- google/adk/evaluation/agent_evaluator.py +168 -128
- google/adk/evaluation/eval_case.py +104 -0
- google/adk/evaluation/eval_metrics.py +74 -0
- google/adk/evaluation/eval_result.py +86 -0
- google/adk/evaluation/eval_set.py +39 -0
- google/adk/evaluation/eval_set_results_manager.py +47 -0
- google/adk/evaluation/eval_sets_manager.py +43 -0
- google/adk/evaluation/evaluation_generator.py +88 -113
- google/adk/evaluation/evaluator.py +58 -0
- google/adk/evaluation/local_eval_set_results_manager.py +113 -0
- google/adk/evaluation/local_eval_sets_manager.py +264 -0
- google/adk/evaluation/response_evaluator.py +106 -1
- google/adk/evaluation/trajectory_evaluator.py +84 -2
- google/adk/events/event.py +6 -1
- google/adk/events/event_actions.py +6 -1
- google/adk/examples/base_example_provider.py +1 -0
- google/adk/examples/example_util.py +3 -2
- google/adk/flows/llm_flows/_code_execution.py +9 -1
- google/adk/flows/llm_flows/audio_transcriber.py +4 -3
- google/adk/flows/llm_flows/base_llm_flow.py +58 -21
- google/adk/flows/llm_flows/contents.py +3 -1
- google/adk/flows/llm_flows/functions.py +9 -8
- google/adk/flows/llm_flows/instructions.py +18 -80
- google/adk/flows/llm_flows/single_flow.py +2 -2
- google/adk/memory/__init__.py +1 -1
- google/adk/memory/_utils.py +23 -0
- google/adk/memory/base_memory_service.py +23 -21
- google/adk/memory/in_memory_memory_service.py +57 -25
- google/adk/memory/memory_entry.py +37 -0
- google/adk/memory/vertex_ai_rag_memory_service.py +38 -15
- google/adk/models/anthropic_llm.py +16 -9
- google/adk/models/base_llm.py +2 -1
- google/adk/models/base_llm_connection.py +2 -0
- google/adk/models/gemini_llm_connection.py +11 -11
- google/adk/models/google_llm.py +12 -2
- google/adk/models/lite_llm.py +80 -23
- google/adk/models/llm_response.py +16 -3
- google/adk/models/registry.py +1 -1
- google/adk/runners.py +98 -42
- google/adk/sessions/__init__.py +1 -1
- google/adk/sessions/_session_util.py +2 -1
- google/adk/sessions/base_session_service.py +6 -33
- google/adk/sessions/database_session_service.py +57 -67
- google/adk/sessions/in_memory_session_service.py +106 -24
- google/adk/sessions/session.py +3 -0
- google/adk/sessions/vertex_ai_session_service.py +44 -51
- google/adk/telemetry.py +7 -2
- google/adk/tools/__init__.py +4 -7
- google/adk/tools/_memory_entry_utils.py +30 -0
- google/adk/tools/agent_tool.py +10 -10
- google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
- google/adk/tools/apihub_tool/clients/apihub_client.py +10 -3
- google/adk/tools/apihub_tool/clients/secret_client.py +1 -0
- google/adk/tools/application_integration_tool/application_integration_toolset.py +111 -85
- google/adk/tools/application_integration_tool/clients/connections_client.py +28 -1
- google/adk/tools/application_integration_tool/clients/integration_client.py +7 -5
- google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
- google/adk/tools/base_toolset.py +96 -0
- google/adk/tools/bigquery/__init__.py +28 -0
- google/adk/tools/bigquery/bigquery_credentials.py +216 -0
- google/adk/tools/bigquery/bigquery_tool.py +116 -0
- google/adk/tools/{built_in_code_execution_tool.py → enterprise_search_tool.py} +17 -11
- google/adk/tools/function_parameter_parse_util.py +9 -2
- google/adk/tools/function_tool.py +33 -3
- google/adk/tools/get_user_choice_tool.py +1 -0
- google/adk/tools/google_api_tool/__init__.py +24 -70
- google/adk/tools/google_api_tool/google_api_tool.py +12 -6
- google/adk/tools/google_api_tool/{google_api_tool_set.py → google_api_toolset.py} +57 -55
- google/adk/tools/google_api_tool/google_api_toolsets.py +108 -0
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
- google/adk/tools/google_search_tool.py +2 -2
- google/adk/tools/langchain_tool.py +96 -49
- google/adk/tools/load_memory_tool.py +14 -5
- google/adk/tools/mcp_tool/__init__.py +3 -2
- google/adk/tools/mcp_tool/conversion_utils.py +6 -2
- google/adk/tools/mcp_tool/mcp_session_manager.py +80 -69
- google/adk/tools/mcp_tool/mcp_tool.py +35 -32
- google/adk/tools/mcp_tool/mcp_toolset.py +99 -194
- google/adk/tools/openapi_tool/auth/credential_exchangers/base_credential_exchanger.py +1 -3
- google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +6 -7
- google/adk/tools/openapi_tool/common/common.py +5 -1
- google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py +7 -2
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +27 -7
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +36 -32
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +11 -1
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
- google/adk/tools/preload_memory_tool.py +27 -18
- google/adk/tools/retrieval/__init__.py +1 -1
- google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
- google/adk/tools/toolbox_toolset.py +107 -0
- google/adk/tools/transfer_to_agent_tool.py +0 -1
- google/adk/utils/__init__.py +13 -0
- google/adk/utils/instructions_utils.py +131 -0
- google/adk/version.py +1 -1
- {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/METADATA +18 -19
- google_adk-1.1.0.dist-info/RECORD +200 -0
- google/adk/agents/remote_agent.py +0 -50
- google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -18
- google/adk/cli/fast_api.py.orig +0 -728
- google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
- google/adk/tools/toolbox_tool.py +0 -46
- google_adk-0.5.0.dist-info/RECORD +0 -180
- {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/WHEEL +0 -0
- {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/entry_points.txt +0 -0
- {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -12,20 +12,28 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
15
18
|
from collections import OrderedDict
|
16
19
|
import json
|
17
20
|
import os
|
18
21
|
import tempfile
|
22
|
+
from typing import Optional
|
23
|
+
from typing import TYPE_CHECKING
|
19
24
|
|
20
25
|
from google.genai import types
|
21
26
|
from typing_extensions import override
|
22
27
|
from vertexai.preview import rag
|
23
28
|
|
24
|
-
from
|
25
|
-
from ..sessions.session import Session
|
29
|
+
from . import _utils
|
26
30
|
from .base_memory_service import BaseMemoryService
|
27
|
-
from .base_memory_service import MemoryResult
|
28
31
|
from .base_memory_service import SearchMemoryResponse
|
32
|
+
from .memory_entry import MemoryEntry
|
33
|
+
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from ..events.event import Event
|
36
|
+
from ..sessions.session import Session
|
29
37
|
|
30
38
|
|
31
39
|
class VertexAiRagMemoryService(BaseMemoryService):
|
@@ -33,8 +41,8 @@ class VertexAiRagMemoryService(BaseMemoryService):
|
|
33
41
|
|
34
42
|
def __init__(
|
35
43
|
self,
|
36
|
-
rag_corpus: str = None,
|
37
|
-
similarity_top_k: int = None,
|
44
|
+
rag_corpus: Optional[str] = None,
|
45
|
+
similarity_top_k: Optional[int] = None,
|
38
46
|
vector_distance_threshold: float = 10,
|
39
47
|
):
|
40
48
|
"""Initializes a VertexAiRagMemoryService.
|
@@ -47,8 +55,10 @@ class VertexAiRagMemoryService(BaseMemoryService):
|
|
47
55
|
vector_distance_threshold: Only returns contexts with vector distance
|
48
56
|
smaller than the threshold..
|
49
57
|
"""
|
50
|
-
self.
|
51
|
-
rag_resources=[
|
58
|
+
self._vertex_rag_store = types.VertexRagStore(
|
59
|
+
rag_resources=[
|
60
|
+
types.VertexRagStoreRagResource(rag_corpus=rag_corpus),
|
61
|
+
],
|
52
62
|
similarity_top_k=similarity_top_k,
|
53
63
|
vector_distance_threshold=vector_distance_threshold,
|
54
64
|
)
|
@@ -79,7 +89,11 @@ class VertexAiRagMemoryService(BaseMemoryService):
|
|
79
89
|
output_string = "\n".join(output_lines)
|
80
90
|
temp_file.write(output_string)
|
81
91
|
temp_file_path = temp_file.name
|
82
|
-
|
92
|
+
|
93
|
+
if not self._vertex_rag_store.rag_resources:
|
94
|
+
raise ValueError("Rag resources must be set.")
|
95
|
+
|
96
|
+
for rag_resource in self._vertex_rag_store.rag_resources:
|
83
97
|
rag.upload_file(
|
84
98
|
corpus_name=rag_resource.rag_corpus,
|
85
99
|
path=temp_file_path,
|
@@ -95,12 +109,14 @@ class VertexAiRagMemoryService(BaseMemoryService):
|
|
95
109
|
self, *, app_name: str, user_id: str, query: str
|
96
110
|
) -> SearchMemoryResponse:
|
97
111
|
"""Searches for sessions that match the query using rag.retrieval_query."""
|
112
|
+
from ..events.event import Event
|
113
|
+
|
98
114
|
response = rag.retrieval_query(
|
99
115
|
text=query,
|
100
|
-
rag_resources=self.
|
101
|
-
rag_corpora=self.
|
102
|
-
similarity_top_k=self.
|
103
|
-
vector_distance_threshold=self.
|
116
|
+
rag_resources=self._vertex_rag_store.rag_resources,
|
117
|
+
rag_corpora=self._vertex_rag_store.rag_corpora,
|
118
|
+
similarity_top_k=self._vertex_rag_store.similarity_top_k,
|
119
|
+
vector_distance_threshold=self._vertex_rag_store.vector_distance_threshold,
|
104
120
|
)
|
105
121
|
|
106
122
|
memory_results = []
|
@@ -144,9 +160,16 @@ class VertexAiRagMemoryService(BaseMemoryService):
|
|
144
160
|
for session_id, event_lists in session_events_map.items():
|
145
161
|
for events in _merge_event_lists(event_lists):
|
146
162
|
sorted_events = sorted(events, key=lambda e: e.timestamp)
|
147
|
-
|
148
|
-
|
149
|
-
|
163
|
+
|
164
|
+
memory_results.extend([
|
165
|
+
MemoryEntry(
|
166
|
+
author=event.author,
|
167
|
+
content=event.content,
|
168
|
+
timestamp=_utils.format_timestamp(event.timestamp),
|
169
|
+
)
|
170
|
+
for event in sorted_events
|
171
|
+
if event.content
|
172
|
+
])
|
150
173
|
return SearchMemoryResponse(memories=memory_results)
|
151
174
|
|
152
175
|
|
@@ -24,8 +24,9 @@ from typing import AsyncGenerator
|
|
24
24
|
from typing import Generator
|
25
25
|
from typing import Iterable
|
26
26
|
from typing import Literal
|
27
|
-
from typing import Optional
|
27
|
+
from typing import Optional
|
28
28
|
from typing import TYPE_CHECKING
|
29
|
+
from typing import Union
|
29
30
|
|
30
31
|
from anthropic import AnthropicVertex
|
31
32
|
from anthropic import NOT_GIVEN
|
@@ -42,7 +43,7 @@ if TYPE_CHECKING:
|
|
42
43
|
|
43
44
|
__all__ = ["Claude"]
|
44
45
|
|
45
|
-
logger = logging.getLogger(__name__)
|
46
|
+
logger = logging.getLogger("google_adk." + __name__)
|
46
47
|
|
47
48
|
MAX_TOKEN = 1024
|
48
49
|
|
@@ -140,15 +141,15 @@ def message_to_generate_content_response(
|
|
140
141
|
role="model",
|
141
142
|
parts=[content_block_to_part(cb) for cb in message.content],
|
142
143
|
),
|
144
|
+
usage_metadata=types.GenerateContentResponseUsageMetadata(
|
145
|
+
prompt_token_count=message.usage.input_tokens,
|
146
|
+
candidates_token_count=message.usage.output_tokens,
|
147
|
+
total_token_count=(
|
148
|
+
message.usage.input_tokens + message.usage.output_tokens
|
149
|
+
),
|
150
|
+
),
|
143
151
|
# TODO: Deal with these later.
|
144
152
|
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
|
145
|
-
# usage_metadata=types.GenerateContentResponseUsageMetadata(
|
146
|
-
# prompt_token_count=message.usage.input_tokens,
|
147
|
-
# candidates_token_count=message.usage.output_tokens,
|
148
|
-
# total_token_count=(
|
149
|
-
# message.usage.input_tokens + message.usage.output_tokens
|
150
|
-
# ),
|
151
|
-
# ),
|
152
153
|
)
|
153
154
|
|
154
155
|
|
@@ -196,6 +197,12 @@ def function_declaration_to_tool_param(
|
|
196
197
|
|
197
198
|
|
198
199
|
class Claude(BaseLlm):
|
200
|
+
""" "Integration with Claude models served from Vertex AI.
|
201
|
+
|
202
|
+
Attributes:
|
203
|
+
model: The name of the Claude model.
|
204
|
+
"""
|
205
|
+
|
199
206
|
model: str = "claude-3-5-sonnet-v2@20241022"
|
200
207
|
|
201
208
|
@staticmethod
|
google/adk/models/base_llm.py
CHANGED
@@ -14,7 +14,8 @@
|
|
14
14
|
from __future__ import annotations
|
15
15
|
|
16
16
|
from abc import abstractmethod
|
17
|
-
from typing import AsyncGenerator
|
17
|
+
from typing import AsyncGenerator
|
18
|
+
from typing import TYPE_CHECKING
|
18
19
|
|
19
20
|
from google.genai import types
|
20
21
|
from pydantic import BaseModel
|
@@ -21,7 +21,7 @@ from google.genai import types
|
|
21
21
|
from .base_llm_connection import BaseLlmConnection
|
22
22
|
from .llm_response import LlmResponse
|
23
23
|
|
24
|
-
logger = logging.getLogger(__name__)
|
24
|
+
logger = logging.getLogger('google_adk.' + __name__)
|
25
25
|
|
26
26
|
|
27
27
|
class GeminiLlmConnection(BaseLlmConnection):
|
@@ -149,16 +149,16 @@ class GeminiLlmConnection(BaseLlmConnection):
|
|
149
149
|
message.server_content.input_transcription
|
150
150
|
and message.server_content.input_transcription.text
|
151
151
|
):
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
152
|
+
user_text = message.server_content.input_transcription.text
|
153
|
+
parts = [
|
154
|
+
types.Part.from_text(
|
155
|
+
text=user_text,
|
156
|
+
)
|
157
|
+
]
|
158
|
+
llm_response = LlmResponse(
|
159
|
+
content=types.Content(role='user', parts=parts)
|
160
|
+
)
|
161
|
+
yield llm_response
|
162
162
|
if (
|
163
163
|
message.server_content.output_transcription
|
164
164
|
and message.server_content.output_transcription.text
|
google/adk/models/google_llm.py
CHANGED
@@ -11,6 +11,8 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
+
|
15
|
+
|
14
16
|
from __future__ import annotations
|
15
17
|
|
16
18
|
import contextlib
|
@@ -34,7 +36,7 @@ from .llm_response import LlmResponse
|
|
34
36
|
if TYPE_CHECKING:
|
35
37
|
from .llm_request import LlmRequest
|
36
38
|
|
37
|
-
logger = logging.getLogger(__name__)
|
39
|
+
logger = logging.getLogger('google_adk.' + __name__)
|
38
40
|
|
39
41
|
_NEW_LINE = '\n'
|
40
42
|
_EXCLUDED_PART_FIELD = {'inline_data': {'data'}}
|
@@ -96,6 +98,7 @@ class Gemini(BaseLlm):
|
|
96
98
|
)
|
97
99
|
response = None
|
98
100
|
text = ''
|
101
|
+
usage_metadata = None
|
99
102
|
# for sse, similar as bidi (see receive method in gemini_llm_connecton.py),
|
100
103
|
# we need to mark those text content as partial and after all partial
|
101
104
|
# contents are sent, we send an accumulated event which contains all the
|
@@ -104,6 +107,7 @@ class Gemini(BaseLlm):
|
|
104
107
|
async for response in responses:
|
105
108
|
logger.info(_build_response_log(response))
|
106
109
|
llm_response = LlmResponse.create(response)
|
110
|
+
usage_metadata = llm_response.usage_metadata
|
107
111
|
if (
|
108
112
|
llm_response.content
|
109
113
|
and llm_response.content.parts
|
@@ -121,6 +125,7 @@ class Gemini(BaseLlm):
|
|
121
125
|
content=types.ModelContent(
|
122
126
|
parts=[types.Part.from_text(text=text)],
|
123
127
|
),
|
128
|
+
usage_metadata=usage_metadata,
|
124
129
|
)
|
125
130
|
text = ''
|
126
131
|
yield llm_response
|
@@ -134,6 +139,7 @@ class Gemini(BaseLlm):
|
|
134
139
|
content=types.ModelContent(
|
135
140
|
parts=[types.Part.from_text(text=text)],
|
136
141
|
),
|
142
|
+
usage_metadata=usage_metadata,
|
137
143
|
)
|
138
144
|
|
139
145
|
else:
|
@@ -174,9 +180,13 @@ class Gemini(BaseLlm):
|
|
174
180
|
@cached_property
|
175
181
|
def _live_api_client(self) -> Client:
|
176
182
|
if self._api_backend == 'vertex':
|
183
|
+
# use beta version for vertex api
|
184
|
+
api_version = 'v1beta1'
|
177
185
|
# use default api version for vertex
|
178
186
|
return Client(
|
179
|
-
http_options=types.HttpOptions(
|
187
|
+
http_options=types.HttpOptions(
|
188
|
+
headers=self._tracking_headers, api_version=api_version
|
189
|
+
)
|
180
190
|
)
|
181
191
|
else:
|
182
192
|
# use v1alpha for ml_dev
|
google/adk/models/lite_llm.py
CHANGED
@@ -30,6 +30,7 @@ from typing import Union
|
|
30
30
|
from google.genai import types
|
31
31
|
from litellm import acompletion
|
32
32
|
from litellm import ChatCompletionAssistantMessage
|
33
|
+
from litellm import ChatCompletionAssistantToolCall
|
33
34
|
from litellm import ChatCompletionDeveloperMessage
|
34
35
|
from litellm import ChatCompletionImageUrlObject
|
35
36
|
from litellm import ChatCompletionMessageToolCall
|
@@ -51,7 +52,7 @@ from .base_llm import BaseLlm
|
|
51
52
|
from .llm_request import LlmRequest
|
52
53
|
from .llm_response import LlmResponse
|
53
54
|
|
54
|
-
logger = logging.getLogger(__name__)
|
55
|
+
logger = logging.getLogger("google_adk." + __name__)
|
55
56
|
|
56
57
|
_NEW_LINE = "\n"
|
57
58
|
_EXCLUDED_PART_FIELD = {"inline_data": {"data"}}
|
@@ -67,6 +68,12 @@ class TextChunk(BaseModel):
|
|
67
68
|
text: str
|
68
69
|
|
69
70
|
|
71
|
+
class UsageMetadataChunk(BaseModel):
|
72
|
+
prompt_tokens: int
|
73
|
+
completion_tokens: int
|
74
|
+
total_tokens: int
|
75
|
+
|
76
|
+
|
70
77
|
class LiteLLMClient:
|
71
78
|
"""Provides acompletion method (for better testability)."""
|
72
79
|
|
@@ -174,12 +181,12 @@ def _content_to_message_param(
|
|
174
181
|
for part in content.parts:
|
175
182
|
if part.function_call:
|
176
183
|
tool_calls.append(
|
177
|
-
|
184
|
+
ChatCompletionAssistantToolCall(
|
178
185
|
type="function",
|
179
186
|
id=part.function_call.id,
|
180
187
|
function=Function(
|
181
188
|
name=part.function_call.name,
|
182
|
-
arguments=part.function_call.args,
|
189
|
+
arguments=json.dumps(part.function_call.args),
|
183
190
|
),
|
184
191
|
)
|
185
192
|
)
|
@@ -344,15 +351,20 @@ def _function_declaration_to_tool_param(
|
|
344
351
|
def _model_response_to_chunk(
|
345
352
|
response: ModelResponse,
|
346
353
|
) -> Generator[
|
347
|
-
Tuple[
|
354
|
+
Tuple[
|
355
|
+
Optional[Union[TextChunk, FunctionChunk, UsageMetadataChunk]],
|
356
|
+
Optional[str],
|
357
|
+
],
|
358
|
+
None,
|
359
|
+
None,
|
348
360
|
]:
|
349
|
-
"""Converts a litellm message to text or
|
361
|
+
"""Converts a litellm message to text, function or usage metadata chunk.
|
350
362
|
|
351
363
|
Args:
|
352
364
|
response: The response from the model.
|
353
365
|
|
354
366
|
Yields:
|
355
|
-
A tuple of text or function chunk and finish reason.
|
367
|
+
A tuple of text or function or usage metadata chunk and finish reason.
|
356
368
|
"""
|
357
369
|
|
358
370
|
message = None
|
@@ -384,11 +396,21 @@ def _model_response_to_chunk(
|
|
384
396
|
if not message:
|
385
397
|
yield None, None
|
386
398
|
|
399
|
+
# Ideally usage would be expected with the last ModelResponseStream with a
|
400
|
+
# finish_reason set. But this is not the case we are observing from litellm.
|
401
|
+
# So we are sending it as a separate chunk to be set on the llm_response.
|
402
|
+
if response.get("usage", None):
|
403
|
+
yield UsageMetadataChunk(
|
404
|
+
prompt_tokens=response["usage"].get("prompt_tokens", 0),
|
405
|
+
completion_tokens=response["usage"].get("completion_tokens", 0),
|
406
|
+
total_tokens=response["usage"].get("total_tokens", 0),
|
407
|
+
), None
|
408
|
+
|
387
409
|
|
388
410
|
def _model_response_to_generate_content_response(
|
389
411
|
response: ModelResponse,
|
390
412
|
) -> LlmResponse:
|
391
|
-
"""Converts a litellm response to LlmResponse.
|
413
|
+
"""Converts a litellm response to LlmResponse. Also adds usage metadata.
|
392
414
|
|
393
415
|
Args:
|
394
416
|
response: The model response.
|
@@ -403,7 +425,15 @@ def _model_response_to_generate_content_response(
|
|
403
425
|
|
404
426
|
if not message:
|
405
427
|
raise ValueError("No message in response")
|
406
|
-
|
428
|
+
|
429
|
+
llm_response = _message_to_generate_content_response(message)
|
430
|
+
if response.get("usage", None):
|
431
|
+
llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata(
|
432
|
+
prompt_token_count=response["usage"].get("prompt_tokens", 0),
|
433
|
+
candidates_token_count=response["usage"].get("completion_tokens", 0),
|
434
|
+
total_token_count=response["usage"].get("total_tokens", 0),
|
435
|
+
)
|
436
|
+
return llm_response
|
407
437
|
|
408
438
|
|
409
439
|
def _message_to_generate_content_response(
|
@@ -628,6 +658,10 @@ class LiteLlm(BaseLlm):
|
|
628
658
|
function_args = ""
|
629
659
|
function_id = None
|
630
660
|
completion_args["stream"] = True
|
661
|
+
aggregated_llm_response = None
|
662
|
+
aggregated_llm_response_with_tool_call = None
|
663
|
+
usage_metadata = None
|
664
|
+
|
631
665
|
for part in self.llm_client.completion(**completion_args):
|
632
666
|
for chunk, finish_reason in _model_response_to_chunk(part):
|
633
667
|
if isinstance(chunk, FunctionChunk):
|
@@ -645,32 +679,55 @@ class LiteLlm(BaseLlm):
|
|
645
679
|
),
|
646
680
|
is_partial=True,
|
647
681
|
)
|
682
|
+
elif isinstance(chunk, UsageMetadataChunk):
|
683
|
+
usage_metadata = types.GenerateContentResponseUsageMetadata(
|
684
|
+
prompt_token_count=chunk.prompt_tokens,
|
685
|
+
candidates_token_count=chunk.completion_tokens,
|
686
|
+
total_token_count=chunk.total_tokens,
|
687
|
+
)
|
688
|
+
|
648
689
|
if finish_reason == "tool_calls" and function_id:
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
690
|
+
aggregated_llm_response_with_tool_call = (
|
691
|
+
_message_to_generate_content_response(
|
692
|
+
ChatCompletionAssistantMessage(
|
693
|
+
role="assistant",
|
694
|
+
content="",
|
695
|
+
tool_calls=[
|
696
|
+
ChatCompletionMessageToolCall(
|
697
|
+
type="function",
|
698
|
+
id=function_id,
|
699
|
+
function=Function(
|
700
|
+
name=function_name,
|
701
|
+
arguments=function_args,
|
702
|
+
),
|
703
|
+
)
|
704
|
+
],
|
705
|
+
)
|
663
706
|
)
|
664
707
|
)
|
665
708
|
function_name = ""
|
666
709
|
function_args = ""
|
667
710
|
function_id = None
|
668
711
|
elif finish_reason == "stop" and text:
|
669
|
-
|
712
|
+
aggregated_llm_response = _message_to_generate_content_response(
|
670
713
|
ChatCompletionAssistantMessage(role="assistant", content=text)
|
671
714
|
)
|
672
715
|
text = ""
|
673
716
|
|
717
|
+
# waiting until streaming ends to yield the llm_response as litellm tends
|
718
|
+
# to send chunk that contains usage_metadata after the chunk with
|
719
|
+
# finish_reason set to tool_calls or stop.
|
720
|
+
if aggregated_llm_response:
|
721
|
+
if usage_metadata:
|
722
|
+
aggregated_llm_response.usage_metadata = usage_metadata
|
723
|
+
usage_metadata = None
|
724
|
+
yield aggregated_llm_response
|
725
|
+
|
726
|
+
if aggregated_llm_response_with_tool_call:
|
727
|
+
if usage_metadata:
|
728
|
+
aggregated_llm_response_with_tool_call.usage_metadata = usage_metadata
|
729
|
+
yield aggregated_llm_response_with_tool_call
|
730
|
+
|
674
731
|
else:
|
675
732
|
response = await self.llm_client.acompletion(**completion_args)
|
676
733
|
yield _model_response_to_generate_content_response(response)
|
@@ -14,9 +14,11 @@
|
|
14
14
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
-
from typing import Any
|
17
|
+
from typing import Any
|
18
|
+
from typing import Optional
|
18
19
|
|
19
20
|
from google.genai import types
|
21
|
+
from pydantic import alias_generators
|
20
22
|
from pydantic import BaseModel
|
21
23
|
from pydantic import ConfigDict
|
22
24
|
|
@@ -40,7 +42,11 @@ class LlmResponse(BaseModel):
|
|
40
42
|
custom_metadata: The custom metadata of the LlmResponse.
|
41
43
|
"""
|
42
44
|
|
43
|
-
model_config = ConfigDict(
|
45
|
+
model_config = ConfigDict(
|
46
|
+
extra='forbid',
|
47
|
+
alias_generator=alias_generators.to_camel,
|
48
|
+
populate_by_name=True,
|
49
|
+
)
|
44
50
|
"""The pydantic model config."""
|
45
51
|
|
46
52
|
content: Optional[types.Content] = None
|
@@ -80,6 +86,9 @@ class LlmResponse(BaseModel):
|
|
80
86
|
NOTE: the entire dict must be JSON serializable.
|
81
87
|
"""
|
82
88
|
|
89
|
+
usage_metadata: Optional[types.GenerateContentResponseUsageMetadata] = None
|
90
|
+
"""The usage metadata of the LlmResponse"""
|
91
|
+
|
83
92
|
@staticmethod
|
84
93
|
def create(
|
85
94
|
generate_content_response: types.GenerateContentResponse,
|
@@ -93,18 +102,20 @@ class LlmResponse(BaseModel):
|
|
93
102
|
Returns:
|
94
103
|
The LlmResponse.
|
95
104
|
"""
|
96
|
-
|
105
|
+
usage_metadata = generate_content_response.usage_metadata
|
97
106
|
if generate_content_response.candidates:
|
98
107
|
candidate = generate_content_response.candidates[0]
|
99
108
|
if candidate.content and candidate.content.parts:
|
100
109
|
return LlmResponse(
|
101
110
|
content=candidate.content,
|
102
111
|
grounding_metadata=candidate.grounding_metadata,
|
112
|
+
usage_metadata=usage_metadata,
|
103
113
|
)
|
104
114
|
else:
|
105
115
|
return LlmResponse(
|
106
116
|
error_code=candidate.finish_reason,
|
107
117
|
error_message=candidate.finish_message,
|
118
|
+
usage_metadata=usage_metadata,
|
108
119
|
)
|
109
120
|
else:
|
110
121
|
if generate_content_response.prompt_feedback:
|
@@ -112,9 +123,11 @@ class LlmResponse(BaseModel):
|
|
112
123
|
return LlmResponse(
|
113
124
|
error_code=prompt_feedback.block_reason,
|
114
125
|
error_message=prompt_feedback.block_reason_message,
|
126
|
+
usage_metadata=usage_metadata,
|
115
127
|
)
|
116
128
|
else:
|
117
129
|
return LlmResponse(
|
118
130
|
error_code='UNKNOWN_ERROR',
|
119
131
|
error_message='Unknown error.',
|
132
|
+
usage_metadata=usage_metadata,
|
120
133
|
)
|
google/adk/models/registry.py
CHANGED