openai-agents 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

agents/__init__.py CHANGED
@@ -14,6 +14,7 @@ from .exceptions import (
14
14
  MaxTurnsExceeded,
15
15
  ModelBehaviorError,
16
16
  OutputGuardrailTripwireTriggered,
17
+ RunErrorDetails,
17
18
  UserError,
18
19
  )
19
20
  from .guardrail import (
@@ -54,10 +55,19 @@ from .stream_events import (
54
55
  StreamEvent,
55
56
  )
56
57
  from .tool import (
58
+ CodeInterpreterTool,
57
59
  ComputerTool,
58
60
  FileSearchTool,
59
61
  FunctionTool,
60
62
  FunctionToolResult,
63
+ HostedMCPTool,
64
+ ImageGenerationTool,
65
+ LocalShellCommandRequest,
66
+ LocalShellExecutor,
67
+ LocalShellTool,
68
+ MCPToolApprovalFunction,
69
+ MCPToolApprovalFunctionResult,
70
+ MCPToolApprovalRequest,
61
71
  Tool,
62
72
  WebSearchTool,
63
73
  default_tool_error_function,
@@ -195,6 +205,7 @@ __all__ = [
195
205
  "AgentHooks",
196
206
  "RunContextWrapper",
197
207
  "TContext",
208
+ "RunErrorDetails",
198
209
  "RunResult",
199
210
  "RunResultStreaming",
200
211
  "RunConfig",
@@ -206,8 +217,17 @@ __all__ = [
206
217
  "FunctionToolResult",
207
218
  "ComputerTool",
208
219
  "FileSearchTool",
220
+ "CodeInterpreterTool",
221
+ "ImageGenerationTool",
222
+ "LocalShellCommandRequest",
223
+ "LocalShellExecutor",
224
+ "LocalShellTool",
209
225
  "Tool",
210
226
  "WebSearchTool",
227
+ "HostedMCPTool",
228
+ "MCPToolApprovalFunction",
229
+ "MCPToolApprovalRequest",
230
+ "MCPToolApprovalFunctionResult",
211
231
  "function_tool",
212
232
  "Usage",
213
233
  "add_trace_processor",
agents/_run_impl.py CHANGED
@@ -14,6 +14,9 @@ from openai.types.responses import (
14
14
  ResponseFunctionWebSearch,
15
15
  ResponseOutputMessage,
16
16
  )
17
+ from openai.types.responses.response_code_interpreter_tool_call import (
18
+ ResponseCodeInterpreterToolCall,
19
+ )
17
20
  from openai.types.responses.response_computer_tool_call import (
18
21
  ActionClick,
19
22
  ActionDoubleClick,
@@ -25,7 +28,14 @@ from openai.types.responses.response_computer_tool_call import (
25
28
  ActionType,
26
29
  ActionWait,
27
30
  )
28
- from openai.types.responses.response_input_param import ComputerCallOutput
31
+ from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse
32
+ from openai.types.responses.response_output_item import (
33
+ ImageGenerationCall,
34
+ LocalShellCall,
35
+ McpApprovalRequest,
36
+ McpCall,
37
+ McpListTools,
38
+ )
29
39
  from openai.types.responses.response_reasoning_item import ResponseReasoningItem
30
40
 
31
41
  from .agent import Agent, ToolsToFinalOutputResult
@@ -38,6 +48,9 @@ from .items import (
38
48
  HandoffCallItem,
39
49
  HandoffOutputItem,
40
50
  ItemHelpers,
51
+ MCPApprovalRequestItem,
52
+ MCPApprovalResponseItem,
53
+ MCPListToolsItem,
41
54
  MessageOutputItem,
42
55
  ModelResponse,
43
56
  ReasoningItem,
@@ -52,7 +65,16 @@ from .model_settings import ModelSettings
52
65
  from .models.interface import ModelTracing
53
66
  from .run_context import RunContextWrapper, TContext
54
67
  from .stream_events import RunItemStreamEvent, StreamEvent
55
- from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool
68
+ from .tool import (
69
+ ComputerTool,
70
+ FunctionTool,
71
+ FunctionToolResult,
72
+ HostedMCPTool,
73
+ LocalShellCommandRequest,
74
+ LocalShellTool,
75
+ MCPToolApprovalRequest,
76
+ Tool,
77
+ )
56
78
  from .tracing import (
57
79
  SpanError,
58
80
  Trace,
@@ -112,15 +134,29 @@ class ToolRunComputerAction:
112
134
  computer_tool: ComputerTool
113
135
 
114
136
 
137
+ @dataclass
138
+ class ToolRunMCPApprovalRequest:
139
+ request_item: McpApprovalRequest
140
+ mcp_tool: HostedMCPTool
141
+
142
+
143
+ @dataclass
144
+ class ToolRunLocalShellCall:
145
+ tool_call: LocalShellCall
146
+ local_shell_tool: LocalShellTool
147
+
148
+
115
149
  @dataclass
116
150
  class ProcessedResponse:
117
151
  new_items: list[RunItem]
118
152
  handoffs: list[ToolRunHandoff]
119
153
  functions: list[ToolRunFunction]
120
154
  computer_actions: list[ToolRunComputerAction]
155
+ local_shell_calls: list[ToolRunLocalShellCall]
121
156
  tools_used: list[str] # Names of all tools used, including hosted tools
157
+ mcp_approval_requests: list[ToolRunMCPApprovalRequest] # Only requests with callbacks
122
158
 
123
- def has_tools_to_run(self) -> bool:
159
+ def has_tools_or_approvals_to_run(self) -> bool:
124
160
  # Handoffs, functions and computer actions need local processing
125
161
  # Hosted tools have already run, so there's nothing to do.
126
162
  return any(
@@ -128,6 +164,8 @@ class ProcessedResponse:
128
164
  self.handoffs,
129
165
  self.functions,
130
166
  self.computer_actions,
167
+ self.local_shell_calls,
168
+ self.mcp_approval_requests,
131
169
  ]
132
170
  )
133
171
 
@@ -226,7 +264,16 @@ class RunImpl:
226
264
  new_step_items.extend([result.run_item for result in function_results])
227
265
  new_step_items.extend(computer_results)
228
266
 
229
- # Second, check if there are any handoffs
267
+ # Next, run the MCP approval requests
268
+ if processed_response.mcp_approval_requests:
269
+ approval_results = await cls.execute_mcp_approval_requests(
270
+ agent=agent,
271
+ approval_requests=processed_response.mcp_approval_requests,
272
+ context_wrapper=context_wrapper,
273
+ )
274
+ new_step_items.extend(approval_results)
275
+
276
+ # Next, check if there are any handoffs
230
277
  if run_handoffs := processed_response.handoffs:
231
278
  return await cls.execute_handoffs(
232
279
  agent=agent,
@@ -240,7 +287,7 @@ class RunImpl:
240
287
  run_config=run_config,
241
288
  )
242
289
 
243
- # Third, we'll check if the tool use should result in a final output
290
+ # Next, we'll check if the tool use should result in a final output
244
291
  check_tool_use = await cls._check_for_final_output_from_tools(
245
292
  agent=agent,
246
293
  tool_results=function_results,
@@ -295,7 +342,7 @@ class RunImpl:
295
342
  )
296
343
  elif (
297
344
  not output_schema or output_schema.is_plain_text()
298
- ) and not processed_response.has_tools_to_run():
345
+ ) and not processed_response.has_tools_or_approvals_to_run():
299
346
  return await cls.execute_final_output(
300
347
  agent=agent,
301
348
  original_input=original_input,
@@ -343,10 +390,20 @@ class RunImpl:
343
390
  run_handoffs = []
344
391
  functions = []
345
392
  computer_actions = []
393
+ local_shell_calls = []
394
+ mcp_approval_requests = []
346
395
  tools_used: list[str] = []
347
396
  handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
348
397
  function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
349
398
  computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)
399
+ local_shell_tool = next(
400
+ (tool for tool in all_tools if isinstance(tool, LocalShellTool)), None
401
+ )
402
+ hosted_mcp_server_map = {
403
+ tool.tool_config["server_label"]: tool
404
+ for tool in all_tools
405
+ if isinstance(tool, HostedMCPTool)
406
+ }
350
407
 
351
408
  for output in response.output:
352
409
  if isinstance(output, ResponseOutputMessage):
@@ -375,6 +432,57 @@ class RunImpl:
375
432
  computer_actions.append(
376
433
  ToolRunComputerAction(tool_call=output, computer_tool=computer_tool)
377
434
  )
435
+ elif isinstance(output, McpApprovalRequest):
436
+ items.append(MCPApprovalRequestItem(raw_item=output, agent=agent))
437
+ if output.server_label not in hosted_mcp_server_map:
438
+ _error_tracing.attach_error_to_current_span(
439
+ SpanError(
440
+ message="MCP server label not found",
441
+ data={"server_label": output.server_label},
442
+ )
443
+ )
444
+ raise ModelBehaviorError(f"MCP server label {output.server_label} not found")
445
+ else:
446
+ server = hosted_mcp_server_map[output.server_label]
447
+ if server.on_approval_request:
448
+ mcp_approval_requests.append(
449
+ ToolRunMCPApprovalRequest(
450
+ request_item=output,
451
+ mcp_tool=server,
452
+ )
453
+ )
454
+ else:
455
+ logger.warning(
456
+ f"MCP server {output.server_label} has no on_approval_request hook"
457
+ )
458
+ elif isinstance(output, McpListTools):
459
+ items.append(MCPListToolsItem(raw_item=output, agent=agent))
460
+ elif isinstance(output, McpCall):
461
+ items.append(ToolCallItem(raw_item=output, agent=agent))
462
+ tools_used.append("mcp")
463
+ elif isinstance(output, ImageGenerationCall):
464
+ items.append(ToolCallItem(raw_item=output, agent=agent))
465
+ tools_used.append("image_generation")
466
+ elif isinstance(output, ResponseCodeInterpreterToolCall):
467
+ items.append(ToolCallItem(raw_item=output, agent=agent))
468
+ tools_used.append("code_interpreter")
469
+ elif isinstance(output, LocalShellCall):
470
+ items.append(ToolCallItem(raw_item=output, agent=agent))
471
+ tools_used.append("local_shell")
472
+ if not local_shell_tool:
473
+ _error_tracing.attach_error_to_current_span(
474
+ SpanError(
475
+ message="Local shell tool not found",
476
+ data={},
477
+ )
478
+ )
479
+ raise ModelBehaviorError(
480
+ "Model produced local shell call without a local shell tool."
481
+ )
482
+ local_shell_calls.append(
483
+ ToolRunLocalShellCall(tool_call=output, local_shell_tool=local_shell_tool)
484
+ )
485
+
378
486
  elif not isinstance(output, ResponseFunctionToolCall):
379
487
  logger.warning(f"Unexpected output type, ignoring: {type(output)}")
380
488
  continue
@@ -416,7 +524,9 @@ class RunImpl:
416
524
  handoffs=run_handoffs,
417
525
  functions=functions,
418
526
  computer_actions=computer_actions,
527
+ local_shell_calls=local_shell_calls,
419
528
  tools_used=tools_used,
529
+ mcp_approval_requests=mcp_approval_requests,
420
530
  )
421
531
 
422
532
  @classmethod
@@ -489,6 +599,30 @@ class RunImpl:
489
599
  for tool_run, result in zip(tool_runs, results)
490
600
  ]
491
601
 
602
+ @classmethod
603
+ async def execute_local_shell_calls(
604
+ cls,
605
+ *,
606
+ agent: Agent[TContext],
607
+ calls: list[ToolRunLocalShellCall],
608
+ context_wrapper: RunContextWrapper[TContext],
609
+ hooks: RunHooks[TContext],
610
+ config: RunConfig,
611
+ ) -> list[RunItem]:
612
+ results: list[RunItem] = []
613
+ # Need to run these serially, because each call can affect the local shell state
614
+ for call in calls:
615
+ results.append(
616
+ await LocalShellAction.execute(
617
+ agent=agent,
618
+ call=call,
619
+ hooks=hooks,
620
+ context_wrapper=context_wrapper,
621
+ config=config,
622
+ )
623
+ )
624
+ return results
625
+
492
626
  @classmethod
493
627
  async def execute_computer_actions(
494
628
  cls,
@@ -643,6 +777,40 @@ class RunImpl:
643
777
  next_step=NextStepHandoff(new_agent),
644
778
  )
645
779
 
780
+ @classmethod
781
+ async def execute_mcp_approval_requests(
782
+ cls,
783
+ *,
784
+ agent: Agent[TContext],
785
+ approval_requests: list[ToolRunMCPApprovalRequest],
786
+ context_wrapper: RunContextWrapper[TContext],
787
+ ) -> list[RunItem]:
788
+ async def run_single_approval(approval_request: ToolRunMCPApprovalRequest) -> RunItem:
789
+ callback = approval_request.mcp_tool.on_approval_request
790
+ assert callback is not None, "Callback is required for MCP approval requests"
791
+ maybe_awaitable_result = callback(
792
+ MCPToolApprovalRequest(context_wrapper, approval_request.request_item)
793
+ )
794
+ if inspect.isawaitable(maybe_awaitable_result):
795
+ result = await maybe_awaitable_result
796
+ else:
797
+ result = maybe_awaitable_result
798
+ reason = result.get("reason", None)
799
+ raw_item: McpApprovalResponse = {
800
+ "approval_request_id": approval_request.request_item.id,
801
+ "approve": result["approve"],
802
+ "type": "mcp_approval_response",
803
+ }
804
+ if not result["approve"] and reason:
805
+ raw_item["reason"] = reason
806
+ return MCPApprovalResponseItem(
807
+ raw_item=raw_item,
808
+ agent=agent,
809
+ )
810
+
811
+ tasks = [run_single_approval(approval_request) for approval_request in approval_requests]
812
+ return await asyncio.gather(*tasks)
813
+
646
814
  @classmethod
647
815
  async def execute_final_output(
648
816
  cls,
@@ -727,6 +895,11 @@ class RunImpl:
727
895
  event = RunItemStreamEvent(item=item, name="tool_output")
728
896
  elif isinstance(item, ReasoningItem):
729
897
  event = RunItemStreamEvent(item=item, name="reasoning_item_created")
898
+ elif isinstance(item, MCPApprovalRequestItem):
899
+ event = RunItemStreamEvent(item=item, name="mcp_approval_requested")
900
+ elif isinstance(item, MCPListToolsItem):
901
+ event = RunItemStreamEvent(item=item, name="mcp_list_tools")
902
+
730
903
  else:
731
904
  logger.warning(f"Unexpected item type: {type(item)}")
732
905
  event = None
@@ -919,3 +1092,54 @@ class ComputerAction:
919
1092
  await computer.wait()
920
1093
 
921
1094
  return await computer.screenshot()
1095
+
1096
+
1097
+ class LocalShellAction:
1098
+ @classmethod
1099
+ async def execute(
1100
+ cls,
1101
+ *,
1102
+ agent: Agent[TContext],
1103
+ call: ToolRunLocalShellCall,
1104
+ hooks: RunHooks[TContext],
1105
+ context_wrapper: RunContextWrapper[TContext],
1106
+ config: RunConfig,
1107
+ ) -> RunItem:
1108
+ await asyncio.gather(
1109
+ hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool),
1110
+ (
1111
+ agent.hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool)
1112
+ if agent.hooks
1113
+ else _coro.noop_coroutine()
1114
+ ),
1115
+ )
1116
+
1117
+ request = LocalShellCommandRequest(
1118
+ ctx_wrapper=context_wrapper,
1119
+ data=call.tool_call,
1120
+ )
1121
+ output = call.local_shell_tool.executor(request)
1122
+ if inspect.isawaitable(output):
1123
+ result = await output
1124
+ else:
1125
+ result = output
1126
+
1127
+ await asyncio.gather(
1128
+ hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result),
1129
+ (
1130
+ agent.hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result)
1131
+ if agent.hooks
1132
+ else _coro.noop_coroutine()
1133
+ ),
1134
+ )
1135
+
1136
+ return ToolCallOutputItem(
1137
+ agent=agent,
1138
+ output=output,
1139
+ raw_item={
1140
+ "type": "local_shell_call_output",
1141
+ "id": call.tool_call.call_id,
1142
+ "output": result,
1143
+ # "id": "out" + call.tool_call.id, # TODO remove this, it should be optional
1144
+ },
1145
+ )
agents/agent.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import dataclasses
4
5
  import inspect
