openai-agents 0.0.19__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

Files changed (43) hide show
  1. agents/__init__.py +5 -2
  2. agents/_run_impl.py +35 -1
  3. agents/agent.py +65 -29
  4. agents/extensions/models/litellm_model.py +7 -3
  5. agents/function_schema.py +11 -1
  6. agents/guardrail.py +5 -1
  7. agents/handoffs.py +14 -0
  8. agents/lifecycle.py +26 -17
  9. agents/mcp/__init__.py +13 -1
  10. agents/mcp/server.py +173 -16
  11. agents/mcp/util.py +89 -6
  12. agents/memory/__init__.py +3 -0
  13. agents/memory/session.py +369 -0
  14. agents/model_settings.py +60 -6
  15. agents/models/chatcmpl_converter.py +31 -2
  16. agents/models/chatcmpl_stream_handler.py +128 -16
  17. agents/models/openai_chatcompletions.py +12 -10
  18. agents/models/openai_responses.py +25 -8
  19. agents/realtime/README.md +3 -0
  20. agents/realtime/__init__.py +174 -0
  21. agents/realtime/agent.py +80 -0
  22. agents/realtime/config.py +128 -0
  23. agents/realtime/events.py +216 -0
  24. agents/realtime/items.py +91 -0
  25. agents/realtime/model.py +69 -0
  26. agents/realtime/model_events.py +159 -0
  27. agents/realtime/model_inputs.py +100 -0
  28. agents/realtime/openai_realtime.py +584 -0
  29. agents/realtime/runner.py +118 -0
  30. agents/realtime/session.py +502 -0
  31. agents/repl.py +1 -4
  32. agents/run.py +131 -10
  33. agents/tool.py +30 -6
  34. agents/tool_context.py +16 -3
  35. agents/tracing/__init__.py +1 -2
  36. agents/tracing/processor_interface.py +1 -1
  37. agents/voice/models/openai_stt.py +1 -1
  38. agents/voice/pipeline.py +6 -0
  39. agents/voice/workflow.py +8 -0
  40. {openai_agents-0.0.19.dist-info → openai_agents-0.2.0.dist-info}/METADATA +133 -8
  41. {openai_agents-0.0.19.dist-info → openai_agents-0.2.0.dist-info}/RECORD +43 -29
  42. {openai_agents-0.0.19.dist-info → openai_agents-0.2.0.dist-info}/WHEEL +0 -0
  43. {openai_agents-0.0.19.dist-info → openai_agents-0.2.0.dist-info}/licenses/LICENSE +0 -0
agents/mcp/server.py CHANGED
@@ -2,21 +2,27 @@ from __future__ import annotations
2
2
 
3
3
  import abc
4
4
  import asyncio
5
+ import inspect
5
6
  from contextlib import AbstractAsyncContextManager, AsyncExitStack
6
7
  from datetime import timedelta
7
8
  from pathlib import Path
8
- from typing import Any, Literal
9
+ from typing import TYPE_CHECKING, Any, Literal, cast
9
10
 
10
11
  from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
11
12
  from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
12
13
  from mcp.client.sse import sse_client
13
14
  from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
14
15
  from mcp.shared.message import SessionMessage
15
- from mcp.types import CallToolResult, InitializeResult
16
+ from mcp.types import CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult
16
17
  from typing_extensions import NotRequired, TypedDict
17
18
 
18
19
  from ..exceptions import UserError
19
20
  from ..logger import logger
21
+ from ..run_context import RunContextWrapper
22
+ from .util import ToolFilter, ToolFilterCallable, ToolFilterContext, ToolFilterStatic
23
+
24
+ if TYPE_CHECKING:
25
+ from ..agent import AgentBase
20
26
 
21
27
 
22
28
  class MCPServer(abc.ABC):
@@ -44,7 +50,11 @@ class MCPServer(abc.ABC):
44
50
  pass
45
51
 
46
52
  @abc.abstractmethod
