openai-agents 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

agents/__init__.py CHANGED
@@ -70,6 +70,7 @@ from .tracing import (
70
70
  GenerationSpanData,
71
71
  GuardrailSpanData,
72
72
  HandoffSpanData,
73
+ MCPListToolsSpanData,
73
74
  Span,
74
75
  SpanData,
75
76
  SpanError,
@@ -89,6 +90,7 @@ from .tracing import (
89
90
  get_current_trace,
90
91
  guardrail_span,
91
92
  handoff_span,
93
+ mcp_tools_span,
92
94
  set_trace_processors,
93
95
  set_tracing_disabled,
94
96
  set_tracing_export_api_key,
@@ -98,6 +100,7 @@ from .tracing import (
98
100
  transcription_span,
99
101
  )
100
102
  from .usage import Usage
103
+ from .version import __version__
101
104
 
102
105
 
103
106
  def set_default_openai_key(key: str, use_for_tracing: bool = True) -> None:
@@ -220,6 +223,7 @@ __all__ = [
220
223
  "speech_group_span",
221
224
  "transcription_span",
222
225
  "speech_span",
226
+ "mcp_tools_span",
223
227
  "trace",
224
228
  "Trace",
225
229
  "TracingProcessor",
@@ -234,6 +238,7 @@ __all__ = [
234
238
  "HandoffSpanData",
235
239
  "SpeechGroupSpanData",
236
240
  "SpeechSpanData",
241
+ "MCPListToolsSpanData",
237
242
  "TranscriptionSpanData",
238
243
  "set_default_openai_key",
239
244
  "set_default_openai_client",
@@ -243,4 +248,5 @@ __all__ = [
243
248
  "gen_trace_id",
244
249
  "gen_span_id",
245
250
  "default_tool_error_function",
251
+ "__version__",
246
252
  ]
agents/_run_impl.py CHANGED
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
+ import dataclasses
4
5
  import inspect
5
6
  from collections.abc import Awaitable
6
- from dataclasses import dataclass
7
+ from dataclasses import dataclass, field
7
8
  from typing import TYPE_CHECKING, Any, cast
8
9
 
9
10
  from openai.types.responses import (
@@ -47,10 +48,11 @@ from .items import (
47
48
  )
48
49
  from .lifecycle import RunHooks
49
50
  from .logger import logger
51
+ from .model_settings import ModelSettings
50
52
  from .models.interface import ModelTracing
51
53
  from .run_context import RunContextWrapper, TContext
52
54
  from .stream_events import RunItemStreamEvent, StreamEvent
53
- from .tool import ComputerTool, FunctionTool, FunctionToolResult
55
+ from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool
54
56
  from .tracing import (
55
57
  SpanError,
56
58
  Trace,
@@ -75,6 +77,23 @@ QUEUE_COMPLETE_SENTINEL = QueueCompleteSentinel()
75
77
  _NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None)
76
78
 
77
79
 
80
+ @dataclass
81
+ class AgentToolUseTracker:
82
+ agent_to_tools: list[tuple[Agent, list[str]]] = field(default_factory=list)
83
+ """Tuple of (agent, list of tools used). Can't use a dict because agents aren't hashable."""
84
+
85
+ def add_tool_use(self, agent: Agent[Any], tool_names: list[str]) -> None:
86
+ existing_data = next((item for item in self.agent_to_tools if item[0] == agent), None)
87
+ if existing_data:
88
+ existing_data[1].extend(tool_names)
89
+ else:
90
+ self.agent_to_tools.append((agent, tool_names))
91
+
92
+ def has_used_tools(self, agent: Agent[Any]) -> bool:
93
+ existing_data = next((item for item in self.agent_to_tools if item[0] == agent), None)
94
+ return existing_data is not None and len(existing_data[1]) > 0
95
+
96
+
78
97
  @dataclass
79
98
  class ToolRunHandoff:
80
99
  handoff: Handoff
@@ -99,6 +118,7 @@ class ProcessedResponse:
99
118
  handoffs: list[ToolRunHandoff]
100
119
  functions: list[ToolRunFunction]
101
120
  computer_actions: list[ToolRunComputerAction]
121
+ tools_used: list[str] # Names of all tools used, including hosted tools
102
122
 
103
123
  def has_tools_to_run(self) -> bool:
104
124
  # Handoffs, functions and computer actions need local processing
@@ -296,11 +316,24 @@ class RunImpl:
296
316
  next_step=NextStepRunAgain(),
297
317
  )
298
318
 
319
+ @classmethod
320
+ def maybe_reset_tool_choice(
321
+ cls, agent: Agent[Any], tool_use_tracker: AgentToolUseTracker, model_settings: ModelSettings
322
+ ) -> ModelSettings:
323
+ """Resets tool choice to None if the agent has used tools and the agent's reset_tool_choice
324
+ flag is True."""
325
+
326
+ if agent.reset_tool_choice is True and tool_use_tracker.has_used_tools(agent):
327
+ return dataclasses.replace(model_settings, tool_choice=None)
328
+
329
+ return model_settings
330
+
299
331
  @classmethod
300
332
  def process_model_response(
301
333
  cls,
302
334
  *,
303
335
  agent: Agent[Any],
336
+ all_tools: list[Tool],
304
337
  response: ModelResponse,
305
338
  output_schema: AgentOutputSchema | None,
306
339
  handoffs: list[Handoff],
@@ -310,22 +343,25 @@ class RunImpl:
310
343
  run_handoffs = []
311
344
  functions = []
312
345
  computer_actions = []
313
-
346
+ tools_used: list[str] = []
314
347
  handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
315
- function_map = {tool.name: tool for tool in agent.tools if isinstance(tool, FunctionTool)}
316
- computer_tool = next((tool for tool in agent.tools if isinstance(tool, ComputerTool)), None)
348
+ function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
349
+ computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)
317
350
 
318
351
  for output in response.output:
319
352
  if isinstance(output, ResponseOutputMessage):
320
353
  items.append(MessageOutputItem(raw_item=output, agent=agent))
321
354
  elif isinstance(output, ResponseFileSearchToolCall):
322
355
  items.append(ToolCallItem(raw_item=output, agent=agent))
356
+ tools_used.append("file_search")
323
357
  elif isinstance(output, ResponseFunctionWebSearch):
324
358
  items.append(ToolCallItem(raw_item=output, agent=agent))
359
+ tools_used.append("web_search")
325
360
  elif isinstance(output, ResponseReasoningItem):
326
361
  items.append(ReasoningItem(raw_item=output, agent=agent))
327
362
  elif isinstance(output, ResponseComputerToolCall):
328
363
  items.append(ToolCallItem(raw_item=output, agent=agent))
364
+ tools_used.append("computer_use")
329
365
  if not computer_tool:
330
366
  _error_tracing.attach_error_to_current_span(
331
367
  SpanError(
@@ -347,6 +383,8 @@ class RunImpl:
347
383
  if not isinstance(output, ResponseFunctionToolCall):
348
384
  continue
349
385
 
386
+ tools_used.append(output.name)
387
+
350
388
  # Handoffs
351
389
  if output.name in handoff_map:
352
390
  items.append(HandoffCallItem(raw_item=output, agent=agent))
@@ -378,6 +416,7 @@ class RunImpl:
378
416
  handoffs=run_handoffs,
379
417
  functions=functions,
380
418
  computer_actions=computer_actions,
419
+ tools_used=tools_used,
381
420
  )
382
421
 
383
422
  @classmethod
@@ -490,7 +529,8 @@ class RunImpl:
490
529
  run_config: RunConfig,
491
530
  ) -> SingleStepResult:
492
531
  # If there is more than one handoff, add tool responses that reject those handoffs
493
- if len(run_handoffs) > 1:
532
+ multiple_handoffs = len(run_handoffs) > 1
533
+ if multiple_handoffs:
494
534
  output_message = "Multiple handoffs detected, ignoring this one."
495
535
  new_step_items.extend(
496
536
  [
@@ -512,6 +552,16 @@ class RunImpl:
512
552
  context_wrapper, actual_handoff.tool_call.arguments
513
553
  )
514
554
  span_handoff.span_data.to_agent = new_agent.name
555
+ if multiple_handoffs:
556
+ requested_agents = [handoff.handoff.agent_name for handoff in run_handoffs]
557
+ span_handoff.set_error(
558
+ SpanError(
559
+ message="Multiple handoffs requested",
560
+ data={
561
+ "requested_agents": requested_agents,
562
+ },
563
+ )
564
+ )
515
565
 
516
566
  # Append a tool output item for the handoff
517
567
  new_step_items.append(
agents/agent.py CHANGED
@@ -6,12 +6,13 @@ from collections.abc import Awaitable
6
6
  from dataclasses import dataclass, field
7
7
  from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
8
8
 
9
- from typing_extensions import TypeAlias, TypedDict
9
+ from typing_extensions import NotRequired, TypeAlias, TypedDict
10
10
 
11
11
  from .guardrail import InputGuardrail, OutputGuardrail
12
12
  from .handoffs import Handoff
13
13
  from .items import ItemHelpers
14
14
  from .logger import logger
15
+ from .mcp import MCPUtil
15
16
  from .model_settings import ModelSettings
16
17
  from .models.interface import Model
17
18
  from .run_context import RunContextWrapper, TContext
@@ -21,6 +22,7 @@ from .util._types import MaybeAwaitable
21
22
 
22
23
  if TYPE_CHECKING:
23
24
  from .lifecycle import AgentHooks
25
+ from .mcp import MCPServer
24
26
  from .result import RunResult
25
27
 
26
28
 
@@ -42,7 +44,7 @@ ToolsToFinalOutputFunction: TypeAlias = Callable[
42
44
  MaybeAwaitable[ToolsToFinalOutputResult],
43
45
  ]
44
46
  """A function that takes a run context and a list of tool results, and returns a
45
- `ToolToFinalOutputResult`.
47
+ `ToolsToFinalOutputResult`.
46
48
  """
47
49
 
48
50
 
@@ -51,6 +53,15 @@ class StopAtTools(TypedDict):
51
53
  """A list of tool names, any of which will stop the agent from running further."""
52
54
 
53
55
 
56
+ class MCPConfig(TypedDict):
57
+ """Configuration for MCP servers."""
58
+
59
+ convert_schemas_to_strict: NotRequired[bool]
60
+ """If True, we will attempt to convert the MCP schemas to strict-mode schemas. This is a
61
+ best-effort conversion, so some schemas may not be convertible. Defaults to False.
62
+ """
63
+
64
+
54
65
  @dataclass
55
66
  class Agent(Generic[TContext]):
56
67
  """An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
@@ -107,6 +118,19 @@ class Agent(Generic[TContext]):
107
118
  tools: list[Tool] = field(default_factory=list)
108
119
  """A list of tools that the agent can use."""
109
120
 
121
+ mcp_servers: list[MCPServer] = field(default_factory=list)
122
+ """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
123
+ the agent can use. Every time the agent runs, it will include tools from these servers in the
124
+ list of available tools.
125
+
126
+ NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
127
+ `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
128
+ longer needed.
129
+ """
130
+
131
+ mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
132
+ """Configuration for MCP servers."""
133
+
110
134
  input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
111
135
  """A list of checks that run in parallel to the agent's execution, before generating a
112
136
  response. Runs only if the agent is the first agent in the chain.
@@ -143,6 +167,10 @@ class Agent(Generic[TContext]):
143
167
  web search, etc are always processed by the LLM.
144
168
  """
145
169
 
170
+ reset_tool_choice: bool = True
171
+ """Whether to reset the tool choice to the default value after a tool has been called. Defaults
172
+ to True. This ensures that the agent doesn't enter an infinite loop of tool usage."""
173
+
146
174
  def clone(self, **kwargs: Any) -> Agent[TContext]:
147
175
  """Make a copy of the agent, with the given arguments changed. For example, you could do:
148
176
  ```
@@ -205,3 +233,13 @@ class Agent(Generic[TContext]):
205
233
  logger.error(f"Instructions must be a string or a function, got {self.instructions}")
206
234
 
207
235
  return None
236
+
237
+ async def get_mcp_tools(self) -> list[Tool]:
238
+ """Fetches the available tools from the MCP servers."""
239
+ convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
240
+ return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
241
+
242
+ async def get_all_tools(self) -> list[Tool]:
243
+ """All agent tools, including MCP tools and function tools."""
244
+ mcp_tools = await self.get_mcp_tools()
245
+ return mcp_tools + self.tools
@@ -0,0 +1,137 @@
1
+ from typing import Optional
2
+
3
+ import graphviz # type: ignore
4
+
5
+ from agents import Agent
6
+ from agents.handoffs import Handoff
7
+ from agents.tool import Tool
8
+
9
+
10
+ def get_main_graph(agent: Agent) -> str:
11
+ """
12
+ Generates the main graph structure in DOT format for the given agent.
13
+
14
+ Args:
15
+ agent (Agent): The agent for which the graph is to be generated.
16
+
17
+ Returns:
18
+ str: The DOT format string representing the graph.
19
+ """
20
+ parts = [
21
+ """
22
+ digraph G {
23
+ graph [splines=true];
24
+ node [fontname="Arial"];
25
+ edge [penwidth=1.5];
26
+ """
27
+ ]
28
+ parts.append(get_all_nodes(agent))
29
+ parts.append(get_all_edges(agent))
30
+ parts.append("}")
31
+ return "".join(parts)
32
+
33
+
34
+ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
35
+ """
36
+ Recursively generates the nodes for the given agent and its handoffs in DOT format.
37
+
38
+ Args:
39
+ agent (Agent): The agent for which the nodes are to be generated.
40
+
41
+ Returns:
42
+ str: The DOT format string representing the nodes.
43
+ """
44
+ parts = []
45
+
46
+ # Start and end the graph
47
+ parts.append(
48
+ '"__start__" [label="__start__", shape=ellipse, style=filled, '
49
+ "fillcolor=lightblue, width=0.5, height=0.3];"
50
+ '"__end__" [label="__end__", shape=ellipse, style=filled, '
51
+ "fillcolor=lightblue, width=0.5, height=0.3];"
52
+ )
53
+ # Ensure parent agent node is colored
54
+ if not parent:
55
+ parts.append(
56
+ f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, '
57
+ "fillcolor=lightyellow, width=1.5, height=0.8];"
58
+ )
59
+
60
+ for tool in agent.tools:
61
+ parts.append(
62
+ f'"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, '
63
+ f"fillcolor=lightgreen, width=0.5, height=0.3];"
64
+ )
65
+
66
+ for handoff in agent.handoffs:
67
+ if isinstance(handoff, Handoff):
68
+ parts.append(
69
+ f'"{handoff.agent_name}" [label="{handoff.agent_name}", '
70
+ f"shape=box, style=filled, style=rounded, "
71
+ f"fillcolor=lightyellow, width=1.5, height=0.8];"
72
+ )
73
+ if isinstance(handoff, Agent):
74
+ parts.append(
75
+ f'"{handoff.name}" [label="{handoff.name}", '
76
+ f"shape=box, style=filled, style=rounded, "
77
+ f"fillcolor=lightyellow, width=1.5, height=0.8];"
78
+ )
79
+ parts.append(get_all_nodes(handoff))
80
+
81
+ return "".join(parts)
82
+
83
+
84
+ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
85
+ """
86
+ Recursively generates the edges for the given agent and its handoffs in DOT format.
87
+
88
+ Args:
89
+ agent (Agent): The agent for which the edges are to be generated.
90
+ parent (Agent, optional): The parent agent. Defaults to None.
91
+
92
+ Returns:
93
+ str: The DOT format string representing the edges.
94
+ """
95
+ parts = []
96
+
97
+ if not parent:
98
+ parts.append(f'"__start__" -> "{agent.name}";')
99
+
100
+ for tool in agent.tools:
101
+ parts.append(f"""
102
+ "{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5];
103
+ "{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""")
104
+
105
+ for handoff in agent.handoffs:
106
+ if isinstance(handoff, Handoff):
107
+ parts.append(f"""
108
+ "{agent.name}" -> "{handoff.agent_name}";""")
109
+ if isinstance(handoff, Agent):
110
+ parts.append(f"""
111
+ "{agent.name}" -> "{handoff.name}";""")
112
+ parts.append(get_all_edges(handoff, agent))
113
+
114
+ if not agent.handoffs and not isinstance(agent, Tool): # type: ignore
115
+ parts.append(f'"{agent.name}" -> "__end__";')
116
+
117
+ return "".join(parts)
118
+
119
+
120
+ def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source:
121
+ """
122
+ Draws the graph for the given agent and optionally saves it as a PNG file.
123
+
124
+ Args:
125
+ agent (Agent): The agent for which the graph is to be drawn.
126
+ filename (str): The name of the file to save the graph as a PNG.
127
+
128
+ Returns:
129
+ graphviz.Source: The graphviz Source object representing the graph.
130
+ """
131
+ dot_code = get_main_graph(agent)
132
+ graph = graphviz.Source(dot_code)
133
+
134
+ if filename:
135
+ graph.render(filename, format="png")
136
+
137
+ return graph
agents/mcp/__init__.py ADDED
@@ -0,0 +1,21 @@
1
+ try:
2
+ from .server import (
3
+ MCPServer,
4
+ MCPServerSse,
5
+ MCPServerSseParams,
6
+ MCPServerStdio,
7
+ MCPServerStdioParams,
8
+ )
9
+ except ImportError:
10
+ pass
11
+
12
+ from .util import MCPUtil
13
+
14
+ __all__ = [
15
+ "MCPServer",
16
+ "MCPServerSse",
17
+ "MCPServerSseParams",
18
+ "MCPServerStdio",
19
+ "MCPServerStdioParams",
20
+ "MCPUtil",
21
+ ]