openai-agents 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

Files changed (53) hide show
  1. agents/__init__.py +223 -0
  2. agents/_config.py +23 -0
  3. agents/_debug.py +17 -0
  4. agents/_run_impl.py +792 -0
  5. agents/_utils.py +61 -0
  6. agents/agent.py +159 -0
  7. agents/agent_output.py +144 -0
  8. agents/computer.py +107 -0
  9. agents/exceptions.py +63 -0
  10. agents/extensions/handoff_filters.py +67 -0
  11. agents/extensions/handoff_prompt.py +19 -0
  12. agents/function_schema.py +340 -0
  13. agents/guardrail.py +320 -0
  14. agents/handoffs.py +236 -0
  15. agents/items.py +246 -0
  16. agents/lifecycle.py +105 -0
  17. agents/logger.py +3 -0
  18. agents/model_settings.py +36 -0
  19. agents/models/__init__.py +0 -0
  20. agents/models/_openai_shared.py +34 -0
  21. agents/models/fake_id.py +5 -0
  22. agents/models/interface.py +107 -0
  23. agents/models/openai_chatcompletions.py +952 -0
  24. agents/models/openai_provider.py +65 -0
  25. agents/models/openai_responses.py +384 -0
  26. agents/result.py +220 -0
  27. agents/run.py +904 -0
  28. agents/run_context.py +26 -0
  29. agents/stream_events.py +58 -0
  30. agents/strict_schema.py +167 -0
  31. agents/tool.py +288 -0
  32. agents/tracing/__init__.py +97 -0
  33. agents/tracing/create.py +306 -0
  34. agents/tracing/logger.py +3 -0
  35. agents/tracing/processor_interface.py +69 -0
  36. agents/tracing/processors.py +261 -0
  37. agents/tracing/scope.py +45 -0
  38. agents/tracing/setup.py +211 -0
  39. agents/tracing/span_data.py +188 -0
  40. agents/tracing/spans.py +264 -0
  41. agents/tracing/traces.py +195 -0
  42. agents/tracing/util.py +17 -0
  43. agents/usage.py +22 -0
  44. agents/version.py +7 -0
  45. openai_agents-0.0.3.dist-info/METADATA +204 -0
  46. openai_agents-0.0.3.dist-info/RECORD +49 -0
  47. openai_agents-0.0.3.dist-info/licenses/LICENSE +21 -0
  48. openai-agents/example.py +0 -2
  49. openai_agents-0.0.1.dist-info/METADATA +0 -17
  50. openai_agents-0.0.1.dist-info/RECORD +0 -6
  51. openai_agents-0.0.1.dist-info/licenses/LICENSE +0 -20
  52. {openai-agents → agents/extensions}/__init__.py +0 -0
  53. {openai_agents-0.0.1.dist-info → openai_agents-0.0.3.dist-info}/WHEEL +0 -0