5
6
  from collections.abc import Awaitable
@@ -17,7 +18,7 @@ from .mcp import MCPUtil
17
18
  from .model_settings import ModelSettings
18
19
  from .models.interface import Model
19
20
  from .run_context import RunContextWrapper, TContext
20
- from .tool import FunctionToolResult, Tool, function_tool
21
+ from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
21
22
  from .util import _transforms
22
23
  from .util._types import MaybeAwaitable
23
24
 
@@ -246,7 +247,22 @@ class Agent(Generic[TContext]):
246
247
  convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
247
248
  return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
248
249
 
249
- async def get_all_tools(self) -> list[Tool]:
250
+ async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
250
251
  """All agent tools, including MCP tools and function tools."""
251
252
  mcp_tools = await self.get_mcp_tools()
252
- return mcp_tools + self.tools
253
+
254
+ async def _check_tool_enabled(tool: Tool) -> bool:
255
+ if not isinstance(tool, FunctionTool):
256
+ return True
257
+
258
+ attr = tool.is_enabled
259
+ if isinstance(attr, bool):
260
+ return attr
261
+ res = attr(run_context, self)
262
+ if inspect.isawaitable(res):
263
+ return bool(await res)
264
+ return bool(res)
265
+
266
+ results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
267
+ enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
268
+ return [*mcp_tools, *enabled]
agents/agent_output.py CHANGED
@@ -38,7 +38,7 @@ class AgentOutputSchemaBase(abc.ABC):
38
38
  @abc.abstractmethod
