google-adk 0.5.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/base_agent.py +76 -30
- google/adk/agents/base_agent.py.orig +330 -0
- google/adk/agents/callback_context.py +0 -5
- google/adk/agents/llm_agent.py +122 -30
- google/adk/agents/loop_agent.py +1 -1
- google/adk/agents/parallel_agent.py +7 -0
- google/adk/agents/readonly_context.py +7 -1
- google/adk/agents/run_config.py +1 -1
- google/adk/agents/sequential_agent.py +31 -0
- google/adk/agents/transcription_entry.py +4 -2
- google/adk/artifacts/gcs_artifact_service.py +1 -1
- google/adk/artifacts/in_memory_artifact_service.py +1 -1
- google/adk/auth/auth_credential.py +6 -1
- google/adk/auth/auth_preprocessor.py +7 -1
- google/adk/auth/auth_tool.py +3 -4
- google/adk/cli/agent_graph.py +5 -5
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/{main-ULN5R5I5.js → main-QOEMUXM4.js} +44 -45
- google/adk/cli/cli.py +7 -7
- google/adk/cli/cli_deploy.py +7 -2
- google/adk/cli/cli_eval.py +172 -99
- google/adk/cli/cli_tools_click.py +147 -64
- google/adk/cli/fast_api.py +330 -148
- google/adk/cli/fast_api.py.orig +174 -80
- google/adk/cli/utils/common.py +23 -0
- google/adk/cli/utils/evals.py +83 -1
- google/adk/cli/utils/logs.py +13 -5
- google/adk/code_executors/__init__.py +3 -1
- google/adk/code_executors/built_in_code_executor.py +52 -0
- google/adk/evaluation/__init__.py +1 -1
- google/adk/evaluation/agent_evaluator.py +168 -128
- google/adk/evaluation/eval_case.py +102 -0
- google/adk/evaluation/eval_set.py +37 -0
- google/adk/evaluation/eval_sets_manager.py +42 -0
- google/adk/evaluation/evaluation_generator.py +88 -113
- google/adk/evaluation/evaluator.py +56 -0
- google/adk/evaluation/local_eval_sets_manager.py +264 -0
- google/adk/evaluation/response_evaluator.py +106 -2
- google/adk/evaluation/trajectory_evaluator.py +83 -2
- google/adk/events/event.py +6 -1
- google/adk/events/event_actions.py +6 -1
- google/adk/examples/example_util.py +3 -2
- google/adk/flows/llm_flows/_code_execution.py +9 -1
- google/adk/flows/llm_flows/audio_transcriber.py +4 -3
- google/adk/flows/llm_flows/base_llm_flow.py +54 -15
- google/adk/flows/llm_flows/functions.py +9 -8
- google/adk/flows/llm_flows/instructions.py +13 -5
- google/adk/flows/llm_flows/single_flow.py +1 -1
- google/adk/memory/__init__.py +1 -1
- google/adk/memory/_utils.py +23 -0
- google/adk/memory/base_memory_service.py +23 -21
- google/adk/memory/base_memory_service.py.orig +76 -0
- google/adk/memory/in_memory_memory_service.py +57 -25
- google/adk/memory/memory_entry.py +37 -0
- google/adk/memory/vertex_ai_rag_memory_service.py +38 -15
- google/adk/models/anthropic_llm.py +16 -9
- google/adk/models/gemini_llm_connection.py +11 -11
- google/adk/models/google_llm.py +9 -2
- google/adk/models/google_llm.py.orig +305 -0
- google/adk/models/lite_llm.py +77 -21
- google/adk/models/llm_response.py +14 -2
- google/adk/models/registry.py +1 -1
- google/adk/runners.py +65 -41
- google/adk/sessions/__init__.py +1 -1
- google/adk/sessions/base_session_service.py +6 -33
- google/adk/sessions/database_session_service.py +58 -65
- google/adk/sessions/in_memory_session_service.py +106 -24
- google/adk/sessions/session.py +3 -0
- google/adk/sessions/vertex_ai_session_service.py +23 -45
- google/adk/telemetry.py +3 -0
- google/adk/tools/__init__.py +4 -7
- google/adk/tools/{built_in_code_execution_tool.py → _built_in_code_execution_tool.py} +11 -0
- google/adk/tools/_memory_entry_utils.py +30 -0
- google/adk/tools/agent_tool.py +9 -9
- google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
- google/adk/tools/application_integration_tool/application_integration_toolset.py +107 -85
- google/adk/tools/application_integration_tool/clients/connections_client.py +20 -0
- google/adk/tools/application_integration_tool/clients/integration_client.py +6 -6
- google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
- google/adk/tools/base_toolset.py +58 -0
- google/adk/tools/enterprise_search_tool.py +65 -0
- google/adk/tools/function_parameter_parse_util.py +2 -2
- google/adk/tools/google_api_tool/__init__.py +18 -70
- google/adk/tools/google_api_tool/google_api_tool.py +11 -5
- google/adk/tools/google_api_tool/google_api_toolset.py +126 -0
- google/adk/tools/google_api_tool/google_api_toolsets.py +102 -0
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
- google/adk/tools/langchain_tool.py +96 -49
- google/adk/tools/load_memory_tool.py +14 -5
- google/adk/tools/mcp_tool/__init__.py +3 -2
- google/adk/tools/mcp_tool/mcp_session_manager.py +153 -16
- google/adk/tools/mcp_tool/mcp_session_manager.py.orig +322 -0
- google/adk/tools/mcp_tool/mcp_tool.py +12 -12
- google/adk/tools/mcp_tool/mcp_toolset.py +155 -195
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +32 -7
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +31 -31
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
- google/adk/tools/preload_memory_tool.py +27 -18
- google/adk/tools/retrieval/__init__.py +1 -1
- google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
- google/adk/tools/toolbox_toolset.py +79 -0
- google/adk/tools/transfer_to_agent_tool.py +0 -1
- google/adk/version.py +1 -1
- {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/METADATA +7 -5
- google_adk-1.0.0.dist-info/RECORD +195 -0
- google/adk/agents/remote_agent.py +0 -50
- google/adk/tools/google_api_tool/google_api_tool_set.py +0 -110
- google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
- google/adk/tools/toolbox_tool.py +0 -46
- google_adk-0.5.0.dist-info/RECORD +0 -180
- {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/WHEEL +0 -0
- {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/entry_points.txt +0 -0
- {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/licenses/LICENSE +0 -0
google/adk/runners.py
CHANGED
@@ -21,8 +21,8 @@ import threading
|
|
21
21
|
from typing import AsyncGenerator
|
22
22
|
from typing import Generator
|
23
23
|
from typing import Optional
|
24
|
+
import warnings
|
24
25
|
|
25
|
-
from deprecated import deprecated
|
26
26
|
from google.genai import types
|
27
27
|
|
28
28
|
from .agents.active_streaming_tool import ActiveStreamingTool
|
@@ -32,7 +32,6 @@ from .agents.invocation_context import new_invocation_context_id
|
|
32
32
|
from .agents.live_request_queue import LiveRequestQueue
|
33
33
|
from .agents.llm_agent import LlmAgent
|
34
34
|
from .agents.run_config import RunConfig
|
35
|
-
from .agents.run_config import StreamingMode
|
36
35
|
from .artifacts.base_artifact_service import BaseArtifactService
|
37
36
|
from .artifacts.in_memory_artifact_service import InMemoryArtifactService
|
38
37
|
from .events.event import Event
|
@@ -42,9 +41,9 @@ from .sessions.base_session_service import BaseSessionService
|
|
42
41
|
from .sessions.in_memory_session_service import InMemorySessionService
|
43
42
|
from .sessions.session import Session
|
44
43
|
from .telemetry import tracer
|
45
|
-
from .tools.
|
44
|
+
from .tools._built_in_code_execution_tool import built_in_code_execution
|
46
45
|
|
47
|
-
logger = logging.getLogger(__name__)
|
46
|
+
logger = logging.getLogger('google_adk.' + __name__)
|
48
47
|
|
49
48
|
|
50
49
|
class Runner:
|
@@ -172,7 +171,7 @@ class Runner:
|
|
172
171
|
The events generated by the agent.
|
173
172
|
"""
|
174
173
|
with tracer.start_as_current_span('invocation'):
|
175
|
-
session = self.session_service.get_session(
|
174
|
+
session = await self.session_service.get_session(
|
176
175
|
app_name=self.app_name, user_id=user_id, session_id=session_id
|
177
176
|
)
|
178
177
|
if not session:
|
@@ -196,7 +195,7 @@ class Runner:
|
|
196
195
|
invocation_context.agent = self._find_agent_to_run(session, root_agent)
|
197
196
|
async for event in invocation_context.agent.run_async(invocation_context):
|
198
197
|
if not event.partial:
|
199
|
-
self.session_service.append_event(session=session, event=event)
|
198
|
+
await self.session_service.append_event(session=session, event=event)
|
200
199
|
yield event
|
201
200
|
|
202
201
|
async def _append_new_message_to_session(
|
@@ -241,30 +240,57 @@ class Runner:
|
|
241
240
|
author='user',
|
242
241
|
content=new_message,
|
243
242
|
)
|
244
|
-
self.session_service.append_event(session=session, event=event)
|
243
|
+
await self.session_service.append_event(session=session, event=event)
|
245
244
|
|
246
245
|
async def run_live(
|
247
246
|
self,
|
248
247
|
*,
|
249
|
-
|
248
|
+
user_id: Optional[str] = None,
|
249
|
+
session_id: Optional[str] = None,
|
250
250
|
live_request_queue: LiveRequestQueue,
|
251
251
|
run_config: RunConfig = RunConfig(),
|
252
|
+
session: Optional[Session] = None,
|
252
253
|
) -> AsyncGenerator[Event, None]:
|
253
254
|
"""Runs the agent in live mode (experimental feature).
|
254
255
|
|
255
256
|
Args:
|
256
|
-
|
257
|
+
user_id: The user ID for the session. Required if `session` is None.
|
258
|
+
session_id: The session ID for the session. Required if `session` is
|
259
|
+
None.
|
257
260
|
live_request_queue: The queue for live requests.
|
258
261
|
run_config: The run config for the agent.
|
262
|
+
session: The session to use. This parameter is deprecated, please use
|
263
|
+
`user_id` and `session_id` instead.
|
259
264
|
|
260
265
|
Yields:
|
261
|
-
|
266
|
+
AsyncGenerator[Event, None]: An asynchronous generator that yields
|
267
|
+
`Event`
|
268
|
+
objects as they are produced by the agent during its live execution.
|
262
269
|
|
263
270
|
.. warning::
|
264
271
|
This feature is **experimental** and its API or behavior may change
|
265
272
|
in future releases.
|
273
|
+
|
274
|
+
.. note::
|
275
|
+
Either `session` or both `user_id` and `session_id` must be provided.
|
266
276
|
"""
|
267
|
-
|
277
|
+
if session is None and (user_id is None or session_id is None):
|
278
|
+
raise ValueError(
|
279
|
+
'Either session or user_id and session_id must be provided.'
|
280
|
+
)
|
281
|
+
if session is not None:
|
282
|
+
warnings.warn(
|
283
|
+
'The `session` parameter is deprecated. Please use `user_id` and'
|
284
|
+
' `session_id` instead.',
|
285
|
+
DeprecationWarning,
|
286
|
+
stacklevel=2,
|
287
|
+
)
|
288
|
+
if not session:
|
289
|
+
session = self.session_service.get_session(
|
290
|
+
app_name=self.app_name, user_id=user_id, session_id=session_id
|
291
|
+
)
|
292
|
+
if not session:
|
293
|
+
raise ValueError(f'Session not found: {session_id}')
|
268
294
|
invocation_context = self._new_invocation_context_for_live(
|
269
295
|
session,
|
270
296
|
live_request_queue=live_request_queue,
|
@@ -276,37 +302,29 @@ class Runner:
|
|
276
302
|
|
277
303
|
invocation_context.active_streaming_tools = {}
|
278
304
|
# TODO(hangfei): switch to use canonical_tools.
|
279
|
-
for
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
305
|
+
# for shell agents, there is no tools associated with it so we should skip.
|
306
|
+
if hasattr(invocation_context.agent, 'tools'):
|
307
|
+
for tool in invocation_context.agent.tools:
|
308
|
+
# replicate a LiveRequestQueue for streaming tools that relis on
|
309
|
+
# LiveRequestQueue
|
310
|
+
from typing import get_type_hints
|
311
|
+
|
312
|
+
type_hints = get_type_hints(tool)
|
313
|
+
for arg_type in type_hints.values():
|
314
|
+
if arg_type is LiveRequestQueue:
|
315
|
+
if not invocation_context.active_streaming_tools:
|
316
|
+
invocation_context.active_streaming_tools = {}
|
317
|
+
active_streaming_tools = ActiveStreamingTool(
|
318
|
+
stream=LiveRequestQueue()
|
319
|
+
)
|
320
|
+
invocation_context.active_streaming_tools[tool.__name__] = (
|
321
|
+
active_streaming_tools
|
322
|
+
)
|
295
323
|
|
296
324
|
async for event in invocation_context.agent.run_live(invocation_context):
|
297
|
-
self.session_service.append_event(session=session, event=event)
|
325
|
+
await self.session_service.append_event(session=session, event=event)
|
298
326
|
yield event
|
299
327
|
|
300
|
-
async def close_session(self, session: Session):
|
301
|
-
"""Closes a session and adds it to the memory service (experimental feature).
|
302
|
-
|
303
|
-
Args:
|
304
|
-
session: The session to close.
|
305
|
-
"""
|
306
|
-
if self.memory_service:
|
307
|
-
await self.memory_service.add_session_to_memory(session)
|
308
|
-
self.session_service.close_session(session=session)
|
309
|
-
|
310
328
|
def _find_agent_to_run(
|
311
329
|
self, session: Session, root_agent: BaseAgent
|
312
330
|
) -> BaseAgent:
|
@@ -391,7 +409,7 @@ class Runner:
|
|
391
409
|
f'CFC is not supported for model: {model_name} in agent:'
|
392
410
|
f' {self.agent.name}'
|
393
411
|
)
|
394
|
-
if built_in_code_execution not in self.agent.canonical_tools:
|
412
|
+
if built_in_code_execution not in self.agent.canonical_tools():
|
395
413
|
self.agent.tools.append(built_in_code_execution)
|
396
414
|
|
397
415
|
return InvocationContext(
|
@@ -430,6 +448,9 @@ class Runner:
|
|
430
448
|
run_config.output_audio_transcription = (
|
431
449
|
types.AudioTranscriptionConfig()
|
432
450
|
)
|
451
|
+
if not run_config.input_audio_transcription:
|
452
|
+
# need this input transcription for agent transferring in live mode.
|
453
|
+
run_config.input_audio_transcription = types.AudioTranscriptionConfig()
|
433
454
|
return self._new_invocation_context(
|
434
455
|
session,
|
435
456
|
live_request_queue=live_request_queue,
|
@@ -448,9 +469,11 @@ class InMemoryRunner(Runner):
|
|
448
469
|
agent: The root agent to run.
|
449
470
|
app_name: The application name of the runner. Defaults to
|
450
471
|
'InMemoryRunner'.
|
472
|
+
_in_memory_session_service: Deprecated. Please don't use. The in-memory
|
473
|
+
session service for the runner.
|
451
474
|
"""
|
452
475
|
|
453
|
-
def __init__(self, agent:
|
476
|
+
def __init__(self, agent: BaseAgent, *, app_name: str = 'InMemoryRunner'):
|
454
477
|
"""Initializes the InMemoryRunner.
|
455
478
|
|
456
479
|
Args:
|
@@ -458,10 +481,11 @@ class InMemoryRunner(Runner):
|
|
458
481
|
app_name: The application name of the runner. Defaults to
|
459
482
|
'InMemoryRunner'.
|
460
483
|
"""
|
484
|
+
self._in_memory_session_service = InMemorySessionService()
|
461
485
|
super().__init__(
|
462
486
|
app_name=app_name,
|
463
487
|
agent=agent,
|
464
488
|
artifact_service=InMemoryArtifactService(),
|
465
|
-
session_service=
|
489
|
+
session_service=self._in_memory_session_service,
|
466
490
|
memory_service=InMemoryMemoryService(),
|
467
491
|
)
|
google/adk/sessions/__init__.py
CHANGED
@@ -40,13 +40,6 @@ class ListSessionsResponse(BaseModel):
|
|
40
40
|
sessions: list[Session] = Field(default_factory=list)
|
41
41
|
|
42
42
|
|
43
|
-
class ListEventsResponse(BaseModel):
|
44
|
-
"""The response of listing events in a session."""
|
45
|
-
|
46
|
-
events: list[Event] = Field(default_factory=list)
|
47
|
-
next_page_token: Optional[str] = None
|
48
|
-
|
49
|
-
|
50
43
|
class BaseSessionService(abc.ABC):
|
51
44
|
"""Base class for session services.
|
52
45
|
|
@@ -54,7 +47,7 @@ class BaseSessionService(abc.ABC):
|
|
54
47
|
"""
|
55
48
|
|
56
49
|
@abc.abstractmethod
|
57
|
-
def create_session(
|
50
|
+
async def create_session(
|
58
51
|
self,
|
59
52
|
*,
|
60
53
|
app_name: str,
|
@@ -74,10 +67,9 @@ class BaseSessionService(abc.ABC):
|
|
74
67
|
Returns:
|
75
68
|
session: The newly created session instance.
|
76
69
|
"""
|
77
|
-
pass
|
78
70
|
|
79
71
|
@abc.abstractmethod
|
80
|
-
def get_session(
|
72
|
+
async def get_session(
|
81
73
|
self,
|
82
74
|
*,
|
83
75
|
app_name: str,
|
@@ -86,39 +78,20 @@ class BaseSessionService(abc.ABC):
|
|
86
78
|
config: Optional[GetSessionConfig] = None,
|
87
79
|
) -> Optional[Session]:
|
88
80
|
"""Gets a session."""
|
89
|
-
pass
|
90
81
|
|
91
82
|
@abc.abstractmethod
|
92
|
-
def list_sessions(
|
83
|
+
async def list_sessions(
|
93
84
|
self, *, app_name: str, user_id: str
|
94
85
|
) -> ListSessionsResponse:
|
95
86
|
"""Lists all the sessions."""
|
96
|
-
pass
|
97
87
|
|
98
88
|
@abc.abstractmethod
|
99
|
-
def delete_session(
|
89
|
+
async def delete_session(
|
100
90
|
self, *, app_name: str, user_id: str, session_id: str
|
101
91
|
) -> None:
|
102
92
|
"""Deletes a session."""
|
103
|
-
pass
|
104
|
-
|
105
|
-
@abc.abstractmethod
|
106
|
-
def list_events(
|
107
|
-
self,
|
108
|
-
*,
|
109
|
-
app_name: str,
|
110
|
-
user_id: str,
|
111
|
-
session_id: str,
|
112
|
-
) -> ListEventsResponse:
|
113
|
-
"""Lists events in a session."""
|
114
|
-
pass
|
115
|
-
|
116
|
-
def close_session(self, *, session: Session):
|
117
|
-
"""Closes a session."""
|
118
|
-
# TODO: determine whether we want to finalize the session here.
|
119
|
-
pass
|
120
93
|
|
121
|
-
def append_event(self, session: Session, event: Event) -> Event:
|
94
|
+
async def append_event(self, session: Session, event: Event) -> Event:
|
122
95
|
"""Appends an event to a session object."""
|
123
96
|
if event.partial:
|
124
97
|
return event
|
@@ -126,7 +99,7 @@ class BaseSessionService(abc.ABC):
|
|
126
99
|
session.events.append(event)
|
127
100
|
return event
|
128
101
|
|
129
|
-
def __update_session_state(self, session: Session, event: Event):
|
102
|
+
def __update_session_state(self, session: Session, event: Event) -> None:
|
130
103
|
"""Updates the session state based on the event."""
|
131
104
|
if not event.actions or not event.actions.state_delta:
|
132
105
|
return
|
@@ -13,9 +13,11 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
import copy
|
15
15
|
from datetime import datetime
|
16
|
+
from datetime import timezone
|
16
17
|
import json
|
17
18
|
import logging
|
18
|
-
from typing import Any
|
19
|
+
from typing import Any
|
20
|
+
from typing import Optional
|
19
21
|
import uuid
|
20
22
|
|
21
23
|
from sqlalchemy import Boolean
|
@@ -49,23 +51,18 @@ from ..events.event import Event
|
|
49
51
|
from . import _session_util
|
50
52
|
from .base_session_service import BaseSessionService
|
51
53
|
from .base_session_service import GetSessionConfig
|
52
|
-
from .base_session_service import ListEventsResponse
|
53
54
|
from .base_session_service import ListSessionsResponse
|
54
55
|
from .session import Session
|
55
56
|
from .state import State
|
56
57
|
|
57
|
-
|
58
|
-
logger = logging.getLogger(__name__)
|
58
|
+
logger = logging.getLogger("google_adk." + __name__)
|
59
59
|
|
60
60
|
DEFAULT_MAX_KEY_LENGTH = 128
|
61
61
|
DEFAULT_MAX_VARCHAR_LENGTH = 256
|
62
62
|
|
63
63
|
|
64
64
|
class DynamicJSON(TypeDecorator):
|
65
|
-
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
|
66
|
-
|
67
|
-
serialization for other databases.
|
68
|
-
"""
|
65
|
+
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON serialization for other databases."""
|
69
66
|
|
70
67
|
impl = Text # Default implementation is TEXT
|
71
68
|
|
@@ -243,10 +240,7 @@ class DatabaseSessionService(BaseSessionService):
|
|
243
240
|
"""A session service that uses a database for storage."""
|
244
241
|
|
245
242
|
def __init__(self, db_url: str):
|
246
|
-
"""
|
247
|
-
Args:
|
248
|
-
db_url: The database URL to connect to.
|
249
|
-
"""
|
243
|
+
"""Initializes the database session service with a database URL."""
|
250
244
|
# 1. Create DB engine for db connection
|
251
245
|
# 2. Create all tables based on schema
|
252
246
|
# 3. Initialize all properties
|
@@ -275,7 +269,7 @@ class DatabaseSessionService(BaseSessionService):
|
|
275
269
|
self.inspector = inspect(self.db_engine)
|
276
270
|
|
277
271
|
# DB session factory method
|
278
|
-
self.
|
272
|
+
self.database_session_factory: sessionmaker[DatabaseSessionFactory] = (
|
279
273
|
sessionmaker(bind=self.db_engine)
|
280
274
|
)
|
281
275
|
|
@@ -284,7 +278,7 @@ class DatabaseSessionService(BaseSessionService):
|
|
284
278
|
Base.metadata.create_all(self.db_engine)
|
285
279
|
|
286
280
|
@override
|
287
|
-
def create_session(
|
281
|
+
async def create_session(
|
288
282
|
self,
|
289
283
|
*,
|
290
284
|
app_name: str,
|
@@ -298,11 +292,11 @@ class DatabaseSessionService(BaseSessionService):
|
|
298
292
|
# 4. Build the session object with generated id
|
299
293
|
# 5. Return the session
|
300
294
|
|
301
|
-
with self.
|
295
|
+
with self.database_session_factory() as session_factory:
|
302
296
|
|
303
297
|
# Fetch app and user states from storage
|
304
|
-
storage_app_state =
|
305
|
-
storage_user_state =
|
298
|
+
storage_app_state = session_factory.get(StorageAppState, (app_name))
|
299
|
+
storage_user_state = session_factory.get(
|
306
300
|
StorageUserState, (app_name, user_id)
|
307
301
|
)
|
308
302
|
|
@@ -312,12 +306,12 @@ class DatabaseSessionService(BaseSessionService):
|
|
312
306
|
# Create state tables if not exist
|
313
307
|
if not storage_app_state:
|
314
308
|
storage_app_state = StorageAppState(app_name=app_name, state={})
|
315
|
-
|
309
|
+
session_factory.add(storage_app_state)
|
316
310
|
if not storage_user_state:
|
317
311
|
storage_user_state = StorageUserState(
|
318
312
|
app_name=app_name, user_id=user_id, state={}
|
319
313
|
)
|
320
|
-
|
314
|
+
session_factory.add(storage_user_state)
|
321
315
|
|
322
316
|
# Extract state deltas
|
323
317
|
app_state_delta, user_state_delta, session_state = _extract_state_delta(
|
@@ -341,10 +335,10 @@ class DatabaseSessionService(BaseSessionService):
|
|
341
335
|
id=session_id,
|
342
336
|
state=session_state,
|
343
337
|
)
|
344
|
-
|
345
|
-
|
338
|
+
session_factory.add(storage_session)
|
339
|
+
session_factory.commit()
|
346
340
|
|
347
|
-
|
341
|
+
session_factory.refresh(storage_session)
|
348
342
|
|
349
343
|
# Merge states for response
|
350
344
|
merged_state = _merge_state(app_state, user_state, session_state)
|
@@ -358,7 +352,7 @@ class DatabaseSessionService(BaseSessionService):
|
|
358
352
|
return session
|
359
353
|
|
360
354
|
@override
|
361
|
-
def get_session(
|
355
|
+
async def get_session(
|
362
356
|
self,
|
363
357
|
*,
|
364
358
|
app_name: str,
|
@@ -369,29 +363,37 @@ class DatabaseSessionService(BaseSessionService):
|
|
369
363
|
# 1. Get the storage session entry from session table
|
370
364
|
# 2. Get all the events based on session id and filtering config
|
371
365
|
# 3. Convert and return the session
|
372
|
-
with self.
|
373
|
-
storage_session =
|
366
|
+
with self.database_session_factory() as session_factory:
|
367
|
+
storage_session = session_factory.get(
|
374
368
|
StorageSession, (app_name, user_id, session_id)
|
375
369
|
)
|
376
370
|
if storage_session is None:
|
377
371
|
return None
|
378
372
|
|
373
|
+
if config and config.after_timestamp:
|
374
|
+
after_dt = datetime.fromtimestamp(
|
375
|
+
config.after_timestamp, tz=timezone.utc
|
376
|
+
)
|
377
|
+
timestamp_filter = StorageEvent.timestamp > after_dt
|
378
|
+
else:
|
379
|
+
timestamp_filter = True
|
380
|
+
|
379
381
|
storage_events = (
|
380
|
-
|
382
|
+
session_factory.query(StorageEvent)
|
381
383
|
.filter(StorageEvent.session_id == storage_session.id)
|
382
|
-
.filter(
|
383
|
-
StorageEvent.timestamp < config.after_timestamp
|
384
|
-
if config
|
385
|
-
else True
|
386
|
-
)
|
387
|
-
.limit(config.num_recent_events if config else None)
|
384
|
+
.filter(timestamp_filter)
|
388
385
|
.order_by(StorageEvent.timestamp.asc())
|
386
|
+
.limit(
|
387
|
+
config.num_recent_events
|
388
|
+
if config and config.num_recent_events
|
389
|
+
else None
|
390
|
+
)
|
389
391
|
.all()
|
390
392
|
)
|
391
393
|
|
392
394
|
# Fetch states from storage
|
393
|
-
storage_app_state =
|
394
|
-
storage_user_state =
|
395
|
+
storage_app_state = session_factory.get(StorageAppState, (app_name))
|
396
|
+
storage_user_state = session_factory.get(
|
395
397
|
StorageUserState, (app_name, user_id)
|
396
398
|
)
|
397
399
|
|
@@ -432,12 +434,12 @@ class DatabaseSessionService(BaseSessionService):
|
|
432
434
|
return session
|
433
435
|
|
434
436
|
@override
|
435
|
-
def list_sessions(
|
437
|
+
async def list_sessions(
|
436
438
|
self, *, app_name: str, user_id: str
|
437
439
|
) -> ListSessionsResponse:
|
438
|
-
with self.
|
440
|
+
with self.database_session_factory() as session_factory:
|
439
441
|
results = (
|
440
|
-
|
442
|
+
session_factory.query(StorageSession)
|
441
443
|
.filter(StorageSession.app_name == app_name)
|
442
444
|
.filter(StorageSession.user_id == user_id)
|
443
445
|
.all()
|
@@ -455,20 +457,20 @@ class DatabaseSessionService(BaseSessionService):
|
|
455
457
|
return ListSessionsResponse(sessions=sessions)
|
456
458
|
|
457
459
|
@override
|
458
|
-
def delete_session(
|
460
|
+
async def delete_session(
|
459
461
|
self, app_name: str, user_id: str, session_id: str
|
460
462
|
) -> None:
|
461
|
-
with self.
|
463
|
+
with self.database_session_factory() as session_factory:
|
462
464
|
stmt = delete(StorageSession).where(
|
463
465
|
StorageSession.app_name == app_name,
|
464
466
|
StorageSession.user_id == user_id,
|
465
467
|
StorageSession.id == session_id,
|
466
468
|
)
|
467
|
-
|
468
|
-
|
469
|
+
session_factory.execute(stmt)
|
470
|
+
session_factory.commit()
|
469
471
|
|
470
472
|
@override
|
471
|
-
def append_event(self, session: Session, event: Event) -> Event:
|
473
|
+
async def append_event(self, session: Session, event: Event) -> Event:
|
472
474
|
logger.info(f"Append event: {event} to session {session.id}")
|
473
475
|
|
474
476
|
if event.partial:
|
@@ -477,24 +479,25 @@ class DatabaseSessionService(BaseSessionService):
|
|
477
479
|
# 1. Check if timestamp is stale
|
478
480
|
# 2. Update session attributes based on event config
|
479
481
|
# 3. Store event to table
|
480
|
-
with self.
|
481
|
-
storage_session =
|
482
|
+
with self.database_session_factory() as session_factory:
|
483
|
+
storage_session = session_factory.get(
|
482
484
|
StorageSession, (session.app_name, session.user_id, session.id)
|
483
485
|
)
|
484
486
|
|
485
487
|
if storage_session.update_time.timestamp() > session.last_update_time:
|
486
488
|
raise ValueError(
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
489
|
+
"The last_update_time provided in the session object"
|
490
|
+
f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is"
|
491
|
+
" earlier than the update_time in the storage_session"
|
492
|
+
f" {storage_session.update_time:'%Y-%m-%d %H:%M:%S'}. Please check"
|
493
|
+
" if it is a stale session."
|
494
|
+
)
|
492
495
|
|
493
496
|
# Fetch states from storage
|
494
|
-
storage_app_state =
|
497
|
+
storage_app_state = session_factory.get(
|
495
498
|
StorageAppState, (session.app_name)
|
496
499
|
)
|
497
|
-
storage_user_state =
|
500
|
+
storage_user_state = session_factory.get(
|
498
501
|
StorageUserState, (session.app_name, session.user_id)
|
499
502
|
)
|
500
503
|
|
@@ -543,28 +546,18 @@ class DatabaseSessionService(BaseSessionService):
|
|
543
546
|
if event.content:
|
544
547
|
storage_event.content = _session_util.encode_content(event.content)
|
545
548
|
|
546
|
-
|
549
|
+
session_factory.add(storage_event)
|
547
550
|
|
548
|
-
|
549
|
-
|
551
|
+
session_factory.commit()
|
552
|
+
session_factory.refresh(storage_session)
|
550
553
|
|
551
554
|
# Update timestamp with commit time
|
552
555
|
session.last_update_time = storage_session.update_time.timestamp()
|
553
556
|
|
554
557
|
# Also update the in-memory session
|
555
|
-
super().append_event(session=session, event=event)
|
558
|
+
await super().append_event(session=session, event=event)
|
556
559
|
return event
|
557
560
|
|
558
|
-
@override
|
559
|
-
def list_events(
|
560
|
-
self,
|
561
|
-
*,
|
562
|
-
app_name: str,
|
563
|
-
user_id: str,
|
564
|
-
session_id: str,
|
565
|
-
) -> ListEventsResponse:
|
566
|
-
raise NotImplementedError()
|
567
|
-
|
568
561
|
|
569
562
|
def convert_event(event: StorageEvent) -> Event:
|
570
563
|
"""Converts a storage event to an event."""
|