47
- async def list_tools(self) -> list[MCPTool]:
53
+ async def list_tools(
54
+ self,
55
+ run_context: RunContextWrapper[Any] | None = None,
56
+ agent: AgentBase | None = None,
57
+ ) -> list[MCPTool]:
48
58
  """List the tools available on the server."""
49
59
  pass
50
60
 
@@ -53,11 +63,30 @@ class MCPServer(abc.ABC):
53
63
  """Invoke a tool on the server."""
54
64
  pass
55
65
 
66
+ @abc.abstractmethod
67
+ async def list_prompts(
68
+ self,
69
+ ) -> ListPromptsResult:
70
+ """List the prompts available on the server."""
71
+ pass
72
+
73
+ @abc.abstractmethod
74
+ async def get_prompt(
75
+ self, name: str, arguments: dict[str, Any] | None = None
76
+ ) -> GetPromptResult:
77
+ """Get a specific prompt from the server."""
78
+ pass
79
+
56
80
 
57
81
  class _MCPServerWithClientSession(MCPServer, abc.ABC):
58
82
  """Base class for MCP servers that use a `ClientSession` to communicate with the server."""
59
83
 
60
- def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None):
84
+ def __init__(
85
+ self,
86
+ cache_tools_list: bool,
87
+ client_session_timeout_seconds: float | None,
88
+ tool_filter: ToolFilter = None,
89
+ ):
61
90
  """
62
91
  Args:
63
92
  cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
@@ -68,6 +97,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
68
97
  (by avoiding a round-trip to the server every time).
69
98
 
70
99
  client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
100
+ tool_filter: The tool filter to use for filtering tools.
71
101
  """
72
102
  self.session: ClientSession | None = None
73
103
  self.exit_stack: AsyncExitStack = AsyncExitStack()
@@ -81,6 +111,86 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
81
111
  self._cache_dirty = True
82
112
  self._tools_list: list[MCPTool] | None = None
83
113
 
114
+ self.tool_filter = tool_filter
115
+
116
+ async def _apply_tool_filter(
117
+ self,
118
+ tools: list[MCPTool],
119
+ run_context: RunContextWrapper[Any],
120
+ agent: AgentBase,
121
+ ) -> list[MCPTool]:
122
+ """Apply the tool filter to the list of tools."""
123
+ if self.tool_filter is None:
124
+ return tools
125
+
126
+ # Handle static tool filter
127
+ if isinstance(self.tool_filter, dict):
128
+ return self._apply_static_tool_filter(tools, self.tool_filter)
129
+
130
+ # Handle callable tool filter (dynamic filter)
131
+ else:
132
+ return await self._apply_dynamic_tool_filter(tools, run_context, agent)
133
+
134
+ def _apply_static_tool_filter(
135
+ self, tools: list[MCPTool], static_filter: ToolFilterStatic
136
+ ) -> list[MCPTool]:
137
+ """Apply static tool filtering based on allowlist and blocklist."""
138
+ filtered_tools = tools
139
+
140
+ # Apply allowed_tool_names filter (whitelist)
141
+ if "allowed_tool_names" in static_filter:
142
+ allowed_names = static_filter["allowed_tool_names"]
143
+ filtered_tools = [t for t in filtered_tools if t.name in allowed_names]
144
+
145
+ # Apply blocked_tool_names filter (blacklist)
146
+ if "blocked_tool_names" in static_filter:
147
+ blocked_names = static_filter["blocked_tool_names"]
148
+ filtered_tools = [t for t in filtered_tools if t.name not in blocked_names]
149
+
150
+ return filtered_tools
151
+
152
+ async def _apply_dynamic_tool_filter(
153
+ self,
154
+ tools: list[MCPTool],
155
+ run_context: RunContextWrapper[Any],
156
+ agent: AgentBase,
157
+ ) -> list[MCPTool]:
158
+ """Apply dynamic tool filtering using a callable filter function."""
159
+
160
+ # Ensure we have a callable filter and cast to help mypy
161
+ if not callable(self.tool_filter):
162
+ raise ValueError("Tool filter must be callable for dynamic filtering")
163
+ tool_filter_func = cast(ToolFilterCallable, self.tool_filter)
164
+
165
+ # Create filter context
166
+ filter_context = ToolFilterContext(
167
+ run_context=run_context,
168
+ agent=agent,
169
+ server_name=self.name,
170
+ )
171
+
172
+ filtered_tools = []
173
+ for tool in tools:
174
+ try:
175
+ # Call the filter function with context
176
+ result = tool_filter_func(filter_context, tool)
177
+
178
+ if inspect.isawaitable(result):
179
+ should_include = await result
180
+ else:
181
+ should_include = result
182
+
183
+ if should_include:
184
+ filtered_tools.append(tool)
185
+ except Exception as e:
186
+ logger.error(
187
+ f"Error applying tool filter to tool '{tool.name}' on server '{self.name}': {e}"
188
+ )
189
+ # On error, exclude the tool for safety
190
+ continue
191
+
192
+ return filtered_tools
193
+
84
194
  @abc.abstractmethod
