openai-agents 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

agents/agent.py CHANGED
@@ -214,7 +214,7 @@ class Agent(AgentBase, Generic[TContext]):
214
214
  calls result in a final output.
215
215
 
216
216
  NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search,
217
- web search, etc are always processed by the LLM.
217
+ web search, etc. are always processed by the LLM.
218
218
  """
219
219
 
220
220
  reset_tool_choice: bool = True
@@ -289,30 +289,3 @@ class Agent(AgentBase, Generic[TContext]):
289
289
  ) -> ResponsePromptParam | None:
290
290
  """Get the prompt for the agent."""
291
291
  return await PromptUtil.to_model_input(self.prompt, run_context, self)
292
-
293
- async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
294
- """Fetches the available tools from the MCP servers."""
295
- convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
296
- return await MCPUtil.get_all_function_tools(
297
- self.mcp_servers, convert_schemas_to_strict, run_context, self
298
- )
299
-
300
- async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
301
- """All agent tools, including MCP tools and function tools."""
302
- mcp_tools = await self.get_mcp_tools(run_context)
303
-
304
- async def _check_tool_enabled(tool: Tool) -> bool:
305
- if not isinstance(tool, FunctionTool):
306
- return True
307
-
308
- attr = tool.is_enabled
309
- if isinstance(attr, bool):
310
- return attr
311
- res = attr(run_context, self)
312
- if inspect.isawaitable(res):
313
- return bool(await res)
314
- return bool(res)
315
-
316
- results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
317
- enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
318
- return [*mcp_tools, *enabled]
agents/agent_output.py CHANGED
@@ -116,7 +116,7 @@ class AgentOutputSchema(AgentOutputSchemaBase):
116
116
  raise UserError(
117
117
  "Strict JSON schema is enabled, but the output type is not valid. "
118
118
  "Either make the output type strict, "
119
- "or wrap your type with AgentOutputSchema(your_type, strict_json_schema=False)"
119
+ "or wrap your type with AgentOutputSchema(YourType, strict_json_schema=False)"
120
120
  ) from e
121
121
 
122
122
  def is_plain_text(self) -> bool:
@@ -45,6 +45,14 @@ from ...tracing.spans import Span
45
45
  from ...usage import Usage
46
46
 
47
47
 
48
+ class InternalChatCompletionMessage(ChatCompletionMessage):
49
+ """
50
+ An internal subclass to carry reasoning_content without modifying the original model.
51
+ """
52
+
53
+ reasoning_content: str
54
+
55
+
48
56
  class LitellmModel(Model):
49
57
  """This class enables using any model via LiteLLM. LiteLLM allows you to acess OpenAPI,
50
58
  Anthropic, Gemini, Mistral, and many other models.
@@ -364,13 +372,18 @@ class LitellmConverter:
364
372
  provider_specific_fields.get("refusal", None) if provider_specific_fields else None
365
373
  )
366
374
 