39
39
  def is_strict_json_schema(self) -> bool:
40
40
  """Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
41
- features, but guarantees valis JSON. See here for details:
41
+ features, but guarantees valid JSON. See here for details:
42
42
  https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
43
43
  """
44
44
  pass
agents/exceptions.py CHANGED
@@ -1,12 +1,42 @@
1
- from typing import TYPE_CHECKING
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Any
2
5
 
3
6
  if TYPE_CHECKING:
7
+ from .agent import Agent
4
8
  from .guardrail import InputGuardrailResult, OutputGuardrailResult
9
+ from .items import ModelResponse, RunItem, TResponseInputItem
10
+ from .run_context import RunContextWrapper
11
+
12
+ from .util._pretty_print import pretty_print_run_error_details
13
+
14
+
15
+ @dataclass
16
+ class RunErrorDetails:
17
+ """Data collected from an agent run when an exception occurs."""
18
+
19
+ input: str | list[TResponseInputItem]
20
+ new_items: list[RunItem]
21
+ raw_responses: list[ModelResponse]
22
+ last_agent: Agent[Any]
23
+ context_wrapper: RunContextWrapper[Any]
24
+ input_guardrail_results: list[InputGuardrailResult]
25
+ output_guardrail_results: list[OutputGuardrailResult]
26
+
27
+ def __str__(self) -> str:
28
+ return pretty_print_run_error_details(self)
5
29
 
