openai-agents 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +223 -0
- agents/_config.py +23 -0
- agents/_debug.py +17 -0
- agents/_run_impl.py +792 -0
- agents/_utils.py +61 -0
- agents/agent.py +159 -0
- agents/agent_output.py +144 -0
- agents/computer.py +107 -0
- agents/exceptions.py +63 -0
- agents/extensions/handoff_filters.py +67 -0
- agents/extensions/handoff_prompt.py +19 -0
- agents/function_schema.py +340 -0
- agents/guardrail.py +320 -0
- agents/handoffs.py +236 -0
- agents/items.py +246 -0
- agents/lifecycle.py +105 -0
- agents/logger.py +3 -0
- agents/model_settings.py +35 -0
- agents/models/__init__.py +0 -0
- agents/models/_openai_shared.py +34 -0
- agents/models/fake_id.py +5 -0
- agents/models/interface.py +107 -0
- agents/models/openai_chatcompletions.py +952 -0
- agents/models/openai_provider.py +65 -0
- agents/models/openai_responses.py +384 -0
- agents/result.py +220 -0
- agents/run.py +904 -0
- agents/run_context.py +26 -0
- agents/stream_events.py +58 -0
- agents/strict_schema.py +167 -0
- agents/tool.py +286 -0
- agents/tracing/__init__.py +97 -0
- agents/tracing/create.py +306 -0
- agents/tracing/logger.py +3 -0
- agents/tracing/processor_interface.py +69 -0
- agents/tracing/processors.py +261 -0
- agents/tracing/scope.py +45 -0
- agents/tracing/setup.py +211 -0
- agents/tracing/span_data.py +188 -0
- agents/tracing/spans.py +264 -0
- agents/tracing/traces.py +195 -0
- agents/tracing/util.py +17 -0
- agents/usage.py +22 -0
- agents/version.py +7 -0
- openai_agents-0.0.2.dist-info/METADATA +202 -0
- openai_agents-0.0.2.dist-info/RECORD +49 -0
- openai_agents-0.0.2.dist-info/licenses/LICENSE +21 -0
- openai-agents/example.py +0 -2
- openai_agents-0.0.1.dist-info/METADATA +0 -17
- openai_agents-0.0.1.dist-info/RECORD +0 -6
- openai_agents-0.0.1.dist-info/licenses/LICENSE +0 -20
- {openai-agents → agents/extensions}/__init__.py +0 -0
- {openai_agents-0.0.1.dist-info → openai_agents-0.0.2.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import httpx
|
|
4
|
+
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
|
|
5
|
+
|
|
6
|
+
from . import _openai_shared
|
|
7
|
+
from .interface import Model, ModelProvider
|
|
8
|
+
from .openai_chatcompletions import OpenAIChatCompletionsModel
|
|
9
|
+
from .openai_responses import OpenAIResponsesModel
|
|
10
|
+
|
|
11
|
+
DEFAULT_MODEL: str = "gpt-4o"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
_http_client: httpx.AsyncClient | None = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# If we create a new httpx client for each request, that would mean no sharing of connection pools,
|
|
18
|
+
# which would mean worse latency and resource usage. So, we share the client across requests.
|
|
19
|
+
def shared_http_client() -> httpx.AsyncClient:
|
|
20
|
+
global _http_client
|
|
21
|
+
if _http_client is None:
|
|
22
|
+
_http_client = DefaultAsyncHttpxClient()
|
|
23
|
+
return _http_client
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class OpenAIProvider(ModelProvider):
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
*,
|
|
30
|
+
api_key: str | None = None,
|
|
31
|
+
base_url: str | None = None,
|
|
32
|
+
openai_client: AsyncOpenAI | None = None,
|
|
33
|
+
organization: str | None = None,
|
|
34
|
+
project: str | None = None,
|
|
35
|
+
use_responses: bool | None = None,
|
|
36
|
+
) -> None:
|
|
37
|
+
if openai_client is not None:
|
|
38
|
+
assert api_key is None and base_url is None, (
|
|
39
|
+
"Don't provide api_key or base_url if you provide openai_client"
|
|
40
|
+
)
|
|
41
|
+
self._client = openai_client
|
|
42
|
+
else:
|
|
43
|
+
self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI(
|
|
44
|
+
api_key=api_key or _openai_shared.get_default_openai_key(),
|
|
45
|
+
base_url=base_url,
|
|
46
|
+
organization=organization,
|
|
47
|
+
project=project,
|
|
48
|
+
http_client=shared_http_client(),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
self._is_openai_model = self._client.base_url.host.startswith("api.openai.com")
|
|
52
|
+
if use_responses is not None:
|
|
53
|
+
self._use_responses = use_responses
|
|
54
|
+
else:
|
|
55
|
+
self._use_responses = _openai_shared.get_use_responses_by_default()
|
|
56
|
+
|
|
57
|
+
def get_model(self, model_name: str | None) -> Model:
|
|
58
|
+
if model_name is None:
|
|
59
|
+
model_name = DEFAULT_MODEL
|
|
60
|
+
|
|
61
|
+
return (
|
|
62
|
+
OpenAIResponsesModel(model=model_name, openai_client=self._client)
|
|
63
|
+
if self._use_responses
|
|
64
|
+
else OpenAIChatCompletionsModel(model=model_name, openai_client=self._client)
|
|
65
|
+
)
|
|
@@ -0,0 +1,384 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import AsyncIterator
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal, overload
|
|
7
|
+
|
|
8
|
+
from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream, NotGiven
|
|
9
|
+
from openai.types import ChatModel
|
|
10
|
+
from openai.types.responses import (
|
|
11
|
+
Response,
|
|
12
|
+
ResponseCompletedEvent,
|
|
13
|
+
ResponseStreamEvent,
|
|
14
|
+
ResponseTextConfigParam,
|
|
15
|
+
ToolParam,
|
|
16
|
+
WebSearchToolParam,
|
|
17
|
+
response_create_params,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from .. import _debug
|
|
21
|
+
from ..agent_output import AgentOutputSchema
|
|
22
|
+
from ..exceptions import UserError
|
|
23
|
+
from ..handoffs import Handoff
|
|
24
|
+
from ..items import ItemHelpers, ModelResponse, TResponseInputItem
|
|
25
|
+
from ..logger import logger
|
|
26
|
+
from ..tool import ComputerTool, FileSearchTool, FunctionTool, Tool, WebSearchTool
|
|
27
|
+
from ..tracing import SpanError, response_span
|
|
28
|
+
from ..usage import Usage
|
|
29
|
+
from ..version import __version__
|
|
30
|
+
from .interface import Model, ModelTracing
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from ..model_settings import ModelSettings
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
_USER_AGENT = f"Agents/Python {__version__}"
|
|
37
|
+
_HEADERS = {"User-Agent": _USER_AGENT}
|
|
38
|
+
|
|
39
|
+
# From the Responses API
|
|
40
|
+
IncludeLiteral = Literal[
|
|
41
|
+
"file_search_call.results",
|
|
42
|
+
"message.input_image.image_url",
|
|
43
|
+
"computer_call_output.output.image_url",
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class OpenAIResponsesModel(Model):
|
|
48
|
+
"""
|
|
49
|
+
Implementation of `Model` that uses the OpenAI Responses API.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
model: str | ChatModel,
|
|
55
|
+
openai_client: AsyncOpenAI,
|
|
56
|
+
) -> None:
|
|
57
|
+
self.model = model
|
|
58
|
+
self._client = openai_client
|
|
59
|
+
|
|
60
|
+
def _non_null_or_not_given(self, value: Any) -> Any:
|
|
61
|
+
return value if value is not None else NOT_GIVEN
|
|
62
|
+
|
|
63
|
+
async def get_response(
|
|
64
|
+
self,
|
|
65
|
+
system_instructions: str | None,
|
|
66
|
+
input: str | list[TResponseInputItem],
|
|
67
|
+
model_settings: ModelSettings,
|
|
68
|
+
tools: list[Tool],
|
|
69
|
+
output_schema: AgentOutputSchema | None,
|
|
70
|
+
handoffs: list[Handoff],
|
|
71
|
+
tracing: ModelTracing,
|
|
72
|
+
) -> ModelResponse:
|
|
73
|
+
with response_span(disabled=tracing.is_disabled()) as span_response:
|
|
74
|
+
try:
|
|
75
|
+
response = await self._fetch_response(
|
|
76
|
+
system_instructions,
|
|
77
|
+
input,
|
|
78
|
+
model_settings,
|
|
79
|
+
tools,
|
|
80
|
+
output_schema,
|
|
81
|
+
handoffs,
|
|
82
|
+
stream=False,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if _debug.DONT_LOG_MODEL_DATA:
|
|
86
|
+
logger.debug("LLM responsed")
|
|
87
|
+
else:
|
|
88
|
+
logger.debug(
|
|
89
|
+
"LLM resp:\n"
|
|
90
|
+
f"{json.dumps([x.model_dump() for x in response.output], indent=2)}\n"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
usage = (
|
|
94
|
+
Usage(
|
|
95
|
+
requests=1,
|
|
96
|
+
input_tokens=response.usage.input_tokens,
|
|
97
|
+
output_tokens=response.usage.output_tokens,
|
|
98
|
+
total_tokens=response.usage.total_tokens,
|
|
99
|
+
)
|
|
100
|
+
if response.usage
|
|
101
|
+
else Usage()
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
if tracing.include_data():
|
|
105
|
+
span_response.span_data.response = response
|
|
106
|
+
span_response.span_data.input = input
|
|
107
|
+
except Exception as e:
|
|
108
|
+
span_response.set_error(
|
|
109
|
+
SpanError(
|
|
110
|
+
message="Error getting response",
|
|
111
|
+
data={
|
|
112
|
+
"error": str(e) if tracing.include_data() else e.__class__.__name__,
|
|
113
|
+
},
|
|
114
|
+
)
|
|
115
|
+
)
|
|
116
|
+
logger.error(f"Error getting response: {e}")
|
|
117
|
+
raise
|
|
118
|
+
|
|
119
|
+
return ModelResponse(
|
|
120
|
+
output=response.output,
|
|
121
|
+
usage=usage,
|
|
122
|
+
referenceable_id=response.id,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
async def stream_response(
|
|
126
|
+
self,
|
|
127
|
+
system_instructions: str | None,
|
|
128
|
+
input: str | list[TResponseInputItem],
|
|
129
|
+
model_settings: ModelSettings,
|
|
130
|
+
tools: list[Tool],
|
|
131
|
+
output_schema: AgentOutputSchema | None,
|
|
132
|
+
handoffs: list[Handoff],
|
|
133
|
+
tracing: ModelTracing,
|
|
134
|
+
) -> AsyncIterator[ResponseStreamEvent]:
|
|
135
|
+
"""
|
|
136
|
+
Yields a partial message as it is generated, as well as the usage information.
|
|
137
|
+
"""
|
|
138
|
+
with response_span(disabled=tracing.is_disabled()) as span_response:
|
|
139
|
+
try:
|
|
140
|
+
stream = await self._fetch_response(
|
|
141
|
+
system_instructions,
|
|
142
|
+
input,
|
|
143
|
+
model_settings,
|
|
144
|
+
tools,
|
|
145
|
+
output_schema,
|
|
146
|
+
handoffs,
|
|
147
|
+
stream=True,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
final_response: Response | None = None
|
|
151
|
+
|
|
152
|
+
async for chunk in stream:
|
|
153
|
+
if isinstance(chunk, ResponseCompletedEvent):
|
|
154
|
+
final_response = chunk.response
|
|
155
|
+
yield chunk
|
|
156
|
+
|
|
157
|
+
if final_response and tracing.include_data():
|
|
158
|
+
span_response.span_data.response = final_response
|
|
159
|
+
span_response.span_data.input = input
|
|
160
|
+
|
|
161
|
+
except Exception as e:
|
|
162
|
+
span_response.set_error(
|
|
163
|
+
SpanError(
|
|
164
|
+
message="Error streaming response",
|
|
165
|
+
data={
|
|
166
|
+
"error": str(e) if tracing.include_data() else e.__class__.__name__,
|
|
167
|
+
},
|
|
168
|
+
)
|
|
169
|
+
)
|
|
170
|
+
logger.error(f"Error streaming response: {e}")
|
|
171
|
+
raise
|
|
172
|
+
|
|
173
|
+
@overload
|
|
174
|
+
async def _fetch_response(
|
|
175
|
+
self,
|
|
176
|
+
system_instructions: str | None,
|
|
177
|
+
input: str | list[TResponseInputItem],
|
|
178
|
+
model_settings: ModelSettings,
|
|
179
|
+
tools: list[Tool],
|
|
180
|
+
output_schema: AgentOutputSchema | None,
|
|
181
|
+
handoffs: list[Handoff],
|
|
182
|
+
stream: Literal[True],
|
|
183
|
+
) -> AsyncStream[ResponseStreamEvent]: ...
|
|
184
|
+
|
|
185
|
+
@overload
|
|
186
|
+
async def _fetch_response(
|
|
187
|
+
self,
|
|
188
|
+
system_instructions: str | None,
|
|
189
|
+
input: str | list[TResponseInputItem],
|
|
190
|
+
model_settings: ModelSettings,
|
|
191
|
+
tools: list[Tool],
|
|
192
|
+
output_schema: AgentOutputSchema | None,
|
|
193
|
+
handoffs: list[Handoff],
|
|
194
|
+
stream: Literal[False],
|
|
195
|
+
) -> Response: ...
|
|
196
|
+
|
|
197
|
+
async def _fetch_response(
|
|
198
|
+
self,
|
|
199
|
+
system_instructions: str | None,
|
|
200
|
+
input: str | list[TResponseInputItem],
|
|
201
|
+
model_settings: ModelSettings,
|
|
202
|
+
tools: list[Tool],
|
|
203
|
+
output_schema: AgentOutputSchema | None,
|
|
204
|
+
handoffs: list[Handoff],
|
|
205
|
+
stream: Literal[True] | Literal[False] = False,
|
|
206
|
+
) -> Response | AsyncStream[ResponseStreamEvent]:
|
|
207
|
+
list_input = ItemHelpers.input_to_new_input_list(input)
|
|
208
|
+
|
|
209
|
+
parallel_tool_calls = (
|
|
210
|
+
True if model_settings.parallel_tool_calls and tools and len(tools) > 0 else NOT_GIVEN
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
tool_choice = Converter.convert_tool_choice(model_settings.tool_choice)
|
|
214
|
+
converted_tools = Converter.convert_tools(tools, handoffs)
|
|
215
|
+
response_format = Converter.get_response_format(output_schema)
|
|
216
|
+
|
|
217
|
+
if _debug.DONT_LOG_MODEL_DATA:
|
|
218
|
+
logger.debug("Calling LLM")
|
|
219
|
+
else:
|
|
220
|
+
logger.debug(
|
|
221
|
+
f"Calling LLM {self.model} with input:\n"
|
|
222
|
+
f"{json.dumps(list_input, indent=2)}\n"
|
|
223
|
+
f"Tools:\n{json.dumps(converted_tools.tools, indent=2)}\n"
|
|
224
|
+
f"Stream: {stream}\n"
|
|
225
|
+
f"Tool choice: {tool_choice}\n"
|
|
226
|
+
f"Response format: {response_format}\n"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
return await self._client.responses.create(
|
|
230
|
+
instructions=self._non_null_or_not_given(system_instructions),
|
|
231
|
+
model=self.model,
|
|
232
|
+
input=list_input,
|
|
233
|
+
include=converted_tools.includes,
|
|
234
|
+
tools=converted_tools.tools,
|
|
235
|
+
temperature=self._non_null_or_not_given(model_settings.temperature),
|
|
236
|
+
top_p=self._non_null_or_not_given(model_settings.top_p),
|
|
237
|
+
truncation=self._non_null_or_not_given(model_settings.truncation),
|
|
238
|
+
tool_choice=tool_choice,
|
|
239
|
+
parallel_tool_calls=parallel_tool_calls,
|
|
240
|
+
stream=stream,
|
|
241
|
+
extra_headers=_HEADERS,
|
|
242
|
+
text=response_format,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
def _get_client(self) -> AsyncOpenAI:
|
|
246
|
+
if self._client is None:
|
|
247
|
+
self._client = AsyncOpenAI()
|
|
248
|
+
return self._client
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
@dataclass
|
|
252
|
+
class ConvertedTools:
|
|
253
|
+
tools: list[ToolParam]
|
|
254
|
+
includes: list[IncludeLiteral]
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class Converter:
|
|
258
|
+
@classmethod
|
|
259
|
+
def convert_tool_choice(
|
|
260
|
+
cls, tool_choice: Literal["auto", "required", "none"] | str | None
|
|
261
|
+
) -> response_create_params.ToolChoice | NotGiven:
|
|
262
|
+
if tool_choice is None:
|
|
263
|
+
return NOT_GIVEN
|
|
264
|
+
elif tool_choice == "required":
|
|
265
|
+
return "required"
|
|
266
|
+
elif tool_choice == "auto":
|
|
267
|
+
return "auto"
|
|
268
|
+
elif tool_choice == "none":
|
|
269
|
+
return "none"
|
|
270
|
+
elif tool_choice == "file_search":
|
|
271
|
+
return {
|
|
272
|
+
"type": "file_search",
|
|
273
|
+
}
|
|
274
|
+
elif tool_choice == "web_search_preview":
|
|
275
|
+
return {
|
|
276
|
+
"type": "web_search_preview",
|
|
277
|
+
}
|
|
278
|
+
elif tool_choice == "computer_use_preview":
|
|
279
|
+
return {
|
|
280
|
+
"type": "computer_use_preview",
|
|
281
|
+
}
|
|
282
|
+
else:
|
|
283
|
+
return {
|
|
284
|
+
"type": "function",
|
|
285
|
+
"name": tool_choice,
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
@classmethod
|
|
289
|
+
def get_response_format(
|
|
290
|
+
cls, output_schema: AgentOutputSchema | None
|
|
291
|
+
) -> ResponseTextConfigParam | NotGiven:
|
|
292
|
+
if output_schema is None or output_schema.is_plain_text():
|
|
293
|
+
return NOT_GIVEN
|
|
294
|
+
else:
|
|
295
|
+
return {
|
|
296
|
+
"format": {
|
|
297
|
+
"type": "json_schema",
|
|
298
|
+
"name": "final_output",
|
|
299
|
+
"schema": output_schema.json_schema(),
|
|
300
|
+
"strict": output_schema.strict_json_schema,
|
|
301
|
+
}
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
@classmethod
|
|
305
|
+
def convert_tools(
|
|
306
|
+
cls,
|
|
307
|
+
tools: list[Tool],
|
|
308
|
+
handoffs: list[Handoff[Any]],
|
|
309
|
+
) -> ConvertedTools:
|
|
310
|
+
converted_tools: list[ToolParam] = []
|
|
311
|
+
includes: list[IncludeLiteral] = []
|
|
312
|
+
|
|
313
|
+
computer_tools = [tool for tool in tools if isinstance(tool, ComputerTool)]
|
|
314
|
+
if len(computer_tools) > 1:
|
|
315
|
+
raise UserError(f"You can only provide one computer tool. Got {len(computer_tools)}")
|
|
316
|
+
|
|
317
|
+
for tool in tools:
|
|
318
|
+
converted_tool, include = cls._convert_tool(tool)
|
|
319
|
+
converted_tools.append(converted_tool)
|
|
320
|
+
if include:
|
|
321
|
+
includes.append(include)
|
|
322
|
+
|
|
323
|
+
for handoff in handoffs:
|
|
324
|
+
converted_tools.append(cls._convert_handoff_tool(handoff))
|
|
325
|
+
|
|
326
|
+
return ConvertedTools(tools=converted_tools, includes=includes)
|
|
327
|
+
|
|
328
|
+
@classmethod
|
|
329
|
+
def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, IncludeLiteral | None]:
|
|
330
|
+
"""Returns converted tool and includes"""
|
|
331
|
+
|
|
332
|
+
if isinstance(tool, FunctionTool):
|
|
333
|
+
converted_tool: ToolParam = {
|
|
334
|
+
"name": tool.name,
|
|
335
|
+
"parameters": tool.params_json_schema,
|
|
336
|
+
"strict": tool.strict_json_schema,
|
|
337
|
+
"type": "function",
|
|
338
|
+
"description": tool.description,
|
|
339
|
+
}
|
|
340
|
+
includes: IncludeLiteral | None = None
|
|
341
|
+
elif isinstance(tool, WebSearchTool):
|
|
342
|
+
ws: WebSearchToolParam = {
|
|
343
|
+
"type": "web_search_preview",
|
|
344
|
+
"user_location": tool.user_location,
|
|
345
|
+
"search_context_size": tool.search_context_size,
|
|
346
|
+
}
|
|
347
|
+
converted_tool = ws
|
|
348
|
+
includes = None
|
|
349
|
+
elif isinstance(tool, FileSearchTool):
|
|
350
|
+
converted_tool = {
|
|
351
|
+
"type": "file_search",
|
|
352
|
+
"vector_store_ids": tool.vector_store_ids,
|
|
353
|
+
}
|
|
354
|
+
if tool.max_num_results:
|
|
355
|
+
converted_tool["max_num_results"] = tool.max_num_results
|
|
356
|
+
if tool.ranking_options:
|
|
357
|
+
converted_tool["ranking_options"] = tool.ranking_options
|
|
358
|
+
if tool.filters:
|
|
359
|
+
converted_tool["filters"] = tool.filters
|
|
360
|
+
|
|
361
|
+
includes = "file_search_call.results" if tool.include_search_results else None
|
|
362
|
+
elif isinstance(tool, ComputerTool):
|
|
363
|
+
converted_tool = {
|
|
364
|
+
"type": "computer-preview",
|
|
365
|
+
"environment": tool.computer.environment,
|
|
366
|
+
"display_width": tool.computer.dimensions[0],
|
|
367
|
+
"display_height": tool.computer.dimensions[1],
|
|
368
|
+
}
|
|
369
|
+
includes = None
|
|
370
|
+
|
|
371
|
+
else:
|
|
372
|
+
raise UserError(f"Unknown tool type: {type(tool)}, tool")
|
|
373
|
+
|
|
374
|
+
return converted_tool, includes
|
|
375
|
+
|
|
376
|
+
@classmethod
|
|
377
|
+
def _convert_handoff_tool(cls, handoff: Handoff) -> ToolParam:
|
|
378
|
+
return {
|
|
379
|
+
"name": handoff.tool_name,
|
|
380
|
+
"parameters": handoff.input_json_schema,
|
|
381
|
+
"strict": handoff.strict_json_schema,
|
|
382
|
+
"type": "function",
|
|
383
|
+
"description": handoff.tool_description,
|
|
384
|
+
}
|
agents/result.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import asyncio
|
|
5
|
+
from collections.abc import AsyncIterator
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
8
|
+
|
|
9
|
+
from typing_extensions import TypeVar
|
|
10
|
+
|
|
11
|
+
from ._run_impl import QueueCompleteSentinel
|
|
12
|
+
from .agent import Agent
|
|
13
|
+
from .agent_output import AgentOutputSchema
|
|
14
|
+
from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded
|
|
15
|
+
from .guardrail import InputGuardrailResult, OutputGuardrailResult
|
|
16
|
+
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
|
|
17
|
+
from .logger import logger
|
|
18
|
+
from .stream_events import StreamEvent
|
|
19
|
+
from .tracing import Trace
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from ._run_impl import QueueCompleteSentinel
|
|
23
|
+
from .agent import Agent
|
|
24
|
+
|
|
25
|
+
T = TypeVar("T")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class RunResultBase(abc.ABC):
|
|
30
|
+
input: str | list[TResponseInputItem]
|
|
31
|
+
"""The original input items i.e. the items before run() was called. This may be a mutated
|
|
32
|
+
version of the input, if there are handoff input filters that mutate the input.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
new_items: list[RunItem]
|
|
36
|
+
"""The new items generated during the agent run. These include things like new messages, tool
|
|
37
|
+
calls and their outputs, etc.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
raw_responses: list[ModelResponse]
|
|
41
|
+
"""The raw LLM responses generated by the model during the agent run."""
|
|
42
|
+
|
|
43
|
+
final_output: Any
|
|
44
|
+
"""The output of the last agent."""
|
|
45
|
+
|
|
46
|
+
input_guardrail_results: list[InputGuardrailResult]
|
|
47
|
+
"""Guardrail results for the input messages."""
|
|
48
|
+
|
|
49
|
+
output_guardrail_results: list[OutputGuardrailResult]
|
|
50
|
+
"""Guardrail results for the final output of the agent."""
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
@abc.abstractmethod
|
|
54
|
+
def last_agent(self) -> Agent[Any]:
|
|
55
|
+
"""The last agent that was run."""
|
|
56
|
+
|
|
57
|
+
def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) -> T:
|
|
58
|
+
"""A convenience method to cast the final output to a specific type. By default, the cast
|
|
59
|
+
is only for the typechecker. If you set `raise_if_incorrect_type` to True, we'll raise a
|
|
60
|
+
TypeError if the final output is not of the given type.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
cls: The type to cast the final output to.
|
|
64
|
+
raise_if_incorrect_type: If True, we'll raise a TypeError if the final output is not of
|
|
65
|
+
the given type.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
The final output casted to the given type.
|
|
69
|
+
"""
|
|
70
|
+
if raise_if_incorrect_type and not isinstance(self.final_output, cls):
|
|
71
|
+
raise TypeError(f"Final output is not of type {cls.__name__}")
|
|
72
|
+
|
|
73
|
+
return cast(T, self.final_output)
|
|
74
|
+
|
|
75
|
+
def to_input_list(self) -> list[TResponseInputItem]:
|
|
76
|
+
"""Creates a new input list, merging the original input with all the new items generated."""
|
|
77
|
+
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input)
|
|
78
|
+
new_items = [item.to_input_item() for item in self.new_items]
|
|
79
|
+
|
|
80
|
+
return original_items + new_items
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class RunResult(RunResultBase):
|
|
85
|
+
_last_agent: Agent[Any]
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def last_agent(self) -> Agent[Any]:
|
|
89
|
+
"""The last agent that was run."""
|
|
90
|
+
return self._last_agent
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@dataclass
|
|
94
|
+
class RunResultStreaming(RunResultBase):
|
|
95
|
+
"""The result of an agent run in streaming mode. You can use the `stream_events` method to
|
|
96
|
+
receive semantic events as they are generated.
|
|
97
|
+
|
|
98
|
+
The streaming method will raise:
|
|
99
|
+
- A MaxTurnsExceeded exception if the agent exceeds the max_turns limit.
|
|
100
|
+
- A GuardrailTripwireTriggered exception if a guardrail is tripped.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
current_agent: Agent[Any]
|
|
104
|
+
"""The current agent that is running."""
|
|
105
|
+
|
|
106
|
+
current_turn: int
|
|
107
|
+
"""The current turn number."""
|
|
108
|
+
|
|
109
|
+
max_turns: int
|
|
110
|
+
"""The maximum number of turns the agent can run for."""
|
|
111
|
+
|
|
112
|
+
final_output: Any
|
|
113
|
+
"""The final output of the agent. This is None until the agent has finished running."""
|
|
114
|
+
|
|
115
|
+
_current_agent_output_schema: AgentOutputSchema | None = field(repr=False)
|
|
116
|
+
|
|
117
|
+
_trace: Trace | None = field(repr=False)
|
|
118
|
+
|
|
119
|
+
is_complete: bool = False
|
|
120
|
+
"""Whether the agent has finished running."""
|
|
121
|
+
|
|
122
|
+
# Queues that the background run_loop writes to
|
|
123
|
+
_event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] = field(
|
|
124
|
+
default_factory=asyncio.Queue, repr=False
|
|
125
|
+
)
|
|
126
|
+
_input_guardrail_queue: asyncio.Queue[InputGuardrailResult] = field(
|
|
127
|
+
default_factory=asyncio.Queue, repr=False
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Store the asyncio tasks that we're waiting on
|
|
131
|
+
_run_impl_task: asyncio.Task[Any] | None = field(default=None, repr=False)
|
|
132
|
+
_input_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
|
|
133
|
+
_output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
|
|
134
|
+
_stored_exception: Exception | None = field(default=None, repr=False)
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def last_agent(self) -> Agent[Any]:
|
|
138
|
+
"""The last agent that was run. Updates as the agent run progresses, so the true last agent
|
|
139
|
+
is only available after the agent run is complete.
|
|
140
|
+
"""
|
|
141
|
+
return self.current_agent
|
|
142
|
+
|
|
143
|
+
async def stream_events(self) -> AsyncIterator[StreamEvent]:
|
|
144
|
+
"""Stream deltas for new items as they are generated. We're using the types from the
|
|
145
|
+
OpenAI Responses API, so these are semantic events: each event has a `type` field that
|
|
146
|
+
describes the type of the event, along with the data for that event.
|
|
147
|
+
|
|
148
|
+
This will raise:
|
|
149
|
+
- A MaxTurnsExceeded exception if the agent exceeds the max_turns limit.
|
|
150
|
+
- A GuardrailTripwireTriggered exception if a guardrail is tripped.
|
|
151
|
+
"""
|
|
152
|
+
while True:
|
|
153
|
+
self._check_errors()
|
|
154
|
+
if self._stored_exception:
|
|
155
|
+
logger.debug("Breaking due to stored exception")
|
|
156
|
+
self.is_complete = True
|
|
157
|
+
break
|
|
158
|
+
|
|
159
|
+
if self.is_complete and self._event_queue.empty():
|
|
160
|
+
break
|
|
161
|
+
|
|
162
|
+
try:
|
|
163
|
+
item = await self._event_queue.get()
|
|
164
|
+
except asyncio.CancelledError:
|
|
165
|
+
break
|
|
166
|
+
|
|
167
|
+
if isinstance(item, QueueCompleteSentinel):
|
|
168
|
+
self._event_queue.task_done()
|
|
169
|
+
# Check for errors, in case the queue was completed due to an exception
|
|
170
|
+
self._check_errors()
|
|
171
|
+
break
|
|
172
|
+
|
|
173
|
+
yield item
|
|
174
|
+
self._event_queue.task_done()
|
|
175
|
+
|
|
176
|
+
if self._trace:
|
|
177
|
+
self._trace.finish(reset_current=True)
|
|
178
|
+
|
|
179
|
+
self._cleanup_tasks()
|
|
180
|
+
|
|
181
|
+
if self._stored_exception:
|
|
182
|
+
raise self._stored_exception
|
|
183
|
+
|
|
184
|
+
def _check_errors(self):
|
|
185
|
+
if self.current_turn > self.max_turns:
|
|
186
|
+
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
|
|
187
|
+
|
|
188
|
+
# Fetch all the completed guardrail results from the queue and raise if needed
|
|
189
|
+
while not self._input_guardrail_queue.empty():
|
|
190
|
+
guardrail_result = self._input_guardrail_queue.get_nowait()
|
|
191
|
+
if guardrail_result.output.tripwire_triggered:
|
|
192
|
+
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
|
|
193
|
+
|
|
194
|
+
# Check the tasks for any exceptions
|
|
195
|
+
if self._run_impl_task and self._run_impl_task.done():
|
|
196
|
+
exc = self._run_impl_task.exception()
|
|
197
|
+
if exc and isinstance(exc, Exception):
|
|
198
|
+
self._stored_exception = exc
|
|
199
|
+
|
|
200
|
+
if self._input_guardrails_task and self._input_guardrails_task.done():
|
|
201
|
+
exc = self._input_guardrails_task.exception()
|
|
202
|
+
if exc and isinstance(exc, Exception):
|
|
203
|
+
self._stored_exception = exc
|
|
204
|
+
|
|
205
|
+
if self._output_guardrails_task and self._output_guardrails_task.done():
|
|
206
|
+
exc = self._output_guardrails_task.exception()
|
|
207
|
+
if exc and isinstance(exc, Exception):
|
|
208
|
+
self._stored_exception = exc
|
|
209
|
+
|
|
210
|
+
def _cleanup_tasks(self):
|
|
211
|
+
if self._run_impl_task and not self._run_impl_task.done():
|
|
212
|
+
self._run_impl_task.cancel()
|
|
213
|
+
|
|
214
|
+
if self._input_guardrails_task and not self._input_guardrails_task.done():
|
|
215
|
+
self._input_guardrails_task.cancel()
|
|
216
|
+
|
|
217
|
+
if self._output_guardrails_task and not self._output_guardrails_task.done():
|
|
218
|
+
self._output_guardrails_task.cancel()
|
|
219
|
+
self._output_guardrails_task.cancel()
|
|
220
|
+
self._output_guardrails_task.cancel()
|