openai-agents 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +12 -0
- agents/_run_impl.py +18 -6
- agents/extensions/memory/__init__.py +1 -3
- agents/extensions/memory/sqlalchemy_session.py +25 -3
- agents/extensions/models/litellm_model.py +11 -6
- agents/items.py +103 -4
- agents/mcp/server.py +43 -11
- agents/mcp/util.py +17 -1
- agents/memory/openai_conversations_session.py +2 -2
- agents/models/chatcmpl_converter.py +44 -18
- agents/models/openai_chatcompletions.py +27 -26
- agents/models/openai_responses.py +31 -29
- agents/realtime/handoffs.py +1 -1
- agents/realtime/model_inputs.py +3 -0
- agents/realtime/openai_realtime.py +38 -29
- agents/realtime/session.py +1 -1
- agents/result.py +48 -11
- agents/run.py +223 -27
- agents/stream_events.py +1 -0
- agents/strict_schema.py +14 -0
- agents/tool.py +86 -3
- agents/voice/models/openai_stt.py +2 -1
- {openai_agents-0.3.3.dist-info → openai_agents-0.4.1.dist-info}/METADATA +2 -2
- {openai_agents-0.3.3.dist-info → openai_agents-0.4.1.dist-info}/RECORD +26 -26
- {openai_agents-0.3.3.dist-info → openai_agents-0.4.1.dist-info}/WHEEL +0 -0
- {openai_agents-0.3.3.dist-info → openai_agents-0.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,12 +2,13 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
from collections.abc import Iterable
|
|
5
|
-
from typing import Any, Literal, cast
|
|
5
|
+
from typing import Any, Literal, Union, cast
|
|
6
6
|
|
|
7
|
-
from openai import
|
|
7
|
+
from openai import Omit, omit
|
|
8
8
|
from openai.types.chat import (
|
|
9
9
|
ChatCompletionAssistantMessageParam,
|
|
10
10
|
ChatCompletionContentPartImageParam,
|
|
11
|
+
ChatCompletionContentPartInputAudioParam,
|
|
11
12
|
ChatCompletionContentPartParam,
|
|
12
13
|
ChatCompletionContentPartTextParam,
|
|
13
14
|
ChatCompletionDeveloperMessageParam,
|
|
@@ -27,6 +28,7 @@ from openai.types.responses import (
|
|
|
27
28
|
ResponseFileSearchToolCallParam,
|
|
28
29
|
ResponseFunctionToolCall,
|
|
29
30
|
ResponseFunctionToolCallParam,
|
|
31
|
+
ResponseInputAudioParam,
|
|
30
32
|
ResponseInputContentParam,
|
|
31
33
|
ResponseInputFileParam,
|
|
32
34
|
ResponseInputImageParam,
|
|
@@ -54,9 +56,9 @@ class Converter:
|
|
|
54
56
|
@classmethod
|
|
55
57
|
def convert_tool_choice(
|
|
56
58
|
cls, tool_choice: Literal["auto", "required", "none"] | str | MCPToolChoice | None
|
|
57
|
-
) -> ChatCompletionToolChoiceOptionParam |
|
|
59
|
+
) -> ChatCompletionToolChoiceOptionParam | Omit:
|
|
58
60
|
if tool_choice is None:
|
|
59
|
-
return
|
|
61
|
+
return omit
|
|
60
62
|
elif isinstance(tool_choice, MCPToolChoice):
|
|
61
63
|
raise UserError("MCPToolChoice is not supported for Chat Completions models")
|
|
62
64
|
elif tool_choice == "auto":
|
|
@@ -76,9 +78,9 @@ class Converter:
|
|
|
76
78
|
@classmethod
|
|
77
79
|
def convert_response_format(
|
|
78
80
|
cls, final_output_schema: AgentOutputSchemaBase | None
|
|
79
|
-
) -> ResponseFormat |
|
|
81
|
+
) -> ResponseFormat | Omit:
|
|
80
82
|
if not final_output_schema or final_output_schema.is_plain_text():
|
|
81
|
-
return
|
|
83
|
+
return omit
|
|
82
84
|
|
|
83
85
|
return {
|
|
84
86
|
"type": "json_schema",
|
|
@@ -287,23 +289,44 @@ class Converter:
|
|
|
287
289
|
},
|
|
288
290
|
)
|
|
289
291
|
)
|
|
292
|
+
elif isinstance(c, dict) and c.get("type") == "input_audio":
|
|
293
|
+
casted_audio_param = cast(ResponseInputAudioParam, c)
|
|
294
|
+
audio_payload = casted_audio_param.get("input_audio")
|
|
295
|
+
if not audio_payload:
|
|
296
|
+
raise UserError(
|
|
297
|
+
f"Only audio data is supported for input_audio {casted_audio_param}"
|
|
298
|
+
)
|
|
299
|
+
if not isinstance(audio_payload, dict):
|
|
300
|
+
raise UserError(
|
|
301
|
+
f"input_audio must provide audio data and format {casted_audio_param}"
|
|
302
|
+
)
|
|
303
|
+
audio_data = audio_payload.get("data")
|
|
304
|
+
audio_format = audio_payload.get("format")
|
|
305
|
+
if not audio_data or not audio_format:
|
|
306
|
+
raise UserError(
|
|
307
|
+
f"input_audio requires both data and format {casted_audio_param}"
|
|
308
|
+
)
|
|
309
|
+
out.append(
|
|
310
|
+
ChatCompletionContentPartInputAudioParam(
|
|
311
|
+
type="input_audio",
|
|
312
|
+
input_audio={
|
|
313
|
+
"data": audio_data,
|
|
314
|
+
"format": audio_format,
|
|
315
|
+
},
|
|
316
|
+
)
|
|
317
|
+
)
|
|
290
318
|
elif isinstance(c, dict) and c.get("type") == "input_file":
|
|
291
319
|
casted_file_param = cast(ResponseInputFileParam, c)
|
|
292
320
|
if "file_data" not in casted_file_param or not casted_file_param["file_data"]:
|
|
293
321
|
raise UserError(
|
|
294
322
|
f"Only file_data is supported for input_file {casted_file_param}"
|
|
295
323
|
)
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
file_data=casted_file_param["file_data"],
|
|
303
|
-
filename=casted_file_param["filename"],
|
|
304
|
-
),
|
|
305
|
-
)
|
|
306
|
-
)
|
|
324
|
+
filedata = FileFile(file_data=casted_file_param["file_data"])
|
|
325
|
+
|
|
326
|
+
if "filename" in casted_file_param and casted_file_param["filename"]:
|
|
327
|
+
filedata["filename"] = casted_file_param["filename"]
|
|
328
|
+
|
|
329
|
+
out.append(File(type="file", file=filedata))
|
|
307
330
|
else:
|
|
308
331
|
raise UserError(f"Unknown content: {c}")
|
|
309
332
|
return out
|
|
@@ -511,10 +534,13 @@ class Converter:
|
|
|
511
534
|
# 5) function call output => tool message
|
|
512
535
|
elif func_output := cls.maybe_function_tool_call_output(item):
|
|
513
536
|
flush_assistant_message()
|
|
537
|
+
output_content = cast(
|
|
538
|
+
Union[str, Iterable[ResponseInputContentParam]], func_output["output"]
|
|
539
|
+
)
|
|
514
540
|
msg: ChatCompletionToolMessageParam = {
|
|
515
541
|
"role": "tool",
|
|
516
542
|
"tool_call_id": func_output["call_id"],
|
|
517
|
-
"content":
|
|
543
|
+
"content": cls.extract_text_content(output_content),
|
|
518
544
|
}
|
|
519
545
|
result.append(msg)
|
|
520
546
|
|
|
@@ -3,9 +3,9 @@ from __future__ import annotations
|
|
|
3
3
|
import json
|
|
4
4
|
import time
|
|
5
5
|
from collections.abc import AsyncIterator
|
|
6
|
-
from typing import TYPE_CHECKING, Any, Literal, overload
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal, cast, overload
|
|
7
7
|
|
|
8
|
-
from openai import
|
|
8
|
+
from openai import AsyncOpenAI, AsyncStream, Omit, omit
|
|
9
9
|
from openai.types import ChatModel
|
|
10
10
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
|
|
11
11
|
from openai.types.chat.chat_completion import Choice
|
|
@@ -44,8 +44,8 @@ class OpenAIChatCompletionsModel(Model):
|
|
|
44
44
|
self.model = model
|
|
45
45
|
self._client = openai_client
|
|
46
46
|
|
|
47
|
-
def
|
|
48
|
-
return value if value is not None else
|
|
47
|
+
def _non_null_or_omit(self, value: Any) -> Any:
|
|
48
|
+
return value if value is not None else omit
|
|
49
49
|
|
|
50
50
|
async def get_response(
|
|
51
51
|
self,
|
|
@@ -243,13 +243,12 @@ class OpenAIChatCompletionsModel(Model):
|
|
|
243
243
|
if tracing.include_data():
|
|
244
244
|
span.span_data.input = converted_messages
|
|
245
245
|
|
|
246
|
-
parallel_tool_calls
|
|
247
|
-
True
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
)
|
|
246
|
+
if model_settings.parallel_tool_calls and tools:
|
|
247
|
+
parallel_tool_calls: bool | Omit = True
|
|
248
|
+
elif model_settings.parallel_tool_calls is False:
|
|
249
|
+
parallel_tool_calls = False
|
|
250
|
+
else:
|
|
251
|
+
parallel_tool_calls = omit
|
|
253
252
|
tool_choice = Converter.convert_tool_choice(model_settings.tool_choice)
|
|
254
253
|
response_format = Converter.convert_response_format(output_schema)
|
|
255
254
|
|
|
@@ -259,6 +258,7 @@ class OpenAIChatCompletionsModel(Model):
|
|
|
259
258
|
converted_tools.append(Converter.convert_handoff_tool(handoff))
|
|
260
259
|
|
|
261
260
|
converted_tools = _to_dump_compatible(converted_tools)
|
|
261
|
+
tools_param = converted_tools if converted_tools else omit
|
|
262
262
|
|
|
263
263
|
if _debug.DONT_LOG_MODEL_DATA:
|
|
264
264
|
logger.debug("Calling LLM")
|
|
@@ -288,28 +288,30 @@ class OpenAIChatCompletionsModel(Model):
|
|
|
288
288
|
self._get_client(), model_settings, stream=stream
|
|
289
289
|
)
|
|
290
290
|
|
|
291
|
+
stream_param: Literal[True] | Omit = True if stream else omit
|
|
292
|
+
|
|
291
293
|
ret = await self._get_client().chat.completions.create(
|
|
292
294
|
model=self.model,
|
|
293
295
|
messages=converted_messages,
|
|
294
|
-
tools=
|
|
295
|
-
temperature=self.
|
|
296
|
-
top_p=self.
|
|
297
|
-
frequency_penalty=self.
|
|
298
|
-
presence_penalty=self.
|
|
299
|
-
max_tokens=self.
|
|
296
|
+
tools=tools_param,
|
|
297
|
+
temperature=self._non_null_or_omit(model_settings.temperature),
|
|
298
|
+
top_p=self._non_null_or_omit(model_settings.top_p),
|
|
299
|
+
frequency_penalty=self._non_null_or_omit(model_settings.frequency_penalty),
|
|
300
|
+
presence_penalty=self._non_null_or_omit(model_settings.presence_penalty),
|
|
301
|
+
max_tokens=self._non_null_or_omit(model_settings.max_tokens),
|
|
300
302
|
tool_choice=tool_choice,
|
|
301
303
|
response_format=response_format,
|
|
302
304
|
parallel_tool_calls=parallel_tool_calls,
|
|
303
|
-
stream=
|
|
304
|
-
stream_options=self.
|
|
305
|
-
store=self.
|
|
306
|
-
reasoning_effort=self.
|
|
307
|
-
verbosity=self.
|
|
308
|
-
top_logprobs=self.
|
|
305
|
+
stream=cast(Any, stream_param),
|
|
306
|
+
stream_options=self._non_null_or_omit(stream_options),
|
|
307
|
+
store=self._non_null_or_omit(store),
|
|
308
|
+
reasoning_effort=self._non_null_or_omit(reasoning_effort),
|
|
309
|
+
verbosity=self._non_null_or_omit(model_settings.verbosity),
|
|
310
|
+
top_logprobs=self._non_null_or_omit(model_settings.top_logprobs),
|
|
309
311
|
extra_headers=self._merge_headers(model_settings),
|
|
310
312
|
extra_query=model_settings.extra_query,
|
|
311
313
|
extra_body=model_settings.extra_body,
|
|
312
|
-
metadata=self.
|
|
314
|
+
metadata=self._non_null_or_omit(model_settings.metadata),
|
|
313
315
|
**(model_settings.extra_args or {}),
|
|
314
316
|
)
|
|
315
317
|
|
|
@@ -319,14 +321,13 @@ class OpenAIChatCompletionsModel(Model):
|
|
|
319
321
|
responses_tool_choice = OpenAIResponsesConverter.convert_tool_choice(
|
|
320
322
|
model_settings.tool_choice
|
|
321
323
|
)
|
|
322
|
-
if responses_tool_choice is None or responses_tool_choice
|
|
324
|
+
if responses_tool_choice is None or responses_tool_choice is omit:
|
|
323
325
|
# For Responses API data compatibility with Chat Completions patterns,
|
|
324
326
|
# we need to set "none" if tool_choice is absent.
|
|
325
327
|
# Without this fix, you'll get the following error:
|
|
326
328
|
# pydantic_core._pydantic_core.ValidationError: 4 validation errors for Response
|
|
327
329
|
# tool_choice.literal['none','auto','required']
|
|
328
330
|
# Input should be 'none', 'auto' or 'required'
|
|
329
|
-
# [type=literal_error, input_value=NOT_GIVEN, input_type=NotGiven]
|
|
330
331
|
# see also: https://github.com/openai/openai-agents-python/issues/980
|
|
331
332
|
responses_tool_choice = "auto"
|
|
332
333
|
|
|
@@ -4,9 +4,9 @@ import json
|
|
|
4
4
|
from collections.abc import AsyncIterator
|
|
5
5
|
from contextvars import ContextVar
|
|
6
6
|
from dataclasses import dataclass
|
|
7
|
-
from typing import TYPE_CHECKING, Any, Literal, cast, overload
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Literal, Union, cast, overload
|
|
8
8
|
|
|
9
|
-
from openai import
|
|
9
|
+
from openai import APIStatusError, AsyncOpenAI, AsyncStream, Omit, omit
|
|
10
10
|
from openai.types import ChatModel
|
|
11
11
|
from openai.types.responses import (
|
|
12
12
|
Response,
|
|
@@ -69,8 +69,8 @@ class OpenAIResponsesModel(Model):
|
|
|
69
69
|
self.model = model
|
|
70
70
|
self._client = openai_client
|
|
71
71
|
|
|
72
|
-
def
|
|
73
|
-
return value if value is not None else
|
|
72
|
+
def _non_null_or_omit(self, value: Any) -> Any:
|
|
73
|
+
return value if value is not None else omit
|
|
74
74
|
|
|
75
75
|
async def get_response(
|
|
76
76
|
self,
|
|
@@ -249,13 +249,12 @@ class OpenAIResponsesModel(Model):
|
|
|
249
249
|
list_input = ItemHelpers.input_to_new_input_list(input)
|
|
250
250
|
list_input = _to_dump_compatible(list_input)
|
|
251
251
|
|
|
252
|
-
parallel_tool_calls
|
|
253
|
-
True
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
)
|
|
252
|
+
if model_settings.parallel_tool_calls and tools:
|
|
253
|
+
parallel_tool_calls: bool | Omit = True
|
|
254
|
+
elif model_settings.parallel_tool_calls is False:
|
|
255
|
+
parallel_tool_calls = False
|
|
256
|
+
else:
|
|
257
|
+
parallel_tool_calls = omit
|
|
259
258
|
|
|
260
259
|
tool_choice = Converter.convert_tool_choice(model_settings.tool_choice)
|
|
261
260
|
converted_tools = Converter.convert_tools(tools, handoffs)
|
|
@@ -297,36 +296,39 @@ class OpenAIResponsesModel(Model):
|
|
|
297
296
|
if model_settings.top_logprobs is not None:
|
|
298
297
|
extra_args["top_logprobs"] = model_settings.top_logprobs
|
|
299
298
|
if model_settings.verbosity is not None:
|
|
300
|
-
if response_format
|
|
299
|
+
if response_format is not omit:
|
|
301
300
|
response_format["verbosity"] = model_settings.verbosity # type: ignore [index]
|
|
302
301
|
else:
|
|
303
302
|
response_format = {"verbosity": model_settings.verbosity}
|
|
304
303
|
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
304
|
+
stream_param: Literal[True] | Omit = True if stream else omit
|
|
305
|
+
|
|
306
|
+
response = await self._client.responses.create(
|
|
307
|
+
previous_response_id=self._non_null_or_omit(previous_response_id),
|
|
308
|
+
conversation=self._non_null_or_omit(conversation_id),
|
|
309
|
+
instructions=self._non_null_or_omit(system_instructions),
|
|
309
310
|
model=self.model,
|
|
310
311
|
input=list_input,
|
|
311
312
|
include=include,
|
|
312
313
|
tools=converted_tools_payload,
|
|
313
|
-
prompt=self.
|
|
314
|
-
temperature=self.
|
|
315
|
-
top_p=self.
|
|
316
|
-
truncation=self.
|
|
317
|
-
max_output_tokens=self.
|
|
314
|
+
prompt=self._non_null_or_omit(prompt),
|
|
315
|
+
temperature=self._non_null_or_omit(model_settings.temperature),
|
|
316
|
+
top_p=self._non_null_or_omit(model_settings.top_p),
|
|
317
|
+
truncation=self._non_null_or_omit(model_settings.truncation),
|
|
318
|
+
max_output_tokens=self._non_null_or_omit(model_settings.max_tokens),
|
|
318
319
|
tool_choice=tool_choice,
|
|
319
320
|
parallel_tool_calls=parallel_tool_calls,
|
|
320
|
-
stream=
|
|
321
|
+
stream=cast(Any, stream_param),
|
|
321
322
|
extra_headers=self._merge_headers(model_settings),
|
|
322
323
|
extra_query=model_settings.extra_query,
|
|
323
324
|
extra_body=model_settings.extra_body,
|
|
324
325
|
text=response_format,
|
|
325
|
-
store=self.
|
|
326
|
-
reasoning=self.
|
|
327
|
-
metadata=self.
|
|
326
|
+
store=self._non_null_or_omit(model_settings.store),
|
|
327
|
+
reasoning=self._non_null_or_omit(model_settings.reasoning),
|
|
328
|
+
metadata=self._non_null_or_omit(model_settings.metadata),
|
|
328
329
|
**extra_args,
|
|
329
330
|
)
|
|
331
|
+
return cast(Union[Response, AsyncStream[ResponseStreamEvent]], response)
|
|
330
332
|
|
|
331
333
|
def _get_client(self) -> AsyncOpenAI:
|
|
332
334
|
if self._client is None:
|
|
@@ -351,9 +353,9 @@ class Converter:
|
|
|
351
353
|
@classmethod
|
|
352
354
|
def convert_tool_choice(
|
|
353
355
|
cls, tool_choice: Literal["auto", "required", "none"] | str | MCPToolChoice | None
|
|
354
|
-
) -> response_create_params.ToolChoice |
|
|
356
|
+
) -> response_create_params.ToolChoice | Omit:
|
|
355
357
|
if tool_choice is None:
|
|
356
|
-
return
|
|
358
|
+
return omit
|
|
357
359
|
elif isinstance(tool_choice, MCPToolChoice):
|
|
358
360
|
return {
|
|
359
361
|
"server_label": tool_choice.server_label,
|
|
@@ -404,9 +406,9 @@ class Converter:
|
|
|
404
406
|
@classmethod
|
|
405
407
|
def get_response_format(
|
|
406
408
|
cls, output_schema: AgentOutputSchemaBase | None
|
|
407
|
-
) -> ResponseTextConfigParam |
|
|
409
|
+
) -> ResponseTextConfigParam | Omit:
|
|
408
410
|
if output_schema is None or output_schema.is_plain_text():
|
|
409
|
-
return
|
|
411
|
+
return omit
|
|
410
412
|
else:
|
|
411
413
|
return {
|
|
412
414
|
"format": {
|
agents/realtime/handoffs.py
CHANGED
|
@@ -13,10 +13,10 @@ from ..strict_schema import ensure_strict_json_schema
|
|
|
13
13
|
from ..tracing.spans import SpanError
|
|
14
14
|
from ..util import _error_tracing, _json
|
|
15
15
|
from ..util._types import MaybeAwaitable
|
|
16
|
+
from . import RealtimeAgent
|
|
16
17
|
|
|
17
18
|
if TYPE_CHECKING:
|
|
18
19
|
from ..agent import AgentBase
|
|
19
|
-
from . import RealtimeAgent
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
# The handoff input type is the type of data passed when the agent is called via a handoff.
|
agents/realtime/model_inputs.py
CHANGED
|
@@ -95,6 +95,9 @@ class RealtimeModelSendToolOutput:
|
|
|
95
95
|
class RealtimeModelSendInterrupt:
|
|
96
96
|
"""Send an interrupt to the model."""
|
|
97
97
|
|
|
98
|
+
force_response_cancel: bool = False
|
|
99
|
+
"""Force sending a response.cancel event even if automatic cancellation is enabled."""
|
|
100
|
+
|
|
98
101
|
|
|
99
102
|
@dataclass
|
|
100
103
|
class RealtimeModelSendSessionUpdate:
|
|
@@ -266,7 +266,8 @@ class OpenAIRealtimeWebSocketModel(RealtimeModel):
|
|
|
266
266
|
|
|
267
267
|
async def _emit_event(self, event: RealtimeModelEvent) -> None:
|
|
268
268
|
"""Emit an event to the listeners."""
|
|
269
|
-
|
|
269
|
+
# Copy list to avoid modification during iteration
|
|
270
|
+
for listener in list(self._listeners):
|
|
270
271
|
await listener.on_event(event)
|
|
271
272
|
|
|
272
273
|
async def _listen_for_messages(self):
|
|
@@ -394,6 +395,7 @@ class OpenAIRealtimeWebSocketModel(RealtimeModel):
|
|
|
394
395
|
current_item_id = playback_state.get("current_item_id")
|
|
395
396
|
current_item_content_index = playback_state.get("current_item_content_index")
|
|
396
397
|
elapsed_ms = playback_state.get("elapsed_ms")
|
|
398
|
+
|
|
397
399
|
if current_item_id is None or elapsed_ms is None:
|
|
398
400
|
logger.debug(
|
|
399
401
|
"Skipping interrupt. "
|
|
@@ -401,29 +403,28 @@ class OpenAIRealtimeWebSocketModel(RealtimeModel):
|
|
|
401
403
|
f"elapsed ms: {elapsed_ms}, "
|
|
402
404
|
f"content index: {current_item_content_index}"
|
|
403
405
|
)
|
|
404
|
-
return
|
|
405
|
-
|
|
406
|
-
current_item_content_index = current_item_content_index or 0
|
|
407
|
-
if elapsed_ms > 0:
|
|
408
|
-
await self._emit_event(
|
|
409
|
-
RealtimeModelAudioInterruptedEvent(
|
|
410
|
-
item_id=current_item_id,
|
|
411
|
-
content_index=current_item_content_index,
|
|
412
|
-
)
|
|
413
|
-
)
|
|
414
|
-
converted = _ConversionHelper.convert_interrupt(
|
|
415
|
-
current_item_id,
|
|
416
|
-
current_item_content_index,
|
|
417
|
-
int(elapsed_ms),
|
|
418
|
-
)
|
|
419
|
-
await self._send_raw_message(converted)
|
|
420
406
|
else:
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
407
|
+
current_item_content_index = current_item_content_index or 0
|
|
408
|
+
if elapsed_ms > 0:
|
|
409
|
+
await self._emit_event(
|
|
410
|
+
RealtimeModelAudioInterruptedEvent(
|
|
411
|
+
item_id=current_item_id,
|
|
412
|
+
content_index=current_item_content_index,
|
|
413
|
+
)
|
|
414
|
+
)
|
|
415
|
+
converted = _ConversionHelper.convert_interrupt(
|
|
416
|
+
current_item_id,
|
|
417
|
+
current_item_content_index,
|
|
418
|
+
int(elapsed_ms),
|
|
419
|
+
)
|
|
420
|
+
await self._send_raw_message(converted)
|
|
421
|
+
else:
|
|
422
|
+
logger.debug(
|
|
423
|
+
"Didn't interrupt bc elapsed ms is < 0. "
|
|
424
|
+
f"Item id: {current_item_id}, "
|
|
425
|
+
f"elapsed ms: {elapsed_ms}, "
|
|
426
|
+
f"content index: {current_item_content_index}"
|
|
427
|
+
)
|
|
427
428
|
|
|
428
429
|
session = self._created_session
|
|
429
430
|
automatic_response_cancellation_enabled = (
|
|
@@ -431,14 +432,18 @@ class OpenAIRealtimeWebSocketModel(RealtimeModel):
|
|
|
431
432
|
and session.audio is not None
|
|
432
433
|
and session.audio.input is not None
|
|
433
434
|
and session.audio.input.turn_detection is not None
|
|
434
|
-
and session.audio.input.turn_detection.interrupt_response is True
|
|
435
|
+
and session.audio.input.turn_detection.interrupt_response is True
|
|
435
436
|
)
|
|
436
|
-
|
|
437
|
+
should_cancel_response = event.force_response_cancel or (
|
|
438
|
+
not automatic_response_cancellation_enabled
|
|
439
|
+
)
|
|
440
|
+
if should_cancel_response:
|
|
437
441
|
await self._cancel_response()
|
|
438
442
|
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
self._playback_tracker
|
|
443
|
+
if current_item_id is not None and elapsed_ms is not None:
|
|
444
|
+
self._audio_state_tracker.on_interrupted()
|
|
445
|
+
if self._playback_tracker:
|
|
446
|
+
self._playback_tracker.on_interrupted()
|
|
442
447
|
|
|
443
448
|
async def _send_session_update(self, event: RealtimeModelSendSessionUpdate) -> None:
|
|
444
449
|
"""Send a session update to the model."""
|
|
@@ -516,6 +521,10 @@ class OpenAIRealtimeWebSocketModel(RealtimeModel):
|
|
|
516
521
|
self._websocket = None
|
|
517
522
|
if self._websocket_task:
|
|
518
523
|
self._websocket_task.cancel()
|
|
524
|
+
try:
|
|
525
|
+
await self._websocket_task
|
|
526
|
+
except asyncio.CancelledError:
|
|
527
|
+
pass
|
|
519
528
|
self._websocket_task = None
|
|
520
529
|
|
|
521
530
|
async def _cancel_response(self) -> None:
|
|
@@ -616,7 +625,7 @@ class OpenAIRealtimeWebSocketModel(RealtimeModel):
|
|
|
616
625
|
and session.audio is not None
|
|
617
626
|
and session.audio.input is not None
|
|
618
627
|
and session.audio.input.turn_detection is not None
|
|
619
|
-
and session.audio.input.turn_detection.interrupt_response is True
|
|
628
|
+
and session.audio.input.turn_detection.interrupt_response is True
|
|
620
629
|
)
|
|
621
630
|
if not automatic_response_cancellation_enabled:
|
|
622
631
|
await self._cancel_response()
|
agents/realtime/session.py
CHANGED
|
@@ -704,7 +704,7 @@ class RealtimeSession(RealtimeModelListener):
|
|
|
704
704
|
)
|
|
705
705
|
|
|
706
706
|
# Interrupt the model
|
|
707
|
-
await self._model.send_event(RealtimeModelSendInterrupt())
|
|
707
|
+
await self._model.send_event(RealtimeModelSendInterrupt(force_response_cancel=True))
|
|
708
708
|
|
|
709
709
|
# Send guardrail triggered message
|
|
710
710
|
guardrail_names = [result.guardrail.get_name() for result in triggered_results]
|
agents/result.py
CHANGED
|
@@ -4,7 +4,7 @@ import abc
|
|
|
4
4
|
import asyncio
|
|
5
5
|
from collections.abc import AsyncIterator
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
|
-
from typing import TYPE_CHECKING, Any, cast
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Literal, cast
|
|
8
8
|
|
|
9
9
|
from typing_extensions import TypeVar
|
|
10
10
|
|
|
@@ -164,6 +164,9 @@ class RunResultStreaming(RunResultBase):
|
|
|
164
164
|
_output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
|
|
165
165
|
_stored_exception: Exception | None = field(default=None, repr=False)
|
|
166
166
|
|
|
167
|
+
# Soft cancel state
|
|
168
|
+
_cancel_mode: Literal["none", "immediate", "after_turn"] = field(default="none", repr=False)
|
|
169
|
+
|
|
167
170
|
@property
|
|
168
171
|
def last_agent(self) -> Agent[Any]:
|
|
169
172
|
"""The last agent that was run. Updates as the agent run progresses, so the true last agent
|
|
@@ -171,17 +174,51 @@ class RunResultStreaming(RunResultBase):
|
|
|
171
174
|
"""
|
|
172
175
|
return self.current_agent
|
|
173
176
|
|
|
174
|
-
def cancel(self) -> None:
|
|
175
|
-
"""
|
|
176
|
-
complete."""
|
|
177
|
-
self._cleanup_tasks() # Cancel all running tasks
|
|
178
|
-
self.is_complete = True # Mark the run as complete to stop event streaming
|
|
177
|
+
def cancel(self, mode: Literal["immediate", "after_turn"] = "immediate") -> None:
|
|
178
|
+
"""Cancel the streaming run.
|
|
179
179
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
180
|
+
Args:
|
|
181
|
+
mode: Cancellation strategy:
|
|
182
|
+
- "immediate": Stop immediately, cancel all tasks, clear queues (default)
|
|
183
|
+
- "after_turn": Complete current turn gracefully before stopping
|
|
184
|
+
* Allows LLM response to finish
|
|
185
|
+
* Executes pending tool calls
|
|
186
|
+
* Saves session state properly
|
|
187
|
+
* Tracks usage accurately
|
|
188
|
+
* Stops before next turn begins
|
|
189
|
+
|
|
190
|
+
Example:
|
|
191
|
+
```python
|
|
192
|
+
result = Runner.run_streamed(agent, "Task", session=session)
|
|
193
|
+
|
|
194
|
+
async for event in result.stream_events():
|
|
195
|
+
if user_interrupted():
|
|
196
|
+
result.cancel(mode="after_turn") # Graceful
|
|
197
|
+
# result.cancel() # Immediate (default)
|
|
198
|
+
```
|
|
199
|
+
|
|
200
|
+
Note: After calling cancel(), you should continue consuming stream_events()
|
|
201
|
+
to allow the cancellation to complete properly.
|
|
202
|
+
"""
|
|
203
|
+
# Store the cancel mode for the background task to check
|
|
204
|
+
self._cancel_mode = mode
|
|
205
|
+
|
|
206
|
+
if mode == "immediate":
|
|
207
|
+
# Existing behavior - immediate shutdown
|
|
208
|
+
self._cleanup_tasks() # Cancel all running tasks
|
|
209
|
+
self.is_complete = True # Mark the run as complete to stop event streaming
|
|
210
|
+
|
|
211
|
+
# Optionally, clear the event queue to prevent processing stale events
|
|
212
|
+
while not self._event_queue.empty():
|
|
213
|
+
self._event_queue.get_nowait()
|
|
214
|
+
while not self._input_guardrail_queue.empty():
|
|
215
|
+
self._input_guardrail_queue.get_nowait()
|
|
216
|
+
|
|
217
|
+
elif mode == "after_turn":
|
|
218
|
+
# Soft cancel - just set the flag
|
|
219
|
+
# The streaming loop will check this and stop gracefully
|
|
220
|
+
# Don't call _cleanup_tasks() or clear queues yet
|
|
221
|
+
pass
|
|
185
222
|
|
|
186
223
|
async def stream_events(self) -> AsyncIterator[StreamEvent]:
|
|
187
224
|
"""Stream deltas for new items as they are generated. We're using the types from the
|