openai-agents 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +223 -0
- agents/_config.py +23 -0
- agents/_debug.py +17 -0
- agents/_run_impl.py +792 -0
- agents/_utils.py +61 -0
- agents/agent.py +159 -0
- agents/agent_output.py +144 -0
- agents/computer.py +107 -0
- agents/exceptions.py +63 -0
- agents/extensions/handoff_filters.py +67 -0
- agents/extensions/handoff_prompt.py +19 -0
- agents/function_schema.py +340 -0
- agents/guardrail.py +320 -0
- agents/handoffs.py +236 -0
- agents/items.py +246 -0
- agents/lifecycle.py +105 -0
- agents/logger.py +3 -0
- agents/model_settings.py +36 -0
- agents/models/__init__.py +0 -0
- agents/models/_openai_shared.py +34 -0
- agents/models/fake_id.py +5 -0
- agents/models/interface.py +107 -0
- agents/models/openai_chatcompletions.py +952 -0
- agents/models/openai_provider.py +65 -0
- agents/models/openai_responses.py +384 -0
- agents/result.py +220 -0
- agents/run.py +904 -0
- agents/run_context.py +26 -0
- agents/stream_events.py +58 -0
- agents/strict_schema.py +167 -0
- agents/tool.py +288 -0
- agents/tracing/__init__.py +97 -0
- agents/tracing/create.py +306 -0
- agents/tracing/logger.py +3 -0
- agents/tracing/processor_interface.py +69 -0
- agents/tracing/processors.py +261 -0
- agents/tracing/scope.py +45 -0
- agents/tracing/setup.py +211 -0
- agents/tracing/span_data.py +188 -0
- agents/tracing/spans.py +264 -0
- agents/tracing/traces.py +195 -0
- agents/tracing/util.py +17 -0
- agents/usage.py +22 -0
- agents/version.py +7 -0
- openai_agents-0.0.3.dist-info/METADATA +204 -0
- openai_agents-0.0.3.dist-info/RECORD +49 -0
- openai_agents-0.0.3.dist-info/licenses/LICENSE +21 -0
- openai-agents/example.py +0 -2
- openai_agents-0.0.1.dist-info/METADATA +0 -17
- openai_agents-0.0.1.dist-info/RECORD +0 -6
- openai_agents-0.0.1.dist-info/licenses/LICENSE +0 -20
- {openai-agents → agents/extensions}/__init__.py +0 -0
- {openai_agents-0.0.1.dist-info → openai_agents-0.0.3.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import inspect
|
|
5
|
+
import logging
|
|
6
|
+
import re
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Any, Callable, Literal, get_args, get_origin, get_type_hints
|
|
9
|
+
|
|
10
|
+
from griffe import Docstring, DocstringSectionKind
|
|
11
|
+
from pydantic import BaseModel, Field, create_model
|
|
12
|
+
|
|
13
|
+
from .exceptions import UserError
|
|
14
|
+
from .run_context import RunContextWrapper
|
|
15
|
+
from .strict_schema import ensure_strict_json_schema
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class FuncSchema:
|
|
20
|
+
"""
|
|
21
|
+
Captures the schema for a python function, in preparation for sending it to an LLM as a tool.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
name: str
|
|
25
|
+
"""The name of the function."""
|
|
26
|
+
description: str | None
|
|
27
|
+
"""The description of the function."""
|
|
28
|
+
params_pydantic_model: type[BaseModel]
|
|
29
|
+
"""A Pydantic model that represents the function's parameters."""
|
|
30
|
+
params_json_schema: dict[str, Any]
|
|
31
|
+
"""The JSON schema for the function's parameters, derived from the Pydantic model."""
|
|
32
|
+
signature: inspect.Signature
|
|
33
|
+
"""The signature of the function."""
|
|
34
|
+
takes_context: bool = False
|
|
35
|
+
"""Whether the function takes a RunContextWrapper argument (must be the first argument)."""
|
|
36
|
+
|
|
37
|
+
def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
|
|
38
|
+
"""
|
|
39
|
+
Converts validated data from the Pydantic model into (args, kwargs), suitable for calling
|
|
40
|
+
the original function.
|
|
41
|
+
"""
|
|
42
|
+
positional_args: list[Any] = []
|
|
43
|
+
keyword_args: dict[str, Any] = {}
|
|
44
|
+
seen_var_positional = False
|
|
45
|
+
|
|
46
|
+
# Use enumerate() so we can skip the first parameter if it's context.
|
|
47
|
+
for idx, (name, param) in enumerate(self.signature.parameters.items()):
|
|
48
|
+
# If the function takes a RunContextWrapper and this is the first parameter, skip it.
|
|
49
|
+
if self.takes_context and idx == 0:
|
|
50
|
+
continue
|
|
51
|
+
|
|
52
|
+
value = getattr(data, name, None)
|
|
53
|
+
if param.kind == param.VAR_POSITIONAL:
|
|
54
|
+
# e.g. *args: extend positional args and mark that *args is now seen
|
|
55
|
+
positional_args.extend(value or [])
|
|
56
|
+
seen_var_positional = True
|
|
57
|
+
elif param.kind == param.VAR_KEYWORD:
|
|
58
|
+
# e.g. **kwargs handling
|
|
59
|
+
keyword_args.update(value or {})
|
|
60
|
+
elif param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
|
|
61
|
+
# Before *args, add to positional args. After *args, add to keyword args.
|
|
62
|
+
if not seen_var_positional:
|
|
63
|
+
positional_args.append(value)
|
|
64
|
+
else:
|
|
65
|
+
keyword_args[name] = value
|
|
66
|
+
else:
|
|
67
|
+
# For KEYWORD_ONLY parameters, always use keyword args.
|
|
68
|
+
keyword_args[name] = value
|
|
69
|
+
return positional_args, keyword_args
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class FuncDocumentation:
|
|
74
|
+
"""Contains metadata about a python function, extracted from its docstring."""
|
|
75
|
+
|
|
76
|
+
name: str
|
|
77
|
+
"""The name of the function, via `__name__`."""
|
|
78
|
+
description: str | None
|
|
79
|
+
"""The description of the function, derived from the docstring."""
|
|
80
|
+
param_descriptions: dict[str, str] | None
|
|
81
|
+
"""The parameter descriptions of the function, derived from the docstring."""
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
DocstringStyle = Literal["google", "numpy", "sphinx"]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# As of Feb 2025, the automatic style detection in griffe is an Insiders feature. This
|
|
88
|
+
# code approximates it.
|
|
89
|
+
def _detect_docstring_style(doc: str) -> DocstringStyle:
|
|
90
|
+
scores: dict[DocstringStyle, int] = {"sphinx": 0, "numpy": 0, "google": 0}
|
|
91
|
+
|
|
92
|
+
# Sphinx style detection: look for :param, :type, :return:, and :rtype:
|
|
93
|
+
sphinx_patterns = [r"^:param\s", r"^:type\s", r"^:return:", r"^:rtype:"]
|
|
94
|
+
for pattern in sphinx_patterns:
|
|
95
|
+
if re.search(pattern, doc, re.MULTILINE):
|
|
96
|
+
scores["sphinx"] += 1
|
|
97
|
+
|
|
98
|
+
# Numpy style detection: look for headers like 'Parameters', 'Returns', or 'Yields' followed by
|
|
99
|
+
# a dashed underline
|
|
100
|
+
numpy_patterns = [
|
|
101
|
+
r"^Parameters\s*\n\s*-{3,}",
|
|
102
|
+
r"^Returns\s*\n\s*-{3,}",
|
|
103
|
+
r"^Yields\s*\n\s*-{3,}",
|
|
104
|
+
]
|
|
105
|
+
for pattern in numpy_patterns:
|
|
106
|
+
if re.search(pattern, doc, re.MULTILINE):
|
|
107
|
+
scores["numpy"] += 1
|
|
108
|
+
|
|
109
|
+
# Google style detection: look for section headers with a trailing colon
|
|
110
|
+
google_patterns = [r"^(Args|Arguments):", r"^(Returns):", r"^(Raises):"]
|
|
111
|
+
for pattern in google_patterns:
|
|
112
|
+
if re.search(pattern, doc, re.MULTILINE):
|
|
113
|
+
scores["google"] += 1
|
|
114
|
+
|
|
115
|
+
max_score = max(scores.values())
|
|
116
|
+
if max_score == 0:
|
|
117
|
+
return "google"
|
|
118
|
+
|
|
119
|
+
# Priority order: sphinx > numpy > google in case of tie
|
|
120
|
+
styles: list[DocstringStyle] = ["sphinx", "numpy", "google"]
|
|
121
|
+
|
|
122
|
+
for style in styles:
|
|
123
|
+
if scores[style] == max_score:
|
|
124
|
+
return style
|
|
125
|
+
|
|
126
|
+
return "google"
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@contextlib.contextmanager
|
|
130
|
+
def _suppress_griffe_logging():
|
|
131
|
+
# Supresses warnings about missing annotations for params
|
|
132
|
+
logger = logging.getLogger("griffe")
|
|
133
|
+
previous_level = logger.getEffectiveLevel()
|
|
134
|
+
logger.setLevel(logging.ERROR)
|
|
135
|
+
try:
|
|
136
|
+
yield
|
|
137
|
+
finally:
|
|
138
|
+
logger.setLevel(previous_level)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def generate_func_documentation(
|
|
142
|
+
func: Callable[..., Any], style: DocstringStyle | None = None
|
|
143
|
+
) -> FuncDocumentation:
|
|
144
|
+
"""
|
|
145
|
+
Extracts metadata from a function docstring, in preparation for sending it to an LLM as a tool.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
func: The function to extract documentation from.
|
|
149
|
+
style: The style of the docstring to use for parsing. If not provided, we will attempt to
|
|
150
|
+
auto-detect the style.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
A FuncDocumentation object containing the function's name, description, and parameter
|
|
154
|
+
descriptions.
|
|
155
|
+
"""
|
|
156
|
+
name = func.__name__
|
|
157
|
+
doc = inspect.getdoc(func)
|
|
158
|
+
if not doc:
|
|
159
|
+
return FuncDocumentation(name=name, description=None, param_descriptions=None)
|
|
160
|
+
|
|
161
|
+
with _suppress_griffe_logging():
|
|
162
|
+
docstring = Docstring(doc, lineno=1, parser=style or _detect_docstring_style(doc))
|
|
163
|
+
parsed = docstring.parse()
|
|
164
|
+
|
|
165
|
+
description: str | None = next(
|
|
166
|
+
(section.value for section in parsed if section.kind == DocstringSectionKind.text), None
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
param_descriptions: dict[str, str] = {
|
|
170
|
+
param.name: param.description
|
|
171
|
+
for section in parsed
|
|
172
|
+
if section.kind == DocstringSectionKind.parameters
|
|
173
|
+
for param in section.value
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
return FuncDocumentation(
|
|
177
|
+
name=func.__name__,
|
|
178
|
+
description=description,
|
|
179
|
+
param_descriptions=param_descriptions or None,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def function_schema(
|
|
184
|
+
func: Callable[..., Any],
|
|
185
|
+
docstring_style: DocstringStyle | None = None,
|
|
186
|
+
name_override: str | None = None,
|
|
187
|
+
description_override: str | None = None,
|
|
188
|
+
use_docstring_info: bool = True,
|
|
189
|
+
strict_json_schema: bool = True,
|
|
190
|
+
) -> FuncSchema:
|
|
191
|
+
"""
|
|
192
|
+
Given a python function, extracts a `FuncSchema` from it, capturing the name, description,
|
|
193
|
+
parameter descriptions, and other metadata.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
func: The function to extract the schema from.
|
|
197
|
+
docstring_style: The style of the docstring to use for parsing. If not provided, we will
|
|
198
|
+
attempt to auto-detect the style.
|
|
199
|
+
name_override: If provided, use this name instead of the function's `__name__`.
|
|
200
|
+
description_override: If provided, use this description instead of the one derived from the
|
|
201
|
+
docstring.
|
|
202
|
+
use_docstring_info: If True, uses the docstring to generate the description and parameter
|
|
203
|
+
descriptions.
|
|
204
|
+
strict_json_schema: Whether the JSON schema is in strict mode. If True, we'll ensure that
|
|
205
|
+
the schema adheres to the "strict" standard the OpenAI API expects. We **strongly**
|
|
206
|
+
recommend setting this to True, as it increases the likelihood of the LLM providing
|
|
207
|
+
correct JSON input.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
A `FuncSchema` object containing the function's name, description, parameter descriptions,
|
|
211
|
+
and other metadata.
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
# 1. Grab docstring info
|
|
215
|
+
if use_docstring_info:
|
|
216
|
+
doc_info = generate_func_documentation(func, docstring_style)
|
|
217
|
+
param_descs = doc_info.param_descriptions or {}
|
|
218
|
+
else:
|
|
219
|
+
doc_info = None
|
|
220
|
+
param_descs = {}
|
|
221
|
+
|
|
222
|
+
func_name = name_override or doc_info.name if doc_info else func.__name__
|
|
223
|
+
|
|
224
|
+
# 2. Inspect function signature and get type hints
|
|
225
|
+
sig = inspect.signature(func)
|
|
226
|
+
type_hints = get_type_hints(func)
|
|
227
|
+
params = list(sig.parameters.items())
|
|
228
|
+
takes_context = False
|
|
229
|
+
filtered_params = []
|
|
230
|
+
|
|
231
|
+
if params:
|
|
232
|
+
first_name, first_param = params[0]
|
|
233
|
+
# Prefer the evaluated type hint if available
|
|
234
|
+
ann = type_hints.get(first_name, first_param.annotation)
|
|
235
|
+
if ann != inspect._empty:
|
|
236
|
+
origin = get_origin(ann) or ann
|
|
237
|
+
if origin is RunContextWrapper:
|
|
238
|
+
takes_context = True # Mark that the function takes context
|
|
239
|
+
else:
|
|
240
|
+
filtered_params.append((first_name, first_param))
|
|
241
|
+
else:
|
|
242
|
+
filtered_params.append((first_name, first_param))
|
|
243
|
+
|
|
244
|
+
# For parameters other than the first, raise error if any use RunContextWrapper.
|
|
245
|
+
for name, param in params[1:]:
|
|
246
|
+
ann = type_hints.get(name, param.annotation)
|
|
247
|
+
if ann != inspect._empty:
|
|
248
|
+
origin = get_origin(ann) or ann
|
|
249
|
+
if origin is RunContextWrapper:
|
|
250
|
+
raise UserError(
|
|
251
|
+
f"RunContextWrapper param found at non-first position in function"
|
|
252
|
+
f" {func.__name__}"
|
|
253
|
+
)
|
|
254
|
+
filtered_params.append((name, param))
|
|
255
|
+
|
|
256
|
+
# We will collect field definitions for create_model as a dict:
|
|
257
|
+
# field_name -> (type_annotation, default_value_or_Field(...))
|
|
258
|
+
fields: dict[str, Any] = {}
|
|
259
|
+
|
|
260
|
+
for name, param in filtered_params:
|
|
261
|
+
ann = type_hints.get(name, param.annotation)
|
|
262
|
+
default = param.default
|
|
263
|
+
|
|
264
|
+
# If there's no type hint, assume `Any`
|
|
265
|
+
if ann == inspect._empty:
|
|
266
|
+
ann = Any
|
|
267
|
+
|
|
268
|
+
# If a docstring param description exists, use it
|
|
269
|
+
field_description = param_descs.get(name, None)
|
|
270
|
+
|
|
271
|
+
# Handle different parameter kinds
|
|
272
|
+
if param.kind == param.VAR_POSITIONAL:
|
|
273
|
+
# e.g. *args: extend positional args
|
|
274
|
+
if get_origin(ann) is tuple:
|
|
275
|
+
# e.g. def foo(*args: tuple[int, ...]) -> treat as List[int]
|
|
276
|
+
args_of_tuple = get_args(ann)
|
|
277
|
+
if len(args_of_tuple) == 2 and args_of_tuple[1] is Ellipsis:
|
|
278
|
+
ann = list[args_of_tuple[0]] # type: ignore
|
|
279
|
+
else:
|
|
280
|
+
ann = list[Any]
|
|
281
|
+
else:
|
|
282
|
+
# If user wrote *args: int, treat as List[int]
|
|
283
|
+
ann = list[ann] # type: ignore
|
|
284
|
+
|
|
285
|
+
# Default factory to empty list
|
|
286
|
+
fields[name] = (
|
|
287
|
+
ann,
|
|
288
|
+
Field(default_factory=list, description=field_description), # type: ignore
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
elif param.kind == param.VAR_KEYWORD:
|
|
292
|
+
# **kwargs handling
|
|
293
|
+
if get_origin(ann) is dict:
|
|
294
|
+
# e.g. def foo(**kwargs: dict[str, int])
|
|
295
|
+
dict_args = get_args(ann)
|
|
296
|
+
if len(dict_args) == 2:
|
|
297
|
+
ann = dict[dict_args[0], dict_args[1]] # type: ignore
|
|
298
|
+
else:
|
|
299
|
+
ann = dict[str, Any]
|
|
300
|
+
else:
|
|
301
|
+
# e.g. def foo(**kwargs: int) -> Dict[str, int]
|
|
302
|
+
ann = dict[str, ann] # type: ignore
|
|
303
|
+
|
|
304
|
+
fields[name] = (
|
|
305
|
+
ann,
|
|
306
|
+
Field(default_factory=dict, description=field_description), # type: ignore
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
else:
|
|
310
|
+
# Normal parameter
|
|
311
|
+
if default == inspect._empty:
|
|
312
|
+
# Required field
|
|
313
|
+
fields[name] = (
|
|
314
|
+
ann,
|
|
315
|
+
Field(..., description=field_description),
|
|
316
|
+
)
|
|
317
|
+
else:
|
|
318
|
+
# Parameter with a default value
|
|
319
|
+
fields[name] = (
|
|
320
|
+
ann,
|
|
321
|
+
Field(default=default, description=field_description),
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# 3. Dynamically build a Pydantic model
|
|
325
|
+
dynamic_model = create_model(f"{func_name}_args", __base__=BaseModel, **fields)
|
|
326
|
+
|
|
327
|
+
# 4. Build JSON schema from that model
|
|
328
|
+
json_schema = dynamic_model.model_json_schema()
|
|
329
|
+
if strict_json_schema:
|
|
330
|
+
json_schema = ensure_strict_json_schema(json_schema)
|
|
331
|
+
|
|
332
|
+
# 5. Return as a FuncSchema dataclass
|
|
333
|
+
return FuncSchema(
|
|
334
|
+
name=func_name,
|
|
335
|
+
description=description_override or doc_info.description if doc_info else None,
|
|
336
|
+
params_pydantic_model=dynamic_model,
|
|
337
|
+
params_json_schema=json_schema,
|
|
338
|
+
signature=sig,
|
|
339
|
+
takes_context=takes_context,
|
|
340
|
+
)
|
agents/guardrail.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from collections.abc import Awaitable
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, Union, overload
|
|
7
|
+
|
|
8
|
+
from typing_extensions import TypeVar
|
|
9
|
+
|
|
10
|
+
from ._utils import MaybeAwaitable
|
|
11
|
+
from .exceptions import UserError
|
|
12
|
+
from .items import TResponseInputItem
|
|
13
|
+
from .run_context import RunContextWrapper, TContext
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from .agent import Agent
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class GuardrailFunctionOutput:
|
|
21
|
+
"""The output of a guardrail function."""
|
|
22
|
+
|
|
23
|
+
output_info: Any
|
|
24
|
+
"""
|
|
25
|
+
Optional information about the guardrail's output. For example, the guardrail could include
|
|
26
|
+
information about the checks it performed and granular results.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
tripwire_triggered: bool
|
|
30
|
+
"""
|
|
31
|
+
Whether the tripwire was triggered. If triggered, the agent's execution will be halted.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class InputGuardrailResult:
|
|
37
|
+
"""The result of a guardrail run."""
|
|
38
|
+
|
|
39
|
+
guardrail: InputGuardrail[Any]
|
|
40
|
+
"""
|
|
41
|
+
The guardrail that was run.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
output: GuardrailFunctionOutput
|
|
45
|
+
"""The output of the guardrail function."""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class OutputGuardrailResult:
|
|
50
|
+
"""The result of a guardrail run."""
|
|
51
|
+
|
|
52
|
+
guardrail: OutputGuardrail[Any]
|
|
53
|
+
"""
|
|
54
|
+
The guardrail that was run.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
agent_output: Any
|
|
58
|
+
"""
|
|
59
|
+
The output of the agent that was checked by the guardrail.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
agent: Agent[Any]
|
|
63
|
+
"""
|
|
64
|
+
The agent that was checked by the guardrail.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
output: GuardrailFunctionOutput
|
|
68
|
+
"""The output of the guardrail function."""
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class InputGuardrail(Generic[TContext]):
|
|
73
|
+
"""Input guardrails are checks that run in parallel to the agent's execution.
|
|
74
|
+
They can be used to do things like:
|
|
75
|
+
- Check if input messages are off-topic
|
|
76
|
+
- Take over control of the agent's execution if an unexpected input is detected
|
|
77
|
+
|
|
78
|
+
You can use the `@input_guardrail()` decorator to turn a function into an `InputGuardrail`, or
|
|
79
|
+
create an `InputGuardrail` manually.
|
|
80
|
+
|
|
81
|
+
Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`, the agent
|
|
82
|
+
execution will immediately stop and a `InputGuardrailTripwireTriggered` exception will be raised
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
guardrail_function: Callable[
|
|
86
|
+
[RunContextWrapper[TContext], Agent[Any], str | list[TResponseInputItem]],
|
|
87
|
+
MaybeAwaitable[GuardrailFunctionOutput],
|
|
88
|
+
]
|
|
89
|
+
"""A function that receives the the agent input and the context, and returns a
|
|
90
|
+
`GuardrailResult`. The result marks whether the tripwire was triggered, and can optionally
|
|
91
|
+
include information about the guardrail's output.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
name: str | None = None
|
|
95
|
+
"""The name of the guardrail, used for tracing. If not provided, we'll use the guardrail
|
|
96
|
+
function's name.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def get_name(self) -> str:
|
|
100
|
+
if self.name:
|
|
101
|
+
return self.name
|
|
102
|
+
|
|
103
|
+
return self.guardrail_function.__name__
|
|
104
|
+
|
|
105
|
+
async def run(
|
|
106
|
+
self,
|
|
107
|
+
agent: Agent[Any],
|
|
108
|
+
input: str | list[TResponseInputItem],
|
|
109
|
+
context: RunContextWrapper[TContext],
|
|
110
|
+
) -> InputGuardrailResult:
|
|
111
|
+
if not callable(self.guardrail_function):
|
|
112
|
+
raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}")
|
|
113
|
+
|
|
114
|
+
output = self.guardrail_function(context, agent, input)
|
|
115
|
+
if inspect.isawaitable(output):
|
|
116
|
+
return InputGuardrailResult(
|
|
117
|
+
guardrail=self,
|
|
118
|
+
output=await output,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
return InputGuardrailResult(
|
|
122
|
+
guardrail=self,
|
|
123
|
+
output=output,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@dataclass
|
|
128
|
+
class OutputGuardrail(Generic[TContext]):
|
|
129
|
+
"""Output guardrails are checks that run on the final output of an agent.
|
|
130
|
+
They can be used to do check if the output passes certain validation criteria
|
|
131
|
+
|
|
132
|
+
You can use the `@output_guardrail()` decorator to turn a function into an `OutputGuardrail`,
|
|
133
|
+
or create an `OutputGuardrail` manually.
|
|
134
|
+
|
|
135
|
+
Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`, a
|
|
136
|
+
`OutputGuardrailTripwireTriggered` exception will be raised.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
guardrail_function: Callable[
|
|
140
|
+
[RunContextWrapper[TContext], Agent[Any], Any],
|
|
141
|
+
MaybeAwaitable[GuardrailFunctionOutput],
|
|
142
|
+
]
|
|
143
|
+
"""A function that receives the final agent, its output, and the context, and returns a
|
|
144
|
+
`GuardrailResult`. The result marks whether the tripwire was triggered, and can optionally
|
|
145
|
+
include information about the guardrail's output.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
name: str | None = None
|
|
149
|
+
"""The name of the guardrail, used for tracing. If not provided, we'll use the guardrail
|
|
150
|
+
function's name.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
def get_name(self) -> str:
|
|
154
|
+
if self.name:
|
|
155
|
+
return self.name
|
|
156
|
+
|
|
157
|
+
return self.guardrail_function.__name__
|
|
158
|
+
|
|
159
|
+
async def run(
|
|
160
|
+
self, context: RunContextWrapper[TContext], agent: Agent[Any], agent_output: Any
|
|
161
|
+
) -> OutputGuardrailResult:
|
|
162
|
+
if not callable(self.guardrail_function):
|
|
163
|
+
raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}")
|
|
164
|
+
|
|
165
|
+
output = self.guardrail_function(context, agent, agent_output)
|
|
166
|
+
if inspect.isawaitable(output):
|
|
167
|
+
return OutputGuardrailResult(
|
|
168
|
+
guardrail=self,
|
|
169
|
+
agent=agent,
|
|
170
|
+
agent_output=agent_output,
|
|
171
|
+
output=await output,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
return OutputGuardrailResult(
|
|
175
|
+
guardrail=self,
|
|
176
|
+
agent=agent,
|
|
177
|
+
agent_output=agent_output,
|
|
178
|
+
output=output,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
TContext_co = TypeVar("TContext_co", bound=Any, covariant=True)
|
|
183
|
+
|
|
184
|
+
# For InputGuardrail
|
|
185
|
+
_InputGuardrailFuncSync = Callable[
|
|
186
|
+
[RunContextWrapper[TContext_co], "Agent[Any]", Union[str, list[TResponseInputItem]]],
|
|
187
|
+
GuardrailFunctionOutput,
|
|
188
|
+
]
|
|
189
|
+
_InputGuardrailFuncAsync = Callable[
|
|
190
|
+
[RunContextWrapper[TContext_co], "Agent[Any]", Union[str, list[TResponseInputItem]]],
|
|
191
|
+
Awaitable[GuardrailFunctionOutput],
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@overload
|
|
196
|
+
def input_guardrail(
|
|
197
|
+
func: _InputGuardrailFuncSync[TContext_co],
|
|
198
|
+
) -> InputGuardrail[TContext_co]: ...
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@overload
|
|
202
|
+
def input_guardrail(
|
|
203
|
+
func: _InputGuardrailFuncAsync[TContext_co],
|
|
204
|
+
) -> InputGuardrail[TContext_co]: ...
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@overload
|
|
208
|
+
def input_guardrail(
|
|
209
|
+
*,
|
|
210
|
+
name: str | None = None,
|
|
211
|
+
) -> Callable[
|
|
212
|
+
[_InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co]],
|
|
213
|
+
InputGuardrail[TContext_co],
|
|
214
|
+
]: ...
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def input_guardrail(
|
|
218
|
+
func: _InputGuardrailFuncSync[TContext_co]
|
|
219
|
+
| _InputGuardrailFuncAsync[TContext_co]
|
|
220
|
+
| None = None,
|
|
221
|
+
*,
|
|
222
|
+
name: str | None = None,
|
|
223
|
+
) -> (
|
|
224
|
+
InputGuardrail[TContext_co]
|
|
225
|
+
| Callable[
|
|
226
|
+
[_InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co]],
|
|
227
|
+
InputGuardrail[TContext_co],
|
|
228
|
+
]
|
|
229
|
+
):
|
|
230
|
+
"""
|
|
231
|
+
Decorator that transforms a sync or async function into an `InputGuardrail`.
|
|
232
|
+
It can be used directly (no parentheses) or with keyword args, e.g.:
|
|
233
|
+
|
|
234
|
+
@input_guardrail
|
|
235
|
+
def my_sync_guardrail(...): ...
|
|
236
|
+
|
|
237
|
+
@input_guardrail(name="guardrail_name")
|
|
238
|
+
async def my_async_guardrail(...): ...
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
def decorator(
|
|
242
|
+
f: _InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co],
|
|
243
|
+
) -> InputGuardrail[TContext_co]:
|
|
244
|
+
return InputGuardrail(guardrail_function=f, name=name)
|
|
245
|
+
|
|
246
|
+
if func is not None:
|
|
247
|
+
# Decorator was used without parentheses
|
|
248
|
+
return decorator(func)
|
|
249
|
+
|
|
250
|
+
# Decorator used with keyword arguments
|
|
251
|
+
return decorator
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
_OutputGuardrailFuncSync = Callable[
|
|
255
|
+
[RunContextWrapper[TContext_co], "Agent[Any]", Any],
|
|
256
|
+
GuardrailFunctionOutput,
|
|
257
|
+
]
|
|
258
|
+
_OutputGuardrailFuncAsync = Callable[
|
|
259
|
+
[RunContextWrapper[TContext_co], "Agent[Any]", Any],
|
|
260
|
+
Awaitable[GuardrailFunctionOutput],
|
|
261
|
+
]
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@overload
|
|
265
|
+
def output_guardrail(
|
|
266
|
+
func: _OutputGuardrailFuncSync[TContext_co],
|
|
267
|
+
) -> OutputGuardrail[TContext_co]: ...
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@overload
|
|
271
|
+
def output_guardrail(
|
|
272
|
+
func: _OutputGuardrailFuncAsync[TContext_co],
|
|
273
|
+
) -> OutputGuardrail[TContext_co]: ...
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@overload
|
|
277
|
+
def output_guardrail(
|
|
278
|
+
*,
|
|
279
|
+
name: str | None = None,
|
|
280
|
+
) -> Callable[
|
|
281
|
+
[_OutputGuardrailFuncSync[TContext_co] | _OutputGuardrailFuncAsync[TContext_co]],
|
|
282
|
+
OutputGuardrail[TContext_co],
|
|
283
|
+
]: ...
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def output_guardrail(
|
|
287
|
+
func: _OutputGuardrailFuncSync[TContext_co]
|
|
288
|
+
| _OutputGuardrailFuncAsync[TContext_co]
|
|
289
|
+
| None = None,
|
|
290
|
+
*,
|
|
291
|
+
name: str | None = None,
|
|
292
|
+
) -> (
|
|
293
|
+
OutputGuardrail[TContext_co]
|
|
294
|
+
| Callable[
|
|
295
|
+
[_OutputGuardrailFuncSync[TContext_co] | _OutputGuardrailFuncAsync[TContext_co]],
|
|
296
|
+
OutputGuardrail[TContext_co],
|
|
297
|
+
]
|
|
298
|
+
):
|
|
299
|
+
"""
|
|
300
|
+
Decorator that transforms a sync or async function into an `OutputGuardrail`.
|
|
301
|
+
It can be used directly (no parentheses) or with keyword args, e.g.:
|
|
302
|
+
|
|
303
|
+
@output_guardrail
|
|
304
|
+
def my_sync_guardrail(...): ...
|
|
305
|
+
|
|
306
|
+
@output_guardrail(name="guardrail_name")
|
|
307
|
+
async def my_async_guardrail(...): ...
|
|
308
|
+
"""
|
|
309
|
+
|
|
310
|
+
def decorator(
|
|
311
|
+
f: _OutputGuardrailFuncSync[TContext_co] | _OutputGuardrailFuncAsync[TContext_co],
|
|
312
|
+
) -> OutputGuardrail[TContext_co]:
|
|
313
|
+
return OutputGuardrail(guardrail_function=f, name=name)
|
|
314
|
+
|
|
315
|
+
if func is not None:
|
|
316
|
+
# Decorator was used without parentheses
|
|
317
|
+
return decorator(func)
|
|
318
|
+
|
|
319
|
+
# Decorator used with keyword arguments
|
|
320
|
+
return decorator
|