openai-agents 0.2.9__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +3 -1
- agents/_run_impl.py +40 -6
- agents/extensions/memory/sqlalchemy_session.py +45 -31
- agents/extensions/models/litellm_model.py +7 -4
- agents/handoffs.py +3 -3
- agents/memory/__init__.py +9 -2
- agents/memory/openai_conversations_session.py +94 -0
- agents/memory/session.py +0 -270
- agents/memory/sqlite_session.py +275 -0
- agents/model_settings.py +4 -2
- agents/models/chatcmpl_stream_handler.py +81 -17
- agents/models/interface.py +4 -0
- agents/models/openai_chatcompletions.py +4 -2
- agents/models/openai_responses.py +24 -10
- agents/realtime/openai_realtime.py +14 -3
- agents/run.py +110 -7
- agents/tool.py +4 -0
- agents/tracing/processors.py +2 -2
- {openai_agents-0.2.9.dist-info → openai_agents-0.2.10.dist-info}/METADATA +2 -2
- {openai_agents-0.2.9.dist-info → openai_agents-0.2.10.dist-info}/RECORD +22 -20
- {openai_agents-0.2.9.dist-info → openai_agents-0.2.10.dist-info}/WHEEL +0 -0
- {openai_agents-0.2.9.dist-info → openai_agents-0.2.10.dist-info}/licenses/LICENSE +0 -0
agents/__init__.py
CHANGED
|
@@ -46,7 +46,7 @@ from .items import (
|
|
|
46
46
|
TResponseInputItem,
|
|
47
47
|
)
|
|
48
48
|
from .lifecycle import AgentHooks, RunHooks
|
|
49
|
-
from .memory import Session, SQLiteSession
|
|
49
|
+
from .memory import OpenAIConversationsSession, Session, SessionABC, SQLiteSession
|
|
50
50
|
from .model_settings import ModelSettings
|
|
51
51
|
from .models.interface import Model, ModelProvider, ModelTracing
|
|
52
52
|
from .models.multi_provider import MultiProvider
|
|
@@ -221,7 +221,9 @@ __all__ = [
|
|
|
221
221
|
"RunHooks",
|
|
222
222
|
"AgentHooks",
|
|
223
223
|
"Session",
|
|
224
|
+
"SessionABC",
|
|
224
225
|
"SQLiteSession",
|
|
226
|
+
"OpenAIConversationsSession",
|
|
225
227
|
"RunContextWrapper",
|
|
226
228
|
"TContext",
|
|
227
229
|
"RunErrorDetails",
|
agents/_run_impl.py
CHANGED
|
@@ -509,13 +509,29 @@ class RunImpl:
|
|
|
509
509
|
# Regular function tool call
|
|
510
510
|
else:
|
|
511
511
|
if output.name not in function_map:
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
512
|
+
if output_schema is not None and output.name == "json_tool_call":
|
|
513
|
+
# LiteLLM could generate non-existent tool calls for structured outputs
|
|
514
|
+
items.append(ToolCallItem(raw_item=output, agent=agent))
|
|
515
|
+
functions.append(
|
|
516
|
+
ToolRunFunction(
|
|
517
|
+
tool_call=output,
|
|
518
|
+
# this tool does not exist in function_map, so generate ad-hoc one,
|
|
519
|
+
# which just parses the input if it's a string, and returns the
|
|
520
|
+
# value otherwise
|
|
521
|
+
function_tool=_build_litellm_json_tool_call(output),
|
|
522
|
+
)
|
|
516
523
|
)
|
|
517
|
-
|
|
518
|
-
|
|
524
|
+
continue
|
|
525
|
+
else:
|
|
526
|
+
_error_tracing.attach_error_to_current_span(
|
|
527
|
+
SpanError(
|
|
528
|
+
message="Tool not found",
|
|
529
|
+
data={"tool_name": output.name},
|
|
530
|
+
)
|
|
531
|
+
)
|
|
532
|
+
error = f"Tool {output.name} not found in agent {agent.name}"
|
|
533
|
+
raise ModelBehaviorError(error)
|
|
534
|
+
|
|
519
535
|
items.append(ToolCallItem(raw_item=output, agent=agent))
|
|
520
536
|
functions.append(
|
|
521
537
|
ToolRunFunction(
|
|
@@ -1193,3 +1209,21 @@ class LocalShellAction:
|
|
|
1193
1209
|
# "id": "out" + call.tool_call.id, # TODO remove this, it should be optional
|
|
1194
1210
|
},
|
|
1195
1211
|
)
|
|
1212
|
+
|
|
1213
|
+
|
|
1214
|
+
def _build_litellm_json_tool_call(output: ResponseFunctionToolCall) -> FunctionTool:
|
|
1215
|
+
async def on_invoke_tool(_ctx: ToolContext[Any], value: Any) -> Any:
|
|
1216
|
+
if isinstance(value, str):
|
|
1217
|
+
import json
|
|
1218
|
+
|
|
1219
|
+
return json.loads(value)
|
|
1220
|
+
return value
|
|
1221
|
+
|
|
1222
|
+
return FunctionTool(
|
|
1223
|
+
name=output.name,
|
|
1224
|
+
description=output.name,
|
|
1225
|
+
params_json_schema={},
|
|
1226
|
+
on_invoke_tool=on_invoke_tool,
|
|
1227
|
+
strict_json_schema=True,
|
|
1228
|
+
is_enabled=True,
|
|
1229
|
+
)
|
|
@@ -64,23 +64,19 @@ class SQLAlchemySession(SessionABC):
|
|
|
64
64
|
create_tables: bool = False,
|
|
65
65
|
sessions_table: str = "agent_sessions",
|
|
66
66
|
messages_table: str = "agent_messages",
|
|
67
|
-
):
|
|
68
|
-
"""
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
Defaults to *False* for production use. Set to *True* for development
|
|
81
|
-
and testing when migrations aren't used.
|
|
82
|
-
sessions_table, messages_table
|
|
83
|
-
Override default table names if needed.
|
|
67
|
+
):
|
|
68
|
+
"""Initializes a new SQLAlchemySession.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
session_id (str): Unique identifier for the conversation.
|
|
72
|
+
engine (AsyncEngine): A pre-configured SQLAlchemy async engine. The engine
|
|
73
|
+
must be created with an async driver (e.g., 'postgresql+asyncpg://',
|
|
74
|
+
'mysql+aiomysql://', or 'sqlite+aiosqlite://').
|
|
75
|
+
create_tables (bool, optional): Whether to automatically create the required
|
|
76
|
+
tables and indexes. Defaults to False for production use. Set to True for
|
|
77
|
+
development and testing when migrations aren't used.
|
|
78
|
+
sessions_table (str, optional): Override the default table name for sessions if needed.
|
|
79
|
+
messages_table (str, optional): Override the default table name for messages if needed.
|
|
84
80
|
"""
|
|
85
81
|
self.session_id = session_id
|
|
86
82
|
self._engine = engine
|
|
@@ -132,9 +128,7 @@ class SQLAlchemySession(SessionABC):
|
|
|
132
128
|
)
|
|
133
129
|
|
|
134
130
|
# Async session factory
|
|
135
|
-
self._session_factory = async_sessionmaker(
|
|
136
|
-
self._engine, expire_on_commit=False
|
|
137
|
-
)
|
|
131
|
+
self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False)
|
|
138
132
|
|
|
139
133
|
self._create_tables = create_tables
|
|
140
134
|
|
|
@@ -152,16 +146,16 @@ class SQLAlchemySession(SessionABC):
|
|
|
152
146
|
) -> SQLAlchemySession:
|
|
153
147
|
"""Create a session from a database URL string.
|
|
154
148
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
149
|
+
Args:
|
|
150
|
+
session_id (str): Conversation ID.
|
|
151
|
+
url (str): Any SQLAlchemy async URL, e.g. "postgresql+asyncpg://user:pass@host/db".
|
|
152
|
+
engine_kwargs (dict[str, Any] | None): Additional keyword arguments forwarded to
|
|
153
|
+
sqlalchemy.ext.asyncio.create_async_engine.
|
|
154
|
+
**kwargs: Additional keyword arguments forwarded to the main constructor
|
|
155
|
+
(e.g., create_tables, custom table names, etc.).
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
SQLAlchemySession: An instance of SQLAlchemySession connected to the specified database.
|
|
165
159
|
"""
|
|
166
160
|
engine_kwargs = engine_kwargs or {}
|
|
167
161
|
engine = create_async_engine(url, **engine_kwargs)
|
|
@@ -186,6 +180,15 @@ class SQLAlchemySession(SessionABC):
|
|
|
186
180
|
self._create_tables = False # Only create once
|
|
187
181
|
|
|
188
182
|
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
|
|
183
|
+
"""Retrieve the conversation history for this session.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
limit: Maximum number of items to retrieve. If None, retrieves all items.
|
|
187
|
+
When specified, returns the latest N items in chronological order.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
List of input items representing the conversation history
|
|
191
|
+
"""
|
|
189
192
|
await self._ensure_tables()
|
|
190
193
|
async with self._session_factory() as sess:
|
|
191
194
|
if limit is None:
|
|
@@ -220,6 +223,11 @@ class SQLAlchemySession(SessionABC):
|
|
|
220
223
|
return items
|
|
221
224
|
|
|
222
225
|
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
|
226
|
+
"""Add new items to the conversation history.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
items: List of input items to add to the history
|
|
230
|
+
"""
|
|
223
231
|
if not items:
|
|
224
232
|
return
|
|
225
233
|
|
|
@@ -258,6 +266,11 @@ class SQLAlchemySession(SessionABC):
|
|
|
258
266
|
)
|
|
259
267
|
|
|
260
268
|
async def pop_item(self) -> TResponseInputItem | None:
|
|
269
|
+
"""Remove and return the most recent item from the session.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
The most recent item if it exists, None if the session is empty
|
|
273
|
+
"""
|
|
261
274
|
await self._ensure_tables()
|
|
262
275
|
async with self._session_factory() as sess:
|
|
263
276
|
async with sess.begin():
|
|
@@ -286,7 +299,8 @@ class SQLAlchemySession(SessionABC):
|
|
|
286
299
|
except json.JSONDecodeError:
|
|
287
300
|
return None
|
|
288
301
|
|
|
289
|
-
async def clear_session(self) -> None:
|
|
302
|
+
async def clear_session(self) -> None:
|
|
303
|
+
"""Clear all items for this session."""
|
|
290
304
|
await self._ensure_tables()
|
|
291
305
|
async with self._session_factory() as sess:
|
|
292
306
|
async with sess.begin():
|
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import json
|
|
4
4
|
import time
|
|
5
5
|
from collections.abc import AsyncIterator
|
|
6
|
+
from copy import copy
|
|
6
7
|
from typing import Any, Literal, cast, overload
|
|
7
8
|
|
|
8
9
|
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
|
|
@@ -82,7 +83,8 @@ class LitellmModel(Model):
|
|
|
82
83
|
output_schema: AgentOutputSchemaBase | None,
|
|
83
84
|
handoffs: list[Handoff],
|
|
84
85
|
tracing: ModelTracing,
|
|
85
|
-
previous_response_id: str | None,
|
|
86
|
+
previous_response_id: str | None = None, # unused
|
|
87
|
+
conversation_id: str | None = None, # unused
|
|
86
88
|
prompt: Any | None = None,
|
|
87
89
|
) -> ModelResponse:
|
|
88
90
|
with generation_span(
|
|
@@ -171,7 +173,8 @@ class LitellmModel(Model):
|
|
|
171
173
|
output_schema: AgentOutputSchemaBase | None,
|
|
172
174
|
handoffs: list[Handoff],
|
|
173
175
|
tracing: ModelTracing,
|
|
174
|
-
previous_response_id: str | None,
|
|
176
|
+
previous_response_id: str | None = None, # unused
|
|
177
|
+
conversation_id: str | None = None, # unused
|
|
175
178
|
prompt: Any | None = None,
|
|
176
179
|
) -> AsyncIterator[TResponseStreamEvent]:
|
|
177
180
|
with generation_span(
|
|
@@ -300,9 +303,9 @@ class LitellmModel(Model):
|
|
|
300
303
|
|
|
301
304
|
extra_kwargs = {}
|
|
302
305
|
if model_settings.extra_query:
|
|
303
|
-
extra_kwargs["extra_query"] = model_settings.extra_query
|
|
306
|
+
extra_kwargs["extra_query"] = copy(model_settings.extra_query)
|
|
304
307
|
if model_settings.metadata:
|
|
305
|
-
extra_kwargs["metadata"] = model_settings.metadata
|
|
308
|
+
extra_kwargs["metadata"] = copy(model_settings.metadata)
|
|
306
309
|
if model_settings.extra_body and isinstance(model_settings.extra_body, dict):
|
|
307
310
|
extra_kwargs.update(model_settings.extra_body)
|
|
308
311
|
|
agents/handoffs.py
CHANGED
|
@@ -119,9 +119,9 @@ class Handoff(Generic[TContext, TAgent]):
|
|
|
119
119
|
True, as it increases the likelihood of correct JSON input.
|
|
120
120
|
"""
|
|
121
121
|
|
|
122
|
-
is_enabled: bool | Callable[
|
|
123
|
-
|
|
124
|
-
|
|
122
|
+
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = (
|
|
123
|
+
True
|
|
124
|
+
)
|
|
125
125
|
"""Whether the handoff is enabled. Either a bool or a Callable that takes the run context and
|
|
126
126
|
agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
|
|
127
127
|
a handoff based on your context/state."""
|
agents/memory/__init__.py
CHANGED
|
@@ -1,3 +1,10 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from .openai_conversations_session import OpenAIConversationsSession
|
|
2
|
+
from .session import Session, SessionABC
|
|
3
|
+
from .sqlite_session import SQLiteSession
|
|
2
4
|
|
|
3
|
-
__all__ = [
|
|
5
|
+
__all__ = [
|
|
6
|
+
"Session",
|
|
7
|
+
"SessionABC",
|
|
8
|
+
"SQLiteSession",
|
|
9
|
+
"OpenAIConversationsSession",
|
|
10
|
+
]
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from openai import AsyncOpenAI
|
|
4
|
+
|
|
5
|
+
from agents.models._openai_shared import get_default_openai_client
|
|
6
|
+
|
|
7
|
+
from ..items import TResponseInputItem
|
|
8
|
+
from .session import SessionABC
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
async def start_openai_conversations_session(openai_client: AsyncOpenAI | None = None) -> str:
|
|
12
|
+
_maybe_openai_client = openai_client
|
|
13
|
+
if openai_client is None:
|
|
14
|
+
_maybe_openai_client = get_default_openai_client() or AsyncOpenAI()
|
|
15
|
+
# this never be None here
|
|
16
|
+
_openai_client: AsyncOpenAI = _maybe_openai_client # type: ignore [assignment]
|
|
17
|
+
|
|
18
|
+
response = await _openai_client.conversations.create(items=[])
|
|
19
|
+
return response.id
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
_EMPTY_SESSION_ID = ""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OpenAIConversationsSession(SessionABC):
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
*,
|
|
29
|
+
conversation_id: str | None = None,
|
|
30
|
+
openai_client: AsyncOpenAI | None = None,
|
|
31
|
+
):
|
|
32
|
+
self._session_id: str | None = conversation_id
|
|
33
|
+
_openai_client = openai_client
|
|
34
|
+
if _openai_client is None:
|
|
35
|
+
_openai_client = get_default_openai_client() or AsyncOpenAI()
|
|
36
|
+
# this never be None here
|
|
37
|
+
self._openai_client: AsyncOpenAI = _openai_client
|
|
38
|
+
|
|
39
|
+
async def _get_session_id(self) -> str:
|
|
40
|
+
if self._session_id is None:
|
|
41
|
+
self._session_id = await start_openai_conversations_session(self._openai_client)
|
|
42
|
+
return self._session_id
|
|
43
|
+
|
|
44
|
+
async def _clear_session_id(self) -> None:
|
|
45
|
+
self._session_id = None
|
|
46
|
+
|
|
47
|
+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
|
|
48
|
+
session_id = await self._get_session_id()
|
|
49
|
+
all_items = []
|
|
50
|
+
if limit is None:
|
|
51
|
+
async for item in self._openai_client.conversations.items.list(
|
|
52
|
+
conversation_id=session_id,
|
|
53
|
+
order="asc",
|
|
54
|
+
):
|
|
55
|
+
# calling model_dump() to make this serializable
|
|
56
|
+
all_items.append(item.model_dump())
|
|
57
|
+
else:
|
|
58
|
+
async for item in self._openai_client.conversations.items.list(
|
|
59
|
+
conversation_id=session_id,
|
|
60
|
+
limit=limit,
|
|
61
|
+
order="desc",
|
|
62
|
+
):
|
|
63
|
+
# calling model_dump() to make this serializable
|
|
64
|
+
all_items.append(item.model_dump())
|
|
65
|
+
if limit is not None and len(all_items) >= limit:
|
|
66
|
+
break
|
|
67
|
+
all_items.reverse()
|
|
68
|
+
|
|
69
|
+
return all_items # type: ignore
|
|
70
|
+
|
|
71
|
+
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
|
72
|
+
session_id = await self._get_session_id()
|
|
73
|
+
await self._openai_client.conversations.items.create(
|
|
74
|
+
conversation_id=session_id,
|
|
75
|
+
items=items,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
async def pop_item(self) -> TResponseInputItem | None:
|
|
79
|
+
session_id = await self._get_session_id()
|
|
80
|
+
items = await self.get_items(limit=1)
|
|
81
|
+
if not items:
|
|
82
|
+
return None
|
|
83
|
+
item_id: str = str(items[0]["id"]) # type: ignore [typeddict-item]
|
|
84
|
+
await self._openai_client.conversations.items.delete(
|
|
85
|
+
conversation_id=session_id, item_id=item_id
|
|
86
|
+
)
|
|
87
|
+
return items[0]
|
|
88
|
+
|
|
89
|
+
async def clear_session(self) -> None:
|
|
90
|
+
session_id = await self._get_session_id()
|
|
91
|
+
await self._openai_client.conversations.delete(
|
|
92
|
+
conversation_id=session_id,
|
|
93
|
+
)
|
|
94
|
+
await self._clear_session_id()
|
agents/memory/session.py
CHANGED
|
@@ -1,11 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import asyncio
|
|
4
|
-
import json
|
|
5
|
-
import sqlite3
|
|
6
|
-
import threading
|
|
7
3
|
from abc import ABC, abstractmethod
|
|
8
|
-
from pathlib import Path
|
|
9
4
|
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
|
10
5
|
|
|
11
6
|
if TYPE_CHECKING:
|
|
@@ -102,268 +97,3 @@ class SessionABC(ABC):
|
|
|
102
97
|
async def clear_session(self) -> None:
|
|
103
98
|
"""Clear all items for this session."""
|
|
104
99
|
...
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
class SQLiteSession(SessionABC):
|
|
108
|
-
"""SQLite-based implementation of session storage.
|
|
109
|
-
|
|
110
|
-
This implementation stores conversation history in a SQLite database.
|
|
111
|
-
By default, uses an in-memory database that is lost when the process ends.
|
|
112
|
-
For persistent storage, provide a file path.
|
|
113
|
-
"""
|
|
114
|
-
|
|
115
|
-
def __init__(
|
|
116
|
-
self,
|
|
117
|
-
session_id: str,
|
|
118
|
-
db_path: str | Path = ":memory:",
|
|
119
|
-
sessions_table: str = "agent_sessions",
|
|
120
|
-
messages_table: str = "agent_messages",
|
|
121
|
-
):
|
|
122
|
-
"""Initialize the SQLite session.
|
|
123
|
-
|
|
124
|
-
Args:
|
|
125
|
-
session_id: Unique identifier for the conversation session
|
|
126
|
-
db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database)
|
|
127
|
-
sessions_table: Name of the table to store session metadata. Defaults to
|
|
128
|
-
'agent_sessions'
|
|
129
|
-
messages_table: Name of the table to store message data. Defaults to 'agent_messages'
|
|
130
|
-
"""
|
|
131
|
-
self.session_id = session_id
|
|
132
|
-
self.db_path = db_path
|
|
133
|
-
self.sessions_table = sessions_table
|
|
134
|
-
self.messages_table = messages_table
|
|
135
|
-
self._local = threading.local()
|
|
136
|
-
self._lock = threading.Lock()
|
|
137
|
-
|
|
138
|
-
# For in-memory databases, we need a shared connection to avoid thread isolation
|
|
139
|
-
# For file databases, we use thread-local connections for better concurrency
|
|
140
|
-
self._is_memory_db = str(db_path) == ":memory:"
|
|
141
|
-
if self._is_memory_db:
|
|
142
|
-
self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False)
|
|
143
|
-
self._shared_connection.execute("PRAGMA journal_mode=WAL")
|
|
144
|
-
self._init_db_for_connection(self._shared_connection)
|
|
145
|
-
else:
|
|
146
|
-
# For file databases, initialize the schema once since it persists
|
|
147
|
-
init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
|
148
|
-
init_conn.execute("PRAGMA journal_mode=WAL")
|
|
149
|
-
self._init_db_for_connection(init_conn)
|
|
150
|
-
init_conn.close()
|
|
151
|
-
|
|
152
|
-
def _get_connection(self) -> sqlite3.Connection:
|
|
153
|
-
"""Get a database connection."""
|
|
154
|
-
if self._is_memory_db:
|
|
155
|
-
# Use shared connection for in-memory database to avoid thread isolation
|
|
156
|
-
return self._shared_connection
|
|
157
|
-
else:
|
|
158
|
-
# Use thread-local connections for file databases
|
|
159
|
-
if not hasattr(self._local, "connection"):
|
|
160
|
-
self._local.connection = sqlite3.connect(
|
|
161
|
-
str(self.db_path),
|
|
162
|
-
check_same_thread=False,
|
|
163
|
-
)
|
|
164
|
-
self._local.connection.execute("PRAGMA journal_mode=WAL")
|
|
165
|
-
assert isinstance(self._local.connection, sqlite3.Connection), (
|
|
166
|
-
f"Expected sqlite3.Connection, got {type(self._local.connection)}"
|
|
167
|
-
)
|
|
168
|
-
return self._local.connection
|
|
169
|
-
|
|
170
|
-
def _init_db_for_connection(self, conn: sqlite3.Connection) -> None:
|
|
171
|
-
"""Initialize the database schema for a specific connection."""
|
|
172
|
-
conn.execute(
|
|
173
|
-
f"""
|
|
174
|
-
CREATE TABLE IF NOT EXISTS {self.sessions_table} (
|
|
175
|
-
session_id TEXT PRIMARY KEY,
|
|
176
|
-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
177
|
-
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
178
|
-
)
|
|
179
|
-
"""
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
conn.execute(
|
|
183
|
-
f"""
|
|
184
|
-
CREATE TABLE IF NOT EXISTS {self.messages_table} (
|
|
185
|
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
186
|
-
session_id TEXT NOT NULL,
|
|
187
|
-
message_data TEXT NOT NULL,
|
|
188
|
-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
189
|
-
FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id)
|
|
190
|
-
ON DELETE CASCADE
|
|
191
|
-
)
|
|
192
|
-
"""
|
|
193
|
-
)
|
|
194
|
-
|
|
195
|
-
conn.execute(
|
|
196
|
-
f"""
|
|
197
|
-
CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id
|
|
198
|
-
ON {self.messages_table} (session_id, created_at)
|
|
199
|
-
"""
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
conn.commit()
|
|
203
|
-
|
|
204
|
-
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
|
|
205
|
-
"""Retrieve the conversation history for this session.
|
|
206
|
-
|
|
207
|
-
Args:
|
|
208
|
-
limit: Maximum number of items to retrieve. If None, retrieves all items.
|
|
209
|
-
When specified, returns the latest N items in chronological order.
|
|
210
|
-
|
|
211
|
-
Returns:
|
|
212
|
-
List of input items representing the conversation history
|
|
213
|
-
"""
|
|
214
|
-
|
|
215
|
-
def _get_items_sync():
|
|
216
|
-
conn = self._get_connection()
|
|
217
|
-
with self._lock if self._is_memory_db else threading.Lock():
|
|
218
|
-
if limit is None:
|
|
219
|
-
# Fetch all items in chronological order
|
|
220
|
-
cursor = conn.execute(
|
|
221
|
-
f"""
|
|
222
|
-
SELECT message_data FROM {self.messages_table}
|
|
223
|
-
WHERE session_id = ?
|
|
224
|
-
ORDER BY created_at ASC
|
|
225
|
-
""",
|
|
226
|
-
(self.session_id,),
|
|
227
|
-
)
|
|
228
|
-
else:
|
|
229
|
-
# Fetch the latest N items in chronological order
|
|
230
|
-
cursor = conn.execute(
|
|
231
|
-
f"""
|
|
232
|
-
SELECT message_data FROM {self.messages_table}
|
|
233
|
-
WHERE session_id = ?
|
|
234
|
-
ORDER BY created_at DESC
|
|
235
|
-
LIMIT ?
|
|
236
|
-
""",
|
|
237
|
-
(self.session_id, limit),
|
|
238
|
-
)
|
|
239
|
-
|
|
240
|
-
rows = cursor.fetchall()
|
|
241
|
-
|
|
242
|
-
# Reverse to get chronological order when using DESC
|
|
243
|
-
if limit is not None:
|
|
244
|
-
rows = list(reversed(rows))
|
|
245
|
-
|
|
246
|
-
items = []
|
|
247
|
-
for (message_data,) in rows:
|
|
248
|
-
try:
|
|
249
|
-
item = json.loads(message_data)
|
|
250
|
-
items.append(item)
|
|
251
|
-
except json.JSONDecodeError:
|
|
252
|
-
# Skip invalid JSON entries
|
|
253
|
-
continue
|
|
254
|
-
|
|
255
|
-
return items
|
|
256
|
-
|
|
257
|
-
return await asyncio.to_thread(_get_items_sync)
|
|
258
|
-
|
|
259
|
-
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
|
260
|
-
"""Add new items to the conversation history.
|
|
261
|
-
|
|
262
|
-
Args:
|
|
263
|
-
items: List of input items to add to the history
|
|
264
|
-
"""
|
|
265
|
-
if not items:
|
|
266
|
-
return
|
|
267
|
-
|
|
268
|
-
def _add_items_sync():
|
|
269
|
-
conn = self._get_connection()
|
|
270
|
-
|
|
271
|
-
with self._lock if self._is_memory_db else threading.Lock():
|
|
272
|
-
# Ensure session exists
|
|
273
|
-
conn.execute(
|
|
274
|
-
f"""
|
|
275
|
-
INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?)
|
|
276
|
-
""",
|
|
277
|
-
(self.session_id,),
|
|
278
|
-
)
|
|
279
|
-
|
|
280
|
-
# Add items
|
|
281
|
-
message_data = [(self.session_id, json.dumps(item)) for item in items]
|
|
282
|
-
conn.executemany(
|
|
283
|
-
f"""
|
|
284
|
-
INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?)
|
|
285
|
-
""",
|
|
286
|
-
message_data,
|
|
287
|
-
)
|
|
288
|
-
|
|
289
|
-
# Update session timestamp
|
|
290
|
-
conn.execute(
|
|
291
|
-
f"""
|
|
292
|
-
UPDATE {self.sessions_table}
|
|
293
|
-
SET updated_at = CURRENT_TIMESTAMP
|
|
294
|
-
WHERE session_id = ?
|
|
295
|
-
""",
|
|
296
|
-
(self.session_id,),
|
|
297
|
-
)
|
|
298
|
-
|
|
299
|
-
conn.commit()
|
|
300
|
-
|
|
301
|
-
await asyncio.to_thread(_add_items_sync)
|
|
302
|
-
|
|
303
|
-
async def pop_item(self) -> TResponseInputItem | None:
|
|
304
|
-
"""Remove and return the most recent item from the session.
|
|
305
|
-
|
|
306
|
-
Returns:
|
|
307
|
-
The most recent item if it exists, None if the session is empty
|
|
308
|
-
"""
|
|
309
|
-
|
|
310
|
-
def _pop_item_sync():
|
|
311
|
-
conn = self._get_connection()
|
|
312
|
-
with self._lock if self._is_memory_db else threading.Lock():
|
|
313
|
-
# Use DELETE with RETURNING to atomically delete and return the most recent item
|
|
314
|
-
cursor = conn.execute(
|
|
315
|
-
f"""
|
|
316
|
-
DELETE FROM {self.messages_table}
|
|
317
|
-
WHERE id = (
|
|
318
|
-
SELECT id FROM {self.messages_table}
|
|
319
|
-
WHERE session_id = ?
|
|
320
|
-
ORDER BY created_at DESC
|
|
321
|
-
LIMIT 1
|
|
322
|
-
)
|
|
323
|
-
RETURNING message_data
|
|
324
|
-
""",
|
|
325
|
-
(self.session_id,),
|
|
326
|
-
)
|
|
327
|
-
|
|
328
|
-
result = cursor.fetchone()
|
|
329
|
-
conn.commit()
|
|
330
|
-
|
|
331
|
-
if result:
|
|
332
|
-
message_data = result[0]
|
|
333
|
-
try:
|
|
334
|
-
item = json.loads(message_data)
|
|
335
|
-
return item
|
|
336
|
-
except json.JSONDecodeError:
|
|
337
|
-
# Return None for corrupted JSON entries (already deleted)
|
|
338
|
-
return None
|
|
339
|
-
|
|
340
|
-
return None
|
|
341
|
-
|
|
342
|
-
return await asyncio.to_thread(_pop_item_sync)
|
|
343
|
-
|
|
344
|
-
async def clear_session(self) -> None:
|
|
345
|
-
"""Clear all items for this session."""
|
|
346
|
-
|
|
347
|
-
def _clear_session_sync():
|
|
348
|
-
conn = self._get_connection()
|
|
349
|
-
with self._lock if self._is_memory_db else threading.Lock():
|
|
350
|
-
conn.execute(
|
|
351
|
-
f"DELETE FROM {self.messages_table} WHERE session_id = ?",
|
|
352
|
-
(self.session_id,),
|
|
353
|
-
)
|
|
354
|
-
conn.execute(
|
|
355
|
-
f"DELETE FROM {self.sessions_table} WHERE session_id = ?",
|
|
356
|
-
(self.session_id,),
|
|
357
|
-
)
|
|
358
|
-
conn.commit()
|
|
359
|
-
|
|
360
|
-
await asyncio.to_thread(_clear_session_sync)
|
|
361
|
-
|
|
362
|
-
def close(self) -> None:
|
|
363
|
-
"""Close the database connection."""
|
|
364
|
-
if self._is_memory_db:
|
|
365
|
-
if hasattr(self, "_shared_connection"):
|
|
366
|
-
self._shared_connection.close()
|
|
367
|
-
else:
|
|
368
|
-
if hasattr(self._local, "connection"):
|
|
369
|
-
self._local.connection.close()
|