85
195
  def create_streams(
86
196
  self,
@@ -131,21 +241,32 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
131
241
  await self.cleanup()
132
242
  raise
133
243
 
134
- async def list_tools(self) -> list[MCPTool]:
244
+ async def list_tools(
245
+ self,
246
+ run_context: RunContextWrapper[Any] | None = None,
247
+ agent: AgentBase | None = None,
248
+ ) -> list[MCPTool]:
135
249
  """List the tools available on the server."""
136
250
  if not self.session:
137
251
  raise UserError("Server not initialized. Make sure you call `connect()` first.")
138
252
 
139
253
  # Return from cache if caching is enabled, we have tools, and the cache is not dirty
140
254
  if self.cache_tools_list and not self._cache_dirty and self._tools_list:
141
- return self._tools_list
142
-
143
- # Reset the cache dirty to False
144
- self._cache_dirty = False
145
-
146
- # Fetch the tools from the server
147
- self._tools_list = (await self.session.list_tools()).tools
148
- return self._tools_list
255
+ tools = self._tools_list
256
+ else:
257
+ # Reset the cache dirty to False
258
+ self._cache_dirty = False
259
+ # Fetch the tools from the server
260
+ self._tools_list = (await self.session.list_tools()).tools
261
+ tools = self._tools_list
262
+
263
+ # Filter tools based on tool_filter
264
+ filtered_tools = tools
265
+ if self.tool_filter is not None:
266
+ if run_context is None or agent is None:
267
+ raise UserError("run_context and agent are required for dynamic tool filtering")
268
+ filtered_tools = await self._apply_tool_filter(filtered_tools, run_context, agent)
269
+ return filtered_tools
149
270
 
150
271
  async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
151
272
  """Invoke a tool on the server."""
@@ -154,6 +275,24 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
154
275
 
155
276
  return await self.session.call_tool(tool_name, arguments)
156
277
 
278
+ async def list_prompts(
279
+ self,
280
+ ) -> ListPromptsResult:
281
+ """List the prompts available on the server."""
282
+ if not self.session:
283
+ raise UserError("Server not initialized. Make sure you call `connect()` first.")
284
+
285
+ return await self.session.list_prompts()
286
+
287
+ async def get_prompt(
288
+ self, name: str, arguments: dict[str, Any] | None = None
289
+ ) -> GetPromptResult:
290
+ """Get a specific prompt from the server."""
291
+ if not self.session:
292
+ raise UserError("Server not initialized. Make sure you call `connect()` first.")
293
+
294
+ return await self.session.get_prompt(name, arguments)
295
+
157
296
  async def cleanup(self):
158
297
  """Cleanup the server."""
159
298
  async with self._cleanup_lock:
@@ -206,6 +345,7 @@ class MCPServerStdio(_MCPServerWithClientSession):
206
345
  cache_tools_list: bool = False,
207
346
  name: str | None = None,
208
347
  client_session_timeout_seconds: float | None = 5,
348
+ tool_filter: ToolFilter = None,
209
349
  ):
