google-adk 1.6.1__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/a2a/converters/event_converter.py +5 -85
- google/adk/a2a/converters/request_converter.py +1 -2
- google/adk/a2a/executor/a2a_agent_executor.py +45 -16
- google/adk/a2a/logs/log_utils.py +1 -2
- google/adk/a2a/utils/__init__.py +0 -0
- google/adk/a2a/utils/agent_card_builder.py +544 -0
- google/adk/a2a/utils/agent_to_a2a.py +118 -0
- google/adk/agents/__init__.py +5 -0
- google/adk/agents/agent_config.py +46 -0
- google/adk/agents/base_agent.py +239 -41
- google/adk/agents/callback_context.py +41 -0
- google/adk/agents/common_configs.py +79 -0
- google/adk/agents/config_agent_utils.py +184 -0
- google/adk/agents/config_schemas/AgentConfig.json +566 -0
- google/adk/agents/invocation_context.py +5 -1
- google/adk/agents/live_request_queue.py +15 -0
- google/adk/agents/llm_agent.py +201 -9
- google/adk/agents/loop_agent.py +35 -1
- google/adk/agents/parallel_agent.py +24 -3
- google/adk/agents/remote_a2a_agent.py +17 -5
- google/adk/agents/sequential_agent.py +22 -1
- google/adk/artifacts/gcs_artifact_service.py +110 -20
- google/adk/auth/auth_handler.py +3 -3
- google/adk/auth/credential_manager.py +23 -23
- google/adk/auth/credential_service/base_credential_service.py +6 -6
- google/adk/auth/credential_service/in_memory_credential_service.py +10 -8
- google/adk/auth/credential_service/session_state_credential_service.py +8 -8
- google/adk/auth/exchanger/oauth2_credential_exchanger.py +3 -3
- google/adk/auth/oauth2_credential_util.py +2 -2
- google/adk/auth/refresher/oauth2_credential_refresher.py +4 -4
- google/adk/cli/agent_graph.py +3 -1
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/main-W7QZBYAR.js +3914 -0
- google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
- google/adk/cli/cli_eval.py +87 -12
- google/adk/cli/cli_tools_click.py +143 -82
- google/adk/cli/fast_api.py +150 -69
- google/adk/cli/utils/agent_loader.py +35 -1
- google/adk/code_executors/base_code_executor.py +14 -19
- google/adk/code_executors/built_in_code_executor.py +4 -1
- google/adk/evaluation/base_eval_service.py +46 -2
- google/adk/evaluation/eval_metrics.py +4 -0
- google/adk/evaluation/eval_sets_manager.py +5 -1
- google/adk/evaluation/evaluation_generator.py +1 -1
- google/adk/evaluation/final_response_match_v2.py +2 -2
- google/adk/evaluation/gcs_eval_sets_manager.py +2 -1
- google/adk/evaluation/in_memory_eval_sets_manager.py +151 -0
- google/adk/evaluation/local_eval_service.py +389 -0
- google/adk/evaluation/local_eval_set_results_manager.py +2 -2
- google/adk/evaluation/local_eval_sets_manager.py +24 -9
- google/adk/evaluation/metric_evaluator_registry.py +16 -6
- google/adk/evaluation/vertex_ai_eval_facade.py +7 -1
- google/adk/events/event.py +7 -2
- google/adk/flows/llm_flows/auto_flow.py +6 -11
- google/adk/flows/llm_flows/base_llm_flow.py +66 -29
- google/adk/flows/llm_flows/contents.py +16 -10
- google/adk/flows/llm_flows/functions.py +89 -52
- google/adk/memory/in_memory_memory_service.py +21 -15
- google/adk/memory/vertex_ai_memory_bank_service.py +12 -10
- google/adk/models/anthropic_llm.py +46 -6
- google/adk/models/base_llm_connection.py +2 -0
- google/adk/models/gemini_llm_connection.py +17 -6
- google/adk/models/google_llm.py +46 -11
- google/adk/models/lite_llm.py +52 -22
- google/adk/plugins/__init__.py +17 -0
- google/adk/plugins/base_plugin.py +317 -0
- google/adk/plugins/plugin_manager.py +265 -0
- google/adk/runners.py +122 -18
- google/adk/sessions/database_session_service.py +51 -52
- google/adk/sessions/vertex_ai_session_service.py +27 -12
- google/adk/tools/__init__.py +2 -0
- google/adk/tools/_automatic_function_calling_util.py +20 -2
- google/adk/tools/agent_tool.py +15 -3
- google/adk/tools/apihub_tool/apihub_toolset.py +38 -39
- google/adk/tools/application_integration_tool/application_integration_toolset.py +35 -37
- google/adk/tools/application_integration_tool/integration_connector_tool.py +2 -3
- google/adk/tools/base_tool.py +9 -9
- google/adk/tools/base_toolset.py +29 -5
- google/adk/tools/bigquery/__init__.py +3 -3
- google/adk/tools/bigquery/metadata_tool.py +2 -0
- google/adk/tools/bigquery/query_tool.py +15 -1
- google/adk/tools/computer_use/__init__.py +13 -0
- google/adk/tools/computer_use/base_computer.py +265 -0
- google/adk/tools/computer_use/computer_use_tool.py +166 -0
- google/adk/tools/computer_use/computer_use_toolset.py +220 -0
- google/adk/tools/enterprise_search_tool.py +4 -2
- google/adk/tools/exit_loop_tool.py +1 -0
- google/adk/tools/google_api_tool/google_api_tool.py +16 -1
- google/adk/tools/google_api_tool/google_api_toolset.py +9 -7
- google/adk/tools/google_api_tool/google_api_toolsets.py +41 -20
- google/adk/tools/google_search_tool.py +4 -2
- google/adk/tools/langchain_tool.py +16 -6
- google/adk/tools/long_running_tool.py +21 -0
- google/adk/tools/mcp_tool/mcp_toolset.py +27 -28
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_spec_parser.py +5 -0
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +8 -8
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +4 -6
- google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +3 -2
- google/adk/tools/tool_context.py +0 -10
- google/adk/tools/url_context_tool.py +4 -2
- google/adk/tools/vertex_ai_search_tool.py +4 -2
- google/adk/utils/model_name_utils.py +90 -0
- google/adk/version.py +1 -1
- {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/METADATA +3 -2
- {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/RECORD +108 -91
- google/adk/cli/browser/main-RXDVX3K6.js +0 -3914
- google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -17
- {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/WHEEL +0 -0
- {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/entry_points.txt +0 -0
- {google_adk-1.6.1.dist-info → google_adk-1.8.0.dist-info}/licenses/LICENSE +0 -0
google/adk/runners.py
CHANGED
@@ -17,9 +17,14 @@ from __future__ import annotations
|
|
17
17
|
import asyncio
|
18
18
|
import logging
|
19
19
|
import queue
|
20
|
+
import time
|
21
|
+
from typing import Any
|
20
22
|
from typing import AsyncGenerator
|
23
|
+
from typing import Callable
|
21
24
|
from typing import Generator
|
25
|
+
from typing import List
|
22
26
|
from typing import Optional
|
27
|
+
import uuid
|
23
28
|
import warnings
|
24
29
|
|
25
30
|
from google.genai import types
|
@@ -36,10 +41,13 @@ from .artifacts.in_memory_artifact_service import InMemoryArtifactService
|
|
36
41
|
from .auth.credential_service.base_credential_service import BaseCredentialService
|
37
42
|
from .code_executors.built_in_code_executor import BuiltInCodeExecutor
|
38
43
|
from .events.event import Event
|
44
|
+
from .events.event import EventActions
|
39
45
|
from .flows.llm_flows.functions import find_matching_function_call
|
40
46
|
from .memory.base_memory_service import BaseMemoryService
|
41
47
|
from .memory.in_memory_memory_service import InMemoryMemoryService
|
42
48
|
from .platform.thread import create_thread
|
49
|
+
from .plugins.base_plugin import BasePlugin
|
50
|
+
from .plugins.plugin_manager import PluginManager
|
43
51
|
from .sessions.base_session_service import BaseSessionService
|
44
52
|
from .sessions.in_memory_session_service import InMemorySessionService
|
45
53
|
from .sessions.session import Session
|
@@ -60,6 +68,7 @@ class Runner:
|
|
60
68
|
app_name: The application name of the runner.
|
61
69
|
agent: The root agent to run.
|
62
70
|
artifact_service: The artifact service for the runner.
|
71
|
+
plugin_manager: The plugin manager for the runner.
|
63
72
|
session_service: The session service for the runner.
|
64
73
|
memory_service: The memory service for the runner.
|
65
74
|
"""
|
@@ -70,6 +79,8 @@ class Runner:
|
|
70
79
|
"""The root agent to run."""
|
71
80
|
artifact_service: Optional[BaseArtifactService] = None
|
72
81
|
"""The artifact service for the runner."""
|
82
|
+
plugin_manager: PluginManager
|
83
|
+
"""The plugin manager for the runner."""
|
73
84
|
session_service: BaseSessionService
|
74
85
|
"""The session service for the runner."""
|
75
86
|
memory_service: Optional[BaseMemoryService] = None
|
@@ -82,6 +93,7 @@ class Runner:
|
|
82
93
|
*,
|
83
94
|
app_name: str,
|
84
95
|
agent: BaseAgent,
|
96
|
+
plugins: Optional[List[BasePlugin]] = None,
|
85
97
|
artifact_service: Optional[BaseArtifactService] = None,
|
86
98
|
session_service: BaseSessionService,
|
87
99
|
memory_service: Optional[BaseMemoryService] = None,
|
@@ -102,6 +114,7 @@ class Runner:
|
|
102
114
|
self.session_service = session_service
|
103
115
|
self.memory_service = memory_service
|
104
116
|
self.credential_service = credential_service
|
117
|
+
self.plugin_manager = PluginManager(plugins=plugins)
|
105
118
|
|
106
119
|
def run(
|
107
120
|
self,
|
@@ -113,8 +126,9 @@ class Runner:
|
|
113
126
|
) -> Generator[Event, None, None]:
|
114
127
|
"""Runs the agent.
|
115
128
|
|
116
|
-
NOTE:
|
117
|
-
|
129
|
+
NOTE:
|
130
|
+
This sync interface is only for local testing and convenience purpose.
|
131
|
+
Consider using `run_async` for production usage.
|
118
132
|
|
119
133
|
Args:
|
120
134
|
user_id: The user ID of the session.
|
@@ -164,6 +178,7 @@ class Runner:
|
|
164
178
|
user_id: str,
|
165
179
|
session_id: str,
|
166
180
|
new_message: types.Content,
|
181
|
+
state_delta: Optional[dict[str, Any]] = None,
|
167
182
|
run_config: RunConfig = RunConfig(),
|
168
183
|
) -> AsyncGenerator[Event, None]:
|
169
184
|
"""Main entry method to run the agent in this runner.
|
@@ -191,19 +206,83 @@ class Runner:
|
|
191
206
|
)
|
192
207
|
root_agent = self.agent
|
193
208
|
|
209
|
+
# Modify user message before execution.
|
210
|
+
modified_user_message = (
|
211
|
+
await invocation_context.plugin_manager.run_on_user_message_callback(
|
212
|
+
invocation_context=invocation_context, user_message=new_message
|
213
|
+
)
|
214
|
+
)
|
215
|
+
if modified_user_message is not None:
|
216
|
+
new_message = modified_user_message
|
217
|
+
|
194
218
|
if new_message:
|
195
219
|
await self._append_new_message_to_session(
|
196
220
|
session,
|
197
221
|
new_message,
|
198
222
|
invocation_context,
|
199
223
|
run_config.save_input_blobs_as_artifacts,
|
224
|
+
state_delta,
|
200
225
|
)
|
201
226
|
|
202
227
|
invocation_context.agent = self._find_agent_to_run(session, root_agent)
|
203
|
-
|
228
|
+
|
229
|
+
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
|
230
|
+
async for event in ctx.agent.run_async(ctx):
|
231
|
+
yield event
|
232
|
+
|
233
|
+
async for event in self._exec_with_plugin(
|
234
|
+
invocation_context, session, execute
|
235
|
+
):
|
236
|
+
yield event
|
237
|
+
|
238
|
+
async def _exec_with_plugin(
|
239
|
+
self,
|
240
|
+
invocation_context: InvocationContext,
|
241
|
+
session: Session,
|
242
|
+
execute_fn: Callable[[InvocationContext], AsyncGenerator[Event, None]],
|
243
|
+
) -> AsyncGenerator[Event, None]:
|
244
|
+
"""Wraps execution with plugin callbacks.
|
245
|
+
|
246
|
+
Args:
|
247
|
+
invocation_context: The invocation context
|
248
|
+
session: The current session
|
249
|
+
execute_fn: A callable that returns an AsyncGenerator of Events
|
250
|
+
|
251
|
+
Yields:
|
252
|
+
Events from the execution, including any generated by plugins
|
253
|
+
"""
|
254
|
+
|
255
|
+
plugin_manager = invocation_context.plugin_manager
|
256
|
+
|
257
|
+
# Step 1: Run the before_run callbacks to see if we should early exit.
|
258
|
+
early_exit_result = await plugin_manager.run_before_run_callback(
|
259
|
+
invocation_context=invocation_context
|
260
|
+
)
|
261
|
+
if isinstance(early_exit_result, Event):
|
262
|
+
await self.session_service.append_event(
|
263
|
+
session=session,
|
264
|
+
event=Event(
|
265
|
+
invocation_id=invocation_context.invocation_id,
|
266
|
+
author='model',
|
267
|
+
content=early_exit_result,
|
268
|
+
),
|
269
|
+
)
|
270
|
+
yield early_exit_result
|
271
|
+
else:
|
272
|
+
# Step 2: Otherwise continue with normal execution
|
273
|
+
async for event in execute_fn(invocation_context):
|
204
274
|
if not event.partial:
|
205
275
|
await self.session_service.append_event(session=session, event=event)
|
206
|
-
|
276
|
+
# Step 3: Run the on_event callbacks to optionally modify the event.
|
277
|
+
modified_event = await plugin_manager.run_on_event_callback(
|
278
|
+
invocation_context=invocation_context, event=event
|
279
|
+
)
|
280
|
+
yield (modified_event if modified_event else event)
|
281
|
+
|
282
|
+
# Step 4: Run the after_run callbacks to optionally modify the context.
|
283
|
+
await plugin_manager.run_after_run_callback(
|
284
|
+
invocation_context=invocation_context
|
285
|
+
)
|
207
286
|
|
208
287
|
async def _append_new_message_to_session(
|
209
288
|
self,
|
@@ -211,6 +290,7 @@ class Runner:
|
|
211
290
|
new_message: types.Content,
|
212
291
|
invocation_context: InvocationContext,
|
213
292
|
save_input_blobs_as_artifacts: bool = False,
|
293
|
+
state_delta: Optional[dict[str, Any]] = None,
|
214
294
|
):
|
215
295
|
"""Appends a new message to the session.
|
216
296
|
|
@@ -242,11 +322,19 @@ class Runner:
|
|
242
322
|
text=f'Uploaded file: {file_name}. It is saved into artifacts'
|
243
323
|
)
|
244
324
|
# Appends only. We do not yield the event because it's not from the model.
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
325
|
+
if state_delta:
|
326
|
+
event = Event(
|
327
|
+
invocation_id=invocation_context.invocation_id,
|
328
|
+
author='user',
|
329
|
+
actions=EventActions(state_delta=state_delta),
|
330
|
+
content=new_message,
|
331
|
+
)
|
332
|
+
else:
|
333
|
+
event = Event(
|
334
|
+
invocation_id=invocation_context.invocation_id,
|
335
|
+
author='user',
|
336
|
+
content=new_message,
|
337
|
+
)
|
250
338
|
await self.session_service.append_event(session=session, event=event)
|
251
339
|
|
252
340
|
async def run_live(
|
@@ -278,7 +366,7 @@ class Runner:
|
|
278
366
|
This feature is **experimental** and its API or behavior may change
|
279
367
|
in future releases.
|
280
368
|
|
281
|
-
..
|
369
|
+
.. NOTE::
|
282
370
|
Either `session` or both `user_id` and `session_id` must be provided.
|
283
371
|
"""
|
284
372
|
if session is None and (user_id is None or session_id is None):
|
@@ -345,8 +433,14 @@ class Runner:
|
|
345
433
|
invocation_context.active_streaming_tools[tool.__name__] = (
|
346
434
|
active_streaming_tool
|
347
435
|
)
|
348
|
-
|
349
|
-
|
436
|
+
|
437
|
+
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
|
438
|
+
async for event in ctx.agent.run_live(ctx):
|
439
|
+
yield event
|
440
|
+
|
441
|
+
async for event in self._exec_with_plugin(
|
442
|
+
invocation_context, session, execute
|
443
|
+
):
|
350
444
|
yield event
|
351
445
|
|
352
446
|
def _find_agent_to_run(
|
@@ -355,9 +449,10 @@ class Runner:
|
|
355
449
|
"""Finds the agent to run to continue the session.
|
356
450
|
|
357
451
|
A qualified agent must be either of:
|
452
|
+
|
358
453
|
- The agent that returned a function call and the last user message is a
|
359
454
|
function response to this function call.
|
360
|
-
- The root agent
|
455
|
+
- The root agent.
|
361
456
|
- An LlmAgent who replied last and is capable to transfer to any other agent
|
362
457
|
in the agent hierarchy.
|
363
458
|
|
@@ -366,7 +461,8 @@ class Runner:
|
|
366
461
|
root_agent: The root agent of the runner.
|
367
462
|
|
368
463
|
Returns:
|
369
|
-
The agent
|
464
|
+
The agent to run. (the active agent that should reply to the latest user
|
465
|
+
message)
|
370
466
|
"""
|
371
467
|
# If the last event is a function response, should send this response to
|
372
468
|
# the agent that returned the corressponding function call regardless the
|
@@ -395,8 +491,8 @@ class Runner:
|
|
395
491
|
def _is_transferable_across_agent_tree(self, agent_to_run: BaseAgent) -> bool:
|
396
492
|
"""Whether the agent to run can transfer to any other agent in the agent tree.
|
397
493
|
|
398
|
-
This typically means all agent_to_run's
|
399
|
-
|
494
|
+
This typically means all agent_to_run's ancestor can transfer to their
|
495
|
+
parent_agent all the way to the root_agent.
|
400
496
|
|
401
497
|
Args:
|
402
498
|
agent_to_run: The agent to check for transferability.
|
@@ -407,7 +503,7 @@ class Runner:
|
|
407
503
|
agent = agent_to_run
|
408
504
|
while agent:
|
409
505
|
if not isinstance(agent, LlmAgent):
|
410
|
-
# Only LLM-based Agent can
|
506
|
+
# Only LLM-based Agent can provide agent transfer capability.
|
411
507
|
return False
|
412
508
|
if agent.disallow_transfer_to_parent:
|
413
509
|
return False
|
@@ -450,6 +546,7 @@ class Runner:
|
|
450
546
|
session_service=self.session_service,
|
451
547
|
memory_service=self.memory_service,
|
452
548
|
credential_service=self.credential_service,
|
549
|
+
plugin_manager=self.plugin_manager,
|
453
550
|
invocation_id=invocation_id,
|
454
551
|
agent=self.agent,
|
455
552
|
session=session,
|
@@ -538,7 +635,13 @@ class InMemoryRunner(Runner):
|
|
538
635
|
session service for the runner.
|
539
636
|
"""
|
540
637
|
|
541
|
-
def __init__(
|
638
|
+
def __init__(
|
639
|
+
self,
|
640
|
+
agent: BaseAgent,
|
641
|
+
*,
|
642
|
+
app_name: str = 'InMemoryRunner',
|
643
|
+
plugins: Optional[list[BasePlugin]] = None,
|
644
|
+
):
|
542
645
|
"""Initializes the InMemoryRunner.
|
543
646
|
|
544
647
|
Args:
|
@@ -551,6 +654,7 @@ class InMemoryRunner(Runner):
|
|
551
654
|
app_name=app_name,
|
552
655
|
agent=agent,
|
553
656
|
artifact_service=InMemoryArtifactService(),
|
657
|
+
plugins=plugins,
|
554
658
|
session_service=self._in_memory_session_service,
|
555
659
|
memory_service=InMemoryMemoryService(),
|
556
660
|
)
|
@@ -137,7 +137,7 @@ class StorageSession(Base):
|
|
137
137
|
DateTime(), default=func.now(), onupdate=func.now()
|
138
138
|
)
|
139
139
|
|
140
|
-
storage_events: Mapped[list[
|
140
|
+
storage_events: Mapped[list[StorageEvent]] = relationship(
|
141
141
|
"StorageEvent",
|
142
142
|
back_populates="storage_session",
|
143
143
|
)
|
@@ -160,6 +160,26 @@ class StorageSession(Base):
|
|
160
160
|
return self.update_time.replace(tzinfo=timezone.utc).timestamp()
|
161
161
|
return self.update_time.timestamp()
|
162
162
|
|
163
|
+
def to_session(
|
164
|
+
self,
|
165
|
+
state: dict[str, Any] | None = None,
|
166
|
+
events: list[Event] | None = None,
|
167
|
+
) -> Session:
|
168
|
+
"""Converts the storage session to a session object."""
|
169
|
+
if state is None:
|
170
|
+
state = {}
|
171
|
+
if events is None:
|
172
|
+
events = []
|
173
|
+
|
174
|
+
return Session(
|
175
|
+
app_name=self.app_name,
|
176
|
+
user_id=self.user_id,
|
177
|
+
id=self.id,
|
178
|
+
state=state,
|
179
|
+
events=events,
|
180
|
+
last_update_time=self.update_timestamp_tz,
|
181
|
+
)
|
182
|
+
|
163
183
|
|
164
184
|
class StorageEvent(Base):
|
165
185
|
"""Represents an event stored in the database."""
|
@@ -373,11 +393,11 @@ class DatabaseSessionService(BaseSessionService):
|
|
373
393
|
# 4. Build the session object with generated id
|
374
394
|
# 5. Return the session
|
375
395
|
|
376
|
-
with self.database_session_factory() as
|
396
|
+
with self.database_session_factory() as sql_session:
|
377
397
|
|
378
398
|
# Fetch app and user states from storage
|
379
|
-
storage_app_state =
|
380
|
-
storage_user_state =
|
399
|
+
storage_app_state = sql_session.get(StorageAppState, (app_name))
|
400
|
+
storage_user_state = sql_session.get(
|
381
401
|
StorageUserState, (app_name, user_id)
|
382
402
|
)
|
383
403
|
|
@@ -387,12 +407,12 @@ class DatabaseSessionService(BaseSessionService):
|
|
387
407
|
# Create state tables if not exist
|
388
408
|
if not storage_app_state:
|
389
409
|
storage_app_state = StorageAppState(app_name=app_name, state={})
|
390
|
-
|
410
|
+
sql_session.add(storage_app_state)
|
391
411
|
if not storage_user_state:
|
392
412
|
storage_user_state = StorageUserState(
|
393
413
|
app_name=app_name, user_id=user_id, state={}
|
394
414
|
)
|
395
|
-
|
415
|
+
sql_session.add(storage_user_state)
|
396
416
|
|
397
417
|
# Extract state deltas
|
398
418
|
app_state_delta, user_state_delta, session_state = _extract_state_delta(
|
@@ -416,21 +436,15 @@ class DatabaseSessionService(BaseSessionService):
|
|
416
436
|
id=session_id,
|
417
437
|
state=session_state,
|
418
438
|
)
|
419
|
-
|
420
|
-
|
439
|
+
sql_session.add(storage_session)
|
440
|
+
sql_session.commit()
|
421
441
|
|
422
|
-
|
442
|
+
sql_session.refresh(storage_session)
|
423
443
|
|
424
444
|
# Merge states for response
|
425
445
|
merged_state = _merge_state(app_state, user_state, session_state)
|
426
|
-
session =
|
427
|
-
|
428
|
-
user_id=str(storage_session.user_id),
|
429
|
-
id=str(storage_session.id),
|
430
|
-
state=merged_state,
|
431
|
-
last_update_time=storage_session.update_timestamp_tz,
|
432
|
-
)
|
433
|
-
return session
|
446
|
+
session = storage_session.to_session(state=merged_state)
|
447
|
+
return session
|
434
448
|
|
435
449
|
@override
|
436
450
|
async def get_session(
|
@@ -444,8 +458,8 @@ class DatabaseSessionService(BaseSessionService):
|
|
444
458
|
# 1. Get the storage session entry from session table
|
445
459
|
# 2. Get all the events based on session id and filtering config
|
446
460
|
# 3. Convert and return the session
|
447
|
-
with self.database_session_factory() as
|
448
|
-
storage_session =
|
461
|
+
with self.database_session_factory() as sql_session:
|
462
|
+
storage_session = sql_session.get(
|
449
463
|
StorageSession, (app_name, user_id, session_id)
|
450
464
|
)
|
451
465
|
if storage_session is None:
|
@@ -458,7 +472,7 @@ class DatabaseSessionService(BaseSessionService):
|
|
458
472
|
timestamp_filter = True
|
459
473
|
|
460
474
|
storage_events = (
|
461
|
-
|
475
|
+
sql_session.query(StorageEvent)
|
462
476
|
.filter(StorageEvent.app_name == app_name)
|
463
477
|
.filter(StorageEvent.session_id == storage_session.id)
|
464
478
|
.filter(StorageEvent.user_id == user_id)
|
@@ -473,8 +487,8 @@ class DatabaseSessionService(BaseSessionService):
|
|
473
487
|
)
|
474
488
|
|
475
489
|
# Fetch states from storage
|
476
|
-
storage_app_state =
|
477
|
-
storage_user_state =
|
490
|
+
storage_app_state = sql_session.get(StorageAppState, (app_name))
|
491
|
+
storage_user_state = sql_session.get(
|
478
492
|
StorageUserState, (app_name, user_id)
|
479
493
|
)
|
480
494
|
|
@@ -486,51 +500,38 @@ class DatabaseSessionService(BaseSessionService):
|
|
486
500
|
merged_state = _merge_state(app_state, user_state, session_state)
|
487
501
|
|
488
502
|
# Convert storage session to session
|
489
|
-
|
490
|
-
|
491
|
-
user_id=user_id,
|
492
|
-
id=session_id,
|
493
|
-
state=merged_state,
|
494
|
-
last_update_time=storage_session.update_timestamp_tz,
|
495
|
-
)
|
496
|
-
session.events = [e.to_event() for e in reversed(storage_events)]
|
503
|
+
events = [e.to_event() for e in reversed(storage_events)]
|
504
|
+
session = storage_session.to_session(state=merged_state, events=events)
|
497
505
|
return session
|
498
506
|
|
499
507
|
@override
|
500
508
|
async def list_sessions(
|
501
509
|
self, *, app_name: str, user_id: str
|
502
510
|
) -> ListSessionsResponse:
|
503
|
-
with self.database_session_factory() as
|
511
|
+
with self.database_session_factory() as sql_session:
|
504
512
|
results = (
|
505
|
-
|
513
|
+
sql_session.query(StorageSession)
|
506
514
|
.filter(StorageSession.app_name == app_name)
|
507
515
|
.filter(StorageSession.user_id == user_id)
|
508
516
|
.all()
|
509
517
|
)
|
510
518
|
sessions = []
|
511
519
|
for storage_session in results:
|
512
|
-
|
513
|
-
app_name=app_name,
|
514
|
-
user_id=user_id,
|
515
|
-
id=storage_session.id,
|
516
|
-
state={},
|
517
|
-
last_update_time=storage_session.update_timestamp_tz,
|
518
|
-
)
|
519
|
-
sessions.append(session)
|
520
|
+
sessions.append(storage_session.to_session())
|
520
521
|
return ListSessionsResponse(sessions=sessions)
|
521
522
|
|
522
523
|
@override
|
523
524
|
async def delete_session(
|
524
525
|
self, app_name: str, user_id: str, session_id: str
|
525
526
|
) -> None:
|
526
|
-
with self.database_session_factory() as
|
527
|
+
with self.database_session_factory() as sql_session:
|
527
528
|
stmt = delete(StorageSession).where(
|
528
529
|
StorageSession.app_name == app_name,
|
529
530
|
StorageSession.user_id == user_id,
|
530
531
|
StorageSession.id == session_id,
|
531
532
|
)
|
532
|
-
|
533
|
-
|
533
|
+
sql_session.execute(stmt)
|
534
|
+
sql_session.commit()
|
534
535
|
|
535
536
|
@override
|
536
537
|
async def append_event(self, session: Session, event: Event) -> Event:
|
@@ -542,8 +543,8 @@ class DatabaseSessionService(BaseSessionService):
|
|
542
543
|
# 1. Check if timestamp is stale
|
543
544
|
# 2. Update session attributes based on event config
|
544
545
|
# 3. Store event to table
|
545
|
-
with self.database_session_factory() as
|
546
|
-
storage_session =
|
546
|
+
with self.database_session_factory() as sql_session:
|
547
|
+
storage_session = sql_session.get(
|
547
548
|
StorageSession, (session.app_name, session.user_id, session.id)
|
548
549
|
)
|
549
550
|
|
@@ -557,10 +558,8 @@ class DatabaseSessionService(BaseSessionService):
|
|
557
558
|
)
|
558
559
|
|
559
560
|
# Fetch states from storage
|
560
|
-
storage_app_state =
|
561
|
-
|
562
|
-
)
|
563
|
-
storage_user_state = session_factory.get(
|
561
|
+
storage_app_state = sql_session.get(StorageAppState, (session.app_name))
|
562
|
+
storage_user_state = sql_session.get(
|
564
563
|
StorageUserState, (session.app_name, session.user_id)
|
565
564
|
)
|
566
565
|
|
@@ -589,10 +588,10 @@ class DatabaseSessionService(BaseSessionService):
|
|
589
588
|
session_state.update(session_state_delta)
|
590
589
|
storage_session.state = session_state
|
591
590
|
|
592
|
-
|
591
|
+
sql_session.add(StorageEvent.from_event(session, event))
|
593
592
|
|
594
|
-
|
595
|
-
|
593
|
+
sql_session.commit()
|
594
|
+
sql_session.refresh(storage_session)
|
596
595
|
|
597
596
|
# Update timestamp with commit time
|
598
597
|
session.last_update_time = storage_session.update_timestamp_tz
|
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
from __future__ import annotations
|
15
15
|
|
16
|
-
import asyncio
|
17
16
|
import json
|
18
17
|
import logging
|
19
18
|
import os
|
@@ -110,7 +109,8 @@ class VertexAiSessionService(BaseSessionService):
|
|
110
109
|
request_dict=session_json_dict,
|
111
110
|
)
|
112
111
|
api_response = _convert_api_response(api_response)
|
113
|
-
logger.info(
|
112
|
+
logger.info('Create session response received.')
|
113
|
+
logger.debug('Create session response: %s', api_response)
|
114
114
|
|
115
115
|
session_id = api_response['name'].split('/')[-3]
|
116
116
|
operation_id = api_response['name'].split('/')[-1]
|
@@ -216,29 +216,36 @@ class VertexAiSessionService(BaseSessionService):
|
|
216
216
|
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
|
217
217
|
request_dict={},
|
218
218
|
)
|
219
|
-
|
219
|
+
converted_api_response = _convert_api_response(list_events_api_response)
|
220
220
|
|
221
|
-
# Handles empty response case
|
222
|
-
if not
|
221
|
+
# Handles empty response case where there are no events to fetch
|
222
|
+
if not converted_api_response or converted_api_response.get(
|
223
223
|
'httpHeaders', None
|
224
224
|
):
|
225
225
|
return session
|
226
226
|
|
227
227
|
session.events += [
|
228
228
|
_from_api_event(event)
|
229
|
-
for event in
|
229
|
+
for event in converted_api_response['sessionEvents']
|
230
230
|
]
|
231
231
|
|
232
|
-
while
|
233
|
-
page_token =
|
232
|
+
while converted_api_response.get('nextPageToken', None):
|
233
|
+
page_token = converted_api_response.get('nextPageToken', None)
|
234
234
|
list_events_api_response = await api_client.async_request(
|
235
235
|
http_method='GET',
|
236
236
|
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events?pageToken={page_token}',
|
237
237
|
request_dict={},
|
238
238
|
)
|
239
|
+
converted_api_response = _convert_api_response(list_events_api_response)
|
240
|
+
|
241
|
+
# Handles empty response case where there are no more events to fetch
|
242
|
+
if not converted_api_response or converted_api_response.get(
|
243
|
+
'httpHeaders', None
|
244
|
+
):
|
245
|
+
break
|
239
246
|
session.events += [
|
240
247
|
_from_api_event(event)
|
241
|
-
for event in
|
248
|
+
for event in converted_api_response['sessionEvents']
|
242
249
|
]
|
243
250
|
|
244
251
|
session.events = [
|
@@ -344,16 +351,24 @@ class VertexAiSessionService(BaseSessionService):
|
|
344
351
|
|
345
352
|
return match.groups()[-1]
|
346
353
|
|
354
|
+
def _api_client_http_options_override(
|
355
|
+
self,
|
356
|
+
) -> Optional[genai.types.HttpOptions]:
|
357
|
+
return None
|
358
|
+
|
347
359
|
def _get_api_client(self):
|
348
360
|
"""Instantiates an API client for the given project and location.
|
349
361
|
|
350
362
|
It needs to be instantiated inside each request so that the event loop
|
351
363
|
management can be properly propagated.
|
352
364
|
"""
|
353
|
-
|
365
|
+
api_client = genai.Client(
|
354
366
|
vertexai=True, project=self._project, location=self._location
|
355
|
-
)
|
356
|
-
|
367
|
+
)._api_client
|
368
|
+
|
369
|
+
if new_options := self._api_client_http_options_override():
|
370
|
+
api_client._http_options = new_options
|
371
|
+
return api_client
|
357
372
|
|
358
373
|
|
359
374
|
def _is_vertex_express_mode(
|
google/adk/tools/__init__.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
|
15
15
|
|
16
16
|
from ..auth.auth_tool import AuthToolArguments
|
17
|
+
from .agent_tool import AgentTool
|
17
18
|
from .apihub_tool.apihub_toolset import APIHubToolset
|
18
19
|
from .base_tool import BaseTool
|
19
20
|
from .example_tool import ExampleTool
|
@@ -31,6 +32,7 @@ from .url_context_tool import url_context
|
|
31
32
|
from .vertex_ai_search_tool import VertexAiSearchTool
|
32
33
|
|
33
34
|
__all__ = [
|
35
|
+
'AgentTool',
|
34
36
|
'APIHubToolset',
|
35
37
|
'AuthToolArguments',
|
36
38
|
'BaseTool',
|
@@ -20,7 +20,6 @@ import typing
|
|
20
20
|
from typing import Any
|
21
21
|
from typing import Callable
|
22
22
|
from typing import Dict
|
23
|
-
from typing import Literal
|
24
23
|
from typing import Optional
|
25
24
|
from typing import Union
|
26
25
|
|
@@ -329,7 +328,26 @@ def from_function_with_options(
|
|
329
328
|
return declaration
|
330
329
|
|
331
330
|
return_annotation = inspect.signature(func).return_annotation
|
332
|
-
|
331
|
+
|
332
|
+
# Handle functions with no return annotation or that return None
|
333
|
+
if (
|
334
|
+
return_annotation is inspect._empty
|
335
|
+
or return_annotation is None
|
336
|
+
or return_annotation is type(None)
|
337
|
+
):
|
338
|
+
# Create a response schema for None/null return
|
339
|
+
return_value = inspect.Parameter(
|
340
|
+
'return_value',
|
341
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
342
|
+
annotation=None,
|
343
|
+
)
|
344
|
+
declaration.response = (
|
345
|
+
_function_parameter_parse_util._parse_schema_from_parameter(
|
346
|
+
variant,
|
347
|
+
return_value,
|
348
|
+
func.__name__,
|
349
|
+
)
|
350
|
+
)
|
333
351
|
return declaration
|
334
352
|
|
335
353
|
return_value = inspect.Parameter(
|