openai-agents 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

agents/__init__.py CHANGED
@@ -21,6 +21,8 @@ from .exceptions import (
21
21
  ModelBehaviorError,
22
22
  OutputGuardrailTripwireTriggered,
23
23
  RunErrorDetails,
24
+ ToolInputGuardrailTripwireTriggered,
25
+ ToolOutputGuardrailTripwireTriggered,
24
26
  UserError,
25
27
  )
26
28
  from .guardrail import (
@@ -83,6 +85,17 @@ from .tool import (
83
85
  default_tool_error_function,
84
86
  function_tool,
85
87
  )
88
+ from .tool_guardrails import (
89
+ ToolGuardrailFunctionOutput,
90
+ ToolInputGuardrail,
91
+ ToolInputGuardrailData,
92
+ ToolInputGuardrailResult,
93
+ ToolOutputGuardrail,
94
+ ToolOutputGuardrailData,
95
+ ToolOutputGuardrailResult,
96
+ tool_input_guardrail,
97
+ tool_output_guardrail,
98
+ )
86
99
  from .tracing import (
87
100
  AgentSpanData,
88
101
  CustomSpanData,
@@ -125,7 +138,7 @@ from .version import __version__
125
138
 
126
139
 
127
140
  def set_default_openai_key(key: str, use_for_tracing: bool = True) -> None:
128
- """Set the default OpenAI API key to use for LLM requests (and optionally tracing(). This is
141
+ """Set the default OpenAI API key to use for LLM requests (and optionally tracing()). This is
129
142
  only necessary if the OPENAI_API_KEY environment variable is not already set.
130
143
 
131
144
  If provided, this key will be used instead of the OPENAI_API_KEY environment variable.
@@ -191,6 +204,8 @@ __all__ = [
191
204
  "AgentsException",
192
205
  "InputGuardrailTripwireTriggered",
193
206
  "OutputGuardrailTripwireTriggered",
207
+ "ToolInputGuardrailTripwireTriggered",
208
+ "ToolOutputGuardrailTripwireTriggered",
194
209
  "DynamicPromptFunction",
195
210
  "GenerateDynamicPromptData",
196
211
  "Prompt",
@@ -204,6 +219,15 @@ __all__ = [
204
219
  "GuardrailFunctionOutput",
205
220
  "input_guardrail",
206
221
  "output_guardrail",
222
+ "ToolInputGuardrail",
223
+ "ToolOutputGuardrail",
224
+ "ToolGuardrailFunctionOutput",
225
+ "ToolInputGuardrailData",
226
+ "ToolInputGuardrailResult",
227
+ "ToolOutputGuardrailData",
228
+ "ToolOutputGuardrailResult",
229
+ "tool_input_guardrail",
230
+ "tool_output_guardrail",
207
231
  "handoff",
208
232
  "Handoff",
209
233
  "HandoffInputData",
agents/_run_impl.py CHANGED
@@ -44,7 +44,13 @@ from openai.types.responses.response_reasoning_item import ResponseReasoningItem
44
44
  from .agent import Agent, ToolsToFinalOutputResult
45
45
  from .agent_output import AgentOutputSchemaBase
46
46
  from .computer import AsyncComputer, Computer
47
- from .exceptions import AgentsException, ModelBehaviorError, UserError
47
+ from .exceptions import (
48
+ AgentsException,
49
+ ModelBehaviorError,
50
+ ToolInputGuardrailTripwireTriggered,
51
+ ToolOutputGuardrailTripwireTriggered,
52
+ UserError,
53
+ )
48
54
  from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
49
55
  from .handoffs import Handoff, HandoffInputData
50
56
  from .items import (
@@ -80,6 +86,12 @@ from .tool import (
80
86
  Tool,
81
87
  )
82
88
  from .tool_context import ToolContext
89
+ from .tool_guardrails import (
90
+ ToolInputGuardrailData,
91
+ ToolInputGuardrailResult,
92
+ ToolOutputGuardrailData,
93
+ ToolOutputGuardrailResult,
94
+ )
83
95
  from .tracing import (
84
96
  SpanError,
85
97
  Trace,
@@ -208,6 +220,12 @@ class SingleStepResult:
208
220
  next_step: NextStepHandoff | NextStepFinalOutput | NextStepRunAgain
209
221
  """The next step to take."""
210
222
 
223
+ tool_input_guardrail_results: list[ToolInputGuardrailResult]
224
+ """Tool input guardrail results from this step."""
225
+
226
+ tool_output_guardrail_results: list[ToolOutputGuardrailResult]
227
+ """Tool output guardrail results from this step."""
228
+
211
229
  @property
212
230
  def generated_items(self) -> list[RunItem]:
213
231
  """Items generated during the agent run (i.e. everything generated after
@@ -250,7 +268,10 @@ class RunImpl:
250
268
  new_step_items.extend(processed_response.new_items)
251
269
 
252
270
  # First, lets run the tool calls - function tools and computer actions
253
- function_results, computer_results = await asyncio.gather(
271
+ (
272
+ (function_results, tool_input_guardrail_results, tool_output_guardrail_results),
273
+ computer_results,
274
+ ) = await asyncio.gather(
254
275
  cls.execute_function_tool_calls(
255
276
  agent=agent,
256
277
  tool_runs=processed_response.functions,
@@ -320,6 +341,8 @@ class RunImpl:
320
341
  final_output=check_tool_use.final_output,
321
342
  hooks=hooks,
322
343
  context_wrapper=context_wrapper,
344
+ tool_input_guardrail_results=tool_input_guardrail_results,
345
+ tool_output_guardrail_results=tool_output_guardrail_results,
323
346
  )
324
347
 
325
348
  # Now we can check if the model also produced a final output
@@ -343,6 +366,8 @@ class RunImpl:
343
366
  final_output=final_output,
344
367
  hooks=hooks,
345
368
  context_wrapper=context_wrapper,
369
+ tool_input_guardrail_results=tool_input_guardrail_results,
370
+ tool_output_guardrail_results=tool_output_guardrail_results,
346
371
  )
347
372
  elif not output_schema or output_schema.is_plain_text():
348
373
  return await cls.execute_final_output(
@@ -354,6 +379,8 @@ class RunImpl:
354
379
  final_output=potential_final_output_text or "",
355
380
  hooks=hooks,
356
381
  context_wrapper=context_wrapper,
382
+ tool_input_guardrail_results=tool_input_guardrail_results,
383
+ tool_output_guardrail_results=tool_output_guardrail_results,
357
384
  )
358
385
 
359
386
  # If there's no final output, we can just run again
@@ -363,6 +390,8 @@ class RunImpl:
363
390
  pre_step_items=pre_step_items,
364
391
  new_step_items=new_step_items,
365
392
  next_step=NextStepRunAgain(),
393
+ tool_input_guardrail_results=tool_input_guardrail_results,
394
+ tool_output_guardrail_results=tool_output_guardrail_results,
366
395
  )
367
396
 
368
397
  @classmethod
@@ -547,6 +576,155 @@ class RunImpl:
547
576
  mcp_approval_requests=mcp_approval_requests,
548
577
  )
549
578
 
579
+ @classmethod
580
+ async def _execute_input_guardrails(
581
+ cls,
582
+ *,
583
+ func_tool: FunctionTool,
584
+ tool_context: ToolContext[TContext],
585
+ agent: Agent[TContext],
586
+ tool_input_guardrail_results: list[ToolInputGuardrailResult],
587
+ ) -> str | None:
588
+ """Execute input guardrails for a tool.
589
+
590
+ Args:
591
+ func_tool: The function tool being executed.
592
+ tool_context: The tool execution context.
593
+ agent: The agent executing the tool.
594
+ tool_input_guardrail_results: List to append guardrail results to.
595
+
596
+ Returns:
597
+ None if tool execution should proceed, or a message string if execution should be
598
+ skipped.
599
+
600
+ Raises:
601
+ ToolInputGuardrailTripwireTriggered: If a guardrail triggers an exception.
602
+ """
603
+ if not func_tool.tool_input_guardrails:
604
+ return None
605
+
606
+ for guardrail in func_tool.tool_input_guardrails:
607
+ gr_out = await guardrail.run(
608
+ ToolInputGuardrailData(
609
+ context=tool_context,
610
+ agent=agent,
611
+ )
612
+ )
613
+
614
+ # Store the guardrail result
615
+ tool_input_guardrail_results.append(
616
+ ToolInputGuardrailResult(
617
+ guardrail=guardrail,
618
+ output=gr_out,
619
+ )
620
+ )
621
+
622
+ # Handle different behavior types
623
+ if gr_out.behavior["type"] == "raise_exception":
624
+ raise ToolInputGuardrailTripwireTriggered(guardrail=guardrail, output=gr_out)
625
+ elif gr_out.behavior["type"] == "reject_content":
626
+ # Set final_result to the message and skip tool execution
627
+ return gr_out.behavior["message"]
628
+ elif gr_out.behavior["type"] == "allow":
629
+ # Continue to next guardrail or tool execution
630
+ continue
631
+
632
+ return None
633
+
634
+ @classmethod
635
+ async def _execute_output_guardrails(
636
+ cls,
637
+ *,
638
+ func_tool: FunctionTool,
639
+ tool_context: ToolContext[TContext],
640
+ agent: Agent[TContext],
641
+ real_result: Any,
642
+ tool_output_guardrail_results: list[ToolOutputGuardrailResult],
643
+ ) -> Any:
644
+ """Execute output guardrails for a tool.
645
+
646
+ Args:
647
+ func_tool: The function tool being executed.
648
+ tool_context: The tool execution context.
649
+ agent: The agent executing the tool.
650
+ real_result: The actual result from the tool execution.
651
+ tool_output_guardrail_results: List to append guardrail results to.
652
+
653
+ Returns:
654
+ The final result after guardrail processing (may be modified).
655
+
656
+ Raises:
657
+ ToolOutputGuardrailTripwireTriggered: If a guardrail triggers an exception.
658
+ """
659
+ if not func_tool.tool_output_guardrails:
660
+ return real_result
661
+
662
+ final_result = real_result
663
+ for output_guardrail in func_tool.tool_output_guardrails:
664
+ gr_out = await output_guardrail.run(
665
+ ToolOutputGuardrailData(
666
+ context=tool_context,
667
+ agent=agent,
668
+ output=real_result,
669
+ )
670
+ )
671
+
672
+ # Store the guardrail result
673
+ tool_output_guardrail_results.append(
674
+ ToolOutputGuardrailResult(
675
+ guardrail=output_guardrail,
676
+ output=gr_out,
677
+ )
678
+ )
679
+
680
+ # Handle different behavior types
681
+ if gr_out.behavior["type"] == "raise_exception":
682
+ raise ToolOutputGuardrailTripwireTriggered(
683
+ guardrail=output_guardrail, output=gr_out
684
+ )
685
+ elif gr_out.behavior["type"] == "reject_content":
686
+ # Override the result with the guardrail message
687
+ final_result = gr_out.behavior["message"]
688
+ break
689
+ elif gr_out.behavior["type"] == "allow":
690
+ # Continue to next guardrail
691
+ continue
692
+
693
+ return final_result
694
+
695
+ @classmethod
696
+ async def _execute_tool_with_hooks(
697
+ cls,
698
+ *,
699
+ func_tool: FunctionTool,
700
+ tool_context: ToolContext[TContext],
701
+ agent: Agent[TContext],
702
+ hooks: RunHooks[TContext],
703
+ tool_call: ResponseFunctionToolCall,
704
+ ) -> Any:
705
+ """Execute the core tool function with before/after hooks.
706
+
707
+ Args:
708
+ func_tool: The function tool being executed.
709
+ tool_context: The tool execution context.
710
+ agent: The agent executing the tool.
711
+ hooks: The run hooks to execute.
712
+ tool_call: The tool call details.
713
+
714
+ Returns:
715
+ The result from the tool execution.
716
+ """
717
+ await asyncio.gather(
718
+ hooks.on_tool_start(tool_context, agent, func_tool),
719
+ (
720
+ agent.hooks.on_tool_start(tool_context, agent, func_tool)
721
+ if agent.hooks
722
+ else _coro.noop_coroutine()
723
+ ),
724
+ )
725
+
726
+ return await func_tool.on_invoke_tool(tool_context, tool_call.arguments)
727
+
550
728
  @classmethod
551
729
  async def execute_function_tool_calls(
552
730
  cls,
@@ -556,7 +734,13 @@ class RunImpl:
556
734
  hooks: RunHooks[TContext],
557
735
  context_wrapper: RunContextWrapper[TContext],
558
736
  config: RunConfig,
559
- ) -> list[FunctionToolResult]:
737
+ ) -> tuple[
738
+ list[FunctionToolResult], list[ToolInputGuardrailResult], list[ToolOutputGuardrailResult]
739
+ ]:
740
+ # Collect guardrail results
741
+ tool_input_guardrail_results: list[ToolInputGuardrailResult] = []
742
+ tool_output_guardrail_results: list[ToolOutputGuardrailResult] = []
743
+
560
744
  async def run_single_tool(
561
745
  func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
562
746
  ) -> Any:
@@ -569,24 +753,48 @@ class RunImpl:
569
753
  if config.trace_include_sensitive_data:
570
754
  span_fn.span_data.input = tool_call.arguments
571
755
  try:
572
- _, _, result = await asyncio.gather(
573
- hooks.on_tool_start(tool_context, agent, func_tool),
574
- (
575
- agent.hooks.on_tool_start(tool_context, agent, func_tool)
576
- if agent.hooks
577
- else _coro.noop_coroutine()
578
- ),
579
- func_tool.on_invoke_tool(tool_context, tool_call.arguments),
756
+ # 1) Run input tool guardrails, if any
757
+ rejected_message = await cls._execute_input_guardrails(
758
+ func_tool=func_tool,
759
+ tool_context=tool_context,
760
+ agent=agent,
761
+ tool_input_guardrail_results=tool_input_guardrail_results,
580
762
  )
581
763
 
582
- await asyncio.gather(
583
- hooks.on_tool_end(tool_context, agent, func_tool, result),
584
- (
585
- agent.hooks.on_tool_end(tool_context, agent, func_tool, result)
586
- if agent.hooks
587
- else _coro.noop_coroutine()
588
- ),
589
- )
764
+ if rejected_message is not None:
765
+ # Input guardrail rejected the tool call
766
+ final_result = rejected_message
767
+ else:
768
+ # 2) Actually run the tool
769
+ real_result = await cls._execute_tool_with_hooks(
770
+ func_tool=func_tool,
771
+ tool_context=tool_context,
772
+ agent=agent,
773
+ hooks=hooks,
774
+ tool_call=tool_call,
775
+ )
776
+
777
+ # 3) Run output tool guardrails, if any
778
+ final_result = await cls._execute_output_guardrails(
779
+ func_tool=func_tool,
780
+ tool_context=tool_context,
781
+ agent=agent,
782
+ real_result=real_result,
783
+ tool_output_guardrail_results=tool_output_guardrail_results,
784
+ )
785
+
786
+ # 4) Tool end hooks (with final result, which may have been overridden)
787
+ await asyncio.gather(
788
+ hooks.on_tool_end(tool_context, agent, func_tool, final_result),
789
+ (
790
+ agent.hooks.on_tool_end(
791
+ tool_context, agent, func_tool, final_result
792
+ )
793
+ if agent.hooks
794
+ else _coro.noop_coroutine()
795
+ ),
796
+ )
797
+ result = final_result
590
798
  except Exception as e:
591
799
  _error_tracing.attach_error_to_current_span(
592
800
  SpanError(
@@ -609,7 +817,7 @@ class RunImpl:
609
817
 
610
818
  results = await asyncio.gather(*tasks)
611
819
 
612
- return [
820
+ function_tool_results = [
613
821
  FunctionToolResult(
614
822
  tool=tool_run.function_tool,
615
823
  output=result,
@@ -622,6 +830,8 @@ class RunImpl:
622
830
  for tool_run, result in zip(tool_runs, results)
623
831
  ]
624
832
 
833
+ return function_tool_results, tool_input_guardrail_results, tool_output_guardrail_results
834
+
625
835
  @classmethod
626
836
  async def execute_local_shell_calls(
627
837
  cls,
@@ -825,6 +1035,8 @@ class RunImpl:
825
1035
  pre_step_items=pre_step_items,
826
1036
  new_step_items=new_step_items,
827
1037
  next_step=NextStepHandoff(new_agent),
1038
+ tool_input_guardrail_results=[],
1039
+ tool_output_guardrail_results=[],
828
1040
  )
829
1041
 
830
1042
  @classmethod
@@ -873,6 +1085,8 @@ class RunImpl:
873
1085
  final_output: Any,
874
1086
  hooks: RunHooks[TContext],
875
1087
  context_wrapper: RunContextWrapper[TContext],
1088
+ tool_input_guardrail_results: list[ToolInputGuardrailResult],
1089
+ tool_output_guardrail_results: list[ToolOutputGuardrailResult],
876
1090
  ) -> SingleStepResult:
877
1091
  # Run the on_end hooks
878
1092
  await cls.run_final_output_hooks(agent, hooks, context_wrapper, final_output)
@@ -883,6 +1097,8 @@ class RunImpl:
883
1097
  pre_step_items=pre_step_items,
884
1098
  new_step_items=new_step_items,
885
1099
  next_step=NextStepFinalOutput(final_output),
1100
+ tool_input_guardrail_results=tool_input_guardrail_results,
1101
+ tool_output_guardrail_results=tool_output_guardrail_results,
886
1102
  )
887
1103
 
888
1104
  @classmethod
agents/exceptions.py CHANGED
@@ -8,6 +8,11 @@ if TYPE_CHECKING:
8
8
  from .guardrail import InputGuardrailResult, OutputGuardrailResult
9
9
  from .items import ModelResponse, RunItem, TResponseInputItem
10
10
  from .run_context import RunContextWrapper
11
+ from .tool_guardrails import (
12
+ ToolGuardrailFunctionOutput,
13
+ ToolInputGuardrail,
14
+ ToolOutputGuardrail,
15
+ )
11
16
 
12
17
  from .util._pretty_print import pretty_print_run_error_details
13
18
 
@@ -94,3 +99,33 @@ class OutputGuardrailTripwireTriggered(AgentsException):
94
99
  super().__init__(
95
100
  f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
96
101
  )
102
+
103
+
104
+ class ToolInputGuardrailTripwireTriggered(AgentsException):
105
+ """Exception raised when a tool input guardrail tripwire is triggered."""
106
+
107
+ guardrail: ToolInputGuardrail[Any]
108
+ """The guardrail that was triggered."""
109
+
110
+ output: ToolGuardrailFunctionOutput
111
+ """The output from the guardrail function."""
112
+
113
+ def __init__(self, guardrail: ToolInputGuardrail[Any], output: ToolGuardrailFunctionOutput):
114
+ self.guardrail = guardrail
115
+ self.output = output
116
+ super().__init__(f"Tool input guardrail {guardrail.__class__.__name__} triggered tripwire")
117
+
118
+
119
+ class ToolOutputGuardrailTripwireTriggered(AgentsException):
120
+ """Exception raised when a tool output guardrail tripwire is triggered."""
121
+
122
+ guardrail: ToolOutputGuardrail[Any]
123
+ """The guardrail that was triggered."""
124
+
125
+ output: ToolGuardrailFunctionOutput
126
+ """The output from the guardrail function."""
127
+
128
+ def __init__(self, guardrail: ToolOutputGuardrail[Any], output: ToolGuardrailFunctionOutput):
129
+ self.guardrail = guardrail
130
+ self.output = output
131
+ super().__init__(f"Tool output guardrail {guardrail.__class__.__name__} triggered tripwire")
@@ -12,7 +12,9 @@ from typing import Any
12
12
 
13
13
  __all__: list[str] = [
14
14
  "EncryptedSession",
15
+ "RedisSession",
15
16
  "SQLAlchemySession",
17
+ "AdvancedSQLiteSession",
16
18
  ]
17
19
 
18
20
 
@@ -28,6 +30,17 @@ def __getattr__(name: str) -> Any:
28
30
  "Install it with: pip install openai-agents[encrypt]"
29
31
  ) from e
30
32
 
33
+ if name == "RedisSession":
34
+ try:
35
+ from .redis_session import RedisSession # noqa: F401
36
+
37
+ return RedisSession
38
+ except ModuleNotFoundError as e:
39
+ raise ImportError(
40
+ "RedisSession requires the 'redis' extra. "
41
+ "Install it with: pip install openai-agents[redis]"
42
+ ) from e
43
+
31
44
  if name == "SQLAlchemySession":
32
45
  try:
33
46
  from .sqlalchemy_session import SQLAlchemySession # noqa: F401
@@ -39,4 +52,14 @@ def __getattr__(name: str) -> Any:
39
52
  "Install it with: pip install openai-agents[sqlalchemy]"
40
53
  ) from e
41
54
 
55
+ if name == "AdvancedSQLiteSession":
56
+ try:
57
+ from .advanced_sqlite_session import AdvancedSQLiteSession # noqa: F401
58
+
59
+ return AdvancedSQLiteSession
60
+ except ModuleNotFoundError as e:
61
+ raise ImportError(
62
+ f"Failed to import AdvancedSQLiteSession: {e}"
63
+ ) from e
64
+
42
65
  raise AttributeError(f"module {__name__} has no attribute {name}")