367
- return ChatCompletionMessage(
375
+ reasoning_content = ""
376
+ if hasattr(message, "reasoning_content") and message.reasoning_content:
377
+ reasoning_content = message.reasoning_content
378
+
379
+ return InternalChatCompletionMessage(
368
380
  content=message.content,
369
381
  refusal=refusal,
370
382
  role="assistant",
371
383
  annotations=cls.convert_annotations_to_openai(message),
372
384
  audio=message.get("audio", None), # litellm deletes audio if not present
373
385
  tool_calls=tool_calls,
386
+ reasoning_content=reasoning_content,
374
387
  )
375
388
 
376
389
  @classmethod
agents/function_schema.py CHANGED
@@ -76,7 +76,7 @@ class FuncSchema:
76
76
 
77
77
  @dataclass
78
78
  class FuncDocumentation:
79
- """Contains metadata about a python function, extracted from its docstring."""
79
+ """Contains metadata about a Python function, extracted from its docstring."""
80
80
 
81
81
  name: str
82
82
  """The name of the function, via `__name__`."""
@@ -194,7 +194,7 @@ def function_schema(
194
194
  strict_json_schema: bool = True,
195
195
  ) -> FuncSchema:
196
196
  """
197
- Given a python function, extracts a `FuncSchema` from it, capturing the name, description,
197
+ Given a Python function, extracts a `FuncSchema` from it, capturing the name, description,
198
198
  parameter descriptions, and other metadata.
199
199
 
200
200
  Args:
@@ -208,7 +208,7 @@ def function_schema(
208
208
  descriptions.
209
209
  strict_json_schema: Whether the JSON schema is in strict mode. If True, we'll ensure that
210
210
  the schema adheres to the "strict" standard the OpenAI API expects. We **strongly**
211
- recommend setting this to True, as it increases the likelihood of the LLM providing
211
+ recommend setting this to True, as it increases the likelihood of the LLM producing
212
212
  correct JSON input.
213
213
 
214
214
  Returns:
agents/guardrail.py CHANGED
@@ -78,8 +78,9 @@ class InputGuardrail(Generic[TContext]):
78
78
  You can use the `@input_guardrail()` decorator to turn a function into an `InputGuardrail`, or
79
79
  create an `InputGuardrail` manually.
80
80
 
81
- Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`, the agent
82
- execution will immediately stop and a `InputGuardrailTripwireTriggered` exception will be raised
81
+ Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`,
82
+ the agent's execution will immediately stop, and
83
+ an `InputGuardrailTripwireTriggered` exception will be raised
83
84
  """
84
85
 
85
86
  guardrail_function: Callable[
@@ -132,7 +133,7 @@ class OutputGuardrail(Generic[TContext]):
132
133
  You can use the `@output_guardrail()` decorator to turn a function into an `OutputGuardrail`,
133
134
  or create an `OutputGuardrail` manually.
134
135
 
135
- Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`, a
136
+ Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`, an
136
137
  `OutputGuardrailTripwireTriggered` exception will be raised.
137
138
  """
138
139
 
@@ -314,7 +315,11 @@ def output_guardrail(
314
315
  def decorator(
315
316
  f: _OutputGuardrailFuncSync[TContext_co] | _OutputGuardrailFuncAsync[TContext_co],
316
317
  ) -> OutputGuardrail[TContext_co]:
317
- return OutputGuardrail(guardrail_function=f, name=name)
318
+ return OutputGuardrail(
319
+ guardrail_function=f,
320
+ # Guardrail name defaults to function's name when not specified (None).
321
+ name=name if name else f.__name__,
322
+ )
318
323
 
319
324
  if func is not None:
320
325
  # Decorator was used without parentheses
agents/items.py CHANGED
@@ -5,6 +5,7 @@ import copy
5
5
  from dataclasses import dataclass
6
6
  from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, Union
7
7
 
8
+ import pydantic
8
9
  from openai.types.responses import (
9
10
  Response,
10
11
  ResponseComputerToolCall,
@@ -212,7 +213,7 @@ RunItem: TypeAlias = Union[
212
213
  """An item generated by an agent."""
213
214
 
214
215
 
215
- @dataclass
216
+ @pydantic.dataclasses.dataclass
216
217
  class ModelResponse:
217
218
  output: list[TResponseOutputItem]
218
219
  """A list of outputs (messages, tool calls, etc) generated by the model"""
agents/model_settings.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
4
  from collections.abc import Mapping
5
- from dataclasses import dataclass, fields, replace
5
+ from dataclasses import fields, replace
6
6
  from typing import Annotated, Any, Literal, Union
7
7
 
8
8
  from openai import Omit as _Omit
@@ -10,6 +10,7 @@ from openai._types import Body, Query
10
10
  from openai.types.responses import ResponseIncludable
11
11
  from openai.types.shared import Reasoning
12
12
  from pydantic import BaseModel, GetCoreSchemaHandler
13
+ from pydantic.dataclasses import dataclass
13
14
  from pydantic_core import core_schema
14
15
  from typing_extensions import TypeAlias
15
16
 
@@ -36,6 +36,7 @@ from openai.types.responses import (
36
36
  ResponseOutputRefusal,
37
37
  ResponseOutputText,
38
38
  ResponseReasoningItem,
39
+ ResponseReasoningItemParam,
39
40
  )
40
41
  from openai.types.responses.response_input_param import FunctionCallOutput, ItemReference, Message
41
42
  from openai.types.responses.response_reasoning_item import Summary
@@ -210,6 +211,12 @@ class Converter:
210
211
  return cast(ResponseOutputMessageParam, item)
211
212
  return None
212
213
 
214
+ @classmethod
215
+ def maybe_reasoning_message(cls, item: Any) -> ResponseReasoningItemParam | None:
216
+ if isinstance(item, dict) and item.get("type") == "reasoning":
217
+ return cast(ResponseReasoningItemParam, item)
218
+ return None
219
+
213
220
  @classmethod
214
221
  def extract_text_content(
215
222
  cls, content: str | Iterable[ResponseInputContentParam]
@@ -459,7 +466,11 @@ class Converter:
459
466
  f"Encountered an item_reference, which is not supported: {item_ref}"
460
467
  )
461
468
 
462
- # 7) If we haven't recognized it => fail or ignore
469
+ # 7) reasoning message => not handled
470
+ elif cls.maybe_reasoning_message(item):
471
+ pass
472
+
473
+ # 8) If we haven't recognized it => fail or ignore
463
474
  else:
464
475
  raise UserError(f"Unhandled item type or structure: {item}")
465
476
 
@@ -198,6 +198,7 @@ class ChatCmplStreamHandler:
198
198
  is not None, # fixed 0 -> 0 or 1
199
199
  type="response.output_text.delta",
200
200
  sequence_number=sequence_number.get_and_increment(),
201
+ logprobs=[],
201
202
  )
202
203
  # Accumulate the text into the response part
203
204
  state.text_content_index_and_output[1].text += delta.content
@@ -288,10 +289,11 @@ class ChatCmplStreamHandler:
288
289
  function_call = state.function_calls[tc_delta.index]
289
290
 
290
291
  # Start streaming as soon as we have function name and call_id
291
- if (not state.function_call_streaming[tc_delta.index] and
292
- function_call.name and
293
- function_call.call_id):
294
-
292
+ if (
293
+ not state.function_call_streaming[tc_delta.index]
294
+ and function_call.name
295
+ and function_call.call_id
296
+ ):
295
297
  # Calculate the output index for this function call
296
298
  function_call_starting_index = 0
297
299
  if state.reasoning_content_index_and_output:
@@ -308,9 +310,9 @@ class ChatCmplStreamHandler:
308
310
 
309
311
  # Mark this function call as streaming and store its output index
310
312
  state.function_call_streaming[tc_delta.index] = True
311
- state.function_call_output_idx[
312
- tc_delta.index
313
- ] = function_call_starting_index
313
+ state.function_call_output_idx[tc_delta.index] = (
314
+ function_call_starting_index
315
+ )
314
316
 
315
317
  # Send initial function call added event
316
318
  yield ResponseOutputItemAddedEvent(
@@ -327,10 +329,11 @@ class ChatCmplStreamHandler:
327
329
  )
328
330
 
329
331
  # Stream arguments if we've started streaming this function call
330
- if (state.function_call_streaming.get(tc_delta.index, False) and
331
- tc_function and
332
- tc_function.arguments):
333
-
332
+ if (
333
+ state.function_call_streaming.get(tc_delta.index, False)
334
+ and tc_function
335
+ and tc_function.arguments
336
+ ):
334
337
  output_index = state.function_call_output_idx[tc_delta.index]
335
338
  yield ResponseFunctionCallArgumentsDeltaEvent(
336
339
  delta=tc_function.arguments,
@@ -493,9 +496,9 @@ class ChatCmplStreamHandler:
493
496
  final_response.output = outputs
494
497
  final_response.usage = (
495
498
  ResponseUsage(
496
- input_tokens=usage.prompt_tokens,
497
- output_tokens=usage.completion_tokens,
498
- total_tokens=usage.total_tokens,
499
+ input_tokens=usage.prompt_tokens or 0,
500
+ output_tokens=usage.completion_tokens or 0,
501
+ total_tokens=usage.total_tokens or 0,
499
502
  output_tokens_details=OutputTokensDetails(
500
503
  reasoning_tokens=usage.completion_tokens_details.reasoning_tokens
501
504
  if usage.completion_tokens_details
@@ -47,6 +47,8 @@ from .model import (
47
47
  RealtimeModel,
48
48
  RealtimeModelConfig,
49
49
  RealtimeModelListener,
50
+ RealtimePlaybackState,
51
+ RealtimePlaybackTracker,
50
52
  )
51
53
  from .model_events import (
52
54
  RealtimeConnectionStatus,
@@ -139,6 +141,8 @@ __all__ = [
139
141
  "RealtimeModel",
140
142
  "RealtimeModelConfig",
141
143
  "RealtimeModelListener",
144
+ "RealtimePlaybackTracker",
145
+ "RealtimePlaybackState",
142
146
  # Model Events
143
147
  "RealtimeConnectionStatus",
144
148
  "RealtimeModelAudioDoneEvent",
@@ -0,0 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from datetime import datetime
5
+
6
+ from ._util import calculate_audio_length_ms
7
+ from .config import RealtimeAudioFormat
8
+
9
+
10
+ @dataclass
11
+ class ModelAudioState:
12
+ initial_received_time: datetime
13
+ audio_length_ms: float
14
+
15
+
16
+ class ModelAudioTracker:
17
+ def __init__(self) -> None:
18
+ # (item_id, item_content_index) -> ModelAudioState
19
+ self._states: dict[tuple[str, int], ModelAudioState] = {}
20
+ self._last_audio_item: tuple[str, int] | None = None
21
+
22
+ def set_audio_format(self, format: RealtimeAudioFormat) -> None:
23
+ """Called when the model wants to set the audio format."""
24
+ self._format = format
25
+
26
+ def on_audio_delta(self, item_id: str, item_content_index: int, audio_bytes: bytes) -> None:
27
+ """Called when an audio delta is received from the model."""
28
+ ms = calculate_audio_length_ms(self._format, audio_bytes)
29
+ new_key = (item_id, item_content_index)
30
+
31
+ self._last_audio_item = new_key
32
+ if new_key not in self._states:
33
+ self._states[new_key] = ModelAudioState(datetime.now(), ms)
34
+ else:
35
+ self._states[new_key].audio_length_ms += ms
36
+
37
+ def on_interrupted(self) -> None:
38
+ """Called when the audio playback has been interrupted."""
39
+ self._last_audio_item = None
40
+
41
+ def get_state(self, item_id: str, item_content_index: int) -> ModelAudioState | None:
42
+ """Called when the model wants to get the current playback state."""
43
+ return self._states.get((item_id, item_content_index))
44
+
45
+ def get_last_audio_item(self) -> tuple[str, int] | None:
46
+ """Called when the model wants to get the last audio item ID and content index."""
47
+ return self._last_audio_item
@@ -0,0 +1,9 @@
1
+ from __future__ import annotations
2
+
3
+ from .config import RealtimeAudioFormat
4
+
5
+
6
+ def calculate_audio_length_ms(format: RealtimeAudioFormat | None, audio_bytes: bytes) -> float:
7
+ if format and format.startswith("g711"):
8
+ return (len(audio_bytes) / 8000) * 1000
9
+ return (len(audio_bytes) / 24 / 2) * 1000
agents/realtime/events.py CHANGED
@@ -115,6 +115,12 @@ class RealtimeAudioEnd:
115
115
  info: RealtimeEventInfo
116
116
  """Common info for all events, such as the context."""
117
117
 
118
+ item_id: str
119
+ """The ID of the item containing audio."""
120
+
121
+ content_index: int
122
+ """The index of the audio content in `item.content`"""
123
+
118
124
  type: Literal["audio_end"] = "audio_end"
119
125
 
120
126
 
@@ -125,6 +131,12 @@ class RealtimeAudio:
125
131
  audio: RealtimeModelAudioEvent
126
132
  """The audio event from the model layer."""
127
133
 
134
+ item_id: str
135
+ """The ID of the item containing audio."""
136
+
137
+ content_index: int
138
+ """The index of the audio content in `item.content`"""
139
+
128
140
  info: RealtimeEventInfo
129
141
  """Common info for all events, such as the context."""
130
142
 
@@ -140,6 +152,12 @@ class RealtimeAudioInterrupted:
140
152
  info: RealtimeEventInfo
141
153
  """Common info for all events, such as the context."""
142
154
 
155
+ item_id: str
156
+ """The ID of the item containing audio."""
157
+
158
+ content_index: int
159
+ """The index of the audio content in `item.content`"""
160
+
143
161
  type: Literal["audio_interrupted"] = "audio_interrupted"
144
162
 
145
163
 
agents/realtime/model.py CHANGED
@@ -6,13 +6,95 @@ from typing import Callable
6
6
  from typing_extensions import NotRequired, TypedDict
7
7
 
8
8
  from ..util._types import MaybeAwaitable
9
+ from ._util import calculate_audio_length_ms
9
10
  from .config import (
11
+ RealtimeAudioFormat,
10
12
  RealtimeSessionModelSettings,
11
13
  )
12
14
  from .model_events import RealtimeModelEvent
13
15
  from .model_inputs import RealtimeModelSendEvent
14
16
 
15
17
 
18
+ class RealtimePlaybackState(TypedDict):
19
+ current_item_id: str | None
20
+ """The item ID of the current item being played."""
21
+
22
+ current_item_content_index: int | None
23
+ """The index of the current item content being played."""
24
+
25
+ elapsed_ms: float | None
26
+ """The number of milliseconds of audio that have been played."""
27
+
28
+
29
+ class RealtimePlaybackTracker:
30
+ """If you have custom playback logic or expect that audio is played with delays or at different
31
+ speeds, create an instance of RealtimePlaybackTracker and pass it to the session. You are
32
+ responsible for tracking the audio playback progress and calling `on_play_bytes` or
33
+ `on_play_ms` when the user has played some audio."""
34
+
35
+ def __init__(self) -> None:
36
+ self._format: RealtimeAudioFormat | None = None
37
+ # (item_id, item_content_index)
38
+ self._current_item: tuple[str, int] | None = None
39
+ self._elapsed_ms: float | None = None
40
+
41
+ def on_play_bytes(self, item_id: str, item_content_index: int, bytes: bytes) -> None:
42
+ """Called by you when you have played some audio.
43
+
44
+ Args:
45
+ item_id: The item ID of the audio being played.
46
+ item_content_index: The index of the audio content in `item.content`
47
+ bytes: The audio bytes that have been fully played.
48
+ """
49
+ ms = calculate_audio_length_ms(self._format, bytes)
50
+ self.on_play_ms(item_id, item_content_index, ms)
51
+
52
+ def on_play_ms(self, item_id: str, item_content_index: int, ms: float) -> None:
53
+ """Called by you when you have played some audio.
54
+
55
+ Args:
56
+ item_id: The item ID of the audio being played.
57
+ item_content_index: The index of the audio content in `item.content`
58
+ ms: The number of milliseconds of audio that have been played.
59
+ """
60
+ if self._current_item != (item_id, item_content_index):
61
+ self._current_item = (item_id, item_content_index)
62
+ self._elapsed_ms = ms
63
+ else:
64
+ assert self._elapsed_ms is not None
65
+ self._elapsed_ms += ms
66
+
67
+ def on_interrupted(self) -> None:
68
+ """Called by the model when the audio playback has been interrupted."""
69
+ self._current_item = None
70
+ self._elapsed_ms = None
71
+
72
+ def set_audio_format(self, format: RealtimeAudioFormat) -> None:
73
+ """Will be called by the model to set the audio format.
74
+
75
+ Args:
76
+ format: The audio format to use.
77
+ """
78
+ self._format = format
79
+
80
+ def get_state(self) -> RealtimePlaybackState:
81
+ """Will be called by the model to get the current playback state."""
82
+ if self._current_item is None:
83
+ return {
84
+ "current_item_id": None,
85
+ "current_item_content_index": None,
86
+ "elapsed_ms": None,
87
+ }
88
+ assert self._elapsed_ms is not None
89
+
90
+ item_id, item_content_index = self._current_item
91
+ return {
92
+ "current_item_id": item_id,
93
+ "current_item_content_index": item_content_index,
94
+ "elapsed_ms": self._elapsed_ms,
95
+ }
96
+
97
+
16
98
  class RealtimeModelListener(abc.ABC):
17
99
  """A listener for realtime transport events."""
18
100
 
@@ -39,6 +121,18 @@ class RealtimeModelConfig(TypedDict):
39
121
  initial_model_settings: NotRequired[RealtimeSessionModelSettings]
40
122
  """The initial model settings to use when connecting."""
41
123
 
124
+ playback_tracker: NotRequired[RealtimePlaybackTracker]
125
+ """The playback tracker to use when tracking audio playback progress. If not set, the model will
126
+ use a default implementation that assumes audio is played immediately, at realtime speed.
127
+
128
+ A playback tracker is useful for interruptions. The model generates audio much faster than
129
+ realtime playback speed. So if there's an interruption, its useful for the model to know how
130
+ much of the audio has been played by the user. In low-latency scenarios, it's fine to assume
131
+ that audio is played back immediately at realtime speed. But in scenarios like phone calls or
132
+ other remote interactions, you can set a playback tracker that lets the model know when audio
133
+ is played to the user.
134
+ """
135
+
42
136
 
43
137
  class RealtimeModel(abc.ABC):
44
138
  """Interface for connecting to a realtime model and sending/receiving events."""
@@ -40,6 +40,12 @@ class RealtimeModelAudioEvent:
40
40
  data: bytes
41
41
  response_id: str
42
42
 
43
+ item_id: str
44
+ """The ID of the item containing audio."""
45
+
46
+ content_index: int
47
+ """The index of the audio content in `item.content`"""
48
+
43
49
  type: Literal["audio"] = "audio"
44
50
 
45
51
 
@@ -47,6 +53,12 @@ class RealtimeModelAudioEvent:
47
53
  class RealtimeModelAudioInterruptedEvent:
48
54
  """Audio interrupted."""
49
55
 
56
+ item_id: str
57
+ """The ID of the item containing audio."""
58
+
59
+ content_index: int
60
+ """The index of the audio content in `item.content`"""
61
+
50
62
  type: Literal["audio_interrupted"] = "audio_interrupted"
51
63
 
52
64
 
@@ -54,6 +66,12 @@ class RealtimeModelAudioInterruptedEvent:
54
66
  class RealtimeModelAudioDoneEvent:
55
67
  """Audio done."""
56
68
 
69
+ item_id: str
70
+ """The ID of the item containing audio."""
71
+
72
+ content_index: int
73
+ """The index of the audio content in `item.content`"""
74
+
57
75
  type: Literal["audio_done"] = "audio_done"
58
76
 
59
77
 
@@ -138,6 +156,15 @@ class RealtimeModelExceptionEvent:
138
156
  type: Literal["exception"] = "exception"
139
157
 
140
158
 
159
+ @dataclass
160
+ class RealtimeModelRawServerEvent:
161
+ """Raw events forwarded from the server."""
162
+
163
+ data: Any
164
+
165
+ type: Literal["raw_server_event"] = "raw_server_event"
166
+
167
+
141
168
  # TODO (rm) Add usage events
142
169
 
143
170
 
@@ -156,4 +183,5 @@ RealtimeModelEvent: TypeAlias = Union[
156
183
  RealtimeModelTurnEndedEvent,
157
184
  RealtimeModelOtherEvent,
158
185
  RealtimeModelExceptionEvent,
186
+ RealtimeModelRawServerEvent,
159
187
  ]