google-adk 0.5.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. google/adk/agents/base_agent.py +76 -30
  2. google/adk/agents/callback_context.py +2 -6
  3. google/adk/agents/llm_agent.py +122 -30
  4. google/adk/agents/loop_agent.py +1 -1
  5. google/adk/agents/parallel_agent.py +7 -0
  6. google/adk/agents/readonly_context.py +8 -0
  7. google/adk/agents/run_config.py +1 -1
  8. google/adk/agents/sequential_agent.py +31 -0
  9. google/adk/agents/transcription_entry.py +4 -2
  10. google/adk/artifacts/gcs_artifact_service.py +1 -1
  11. google/adk/artifacts/in_memory_artifact_service.py +1 -1
  12. google/adk/auth/auth_credential.py +10 -2
  13. google/adk/auth/auth_preprocessor.py +7 -1
  14. google/adk/auth/auth_tool.py +3 -4
  15. google/adk/cli/agent_graph.py +5 -5
  16. google/adk/cli/browser/index.html +4 -4
  17. google/adk/cli/browser/{main-ULN5R5I5.js → main-PKDNKWJE.js} +59 -60
  18. google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
  19. google/adk/cli/cli.py +10 -9
  20. google/adk/cli/cli_deploy.py +7 -2
  21. google/adk/cli/cli_eval.py +109 -115
  22. google/adk/cli/cli_tools_click.py +179 -67
  23. google/adk/cli/fast_api.py +248 -197
  24. google/adk/cli/utils/agent_loader.py +137 -0
  25. google/adk/cli/utils/cleanup.py +40 -0
  26. google/adk/cli/utils/common.py +23 -0
  27. google/adk/cli/utils/evals.py +83 -0
  28. google/adk/cli/utils/logs.py +8 -5
  29. google/adk/code_executors/__init__.py +3 -1
  30. google/adk/code_executors/built_in_code_executor.py +52 -0
  31. google/adk/code_executors/code_execution_utils.py +2 -1
  32. google/adk/code_executors/container_code_executor.py +0 -1
  33. google/adk/code_executors/vertex_ai_code_executor.py +6 -8
  34. google/adk/evaluation/__init__.py +1 -1
  35. google/adk/evaluation/agent_evaluator.py +168 -128
  36. google/adk/evaluation/eval_case.py +104 -0
  37. google/adk/evaluation/eval_metrics.py +74 -0
  38. google/adk/evaluation/eval_result.py +86 -0
  39. google/adk/evaluation/eval_set.py +39 -0
  40. google/adk/evaluation/eval_set_results_manager.py +47 -0
  41. google/adk/evaluation/eval_sets_manager.py +43 -0
  42. google/adk/evaluation/evaluation_generator.py +88 -113
  43. google/adk/evaluation/evaluator.py +58 -0
  44. google/adk/evaluation/local_eval_set_results_manager.py +113 -0
  45. google/adk/evaluation/local_eval_sets_manager.py +264 -0
  46. google/adk/evaluation/response_evaluator.py +106 -1
  47. google/adk/evaluation/trajectory_evaluator.py +84 -2
  48. google/adk/events/event.py +6 -1
  49. google/adk/events/event_actions.py +6 -1
  50. google/adk/examples/base_example_provider.py +1 -0
  51. google/adk/examples/example_util.py +3 -2
  52. google/adk/flows/llm_flows/_code_execution.py +9 -1
  53. google/adk/flows/llm_flows/audio_transcriber.py +4 -3
  54. google/adk/flows/llm_flows/base_llm_flow.py +58 -21
  55. google/adk/flows/llm_flows/contents.py +3 -1
  56. google/adk/flows/llm_flows/functions.py +9 -8
  57. google/adk/flows/llm_flows/instructions.py +18 -80
  58. google/adk/flows/llm_flows/single_flow.py +2 -2
  59. google/adk/memory/__init__.py +1 -1
  60. google/adk/memory/_utils.py +23 -0
  61. google/adk/memory/base_memory_service.py +23 -21
  62. google/adk/memory/in_memory_memory_service.py +57 -25
  63. google/adk/memory/memory_entry.py +37 -0
  64. google/adk/memory/vertex_ai_rag_memory_service.py +38 -15
  65. google/adk/models/anthropic_llm.py +16 -9
  66. google/adk/models/base_llm.py +2 -1
  67. google/adk/models/base_llm_connection.py +2 -0
  68. google/adk/models/gemini_llm_connection.py +11 -11
  69. google/adk/models/google_llm.py +12 -2
  70. google/adk/models/lite_llm.py +80 -23
  71. google/adk/models/llm_response.py +16 -3
  72. google/adk/models/registry.py +1 -1
  73. google/adk/runners.py +98 -42
  74. google/adk/sessions/__init__.py +1 -1
  75. google/adk/sessions/_session_util.py +2 -1
  76. google/adk/sessions/base_session_service.py +6 -33
  77. google/adk/sessions/database_session_service.py +57 -67
  78. google/adk/sessions/in_memory_session_service.py +106 -24
  79. google/adk/sessions/session.py +3 -0
  80. google/adk/sessions/vertex_ai_session_service.py +44 -51
  81. google/adk/telemetry.py +7 -2
  82. google/adk/tools/__init__.py +4 -7
  83. google/adk/tools/_memory_entry_utils.py +30 -0
  84. google/adk/tools/agent_tool.py +10 -10
  85. google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
  86. google/adk/tools/apihub_tool/clients/apihub_client.py +10 -3
  87. google/adk/tools/apihub_tool/clients/secret_client.py +1 -0
  88. google/adk/tools/application_integration_tool/application_integration_toolset.py +111 -85
  89. google/adk/tools/application_integration_tool/clients/connections_client.py +28 -1
  90. google/adk/tools/application_integration_tool/clients/integration_client.py +7 -5
  91. google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
  92. google/adk/tools/base_toolset.py +96 -0
  93. google/adk/tools/bigquery/__init__.py +28 -0
  94. google/adk/tools/bigquery/bigquery_credentials.py +216 -0
  95. google/adk/tools/bigquery/bigquery_tool.py +116 -0
  96. google/adk/tools/{built_in_code_execution_tool.py → enterprise_search_tool.py} +17 -11
  97. google/adk/tools/function_parameter_parse_util.py +9 -2
  98. google/adk/tools/function_tool.py +33 -3
  99. google/adk/tools/get_user_choice_tool.py +1 -0
  100. google/adk/tools/google_api_tool/__init__.py +24 -70
  101. google/adk/tools/google_api_tool/google_api_tool.py +12 -6
  102. google/adk/tools/google_api_tool/{google_api_tool_set.py → google_api_toolset.py} +57 -55
  103. google/adk/tools/google_api_tool/google_api_toolsets.py +108 -0
  104. google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
  105. google/adk/tools/google_search_tool.py +2 -2
  106. google/adk/tools/langchain_tool.py +96 -49
  107. google/adk/tools/load_memory_tool.py +14 -5
  108. google/adk/tools/mcp_tool/__init__.py +3 -2
  109. google/adk/tools/mcp_tool/conversion_utils.py +6 -2
  110. google/adk/tools/mcp_tool/mcp_session_manager.py +80 -69
  111. google/adk/tools/mcp_tool/mcp_tool.py +35 -32
  112. google/adk/tools/mcp_tool/mcp_toolset.py +99 -194
  113. google/adk/tools/openapi_tool/auth/credential_exchangers/base_credential_exchanger.py +1 -3
  114. google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +6 -7
  115. google/adk/tools/openapi_tool/common/common.py +5 -1
  116. google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py +7 -2
  117. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +27 -7
  118. google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +36 -32
  119. google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +11 -1
  120. google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
  121. google/adk/tools/preload_memory_tool.py +27 -18
  122. google/adk/tools/retrieval/__init__.py +1 -1
  123. google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
  124. google/adk/tools/toolbox_toolset.py +107 -0
  125. google/adk/tools/transfer_to_agent_tool.py +0 -1
  126. google/adk/utils/__init__.py +13 -0
  127. google/adk/utils/instructions_utils.py +131 -0
  128. google/adk/version.py +1 -1
  129. {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/METADATA +18 -19
  130. google_adk-1.1.0.dist-info/RECORD +200 -0
  131. google/adk/agents/remote_agent.py +0 -50
  132. google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -18
  133. google/adk/cli/fast_api.py.orig +0 -728
  134. google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
  135. google/adk/tools/toolbox_tool.py +0 -46
  136. google_adk-0.5.0.dist-info/RECORD +0 -180
  137. {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/WHEEL +0 -0
  138. {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/entry_points.txt +0 -0
  139. {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -12,20 +12,28 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+
16
+ from __future__ import annotations
17
+
15
18
  from collections import OrderedDict
16
19
  import json
17
20
  import os
18
21
  import tempfile
22
+ from typing import Optional
23
+ from typing import TYPE_CHECKING
19
24
 
20
25
  from google.genai import types
21
26
  from typing_extensions import override
22
27
  from vertexai.preview import rag
23
28
 
24
- from ..events.event import Event
25
- from ..sessions.session import Session
29
+ from . import _utils
26
30
  from .base_memory_service import BaseMemoryService
27
- from .base_memory_service import MemoryResult
28
31
  from .base_memory_service import SearchMemoryResponse
32
+ from .memory_entry import MemoryEntry
33
+
34
+ if TYPE_CHECKING:
35
+ from ..events.event import Event
36
+ from ..sessions.session import Session
29
37
 
30
38
 
31
39
  class VertexAiRagMemoryService(BaseMemoryService):
@@ -33,8 +41,8 @@ class VertexAiRagMemoryService(BaseMemoryService):
33
41
 
34
42
  def __init__(
35
43
  self,
36
- rag_corpus: str = None,
37
- similarity_top_k: int = None,
44
+ rag_corpus: Optional[str] = None,
45
+ similarity_top_k: Optional[int] = None,
38
46
  vector_distance_threshold: float = 10,
39
47
  ):
40
48
  """Initializes a VertexAiRagMemoryService.
@@ -47,8 +55,10 @@ class VertexAiRagMemoryService(BaseMemoryService):
47
55
  vector_distance_threshold: Only returns contexts with vector distance
48
56
  smaller than the threshold..
49
57
  """
50
- self.vertex_rag_store = types.VertexRagStore(
51
- rag_resources=[rag.RagResource(rag_corpus=rag_corpus)],
58
+ self._vertex_rag_store = types.VertexRagStore(
59
+ rag_resources=[
60
+ types.VertexRagStoreRagResource(rag_corpus=rag_corpus),
61
+ ],
52
62
  similarity_top_k=similarity_top_k,
53
63
  vector_distance_threshold=vector_distance_threshold,
54
64
  )
@@ -79,7 +89,11 @@ class VertexAiRagMemoryService(BaseMemoryService):
79
89
  output_string = "\n".join(output_lines)
80
90
  temp_file.write(output_string)
81
91
  temp_file_path = temp_file.name
82
- for rag_resource in self.vertex_rag_store.rag_resources:
92
+
93
+ if not self._vertex_rag_store.rag_resources:
94
+ raise ValueError("Rag resources must be set.")
95
+
96
+ for rag_resource in self._vertex_rag_store.rag_resources:
83
97
  rag.upload_file(
84
98
  corpus_name=rag_resource.rag_corpus,
85
99
  path=temp_file_path,
@@ -95,12 +109,14 @@ class VertexAiRagMemoryService(BaseMemoryService):
95
109
  self, *, app_name: str, user_id: str, query: str
96
110
  ) -> SearchMemoryResponse:
97
111
  """Searches for sessions that match the query using rag.retrieval_query."""
112
+ from ..events.event import Event
113
+
98
114
  response = rag.retrieval_query(
99
115
  text=query,
100
- rag_resources=self.vertex_rag_store.rag_resources,
101
- rag_corpora=self.vertex_rag_store.rag_corpora,
102
- similarity_top_k=self.vertex_rag_store.similarity_top_k,
103
- vector_distance_threshold=self.vertex_rag_store.vector_distance_threshold,
116
+ rag_resources=self._vertex_rag_store.rag_resources,
117
+ rag_corpora=self._vertex_rag_store.rag_corpora,
118
+ similarity_top_k=self._vertex_rag_store.similarity_top_k,
119
+ vector_distance_threshold=self._vertex_rag_store.vector_distance_threshold,
104
120
  )
105
121
 
106
122
  memory_results = []
@@ -144,9 +160,16 @@ class VertexAiRagMemoryService(BaseMemoryService):
144
160
  for session_id, event_lists in session_events_map.items():
145
161
  for events in _merge_event_lists(event_lists):
146
162
  sorted_events = sorted(events, key=lambda e: e.timestamp)
147
- memory_results.append(
148
- MemoryResult(session_id=session_id, events=sorted_events)
149
- )
163
+
164
+ memory_results.extend([
165
+ MemoryEntry(
166
+ author=event.author,
167
+ content=event.content,
168
+ timestamp=_utils.format_timestamp(event.timestamp),
169
+ )
170
+ for event in sorted_events
171
+ if event.content
172
+ ])
150
173
  return SearchMemoryResponse(memories=memory_results)
151
174
 
152
175
 
@@ -24,8 +24,9 @@ from typing import AsyncGenerator
24
24
  from typing import Generator
25
25
  from typing import Iterable
26
26
  from typing import Literal
27
- from typing import Optional, Union
27
+ from typing import Optional
28
28
  from typing import TYPE_CHECKING
29
+ from typing import Union
29
30
 
30
31
  from anthropic import AnthropicVertex
31
32
  from anthropic import NOT_GIVEN
@@ -42,7 +43,7 @@ if TYPE_CHECKING:
42
43
 
43
44
  __all__ = ["Claude"]
44
45
 
45
- logger = logging.getLogger(__name__)
46
+ logger = logging.getLogger("google_adk." + __name__)
46
47
 
47
48
  MAX_TOKEN = 1024
48
49
 
@@ -140,15 +141,15 @@ def message_to_generate_content_response(
140
141
  role="model",
141
142
  parts=[content_block_to_part(cb) for cb in message.content],
142
143
  ),
144
+ usage_metadata=types.GenerateContentResponseUsageMetadata(
145
+ prompt_token_count=message.usage.input_tokens,
146
+ candidates_token_count=message.usage.output_tokens,
147
+ total_token_count=(
148
+ message.usage.input_tokens + message.usage.output_tokens
149
+ ),
150
+ ),
143
151
  # TODO: Deal with these later.
144
152
  # finish_reason=to_google_genai_finish_reason(message.stop_reason),
145
- # usage_metadata=types.GenerateContentResponseUsageMetadata(
146
- # prompt_token_count=message.usage.input_tokens,
147
- # candidates_token_count=message.usage.output_tokens,
148
- # total_token_count=(
149
- # message.usage.input_tokens + message.usage.output_tokens
150
- # ),
151
- # ),
152
153
  )
153
154
 
154
155
 
@@ -196,6 +197,12 @@ def function_declaration_to_tool_param(
196
197
 
197
198
 
198
199
  class Claude(BaseLlm):
200
+ """ "Integration with Claude models served from Vertex AI.
201
+
202
+ Attributes:
203
+ model: The name of the Claude model.
204
+ """
205
+
199
206
  model: str = "claude-3-5-sonnet-v2@20241022"
200
207
 
201
208
  @staticmethod
@@ -14,7 +14,8 @@
14
14
  from __future__ import annotations
15
15
 
16
16
  from abc import abstractmethod
17
- from typing import AsyncGenerator, TYPE_CHECKING
17
+ from typing import AsyncGenerator
18
+ from typing import TYPE_CHECKING
18
19
 
19
20
  from google.genai import types
20
21
  from pydantic import BaseModel
@@ -14,7 +14,9 @@
14
14
 
15
15
  from abc import abstractmethod
16
16
  from typing import AsyncGenerator
17
+
17
18
  from google.genai import types
19
+
18
20
  from .llm_response import LlmResponse
19
21
 
20
22
 
@@ -21,7 +21,7 @@ from google.genai import types
21
21
  from .base_llm_connection import BaseLlmConnection
22
22
  from .llm_response import LlmResponse
23
23
 
24
- logger = logging.getLogger(__name__)
24
+ logger = logging.getLogger('google_adk.' + __name__)
25
25
 
26
26
 
27
27
  class GeminiLlmConnection(BaseLlmConnection):
@@ -149,16 +149,16 @@ class GeminiLlmConnection(BaseLlmConnection):
149
149
  message.server_content.input_transcription
150
150
  and message.server_content.input_transcription.text
151
151
  ):
152
- user_text = message.server_content.input_transcription.text
153
- parts = [
154
- types.Part.from_text(
155
- text=user_text,
156
- )
157
- ]
158
- llm_response = LlmResponse(
159
- content=types.Content(role='user', parts=parts)
160
- )
161
- yield llm_response
152
+ user_text = message.server_content.input_transcription.text
153
+ parts = [
154
+ types.Part.from_text(
155
+ text=user_text,
156
+ )
157
+ ]
158
+ llm_response = LlmResponse(
159
+ content=types.Content(role='user', parts=parts)
160
+ )
161
+ yield llm_response
162
162
  if (
163
163
  message.server_content.output_transcription
164
164
  and message.server_content.output_transcription.text
@@ -11,6 +11,8 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
15
+
14
16
  from __future__ import annotations
15
17
 
16
18
  import contextlib
@@ -34,7 +36,7 @@ from .llm_response import LlmResponse
34
36
  if TYPE_CHECKING:
35
37
  from .llm_request import LlmRequest
36
38
 
37
- logger = logging.getLogger(__name__)
39
+ logger = logging.getLogger('google_adk.' + __name__)
38
40
 
39
41
  _NEW_LINE = '\n'
40
42
  _EXCLUDED_PART_FIELD = {'inline_data': {'data'}}
@@ -96,6 +98,7 @@ class Gemini(BaseLlm):
96
98
  )
97
99
  response = None
98
100
  text = ''
101
+ usage_metadata = None
99
102
  # for sse, similar as bidi (see receive method in gemini_llm_connecton.py),
100
103
  # we need to mark those text content as partial and after all partial
101
104
  # contents are sent, we send an accumulated event which contains all the
@@ -104,6 +107,7 @@ class Gemini(BaseLlm):
104
107
  async for response in responses:
105
108
  logger.info(_build_response_log(response))
106
109
  llm_response = LlmResponse.create(response)
110
+ usage_metadata = llm_response.usage_metadata
107
111
  if (
108
112
  llm_response.content
109
113
  and llm_response.content.parts
@@ -121,6 +125,7 @@ class Gemini(BaseLlm):
121
125
  content=types.ModelContent(
122
126
  parts=[types.Part.from_text(text=text)],
123
127
  ),
128
+ usage_metadata=usage_metadata,
124
129
  )
125
130
  text = ''
126
131
  yield llm_response
@@ -134,6 +139,7 @@ class Gemini(BaseLlm):
134
139
  content=types.ModelContent(
135
140
  parts=[types.Part.from_text(text=text)],
136
141
  ),
142
+ usage_metadata=usage_metadata,
137
143
  )
138
144
 
139
145
  else:
@@ -174,9 +180,13 @@ class Gemini(BaseLlm):
174
180
  @cached_property
175
181
  def _live_api_client(self) -> Client:
176
182
  if self._api_backend == 'vertex':
183
+ # use beta version for vertex api
184
+ api_version = 'v1beta1'
177
185
  # use default api version for vertex
178
186
  return Client(
179
- http_options=types.HttpOptions(headers=self._tracking_headers)
187
+ http_options=types.HttpOptions(
188
+ headers=self._tracking_headers, api_version=api_version
189
+ )
180
190
  )
181
191
  else:
182
192
  # use v1alpha for ml_dev
@@ -30,6 +30,7 @@ from typing import Union
30
30
  from google.genai import types
31
31
  from litellm import acompletion
32
32
  from litellm import ChatCompletionAssistantMessage
33
+ from litellm import ChatCompletionAssistantToolCall
33
34
  from litellm import ChatCompletionDeveloperMessage
34
35
  from litellm import ChatCompletionImageUrlObject
35
36
  from litellm import ChatCompletionMessageToolCall
@@ -51,7 +52,7 @@ from .base_llm import BaseLlm
51
52
  from .llm_request import LlmRequest
52
53
  from .llm_response import LlmResponse
53
54
 
54
- logger = logging.getLogger(__name__)
55
+ logger = logging.getLogger("google_adk." + __name__)
55
56
 
56
57
  _NEW_LINE = "\n"
57
58
  _EXCLUDED_PART_FIELD = {"inline_data": {"data"}}
@@ -67,6 +68,12 @@ class TextChunk(BaseModel):
67
68
  text: str
68
69
 
69
70
 
71
+ class UsageMetadataChunk(BaseModel):
72
+ prompt_tokens: int
73
+ completion_tokens: int
74
+ total_tokens: int
75
+
76
+
70
77
  class LiteLLMClient:
71
78
  """Provides acompletion method (for better testability)."""
72
79
 
@@ -174,12 +181,12 @@ def _content_to_message_param(
174
181
  for part in content.parts:
175
182
  if part.function_call:
176
183
  tool_calls.append(
177
- ChatCompletionMessageToolCall(
184
+ ChatCompletionAssistantToolCall(
178
185
  type="function",
179
186
  id=part.function_call.id,
180
187
  function=Function(
181
188
  name=part.function_call.name,
182
- arguments=part.function_call.args,
189
+ arguments=json.dumps(part.function_call.args),
183
190
  ),
184
191
  )
185
192
  )
@@ -344,15 +351,20 @@ def _function_declaration_to_tool_param(
344
351
  def _model_response_to_chunk(
345
352
  response: ModelResponse,
346
353
  ) -> Generator[
347
- Tuple[Optional[Union[TextChunk, FunctionChunk]], Optional[str]], None, None
354
+ Tuple[
355
+ Optional[Union[TextChunk, FunctionChunk, UsageMetadataChunk]],
356
+ Optional[str],
357
+ ],
358
+ None,
359
+ None,
348
360
  ]:
349
- """Converts a litellm message to text or function chunk.
361
+ """Converts a litellm message to text, function or usage metadata chunk.
350
362
 
351
363
  Args:
352
364
  response: The response from the model.
353
365
 
354
366
  Yields:
355
- A tuple of text or function chunk and finish reason.
367
+ A tuple of text or function or usage metadata chunk and finish reason.
356
368
  """
357
369
 
358
370
  message = None
@@ -384,11 +396,21 @@ def _model_response_to_chunk(
384
396
  if not message:
385
397
  yield None, None
386
398
 
399
+ # Ideally usage would be expected with the last ModelResponseStream with a
400
+ # finish_reason set. But this is not the case we are observing from litellm.
401
+ # So we are sending it as a separate chunk to be set on the llm_response.
402
+ if response.get("usage", None):
403
+ yield UsageMetadataChunk(
404
+ prompt_tokens=response["usage"].get("prompt_tokens", 0),
405
+ completion_tokens=response["usage"].get("completion_tokens", 0),
406
+ total_tokens=response["usage"].get("total_tokens", 0),
407
+ ), None
408
+
387
409
 
388
410
  def _model_response_to_generate_content_response(
389
411
  response: ModelResponse,
390
412
  ) -> LlmResponse:
391
- """Converts a litellm response to LlmResponse.
413
+ """Converts a litellm response to LlmResponse. Also adds usage metadata.
392
414
 
393
415
  Args:
394
416
  response: The model response.
@@ -403,7 +425,15 @@ def _model_response_to_generate_content_response(
403
425
 
404
426
  if not message:
405
427
  raise ValueError("No message in response")
406
- return _message_to_generate_content_response(message)
428
+
429
+ llm_response = _message_to_generate_content_response(message)
430
+ if response.get("usage", None):
431
+ llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata(
432
+ prompt_token_count=response["usage"].get("prompt_tokens", 0),
433
+ candidates_token_count=response["usage"].get("completion_tokens", 0),
434
+ total_token_count=response["usage"].get("total_tokens", 0),
435
+ )
436
+ return llm_response
407
437
 
408
438
 
409
439
  def _message_to_generate_content_response(
@@ -628,6 +658,10 @@ class LiteLlm(BaseLlm):
628
658
  function_args = ""
629
659
  function_id = None
630
660
  completion_args["stream"] = True
661
+ aggregated_llm_response = None
662
+ aggregated_llm_response_with_tool_call = None
663
+ usage_metadata = None
664
+
631
665
  for part in self.llm_client.completion(**completion_args):
632
666
  for chunk, finish_reason in _model_response_to_chunk(part):
633
667
  if isinstance(chunk, FunctionChunk):
@@ -645,32 +679,55 @@ class LiteLlm(BaseLlm):
645
679
  ),
646
680
  is_partial=True,
647
681
  )
682
+ elif isinstance(chunk, UsageMetadataChunk):
683
+ usage_metadata = types.GenerateContentResponseUsageMetadata(
684
+ prompt_token_count=chunk.prompt_tokens,
685
+ candidates_token_count=chunk.completion_tokens,
686
+ total_token_count=chunk.total_tokens,
687
+ )
688
+
648
689
  if finish_reason == "tool_calls" and function_id:
649
- yield _message_to_generate_content_response(
650
- ChatCompletionAssistantMessage(
651
- role="assistant",
652
- content="",
653
- tool_calls=[
654
- ChatCompletionMessageToolCall(
655
- type="function",
656
- id=function_id,
657
- function=Function(
658
- name=function_name,
659
- arguments=function_args,
660
- ),
661
- )
662
- ],
690
+ aggregated_llm_response_with_tool_call = (
691
+ _message_to_generate_content_response(
692
+ ChatCompletionAssistantMessage(
693
+ role="assistant",
694
+ content="",
695
+ tool_calls=[
696
+ ChatCompletionMessageToolCall(
697
+ type="function",
698
+ id=function_id,
699
+ function=Function(
700
+ name=function_name,
701
+ arguments=function_args,
702
+ ),
703
+ )
704
+ ],
705
+ )
663
706
  )
664
707
  )
665
708
  function_name = ""
666
709
  function_args = ""
667
710
  function_id = None
668
711
  elif finish_reason == "stop" and text:
669
- yield _message_to_generate_content_response(
712
+ aggregated_llm_response = _message_to_generate_content_response(
670
713
  ChatCompletionAssistantMessage(role="assistant", content=text)
671
714
  )
672
715
  text = ""
673
716
 
717
+ # waiting until streaming ends to yield the llm_response as litellm tends
718
+ # to send chunk that contains usage_metadata after the chunk with
719
+ # finish_reason set to tool_calls or stop.
720
+ if aggregated_llm_response:
721
+ if usage_metadata:
722
+ aggregated_llm_response.usage_metadata = usage_metadata
723
+ usage_metadata = None
724
+ yield aggregated_llm_response
725
+
726
+ if aggregated_llm_response_with_tool_call:
727
+ if usage_metadata:
728
+ aggregated_llm_response_with_tool_call.usage_metadata = usage_metadata
729
+ yield aggregated_llm_response_with_tool_call
730
+
674
731
  else:
675
732
  response = await self.llm_client.acompletion(**completion_args)
676
733
  yield _model_response_to_generate_content_response(response)
@@ -14,9 +14,11 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- from typing import Any, Optional
17
+ from typing import Any
18
+ from typing import Optional
18
19
 
19
20
  from google.genai import types
21
+ from pydantic import alias_generators
20
22
  from pydantic import BaseModel
21
23
  from pydantic import ConfigDict
22
24
 
@@ -40,7 +42,11 @@ class LlmResponse(BaseModel):
40
42
  custom_metadata: The custom metadata of the LlmResponse.
41
43
  """
42
44
 
43
- model_config = ConfigDict(extra='forbid')
45
+ model_config = ConfigDict(
46
+ extra='forbid',
47
+ alias_generator=alias_generators.to_camel,
48
+ populate_by_name=True,
49
+ )
44
50
  """The pydantic model config."""
45
51
 
46
52
  content: Optional[types.Content] = None
@@ -80,6 +86,9 @@ class LlmResponse(BaseModel):
80
86
  NOTE: the entire dict must be JSON serializable.
81
87
  """
82
88
 
89
+ usage_metadata: Optional[types.GenerateContentResponseUsageMetadata] = None
90
+ """The usage metadata of the LlmResponse"""
91
+
83
92
  @staticmethod
84
93
  def create(
85
94
  generate_content_response: types.GenerateContentResponse,
@@ -93,18 +102,20 @@ class LlmResponse(BaseModel):
93
102
  Returns:
94
103
  The LlmResponse.
95
104
  """
96
-
105
+ usage_metadata = generate_content_response.usage_metadata
97
106
  if generate_content_response.candidates:
98
107
  candidate = generate_content_response.candidates[0]
99
108
  if candidate.content and candidate.content.parts:
100
109
  return LlmResponse(
101
110
  content=candidate.content,
102
111
  grounding_metadata=candidate.grounding_metadata,
112
+ usage_metadata=usage_metadata,
103
113
  )
104
114
  else:
105
115
  return LlmResponse(
106
116
  error_code=candidate.finish_reason,
107
117
  error_message=candidate.finish_message,
118
+ usage_metadata=usage_metadata,
108
119
  )
109
120
  else:
110
121
  if generate_content_response.prompt_feedback:
@@ -112,9 +123,11 @@ class LlmResponse(BaseModel):
112
123
  return LlmResponse(
113
124
  error_code=prompt_feedback.block_reason,
114
125
  error_message=prompt_feedback.block_reason_message,
126
+ usage_metadata=usage_metadata,
115
127
  )
116
128
  else:
117
129
  return LlmResponse(
118
130
  error_code='UNKNOWN_ERROR',
119
131
  error_message='Unknown error.',
132
+ usage_metadata=usage_metadata,
120
133
  )
@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING
24
24
  if TYPE_CHECKING:
25
25
  from .base_llm import BaseLlm
26
26
 
27
- logger = logging.getLogger(__name__)
27
+ logger = logging.getLogger('google_adk.' + __name__)
28
28
 
29
29
 
30
30
  _llm_registry_dict: dict[str, type[BaseLlm]] = {}