openai-agents 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +3 -1
- agents/_run_impl.py +44 -7
- agents/agent.py +36 -4
- agents/extensions/memory/__init__.py +15 -0
- agents/extensions/memory/sqlalchemy_session.py +312 -0
- agents/extensions/models/litellm_model.py +11 -6
- agents/extensions/models/litellm_provider.py +3 -1
- agents/function_schema.py +2 -2
- agents/handoffs.py +3 -3
- agents/lifecycle.py +40 -1
- agents/mcp/server.py +59 -8
- agents/memory/__init__.py +9 -2
- agents/memory/openai_conversations_session.py +94 -0
- agents/memory/session.py +0 -270
- agents/memory/sqlite_session.py +275 -0
- agents/model_settings.py +8 -3
- agents/models/__init__.py +13 -0
- agents/models/chatcmpl_converter.py +5 -0
- agents/models/chatcmpl_stream_handler.py +81 -17
- agents/models/default_models.py +58 -0
- agents/models/interface.py +4 -0
- agents/models/openai_chatcompletions.py +4 -2
- agents/models/openai_provider.py +3 -1
- agents/models/openai_responses.py +24 -10
- agents/realtime/config.py +3 -0
- agents/realtime/events.py +11 -0
- agents/realtime/model_events.py +10 -0
- agents/realtime/openai_realtime.py +39 -5
- agents/realtime/session.py +7 -0
- agents/repl.py +7 -3
- agents/run.py +132 -7
- agents/tool.py +9 -1
- agents/tracing/processors.py +2 -2
- {openai_agents-0.2.8.dist-info → openai_agents-0.2.10.dist-info}/METADATA +16 -14
- {openai_agents-0.2.8.dist-info → openai_agents-0.2.10.dist-info}/RECORD +37 -32
- {openai_agents-0.2.8.dist-info → openai_agents-0.2.10.dist-info}/WHEEL +0 -0
- {openai_agents-0.2.8.dist-info → openai_agents-0.2.10.dist-info}/licenses/LICENSE +0 -0
agents/mcp/server.py
CHANGED
|
@@ -3,10 +3,11 @@ from __future__ import annotations
|
|
|
3
3
|
import abc
|
|
4
4
|
import asyncio
|
|
5
5
|
import inspect
|
|
6
|
+
from collections.abc import Awaitable
|
|
6
7
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
|
7
8
|
from datetime import timedelta
|
|
8
9
|
from pathlib import Path
|
|
9
|
-
from typing import TYPE_CHECKING, Any, Literal,
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
|
|
10
11
|
|
|
11
12
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
12
13
|
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
|
|
@@ -19,7 +20,9 @@ from typing_extensions import NotRequired, TypedDict
|
|
|
19
20
|
from ..exceptions import UserError
|
|
20
21
|
from ..logger import logger
|
|
21
22
|
from ..run_context import RunContextWrapper
|
|
22
|
-
from .util import ToolFilter,
|
|
23
|
+
from .util import ToolFilter, ToolFilterContext, ToolFilterStatic
|
|
24
|
+
|
|
25
|
+
T = TypeVar("T")
|
|
23
26
|
|
|
24
27
|
if TYPE_CHECKING:
|
|
25
28
|
from ..agent import AgentBase
|
|
@@ -98,6 +101,8 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
98
101
|
client_session_timeout_seconds: float | None,
|
|
99
102
|
tool_filter: ToolFilter = None,
|
|
100
103
|
use_structured_content: bool = False,
|
|
104
|
+
max_retry_attempts: int = 0,
|
|
105
|
+
retry_backoff_seconds_base: float = 1.0,
|
|
101
106
|
):
|
|
102
107
|
"""
|
|
103
108
|
Args:
|
|
@@ -115,6 +120,10 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
115
120
|
include the structured content in the `tool_result.content`, and using it by
|
|
116
121
|
default will cause duplicate content. You can set this to True if you know the
|
|
117
122
|
server will not duplicate the structured content in the `tool_result.content`.
|
|
123
|
+
max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
|
|
124
|
+
Defaults to no retries.
|
|
125
|
+
retry_backoff_seconds_base: The base delay, in seconds, used for exponential
|
|
126
|
+
backoff between retries.
|
|
118
127
|
"""
|
|
119
128
|
super().__init__(use_structured_content=use_structured_content)
|
|
120
129
|
self.session: ClientSession | None = None
|
|
@@ -124,6 +133,8 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
124
133
|
self.server_initialize_result: InitializeResult | None = None
|
|
125
134
|
|
|
126
135
|
self.client_session_timeout_seconds = client_session_timeout_seconds
|
|
136
|
+
self.max_retry_attempts = max_retry_attempts
|
|
137
|
+
self.retry_backoff_seconds_base = retry_backoff_seconds_base
|
|
127
138
|
|
|
128
139
|
# The cache is always dirty at startup, so that we fetch tools at least once
|
|
129
140
|
self._cache_dirty = True
|
|
@@ -175,10 +186,10 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
175
186
|
) -> list[MCPTool]:
|
|
176
187
|
"""Apply dynamic tool filtering using a callable filter function."""
|
|
177
188
|
|
|
178
|
-
# Ensure we have a callable filter
|
|
189
|
+
# Ensure we have a callable filter
|
|
179
190
|
if not callable(self.tool_filter):
|
|
180
191
|
raise ValueError("Tool filter must be callable for dynamic filtering")
|
|
181
|
-
tool_filter_func =
|
|
192
|
+
tool_filter_func = self.tool_filter
|
|
182
193
|
|
|
183
194
|
# Create filter context
|
|
184
195
|
filter_context = ToolFilterContext(
|
|
@@ -233,6 +244,18 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
233
244
|
"""Invalidate the tools cache."""
|
|
234
245
|
self._cache_dirty = True
|
|
235
246
|
|
|
247
|
+
async def _run_with_retries(self, func: Callable[[], Awaitable[T]]) -> T:
|
|
248
|
+
attempts = 0
|
|
249
|
+
while True:
|
|
250
|
+
try:
|
|
251
|
+
return await func()
|
|
252
|
+
except Exception:
|
|
253
|
+
attempts += 1
|
|
254
|
+
if self.max_retry_attempts != -1 and attempts > self.max_retry_attempts:
|
|
255
|
+
raise
|
|
256
|
+
backoff = self.retry_backoff_seconds_base * (2 ** (attempts - 1))
|
|
257
|
+
await asyncio.sleep(backoff)
|
|
258
|
+
|
|
236
259
|
async def connect(self):
|
|
237
260
|
"""Connect to the server."""
|
|
238
261
|
try:
|
|
@@ -267,15 +290,17 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
267
290
|
"""List the tools available on the server."""
|
|
268
291
|
if not self.session:
|
|
269
292
|
raise UserError("Server not initialized. Make sure you call `connect()` first.")
|
|
293
|
+
session = self.session
|
|
294
|
+
assert session is not None
|
|
270
295
|
|
|
271
296
|
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
|
|
272
297
|
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
|
|
273
298
|
tools = self._tools_list
|
|
274
299
|
else:
|
|
275
|
-
# Reset the cache dirty to False
|
|
276
|
-
self._cache_dirty = False
|
|
277
300
|
# Fetch the tools from the server
|
|
278
|
-
|
|
301
|
+
result = await self._run_with_retries(lambda: session.list_tools())
|
|
302
|
+
self._tools_list = result.tools
|
|
303
|
+
self._cache_dirty = False
|
|
279
304
|
tools = self._tools_list
|
|
280
305
|
|
|
281
306
|
# Filter tools based on tool_filter
|
|
@@ -290,8 +315,10 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
290
315
|
"""Invoke a tool on the server."""
|
|
291
316
|
if not self.session:
|
|
292
317
|
raise UserError("Server not initialized. Make sure you call `connect()` first.")
|
|
318
|
+
session = self.session
|
|
319
|
+
assert session is not None
|
|
293
320
|
|
|
294
|
-
return await self.session.call_tool(tool_name, arguments)
|
|
321
|
+
return await self._run_with_retries(lambda: session.call_tool(tool_name, arguments))
|
|
295
322
|
|
|
296
323
|
async def list_prompts(
|
|
297
324
|
self,
|
|
@@ -365,6 +392,8 @@ class MCPServerStdio(_MCPServerWithClientSession):
|
|
|
365
392
|
client_session_timeout_seconds: float | None = 5,
|
|
366
393
|
tool_filter: ToolFilter = None,
|
|
367
394
|
use_structured_content: bool = False,
|
|
395
|
+
max_retry_attempts: int = 0,
|
|
396
|
+
retry_backoff_seconds_base: float = 1.0,
|
|
368
397
|
):
|
|
369
398
|
"""Create a new MCP server based on the stdio transport.
|
|
370
399
|
|
|
@@ -388,12 +417,18 @@ class MCPServerStdio(_MCPServerWithClientSession):
|
|
|
388
417
|
include the structured content in the `tool_result.content`, and using it by
|
|
389
418
|
default will cause duplicate content. You can set this to True if you know the
|
|
390
419
|
server will not duplicate the structured content in the `tool_result.content`.
|
|
420
|
+
max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
|
|
421
|
+
Defaults to no retries.
|
|
422
|
+
retry_backoff_seconds_base: The base delay, in seconds, for exponential
|
|
423
|
+
backoff between retries.
|
|
391
424
|
"""
|
|
392
425
|
super().__init__(
|
|
393
426
|
cache_tools_list,
|
|
394
427
|
client_session_timeout_seconds,
|
|
395
428
|
tool_filter,
|
|
396
429
|
use_structured_content,
|
|
430
|
+
max_retry_attempts,
|
|
431
|
+
retry_backoff_seconds_base,
|
|
397
432
|
)
|
|
398
433
|
|
|
399
434
|
self.params = StdioServerParameters(
|
|
@@ -455,6 +490,8 @@ class MCPServerSse(_MCPServerWithClientSession):
|
|
|
455
490
|
client_session_timeout_seconds: float | None = 5,
|
|
456
491
|
tool_filter: ToolFilter = None,
|
|
457
492
|
use_structured_content: bool = False,
|
|
493
|
+
max_retry_attempts: int = 0,
|
|
494
|
+
retry_backoff_seconds_base: float = 1.0,
|
|
458
495
|
):
|
|
459
496
|
"""Create a new MCP server based on the HTTP with SSE transport.
|
|
460
497
|
|
|
@@ -480,12 +517,18 @@ class MCPServerSse(_MCPServerWithClientSession):
|
|
|
480
517
|
include the structured content in the `tool_result.content`, and using it by
|
|
481
518
|
default will cause duplicate content. You can set this to True if you know the
|
|
482
519
|
server will not duplicate the structured content in the `tool_result.content`.
|
|
520
|
+
max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
|
|
521
|
+
Defaults to no retries.
|
|
522
|
+
retry_backoff_seconds_base: The base delay, in seconds, for exponential
|
|
523
|
+
backoff between retries.
|
|
483
524
|
"""
|
|
484
525
|
super().__init__(
|
|
485
526
|
cache_tools_list,
|
|
486
527
|
client_session_timeout_seconds,
|
|
487
528
|
tool_filter,
|
|
488
529
|
use_structured_content,
|
|
530
|
+
max_retry_attempts,
|
|
531
|
+
retry_backoff_seconds_base,
|
|
489
532
|
)
|
|
490
533
|
|
|
491
534
|
self.params = params
|
|
@@ -547,6 +590,8 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
|
|
|
547
590
|
client_session_timeout_seconds: float | None = 5,
|
|
548
591
|
tool_filter: ToolFilter = None,
|
|
549
592
|
use_structured_content: bool = False,
|
|
593
|
+
max_retry_attempts: int = 0,
|
|
594
|
+
retry_backoff_seconds_base: float = 1.0,
|
|
550
595
|
):
|
|
551
596
|
"""Create a new MCP server based on the Streamable HTTP transport.
|
|
552
597
|
|
|
@@ -573,12 +618,18 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
|
|
|
573
618
|
include the structured content in the `tool_result.content`, and using it by
|
|
574
619
|
default will cause duplicate content. You can set this to True if you know the
|
|
575
620
|
server will not duplicate the structured content in the `tool_result.content`.
|
|
621
|
+
max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
|
|
622
|
+
Defaults to no retries.
|
|
623
|
+
retry_backoff_seconds_base: The base delay, in seconds, for exponential
|
|
624
|
+
backoff between retries.
|
|
576
625
|
"""
|
|
577
626
|
super().__init__(
|
|
578
627
|
cache_tools_list,
|
|
579
628
|
client_session_timeout_seconds,
|
|
580
629
|
tool_filter,
|
|
581
630
|
use_structured_content,
|
|
631
|
+
max_retry_attempts,
|
|
632
|
+
retry_backoff_seconds_base,
|
|
582
633
|
)
|
|
583
634
|
|
|
584
635
|
self.params = params
|
agents/memory/__init__.py
CHANGED
|
@@ -1,3 +1,10 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from .openai_conversations_session import OpenAIConversationsSession
|
|
2
|
+
from .session import Session, SessionABC
|
|
3
|
+
from .sqlite_session import SQLiteSession
|
|
2
4
|
|
|
3
|
-
__all__ = [
|
|
5
|
+
__all__ = [
|
|
6
|
+
"Session",
|
|
7
|
+
"SessionABC",
|
|
8
|
+
"SQLiteSession",
|
|
9
|
+
"OpenAIConversationsSession",
|
|
10
|
+
]
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from openai import AsyncOpenAI
|
|
4
|
+
|
|
5
|
+
from agents.models._openai_shared import get_default_openai_client
|
|
6
|
+
|
|
7
|
+
from ..items import TResponseInputItem
|
|
8
|
+
from .session import SessionABC
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
async def start_openai_conversations_session(openai_client: AsyncOpenAI | None = None) -> str:
|
|
12
|
+
_maybe_openai_client = openai_client
|
|
13
|
+
if openai_client is None:
|
|
14
|
+
_maybe_openai_client = get_default_openai_client() or AsyncOpenAI()
|
|
15
|
+
# this never be None here
|
|
16
|
+
_openai_client: AsyncOpenAI = _maybe_openai_client # type: ignore [assignment]
|
|
17
|
+
|
|
18
|
+
response = await _openai_client.conversations.create(items=[])
|
|
19
|
+
return response.id
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
_EMPTY_SESSION_ID = ""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OpenAIConversationsSession(SessionABC):
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
*,
|
|
29
|
+
conversation_id: str | None = None,
|
|
30
|
+
openai_client: AsyncOpenAI | None = None,
|
|
31
|
+
):
|
|
32
|
+
self._session_id: str | None = conversation_id
|
|
33
|
+
_openai_client = openai_client
|
|
34
|
+
if _openai_client is None:
|
|
35
|
+
_openai_client = get_default_openai_client() or AsyncOpenAI()
|
|
36
|
+
# this never be None here
|
|
37
|
+
self._openai_client: AsyncOpenAI = _openai_client
|
|
38
|
+
|
|
39
|
+
async def _get_session_id(self) -> str:
|
|
40
|
+
if self._session_id is None:
|
|
41
|
+
self._session_id = await start_openai_conversations_session(self._openai_client)
|
|
42
|
+
return self._session_id
|
|
43
|
+
|
|
44
|
+
async def _clear_session_id(self) -> None:
|
|
45
|
+
self._session_id = None
|
|
46
|
+
|
|
47
|
+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
|
|
48
|
+
session_id = await self._get_session_id()
|
|
49
|
+
all_items = []
|
|
50
|
+
if limit is None:
|
|
51
|
+
async for item in self._openai_client.conversations.items.list(
|
|
52
|
+
conversation_id=session_id,
|
|
53
|
+
order="asc",
|
|
54
|
+
):
|
|
55
|
+
# calling model_dump() to make this serializable
|
|
56
|
+
all_items.append(item.model_dump())
|
|
57
|
+
else:
|
|
58
|
+
async for item in self._openai_client.conversations.items.list(
|
|
59
|
+
conversation_id=session_id,
|
|
60
|
+
limit=limit,
|
|
61
|
+
order="desc",
|
|
62
|
+
):
|
|
63
|
+
# calling model_dump() to make this serializable
|
|
64
|
+
all_items.append(item.model_dump())
|
|
65
|
+
if limit is not None and len(all_items) >= limit:
|
|
66
|
+
break
|
|
67
|
+
all_items.reverse()
|
|
68
|
+
|
|
69
|
+
return all_items # type: ignore
|
|
70
|
+
|
|
71
|
+
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
|
72
|
+
session_id = await self._get_session_id()
|
|
73
|
+
await self._openai_client.conversations.items.create(
|
|
74
|
+
conversation_id=session_id,
|
|
75
|
+
items=items,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
async def pop_item(self) -> TResponseInputItem | None:
|
|
79
|
+
session_id = await self._get_session_id()
|
|
80
|
+
items = await self.get_items(limit=1)
|
|
81
|
+
if not items:
|
|
82
|
+
return None
|
|
83
|
+
item_id: str = str(items[0]["id"]) # type: ignore [typeddict-item]
|
|
84
|
+
await self._openai_client.conversations.items.delete(
|
|
85
|
+
conversation_id=session_id, item_id=item_id
|
|
86
|
+
)
|
|
87
|
+
return items[0]
|
|
88
|
+
|
|
89
|
+
async def clear_session(self) -> None:
|
|
90
|
+
session_id = await self._get_session_id()
|
|
91
|
+
await self._openai_client.conversations.delete(
|
|
92
|
+
conversation_id=session_id,
|
|
93
|
+
)
|
|
94
|
+
await self._clear_session_id()
|
agents/memory/session.py
CHANGED
|
@@ -1,11 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import asyncio
|
|
4
|
-
import json
|
|
5
|
-
import sqlite3
|
|
6
|
-
import threading
|
|
7
3
|
from abc import ABC, abstractmethod
|
|
8
|
-
from pathlib import Path
|
|
9
4
|
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
|
10
5
|
|
|
11
6
|
if TYPE_CHECKING:
|
|
@@ -102,268 +97,3 @@ class SessionABC(ABC):
|
|
|
102
97
|
async def clear_session(self) -> None:
|
|
103
98
|
"""Clear all items for this session."""
|
|
104
99
|
...
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
class SQLiteSession(SessionABC):
|
|
108
|
-
"""SQLite-based implementation of session storage.
|
|
109
|
-
|
|
110
|
-
This implementation stores conversation history in a SQLite database.
|
|
111
|
-
By default, uses an in-memory database that is lost when the process ends.
|
|
112
|
-
For persistent storage, provide a file path.
|
|
113
|
-
"""
|
|
114
|
-
|
|
115
|
-
def __init__(
|
|
116
|
-
self,
|
|
117
|
-
session_id: str,
|
|
118
|
-
db_path: str | Path = ":memory:",
|
|
119
|
-
sessions_table: str = "agent_sessions",
|
|
120
|
-
messages_table: str = "agent_messages",
|
|
121
|
-
):
|
|
122
|
-
"""Initialize the SQLite session.
|
|
123
|
-
|
|
124
|
-
Args:
|
|
125
|
-
session_id: Unique identifier for the conversation session
|
|
126
|
-
db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database)
|
|
127
|
-
sessions_table: Name of the table to store session metadata. Defaults to
|
|
128
|
-
'agent_sessions'
|
|
129
|
-
messages_table: Name of the table to store message data. Defaults to 'agent_messages'
|
|
130
|
-
"""
|
|
131
|
-
self.session_id = session_id
|
|
132
|
-
self.db_path = db_path
|
|
133
|
-
self.sessions_table = sessions_table
|
|
134
|
-
self.messages_table = messages_table
|
|
135
|
-
self._local = threading.local()
|
|
136
|
-
self._lock = threading.Lock()
|
|
137
|
-
|
|
138
|
-
# For in-memory databases, we need a shared connection to avoid thread isolation
|
|
139
|
-
# For file databases, we use thread-local connections for better concurrency
|
|
140
|
-
self._is_memory_db = str(db_path) == ":memory:"
|
|
141
|
-
if self._is_memory_db:
|
|
142
|
-
self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False)
|
|
143
|
-
self._shared_connection.execute("PRAGMA journal_mode=WAL")
|
|
144
|
-
self._init_db_for_connection(self._shared_connection)
|
|
145
|
-
else:
|
|
146
|
-
# For file databases, initialize the schema once since it persists
|
|
147
|
-
init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
|
148
|
-
init_conn.execute("PRAGMA journal_mode=WAL")
|
|
149
|
-
self._init_db_for_connection(init_conn)
|
|
150
|
-
init_conn.close()
|
|
151
|
-
|
|
152
|
-
def _get_connection(self) -> sqlite3.Connection:
|
|
153
|
-
"""Get a database connection."""
|
|
154
|
-
if self._is_memory_db:
|
|
155
|
-
# Use shared connection for in-memory database to avoid thread isolation
|
|
156
|
-
return self._shared_connection
|
|
157
|
-
else:
|
|
158
|
-
# Use thread-local connections for file databases
|
|
159
|
-
if not hasattr(self._local, "connection"):
|
|
160
|
-
self._local.connection = sqlite3.connect(
|
|
161
|
-
str(self.db_path),
|
|
162
|
-
check_same_thread=False,
|
|
163
|
-
)
|
|
164
|
-
self._local.connection.execute("PRAGMA journal_mode=WAL")
|
|
165
|
-
assert isinstance(self._local.connection, sqlite3.Connection), (
|
|
166
|
-
f"Expected sqlite3.Connection, got {type(self._local.connection)}"
|
|
167
|
-
)
|
|
168
|
-
return self._local.connection
|
|
169
|
-
|
|
170
|
-
def _init_db_for_connection(self, conn: sqlite3.Connection) -> None:
|
|
171
|
-
"""Initialize the database schema for a specific connection."""
|
|
172
|
-
conn.execute(
|
|
173
|
-
f"""
|
|
174
|
-
CREATE TABLE IF NOT EXISTS {self.sessions_table} (
|
|
175
|
-
session_id TEXT PRIMARY KEY,
|
|
176
|
-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
177
|
-
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
178
|
-
)
|
|
179
|
-
"""
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
conn.execute(
|
|
183
|
-
f"""
|
|
184
|
-
CREATE TABLE IF NOT EXISTS {self.messages_table} (
|
|
185
|
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
186
|
-
session_id TEXT NOT NULL,
|
|
187
|
-
message_data TEXT NOT NULL,
|
|
188
|
-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
189
|
-
FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id)
|
|
190
|
-
ON DELETE CASCADE
|
|
191
|
-
)
|
|
192
|
-
"""
|
|
193
|
-
)
|
|
194
|
-
|
|
195
|
-
conn.execute(
|
|
196
|
-
f"""
|
|
197
|
-
CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id
|
|
198
|
-
ON {self.messages_table} (session_id, created_at)
|
|
199
|
-
"""
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
conn.commit()
|
|
203
|
-
|
|
204
|
-
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
|
|
205
|
-
"""Retrieve the conversation history for this session.
|
|
206
|
-
|
|
207
|
-
Args:
|
|
208
|
-
limit: Maximum number of items to retrieve. If None, retrieves all items.
|
|
209
|
-
When specified, returns the latest N items in chronological order.
|
|
210
|
-
|
|
211
|
-
Returns:
|
|
212
|
-
List of input items representing the conversation history
|
|
213
|
-
"""
|
|
214
|
-
|
|
215
|
-
def _get_items_sync():
|
|
216
|
-
conn = self._get_connection()
|
|
217
|
-
with self._lock if self._is_memory_db else threading.Lock():
|
|
218
|
-
if limit is None:
|
|
219
|
-
# Fetch all items in chronological order
|
|
220
|
-
cursor = conn.execute(
|
|
221
|
-
f"""
|
|
222
|
-
SELECT message_data FROM {self.messages_table}
|
|
223
|
-
WHERE session_id = ?
|
|
224
|
-
ORDER BY created_at ASC
|
|
225
|
-
""",
|
|
226
|
-
(self.session_id,),
|
|
227
|
-
)
|
|
228
|
-
else:
|
|
229
|
-
# Fetch the latest N items in chronological order
|
|
230
|
-
cursor = conn.execute(
|
|
231
|
-
f"""
|
|
232
|
-
SELECT message_data FROM {self.messages_table}
|
|
233
|
-
WHERE session_id = ?
|
|
234
|
-
ORDER BY created_at DESC
|
|
235
|
-
LIMIT ?
|
|
236
|
-
""",
|
|
237
|
-
(self.session_id, limit),
|
|
238
|
-
)
|
|
239
|
-
|
|
240
|
-
rows = cursor.fetchall()
|
|
241
|
-
|
|
242
|
-
# Reverse to get chronological order when using DESC
|
|
243
|
-
if limit is not None:
|
|
244
|
-
rows = list(reversed(rows))
|
|
245
|
-
|
|
246
|
-
items = []
|
|
247
|
-
for (message_data,) in rows:
|
|
248
|
-
try:
|
|
249
|
-
item = json.loads(message_data)
|
|
250
|
-
items.append(item)
|
|
251
|
-
except json.JSONDecodeError:
|
|
252
|
-
# Skip invalid JSON entries
|
|
253
|
-
continue
|
|
254
|
-
|
|
255
|
-
return items
|
|
256
|
-
|
|
257
|
-
return await asyncio.to_thread(_get_items_sync)
|
|
258
|
-
|
|
259
|
-
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
|
260
|
-
"""Add new items to the conversation history.
|
|
261
|
-
|
|
262
|
-
Args:
|
|
263
|
-
items: List of input items to add to the history
|
|
264
|
-
"""
|
|
265
|
-
if not items:
|
|
266
|
-
return
|
|
267
|
-
|
|
268
|
-
def _add_items_sync():
|
|
269
|
-
conn = self._get_connection()
|
|
270
|
-
|
|
271
|
-
with self._lock if self._is_memory_db else threading.Lock():
|
|
272
|
-
# Ensure session exists
|
|
273
|
-
conn.execute(
|
|
274
|
-
f"""
|
|
275
|
-
INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?)
|
|
276
|
-
""",
|
|
277
|
-
(self.session_id,),
|
|
278
|
-
)
|
|
279
|
-
|
|
280
|
-
# Add items
|
|
281
|
-
message_data = [(self.session_id, json.dumps(item)) for item in items]
|
|
282
|
-
conn.executemany(
|
|
283
|
-
f"""
|
|
284
|
-
INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?)
|
|
285
|
-
""",
|
|
286
|
-
message_data,
|
|
287
|
-
)
|
|
288
|
-
|
|
289
|
-
# Update session timestamp
|
|
290
|
-
conn.execute(
|
|
291
|
-
f"""
|
|
292
|
-
UPDATE {self.sessions_table}
|
|
293
|
-
SET updated_at = CURRENT_TIMESTAMP
|
|
294
|
-
WHERE session_id = ?
|
|
295
|
-
""",
|
|
296
|
-
(self.session_id,),
|
|
297
|
-
)
|
|
298
|
-
|
|
299
|
-
conn.commit()
|
|
300
|
-
|
|
301
|
-
await asyncio.to_thread(_add_items_sync)
|
|
302
|
-
|
|
303
|
-
async def pop_item(self) -> TResponseInputItem | None:
|
|
304
|
-
"""Remove and return the most recent item from the session.
|
|
305
|
-
|
|
306
|
-
Returns:
|
|
307
|
-
The most recent item if it exists, None if the session is empty
|
|
308
|
-
"""
|
|
309
|
-
|
|
310
|
-
def _pop_item_sync():
|
|
311
|
-
conn = self._get_connection()
|
|
312
|
-
with self._lock if self._is_memory_db else threading.Lock():
|
|
313
|
-
# Use DELETE with RETURNING to atomically delete and return the most recent item
|
|
314
|
-
cursor = conn.execute(
|
|
315
|
-
f"""
|
|
316
|
-
DELETE FROM {self.messages_table}
|
|
317
|
-
WHERE id = (
|
|
318
|
-
SELECT id FROM {self.messages_table}
|
|
319
|
-
WHERE session_id = ?
|
|
320
|
-
ORDER BY created_at DESC
|
|
321
|
-
LIMIT 1
|
|
322
|
-
)
|
|
323
|
-
RETURNING message_data
|
|
324
|
-
""",
|
|
325
|
-
(self.session_id,),
|
|
326
|
-
)
|
|
327
|
-
|
|
328
|
-
result = cursor.fetchone()
|
|
329
|
-
conn.commit()
|
|
330
|
-
|
|
331
|
-
if result:
|
|
332
|
-
message_data = result[0]
|
|
333
|
-
try:
|
|
334
|
-
item = json.loads(message_data)
|
|
335
|
-
return item
|
|
336
|
-
except json.JSONDecodeError:
|
|
337
|
-
# Return None for corrupted JSON entries (already deleted)
|
|
338
|
-
return None
|
|
339
|
-
|
|
340
|
-
return None
|
|
341
|
-
|
|
342
|
-
return await asyncio.to_thread(_pop_item_sync)
|
|
343
|
-
|
|
344
|
-
async def clear_session(self) -> None:
|
|
345
|
-
"""Clear all items for this session."""
|
|
346
|
-
|
|
347
|
-
def _clear_session_sync():
|
|
348
|
-
conn = self._get_connection()
|
|
349
|
-
with self._lock if self._is_memory_db else threading.Lock():
|
|
350
|
-
conn.execute(
|
|
351
|
-
f"DELETE FROM {self.messages_table} WHERE session_id = ?",
|
|
352
|
-
(self.session_id,),
|
|
353
|
-
)
|
|
354
|
-
conn.execute(
|
|
355
|
-
f"DELETE FROM {self.sessions_table} WHERE session_id = ?",
|
|
356
|
-
(self.session_id,),
|
|
357
|
-
)
|
|
358
|
-
conn.commit()
|
|
359
|
-
|
|
360
|
-
await asyncio.to_thread(_clear_session_sync)
|
|
361
|
-
|
|
362
|
-
def close(self) -> None:
|
|
363
|
-
"""Close the database connection."""
|
|
364
|
-
if self._is_memory_db:
|
|
365
|
-
if hasattr(self, "_shared_connection"):
|
|
366
|
-
self._shared_connection.close()
|
|
367
|
-
else:
|
|
368
|
-
if hasattr(self._local, "connection"):
|
|
369
|
-
self._local.connection.close()
|