openai-agents 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +5 -1
- agents/_run_impl.py +5 -1
- agents/agent.py +61 -29
- agents/function_schema.py +11 -1
- agents/guardrail.py +5 -1
- agents/lifecycle.py +26 -17
- agents/mcp/server.py +43 -11
- agents/mcp/util.py +5 -6
- agents/memory/__init__.py +3 -0
- agents/memory/session.py +369 -0
- agents/model_settings.py +15 -7
- agents/models/chatcmpl_converter.py +19 -2
- agents/models/chatcmpl_stream_handler.py +1 -1
- agents/models/openai_responses.py +11 -4
- agents/realtime/README.md +3 -0
- agents/realtime/__init__.py +174 -0
- agents/realtime/agent.py +80 -0
- agents/realtime/config.py +128 -0
- agents/realtime/events.py +216 -0
- agents/realtime/items.py +91 -0
- agents/realtime/model.py +69 -0
- agents/realtime/model_events.py +159 -0
- agents/realtime/model_inputs.py +100 -0
- agents/realtime/openai_realtime.py +584 -0
- agents/realtime/runner.py +118 -0
- agents/realtime/session.py +502 -0
- agents/run.py +106 -4
- agents/tool.py +6 -7
- agents/tool_context.py +16 -3
- agents/voice/models/openai_stt.py +1 -1
- agents/voice/pipeline.py +6 -0
- agents/voice/workflow.py +8 -0
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.0.dist-info}/METADATA +120 -3
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.0.dist-info}/RECORD +36 -22
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.0.dist-info}/WHEEL +0 -0
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.0.dist-info}/licenses/LICENSE +0 -0
agents/__init__.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Literal
|
|
|
5
5
|
from openai import AsyncOpenAI
|
|
6
6
|
|
|
7
7
|
from . import _config
|
|
8
|
-
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
|
|
8
|
+
from .agent import Agent, AgentBase, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
|
|
9
9
|
from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
|
|
10
10
|
from .computer import AsyncComputer, Button, Computer, Environment
|
|
11
11
|
from .exceptions import (
|
|
@@ -40,6 +40,7 @@ from .items import (
|
|
|
40
40
|
TResponseInputItem,
|
|
41
41
|
)
|
|
42
42
|
from .lifecycle import AgentHooks, RunHooks
|
|
43
|
+
from .memory import Session, SQLiteSession
|
|
43
44
|
from .model_settings import ModelSettings
|
|
44
45
|
from .models.interface import Model, ModelProvider, ModelTracing
|
|
45
46
|
from .models.openai_chatcompletions import OpenAIChatCompletionsModel
|
|
@@ -160,6 +161,7 @@ def enable_verbose_stdout_logging():
|
|
|
160
161
|
|
|
161
162
|
__all__ = [
|
|
162
163
|
"Agent",
|
|
164
|
+
"AgentBase",
|
|
163
165
|
"ToolsToFinalOutputFunction",
|
|
164
166
|
"ToolsToFinalOutputResult",
|
|
165
167
|
"Runner",
|
|
@@ -209,6 +211,8 @@ __all__ = [
|
|
|
209
211
|
"ItemHelpers",
|
|
210
212
|
"RunHooks",
|
|
211
213
|
"AgentHooks",
|
|
214
|
+
"Session",
|
|
215
|
+
"SQLiteSession",
|
|
212
216
|
"RunContextWrapper",
|
|
213
217
|
"TContext",
|
|
214
218
|
"RunErrorDetails",
|
agents/_run_impl.py
CHANGED
|
@@ -548,7 +548,11 @@ class RunImpl:
|
|
|
548
548
|
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
|
|
549
549
|
) -> Any:
|
|
550
550
|
with function_span(func_tool.name) as span_fn:
|
|
551
|
-
tool_context = ToolContext.from_agent_context(
|
|
551
|
+
tool_context = ToolContext.from_agent_context(
|
|
552
|
+
context_wrapper,
|
|
553
|
+
tool_call.call_id,
|
|
554
|
+
tool_call=tool_call,
|
|
555
|
+
)
|
|
552
556
|
if config.trace_include_sensitive_data:
|
|
553
557
|
span_fn.span_data.input = tool_call.arguments
|
|
554
558
|
try:
|
agents/agent.py
CHANGED
|
@@ -67,7 +67,63 @@ class MCPConfig(TypedDict):
|
|
|
67
67
|
|
|
68
68
|
|
|
69
69
|
@dataclass
|
|
70
|
-
class
|
|
70
|
+
class AgentBase(Generic[TContext]):
|
|
71
|
+
"""Base class for `Agent` and `RealtimeAgent`."""
|
|
72
|
+
|
|
73
|
+
name: str
|
|
74
|
+
"""The name of the agent."""
|
|
75
|
+
|
|
76
|
+
handoff_description: str | None = None
|
|
77
|
+
"""A description of the agent. This is used when the agent is used as a handoff, so that an
|
|
78
|
+
LLM knows what it does and when to invoke it.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
tools: list[Tool] = field(default_factory=list)
|
|
82
|
+
"""A list of tools that the agent can use."""
|
|
83
|
+
|
|
84
|
+
mcp_servers: list[MCPServer] = field(default_factory=list)
|
|
85
|
+
"""A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
|
|
86
|
+
the agent can use. Every time the agent runs, it will include tools from these servers in the
|
|
87
|
+
list of available tools.
|
|
88
|
+
|
|
89
|
+
NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
|
|
90
|
+
`server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
|
|
91
|
+
longer needed.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
|
|
95
|
+
"""Configuration for MCP servers."""
|
|
96
|
+
|
|
97
|
+
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
|
|
98
|
+
"""Fetches the available tools from the MCP servers."""
|
|
99
|
+
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
|
|
100
|
+
return await MCPUtil.get_all_function_tools(
|
|
101
|
+
self.mcp_servers, convert_schemas_to_strict, run_context, self
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
|
|
105
|
+
"""All agent tools, including MCP tools and function tools."""
|
|
106
|
+
mcp_tools = await self.get_mcp_tools(run_context)
|
|
107
|
+
|
|
108
|
+
async def _check_tool_enabled(tool: Tool) -> bool:
|
|
109
|
+
if not isinstance(tool, FunctionTool):
|
|
110
|
+
return True
|
|
111
|
+
|
|
112
|
+
attr = tool.is_enabled
|
|
113
|
+
if isinstance(attr, bool):
|
|
114
|
+
return attr
|
|
115
|
+
res = attr(run_context, self)
|
|
116
|
+
if inspect.isawaitable(res):
|
|
117
|
+
return bool(await res)
|
|
118
|
+
return bool(res)
|
|
119
|
+
|
|
120
|
+
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
|
|
121
|
+
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
|
|
122
|
+
return [*mcp_tools, *enabled]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@dataclass
|
|
126
|
+
class Agent(AgentBase, Generic[TContext]):
|
|
71
127
|
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
|
|
72
128
|
|
|
73
129
|
We strongly recommend passing `instructions`, which is the "system prompt" for the agent. In
|
|
@@ -76,10 +132,9 @@ class Agent(Generic[TContext]):
|
|
|
76
132
|
|
|
77
133
|
Agents are generic on the context type. The context is a (mutable) object you create. It is
|
|
78
134
|
passed to tool functions, handoffs, guardrails, etc.
|
|
79
|
-
"""
|
|
80
135
|
|
|
81
|
-
|
|
82
|
-
"""
|
|
136
|
+
See `AgentBase` for base parameters that are shared with `RealtimeAgent`s.
|
|
137
|
+
"""
|
|
83
138
|
|
|
84
139
|
instructions: (
|
|
85
140
|
str
|
|
@@ -103,11 +158,6 @@ class Agent(Generic[TContext]):
|
|
|
103
158
|
usable with OpenAI models, using the Responses API.
|
|
104
159
|
"""
|
|
105
160
|
|
|
106
|
-
handoff_description: str | None = None
|
|
107
|
-
"""A description of the agent. This is used when the agent is used as a handoff, so that an
|
|
108
|
-
LLM knows what it does and when to invoke it.
|
|
109
|
-
"""
|
|
110
|
-
|
|
111
161
|
handoffs: list[Agent[Any] | Handoff[TContext]] = field(default_factory=list)
|
|
112
162
|
"""Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs,
|
|
113
163
|
and the agent can choose to delegate to them if relevant. Allows for separation of concerns and
|
|
@@ -125,22 +175,6 @@ class Agent(Generic[TContext]):
|
|
|
125
175
|
"""Configures model-specific tuning parameters (e.g. temperature, top_p).
|
|
126
176
|
"""
|
|
127
177
|
|
|
128
|
-
tools: list[Tool] = field(default_factory=list)
|
|
129
|
-
"""A list of tools that the agent can use."""
|
|
130
|
-
|
|
131
|
-
mcp_servers: list[MCPServer] = field(default_factory=list)
|
|
132
|
-
"""A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
|
|
133
|
-
the agent can use. Every time the agent runs, it will include tools from these servers in the
|
|
134
|
-
list of available tools.
|
|
135
|
-
|
|
136
|
-
NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
|
|
137
|
-
`server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
|
|
138
|
-
longer needed.
|
|
139
|
-
"""
|
|
140
|
-
|
|
141
|
-
mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
|
|
142
|
-
"""Configuration for MCP servers."""
|
|
143
|
-
|
|
144
178
|
input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
|
|
145
179
|
"""A list of checks that run in parallel to the agent's execution, before generating a
|
|
146
180
|
response. Runs only if the agent is the first agent in the chain.
|
|
@@ -176,7 +210,7 @@ class Agent(Generic[TContext]):
|
|
|
176
210
|
The final output will be the output of the first matching tool call. The LLM does not
|
|
177
211
|
process the result of the tool call.
|
|
178
212
|
- A function: If you pass a function, it will be called with the run context and the list of
|
|
179
|
-
tool results. It must return a `
|
|
213
|
+
tool results. It must return a `ToolsToFinalOutputResult`, which determines whether the tool
|
|
180
214
|
calls result in a final output.
|
|
181
215
|
|
|
182
216
|
NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search,
|
|
@@ -256,9 +290,7 @@ class Agent(Generic[TContext]):
|
|
|
256
290
|
"""Get the prompt for the agent."""
|
|
257
291
|
return await PromptUtil.to_model_input(self.prompt, run_context, self)
|
|
258
292
|
|
|
259
|
-
async def get_mcp_tools(
|
|
260
|
-
self, run_context: RunContextWrapper[TContext]
|
|
261
|
-
) -> list[Tool]:
|
|
293
|
+
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
|
|
262
294
|
"""Fetches the available tools from the MCP servers."""
|
|
263
295
|
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
|
|
264
296
|
return await MCPUtil.get_all_function_tools(
|
agents/function_schema.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Any, Callable, Literal, get_args, get_origin, get_type_hints
|
|
|
9
9
|
|
|
10
10
|
from griffe import Docstring, DocstringSectionKind
|
|
11
11
|
from pydantic import BaseModel, Field, create_model
|
|
12
|
+
from pydantic.fields import FieldInfo
|
|
12
13
|
|
|
13
14
|
from .exceptions import UserError
|
|
14
15
|
from .run_context import RunContextWrapper
|
|
@@ -319,6 +320,14 @@ def function_schema(
|
|
|
319
320
|
ann,
|
|
320
321
|
Field(..., description=field_description),
|
|
321
322
|
)
|
|
323
|
+
elif isinstance(default, FieldInfo):
|
|
324
|
+
# Parameter with a default value that is a Field(...)
|
|
325
|
+
fields[name] = (
|
|
326
|
+
ann,
|
|
327
|
+
FieldInfo.merge_field_infos(
|
|
328
|
+
default, description=field_description or default.description
|
|
329
|
+
),
|
|
330
|
+
)
|
|
322
331
|
else:
|
|
323
332
|
# Parameter with a default value
|
|
324
333
|
fields[name] = (
|
|
@@ -337,7 +346,8 @@ def function_schema(
|
|
|
337
346
|
# 5. Return as a FuncSchema dataclass
|
|
338
347
|
return FuncSchema(
|
|
339
348
|
name=func_name,
|
|
340
|
-
|
|
349
|
+
# Ensure description_override takes precedence even if docstring info is disabled.
|
|
350
|
+
description=description_override or (doc_info.description if doc_info else None),
|
|
341
351
|
params_pydantic_model=dynamic_model,
|
|
342
352
|
params_json_schema=json_schema,
|
|
343
353
|
signature=sig,
|
agents/guardrail.py
CHANGED
|
@@ -241,7 +241,11 @@ def input_guardrail(
|
|
|
241
241
|
def decorator(
|
|
242
242
|
f: _InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co],
|
|
243
243
|
) -> InputGuardrail[TContext_co]:
|
|
244
|
-
return InputGuardrail(
|
|
244
|
+
return InputGuardrail(
|
|
245
|
+
guardrail_function=f,
|
|
246
|
+
# If not set, guardrail name uses the function’s name by default.
|
|
247
|
+
name=name if name else f.__name__
|
|
248
|
+
)
|
|
245
249
|
|
|
246
250
|
if func is not None:
|
|
247
251
|
# Decorator was used without parentheses
|
agents/lifecycle.py
CHANGED
|
@@ -1,25 +1,27 @@
|
|
|
1
1
|
from typing import Any, Generic
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from typing_extensions import TypeVar
|
|
4
|
+
|
|
5
|
+
from .agent import Agent, AgentBase
|
|
4
6
|
from .run_context import RunContextWrapper, TContext
|
|
5
7
|
from .tool import Tool
|
|
6
8
|
|
|
9
|
+
TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase)
|
|
10
|
+
|
|
7
11
|
|
|
8
|
-
class
|
|
12
|
+
class RunHooksBase(Generic[TContext, TAgent]):
|
|
9
13
|
"""A class that receives callbacks on various lifecycle events in an agent run. Subclass and
|
|
10
14
|
override the methods you need.
|
|
11
15
|
"""
|
|
12
16
|
|
|
13
|
-
async def on_agent_start(
|
|
14
|
-
self, context: RunContextWrapper[TContext], agent: Agent[TContext]
|
|
15
|
-
) -> None:
|
|
17
|
+
async def on_agent_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
|
|
16
18
|
"""Called before the agent is invoked. Called each time the current agent changes."""
|
|
17
19
|
pass
|
|
18
20
|
|
|
19
21
|
async def on_agent_end(
|
|
20
22
|
self,
|
|
21
23
|
context: RunContextWrapper[TContext],
|
|
22
|
-
agent:
|
|
24
|
+
agent: TAgent,
|
|
23
25
|
output: Any,
|
|
24
26
|
) -> None:
|
|
25
27
|
"""Called when the agent produces a final output."""
|
|
@@ -28,8 +30,8 @@ class RunHooks(Generic[TContext]):
|
|
|
28
30
|
async def on_handoff(
|
|
29
31
|
self,
|
|
30
32
|
context: RunContextWrapper[TContext],
|
|
31
|
-
from_agent:
|
|
32
|
-
to_agent:
|
|
33
|
+
from_agent: TAgent,
|
|
34
|
+
to_agent: TAgent,
|
|
33
35
|
) -> None:
|
|
34
36
|
"""Called when a handoff occurs."""
|
|
35
37
|
pass
|
|
@@ -37,7 +39,7 @@ class RunHooks(Generic[TContext]):
|
|
|
37
39
|
async def on_tool_start(
|
|
38
40
|
self,
|
|
39
41
|
context: RunContextWrapper[TContext],
|
|
40
|
-
agent:
|
|
42
|
+
agent: TAgent,
|
|
41
43
|
tool: Tool,
|
|
42
44
|
) -> None:
|
|
43
45
|
"""Called before a tool is invoked."""
|
|
@@ -46,7 +48,7 @@ class RunHooks(Generic[TContext]):
|
|
|
46
48
|
async def on_tool_end(
|
|
47
49
|
self,
|
|
48
50
|
context: RunContextWrapper[TContext],
|
|
49
|
-
agent:
|
|
51
|
+
agent: TAgent,
|
|
50
52
|
tool: Tool,
|
|
51
53
|
result: str,
|
|
52
54
|
) -> None:
|
|
@@ -54,14 +56,14 @@ class RunHooks(Generic[TContext]):
|
|
|
54
56
|
pass
|
|
55
57
|
|
|
56
58
|
|
|
57
|
-
class
|
|
59
|
+
class AgentHooksBase(Generic[TContext, TAgent]):
|
|
58
60
|
"""A class that receives callbacks on various lifecycle events for a specific agent. You can
|
|
59
61
|
set this on `agent.hooks` to receive events for that specific agent.
|
|
60
62
|
|
|
61
63
|
Subclass and override the methods you need.
|
|
62
64
|
"""
|
|
63
65
|
|
|
64
|
-
async def on_start(self, context: RunContextWrapper[TContext], agent:
|
|
66
|
+
async def on_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
|
|
65
67
|
"""Called before the agent is invoked. Called each time the running agent is changed to this
|
|
66
68
|
agent."""
|
|
67
69
|
pass
|
|
@@ -69,7 +71,7 @@ class AgentHooks(Generic[TContext]):
|
|
|
69
71
|
async def on_end(
|
|
70
72
|
self,
|
|
71
73
|
context: RunContextWrapper[TContext],
|
|
72
|
-
agent:
|
|
74
|
+
agent: TAgent,
|
|
73
75
|
output: Any,
|
|
74
76
|
) -> None:
|
|
75
77
|
"""Called when the agent produces a final output."""
|
|
@@ -78,8 +80,8 @@ class AgentHooks(Generic[TContext]):
|
|
|
78
80
|
async def on_handoff(
|
|
79
81
|
self,
|
|
80
82
|
context: RunContextWrapper[TContext],
|
|
81
|
-
agent:
|
|
82
|
-
source:
|
|
83
|
+
agent: TAgent,
|
|
84
|
+
source: TAgent,
|
|
83
85
|
) -> None:
|
|
84
86
|
"""Called when the agent is being handed off to. The `source` is the agent that is handing
|
|
85
87
|
off to this agent."""
|
|
@@ -88,7 +90,7 @@ class AgentHooks(Generic[TContext]):
|
|
|
88
90
|
async def on_tool_start(
|
|
89
91
|
self,
|
|
90
92
|
context: RunContextWrapper[TContext],
|
|
91
|
-
agent:
|
|
93
|
+
agent: TAgent,
|
|
92
94
|
tool: Tool,
|
|
93
95
|
) -> None:
|
|
94
96
|
"""Called before a tool is invoked."""
|
|
@@ -97,9 +99,16 @@ class AgentHooks(Generic[TContext]):
|
|
|
97
99
|
async def on_tool_end(
|
|
98
100
|
self,
|
|
99
101
|
context: RunContextWrapper[TContext],
|
|
100
|
-
agent:
|
|
102
|
+
agent: TAgent,
|
|
101
103
|
tool: Tool,
|
|
102
104
|
result: str,
|
|
103
105
|
) -> None:
|
|
104
106
|
"""Called after a tool is invoked."""
|
|
105
107
|
pass
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
RunHooks = RunHooksBase[TContext, Agent]
|
|
111
|
+
"""Run hooks when using `Agent`."""
|
|
112
|
+
|
|
113
|
+
AgentHooks = AgentHooksBase[TContext, Agent]
|
|
114
|
+
"""Agent hooks for `Agent`s."""
|
agents/mcp/server.py
CHANGED
|
@@ -13,7 +13,7 @@ from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_cli
|
|
|
13
13
|
from mcp.client.sse import sse_client
|
|
14
14
|
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
|
|
15
15
|
from mcp.shared.message import SessionMessage
|
|
16
|
-
from mcp.types import CallToolResult, InitializeResult
|
|
16
|
+
from mcp.types import CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult
|
|
17
17
|
from typing_extensions import NotRequired, TypedDict
|
|
18
18
|
|
|
19
19
|
from ..exceptions import UserError
|
|
@@ -22,7 +22,7 @@ from ..run_context import RunContextWrapper
|
|
|
22
22
|
from .util import ToolFilter, ToolFilterCallable, ToolFilterContext, ToolFilterStatic
|
|
23
23
|
|
|
24
24
|
if TYPE_CHECKING:
|
|
25
|
-
from ..agent import
|
|
25
|
+
from ..agent import AgentBase
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class MCPServer(abc.ABC):
|
|
@@ -52,8 +52,8 @@ class MCPServer(abc.ABC):
|
|
|
52
52
|
@abc.abstractmethod
|
|
53
53
|
async def list_tools(
|
|
54
54
|
self,
|
|
55
|
-
run_context: RunContextWrapper[Any],
|
|
56
|
-
agent:
|
|
55
|
+
run_context: RunContextWrapper[Any] | None = None,
|
|
56
|
+
agent: AgentBase | None = None,
|
|
57
57
|
) -> list[MCPTool]:
|
|
58
58
|
"""List the tools available on the server."""
|
|
59
59
|
pass
|
|
@@ -63,6 +63,20 @@ class MCPServer(abc.ABC):
|
|
|
63
63
|
"""Invoke a tool on the server."""
|
|
64
64
|
pass
|
|
65
65
|
|
|
66
|
+
@abc.abstractmethod
|
|
67
|
+
async def list_prompts(
|
|
68
|
+
self,
|
|
69
|
+
) -> ListPromptsResult:
|
|
70
|
+
"""List the prompts available on the server."""
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
@abc.abstractmethod
|
|
74
|
+
async def get_prompt(
|
|
75
|
+
self, name: str, arguments: dict[str, Any] | None = None
|
|
76
|
+
) -> GetPromptResult:
|
|
77
|
+
"""Get a specific prompt from the server."""
|
|
78
|
+
pass
|
|
79
|
+
|
|
66
80
|
|
|
67
81
|
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
68
82
|
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
|
|
@@ -103,7 +117,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
103
117
|
self,
|
|
104
118
|
tools: list[MCPTool],
|
|
105
119
|
run_context: RunContextWrapper[Any],
|
|
106
|
-
agent:
|
|
120
|
+
agent: AgentBase,
|
|
107
121
|
) -> list[MCPTool]:
|
|
108
122
|
"""Apply the tool filter to the list of tools."""
|
|
109
123
|
if self.tool_filter is None:
|
|
@@ -118,9 +132,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
118
132
|
return await self._apply_dynamic_tool_filter(tools, run_context, agent)
|
|
119
133
|
|
|
120
134
|
def _apply_static_tool_filter(
|
|
121
|
-
self,
|
|
122
|
-
tools: list[MCPTool],
|
|
123
|
-
static_filter: ToolFilterStatic
|
|
135
|
+
self, tools: list[MCPTool], static_filter: ToolFilterStatic
|
|
124
136
|
) -> list[MCPTool]:
|
|
125
137
|
"""Apply static tool filtering based on allowlist and blocklist."""
|
|
126
138
|
filtered_tools = tools
|
|
@@ -141,7 +153,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
141
153
|
self,
|
|
142
154
|
tools: list[MCPTool],
|
|
143
155
|
run_context: RunContextWrapper[Any],
|
|
144
|
-
agent:
|
|
156
|
+
agent: AgentBase,
|
|
145
157
|
) -> list[MCPTool]:
|
|
146
158
|
"""Apply dynamic tool filtering using a callable filter function."""
|
|
147
159
|
|
|
@@ -231,8 +243,8 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
231
243
|
|
|
232
244
|
async def list_tools(
|
|
233
245
|
self,
|
|
234
|
-
run_context: RunContextWrapper[Any],
|
|
235
|
-
agent:
|
|
246
|
+
run_context: RunContextWrapper[Any] | None = None,
|
|
247
|
+
agent: AgentBase | None = None,
|
|
236
248
|
) -> list[MCPTool]:
|
|
237
249
|
"""List the tools available on the server."""
|
|
238
250
|
if not self.session:
|
|
@@ -251,6 +263,8 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
251
263
|
# Filter tools based on tool_filter
|
|
252
264
|
filtered_tools = tools
|
|
253
265
|
if self.tool_filter is not None:
|
|
266
|
+
if run_context is None or agent is None:
|
|
267
|
+
raise UserError("run_context and agent are required for dynamic tool filtering")
|
|
254
268
|
filtered_tools = await self._apply_tool_filter(filtered_tools, run_context, agent)
|
|
255
269
|
return filtered_tools
|
|
256
270
|
|
|
@@ -261,6 +275,24 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
261
275
|
|
|
262
276
|
return await self.session.call_tool(tool_name, arguments)
|
|
263
277
|
|
|
278
|
+
async def list_prompts(
|
|
279
|
+
self,
|
|
280
|
+
) -> ListPromptsResult:
|
|
281
|
+
"""List the prompts available on the server."""
|
|
282
|
+
if not self.session:
|
|
283
|
+
raise UserError("Server not initialized. Make sure you call `connect()` first.")
|
|
284
|
+
|
|
285
|
+
return await self.session.list_prompts()
|
|
286
|
+
|
|
287
|
+
async def get_prompt(
|
|
288
|
+
self, name: str, arguments: dict[str, Any] | None = None
|
|
289
|
+
) -> GetPromptResult:
|
|
290
|
+
"""Get a specific prompt from the server."""
|
|
291
|
+
if not self.session:
|
|
292
|
+
raise UserError("Server not initialized. Make sure you call `connect()` first.")
|
|
293
|
+
|
|
294
|
+
return await self.session.get_prompt(name, arguments)
|
|
295
|
+
|
|
264
296
|
async def cleanup(self):
|
|
265
297
|
"""Cleanup the server."""
|
|
266
298
|
async with self._cleanup_lock:
|
agents/mcp/util.py
CHANGED
|
@@ -5,12 +5,11 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
|
5
5
|
|
|
6
6
|
from typing_extensions import NotRequired, TypedDict
|
|
7
7
|
|
|
8
|
-
from agents.strict_schema import ensure_strict_json_schema
|
|
9
|
-
|
|
10
8
|
from .. import _debug
|
|
11
9
|
from ..exceptions import AgentsException, ModelBehaviorError, UserError
|
|
12
10
|
from ..logger import logger
|
|
13
11
|
from ..run_context import RunContextWrapper
|
|
12
|
+
from ..strict_schema import ensure_strict_json_schema
|
|
14
13
|
from ..tool import FunctionTool, Tool
|
|
15
14
|
from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span
|
|
16
15
|
from ..util._types import MaybeAwaitable
|
|
@@ -18,7 +17,7 @@ from ..util._types import MaybeAwaitable
|
|
|
18
17
|
if TYPE_CHECKING:
|
|
19
18
|
from mcp.types import Tool as MCPTool
|
|
20
19
|
|
|
21
|
-
from ..agent import
|
|
20
|
+
from ..agent import AgentBase
|
|
22
21
|
from .server import MCPServer
|
|
23
22
|
|
|
24
23
|
|
|
@@ -29,7 +28,7 @@ class ToolFilterContext:
|
|
|
29
28
|
run_context: RunContextWrapper[Any]
|
|
30
29
|
"""The current run context."""
|
|
31
30
|
|
|
32
|
-
agent: "
|
|
31
|
+
agent: "AgentBase"
|
|
33
32
|
"""The agent that is requesting the tool list."""
|
|
34
33
|
|
|
35
34
|
server_name: str
|
|
@@ -100,7 +99,7 @@ class MCPUtil:
|
|
|
100
99
|
servers: list["MCPServer"],
|
|
101
100
|
convert_schemas_to_strict: bool,
|
|
102
101
|
run_context: RunContextWrapper[Any],
|
|
103
|
-
agent: "
|
|
102
|
+
agent: "AgentBase",
|
|
104
103
|
) -> list[Tool]:
|
|
105
104
|
"""Get all function tools from a list of MCP servers."""
|
|
106
105
|
tools = []
|
|
@@ -126,7 +125,7 @@ class MCPUtil:
|
|
|
126
125
|
server: "MCPServer",
|
|
127
126
|
convert_schemas_to_strict: bool,
|
|
128
127
|
run_context: RunContextWrapper[Any],
|
|
129
|
-
agent: "
|
|
128
|
+
agent: "AgentBase",
|
|
130
129
|
) -> list[Tool]:
|
|
131
130
|
"""Get all function tools from a single MCP server."""
|
|
132
131
|
|