openai-agents 0.0.16__py3-none-any.whl → 0.0.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

agents/__init__.py CHANGED
@@ -14,6 +14,7 @@ from .exceptions import (
14
14
  MaxTurnsExceeded,
15
15
  ModelBehaviorError,
16
16
  OutputGuardrailTripwireTriggered,
17
+ RunErrorDetails,
17
18
  UserError,
18
19
  )
19
20
  from .guardrail import (
@@ -44,6 +45,8 @@ from .models.interface import Model, ModelProvider, ModelTracing
44
45
  from .models.openai_chatcompletions import OpenAIChatCompletionsModel
45
46
  from .models.openai_provider import OpenAIProvider
46
47
  from .models.openai_responses import OpenAIResponsesModel
48
+ from .prompts import DynamicPromptFunction, GenerateDynamicPromptData, Prompt
49
+ from .repl import run_demo_loop
47
50
  from .result import RunResult, RunResultStreaming
48
51
  from .run import RunConfig, Runner
49
52
  from .run_context import RunContextWrapper, TContext
@@ -159,6 +162,7 @@ __all__ = [
159
162
  "ToolsToFinalOutputFunction",
160
163
  "ToolsToFinalOutputResult",
161
164
  "Runner",
165
+ "run_demo_loop",
162
166
  "Model",
163
167
  "ModelProvider",
164
168
  "ModelTracing",
@@ -175,6 +179,9 @@ __all__ = [
175
179
  "AgentsException",
176
180
  "InputGuardrailTripwireTriggered",
177
181
  "OutputGuardrailTripwireTriggered",
182
+ "DynamicPromptFunction",
183
+ "GenerateDynamicPromptData",
184
+ "Prompt",
178
185
  "MaxTurnsExceeded",
179
186
  "ModelBehaviorError",
180
187
  "UserError",
@@ -204,6 +211,7 @@ __all__ = [
204
211
  "AgentHooks",
205
212
  "RunContextWrapper",
206
213
  "TContext",
214
+ "RunErrorDetails",
207
215
  "RunResult",
208
216
  "RunResultStreaming",
209
217
  "RunConfig",
agents/_run_impl.py CHANGED
@@ -33,6 +33,7 @@ from openai.types.responses.response_output_item import (
33
33
  ImageGenerationCall,
34
34
  LocalShellCall,
35
35
  McpApprovalRequest,
36
+ McpCall,
36
37
  McpListTools,
37
38
  )
38
39
  from openai.types.responses.response_reasoning_item import ResponseReasoningItem
@@ -74,6 +75,7 @@ from .tool import (
74
75
  MCPToolApprovalRequest,
75
76
  Tool,
76
77
  )
78
+ from .tool_context import ToolContext
77
79
  from .tracing import (
78
80
  SpanError,
79
81
  Trace,
@@ -456,6 +458,9 @@ class RunImpl:
456
458
  )
457
459
  elif isinstance(output, McpListTools):
458
460
  items.append(MCPListToolsItem(raw_item=output, agent=agent))
461
+ elif isinstance(output, McpCall):
462
+ items.append(ToolCallItem(raw_item=output, agent=agent))
463
+ tools_used.append("mcp")
459
464
  elif isinstance(output, ImageGenerationCall):
460
465
  items.append(ToolCallItem(raw_item=output, agent=agent))
461
466
  tools_used.append("image_generation")
@@ -539,23 +544,24 @@ class RunImpl:
539
544
  func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
540
545
  ) -> Any:
541
546
  with function_span(func_tool.name) as span_fn:
547
+ tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id)
542
548
  if config.trace_include_sensitive_data:
543
549
  span_fn.span_data.input = tool_call.arguments
544
550
  try:
545
551
  _, _, result = await asyncio.gather(
546
- hooks.on_tool_start(context_wrapper, agent, func_tool),
552
+ hooks.on_tool_start(tool_context, agent, func_tool),
547
553
  (
548
- agent.hooks.on_tool_start(context_wrapper, agent, func_tool)
554
+ agent.hooks.on_tool_start(tool_context, agent, func_tool)
549
555
  if agent.hooks
550
556
  else _coro.noop_coroutine()
551
557
  ),
552
- func_tool.on_invoke_tool(context_wrapper, tool_call.arguments),
558
+ func_tool.on_invoke_tool(tool_context, tool_call.arguments),
553
559
  )
554
560
 
555
561
  await asyncio.gather(
556
- hooks.on_tool_end(context_wrapper, agent, func_tool, result),
562
+ hooks.on_tool_end(tool_context, agent, func_tool, result),
557
563
  (
558
- agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result)
564
+ agent.hooks.on_tool_end(tool_context, agent, func_tool, result)
559
565
  if agent.hooks
560
566
  else _coro.noop_coroutine()
561
567
  ),
agents/agent.py CHANGED
@@ -1,11 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import dataclasses
4
5
  import inspect
5
6
  from collections.abc import Awaitable
6
7
  from dataclasses import dataclass, field
7
8
  from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
8
9
 
10
+ from openai.types.responses.response_prompt_param import ResponsePromptParam
9
11
  from typing_extensions import NotRequired, TypeAlias, TypedDict
10
12
 
11
13
  from .agent_output import AgentOutputSchemaBase
@@ -16,8 +18,9 @@ from .logger import logger
16
18
  from .mcp import MCPUtil
17
19
  from .model_settings import ModelSettings
18
20
  from .models.interface import Model
21
+ from .prompts import DynamicPromptFunction, Prompt, PromptUtil
19
22
  from .run_context import RunContextWrapper, TContext
20
- from .tool import FunctionToolResult, Tool, function_tool
23
+ from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
21
24
  from .util import _transforms
22
25
  from .util._types import MaybeAwaitable
23
26
 
@@ -94,6 +97,12 @@ class Agent(Generic[TContext]):
94
97
  return a string.
95
98
  """
96
99
 
100
+ prompt: Prompt | DynamicPromptFunction | None = None
101
+ """A prompt object (or a function that returns a Prompt). Prompts allow you to dynamically
102
+ configure the instructions, tools and other config for an agent outside of your code. Only
103
+ usable with OpenAI models, using the Responses API.
104
+ """
105
+
97
106
  handoff_description: str | None = None
98
107
  """A description of the agent. This is used when the agent is used as a handoff, so that an
99
108
  LLM knows what it does and when to invoke it.
@@ -241,12 +250,33 @@ class Agent(Generic[TContext]):
241
250
 
242
251
  return None
243
252
 
253
+ async def get_prompt(
254
+ self, run_context: RunContextWrapper[TContext]
255
+ ) -> ResponsePromptParam | None:
256
+ """Get the prompt for the agent."""
257
+ return await PromptUtil.to_model_input(self.prompt, run_context, self)
258
+
244
259
  async def get_mcp_tools(self) -> list[Tool]:
245
260
  """Fetches the available tools from the MCP servers."""
246
261
  convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
247
262
  return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
248
263
 
249
- async def get_all_tools(self) -> list[Tool]:
264
+ async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
250
265
  """All agent tools, including MCP tools and function tools."""
251
266
  mcp_tools = await self.get_mcp_tools()
252
- return mcp_tools + self.tools
267
+
268
+ async def _check_tool_enabled(tool: Tool) -> bool:
269
+ if not isinstance(tool, FunctionTool):
270
+ return True
271
+
272
+ attr = tool.is_enabled
273
+ if isinstance(attr, bool):
274
+ return attr
275
+ res = attr(run_context, self)
276
+ if inspect.isawaitable(res):
277
+ return bool(await res)
278
+ return bool(res)
279
+
280
+ results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
281
+ enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
282
+ return [*mcp_tools, *enabled]
agents/agent_output.py CHANGED
@@ -38,7 +38,7 @@ class AgentOutputSchemaBase(abc.ABC):
38
38
  @abc.abstractmethod
39
39
  def is_strict_json_schema(self) -> bool:
40
40
  """Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
41
- features, but guarantees valis JSON. See here for details:
41
+ features, but guarantees valid JSON. See here for details:
42
42
  https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
43
43
  """
44
44
  pass
agents/exceptions.py CHANGED
@@ -1,12 +1,42 @@
1
- from typing import TYPE_CHECKING
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Any
2
5
 
3
6
  if TYPE_CHECKING:
7
+ from .agent import Agent
4
8
  from .guardrail import InputGuardrailResult, OutputGuardrailResult
9
+ from .items import ModelResponse, RunItem, TResponseInputItem
10
+ from .run_context import RunContextWrapper
11
+
12
+ from .util._pretty_print import pretty_print_run_error_details
13
+
14
+
15
+ @dataclass
16
+ class RunErrorDetails:
17
+ """Data collected from an agent run when an exception occurs."""
18
+
19
+ input: str | list[TResponseInputItem]
20
+ new_items: list[RunItem]
21
+ raw_responses: list[ModelResponse]
22
+ last_agent: Agent[Any]
23
+ context_wrapper: RunContextWrapper[Any]
24
+ input_guardrail_results: list[InputGuardrailResult]
25
+ output_guardrail_results: list[OutputGuardrailResult]
26
+
27
+ def __str__(self) -> str:
28
+ return pretty_print_run_error_details(self)
5
29
 
6
30
 
7
31
  class AgentsException(Exception):
8
32
  """Base class for all exceptions in the Agents SDK."""
9
33
 
34
+ run_data: RunErrorDetails | None
35
+
36
+ def __init__(self, *args: object) -> None:
37
+ super().__init__(*args)
38
+ self.run_data = None
39
+
10
40
 
11
41
  class MaxTurnsExceeded(AgentsException):
12
42
  """Exception raised when the maximum number of turns is exceeded."""
@@ -15,6 +45,7 @@ class MaxTurnsExceeded(AgentsException):
15
45
 
16
46
  def __init__(self, message: str):
17
47
  self.message = message
48
+ super().__init__(message)
18
49
 
19
50
 
20
51
  class ModelBehaviorError(AgentsException):
@@ -26,6 +57,7 @@ class ModelBehaviorError(AgentsException):
26
57
 
27
58
  def __init__(self, message: str):
28
59
  self.message = message
60
+ super().__init__(message)
29
61
 
30
62
 
31
63
  class UserError(AgentsException):
@@ -35,15 +67,16 @@ class UserError(AgentsException):
35
67
 
36
68
  def __init__(self, message: str):
37
69
  self.message = message
70
+ super().__init__(message)
38
71
 
39
72
 
40
73
  class InputGuardrailTripwireTriggered(AgentsException):
41
74
  """Exception raised when a guardrail tripwire is triggered."""
42
75
 
43
- guardrail_result: "InputGuardrailResult"
76
+ guardrail_result: InputGuardrailResult
44
77
  """The result data of the guardrail that was triggered."""
45
78
 
46
- def __init__(self, guardrail_result: "InputGuardrailResult"):
79
+ def __init__(self, guardrail_result: InputGuardrailResult):
47
80
  self.guardrail_result = guardrail_result
48
81
  super().__init__(
49
82
  f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
@@ -53,10 +86,10 @@ class InputGuardrailTripwireTriggered(AgentsException):
53
86
  class OutputGuardrailTripwireTriggered(AgentsException):
54
87
  """Exception raised when a guardrail tripwire is triggered."""
55
88
 
56
- guardrail_result: "OutputGuardrailResult"
89
+ guardrail_result: OutputGuardrailResult
57
90
  """The result data of the guardrail that was triggered."""
58
91
 
59
- def __init__(self, guardrail_result: "OutputGuardrailResult"):
92
+ def __init__(self, guardrail_result: OutputGuardrailResult):
60
93
  self.guardrail_result = guardrail_result
61
94
  super().__init__(
62
95
  f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
@@ -5,7 +5,6 @@ import time
5
5
  from collections.abc import AsyncIterator
6
6
  from typing import Any, Literal, cast, overload
7
7
 
8
- import litellm.types
9
8
  from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
10
9
 
11
10
  from agents.exceptions import ModelBehaviorError
@@ -72,6 +71,7 @@ class LitellmModel(Model):
72
71
  handoffs: list[Handoff],
73
72
  tracing: ModelTracing,
74
73
  previous_response_id: str | None,
74
+ prompt: Any | None = None,
75
75
  ) -> ModelResponse:
76
76
  with generation_span(
77
77
  model=str(self.model),
@@ -89,6 +89,7 @@ class LitellmModel(Model):
89
89
  span_generation,
90
90
  tracing,
91
91
  stream=False,
92
+ prompt=prompt,
92
93
  )
93
94
 
94
95
  assert isinstance(response.choices[0], litellm.types.utils.Choices)
@@ -112,11 +113,13 @@ class LitellmModel(Model):
112
113
  cached_tokens=getattr(
113
114
  response_usage.prompt_tokens_details, "cached_tokens", 0
114
115
  )
116
+ or 0
115
117
  ),
116
118
  output_tokens_details=OutputTokensDetails(
117
119
  reasoning_tokens=getattr(
118
120
  response_usage.completion_tokens_details, "reasoning_tokens", 0
119
121
  )
122
+ or 0
120
123
  ),
121
124
  )
122
125
  if response.usage
@@ -152,8 +155,8 @@ class LitellmModel(Model):
152
155
  output_schema: AgentOutputSchemaBase | None,
153
156
  handoffs: list[Handoff],
154
157
  tracing: ModelTracing,
155
- *,
156
158
  previous_response_id: str | None,
159
+ prompt: Any | None = None,
157
160
  ) -> AsyncIterator[TResponseStreamEvent]:
158
161
  with generation_span(
159
162
  model=str(self.model),
@@ -171,6 +174,7 @@ class LitellmModel(Model):
171
174
  span_generation,
172
175
  tracing,
173
176
  stream=True,
177
+ prompt=prompt,
174
178
  )
175
179
 
176
180
  final_response: Response | None = None
@@ -201,6 +205,7 @@ class LitellmModel(Model):
201
205
  span: Span[GenerationSpanData],
202
206
  tracing: ModelTracing,
203
207
  stream: Literal[True],
208
+ prompt: Any | None = None,
204
209
  ) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ...
205
210
 
206
211
  @overload
@@ -215,6 +220,7 @@ class LitellmModel(Model):
215
220
  span: Span[GenerationSpanData],
216
221
  tracing: ModelTracing,
217
222
  stream: Literal[False],
223
+ prompt: Any | None = None,
218
224
  ) -> litellm.types.utils.ModelResponse: ...
219
225
 
220
226
  async def _fetch_response(
@@ -228,6 +234,7 @@ class LitellmModel(Model):
228
234
  span: Span[GenerationSpanData],
229
235
  tracing: ModelTracing,
230
236
  stream: bool = False,
237
+ prompt: Any | None = None,
231
238
  ) -> litellm.types.utils.ModelResponse | tuple[Response, AsyncStream[ChatCompletionChunk]]:
232
239
  converted_messages = Converter.items_to_messages(input)
233
240
 
@@ -283,6 +290,10 @@ class LitellmModel(Model):
283
290
  if model_settings.extra_body and isinstance(model_settings.extra_body, dict):
284
291
  extra_kwargs.update(model_settings.extra_body)
285
292
 
293
+ # Add kwargs from model_settings.extra_args, filtering out None values
294
+ if model_settings.extra_args:
295
+ extra_kwargs.update(model_settings.extra_args)
296
+
286
297
  ret = await litellm.acompletion(
287
298
  model=self.model,
288
299
  messages=converted_messages,
@@ -1,4 +1,4 @@
1
- from typing import Optional
1
+ from __future__ import annotations
2
2
 
3
3
  import graphviz # type: ignore
4
4
 
@@ -31,7 +31,9 @@ def get_main_graph(agent: Agent) -> str:
31
31
  return "".join(parts)
32
32
 
33
33
 
34
- def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
34
+ def get_all_nodes(
35
+ agent: Agent, parent: Agent | None = None, visited: set[str] | None = None
36
+ ) -> str:
35
37
  """
36
38
  Recursively generates the nodes for the given agent and its handoffs in DOT format.
37
39
 
@@ -41,17 +43,23 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
41
43
  Returns:
42
44
  str: The DOT format string representing the nodes.
43
45
  """
46
+ if visited is None:
47
+ visited = set()
48
+ if agent.name in visited:
49
+ return ""
50
+ visited.add(agent.name)
51
+
44
52
  parts = []
45
53
 
46
54
  # Start and end the graph
47
- parts.append(
48
- '"__start__" [label="__start__", shape=ellipse, style=filled, '
49
- "fillcolor=lightblue, width=0.5, height=0.3];"
50
- '"__end__" [label="__end__", shape=ellipse, style=filled, '
51
- "fillcolor=lightblue, width=0.5, height=0.3];"
52
- )
53
- # Ensure parent agent node is colored
54
55
  if not parent:
56
+ parts.append(
57
+ '"__start__" [label="__start__", shape=ellipse, style=filled, '
58
+ "fillcolor=lightblue, width=0.5, height=0.3];"
59
+ '"__end__" [label="__end__", shape=ellipse, style=filled, '
60
+ "fillcolor=lightblue, width=0.5, height=0.3];"
61
+ )
62
+ # Ensure parent agent node is colored
55
63
  parts.append(
56
64
  f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, '
57
65
  "fillcolor=lightyellow, width=1.5, height=0.8];"
@@ -71,17 +79,20 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
71
79
  f"fillcolor=lightyellow, width=1.5, height=0.8];"
72
80
  )
73
81
  if isinstance(handoff, Agent):
74
- parts.append(
75
- f'"{handoff.name}" [label="{handoff.name}", '
76
- f"shape=box, style=filled, style=rounded, "
77
- f"fillcolor=lightyellow, width=1.5, height=0.8];"
78
- )
79
- parts.append(get_all_nodes(handoff))
82
+ if handoff.name not in visited:
83
+ parts.append(
84
+ f'"{handoff.name}" [label="{handoff.name}", '
85
+ f"shape=box, style=filled, style=rounded, "
86
+ f"fillcolor=lightyellow, width=1.5, height=0.8];"
87
+ )
88
+ parts.append(get_all_nodes(handoff, agent, visited))
80
89
 
81
90
  return "".join(parts)
82
91
 
83
92
 
84
- def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
93
+ def get_all_edges(
94
+ agent: Agent, parent: Agent | None = None, visited: set[str] | None = None
95
+ ) -> str:
85
96
  """
86
97
  Recursively generates the edges for the given agent and its handoffs in DOT format.
87
98
 
@@ -92,6 +103,12 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
92
103
  Returns:
93
104
  str: The DOT format string representing the edges.
94
105
  """
106
+ if visited is None:
107
+ visited = set()
108
+ if agent.name in visited:
109
+ return ""
110
+ visited.add(agent.name)
111
+
95
112
  parts = []
96
113
 
97
114
  if not parent:
@@ -109,7 +126,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
109
126
  if isinstance(handoff, Agent):
110
127
  parts.append(f"""
111
128
  "{agent.name}" -> "{handoff.name}";""")
112
- parts.append(get_all_edges(handoff, agent))
129
+ parts.append(get_all_edges(handoff, agent, visited))
113
130
 
114
131
  if not agent.handoffs and not isinstance(agent, Tool): # type: ignore
115
132
  parts.append(f'"{agent.name}" -> "__end__";')
@@ -117,7 +134,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
117
134
  return "".join(parts)
118
135
 
119
136
 
120
- def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source:
137
+ def draw_graph(agent: Agent, filename: str | None = None) -> graphviz.Source:
121
138
  """
122
139
  Draws the graph for the given agent and optionally saves it as a PNG file.
123
140
 
agents/function_schema.py CHANGED
@@ -13,6 +13,7 @@ from pydantic import BaseModel, Field, create_model
13
13
  from .exceptions import UserError
14
14
  from .run_context import RunContextWrapper
15
15
  from .strict_schema import ensure_strict_json_schema
16
+ from .tool_context import ToolContext
16
17
 
17
18
 
18
19
  @dataclass
@@ -222,7 +223,8 @@ def function_schema(
222
223
  doc_info = None
223
224
  param_descs = {}
224
225
 
225
- func_name = name_override or doc_info.name if doc_info else func.__name__
226
+ # Ensure name_override takes precedence even if docstring info is disabled.
227
+ func_name = name_override or (doc_info.name if doc_info else func.__name__)
226
228
 
227
229
  # 2. Inspect function signature and get type hints
228
230
  sig = inspect.signature(func)
@@ -237,21 +239,21 @@ def function_schema(
237
239
  ann = type_hints.get(first_name, first_param.annotation)
238
240
  if ann != inspect._empty:
239
241
  origin = get_origin(ann) or ann
240
- if origin is RunContextWrapper:
242
+ if origin is RunContextWrapper or origin is ToolContext:
241
243
  takes_context = True # Mark that the function takes context
242
244
  else:
243
245
  filtered_params.append((first_name, first_param))
244
246
  else:
245
247
  filtered_params.append((first_name, first_param))
246
248
 
247
- # For parameters other than the first, raise error if any use RunContextWrapper.
249
+ # For parameters other than the first, raise error if any use RunContextWrapper or ToolContext.
248
250
  for name, param in params[1:]:
249
251
  ann = type_hints.get(name, param.annotation)
250
252
  if ann != inspect._empty:
251
253
  origin = get_origin(ann) or ann
252
- if origin is RunContextWrapper:
254
+ if origin is RunContextWrapper or origin is ToolContext:
253
255
  raise UserError(
254
- f"RunContextWrapper param found at non-first position in function"
256
+ f"RunContextWrapper/ToolContext param found at non-first position in function"
255
257
  f" {func.__name__}"
256
258
  )
257
259
  filtered_params.append((name, param))
agents/handoffs.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
+ import json
4
5
  from collections.abc import Awaitable
5
6
  from dataclasses import dataclass
6
7
  from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload
@@ -99,8 +100,7 @@ class Handoff(Generic[TContext]):
99
100
  """
100
101
 
101
102
  def get_transfer_message(self, agent: Agent[Any]) -> str:
102
- base = f"{{'assistant': '{agent.name}'}}"
103
- return base
103
+ return json.dumps({"assistant": agent.name})
104
104
 
105
105
  @classmethod
106
106
  def default_tool_name(cls, agent: Agent[Any]) -> str:
@@ -168,7 +168,7 @@ def handoff(
168
168
  input_filter: a function that filters the inputs that are passed to the next agent.
169
169
  """
170
170
  assert (on_handoff and input_type) or not (on_handoff and input_type), (
171
- "You must provide either both on_input and input_type, or neither"
171
+ "You must provide either both on_handoff and input_type, or neither"
172
172
  )
173
173
  type_adapter: TypeAdapter[Any] | None
174
174
  if input_type is not None:
agents/mcp/server.py CHANGED
@@ -88,7 +88,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
88
88
  tuple[
89
89
  MemoryObjectReceiveStream[SessionMessage | Exception],
90
90
  MemoryObjectSendStream[SessionMessage],
91
- GetSessionIdCallback | None
91
+ GetSessionIdCallback | None,
92
92
  ]
93
93
  ]:
94
94
  """Create the streams for the server."""
@@ -243,7 +243,7 @@ class MCPServerStdio(_MCPServerWithClientSession):
243
243
  tuple[
244
244
  MemoryObjectReceiveStream[SessionMessage | Exception],
245
245
  MemoryObjectSendStream[SessionMessage],
246
- GetSessionIdCallback | None
246
+ GetSessionIdCallback | None,
247
247
  ]
248
248
  ]:
249
249
  """Create the streams for the server."""
@@ -314,7 +314,7 @@ class MCPServerSse(_MCPServerWithClientSession):
314
314
  tuple[
315
315
  MemoryObjectReceiveStream[SessionMessage | Exception],
316
316
  MemoryObjectSendStream[SessionMessage],
317
- GetSessionIdCallback | None
317
+ GetSessionIdCallback | None,
318
318
  ]
319
319
  ]:
320
320
  """Create the streams for the server."""
@@ -340,10 +340,10 @@ class MCPServerStreamableHttpParams(TypedDict):
340
340
  headers: NotRequired[dict[str, str]]
341
341
  """The headers to send to the server."""
342
342
 
343
- timeout: NotRequired[timedelta]
343
+ timeout: NotRequired[timedelta | float]
344
344
  """The timeout for the HTTP request. Defaults to 5 seconds."""
345
345
 
346
- sse_read_timeout: NotRequired[timedelta]
346
+ sse_read_timeout: NotRequired[timedelta | float]
347
347
  """The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
348
348
 
349
349
  terminate_on_close: NotRequired[bool]
@@ -394,16 +394,16 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
394
394
  tuple[
395
395
  MemoryObjectReceiveStream[SessionMessage | Exception],
396
396
  MemoryObjectSendStream[SessionMessage],
397
- GetSessionIdCallback | None
397
+ GetSessionIdCallback | None,
398
398
  ]
399
399
  ]:
400
400
  """Create the streams for the server."""
401
401
  return streamablehttp_client(
402
402
  url=self.params["url"],
403
403
  headers=self.params.get("headers", None),
404
- timeout=self.params.get("timeout", timedelta(seconds=30)),
405
- sse_read_timeout=self.params.get("sse_read_timeout", timedelta(seconds=60 * 5)),
406
- terminate_on_close=self.params.get("terminate_on_close", True)
404
+ timeout=self.params.get("timeout", 5),
405
+ sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
406
+ terminate_on_close=self.params.get("terminate_on_close", True),
407
407
  )
408
408
 
409
409
  @property
agents/mcp/util.py CHANGED
@@ -116,7 +116,7 @@ class MCPUtil:
116
116
  if len(result.content) == 1:
117
117
  tool_output = result.content[0].model_dump_json()
118
118
  elif len(result.content) > 1:
119
- tool_output = json.dumps([item.model_dump() for item in result.content])
119
+ tool_output = json.dumps([item.model_dump(mode="json") for item in result.content])
120
120
  else:
121
121
  logger.error(f"Errored MCP tool result: {result}")
122
122
  tool_output = "Error running tool."
agents/model_settings.py CHANGED
@@ -73,6 +73,11 @@ class ModelSettings:
73
73
  """Additional headers to provide with the request.
74
74
  Defaults to None if not provided."""
75
75
 
76
+ extra_args: dict[str, Any] | None = None
77
+ """Arbitrary keyword arguments to pass to the model API call.
78
+ These will be passed directly to the underlying model provider's API.
79
+ Use with caution as not all models support all parameters."""
80
+
76
81
  def resolve(self, override: ModelSettings | None) -> ModelSettings:
77
82
  """Produce a new ModelSettings by overlaying any non-None values from the
78
83
  override on top of this instance."""
@@ -84,6 +89,16 @@ class ModelSettings:
84
89
  for field in fields(self)
85
90
  if getattr(override, field.name) is not None
86
91
  }