210
350
  """Create a new MCP server based on the stdio transport.
211
351
 
@@ -223,8 +363,13 @@ class MCPServerStdio(_MCPServerWithClientSession):
223
363
  name: A readable name for the server. If not provided, we'll create one from the
224
364
  command.
225
365
  client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
366
+ tool_filter: The tool filter to use for filtering tools.
226
367
  """
227
- super().__init__(cache_tools_list, client_session_timeout_seconds)
368
+ super().__init__(
369
+ cache_tools_list,
370
+ client_session_timeout_seconds,
371
+ tool_filter,
372
+ )
228
373
 
229
374
  self.params = StdioServerParameters(
230
375
  command=params["command"],
@@ -283,6 +428,7 @@ class MCPServerSse(_MCPServerWithClientSession):
283
428
  cache_tools_list: bool = False,
284
429
  name: str | None = None,
285
430
  client_session_timeout_seconds: float | None = 5,
431
+ tool_filter: ToolFilter = None,
286
432
  ):
287
433
  """Create a new MCP server based on the HTTP with SSE transport.
288
434
 
@@ -302,8 +448,13 @@ class MCPServerSse(_MCPServerWithClientSession):
302
448
  URL.
303
449
 
304
450
  client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
451
+ tool_filter: The tool filter to use for filtering tools.
305
452
  """
306
- super().__init__(cache_tools_list, client_session_timeout_seconds)
453
+ super().__init__(
454
+ cache_tools_list,
455
+ client_session_timeout_seconds,
456
+ tool_filter,
457
+ )
307
458
 
308
459
  self.params = params
309
460
  self._name = name or f"sse: {self.params['url']}"
@@ -362,6 +513,7 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
362
513
  cache_tools_list: bool = False,
363
514
  name: str | None = None,
364
515
  client_session_timeout_seconds: float | None = 5,
516
+ tool_filter: ToolFilter = None,
365
517
  ):
366
518
  """Create a new MCP server based on the Streamable HTTP transport.
367
519
 
@@ -382,8 +534,13 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
382
534
  URL.
383
535
 
384
536
  client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
537
+ tool_filter: The tool filter to use for filtering tools.
385
538
  """
386
- super().__init__(cache_tools_list, client_session_timeout_seconds)
539
+ super().__init__(
540
+ cache_tools_list,
541
+ client_session_timeout_seconds,
542
+ tool_filter,
543
+ )
387
544
 
388
545
  self.params = params
389
546
  self._name = name or f"streamable_http: {self.params['url']}"
agents/mcp/util.py CHANGED
@@ -1,34 +1,113 @@
1
1
  import functools
2
2
  import json
3
- from typing import TYPE_CHECKING, Any
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
4
5
 
5
- from agents.strict_schema import ensure_strict_json_schema
6
+ from typing_extensions import NotRequired, TypedDict
6
7
 
7
8
  from .. import _debug
8
9
  from ..exceptions import AgentsException, ModelBehaviorError, UserError
9
10
  from ..logger import logger
10
11
  from ..run_context import RunContextWrapper
12
+ from ..strict_schema import ensure_strict_json_schema
11
13
  from ..tool import FunctionTool, Tool
12
14
  from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span
15
+ from ..util._types import MaybeAwaitable
13
16
 
14
17
  if TYPE_CHECKING:
15
18
  from mcp.types import Tool as MCPTool
16
19
 
20
+ from ..agent import AgentBase
17
21
  from .server import MCPServer
18
22
 
19
23
 