@@ -0,0 +1,340 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import inspect
5
+ import logging
6
+ import re
7
+ from dataclasses import dataclass
8
+ from typing import Any, Callable, Literal, get_args, get_origin, get_type_hints
9
+
10
+ from griffe import Docstring, DocstringSectionKind
11
+ from pydantic import BaseModel, Field, create_model
12
+
13
+ from .exceptions import UserError
14
+ from .run_context import RunContextWrapper
15
+ from .strict_schema import ensure_strict_json_schema
16
+
17
+
18
+ @dataclass
19
+ class FuncSchema:
20
+ """
21
+ Captures the schema for a python function, in preparation for sending it to an LLM as a tool.
22
+ """
23
+
24
+ name: str
25
+ """The name of the function."""
26
+ description: str | None
27
+ """The description of the function."""
28
+ params_pydantic_model: type[BaseModel]
29
+ """A Pydantic model that represents the function's parameters."""
30
+ params_json_schema: dict[str, Any]
31
+ """The JSON schema for the function's parameters, derived from the Pydantic model."""
32
+ signature: inspect.Signature
33
+ """The signature of the function."""
34
+ takes_context: bool = False
35
+ """Whether the function takes a RunContextWrapper argument (must be the first argument)."""
36
+
37
+ def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
38
+ """
39
+ Converts validated data from the Pydantic model into (args, kwargs), suitable for calling
40
+ the original function.
41
+ """
42
+ positional_args: list[Any] = []
43
+ keyword_args: dict[str, Any] = {}
44
+ seen_var_positional = False
45
+
46
+ # Use enumerate() so we can skip the first parameter if it's context.
47
+ for idx, (name, param) in enumerate(self.signature.parameters.items()):
48
+ # If the function takes a RunContextWrapper and this is the first parameter, skip it.
49
+ if self.takes_context and idx == 0:
50
+ continue
51
+
52
+ value = getattr(data, name, None)
53
+ if param.kind == param.VAR_POSITIONAL:
54
+ # e.g. *args: extend positional args and mark that *args is now seen
55
+ positional_args.extend(value or [])
56
+ seen_var_positional = True
57
+ elif param.kind == param.VAR_KEYWORD:
58
+ # e.g. **kwargs handling
59
+ keyword_args.update(value or {})
60
+ elif param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
61
+ # Before *args, add to positional args. After *args, add to keyword args.
62
+ if not seen_var_positional:
63
+ positional_args.append(value)
64
+ else:
65
+ keyword_args[name] = value
66
+ else:
67
+ # For KEYWORD_ONLY parameters, always use keyword args.
68
+ keyword_args[name] = value
69
+ return positional_args, keyword_args
70
+
71
+
72
+ @dataclass
73
+ class FuncDocumentation:
74
+ """Contains metadata about a python function, extracted from its docstring."""
75
+
76
+ name: str
77
+ """The name of the function, via `__name__`."""
78
+ description: str | None
79
+ """The description of the function, derived from the docstring."""
80
+ param_descriptions: dict[str, str] | None
81
+ """The parameter descriptions of the function, derived from the docstring."""
82
+
83
+
84
+ DocstringStyle = Literal["google", "numpy", "sphinx"]
85
+
86
+
87
+ # As of Feb 2025, the automatic style detection in griffe is an Insiders feature. This
88
+ # code approximates it.
89
+ def _detect_docstring_style(doc: str) -> DocstringStyle:
90
+ scores: dict[DocstringStyle, int] = {"sphinx": 0, "numpy": 0, "google": 0}
91
+
92
+ # Sphinx style detection: look for :param, :type, :return:, and :rtype:
93
+ sphinx_patterns = [r"^:param\s", r"^:type\s", r"^:return:", r"^:rtype:"]
94
+ for pattern in sphinx_patterns:
95
+ if re.search(pattern, doc, re.MULTILINE):
96
+ scores["sphinx"] += 1
97
+
98
+ # Numpy style detection: look for headers like 'Parameters', 'Returns', or 'Yields' followed by
99
+ # a dashed underline
100
+ numpy_patterns = [
101
+ r"^Parameters\s*\n\s*-{3,}",
102
+ r"^Returns\s*\n\s*-{3,}",
103
+ r"^Yields\s*\n\s*-{3,}",
104
+ ]
105
+ for pattern in numpy_patterns:
106
+ if re.search(pattern, doc, re.MULTILINE):
107
+ scores["numpy"] += 1
108
+
109
+ # Google style detection: look for section headers with a trailing colon
110
+ google_patterns = [r"^(Args|Arguments):", r"^(Returns):", r"^(Raises):"]
111
+ for pattern in google_patterns:
112
+ if re.search(pattern, doc, re.MULTILINE):
113
+ scores["google"] += 1
114
+
115
+ max_score = max(scores.values())
116
+ if max_score == 0:
117
+ return "google"
118
+
119
+ # Priority order: sphinx > numpy > google in case of tie
120
+ styles: list[DocstringStyle] = ["sphinx", "numpy", "google"]
121
+
122
+ for style in styles:
123
+ if scores[style] == max_score:
124
+ return style
125
+
126
+ return "google"
127
+
128
+
129
+ @contextlib.contextmanager
130
+ def _suppress_griffe_logging():
131
+ # Supresses warnings about missing annotations for params
132
+ logger = logging.getLogger("griffe")
133
+ previous_level = logger.getEffectiveLevel()
134
+ logger.setLevel(logging.ERROR)
135
+ try:
136
+ yield
137
+ finally:
138
+ logger.setLevel(previous_level)
139
+
140
+
141
+ def generate_func_documentation(
142
+ func: Callable[..., Any], style: DocstringStyle | None = None
143
+ ) -> FuncDocumentation:
144
+ """
145
+ Extracts metadata from a function docstring, in preparation for sending it to an LLM as a tool.
146
+
147
+ Args:
148
+ func: The function to extract documentation from.
149
+ style: The style of the docstring to use for parsing. If not provided, we will attempt to
150
+ auto-detect the style.
151
+
152
+ Returns:
153
+ A FuncDocumentation object containing the function's name, description, and parameter
154
+ descriptions.
155
+ """
156
+ name = func.__name__
157
+ doc = inspect.getdoc(func)
158
+ if not doc:
159
+ return FuncDocumentation(name=name, description=None, param_descriptions=None)
160
+
161
+ with _suppress_griffe_logging():
162
+ docstring = Docstring(doc, lineno=1, parser=style or _detect_docstring_style(doc))
163
+ parsed = docstring.parse()
164
+
165
+ description: str | None = next(
166
+ (section.value for section in parsed if section.kind == DocstringSectionKind.text), None
167
+ )
168
+
169
+ param_descriptions: dict[str, str] = {
170
+ param.name: param.description
171
+ for section in parsed
172
+ if section.kind == DocstringSectionKind.parameters
173
+ for param in section.value
174
+ }
175
+
176
+ return FuncDocumentation(
177
+ name=func.__name__,
178
+ description=description,
179
+ param_descriptions=param_descriptions or None,
180
+ )
181
+
182
+
183
+ def function_schema(
184
+ func: Callable[..., Any],
185
+ docstring_style: DocstringStyle | None = None,
186
+ name_override: str | None = None,
187
+ description_override: str | None = None,
188
+ use_docstring_info: bool = True,
189
+ strict_json_schema: bool = True,
190
+ ) -> FuncSchema:
191
+ """
192
+ Given a python function, extracts a `FuncSchema` from it, capturing the name, description,
193
+ parameter descriptions, and other metadata.
194
+
195
+ Args:
196
+ func: The function to extract the schema from.
197
+ docstring_style: The style of the docstring to use for parsing. If not provided, we will
198
+ attempt to auto-detect the style.
199
+ name_override: If provided, use this name instead of the function's `__name__`.
200
+ description_override: If provided, use this description instead of the one derived from the
201
+ docstring.
202
+ use_docstring_info: If True, uses the docstring to generate the description and parameter
203
+ descriptions.
204
+ strict_json_schema: Whether the JSON schema is in strict mode. If True, we'll ensure that
205
+ the schema adheres to the "strict" standard the OpenAI API expects. We **strongly**
206
+ recommend setting this to True, as it increases the likelihood of the LLM providing
207
+ correct JSON input.
208
+
209
+ Returns:
210
+ A `FuncSchema` object containing the function's name, description, parameter descriptions,
211
+ and other metadata.
212
+ """
213
+
214
+ # 1. Grab docstring info
215
+ if use_docstring_info:
216
+ doc_info = generate_func_documentation(func, docstring_style)
217
+ param_descs = doc_info.param_descriptions or {}
218
+ else:
219
+ doc_info = None
220
+ param_descs = {}
221
+
222
+ func_name = name_override or doc_info.name if doc_info else func.__name__
223
+
224
+ # 2. Inspect function signature and get type hints
225
+ sig = inspect.signature(func)
226
+ type_hints = get_type_hints(func)
227
+ params = list(sig.parameters.items())
228
+ takes_context = False
229
+ filtered_params = []
230
+
231
+ if params:
232
+ first_name, first_param = params[0]
233
+ # Prefer the evaluated type hint if available
234
+ ann = type_hints.get(first_name, first_param.annotation)
235
+ if ann != inspect._empty:
236
+ origin = get_origin(ann) or ann
237
+ if origin is RunContextWrapper:
238
+ takes_context = True # Mark that the function takes context
239
+ else:
240
+ filtered_params.append((first_name, first_param))
241
+ else:
242
+ filtered_params.append((first_name, first_param))
243
+
244
+ # For parameters other than the first, raise error if any use RunContextWrapper.
245
+ for name, param in params[1:]:
246
+ ann = type_hints.get(name, param.annotation)
247
+ if ann != inspect._empty:
248
+ origin = get_origin(ann) or ann
249
+ if origin is RunContextWrapper:
250
+ raise UserError(
251
+ f"RunContextWrapper param found at non-first position in function"
252
+ f" {func.__name__}"
253
+ )
254
+ filtered_params.append((name, param))
255
+
256
+ # We will collect field definitions for create_model as a dict:
257
+ # field_name -> (type_annotation, default_value_or_Field(...))
258
+ fields: dict[str, Any] = {}
259
+
260
+ for name, param in filtered_params:
261
+ ann = type_hints.get(name, param.annotation)
262
+ default = param.default
263
+
264
+ # If there's no type hint, assume `Any`
265
+ if ann == inspect._empty:
266
+ ann = Any
267
+
268
+ # If a docstring param description exists, use it
269
+ field_description = param_descs.get(name, None)
270
+
271
+ # Handle different parameter kinds
272
+ if param.kind == param.VAR_POSITIONAL:
273
+ # e.g. *args: extend positional args
274
+ if get_origin(ann) is tuple:
275
+ # e.g. def foo(*args: tuple[int, ...]) -> treat as List[int]
276
+ args_of_tuple = get_args(ann)
277
+ if len(args_of_tuple) == 2 and args_of_tuple[1] is Ellipsis:
278
+ ann = list[args_of_tuple[0]] # type: ignore
279
+ else:
280
+ ann = list[Any]
281
+ else:
282
+ # If user wrote *args: int, treat as List[int]
283
+ ann = list[ann] # type: ignore
284
+
285
+ # Default factory to empty list
286
+ fields[name] = (
287
+ ann,
288
+ Field(default_factory=list, description=field_description), # type: ignore
289
+ )
290
+
291
+ elif param.kind == param.VAR_KEYWORD:
292
+ # **kwargs handling
293
+ if get_origin(ann) is dict:
294
+ # e.g. def foo(**kwargs: dict[str, int])
295
+ dict_args = get_args(ann)
296
+ if len(dict_args) == 2:
297
+ ann = dict[dict_args[0], dict_args[1]] # type: ignore
298
+ else:
299
+ ann = dict[str, Any]
300
+ else:
301
+ # e.g. def foo(**kwargs: int) -> Dict[str, int]
302
+ ann = dict[str, ann] # type: ignore
303
+
304
+ fields[name] = (
305
+ ann,
306
+ Field(default_factory=dict, description=field_description), # type: ignore
307
+ )
308
+
309
+ else:
310
+ # Normal parameter
311
+ if default == inspect._empty:
312
+ # Required field
313
+ fields[name] = (
314
+ ann,
315
+ Field(..., description=field_description),
316
+ )
317
+ else:
318
+ # Parameter with a default value
319
+ fields[name] = (
320
+ ann,
321
+ Field(default=default, description=field_description),
322
+ )
323
+
324
+ # 3. Dynamically build a Pydantic model
325
+ dynamic_model = create_model(f"{func_name}_args", __base__=BaseModel, **fields)
326
+
327
+ # 4. Build JSON schema from that model
328
+ json_schema = dynamic_model.model_json_schema()
329
+ if strict_json_schema:
330
+ json_schema = ensure_strict_json_schema(json_schema)
331
+
332
+ # 5. Return as a FuncSchema dataclass
333
+ return FuncSchema(
334
+ name=func_name,
335
+ description=description_override or doc_info.description if doc_info else None,
336
+ params_pydantic_model=dynamic_model,
337
+ params_json_schema=json_schema,
338
+ signature=sig,
339
+ takes_context=takes_context,
340
+ )
agents/guardrail.py ADDED
@@ -0,0 +1,320 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from collections.abc import Awaitable
5
+ from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING, Any, Callable, Generic, Union, overload
7
+
8
+ from typing_extensions import TypeVar
9
+
10
+ from ._utils import MaybeAwaitable
11
+ from .exceptions import UserError
12
+ from .items import TResponseInputItem
13
+ from .run_context import RunContextWrapper, TContext
14
+
15
+ if TYPE_CHECKING:
16
+ from .agent import Agent
17
+
18
+
19
+ @dataclass
20
+ class GuardrailFunctionOutput:
21
+ """The output of a guardrail function."""
22
+
23
+ output_info: Any
24
+ """
25
+ Optional information about the guardrail's output. For example, the guardrail could include
26
+ information about the checks it performed and granular results.
27
+ """
28
+
29
+ tripwire_triggered: bool
30
+ """
31
+ Whether the tripwire was triggered. If triggered, the agent's execution will be halted.
32
+ """
33
+
34
+
35
+ @dataclass
36
+ class InputGuardrailResult:
37
+ """The result of a guardrail run."""
38
+
39
+ guardrail: InputGuardrail[Any]
40
+ """
41
+ The guardrail that was run.
42
+ """
43
+
44
+ output: GuardrailFunctionOutput
45
+ """The output of the guardrail function."""
46
+
47
+
48
+ @dataclass
49
+ class OutputGuardrailResult:
50
+ """The result of a guardrail run."""
51
+
52
+ guardrail: OutputGuardrail[Any]
53
+ """
54
+ The guardrail that was run.
55
+ """
56
+
57
+ agent_output: Any
58
+ """
59
+ The output of the agent that was checked by the guardrail.
60
+ """
61
+
62
+ agent: Agent[Any]
63
+ """
64
+ The agent that was checked by the guardrail.
65
+ """
66
+
67
+ output: GuardrailFunctionOutput
68
+ """The output of the guardrail function."""
69
+
70
+
71
+ @dataclass
72
+ class InputGuardrail(Generic[TContext]):
73
+ """Input guardrails are checks that run in parallel to the agent's execution.
74
+ They can be used to do things like:
75
+ - Check if input messages are off-topic
76
+ - Take over control of the agent's execution if an unexpected input is detected
77
+
78
+ You can use the `@input_guardrail()` decorator to turn a function into an `InputGuardrail`, or
79
+ create an `InputGuardrail` manually.
80
+
81
+ Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`, the agent
82
+ execution will immediately stop and a `InputGuardrailTripwireTriggered` exception will be raised
83
+ """
84
+
85
+ guardrail_function: Callable[
86
+ [RunContextWrapper[TContext], Agent[Any], str | list[TResponseInputItem]],
87
+ MaybeAwaitable[GuardrailFunctionOutput],
88
+ ]
89
+ """A function that receives the the agent input and the context, and returns a
90
+ `GuardrailResult`. The result marks whether the tripwire was triggered, and can optionally
91
+ include information about the guardrail's output.
92
+ """
93
+
94
+ name: str | None = None
95
+ """The name of the guardrail, used for tracing. If not provided, we'll use the guardrail
96
+ function's name.
97
+ """
98
+
99
+ def get_name(self) -> str:
100
+ if self.name:
101
+ return self.name
102
+
103
+ return self.guardrail_function.__name__
104
+
105
+ async def run(
106
+ self,
107
+ agent: Agent[Any],
108
+ input: str | list[TResponseInputItem],
109
+ context: RunContextWrapper[TContext],
110
+ ) -> InputGuardrailResult:
111
+ if not callable(self.guardrail_function):
112
+ raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}")
113
+
114
+ output = self.guardrail_function(context, agent, input)
115
+ if inspect.isawaitable(output):
116
+ return InputGuardrailResult(
117
+ guardrail=self,
118
+ output=await output,
119
+ )
120
+
121
+ return InputGuardrailResult(
122
+ guardrail=self,
123
+ output=output,
124
+ )
125
+
126
+
127
+ @dataclass
128
+ class OutputGuardrail(Generic[TContext]):
129
+ """Output guardrails are checks that run on the final output of an agent.
130
+ They can be used to do check if the output passes certain validation criteria
131
+
132
+ You can use the `@output_guardrail()` decorator to turn a function into an `OutputGuardrail`,
133
+ or create an `OutputGuardrail` manually.
134
+
135
+ Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`, a
136
+ `OutputGuardrailTripwireTriggered` exception will be raised.
137
+ """
138
+
139
+ guardrail_function: Callable[
140
+ [RunContextWrapper[TContext], Agent[Any], Any],
141
+ MaybeAwaitable[GuardrailFunctionOutput],
142
+ ]
143
+ """A function that receives the final agent, its output, and the context, and returns a
144
+ `GuardrailResult`. The result marks whether the tripwire was triggered, and can optionally
145
+ include information about the guardrail's output.
146
+ """
147
+
148
+ name: str | None = None
149
+ """The name of the guardrail, used for tracing. If not provided, we'll use the guardrail
150
+ function's name.
151
+ """
152
+
153
+ def get_name(self) -> str:
154
+ if self.name:
155
+ return self.name
156
+
157
+ return self.guardrail_function.__name__
158
+
159
+ async def run(
160
+ self, context: RunContextWrapper[TContext], agent: Agent[Any], agent_output: Any
161
+ ) -> OutputGuardrailResult:
162
+ if not callable(self.guardrail_function):
163
+ raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}")
164
+
165
+ output = self.guardrail_function(context, agent, agent_output)
166
+ if inspect.isawaitable(output):
167
+ return OutputGuardrailResult(
168
+ guardrail=self,
169
+ agent=agent,
170
+ agent_output=agent_output,
171
+ output=await output,
172
+ )
173
+
174
+ return OutputGuardrailResult(
175
+ guardrail=self,
176
+ agent=agent,
177
+ agent_output=agent_output,
178
+ output=output,
179
+ )
180
+
181
+
182
+ TContext_co = TypeVar("TContext_co", bound=Any, covariant=True)
183
+
184
+ # For InputGuardrail
185
+ _InputGuardrailFuncSync = Callable[
186
+ [RunContextWrapper[TContext_co], "Agent[Any]", Union[str, list[TResponseInputItem]]],
187
+ GuardrailFunctionOutput,
188
+ ]
189
+ _InputGuardrailFuncAsync = Callable[
190
+ [RunContextWrapper[TContext_co], "Agent[Any]", Union[str, list[TResponseInputItem]]],
191
+ Awaitable[GuardrailFunctionOutput],
192
+ ]
193
+
194
+
195
+ @overload
196
+ def input_guardrail(
197
+ func: _InputGuardrailFuncSync[TContext_co],
198
+ ) -> InputGuardrail[TContext_co]: ...
199
+
200
+
201
+ @overload
202
+ def input_guardrail(
203
+ func: _InputGuardrailFuncAsync[TContext_co],
204
+ ) -> InputGuardrail[TContext_co]: ...
205
+
206
+
207
+ @overload
208
+ def input_guardrail(
209
+ *,
210
+ name: str | None = None,
211
+ ) -> Callable[
212
+ [_InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co]],
213
+ InputGuardrail[TContext_co],
214
+ ]: ...
215
+
216
+
217
+ def input_guardrail(
218
+ func: _InputGuardrailFuncSync[TContext_co]
219
+ | _InputGuardrailFuncAsync[TContext_co]
220
+ | None = None,
221
+ *,
222
+ name: str | None = None,
223
+ ) -> (
224
+ InputGuardrail[TContext_co]
225
+ | Callable[
226
+ [_InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co]],
227
+ InputGuardrail[TContext_co],
228
+ ]
229
+ ):
230
+ """
231
+ Decorator that transforms a sync or async function into an `InputGuardrail`.
232
+ It can be used directly (no parentheses) or with keyword args, e.g.:
233
+
234
+ @input_guardrail
235
+ def my_sync_guardrail(...): ...
236
+
237
+ @input_guardrail(name="guardrail_name")
238
+ async def my_async_guardrail(...): ...
239
+ """
240
+
241
+ def decorator(
242
+ f: _InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co],
243
+ ) -> InputGuardrail[TContext_co]:
244
+ return InputGuardrail(guardrail_function=f, name=name)
245
+
246
+ if func is not None:
247
+ # Decorator was used without parentheses
248
+ return decorator(func)
249
+
250
+ # Decorator used with keyword arguments
251
+ return decorator
252
+
253
+
254
+ _OutputGuardrailFuncSync = Callable[
255
+ [RunContextWrapper[TContext_co], "Agent[Any]", Any],
256
+ GuardrailFunctionOutput,
257
+ ]
258
+ _OutputGuardrailFuncAsync = Callable[
259
+ [RunContextWrapper[TContext_co], "Agent[Any]", Any],
260
+ Awaitable[GuardrailFunctionOutput],
261
+ ]
262
+
263
+
264
+ @overload
265
+ def output_guardrail(
266
+ func: _OutputGuardrailFuncSync[TContext_co],
267
+ ) -> OutputGuardrail[TContext_co]: ...
268
+
269
+
270
+ @overload
271
+ def output_guardrail(
272
+ func: _OutputGuardrailFuncAsync[TContext_co],
273
+ ) -> OutputGuardrail[TContext_co]: ...
274
+
275
+
276
+ @overload
277
+ def output_guardrail(
278
+ *,
279
+ name: str | None = None,
280
+ ) -> Callable[
281
+ [_OutputGuardrailFuncSync[TContext_co] | _OutputGuardrailFuncAsync[TContext_co]],
282
+ OutputGuardrail[TContext_co],
283
+ ]: ...
284
+
285
+
286
+ def output_guardrail(
287
+ func: _OutputGuardrailFuncSync[TContext_co]
288
+ | _OutputGuardrailFuncAsync[TContext_co]
289
+ | None = None,
290
+ *,
291
+ name: str | None = None,
292
+ ) -> (
293
+ OutputGuardrail[TContext_co]
294
+ | Callable[
295
+ [_OutputGuardrailFuncSync[TContext_co] | _OutputGuardrailFuncAsync[TContext_co]],
296
+ OutputGuardrail[TContext_co],
297
+ ]
298
+ ):
299
+ """
300
+ Decorator that transforms a sync or async function into an `OutputGuardrail`.
301
+ It can be used directly (no parentheses) or with keyword args, e.g.:
302
+
303
+ @output_guardrail
304
+ def my_sync_guardrail(...): ...
305
+
306
+ @output_guardrail(name="guardrail_name")
307
+ async def my_async_guardrail(...): ...
308
+ """
309
+
310
+ def decorator(
311
+ f: _OutputGuardrailFuncSync[TContext_co] | _OutputGuardrailFuncAsync[TContext_co],
312
+ ) -> OutputGuardrail[TContext_co]:
313
+ return OutputGuardrail(guardrail_function=f, name=name)
314
+
315
+ if func is not None:
316
+ # Decorator was used without parentheses
317
+ return decorator(func)
318
+
319
+ # Decorator used with keyword arguments
320
+ return decorator