google-adk 0.5.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/base_agent.py +76 -30
- google/adk/agents/callback_context.py +2 -6
- google/adk/agents/llm_agent.py +122 -30
- google/adk/agents/loop_agent.py +1 -1
- google/adk/agents/parallel_agent.py +7 -0
- google/adk/agents/readonly_context.py +8 -0
- google/adk/agents/run_config.py +1 -1
- google/adk/agents/sequential_agent.py +31 -0
- google/adk/agents/transcription_entry.py +4 -2
- google/adk/artifacts/gcs_artifact_service.py +1 -1
- google/adk/artifacts/in_memory_artifact_service.py +1 -1
- google/adk/auth/auth_credential.py +10 -2
- google/adk/auth/auth_preprocessor.py +7 -1
- google/adk/auth/auth_tool.py +3 -4
- google/adk/cli/agent_graph.py +5 -5
- google/adk/cli/browser/index.html +4 -4
- google/adk/cli/browser/{main-ULN5R5I5.js → main-PKDNKWJE.js} +59 -60
- google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
- google/adk/cli/cli.py +10 -9
- google/adk/cli/cli_deploy.py +7 -2
- google/adk/cli/cli_eval.py +109 -115
- google/adk/cli/cli_tools_click.py +179 -67
- google/adk/cli/fast_api.py +248 -197
- google/adk/cli/utils/agent_loader.py +137 -0
- google/adk/cli/utils/cleanup.py +40 -0
- google/adk/cli/utils/common.py +23 -0
- google/adk/cli/utils/evals.py +83 -0
- google/adk/cli/utils/logs.py +8 -5
- google/adk/code_executors/__init__.py +3 -1
- google/adk/code_executors/built_in_code_executor.py +52 -0
- google/adk/code_executors/code_execution_utils.py +2 -1
- google/adk/code_executors/container_code_executor.py +0 -1
- google/adk/code_executors/vertex_ai_code_executor.py +6 -8
- google/adk/evaluation/__init__.py +1 -1
- google/adk/evaluation/agent_evaluator.py +168 -128
- google/adk/evaluation/eval_case.py +104 -0
- google/adk/evaluation/eval_metrics.py +74 -0
- google/adk/evaluation/eval_result.py +86 -0
- google/adk/evaluation/eval_set.py +39 -0
- google/adk/evaluation/eval_set_results_manager.py +47 -0
- google/adk/evaluation/eval_sets_manager.py +43 -0
- google/adk/evaluation/evaluation_generator.py +88 -113
- google/adk/evaluation/evaluator.py +58 -0
- google/adk/evaluation/local_eval_set_results_manager.py +113 -0
- google/adk/evaluation/local_eval_sets_manager.py +264 -0
- google/adk/evaluation/response_evaluator.py +106 -1
- google/adk/evaluation/trajectory_evaluator.py +84 -2
- google/adk/events/event.py +6 -1
- google/adk/events/event_actions.py +6 -1
- google/adk/examples/base_example_provider.py +1 -0
- google/adk/examples/example_util.py +3 -2
- google/adk/flows/llm_flows/_code_execution.py +9 -1
- google/adk/flows/llm_flows/audio_transcriber.py +4 -3
- google/adk/flows/llm_flows/base_llm_flow.py +58 -21
- google/adk/flows/llm_flows/contents.py +3 -1
- google/adk/flows/llm_flows/functions.py +9 -8
- google/adk/flows/llm_flows/instructions.py +18 -80
- google/adk/flows/llm_flows/single_flow.py +2 -2
- google/adk/memory/__init__.py +1 -1
- google/adk/memory/_utils.py +23 -0
- google/adk/memory/base_memory_service.py +23 -21
- google/adk/memory/in_memory_memory_service.py +57 -25
- google/adk/memory/memory_entry.py +37 -0
- google/adk/memory/vertex_ai_rag_memory_service.py +38 -15
- google/adk/models/anthropic_llm.py +16 -9
- google/adk/models/base_llm.py +2 -1
- google/adk/models/base_llm_connection.py +2 -0
- google/adk/models/gemini_llm_connection.py +11 -11
- google/adk/models/google_llm.py +12 -2
- google/adk/models/lite_llm.py +80 -23
- google/adk/models/llm_response.py +16 -3
- google/adk/models/registry.py +1 -1
- google/adk/runners.py +98 -42
- google/adk/sessions/__init__.py +1 -1
- google/adk/sessions/_session_util.py +2 -1
- google/adk/sessions/base_session_service.py +6 -33
- google/adk/sessions/database_session_service.py +57 -67
- google/adk/sessions/in_memory_session_service.py +106 -24
- google/adk/sessions/session.py +3 -0
- google/adk/sessions/vertex_ai_session_service.py +44 -51
- google/adk/telemetry.py +7 -2
- google/adk/tools/__init__.py +4 -7
- google/adk/tools/_memory_entry_utils.py +30 -0
- google/adk/tools/agent_tool.py +10 -10
- google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
- google/adk/tools/apihub_tool/clients/apihub_client.py +10 -3
- google/adk/tools/apihub_tool/clients/secret_client.py +1 -0
- google/adk/tools/application_integration_tool/application_integration_toolset.py +111 -85
- google/adk/tools/application_integration_tool/clients/connections_client.py +28 -1
- google/adk/tools/application_integration_tool/clients/integration_client.py +7 -5
- google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
- google/adk/tools/base_toolset.py +96 -0
- google/adk/tools/bigquery/__init__.py +28 -0
- google/adk/tools/bigquery/bigquery_credentials.py +216 -0
- google/adk/tools/bigquery/bigquery_tool.py +116 -0
- google/adk/tools/{built_in_code_execution_tool.py → enterprise_search_tool.py} +17 -11
- google/adk/tools/function_parameter_parse_util.py +9 -2
- google/adk/tools/function_tool.py +33 -3
- google/adk/tools/get_user_choice_tool.py +1 -0
- google/adk/tools/google_api_tool/__init__.py +24 -70
- google/adk/tools/google_api_tool/google_api_tool.py +12 -6
- google/adk/tools/google_api_tool/{google_api_tool_set.py → google_api_toolset.py} +57 -55
- google/adk/tools/google_api_tool/google_api_toolsets.py +108 -0
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
- google/adk/tools/google_search_tool.py +2 -2
- google/adk/tools/langchain_tool.py +96 -49
- google/adk/tools/load_memory_tool.py +14 -5
- google/adk/tools/mcp_tool/__init__.py +3 -2
- google/adk/tools/mcp_tool/conversion_utils.py +6 -2
- google/adk/tools/mcp_tool/mcp_session_manager.py +80 -69
- google/adk/tools/mcp_tool/mcp_tool.py +35 -32
- google/adk/tools/mcp_tool/mcp_toolset.py +99 -194
- google/adk/tools/openapi_tool/auth/credential_exchangers/base_credential_exchanger.py +1 -3
- google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +6 -7
- google/adk/tools/openapi_tool/common/common.py +5 -1
- google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py +7 -2
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +27 -7
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +36 -32
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +11 -1
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
- google/adk/tools/preload_memory_tool.py +27 -18
- google/adk/tools/retrieval/__init__.py +1 -1
- google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
- google/adk/tools/toolbox_toolset.py +107 -0
- google/adk/tools/transfer_to_agent_tool.py +0 -1
- google/adk/utils/__init__.py +13 -0
- google/adk/utils/instructions_utils.py +131 -0
- google/adk/version.py +1 -1
- {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/METADATA +18 -19
- google_adk-1.1.0.dist-info/RECORD +200 -0
- google/adk/agents/remote_agent.py +0 -50
- google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -18
- google/adk/cli/fast_api.py.orig +0 -728
- google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
- google/adk/tools/toolbox_tool.py +0 -46
- google_adk-0.5.0.dist-info/RECORD +0 -180
- {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/WHEEL +0 -0
- {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/entry_points.txt +0 -0
- {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/licenses/LICENSE +0 -0
google/adk/runners.py
CHANGED
@@ -21,8 +21,8 @@ import threading
|
|
21
21
|
from typing import AsyncGenerator
|
22
22
|
from typing import Generator
|
23
23
|
from typing import Optional
|
24
|
+
import warnings
|
24
25
|
|
25
|
-
from deprecated import deprecated
|
26
26
|
from google.genai import types
|
27
27
|
|
28
28
|
from .agents.active_streaming_tool import ActiveStreamingTool
|
@@ -32,9 +32,9 @@ from .agents.invocation_context import new_invocation_context_id
|
|
32
32
|
from .agents.live_request_queue import LiveRequestQueue
|
33
33
|
from .agents.llm_agent import LlmAgent
|
34
34
|
from .agents.run_config import RunConfig
|
35
|
-
from .agents.run_config import StreamingMode
|
36
35
|
from .artifacts.base_artifact_service import BaseArtifactService
|
37
36
|
from .artifacts.in_memory_artifact_service import InMemoryArtifactService
|
37
|
+
from .code_executors.built_in_code_executor import BuiltInCodeExecutor
|
38
38
|
from .events.event import Event
|
39
39
|
from .memory.base_memory_service import BaseMemoryService
|
40
40
|
from .memory.in_memory_memory_service import InMemoryMemoryService
|
@@ -42,9 +42,9 @@ from .sessions.base_session_service import BaseSessionService
|
|
42
42
|
from .sessions.in_memory_session_service import InMemorySessionService
|
43
43
|
from .sessions.session import Session
|
44
44
|
from .telemetry import tracer
|
45
|
-
from .tools.
|
45
|
+
from .tools.base_toolset import BaseToolset
|
46
46
|
|
47
|
-
logger = logging.getLogger(__name__)
|
47
|
+
logger = logging.getLogger('google_adk.' + __name__)
|
48
48
|
|
49
49
|
|
50
50
|
class Runner:
|
@@ -172,7 +172,7 @@ class Runner:
|
|
172
172
|
The events generated by the agent.
|
173
173
|
"""
|
174
174
|
with tracer.start_as_current_span('invocation'):
|
175
|
-
session = self.session_service.get_session(
|
175
|
+
session = await self.session_service.get_session(
|
176
176
|
app_name=self.app_name, user_id=user_id, session_id=session_id
|
177
177
|
)
|
178
178
|
if not session:
|
@@ -196,7 +196,7 @@ class Runner:
|
|
196
196
|
invocation_context.agent = self._find_agent_to_run(session, root_agent)
|
197
197
|
async for event in invocation_context.agent.run_async(invocation_context):
|
198
198
|
if not event.partial:
|
199
|
-
self.session_service.append_event(session=session, event=event)
|
199
|
+
await self.session_service.append_event(session=session, event=event)
|
200
200
|
yield event
|
201
201
|
|
202
202
|
async def _append_new_message_to_session(
|
@@ -241,30 +241,57 @@ class Runner:
|
|
241
241
|
author='user',
|
242
242
|
content=new_message,
|
243
243
|
)
|
244
|
-
self.session_service.append_event(session=session, event=event)
|
244
|
+
await self.session_service.append_event(session=session, event=event)
|
245
245
|
|
246
246
|
async def run_live(
|
247
247
|
self,
|
248
248
|
*,
|
249
|
-
|
249
|
+
user_id: Optional[str] = None,
|
250
|
+
session_id: Optional[str] = None,
|
250
251
|
live_request_queue: LiveRequestQueue,
|
251
252
|
run_config: RunConfig = RunConfig(),
|
253
|
+
session: Optional[Session] = None,
|
252
254
|
) -> AsyncGenerator[Event, None]:
|
253
255
|
"""Runs the agent in live mode (experimental feature).
|
254
256
|
|
255
257
|
Args:
|
256
|
-
|
258
|
+
user_id: The user ID for the session. Required if `session` is None.
|
259
|
+
session_id: The session ID for the session. Required if `session` is
|
260
|
+
None.
|
257
261
|
live_request_queue: The queue for live requests.
|
258
262
|
run_config: The run config for the agent.
|
263
|
+
session: The session to use. This parameter is deprecated, please use
|
264
|
+
`user_id` and `session_id` instead.
|
259
265
|
|
260
266
|
Yields:
|
261
|
-
|
267
|
+
AsyncGenerator[Event, None]: An asynchronous generator that yields
|
268
|
+
`Event`
|
269
|
+
objects as they are produced by the agent during its live execution.
|
262
270
|
|
263
271
|
.. warning::
|
264
272
|
This feature is **experimental** and its API or behavior may change
|
265
273
|
in future releases.
|
274
|
+
|
275
|
+
.. note::
|
276
|
+
Either `session` or both `user_id` and `session_id` must be provided.
|
266
277
|
"""
|
267
|
-
|
278
|
+
if session is None and (user_id is None or session_id is None):
|
279
|
+
raise ValueError(
|
280
|
+
'Either session or user_id and session_id must be provided.'
|
281
|
+
)
|
282
|
+
if session is not None:
|
283
|
+
warnings.warn(
|
284
|
+
'The `session` parameter is deprecated. Please use `user_id` and'
|
285
|
+
' `session_id` instead.',
|
286
|
+
DeprecationWarning,
|
287
|
+
stacklevel=2,
|
288
|
+
)
|
289
|
+
if not session:
|
290
|
+
session = await self.session_service.get_session(
|
291
|
+
app_name=self.app_name, user_id=user_id, session_id=session_id
|
292
|
+
)
|
293
|
+
if not session:
|
294
|
+
raise ValueError(f'Session not found: {session_id}')
|
268
295
|
invocation_context = self._new_invocation_context_for_live(
|
269
296
|
session,
|
270
297
|
live_request_queue=live_request_queue,
|
@@ -276,37 +303,29 @@ class Runner:
|
|
276
303
|
|
277
304
|
invocation_context.active_streaming_tools = {}
|
278
305
|
# TODO(hangfei): switch to use canonical_tools.
|
279
|
-
for
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
306
|
+
# for shell agents, there is no tools associated with it so we should skip.
|
307
|
+
if hasattr(invocation_context.agent, 'tools'):
|
308
|
+
for tool in invocation_context.agent.tools:
|
309
|
+
# replicate a LiveRequestQueue for streaming tools that relis on
|
310
|
+
# LiveRequestQueue
|
311
|
+
from typing import get_type_hints
|
312
|
+
|
313
|
+
type_hints = get_type_hints(tool)
|
314
|
+
for arg_type in type_hints.values():
|
315
|
+
if arg_type is LiveRequestQueue:
|
316
|
+
if not invocation_context.active_streaming_tools:
|
317
|
+
invocation_context.active_streaming_tools = {}
|
318
|
+
active_streaming_tools = ActiveStreamingTool(
|
319
|
+
stream=LiveRequestQueue()
|
320
|
+
)
|
321
|
+
invocation_context.active_streaming_tools[tool.__name__] = (
|
322
|
+
active_streaming_tools
|
323
|
+
)
|
295
324
|
|
296
325
|
async for event in invocation_context.agent.run_live(invocation_context):
|
297
|
-
self.session_service.append_event(session=session, event=event)
|
326
|
+
await self.session_service.append_event(session=session, event=event)
|
298
327
|
yield event
|
299
328
|
|
300
|
-
async def close_session(self, session: Session):
|
301
|
-
"""Closes a session and adds it to the memory service (experimental feature).
|
302
|
-
|
303
|
-
Args:
|
304
|
-
session: The session to close.
|
305
|
-
"""
|
306
|
-
if self.memory_service:
|
307
|
-
await self.memory_service.add_session_to_memory(session)
|
308
|
-
self.session_service.close_session(session=session)
|
309
|
-
|
310
329
|
def _find_agent_to_run(
|
311
330
|
self, session: Session, root_agent: BaseAgent
|
312
331
|
) -> BaseAgent:
|
@@ -391,8 +410,8 @@ class Runner:
|
|
391
410
|
f'CFC is not supported for model: {model_name} in agent:'
|
392
411
|
f' {self.agent.name}'
|
393
412
|
)
|
394
|
-
if
|
395
|
-
self.agent.
|
413
|
+
if not isinstance(self.agent.code_executor, BuiltInCodeExecutor):
|
414
|
+
self.agent.code_executor = BuiltInCodeExecutor()
|
396
415
|
|
397
416
|
return InvocationContext(
|
398
417
|
artifact_service=self.artifact_service,
|
@@ -430,12 +449,46 @@ class Runner:
|
|
430
449
|
run_config.output_audio_transcription = (
|
431
450
|
types.AudioTranscriptionConfig()
|
432
451
|
)
|
452
|
+
if not run_config.input_audio_transcription:
|
453
|
+
# need this input transcription for agent transferring in live mode.
|
454
|
+
run_config.input_audio_transcription = types.AudioTranscriptionConfig()
|
433
455
|
return self._new_invocation_context(
|
434
456
|
session,
|
435
457
|
live_request_queue=live_request_queue,
|
436
458
|
run_config=run_config,
|
437
459
|
)
|
438
460
|
|
461
|
+
def _collect_toolset(self, agent: BaseAgent) -> set[BaseToolset]:
|
462
|
+
toolsets = set()
|
463
|
+
if isinstance(agent, LlmAgent):
|
464
|
+
for tool_union in agent.tools:
|
465
|
+
if isinstance(tool_union, BaseToolset):
|
466
|
+
toolsets.add(tool_union)
|
467
|
+
for sub_agent in agent.sub_agents:
|
468
|
+
toolsets.update(self._collect_toolset(sub_agent))
|
469
|
+
return toolsets
|
470
|
+
|
471
|
+
async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]):
|
472
|
+
"""Clean up toolsets with proper task context management."""
|
473
|
+
if not toolsets_to_close:
|
474
|
+
return
|
475
|
+
|
476
|
+
# This maintains the same task context throughout cleanup
|
477
|
+
for toolset in toolsets_to_close:
|
478
|
+
try:
|
479
|
+
logger.info('Closing toolset: %s', type(toolset).__name__)
|
480
|
+
# Use asyncio.wait_for to add timeout protection
|
481
|
+
await asyncio.wait_for(toolset.close(), timeout=10.0)
|
482
|
+
logger.info('Successfully closed toolset: %s', type(toolset).__name__)
|
483
|
+
except asyncio.TimeoutError:
|
484
|
+
logger.warning('Toolset %s cleanup timed out', type(toolset).__name__)
|
485
|
+
except Exception as e:
|
486
|
+
logger.error('Error closing toolset %s: %s', type(toolset).__name__, e)
|
487
|
+
|
488
|
+
async def close(self):
|
489
|
+
"""Closes the runner."""
|
490
|
+
await self._cleanup_toolsets(self._collect_toolset(self.agent))
|
491
|
+
|
439
492
|
|
440
493
|
class InMemoryRunner(Runner):
|
441
494
|
"""An in-memory Runner for testing and development.
|
@@ -448,9 +501,11 @@ class InMemoryRunner(Runner):
|
|
448
501
|
agent: The root agent to run.
|
449
502
|
app_name: The application name of the runner. Defaults to
|
450
503
|
'InMemoryRunner'.
|
504
|
+
_in_memory_session_service: Deprecated. Please don't use. The in-memory
|
505
|
+
session service for the runner.
|
451
506
|
"""
|
452
507
|
|
453
|
-
def __init__(self, agent:
|
508
|
+
def __init__(self, agent: BaseAgent, *, app_name: str = 'InMemoryRunner'):
|
454
509
|
"""Initializes the InMemoryRunner.
|
455
510
|
|
456
511
|
Args:
|
@@ -458,10 +513,11 @@ class InMemoryRunner(Runner):
|
|
458
513
|
app_name: The application name of the runner. Defaults to
|
459
514
|
'InMemoryRunner'.
|
460
515
|
"""
|
516
|
+
self._in_memory_session_service = InMemorySessionService()
|
461
517
|
super().__init__(
|
462
518
|
app_name=app_name,
|
463
519
|
agent=agent,
|
464
520
|
artifact_service=InMemoryArtifactService(),
|
465
|
-
session_service=
|
521
|
+
session_service=self._in_memory_session_service,
|
466
522
|
memory_service=InMemoryMemoryService(),
|
467
523
|
)
|
google/adk/sessions/__init__.py
CHANGED
@@ -40,13 +40,6 @@ class ListSessionsResponse(BaseModel):
|
|
40
40
|
sessions: list[Session] = Field(default_factory=list)
|
41
41
|
|
42
42
|
|
43
|
-
class ListEventsResponse(BaseModel):
|
44
|
-
"""The response of listing events in a session."""
|
45
|
-
|
46
|
-
events: list[Event] = Field(default_factory=list)
|
47
|
-
next_page_token: Optional[str] = None
|
48
|
-
|
49
|
-
|
50
43
|
class BaseSessionService(abc.ABC):
|
51
44
|
"""Base class for session services.
|
52
45
|
|
@@ -54,7 +47,7 @@ class BaseSessionService(abc.ABC):
|
|
54
47
|
"""
|
55
48
|
|
56
49
|
@abc.abstractmethod
|
57
|
-
def create_session(
|
50
|
+
async def create_session(
|
58
51
|
self,
|
59
52
|
*,
|
60
53
|
app_name: str,
|
@@ -74,10 +67,9 @@ class BaseSessionService(abc.ABC):
|
|
74
67
|
Returns:
|
75
68
|
session: The newly created session instance.
|
76
69
|
"""
|
77
|
-
pass
|
78
70
|
|
79
71
|
@abc.abstractmethod
|
80
|
-
def get_session(
|
72
|
+
async def get_session(
|
81
73
|
self,
|
82
74
|
*,
|
83
75
|
app_name: str,
|
@@ -86,39 +78,20 @@ class BaseSessionService(abc.ABC):
|
|
86
78
|
config: Optional[GetSessionConfig] = None,
|
87
79
|
) -> Optional[Session]:
|
88
80
|
"""Gets a session."""
|
89
|
-
pass
|
90
81
|
|
91
82
|
@abc.abstractmethod
|
92
|
-
def list_sessions(
|
83
|
+
async def list_sessions(
|
93
84
|
self, *, app_name: str, user_id: str
|
94
85
|
) -> ListSessionsResponse:
|
95
86
|
"""Lists all the sessions."""
|
96
|
-
pass
|
97
87
|
|
98
88
|
@abc.abstractmethod
|
99
|
-
def delete_session(
|
89
|
+
async def delete_session(
|
100
90
|
self, *, app_name: str, user_id: str, session_id: str
|
101
91
|
) -> None:
|
102
92
|
"""Deletes a session."""
|
103
|
-
pass
|
104
|
-
|
105
|
-
@abc.abstractmethod
|
106
|
-
def list_events(
|
107
|
-
self,
|
108
|
-
*,
|
109
|
-
app_name: str,
|
110
|
-
user_id: str,
|
111
|
-
session_id: str,
|
112
|
-
) -> ListEventsResponse:
|
113
|
-
"""Lists events in a session."""
|
114
|
-
pass
|
115
|
-
|
116
|
-
def close_session(self, *, session: Session):
|
117
|
-
"""Closes a session."""
|
118
|
-
# TODO: determine whether we want to finalize the session here.
|
119
|
-
pass
|
120
93
|
|
121
|
-
def append_event(self, session: Session, event: Event) -> Event:
|
94
|
+
async def append_event(self, session: Session, event: Event) -> Event:
|
122
95
|
"""Appends an event to a session object."""
|
123
96
|
if event.partial:
|
124
97
|
return event
|
@@ -126,7 +99,7 @@ class BaseSessionService(abc.ABC):
|
|
126
99
|
session.events.append(event)
|
127
100
|
return event
|
128
101
|
|
129
|
-
def __update_session_state(self, session: Session, event: Event):
|
102
|
+
def __update_session_state(self, session: Session, event: Event) -> None:
|
130
103
|
"""Updates the session state based on the event."""
|
131
104
|
if not event.actions or not event.actions.state_delta:
|
132
105
|
return
|
@@ -15,7 +15,8 @@ import copy
|
|
15
15
|
from datetime import datetime
|
16
16
|
import json
|
17
17
|
import logging
|
18
|
-
from typing import Any
|
18
|
+
from typing import Any
|
19
|
+
from typing import Optional
|
19
20
|
import uuid
|
20
21
|
|
21
22
|
from sqlalchemy import Boolean
|
@@ -45,27 +46,22 @@ from sqlalchemy.types import TypeDecorator
|
|
45
46
|
from typing_extensions import override
|
46
47
|
from tzlocal import get_localzone
|
47
48
|
|
48
|
-
from ..events.event import Event
|
49
49
|
from . import _session_util
|
50
|
+
from ..events.event import Event
|
50
51
|
from .base_session_service import BaseSessionService
|
51
52
|
from .base_session_service import GetSessionConfig
|
52
|
-
from .base_session_service import ListEventsResponse
|
53
53
|
from .base_session_service import ListSessionsResponse
|
54
54
|
from .session import Session
|
55
55
|
from .state import State
|
56
56
|
|
57
|
-
|
58
|
-
logger = logging.getLogger(__name__)
|
57
|
+
logger = logging.getLogger("google_adk." + __name__)
|
59
58
|
|
60
59
|
DEFAULT_MAX_KEY_LENGTH = 128
|
61
60
|
DEFAULT_MAX_VARCHAR_LENGTH = 256
|
62
61
|
|
63
62
|
|
64
63
|
class DynamicJSON(TypeDecorator):
|
65
|
-
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
|
66
|
-
|
67
|
-
serialization for other databases.
|
68
|
-
"""
|
64
|
+
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON serialization for other databases."""
|
69
65
|
|
70
66
|
impl = Text # Default implementation is TEXT
|
71
67
|
|
@@ -243,10 +239,7 @@ class DatabaseSessionService(BaseSessionService):
|
|
243
239
|
"""A session service that uses a database for storage."""
|
244
240
|
|
245
241
|
def __init__(self, db_url: str):
|
246
|
-
"""
|
247
|
-
Args:
|
248
|
-
db_url: The database URL to connect to.
|
249
|
-
"""
|
242
|
+
"""Initializes the database session service with a database URL."""
|
250
243
|
# 1. Create DB engine for db connection
|
251
244
|
# 2. Create all tables based on schema
|
252
245
|
# 3. Initialize all properties
|
@@ -275,7 +268,7 @@ class DatabaseSessionService(BaseSessionService):
|
|
275
268
|
self.inspector = inspect(self.db_engine)
|
276
269
|
|
277
270
|
# DB session factory method
|
278
|
-
self.
|
271
|
+
self.database_session_factory: sessionmaker[DatabaseSessionFactory] = (
|
279
272
|
sessionmaker(bind=self.db_engine)
|
280
273
|
)
|
281
274
|
|
@@ -284,7 +277,7 @@ class DatabaseSessionService(BaseSessionService):
|
|
284
277
|
Base.metadata.create_all(self.db_engine)
|
285
278
|
|
286
279
|
@override
|
287
|
-
def create_session(
|
280
|
+
async def create_session(
|
288
281
|
self,
|
289
282
|
*,
|
290
283
|
app_name: str,
|
@@ -298,11 +291,11 @@ class DatabaseSessionService(BaseSessionService):
|
|
298
291
|
# 4. Build the session object with generated id
|
299
292
|
# 5. Return the session
|
300
293
|
|
301
|
-
with self.
|
294
|
+
with self.database_session_factory() as session_factory:
|
302
295
|
|
303
296
|
# Fetch app and user states from storage
|
304
|
-
storage_app_state =
|
305
|
-
storage_user_state =
|
297
|
+
storage_app_state = session_factory.get(StorageAppState, (app_name))
|
298
|
+
storage_user_state = session_factory.get(
|
306
299
|
StorageUserState, (app_name, user_id)
|
307
300
|
)
|
308
301
|
|
@@ -312,12 +305,12 @@ class DatabaseSessionService(BaseSessionService):
|
|
312
305
|
# Create state tables if not exist
|
313
306
|
if not storage_app_state:
|
314
307
|
storage_app_state = StorageAppState(app_name=app_name, state={})
|
315
|
-
|
308
|
+
session_factory.add(storage_app_state)
|
316
309
|
if not storage_user_state:
|
317
310
|
storage_user_state = StorageUserState(
|
318
311
|
app_name=app_name, user_id=user_id, state={}
|
319
312
|
)
|
320
|
-
|
313
|
+
session_factory.add(storage_user_state)
|
321
314
|
|
322
315
|
# Extract state deltas
|
323
316
|
app_state_delta, user_state_delta, session_state = _extract_state_delta(
|
@@ -341,10 +334,10 @@ class DatabaseSessionService(BaseSessionService):
|
|
341
334
|
id=session_id,
|
342
335
|
state=session_state,
|
343
336
|
)
|
344
|
-
|
345
|
-
|
337
|
+
session_factory.add(storage_session)
|
338
|
+
session_factory.commit()
|
346
339
|
|
347
|
-
|
340
|
+
session_factory.refresh(storage_session)
|
348
341
|
|
349
342
|
# Merge states for response
|
350
343
|
merged_state = _merge_state(app_state, user_state, session_state)
|
@@ -358,7 +351,7 @@ class DatabaseSessionService(BaseSessionService):
|
|
358
351
|
return session
|
359
352
|
|
360
353
|
@override
|
361
|
-
def get_session(
|
354
|
+
async def get_session(
|
362
355
|
self,
|
363
356
|
*,
|
364
357
|
app_name: str,
|
@@ -369,29 +362,35 @@ class DatabaseSessionService(BaseSessionService):
|
|
369
362
|
# 1. Get the storage session entry from session table
|
370
363
|
# 2. Get all the events based on session id and filtering config
|
371
364
|
# 3. Convert and return the session
|
372
|
-
with self.
|
373
|
-
storage_session =
|
365
|
+
with self.database_session_factory() as session_factory:
|
366
|
+
storage_session = session_factory.get(
|
374
367
|
StorageSession, (app_name, user_id, session_id)
|
375
368
|
)
|
376
369
|
if storage_session is None:
|
377
370
|
return None
|
378
371
|
|
372
|
+
if config and config.after_timestamp:
|
373
|
+
after_dt = datetime.fromtimestamp(config.after_timestamp)
|
374
|
+
timestamp_filter = StorageEvent.timestamp >= after_dt
|
375
|
+
else:
|
376
|
+
timestamp_filter = True
|
377
|
+
|
379
378
|
storage_events = (
|
380
|
-
|
379
|
+
session_factory.query(StorageEvent)
|
381
380
|
.filter(StorageEvent.session_id == storage_session.id)
|
382
|
-
.filter(
|
383
|
-
|
384
|
-
|
385
|
-
|
381
|
+
.filter(timestamp_filter)
|
382
|
+
.order_by(StorageEvent.timestamp.desc())
|
383
|
+
.limit(
|
384
|
+
config.num_recent_events
|
385
|
+
if config and config.num_recent_events
|
386
|
+
else None
|
386
387
|
)
|
387
|
-
.limit(config.num_recent_events if config else None)
|
388
|
-
.order_by(StorageEvent.timestamp.asc())
|
389
388
|
.all()
|
390
389
|
)
|
391
390
|
|
392
391
|
# Fetch states from storage
|
393
|
-
storage_app_state =
|
394
|
-
storage_user_state =
|
392
|
+
storage_app_state = session_factory.get(StorageAppState, (app_name))
|
393
|
+
storage_user_state = session_factory.get(
|
395
394
|
StorageUserState, (app_name, user_id)
|
396
395
|
)
|
397
396
|
|
@@ -427,17 +426,17 @@ class DatabaseSessionService(BaseSessionService):
|
|
427
426
|
error_message=e.error_message,
|
428
427
|
interrupted=e.interrupted,
|
429
428
|
)
|
430
|
-
for e in storage_events
|
429
|
+
for e in reversed(storage_events)
|
431
430
|
]
|
432
431
|
return session
|
433
432
|
|
434
433
|
@override
|
435
|
-
def list_sessions(
|
434
|
+
async def list_sessions(
|
436
435
|
self, *, app_name: str, user_id: str
|
437
436
|
) -> ListSessionsResponse:
|
438
|
-
with self.
|
437
|
+
with self.database_session_factory() as session_factory:
|
439
438
|
results = (
|
440
|
-
|
439
|
+
session_factory.query(StorageSession)
|
441
440
|
.filter(StorageSession.app_name == app_name)
|
442
441
|
.filter(StorageSession.user_id == user_id)
|
443
442
|
.all()
|
@@ -455,20 +454,20 @@ class DatabaseSessionService(BaseSessionService):
|
|
455
454
|
return ListSessionsResponse(sessions=sessions)
|
456
455
|
|
457
456
|
@override
|
458
|
-
def delete_session(
|
457
|
+
async def delete_session(
|
459
458
|
self, app_name: str, user_id: str, session_id: str
|
460
459
|
) -> None:
|
461
|
-
with self.
|
460
|
+
with self.database_session_factory() as session_factory:
|
462
461
|
stmt = delete(StorageSession).where(
|
463
462
|
StorageSession.app_name == app_name,
|
464
463
|
StorageSession.user_id == user_id,
|
465
464
|
StorageSession.id == session_id,
|
466
465
|
)
|
467
|
-
|
468
|
-
|
466
|
+
session_factory.execute(stmt)
|
467
|
+
session_factory.commit()
|
469
468
|
|
470
469
|
@override
|
471
|
-
def append_event(self, session: Session, event: Event) -> Event:
|
470
|
+
async def append_event(self, session: Session, event: Event) -> Event:
|
472
471
|
logger.info(f"Append event: {event} to session {session.id}")
|
473
472
|
|
474
473
|
if event.partial:
|
@@ -477,24 +476,25 @@ class DatabaseSessionService(BaseSessionService):
|
|
477
476
|
# 1. Check if timestamp is stale
|
478
477
|
# 2. Update session attributes based on event config
|
479
478
|
# 3. Store event to table
|
480
|
-
with self.
|
481
|
-
storage_session =
|
479
|
+
with self.database_session_factory() as session_factory:
|
480
|
+
storage_session = session_factory.get(
|
482
481
|
StorageSession, (session.app_name, session.user_id, session.id)
|
483
482
|
)
|
484
483
|
|
485
484
|
if storage_session.update_time.timestamp() > session.last_update_time:
|
486
485
|
raise ValueError(
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
486
|
+
"The last_update_time provided in the session object"
|
487
|
+
f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is"
|
488
|
+
" earlier than the update_time in the storage_session"
|
489
|
+
f" {storage_session.update_time:'%Y-%m-%d %H:%M:%S'}. Please check"
|
490
|
+
" if it is a stale session."
|
491
|
+
)
|
492
492
|
|
493
493
|
# Fetch states from storage
|
494
|
-
storage_app_state =
|
494
|
+
storage_app_state = session_factory.get(
|
495
495
|
StorageAppState, (session.app_name)
|
496
496
|
)
|
497
|
-
storage_user_state =
|
497
|
+
storage_user_state = session_factory.get(
|
498
498
|
StorageUserState, (session.app_name, session.user_id)
|
499
499
|
)
|
500
500
|
|
@@ -543,28 +543,18 @@ class DatabaseSessionService(BaseSessionService):
|
|
543
543
|
if event.content:
|
544
544
|
storage_event.content = _session_util.encode_content(event.content)
|
545
545
|
|
546
|
-
|
546
|
+
session_factory.add(storage_event)
|
547
547
|
|
548
|
-
|
549
|
-
|
548
|
+
session_factory.commit()
|
549
|
+
session_factory.refresh(storage_session)
|
550
550
|
|
551
551
|
# Update timestamp with commit time
|
552
552
|
session.last_update_time = storage_session.update_time.timestamp()
|
553
553
|
|
554
554
|
# Also update the in-memory session
|
555
|
-
super().append_event(session=session, event=event)
|
555
|
+
await super().append_event(session=session, event=event)
|
556
556
|
return event
|
557
557
|
|
558
|
-
@override
|
559
|
-
def list_events(
|
560
|
-
self,
|
561
|
-
*,
|
562
|
-
app_name: str,
|
563
|
-
user_id: str,
|
564
|
-
session_id: str,
|
565
|
-
) -> ListEventsResponse:
|
566
|
-
raise NotImplementedError()
|
567
|
-
|
568
558
|
|
569
559
|
def convert_event(event: StorageEvent) -> Event:
|
570
560
|
"""Converts a storage event to an event."""
|