24
+ @dataclass
25
+ class ToolFilterContext:
26
+ """Context information available to tool filter functions."""
27
+
28
+ run_context: RunContextWrapper[Any]
29
+ """The current run context."""
30
+
31
+ agent: "AgentBase"
32
+ """The agent that is requesting the tool list."""
33
+
34
+ server_name: str
35
+ """The name of the MCP server."""
36
+
37
+
38
+ ToolFilterCallable = Callable[["ToolFilterContext", "MCPTool"], MaybeAwaitable[bool]]
39
+ """A function that determines whether a tool should be available.
40
+
41
+ Args:
42
+ context: The context information including run context, agent, and server name.
43
+ tool: The MCP tool to filter.
44
+
45
+ Returns:
46
+ Whether the tool should be available (True) or filtered out (False).
47
+ """
48
+
49
+
50
+ class ToolFilterStatic(TypedDict):
51
+ """Static tool filter configuration using allowlists and blocklists."""
52
+
53
+ allowed_tool_names: NotRequired[list[str]]
54
+ """Optional list of tool names to allow (whitelist).
55
+ If set, only these tools will be available."""
56
+
57
+ blocked_tool_names: NotRequired[list[str]]
58
+ """Optional list of tool names to exclude (blacklist).
59
+ If set, these tools will be filtered out."""
60
+
61
+
62
+ ToolFilter = Union[ToolFilterCallable, ToolFilterStatic, None]
63
+ """A tool filter that can be either a function, static configuration, or None (no filtering)."""
64
+
65
+
66
+ def create_static_tool_filter(
67
+ allowed_tool_names: Optional[list[str]] = None,
68
+ blocked_tool_names: Optional[list[str]] = None,
69
+ ) -> Optional[ToolFilterStatic]:
70
+ """Create a static tool filter from allowlist and blocklist parameters.
71
+
72
+ This is a convenience function for creating a ToolFilterStatic.
73
+
74
+ Args:
75
+ allowed_tool_names: Optional list of tool names to allow (whitelist).
76
+ blocked_tool_names: Optional list of tool names to exclude (blacklist).
77
+
78
+ Returns:
79
+ A ToolFilterStatic if any filtering is specified, None otherwise.
80
+ """
81
+ if allowed_tool_names is None and blocked_tool_names is None:
82
+ return None
83
+
84
+ filter_dict: ToolFilterStatic = {}
85
+ if allowed_tool_names is not None:
86
+ filter_dict["allowed_tool_names"] = allowed_tool_names
87
+ if blocked_tool_names is not None:
88
+ filter_dict["blocked_tool_names"] = blocked_tool_names
89
+
90
+ return filter_dict
91
+
92
+
20
93
  class MCPUtil:
21
94
  """Set of utilities for interop between MCP and Agents SDK tools."""
22
95
 
23
96
  @classmethod
24
97
  async def get_all_function_tools(
25
- cls, servers: list["MCPServer"], convert_schemas_to_strict: bool
98
+ cls,
99
+ servers: list["MCPServer"],
100
+ convert_schemas_to_strict: bool,
101
+ run_context: RunContextWrapper[Any],
102
+ agent: "AgentBase",
26
103
  ) -> list[Tool]:
27
104
  """Get all function tools from a list of MCP servers."""
28
105
  tools = []
29
106
  tool_names: set[str] = set()
30
107
  for server in servers:
31
- server_tools = await cls.get_function_tools(server, convert_schemas_to_strict)
108
+ server_tools = await cls.get_function_tools(
109
+ server, convert_schemas_to_strict, run_context, agent
110
+ )
32
111
  server_tool_names = {tool.name for tool in server_tools}
33
112
  if len(server_tool_names & tool_names) > 0:
34
113
  raise UserError(
@@ -42,12 +121,16 @@ class MCPUtil:
42
121
 
43
122
  @classmethod
44
123
  async def get_function_tools(
45
- cls, server: "MCPServer", convert_schemas_to_strict: bool
124
+ cls,
125
+ server: "MCPServer",
126
+ convert_schemas_to_strict: bool,
127
+ run_context: RunContextWrapper[Any],
128
+ agent: "AgentBase",
46
129
  ) -> list[Tool]:
47
130
  """Get all function tools from a single MCP server."""
48
131
 
49
132
  with mcp_tools_span(server=server.name) as span:
50
- tools = await server.list_tools()
133
+ tools = await server.list_tools(run_context, agent)
51
134
  span.span_data.result = [tool.name for tool in tools]
52
135
 
53
136
  return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools]
@@ -0,0 +1,3 @@
1
+ from .session import Session, SQLiteSession
2
+
3
+ __all__ = ["Session", "SQLiteSession"]