92
+
93
+ # Handle extra_args merging specially - merge dictionaries instead of replacing
94
+ if self.extra_args is not None or override.extra_args is not None:
95
+ merged_args = {}
96
+ if self.extra_args:
97
+ merged_args.update(self.extra_args)
98
+ if override.extra_args:
99
+ merged_args.update(override.extra_args)
100
+ changes["extra_args"] = merged_args if merged_args else None
101
+
87
102
  return replace(self, **changes)
88
103
 
89
104
  def to_json_dict(self) -> dict[str, Any]:
@@ -5,6 +5,8 @@ import enum
5
5
  from collections.abc import AsyncIterator
6
6
  from typing import TYPE_CHECKING
7
7
 
8
+ from openai.types.responses.response_prompt_param import ResponsePromptParam
9
+
8
10
  from ..agent_output import AgentOutputSchemaBase
9
11
  from ..handoffs import Handoff
10
12
  from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent
@@ -46,6 +48,7 @@ class Model(abc.ABC):
46
48
  tracing: ModelTracing,
47
49
  *,
48
50
  previous_response_id: str | None,
51
+ prompt: ResponsePromptParam | None,
49
52
  ) -> ModelResponse:
50
53
  """Get a response from the model.
51
54
 
@@ -59,6 +62,7 @@ class Model(abc.ABC):
59
62
  tracing: Tracing configuration.
60
63
  previous_response_id: the ID of the previous response. Generally not used by the model,
61
64
  except for the OpenAI Responses API.
65
+ prompt: The prompt config to use for the model.
62
66
 
63
67
  Returns:
64
68
  The full model response.
@@ -77,6 +81,7 @@ class Model(abc.ABC):
77
81
  tracing: ModelTracing,
78
82
  *,
79
83
  previous_response_id: str | None,
84
+ prompt: ResponsePromptParam | None,
80
85
  ) -> AsyncIterator[TResponseStreamEvent]:
81
86
  """Stream a response from the model.
82
87
 
@@ -90,6 +95,7 @@ class Model(abc.ABC):
90
95
  tracing: Tracing configuration.
91
96
  previous_response_id: the ID of the previous response. Generally not used by the model,
92
97
  except for the OpenAI Responses API.
98
+ prompt: The prompt config to use for the model.
93
99
 
94
100
  Returns:
95
101
  An iterator of response stream events, in OpenAI Responses format.