google-adk 1.6.1__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/a2a/converters/event_converter.py +5 -85
- google/adk/a2a/converters/request_converter.py +1 -2
- google/adk/a2a/executor/a2a_agent_executor.py +45 -16
- google/adk/a2a/logs/log_utils.py +1 -2
- google/adk/a2a/utils/__init__.py +0 -0
- google/adk/a2a/utils/agent_card_builder.py +544 -0
- google/adk/a2a/utils/agent_to_a2a.py +118 -0
- google/adk/agents/__init__.py +5 -0
- google/adk/agents/agent_config.py +46 -0
- google/adk/agents/base_agent.py +239 -41
- google/adk/agents/callback_context.py +41 -0
- google/adk/agents/common_configs.py +79 -0
- google/adk/agents/config_agent_utils.py +184 -0
- google/adk/agents/config_schemas/AgentConfig.json +566 -0
- google/adk/agents/invocation_context.py +5 -1
- google/adk/agents/live_request_queue.py +15 -0
- google/adk/agents/llm_agent.py +201 -9
- google/adk/agents/loop_agent.py +35 -1
- google/adk/agents/parallel_agent.py +24 -3
- google/adk/agents/remote_a2a_agent.py +17 -5
- google/adk/agents/sequential_agent.py +22 -1
- google/adk/artifacts/gcs_artifact_service.py +110 -20
- google/adk/auth/auth_handler.py +3 -3
- google/adk/auth/credential_manager.py +23 -23
- google/adk/auth/credential_service/base_credential_service.py +6 -6
- google/adk/auth/credential_service/in_memory_credential_service.py +10 -8
- google/adk/auth/credential_service/session_state_credential_service.py +8 -8
- google/adk/auth/exchanger/oauth2_credential_exchanger.py +3 -3
- google/adk/auth/oauth2_credential_util.py +2 -2
- google/adk/auth/refresher/oauth2_credential_refresher.py +4 -4
- google/adk/cli/agent_graph.py +3 -1
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/main-W7QZBYAR.js +3914 -0
- google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
- google/adk/cli/cli_eval.py +87 -12
- google/adk/cli/cli_tools_click.py +143 -82
- google/adk/cli/fast_api.py +150 -69
- google/adk/cli/utils/agent_loader.py +35 -1
- google/adk/code_executors/base_code_executor.py +14 -19
- google/adk/code_executors/built_in_code_executor.py +4 -1
- google/adk/evaluation/base_eval_service.py +46 -2
- google/adk/evaluation/eval_metrics.py +4 -0
- google/adk/evaluation/eval_sets_manager.py +5 -1
- google/adk/evaluation/evaluation_generator.py +1 -1
- google/adk/evaluation/final_response_match_v2.py +2 -2
- google/adk/evaluation/gcs_eval_sets_manager.py +2 -1
- google/adk/evaluation/in_memory_eval_sets_manager.py +151 -0
- google/adk/evaluation/local_eval_service.py +389 -0
- google/adk/evaluation/local_eval_set_results_manager.py +2 -2
- google/adk/evaluation/local_eval_sets_manager.py +24 -9
- google/adk/evaluation/metric_evaluator_registry.py +16 -6
- google/adk/evaluation/vertex_ai_eval_facade.py +7 -1
- google/adk/events/event.py +7 -2
- google/adk/flows/llm_flows/auto_flow.py +6 -11
- google/adk/flows/llm_flows/base_llm_flow.py +66 -29
- google/adk/flows/llm_flows/contents.py +16 -10
- google/adk/flows/llm_flows/functions.py +89 -52
- google/adk/memory/in_memory_memory_service.py +21 -15
- google/adk/memory/vertex_ai_memory_bank_service.py +12 -10
- google/adk/models/anthropic_llm.py +46 -6
- google/adk/models/base_llm_connection.py +2 -0
- google/adk/models/gemini_llm_connection.py +17 -6
- google/adk/models/google_llm.py +46 -11
- google/adk/models/lite_llm.py +52 -22
- google/adk/plugins/__init__.py +17 -0
- google/adk/plugins/base_plugin.py +317 -0
- google/adk/plugins/plugin_manager.py +265 -0
- google/adk/runners.py +122 -18
- google/adk/sessions/database_session_service.py +51 -52
- google/adk/sessions/vertex_ai_session_service.py +27 -12
- google/adk/tools/__init__.py +2 -0
- google/adk/tools/_automatic_function_calling_util.py +20 -2
- google/adk/tools/agent_tool.py +15 -3
- google/adk/tools/apihub_tool/apihub_toolset.py +38 -39
- google/adk/tools/application_integration_tool/application_integration_toolset.py +35 -37
- google/adk/tools/application_integration_tool/integration_connector_tool.py +2 -3
- google/adk/tools/base_tool.py +9 -9
- google/adk/tools/base_toolset.py +29 -5
- google/adk/tools/bigquery/__init__.py +3 -3
- google/adk/tools/bigquery/metadata_tool.py +2 -0
- google/adk/tools/bigquery/query_tool.py +15 -1
- google/adk/tools/computer_use/__init__.py +13 -0
- google/adk/tools/computer_use/base_computer.py +265 -0
- google/adk/tools/computer_use/computer_use_tool.py +166 -0
- google/adk/tools/computer_use/computer_use_toolset.py +220 -0
- google/adk/tools/enterprise_search_tool.py +4 -2
- google/adk/tools/exit_loop_tool.py +1 -0
- google/adk/tools/google_api_tool/google_api_tool.py +16 -1
- google/adk/tools/google_api_tool/google_api_toolset.py +9 -7
- google/adk/tools/google_api_tool/google_api_toolsets.py +41 -20
- google/adk/tools/google_search_tool.py +4 -2
- google/adk/tools/langchain_tool.py +16 -6
- google/adk/tools/long_running_tool.py +21 -0
- google/adk/tools/mcp_tool/mcp_toolset.py +27 -28
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_spec_parser.py +5 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +8 -8
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +4 -6
- google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +3 -2
- google/adk/tools/tool_context.py +0 -10
- google/adk/tools/url_context_tool.py +4 -2
- google/adk/tools/vertex_ai_search_tool.py +4 -2
- google/adk/utils/model_name_utils.py +90 -0
- google/adk/version.py +1 -1
- {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/METADATA +3 -2
- {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/RECORD +108 -91
- google/adk/cli/browser/main-RXDVX3K6.js +0 -3914
- google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -17
- {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/WHEEL +0 -0
- {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/entry_points.txt +0 -0
- {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -14,6 +14,7 @@
|
|
14
14
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
+
import math
|
17
18
|
import os
|
18
19
|
from typing import Optional
|
19
20
|
|
@@ -112,7 +113,12 @@ class _VertexAiEvalFacade(Evaluator):
|
|
112
113
|
return ""
|
113
114
|
|
114
115
|
def _get_score(self, eval_result) -> Optional[float]:
|
115
|
-
if
|
116
|
+
if (
|
117
|
+
eval_result
|
118
|
+
and eval_result.summary_metrics
|
119
|
+
and isinstance(eval_result.summary_metrics[0].mean_score, float)
|
120
|
+
and not math.isnan(eval_result.summary_metrics[0].mean_score)
|
121
|
+
):
|
116
122
|
return eval_result.summary_metrics[0].mean_score
|
117
123
|
|
118
124
|
return None
|
google/adk/events/event.py
CHANGED
@@ -42,7 +42,6 @@ class Event(LlmResponse):
|
|
42
42
|
branch: The branch of the event.
|
43
43
|
id: The unique identifier of the event.
|
44
44
|
timestamp: The timestamp of the event.
|
45
|
-
is_final_response: Whether the event is the final response of the agent.
|
46
45
|
get_function_calls: Returns the function calls in the event.
|
47
46
|
"""
|
48
47
|
|
@@ -92,7 +91,13 @@ class Event(LlmResponse):
|
|
92
91
|
self.id = Event.new_id()
|
93
92
|
|
94
93
|
def is_final_response(self) -> bool:
|
95
|
-
"""Returns whether the event is the final response of
|
94
|
+
"""Returns whether the event is the final response of an agent.
|
95
|
+
|
96
|
+
NOTE: This method is ONLY for use by Agent Development Kit.
|
97
|
+
|
98
|
+
Note that when multiple agents participage in one invocation, there could be
|
99
|
+
one event has `is_final_response()` as True for each participating agent.
|
100
|
+
"""
|
96
101
|
if self.actions.skip_summarization or self.long_running_tool_ids:
|
97
102
|
return True
|
98
103
|
return (
|
@@ -14,6 +14,8 @@
|
|
14
14
|
|
15
15
|
"""Implementation of AutoFlow."""
|
16
16
|
|
17
|
+
from __future__ import annotations
|
18
|
+
|
17
19
|
from . import agent_transfer
|
18
20
|
from .single_flow import SingleFlow
|
19
21
|
|
@@ -29,19 +31,12 @@ class AutoFlow(SingleFlow):
|
|
29
31
|
|
30
32
|
For peer-agent transfers, it's only enabled when all below conditions are met:
|
31
33
|
|
32
|
-
- The parent agent is also
|
34
|
+
- The parent agent is also an LlmAgent.
|
33
35
|
- `disallow_transfer_to_peer` option of this agent is False (default).
|
34
36
|
|
35
|
-
Depending on the target agent
|
36
|
-
reversed.
|
37
|
-
|
38
|
-
- If the flow type of the tranferee agent is also auto, transfee agent will
|
39
|
-
remain as the active agent. The transfee agent will respond to the user's
|
40
|
-
next message directly.
|
41
|
-
- If the flow type of the transfere agent is not auto, the active agent will
|
42
|
-
be reversed back to previous agent.
|
43
|
-
|
44
|
-
TODO: allow user to config auto-reverse function.
|
37
|
+
Depending on the target agent type, the transfer may be automatically
|
38
|
+
reversed. (see Runner._find_agent_to_run method for which agent will remain
|
39
|
+
active to handle next user message.)
|
45
40
|
"""
|
46
41
|
|
47
42
|
def __init__(self):
|
@@ -42,6 +42,7 @@ from ...models.llm_response import LlmResponse
|
|
42
42
|
from ...telemetry import trace_call_llm
|
43
43
|
from ...telemetry import trace_send_data
|
44
44
|
from ...telemetry import tracer
|
45
|
+
from ...tools.base_toolset import BaseToolset
|
45
46
|
from ...tools.tool_context import ToolContext
|
46
47
|
|
47
48
|
if TYPE_CHECKING:
|
@@ -194,7 +195,12 @@ class BaseLlmFlow(ABC):
|
|
194
195
|
if live_request.close:
|
195
196
|
await llm_connection.close()
|
196
197
|
return
|
197
|
-
|
198
|
+
|
199
|
+
if live_request.activity_start:
|
200
|
+
await llm_connection.send_realtime(types.ActivityStart())
|
201
|
+
elif live_request.activity_end:
|
202
|
+
await llm_connection.send_realtime(types.ActivityEnd())
|
203
|
+
elif live_request.blob:
|
198
204
|
# Cache audio data here for transcription
|
199
205
|
if not invocation_context.transcription_cache:
|
200
206
|
invocation_context.transcription_cache = []
|
@@ -205,6 +211,7 @@ class BaseLlmFlow(ABC):
|
|
205
211
|
TranscriptionEntry(role='user', data=live_request.blob)
|
206
212
|
)
|
207
213
|
await llm_connection.send_realtime(live_request.blob)
|
214
|
+
|
208
215
|
if live_request.content:
|
209
216
|
await llm_connection.send_content(live_request.content)
|
210
217
|
|
@@ -283,14 +290,10 @@ class BaseLlmFlow(ABC):
|
|
283
290
|
async for event in self._run_one_step_async(invocation_context):
|
284
291
|
last_event = event
|
285
292
|
yield event
|
286
|
-
if not last_event or last_event.is_final_response():
|
293
|
+
if not last_event or last_event.is_final_response() or last_event.partial:
|
294
|
+
if last_event and last_event.partial:
|
295
|
+
logger.warning('The last event is partial, which is not expected.')
|
287
296
|
break
|
288
|
-
if last_event.partial:
|
289
|
-
# TODO: handle this in BaseLlm level.
|
290
|
-
raise ValueError(
|
291
|
-
f"Last event shouldn't be partial. LLM max output limit may be"
|
292
|
-
f' reached.'
|
293
|
-
)
|
294
297
|
|
295
298
|
async def _run_one_step_async(
|
296
299
|
self,
|
@@ -339,13 +342,25 @@ class BaseLlmFlow(ABC):
|
|
339
342
|
yield event
|
340
343
|
|
341
344
|
# Run processors for tools.
|
342
|
-
for
|
343
|
-
ReadonlyContext(invocation_context)
|
344
|
-
):
|
345
|
+
for tool_union in agent.tools:
|
345
346
|
tool_context = ToolContext(invocation_context)
|
346
|
-
|
347
|
-
|
347
|
+
|
348
|
+
# If it's a toolset, process it first
|
349
|
+
if isinstance(tool_union, BaseToolset):
|
350
|
+
await tool_union.process_llm_request(
|
351
|
+
tool_context=tool_context, llm_request=llm_request
|
352
|
+
)
|
353
|
+
|
354
|
+
from ...agents.llm_agent import _convert_tool_union_to_tools
|
355
|
+
|
356
|
+
# Then process all tools from this tool union
|
357
|
+
tools = await _convert_tool_union_to_tools(
|
358
|
+
tool_union, ReadonlyContext(invocation_context)
|
348
359
|
)
|
360
|
+
for tool in tools:
|
361
|
+
await tool.process_llm_request(
|
362
|
+
tool_context=tool_context, llm_request=llm_request
|
363
|
+
)
|
349
364
|
|
350
365
|
async def _postprocess_async(
|
351
366
|
self,
|
@@ -569,21 +584,32 @@ class BaseLlmFlow(ABC):
|
|
569
584
|
if not isinstance(agent, LlmAgent):
|
570
585
|
return
|
571
586
|
|
572
|
-
if not agent.canonical_before_model_callbacks:
|
573
|
-
return
|
574
|
-
|
575
587
|
callback_context = CallbackContext(
|
576
588
|
invocation_context, event_actions=model_response_event.actions
|
577
589
|
)
|
578
590
|
|
591
|
+
# First run callbacks from the plugins.
|
592
|
+
callback_response = (
|
593
|
+
await invocation_context.plugin_manager.run_before_model_callback(
|
594
|
+
callback_context=callback_context,
|
595
|
+
llm_request=llm_request,
|
596
|
+
)
|
597
|
+
)
|
598
|
+
if callback_response:
|
599
|
+
return callback_response
|
600
|
+
|
601
|
+
# If no overrides are provided from the plugins, further run the canonical
|
602
|
+
# callbacks.
|
603
|
+
if not agent.canonical_before_model_callbacks:
|
604
|
+
return
|
579
605
|
for callback in agent.canonical_before_model_callbacks:
|
580
|
-
|
606
|
+
callback_response = callback(
|
581
607
|
callback_context=callback_context, llm_request=llm_request
|
582
608
|
)
|
583
|
-
if inspect.isawaitable(
|
584
|
-
|
585
|
-
if
|
586
|
-
return
|
609
|
+
if inspect.isawaitable(callback_response):
|
610
|
+
callback_response = await callback_response
|
611
|
+
if callback_response:
|
612
|
+
return callback_response
|
587
613
|
|
588
614
|
async def _handle_after_model_callback(
|
589
615
|
self,
|
@@ -597,21 +623,32 @@ class BaseLlmFlow(ABC):
|
|
597
623
|
if not isinstance(agent, LlmAgent):
|
598
624
|
return
|
599
625
|
|
600
|
-
if not agent.canonical_after_model_callbacks:
|
601
|
-
return
|
602
|
-
|
603
626
|
callback_context = CallbackContext(
|
604
627
|
invocation_context, event_actions=model_response_event.actions
|
605
628
|
)
|
606
629
|
|
630
|
+
# First run callbacks from the plugins.
|
631
|
+
callback_response = (
|
632
|
+
await invocation_context.plugin_manager.run_after_model_callback(
|
633
|
+
callback_context=CallbackContext(invocation_context),
|
634
|
+
llm_response=llm_response,
|
635
|
+
)
|
636
|
+
)
|
637
|
+
if callback_response:
|
638
|
+
return callback_response
|
639
|
+
|
640
|
+
# If no overrides are provided from the plugins, further run the canonical
|
641
|
+
# callbacks.
|
642
|
+
if not agent.canonical_after_model_callbacks:
|
643
|
+
return
|
607
644
|
for callback in agent.canonical_after_model_callbacks:
|
608
|
-
|
645
|
+
callback_response = callback(
|
609
646
|
callback_context=callback_context, llm_response=llm_response
|
610
647
|
)
|
611
|
-
if inspect.isawaitable(
|
612
|
-
|
613
|
-
if
|
614
|
-
return
|
648
|
+
if inspect.isawaitable(callback_response):
|
649
|
+
callback_response = await callback_response
|
650
|
+
if callback_response:
|
651
|
+
return callback_response
|
615
652
|
|
616
653
|
def _finalize_model_response_event(
|
617
654
|
self,
|
@@ -157,12 +157,21 @@ def _rearrange_events_for_latest_function_response(
|
|
157
157
|
for function_call in function_calls:
|
158
158
|
if function_call.id in function_responses_ids:
|
159
159
|
function_call_event_idx = idx
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
#
|
164
|
-
|
165
|
-
|
160
|
+
function_call_ids = {
|
161
|
+
function_call.id for function_call in function_calls
|
162
|
+
}
|
163
|
+
# last response event should only contain the responses for the
|
164
|
+
# function calls in the same function call event
|
165
|
+
if not function_responses_ids.issubset(function_call_ids):
|
166
|
+
raise ValueError(
|
167
|
+
'Last response event should only contain the responses for the'
|
168
|
+
' function calls in the same function call event. Function'
|
169
|
+
f' call ids found : {function_call_ids}, function response'
|
170
|
+
f' ids provided: {function_responses_ids}'
|
171
|
+
)
|
172
|
+
# collect all function responses from the function call event to
|
173
|
+
# the last response event
|
174
|
+
function_responses_ids = function_call_ids
|
166
175
|
break
|
167
176
|
|
168
177
|
if function_call_event_idx == -1:
|
@@ -363,10 +372,7 @@ def _merge_function_response_events(
|
|
363
372
|
list is in increasing order of timestamp; 2. the first event is the
|
364
373
|
initial function_response event; 3. all later events should contain at
|
365
374
|
least one function_response part that related to the function_call
|
366
|
-
event.
|
367
|
-
intermediate response, there could also be some intermediate model
|
368
|
-
response event without any function_response and such event will be
|
369
|
-
ignored.)
|
375
|
+
event.
|
370
376
|
Caveat: This implementation doesn't support when a parallel function_call
|
371
377
|
event contains async function_call of the same name.
|
372
378
|
|
@@ -153,37 +153,67 @@ async def handle_function_calls_async(
|
|
153
153
|
# do not use "args" as the variable name, because it is a reserved keyword
|
154
154
|
# in python debugger.
|
155
155
|
function_args = function_call.args or {}
|
156
|
-
function_response: Optional[dict] = None
|
157
156
|
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
break
|
157
|
+
# Step 1: Check if plugin before_tool_callback overrides the function
|
158
|
+
# response.
|
159
|
+
function_response = (
|
160
|
+
await invocation_context.plugin_manager.run_before_tool_callback(
|
161
|
+
tool=tool, tool_args=function_args, tool_context=tool_context
|
162
|
+
)
|
163
|
+
)
|
166
164
|
|
167
|
-
|
165
|
+
# Step 2: If no overrides are provided from the plugins, further run the
|
166
|
+
# canonical callback.
|
167
|
+
if function_response is None:
|
168
|
+
for callback in agent.canonical_before_tool_callbacks:
|
169
|
+
function_response = callback(
|
170
|
+
tool=tool, args=function_args, tool_context=tool_context
|
171
|
+
)
|
172
|
+
if inspect.isawaitable(function_response):
|
173
|
+
function_response = await function_response
|
174
|
+
if function_response:
|
175
|
+
break
|
176
|
+
|
177
|
+
# Step 3: Otherwise, proceed calling the tool normally.
|
178
|
+
if function_response is None:
|
168
179
|
function_response = await __call_tool_async(
|
169
180
|
tool, args=function_args, tool_context=tool_context
|
170
181
|
)
|
171
182
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
183
|
+
# Step 4: Check if plugin after_tool_callback overrides the function
|
184
|
+
# response.
|
185
|
+
altered_function_response = (
|
186
|
+
await invocation_context.plugin_manager.run_after_tool_callback(
|
187
|
+
tool=tool,
|
188
|
+
tool_args=function_args,
|
189
|
+
tool_context=tool_context,
|
190
|
+
result=function_response,
|
191
|
+
)
|
192
|
+
)
|
193
|
+
|
194
|
+
# Step 5: If no overrides are provided from the plugins, further run the
|
195
|
+
# canonical after_tool_callbacks.
|
196
|
+
if altered_function_response is None:
|
197
|
+
for callback in agent.canonical_after_tool_callbacks:
|
198
|
+
altered_function_response = callback(
|
199
|
+
tool=tool,
|
200
|
+
args=function_args,
|
201
|
+
tool_context=tool_context,
|
202
|
+
tool_response=function_response,
|
203
|
+
)
|
204
|
+
if inspect.isawaitable(altered_function_response):
|
205
|
+
altered_function_response = await altered_function_response
|
206
|
+
if altered_function_response:
|
207
|
+
break
|
208
|
+
|
209
|
+
# Step 6: If alternative response exists from after_tool_callback, use it
|
210
|
+
# instead of the original function response.
|
211
|
+
if altered_function_response is not None:
|
212
|
+
function_response = altered_function_response
|
184
213
|
|
185
214
|
if tool.is_long_running:
|
186
|
-
# Allow long running function to return None to not provide function
|
215
|
+
# Allow long running function to return None to not provide function
|
216
|
+
# response.
|
187
217
|
if not function_response:
|
188
218
|
continue
|
189
219
|
|
@@ -237,35 +267,27 @@ async def handle_function_calls_live(
|
|
237
267
|
# in python debugger.
|
238
268
|
function_args = function_call.args or {}
|
239
269
|
function_response = None
|
240
|
-
|
241
|
-
#
|
242
|
-
#
|
243
|
-
|
244
|
-
|
245
|
-
if agent.before_tool_callback:
|
246
|
-
function_response = agent.before_tool_callback(
|
270
|
+
|
271
|
+
# Handle before_tool_callbacks - iterate through the canonical callback
|
272
|
+
# list
|
273
|
+
for callback in agent.canonical_before_tool_callbacks:
|
274
|
+
function_response = callback(
|
247
275
|
tool=tool, args=function_args, tool_context=tool_context
|
248
276
|
)
|
249
277
|
if inspect.isawaitable(function_response):
|
250
278
|
function_response = await function_response
|
279
|
+
if function_response:
|
280
|
+
break
|
251
281
|
|
252
|
-
if
|
282
|
+
if function_response is None:
|
253
283
|
function_response = await _process_function_live_helper(
|
254
284
|
tool, tool_context, function_call, function_args, invocation_context
|
255
285
|
)
|
256
286
|
|
257
287
|
# Calls after_tool_callback if it exists.
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
# function_args,
|
262
|
-
# tool_context,
|
263
|
-
# function_response,
|
264
|
-
# )
|
265
|
-
# if new_response:
|
266
|
-
# function_response = new_response
|
267
|
-
if agent.after_tool_callback:
|
268
|
-
altered_function_response = agent.after_tool_callback(
|
288
|
+
altered_function_response = None
|
289
|
+
for callback in agent.canonical_after_tool_callbacks:
|
290
|
+
altered_function_response = callback(
|
269
291
|
tool=tool,
|
270
292
|
args=function_args,
|
271
293
|
tool_context=tool_context,
|
@@ -273,8 +295,11 @@ async def handle_function_calls_live(
|
|
273
295
|
)
|
274
296
|
if inspect.isawaitable(altered_function_response):
|
275
297
|
altered_function_response = await altered_function_response
|
276
|
-
if altered_function_response
|
277
|
-
|
298
|
+
if altered_function_response:
|
299
|
+
break
|
300
|
+
|
301
|
+
if altered_function_response is not None:
|
302
|
+
function_response = altered_function_response
|
278
303
|
|
279
304
|
if tool.is_long_running:
|
280
305
|
# Allow async function to return None to not provide function response.
|
@@ -480,6 +505,16 @@ def __build_response_event(
|
|
480
505
|
return function_response_event
|
481
506
|
|
482
507
|
|
508
|
+
def deep_merge_dicts(d1: dict, d2: dict) -> dict:
|
509
|
+
"""Recursively merges d2 into d1."""
|
510
|
+
for key, value in d2.items():
|
511
|
+
if key in d1 and isinstance(d1[key], dict) and isinstance(value, dict):
|
512
|
+
d1[key] = deep_merge_dicts(d1[key], value)
|
513
|
+
else:
|
514
|
+
d1[key] = value
|
515
|
+
return d1
|
516
|
+
|
517
|
+
|
483
518
|
def merge_parallel_function_response_events(
|
484
519
|
function_response_events: list['Event'],
|
485
520
|
) -> 'Event':
|
@@ -498,15 +533,17 @@ def merge_parallel_function_response_events(
|
|
498
533
|
base_event = function_response_events[0]
|
499
534
|
|
500
535
|
# Merge actions from all events
|
501
|
-
|
502
|
-
merged_actions = EventActions()
|
503
|
-
merged_requested_auth_configs = {}
|
536
|
+
merged_actions_data = {}
|
504
537
|
for event in function_response_events:
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
538
|
+
if event.actions:
|
539
|
+
# Use `by_alias=True` because it converts the model to a dictionary while respecting field aliases, ensuring that the enum fields are correctly handled without creating a duplicate.
|
540
|
+
merged_actions_data = deep_merge_dicts(
|
541
|
+
merged_actions_data,
|
542
|
+
event.actions.model_dump(exclude_none=True, by_alias=True),
|
543
|
+
)
|
544
|
+
|
545
|
+
merged_actions = EventActions.model_validate(merged_actions_data)
|
546
|
+
|
510
547
|
# Create the new merged event
|
511
548
|
merged_event = Event(
|
512
549
|
invocation_id=Event.new_id(),
|
@@ -14,6 +14,7 @@
|
|
14
14
|
from __future__ import annotations
|
15
15
|
|
16
16
|
import re
|
17
|
+
import threading
|
17
18
|
from typing import TYPE_CHECKING
|
18
19
|
|
19
20
|
from typing_extensions import override
|
@@ -42,38 +43,43 @@ class InMemoryMemoryService(BaseMemoryService):
|
|
42
43
|
|
43
44
|
Uses keyword matching instead of semantic search.
|
44
45
|
|
45
|
-
|
46
|
-
|
46
|
+
This class is thread-safe, however, it should be used for testing and
|
47
|
+
development only.
|
47
48
|
"""
|
48
49
|
|
49
50
|
def __init__(self):
|
51
|
+
self._lock = threading.Lock()
|
52
|
+
|
50
53
|
self._session_events: dict[str, dict[str, list[Event]]] = {}
|
51
|
-
"""Keys are app_name/user_id
|
54
|
+
"""Keys are "{app_name}/{user_id}". Values are dicts of session_id to
|
55
|
+
session event lists.
|
56
|
+
"""
|
52
57
|
|
53
58
|
@override
|
54
59
|
async def add_session_to_memory(self, session: Session):
|
55
60
|
user_key = _user_key(session.app_name, session.user_id)
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
61
|
+
|
62
|
+
with self._lock:
|
63
|
+
self._session_events[user_key] = self._session_events.get(user_key, {})
|
64
|
+
self._session_events[user_key][session.id] = [
|
65
|
+
event
|
66
|
+
for event in session.events
|
67
|
+
if event.content and event.content.parts
|
68
|
+
]
|
64
69
|
|
65
70
|
@override
|
66
71
|
async def search_memory(
|
67
72
|
self, *, app_name: str, user_id: str, query: str
|
68
73
|
) -> SearchMemoryResponse:
|
69
74
|
user_key = _user_key(app_name, user_id)
|
70
|
-
if user_key not in self._session_events:
|
71
|
-
return SearchMemoryResponse()
|
72
75
|
|
73
|
-
|
76
|
+
with self._lock:
|
77
|
+
session_event_lists = self._session_events.get(user_key, {})
|
78
|
+
|
79
|
+
words_in_query = _extract_words_lower(query)
|
74
80
|
response = SearchMemoryResponse()
|
75
81
|
|
76
|
-
for session_events in
|
82
|
+
for session_events in session_event_lists.values():
|
77
83
|
for event in session_events:
|
78
84
|
if not event.content or not event.content.parts:
|
79
85
|
continue
|
@@ -16,13 +16,15 @@ from __future__ import annotations
|
|
16
16
|
|
17
17
|
import json
|
18
18
|
import logging
|
19
|
+
from typing import Any
|
20
|
+
from typing import Dict
|
19
21
|
from typing import Optional
|
20
22
|
from typing import TYPE_CHECKING
|
21
23
|
|
24
|
+
from google.genai import Client
|
25
|
+
from google.genai import types
|
22
26
|
from typing_extensions import override
|
23
27
|
|
24
|
-
from google import genai
|
25
|
-
|
26
28
|
from .base_memory_service import BaseMemoryService
|
27
29
|
from .base_memory_service import SearchMemoryResponse
|
28
30
|
from .memory_entry import MemoryEntry
|
@@ -84,7 +86,8 @@ class VertexAiMemoryBankService(BaseMemoryService):
|
|
84
86
|
path=f'reasoningEngines/{self._agent_engine_id}/memories:generate',
|
85
87
|
request_dict=request_dict,
|
86
88
|
)
|
87
|
-
logger.info(
|
89
|
+
logger.info('Generate memory response received.')
|
90
|
+
logger.debug('Generate memory response: %s', api_response)
|
88
91
|
else:
|
89
92
|
logger.info('No events to add to memory.')
|
90
93
|
|
@@ -106,7 +109,8 @@ class VertexAiMemoryBankService(BaseMemoryService):
|
|
106
109
|
},
|
107
110
|
)
|
108
111
|
api_response = _convert_api_response(api_response)
|
109
|
-
logger.info(
|
112
|
+
logger.info('Search memory response received.')
|
113
|
+
logger.debug('Search memory response: %s', api_response)
|
110
114
|
|
111
115
|
if not api_response or not api_response.get('retrievedMemories', None):
|
112
116
|
return SearchMemoryResponse()
|
@@ -117,10 +121,8 @@ class VertexAiMemoryBankService(BaseMemoryService):
|
|
117
121
|
memory_events.append(
|
118
122
|
MemoryEntry(
|
119
123
|
author='user',
|
120
|
-
content=
|
121
|
-
parts=[
|
122
|
-
genai.types.Part(text=memory.get('memory').get('fact'))
|
123
|
-
],
|
124
|
+
content=types.Content(
|
125
|
+
parts=[types.Part(text=memory.get('memory').get('fact'))],
|
124
126
|
role='user',
|
125
127
|
),
|
126
128
|
timestamp=memory.get('updateTime'),
|
@@ -137,13 +139,13 @@ class VertexAiMemoryBankService(BaseMemoryService):
|
|
137
139
|
Returns:
|
138
140
|
An API client for the given project and location.
|
139
141
|
"""
|
140
|
-
client =
|
142
|
+
client = Client(
|
141
143
|
vertexai=True, project=self._project, location=self._location
|
142
144
|
)
|
143
145
|
return client._api_client
|
144
146
|
|
145
147
|
|
146
|
-
def _convert_api_response(api_response):
|
148
|
+
def _convert_api_response(api_response) -> Dict[str, Any]:
|
147
149
|
"""Converts the API response to a JSON object based on the type."""
|
148
150
|
if hasattr(api_response, 'body'):
|
149
151
|
return json.loads(api_response.body)
|