6
30
 
7
31
  class AgentsException(Exception):
8
32
  """Base class for all exceptions in the Agents SDK."""
9
33
 
34
+ run_data: RunErrorDetails | None
35
+
36
+ def __init__(self, *args: object) -> None:
37
+ super().__init__(*args)
38
+ self.run_data = None
39
+
10
40
 
11
41
  class MaxTurnsExceeded(AgentsException):
12
42
  """Exception raised when the maximum number of turns is exceeded."""
@@ -15,6 +45,7 @@ class MaxTurnsExceeded(AgentsException):
15
45
 
16
46
  def __init__(self, message: str):
17
47
  self.message = message
48
+ super().__init__(message)
18
49
 
19
50
 
20
51
  class ModelBehaviorError(AgentsException):
@@ -26,6 +57,7 @@ class ModelBehaviorError(AgentsException):
26
57
 
27
58
  def __init__(self, message: str):
28
59
  self.message = message
60
+ super().__init__(message)
29
61
 
30
62
 
31
63
  class UserError(AgentsException):
@@ -35,15 +67,16 @@ class UserError(AgentsException):
35
67
 
36
68
  def __init__(self, message: str):
37
69
  self.message = message
70
+ super().__init__(message)
38
71
 
39
72
 
40
73
  class InputGuardrailTripwireTriggered(AgentsException):
