openai-agents 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +5 -1
- agents/_run_impl.py +5 -1
- agents/agent.py +62 -30
- agents/agent_output.py +2 -2
- agents/function_schema.py +11 -1
- agents/guardrail.py +5 -1
- agents/handoffs.py +32 -14
- agents/lifecycle.py +26 -17
- agents/mcp/server.py +82 -11
- agents/mcp/util.py +16 -9
- agents/memory/__init__.py +3 -0
- agents/memory/session.py +369 -0
- agents/model_settings.py +15 -7
- agents/models/chatcmpl_converter.py +20 -3
- agents/models/chatcmpl_stream_handler.py +134 -43
- agents/models/openai_responses.py +12 -5
- agents/realtime/README.md +3 -0
- agents/realtime/__init__.py +177 -0
- agents/realtime/agent.py +89 -0
- agents/realtime/config.py +188 -0
- agents/realtime/events.py +216 -0
- agents/realtime/handoffs.py +165 -0
- agents/realtime/items.py +184 -0
- agents/realtime/model.py +69 -0
- agents/realtime/model_events.py +159 -0
- agents/realtime/model_inputs.py +100 -0
- agents/realtime/openai_realtime.py +670 -0
- agents/realtime/runner.py +118 -0
- agents/realtime/session.py +535 -0
- agents/run.py +106 -4
- agents/tool.py +6 -7
- agents/tool_context.py +16 -3
- agents/voice/models/openai_stt.py +1 -1
- agents/voice/pipeline.py +6 -0
- agents/voice/workflow.py +8 -0
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/METADATA +121 -4
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/RECORD +39 -24
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/WHEEL +0 -0
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/licenses/LICENSE +0 -0
agents/mcp/server.py
CHANGED
|
@@ -13,7 +13,7 @@ from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_cli
|
|
|
13
13
|
from mcp.client.sse import sse_client
|
|
14
14
|
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
|
|
15
15
|
from mcp.shared.message import SessionMessage
|
|
16
|
-
from mcp.types import CallToolResult, InitializeResult
|
|
16
|
+
from mcp.types import CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult
|
|
17
17
|
from typing_extensions import NotRequired, TypedDict
|
|
18
18
|
|
|
19
19
|
from ..exceptions import UserError
|
|
@@ -22,12 +22,23 @@ from ..run_context import RunContextWrapper
|
|
|
22
22
|
from .util import ToolFilter, ToolFilterCallable, ToolFilterContext, ToolFilterStatic
|
|
23
23
|
|
|
24
24
|
if TYPE_CHECKING:
|
|
25
|
-
from ..agent import
|
|
25
|
+
from ..agent import AgentBase
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class MCPServer(abc.ABC):
|
|
29
29
|
"""Base class for Model Context Protocol servers."""
|
|
30
30
|
|
|
31
|
+
def __init__(self, use_structured_content: bool = False):
|
|
32
|
+
"""
|
|
33
|
+
Args:
|
|
34
|
+
use_structured_content: Whether to use `tool_result.structured_content` when calling an
|
|
35
|
+
MCP tool.Defaults to False for backwards compatibility - most MCP servers still
|
|
36
|
+
include the structured content in the `tool_result.content`, and using it by
|
|
37
|
+
default will cause duplicate content. You can set this to True if you know the
|
|
38
|
+
server will not duplicate the structured content in the `tool_result.content`.
|
|
39
|
+
"""
|
|
40
|
+
self.use_structured_content = use_structured_content
|
|
41
|
+
|
|
31
42
|
@abc.abstractmethod
|
|
32
43
|
async def connect(self):
|
|
33
44
|
"""Connect to the server. For example, this might mean spawning a subprocess or
|
|
@@ -52,8 +63,8 @@ class MCPServer(abc.ABC):
|
|
|
52
63
|
@abc.abstractmethod
|
|
53
64
|
async def list_tools(
|
|
54
65
|
self,
|
|
55
|
-
run_context: RunContextWrapper[Any],
|
|
56
|
-
agent:
|
|
66
|
+
run_context: RunContextWrapper[Any] | None = None,
|
|
67
|
+
agent: AgentBase | None = None,
|
|
57
68
|
) -> list[MCPTool]:
|
|
58
69
|
"""List the tools available on the server."""
|
|
59
70
|
pass
|
|
@@ -63,6 +74,20 @@ class MCPServer(abc.ABC):
|
|
|
63
74
|
"""Invoke a tool on the server."""
|
|
64
75
|
pass
|
|
65
76
|
|
|
77
|
+
@abc.abstractmethod
|
|
78
|
+
async def list_prompts(
|
|
79
|
+
self,
|
|
80
|
+
) -> ListPromptsResult:
|
|
81
|
+
"""List the prompts available on the server."""
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
@abc.abstractmethod
|
|
85
|
+
async def get_prompt(
|
|
86
|
+
self, name: str, arguments: dict[str, Any] | None = None
|
|
87
|
+
) -> GetPromptResult:
|
|
88
|
+
"""Get a specific prompt from the server."""
|
|
89
|
+
pass
|
|
90
|
+
|
|
66
91
|
|
|
67
92
|
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
68
93
|
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
|
|
@@ -72,6 +97,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
72
97
|
cache_tools_list: bool,
|
|
73
98
|
client_session_timeout_seconds: float | None,
|
|
74
99
|
tool_filter: ToolFilter = None,
|
|
100
|
+
use_structured_content: bool = False,
|
|
75
101
|
):
|
|
76
102
|
"""
|
|
77
103
|
Args:
|
|
@@ -84,7 +110,13 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
84
110
|
|
|
85
111
|
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
|
|
86
112
|
tool_filter: The tool filter to use for filtering tools.
|
|
113
|
+
use_structured_content: Whether to use `tool_result.structured_content` when calling an
|
|
114
|
+
MCP tool. Defaults to False for backwards compatibility - most MCP servers still
|
|
115
|
+
include the structured content in the `tool_result.content`, and using it by
|
|
116
|
+
default will cause duplicate content. You can set this to True if you know the
|
|
117
|
+
server will not duplicate the structured content in the `tool_result.content`.
|
|
87
118
|
"""
|
|
119
|
+
super().__init__(use_structured_content=use_structured_content)
|
|
88
120
|
self.session: ClientSession | None = None
|
|
89
121
|
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
|
90
122
|
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
|
|
@@ -103,7 +135,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
103
135
|
self,
|
|
104
136
|
tools: list[MCPTool],
|
|
105
137
|
run_context: RunContextWrapper[Any],
|
|
106
|
-
agent:
|
|
138
|
+
agent: AgentBase,
|
|
107
139
|
) -> list[MCPTool]:
|
|
108
140
|
"""Apply the tool filter to the list of tools."""
|
|
109
141
|
if self.tool_filter is None:
|
|
@@ -118,9 +150,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
118
150
|
return await self._apply_dynamic_tool_filter(tools, run_context, agent)
|
|
119
151
|
|
|
120
152
|
def _apply_static_tool_filter(
|
|
121
|
-
self,
|
|
122
|
-
tools: list[MCPTool],
|
|
123
|
-
static_filter: ToolFilterStatic
|
|
153
|
+
self, tools: list[MCPTool], static_filter: ToolFilterStatic
|
|
124
154
|
) -> list[MCPTool]:
|
|
125
155
|
"""Apply static tool filtering based on allowlist and blocklist."""
|
|
126
156
|
filtered_tools = tools
|
|
@@ -141,7 +171,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
141
171
|
self,
|
|
142
172
|
tools: list[MCPTool],
|
|
143
173
|
run_context: RunContextWrapper[Any],
|
|
144
|
-
agent:
|
|
174
|
+
agent: AgentBase,
|
|
145
175
|
) -> list[MCPTool]:
|
|
146
176
|
"""Apply dynamic tool filtering using a callable filter function."""
|
|
147
177
|
|
|
@@ -231,8 +261,8 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
231
261
|
|
|
232
262
|
async def list_tools(
|
|
233
263
|
self,
|
|
234
|
-
run_context: RunContextWrapper[Any],
|
|
235
|
-
agent:
|
|
264
|
+
run_context: RunContextWrapper[Any] | None = None,
|
|
265
|
+
agent: AgentBase | None = None,
|
|
236
266
|
) -> list[MCPTool]:
|
|
237
267
|
"""List the tools available on the server."""
|
|
238
268
|
if not self.session:
|
|
@@ -251,6 +281,8 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
251
281
|
# Filter tools based on tool_filter
|
|
252
282
|
filtered_tools = tools
|
|
253
283
|
if self.tool_filter is not None:
|
|
284
|
+
if run_context is None or agent is None:
|
|
285
|
+
raise UserError("run_context and agent are required for dynamic tool filtering")
|
|
254
286
|
filtered_tools = await self._apply_tool_filter(filtered_tools, run_context, agent)
|
|
255
287
|
return filtered_tools
|
|
256
288
|
|
|
@@ -261,6 +293,24 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
261
293
|
|
|
262
294
|
return await self.session.call_tool(tool_name, arguments)
|
|
263
295
|
|
|
296
|
+
async def list_prompts(
|
|
297
|
+
self,
|
|
298
|
+
) -> ListPromptsResult:
|
|
299
|
+
"""List the prompts available on the server."""
|
|
300
|
+
if not self.session:
|
|
301
|
+
raise UserError("Server not initialized. Make sure you call `connect()` first.")
|
|
302
|
+
|
|
303
|
+
return await self.session.list_prompts()
|
|
304
|
+
|
|
305
|
+
async def get_prompt(
|
|
306
|
+
self, name: str, arguments: dict[str, Any] | None = None
|
|
307
|
+
) -> GetPromptResult:
|
|
308
|
+
"""Get a specific prompt from the server."""
|
|
309
|
+
if not self.session:
|
|
310
|
+
raise UserError("Server not initialized. Make sure you call `connect()` first.")
|
|
311
|
+
|
|
312
|
+
return await self.session.get_prompt(name, arguments)
|
|
313
|
+
|
|
264
314
|
async def cleanup(self):
|
|
265
315
|
"""Cleanup the server."""
|
|
266
316
|
async with self._cleanup_lock:
|
|
@@ -314,6 +364,7 @@ class MCPServerStdio(_MCPServerWithClientSession):
|
|
|
314
364
|
name: str | None = None,
|
|
315
365
|
client_session_timeout_seconds: float | None = 5,
|
|
316
366
|
tool_filter: ToolFilter = None,
|
|
367
|
+
use_structured_content: bool = False,
|
|
317
368
|
):
|
|
318
369
|
"""Create a new MCP server based on the stdio transport.
|
|
319
370
|
|
|
@@ -332,11 +383,17 @@ class MCPServerStdio(_MCPServerWithClientSession):
|
|
|
332
383
|
command.
|
|
333
384
|
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
|
|
334
385
|
tool_filter: The tool filter to use for filtering tools.
|
|
386
|
+
use_structured_content: Whether to use `tool_result.structured_content` when calling an
|
|
387
|
+
MCP tool. Defaults to False for backwards compatibility - most MCP servers still
|
|
388
|
+
include the structured content in the `tool_result.content`, and using it by
|
|
389
|
+
default will cause duplicate content. You can set this to True if you know the
|
|
390
|
+
server will not duplicate the structured content in the `tool_result.content`.
|
|
335
391
|
"""
|
|
336
392
|
super().__init__(
|
|
337
393
|
cache_tools_list,
|
|
338
394
|
client_session_timeout_seconds,
|
|
339
395
|
tool_filter,
|
|
396
|
+
use_structured_content,
|
|
340
397
|
)
|
|
341
398
|
|
|
342
399
|
self.params = StdioServerParameters(
|
|
@@ -397,6 +454,7 @@ class MCPServerSse(_MCPServerWithClientSession):
|
|
|
397
454
|
name: str | None = None,
|
|
398
455
|
client_session_timeout_seconds: float | None = 5,
|
|
399
456
|
tool_filter: ToolFilter = None,
|
|
457
|
+
use_structured_content: bool = False,
|
|
400
458
|
):
|
|
401
459
|
"""Create a new MCP server based on the HTTP with SSE transport.
|
|
402
460
|
|
|
@@ -417,11 +475,17 @@ class MCPServerSse(_MCPServerWithClientSession):
|
|
|
417
475
|
|
|
418
476
|
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
|
|
419
477
|
tool_filter: The tool filter to use for filtering tools.
|
|
478
|
+
use_structured_content: Whether to use `tool_result.structured_content` when calling an
|
|
479
|
+
MCP tool. Defaults to False for backwards compatibility - most MCP servers still
|
|
480
|
+
include the structured content in the `tool_result.content`, and using it by
|
|
481
|
+
default will cause duplicate content. You can set this to True if you know the
|
|
482
|
+
server will not duplicate the structured content in the `tool_result.content`.
|
|
420
483
|
"""
|
|
421
484
|
super().__init__(
|
|
422
485
|
cache_tools_list,
|
|
423
486
|
client_session_timeout_seconds,
|
|
424
487
|
tool_filter,
|
|
488
|
+
use_structured_content,
|
|
425
489
|
)
|
|
426
490
|
|
|
427
491
|
self.params = params
|
|
@@ -482,6 +546,7 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
|
|
|
482
546
|
name: str | None = None,
|
|
483
547
|
client_session_timeout_seconds: float | None = 5,
|
|
484
548
|
tool_filter: ToolFilter = None,
|
|
549
|
+
use_structured_content: bool = False,
|
|
485
550
|
):
|
|
486
551
|
"""Create a new MCP server based on the Streamable HTTP transport.
|
|
487
552
|
|
|
@@ -503,11 +568,17 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
|
|
|
503
568
|
|
|
504
569
|
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
|
|
505
570
|
tool_filter: The tool filter to use for filtering tools.
|
|
571
|
+
use_structured_content: Whether to use `tool_result.structured_content` when calling an
|
|
572
|
+
MCP tool. Defaults to False for backwards compatibility - most MCP servers still
|
|
573
|
+
include the structured content in the `tool_result.content`, and using it by
|
|
574
|
+
default will cause duplicate content. You can set this to True if you know the
|
|
575
|
+
server will not duplicate the structured content in the `tool_result.content`.
|
|
506
576
|
"""
|
|
507
577
|
super().__init__(
|
|
508
578
|
cache_tools_list,
|
|
509
579
|
client_session_timeout_seconds,
|
|
510
580
|
tool_filter,
|
|
581
|
+
use_structured_content,
|
|
511
582
|
)
|
|
512
583
|
|
|
513
584
|
self.params = params
|
agents/mcp/util.py
CHANGED
|
@@ -5,12 +5,11 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
|
5
5
|
|
|
6
6
|
from typing_extensions import NotRequired, TypedDict
|
|
7
7
|
|
|
8
|
-
from agents.strict_schema import ensure_strict_json_schema
|
|
9
|
-
|
|
10
8
|
from .. import _debug
|
|
11
9
|
from ..exceptions import AgentsException, ModelBehaviorError, UserError
|
|
12
10
|
from ..logger import logger
|
|
13
11
|
from ..run_context import RunContextWrapper
|
|
12
|
+
from ..strict_schema import ensure_strict_json_schema
|
|
14
13
|
from ..tool import FunctionTool, Tool
|
|
15
14
|
from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span
|
|
16
15
|
from ..util._types import MaybeAwaitable
|
|
@@ -18,7 +17,7 @@ from ..util._types import MaybeAwaitable
|
|
|
18
17
|
if TYPE_CHECKING:
|
|
19
18
|
from mcp.types import Tool as MCPTool
|
|
20
19
|
|
|
21
|
-
from ..agent import
|
|
20
|
+
from ..agent import AgentBase
|
|
22
21
|
from .server import MCPServer
|
|
23
22
|
|
|
24
23
|
|
|
@@ -29,7 +28,7 @@ class ToolFilterContext:
|
|
|
29
28
|
run_context: RunContextWrapper[Any]
|
|
30
29
|
"""The current run context."""
|
|
31
30
|
|
|
32
|
-
agent: "
|
|
31
|
+
agent: "AgentBase"
|
|
33
32
|
"""The agent that is requesting the tool list."""
|
|
34
33
|
|
|
35
34
|
server_name: str
|
|
@@ -100,7 +99,7 @@ class MCPUtil:
|
|
|
100
99
|
servers: list["MCPServer"],
|
|
101
100
|
convert_schemas_to_strict: bool,
|
|
102
101
|
run_context: RunContextWrapper[Any],
|
|
103
|
-
agent: "
|
|
102
|
+
agent: "AgentBase",
|
|
104
103
|
) -> list[Tool]:
|
|
105
104
|
"""Get all function tools from a list of MCP servers."""
|
|
106
105
|
tools = []
|
|
@@ -126,7 +125,7 @@ class MCPUtil:
|
|
|
126
125
|
server: "MCPServer",
|
|
127
126
|
convert_schemas_to_strict: bool,
|
|
128
127
|
run_context: RunContextWrapper[Any],
|
|
129
|
-
agent: "
|
|
128
|
+
agent: "AgentBase",
|
|
130
129
|
) -> list[Tool]:
|
|
131
130
|
"""Get all function tools from a single MCP server."""
|
|
132
131
|
|
|
@@ -199,11 +198,19 @@ class MCPUtil:
|
|
|
199
198
|
# string. We'll try to convert.
|
|
200
199
|
if len(result.content) == 1:
|
|
201
200
|
tool_output = result.content[0].model_dump_json()
|
|
201
|
+
# Append structured content if it exists and we're using it.
|
|
202
|
+
if server.use_structured_content and result.structuredContent:
|
|
203
|
+
tool_output = f"{tool_output}\n{json.dumps(result.structuredContent)}"
|
|
202
204
|
elif len(result.content) > 1:
|
|
203
|
-
|
|
205
|
+
tool_results = [item.model_dump(mode="json") for item in result.content]
|
|
206
|
+
if server.use_structured_content and result.structuredContent:
|
|
207
|
+
tool_results.append(result.structuredContent)
|
|
208
|
+
tool_output = json.dumps(tool_results)
|
|
209
|
+
elif server.use_structured_content and result.structuredContent:
|
|
210
|
+
tool_output = json.dumps(result.structuredContent)
|
|
204
211
|
else:
|
|
205
|
-
|
|
206
|
-
tool_output = "
|
|
212
|
+
# Empty content is a valid result (e.g., "no results found")
|
|
213
|
+
tool_output = "[]"
|
|
207
214
|
|
|
208
215
|
current_span = get_current_span()
|
|
209
216
|
if current_span:
|
agents/memory/session.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import sqlite3
|
|
6
|
+
import threading
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from ..items import TResponseInputItem
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@runtime_checkable
|
|
16
|
+
class Session(Protocol):
|
|
17
|
+
"""Protocol for session implementations.
|
|
18
|
+
|
|
19
|
+
Session stores conversation history for a specific session, allowing
|
|
20
|
+
agents to maintain context without requiring explicit manual memory management.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
session_id: str
|
|
24
|
+
|
|
25
|
+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
|
|
26
|
+
"""Retrieve the conversation history for this session.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
limit: Maximum number of items to retrieve. If None, retrieves all items.
|
|
30
|
+
When specified, returns the latest N items in chronological order.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
List of input items representing the conversation history
|
|
34
|
+
"""
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
|
38
|
+
"""Add new items to the conversation history.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
items: List of input items to add to the history
|
|
42
|
+
"""
|
|
43
|
+
...
|
|
44
|
+
|
|
45
|
+
async def pop_item(self) -> TResponseInputItem | None:
|
|
46
|
+
"""Remove and return the most recent item from the session.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
The most recent item if it exists, None if the session is empty
|
|
50
|
+
"""
|
|
51
|
+
...
|
|
52
|
+
|
|
53
|
+
async def clear_session(self) -> None:
|
|
54
|
+
"""Clear all items for this session."""
|
|
55
|
+
...
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class SessionABC(ABC):
|
|
59
|
+
"""Abstract base class for session implementations.
|
|
60
|
+
|
|
61
|
+
Session stores conversation history for a specific session, allowing
|
|
62
|
+
agents to maintain context without requiring explicit manual memory management.
|
|
63
|
+
|
|
64
|
+
This ABC is intended for internal use and as a base class for concrete implementations.
|
|
65
|
+
Third-party libraries should implement the Session protocol instead.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
session_id: str
|
|
69
|
+
|
|
70
|
+
@abstractmethod
|
|
71
|
+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
|
|
72
|
+
"""Retrieve the conversation history for this session.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
limit: Maximum number of items to retrieve. If None, retrieves all items.
|
|
76
|
+
When specified, returns the latest N items in chronological order.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
List of input items representing the conversation history
|
|
80
|
+
"""
|
|
81
|
+
...
|
|
82
|
+
|
|
83
|
+
@abstractmethod
|
|
84
|
+
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
|
85
|
+
"""Add new items to the conversation history.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
items: List of input items to add to the history
|
|
89
|
+
"""
|
|
90
|
+
...
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
async def pop_item(self) -> TResponseInputItem | None:
|
|
94
|
+
"""Remove and return the most recent item from the session.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
The most recent item if it exists, None if the session is empty
|
|
98
|
+
"""
|
|
99
|
+
...
|
|
100
|
+
|
|
101
|
+
@abstractmethod
|
|
102
|
+
async def clear_session(self) -> None:
|
|
103
|
+
"""Clear all items for this session."""
|
|
104
|
+
...
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class SQLiteSession(SessionABC):
|
|
108
|
+
"""SQLite-based implementation of session storage.
|
|
109
|
+
|
|
110
|
+
This implementation stores conversation history in a SQLite database.
|
|
111
|
+
By default, uses an in-memory database that is lost when the process ends.
|
|
112
|
+
For persistent storage, provide a file path.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
session_id: str,
|
|
118
|
+
db_path: str | Path = ":memory:",
|
|
119
|
+
sessions_table: str = "agent_sessions",
|
|
120
|
+
messages_table: str = "agent_messages",
|
|
121
|
+
):
|
|
122
|
+
"""Initialize the SQLite session.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
session_id: Unique identifier for the conversation session
|
|
126
|
+
db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database)
|
|
127
|
+
sessions_table: Name of the table to store session metadata. Defaults to
|
|
128
|
+
'agent_sessions'
|
|
129
|
+
messages_table: Name of the table to store message data. Defaults to 'agent_messages'
|
|
130
|
+
"""
|
|
131
|
+
self.session_id = session_id
|
|
132
|
+
self.db_path = db_path
|
|
133
|
+
self.sessions_table = sessions_table
|
|
134
|
+
self.messages_table = messages_table
|
|
135
|
+
self._local = threading.local()
|
|
136
|
+
self._lock = threading.Lock()
|
|
137
|
+
|
|
138
|
+
# For in-memory databases, we need a shared connection to avoid thread isolation
|
|
139
|
+
# For file databases, we use thread-local connections for better concurrency
|
|
140
|
+
self._is_memory_db = str(db_path) == ":memory:"
|
|
141
|
+
if self._is_memory_db:
|
|
142
|
+
self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False)
|
|
143
|
+
self._shared_connection.execute("PRAGMA journal_mode=WAL")
|
|
144
|
+
self._init_db_for_connection(self._shared_connection)
|
|
145
|
+
else:
|
|
146
|
+
# For file databases, initialize the schema once since it persists
|
|
147
|
+
init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
|
148
|
+
init_conn.execute("PRAGMA journal_mode=WAL")
|
|
149
|
+
self._init_db_for_connection(init_conn)
|
|
150
|
+
init_conn.close()
|
|
151
|
+
|
|
152
|
+
def _get_connection(self) -> sqlite3.Connection:
|
|
153
|
+
"""Get a database connection."""
|
|
154
|
+
if self._is_memory_db:
|
|
155
|
+
# Use shared connection for in-memory database to avoid thread isolation
|
|
156
|
+
return self._shared_connection
|
|
157
|
+
else:
|
|
158
|
+
# Use thread-local connections for file databases
|
|
159
|
+
if not hasattr(self._local, "connection"):
|
|
160
|
+
self._local.connection = sqlite3.connect(
|
|
161
|
+
str(self.db_path),
|
|
162
|
+
check_same_thread=False,
|
|
163
|
+
)
|
|
164
|
+
self._local.connection.execute("PRAGMA journal_mode=WAL")
|
|
165
|
+
assert isinstance(self._local.connection, sqlite3.Connection), (
|
|
166
|
+
f"Expected sqlite3.Connection, got {type(self._local.connection)}"
|
|
167
|
+
)
|
|
168
|
+
return self._local.connection
|
|
169
|
+
|
|
170
|
+
def _init_db_for_connection(self, conn: sqlite3.Connection) -> None:
|
|
171
|
+
"""Initialize the database schema for a specific connection."""
|
|
172
|
+
conn.execute(
|
|
173
|
+
f"""
|
|
174
|
+
CREATE TABLE IF NOT EXISTS {self.sessions_table} (
|
|
175
|
+
session_id TEXT PRIMARY KEY,
|
|
176
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
177
|
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
178
|
+
)
|
|
179
|
+
"""
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
conn.execute(
|
|
183
|
+
f"""
|
|
184
|
+
CREATE TABLE IF NOT EXISTS {self.messages_table} (
|
|
185
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
186
|
+
session_id TEXT NOT NULL,
|
|
187
|
+
message_data TEXT NOT NULL,
|
|
188
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
189
|
+
FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id)
|
|
190
|
+
ON DELETE CASCADE
|
|
191
|
+
)
|
|
192
|
+
"""
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
conn.execute(
|
|
196
|
+
f"""
|
|
197
|
+
CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id
|
|
198
|
+
ON {self.messages_table} (session_id, created_at)
|
|
199
|
+
"""
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
conn.commit()
|
|
203
|
+
|
|
204
|
+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
|
|
205
|
+
"""Retrieve the conversation history for this session.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
limit: Maximum number of items to retrieve. If None, retrieves all items.
|
|
209
|
+
When specified, returns the latest N items in chronological order.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
List of input items representing the conversation history
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
def _get_items_sync():
|
|
216
|
+
conn = self._get_connection()
|
|
217
|
+
with self._lock if self._is_memory_db else threading.Lock():
|
|
218
|
+
if limit is None:
|
|
219
|
+
# Fetch all items in chronological order
|
|
220
|
+
cursor = conn.execute(
|
|
221
|
+
f"""
|
|
222
|
+
SELECT message_data FROM {self.messages_table}
|
|
223
|
+
WHERE session_id = ?
|
|
224
|
+
ORDER BY created_at ASC
|
|
225
|
+
""",
|
|
226
|
+
(self.session_id,),
|
|
227
|
+
)
|
|
228
|
+
else:
|
|
229
|
+
# Fetch the latest N items in chronological order
|
|
230
|
+
cursor = conn.execute(
|
|
231
|
+
f"""
|
|
232
|
+
SELECT message_data FROM {self.messages_table}
|
|
233
|
+
WHERE session_id = ?
|
|
234
|
+
ORDER BY created_at DESC
|
|
235
|
+
LIMIT ?
|
|
236
|
+
""",
|
|
237
|
+
(self.session_id, limit),
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
rows = cursor.fetchall()
|
|
241
|
+
|
|
242
|
+
# Reverse to get chronological order when using DESC
|
|
243
|
+
if limit is not None:
|
|
244
|
+
rows = list(reversed(rows))
|
|
245
|
+
|
|
246
|
+
items = []
|
|
247
|
+
for (message_data,) in rows:
|
|
248
|
+
try:
|
|
249
|
+
item = json.loads(message_data)
|
|
250
|
+
items.append(item)
|
|
251
|
+
except json.JSONDecodeError:
|
|
252
|
+
# Skip invalid JSON entries
|
|
253
|
+
continue
|
|
254
|
+
|
|
255
|
+
return items
|
|
256
|
+
|
|
257
|
+
return await asyncio.to_thread(_get_items_sync)
|
|
258
|
+
|
|
259
|
+
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
|
260
|
+
"""Add new items to the conversation history.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
items: List of input items to add to the history
|
|
264
|
+
"""
|
|
265
|
+
if not items:
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
def _add_items_sync():
|
|
269
|
+
conn = self._get_connection()
|
|
270
|
+
|
|
271
|
+
with self._lock if self._is_memory_db else threading.Lock():
|
|
272
|
+
# Ensure session exists
|
|
273
|
+
conn.execute(
|
|
274
|
+
f"""
|
|
275
|
+
INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?)
|
|
276
|
+
""",
|
|
277
|
+
(self.session_id,),
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Add items
|
|
281
|
+
message_data = [(self.session_id, json.dumps(item)) for item in items]
|
|
282
|
+
conn.executemany(
|
|
283
|
+
f"""
|
|
284
|
+
INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?)
|
|
285
|
+
""",
|
|
286
|
+
message_data,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Update session timestamp
|
|
290
|
+
conn.execute(
|
|
291
|
+
f"""
|
|
292
|
+
UPDATE {self.sessions_table}
|
|
293
|
+
SET updated_at = CURRENT_TIMESTAMP
|
|
294
|
+
WHERE session_id = ?
|
|
295
|
+
""",
|
|
296
|
+
(self.session_id,),
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
conn.commit()
|
|
300
|
+
|
|
301
|
+
await asyncio.to_thread(_add_items_sync)
|
|
302
|
+
|
|
303
|
+
async def pop_item(self) -> TResponseInputItem | None:
|
|
304
|
+
"""Remove and return the most recent item from the session.
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
The most recent item if it exists, None if the session is empty
|
|
308
|
+
"""
|
|
309
|
+
|
|
310
|
+
def _pop_item_sync():
|
|
311
|
+
conn = self._get_connection()
|
|
312
|
+
with self._lock if self._is_memory_db else threading.Lock():
|
|
313
|
+
# Use DELETE with RETURNING to atomically delete and return the most recent item
|
|
314
|
+
cursor = conn.execute(
|
|
315
|
+
f"""
|
|
316
|
+
DELETE FROM {self.messages_table}
|
|
317
|
+
WHERE id = (
|
|
318
|
+
SELECT id FROM {self.messages_table}
|
|
319
|
+
WHERE session_id = ?
|
|
320
|
+
ORDER BY created_at DESC
|
|
321
|
+
LIMIT 1
|
|
322
|
+
)
|
|
323
|
+
RETURNING message_data
|
|
324
|
+
""",
|
|
325
|
+
(self.session_id,),
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
result = cursor.fetchone()
|
|
329
|
+
conn.commit()
|
|
330
|
+
|
|
331
|
+
if result:
|
|
332
|
+
message_data = result[0]
|
|
333
|
+
try:
|
|
334
|
+
item = json.loads(message_data)
|
|
335
|
+
return item
|
|
336
|
+
except json.JSONDecodeError:
|
|
337
|
+
# Return None for corrupted JSON entries (already deleted)
|
|
338
|
+
return None
|
|
339
|
+
|
|
340
|
+
return None
|
|
341
|
+
|
|
342
|
+
return await asyncio.to_thread(_pop_item_sync)
|
|
343
|
+
|
|
344
|
+
async def clear_session(self) -> None:
|
|
345
|
+
"""Clear all items for this session."""
|
|
346
|
+
|
|
347
|
+
def _clear_session_sync():
|
|
348
|
+
conn = self._get_connection()
|
|
349
|
+
with self._lock if self._is_memory_db else threading.Lock():
|
|
350
|
+
conn.execute(
|
|
351
|
+
f"DELETE FROM {self.messages_table} WHERE session_id = ?",
|
|
352
|
+
(self.session_id,),
|
|
353
|
+
)
|
|
354
|
+
conn.execute(
|
|
355
|
+
f"DELETE FROM {self.sessions_table} WHERE session_id = ?",
|
|
356
|
+
(self.session_id,),
|
|
357
|
+
)
|
|
358
|
+
conn.commit()
|
|
359
|
+
|
|
360
|
+
await asyncio.to_thread(_clear_session_sync)
|
|
361
|
+
|
|
362
|
+
def close(self) -> None:
|
|
363
|
+
"""Close the database connection."""
|
|
364
|
+
if self._is_memory_db:
|
|
365
|
+
if hasattr(self, "_shared_connection"):
|
|
366
|
+
self._shared_connection.close()
|
|
367
|
+
else:
|
|
368
|
+
if hasattr(self._local, "connection"):
|
|
369
|
+
self._local.connection.close()
|