openai-agents 0.0.19__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +0 -1
- agents/_run_impl.py +30 -0
- agents/agent.py +7 -3
- agents/extensions/models/litellm_model.py +7 -3
- agents/handoffs.py +14 -0
- agents/mcp/__init__.py +13 -1
- agents/mcp/server.py +140 -15
- agents/mcp/util.py +89 -5
- agents/model_settings.py +52 -6
- agents/models/chatcmpl_converter.py +12 -0
- agents/models/chatcmpl_stream_handler.py +127 -15
- agents/models/openai_chatcompletions.py +12 -10
- agents/models/openai_responses.py +14 -4
- agents/repl.py +1 -4
- agents/run.py +25 -6
- agents/tool.py +25 -0
- agents/tracing/__init__.py +1 -2
- agents/tracing/processor_interface.py +1 -1
- {openai_agents-0.0.19.dist-info → openai_agents-0.1.0.dist-info}/METADATA +14 -6
- {openai_agents-0.0.19.dist-info → openai_agents-0.1.0.dist-info}/RECORD +22 -22
- {openai_agents-0.0.19.dist-info → openai_agents-0.1.0.dist-info}/WHEEL +0 -0
- {openai_agents-0.0.19.dist-info → openai_agents-0.1.0.dist-info}/licenses/LICENSE +0 -0
agents/__init__.py
CHANGED
agents/_run_impl.py
CHANGED
|
@@ -28,6 +28,9 @@ from openai.types.responses.response_computer_tool_call import (
|
|
|
28
28
|
ActionType,
|
|
29
29
|
ActionWait,
|
|
30
30
|
)
|
|
31
|
+
from openai.types.responses.response_input_item_param import (
|
|
32
|
+
ComputerCallOutputAcknowledgedSafetyCheck,
|
|
33
|
+
)
|
|
31
34
|
from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse
|
|
32
35
|
from openai.types.responses.response_output_item import (
|
|
33
36
|
ImageGenerationCall,
|
|
@@ -67,6 +70,7 @@ from .run_context import RunContextWrapper, TContext
|
|
|
67
70
|
from .stream_events import RunItemStreamEvent, StreamEvent
|
|
68
71
|
from .tool import (
|
|
69
72
|
ComputerTool,
|
|
73
|
+
ComputerToolSafetyCheckData,
|
|
70
74
|
FunctionTool,
|
|
71
75
|
FunctionToolResult,
|
|
72
76
|
HostedMCPTool,
|
|
@@ -638,6 +642,29 @@ class RunImpl:
|
|
|
638
642
|
results: list[RunItem] = []
|
|
639
643
|
# Need to run these serially, because each action can affect the computer state
|
|
640
644
|
for action in actions:
|
|
645
|
+
acknowledged: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None
|
|
646
|
+
if action.tool_call.pending_safety_checks and action.computer_tool.on_safety_check:
|
|
647
|
+
acknowledged = []
|
|
648
|
+
for check in action.tool_call.pending_safety_checks:
|
|
649
|
+
data = ComputerToolSafetyCheckData(
|
|
650
|
+
ctx_wrapper=context_wrapper,
|
|
651
|
+
agent=agent,
|
|
652
|
+
tool_call=action.tool_call,
|
|
653
|
+
safety_check=check,
|
|
654
|
+
)
|
|
655
|
+
maybe = action.computer_tool.on_safety_check(data)
|
|
656
|
+
ack = await maybe if inspect.isawaitable(maybe) else maybe
|
|
657
|
+
if ack:
|
|
658
|
+
acknowledged.append(
|
|
659
|
+
ComputerCallOutputAcknowledgedSafetyCheck(
|
|
660
|
+
id=check.id,
|
|
661
|
+
code=check.code,
|
|
662
|
+
message=check.message,
|
|
663
|
+
)
|
|
664
|
+
)
|
|
665
|
+
else:
|
|
666
|
+
raise UserError("Computer tool safety check was not acknowledged")
|
|
667
|
+
|
|
641
668
|
results.append(
|
|
642
669
|
await ComputerAction.execute(
|
|
643
670
|
agent=agent,
|
|
@@ -645,6 +672,7 @@ class RunImpl:
|
|
|
645
672
|
hooks=hooks,
|
|
646
673
|
context_wrapper=context_wrapper,
|
|
647
674
|
config=config,
|
|
675
|
+
acknowledged_safety_checks=acknowledged,
|
|
648
676
|
)
|
|
649
677
|
)
|
|
650
678
|
|
|
@@ -998,6 +1026,7 @@ class ComputerAction:
|
|
|
998
1026
|
hooks: RunHooks[TContext],
|
|
999
1027
|
context_wrapper: RunContextWrapper[TContext],
|
|
1000
1028
|
config: RunConfig,
|
|
1029
|
+
acknowledged_safety_checks: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None,
|
|
1001
1030
|
) -> RunItem:
|
|
1002
1031
|
output_func = (
|
|
1003
1032
|
cls._get_screenshot_async(action.computer_tool.computer, action.tool_call)
|
|
@@ -1036,6 +1065,7 @@ class ComputerAction:
|
|
|
1036
1065
|
"image_url": image_url,
|
|
1037
1066
|
},
|
|
1038
1067
|
type="computer_call_output",
|
|
1068
|
+
acknowledged_safety_checks=acknowledged_safety_checks,
|
|
1039
1069
|
),
|
|
1040
1070
|
)
|
|
1041
1071
|
|
agents/agent.py
CHANGED
|
@@ -256,14 +256,18 @@ class Agent(Generic[TContext]):
|
|
|
256
256
|
"""Get the prompt for the agent."""
|
|
257
257
|
return await PromptUtil.to_model_input(self.prompt, run_context, self)
|
|
258
258
|
|
|
259
|
-
async def get_mcp_tools(
|
|
259
|
+
async def get_mcp_tools(
|
|
260
|
+
self, run_context: RunContextWrapper[TContext]
|
|
261
|
+
) -> list[Tool]:
|
|
260
262
|
"""Fetches the available tools from the MCP servers."""
|
|
261
263
|
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
|
|
262
|
-
return await MCPUtil.get_all_function_tools(
|
|
264
|
+
return await MCPUtil.get_all_function_tools(
|
|
265
|
+
self.mcp_servers, convert_schemas_to_strict, run_context, self
|
|
266
|
+
)
|
|
263
267
|
|
|
264
268
|
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
|
|
265
269
|
"""All agent tools, including MCP tools and function tools."""
|
|
266
|
-
mcp_tools = await self.get_mcp_tools()
|
|
270
|
+
mcp_tools = await self.get_mcp_tools(run_context)
|
|
267
271
|
|
|
268
272
|
async def _check_tool_enabled(tool: Tool) -> bool:
|
|
269
273
|
if not isinstance(tool, FunctionTool):
|
|
@@ -98,7 +98,11 @@ class LitellmModel(Model):
|
|
|
98
98
|
logger.debug("Received model response")
|
|
99
99
|
else:
|
|
100
100
|
logger.debug(
|
|
101
|
-
f"LLM resp:\n{
|
|
101
|
+
f"""LLM resp:\n{
|
|
102
|
+
json.dumps(
|
|
103
|
+
response.choices[0].message.model_dump(), indent=2, ensure_ascii=False
|
|
104
|
+
)
|
|
105
|
+
}\n"""
|
|
102
106
|
)
|
|
103
107
|
|
|
104
108
|
if hasattr(response, "usage"):
|
|
@@ -269,8 +273,8 @@ class LitellmModel(Model):
|
|
|
269
273
|
else:
|
|
270
274
|
logger.debug(
|
|
271
275
|
f"Calling Litellm model: {self.model}\n"
|
|
272
|
-
f"{json.dumps(converted_messages, indent=2)}\n"
|
|
273
|
-
f"Tools:\n{json.dumps(converted_tools, indent=2)}\n"
|
|
276
|
+
f"{json.dumps(converted_messages, indent=2, ensure_ascii=False)}\n"
|
|
277
|
+
f"Tools:\n{json.dumps(converted_tools, indent=2, ensure_ascii=False)}\n"
|
|
274
278
|
f"Stream: {stream}\n"
|
|
275
279
|
f"Tool choice: {tool_choice}\n"
|
|
276
280
|
f"Response format: {response_format}\n"
|
agents/handoffs.py
CHANGED
|
@@ -15,6 +15,7 @@ from .run_context import RunContextWrapper, TContext
|
|
|
15
15
|
from .strict_schema import ensure_strict_json_schema
|
|
16
16
|
from .tracing.spans import SpanError
|
|
17
17
|
from .util import _error_tracing, _json, _transforms
|
|
18
|
+
from .util._types import MaybeAwaitable
|
|
18
19
|
|
|
19
20
|
if TYPE_CHECKING:
|
|
20
21
|
from .agent import Agent
|
|
@@ -99,6 +100,11 @@ class Handoff(Generic[TContext]):
|
|
|
99
100
|
True, as it increases the likelihood of correct JSON input.
|
|
100
101
|
"""
|
|
101
102
|
|
|
103
|
+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
|
|
104
|
+
"""Whether the handoff is enabled. Either a bool or a Callable that takes the run context and
|
|
105
|
+
agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
|
|
106
|
+
a handoff based on your context/state."""
|
|
107
|
+
|
|
102
108
|
def get_transfer_message(self, agent: Agent[Any]) -> str:
|
|
103
109
|
return json.dumps({"assistant": agent.name})
|
|
104
110
|
|
|
@@ -121,6 +127,7 @@ def handoff(
|
|
|
121
127
|
tool_name_override: str | None = None,
|
|
122
128
|
tool_description_override: str | None = None,
|
|
123
129
|
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
|
|
130
|
+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
|
|
124
131
|
) -> Handoff[TContext]: ...
|
|
125
132
|
|
|
126
133
|
|
|
@@ -133,6 +140,7 @@ def handoff(
|
|
|
133
140
|
tool_description_override: str | None = None,
|
|
134
141
|
tool_name_override: str | None = None,
|
|
135
142
|
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
|
|
143
|
+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
|
|
136
144
|
) -> Handoff[TContext]: ...
|
|
137
145
|
|
|
138
146
|
|
|
@@ -144,6 +152,7 @@ def handoff(
|
|
|
144
152
|
tool_description_override: str | None = None,
|
|
145
153
|
tool_name_override: str | None = None,
|
|
146
154
|
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
|
|
155
|
+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
|
|
147
156
|
) -> Handoff[TContext]: ...
|
|
148
157
|
|
|
149
158
|
|
|
@@ -154,6 +163,7 @@ def handoff(
|
|
|
154
163
|
on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None,
|
|
155
164
|
input_type: type[THandoffInput] | None = None,
|
|
156
165
|
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
|
|
166
|
+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
|
|
157
167
|
) -> Handoff[TContext]:
|
|
158
168
|
"""Create a handoff from an agent.
|
|
159
169
|
|
|
@@ -166,6 +176,9 @@ def handoff(
|
|
|
166
176
|
input_type: the type of the input to the handoff. If provided, the input will be validated
|
|
167
177
|
against this type. Only relevant if you pass a function that takes an input.
|
|
168
178
|
input_filter: a function that filters the inputs that are passed to the next agent.
|
|
179
|
+
is_enabled: Whether the handoff is enabled. Can be a bool or a callable that takes the run
|
|
180
|
+
context and agent and returns whether the handoff is enabled. Disabled handoffs are
|
|
181
|
+
hidden from the LLM at runtime.
|
|
169
182
|
"""
|
|
170
183
|
assert (on_handoff and input_type) or not (on_handoff and input_type), (
|
|
171
184
|
"You must provide either both on_handoff and input_type, or neither"
|
|
@@ -233,4 +246,5 @@ def handoff(
|
|
|
233
246
|
on_invoke_handoff=_invoke_handoff,
|
|
234
247
|
input_filter=input_filter,
|
|
235
248
|
agent_name=agent.name,
|
|
249
|
+
is_enabled=is_enabled,
|
|
236
250
|
)
|
agents/mcp/__init__.py
CHANGED
|
@@ -11,7 +11,14 @@ try:
|
|
|
11
11
|
except ImportError:
|
|
12
12
|
pass
|
|
13
13
|
|
|
14
|
-
from .util import
|
|
14
|
+
from .util import (
|
|
15
|
+
MCPUtil,
|
|
16
|
+
ToolFilter,
|
|
17
|
+
ToolFilterCallable,
|
|
18
|
+
ToolFilterContext,
|
|
19
|
+
ToolFilterStatic,
|
|
20
|
+
create_static_tool_filter,
|
|
21
|
+
)
|
|
15
22
|
|
|
16
23
|
__all__ = [
|
|
17
24
|
"MCPServer",
|
|
@@ -22,4 +29,9 @@ __all__ = [
|
|
|
22
29
|
"MCPServerStreamableHttp",
|
|
23
30
|
"MCPServerStreamableHttpParams",
|
|
24
31
|
"MCPUtil",
|
|
32
|
+
"ToolFilter",
|
|
33
|
+
"ToolFilterCallable",
|
|
34
|
+
"ToolFilterContext",
|
|
35
|
+
"ToolFilterStatic",
|
|
36
|
+
"create_static_tool_filter",
|
|
25
37
|
]
|
agents/mcp/server.py
CHANGED
|
@@ -2,10 +2,11 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
4
|
import asyncio
|
|
5
|
+
import inspect
|
|
5
6
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
|
6
7
|
from datetime import timedelta
|
|
7
8
|
from pathlib import Path
|
|
8
|
-
from typing import Any, Literal
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Literal, cast
|
|
9
10
|
|
|
10
11
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
11
12
|
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
|
|
@@ -17,6 +18,11 @@ from typing_extensions import NotRequired, TypedDict
|
|
|
17
18
|
|
|
18
19
|
from ..exceptions import UserError
|
|
19
20
|
from ..logger import logger
|
|
21
|
+
from ..run_context import RunContextWrapper
|
|
22
|
+
from .util import ToolFilter, ToolFilterCallable, ToolFilterContext, ToolFilterStatic
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from ..agent import Agent
|
|
20
26
|
|
|
21
27
|
|
|
22
28
|
class MCPServer(abc.ABC):
|
|
@@ -44,7 +50,11 @@ class MCPServer(abc.ABC):
|
|
|
44
50
|
pass
|
|
45
51
|
|
|
46
52
|
@abc.abstractmethod
|
|
47
|
-
async def list_tools(
|
|
53
|
+
async def list_tools(
|
|
54
|
+
self,
|
|
55
|
+
run_context: RunContextWrapper[Any],
|
|
56
|
+
agent: Agent[Any],
|
|
57
|
+
) -> list[MCPTool]:
|
|
48
58
|
"""List the tools available on the server."""
|
|
49
59
|
pass
|
|
50
60
|
|
|
@@ -57,7 +67,12 @@ class MCPServer(abc.ABC):
|
|
|
57
67
|
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
58
68
|
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
|
|
59
69
|
|
|
60
|
-
def __init__(
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
cache_tools_list: bool,
|
|
73
|
+
client_session_timeout_seconds: float | None,
|
|
74
|
+
tool_filter: ToolFilter = None,
|
|
75
|
+
):
|
|
61
76
|
"""
|
|
62
77
|
Args:
|
|
63
78
|
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
|
@@ -68,6 +83,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
68
83
|
(by avoiding a round-trip to the server every time).
|
|
69
84
|
|
|
70
85
|
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
|
|
86
|
+
tool_filter: The tool filter to use for filtering tools.
|
|
71
87
|
"""
|
|
72
88
|
self.session: ClientSession | None = None
|
|
73
89
|
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
|
@@ -81,6 +97,88 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
81
97
|
self._cache_dirty = True
|
|
82
98
|
self._tools_list: list[MCPTool] | None = None
|
|
83
99
|
|
|
100
|
+
self.tool_filter = tool_filter
|
|
101
|
+
|
|
102
|
+
async def _apply_tool_filter(
|
|
103
|
+
self,
|
|
104
|
+
tools: list[MCPTool],
|
|
105
|
+
run_context: RunContextWrapper[Any],
|
|
106
|
+
agent: Agent[Any],
|
|
107
|
+
) -> list[MCPTool]:
|
|
108
|
+
"""Apply the tool filter to the list of tools."""
|
|
109
|
+
if self.tool_filter is None:
|
|
110
|
+
return tools
|
|
111
|
+
|
|
112
|
+
# Handle static tool filter
|
|
113
|
+
if isinstance(self.tool_filter, dict):
|
|
114
|
+
return self._apply_static_tool_filter(tools, self.tool_filter)
|
|
115
|
+
|
|
116
|
+
# Handle callable tool filter (dynamic filter)
|
|
117
|
+
else:
|
|
118
|
+
return await self._apply_dynamic_tool_filter(tools, run_context, agent)
|
|
119
|
+
|
|
120
|
+
def _apply_static_tool_filter(
|
|
121
|
+
self,
|
|
122
|
+
tools: list[MCPTool],
|
|
123
|
+
static_filter: ToolFilterStatic
|
|
124
|
+
) -> list[MCPTool]:
|
|
125
|
+
"""Apply static tool filtering based on allowlist and blocklist."""
|
|
126
|
+
filtered_tools = tools
|
|
127
|
+
|
|
128
|
+
# Apply allowed_tool_names filter (whitelist)
|
|
129
|
+
if "allowed_tool_names" in static_filter:
|
|
130
|
+
allowed_names = static_filter["allowed_tool_names"]
|
|
131
|
+
filtered_tools = [t for t in filtered_tools if t.name in allowed_names]
|
|
132
|
+
|
|
133
|
+
# Apply blocked_tool_names filter (blacklist)
|
|
134
|
+
if "blocked_tool_names" in static_filter:
|
|
135
|
+
blocked_names = static_filter["blocked_tool_names"]
|
|
136
|
+
filtered_tools = [t for t in filtered_tools if t.name not in blocked_names]
|
|
137
|
+
|
|
138
|
+
return filtered_tools
|
|
139
|
+
|
|
140
|
+
async def _apply_dynamic_tool_filter(
|
|
141
|
+
self,
|
|
142
|
+
tools: list[MCPTool],
|
|
143
|
+
run_context: RunContextWrapper[Any],
|
|
144
|
+
agent: Agent[Any],
|
|
145
|
+
) -> list[MCPTool]:
|
|
146
|
+
"""Apply dynamic tool filtering using a callable filter function."""
|
|
147
|
+
|
|
148
|
+
# Ensure we have a callable filter and cast to help mypy
|
|
149
|
+
if not callable(self.tool_filter):
|
|
150
|
+
raise ValueError("Tool filter must be callable for dynamic filtering")
|
|
151
|
+
tool_filter_func = cast(ToolFilterCallable, self.tool_filter)
|
|
152
|
+
|
|
153
|
+
# Create filter context
|
|
154
|
+
filter_context = ToolFilterContext(
|
|
155
|
+
run_context=run_context,
|
|
156
|
+
agent=agent,
|
|
157
|
+
server_name=self.name,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
filtered_tools = []
|
|
161
|
+
for tool in tools:
|
|
162
|
+
try:
|
|
163
|
+
# Call the filter function with context
|
|
164
|
+
result = tool_filter_func(filter_context, tool)
|
|
165
|
+
|
|
166
|
+
if inspect.isawaitable(result):
|
|
167
|
+
should_include = await result
|
|
168
|
+
else:
|
|
169
|
+
should_include = result
|
|
170
|
+
|
|
171
|
+
if should_include:
|
|
172
|
+
filtered_tools.append(tool)
|
|
173
|
+
except Exception as e:
|
|
174
|
+
logger.error(
|
|
175
|
+
f"Error applying tool filter to tool '{tool.name}' on server '{self.name}': {e}"
|
|
176
|
+
)
|
|
177
|
+
# On error, exclude the tool for safety
|
|
178
|
+
continue
|
|
179
|
+
|
|
180
|
+
return filtered_tools
|
|
181
|
+
|
|
84
182
|
@abc.abstractmethod
|
|
85
183
|
def create_streams(
|
|
86
184
|
self,
|
|
@@ -131,21 +229,30 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
131
229
|
await self.cleanup()
|
|
132
230
|
raise
|
|
133
231
|
|
|
134
|
-
async def list_tools(
|
|
232
|
+
async def list_tools(
|
|
233
|
+
self,
|
|
234
|
+
run_context: RunContextWrapper[Any],
|
|
235
|
+
agent: Agent[Any],
|
|
236
|
+
) -> list[MCPTool]:
|
|
135
237
|
"""List the tools available on the server."""
|
|
136
238
|
if not self.session:
|
|
137
239
|
raise UserError("Server not initialized. Make sure you call `connect()` first.")
|
|
138
240
|
|
|
139
241
|
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
|
|
140
242
|
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
243
|
+
tools = self._tools_list
|
|
244
|
+
else:
|
|
245
|
+
# Reset the cache dirty to False
|
|
246
|
+
self._cache_dirty = False
|
|
247
|
+
# Fetch the tools from the server
|
|
248
|
+
self._tools_list = (await self.session.list_tools()).tools
|
|
249
|
+
tools = self._tools_list
|
|
250
|
+
|
|
251
|
+
# Filter tools based on tool_filter
|
|
252
|
+
filtered_tools = tools
|
|
253
|
+
if self.tool_filter is not None:
|
|
254
|
+
filtered_tools = await self._apply_tool_filter(filtered_tools, run_context, agent)
|
|
255
|
+
return filtered_tools
|
|
149
256
|
|
|
150
257
|
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
|
|
151
258
|
"""Invoke a tool on the server."""
|
|
@@ -206,6 +313,7 @@ class MCPServerStdio(_MCPServerWithClientSession):
|
|
|
206
313
|
cache_tools_list: bool = False,
|
|
207
314
|
name: str | None = None,
|
|
208
315
|
client_session_timeout_seconds: float | None = 5,
|
|
316
|
+
tool_filter: ToolFilter = None,
|
|
209
317
|
):
|
|
210
318
|
"""Create a new MCP server based on the stdio transport.
|
|
211
319
|
|
|
@@ -223,8 +331,13 @@ class MCPServerStdio(_MCPServerWithClientSession):
|
|
|
223
331
|
name: A readable name for the server. If not provided, we'll create one from the
|
|
224
332
|
command.
|
|
225
333
|
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
|
|
334
|
+
tool_filter: The tool filter to use for filtering tools.
|
|
226
335
|
"""
|
|
227
|
-
super().__init__(
|
|
336
|
+
super().__init__(
|
|
337
|
+
cache_tools_list,
|
|
338
|
+
client_session_timeout_seconds,
|
|
339
|
+
tool_filter,
|
|
340
|
+
)
|
|
228
341
|
|
|
229
342
|
self.params = StdioServerParameters(
|
|
230
343
|
command=params["command"],
|
|
@@ -283,6 +396,7 @@ class MCPServerSse(_MCPServerWithClientSession):
|
|
|
283
396
|
cache_tools_list: bool = False,
|
|
284
397
|
name: str | None = None,
|
|
285
398
|
client_session_timeout_seconds: float | None = 5,
|
|
399
|
+
tool_filter: ToolFilter = None,
|
|
286
400
|
):
|
|
287
401
|
"""Create a new MCP server based on the HTTP with SSE transport.
|
|
288
402
|
|
|
@@ -302,8 +416,13 @@ class MCPServerSse(_MCPServerWithClientSession):
|
|
|
302
416
|
URL.
|
|
303
417
|
|
|
304
418
|
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
|
|
419
|
+
tool_filter: The tool filter to use for filtering tools.
|
|
305
420
|
"""
|
|
306
|
-
super().__init__(
|
|
421
|
+
super().__init__(
|
|
422
|
+
cache_tools_list,
|
|
423
|
+
client_session_timeout_seconds,
|
|
424
|
+
tool_filter,
|
|
425
|
+
)
|
|
307
426
|
|
|
308
427
|
self.params = params
|
|
309
428
|
self._name = name or f"sse: {self.params['url']}"
|
|
@@ -362,6 +481,7 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
|
|
|
362
481
|
cache_tools_list: bool = False,
|
|
363
482
|
name: str | None = None,
|
|
364
483
|
client_session_timeout_seconds: float | None = 5,
|
|
484
|
+
tool_filter: ToolFilter = None,
|
|
365
485
|
):
|
|
366
486
|
"""Create a new MCP server based on the Streamable HTTP transport.
|
|
367
487
|
|
|
@@ -382,8 +502,13 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
|
|
|
382
502
|
URL.
|
|
383
503
|
|
|
384
504
|
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
|
|
505
|
+
tool_filter: The tool filter to use for filtering tools.
|
|
385
506
|
"""
|
|
386
|
-
super().__init__(
|
|
507
|
+
super().__init__(
|
|
508
|
+
cache_tools_list,
|
|
509
|
+
client_session_timeout_seconds,
|
|
510
|
+
tool_filter,
|
|
511
|
+
)
|
|
387
512
|
|
|
388
513
|
self.params = params
|
|
389
514
|
self._name = name or f"streamable_http: {self.params['url']}"
|
agents/mcp/util.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import json
|
|
3
|
-
from
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
5
|
+
|
|
6
|
+
from typing_extensions import NotRequired, TypedDict
|
|
4
7
|
|
|
5
8
|
from agents.strict_schema import ensure_strict_json_schema
|
|
6
9
|
|
|
@@ -10,25 +13,102 @@ from ..logger import logger
|
|
|
10
13
|
from ..run_context import RunContextWrapper
|
|
11
14
|
from ..tool import FunctionTool, Tool
|
|
12
15
|
from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span
|
|
16
|
+
from ..util._types import MaybeAwaitable
|
|
13
17
|
|
|
14
18
|
if TYPE_CHECKING:
|
|
15
19
|
from mcp.types import Tool as MCPTool
|
|
16
20
|
|
|
21
|
+
from ..agent import Agent
|
|
17
22
|
from .server import MCPServer
|
|
18
23
|
|
|
19
24
|
|
|
25
|
+
@dataclass
|
|
26
|
+
class ToolFilterContext:
|
|
27
|
+
"""Context information available to tool filter functions."""
|
|
28
|
+
|
|
29
|
+
run_context: RunContextWrapper[Any]
|
|
30
|
+
"""The current run context."""
|
|
31
|
+
|
|
32
|
+
agent: "Agent[Any]"
|
|
33
|
+
"""The agent that is requesting the tool list."""
|
|
34
|
+
|
|
35
|
+
server_name: str
|
|
36
|
+
"""The name of the MCP server."""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
ToolFilterCallable = Callable[["ToolFilterContext", "MCPTool"], MaybeAwaitable[bool]]
|
|
40
|
+
"""A function that determines whether a tool should be available.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
context: The context information including run context, agent, and server name.
|
|
44
|
+
tool: The MCP tool to filter.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Whether the tool should be available (True) or filtered out (False).
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ToolFilterStatic(TypedDict):
|
|
52
|
+
"""Static tool filter configuration using allowlists and blocklists."""
|
|
53
|
+
|
|
54
|
+
allowed_tool_names: NotRequired[list[str]]
|
|
55
|
+
"""Optional list of tool names to allow (whitelist).
|
|
56
|
+
If set, only these tools will be available."""
|
|
57
|
+
|
|
58
|
+
blocked_tool_names: NotRequired[list[str]]
|
|
59
|
+
"""Optional list of tool names to exclude (blacklist).
|
|
60
|
+
If set, these tools will be filtered out."""
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
ToolFilter = Union[ToolFilterCallable, ToolFilterStatic, None]
|
|
64
|
+
"""A tool filter that can be either a function, static configuration, or None (no filtering)."""
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def create_static_tool_filter(
|
|
68
|
+
allowed_tool_names: Optional[list[str]] = None,
|
|
69
|
+
blocked_tool_names: Optional[list[str]] = None,
|
|
70
|
+
) -> Optional[ToolFilterStatic]:
|
|
71
|
+
"""Create a static tool filter from allowlist and blocklist parameters.
|
|
72
|
+
|
|
73
|
+
This is a convenience function for creating a ToolFilterStatic.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
allowed_tool_names: Optional list of tool names to allow (whitelist).
|
|
77
|
+
blocked_tool_names: Optional list of tool names to exclude (blacklist).
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
A ToolFilterStatic if any filtering is specified, None otherwise.
|
|
81
|
+
"""
|
|
82
|
+
if allowed_tool_names is None and blocked_tool_names is None:
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
filter_dict: ToolFilterStatic = {}
|
|
86
|
+
if allowed_tool_names is not None:
|
|
87
|
+
filter_dict["allowed_tool_names"] = allowed_tool_names
|
|
88
|
+
if blocked_tool_names is not None:
|
|
89
|
+
filter_dict["blocked_tool_names"] = blocked_tool_names
|
|
90
|
+
|
|
91
|
+
return filter_dict
|
|
92
|
+
|
|
93
|
+
|
|
20
94
|
class MCPUtil:
|
|
21
95
|
"""Set of utilities for interop between MCP and Agents SDK tools."""
|
|
22
96
|
|
|
23
97
|
@classmethod
|
|
24
98
|
async def get_all_function_tools(
|
|
25
|
-
cls,
|
|
99
|
+
cls,
|
|
100
|
+
servers: list["MCPServer"],
|
|
101
|
+
convert_schemas_to_strict: bool,
|
|
102
|
+
run_context: RunContextWrapper[Any],
|
|
103
|
+
agent: "Agent[Any]",
|
|
26
104
|
) -> list[Tool]:
|
|
27
105
|
"""Get all function tools from a list of MCP servers."""
|
|
28
106
|
tools = []
|
|
29
107
|
tool_names: set[str] = set()
|
|
30
108
|
for server in servers:
|
|
31
|
-
server_tools = await cls.get_function_tools(
|
|
109
|
+
server_tools = await cls.get_function_tools(
|
|
110
|
+
server, convert_schemas_to_strict, run_context, agent
|
|
111
|
+
)
|
|
32
112
|
server_tool_names = {tool.name for tool in server_tools}
|
|
33
113
|
if len(server_tool_names & tool_names) > 0:
|
|
34
114
|
raise UserError(
|
|
@@ -42,12 +122,16 @@ class MCPUtil:
|
|
|
42
122
|
|
|
43
123
|
@classmethod
|
|
44
124
|
async def get_function_tools(
|
|
45
|
-
cls,
|
|
125
|
+
cls,
|
|
126
|
+
server: "MCPServer",
|
|
127
|
+
convert_schemas_to_strict: bool,
|
|
128
|
+
run_context: RunContextWrapper[Any],
|
|
129
|
+
agent: "Agent[Any]",
|
|
46
130
|
) -> list[Tool]:
|
|
47
131
|
"""Get all function tools from a single MCP server."""
|
|
48
132
|
|
|
49
133
|
with mcp_tools_span(server=server.name) as span:
|
|
50
|
-
tools = await server.list_tools()
|
|
134
|
+
tools = await server.list_tools(run_context, agent)
|
|
51
135
|
span.span_data.result = [tool.name for tool in tools]
|
|
52
136
|
|
|
53
137
|
return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools]
|