google-adk 1.6.1__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. google/adk/a2a/converters/event_converter.py +5 -85
  2. google/adk/a2a/converters/request_converter.py +1 -2
  3. google/adk/a2a/executor/a2a_agent_executor.py +45 -16
  4. google/adk/a2a/logs/log_utils.py +1 -2
  5. google/adk/a2a/utils/__init__.py +0 -0
  6. google/adk/a2a/utils/agent_card_builder.py +544 -0
  7. google/adk/a2a/utils/agent_to_a2a.py +118 -0
  8. google/adk/agents/__init__.py +5 -0
  9. google/adk/agents/agent_config.py +46 -0
  10. google/adk/agents/base_agent.py +239 -41
  11. google/adk/agents/callback_context.py +41 -0
  12. google/adk/agents/common_configs.py +79 -0
  13. google/adk/agents/config_agent_utils.py +184 -0
  14. google/adk/agents/config_schemas/AgentConfig.json +566 -0
  15. google/adk/agents/invocation_context.py +5 -1
  16. google/adk/agents/live_request_queue.py +15 -0
  17. google/adk/agents/llm_agent.py +201 -9
  18. google/adk/agents/loop_agent.py +35 -1
  19. google/adk/agents/parallel_agent.py +24 -3
  20. google/adk/agents/remote_a2a_agent.py +17 -5
  21. google/adk/agents/sequential_agent.py +22 -1
  22. google/adk/artifacts/gcs_artifact_service.py +110 -20
  23. google/adk/auth/auth_handler.py +3 -3
  24. google/adk/auth/credential_manager.py +23 -23
  25. google/adk/auth/credential_service/base_credential_service.py +6 -6
  26. google/adk/auth/credential_service/in_memory_credential_service.py +10 -8
  27. google/adk/auth/credential_service/session_state_credential_service.py +8 -8
  28. google/adk/auth/exchanger/oauth2_credential_exchanger.py +3 -3
  29. google/adk/auth/oauth2_credential_util.py +2 -2
  30. google/adk/auth/refresher/oauth2_credential_refresher.py +4 -4
  31. google/adk/cli/agent_graph.py +3 -1
  32. google/adk/cli/browser/index.html +2 -2
  33. google/adk/cli/browser/main-W7QZBYAR.js +3914 -0
  34. google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
  35. google/adk/cli/cli_eval.py +87 -12
  36. google/adk/cli/cli_tools_click.py +143 -82
  37. google/adk/cli/fast_api.py +150 -69
  38. google/adk/cli/utils/agent_loader.py +35 -1
  39. google/adk/code_executors/base_code_executor.py +14 -19
  40. google/adk/code_executors/built_in_code_executor.py +4 -1
  41. google/adk/evaluation/base_eval_service.py +46 -2
  42. google/adk/evaluation/eval_metrics.py +4 -0
  43. google/adk/evaluation/eval_sets_manager.py +5 -1
  44. google/adk/evaluation/evaluation_generator.py +1 -1
  45. google/adk/evaluation/final_response_match_v2.py +2 -2
  46. google/adk/evaluation/gcs_eval_sets_manager.py +2 -1
  47. google/adk/evaluation/in_memory_eval_sets_manager.py +151 -0
  48. google/adk/evaluation/local_eval_service.py +389 -0
  49. google/adk/evaluation/local_eval_set_results_manager.py +2 -2
  50. google/adk/evaluation/local_eval_sets_manager.py +24 -9
  51. google/adk/evaluation/metric_evaluator_registry.py +16 -6
  52. google/adk/evaluation/vertex_ai_eval_facade.py +7 -1
  53. google/adk/events/event.py +7 -2
  54. google/adk/flows/llm_flows/auto_flow.py +6 -11
  55. google/adk/flows/llm_flows/base_llm_flow.py +66 -29
  56. google/adk/flows/llm_flows/contents.py +16 -10
  57. google/adk/flows/llm_flows/functions.py +89 -52
  58. google/adk/memory/in_memory_memory_service.py +21 -15
  59. google/adk/memory/vertex_ai_memory_bank_service.py +12 -10
  60. google/adk/models/anthropic_llm.py +46 -6
  61. google/adk/models/base_llm_connection.py +2 -0
  62. google/adk/models/gemini_llm_connection.py +17 -6
  63. google/adk/models/google_llm.py +46 -11
  64. google/adk/models/lite_llm.py +52 -22
  65. google/adk/plugins/__init__.py +17 -0
  66. google/adk/plugins/base_plugin.py +317 -0
  67. google/adk/plugins/plugin_manager.py +265 -0
  68. google/adk/runners.py +122 -18
  69. google/adk/sessions/database_session_service.py +51 -52
  70. google/adk/sessions/vertex_ai_session_service.py +27 -12
  71. google/adk/tools/__init__.py +2 -0
  72. google/adk/tools/_automatic_function_calling_util.py +20 -2
  73. google/adk/tools/agent_tool.py +15 -3
  74. google/adk/tools/apihub_tool/apihub_toolset.py +38 -39
  75. google/adk/tools/application_integration_tool/application_integration_toolset.py +35 -37
  76. google/adk/tools/application_integration_tool/integration_connector_tool.py +2 -3
  77. google/adk/tools/base_tool.py +9 -9
  78. google/adk/tools/base_toolset.py +29 -5
  79. google/adk/tools/bigquery/__init__.py +3 -3
  80. google/adk/tools/bigquery/metadata_tool.py +2 -0
  81. google/adk/tools/bigquery/query_tool.py +15 -1
  82. google/adk/tools/computer_use/__init__.py +13 -0
  83. google/adk/tools/computer_use/base_computer.py +265 -0
  84. google/adk/tools/computer_use/computer_use_tool.py +166 -0
  85. google/adk/tools/computer_use/computer_use_toolset.py +220 -0
  86. google/adk/tools/enterprise_search_tool.py +4 -2
  87. google/adk/tools/exit_loop_tool.py +1 -0
  88. google/adk/tools/google_api_tool/google_api_tool.py +16 -1
  89. google/adk/tools/google_api_tool/google_api_toolset.py +9 -7
  90. google/adk/tools/google_api_tool/google_api_toolsets.py +41 -20
  91. google/adk/tools/google_search_tool.py +4 -2
  92. google/adk/tools/langchain_tool.py +16 -6
  93. google/adk/tools/long_running_tool.py +21 -0
  94. google/adk/tools/mcp_tool/mcp_toolset.py +27 -28
  95. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_spec_parser.py +5 -0
  96. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +8 -8
  97. google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +4 -6
  98. google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +3 -2
  99. google/adk/tools/tool_context.py +0 -10
  100. google/adk/tools/url_context_tool.py +4 -2
  101. google/adk/tools/vertex_ai_search_tool.py +4 -2
  102. google/adk/utils/model_name_utils.py +90 -0
  103. google/adk/version.py +1 -1
  104. {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/METADATA +3 -2
  105. {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/RECORD +108 -91
  106. google/adk/cli/browser/main-RXDVX3K6.js +0 -3914
  107. google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -17
  108. {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/WHEEL +0 -0
  109. {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/entry_points.txt +0 -0
  110. {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/licenses/LICENSE +0 -0
google/adk/runners.py CHANGED
@@ -17,9 +17,14 @@ from __future__ import annotations
17
17
  import asyncio
18
18
  import logging
19
19
  import queue
20
+ import time
21
+ from typing import Any
20
22
  from typing import AsyncGenerator
23
+ from typing import Callable
21
24
  from typing import Generator
25
+ from typing import List
22
26
  from typing import Optional
27
+ import uuid
23
28
  import warnings
24
29
 
25
30
  from google.genai import types
@@ -36,10 +41,13 @@ from .artifacts.in_memory_artifact_service import InMemoryArtifactService
36
41
  from .auth.credential_service.base_credential_service import BaseCredentialService
37
42
  from .code_executors.built_in_code_executor import BuiltInCodeExecutor
38
43
  from .events.event import Event
44
+ from .events.event import EventActions
39
45
  from .flows.llm_flows.functions import find_matching_function_call
40
46
  from .memory.base_memory_service import BaseMemoryService
41
47
  from .memory.in_memory_memory_service import InMemoryMemoryService
42
48
  from .platform.thread import create_thread
49
+ from .plugins.base_plugin import BasePlugin
50
+ from .plugins.plugin_manager import PluginManager
43
51
  from .sessions.base_session_service import BaseSessionService
44
52
  from .sessions.in_memory_session_service import InMemorySessionService
45
53
  from .sessions.session import Session
@@ -60,6 +68,7 @@ class Runner:
60
68
  app_name: The application name of the runner.
61
69
  agent: The root agent to run.
62
70
  artifact_service: The artifact service for the runner.
71
+ plugin_manager: The plugin manager for the runner.
63
72
  session_service: The session service for the runner.
64
73
  memory_service: The memory service for the runner.
65
74
  """
@@ -70,6 +79,8 @@ class Runner:
70
79
  """The root agent to run."""
71
80
  artifact_service: Optional[BaseArtifactService] = None
72
81
  """The artifact service for the runner."""
82
+ plugin_manager: PluginManager
83
+ """The plugin manager for the runner."""
73
84
  session_service: BaseSessionService
74
85
  """The session service for the runner."""
75
86
  memory_service: Optional[BaseMemoryService] = None
@@ -82,6 +93,7 @@ class Runner:
82
93
  *,
83
94
  app_name: str,
84
95
  agent: BaseAgent,
96
+ plugins: Optional[List[BasePlugin]] = None,
85
97
  artifact_service: Optional[BaseArtifactService] = None,
86
98
  session_service: BaseSessionService,
87
99
  memory_service: Optional[BaseMemoryService] = None,
@@ -102,6 +114,7 @@ class Runner:
102
114
  self.session_service = session_service
103
115
  self.memory_service = memory_service
104
116
  self.credential_service = credential_service
117
+ self.plugin_manager = PluginManager(plugins=plugins)
105
118
 
106
119
  def run(
107
120
  self,
@@ -113,8 +126,9 @@ class Runner:
113
126
  ) -> Generator[Event, None, None]:
114
127
  """Runs the agent.
115
128
 
116
- NOTE: This sync interface is only for local testing and convenience purpose.
117
- Consider using `run_async` for production usage.
129
+ NOTE:
130
+ This sync interface is only for local testing and convenience purpose.
131
+ Consider using `run_async` for production usage.
118
132
 
119
133
  Args:
120
134
  user_id: The user ID of the session.
@@ -164,6 +178,7 @@ class Runner:
164
178
  user_id: str,
165
179
  session_id: str,
166
180
  new_message: types.Content,
181
+ state_delta: Optional[dict[str, Any]] = None,
167
182
  run_config: RunConfig = RunConfig(),
168
183
  ) -> AsyncGenerator[Event, None]:
169
184
  """Main entry method to run the agent in this runner.
@@ -191,19 +206,83 @@ class Runner:
191
206
  )
192
207
  root_agent = self.agent
193
208
 
209
+ # Modify user message before execution.
210
+ modified_user_message = (
211
+ await invocation_context.plugin_manager.run_on_user_message_callback(
212
+ invocation_context=invocation_context, user_message=new_message
213
+ )
214
+ )
215
+ if modified_user_message is not None:
216
+ new_message = modified_user_message
217
+
194
218
  if new_message:
195
219
  await self._append_new_message_to_session(
196
220
  session,
197
221
  new_message,
198
222
  invocation_context,
199
223
  run_config.save_input_blobs_as_artifacts,
224
+ state_delta,
200
225
  )
201
226
 
202
227
  invocation_context.agent = self._find_agent_to_run(session, root_agent)
203
- async for event in invocation_context.agent.run_async(invocation_context):
228
+
229
+ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
230
+ async for event in ctx.agent.run_async(ctx):
231
+ yield event
232
+
233
+ async for event in self._exec_with_plugin(
234
+ invocation_context, session, execute
235
+ ):
236
+ yield event
237
+
238
+ async def _exec_with_plugin(
239
+ self,
240
+ invocation_context: InvocationContext,
241
+ session: Session,
242
+ execute_fn: Callable[[InvocationContext], AsyncGenerator[Event, None]],
243
+ ) -> AsyncGenerator[Event, None]:
244
+ """Wraps execution with plugin callbacks.
245
+
246
+ Args:
247
+ invocation_context: The invocation context
248
+ session: The current session
249
+ execute_fn: A callable that returns an AsyncGenerator of Events
250
+
251
+ Yields:
252
+ Events from the execution, including any generated by plugins
253
+ """
254
+
255
+ plugin_manager = invocation_context.plugin_manager
256
+
257
+ # Step 1: Run the before_run callbacks to see if we should early exit.
258
+ early_exit_result = await plugin_manager.run_before_run_callback(
259
+ invocation_context=invocation_context
260
+ )
261
+ if isinstance(early_exit_result, Event):
262
+ await self.session_service.append_event(
263
+ session=session,
264
+ event=Event(
265
+ invocation_id=invocation_context.invocation_id,
266
+ author='model',
267
+ content=early_exit_result,
268
+ ),
269
+ )
270
+ yield early_exit_result
271
+ else:
272
+ # Step 2: Otherwise continue with normal execution
273
+ async for event in execute_fn(invocation_context):
204
274
  if not event.partial:
205
275
  await self.session_service.append_event(session=session, event=event)
206
- yield event
276
+ # Step 3: Run the on_event callbacks to optionally modify the event.
277
+ modified_event = await plugin_manager.run_on_event_callback(
278
+ invocation_context=invocation_context, event=event
279
+ )
280
+ yield (modified_event if modified_event else event)
281
+
282
+ # Step 4: Run the after_run callbacks to optionally modify the context.
283
+ await plugin_manager.run_after_run_callback(
284
+ invocation_context=invocation_context
285
+ )
207
286
 
208
287
  async def _append_new_message_to_session(
209
288
  self,
@@ -211,6 +290,7 @@ class Runner:
211
290
  new_message: types.Content,
212
291
  invocation_context: InvocationContext,
213
292
  save_input_blobs_as_artifacts: bool = False,
293
+ state_delta: Optional[dict[str, Any]] = None,
214
294
  ):
215
295
  """Appends a new message to the session.
216
296
 
@@ -242,11 +322,19 @@ class Runner:
242
322
  text=f'Uploaded file: {file_name}. It is saved into artifacts'
243
323
  )
244
324
  # Appends only. We do not yield the event because it's not from the model.
245
- event = Event(
246
- invocation_id=invocation_context.invocation_id,
247
- author='user',
248
- content=new_message,
249
- )
325
+ if state_delta:
326
+ event = Event(
327
+ invocation_id=invocation_context.invocation_id,
328
+ author='user',
329
+ actions=EventActions(state_delta=state_delta),
330
+ content=new_message,
331
+ )
332
+ else:
333
+ event = Event(
334
+ invocation_id=invocation_context.invocation_id,
335
+ author='user',
336
+ content=new_message,
337
+ )
250
338
  await self.session_service.append_event(session=session, event=event)
251
339
 
252
340
  async def run_live(
@@ -278,7 +366,7 @@ class Runner:
278
366
  This feature is **experimental** and its API or behavior may change
279
367
  in future releases.
280
368
 
281
- .. note::
369
+ .. NOTE::
282
370
  Either `session` or both `user_id` and `session_id` must be provided.
283
371
  """
284
372
  if session is None and (user_id is None or session_id is None):
@@ -345,8 +433,14 @@ class Runner:
345
433
  invocation_context.active_streaming_tools[tool.__name__] = (
346
434
  active_streaming_tool
347
435
  )
348
- async for event in invocation_context.agent.run_live(invocation_context):
349
- await self.session_service.append_event(session=session, event=event)
436
+
437
+ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
438
+ async for event in ctx.agent.run_live(ctx):
439
+ yield event
440
+
441
+ async for event in self._exec_with_plugin(
442
+ invocation_context, session, execute
443
+ ):
350
444
  yield event
351
445
 
352
446
  def _find_agent_to_run(
@@ -355,9 +449,10 @@ class Runner:
355
449
  """Finds the agent to run to continue the session.
356
450
 
357
451
  A qualified agent must be either of:
452
+
358
453
  - The agent that returned a function call and the last user message is a
359
454
  function response to this function call.
360
- - The root agent;
455
+ - The root agent.
361
456
  - An LlmAgent who replied last and is capable to transfer to any other agent
362
457
  in the agent hierarchy.
363
458
 
@@ -366,7 +461,8 @@ class Runner:
366
461
  root_agent: The root agent of the runner.
367
462
 
368
463
  Returns:
369
- The agent of the last message in the session or the root agent.
464
+ The agent to run. (the active agent that should reply to the latest user
465
+ message)
370
466
  """
371
467
  # If the last event is a function response, should send this response to
372
468
  # the agent that returned the corressponding function call regardless the
@@ -395,8 +491,8 @@ class Runner:
395
491
  def _is_transferable_across_agent_tree(self, agent_to_run: BaseAgent) -> bool:
396
492
  """Whether the agent to run can transfer to any other agent in the agent tree.
397
493
 
398
- This typically means all agent_to_run's parent through root agent can
399
- transfer to their parent_agent.
494
+ This typically means all agent_to_run's ancestor can transfer to their
495
+ parent_agent all the way to the root_agent.
400
496
 
401
497
  Args:
402
498
  agent_to_run: The agent to check for transferability.
@@ -407,7 +503,7 @@ class Runner:
407
503
  agent = agent_to_run
408
504
  while agent:
409
505
  if not isinstance(agent, LlmAgent):
410
- # Only LLM-based Agent can provider agent transfer capability.
506
+ # Only LLM-based Agent can provide agent transfer capability.
411
507
  return False
412
508
  if agent.disallow_transfer_to_parent:
413
509
  return False
@@ -450,6 +546,7 @@ class Runner:
450
546
  session_service=self.session_service,
451
547
  memory_service=self.memory_service,
452
548
  credential_service=self.credential_service,
549
+ plugin_manager=self.plugin_manager,
453
550
  invocation_id=invocation_id,
454
551
  agent=self.agent,
455
552
  session=session,
@@ -538,7 +635,13 @@ class InMemoryRunner(Runner):
538
635
  session service for the runner.
539
636
  """
540
637
 
541
- def __init__(self, agent: BaseAgent, *, app_name: str = 'InMemoryRunner'):
638
+ def __init__(
639
+ self,
640
+ agent: BaseAgent,
641
+ *,
642
+ app_name: str = 'InMemoryRunner',
643
+ plugins: Optional[list[BasePlugin]] = None,
644
+ ):
542
645
  """Initializes the InMemoryRunner.
543
646
 
544
647
  Args:
@@ -551,6 +654,7 @@ class InMemoryRunner(Runner):
551
654
  app_name=app_name,
552
655
  agent=agent,
553
656
  artifact_service=InMemoryArtifactService(),
657
+ plugins=plugins,
554
658
  session_service=self._in_memory_session_service,
555
659
  memory_service=InMemoryMemoryService(),
556
660
  )
@@ -137,7 +137,7 @@ class StorageSession(Base):
137
137
  DateTime(), default=func.now(), onupdate=func.now()
138
138
  )
139
139
 
140
- storage_events: Mapped[list["StorageEvent"]] = relationship(
140
+ storage_events: Mapped[list[StorageEvent]] = relationship(
141
141
  "StorageEvent",
142
142
  back_populates="storage_session",
143
143
  )
@@ -160,6 +160,26 @@ class StorageSession(Base):
160
160
  return self.update_time.replace(tzinfo=timezone.utc).timestamp()
161
161
  return self.update_time.timestamp()
162
162
 
163
+ def to_session(
164
+ self,
165
+ state: dict[str, Any] | None = None,
166
+ events: list[Event] | None = None,
167
+ ) -> Session:
168
+ """Converts the storage session to a session object."""
169
+ if state is None:
170
+ state = {}
171
+ if events is None:
172
+ events = []
173
+
174
+ return Session(
175
+ app_name=self.app_name,
176
+ user_id=self.user_id,
177
+ id=self.id,
178
+ state=state,
179
+ events=events,
180
+ last_update_time=self.update_timestamp_tz,
181
+ )
182
+
163
183
 
164
184
  class StorageEvent(Base):
165
185
  """Represents an event stored in the database."""
@@ -373,11 +393,11 @@ class DatabaseSessionService(BaseSessionService):
373
393
  # 4. Build the session object with generated id
374
394
  # 5. Return the session
375
395
 
376
- with self.database_session_factory() as session_factory:
396
+ with self.database_session_factory() as sql_session:
377
397
 
378
398
  # Fetch app and user states from storage
379
- storage_app_state = session_factory.get(StorageAppState, (app_name))
380
- storage_user_state = session_factory.get(
399
+ storage_app_state = sql_session.get(StorageAppState, (app_name))
400
+ storage_user_state = sql_session.get(
381
401
  StorageUserState, (app_name, user_id)
382
402
  )
383
403
 
@@ -387,12 +407,12 @@ class DatabaseSessionService(BaseSessionService):
387
407
  # Create state tables if not exist
388
408
  if not storage_app_state:
389
409
  storage_app_state = StorageAppState(app_name=app_name, state={})
390
- session_factory.add(storage_app_state)
410
+ sql_session.add(storage_app_state)
391
411
  if not storage_user_state:
392
412
  storage_user_state = StorageUserState(
393
413
  app_name=app_name, user_id=user_id, state={}
394
414
  )
395
- session_factory.add(storage_user_state)
415
+ sql_session.add(storage_user_state)
396
416
 
397
417
  # Extract state deltas
398
418
  app_state_delta, user_state_delta, session_state = _extract_state_delta(
@@ -416,21 +436,15 @@ class DatabaseSessionService(BaseSessionService):
416
436
  id=session_id,
417
437
  state=session_state,
418
438
  )
419
- session_factory.add(storage_session)
420
- session_factory.commit()
439
+ sql_session.add(storage_session)
440
+ sql_session.commit()
421
441
 
422
- session_factory.refresh(storage_session)
442
+ sql_session.refresh(storage_session)
423
443
 
424
444
  # Merge states for response
425
445
  merged_state = _merge_state(app_state, user_state, session_state)
426
- session = Session(
427
- app_name=str(storage_session.app_name),
428
- user_id=str(storage_session.user_id),
429
- id=str(storage_session.id),
430
- state=merged_state,
431
- last_update_time=storage_session.update_timestamp_tz,
432
- )
433
- return session
446
+ session = storage_session.to_session(state=merged_state)
447
+ return session
434
448
 
435
449
  @override
436
450
  async def get_session(
@@ -444,8 +458,8 @@ class DatabaseSessionService(BaseSessionService):
444
458
  # 1. Get the storage session entry from session table
445
459
  # 2. Get all the events based on session id and filtering config
446
460
  # 3. Convert and return the session
447
- with self.database_session_factory() as session_factory:
448
- storage_session = session_factory.get(
461
+ with self.database_session_factory() as sql_session:
462
+ storage_session = sql_session.get(
449
463
  StorageSession, (app_name, user_id, session_id)
450
464
  )
451
465
  if storage_session is None:
@@ -458,7 +472,7 @@ class DatabaseSessionService(BaseSessionService):
458
472
  timestamp_filter = True
459
473
 
460
474
  storage_events = (
461
- session_factory.query(StorageEvent)
475
+ sql_session.query(StorageEvent)
462
476
  .filter(StorageEvent.app_name == app_name)
463
477
  .filter(StorageEvent.session_id == storage_session.id)
464
478
  .filter(StorageEvent.user_id == user_id)
@@ -473,8 +487,8 @@ class DatabaseSessionService(BaseSessionService):
473
487
  )
474
488
 
475
489
  # Fetch states from storage
476
- storage_app_state = session_factory.get(StorageAppState, (app_name))
477
- storage_user_state = session_factory.get(
490
+ storage_app_state = sql_session.get(StorageAppState, (app_name))
491
+ storage_user_state = sql_session.get(
478
492
  StorageUserState, (app_name, user_id)
479
493
  )
480
494
 
@@ -486,51 +500,38 @@ class DatabaseSessionService(BaseSessionService):
486
500
  merged_state = _merge_state(app_state, user_state, session_state)
487
501
 
488
502
  # Convert storage session to session
489
- session = Session(
490
- app_name=app_name,
491
- user_id=user_id,
492
- id=session_id,
493
- state=merged_state,
494
- last_update_time=storage_session.update_timestamp_tz,
495
- )
496
- session.events = [e.to_event() for e in reversed(storage_events)]
503
+ events = [e.to_event() for e in reversed(storage_events)]
504
+ session = storage_session.to_session(state=merged_state, events=events)
497
505
  return session
498
506
 
499
507
  @override
500
508
  async def list_sessions(
501
509
  self, *, app_name: str, user_id: str
502
510
  ) -> ListSessionsResponse:
503
- with self.database_session_factory() as session_factory:
511
+ with self.database_session_factory() as sql_session:
504
512
  results = (
505
- session_factory.query(StorageSession)
513
+ sql_session.query(StorageSession)
506
514
  .filter(StorageSession.app_name == app_name)
507
515
  .filter(StorageSession.user_id == user_id)
508
516
  .all()
509
517
  )
510
518
  sessions = []
511
519
  for storage_session in results:
512
- session = Session(
513
- app_name=app_name,
514
- user_id=user_id,
515
- id=storage_session.id,
516
- state={},
517
- last_update_time=storage_session.update_timestamp_tz,
518
- )
519
- sessions.append(session)
520
+ sessions.append(storage_session.to_session())
520
521
  return ListSessionsResponse(sessions=sessions)
521
522
 
522
523
  @override
523
524
  async def delete_session(
524
525
  self, app_name: str, user_id: str, session_id: str
525
526
  ) -> None:
526
- with self.database_session_factory() as session_factory:
527
+ with self.database_session_factory() as sql_session:
527
528
  stmt = delete(StorageSession).where(
528
529
  StorageSession.app_name == app_name,
529
530
  StorageSession.user_id == user_id,
530
531
  StorageSession.id == session_id,
531
532
  )
532
- session_factory.execute(stmt)
533
- session_factory.commit()
533
+ sql_session.execute(stmt)
534
+ sql_session.commit()
534
535
 
535
536
  @override
536
537
  async def append_event(self, session: Session, event: Event) -> Event:
@@ -542,8 +543,8 @@ class DatabaseSessionService(BaseSessionService):
542
543
  # 1. Check if timestamp is stale
543
544
  # 2. Update session attributes based on event config
544
545
  # 3. Store event to table
545
- with self.database_session_factory() as session_factory:
546
- storage_session = session_factory.get(
546
+ with self.database_session_factory() as sql_session:
547
+ storage_session = sql_session.get(
547
548
  StorageSession, (session.app_name, session.user_id, session.id)
548
549
  )
549
550
 
@@ -557,10 +558,8 @@ class DatabaseSessionService(BaseSessionService):
557
558
  )
558
559
 
559
560
  # Fetch states from storage
560
- storage_app_state = session_factory.get(
561
- StorageAppState, (session.app_name)
562
- )
563
- storage_user_state = session_factory.get(
561
+ storage_app_state = sql_session.get(StorageAppState, (session.app_name))
562
+ storage_user_state = sql_session.get(
564
563
  StorageUserState, (session.app_name, session.user_id)
565
564
  )
566
565
 
@@ -589,10 +588,10 @@ class DatabaseSessionService(BaseSessionService):
589
588
  session_state.update(session_state_delta)
590
589
  storage_session.state = session_state
591
590
 
592
- session_factory.add(StorageEvent.from_event(session, event))
591
+ sql_session.add(StorageEvent.from_event(session, event))
593
592
 
594
- session_factory.commit()
595
- session_factory.refresh(storage_session)
593
+ sql_session.commit()
594
+ sql_session.refresh(storage_session)
596
595
 
597
596
  # Update timestamp with commit time
598
597
  session.last_update_time = storage_session.update_timestamp_tz
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  from __future__ import annotations
15
15
 
16
- import asyncio
17
16
  import json
18
17
  import logging
19
18
  import os
@@ -110,7 +109,8 @@ class VertexAiSessionService(BaseSessionService):
110
109
  request_dict=session_json_dict,
111
110
  )
112
111
  api_response = _convert_api_response(api_response)
113
- logger.info(f'Create Session response {api_response}')
112
+ logger.info('Create session response received.')
113
+ logger.debug('Create session response: %s', api_response)
114
114
 
115
115
  session_id = api_response['name'].split('/')[-3]
116
116
  operation_id = api_response['name'].split('/')[-1]
@@ -216,29 +216,36 @@ class VertexAiSessionService(BaseSessionService):
216
216
  path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
217
217
  request_dict={},
218
218
  )
219
- list_events_api_response = _convert_api_response(list_events_api_response)
219
+ converted_api_response = _convert_api_response(list_events_api_response)
220
220
 
221
- # Handles empty response case
222
- if not list_events_api_response or list_events_api_response.get(
221
+ # Handles empty response case where there are no events to fetch
222
+ if not converted_api_response or converted_api_response.get(
223
223
  'httpHeaders', None
224
224
  ):
225
225
  return session
226
226
 
227
227
  session.events += [
228
228
  _from_api_event(event)
229
- for event in list_events_api_response['sessionEvents']
229
+ for event in converted_api_response['sessionEvents']
230
230
  ]
231
231
 
232
- while list_events_api_response.get('nextPageToken', None):
233
- page_token = list_events_api_response.get('nextPageToken', None)
232
+ while converted_api_response.get('nextPageToken', None):
233
+ page_token = converted_api_response.get('nextPageToken', None)
234
234
  list_events_api_response = await api_client.async_request(
235
235
  http_method='GET',
236
236
  path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events?pageToken={page_token}',
237
237
  request_dict={},
238
238
  )
239
+ converted_api_response = _convert_api_response(list_events_api_response)
240
+
241
+ # Handles empty response case where there are no more events to fetch
242
+ if not converted_api_response or converted_api_response.get(
243
+ 'httpHeaders', None
244
+ ):
245
+ break
239
246
  session.events += [
240
247
  _from_api_event(event)
241
- for event in list_events_api_response['sessionEvents']
248
+ for event in converted_api_response['sessionEvents']
242
249
  ]
243
250
 
244
251
  session.events = [
@@ -344,16 +351,24 @@ class VertexAiSessionService(BaseSessionService):
344
351
 
345
352
  return match.groups()[-1]
346
353
 
354
+ def _api_client_http_options_override(
355
+ self,
356
+ ) -> Optional[genai.types.HttpOptions]:
357
+ return None
358
+
347
359
  def _get_api_client(self):
348
360
  """Instantiates an API client for the given project and location.
349
361
 
350
362
  It needs to be instantiated inside each request so that the event loop
351
363
  management can be properly propagated.
352
364
  """
353
- client = genai.Client(
365
+ api_client = genai.Client(
354
366
  vertexai=True, project=self._project, location=self._location
355
- )
356
- return client._api_client
367
+ )._api_client
368
+
369
+ if new_options := self._api_client_http_options_override():
370
+ api_client._http_options = new_options
371
+ return api_client
357
372
 
358
373
 
359
374
  def _is_vertex_express_mode(
@@ -14,6 +14,7 @@
14
14
 
15
15
 
16
16
  from ..auth.auth_tool import AuthToolArguments
17
+ from .agent_tool import AgentTool
17
18
  from .apihub_tool.apihub_toolset import APIHubToolset
18
19
  from .base_tool import BaseTool
19
20
  from .example_tool import ExampleTool
@@ -31,6 +32,7 @@ from .url_context_tool import url_context
31
32
  from .vertex_ai_search_tool import VertexAiSearchTool
32
33
 
33
34
  __all__ = [
35
+ 'AgentTool',
34
36
  'APIHubToolset',
35
37
  'AuthToolArguments',
36
38
  'BaseTool',
@@ -20,7 +20,6 @@ import typing
20
20
  from typing import Any
21
21
  from typing import Callable
22
22
  from typing import Dict
23
- from typing import Literal
24
23
  from typing import Optional
25
24
  from typing import Union
26
25
 
@@ -329,7 +328,26 @@ def from_function_with_options(
329
328
  return declaration
330
329
 
331
330
  return_annotation = inspect.signature(func).return_annotation
332
- if return_annotation is inspect._empty:
331
+
332
+ # Handle functions with no return annotation or that return None
333
+ if (
334
+ return_annotation is inspect._empty
335
+ or return_annotation is None
336
+ or return_annotation is type(None)
337
+ ):
338
+ # Create a response schema for None/null return
339
+ return_value = inspect.Parameter(
340
+ 'return_value',
341
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
342
+ annotation=None,
343
+ )
344
+ declaration.response = (
345
+ _function_parameter_parse_util._parse_schema_from_parameter(
346
+ variant,
347
+ return_value,
348
+ func.__name__,
349
+ )
350
+ )
333
351
  return declaration
334
352
 
335
353
  return_value = inspect.Parameter(