openai-agents 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

agents/mcp/server.py ADDED
@@ -0,0 +1,301 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ import asyncio
5
+ from contextlib import AbstractAsyncContextManager, AsyncExitStack
6
+ from pathlib import Path
7
+ from typing import Any, Literal
8
+
9
+ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10
+ from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
11
+ from mcp.client.sse import sse_client
12
+ from mcp.types import CallToolResult, JSONRPCMessage
13
+ from typing_extensions import NotRequired, TypedDict
14
+
15
+ from ..exceptions import UserError
16
+ from ..logger import logger
17
+
18
+
19
+ class MCPServer(abc.ABC):
20
+ """Base class for Model Context Protocol servers."""
21
+
22
+ @abc.abstractmethod
23
+ async def connect(self):
24
+ """Connect to the server. For example, this might mean spawning a subprocess or
25
+ opening a network connection. The server is expected to remain connected until
26
+ `cleanup()` is called.
27
+ """
28
+ pass
29
+
30
+ @property
31
+ @abc.abstractmethod
32
+ def name(self) -> str:
33
+ """A readable name for the server."""
34
+ pass
35
+
36
+ @abc.abstractmethod
37
+ async def cleanup(self):
38
+ """Cleanup the server. For example, this might mean closing a subprocess or
39
+ closing a network connection.
40
+ """
41
+ pass
42
+
43
+ @abc.abstractmethod
44
+ async def list_tools(self) -> list[MCPTool]:
45
+ """List the tools available on the server."""
46
+ pass
47
+
48
+ @abc.abstractmethod
49
+ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
50
+ """Invoke a tool on the server."""
51
+ pass
52
+
53
+
54
+ class _MCPServerWithClientSession(MCPServer, abc.ABC):
55
+ """Base class for MCP servers that use a `ClientSession` to communicate with the server."""
56
+
57
+ def __init__(self, cache_tools_list: bool):
58
+ """
59
+ Args:
60
+ cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
61
+ cached and only fetched from the server once. If `False`, the tools list will be
62
+ fetched from the server on each call to `list_tools()`. The cache can be invalidated
63
+ by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
64
+ server will not change its tools list, because it can drastically improve latency
65
+ (by avoiding a round-trip to the server every time).
66
+ """
67
+ self.session: ClientSession | None = None
68
+ self.exit_stack: AsyncExitStack = AsyncExitStack()
69
+ self._cleanup_lock: asyncio.Lock = asyncio.Lock()
70
+ self.cache_tools_list = cache_tools_list
71
+
72
+ # The cache is always dirty at startup, so that we fetch tools at least once
73
+ self._cache_dirty = True
74
+ self._tools_list: list[MCPTool] | None = None
75
+
76
+ @abc.abstractmethod
77
+ def create_streams(
78
+ self,
79
+ ) -> AbstractAsyncContextManager[
80
+ tuple[
81
+ MemoryObjectReceiveStream[JSONRPCMessage | Exception],
82
+ MemoryObjectSendStream[JSONRPCMessage],
83
+ ]
84
+ ]:
85
+ """Create the streams for the server."""
86
+ pass
87
+
88
+ async def __aenter__(self):
89
+ await self.connect()
90
+ return self
91
+
92
+ async def __aexit__(self, exc_type, exc_value, traceback):
93
+ await self.cleanup()
94
+
95
+ def invalidate_tools_cache(self):
96
+ """Invalidate the tools cache."""
97
+ self._cache_dirty = True
98
+
99
+ async def connect(self):
100
+ """Connect to the server."""
101
+ try:
102
+ transport = await self.exit_stack.enter_async_context(self.create_streams())
103
+ read, write = transport
104
+ session = await self.exit_stack.enter_async_context(ClientSession(read, write))
105
+ await session.initialize()
106
+ self.session = session
107
+ except Exception as e:
108
+ logger.error(f"Error initializing MCP server: {e}")
109
+ await self.cleanup()
110
+ raise
111
+
112
+ async def list_tools(self) -> list[MCPTool]:
113
+ """List the tools available on the server."""
114
+ if not self.session:
115
+ raise UserError("Server not initialized. Make sure you call `connect()` first.")
116
+
117
+ # Return from cache if caching is enabled, we have tools, and the cache is not dirty
118
+ if self.cache_tools_list and not self._cache_dirty and self._tools_list:
119
+ return self._tools_list
120
+
121
+ # Reset the cache dirty to False
122
+ self._cache_dirty = False
123
+
124
+ # Fetch the tools from the server
125
+ self._tools_list = (await self.session.list_tools()).tools
126
+ return self._tools_list
127
+
128
+ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
129
+ """Invoke a tool on the server."""
130
+ if not self.session:
131
+ raise UserError("Server not initialized. Make sure you call `connect()` first.")
132
+
133
+ return await self.session.call_tool(tool_name, arguments)
134
+
135
+ async def cleanup(self):
136
+ """Cleanup the server."""
137
+ async with self._cleanup_lock:
138
+ try:
139
+ await self.exit_stack.aclose()
140
+ self.session = None
141
+ except Exception as e:
142
+ logger.error(f"Error cleaning up server: {e}")
143
+
144
+
145
+ class MCPServerStdioParams(TypedDict):
146
+ """Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another
147
+ import.
148
+ """
149
+
150
+ command: str
151
+ """The executable to run to start the server. For example, `python` or `node`."""
152
+
153
+ args: NotRequired[list[str]]
154
+ """Command line args to pass to the `command` executable. For example, `['foo.py']` or
155
+ `['server.js', '--port', '8080']`."""
156
+
157
+ env: NotRequired[dict[str, str]]
158
+ """The environment variables to set for the server. ."""
159
+
160
+ cwd: NotRequired[str | Path]
161
+ """The working directory to use when spawning the process."""
162
+
163
+ encoding: NotRequired[str]
164
+ """The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`."""
165
+
166
+ encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]]
167
+ """The text encoding error handler. Defaults to `strict`.
168
+
169
+ See https://docs.python.org/3/library/codecs.html#codec-base-classes for
170
+ explanations of possible values.
171
+ """
172
+
173
+
174
+ class MCPServerStdio(_MCPServerWithClientSession):
175
+ """MCP server implementation that uses the stdio transport. See the [spec]
176
+ (https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for
177
+ details.
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ params: MCPServerStdioParams,
183
+ cache_tools_list: bool = False,
184
+ name: str | None = None,
185
+ ):
186
+ """Create a new MCP server based on the stdio transport.
187
+
188
+ Args:
189
+ params: The params that configure the server. This includes the command to run to
190
+ start the server, the args to pass to the command, the environment variables to
191
+ set for the server, the working directory to use when spawning the process, and
192
+ the text encoding used when sending/receiving messages to the server.
193
+ cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
194
+ cached and only fetched from the server once. If `False`, the tools list will be
195
+ fetched from the server on each call to `list_tools()`. The cache can be
196
+ invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
197
+ if you know the server will not change its tools list, because it can drastically
198
+ improve latency (by avoiding a round-trip to the server every time).
199
+ name: A readable name for the server. If not provided, we'll create one from the
200
+ command.
201
+ """
202
+ super().__init__(cache_tools_list)
203
+
204
+ self.params = StdioServerParameters(
205
+ command=params["command"],
206
+ args=params.get("args", []),
207
+ env=params.get("env"),
208
+ cwd=params.get("cwd"),
209
+ encoding=params.get("encoding", "utf-8"),
210
+ encoding_error_handler=params.get("encoding_error_handler", "strict"),
211
+ )
212
+
213
+ self._name = name or f"stdio: {self.params.command}"
214
+
215
+ def create_streams(
216
+ self,
217
+ ) -> AbstractAsyncContextManager[
218
+ tuple[
219
+ MemoryObjectReceiveStream[JSONRPCMessage | Exception],
220
+ MemoryObjectSendStream[JSONRPCMessage],
221
+ ]
222
+ ]:
223
+ """Create the streams for the server."""
224
+ return stdio_client(self.params)
225
+
226
+ @property
227
+ def name(self) -> str:
228
+ """A readable name for the server."""
229
+ return self._name
230
+
231
+
232
+ class MCPServerSseParams(TypedDict):
233
+ """Mirrors the params in`mcp.client.sse.sse_client`."""
234
+
235
+ url: str
236
+ """The URL of the server."""
237
+
238
+ headers: NotRequired[dict[str, str]]
239
+ """The headers to send to the server."""
240
+
241
+ timeout: NotRequired[float]
242
+ """The timeout for the HTTP request. Defaults to 5 seconds."""
243
+
244
+ sse_read_timeout: NotRequired[float]
245
+ """The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
246
+
247
+
248
+ class MCPServerSse(_MCPServerWithClientSession):
249
+ """MCP server implementation that uses the HTTP with SSE transport. See the [spec]
250
+ (https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
251
+ for details.
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ params: MCPServerSseParams,
257
+ cache_tools_list: bool = False,
258
+ name: str | None = None,
259
+ ):
260
+ """Create a new MCP server based on the HTTP with SSE transport.
261
+
262
+ Args:
263
+ params: The params that configure the server. This includes the URL of the server,
264
+ the headers to send to the server, the timeout for the HTTP request, and the
265
+ timeout for the SSE connection.
266
+
267
+ cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
268
+ cached and only fetched from the server once. If `False`, the tools list will be
269
+ fetched from the server on each call to `list_tools()`. The cache can be
270
+ invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
271
+ if you know the server will not change its tools list, because it can drastically
272
+ improve latency (by avoiding a round-trip to the server every time).
273
+
274
+ name: A readable name for the server. If not provided, we'll create one from the
275
+ URL.
276
+ """
277
+ super().__init__(cache_tools_list)
278
+
279
+ self.params = params
280
+ self._name = name or f"sse: {self.params['url']}"
281
+
282
+ def create_streams(
283
+ self,
284
+ ) -> AbstractAsyncContextManager[
285
+ tuple[
286
+ MemoryObjectReceiveStream[JSONRPCMessage | Exception],
287
+ MemoryObjectSendStream[JSONRPCMessage],
288
+ ]
289
+ ]:
290
+ """Create the streams for the server."""
291
+ return sse_client(
292
+ url=self.params["url"],
293
+ headers=self.params.get("headers", None),
294
+ timeout=self.params.get("timeout", 5),
295
+ sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
296
+ )
297
+
298
+ @property
299
+ def name(self) -> str:
300
+ """A readable name for the server."""
301
+ return self._name
agents/mcp/util.py ADDED
@@ -0,0 +1,131 @@
1
+ import functools
2
+ import json
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ from agents.strict_schema import ensure_strict_json_schema
6
+
7
+ from .. import _debug
8
+ from ..exceptions import AgentsException, ModelBehaviorError, UserError
9
+ from ..logger import logger
10
+ from ..run_context import RunContextWrapper
11
+ from ..tool import FunctionTool, Tool
12
+ from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span
13
+
14
+ if TYPE_CHECKING:
15
+ from mcp.types import Tool as MCPTool
16
+
17
+ from .server import MCPServer
18
+
19
+
20
+ class MCPUtil:
21
+ """Set of utilities for interop between MCP and Agents SDK tools."""
22
+
23
+ @classmethod
24
+ async def get_all_function_tools(
25
+ cls, servers: list["MCPServer"], convert_schemas_to_strict: bool
26
+ ) -> list[Tool]:
27
+ """Get all function tools from a list of MCP servers."""
28
+ tools = []
29
+ tool_names: set[str] = set()
30
+ for server in servers:
31
+ server_tools = await cls.get_function_tools(server, convert_schemas_to_strict)
32
+ server_tool_names = {tool.name for tool in server_tools}
33
+ if len(server_tool_names & tool_names) > 0:
34
+ raise UserError(
35
+ f"Duplicate tool names found across MCP servers: "
36
+ f"{server_tool_names & tool_names}"
37
+ )
38
+ tool_names.update(server_tool_names)
39
+ tools.extend(server_tools)
40
+
41
+ return tools
42
+
43
+ @classmethod
44
+ async def get_function_tools(
45
+ cls, server: "MCPServer", convert_schemas_to_strict: bool
46
+ ) -> list[Tool]:
47
+ """Get all function tools from a single MCP server."""
48
+
49
+ with mcp_tools_span(server=server.name) as span:
50
+ tools = await server.list_tools()
51
+ span.span_data.result = [tool.name for tool in tools]
52
+
53
+ return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools]
54
+
55
+ @classmethod
56
+ def to_function_tool(
57
+ cls, tool: "MCPTool", server: "MCPServer", convert_schemas_to_strict: bool
58
+ ) -> FunctionTool:
59
+ """Convert an MCP tool to an Agents SDK function tool."""
60
+ invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool)
61
+ schema, is_strict = tool.inputSchema, False
62
+ if convert_schemas_to_strict:
63
+ try:
64
+ schema = ensure_strict_json_schema(schema)
65
+ is_strict = True
66
+ except Exception as e:
67
+ logger.info(f"Error converting MCP schema to strict mode: {e}")
68
+
69
+ return FunctionTool(
70
+ name=tool.name,
71
+ description=tool.description or "",
72
+ params_json_schema=schema,
73
+ on_invoke_tool=invoke_func,
74
+ strict_json_schema=is_strict,
75
+ )
76
+
77
+ @classmethod
78
+ async def invoke_mcp_tool(
79
+ cls, server: "MCPServer", tool: "MCPTool", context: RunContextWrapper[Any], input_json: str
80
+ ) -> str:
81
+ """Invoke an MCP tool and return the result as a string."""
82
+ try:
83
+ json_data: dict[str, Any] = json.loads(input_json) if input_json else {}
84
+ except Exception as e:
85
+ if _debug.DONT_LOG_TOOL_DATA:
86
+ logger.debug(f"Invalid JSON input for tool {tool.name}")
87
+ else:
88
+ logger.debug(f"Invalid JSON input for tool {tool.name}: {input_json}")
89
+ raise ModelBehaviorError(
90
+ f"Invalid JSON input for tool {tool.name}: {input_json}"
91
+ ) from e
92
+
93
+ if _debug.DONT_LOG_TOOL_DATA:
94
+ logger.debug(f"Invoking MCP tool {tool.name}")
95
+ else:
96
+ logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}")
97
+
98
+ try:
99
+ result = await server.call_tool(tool.name, json_data)
100
+ except Exception as e:
101
+ logger.error(f"Error invoking MCP tool {tool.name}: {e}")
102
+ raise AgentsException(f"Error invoking MCP tool {tool.name}: {e}") from e
103
+
104
+ if _debug.DONT_LOG_TOOL_DATA:
105
+ logger.debug(f"MCP tool {tool.name} completed.")
106
+ else:
107
+ logger.debug(f"MCP tool {tool.name} returned {result}")
108
+
109
+ # The MCP tool result is a list of content items, whereas OpenAI tool outputs are a single
110
+ # string. We'll try to convert.
111
+ if len(result.content) == 1:
112
+ tool_output = result.content[0].model_dump_json()
113
+ elif len(result.content) > 1:
114
+ tool_output = json.dumps([item.model_dump() for item in result.content])
115
+ else:
116
+ logger.error(f"Errored MCP tool result: {result}")
117
+ tool_output = "Error running tool."
118
+
119
+ current_span = get_current_span()
120
+ if current_span:
121
+ if isinstance(current_span.span_data, FunctionSpanData):
122
+ current_span.span_data.output = tool_output
123
+ current_span.span_data.mcp_data = {
124
+ "server": server.name,
125
+ }
126
+ else:
127
+ logger.warning(
128
+ f"Current span is not a FunctionSpanData, skipping tool output: {current_span}"
129
+ )
130
+
131
+ return tool_output
agents/model_settings.py CHANGED
@@ -1,8 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import dataclass
3
+ from dataclasses import dataclass, fields, replace
4
4
  from typing import Literal
5
5
 
6
+ from openai.types.shared import Reasoning
7
+
6
8
 
7
9
  @dataclass
8
10
  class ModelSettings:
@@ -30,8 +32,9 @@ class ModelSettings:
30
32
  tool_choice: Literal["auto", "required", "none"] | str | None = None
31
33
  """The tool choice to use when calling the model."""
32
34
 
33
- parallel_tool_calls: bool | None = False
34
- """Whether to use parallel tool calls when calling the model."""
35
+ parallel_tool_calls: bool | None = None
36
+ """Whether to use parallel tool calls when calling the model.
37
+ Defaults to False if not provided."""
35
38
 
36
39
  truncation: Literal["auto", "disabled"] | None = None
37
40
  """The truncation strategy to use when calling the model."""
@@ -39,18 +42,27 @@ class ModelSettings:
39
42
  max_tokens: int | None = None
40
43
  """The maximum number of output tokens to generate."""
41
44
 
45
+ reasoning: Reasoning | None = None
46
+ """Configuration options for
47
+ [reasoning models](https://platform.openai.com/docs/guides/reasoning).
48
+ """
49
+
50
+ metadata: dict[str, str] | None = None
51
+ """Metadata to include with the model response call."""
52
+
53
+ store: bool | None = None
54
+ """Whether to store the generated model response for later retrieval.
55
+ Defaults to True if not provided."""
56
+
42
57
  def resolve(self, override: ModelSettings | None) -> ModelSettings:
43
58
  """Produce a new ModelSettings by overlaying any non-None values from the
44
59
  override on top of this instance."""
45
60
  if override is None:
46
61
  return self
47
- return ModelSettings(
48
- temperature=override.temperature or self.temperature,
49
- top_p=override.top_p or self.top_p,
50
- frequency_penalty=override.frequency_penalty or self.frequency_penalty,
51
- presence_penalty=override.presence_penalty or self.presence_penalty,
52
- tool_choice=override.tool_choice or self.tool_choice,
53
- parallel_tool_calls=override.parallel_tool_calls or self.parallel_tool_calls,
54
- truncation=override.truncation or self.truncation,
55
- max_tokens=override.max_tokens or self.max_tokens,
56
- )
62
+
63
+ changes = {
64
+ field.name: getattr(override, field.name)
65
+ for field in fields(self)
66
+ if getattr(override, field.name) is not None
67
+ }
68
+ return replace(self, **changes)
@@ -518,6 +518,11 @@ class OpenAIChatCompletionsModel(Model):
518
518
  f"Response format: {response_format}\n"
519
519
  )
520
520
 
521
+ # Match the behavior of Responses where store is True when not given
522
+ store = model_settings.store if model_settings.store is not None else True
523
+
524
+ reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None
525
+
521
526
  ret = await self._get_client().chat.completions.create(
522
527
  model=self.model,
523
528
  messages=converted_messages,
@@ -532,7 +537,10 @@ class OpenAIChatCompletionsModel(Model):
532
537
  parallel_tool_calls=parallel_tool_calls,
533
538
  stream=stream,
534
539
  stream_options={"include_usage": True} if stream else NOT_GIVEN,
540
+ store=store,
541
+ reasoning_effort=self._non_null_or_not_given(reasoning_effort),
535
542
  extra_headers=_HEADERS,
543
+ metadata=model_settings.metadata,
536
544
  )
537
545
 
538
546
  if isinstance(ret, ChatCompletion):
@@ -551,6 +559,7 @@ class OpenAIChatCompletionsModel(Model):
551
559
  temperature=model_settings.temperature,
552
560
  tools=[],
553
561
  parallel_tool_calls=parallel_tool_calls or False,
562
+ reasoning=model_settings.reasoning,
554
563
  )
555
564
  return response, ret
556
565
 
@@ -757,7 +766,7 @@ class _Converter:
757
766
  elif isinstance(c, dict) and c.get("type") == "input_file":
758
767
  raise UserError(f"File uploads are not supported for chat completions {c}")
759
768
  else:
760
- raise UserError(f"Unknonw content: {c}")
769
+ raise UserError(f"Unknown content: {c}")
761
770
  return out
762
771
 
763
772
  @classmethod
@@ -919,12 +928,13 @@ class _Converter:
919
928
  elif func_call := cls.maybe_function_tool_call(item):
920
929
  asst = ensure_assistant_message()
921
930
  tool_calls = list(asst.get("tool_calls", []))
931
+ arguments = func_call["arguments"] if func_call["arguments"] else "{}"
922
932
  new_tool_call = ChatCompletionMessageToolCallParam(
923
933
  id=func_call["call_id"],
924
934
  type="function",
925
935
  function={
926
936
  "name": func_call["name"],
927
- "arguments": func_call["arguments"],
937
+ "arguments": arguments,
928
938
  },
929
939
  )
930
940
  tool_calls.append(new_tool_call)
@@ -967,7 +977,7 @@ class ToolConverter:
967
977
  }
968
978
 
969
979
  raise UserError(
970
- f"Hosted tools are not supported with the ChatCompletions API. FGot tool type: "
980
+ f"Hosted tools are not supported with the ChatCompletions API. Got tool type: "
971
981
  f"{type(tool)}, tool: {tool}"
972
982
  )
973
983
 
@@ -83,7 +83,7 @@ class OpenAIResponsesModel(Model):
83
83
  )
84
84
 
85
85
  if _debug.DONT_LOG_MODEL_DATA:
86
- logger.debug("LLM responsed")
86
+ logger.debug("LLM responded")
87
87
  else:
88
88
  logger.debug(
89
89
  "LLM resp:\n"
@@ -208,7 +208,11 @@ class OpenAIResponsesModel(Model):
208
208
  list_input = ItemHelpers.input_to_new_input_list(input)
209
209
 
210
210
  parallel_tool_calls = (
211
- True if model_settings.parallel_tool_calls and tools and len(tools) > 0 else NOT_GIVEN
211
+ True
212
+ if model_settings.parallel_tool_calls and tools and len(tools) > 0
213
+ else False
214
+ if model_settings.parallel_tool_calls is False
215
+ else NOT_GIVEN
212
216
  )
213
217
 
214
218
  tool_choice = Converter.convert_tool_choice(model_settings.tool_choice)
@@ -242,6 +246,9 @@ class OpenAIResponsesModel(Model):
242
246
  stream=stream,
243
247
  extra_headers=_HEADERS,
244
248
  text=response_format,
249
+ store=self._non_null_or_not_given(model_settings.store),
250
+ reasoning=self._non_null_or_not_given(model_settings.reasoning),
251
+ metadata=model_settings.metadata,
245
252
  )
246
253
 
247
254
  def _get_client(self) -> AsyncOpenAI:
agents/py.typed ADDED
@@ -0,0 +1 @@
1
+