google-adk 0.5.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. google/adk/agents/base_agent.py +76 -30
  2. google/adk/agents/base_agent.py.orig +330 -0
  3. google/adk/agents/callback_context.py +0 -5
  4. google/adk/agents/llm_agent.py +122 -30
  5. google/adk/agents/loop_agent.py +1 -1
  6. google/adk/agents/parallel_agent.py +7 -0
  7. google/adk/agents/readonly_context.py +7 -1
  8. google/adk/agents/run_config.py +1 -1
  9. google/adk/agents/sequential_agent.py +31 -0
  10. google/adk/agents/transcription_entry.py +4 -2
  11. google/adk/artifacts/gcs_artifact_service.py +1 -1
  12. google/adk/artifacts/in_memory_artifact_service.py +1 -1
  13. google/adk/auth/auth_credential.py +6 -1
  14. google/adk/auth/auth_preprocessor.py +7 -1
  15. google/adk/auth/auth_tool.py +3 -4
  16. google/adk/cli/agent_graph.py +5 -5
  17. google/adk/cli/browser/index.html +2 -2
  18. google/adk/cli/browser/{main-ULN5R5I5.js → main-QOEMUXM4.js} +44 -45
  19. google/adk/cli/cli.py +7 -7
  20. google/adk/cli/cli_deploy.py +7 -2
  21. google/adk/cli/cli_eval.py +172 -99
  22. google/adk/cli/cli_tools_click.py +147 -64
  23. google/adk/cli/fast_api.py +330 -148
  24. google/adk/cli/fast_api.py.orig +174 -80
  25. google/adk/cli/utils/common.py +23 -0
  26. google/adk/cli/utils/evals.py +83 -1
  27. google/adk/cli/utils/logs.py +13 -5
  28. google/adk/code_executors/__init__.py +3 -1
  29. google/adk/code_executors/built_in_code_executor.py +52 -0
  30. google/adk/evaluation/__init__.py +1 -1
  31. google/adk/evaluation/agent_evaluator.py +168 -128
  32. google/adk/evaluation/eval_case.py +102 -0
  33. google/adk/evaluation/eval_set.py +37 -0
  34. google/adk/evaluation/eval_sets_manager.py +42 -0
  35. google/adk/evaluation/evaluation_generator.py +88 -113
  36. google/adk/evaluation/evaluator.py +56 -0
  37. google/adk/evaluation/local_eval_sets_manager.py +264 -0
  38. google/adk/evaluation/response_evaluator.py +106 -2
  39. google/adk/evaluation/trajectory_evaluator.py +83 -2
  40. google/adk/events/event.py +6 -1
  41. google/adk/events/event_actions.py +6 -1
  42. google/adk/examples/example_util.py +3 -2
  43. google/adk/flows/llm_flows/_code_execution.py +9 -1
  44. google/adk/flows/llm_flows/audio_transcriber.py +4 -3
  45. google/adk/flows/llm_flows/base_llm_flow.py +54 -15
  46. google/adk/flows/llm_flows/functions.py +9 -8
  47. google/adk/flows/llm_flows/instructions.py +13 -5
  48. google/adk/flows/llm_flows/single_flow.py +1 -1
  49. google/adk/memory/__init__.py +1 -1
  50. google/adk/memory/_utils.py +23 -0
  51. google/adk/memory/base_memory_service.py +23 -21
  52. google/adk/memory/base_memory_service.py.orig +76 -0
  53. google/adk/memory/in_memory_memory_service.py +57 -25
  54. google/adk/memory/memory_entry.py +37 -0
  55. google/adk/memory/vertex_ai_rag_memory_service.py +38 -15
  56. google/adk/models/anthropic_llm.py +16 -9
  57. google/adk/models/gemini_llm_connection.py +11 -11
  58. google/adk/models/google_llm.py +9 -2
  59. google/adk/models/google_llm.py.orig +305 -0
  60. google/adk/models/lite_llm.py +77 -21
  61. google/adk/models/llm_response.py +14 -2
  62. google/adk/models/registry.py +1 -1
  63. google/adk/runners.py +65 -41
  64. google/adk/sessions/__init__.py +1 -1
  65. google/adk/sessions/base_session_service.py +6 -33
  66. google/adk/sessions/database_session_service.py +58 -65
  67. google/adk/sessions/in_memory_session_service.py +106 -24
  68. google/adk/sessions/session.py +3 -0
  69. google/adk/sessions/vertex_ai_session_service.py +23 -45
  70. google/adk/telemetry.py +3 -0
  71. google/adk/tools/__init__.py +4 -7
  72. google/adk/tools/{built_in_code_execution_tool.py → _built_in_code_execution_tool.py} +11 -0
  73. google/adk/tools/_memory_entry_utils.py +30 -0
  74. google/adk/tools/agent_tool.py +9 -9
  75. google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
  76. google/adk/tools/application_integration_tool/application_integration_toolset.py +107 -85
  77. google/adk/tools/application_integration_tool/clients/connections_client.py +20 -0
  78. google/adk/tools/application_integration_tool/clients/integration_client.py +6 -6
  79. google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
  80. google/adk/tools/base_toolset.py +58 -0
  81. google/adk/tools/enterprise_search_tool.py +65 -0
  82. google/adk/tools/function_parameter_parse_util.py +2 -2
  83. google/adk/tools/google_api_tool/__init__.py +18 -70
  84. google/adk/tools/google_api_tool/google_api_tool.py +11 -5
  85. google/adk/tools/google_api_tool/google_api_toolset.py +126 -0
  86. google/adk/tools/google_api_tool/google_api_toolsets.py +102 -0
  87. google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
  88. google/adk/tools/langchain_tool.py +96 -49
  89. google/adk/tools/load_memory_tool.py +14 -5
  90. google/adk/tools/mcp_tool/__init__.py +3 -2
  91. google/adk/tools/mcp_tool/mcp_session_manager.py +153 -16
  92. google/adk/tools/mcp_tool/mcp_session_manager.py.orig +322 -0
  93. google/adk/tools/mcp_tool/mcp_tool.py +12 -12
  94. google/adk/tools/mcp_tool/mcp_toolset.py +155 -195
  95. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +32 -7
  96. google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +31 -31
  97. google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
  98. google/adk/tools/preload_memory_tool.py +27 -18
  99. google/adk/tools/retrieval/__init__.py +1 -1
  100. google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
  101. google/adk/tools/toolbox_toolset.py +79 -0
  102. google/adk/tools/transfer_to_agent_tool.py +0 -1
  103. google/adk/version.py +1 -1
  104. {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/METADATA +7 -5
  105. google_adk-1.0.0.dist-info/RECORD +195 -0
  106. google/adk/agents/remote_agent.py +0 -50
  107. google/adk/tools/google_api_tool/google_api_tool_set.py +0 -110
  108. google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
  109. google/adk/tools/toolbox_tool.py +0 -46
  110. google_adk-0.5.0.dist-info/RECORD +0 -180
  111. {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/WHEEL +0 -0
  112. {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/entry_points.txt +0 -0
  113. {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/licenses/LICENSE +0 -0
google/adk/runners.py CHANGED
@@ -21,8 +21,8 @@ import threading
21
21
  from typing import AsyncGenerator
22
22
  from typing import Generator
23
23
  from typing import Optional
24
+ import warnings
24
25
 
25
- from deprecated import deprecated
26
26
  from google.genai import types
27
27
 
28
28
  from .agents.active_streaming_tool import ActiveStreamingTool
@@ -32,7 +32,6 @@ from .agents.invocation_context import new_invocation_context_id
32
32
  from .agents.live_request_queue import LiveRequestQueue
33
33
  from .agents.llm_agent import LlmAgent
34
34
  from .agents.run_config import RunConfig
35
- from .agents.run_config import StreamingMode
36
35
  from .artifacts.base_artifact_service import BaseArtifactService
37
36
  from .artifacts.in_memory_artifact_service import InMemoryArtifactService
38
37
  from .events.event import Event
@@ -42,9 +41,9 @@ from .sessions.base_session_service import BaseSessionService
42
41
  from .sessions.in_memory_session_service import InMemorySessionService
43
42
  from .sessions.session import Session
44
43
  from .telemetry import tracer
45
- from .tools.built_in_code_execution_tool import built_in_code_execution
44
+ from .tools._built_in_code_execution_tool import built_in_code_execution
46
45
 
47
- logger = logging.getLogger(__name__)
46
+ logger = logging.getLogger('google_adk.' + __name__)
48
47
 
49
48
 
50
49
  class Runner:
@@ -172,7 +171,7 @@ class Runner:
172
171
  The events generated by the agent.
173
172
  """
174
173
  with tracer.start_as_current_span('invocation'):
175
- session = self.session_service.get_session(
174
+ session = await self.session_service.get_session(
176
175
  app_name=self.app_name, user_id=user_id, session_id=session_id
177
176
  )
178
177
  if not session:
@@ -196,7 +195,7 @@ class Runner:
196
195
  invocation_context.agent = self._find_agent_to_run(session, root_agent)
197
196
  async for event in invocation_context.agent.run_async(invocation_context):
198
197
  if not event.partial:
199
- self.session_service.append_event(session=session, event=event)
198
+ await self.session_service.append_event(session=session, event=event)
200
199
  yield event
201
200
 
202
201
  async def _append_new_message_to_session(
@@ -241,30 +240,57 @@ class Runner:
241
240
  author='user',
242
241
  content=new_message,
243
242
  )
244
- self.session_service.append_event(session=session, event=event)
243
+ await self.session_service.append_event(session=session, event=event)
245
244
 
246
245
  async def run_live(
247
246
  self,
248
247
  *,
249
- session: Session,
248
+ user_id: Optional[str] = None,
249
+ session_id: Optional[str] = None,
250
250
  live_request_queue: LiveRequestQueue,
251
251
  run_config: RunConfig = RunConfig(),
252
+ session: Optional[Session] = None,
252
253
  ) -> AsyncGenerator[Event, None]:
253
254
  """Runs the agent in live mode (experimental feature).
254
255
 
255
256
  Args:
256
- session: The session to use.
257
+ user_id: The user ID for the session. Required if `session` is None.
258
+ session_id: The session ID for the session. Required if `session` is
259
+ None.
257
260
  live_request_queue: The queue for live requests.
258
261
  run_config: The run config for the agent.
262
+ session: The session to use. This parameter is deprecated, please use
263
+ `user_id` and `session_id` instead.
259
264
 
260
265
  Yields:
261
- The events generated by the agent.
266
+ AsyncGenerator[Event, None]: An asynchronous generator that yields
267
+ `Event`
268
+ objects as they are produced by the agent during its live execution.
262
269
 
263
270
  .. warning::
264
271
  This feature is **experimental** and its API or behavior may change
265
272
  in future releases.
273
+
274
+ .. note::
275
+ Either `session` or both `user_id` and `session_id` must be provided.
266
276
  """
267
- # TODO: right now, only works for a single audio agent without FC.
277
+ if session is None and (user_id is None or session_id is None):
278
+ raise ValueError(
279
+ 'Either session or user_id and session_id must be provided.'
280
+ )
281
+ if session is not None:
282
+ warnings.warn(
283
+ 'The `session` parameter is deprecated. Please use `user_id` and'
284
+ ' `session_id` instead.',
285
+ DeprecationWarning,
286
+ stacklevel=2,
287
+ )
288
+ if not session:
289
+ session = self.session_service.get_session(
290
+ app_name=self.app_name, user_id=user_id, session_id=session_id
291
+ )
292
+ if not session:
293
+ raise ValueError(f'Session not found: {session_id}')
268
294
  invocation_context = self._new_invocation_context_for_live(
269
295
  session,
270
296
  live_request_queue=live_request_queue,
@@ -276,37 +302,29 @@ class Runner:
276
302
 
277
303
  invocation_context.active_streaming_tools = {}
278
304
  # TODO(hangfei): switch to use canonical_tools.
279
- for tool in invocation_context.agent.tools:
280
- # replicate a LiveRequestQueue for streaming tools that relis on
281
- # LiveRequestQueue
282
- from typing import get_type_hints
283
-
284
- type_hints = get_type_hints(tool)
285
- for arg_type in type_hints.values():
286
- if arg_type is LiveRequestQueue:
287
- if not invocation_context.active_streaming_tools:
288
- invocation_context.active_streaming_tools = {}
289
- active_streaming_tools = ActiveStreamingTool(
290
- stream=LiveRequestQueue()
291
- )
292
- invocation_context.active_streaming_tools[tool.__name__] = (
293
- active_streaming_tools
294
- )
305
+ # for shell agents, there is no tools associated with it so we should skip.
306
+ if hasattr(invocation_context.agent, 'tools'):
307
+ for tool in invocation_context.agent.tools:
308
+ # replicate a LiveRequestQueue for streaming tools that relis on
309
+ # LiveRequestQueue
310
+ from typing import get_type_hints
311
+
312
+ type_hints = get_type_hints(tool)
313
+ for arg_type in type_hints.values():
314
+ if arg_type is LiveRequestQueue:
315
+ if not invocation_context.active_streaming_tools:
316
+ invocation_context.active_streaming_tools = {}
317
+ active_streaming_tools = ActiveStreamingTool(
318
+ stream=LiveRequestQueue()
319
+ )
320
+ invocation_context.active_streaming_tools[tool.__name__] = (
321
+ active_streaming_tools
322
+ )
295
323
 
296
324
  async for event in invocation_context.agent.run_live(invocation_context):
297
- self.session_service.append_event(session=session, event=event)
325
+ await self.session_service.append_event(session=session, event=event)
298
326
  yield event
299
327
 
300
- async def close_session(self, session: Session):
301
- """Closes a session and adds it to the memory service (experimental feature).
302
-
303
- Args:
304
- session: The session to close.
305
- """
306
- if self.memory_service:
307
- await self.memory_service.add_session_to_memory(session)
308
- self.session_service.close_session(session=session)
309
-
310
328
  def _find_agent_to_run(
311
329
  self, session: Session, root_agent: BaseAgent
312
330
  ) -> BaseAgent:
@@ -391,7 +409,7 @@ class Runner:
391
409
  f'CFC is not supported for model: {model_name} in agent:'
392
410
  f' {self.agent.name}'
393
411
  )
394
- if built_in_code_execution not in self.agent.canonical_tools:
412
+ if built_in_code_execution not in self.agent.canonical_tools():
395
413
  self.agent.tools.append(built_in_code_execution)
396
414
 
397
415
  return InvocationContext(
@@ -430,6 +448,9 @@ class Runner:
430
448
  run_config.output_audio_transcription = (
431
449
  types.AudioTranscriptionConfig()
432
450
  )
451
+ if not run_config.input_audio_transcription:
452
+ # need this input transcription for agent transferring in live mode.
453
+ run_config.input_audio_transcription = types.AudioTranscriptionConfig()
433
454
  return self._new_invocation_context(
434
455
  session,
435
456
  live_request_queue=live_request_queue,
@@ -448,9 +469,11 @@ class InMemoryRunner(Runner):
448
469
  agent: The root agent to run.
449
470
  app_name: The application name of the runner. Defaults to
450
471
  'InMemoryRunner'.
472
+ _in_memory_session_service: Deprecated. Please don't use. The in-memory
473
+ session service for the runner.
451
474
  """
452
475
 
453
- def __init__(self, agent: LlmAgent, *, app_name: str = 'InMemoryRunner'):
476
+ def __init__(self, agent: BaseAgent, *, app_name: str = 'InMemoryRunner'):
454
477
  """Initializes the InMemoryRunner.
455
478
 
456
479
  Args:
@@ -458,10 +481,11 @@ class InMemoryRunner(Runner):
458
481
  app_name: The application name of the runner. Defaults to
459
482
  'InMemoryRunner'.
460
483
  """
484
+ self._in_memory_session_service = InMemorySessionService()
461
485
  super().__init__(
462
486
  app_name=app_name,
463
487
  agent=agent,
464
488
  artifact_service=InMemoryArtifactService(),
465
- session_service=InMemorySessionService(),
489
+ session_service=self._in_memory_session_service,
466
490
  memory_service=InMemoryMemoryService(),
467
491
  )
@@ -19,7 +19,7 @@ from .session import Session
19
19
  from .state import State
20
20
  from .vertex_ai_session_service import VertexAiSessionService
21
21
 
22
- logger = logging.getLogger(__name__)
22
+ logger = logging.getLogger('google_adk.' + __name__)
23
23
 
24
24
 
25
25
  __all__ = [
@@ -40,13 +40,6 @@ class ListSessionsResponse(BaseModel):
40
40
  sessions: list[Session] = Field(default_factory=list)
41
41
 
42
42
 
43
- class ListEventsResponse(BaseModel):
44
- """The response of listing events in a session."""
45
-
46
- events: list[Event] = Field(default_factory=list)
47
- next_page_token: Optional[str] = None
48
-
49
-
50
43
  class BaseSessionService(abc.ABC):
51
44
  """Base class for session services.
52
45
 
@@ -54,7 +47,7 @@ class BaseSessionService(abc.ABC):
54
47
  """
55
48
 
56
49
  @abc.abstractmethod
57
- def create_session(
50
+ async def create_session(
58
51
  self,
59
52
  *,
60
53
  app_name: str,
@@ -74,10 +67,9 @@ class BaseSessionService(abc.ABC):
74
67
  Returns:
75
68
  session: The newly created session instance.
76
69
  """
77
- pass
78
70
 
79
71
  @abc.abstractmethod
80
- def get_session(
72
+ async def get_session(
81
73
  self,
82
74
  *,
83
75
  app_name: str,
@@ -86,39 +78,20 @@ class BaseSessionService(abc.ABC):
86
78
  config: Optional[GetSessionConfig] = None,
87
79
  ) -> Optional[Session]:
88
80
  """Gets a session."""
89
- pass
90
81
 
91
82
  @abc.abstractmethod
92
- def list_sessions(
83
+ async def list_sessions(
93
84
  self, *, app_name: str, user_id: str
94
85
  ) -> ListSessionsResponse:
95
86
  """Lists all the sessions."""
96
- pass
97
87
 
98
88
  @abc.abstractmethod
99
- def delete_session(
89
+ async def delete_session(
100
90
  self, *, app_name: str, user_id: str, session_id: str
101
91
  ) -> None:
102
92
  """Deletes a session."""
103
- pass
104
-
105
- @abc.abstractmethod
106
- def list_events(
107
- self,
108
- *,
109
- app_name: str,
110
- user_id: str,
111
- session_id: str,
112
- ) -> ListEventsResponse:
113
- """Lists events in a session."""
114
- pass
115
-
116
- def close_session(self, *, session: Session):
117
- """Closes a session."""
118
- # TODO: determine whether we want to finalize the session here.
119
- pass
120
93
 
121
- def append_event(self, session: Session, event: Event) -> Event:
94
+ async def append_event(self, session: Session, event: Event) -> Event:
122
95
  """Appends an event to a session object."""
123
96
  if event.partial:
124
97
  return event
@@ -126,7 +99,7 @@ class BaseSessionService(abc.ABC):
126
99
  session.events.append(event)
127
100
  return event
128
101
 
129
- def __update_session_state(self, session: Session, event: Event):
102
+ def __update_session_state(self, session: Session, event: Event) -> None:
130
103
  """Updates the session state based on the event."""
131
104
  if not event.actions or not event.actions.state_delta:
132
105
  return
@@ -13,9 +13,11 @@
13
13
  # limitations under the License.
14
14
  import copy
15
15
  from datetime import datetime
16
+ from datetime import timezone
16
17
  import json
17
18
  import logging
18
- from typing import Any, Optional
19
+ from typing import Any
20
+ from typing import Optional
19
21
  import uuid
20
22
 
21
23
  from sqlalchemy import Boolean
@@ -49,23 +51,18 @@ from ..events.event import Event
49
51
  from . import _session_util
50
52
  from .base_session_service import BaseSessionService
51
53
  from .base_session_service import GetSessionConfig
52
- from .base_session_service import ListEventsResponse
53
54
  from .base_session_service import ListSessionsResponse
54
55
  from .session import Session
55
56
  from .state import State
56
57
 
57
-
58
- logger = logging.getLogger(__name__)
58
+ logger = logging.getLogger("google_adk." + __name__)
59
59
 
60
60
  DEFAULT_MAX_KEY_LENGTH = 128
61
61
  DEFAULT_MAX_VARCHAR_LENGTH = 256
62
62
 
63
63
 
64
64
  class DynamicJSON(TypeDecorator):
65
- """A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
66
-
67
- serialization for other databases.
68
- """
65
+ """A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON serialization for other databases."""
69
66
 
70
67
  impl = Text # Default implementation is TEXT
71
68
 
@@ -243,10 +240,7 @@ class DatabaseSessionService(BaseSessionService):
243
240
  """A session service that uses a database for storage."""
244
241
 
245
242
  def __init__(self, db_url: str):
246
- """
247
- Args:
248
- db_url: The database URL to connect to.
249
- """
243
+ """Initializes the database session service with a database URL."""
250
244
  # 1. Create DB engine for db connection
251
245
  # 2. Create all tables based on schema
252
246
  # 3. Initialize all properties
@@ -275,7 +269,7 @@ class DatabaseSessionService(BaseSessionService):
275
269
  self.inspector = inspect(self.db_engine)
276
270
 
277
271
  # DB session factory method
278
- self.DatabaseSessionFactory: sessionmaker[DatabaseSessionFactory] = (
272
+ self.database_session_factory: sessionmaker[DatabaseSessionFactory] = (
279
273
  sessionmaker(bind=self.db_engine)
280
274
  )
281
275
 
@@ -284,7 +278,7 @@ class DatabaseSessionService(BaseSessionService):
284
278
  Base.metadata.create_all(self.db_engine)
285
279
 
286
280
  @override
287
- def create_session(
281
+ async def create_session(
288
282
  self,
289
283
  *,
290
284
  app_name: str,
@@ -298,11 +292,11 @@ class DatabaseSessionService(BaseSessionService):
298
292
  # 4. Build the session object with generated id
299
293
  # 5. Return the session
300
294
 
301
- with self.DatabaseSessionFactory() as sessionFactory:
295
+ with self.database_session_factory() as session_factory:
302
296
 
303
297
  # Fetch app and user states from storage
304
- storage_app_state = sessionFactory.get(StorageAppState, (app_name))
305
- storage_user_state = sessionFactory.get(
298
+ storage_app_state = session_factory.get(StorageAppState, (app_name))
299
+ storage_user_state = session_factory.get(
306
300
  StorageUserState, (app_name, user_id)
307
301
  )
308
302
 
@@ -312,12 +306,12 @@ class DatabaseSessionService(BaseSessionService):
312
306
  # Create state tables if not exist
313
307
  if not storage_app_state:
314
308
  storage_app_state = StorageAppState(app_name=app_name, state={})
315
- sessionFactory.add(storage_app_state)
309
+ session_factory.add(storage_app_state)
316
310
  if not storage_user_state:
317
311
  storage_user_state = StorageUserState(
318
312
  app_name=app_name, user_id=user_id, state={}
319
313
  )
320
- sessionFactory.add(storage_user_state)
314
+ session_factory.add(storage_user_state)
321
315
 
322
316
  # Extract state deltas
323
317
  app_state_delta, user_state_delta, session_state = _extract_state_delta(
@@ -341,10 +335,10 @@ class DatabaseSessionService(BaseSessionService):
341
335
  id=session_id,
342
336
  state=session_state,
343
337
  )
344
- sessionFactory.add(storage_session)
345
- sessionFactory.commit()
338
+ session_factory.add(storage_session)
339
+ session_factory.commit()
346
340
 
347
- sessionFactory.refresh(storage_session)
341
+ session_factory.refresh(storage_session)
348
342
 
349
343
  # Merge states for response
350
344
  merged_state = _merge_state(app_state, user_state, session_state)
@@ -358,7 +352,7 @@ class DatabaseSessionService(BaseSessionService):
358
352
  return session
359
353
 
360
354
  @override
361
- def get_session(
355
+ async def get_session(
362
356
  self,
363
357
  *,
364
358
  app_name: str,
@@ -369,29 +363,37 @@ class DatabaseSessionService(BaseSessionService):
369
363
  # 1. Get the storage session entry from session table
370
364
  # 2. Get all the events based on session id and filtering config
371
365
  # 3. Convert and return the session
372
- with self.DatabaseSessionFactory() as sessionFactory:
373
- storage_session = sessionFactory.get(
366
+ with self.database_session_factory() as session_factory:
367
+ storage_session = session_factory.get(
374
368
  StorageSession, (app_name, user_id, session_id)
375
369
  )
376
370
  if storage_session is None:
377
371
  return None
378
372
 
373
+ if config and config.after_timestamp:
374
+ after_dt = datetime.fromtimestamp(
375
+ config.after_timestamp, tz=timezone.utc
376
+ )
377
+ timestamp_filter = StorageEvent.timestamp > after_dt
378
+ else:
379
+ timestamp_filter = True
380
+
379
381
  storage_events = (
380
- sessionFactory.query(StorageEvent)
382
+ session_factory.query(StorageEvent)
381
383
  .filter(StorageEvent.session_id == storage_session.id)
382
- .filter(
383
- StorageEvent.timestamp < config.after_timestamp
384
- if config
385
- else True
386
- )
387
- .limit(config.num_recent_events if config else None)
384
+ .filter(timestamp_filter)
388
385
  .order_by(StorageEvent.timestamp.asc())
386
+ .limit(
387
+ config.num_recent_events
388
+ if config and config.num_recent_events
389
+ else None
390
+ )
389
391
  .all()
390
392
  )
391
393
 
392
394
  # Fetch states from storage
393
- storage_app_state = sessionFactory.get(StorageAppState, (app_name))
394
- storage_user_state = sessionFactory.get(
395
+ storage_app_state = session_factory.get(StorageAppState, (app_name))
396
+ storage_user_state = session_factory.get(
395
397
  StorageUserState, (app_name, user_id)
396
398
  )
397
399
 
@@ -432,12 +434,12 @@ class DatabaseSessionService(BaseSessionService):
432
434
  return session
433
435
 
434
436
  @override
435
- def list_sessions(
437
+ async def list_sessions(
436
438
  self, *, app_name: str, user_id: str
437
439
  ) -> ListSessionsResponse:
438
- with self.DatabaseSessionFactory() as sessionFactory:
440
+ with self.database_session_factory() as session_factory:
439
441
  results = (
440
- sessionFactory.query(StorageSession)
442
+ session_factory.query(StorageSession)
441
443
  .filter(StorageSession.app_name == app_name)
442
444
  .filter(StorageSession.user_id == user_id)
443
445
  .all()
@@ -455,20 +457,20 @@ class DatabaseSessionService(BaseSessionService):
455
457
  return ListSessionsResponse(sessions=sessions)
456
458
 
457
459
  @override
458
- def delete_session(
460
+ async def delete_session(
459
461
  self, app_name: str, user_id: str, session_id: str
460
462
  ) -> None:
461
- with self.DatabaseSessionFactory() as sessionFactory:
463
+ with self.database_session_factory() as session_factory:
462
464
  stmt = delete(StorageSession).where(
463
465
  StorageSession.app_name == app_name,
464
466
  StorageSession.user_id == user_id,
465
467
  StorageSession.id == session_id,
466
468
  )
467
- sessionFactory.execute(stmt)
468
- sessionFactory.commit()
469
+ session_factory.execute(stmt)
470
+ session_factory.commit()
469
471
 
470
472
  @override
471
- def append_event(self, session: Session, event: Event) -> Event:
473
+ async def append_event(self, session: Session, event: Event) -> Event:
472
474
  logger.info(f"Append event: {event} to session {session.id}")
473
475
 
474
476
  if event.partial:
@@ -477,24 +479,25 @@ class DatabaseSessionService(BaseSessionService):
477
479
  # 1. Check if timestamp is stale
478
480
  # 2. Update session attributes based on event config
479
481
  # 3. Store event to table
480
- with self.DatabaseSessionFactory() as sessionFactory:
481
- storage_session = sessionFactory.get(
482
+ with self.database_session_factory() as session_factory:
483
+ storage_session = session_factory.get(
482
484
  StorageSession, (session.app_name, session.user_id, session.id)
483
485
  )
484
486
 
485
487
  if storage_session.update_time.timestamp() > session.last_update_time:
486
488
  raise ValueError(
487
- f"Session last_update_time "
488
- f"{datetime.fromtimestamp(session.last_update_time):%Y-%m-%d %H:%M:%S} "
489
- f"is later than the update_time in storage "
490
- f"{storage_session.update_time:%Y-%m-%d %H:%M:%S}"
491
- )
489
+ "The last_update_time provided in the session object"
490
+ f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is"
491
+ " earlier than the update_time in the storage_session"
492
+ f" {storage_session.update_time:'%Y-%m-%d %H:%M:%S'}. Please check"
493
+ " if it is a stale session."
494
+ )
492
495
 
493
496
  # Fetch states from storage
494
- storage_app_state = sessionFactory.get(
497
+ storage_app_state = session_factory.get(
495
498
  StorageAppState, (session.app_name)
496
499
  )
497
- storage_user_state = sessionFactory.get(
500
+ storage_user_state = session_factory.get(
498
501
  StorageUserState, (session.app_name, session.user_id)
499
502
  )
500
503
 
@@ -543,28 +546,18 @@ class DatabaseSessionService(BaseSessionService):
543
546
  if event.content:
544
547
  storage_event.content = _session_util.encode_content(event.content)
545
548
 
546
- sessionFactory.add(storage_event)
549
+ session_factory.add(storage_event)
547
550
 
548
- sessionFactory.commit()
549
- sessionFactory.refresh(storage_session)
551
+ session_factory.commit()
552
+ session_factory.refresh(storage_session)
550
553
 
551
554
  # Update timestamp with commit time
552
555
  session.last_update_time = storage_session.update_time.timestamp()
553
556
 
554
557
  # Also update the in-memory session
555
- super().append_event(session=session, event=event)
558
+ await super().append_event(session=session, event=event)
556
559
  return event
557
560
 
558
- @override
559
- def list_events(
560
- self,
561
- *,
562
- app_name: str,
563
- user_id: str,
564
- session_id: str,
565
- ) -> ListEventsResponse:
566
- raise NotImplementedError()
567
-
568
561
 
569
562
  def convert_event(event: StorageEvent) -> Event:
570
563
  """Converts a storage event to an event."""