41
74
  """Exception raised when a guardrail tripwire is triggered."""
42
75
 
43
- guardrail_result: "InputGuardrailResult"
76
+ guardrail_result: InputGuardrailResult
44
77
  """The result data of the guardrail that was triggered."""
45
78
 
46
- def __init__(self, guardrail_result: "InputGuardrailResult"):
79
+ def __init__(self, guardrail_result: InputGuardrailResult):
47
80
  self.guardrail_result = guardrail_result
48
81
  super().__init__(
49
82
  f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
@@ -53,10 +86,10 @@ class InputGuardrailTripwireTriggered(AgentsException):
53
86
  class OutputGuardrailTripwireTriggered(AgentsException):
54
87
  """Exception raised when a guardrail tripwire is triggered."""
55
88
 
56
- guardrail_result: "OutputGuardrailResult"
89
+ guardrail_result: OutputGuardrailResult
57
90
  """The result data of the guardrail that was triggered."""
58
91
 
59
- def __init__(self, guardrail_result: "OutputGuardrailResult"):
92
+ def __init__(self, guardrail_result: OutputGuardrailResult):
60
93
  self.guardrail_result = guardrail_result
61
94
  super().__init__(
62
95
  f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
@@ -5,7 +5,7 @@ import time
5
5
  from collections.abc import AsyncIterator
6
6
  from typing import Any, Literal, cast, overload
7
7
 
8
- import litellm.types
8
+ from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
9
9
 
10
10
  from agents.exceptions import ModelBehaviorError
11
11
 
@@ -107,6 +107,18 @@ class LitellmModel(Model):
107
107
  input_tokens=response_usage.prompt_tokens,
108
108
  output_tokens=response_usage.completion_tokens,
109
109
  total_tokens=response_usage.total_tokens,
110
+ input_tokens_details=InputTokensDetails(
111
+ cached_tokens=getattr(
112
+ response_usage.prompt_tokens_details, "cached_tokens", 0
113
+ )
114
+ or 0
115
+ ),
116
+ output_tokens_details=OutputTokensDetails(
117
+ reasoning_tokens=getattr(
118
+ response_usage.completion_tokens_details, "reasoning_tokens", 0
119
+ )
120
+ or 0
121
+ ),
110
122
  )
111
123
  if response.usage
112
124
  else Usage()
@@ -1,4 +1,4 @@
1
- from typing import Optional
1
+ from __future__ import annotations
2
2
 
3
3
  import graphviz # type: ignore
4
4
 
@@ -31,7 +31,9 @@ def get_main_graph(agent: Agent) -> str:
31
31
  return "".join(parts)
32
32
 
33
33
 
34
- def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
34
+ def get_all_nodes(
35
+ agent: Agent, parent: Agent | None = None, visited: set[str] | None = None
36
+ ) -> str:
35
37
  """
36
38
  Recursively generates the nodes for the given agent and its handoffs in DOT format.
37
39
 
@@ -41,17 +43,23 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
41
43
  Returns:
42
44
  str: The DOT format string representing the nodes.
43
45
  """
46
+ if visited is None:
47
+ visited = set()
48
+ if agent.name in visited:
49
+ return ""
50
+ visited.add(agent.name)
51
+
44
52
  parts = []
45
53
 
46
54
  # Start and end the graph
47
- parts.append(
48
- '"__start__" [label="__start__", shape=ellipse, style=filled, '
49
- "fillcolor=lightblue, width=0.5, height=0.3];"
50
- '"__end__" [label="__end__", shape=ellipse, style=filled, '
51
- "fillcolor=lightblue, width=0.5, height=0.3];"
52
- )
53
- # Ensure parent agent node is colored
54
55
  if not parent:
56
+ parts.append(
57
+ '"__start__" [label="__start__", shape=ellipse, style=filled, '
58
+ "fillcolor=lightblue, width=0.5, height=0.3];"
59
+ '"__end__" [label="__end__", shape=ellipse, style=filled, '
60
+ "fillcolor=lightblue, width=0.5, height=0.3];"
61
+ )
62
+ # Ensure parent agent node is colored
55
63
  parts.append(
56
64
  f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, '
57
65
  "fillcolor=lightyellow, width=1.5, height=0.8];"
@@ -71,17 +79,20 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
71
79
  f"fillcolor=lightyellow, width=1.5, height=0.8];"
72
80
  )
73
81
  if isinstance(handoff, Agent):
74
- parts.append(
75
- f'"{handoff.name}" [label="{handoff.name}", '
76
- f"shape=box, style=filled, style=rounded, "
77
- f"fillcolor=lightyellow, width=1.5, height=0.8];"
78
- )
79
- parts.append(get_all_nodes(handoff))
82
+ if handoff.name not in visited:
83
+ parts.append(
84
+ f'"{handoff.name}" [label="{handoff.name}", '
85
+ f"shape=box, style=filled, style=rounded, "
86
+ f"fillcolor=lightyellow, width=1.5, height=0.8];"
87
+ )
88
+ parts.append(get_all_nodes(handoff, agent, visited))
80
89
 
81
90
  return "".join(parts)
82
91
 
83
92
 
84
- def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
93
+ def get_all_edges(
94
+ agent: Agent, parent: Agent | None = None, visited: set[str] | None = None
95
+ ) -> str:
85
96
  """
86
97
  Recursively generates the edges for the given agent and its handoffs in DOT format.
87
98
 
@@ -92,6 +103,12 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
92
103
  Returns:
93
104
  str: The DOT format string representing the edges.
94
105
  """
106
+ if visited is None:
107
+ visited = set()
108
+ if agent.name in visited:
109
+ return ""
110
+ visited.add(agent.name)
111
+
95
112
  parts = []
96
113
 
97
114
  if not parent:
@@ -109,7 +126,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
109
126
  if isinstance(handoff, Agent):
110
127
  parts.append(f"""
111
128
  "{agent.name}" -> "{handoff.name}";""")
112
- parts.append(get_all_edges(handoff, agent))
129
+ parts.append(get_all_edges(handoff, agent, visited))
113
130
 
114
131
  if not agent.handoffs and not isinstance(agent, Tool): # type: ignore
115
132
  parts.append(f'"{agent.name}" -> "__end__";')
@@ -117,7 +134,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
117
134
  return "".join(parts)
118
135
 
119
136
 
120
- def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source:
137
+ def draw_graph(agent: Agent, filename: str | None = None) -> graphviz.Source:
121
138
  """
122
139
  Draws the graph for the given agent and optionally saves it as a PNG file.
123
140
 
agents/handoffs.py CHANGED
@@ -168,7 +168,7 @@ def handoff(
168
168
  input_filter: a function that filters the inputs that are passed to the next agent.
169
169
  """
170
170
  assert (on_handoff and input_type) or not (on_handoff and input_type), (
171
- "You must provide either both on_input and input_type, or neither"
171
+ "You must provide either both on_handoff and input_type, or neither"
172
172
  )
173
173
  type_adapter: TypeAdapter[Any] | None
174
174
  if input_type is not None: