openai-agents 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +20 -0
- agents/_run_impl.py +230 -6
- agents/agent.py +19 -3
- agents/agent_output.py +1 -1
- agents/exceptions.py +38 -5
- agents/extensions/models/litellm_model.py +13 -1
- agents/extensions/visualization.py +35 -18
- agents/handoffs.py +1 -1
- agents/items.py +57 -3
- agents/mcp/server.py +9 -7
- agents/mcp/util.py +1 -1
- agents/models/chatcmpl_stream_handler.py +25 -1
- agents/models/openai_chatcompletions.py +31 -6
- agents/models/openai_responses.py +44 -13
- agents/result.py +43 -13
- agents/run.py +35 -6
- agents/stream_events.py +3 -0
- agents/tool.py +128 -3
- agents/tracing/processors.py +29 -3
- agents/usage.py +21 -1
- agents/util/_pretty_print.py +12 -0
- agents/voice/model.py +2 -0
- {openai_agents-0.0.15.dist-info → openai_agents-0.0.17.dist-info}/METADATA +2 -2
- {openai_agents-0.0.15.dist-info → openai_agents-0.0.17.dist-info}/RECORD +26 -26
- {openai_agents-0.0.15.dist-info → openai_agents-0.0.17.dist-info}/WHEEL +0 -0
- {openai_agents-0.0.15.dist-info → openai_agents-0.0.17.dist-info}/licenses/LICENSE +0 -0
agents/__init__.py
CHANGED
|
@@ -14,6 +14,7 @@ from .exceptions import (
|
|
|
14
14
|
MaxTurnsExceeded,
|
|
15
15
|
ModelBehaviorError,
|
|
16
16
|
OutputGuardrailTripwireTriggered,
|
|
17
|
+
RunErrorDetails,
|
|
17
18
|
UserError,
|
|
18
19
|
)
|
|
19
20
|
from .guardrail import (
|
|
@@ -54,10 +55,19 @@ from .stream_events import (
|
|
|
54
55
|
StreamEvent,
|
|
55
56
|
)
|
|
56
57
|
from .tool import (
|
|
58
|
+
CodeInterpreterTool,
|
|
57
59
|
ComputerTool,
|
|
58
60
|
FileSearchTool,
|
|
59
61
|
FunctionTool,
|
|
60
62
|
FunctionToolResult,
|
|
63
|
+
HostedMCPTool,
|
|
64
|
+
ImageGenerationTool,
|
|
65
|
+
LocalShellCommandRequest,
|
|
66
|
+
LocalShellExecutor,
|
|
67
|
+
LocalShellTool,
|
|
68
|
+
MCPToolApprovalFunction,
|
|
69
|
+
MCPToolApprovalFunctionResult,
|
|
70
|
+
MCPToolApprovalRequest,
|
|
61
71
|
Tool,
|
|
62
72
|
WebSearchTool,
|
|
63
73
|
default_tool_error_function,
|
|
@@ -195,6 +205,7 @@ __all__ = [
|
|
|
195
205
|
"AgentHooks",
|
|
196
206
|
"RunContextWrapper",
|
|
197
207
|
"TContext",
|
|
208
|
+
"RunErrorDetails",
|
|
198
209
|
"RunResult",
|
|
199
210
|
"RunResultStreaming",
|
|
200
211
|
"RunConfig",
|
|
@@ -206,8 +217,17 @@ __all__ = [
|
|
|
206
217
|
"FunctionToolResult",
|
|
207
218
|
"ComputerTool",
|
|
208
219
|
"FileSearchTool",
|
|
220
|
+
"CodeInterpreterTool",
|
|
221
|
+
"ImageGenerationTool",
|
|
222
|
+
"LocalShellCommandRequest",
|
|
223
|
+
"LocalShellExecutor",
|
|
224
|
+
"LocalShellTool",
|
|
209
225
|
"Tool",
|
|
210
226
|
"WebSearchTool",
|
|
227
|
+
"HostedMCPTool",
|
|
228
|
+
"MCPToolApprovalFunction",
|
|
229
|
+
"MCPToolApprovalRequest",
|
|
230
|
+
"MCPToolApprovalFunctionResult",
|
|
211
231
|
"function_tool",
|
|
212
232
|
"Usage",
|
|
213
233
|
"add_trace_processor",
|
agents/_run_impl.py
CHANGED
|
@@ -14,6 +14,9 @@ from openai.types.responses import (
|
|
|
14
14
|
ResponseFunctionWebSearch,
|
|
15
15
|
ResponseOutputMessage,
|
|
16
16
|
)
|
|
17
|
+
from openai.types.responses.response_code_interpreter_tool_call import (
|
|
18
|
+
ResponseCodeInterpreterToolCall,
|
|
19
|
+
)
|
|
17
20
|
from openai.types.responses.response_computer_tool_call import (
|
|
18
21
|
ActionClick,
|
|
19
22
|
ActionDoubleClick,
|
|
@@ -25,7 +28,14 @@ from openai.types.responses.response_computer_tool_call import (
|
|
|
25
28
|
ActionType,
|
|
26
29
|
ActionWait,
|
|
27
30
|
)
|
|
28
|
-
from openai.types.responses.response_input_param import ComputerCallOutput
|
|
31
|
+
from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse
|
|
32
|
+
from openai.types.responses.response_output_item import (
|
|
33
|
+
ImageGenerationCall,
|
|
34
|
+
LocalShellCall,
|
|
35
|
+
McpApprovalRequest,
|
|
36
|
+
McpCall,
|
|
37
|
+
McpListTools,
|
|
38
|
+
)
|
|
29
39
|
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
|
30
40
|
|
|
31
41
|
from .agent import Agent, ToolsToFinalOutputResult
|
|
@@ -38,6 +48,9 @@ from .items import (
|
|
|
38
48
|
HandoffCallItem,
|
|
39
49
|
HandoffOutputItem,
|
|
40
50
|
ItemHelpers,
|
|
51
|
+
MCPApprovalRequestItem,
|
|
52
|
+
MCPApprovalResponseItem,
|
|
53
|
+
MCPListToolsItem,
|
|
41
54
|
MessageOutputItem,
|
|
42
55
|
ModelResponse,
|
|
43
56
|
ReasoningItem,
|
|
@@ -52,7 +65,16 @@ from .model_settings import ModelSettings
|
|
|
52
65
|
from .models.interface import ModelTracing
|
|
53
66
|
from .run_context import RunContextWrapper, TContext
|
|
54
67
|
from .stream_events import RunItemStreamEvent, StreamEvent
|
|
55
|
-
from .tool import
|
|
68
|
+
from .tool import (
|
|
69
|
+
ComputerTool,
|
|
70
|
+
FunctionTool,
|
|
71
|
+
FunctionToolResult,
|
|
72
|
+
HostedMCPTool,
|
|
73
|
+
LocalShellCommandRequest,
|
|
74
|
+
LocalShellTool,
|
|
75
|
+
MCPToolApprovalRequest,
|
|
76
|
+
Tool,
|
|
77
|
+
)
|
|
56
78
|
from .tracing import (
|
|
57
79
|
SpanError,
|
|
58
80
|
Trace,
|
|
@@ -112,15 +134,29 @@ class ToolRunComputerAction:
|
|
|
112
134
|
computer_tool: ComputerTool
|
|
113
135
|
|
|
114
136
|
|
|
137
|
+
@dataclass
|
|
138
|
+
class ToolRunMCPApprovalRequest:
|
|
139
|
+
request_item: McpApprovalRequest
|
|
140
|
+
mcp_tool: HostedMCPTool
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@dataclass
|
|
144
|
+
class ToolRunLocalShellCall:
|
|
145
|
+
tool_call: LocalShellCall
|
|
146
|
+
local_shell_tool: LocalShellTool
|
|
147
|
+
|
|
148
|
+
|
|
115
149
|
@dataclass
|
|
116
150
|
class ProcessedResponse:
|
|
117
151
|
new_items: list[RunItem]
|
|
118
152
|
handoffs: list[ToolRunHandoff]
|
|
119
153
|
functions: list[ToolRunFunction]
|
|
120
154
|
computer_actions: list[ToolRunComputerAction]
|
|
155
|
+
local_shell_calls: list[ToolRunLocalShellCall]
|
|
121
156
|
tools_used: list[str] # Names of all tools used, including hosted tools
|
|
157
|
+
mcp_approval_requests: list[ToolRunMCPApprovalRequest] # Only requests with callbacks
|
|
122
158
|
|
|
123
|
-
def
|
|
159
|
+
def has_tools_or_approvals_to_run(self) -> bool:
|
|
124
160
|
# Handoffs, functions and computer actions need local processing
|
|
125
161
|
# Hosted tools have already run, so there's nothing to do.
|
|
126
162
|
return any(
|
|
@@ -128,6 +164,8 @@ class ProcessedResponse:
|
|
|
128
164
|
self.handoffs,
|
|
129
165
|
self.functions,
|
|
130
166
|
self.computer_actions,
|
|
167
|
+
self.local_shell_calls,
|
|
168
|
+
self.mcp_approval_requests,
|
|
131
169
|
]
|
|
132
170
|
)
|
|
133
171
|
|
|
@@ -226,7 +264,16 @@ class RunImpl:
|
|
|
226
264
|
new_step_items.extend([result.run_item for result in function_results])
|
|
227
265
|
new_step_items.extend(computer_results)
|
|
228
266
|
|
|
229
|
-
#
|
|
267
|
+
# Next, run the MCP approval requests
|
|
268
|
+
if processed_response.mcp_approval_requests:
|
|
269
|
+
approval_results = await cls.execute_mcp_approval_requests(
|
|
270
|
+
agent=agent,
|
|
271
|
+
approval_requests=processed_response.mcp_approval_requests,
|
|
272
|
+
context_wrapper=context_wrapper,
|
|
273
|
+
)
|
|
274
|
+
new_step_items.extend(approval_results)
|
|
275
|
+
|
|
276
|
+
# Next, check if there are any handoffs
|
|
230
277
|
if run_handoffs := processed_response.handoffs:
|
|
231
278
|
return await cls.execute_handoffs(
|
|
232
279
|
agent=agent,
|
|
@@ -240,7 +287,7 @@ class RunImpl:
|
|
|
240
287
|
run_config=run_config,
|
|
241
288
|
)
|
|
242
289
|
|
|
243
|
-
#
|
|
290
|
+
# Next, we'll check if the tool use should result in a final output
|
|
244
291
|
check_tool_use = await cls._check_for_final_output_from_tools(
|
|
245
292
|
agent=agent,
|
|
246
293
|
tool_results=function_results,
|
|
@@ -295,7 +342,7 @@ class RunImpl:
|
|
|
295
342
|
)
|
|
296
343
|
elif (
|
|
297
344
|
not output_schema or output_schema.is_plain_text()
|
|
298
|
-
) and not processed_response.
|
|
345
|
+
) and not processed_response.has_tools_or_approvals_to_run():
|
|
299
346
|
return await cls.execute_final_output(
|
|
300
347
|
agent=agent,
|
|
301
348
|
original_input=original_input,
|
|
@@ -343,10 +390,20 @@ class RunImpl:
|
|
|
343
390
|
run_handoffs = []
|
|
344
391
|
functions = []
|
|
345
392
|
computer_actions = []
|
|
393
|
+
local_shell_calls = []
|
|
394
|
+
mcp_approval_requests = []
|
|
346
395
|
tools_used: list[str] = []
|
|
347
396
|
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
|
|
348
397
|
function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
|
|
349
398
|
computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)
|
|
399
|
+
local_shell_tool = next(
|
|
400
|
+
(tool for tool in all_tools if isinstance(tool, LocalShellTool)), None
|
|
401
|
+
)
|
|
402
|
+
hosted_mcp_server_map = {
|
|
403
|
+
tool.tool_config["server_label"]: tool
|
|
404
|
+
for tool in all_tools
|
|
405
|
+
if isinstance(tool, HostedMCPTool)
|
|
406
|
+
}
|
|
350
407
|
|
|
351
408
|
for output in response.output:
|
|
352
409
|
if isinstance(output, ResponseOutputMessage):
|
|
@@ -375,6 +432,57 @@ class RunImpl:
|
|
|
375
432
|
computer_actions.append(
|
|
376
433
|
ToolRunComputerAction(tool_call=output, computer_tool=computer_tool)
|
|
377
434
|
)
|
|
435
|
+
elif isinstance(output, McpApprovalRequest):
|
|
436
|
+
items.append(MCPApprovalRequestItem(raw_item=output, agent=agent))
|
|
437
|
+
if output.server_label not in hosted_mcp_server_map:
|
|
438
|
+
_error_tracing.attach_error_to_current_span(
|
|
439
|
+
SpanError(
|
|
440
|
+
message="MCP server label not found",
|
|
441
|
+
data={"server_label": output.server_label},
|
|
442
|
+
)
|
|
443
|
+
)
|
|
444
|
+
raise ModelBehaviorError(f"MCP server label {output.server_label} not found")
|
|
445
|
+
else:
|
|
446
|
+
server = hosted_mcp_server_map[output.server_label]
|
|
447
|
+
if server.on_approval_request:
|
|
448
|
+
mcp_approval_requests.append(
|
|
449
|
+
ToolRunMCPApprovalRequest(
|
|
450
|
+
request_item=output,
|
|
451
|
+
mcp_tool=server,
|
|
452
|
+
)
|
|
453
|
+
)
|
|
454
|
+
else:
|
|
455
|
+
logger.warning(
|
|
456
|
+
f"MCP server {output.server_label} has no on_approval_request hook"
|
|
457
|
+
)
|
|
458
|
+
elif isinstance(output, McpListTools):
|
|
459
|
+
items.append(MCPListToolsItem(raw_item=output, agent=agent))
|
|
460
|
+
elif isinstance(output, McpCall):
|
|
461
|
+
items.append(ToolCallItem(raw_item=output, agent=agent))
|
|
462
|
+
tools_used.append("mcp")
|
|
463
|
+
elif isinstance(output, ImageGenerationCall):
|
|
464
|
+
items.append(ToolCallItem(raw_item=output, agent=agent))
|
|
465
|
+
tools_used.append("image_generation")
|
|
466
|
+
elif isinstance(output, ResponseCodeInterpreterToolCall):
|
|
467
|
+
items.append(ToolCallItem(raw_item=output, agent=agent))
|
|
468
|
+
tools_used.append("code_interpreter")
|
|
469
|
+
elif isinstance(output, LocalShellCall):
|
|
470
|
+
items.append(ToolCallItem(raw_item=output, agent=agent))
|
|
471
|
+
tools_used.append("local_shell")
|
|
472
|
+
if not local_shell_tool:
|
|
473
|
+
_error_tracing.attach_error_to_current_span(
|
|
474
|
+
SpanError(
|
|
475
|
+
message="Local shell tool not found",
|
|
476
|
+
data={},
|
|
477
|
+
)
|
|
478
|
+
)
|
|
479
|
+
raise ModelBehaviorError(
|
|
480
|
+
"Model produced local shell call without a local shell tool."
|
|
481
|
+
)
|
|
482
|
+
local_shell_calls.append(
|
|
483
|
+
ToolRunLocalShellCall(tool_call=output, local_shell_tool=local_shell_tool)
|
|
484
|
+
)
|
|
485
|
+
|
|
378
486
|
elif not isinstance(output, ResponseFunctionToolCall):
|
|
379
487
|
logger.warning(f"Unexpected output type, ignoring: {type(output)}")
|
|
380
488
|
continue
|
|
@@ -416,7 +524,9 @@ class RunImpl:
|
|
|
416
524
|
handoffs=run_handoffs,
|
|
417
525
|
functions=functions,
|
|
418
526
|
computer_actions=computer_actions,
|
|
527
|
+
local_shell_calls=local_shell_calls,
|
|
419
528
|
tools_used=tools_used,
|
|
529
|
+
mcp_approval_requests=mcp_approval_requests,
|
|
420
530
|
)
|
|
421
531
|
|
|
422
532
|
@classmethod
|
|
@@ -489,6 +599,30 @@ class RunImpl:
|
|
|
489
599
|
for tool_run, result in zip(tool_runs, results)
|
|
490
600
|
]
|
|
491
601
|
|
|
602
|
+
@classmethod
|
|
603
|
+
async def execute_local_shell_calls(
|
|
604
|
+
cls,
|
|
605
|
+
*,
|
|
606
|
+
agent: Agent[TContext],
|
|
607
|
+
calls: list[ToolRunLocalShellCall],
|
|
608
|
+
context_wrapper: RunContextWrapper[TContext],
|
|
609
|
+
hooks: RunHooks[TContext],
|
|
610
|
+
config: RunConfig,
|
|
611
|
+
) -> list[RunItem]:
|
|
612
|
+
results: list[RunItem] = []
|
|
613
|
+
# Need to run these serially, because each call can affect the local shell state
|
|
614
|
+
for call in calls:
|
|
615
|
+
results.append(
|
|
616
|
+
await LocalShellAction.execute(
|
|
617
|
+
agent=agent,
|
|
618
|
+
call=call,
|
|
619
|
+
hooks=hooks,
|
|
620
|
+
context_wrapper=context_wrapper,
|
|
621
|
+
config=config,
|
|
622
|
+
)
|
|
623
|
+
)
|
|
624
|
+
return results
|
|
625
|
+
|
|
492
626
|
@classmethod
|
|
493
627
|
async def execute_computer_actions(
|
|
494
628
|
cls,
|
|
@@ -643,6 +777,40 @@ class RunImpl:
|
|
|
643
777
|
next_step=NextStepHandoff(new_agent),
|
|
644
778
|
)
|
|
645
779
|
|
|
780
|
+
@classmethod
|
|
781
|
+
async def execute_mcp_approval_requests(
|
|
782
|
+
cls,
|
|
783
|
+
*,
|
|
784
|
+
agent: Agent[TContext],
|
|
785
|
+
approval_requests: list[ToolRunMCPApprovalRequest],
|
|
786
|
+
context_wrapper: RunContextWrapper[TContext],
|
|
787
|
+
) -> list[RunItem]:
|
|
788
|
+
async def run_single_approval(approval_request: ToolRunMCPApprovalRequest) -> RunItem:
|
|
789
|
+
callback = approval_request.mcp_tool.on_approval_request
|
|
790
|
+
assert callback is not None, "Callback is required for MCP approval requests"
|
|
791
|
+
maybe_awaitable_result = callback(
|
|
792
|
+
MCPToolApprovalRequest(context_wrapper, approval_request.request_item)
|
|
793
|
+
)
|
|
794
|
+
if inspect.isawaitable(maybe_awaitable_result):
|
|
795
|
+
result = await maybe_awaitable_result
|
|
796
|
+
else:
|
|
797
|
+
result = maybe_awaitable_result
|
|
798
|
+
reason = result.get("reason", None)
|
|
799
|
+
raw_item: McpApprovalResponse = {
|
|
800
|
+
"approval_request_id": approval_request.request_item.id,
|
|
801
|
+
"approve": result["approve"],
|
|
802
|
+
"type": "mcp_approval_response",
|
|
803
|
+
}
|
|
804
|
+
if not result["approve"] and reason:
|
|
805
|
+
raw_item["reason"] = reason
|
|
806
|
+
return MCPApprovalResponseItem(
|
|
807
|
+
raw_item=raw_item,
|
|
808
|
+
agent=agent,
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
tasks = [run_single_approval(approval_request) for approval_request in approval_requests]
|
|
812
|
+
return await asyncio.gather(*tasks)
|
|
813
|
+
|
|
646
814
|
@classmethod
|
|
647
815
|
async def execute_final_output(
|
|
648
816
|
cls,
|
|
@@ -727,6 +895,11 @@ class RunImpl:
|
|
|
727
895
|
event = RunItemStreamEvent(item=item, name="tool_output")
|
|
728
896
|
elif isinstance(item, ReasoningItem):
|
|
729
897
|
event = RunItemStreamEvent(item=item, name="reasoning_item_created")
|
|
898
|
+
elif isinstance(item, MCPApprovalRequestItem):
|
|
899
|
+
event = RunItemStreamEvent(item=item, name="mcp_approval_requested")
|
|
900
|
+
elif isinstance(item, MCPListToolsItem):
|
|
901
|
+
event = RunItemStreamEvent(item=item, name="mcp_list_tools")
|
|
902
|
+
|
|
730
903
|
else:
|
|
731
904
|
logger.warning(f"Unexpected item type: {type(item)}")
|
|
732
905
|
event = None
|
|
@@ -919,3 +1092,54 @@ class ComputerAction:
|
|
|
919
1092
|
await computer.wait()
|
|
920
1093
|
|
|
921
1094
|
return await computer.screenshot()
|
|
1095
|
+
|
|
1096
|
+
|
|
1097
|
+
class LocalShellAction:
|
|
1098
|
+
@classmethod
|
|
1099
|
+
async def execute(
|
|
1100
|
+
cls,
|
|
1101
|
+
*,
|
|
1102
|
+
agent: Agent[TContext],
|
|
1103
|
+
call: ToolRunLocalShellCall,
|
|
1104
|
+
hooks: RunHooks[TContext],
|
|
1105
|
+
context_wrapper: RunContextWrapper[TContext],
|
|
1106
|
+
config: RunConfig,
|
|
1107
|
+
) -> RunItem:
|
|
1108
|
+
await asyncio.gather(
|
|
1109
|
+
hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool),
|
|
1110
|
+
(
|
|
1111
|
+
agent.hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool)
|
|
1112
|
+
if agent.hooks
|
|
1113
|
+
else _coro.noop_coroutine()
|
|
1114
|
+
),
|
|
1115
|
+
)
|
|
1116
|
+
|
|
1117
|
+
request = LocalShellCommandRequest(
|
|
1118
|
+
ctx_wrapper=context_wrapper,
|
|
1119
|
+
data=call.tool_call,
|
|
1120
|
+
)
|
|
1121
|
+
output = call.local_shell_tool.executor(request)
|
|
1122
|
+
if inspect.isawaitable(output):
|
|
1123
|
+
result = await output
|
|
1124
|
+
else:
|
|
1125
|
+
result = output
|
|
1126
|
+
|
|
1127
|
+
await asyncio.gather(
|
|
1128
|
+
hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result),
|
|
1129
|
+
(
|
|
1130
|
+
agent.hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result)
|
|
1131
|
+
if agent.hooks
|
|
1132
|
+
else _coro.noop_coroutine()
|
|
1133
|
+
),
|
|
1134
|
+
)
|
|
1135
|
+
|
|
1136
|
+
return ToolCallOutputItem(
|
|
1137
|
+
agent=agent,
|
|
1138
|
+
output=output,
|
|
1139
|
+
raw_item={
|
|
1140
|
+
"type": "local_shell_call_output",
|
|
1141
|
+
"id": call.tool_call.call_id,
|
|
1142
|
+
"output": result,
|
|
1143
|
+
# "id": "out" + call.tool_call.id, # TODO remove this, it should be optional
|
|
1144
|
+
},
|
|
1145
|
+
)
|
agents/agent.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import dataclasses
|
|
4
5
|
import inspect
|
|
5
6
|
from collections.abc import Awaitable
|
|
@@ -17,7 +18,7 @@ from .mcp import MCPUtil
|
|
|
17
18
|
from .model_settings import ModelSettings
|
|
18
19
|
from .models.interface import Model
|
|
19
20
|
from .run_context import RunContextWrapper, TContext
|
|
20
|
-
from .tool import FunctionToolResult, Tool, function_tool
|
|
21
|
+
from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
|
|
21
22
|
from .util import _transforms
|
|
22
23
|
from .util._types import MaybeAwaitable
|
|
23
24
|
|
|
@@ -246,7 +247,22 @@ class Agent(Generic[TContext]):
|
|
|
246
247
|
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
|
|
247
248
|
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
|
|
248
249
|
|
|
249
|
-
async def get_all_tools(self) -> list[Tool]:
|
|
250
|
+
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
|
|
250
251
|
"""All agent tools, including MCP tools and function tools."""
|
|
251
252
|
mcp_tools = await self.get_mcp_tools()
|
|
252
|
-
|
|
253
|
+
|
|
254
|
+
async def _check_tool_enabled(tool: Tool) -> bool:
|
|
255
|
+
if not isinstance(tool, FunctionTool):
|
|
256
|
+
return True
|
|
257
|
+
|
|
258
|
+
attr = tool.is_enabled
|
|
259
|
+
if isinstance(attr, bool):
|
|
260
|
+
return attr
|
|
261
|
+
res = attr(run_context, self)
|
|
262
|
+
if inspect.isawaitable(res):
|
|
263
|
+
return bool(await res)
|
|
264
|
+
return bool(res)
|
|
265
|
+
|
|
266
|
+
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
|
|
267
|
+
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
|
|
268
|
+
return [*mcp_tools, *enabled]
|
agents/agent_output.py
CHANGED
|
@@ -38,7 +38,7 @@ class AgentOutputSchemaBase(abc.ABC):
|
|
|
38
38
|
@abc.abstractmethod
|
|
39
39
|
def is_strict_json_schema(self) -> bool:
|
|
40
40
|
"""Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
|
|
41
|
-
features, but guarantees
|
|
41
|
+
features, but guarantees valid JSON. See here for details:
|
|
42
42
|
https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
|
|
43
43
|
"""
|
|
44
44
|
pass
|
agents/exceptions.py
CHANGED
|
@@ -1,12 +1,42 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
2
5
|
|
|
3
6
|
if TYPE_CHECKING:
|
|
7
|
+
from .agent import Agent
|
|
4
8
|
from .guardrail import InputGuardrailResult, OutputGuardrailResult
|
|
9
|
+
from .items import ModelResponse, RunItem, TResponseInputItem
|
|
10
|
+
from .run_context import RunContextWrapper
|
|
11
|
+
|
|
12
|
+
from .util._pretty_print import pretty_print_run_error_details
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class RunErrorDetails:
|
|
17
|
+
"""Data collected from an agent run when an exception occurs."""
|
|
18
|
+
|
|
19
|
+
input: str | list[TResponseInputItem]
|
|
20
|
+
new_items: list[RunItem]
|
|
21
|
+
raw_responses: list[ModelResponse]
|
|
22
|
+
last_agent: Agent[Any]
|
|
23
|
+
context_wrapper: RunContextWrapper[Any]
|
|
24
|
+
input_guardrail_results: list[InputGuardrailResult]
|
|
25
|
+
output_guardrail_results: list[OutputGuardrailResult]
|
|
26
|
+
|
|
27
|
+
def __str__(self) -> str:
|
|
28
|
+
return pretty_print_run_error_details(self)
|
|
5
29
|
|
|
6
30
|
|
|
7
31
|
class AgentsException(Exception):
|
|
8
32
|
"""Base class for all exceptions in the Agents SDK."""
|
|
9
33
|
|
|
34
|
+
run_data: RunErrorDetails | None
|
|
35
|
+
|
|
36
|
+
def __init__(self, *args: object) -> None:
|
|
37
|
+
super().__init__(*args)
|
|
38
|
+
self.run_data = None
|
|
39
|
+
|
|
10
40
|
|
|
11
41
|
class MaxTurnsExceeded(AgentsException):
|
|
12
42
|
"""Exception raised when the maximum number of turns is exceeded."""
|
|
@@ -15,6 +45,7 @@ class MaxTurnsExceeded(AgentsException):
|
|
|
15
45
|
|
|
16
46
|
def __init__(self, message: str):
|
|
17
47
|
self.message = message
|
|
48
|
+
super().__init__(message)
|
|
18
49
|
|
|
19
50
|
|
|
20
51
|
class ModelBehaviorError(AgentsException):
|
|
@@ -26,6 +57,7 @@ class ModelBehaviorError(AgentsException):
|
|
|
26
57
|
|
|
27
58
|
def __init__(self, message: str):
|
|
28
59
|
self.message = message
|
|
60
|
+
super().__init__(message)
|
|
29
61
|
|
|
30
62
|
|
|
31
63
|
class UserError(AgentsException):
|
|
@@ -35,15 +67,16 @@ class UserError(AgentsException):
|
|
|
35
67
|
|
|
36
68
|
def __init__(self, message: str):
|
|
37
69
|
self.message = message
|
|
70
|
+
super().__init__(message)
|
|
38
71
|
|
|
39
72
|
|
|
40
73
|
class InputGuardrailTripwireTriggered(AgentsException):
|
|
41
74
|
"""Exception raised when a guardrail tripwire is triggered."""
|
|
42
75
|
|
|
43
|
-
guardrail_result:
|
|
76
|
+
guardrail_result: InputGuardrailResult
|
|
44
77
|
"""The result data of the guardrail that was triggered."""
|
|
45
78
|
|
|
46
|
-
def __init__(self, guardrail_result:
|
|
79
|
+
def __init__(self, guardrail_result: InputGuardrailResult):
|
|
47
80
|
self.guardrail_result = guardrail_result
|
|
48
81
|
super().__init__(
|
|
49
82
|
f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
|
|
@@ -53,10 +86,10 @@ class InputGuardrailTripwireTriggered(AgentsException):
|
|
|
53
86
|
class OutputGuardrailTripwireTriggered(AgentsException):
|
|
54
87
|
"""Exception raised when a guardrail tripwire is triggered."""
|
|
55
88
|
|
|
56
|
-
guardrail_result:
|
|
89
|
+
guardrail_result: OutputGuardrailResult
|
|
57
90
|
"""The result data of the guardrail that was triggered."""
|
|
58
91
|
|
|
59
|
-
def __init__(self, guardrail_result:
|
|
92
|
+
def __init__(self, guardrail_result: OutputGuardrailResult):
|
|
60
93
|
self.guardrail_result = guardrail_result
|
|
61
94
|
super().__init__(
|
|
62
95
|
f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
|
|
@@ -5,7 +5,7 @@ import time
|
|
|
5
5
|
from collections.abc import AsyncIterator
|
|
6
6
|
from typing import Any, Literal, cast, overload
|
|
7
7
|
|
|
8
|
-
|
|
8
|
+
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
|
|
9
9
|
|
|
10
10
|
from agents.exceptions import ModelBehaviorError
|
|
11
11
|
|
|
@@ -107,6 +107,18 @@ class LitellmModel(Model):
|
|
|
107
107
|
input_tokens=response_usage.prompt_tokens,
|
|
108
108
|
output_tokens=response_usage.completion_tokens,
|
|
109
109
|
total_tokens=response_usage.total_tokens,
|
|
110
|
+
input_tokens_details=InputTokensDetails(
|
|
111
|
+
cached_tokens=getattr(
|
|
112
|
+
response_usage.prompt_tokens_details, "cached_tokens", 0
|
|
113
|
+
)
|
|
114
|
+
or 0
|
|
115
|
+
),
|
|
116
|
+
output_tokens_details=OutputTokensDetails(
|
|
117
|
+
reasoning_tokens=getattr(
|
|
118
|
+
response_usage.completion_tokens_details, "reasoning_tokens", 0
|
|
119
|
+
)
|
|
120
|
+
or 0
|
|
121
|
+
),
|
|
110
122
|
)
|
|
111
123
|
if response.usage
|
|
112
124
|
else Usage()
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import graphviz # type: ignore
|
|
4
4
|
|
|
@@ -31,7 +31,9 @@ def get_main_graph(agent: Agent) -> str:
|
|
|
31
31
|
return "".join(parts)
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
def get_all_nodes(
|
|
34
|
+
def get_all_nodes(
|
|
35
|
+
agent: Agent, parent: Agent | None = None, visited: set[str] | None = None
|
|
36
|
+
) -> str:
|
|
35
37
|
"""
|
|
36
38
|
Recursively generates the nodes for the given agent and its handoffs in DOT format.
|
|
37
39
|
|
|
@@ -41,17 +43,23 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
|
|
|
41
43
|
Returns:
|
|
42
44
|
str: The DOT format string representing the nodes.
|
|
43
45
|
"""
|
|
46
|
+
if visited is None:
|
|
47
|
+
visited = set()
|
|
48
|
+
if agent.name in visited:
|
|
49
|
+
return ""
|
|
50
|
+
visited.add(agent.name)
|
|
51
|
+
|
|
44
52
|
parts = []
|
|
45
53
|
|
|
46
54
|
# Start and end the graph
|
|
47
|
-
parts.append(
|
|
48
|
-
'"__start__" [label="__start__", shape=ellipse, style=filled, '
|
|
49
|
-
"fillcolor=lightblue, width=0.5, height=0.3];"
|
|
50
|
-
'"__end__" [label="__end__", shape=ellipse, style=filled, '
|
|
51
|
-
"fillcolor=lightblue, width=0.5, height=0.3];"
|
|
52
|
-
)
|
|
53
|
-
# Ensure parent agent node is colored
|
|
54
55
|
if not parent:
|
|
56
|
+
parts.append(
|
|
57
|
+
'"__start__" [label="__start__", shape=ellipse, style=filled, '
|
|
58
|
+
"fillcolor=lightblue, width=0.5, height=0.3];"
|
|
59
|
+
'"__end__" [label="__end__", shape=ellipse, style=filled, '
|
|
60
|
+
"fillcolor=lightblue, width=0.5, height=0.3];"
|
|
61
|
+
)
|
|
62
|
+
# Ensure parent agent node is colored
|
|
55
63
|
parts.append(
|
|
56
64
|
f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, '
|
|
57
65
|
"fillcolor=lightyellow, width=1.5, height=0.8];"
|
|
@@ -71,17 +79,20 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
|
|
|
71
79
|
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
|
72
80
|
)
|
|
73
81
|
if isinstance(handoff, Agent):
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
82
|
+
if handoff.name not in visited:
|
|
83
|
+
parts.append(
|
|
84
|
+
f'"{handoff.name}" [label="{handoff.name}", '
|
|
85
|
+
f"shape=box, style=filled, style=rounded, "
|
|
86
|
+
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
|
87
|
+
)
|
|
88
|
+
parts.append(get_all_nodes(handoff, agent, visited))
|
|
80
89
|
|
|
81
90
|
return "".join(parts)
|
|
82
91
|
|
|
83
92
|
|
|
84
|
-
def get_all_edges(
|
|
93
|
+
def get_all_edges(
|
|
94
|
+
agent: Agent, parent: Agent | None = None, visited: set[str] | None = None
|
|
95
|
+
) -> str:
|
|
85
96
|
"""
|
|
86
97
|
Recursively generates the edges for the given agent and its handoffs in DOT format.
|
|
87
98
|
|
|
@@ -92,6 +103,12 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
|
|
|
92
103
|
Returns:
|
|
93
104
|
str: The DOT format string representing the edges.
|
|
94
105
|
"""
|
|
106
|
+
if visited is None:
|
|
107
|
+
visited = set()
|
|
108
|
+
if agent.name in visited:
|
|
109
|
+
return ""
|
|
110
|
+
visited.add(agent.name)
|
|
111
|
+
|
|
95
112
|
parts = []
|
|
96
113
|
|
|
97
114
|
if not parent:
|
|
@@ -109,7 +126,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
|
|
|
109
126
|
if isinstance(handoff, Agent):
|
|
110
127
|
parts.append(f"""
|
|
111
128
|
"{agent.name}" -> "{handoff.name}";""")
|
|
112
|
-
parts.append(get_all_edges(handoff, agent))
|
|
129
|
+
parts.append(get_all_edges(handoff, agent, visited))
|
|
113
130
|
|
|
114
131
|
if not agent.handoffs and not isinstance(agent, Tool): # type: ignore
|
|
115
132
|
parts.append(f'"{agent.name}" -> "__end__";')
|
|
@@ -117,7 +134,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
|
|
|
117
134
|
return "".join(parts)
|
|
118
135
|
|
|
119
136
|
|
|
120
|
-
def draw_graph(agent: Agent, filename:
|
|
137
|
+
def draw_graph(agent: Agent, filename: str | None = None) -> graphviz.Source:
|
|
121
138
|
"""
|
|
122
139
|
Draws the graph for the given agent and optionally saves it as a PNG file.
|
|
123
140
|
|
agents/handoffs.py
CHANGED
|
@@ -168,7 +168,7 @@ def handoff(
|
|
|
168
168
|
input_filter: a function that filters the inputs that are passed to the next agent.
|
|
169
169
|
"""
|
|
170
170
|
assert (on_handoff and input_type) or not (on_handoff and input_type), (
|
|
171
|
-
"You must provide either both
|
|
171
|
+
"You must provide either both on_handoff and input_type, or neither"
|
|
172
172
|
)
|
|
173
173
|
type_adapter: TypeAdapter[Any] | None
|
|
174
174
|
if input_type is not None:
|