google-adk 0.5.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. google/adk/agents/base_agent.py +76 -30
  2. google/adk/agents/callback_context.py +2 -6
  3. google/adk/agents/llm_agent.py +122 -30
  4. google/adk/agents/loop_agent.py +1 -1
  5. google/adk/agents/parallel_agent.py +7 -0
  6. google/adk/agents/readonly_context.py +8 -0
  7. google/adk/agents/run_config.py +1 -1
  8. google/adk/agents/sequential_agent.py +31 -0
  9. google/adk/agents/transcription_entry.py +4 -2
  10. google/adk/artifacts/gcs_artifact_service.py +1 -1
  11. google/adk/artifacts/in_memory_artifact_service.py +1 -1
  12. google/adk/auth/auth_credential.py +10 -2
  13. google/adk/auth/auth_preprocessor.py +7 -1
  14. google/adk/auth/auth_tool.py +3 -4
  15. google/adk/cli/agent_graph.py +5 -5
  16. google/adk/cli/browser/index.html +4 -4
  17. google/adk/cli/browser/{main-ULN5R5I5.js → main-PKDNKWJE.js} +59 -60
  18. google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
  19. google/adk/cli/cli.py +10 -9
  20. google/adk/cli/cli_deploy.py +7 -2
  21. google/adk/cli/cli_eval.py +109 -115
  22. google/adk/cli/cli_tools_click.py +179 -67
  23. google/adk/cli/fast_api.py +248 -197
  24. google/adk/cli/utils/agent_loader.py +137 -0
  25. google/adk/cli/utils/cleanup.py +40 -0
  26. google/adk/cli/utils/common.py +23 -0
  27. google/adk/cli/utils/evals.py +83 -0
  28. google/adk/cli/utils/logs.py +8 -5
  29. google/adk/code_executors/__init__.py +3 -1
  30. google/adk/code_executors/built_in_code_executor.py +52 -0
  31. google/adk/code_executors/code_execution_utils.py +2 -1
  32. google/adk/code_executors/container_code_executor.py +0 -1
  33. google/adk/code_executors/vertex_ai_code_executor.py +6 -8
  34. google/adk/evaluation/__init__.py +1 -1
  35. google/adk/evaluation/agent_evaluator.py +168 -128
  36. google/adk/evaluation/eval_case.py +104 -0
  37. google/adk/evaluation/eval_metrics.py +74 -0
  38. google/adk/evaluation/eval_result.py +86 -0
  39. google/adk/evaluation/eval_set.py +39 -0
  40. google/adk/evaluation/eval_set_results_manager.py +47 -0
  41. google/adk/evaluation/eval_sets_manager.py +43 -0
  42. google/adk/evaluation/evaluation_generator.py +88 -113
  43. google/adk/evaluation/evaluator.py +58 -0
  44. google/adk/evaluation/local_eval_set_results_manager.py +113 -0
  45. google/adk/evaluation/local_eval_sets_manager.py +264 -0
  46. google/adk/evaluation/response_evaluator.py +106 -1
  47. google/adk/evaluation/trajectory_evaluator.py +84 -2
  48. google/adk/events/event.py +6 -1
  49. google/adk/events/event_actions.py +6 -1
  50. google/adk/examples/base_example_provider.py +1 -0
  51. google/adk/examples/example_util.py +3 -2
  52. google/adk/flows/llm_flows/_code_execution.py +9 -1
  53. google/adk/flows/llm_flows/audio_transcriber.py +4 -3
  54. google/adk/flows/llm_flows/base_llm_flow.py +58 -21
  55. google/adk/flows/llm_flows/contents.py +3 -1
  56. google/adk/flows/llm_flows/functions.py +9 -8
  57. google/adk/flows/llm_flows/instructions.py +18 -80
  58. google/adk/flows/llm_flows/single_flow.py +2 -2
  59. google/adk/memory/__init__.py +1 -1
  60. google/adk/memory/_utils.py +23 -0
  61. google/adk/memory/base_memory_service.py +23 -21
  62. google/adk/memory/in_memory_memory_service.py +57 -25
  63. google/adk/memory/memory_entry.py +37 -0
  64. google/adk/memory/vertex_ai_rag_memory_service.py +38 -15
  65. google/adk/models/anthropic_llm.py +16 -9
  66. google/adk/models/base_llm.py +2 -1
  67. google/adk/models/base_llm_connection.py +2 -0
  68. google/adk/models/gemini_llm_connection.py +11 -11
  69. google/adk/models/google_llm.py +12 -2
  70. google/adk/models/lite_llm.py +80 -23
  71. google/adk/models/llm_response.py +16 -3
  72. google/adk/models/registry.py +1 -1
  73. google/adk/runners.py +98 -42
  74. google/adk/sessions/__init__.py +1 -1
  75. google/adk/sessions/_session_util.py +2 -1
  76. google/adk/sessions/base_session_service.py +6 -33
  77. google/adk/sessions/database_session_service.py +57 -67
  78. google/adk/sessions/in_memory_session_service.py +106 -24
  79. google/adk/sessions/session.py +3 -0
  80. google/adk/sessions/vertex_ai_session_service.py +44 -51
  81. google/adk/telemetry.py +7 -2
  82. google/adk/tools/__init__.py +4 -7
  83. google/adk/tools/_memory_entry_utils.py +30 -0
  84. google/adk/tools/agent_tool.py +10 -10
  85. google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
  86. google/adk/tools/apihub_tool/clients/apihub_client.py +10 -3
  87. google/adk/tools/apihub_tool/clients/secret_client.py +1 -0
  88. google/adk/tools/application_integration_tool/application_integration_toolset.py +111 -85
  89. google/adk/tools/application_integration_tool/clients/connections_client.py +28 -1
  90. google/adk/tools/application_integration_tool/clients/integration_client.py +7 -5
  91. google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
  92. google/adk/tools/base_toolset.py +96 -0
  93. google/adk/tools/bigquery/__init__.py +28 -0
  94. google/adk/tools/bigquery/bigquery_credentials.py +216 -0
  95. google/adk/tools/bigquery/bigquery_tool.py +116 -0
  96. google/adk/tools/{built_in_code_execution_tool.py → enterprise_search_tool.py} +17 -11
  97. google/adk/tools/function_parameter_parse_util.py +9 -2
  98. google/adk/tools/function_tool.py +33 -3
  99. google/adk/tools/get_user_choice_tool.py +1 -0
  100. google/adk/tools/google_api_tool/__init__.py +24 -70
  101. google/adk/tools/google_api_tool/google_api_tool.py +12 -6
  102. google/adk/tools/google_api_tool/{google_api_tool_set.py → google_api_toolset.py} +57 -55
  103. google/adk/tools/google_api_tool/google_api_toolsets.py +108 -0
  104. google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
  105. google/adk/tools/google_search_tool.py +2 -2
  106. google/adk/tools/langchain_tool.py +96 -49
  107. google/adk/tools/load_memory_tool.py +14 -5
  108. google/adk/tools/mcp_tool/__init__.py +3 -2
  109. google/adk/tools/mcp_tool/conversion_utils.py +6 -2
  110. google/adk/tools/mcp_tool/mcp_session_manager.py +80 -69
  111. google/adk/tools/mcp_tool/mcp_tool.py +35 -32
  112. google/adk/tools/mcp_tool/mcp_toolset.py +99 -194
  113. google/adk/tools/openapi_tool/auth/credential_exchangers/base_credential_exchanger.py +1 -3
  114. google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +6 -7
  115. google/adk/tools/openapi_tool/common/common.py +5 -1
  116. google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py +7 -2
  117. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +27 -7
  118. google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +36 -32
  119. google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +11 -1
  120. google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
  121. google/adk/tools/preload_memory_tool.py +27 -18
  122. google/adk/tools/retrieval/__init__.py +1 -1
  123. google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
  124. google/adk/tools/toolbox_toolset.py +107 -0
  125. google/adk/tools/transfer_to_agent_tool.py +0 -1
  126. google/adk/utils/__init__.py +13 -0
  127. google/adk/utils/instructions_utils.py +131 -0
  128. google/adk/version.py +1 -1
  129. {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/METADATA +18 -19
  130. google_adk-1.1.0.dist-info/RECORD +200 -0
  131. google/adk/agents/remote_agent.py +0 -50
  132. google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -18
  133. google/adk/cli/fast_api.py.orig +0 -728
  134. google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
  135. google/adk/tools/toolbox_tool.py +0 -46
  136. google_adk-0.5.0.dist-info/RECORD +0 -180
  137. {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/WHEEL +0 -0
  138. {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/entry_points.txt +0 -0
  139. {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/licenses/LICENSE +0 -0
google/adk/runners.py CHANGED
@@ -21,8 +21,8 @@ import threading
21
21
  from typing import AsyncGenerator
22
22
  from typing import Generator
23
23
  from typing import Optional
24
+ import warnings
24
25
 
25
- from deprecated import deprecated
26
26
  from google.genai import types
27
27
 
28
28
  from .agents.active_streaming_tool import ActiveStreamingTool
@@ -32,9 +32,9 @@ from .agents.invocation_context import new_invocation_context_id
32
32
  from .agents.live_request_queue import LiveRequestQueue
33
33
  from .agents.llm_agent import LlmAgent
34
34
  from .agents.run_config import RunConfig
35
- from .agents.run_config import StreamingMode
36
35
  from .artifacts.base_artifact_service import BaseArtifactService
37
36
  from .artifacts.in_memory_artifact_service import InMemoryArtifactService
37
+ from .code_executors.built_in_code_executor import BuiltInCodeExecutor
38
38
  from .events.event import Event
39
39
  from .memory.base_memory_service import BaseMemoryService
40
40
  from .memory.in_memory_memory_service import InMemoryMemoryService
@@ -42,9 +42,9 @@ from .sessions.base_session_service import BaseSessionService
42
42
  from .sessions.in_memory_session_service import InMemorySessionService
43
43
  from .sessions.session import Session
44
44
  from .telemetry import tracer
45
- from .tools.built_in_code_execution_tool import built_in_code_execution
45
+ from .tools.base_toolset import BaseToolset
46
46
 
47
- logger = logging.getLogger(__name__)
47
+ logger = logging.getLogger('google_adk.' + __name__)
48
48
 
49
49
 
50
50
  class Runner:
@@ -172,7 +172,7 @@ class Runner:
172
172
  The events generated by the agent.
173
173
  """
174
174
  with tracer.start_as_current_span('invocation'):
175
- session = self.session_service.get_session(
175
+ session = await self.session_service.get_session(
176
176
  app_name=self.app_name, user_id=user_id, session_id=session_id
177
177
  )
178
178
  if not session:
@@ -196,7 +196,7 @@ class Runner:
196
196
  invocation_context.agent = self._find_agent_to_run(session, root_agent)
197
197
  async for event in invocation_context.agent.run_async(invocation_context):
198
198
  if not event.partial:
199
- self.session_service.append_event(session=session, event=event)
199
+ await self.session_service.append_event(session=session, event=event)
200
200
  yield event
201
201
 
202
202
  async def _append_new_message_to_session(
@@ -241,30 +241,57 @@ class Runner:
241
241
  author='user',
242
242
  content=new_message,
243
243
  )
244
- self.session_service.append_event(session=session, event=event)
244
+ await self.session_service.append_event(session=session, event=event)
245
245
 
246
246
  async def run_live(
247
247
  self,
248
248
  *,
249
- session: Session,
249
+ user_id: Optional[str] = None,
250
+ session_id: Optional[str] = None,
250
251
  live_request_queue: LiveRequestQueue,
251
252
  run_config: RunConfig = RunConfig(),
253
+ session: Optional[Session] = None,
252
254
  ) -> AsyncGenerator[Event, None]:
253
255
  """Runs the agent in live mode (experimental feature).
254
256
 
255
257
  Args:
256
- session: The session to use.
258
+ user_id: The user ID for the session. Required if `session` is None.
259
+ session_id: The session ID for the session. Required if `session` is
260
+ None.
257
261
  live_request_queue: The queue for live requests.
258
262
  run_config: The run config for the agent.
263
+ session: The session to use. This parameter is deprecated, please use
264
+ `user_id` and `session_id` instead.
259
265
 
260
266
  Yields:
261
- The events generated by the agent.
267
+ AsyncGenerator[Event, None]: An asynchronous generator that yields
268
+ `Event`
269
+ objects as they are produced by the agent during its live execution.
262
270
 
263
271
  .. warning::
264
272
  This feature is **experimental** and its API or behavior may change
265
273
  in future releases.
274
+
275
+ .. note::
276
+ Either `session` or both `user_id` and `session_id` must be provided.
266
277
  """
267
- # TODO: right now, only works for a single audio agent without FC.
278
+ if session is None and (user_id is None or session_id is None):
279
+ raise ValueError(
280
+ 'Either session or user_id and session_id must be provided.'
281
+ )
282
+ if session is not None:
283
+ warnings.warn(
284
+ 'The `session` parameter is deprecated. Please use `user_id` and'
285
+ ' `session_id` instead.',
286
+ DeprecationWarning,
287
+ stacklevel=2,
288
+ )
289
+ if not session:
290
+ session = await self.session_service.get_session(
291
+ app_name=self.app_name, user_id=user_id, session_id=session_id
292
+ )
293
+ if not session:
294
+ raise ValueError(f'Session not found: {session_id}')
268
295
  invocation_context = self._new_invocation_context_for_live(
269
296
  session,
270
297
  live_request_queue=live_request_queue,
@@ -276,37 +303,29 @@ class Runner:
276
303
 
277
304
  invocation_context.active_streaming_tools = {}
278
305
  # TODO(hangfei): switch to use canonical_tools.
279
- for tool in invocation_context.agent.tools:
280
- # replicate a LiveRequestQueue for streaming tools that relis on
281
- # LiveRequestQueue
282
- from typing import get_type_hints
283
-
284
- type_hints = get_type_hints(tool)
285
- for arg_type in type_hints.values():
286
- if arg_type is LiveRequestQueue:
287
- if not invocation_context.active_streaming_tools:
288
- invocation_context.active_streaming_tools = {}
289
- active_streaming_tools = ActiveStreamingTool(
290
- stream=LiveRequestQueue()
291
- )
292
- invocation_context.active_streaming_tools[tool.__name__] = (
293
- active_streaming_tools
294
- )
306
+ # for shell agents, there is no tools associated with it so we should skip.
307
+ if hasattr(invocation_context.agent, 'tools'):
308
+ for tool in invocation_context.agent.tools:
309
+ # replicate a LiveRequestQueue for streaming tools that relis on
310
+ # LiveRequestQueue
311
+ from typing import get_type_hints
312
+
313
+ type_hints = get_type_hints(tool)
314
+ for arg_type in type_hints.values():
315
+ if arg_type is LiveRequestQueue:
316
+ if not invocation_context.active_streaming_tools:
317
+ invocation_context.active_streaming_tools = {}
318
+ active_streaming_tools = ActiveStreamingTool(
319
+ stream=LiveRequestQueue()
320
+ )
321
+ invocation_context.active_streaming_tools[tool.__name__] = (
322
+ active_streaming_tools
323
+ )
295
324
 
296
325
  async for event in invocation_context.agent.run_live(invocation_context):
297
- self.session_service.append_event(session=session, event=event)
326
+ await self.session_service.append_event(session=session, event=event)
298
327
  yield event
299
328
 
300
- async def close_session(self, session: Session):
301
- """Closes a session and adds it to the memory service (experimental feature).
302
-
303
- Args:
304
- session: The session to close.
305
- """
306
- if self.memory_service:
307
- await self.memory_service.add_session_to_memory(session)
308
- self.session_service.close_session(session=session)
309
-
310
329
  def _find_agent_to_run(
311
330
  self, session: Session, root_agent: BaseAgent
312
331
  ) -> BaseAgent:
@@ -391,8 +410,8 @@ class Runner:
391
410
  f'CFC is not supported for model: {model_name} in agent:'
392
411
  f' {self.agent.name}'
393
412
  )
394
- if built_in_code_execution not in self.agent.canonical_tools:
395
- self.agent.tools.append(built_in_code_execution)
413
+ if not isinstance(self.agent.code_executor, BuiltInCodeExecutor):
414
+ self.agent.code_executor = BuiltInCodeExecutor()
396
415
 
397
416
  return InvocationContext(
398
417
  artifact_service=self.artifact_service,
@@ -430,12 +449,46 @@ class Runner:
430
449
  run_config.output_audio_transcription = (
431
450
  types.AudioTranscriptionConfig()
432
451
  )
452
+ if not run_config.input_audio_transcription:
453
+ # need this input transcription for agent transferring in live mode.
454
+ run_config.input_audio_transcription = types.AudioTranscriptionConfig()
433
455
  return self._new_invocation_context(
434
456
  session,
435
457
  live_request_queue=live_request_queue,
436
458
  run_config=run_config,
437
459
  )
438
460
 
461
+ def _collect_toolset(self, agent: BaseAgent) -> set[BaseToolset]:
462
+ toolsets = set()
463
+ if isinstance(agent, LlmAgent):
464
+ for tool_union in agent.tools:
465
+ if isinstance(tool_union, BaseToolset):
466
+ toolsets.add(tool_union)
467
+ for sub_agent in agent.sub_agents:
468
+ toolsets.update(self._collect_toolset(sub_agent))
469
+ return toolsets
470
+
471
+ async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]):
472
+ """Clean up toolsets with proper task context management."""
473
+ if not toolsets_to_close:
474
+ return
475
+
476
+ # This maintains the same task context throughout cleanup
477
+ for toolset in toolsets_to_close:
478
+ try:
479
+ logger.info('Closing toolset: %s', type(toolset).__name__)
480
+ # Use asyncio.wait_for to add timeout protection
481
+ await asyncio.wait_for(toolset.close(), timeout=10.0)
482
+ logger.info('Successfully closed toolset: %s', type(toolset).__name__)
483
+ except asyncio.TimeoutError:
484
+ logger.warning('Toolset %s cleanup timed out', type(toolset).__name__)
485
+ except Exception as e:
486
+ logger.error('Error closing toolset %s: %s', type(toolset).__name__, e)
487
+
488
+ async def close(self):
489
+ """Closes the runner."""
490
+ await self._cleanup_toolsets(self._collect_toolset(self.agent))
491
+
439
492
 
440
493
  class InMemoryRunner(Runner):
441
494
  """An in-memory Runner for testing and development.
@@ -448,9 +501,11 @@ class InMemoryRunner(Runner):
448
501
  agent: The root agent to run.
449
502
  app_name: The application name of the runner. Defaults to
450
503
  'InMemoryRunner'.
504
+ _in_memory_session_service: Deprecated. Please don't use. The in-memory
505
+ session service for the runner.
451
506
  """
452
507
 
453
- def __init__(self, agent: LlmAgent, *, app_name: str = 'InMemoryRunner'):
508
+ def __init__(self, agent: BaseAgent, *, app_name: str = 'InMemoryRunner'):
454
509
  """Initializes the InMemoryRunner.
455
510
 
456
511
  Args:
@@ -458,10 +513,11 @@ class InMemoryRunner(Runner):
458
513
  app_name: The application name of the runner. Defaults to
459
514
  'InMemoryRunner'.
460
515
  """
516
+ self._in_memory_session_service = InMemorySessionService()
461
517
  super().__init__(
462
518
  app_name=app_name,
463
519
  agent=agent,
464
520
  artifact_service=InMemoryArtifactService(),
465
- session_service=InMemorySessionService(),
521
+ session_service=self._in_memory_session_service,
466
522
  memory_service=InMemoryMemoryService(),
467
523
  )
@@ -19,7 +19,7 @@ from .session import Session
19
19
  from .state import State
20
20
  from .vertex_ai_session_service import VertexAiSessionService
21
21
 
22
- logger = logging.getLogger(__name__)
22
+ logger = logging.getLogger('google_adk.' + __name__)
23
23
 
24
24
 
25
25
  __all__ = [
@@ -15,7 +15,8 @@
15
15
  """Utility functions for session service."""
16
16
 
17
17
  import base64
18
- from typing import Any, Optional
18
+ from typing import Any
19
+ from typing import Optional
19
20
 
20
21
  from google.genai import types
21
22
 
@@ -40,13 +40,6 @@ class ListSessionsResponse(BaseModel):
40
40
  sessions: list[Session] = Field(default_factory=list)
41
41
 
42
42
 
43
- class ListEventsResponse(BaseModel):
44
- """The response of listing events in a session."""
45
-
46
- events: list[Event] = Field(default_factory=list)
47
- next_page_token: Optional[str] = None
48
-
49
-
50
43
  class BaseSessionService(abc.ABC):
51
44
  """Base class for session services.
52
45
 
@@ -54,7 +47,7 @@ class BaseSessionService(abc.ABC):
54
47
  """
55
48
 
56
49
  @abc.abstractmethod
57
- def create_session(
50
+ async def create_session(
58
51
  self,
59
52
  *,
60
53
  app_name: str,
@@ -74,10 +67,9 @@ class BaseSessionService(abc.ABC):
74
67
  Returns:
75
68
  session: The newly created session instance.
76
69
  """
77
- pass
78
70
 
79
71
  @abc.abstractmethod
80
- def get_session(
72
+ async def get_session(
81
73
  self,
82
74
  *,
83
75
  app_name: str,
@@ -86,39 +78,20 @@ class BaseSessionService(abc.ABC):
86
78
  config: Optional[GetSessionConfig] = None,
87
79
  ) -> Optional[Session]:
88
80
  """Gets a session."""
89
- pass
90
81
 
91
82
  @abc.abstractmethod
92
- def list_sessions(
83
+ async def list_sessions(
93
84
  self, *, app_name: str, user_id: str
94
85
  ) -> ListSessionsResponse:
95
86
  """Lists all the sessions."""
96
- pass
97
87
 
98
88
  @abc.abstractmethod
99
- def delete_session(
89
+ async def delete_session(
100
90
  self, *, app_name: str, user_id: str, session_id: str
101
91
  ) -> None:
102
92
  """Deletes a session."""
103
- pass
104
-
105
- @abc.abstractmethod
106
- def list_events(
107
- self,
108
- *,
109
- app_name: str,
110
- user_id: str,
111
- session_id: str,
112
- ) -> ListEventsResponse:
113
- """Lists events in a session."""
114
- pass
115
-
116
- def close_session(self, *, session: Session):
117
- """Closes a session."""
118
- # TODO: determine whether we want to finalize the session here.
119
- pass
120
93
 
121
- def append_event(self, session: Session, event: Event) -> Event:
94
+ async def append_event(self, session: Session, event: Event) -> Event:
122
95
  """Appends an event to a session object."""
123
96
  if event.partial:
124
97
  return event
@@ -126,7 +99,7 @@ class BaseSessionService(abc.ABC):
126
99
  session.events.append(event)
127
100
  return event
128
101
 
129
- def __update_session_state(self, session: Session, event: Event):
102
+ def __update_session_state(self, session: Session, event: Event) -> None:
130
103
  """Updates the session state based on the event."""
131
104
  if not event.actions or not event.actions.state_delta:
132
105
  return
@@ -15,7 +15,8 @@ import copy
15
15
  from datetime import datetime
16
16
  import json
17
17
  import logging
18
- from typing import Any, Optional
18
+ from typing import Any
19
+ from typing import Optional
19
20
  import uuid
20
21
 
21
22
  from sqlalchemy import Boolean
@@ -45,27 +46,22 @@ from sqlalchemy.types import TypeDecorator
45
46
  from typing_extensions import override
46
47
  from tzlocal import get_localzone
47
48
 
48
- from ..events.event import Event
49
49
  from . import _session_util
50
+ from ..events.event import Event
50
51
  from .base_session_service import BaseSessionService
51
52
  from .base_session_service import GetSessionConfig
52
- from .base_session_service import ListEventsResponse
53
53
  from .base_session_service import ListSessionsResponse
54
54
  from .session import Session
55
55
  from .state import State
56
56
 
57
-
58
- logger = logging.getLogger(__name__)
57
+ logger = logging.getLogger("google_adk." + __name__)
59
58
 
60
59
  DEFAULT_MAX_KEY_LENGTH = 128
61
60
  DEFAULT_MAX_VARCHAR_LENGTH = 256
62
61
 
63
62
 
64
63
  class DynamicJSON(TypeDecorator):
65
- """A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
66
-
67
- serialization for other databases.
68
- """
64
+ """A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON serialization for other databases."""
69
65
 
70
66
  impl = Text # Default implementation is TEXT
71
67
 
@@ -243,10 +239,7 @@ class DatabaseSessionService(BaseSessionService):
243
239
  """A session service that uses a database for storage."""
244
240
 
245
241
  def __init__(self, db_url: str):
246
- """
247
- Args:
248
- db_url: The database URL to connect to.
249
- """
242
+ """Initializes the database session service with a database URL."""
250
243
  # 1. Create DB engine for db connection
251
244
  # 2. Create all tables based on schema
252
245
  # 3. Initialize all properties
@@ -275,7 +268,7 @@ class DatabaseSessionService(BaseSessionService):
275
268
  self.inspector = inspect(self.db_engine)
276
269
 
277
270
  # DB session factory method
278
- self.DatabaseSessionFactory: sessionmaker[DatabaseSessionFactory] = (
271
+ self.database_session_factory: sessionmaker[DatabaseSessionFactory] = (
279
272
  sessionmaker(bind=self.db_engine)
280
273
  )
281
274
 
@@ -284,7 +277,7 @@ class DatabaseSessionService(BaseSessionService):
284
277
  Base.metadata.create_all(self.db_engine)
285
278
 
286
279
  @override
287
- def create_session(
280
+ async def create_session(
288
281
  self,
289
282
  *,
290
283
  app_name: str,
@@ -298,11 +291,11 @@ class DatabaseSessionService(BaseSessionService):
298
291
  # 4. Build the session object with generated id
299
292
  # 5. Return the session
300
293
 
301
- with self.DatabaseSessionFactory() as sessionFactory:
294
+ with self.database_session_factory() as session_factory:
302
295
 
303
296
  # Fetch app and user states from storage
304
- storage_app_state = sessionFactory.get(StorageAppState, (app_name))
305
- storage_user_state = sessionFactory.get(
297
+ storage_app_state = session_factory.get(StorageAppState, (app_name))
298
+ storage_user_state = session_factory.get(
306
299
  StorageUserState, (app_name, user_id)
307
300
  )
308
301
 
@@ -312,12 +305,12 @@ class DatabaseSessionService(BaseSessionService):
312
305
  # Create state tables if not exist
313
306
  if not storage_app_state:
314
307
  storage_app_state = StorageAppState(app_name=app_name, state={})
315
- sessionFactory.add(storage_app_state)
308
+ session_factory.add(storage_app_state)
316
309
  if not storage_user_state:
317
310
  storage_user_state = StorageUserState(
318
311
  app_name=app_name, user_id=user_id, state={}
319
312
  )
320
- sessionFactory.add(storage_user_state)
313
+ session_factory.add(storage_user_state)
321
314
 
322
315
  # Extract state deltas
323
316
  app_state_delta, user_state_delta, session_state = _extract_state_delta(
@@ -341,10 +334,10 @@ class DatabaseSessionService(BaseSessionService):
341
334
  id=session_id,
342
335
  state=session_state,
343
336
  )
344
- sessionFactory.add(storage_session)
345
- sessionFactory.commit()
337
+ session_factory.add(storage_session)
338
+ session_factory.commit()
346
339
 
347
- sessionFactory.refresh(storage_session)
340
+ session_factory.refresh(storage_session)
348
341
 
349
342
  # Merge states for response
350
343
  merged_state = _merge_state(app_state, user_state, session_state)
@@ -358,7 +351,7 @@ class DatabaseSessionService(BaseSessionService):
358
351
  return session
359
352
 
360
353
  @override
361
- def get_session(
354
+ async def get_session(
362
355
  self,
363
356
  *,
364
357
  app_name: str,
@@ -369,29 +362,35 @@ class DatabaseSessionService(BaseSessionService):
369
362
  # 1. Get the storage session entry from session table
370
363
  # 2. Get all the events based on session id and filtering config
371
364
  # 3. Convert and return the session
372
- with self.DatabaseSessionFactory() as sessionFactory:
373
- storage_session = sessionFactory.get(
365
+ with self.database_session_factory() as session_factory:
366
+ storage_session = session_factory.get(
374
367
  StorageSession, (app_name, user_id, session_id)
375
368
  )
376
369
  if storage_session is None:
377
370
  return None
378
371
 
372
+ if config and config.after_timestamp:
373
+ after_dt = datetime.fromtimestamp(config.after_timestamp)
374
+ timestamp_filter = StorageEvent.timestamp >= after_dt
375
+ else:
376
+ timestamp_filter = True
377
+
379
378
  storage_events = (
380
- sessionFactory.query(StorageEvent)
379
+ session_factory.query(StorageEvent)
381
380
  .filter(StorageEvent.session_id == storage_session.id)
382
- .filter(
383
- StorageEvent.timestamp < config.after_timestamp
384
- if config
385
- else True
381
+ .filter(timestamp_filter)
382
+ .order_by(StorageEvent.timestamp.desc())
383
+ .limit(
384
+ config.num_recent_events
385
+ if config and config.num_recent_events
386
+ else None
386
387
  )
387
- .limit(config.num_recent_events if config else None)
388
- .order_by(StorageEvent.timestamp.asc())
389
388
  .all()
390
389
  )
391
390
 
392
391
  # Fetch states from storage
393
- storage_app_state = sessionFactory.get(StorageAppState, (app_name))
394
- storage_user_state = sessionFactory.get(
392
+ storage_app_state = session_factory.get(StorageAppState, (app_name))
393
+ storage_user_state = session_factory.get(
395
394
  StorageUserState, (app_name, user_id)
396
395
  )
397
396
 
@@ -427,17 +426,17 @@ class DatabaseSessionService(BaseSessionService):
427
426
  error_message=e.error_message,
428
427
  interrupted=e.interrupted,
429
428
  )
430
- for e in storage_events
429
+ for e in reversed(storage_events)
431
430
  ]
432
431
  return session
433
432
 
434
433
  @override
435
- def list_sessions(
434
+ async def list_sessions(
436
435
  self, *, app_name: str, user_id: str
437
436
  ) -> ListSessionsResponse:
438
- with self.DatabaseSessionFactory() as sessionFactory:
437
+ with self.database_session_factory() as session_factory:
439
438
  results = (
440
- sessionFactory.query(StorageSession)
439
+ session_factory.query(StorageSession)
441
440
  .filter(StorageSession.app_name == app_name)
442
441
  .filter(StorageSession.user_id == user_id)
443
442
  .all()
@@ -455,20 +454,20 @@ class DatabaseSessionService(BaseSessionService):
455
454
  return ListSessionsResponse(sessions=sessions)
456
455
 
457
456
  @override
458
- def delete_session(
457
+ async def delete_session(
459
458
  self, app_name: str, user_id: str, session_id: str
460
459
  ) -> None:
461
- with self.DatabaseSessionFactory() as sessionFactory:
460
+ with self.database_session_factory() as session_factory:
462
461
  stmt = delete(StorageSession).where(
463
462
  StorageSession.app_name == app_name,
464
463
  StorageSession.user_id == user_id,
465
464
  StorageSession.id == session_id,
466
465
  )
467
- sessionFactory.execute(stmt)
468
- sessionFactory.commit()
466
+ session_factory.execute(stmt)
467
+ session_factory.commit()
469
468
 
470
469
  @override
471
- def append_event(self, session: Session, event: Event) -> Event:
470
+ async def append_event(self, session: Session, event: Event) -> Event:
472
471
  logger.info(f"Append event: {event} to session {session.id}")
473
472
 
474
473
  if event.partial:
@@ -477,24 +476,25 @@ class DatabaseSessionService(BaseSessionService):
477
476
  # 1. Check if timestamp is stale
478
477
  # 2. Update session attributes based on event config
479
478
  # 3. Store event to table
480
- with self.DatabaseSessionFactory() as sessionFactory:
481
- storage_session = sessionFactory.get(
479
+ with self.database_session_factory() as session_factory:
480
+ storage_session = session_factory.get(
482
481
  StorageSession, (session.app_name, session.user_id, session.id)
483
482
  )
484
483
 
485
484
  if storage_session.update_time.timestamp() > session.last_update_time:
486
485
  raise ValueError(
487
- f"Session last_update_time "
488
- f"{datetime.fromtimestamp(session.last_update_time):%Y-%m-%d %H:%M:%S} "
489
- f"is later than the update_time in storage "
490
- f"{storage_session.update_time:%Y-%m-%d %H:%M:%S}"
491
- )
486
+ "The last_update_time provided in the session object"
487
+ f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is"
488
+ " earlier than the update_time in the storage_session"
489
+ f" {storage_session.update_time:'%Y-%m-%d %H:%M:%S'}. Please check"
490
+ " if it is a stale session."
491
+ )
492
492
 
493
493
  # Fetch states from storage
494
- storage_app_state = sessionFactory.get(
494
+ storage_app_state = session_factory.get(
495
495
  StorageAppState, (session.app_name)
496
496
  )
497
- storage_user_state = sessionFactory.get(
497
+ storage_user_state = session_factory.get(
498
498
  StorageUserState, (session.app_name, session.user_id)
499
499
  )
500
500
 
@@ -543,28 +543,18 @@ class DatabaseSessionService(BaseSessionService):
543
543
  if event.content:
544
544
  storage_event.content = _session_util.encode_content(event.content)
545
545
 
546
- sessionFactory.add(storage_event)
546
+ session_factory.add(storage_event)
547
547
 
548
- sessionFactory.commit()
549
- sessionFactory.refresh(storage_session)
548
+ session_factory.commit()
549
+ session_factory.refresh(storage_session)
550
550
 
551
551
  # Update timestamp with commit time
552
552
  session.last_update_time = storage_session.update_time.timestamp()
553
553
 
554
554
  # Also update the in-memory session
555
- super().append_event(session=session, event=event)
555
+ await super().append_event(session=session, event=event)
556
556
  return event
557
557
 
558
- @override
559
- def list_events(
560
- self,
561
- *,
562
- app_name: str,
563
- user_id: str,
564
- session_id: str,
565
- ) -> ListEventsResponse:
566
- raise NotImplementedError()
567
-
568
558
 
569
559
  def convert_event(event: StorageEvent) -> Event:
570
560
  """Converts a storage event to an event."""