openai-agents 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +6 -0
- agents/_run_impl.py +56 -6
- agents/agent.py +40 -2
- agents/extensions/visualization.py +137 -0
- agents/mcp/__init__.py +21 -0
- agents/mcp/server.py +301 -0
- agents/mcp/util.py +131 -0
- agents/model_settings.py +25 -13
- agents/models/openai_chatcompletions.py +13 -3
- agents/models/openai_responses.py +9 -2
- agents/py.typed +1 -0
- agents/run.py +45 -7
- agents/strict_schema.py +1 -1
- agents/tracing/__init__.py +4 -0
- agents/tracing/create.py +29 -0
- agents/tracing/processors.py +26 -8
- agents/tracing/span_data.py +33 -3
- agents/version.py +1 -1
- agents/voice/imports.py +1 -1
- agents/voice/models/openai_stt.py +1 -2
- {openai_agents-0.0.6.dist-info → openai_agents-0.0.8.dist-info}/METADATA +5 -2
- {openai_agents-0.0.6.dist-info → openai_agents-0.0.8.dist-info}/RECORD +24 -19
- {openai_agents-0.0.6.dist-info → openai_agents-0.0.8.dist-info}/WHEEL +0 -0
- {openai_agents-0.0.6.dist-info → openai_agents-0.0.8.dist-info}/licenses/LICENSE +0 -0
agents/mcp/server.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import asyncio
|
|
5
|
+
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
10
|
+
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
|
|
11
|
+
from mcp.client.sse import sse_client
|
|
12
|
+
from mcp.types import CallToolResult, JSONRPCMessage
|
|
13
|
+
from typing_extensions import NotRequired, TypedDict
|
|
14
|
+
|
|
15
|
+
from ..exceptions import UserError
|
|
16
|
+
from ..logger import logger
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MCPServer(abc.ABC):
|
|
20
|
+
"""Base class for Model Context Protocol servers."""
|
|
21
|
+
|
|
22
|
+
@abc.abstractmethod
|
|
23
|
+
async def connect(self):
|
|
24
|
+
"""Connect to the server. For example, this might mean spawning a subprocess or
|
|
25
|
+
opening a network connection. The server is expected to remain connected until
|
|
26
|
+
`cleanup()` is called.
|
|
27
|
+
"""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
@abc.abstractmethod
|
|
32
|
+
def name(self) -> str:
|
|
33
|
+
"""A readable name for the server."""
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@abc.abstractmethod
|
|
37
|
+
async def cleanup(self):
|
|
38
|
+
"""Cleanup the server. For example, this might mean closing a subprocess or
|
|
39
|
+
closing a network connection.
|
|
40
|
+
"""
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
@abc.abstractmethod
|
|
44
|
+
async def list_tools(self) -> list[MCPTool]:
|
|
45
|
+
"""List the tools available on the server."""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
@abc.abstractmethod
|
|
49
|
+
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
|
|
50
|
+
"""Invoke a tool on the server."""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
55
|
+
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
|
|
56
|
+
|
|
57
|
+
def __init__(self, cache_tools_list: bool):
|
|
58
|
+
"""
|
|
59
|
+
Args:
|
|
60
|
+
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
|
61
|
+
cached and only fetched from the server once. If `False`, the tools list will be
|
|
62
|
+
fetched from the server on each call to `list_tools()`. The cache can be invalidated
|
|
63
|
+
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
|
|
64
|
+
server will not change its tools list, because it can drastically improve latency
|
|
65
|
+
(by avoiding a round-trip to the server every time).
|
|
66
|
+
"""
|
|
67
|
+
self.session: ClientSession | None = None
|
|
68
|
+
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
|
69
|
+
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
|
|
70
|
+
self.cache_tools_list = cache_tools_list
|
|
71
|
+
|
|
72
|
+
# The cache is always dirty at startup, so that we fetch tools at least once
|
|
73
|
+
self._cache_dirty = True
|
|
74
|
+
self._tools_list: list[MCPTool] | None = None
|
|
75
|
+
|
|
76
|
+
@abc.abstractmethod
|
|
77
|
+
def create_streams(
|
|
78
|
+
self,
|
|
79
|
+
) -> AbstractAsyncContextManager[
|
|
80
|
+
tuple[
|
|
81
|
+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
|
82
|
+
MemoryObjectSendStream[JSONRPCMessage],
|
|
83
|
+
]
|
|
84
|
+
]:
|
|
85
|
+
"""Create the streams for the server."""
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
async def __aenter__(self):
|
|
89
|
+
await self.connect()
|
|
90
|
+
return self
|
|
91
|
+
|
|
92
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
93
|
+
await self.cleanup()
|
|
94
|
+
|
|
95
|
+
def invalidate_tools_cache(self):
|
|
96
|
+
"""Invalidate the tools cache."""
|
|
97
|
+
self._cache_dirty = True
|
|
98
|
+
|
|
99
|
+
async def connect(self):
|
|
100
|
+
"""Connect to the server."""
|
|
101
|
+
try:
|
|
102
|
+
transport = await self.exit_stack.enter_async_context(self.create_streams())
|
|
103
|
+
read, write = transport
|
|
104
|
+
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
|
105
|
+
await session.initialize()
|
|
106
|
+
self.session = session
|
|
107
|
+
except Exception as e:
|
|
108
|
+
logger.error(f"Error initializing MCP server: {e}")
|
|
109
|
+
await self.cleanup()
|
|
110
|
+
raise
|
|
111
|
+
|
|
112
|
+
async def list_tools(self) -> list[MCPTool]:
|
|
113
|
+
"""List the tools available on the server."""
|
|
114
|
+
if not self.session:
|
|
115
|
+
raise UserError("Server not initialized. Make sure you call `connect()` first.")
|
|
116
|
+
|
|
117
|
+
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
|
|
118
|
+
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
|
|
119
|
+
return self._tools_list
|
|
120
|
+
|
|
121
|
+
# Reset the cache dirty to False
|
|
122
|
+
self._cache_dirty = False
|
|
123
|
+
|
|
124
|
+
# Fetch the tools from the server
|
|
125
|
+
self._tools_list = (await self.session.list_tools()).tools
|
|
126
|
+
return self._tools_list
|
|
127
|
+
|
|
128
|
+
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
|
|
129
|
+
"""Invoke a tool on the server."""
|
|
130
|
+
if not self.session:
|
|
131
|
+
raise UserError("Server not initialized. Make sure you call `connect()` first.")
|
|
132
|
+
|
|
133
|
+
return await self.session.call_tool(tool_name, arguments)
|
|
134
|
+
|
|
135
|
+
async def cleanup(self):
|
|
136
|
+
"""Cleanup the server."""
|
|
137
|
+
async with self._cleanup_lock:
|
|
138
|
+
try:
|
|
139
|
+
await self.exit_stack.aclose()
|
|
140
|
+
self.session = None
|
|
141
|
+
except Exception as e:
|
|
142
|
+
logger.error(f"Error cleaning up server: {e}")
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class MCPServerStdioParams(TypedDict):
|
|
146
|
+
"""Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another
|
|
147
|
+
import.
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
command: str
|
|
151
|
+
"""The executable to run to start the server. For example, `python` or `node`."""
|
|
152
|
+
|
|
153
|
+
args: NotRequired[list[str]]
|
|
154
|
+
"""Command line args to pass to the `command` executable. For example, `['foo.py']` or
|
|
155
|
+
`['server.js', '--port', '8080']`."""
|
|
156
|
+
|
|
157
|
+
env: NotRequired[dict[str, str]]
|
|
158
|
+
"""The environment variables to set for the server. ."""
|
|
159
|
+
|
|
160
|
+
cwd: NotRequired[str | Path]
|
|
161
|
+
"""The working directory to use when spawning the process."""
|
|
162
|
+
|
|
163
|
+
encoding: NotRequired[str]
|
|
164
|
+
"""The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`."""
|
|
165
|
+
|
|
166
|
+
encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]]
|
|
167
|
+
"""The text encoding error handler. Defaults to `strict`.
|
|
168
|
+
|
|
169
|
+
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
|
|
170
|
+
explanations of possible values.
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class MCPServerStdio(_MCPServerWithClientSession):
|
|
175
|
+
"""MCP server implementation that uses the stdio transport. See the [spec]
|
|
176
|
+
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for
|
|
177
|
+
details.
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
def __init__(
|
|
181
|
+
self,
|
|
182
|
+
params: MCPServerStdioParams,
|
|
183
|
+
cache_tools_list: bool = False,
|
|
184
|
+
name: str | None = None,
|
|
185
|
+
):
|
|
186
|
+
"""Create a new MCP server based on the stdio transport.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
params: The params that configure the server. This includes the command to run to
|
|
190
|
+
start the server, the args to pass to the command, the environment variables to
|
|
191
|
+
set for the server, the working directory to use when spawning the process, and
|
|
192
|
+
the text encoding used when sending/receiving messages to the server.
|
|
193
|
+
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
|
194
|
+
cached and only fetched from the server once. If `False`, the tools list will be
|
|
195
|
+
fetched from the server on each call to `list_tools()`. The cache can be
|
|
196
|
+
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
|
|
197
|
+
if you know the server will not change its tools list, because it can drastically
|
|
198
|
+
improve latency (by avoiding a round-trip to the server every time).
|
|
199
|
+
name: A readable name for the server. If not provided, we'll create one from the
|
|
200
|
+
command.
|
|
201
|
+
"""
|
|
202
|
+
super().__init__(cache_tools_list)
|
|
203
|
+
|
|
204
|
+
self.params = StdioServerParameters(
|
|
205
|
+
command=params["command"],
|
|
206
|
+
args=params.get("args", []),
|
|
207
|
+
env=params.get("env"),
|
|
208
|
+
cwd=params.get("cwd"),
|
|
209
|
+
encoding=params.get("encoding", "utf-8"),
|
|
210
|
+
encoding_error_handler=params.get("encoding_error_handler", "strict"),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
self._name = name or f"stdio: {self.params.command}"
|
|
214
|
+
|
|
215
|
+
def create_streams(
|
|
216
|
+
self,
|
|
217
|
+
) -> AbstractAsyncContextManager[
|
|
218
|
+
tuple[
|
|
219
|
+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
|
220
|
+
MemoryObjectSendStream[JSONRPCMessage],
|
|
221
|
+
]
|
|
222
|
+
]:
|
|
223
|
+
"""Create the streams for the server."""
|
|
224
|
+
return stdio_client(self.params)
|
|
225
|
+
|
|
226
|
+
@property
|
|
227
|
+
def name(self) -> str:
|
|
228
|
+
"""A readable name for the server."""
|
|
229
|
+
return self._name
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class MCPServerSseParams(TypedDict):
|
|
233
|
+
"""Mirrors the params in`mcp.client.sse.sse_client`."""
|
|
234
|
+
|
|
235
|
+
url: str
|
|
236
|
+
"""The URL of the server."""
|
|
237
|
+
|
|
238
|
+
headers: NotRequired[dict[str, str]]
|
|
239
|
+
"""The headers to send to the server."""
|
|
240
|
+
|
|
241
|
+
timeout: NotRequired[float]
|
|
242
|
+
"""The timeout for the HTTP request. Defaults to 5 seconds."""
|
|
243
|
+
|
|
244
|
+
sse_read_timeout: NotRequired[float]
|
|
245
|
+
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class MCPServerSse(_MCPServerWithClientSession):
|
|
249
|
+
"""MCP server implementation that uses the HTTP with SSE transport. See the [spec]
|
|
250
|
+
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
|
|
251
|
+
for details.
|
|
252
|
+
"""
|
|
253
|
+
|
|
254
|
+
def __init__(
|
|
255
|
+
self,
|
|
256
|
+
params: MCPServerSseParams,
|
|
257
|
+
cache_tools_list: bool = False,
|
|
258
|
+
name: str | None = None,
|
|
259
|
+
):
|
|
260
|
+
"""Create a new MCP server based on the HTTP with SSE transport.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
params: The params that configure the server. This includes the URL of the server,
|
|
264
|
+
the headers to send to the server, the timeout for the HTTP request, and the
|
|
265
|
+
timeout for the SSE connection.
|
|
266
|
+
|
|
267
|
+
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
|
268
|
+
cached and only fetched from the server once. If `False`, the tools list will be
|
|
269
|
+
fetched from the server on each call to `list_tools()`. The cache can be
|
|
270
|
+
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
|
|
271
|
+
if you know the server will not change its tools list, because it can drastically
|
|
272
|
+
improve latency (by avoiding a round-trip to the server every time).
|
|
273
|
+
|
|
274
|
+
name: A readable name for the server. If not provided, we'll create one from the
|
|
275
|
+
URL.
|
|
276
|
+
"""
|
|
277
|
+
super().__init__(cache_tools_list)
|
|
278
|
+
|
|
279
|
+
self.params = params
|
|
280
|
+
self._name = name or f"sse: {self.params['url']}"
|
|
281
|
+
|
|
282
|
+
def create_streams(
|
|
283
|
+
self,
|
|
284
|
+
) -> AbstractAsyncContextManager[
|
|
285
|
+
tuple[
|
|
286
|
+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
|
287
|
+
MemoryObjectSendStream[JSONRPCMessage],
|
|
288
|
+
]
|
|
289
|
+
]:
|
|
290
|
+
"""Create the streams for the server."""
|
|
291
|
+
return sse_client(
|
|
292
|
+
url=self.params["url"],
|
|
293
|
+
headers=self.params.get("headers", None),
|
|
294
|
+
timeout=self.params.get("timeout", 5),
|
|
295
|
+
sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
@property
|
|
299
|
+
def name(self) -> str:
|
|
300
|
+
"""A readable name for the server."""
|
|
301
|
+
return self._name
|
agents/mcp/util.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import json
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
|
+
|
|
5
|
+
from agents.strict_schema import ensure_strict_json_schema
|
|
6
|
+
|
|
7
|
+
from .. import _debug
|
|
8
|
+
from ..exceptions import AgentsException, ModelBehaviorError, UserError
|
|
9
|
+
from ..logger import logger
|
|
10
|
+
from ..run_context import RunContextWrapper
|
|
11
|
+
from ..tool import FunctionTool, Tool
|
|
12
|
+
from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from mcp.types import Tool as MCPTool
|
|
16
|
+
|
|
17
|
+
from .server import MCPServer
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class MCPUtil:
|
|
21
|
+
"""Set of utilities for interop between MCP and Agents SDK tools."""
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
async def get_all_function_tools(
|
|
25
|
+
cls, servers: list["MCPServer"], convert_schemas_to_strict: bool
|
|
26
|
+
) -> list[Tool]:
|
|
27
|
+
"""Get all function tools from a list of MCP servers."""
|
|
28
|
+
tools = []
|
|
29
|
+
tool_names: set[str] = set()
|
|
30
|
+
for server in servers:
|
|
31
|
+
server_tools = await cls.get_function_tools(server, convert_schemas_to_strict)
|
|
32
|
+
server_tool_names = {tool.name for tool in server_tools}
|
|
33
|
+
if len(server_tool_names & tool_names) > 0:
|
|
34
|
+
raise UserError(
|
|
35
|
+
f"Duplicate tool names found across MCP servers: "
|
|
36
|
+
f"{server_tool_names & tool_names}"
|
|
37
|
+
)
|
|
38
|
+
tool_names.update(server_tool_names)
|
|
39
|
+
tools.extend(server_tools)
|
|
40
|
+
|
|
41
|
+
return tools
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
async def get_function_tools(
|
|
45
|
+
cls, server: "MCPServer", convert_schemas_to_strict: bool
|
|
46
|
+
) -> list[Tool]:
|
|
47
|
+
"""Get all function tools from a single MCP server."""
|
|
48
|
+
|
|
49
|
+
with mcp_tools_span(server=server.name) as span:
|
|
50
|
+
tools = await server.list_tools()
|
|
51
|
+
span.span_data.result = [tool.name for tool in tools]
|
|
52
|
+
|
|
53
|
+
return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools]
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def to_function_tool(
|
|
57
|
+
cls, tool: "MCPTool", server: "MCPServer", convert_schemas_to_strict: bool
|
|
58
|
+
) -> FunctionTool:
|
|
59
|
+
"""Convert an MCP tool to an Agents SDK function tool."""
|
|
60
|
+
invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool)
|
|
61
|
+
schema, is_strict = tool.inputSchema, False
|
|
62
|
+
if convert_schemas_to_strict:
|
|
63
|
+
try:
|
|
64
|
+
schema = ensure_strict_json_schema(schema)
|
|
65
|
+
is_strict = True
|
|
66
|
+
except Exception as e:
|
|
67
|
+
logger.info(f"Error converting MCP schema to strict mode: {e}")
|
|
68
|
+
|
|
69
|
+
return FunctionTool(
|
|
70
|
+
name=tool.name,
|
|
71
|
+
description=tool.description or "",
|
|
72
|
+
params_json_schema=schema,
|
|
73
|
+
on_invoke_tool=invoke_func,
|
|
74
|
+
strict_json_schema=is_strict,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
async def invoke_mcp_tool(
|
|
79
|
+
cls, server: "MCPServer", tool: "MCPTool", context: RunContextWrapper[Any], input_json: str
|
|
80
|
+
) -> str:
|
|
81
|
+
"""Invoke an MCP tool and return the result as a string."""
|
|
82
|
+
try:
|
|
83
|
+
json_data: dict[str, Any] = json.loads(input_json) if input_json else {}
|
|
84
|
+
except Exception as e:
|
|
85
|
+
if _debug.DONT_LOG_TOOL_DATA:
|
|
86
|
+
logger.debug(f"Invalid JSON input for tool {tool.name}")
|
|
87
|
+
else:
|
|
88
|
+
logger.debug(f"Invalid JSON input for tool {tool.name}: {input_json}")
|
|
89
|
+
raise ModelBehaviorError(
|
|
90
|
+
f"Invalid JSON input for tool {tool.name}: {input_json}"
|
|
91
|
+
) from e
|
|
92
|
+
|
|
93
|
+
if _debug.DONT_LOG_TOOL_DATA:
|
|
94
|
+
logger.debug(f"Invoking MCP tool {tool.name}")
|
|
95
|
+
else:
|
|
96
|
+
logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}")
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
result = await server.call_tool(tool.name, json_data)
|
|
100
|
+
except Exception as e:
|
|
101
|
+
logger.error(f"Error invoking MCP tool {tool.name}: {e}")
|
|
102
|
+
raise AgentsException(f"Error invoking MCP tool {tool.name}: {e}") from e
|
|
103
|
+
|
|
104
|
+
if _debug.DONT_LOG_TOOL_DATA:
|
|
105
|
+
logger.debug(f"MCP tool {tool.name} completed.")
|
|
106
|
+
else:
|
|
107
|
+
logger.debug(f"MCP tool {tool.name} returned {result}")
|
|
108
|
+
|
|
109
|
+
# The MCP tool result is a list of content items, whereas OpenAI tool outputs are a single
|
|
110
|
+
# string. We'll try to convert.
|
|
111
|
+
if len(result.content) == 1:
|
|
112
|
+
tool_output = result.content[0].model_dump_json()
|
|
113
|
+
elif len(result.content) > 1:
|
|
114
|
+
tool_output = json.dumps([item.model_dump() for item in result.content])
|
|
115
|
+
else:
|
|
116
|
+
logger.error(f"Errored MCP tool result: {result}")
|
|
117
|
+
tool_output = "Error running tool."
|
|
118
|
+
|
|
119
|
+
current_span = get_current_span()
|
|
120
|
+
if current_span:
|
|
121
|
+
if isinstance(current_span.span_data, FunctionSpanData):
|
|
122
|
+
current_span.span_data.output = tool_output
|
|
123
|
+
current_span.span_data.mcp_data = {
|
|
124
|
+
"server": server.name,
|
|
125
|
+
}
|
|
126
|
+
else:
|
|
127
|
+
logger.warning(
|
|
128
|
+
f"Current span is not a FunctionSpanData, skipping tool output: {current_span}"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return tool_output
|
agents/model_settings.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from dataclasses import dataclass
|
|
3
|
+
from dataclasses import dataclass, fields, replace
|
|
4
4
|
from typing import Literal
|
|
5
5
|
|
|
6
|
+
from openai.types.shared import Reasoning
|
|
7
|
+
|
|
6
8
|
|
|
7
9
|
@dataclass
|
|
8
10
|
class ModelSettings:
|
|
@@ -30,8 +32,9 @@ class ModelSettings:
|
|
|
30
32
|
tool_choice: Literal["auto", "required", "none"] | str | None = None
|
|
31
33
|
"""The tool choice to use when calling the model."""
|
|
32
34
|
|
|
33
|
-
parallel_tool_calls: bool | None =
|
|
34
|
-
"""Whether to use parallel tool calls when calling the model.
|
|
35
|
+
parallel_tool_calls: bool | None = None
|
|
36
|
+
"""Whether to use parallel tool calls when calling the model.
|
|
37
|
+
Defaults to False if not provided."""
|
|
35
38
|
|
|
36
39
|
truncation: Literal["auto", "disabled"] | None = None
|
|
37
40
|
"""The truncation strategy to use when calling the model."""
|
|
@@ -39,18 +42,27 @@ class ModelSettings:
|
|
|
39
42
|
max_tokens: int | None = None
|
|
40
43
|
"""The maximum number of output tokens to generate."""
|
|
41
44
|
|
|
45
|
+
reasoning: Reasoning | None = None
|
|
46
|
+
"""Configuration options for
|
|
47
|
+
[reasoning models](https://platform.openai.com/docs/guides/reasoning).
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
metadata: dict[str, str] | None = None
|
|
51
|
+
"""Metadata to include with the model response call."""
|
|
52
|
+
|
|
53
|
+
store: bool | None = None
|
|
54
|
+
"""Whether to store the generated model response for later retrieval.
|
|
55
|
+
Defaults to True if not provided."""
|
|
56
|
+
|
|
42
57
|
def resolve(self, override: ModelSettings | None) -> ModelSettings:
|
|
43
58
|
"""Produce a new ModelSettings by overlaying any non-None values from the
|
|
44
59
|
override on top of this instance."""
|
|
45
60
|
if override is None:
|
|
46
61
|
return self
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
truncation=override.truncation or self.truncation,
|
|
55
|
-
max_tokens=override.max_tokens or self.max_tokens,
|
|
56
|
-
)
|
|
62
|
+
|
|
63
|
+
changes = {
|
|
64
|
+
field.name: getattr(override, field.name)
|
|
65
|
+
for field in fields(self)
|
|
66
|
+
if getattr(override, field.name) is not None
|
|
67
|
+
}
|
|
68
|
+
return replace(self, **changes)
|
|
@@ -518,6 +518,11 @@ class OpenAIChatCompletionsModel(Model):
|
|
|
518
518
|
f"Response format: {response_format}\n"
|
|
519
519
|
)
|
|
520
520
|
|
|
521
|
+
# Match the behavior of Responses where store is True when not given
|
|
522
|
+
store = model_settings.store if model_settings.store is not None else True
|
|
523
|
+
|
|
524
|
+
reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None
|
|
525
|
+
|
|
521
526
|
ret = await self._get_client().chat.completions.create(
|
|
522
527
|
model=self.model,
|
|
523
528
|
messages=converted_messages,
|
|
@@ -532,7 +537,10 @@ class OpenAIChatCompletionsModel(Model):
|
|
|
532
537
|
parallel_tool_calls=parallel_tool_calls,
|
|
533
538
|
stream=stream,
|
|
534
539
|
stream_options={"include_usage": True} if stream else NOT_GIVEN,
|
|
540
|
+
store=store,
|
|
541
|
+
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
|
|
535
542
|
extra_headers=_HEADERS,
|
|
543
|
+
metadata=model_settings.metadata,
|
|
536
544
|
)
|
|
537
545
|
|
|
538
546
|
if isinstance(ret, ChatCompletion):
|
|
@@ -551,6 +559,7 @@ class OpenAIChatCompletionsModel(Model):
|
|
|
551
559
|
temperature=model_settings.temperature,
|
|
552
560
|
tools=[],
|
|
553
561
|
parallel_tool_calls=parallel_tool_calls or False,
|
|
562
|
+
reasoning=model_settings.reasoning,
|
|
554
563
|
)
|
|
555
564
|
return response, ret
|
|
556
565
|
|
|
@@ -757,7 +766,7 @@ class _Converter:
|
|
|
757
766
|
elif isinstance(c, dict) and c.get("type") == "input_file":
|
|
758
767
|
raise UserError(f"File uploads are not supported for chat completions {c}")
|
|
759
768
|
else:
|
|
760
|
-
raise UserError(f"
|
|
769
|
+
raise UserError(f"Unknown content: {c}")
|
|
761
770
|
return out
|
|
762
771
|
|
|
763
772
|
@classmethod
|
|
@@ -919,12 +928,13 @@ class _Converter:
|
|
|
919
928
|
elif func_call := cls.maybe_function_tool_call(item):
|
|
920
929
|
asst = ensure_assistant_message()
|
|
921
930
|
tool_calls = list(asst.get("tool_calls", []))
|
|
931
|
+
arguments = func_call["arguments"] if func_call["arguments"] else "{}"
|
|
922
932
|
new_tool_call = ChatCompletionMessageToolCallParam(
|
|
923
933
|
id=func_call["call_id"],
|
|
924
934
|
type="function",
|
|
925
935
|
function={
|
|
926
936
|
"name": func_call["name"],
|
|
927
|
-
"arguments":
|
|
937
|
+
"arguments": arguments,
|
|
928
938
|
},
|
|
929
939
|
)
|
|
930
940
|
tool_calls.append(new_tool_call)
|
|
@@ -967,7 +977,7 @@ class ToolConverter:
|
|
|
967
977
|
}
|
|
968
978
|
|
|
969
979
|
raise UserError(
|
|
970
|
-
f"Hosted tools are not supported with the ChatCompletions API.
|
|
980
|
+
f"Hosted tools are not supported with the ChatCompletions API. Got tool type: "
|
|
971
981
|
f"{type(tool)}, tool: {tool}"
|
|
972
982
|
)
|
|
973
983
|
|
|
@@ -83,7 +83,7 @@ class OpenAIResponsesModel(Model):
|
|
|
83
83
|
)
|
|
84
84
|
|
|
85
85
|
if _debug.DONT_LOG_MODEL_DATA:
|
|
86
|
-
logger.debug("LLM
|
|
86
|
+
logger.debug("LLM responded")
|
|
87
87
|
else:
|
|
88
88
|
logger.debug(
|
|
89
89
|
"LLM resp:\n"
|
|
@@ -208,7 +208,11 @@ class OpenAIResponsesModel(Model):
|
|
|
208
208
|
list_input = ItemHelpers.input_to_new_input_list(input)
|
|
209
209
|
|
|
210
210
|
parallel_tool_calls = (
|
|
211
|
-
True
|
|
211
|
+
True
|
|
212
|
+
if model_settings.parallel_tool_calls and tools and len(tools) > 0
|
|
213
|
+
else False
|
|
214
|
+
if model_settings.parallel_tool_calls is False
|
|
215
|
+
else NOT_GIVEN
|
|
212
216
|
)
|
|
213
217
|
|
|
214
218
|
tool_choice = Converter.convert_tool_choice(model_settings.tool_choice)
|
|
@@ -242,6 +246,9 @@ class OpenAIResponsesModel(Model):
|
|
|
242
246
|
stream=stream,
|
|
243
247
|
extra_headers=_HEADERS,
|
|
244
248
|
text=response_format,
|
|
249
|
+
store=self._non_null_or_not_given(model_settings.store),
|
|
250
|
+
reasoning=self._non_null_or_not_given(model_settings.reasoning),
|
|
251
|
+
metadata=model_settings.metadata,
|
|
245
252
|
)
|
|
246
253
|
|
|
247
254
|
def _get_client(self) -> AsyncOpenAI:
|
agents/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|