openai-agents 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

agents/__init__.py CHANGED
@@ -5,7 +5,7 @@ from typing import Literal
5
5
  from openai import AsyncOpenAI
6
6
 
7
7
  from . import _config
8
- from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
8
+ from .agent import Agent, AgentBase, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
9
9
  from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
10
10
  from .computer import AsyncComputer, Button, Computer, Environment
11
11
  from .exceptions import (
@@ -40,6 +40,7 @@ from .items import (
40
40
  TResponseInputItem,
41
41
  )
42
42
  from .lifecycle import AgentHooks, RunHooks
43
+ from .memory import Session, SQLiteSession
43
44
  from .model_settings import ModelSettings
44
45
  from .models.interface import Model, ModelProvider, ModelTracing
45
46
  from .models.openai_chatcompletions import OpenAIChatCompletionsModel
@@ -160,6 +161,7 @@ def enable_verbose_stdout_logging():
160
161
 
161
162
  __all__ = [
162
163
  "Agent",
164
+ "AgentBase",
163
165
  "ToolsToFinalOutputFunction",
164
166
  "ToolsToFinalOutputResult",
165
167
  "Runner",
@@ -209,6 +211,8 @@ __all__ = [
209
211
  "ItemHelpers",
210
212
  "RunHooks",
211
213
  "AgentHooks",
214
+ "Session",
215
+ "SQLiteSession",
212
216
  "RunContextWrapper",
213
217
  "TContext",
214
218
  "RunErrorDetails",
agents/_run_impl.py CHANGED
@@ -548,7 +548,11 @@ class RunImpl:
548
548
  func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
549
549
  ) -> Any:
550
550
  with function_span(func_tool.name) as span_fn:
551
- tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id)
551
+ tool_context = ToolContext.from_agent_context(
552
+ context_wrapper,
553
+ tool_call.call_id,
554
+ tool_call=tool_call,
555
+ )
552
556
  if config.trace_include_sensitive_data:
553
557
  span_fn.span_data.input = tool_call.arguments
554
558
  try:
agents/agent.py CHANGED
@@ -67,7 +67,63 @@ class MCPConfig(TypedDict):
67
67
 
68
68
 
69
69
  @dataclass
70
- class Agent(Generic[TContext]):
70
+ class AgentBase(Generic[TContext]):
71
+ """Base class for `Agent` and `RealtimeAgent`."""
72
+
73
+ name: str
74
+ """The name of the agent."""
75
+
76
+ handoff_description: str | None = None
77
+ """A description of the agent. This is used when the agent is used as a handoff, so that an
78
+ LLM knows what it does and when to invoke it.
79
+ """
80
+
81
+ tools: list[Tool] = field(default_factory=list)
82
+ """A list of tools that the agent can use."""
83
+
84
+ mcp_servers: list[MCPServer] = field(default_factory=list)
85
+ """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
86
+ the agent can use. Every time the agent runs, it will include tools from these servers in the
87
+ list of available tools.
88
+
89
+ NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
90
+ `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
91
+ longer needed.
92
+ """
93
+
94
+ mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
95
+ """Configuration for MCP servers."""
96
+
97
+ async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
98
+ """Fetches the available tools from the MCP servers."""
99
+ convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
100
+ return await MCPUtil.get_all_function_tools(
101
+ self.mcp_servers, convert_schemas_to_strict, run_context, self
102
+ )
103
+
104
+ async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
105
+ """All agent tools, including MCP tools and function tools."""
106
+ mcp_tools = await self.get_mcp_tools(run_context)
107
+
108
+ async def _check_tool_enabled(tool: Tool) -> bool:
109
+ if not isinstance(tool, FunctionTool):
110
+ return True
111
+
112
+ attr = tool.is_enabled
113
+ if isinstance(attr, bool):
114
+ return attr
115
+ res = attr(run_context, self)
116
+ if inspect.isawaitable(res):
117
+ return bool(await res)
118
+ return bool(res)
119
+
120
+ results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
121
+ enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
122
+ return [*mcp_tools, *enabled]
123
+
124
+
125
+ @dataclass
126
+ class Agent(AgentBase, Generic[TContext]):
71
127
  """An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
72
128
 
73
129
  We strongly recommend passing `instructions`, which is the "system prompt" for the agent. In
@@ -76,10 +132,9 @@ class Agent(Generic[TContext]):
76
132
 
77
133
  Agents are generic on the context type. The context is a (mutable) object you create. It is
78
134
  passed to tool functions, handoffs, guardrails, etc.
79
- """
80
135
 
81
- name: str
82
- """The name of the agent."""
136
+ See `AgentBase` for base parameters that are shared with `RealtimeAgent`s.
137
+ """
83
138
 
84
139
  instructions: (
85
140
  str
@@ -103,11 +158,6 @@ class Agent(Generic[TContext]):
103
158
  usable with OpenAI models, using the Responses API.
104
159
  """
105
160
 
106
- handoff_description: str | None = None
107
- """A description of the agent. This is used when the agent is used as a handoff, so that an
108
- LLM knows what it does and when to invoke it.
109
- """
110
-
111
161
  handoffs: list[Agent[Any] | Handoff[TContext]] = field(default_factory=list)
112
162
  """Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs,
113
163
  and the agent can choose to delegate to them if relevant. Allows for separation of concerns and
@@ -125,22 +175,6 @@ class Agent(Generic[TContext]):
125
175
  """Configures model-specific tuning parameters (e.g. temperature, top_p).
126
176
  """
127
177
 
128
- tools: list[Tool] = field(default_factory=list)
129
- """A list of tools that the agent can use."""
130
-
131
- mcp_servers: list[MCPServer] = field(default_factory=list)
132
- """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
133
- the agent can use. Every time the agent runs, it will include tools from these servers in the
134
- list of available tools.
135
-
136
- NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
137
- `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
138
- longer needed.
139
- """
140
-
141
- mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
142
- """Configuration for MCP servers."""
143
-
144
178
  input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
145
179
  """A list of checks that run in parallel to the agent's execution, before generating a
146
180
  response. Runs only if the agent is the first agent in the chain.
@@ -176,7 +210,7 @@ class Agent(Generic[TContext]):
176
210
  The final output will be the output of the first matching tool call. The LLM does not
177
211
  process the result of the tool call.
178
212
  - A function: If you pass a function, it will be called with the run context and the list of
179
- tool results. It must return a `ToolToFinalOutputResult`, which determines whether the tool
213
+ tool results. It must return a `ToolsToFinalOutputResult`, which determines whether the tool
180
214
  calls result in a final output.
181
215
 
182
216
  NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search,
@@ -256,9 +290,7 @@ class Agent(Generic[TContext]):
256
290
  """Get the prompt for the agent."""
257
291
  return await PromptUtil.to_model_input(self.prompt, run_context, self)
258
292
 
259
- async def get_mcp_tools(
260
- self, run_context: RunContextWrapper[TContext]
261
- ) -> list[Tool]:
293
+ async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
262
294
  """Fetches the available tools from the MCP servers."""
263
295
  convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
264
296
  return await MCPUtil.get_all_function_tools(
agents/function_schema.py CHANGED
@@ -9,6 +9,7 @@ from typing import Any, Callable, Literal, get_args, get_origin, get_type_hints
9
9
 
10
10
  from griffe import Docstring, DocstringSectionKind
11
11
  from pydantic import BaseModel, Field, create_model
12
+ from pydantic.fields import FieldInfo
12
13
 
13
14
  from .exceptions import UserError
14
15
  from .run_context import RunContextWrapper
@@ -319,6 +320,14 @@ def function_schema(
319
320
  ann,
320
321
  Field(..., description=field_description),
321
322
  )
323
+ elif isinstance(default, FieldInfo):
324
+ # Parameter with a default value that is a Field(...)
325
+ fields[name] = (
326
+ ann,
327
+ FieldInfo.merge_field_infos(
328
+ default, description=field_description or default.description
329
+ ),
330
+ )
322
331
  else:
323
332
  # Parameter with a default value
324
333
  fields[name] = (
@@ -337,7 +346,8 @@ def function_schema(
337
346
  # 5. Return as a FuncSchema dataclass
338
347
  return FuncSchema(
339
348
  name=func_name,
340
- description=description_override or doc_info.description if doc_info else None,
349
+ # Ensure description_override takes precedence even if docstring info is disabled.
350
+ description=description_override or (doc_info.description if doc_info else None),
341
351
  params_pydantic_model=dynamic_model,
342
352
  params_json_schema=json_schema,
343
353
  signature=sig,
agents/guardrail.py CHANGED
@@ -241,7 +241,11 @@ def input_guardrail(
241
241
  def decorator(
242
242
  f: _InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co],
243
243
  ) -> InputGuardrail[TContext_co]:
244
- return InputGuardrail(guardrail_function=f, name=name)
244
+ return InputGuardrail(
245
+ guardrail_function=f,
246
+ # If not set, guardrail name uses the function’s name by default.
247
+ name=name if name else f.__name__
248
+ )
245
249
 
246
250
  if func is not None:
247
251
  # Decorator was used without parentheses
agents/lifecycle.py CHANGED
@@ -1,25 +1,27 @@
1
1
  from typing import Any, Generic
2
2
 
3
- from .agent import Agent
3
+ from typing_extensions import TypeVar
4
+
5
+ from .agent import Agent, AgentBase
4
6
  from .run_context import RunContextWrapper, TContext
5
7
  from .tool import Tool
6
8
 
9
+ TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase)
10
+
7
11
 
8
- class RunHooks(Generic[TContext]):
12
+ class RunHooksBase(Generic[TContext, TAgent]):
9
13
  """A class that receives callbacks on various lifecycle events in an agent run. Subclass and
10
14
  override the methods you need.
11
15
  """
12
16
 
13
- async def on_agent_start(
14
- self, context: RunContextWrapper[TContext], agent: Agent[TContext]
15
- ) -> None:
17
+ async def on_agent_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
16
18
  """Called before the agent is invoked. Called each time the current agent changes."""
17
19
  pass
18
20
 
19
21
  async def on_agent_end(
20
22
  self,
21
23
  context: RunContextWrapper[TContext],
22
- agent: Agent[TContext],
24
+ agent: TAgent,
23
25
  output: Any,
24
26
  ) -> None:
25
27
  """Called when the agent produces a final output."""
@@ -28,8 +30,8 @@ class RunHooks(Generic[TContext]):
28
30
  async def on_handoff(
29
31
  self,
30
32
  context: RunContextWrapper[TContext],
31
- from_agent: Agent[TContext],
32
- to_agent: Agent[TContext],
33
+ from_agent: TAgent,
34
+ to_agent: TAgent,
33
35
  ) -> None:
34
36
  """Called when a handoff occurs."""
35
37
  pass
@@ -37,7 +39,7 @@ class RunHooks(Generic[TContext]):
37
39
  async def on_tool_start(
38
40
  self,
39
41
  context: RunContextWrapper[TContext],
40
- agent: Agent[TContext],
42
+ agent: TAgent,
41
43
  tool: Tool,
42
44
  ) -> None:
43
45
  """Called before a tool is invoked."""
@@ -46,7 +48,7 @@ class RunHooks(Generic[TContext]):
46
48
  async def on_tool_end(
47
49
  self,
48
50
  context: RunContextWrapper[TContext],
49
- agent: Agent[TContext],
51
+ agent: TAgent,
50
52
  tool: Tool,
51
53
  result: str,
52
54
  ) -> None:
@@ -54,14 +56,14 @@ class RunHooks(Generic[TContext]):
54
56
  pass
55
57
 
56
58
 
57
- class AgentHooks(Generic[TContext]):
59
+ class AgentHooksBase(Generic[TContext, TAgent]):
58
60
  """A class that receives callbacks on various lifecycle events for a specific agent. You can
59
61
  set this on `agent.hooks` to receive events for that specific agent.
60
62
 
61
63
  Subclass and override the methods you need.
62
64
  """
63
65
 
64
- async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None:
66
+ async def on_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
65
67
  """Called before the agent is invoked. Called each time the running agent is changed to this
66
68
  agent."""
67
69
  pass
@@ -69,7 +71,7 @@ class AgentHooks(Generic[TContext]):
69
71
  async def on_end(
70
72
  self,
71
73
  context: RunContextWrapper[TContext],
72
- agent: Agent[TContext],
74
+ agent: TAgent,
73
75
  output: Any,
74
76
  ) -> None:
75
77
  """Called when the agent produces a final output."""
@@ -78,8 +80,8 @@ class AgentHooks(Generic[TContext]):
78
80
  async def on_handoff(
79
81
  self,
80
82
  context: RunContextWrapper[TContext],
81
- agent: Agent[TContext],
82
- source: Agent[TContext],
83
+ agent: TAgent,
84
+ source: TAgent,
83
85
  ) -> None:
84
86
  """Called when the agent is being handed off to. The `source` is the agent that is handing
85
87
  off to this agent."""
@@ -88,7 +90,7 @@ class AgentHooks(Generic[TContext]):
88
90
  async def on_tool_start(
89
91
  self,
90
92
  context: RunContextWrapper[TContext],
91
- agent: Agent[TContext],
93
+ agent: TAgent,
92
94
  tool: Tool,
93
95
  ) -> None:
94
96
  """Called before a tool is invoked."""
@@ -97,9 +99,16 @@ class AgentHooks(Generic[TContext]):
97
99
  async def on_tool_end(
98
100
  self,
99
101
  context: RunContextWrapper[TContext],
100
- agent: Agent[TContext],
102
+ agent: TAgent,
101
103
  tool: Tool,
102
104
  result: str,
103
105
  ) -> None:
104
106
  """Called after a tool is invoked."""
105
107
  pass
108
+
109
+
110
+ RunHooks = RunHooksBase[TContext, Agent]
111
+ """Run hooks when using `Agent`."""
112
+
113
+ AgentHooks = AgentHooksBase[TContext, Agent]
114
+ """Agent hooks for `Agent`s."""
agents/mcp/server.py CHANGED
@@ -13,7 +13,7 @@ from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_cli
13
13
  from mcp.client.sse import sse_client
14
14
  from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
15
15
  from mcp.shared.message import SessionMessage
16
- from mcp.types import CallToolResult, InitializeResult
16
+ from mcp.types import CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult
17
17
  from typing_extensions import NotRequired, TypedDict
18
18
 
19
19
  from ..exceptions import UserError
@@ -22,7 +22,7 @@ from ..run_context import RunContextWrapper
22
22
  from .util import ToolFilter, ToolFilterCallable, ToolFilterContext, ToolFilterStatic
23
23
 
24
24
  if TYPE_CHECKING:
25
- from ..agent import Agent
25
+ from ..agent import AgentBase
26
26
 
27
27
 
28
28
  class MCPServer(abc.ABC):
@@ -52,8 +52,8 @@ class MCPServer(abc.ABC):
52
52
  @abc.abstractmethod
53
53
  async def list_tools(
54
54
  self,
55
- run_context: RunContextWrapper[Any],
56
- agent: Agent[Any],
55
+ run_context: RunContextWrapper[Any] | None = None,
56
+ agent: AgentBase | None = None,
57
57
  ) -> list[MCPTool]:
58
58
  """List the tools available on the server."""
59
59
  pass
@@ -63,6 +63,20 @@ class MCPServer(abc.ABC):
63
63
  """Invoke a tool on the server."""
64
64
  pass
65
65
 
66
+ @abc.abstractmethod
67
+ async def list_prompts(
68
+ self,
69
+ ) -> ListPromptsResult:
70
+ """List the prompts available on the server."""
71
+ pass
72
+
73
+ @abc.abstractmethod
74
+ async def get_prompt(
75
+ self, name: str, arguments: dict[str, Any] | None = None
76
+ ) -> GetPromptResult:
77
+ """Get a specific prompt from the server."""
78
+ pass
79
+
66
80
 
67
81
  class _MCPServerWithClientSession(MCPServer, abc.ABC):
68
82
  """Base class for MCP servers that use a `ClientSession` to communicate with the server."""
@@ -103,7 +117,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
103
117
  self,
104
118
  tools: list[MCPTool],
105
119
  run_context: RunContextWrapper[Any],
106
- agent: Agent[Any],
120
+ agent: AgentBase,
107
121
  ) -> list[MCPTool]:
108
122
  """Apply the tool filter to the list of tools."""
109
123
  if self.tool_filter is None:
@@ -118,9 +132,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
118
132
  return await self._apply_dynamic_tool_filter(tools, run_context, agent)
119
133
 
120
134
  def _apply_static_tool_filter(
121
- self,
122
- tools: list[MCPTool],
123
- static_filter: ToolFilterStatic
135
+ self, tools: list[MCPTool], static_filter: ToolFilterStatic
124
136
  ) -> list[MCPTool]:
125
137
  """Apply static tool filtering based on allowlist and blocklist."""
126
138
  filtered_tools = tools
@@ -141,7 +153,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
141
153
  self,
142
154
  tools: list[MCPTool],
143
155
  run_context: RunContextWrapper[Any],
144
- agent: Agent[Any],
156
+ agent: AgentBase,
145
157
  ) -> list[MCPTool]:
146
158
  """Apply dynamic tool filtering using a callable filter function."""
147
159
 
@@ -231,8 +243,8 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
231
243
 
232
244
  async def list_tools(
233
245
  self,
234
- run_context: RunContextWrapper[Any],
235
- agent: Agent[Any],
246
+ run_context: RunContextWrapper[Any] | None = None,
247
+ agent: AgentBase | None = None,
236
248
  ) -> list[MCPTool]:
237
249
  """List the tools available on the server."""
238
250
  if not self.session:
@@ -251,6 +263,8 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
251
263
  # Filter tools based on tool_filter
252
264
  filtered_tools = tools
253
265
  if self.tool_filter is not None:
266
+ if run_context is None or agent is None:
267
+ raise UserError("run_context and agent are required for dynamic tool filtering")
254
268
  filtered_tools = await self._apply_tool_filter(filtered_tools, run_context, agent)
255
269
  return filtered_tools
256
270
 
@@ -261,6 +275,24 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
261
275
 
262
276
  return await self.session.call_tool(tool_name, arguments)
263
277
 
278
+ async def list_prompts(
279
+ self,
280
+ ) -> ListPromptsResult:
281
+ """List the prompts available on the server."""
282
+ if not self.session:
283
+ raise UserError("Server not initialized. Make sure you call `connect()` first.")
284
+
285
+ return await self.session.list_prompts()
286
+
287
+ async def get_prompt(
288
+ self, name: str, arguments: dict[str, Any] | None = None
289
+ ) -> GetPromptResult:
290
+ """Get a specific prompt from the server."""
291
+ if not self.session:
292
+ raise UserError("Server not initialized. Make sure you call `connect()` first.")
293
+
294
+ return await self.session.get_prompt(name, arguments)
295
+
264
296
  async def cleanup(self):
265
297
  """Cleanup the server."""
266
298
  async with self._cleanup_lock:
agents/mcp/util.py CHANGED
@@ -5,12 +5,11 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
5
5
 
6
6
  from typing_extensions import NotRequired, TypedDict
7
7
 
8
- from agents.strict_schema import ensure_strict_json_schema
9
-
10
8
  from .. import _debug
11
9
  from ..exceptions import AgentsException, ModelBehaviorError, UserError
12
10
  from ..logger import logger
13
11
  from ..run_context import RunContextWrapper
12
+ from ..strict_schema import ensure_strict_json_schema
14
13
  from ..tool import FunctionTool, Tool
15
14
  from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span
16
15
  from ..util._types import MaybeAwaitable
@@ -18,7 +17,7 @@ from ..util._types import MaybeAwaitable
18
17
  if TYPE_CHECKING:
19
18
  from mcp.types import Tool as MCPTool
20
19
 
21
- from ..agent import Agent
20
+ from ..agent import AgentBase
22
21
  from .server import MCPServer
23
22
 
24
23
 
@@ -29,7 +28,7 @@ class ToolFilterContext:
29
28
  run_context: RunContextWrapper[Any]
30
29
  """The current run context."""
31
30
 
32
- agent: "Agent[Any]"
31
+ agent: "AgentBase"
33
32
  """The agent that is requesting the tool list."""
34
33
 
35
34
  server_name: str
@@ -100,7 +99,7 @@ class MCPUtil:
100
99
  servers: list["MCPServer"],
101
100
  convert_schemas_to_strict: bool,
102
101
  run_context: RunContextWrapper[Any],
103
- agent: "Agent[Any]",
102
+ agent: "AgentBase",
104
103
  ) -> list[Tool]:
105
104
  """Get all function tools from a list of MCP servers."""
106
105
  tools = []
@@ -126,7 +125,7 @@ class MCPUtil:
126
125
  server: "MCPServer",
127
126
  convert_schemas_to_strict: bool,
128
127
  run_context: RunContextWrapper[Any],
129
- agent: "Agent[Any]",
128
+ agent: "AgentBase",
130
129
  ) -> list[Tool]:
131
130
  """Get all function tools from a single MCP server."""
132
131
 
@@ -0,0 +1,3 @@
1
+ from .session import Session, SQLiteSession
2
+
3
+ __all__ = ["Session", "SQLiteSession"]