openai-agents 0.0.16__py3-none-any.whl → 0.0.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +8 -0
- agents/_run_impl.py +11 -5
- agents/agent.py +33 -3
- agents/agent_output.py +1 -1
- agents/exceptions.py +38 -5
- agents/extensions/models/litellm_model.py +13 -2
- agents/extensions/visualization.py +35 -18
- agents/function_schema.py +7 -5
- agents/handoffs.py +3 -3
- agents/mcp/server.py +9 -9
- agents/mcp/util.py +1 -1
- agents/model_settings.py +15 -0
- agents/models/interface.py +6 -0
- agents/models/openai_chatcompletions.py +26 -6
- agents/models/openai_responses.py +10 -0
- agents/prompts.py +76 -0
- agents/repl.py +65 -0
- agents/result.py +43 -13
- agents/run.py +48 -8
- agents/stream_events.py +1 -0
- agents/tool.py +26 -5
- agents/tool_context.py +29 -0
- agents/tracing/processors.py +29 -3
- agents/util/_pretty_print.py +12 -0
- agents/voice/model.py +2 -0
- {openai_agents-0.0.16.dist-info → openai_agents-0.0.18.dist-info}/METADATA +6 -3
- {openai_agents-0.0.16.dist-info → openai_agents-0.0.18.dist-info}/RECORD +29 -26
- {openai_agents-0.0.16.dist-info → openai_agents-0.0.18.dist-info}/WHEEL +0 -0
- {openai_agents-0.0.16.dist-info → openai_agents-0.0.18.dist-info}/licenses/LICENSE +0 -0
agents/__init__.py
CHANGED
|
@@ -14,6 +14,7 @@ from .exceptions import (
|
|
|
14
14
|
MaxTurnsExceeded,
|
|
15
15
|
ModelBehaviorError,
|
|
16
16
|
OutputGuardrailTripwireTriggered,
|
|
17
|
+
RunErrorDetails,
|
|
17
18
|
UserError,
|
|
18
19
|
)
|
|
19
20
|
from .guardrail import (
|
|
@@ -44,6 +45,8 @@ from .models.interface import Model, ModelProvider, ModelTracing
|
|
|
44
45
|
from .models.openai_chatcompletions import OpenAIChatCompletionsModel
|
|
45
46
|
from .models.openai_provider import OpenAIProvider
|
|
46
47
|
from .models.openai_responses import OpenAIResponsesModel
|
|
48
|
+
from .prompts import DynamicPromptFunction, GenerateDynamicPromptData, Prompt
|
|
49
|
+
from .repl import run_demo_loop
|
|
47
50
|
from .result import RunResult, RunResultStreaming
|
|
48
51
|
from .run import RunConfig, Runner
|
|
49
52
|
from .run_context import RunContextWrapper, TContext
|
|
@@ -159,6 +162,7 @@ __all__ = [
|
|
|
159
162
|
"ToolsToFinalOutputFunction",
|
|
160
163
|
"ToolsToFinalOutputResult",
|
|
161
164
|
"Runner",
|
|
165
|
+
"run_demo_loop",
|
|
162
166
|
"Model",
|
|
163
167
|
"ModelProvider",
|
|
164
168
|
"ModelTracing",
|
|
@@ -175,6 +179,9 @@ __all__ = [
|
|
|
175
179
|
"AgentsException",
|
|
176
180
|
"InputGuardrailTripwireTriggered",
|
|
177
181
|
"OutputGuardrailTripwireTriggered",
|
|
182
|
+
"DynamicPromptFunction",
|
|
183
|
+
"GenerateDynamicPromptData",
|
|
184
|
+
"Prompt",
|
|
178
185
|
"MaxTurnsExceeded",
|
|
179
186
|
"ModelBehaviorError",
|
|
180
187
|
"UserError",
|
|
@@ -204,6 +211,7 @@ __all__ = [
|
|
|
204
211
|
"AgentHooks",
|
|
205
212
|
"RunContextWrapper",
|
|
206
213
|
"TContext",
|
|
214
|
+
"RunErrorDetails",
|
|
207
215
|
"RunResult",
|
|
208
216
|
"RunResultStreaming",
|
|
209
217
|
"RunConfig",
|
agents/_run_impl.py
CHANGED
|
@@ -33,6 +33,7 @@ from openai.types.responses.response_output_item import (
|
|
|
33
33
|
ImageGenerationCall,
|
|
34
34
|
LocalShellCall,
|
|
35
35
|
McpApprovalRequest,
|
|
36
|
+
McpCall,
|
|
36
37
|
McpListTools,
|
|
37
38
|
)
|
|
38
39
|
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
|
@@ -74,6 +75,7 @@ from .tool import (
|
|
|
74
75
|
MCPToolApprovalRequest,
|
|
75
76
|
Tool,
|
|
76
77
|
)
|
|
78
|
+
from .tool_context import ToolContext
|
|
77
79
|
from .tracing import (
|
|
78
80
|
SpanError,
|
|
79
81
|
Trace,
|
|
@@ -456,6 +458,9 @@ class RunImpl:
|
|
|
456
458
|
)
|
|
457
459
|
elif isinstance(output, McpListTools):
|
|
458
460
|
items.append(MCPListToolsItem(raw_item=output, agent=agent))
|
|
461
|
+
elif isinstance(output, McpCall):
|
|
462
|
+
items.append(ToolCallItem(raw_item=output, agent=agent))
|
|
463
|
+
tools_used.append("mcp")
|
|
459
464
|
elif isinstance(output, ImageGenerationCall):
|
|
460
465
|
items.append(ToolCallItem(raw_item=output, agent=agent))
|
|
461
466
|
tools_used.append("image_generation")
|
|
@@ -539,23 +544,24 @@ class RunImpl:
|
|
|
539
544
|
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
|
|
540
545
|
) -> Any:
|
|
541
546
|
with function_span(func_tool.name) as span_fn:
|
|
547
|
+
tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id)
|
|
542
548
|
if config.trace_include_sensitive_data:
|
|
543
549
|
span_fn.span_data.input = tool_call.arguments
|
|
544
550
|
try:
|
|
545
551
|
_, _, result = await asyncio.gather(
|
|
546
|
-
hooks.on_tool_start(
|
|
552
|
+
hooks.on_tool_start(tool_context, agent, func_tool),
|
|
547
553
|
(
|
|
548
|
-
agent.hooks.on_tool_start(
|
|
554
|
+
agent.hooks.on_tool_start(tool_context, agent, func_tool)
|
|
549
555
|
if agent.hooks
|
|
550
556
|
else _coro.noop_coroutine()
|
|
551
557
|
),
|
|
552
|
-
func_tool.on_invoke_tool(
|
|
558
|
+
func_tool.on_invoke_tool(tool_context, tool_call.arguments),
|
|
553
559
|
)
|
|
554
560
|
|
|
555
561
|
await asyncio.gather(
|
|
556
|
-
hooks.on_tool_end(
|
|
562
|
+
hooks.on_tool_end(tool_context, agent, func_tool, result),
|
|
557
563
|
(
|
|
558
|
-
agent.hooks.on_tool_end(
|
|
564
|
+
agent.hooks.on_tool_end(tool_context, agent, func_tool, result)
|
|
559
565
|
if agent.hooks
|
|
560
566
|
else _coro.noop_coroutine()
|
|
561
567
|
),
|
agents/agent.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import dataclasses
|
|
4
5
|
import inspect
|
|
5
6
|
from collections.abc import Awaitable
|
|
6
7
|
from dataclasses import dataclass, field
|
|
7
8
|
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
|
|
8
9
|
|
|
10
|
+
from openai.types.responses.response_prompt_param import ResponsePromptParam
|
|
9
11
|
from typing_extensions import NotRequired, TypeAlias, TypedDict
|
|
10
12
|
|
|
11
13
|
from .agent_output import AgentOutputSchemaBase
|
|
@@ -16,8 +18,9 @@ from .logger import logger
|
|
|
16
18
|
from .mcp import MCPUtil
|
|
17
19
|
from .model_settings import ModelSettings
|
|
18
20
|
from .models.interface import Model
|
|
21
|
+
from .prompts import DynamicPromptFunction, Prompt, PromptUtil
|
|
19
22
|
from .run_context import RunContextWrapper, TContext
|
|
20
|
-
from .tool import FunctionToolResult, Tool, function_tool
|
|
23
|
+
from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
|
|
21
24
|
from .util import _transforms
|
|
22
25
|
from .util._types import MaybeAwaitable
|
|
23
26
|
|
|
@@ -94,6 +97,12 @@ class Agent(Generic[TContext]):
|
|
|
94
97
|
return a string.
|
|
95
98
|
"""
|
|
96
99
|
|
|
100
|
+
prompt: Prompt | DynamicPromptFunction | None = None
|
|
101
|
+
"""A prompt object (or a function that returns a Prompt). Prompts allow you to dynamically
|
|
102
|
+
configure the instructions, tools and other config for an agent outside of your code. Only
|
|
103
|
+
usable with OpenAI models, using the Responses API.
|
|
104
|
+
"""
|
|
105
|
+
|
|
97
106
|
handoff_description: str | None = None
|
|
98
107
|
"""A description of the agent. This is used when the agent is used as a handoff, so that an
|
|
99
108
|
LLM knows what it does and when to invoke it.
|
|
@@ -241,12 +250,33 @@ class Agent(Generic[TContext]):
|
|
|
241
250
|
|
|
242
251
|
return None
|
|
243
252
|
|
|
253
|
+
async def get_prompt(
|
|
254
|
+
self, run_context: RunContextWrapper[TContext]
|
|
255
|
+
) -> ResponsePromptParam | None:
|
|
256
|
+
"""Get the prompt for the agent."""
|
|
257
|
+
return await PromptUtil.to_model_input(self.prompt, run_context, self)
|
|
258
|
+
|
|
244
259
|
async def get_mcp_tools(self) -> list[Tool]:
|
|
245
260
|
"""Fetches the available tools from the MCP servers."""
|
|
246
261
|
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
|
|
247
262
|
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
|
|
248
263
|
|
|
249
|
-
async def get_all_tools(self) -> list[Tool]:
|
|
264
|
+
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
|
|
250
265
|
"""All agent tools, including MCP tools and function tools."""
|
|
251
266
|
mcp_tools = await self.get_mcp_tools()
|
|
252
|
-
|
|
267
|
+
|
|
268
|
+
async def _check_tool_enabled(tool: Tool) -> bool:
|
|
269
|
+
if not isinstance(tool, FunctionTool):
|
|
270
|
+
return True
|
|
271
|
+
|
|
272
|
+
attr = tool.is_enabled
|
|
273
|
+
if isinstance(attr, bool):
|
|
274
|
+
return attr
|
|
275
|
+
res = attr(run_context, self)
|
|
276
|
+
if inspect.isawaitable(res):
|
|
277
|
+
return bool(await res)
|
|
278
|
+
return bool(res)
|
|
279
|
+
|
|
280
|
+
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
|
|
281
|
+
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
|
|
282
|
+
return [*mcp_tools, *enabled]
|
agents/agent_output.py
CHANGED
|
@@ -38,7 +38,7 @@ class AgentOutputSchemaBase(abc.ABC):
|
|
|
38
38
|
@abc.abstractmethod
|
|
39
39
|
def is_strict_json_schema(self) -> bool:
|
|
40
40
|
"""Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
|
|
41
|
-
features, but guarantees
|
|
41
|
+
features, but guarantees valid JSON. See here for details:
|
|
42
42
|
https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
|
|
43
43
|
"""
|
|
44
44
|
pass
|
agents/exceptions.py
CHANGED
|
@@ -1,12 +1,42 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
2
5
|
|
|
3
6
|
if TYPE_CHECKING:
|
|
7
|
+
from .agent import Agent
|
|
4
8
|
from .guardrail import InputGuardrailResult, OutputGuardrailResult
|
|
9
|
+
from .items import ModelResponse, RunItem, TResponseInputItem
|
|
10
|
+
from .run_context import RunContextWrapper
|
|
11
|
+
|
|
12
|
+
from .util._pretty_print import pretty_print_run_error_details
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class RunErrorDetails:
|
|
17
|
+
"""Data collected from an agent run when an exception occurs."""
|
|
18
|
+
|
|
19
|
+
input: str | list[TResponseInputItem]
|
|
20
|
+
new_items: list[RunItem]
|
|
21
|
+
raw_responses: list[ModelResponse]
|
|
22
|
+
last_agent: Agent[Any]
|
|
23
|
+
context_wrapper: RunContextWrapper[Any]
|
|
24
|
+
input_guardrail_results: list[InputGuardrailResult]
|
|
25
|
+
output_guardrail_results: list[OutputGuardrailResult]
|
|
26
|
+
|
|
27
|
+
def __str__(self) -> str:
|
|
28
|
+
return pretty_print_run_error_details(self)
|
|
5
29
|
|
|
6
30
|
|
|
7
31
|
class AgentsException(Exception):
|
|
8
32
|
"""Base class for all exceptions in the Agents SDK."""
|
|
9
33
|
|
|
34
|
+
run_data: RunErrorDetails | None
|
|
35
|
+
|
|
36
|
+
def __init__(self, *args: object) -> None:
|
|
37
|
+
super().__init__(*args)
|
|
38
|
+
self.run_data = None
|
|
39
|
+
|
|
10
40
|
|
|
11
41
|
class MaxTurnsExceeded(AgentsException):
|
|
12
42
|
"""Exception raised when the maximum number of turns is exceeded."""
|
|
@@ -15,6 +45,7 @@ class MaxTurnsExceeded(AgentsException):
|
|
|
15
45
|
|
|
16
46
|
def __init__(self, message: str):
|
|
17
47
|
self.message = message
|
|
48
|
+
super().__init__(message)
|
|
18
49
|
|
|
19
50
|
|
|
20
51
|
class ModelBehaviorError(AgentsException):
|
|
@@ -26,6 +57,7 @@ class ModelBehaviorError(AgentsException):
|
|
|
26
57
|
|
|
27
58
|
def __init__(self, message: str):
|
|
28
59
|
self.message = message
|
|
60
|
+
super().__init__(message)
|
|
29
61
|
|
|
30
62
|
|
|
31
63
|
class UserError(AgentsException):
|
|
@@ -35,15 +67,16 @@ class UserError(AgentsException):
|
|
|
35
67
|
|
|
36
68
|
def __init__(self, message: str):
|
|
37
69
|
self.message = message
|
|
70
|
+
super().__init__(message)
|
|
38
71
|
|
|
39
72
|
|
|
40
73
|
class InputGuardrailTripwireTriggered(AgentsException):
|
|
41
74
|
"""Exception raised when a guardrail tripwire is triggered."""
|
|
42
75
|
|
|
43
|
-
guardrail_result:
|
|
76
|
+
guardrail_result: InputGuardrailResult
|
|
44
77
|
"""The result data of the guardrail that was triggered."""
|
|
45
78
|
|
|
46
|
-
def __init__(self, guardrail_result:
|
|
79
|
+
def __init__(self, guardrail_result: InputGuardrailResult):
|
|
47
80
|
self.guardrail_result = guardrail_result
|
|
48
81
|
super().__init__(
|
|
49
82
|
f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
|
|
@@ -53,10 +86,10 @@ class InputGuardrailTripwireTriggered(AgentsException):
|
|
|
53
86
|
class OutputGuardrailTripwireTriggered(AgentsException):
|
|
54
87
|
"""Exception raised when a guardrail tripwire is triggered."""
|
|
55
88
|
|
|
56
|
-
guardrail_result:
|
|
89
|
+
guardrail_result: OutputGuardrailResult
|
|
57
90
|
"""The result data of the guardrail that was triggered."""
|
|
58
91
|
|
|
59
|
-
def __init__(self, guardrail_result:
|
|
92
|
+
def __init__(self, guardrail_result: OutputGuardrailResult):
|
|
60
93
|
self.guardrail_result = guardrail_result
|
|
61
94
|
super().__init__(
|
|
62
95
|
f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
|
|
@@ -5,7 +5,6 @@ import time
|
|
|
5
5
|
from collections.abc import AsyncIterator
|
|
6
6
|
from typing import Any, Literal, cast, overload
|
|
7
7
|
|
|
8
|
-
import litellm.types
|
|
9
8
|
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
|
|
10
9
|
|
|
11
10
|
from agents.exceptions import ModelBehaviorError
|
|
@@ -72,6 +71,7 @@ class LitellmModel(Model):
|
|
|
72
71
|
handoffs: list[Handoff],
|
|
73
72
|
tracing: ModelTracing,
|
|
74
73
|
previous_response_id: str | None,
|
|
74
|
+
prompt: Any | None = None,
|
|
75
75
|
) -> ModelResponse:
|
|
76
76
|
with generation_span(
|
|
77
77
|
model=str(self.model),
|
|
@@ -89,6 +89,7 @@ class LitellmModel(Model):
|
|
|
89
89
|
span_generation,
|
|
90
90
|
tracing,
|
|
91
91
|
stream=False,
|
|
92
|
+
prompt=prompt,
|
|
92
93
|
)
|
|
93
94
|
|
|
94
95
|
assert isinstance(response.choices[0], litellm.types.utils.Choices)
|
|
@@ -112,11 +113,13 @@ class LitellmModel(Model):
|
|
|
112
113
|
cached_tokens=getattr(
|
|
113
114
|
response_usage.prompt_tokens_details, "cached_tokens", 0
|
|
114
115
|
)
|
|
116
|
+
or 0
|
|
115
117
|
),
|
|
116
118
|
output_tokens_details=OutputTokensDetails(
|
|
117
119
|
reasoning_tokens=getattr(
|
|
118
120
|
response_usage.completion_tokens_details, "reasoning_tokens", 0
|
|
119
121
|
)
|
|
122
|
+
or 0
|
|
120
123
|
),
|
|
121
124
|
)
|
|
122
125
|
if response.usage
|
|
@@ -152,8 +155,8 @@ class LitellmModel(Model):
|
|
|
152
155
|
output_schema: AgentOutputSchemaBase | None,
|
|
153
156
|
handoffs: list[Handoff],
|
|
154
157
|
tracing: ModelTracing,
|
|
155
|
-
*,
|
|
156
158
|
previous_response_id: str | None,
|
|
159
|
+
prompt: Any | None = None,
|
|
157
160
|
) -> AsyncIterator[TResponseStreamEvent]:
|
|
158
161
|
with generation_span(
|
|
159
162
|
model=str(self.model),
|
|
@@ -171,6 +174,7 @@ class LitellmModel(Model):
|
|
|
171
174
|
span_generation,
|
|
172
175
|
tracing,
|
|
173
176
|
stream=True,
|
|
177
|
+
prompt=prompt,
|
|
174
178
|
)
|
|
175
179
|
|
|
176
180
|
final_response: Response | None = None
|
|
@@ -201,6 +205,7 @@ class LitellmModel(Model):
|
|
|
201
205
|
span: Span[GenerationSpanData],
|
|
202
206
|
tracing: ModelTracing,
|
|
203
207
|
stream: Literal[True],
|
|
208
|
+
prompt: Any | None = None,
|
|
204
209
|
) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ...
|
|
205
210
|
|
|
206
211
|
@overload
|
|
@@ -215,6 +220,7 @@ class LitellmModel(Model):
|
|
|
215
220
|
span: Span[GenerationSpanData],
|
|
216
221
|
tracing: ModelTracing,
|
|
217
222
|
stream: Literal[False],
|
|
223
|
+
prompt: Any | None = None,
|
|
218
224
|
) -> litellm.types.utils.ModelResponse: ...
|
|
219
225
|
|
|
220
226
|
async def _fetch_response(
|
|
@@ -228,6 +234,7 @@ class LitellmModel(Model):
|
|
|
228
234
|
span: Span[GenerationSpanData],
|
|
229
235
|
tracing: ModelTracing,
|
|
230
236
|
stream: bool = False,
|
|
237
|
+
prompt: Any | None = None,
|
|
231
238
|
) -> litellm.types.utils.ModelResponse | tuple[Response, AsyncStream[ChatCompletionChunk]]:
|
|
232
239
|
converted_messages = Converter.items_to_messages(input)
|
|
233
240
|
|
|
@@ -283,6 +290,10 @@ class LitellmModel(Model):
|
|
|
283
290
|
if model_settings.extra_body and isinstance(model_settings.extra_body, dict):
|
|
284
291
|
extra_kwargs.update(model_settings.extra_body)
|
|
285
292
|
|
|
293
|
+
# Add kwargs from model_settings.extra_args, filtering out None values
|
|
294
|
+
if model_settings.extra_args:
|
|
295
|
+
extra_kwargs.update(model_settings.extra_args)
|
|
296
|
+
|
|
286
297
|
ret = await litellm.acompletion(
|
|
287
298
|
model=self.model,
|
|
288
299
|
messages=converted_messages,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import graphviz # type: ignore
|
|
4
4
|
|
|
@@ -31,7 +31,9 @@ def get_main_graph(agent: Agent) -> str:
|
|
|
31
31
|
return "".join(parts)
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
def get_all_nodes(
|
|
34
|
+
def get_all_nodes(
|
|
35
|
+
agent: Agent, parent: Agent | None = None, visited: set[str] | None = None
|
|
36
|
+
) -> str:
|
|
35
37
|
"""
|
|
36
38
|
Recursively generates the nodes for the given agent and its handoffs in DOT format.
|
|
37
39
|
|
|
@@ -41,17 +43,23 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
|
|
|
41
43
|
Returns:
|
|
42
44
|
str: The DOT format string representing the nodes.
|
|
43
45
|
"""
|
|
46
|
+
if visited is None:
|
|
47
|
+
visited = set()
|
|
48
|
+
if agent.name in visited:
|
|
49
|
+
return ""
|
|
50
|
+
visited.add(agent.name)
|
|
51
|
+
|
|
44
52
|
parts = []
|
|
45
53
|
|
|
46
54
|
# Start and end the graph
|
|
47
|
-
parts.append(
|
|
48
|
-
'"__start__" [label="__start__", shape=ellipse, style=filled, '
|
|
49
|
-
"fillcolor=lightblue, width=0.5, height=0.3];"
|
|
50
|
-
'"__end__" [label="__end__", shape=ellipse, style=filled, '
|
|
51
|
-
"fillcolor=lightblue, width=0.5, height=0.3];"
|
|
52
|
-
)
|
|
53
|
-
# Ensure parent agent node is colored
|
|
54
55
|
if not parent:
|
|
56
|
+
parts.append(
|
|
57
|
+
'"__start__" [label="__start__", shape=ellipse, style=filled, '
|
|
58
|
+
"fillcolor=lightblue, width=0.5, height=0.3];"
|
|
59
|
+
'"__end__" [label="__end__", shape=ellipse, style=filled, '
|
|
60
|
+
"fillcolor=lightblue, width=0.5, height=0.3];"
|
|
61
|
+
)
|
|
62
|
+
# Ensure parent agent node is colored
|
|
55
63
|
parts.append(
|
|
56
64
|
f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, '
|
|
57
65
|
"fillcolor=lightyellow, width=1.5, height=0.8];"
|
|
@@ -71,17 +79,20 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
|
|
|
71
79
|
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
|
72
80
|
)
|
|
73
81
|
if isinstance(handoff, Agent):
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
82
|
+
if handoff.name not in visited:
|
|
83
|
+
parts.append(
|
|
84
|
+
f'"{handoff.name}" [label="{handoff.name}", '
|
|
85
|
+
f"shape=box, style=filled, style=rounded, "
|
|
86
|
+
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
|
87
|
+
)
|
|
88
|
+
parts.append(get_all_nodes(handoff, agent, visited))
|
|
80
89
|
|
|
81
90
|
return "".join(parts)
|
|
82
91
|
|
|
83
92
|
|
|
84
|
-
def get_all_edges(
|
|
93
|
+
def get_all_edges(
|
|
94
|
+
agent: Agent, parent: Agent | None = None, visited: set[str] | None = None
|
|
95
|
+
) -> str:
|
|
85
96
|
"""
|
|
86
97
|
Recursively generates the edges for the given agent and its handoffs in DOT format.
|
|
87
98
|
|
|
@@ -92,6 +103,12 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
|
|
|
92
103
|
Returns:
|
|
93
104
|
str: The DOT format string representing the edges.
|
|
94
105
|
"""
|
|
106
|
+
if visited is None:
|
|
107
|
+
visited = set()
|
|
108
|
+
if agent.name in visited:
|
|
109
|
+
return ""
|
|
110
|
+
visited.add(agent.name)
|
|
111
|
+
|
|
95
112
|
parts = []
|
|
96
113
|
|
|
97
114
|
if not parent:
|
|
@@ -109,7 +126,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
|
|
|
109
126
|
if isinstance(handoff, Agent):
|
|
110
127
|
parts.append(f"""
|
|
111
128
|
"{agent.name}" -> "{handoff.name}";""")
|
|
112
|
-
parts.append(get_all_edges(handoff, agent))
|
|
129
|
+
parts.append(get_all_edges(handoff, agent, visited))
|
|
113
130
|
|
|
114
131
|
if not agent.handoffs and not isinstance(agent, Tool): # type: ignore
|
|
115
132
|
parts.append(f'"{agent.name}" -> "__end__";')
|
|
@@ -117,7 +134,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
|
|
|
117
134
|
return "".join(parts)
|
|
118
135
|
|
|
119
136
|
|
|
120
|
-
def draw_graph(agent: Agent, filename:
|
|
137
|
+
def draw_graph(agent: Agent, filename: str | None = None) -> graphviz.Source:
|
|
121
138
|
"""
|
|
122
139
|
Draws the graph for the given agent and optionally saves it as a PNG file.
|
|
123
140
|
|
agents/function_schema.py
CHANGED
|
@@ -13,6 +13,7 @@ from pydantic import BaseModel, Field, create_model
|
|
|
13
13
|
from .exceptions import UserError
|
|
14
14
|
from .run_context import RunContextWrapper
|
|
15
15
|
from .strict_schema import ensure_strict_json_schema
|
|
16
|
+
from .tool_context import ToolContext
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
@dataclass
|
|
@@ -222,7 +223,8 @@ def function_schema(
|
|
|
222
223
|
doc_info = None
|
|
223
224
|
param_descs = {}
|
|
224
225
|
|
|
225
|
-
|
|
226
|
+
# Ensure name_override takes precedence even if docstring info is disabled.
|
|
227
|
+
func_name = name_override or (doc_info.name if doc_info else func.__name__)
|
|
226
228
|
|
|
227
229
|
# 2. Inspect function signature and get type hints
|
|
228
230
|
sig = inspect.signature(func)
|
|
@@ -237,21 +239,21 @@ def function_schema(
|
|
|
237
239
|
ann = type_hints.get(first_name, first_param.annotation)
|
|
238
240
|
if ann != inspect._empty:
|
|
239
241
|
origin = get_origin(ann) or ann
|
|
240
|
-
if origin is RunContextWrapper:
|
|
242
|
+
if origin is RunContextWrapper or origin is ToolContext:
|
|
241
243
|
takes_context = True # Mark that the function takes context
|
|
242
244
|
else:
|
|
243
245
|
filtered_params.append((first_name, first_param))
|
|
244
246
|
else:
|
|
245
247
|
filtered_params.append((first_name, first_param))
|
|
246
248
|
|
|
247
|
-
# For parameters other than the first, raise error if any use RunContextWrapper.
|
|
249
|
+
# For parameters other than the first, raise error if any use RunContextWrapper or ToolContext.
|
|
248
250
|
for name, param in params[1:]:
|
|
249
251
|
ann = type_hints.get(name, param.annotation)
|
|
250
252
|
if ann != inspect._empty:
|
|
251
253
|
origin = get_origin(ann) or ann
|
|
252
|
-
if origin is RunContextWrapper:
|
|
254
|
+
if origin is RunContextWrapper or origin is ToolContext:
|
|
253
255
|
raise UserError(
|
|
254
|
-
f"RunContextWrapper param found at non-first position in function"
|
|
256
|
+
f"RunContextWrapper/ToolContext param found at non-first position in function"
|
|
255
257
|
f" {func.__name__}"
|
|
256
258
|
)
|
|
257
259
|
filtered_params.append((name, param))
|
agents/handoffs.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
+
import json
|
|
4
5
|
from collections.abc import Awaitable
|
|
5
6
|
from dataclasses import dataclass
|
|
6
7
|
from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload
|
|
@@ -99,8 +100,7 @@ class Handoff(Generic[TContext]):
|
|
|
99
100
|
"""
|
|
100
101
|
|
|
101
102
|
def get_transfer_message(self, agent: Agent[Any]) -> str:
|
|
102
|
-
|
|
103
|
-
return base
|
|
103
|
+
return json.dumps({"assistant": agent.name})
|
|
104
104
|
|
|
105
105
|
@classmethod
|
|
106
106
|
def default_tool_name(cls, agent: Agent[Any]) -> str:
|
|
@@ -168,7 +168,7 @@ def handoff(
|
|
|
168
168
|
input_filter: a function that filters the inputs that are passed to the next agent.
|
|
169
169
|
"""
|
|
170
170
|
assert (on_handoff and input_type) or not (on_handoff and input_type), (
|
|
171
|
-
"You must provide either both
|
|
171
|
+
"You must provide either both on_handoff and input_type, or neither"
|
|
172
172
|
)
|
|
173
173
|
type_adapter: TypeAdapter[Any] | None
|
|
174
174
|
if input_type is not None:
|
agents/mcp/server.py
CHANGED
|
@@ -88,7 +88,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
88
88
|
tuple[
|
|
89
89
|
MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
90
90
|
MemoryObjectSendStream[SessionMessage],
|
|
91
|
-
GetSessionIdCallback | None
|
|
91
|
+
GetSessionIdCallback | None,
|
|
92
92
|
]
|
|
93
93
|
]:
|
|
94
94
|
"""Create the streams for the server."""
|
|
@@ -243,7 +243,7 @@ class MCPServerStdio(_MCPServerWithClientSession):
|
|
|
243
243
|
tuple[
|
|
244
244
|
MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
245
245
|
MemoryObjectSendStream[SessionMessage],
|
|
246
|
-
GetSessionIdCallback | None
|
|
246
|
+
GetSessionIdCallback | None,
|
|
247
247
|
]
|
|
248
248
|
]:
|
|
249
249
|
"""Create the streams for the server."""
|
|
@@ -314,7 +314,7 @@ class MCPServerSse(_MCPServerWithClientSession):
|
|
|
314
314
|
tuple[
|
|
315
315
|
MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
316
316
|
MemoryObjectSendStream[SessionMessage],
|
|
317
|
-
GetSessionIdCallback | None
|
|
317
|
+
GetSessionIdCallback | None,
|
|
318
318
|
]
|
|
319
319
|
]:
|
|
320
320
|
"""Create the streams for the server."""
|
|
@@ -340,10 +340,10 @@ class MCPServerStreamableHttpParams(TypedDict):
|
|
|
340
340
|
headers: NotRequired[dict[str, str]]
|
|
341
341
|
"""The headers to send to the server."""
|
|
342
342
|
|
|
343
|
-
timeout: NotRequired[timedelta]
|
|
343
|
+
timeout: NotRequired[timedelta | float]
|
|
344
344
|
"""The timeout for the HTTP request. Defaults to 5 seconds."""
|
|
345
345
|
|
|
346
|
-
sse_read_timeout: NotRequired[timedelta]
|
|
346
|
+
sse_read_timeout: NotRequired[timedelta | float]
|
|
347
347
|
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
|
|
348
348
|
|
|
349
349
|
terminate_on_close: NotRequired[bool]
|
|
@@ -394,16 +394,16 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
|
|
|
394
394
|
tuple[
|
|
395
395
|
MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
396
396
|
MemoryObjectSendStream[SessionMessage],
|
|
397
|
-
GetSessionIdCallback | None
|
|
397
|
+
GetSessionIdCallback | None,
|
|
398
398
|
]
|
|
399
399
|
]:
|
|
400
400
|
"""Create the streams for the server."""
|
|
401
401
|
return streamablehttp_client(
|
|
402
402
|
url=self.params["url"],
|
|
403
403
|
headers=self.params.get("headers", None),
|
|
404
|
-
timeout=self.params.get("timeout",
|
|
405
|
-
sse_read_timeout=self.params.get("sse_read_timeout",
|
|
406
|
-
terminate_on_close=self.params.get("terminate_on_close", True)
|
|
404
|
+
timeout=self.params.get("timeout", 5),
|
|
405
|
+
sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
|
|
406
|
+
terminate_on_close=self.params.get("terminate_on_close", True),
|
|
407
407
|
)
|
|
408
408
|
|
|
409
409
|
@property
|
agents/mcp/util.py
CHANGED
|
@@ -116,7 +116,7 @@ class MCPUtil:
|
|
|
116
116
|
if len(result.content) == 1:
|
|
117
117
|
tool_output = result.content[0].model_dump_json()
|
|
118
118
|
elif len(result.content) > 1:
|
|
119
|
-
tool_output = json.dumps([item.model_dump() for item in result.content])
|
|
119
|
+
tool_output = json.dumps([item.model_dump(mode="json") for item in result.content])
|
|
120
120
|
else:
|
|
121
121
|
logger.error(f"Errored MCP tool result: {result}")
|
|
122
122
|
tool_output = "Error running tool."
|
agents/model_settings.py
CHANGED
|
@@ -73,6 +73,11 @@ class ModelSettings:
|
|
|
73
73
|
"""Additional headers to provide with the request.
|
|
74
74
|
Defaults to None if not provided."""
|
|
75
75
|
|
|
76
|
+
extra_args: dict[str, Any] | None = None
|
|
77
|
+
"""Arbitrary keyword arguments to pass to the model API call.
|
|
78
|
+
These will be passed directly to the underlying model provider's API.
|
|
79
|
+
Use with caution as not all models support all parameters."""
|
|
80
|
+
|
|
76
81
|
def resolve(self, override: ModelSettings | None) -> ModelSettings:
|
|
77
82
|
"""Produce a new ModelSettings by overlaying any non-None values from the
|
|
78
83
|
override on top of this instance."""
|
|
@@ -84,6 +89,16 @@ class ModelSettings:
|
|
|
84
89
|
for field in fields(self)
|
|
85
90
|
if getattr(override, field.name) is not None
|
|
86
91
|
}
|
|
92
|
+
|
|
93
|
+
# Handle extra_args merging specially - merge dictionaries instead of replacing
|
|
94
|
+
if self.extra_args is not None or override.extra_args is not None:
|
|
95
|
+
merged_args = {}
|
|
96
|
+
if self.extra_args:
|
|
97
|
+
merged_args.update(self.extra_args)
|
|
98
|
+
if override.extra_args:
|
|
99
|
+
merged_args.update(override.extra_args)
|
|
100
|
+
changes["extra_args"] = merged_args if merged_args else None
|
|
101
|
+
|
|
87
102
|
return replace(self, **changes)
|
|
88
103
|
|
|
89
104
|
def to_json_dict(self) -> dict[str, Any]:
|
agents/models/interface.py
CHANGED
|
@@ -5,6 +5,8 @@ import enum
|
|
|
5
5
|
from collections.abc import AsyncIterator
|
|
6
6
|
from typing import TYPE_CHECKING
|
|
7
7
|
|
|
8
|
+
from openai.types.responses.response_prompt_param import ResponsePromptParam
|
|
9
|
+
|
|
8
10
|
from ..agent_output import AgentOutputSchemaBase
|
|
9
11
|
from ..handoffs import Handoff
|
|
10
12
|
from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent
|
|
@@ -46,6 +48,7 @@ class Model(abc.ABC):
|
|
|
46
48
|
tracing: ModelTracing,
|
|
47
49
|
*,
|
|
48
50
|
previous_response_id: str | None,
|
|
51
|
+
prompt: ResponsePromptParam | None,
|
|
49
52
|
) -> ModelResponse:
|
|
50
53
|
"""Get a response from the model.
|
|
51
54
|
|
|
@@ -59,6 +62,7 @@ class Model(abc.ABC):
|
|
|
59
62
|
tracing: Tracing configuration.
|
|
60
63
|
previous_response_id: the ID of the previous response. Generally not used by the model,
|
|
61
64
|
except for the OpenAI Responses API.
|
|
65
|
+
prompt: The prompt config to use for the model.
|
|
62
66
|
|
|
63
67
|
Returns:
|
|
64
68
|
The full model response.
|
|
@@ -77,6 +81,7 @@ class Model(abc.ABC):
|
|
|
77
81
|
tracing: ModelTracing,
|
|
78
82
|
*,
|
|
79
83
|
previous_response_id: str | None,
|
|
84
|
+
prompt: ResponsePromptParam | None,
|
|
80
85
|
) -> AsyncIterator[TResponseStreamEvent]:
|
|
81
86
|
"""Stream a response from the model.
|
|
82
87
|
|
|
@@ -90,6 +95,7 @@ class Model(abc.ABC):
|
|
|
90
95
|
tracing: Tracing configuration.
|
|
91
96
|
previous_response_id: the ID of the previous response. Generally not used by the model,
|
|
92
97
|
except for the OpenAI Responses API.
|
|
98
|
+
prompt: The prompt config to use for the model.
|
|
93
99
|
|
|
94
100
|
Returns:
|
|
95
101
|
An iterator of response stream events, in OpenAI Responses format.
|