google-adk 1.6.1__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. google/adk/a2a/converters/event_converter.py +5 -85
  2. google/adk/a2a/converters/request_converter.py +1 -2
  3. google/adk/a2a/executor/a2a_agent_executor.py +45 -16
  4. google/adk/a2a/logs/log_utils.py +1 -2
  5. google/adk/a2a/utils/__init__.py +0 -0
  6. google/adk/a2a/utils/agent_card_builder.py +544 -0
  7. google/adk/a2a/utils/agent_to_a2a.py +118 -0
  8. google/adk/agents/__init__.py +5 -0
  9. google/adk/agents/agent_config.py +46 -0
  10. google/adk/agents/base_agent.py +239 -41
  11. google/adk/agents/callback_context.py +41 -0
  12. google/adk/agents/common_configs.py +79 -0
  13. google/adk/agents/config_agent_utils.py +184 -0
  14. google/adk/agents/config_schemas/AgentConfig.json +566 -0
  15. google/adk/agents/invocation_context.py +5 -1
  16. google/adk/agents/live_request_queue.py +15 -0
  17. google/adk/agents/llm_agent.py +201 -9
  18. google/adk/agents/loop_agent.py +35 -1
  19. google/adk/agents/parallel_agent.py +24 -3
  20. google/adk/agents/remote_a2a_agent.py +17 -5
  21. google/adk/agents/sequential_agent.py +22 -1
  22. google/adk/artifacts/gcs_artifact_service.py +110 -20
  23. google/adk/auth/auth_handler.py +3 -3
  24. google/adk/auth/credential_manager.py +23 -23
  25. google/adk/auth/credential_service/base_credential_service.py +6 -6
  26. google/adk/auth/credential_service/in_memory_credential_service.py +10 -8
  27. google/adk/auth/credential_service/session_state_credential_service.py +8 -8
  28. google/adk/auth/exchanger/oauth2_credential_exchanger.py +3 -3
  29. google/adk/auth/oauth2_credential_util.py +2 -2
  30. google/adk/auth/refresher/oauth2_credential_refresher.py +4 -4
  31. google/adk/cli/agent_graph.py +3 -1
  32. google/adk/cli/browser/index.html +2 -2
  33. google/adk/cli/browser/main-W7QZBYAR.js +3914 -0
  34. google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
  35. google/adk/cli/cli_eval.py +87 -12
  36. google/adk/cli/cli_tools_click.py +143 -82
  37. google/adk/cli/fast_api.py +150 -69
  38. google/adk/cli/utils/agent_loader.py +35 -1
  39. google/adk/code_executors/base_code_executor.py +14 -19
  40. google/adk/code_executors/built_in_code_executor.py +4 -1
  41. google/adk/evaluation/base_eval_service.py +46 -2
  42. google/adk/evaluation/eval_metrics.py +4 -0
  43. google/adk/evaluation/eval_sets_manager.py +5 -1
  44. google/adk/evaluation/evaluation_generator.py +1 -1
  45. google/adk/evaluation/final_response_match_v2.py +2 -2
  46. google/adk/evaluation/gcs_eval_sets_manager.py +2 -1
  47. google/adk/evaluation/in_memory_eval_sets_manager.py +151 -0
  48. google/adk/evaluation/local_eval_service.py +389 -0
  49. google/adk/evaluation/local_eval_set_results_manager.py +2 -2
  50. google/adk/evaluation/local_eval_sets_manager.py +24 -9
  51. google/adk/evaluation/metric_evaluator_registry.py +16 -6
  52. google/adk/evaluation/vertex_ai_eval_facade.py +7 -1
  53. google/adk/events/event.py +7 -2
  54. google/adk/flows/llm_flows/auto_flow.py +6 -11
  55. google/adk/flows/llm_flows/base_llm_flow.py +66 -29
  56. google/adk/flows/llm_flows/contents.py +16 -10
  57. google/adk/flows/llm_flows/functions.py +89 -52
  58. google/adk/memory/in_memory_memory_service.py +21 -15
  59. google/adk/memory/vertex_ai_memory_bank_service.py +12 -10
  60. google/adk/models/anthropic_llm.py +46 -6
  61. google/adk/models/base_llm_connection.py +2 -0
  62. google/adk/models/gemini_llm_connection.py +17 -6
  63. google/adk/models/google_llm.py +46 -11
  64. google/adk/models/lite_llm.py +52 -22
  65. google/adk/plugins/__init__.py +17 -0
  66. google/adk/plugins/base_plugin.py +317 -0
  67. google/adk/plugins/plugin_manager.py +265 -0
  68. google/adk/runners.py +122 -18
  69. google/adk/sessions/database_session_service.py +51 -52
  70. google/adk/sessions/vertex_ai_session_service.py +27 -12
  71. google/adk/tools/__init__.py +2 -0
  72. google/adk/tools/_automatic_function_calling_util.py +20 -2
  73. google/adk/tools/agent_tool.py +15 -3
  74. google/adk/tools/apihub_tool/apihub_toolset.py +38 -39
  75. google/adk/tools/application_integration_tool/application_integration_toolset.py +35 -37
  76. google/adk/tools/application_integration_tool/integration_connector_tool.py +2 -3
  77. google/adk/tools/base_tool.py +9 -9
  78. google/adk/tools/base_toolset.py +29 -5
  79. google/adk/tools/bigquery/__init__.py +3 -3
  80. google/adk/tools/bigquery/metadata_tool.py +2 -0
  81. google/adk/tools/bigquery/query_tool.py +15 -1
  82. google/adk/tools/computer_use/__init__.py +13 -0
  83. google/adk/tools/computer_use/base_computer.py +265 -0
  84. google/adk/tools/computer_use/computer_use_tool.py +166 -0
  85. google/adk/tools/computer_use/computer_use_toolset.py +220 -0
  86. google/adk/tools/enterprise_search_tool.py +4 -2
  87. google/adk/tools/exit_loop_tool.py +1 -0
  88. google/adk/tools/google_api_tool/google_api_tool.py +16 -1
  89. google/adk/tools/google_api_tool/google_api_toolset.py +9 -7
  90. google/adk/tools/google_api_tool/google_api_toolsets.py +41 -20
  91. google/adk/tools/google_search_tool.py +4 -2
  92. google/adk/tools/langchain_tool.py +16 -6
  93. google/adk/tools/long_running_tool.py +21 -0
  94. google/adk/tools/mcp_tool/mcp_toolset.py +27 -28
  95. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_spec_parser.py +5 -0
  96. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +8 -8
  97. google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +4 -6
  98. google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +3 -2
  99. google/adk/tools/tool_context.py +0 -10
  100. google/adk/tools/url_context_tool.py +4 -2
  101. google/adk/tools/vertex_ai_search_tool.py +4 -2
  102. google/adk/utils/model_name_utils.py +90 -0
  103. google/adk/version.py +1 -1
  104. {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/METADATA +3 -2
  105. {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/RECORD +108 -91
  106. google/adk/cli/browser/main-RXDVX3K6.js +0 -3914
  107. google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -17
  108. {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/WHEEL +0 -0
  109. {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/entry_points.txt +0 -0
  110. {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -14,6 +14,7 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ import math
17
18
  import os
18
19
  from typing import Optional
19
20
 
@@ -112,7 +113,12 @@ class _VertexAiEvalFacade(Evaluator):
112
113
  return ""
113
114
 
114
115
  def _get_score(self, eval_result) -> Optional[float]:
115
- if eval_result and eval_result.summary_metrics:
116
+ if (
117
+ eval_result
118
+ and eval_result.summary_metrics
119
+ and isinstance(eval_result.summary_metrics[0].mean_score, float)
120
+ and not math.isnan(eval_result.summary_metrics[0].mean_score)
121
+ ):
116
122
  return eval_result.summary_metrics[0].mean_score
117
123
 
118
124
  return None
@@ -42,7 +42,6 @@ class Event(LlmResponse):
42
42
  branch: The branch of the event.
43
43
  id: The unique identifier of the event.
44
44
  timestamp: The timestamp of the event.
45
- is_final_response: Whether the event is the final response of the agent.
46
45
  get_function_calls: Returns the function calls in the event.
47
46
  """
48
47
 
@@ -92,7 +91,13 @@ class Event(LlmResponse):
92
91
  self.id = Event.new_id()
93
92
 
94
93
  def is_final_response(self) -> bool:
95
- """Returns whether the event is the final response of the agent."""
94
+ """Returns whether the event is the final response of an agent.
95
+
96
+ NOTE: This method is ONLY for use by Agent Development Kit.
97
+
98
+ Note that when multiple agents participage in one invocation, there could be
99
+ one event has `is_final_response()` as True for each participating agent.
100
+ """
96
101
  if self.actions.skip_summarization or self.long_running_tool_ids:
97
102
  return True
98
103
  return (
@@ -14,6 +14,8 @@
14
14
 
15
15
  """Implementation of AutoFlow."""
16
16
 
17
+ from __future__ import annotations
18
+
17
19
  from . import agent_transfer
18
20
  from .single_flow import SingleFlow
19
21
 
@@ -29,19 +31,12 @@ class AutoFlow(SingleFlow):
29
31
 
30
32
  For peer-agent transfers, it's only enabled when all below conditions are met:
31
33
 
32
- - The parent agent is also of AutoFlow;
34
+ - The parent agent is also an LlmAgent.
33
35
  - `disallow_transfer_to_peer` option of this agent is False (default).
34
36
 
35
- Depending on the target agent flow type, the transfer may be automatically
36
- reversed. The condition is as below:
37
-
38
- - If the flow type of the tranferee agent is also auto, transfee agent will
39
- remain as the active agent. The transfee agent will respond to the user's
40
- next message directly.
41
- - If the flow type of the transfere agent is not auto, the active agent will
42
- be reversed back to previous agent.
43
-
44
- TODO: allow user to config auto-reverse function.
37
+ Depending on the target agent type, the transfer may be automatically
38
+ reversed. (see Runner._find_agent_to_run method for which agent will remain
39
+ active to handle next user message.)
45
40
  """
46
41
 
47
42
  def __init__(self):
@@ -42,6 +42,7 @@ from ...models.llm_response import LlmResponse
42
42
  from ...telemetry import trace_call_llm
43
43
  from ...telemetry import trace_send_data
44
44
  from ...telemetry import tracer
45
+ from ...tools.base_toolset import BaseToolset
45
46
  from ...tools.tool_context import ToolContext
46
47
 
47
48
  if TYPE_CHECKING:
@@ -194,7 +195,12 @@ class BaseLlmFlow(ABC):
194
195
  if live_request.close:
195
196
  await llm_connection.close()
196
197
  return
197
- if live_request.blob:
198
+
199
+ if live_request.activity_start:
200
+ await llm_connection.send_realtime(types.ActivityStart())
201
+ elif live_request.activity_end:
202
+ await llm_connection.send_realtime(types.ActivityEnd())
203
+ elif live_request.blob:
198
204
  # Cache audio data here for transcription
199
205
  if not invocation_context.transcription_cache:
200
206
  invocation_context.transcription_cache = []
@@ -205,6 +211,7 @@ class BaseLlmFlow(ABC):
205
211
  TranscriptionEntry(role='user', data=live_request.blob)
206
212
  )
207
213
  await llm_connection.send_realtime(live_request.blob)
214
+
208
215
  if live_request.content:
209
216
  await llm_connection.send_content(live_request.content)
210
217
 
@@ -283,14 +290,10 @@ class BaseLlmFlow(ABC):
283
290
  async for event in self._run_one_step_async(invocation_context):
284
291
  last_event = event
285
292
  yield event
286
- if not last_event or last_event.is_final_response():
293
+ if not last_event or last_event.is_final_response() or last_event.partial:
294
+ if last_event and last_event.partial:
295
+ logger.warning('The last event is partial, which is not expected.')
287
296
  break
288
- if last_event.partial:
289
- # TODO: handle this in BaseLlm level.
290
- raise ValueError(
291
- f"Last event shouldn't be partial. LLM max output limit may be"
292
- f' reached.'
293
- )
294
297
 
295
298
  async def _run_one_step_async(
296
299
  self,
@@ -339,13 +342,25 @@ class BaseLlmFlow(ABC):
339
342
  yield event
340
343
 
341
344
  # Run processors for tools.
342
- for tool in await agent.canonical_tools(
343
- ReadonlyContext(invocation_context)
344
- ):
345
+ for tool_union in agent.tools:
345
346
  tool_context = ToolContext(invocation_context)
346
- await tool.process_llm_request(
347
- tool_context=tool_context, llm_request=llm_request
347
+
348
+ # If it's a toolset, process it first
349
+ if isinstance(tool_union, BaseToolset):
350
+ await tool_union.process_llm_request(
351
+ tool_context=tool_context, llm_request=llm_request
352
+ )
353
+
354
+ from ...agents.llm_agent import _convert_tool_union_to_tools
355
+
356
+ # Then process all tools from this tool union
357
+ tools = await _convert_tool_union_to_tools(
358
+ tool_union, ReadonlyContext(invocation_context)
348
359
  )
360
+ for tool in tools:
361
+ await tool.process_llm_request(
362
+ tool_context=tool_context, llm_request=llm_request
363
+ )
349
364
 
350
365
  async def _postprocess_async(
351
366
  self,
@@ -569,21 +584,32 @@ class BaseLlmFlow(ABC):
569
584
  if not isinstance(agent, LlmAgent):
570
585
  return
571
586
 
572
- if not agent.canonical_before_model_callbacks:
573
- return
574
-
575
587
  callback_context = CallbackContext(
576
588
  invocation_context, event_actions=model_response_event.actions
577
589
  )
578
590
 
591
+ # First run callbacks from the plugins.
592
+ callback_response = (
593
+ await invocation_context.plugin_manager.run_before_model_callback(
594
+ callback_context=callback_context,
595
+ llm_request=llm_request,
596
+ )
597
+ )
598
+ if callback_response:
599
+ return callback_response
600
+
601
+ # If no overrides are provided from the plugins, further run the canonical
602
+ # callbacks.
603
+ if not agent.canonical_before_model_callbacks:
604
+ return
579
605
  for callback in agent.canonical_before_model_callbacks:
580
- before_model_callback_content = callback(
606
+ callback_response = callback(
581
607
  callback_context=callback_context, llm_request=llm_request
582
608
  )
583
- if inspect.isawaitable(before_model_callback_content):
584
- before_model_callback_content = await before_model_callback_content
585
- if before_model_callback_content:
586
- return before_model_callback_content
609
+ if inspect.isawaitable(callback_response):
610
+ callback_response = await callback_response
611
+ if callback_response:
612
+ return callback_response
587
613
 
588
614
  async def _handle_after_model_callback(
589
615
  self,
@@ -597,21 +623,32 @@ class BaseLlmFlow(ABC):
597
623
  if not isinstance(agent, LlmAgent):
598
624
  return
599
625
 
600
- if not agent.canonical_after_model_callbacks:
601
- return
602
-
603
626
  callback_context = CallbackContext(
604
627
  invocation_context, event_actions=model_response_event.actions
605
628
  )
606
629
 
630
+ # First run callbacks from the plugins.
631
+ callback_response = (
632
+ await invocation_context.plugin_manager.run_after_model_callback(
633
+ callback_context=CallbackContext(invocation_context),
634
+ llm_response=llm_response,
635
+ )
636
+ )
637
+ if callback_response:
638
+ return callback_response
639
+
640
+ # If no overrides are provided from the plugins, further run the canonical
641
+ # callbacks.
642
+ if not agent.canonical_after_model_callbacks:
643
+ return
607
644
  for callback in agent.canonical_after_model_callbacks:
608
- after_model_callback_content = callback(
645
+ callback_response = callback(
609
646
  callback_context=callback_context, llm_response=llm_response
610
647
  )
611
- if inspect.isawaitable(after_model_callback_content):
612
- after_model_callback_content = await after_model_callback_content
613
- if after_model_callback_content:
614
- return after_model_callback_content
648
+ if inspect.isawaitable(callback_response):
649
+ callback_response = await callback_response
650
+ if callback_response:
651
+ return callback_response
615
652
 
616
653
  def _finalize_model_response_event(
617
654
  self,
@@ -157,12 +157,21 @@ def _rearrange_events_for_latest_function_response(
157
157
  for function_call in function_calls:
158
158
  if function_call.id in function_responses_ids:
159
159
  function_call_event_idx = idx
160
- break
161
- if function_call_event_idx != -1:
162
- # in case the last response event only have part of the responses
163
- # for the function calls in the function call event
164
- for function_call in function_calls:
165
- function_responses_ids.add(function_call.id)
160
+ function_call_ids = {
161
+ function_call.id for function_call in function_calls
162
+ }
163
+ # last response event should only contain the responses for the
164
+ # function calls in the same function call event
165
+ if not function_responses_ids.issubset(function_call_ids):
166
+ raise ValueError(
167
+ 'Last response event should only contain the responses for the'
168
+ ' function calls in the same function call event. Function'
169
+ f' call ids found : {function_call_ids}, function response'
170
+ f' ids provided: {function_responses_ids}'
171
+ )
172
+ # collect all function responses from the function call event to
173
+ # the last response event
174
+ function_responses_ids = function_call_ids
166
175
  break
167
176
 
168
177
  if function_call_event_idx == -1:
@@ -363,10 +372,7 @@ def _merge_function_response_events(
363
372
  list is in increasing order of timestamp; 2. the first event is the
364
373
  initial function_response event; 3. all later events should contain at
365
374
  least one function_response part that related to the function_call
366
- event. (Note, 3. may not be true when aync function return some
367
- intermediate response, there could also be some intermediate model
368
- response event without any function_response and such event will be
369
- ignored.)
375
+ event.
370
376
  Caveat: This implementation doesn't support when a parallel function_call
371
377
  event contains async function_call of the same name.
372
378
 
@@ -153,37 +153,67 @@ async def handle_function_calls_async(
153
153
  # do not use "args" as the variable name, because it is a reserved keyword
154
154
  # in python debugger.
155
155
  function_args = function_call.args or {}
156
- function_response: Optional[dict] = None
157
156
 
158
- for callback in agent.canonical_before_tool_callbacks:
159
- function_response = callback(
160
- tool=tool, args=function_args, tool_context=tool_context
161
- )
162
- if inspect.isawaitable(function_response):
163
- function_response = await function_response
164
- if function_response:
165
- break
157
+ # Step 1: Check if plugin before_tool_callback overrides the function
158
+ # response.
159
+ function_response = (
160
+ await invocation_context.plugin_manager.run_before_tool_callback(
161
+ tool=tool, tool_args=function_args, tool_context=tool_context
162
+ )
163
+ )
166
164
 
167
- if not function_response:
165
+ # Step 2: If no overrides are provided from the plugins, further run the
166
+ # canonical callback.
167
+ if function_response is None:
168
+ for callback in agent.canonical_before_tool_callbacks:
169
+ function_response = callback(
170
+ tool=tool, args=function_args, tool_context=tool_context
171
+ )
172
+ if inspect.isawaitable(function_response):
173
+ function_response = await function_response
174
+ if function_response:
175
+ break
176
+
177
+ # Step 3: Otherwise, proceed calling the tool normally.
178
+ if function_response is None:
168
179
  function_response = await __call_tool_async(
169
180
  tool, args=function_args, tool_context=tool_context
170
181
  )
171
182
 
172
- for callback in agent.canonical_after_tool_callbacks:
173
- altered_function_response = callback(
174
- tool=tool,
175
- args=function_args,
176
- tool_context=tool_context,
177
- tool_response=function_response,
178
- )
179
- if inspect.isawaitable(altered_function_response):
180
- altered_function_response = await altered_function_response
181
- if altered_function_response is not None:
182
- function_response = altered_function_response
183
- break
183
+ # Step 4: Check if plugin after_tool_callback overrides the function
184
+ # response.
185
+ altered_function_response = (
186
+ await invocation_context.plugin_manager.run_after_tool_callback(
187
+ tool=tool,
188
+ tool_args=function_args,
189
+ tool_context=tool_context,
190
+ result=function_response,
191
+ )
192
+ )
193
+
194
+ # Step 5: If no overrides are provided from the plugins, further run the
195
+ # canonical after_tool_callbacks.
196
+ if altered_function_response is None:
197
+ for callback in agent.canonical_after_tool_callbacks:
198
+ altered_function_response = callback(
199
+ tool=tool,
200
+ args=function_args,
201
+ tool_context=tool_context,
202
+ tool_response=function_response,
203
+ )
204
+ if inspect.isawaitable(altered_function_response):
205
+ altered_function_response = await altered_function_response
206
+ if altered_function_response:
207
+ break
208
+
209
+ # Step 6: If alternative response exists from after_tool_callback, use it
210
+ # instead of the original function response.
211
+ if altered_function_response is not None:
212
+ function_response = altered_function_response
184
213
 
185
214
  if tool.is_long_running:
186
- # Allow long running function to return None to not provide function response.
215
+ # Allow long running function to return None to not provide function
216
+ # response.
187
217
  if not function_response:
188
218
  continue
189
219
 
@@ -237,35 +267,27 @@ async def handle_function_calls_live(
237
267
  # in python debugger.
238
268
  function_args = function_call.args or {}
239
269
  function_response = None
240
- # # Calls the tool if before_tool_callback does not exist or returns None.
241
- # if agent.before_tool_callback:
242
- # function_response = agent.before_tool_callback(
243
- # tool, function_args, tool_context
244
- # )
245
- if agent.before_tool_callback:
246
- function_response = agent.before_tool_callback(
270
+
271
+ # Handle before_tool_callbacks - iterate through the canonical callback
272
+ # list
273
+ for callback in agent.canonical_before_tool_callbacks:
274
+ function_response = callback(
247
275
  tool=tool, args=function_args, tool_context=tool_context
248
276
  )
249
277
  if inspect.isawaitable(function_response):
250
278
  function_response = await function_response
279
+ if function_response:
280
+ break
251
281
 
252
- if not function_response:
282
+ if function_response is None:
253
283
  function_response = await _process_function_live_helper(
254
284
  tool, tool_context, function_call, function_args, invocation_context
255
285
  )
256
286
 
257
287
  # Calls after_tool_callback if it exists.
258
- # if agent.after_tool_callback:
259
- # new_response = agent.after_tool_callback(
260
- # tool,
261
- # function_args,
262
- # tool_context,
263
- # function_response,
264
- # )
265
- # if new_response:
266
- # function_response = new_response
267
- if agent.after_tool_callback:
268
- altered_function_response = agent.after_tool_callback(
288
+ altered_function_response = None
289
+ for callback in agent.canonical_after_tool_callbacks:
290
+ altered_function_response = callback(
269
291
  tool=tool,
270
292
  args=function_args,
271
293
  tool_context=tool_context,
@@ -273,8 +295,11 @@ async def handle_function_calls_live(
273
295
  )
274
296
  if inspect.isawaitable(altered_function_response):
275
297
  altered_function_response = await altered_function_response
276
- if altered_function_response is not None:
277
- function_response = altered_function_response
298
+ if altered_function_response:
299
+ break
300
+
301
+ if altered_function_response is not None:
302
+ function_response = altered_function_response
278
303
 
279
304
  if tool.is_long_running:
280
305
  # Allow async function to return None to not provide function response.
@@ -480,6 +505,16 @@ def __build_response_event(
480
505
  return function_response_event
481
506
 
482
507
 
508
+ def deep_merge_dicts(d1: dict, d2: dict) -> dict:
509
+ """Recursively merges d2 into d1."""
510
+ for key, value in d2.items():
511
+ if key in d1 and isinstance(d1[key], dict) and isinstance(value, dict):
512
+ d1[key] = deep_merge_dicts(d1[key], value)
513
+ else:
514
+ d1[key] = value
515
+ return d1
516
+
517
+
483
518
  def merge_parallel_function_response_events(
484
519
  function_response_events: list['Event'],
485
520
  ) -> 'Event':
@@ -498,15 +533,17 @@ def merge_parallel_function_response_events(
498
533
  base_event = function_response_events[0]
499
534
 
500
535
  # Merge actions from all events
501
-
502
- merged_actions = EventActions()
503
- merged_requested_auth_configs = {}
536
+ merged_actions_data = {}
504
537
  for event in function_response_events:
505
- merged_requested_auth_configs.update(event.actions.requested_auth_configs)
506
- merged_actions = merged_actions.model_copy(
507
- update=event.actions.model_dump()
508
- )
509
- merged_actions.requested_auth_configs = merged_requested_auth_configs
538
+ if event.actions:
539
+ # Use `by_alias=True` because it converts the model to a dictionary while respecting field aliases, ensuring that the enum fields are correctly handled without creating a duplicate.
540
+ merged_actions_data = deep_merge_dicts(
541
+ merged_actions_data,
542
+ event.actions.model_dump(exclude_none=True, by_alias=True),
543
+ )
544
+
545
+ merged_actions = EventActions.model_validate(merged_actions_data)
546
+
510
547
  # Create the new merged event
511
548
  merged_event = Event(
512
549
  invocation_id=Event.new_id(),
@@ -14,6 +14,7 @@
14
14
  from __future__ import annotations
15
15
 
16
16
  import re
17
+ import threading
17
18
  from typing import TYPE_CHECKING
18
19
 
19
20
  from typing_extensions import override
@@ -42,38 +43,43 @@ class InMemoryMemoryService(BaseMemoryService):
42
43
 
43
44
  Uses keyword matching instead of semantic search.
44
45
 
45
- It is not suitable for multi-threaded production environments. Use it for
46
- testing and development only.
46
+ This class is thread-safe, however, it should be used for testing and
47
+ development only.
47
48
  """
48
49
 
49
50
  def __init__(self):
51
+ self._lock = threading.Lock()
52
+
50
53
  self._session_events: dict[str, dict[str, list[Event]]] = {}
51
- """Keys are app_name/user_id, session_id. Values are session event lists."""
54
+ """Keys are "{app_name}/{user_id}". Values are dicts of session_id to
55
+ session event lists.
56
+ """
52
57
 
53
58
  @override
54
59
  async def add_session_to_memory(self, session: Session):
55
60
  user_key = _user_key(session.app_name, session.user_id)
56
- self._session_events[user_key] = self._session_events.get(
57
- _user_key(session.app_name, session.user_id), {}
58
- )
59
- self._session_events[user_key][session.id] = [
60
- event
61
- for event in session.events
62
- if event.content and event.content.parts
63
- ]
61
+
62
+ with self._lock:
63
+ self._session_events[user_key] = self._session_events.get(user_key, {})
64
+ self._session_events[user_key][session.id] = [
65
+ event
66
+ for event in session.events
67
+ if event.content and event.content.parts
68
+ ]
64
69
 
65
70
  @override
66
71
  async def search_memory(
67
72
  self, *, app_name: str, user_id: str, query: str
68
73
  ) -> SearchMemoryResponse:
69
74
  user_key = _user_key(app_name, user_id)
70
- if user_key not in self._session_events:
71
- return SearchMemoryResponse()
72
75
 
73
- words_in_query = set(query.lower().split())
76
+ with self._lock:
77
+ session_event_lists = self._session_events.get(user_key, {})
78
+
79
+ words_in_query = _extract_words_lower(query)
74
80
  response = SearchMemoryResponse()
75
81
 
76
- for session_events in self._session_events[user_key].values():
82
+ for session_events in session_event_lists.values():
77
83
  for event in session_events:
78
84
  if not event.content or not event.content.parts:
79
85
  continue
@@ -16,13 +16,15 @@ from __future__ import annotations
16
16
 
17
17
  import json
18
18
  import logging
19
+ from typing import Any
20
+ from typing import Dict
19
21
  from typing import Optional
20
22
  from typing import TYPE_CHECKING
21
23
 
24
+ from google.genai import Client
25
+ from google.genai import types
22
26
  from typing_extensions import override
23
27
 
24
- from google import genai
25
-
26
28
  from .base_memory_service import BaseMemoryService
27
29
  from .base_memory_service import SearchMemoryResponse
28
30
  from .memory_entry import MemoryEntry
@@ -84,7 +86,8 @@ class VertexAiMemoryBankService(BaseMemoryService):
84
86
  path=f'reasoningEngines/{self._agent_engine_id}/memories:generate',
85
87
  request_dict=request_dict,
86
88
  )
87
- logger.info(f'Generate memory response: {api_response}')
89
+ logger.info('Generate memory response received.')
90
+ logger.debug('Generate memory response: %s', api_response)
88
91
  else:
89
92
  logger.info('No events to add to memory.')
90
93
 
@@ -106,7 +109,8 @@ class VertexAiMemoryBankService(BaseMemoryService):
106
109
  },
107
110
  )
108
111
  api_response = _convert_api_response(api_response)
109
- logger.info(f'Search memory response: {api_response}')
112
+ logger.info('Search memory response received.')
113
+ logger.debug('Search memory response: %s', api_response)
110
114
 
111
115
  if not api_response or not api_response.get('retrievedMemories', None):
112
116
  return SearchMemoryResponse()
@@ -117,10 +121,8 @@ class VertexAiMemoryBankService(BaseMemoryService):
117
121
  memory_events.append(
118
122
  MemoryEntry(
119
123
  author='user',
120
- content=genai.types.Content(
121
- parts=[
122
- genai.types.Part(text=memory.get('memory').get('fact'))
123
- ],
124
+ content=types.Content(
125
+ parts=[types.Part(text=memory.get('memory').get('fact'))],
124
126
  role='user',
125
127
  ),
126
128
  timestamp=memory.get('updateTime'),
@@ -137,13 +139,13 @@ class VertexAiMemoryBankService(BaseMemoryService):
137
139
  Returns:
138
140
  An API client for the given project and location.
139
141
  """
140
- client = genai.Client(
142
+ client = Client(
141
143
  vertexai=True, project=self._project, location=self._location
142
144
  )
143
145
  return client._api_client
144
146
 
145
147
 
146
- def _convert_api_response(api_response):
148
+ def _convert_api_response(api_response) -> Dict[str, Any]:
147
149
  """Converts the API response to a JSON object based on the type."""
148
150
  if hasattr(api_response, 'body'):
149
151
  return json.loads(api_response.body)