openai-agents 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +5 -1
- agents/_run_impl.py +5 -1
- agents/agent.py +62 -30
- agents/agent_output.py +2 -2
- agents/function_schema.py +11 -1
- agents/guardrail.py +5 -1
- agents/handoffs.py +32 -14
- agents/lifecycle.py +26 -17
- agents/mcp/server.py +82 -11
- agents/mcp/util.py +16 -9
- agents/memory/__init__.py +3 -0
- agents/memory/session.py +369 -0
- agents/model_settings.py +15 -7
- agents/models/chatcmpl_converter.py +20 -3
- agents/models/chatcmpl_stream_handler.py +134 -43
- agents/models/openai_responses.py +12 -5
- agents/realtime/README.md +3 -0
- agents/realtime/__init__.py +177 -0
- agents/realtime/agent.py +89 -0
- agents/realtime/config.py +188 -0
- agents/realtime/events.py +216 -0
- agents/realtime/handoffs.py +165 -0
- agents/realtime/items.py +184 -0
- agents/realtime/model.py +69 -0
- agents/realtime/model_events.py +159 -0
- agents/realtime/model_inputs.py +100 -0
- agents/realtime/openai_realtime.py +670 -0
- agents/realtime/runner.py +118 -0
- agents/realtime/session.py +535 -0
- agents/run.py +106 -4
- agents/tool.py +6 -7
- agents/tool_context.py +16 -3
- agents/voice/models/openai_stt.py +1 -1
- agents/voice/pipeline.py +6 -0
- agents/voice/workflow.py +8 -0
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/METADATA +121 -4
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/RECORD +39 -24
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/WHEEL +0 -0
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""Minimal realtime session implementation for voice agents."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
|
|
7
|
+
from ..run_context import RunContextWrapper, TContext
|
|
8
|
+
from .agent import RealtimeAgent
|
|
9
|
+
from .config import (
|
|
10
|
+
RealtimeRunConfig,
|
|
11
|
+
RealtimeSessionModelSettings,
|
|
12
|
+
)
|
|
13
|
+
from .model import (
|
|
14
|
+
RealtimeModel,
|
|
15
|
+
RealtimeModelConfig,
|
|
16
|
+
)
|
|
17
|
+
from .openai_realtime import OpenAIRealtimeWebSocketModel
|
|
18
|
+
from .session import RealtimeSession
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class RealtimeRunner:
|
|
22
|
+
"""A `RealtimeRunner` is the equivalent of `Runner` for realtime agents. It automatically
|
|
23
|
+
handles multiple turns by maintaining a persistent connection with the underlying model
|
|
24
|
+
layer.
|
|
25
|
+
|
|
26
|
+
The session manages the local history copy, executes tools, runs guardrails and facilitates
|
|
27
|
+
handoffs between agents.
|
|
28
|
+
|
|
29
|
+
Since this code runs on your server, it uses WebSockets by default. You can optionally create
|
|
30
|
+
your own custom model layer by implementing the `RealtimeModel` interface.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
starting_agent: RealtimeAgent,
|
|
36
|
+
*,
|
|
37
|
+
model: RealtimeModel | None = None,
|
|
38
|
+
config: RealtimeRunConfig | None = None,
|
|
39
|
+
) -> None:
|
|
40
|
+
"""Initialize the realtime runner.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
starting_agent: The agent to start the session with.
|
|
44
|
+
context: The context to use for the session.
|
|
45
|
+
model: The model to use. If not provided, will use a default OpenAI realtime model.
|
|
46
|
+
config: Override parameters to use for the entire run.
|
|
47
|
+
"""
|
|
48
|
+
self._starting_agent = starting_agent
|
|
49
|
+
self._config = config
|
|
50
|
+
self._model = model or OpenAIRealtimeWebSocketModel()
|
|
51
|
+
|
|
52
|
+
async def run(
|
|
53
|
+
self, *, context: TContext | None = None, model_config: RealtimeModelConfig | None = None
|
|
54
|
+
) -> RealtimeSession:
|
|
55
|
+
"""Start and returns a realtime session.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
RealtimeSession: A session object that allows bidirectional communication with the
|
|
59
|
+
realtime model.
|
|
60
|
+
|
|
61
|
+
Example:
|
|
62
|
+
```python
|
|
63
|
+
runner = RealtimeRunner(agent)
|
|
64
|
+
async with await runner.run() as session:
|
|
65
|
+
await session.send_message("Hello")
|
|
66
|
+
async for event in session:
|
|
67
|
+
print(event)
|
|
68
|
+
```
|
|
69
|
+
"""
|
|
70
|
+
model_settings = await self._get_model_settings(
|
|
71
|
+
agent=self._starting_agent,
|
|
72
|
+
disable_tracing=self._config.get("tracing_disabled", False) if self._config else False,
|
|
73
|
+
initial_settings=model_config.get("initial_model_settings") if model_config else None,
|
|
74
|
+
overrides=self._config.get("model_settings") if self._config else None,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
model_config = model_config.copy() if model_config else {}
|
|
78
|
+
model_config["initial_model_settings"] = model_settings
|
|
79
|
+
|
|
80
|
+
# Create and return the connection
|
|
81
|
+
session = RealtimeSession(
|
|
82
|
+
model=self._model,
|
|
83
|
+
agent=self._starting_agent,
|
|
84
|
+
context=context,
|
|
85
|
+
model_config=model_config,
|
|
86
|
+
run_config=self._config,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
return session
|
|
90
|
+
|
|
91
|
+
async def _get_model_settings(
|
|
92
|
+
self,
|
|
93
|
+
agent: RealtimeAgent,
|
|
94
|
+
disable_tracing: bool,
|
|
95
|
+
context: TContext | None = None,
|
|
96
|
+
initial_settings: RealtimeSessionModelSettings | None = None,
|
|
97
|
+
overrides: RealtimeSessionModelSettings | None = None,
|
|
98
|
+
) -> RealtimeSessionModelSettings:
|
|
99
|
+
context_wrapper = RunContextWrapper(context)
|
|
100
|
+
model_settings = initial_settings.copy() if initial_settings else {}
|
|
101
|
+
|
|
102
|
+
instructions, tools = await asyncio.gather(
|
|
103
|
+
agent.get_system_prompt(context_wrapper),
|
|
104
|
+
agent.get_all_tools(context_wrapper),
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if instructions is not None:
|
|
108
|
+
model_settings["instructions"] = instructions
|
|
109
|
+
if tools is not None:
|
|
110
|
+
model_settings["tools"] = tools
|
|
111
|
+
|
|
112
|
+
if overrides:
|
|
113
|
+
model_settings.update(overrides)
|
|
114
|
+
|
|
115
|
+
if disable_tracing:
|
|
116
|
+
model_settings["tracing"] = None
|
|
117
|
+
|
|
118
|
+
return model_settings
|
|
@@ -0,0 +1,535 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import inspect
|
|
5
|
+
from collections.abc import AsyncIterator
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
from typing_extensions import assert_never
|
|
9
|
+
|
|
10
|
+
from ..agent import Agent
|
|
11
|
+
from ..exceptions import ModelBehaviorError, UserError
|
|
12
|
+
from ..handoffs import Handoff
|
|
13
|
+
from ..run_context import RunContextWrapper, TContext
|
|
14
|
+
from ..tool import FunctionTool
|
|
15
|
+
from ..tool_context import ToolContext
|
|
16
|
+
from .agent import RealtimeAgent
|
|
17
|
+
from .config import RealtimeRunConfig, RealtimeSessionModelSettings, RealtimeUserInput
|
|
18
|
+
from .events import (
|
|
19
|
+
RealtimeAgentEndEvent,
|
|
20
|
+
RealtimeAgentStartEvent,
|
|
21
|
+
RealtimeAudio,
|
|
22
|
+
RealtimeAudioEnd,
|
|
23
|
+
RealtimeAudioInterrupted,
|
|
24
|
+
RealtimeError,
|
|
25
|
+
RealtimeEventInfo,
|
|
26
|
+
RealtimeGuardrailTripped,
|
|
27
|
+
RealtimeHandoffEvent,
|
|
28
|
+
RealtimeHistoryAdded,
|
|
29
|
+
RealtimeHistoryUpdated,
|
|
30
|
+
RealtimeRawModelEvent,
|
|
31
|
+
RealtimeSessionEvent,
|
|
32
|
+
RealtimeToolEnd,
|
|
33
|
+
RealtimeToolStart,
|
|
34
|
+
)
|
|
35
|
+
from .handoffs import realtime_handoff
|
|
36
|
+
from .items import InputAudio, InputText, RealtimeItem
|
|
37
|
+
from .model import RealtimeModel, RealtimeModelConfig, RealtimeModelListener
|
|
38
|
+
from .model_events import (
|
|
39
|
+
RealtimeModelEvent,
|
|
40
|
+
RealtimeModelInputAudioTranscriptionCompletedEvent,
|
|
41
|
+
RealtimeModelToolCallEvent,
|
|
42
|
+
)
|
|
43
|
+
from .model_inputs import (
|
|
44
|
+
RealtimeModelSendAudio,
|
|
45
|
+
RealtimeModelSendInterrupt,
|
|
46
|
+
RealtimeModelSendSessionUpdate,
|
|
47
|
+
RealtimeModelSendToolOutput,
|
|
48
|
+
RealtimeModelSendUserInput,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class RealtimeSession(RealtimeModelListener):
|
|
53
|
+
"""A connection to a realtime model. It streams events from the model to you, and allows you to
|
|
54
|
+
send messages and audio to the model.
|
|
55
|
+
|
|
56
|
+
Example:
|
|
57
|
+
```python
|
|
58
|
+
runner = RealtimeRunner(agent)
|
|
59
|
+
async with await runner.run() as session:
|
|
60
|
+
# Send messages
|
|
61
|
+
await session.send_message("Hello")
|
|
62
|
+
await session.send_audio(audio_bytes)
|
|
63
|
+
|
|
64
|
+
# Stream events
|
|
65
|
+
async for event in session:
|
|
66
|
+
if event.type == "audio":
|
|
67
|
+
# Handle audio event
|
|
68
|
+
pass
|
|
69
|
+
```
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
model: RealtimeModel,
|
|
75
|
+
agent: RealtimeAgent,
|
|
76
|
+
context: TContext | None,
|
|
77
|
+
model_config: RealtimeModelConfig | None = None,
|
|
78
|
+
run_config: RealtimeRunConfig | None = None,
|
|
79
|
+
) -> None:
|
|
80
|
+
"""Initialize the session.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
model: The model to use.
|
|
84
|
+
agent: The current agent.
|
|
85
|
+
context: The context object.
|
|
86
|
+
model_config: Model configuration.
|
|
87
|
+
run_config: Runtime configuration including guardrails.
|
|
88
|
+
"""
|
|
89
|
+
self._model = model
|
|
90
|
+
self._current_agent = agent
|
|
91
|
+
self._context_wrapper = RunContextWrapper(context)
|
|
92
|
+
self._event_info = RealtimeEventInfo(context=self._context_wrapper)
|
|
93
|
+
self._history: list[RealtimeItem] = []
|
|
94
|
+
self._model_config = model_config or {}
|
|
95
|
+
self._run_config = run_config or {}
|
|
96
|
+
self._event_queue: asyncio.Queue[RealtimeSessionEvent] = asyncio.Queue()
|
|
97
|
+
self._closed = False
|
|
98
|
+
self._stored_exception: Exception | None = None
|
|
99
|
+
|
|
100
|
+
# Guardrails state tracking
|
|
101
|
+
self._interrupted_by_guardrail = False
|
|
102
|
+
self._item_transcripts: dict[str, str] = {} # item_id -> accumulated transcript
|
|
103
|
+
self._item_guardrail_run_counts: dict[str, int] = {} # item_id -> run count
|
|
104
|
+
self._debounce_text_length = self._run_config.get("guardrails_settings", {}).get(
|
|
105
|
+
"debounce_text_length", 100
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
self._guardrail_tasks: set[asyncio.Task[Any]] = set()
|
|
109
|
+
|
|
110
|
+
async def __aenter__(self) -> RealtimeSession:
|
|
111
|
+
"""Start the session by connecting to the model. After this, you will be able to stream
|
|
112
|
+
events from the model and send messages and audio to the model.
|
|
113
|
+
"""
|
|
114
|
+
# Add ourselves as a listener
|
|
115
|
+
self._model.add_listener(self)
|
|
116
|
+
|
|
117
|
+
# Connect to the model
|
|
118
|
+
await self._model.connect(self._model_config)
|
|
119
|
+
|
|
120
|
+
# Emit initial history update
|
|
121
|
+
await self._put_event(
|
|
122
|
+
RealtimeHistoryUpdated(
|
|
123
|
+
history=self._history,
|
|
124
|
+
info=self._event_info,
|
|
125
|
+
)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
return self
|
|
129
|
+
|
|
130
|
+
async def enter(self) -> RealtimeSession:
|
|
131
|
+
"""Enter the async context manager. We strongly recommend using the async context manager
|
|
132
|
+
pattern instead of this method. If you use this, you need to manually call `close()` when
|
|
133
|
+
you are done.
|
|
134
|
+
"""
|
|
135
|
+
return await self.__aenter__()
|
|
136
|
+
|
|
137
|
+
async def __aexit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None:
|
|
138
|
+
"""End the session."""
|
|
139
|
+
await self.close()
|
|
140
|
+
|
|
141
|
+
async def __aiter__(self) -> AsyncIterator[RealtimeSessionEvent]:
|
|
142
|
+
"""Iterate over events from the session."""
|
|
143
|
+
while not self._closed:
|
|
144
|
+
try:
|
|
145
|
+
# Check if there's a stored exception to raise
|
|
146
|
+
if self._stored_exception is not None:
|
|
147
|
+
# Clean up resources before raising
|
|
148
|
+
await self._cleanup()
|
|
149
|
+
raise self._stored_exception
|
|
150
|
+
|
|
151
|
+
event = await self._event_queue.get()
|
|
152
|
+
yield event
|
|
153
|
+
except asyncio.CancelledError:
|
|
154
|
+
break
|
|
155
|
+
|
|
156
|
+
async def close(self) -> None:
|
|
157
|
+
"""Close the session."""
|
|
158
|
+
await self._cleanup()
|
|
159
|
+
|
|
160
|
+
async def send_message(self, message: RealtimeUserInput) -> None:
|
|
161
|
+
"""Send a message to the model."""
|
|
162
|
+
await self._model.send_event(RealtimeModelSendUserInput(user_input=message))
|
|
163
|
+
|
|
164
|
+
async def send_audio(self, audio: bytes, *, commit: bool = False) -> None:
|
|
165
|
+
"""Send a raw audio chunk to the model."""
|
|
166
|
+
await self._model.send_event(RealtimeModelSendAudio(audio=audio, commit=commit))
|
|
167
|
+
|
|
168
|
+
async def interrupt(self) -> None:
|
|
169
|
+
"""Interrupt the model."""
|
|
170
|
+
await self._model.send_event(RealtimeModelSendInterrupt())
|
|
171
|
+
|
|
172
|
+
async def on_event(self, event: RealtimeModelEvent) -> None:
|
|
173
|
+
await self._put_event(RealtimeRawModelEvent(data=event, info=self._event_info))
|
|
174
|
+
|
|
175
|
+
if event.type == "error":
|
|
176
|
+
await self._put_event(RealtimeError(info=self._event_info, error=event.error))
|
|
177
|
+
elif event.type == "function_call":
|
|
178
|
+
await self._handle_tool_call(event)
|
|
179
|
+
elif event.type == "audio":
|
|
180
|
+
await self._put_event(RealtimeAudio(info=self._event_info, audio=event))
|
|
181
|
+
elif event.type == "audio_interrupted":
|
|
182
|
+
await self._put_event(RealtimeAudioInterrupted(info=self._event_info))
|
|
183
|
+
elif event.type == "audio_done":
|
|
184
|
+
await self._put_event(RealtimeAudioEnd(info=self._event_info))
|
|
185
|
+
elif event.type == "input_audio_transcription_completed":
|
|
186
|
+
self._history = RealtimeSession._get_new_history(self._history, event)
|
|
187
|
+
await self._put_event(
|
|
188
|
+
RealtimeHistoryUpdated(info=self._event_info, history=self._history)
|
|
189
|
+
)
|
|
190
|
+
elif event.type == "transcript_delta":
|
|
191
|
+
# Accumulate transcript text for guardrail debouncing per item_id
|
|
192
|
+
item_id = event.item_id
|
|
193
|
+
if item_id not in self._item_transcripts:
|
|
194
|
+
self._item_transcripts[item_id] = ""
|
|
195
|
+
self._item_guardrail_run_counts[item_id] = 0
|
|
196
|
+
|
|
197
|
+
self._item_transcripts[item_id] += event.delta
|
|
198
|
+
|
|
199
|
+
# Check if we should run guardrails based on debounce threshold
|
|
200
|
+
current_length = len(self._item_transcripts[item_id])
|
|
201
|
+
threshold = self._debounce_text_length
|
|
202
|
+
next_run_threshold = (self._item_guardrail_run_counts[item_id] + 1) * threshold
|
|
203
|
+
|
|
204
|
+
if current_length >= next_run_threshold:
|
|
205
|
+
self._item_guardrail_run_counts[item_id] += 1
|
|
206
|
+
self._enqueue_guardrail_task(self._item_transcripts[item_id])
|
|
207
|
+
elif event.type == "item_updated":
|
|
208
|
+
is_new = not any(item.item_id == event.item.item_id for item in self._history)
|
|
209
|
+
self._history = self._get_new_history(self._history, event.item)
|
|
210
|
+
if is_new:
|
|
211
|
+
new_item = next(
|
|
212
|
+
item for item in self._history if item.item_id == event.item.item_id
|
|
213
|
+
)
|
|
214
|
+
await self._put_event(RealtimeHistoryAdded(info=self._event_info, item=new_item))
|
|
215
|
+
else:
|
|
216
|
+
await self._put_event(
|
|
217
|
+
RealtimeHistoryUpdated(info=self._event_info, history=self._history)
|
|
218
|
+
)
|
|
219
|
+
elif event.type == "item_deleted":
|
|
220
|
+
deleted_id = event.item_id
|
|
221
|
+
self._history = [item for item in self._history if item.item_id != deleted_id]
|
|
222
|
+
await self._put_event(
|
|
223
|
+
RealtimeHistoryUpdated(info=self._event_info, history=self._history)
|
|
224
|
+
)
|
|
225
|
+
elif event.type == "connection_status":
|
|
226
|
+
pass
|
|
227
|
+
elif event.type == "turn_started":
|
|
228
|
+
await self._put_event(
|
|
229
|
+
RealtimeAgentStartEvent(
|
|
230
|
+
agent=self._current_agent,
|
|
231
|
+
info=self._event_info,
|
|
232
|
+
)
|
|
233
|
+
)
|
|
234
|
+
elif event.type == "turn_ended":
|
|
235
|
+
# Clear guardrail state for next turn
|
|
236
|
+
self._item_transcripts.clear()
|
|
237
|
+
self._item_guardrail_run_counts.clear()
|
|
238
|
+
self._interrupted_by_guardrail = False
|
|
239
|
+
|
|
240
|
+
await self._put_event(
|
|
241
|
+
RealtimeAgentEndEvent(
|
|
242
|
+
agent=self._current_agent,
|
|
243
|
+
info=self._event_info,
|
|
244
|
+
)
|
|
245
|
+
)
|
|
246
|
+
elif event.type == "exception":
|
|
247
|
+
# Store the exception to be raised in __aiter__
|
|
248
|
+
self._stored_exception = event.exception
|
|
249
|
+
elif event.type == "other":
|
|
250
|
+
pass
|
|
251
|
+
else:
|
|
252
|
+
assert_never(event)
|
|
253
|
+
|
|
254
|
+
async def _put_event(self, event: RealtimeSessionEvent) -> None:
|
|
255
|
+
"""Put an event into the queue."""
|
|
256
|
+
await self._event_queue.put(event)
|
|
257
|
+
|
|
258
|
+
async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
|
|
259
|
+
"""Handle a tool call event."""
|
|
260
|
+
tools, handoffs = await asyncio.gather(
|
|
261
|
+
self._current_agent.get_all_tools(self._context_wrapper),
|
|
262
|
+
self._get_handoffs(self._current_agent, self._context_wrapper),
|
|
263
|
+
)
|
|
264
|
+
function_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
|
|
265
|
+
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
|
|
266
|
+
|
|
267
|
+
if event.name in function_map:
|
|
268
|
+
await self._put_event(
|
|
269
|
+
RealtimeToolStart(
|
|
270
|
+
info=self._event_info,
|
|
271
|
+
tool=function_map[event.name],
|
|
272
|
+
agent=self._current_agent,
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
func_tool = function_map[event.name]
|
|
277
|
+
tool_context = ToolContext(
|
|
278
|
+
context=self._context_wrapper.context,
|
|
279
|
+
usage=self._context_wrapper.usage,
|
|
280
|
+
tool_name=event.name,
|
|
281
|
+
tool_call_id=event.call_id,
|
|
282
|
+
)
|
|
283
|
+
result = await func_tool.on_invoke_tool(tool_context, event.arguments)
|
|
284
|
+
|
|
285
|
+
await self._model.send_event(
|
|
286
|
+
RealtimeModelSendToolOutput(
|
|
287
|
+
tool_call=event, output=str(result), start_response=True
|
|
288
|
+
)
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
await self._put_event(
|
|
292
|
+
RealtimeToolEnd(
|
|
293
|
+
info=self._event_info,
|
|
294
|
+
tool=func_tool,
|
|
295
|
+
output=result,
|
|
296
|
+
agent=self._current_agent,
|
|
297
|
+
)
|
|
298
|
+
)
|
|
299
|
+
elif event.name in handoff_map:
|
|
300
|
+
handoff = handoff_map[event.name]
|
|
301
|
+
tool_context = ToolContext(
|
|
302
|
+
context=self._context_wrapper.context,
|
|
303
|
+
usage=self._context_wrapper.usage,
|
|
304
|
+
tool_name=event.name,
|
|
305
|
+
tool_call_id=event.call_id,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Execute the handoff to get the new agent
|
|
309
|
+
result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments)
|
|
310
|
+
if not isinstance(result, RealtimeAgent):
|
|
311
|
+
raise UserError(
|
|
312
|
+
f"Handoff {handoff.tool_name} returned invalid result: {type(result)}"
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# Store previous agent for event
|
|
316
|
+
previous_agent = self._current_agent
|
|
317
|
+
|
|
318
|
+
# Update current agent
|
|
319
|
+
self._current_agent = result
|
|
320
|
+
|
|
321
|
+
# Get updated model settings from new agent
|
|
322
|
+
updated_settings = await self._get__updated_model_settings(self._current_agent)
|
|
323
|
+
|
|
324
|
+
# Send handoff event
|
|
325
|
+
await self._put_event(
|
|
326
|
+
RealtimeHandoffEvent(
|
|
327
|
+
from_agent=previous_agent,
|
|
328
|
+
to_agent=self._current_agent,
|
|
329
|
+
info=self._event_info,
|
|
330
|
+
)
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# Send tool output to complete the handoff
|
|
334
|
+
await self._model.send_event(
|
|
335
|
+
RealtimeModelSendToolOutput(
|
|
336
|
+
tool_call=event,
|
|
337
|
+
output=f"Handed off to {self._current_agent.name}",
|
|
338
|
+
start_response=True,
|
|
339
|
+
)
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# Send session update to model
|
|
343
|
+
await self._model.send_event(
|
|
344
|
+
RealtimeModelSendSessionUpdate(session_settings=updated_settings)
|
|
345
|
+
)
|
|
346
|
+
else:
|
|
347
|
+
raise ModelBehaviorError(f"Tool {event.name} not found")
|
|
348
|
+
|
|
349
|
+
@classmethod
|
|
350
|
+
def _get_new_history(
|
|
351
|
+
cls,
|
|
352
|
+
old_history: list[RealtimeItem],
|
|
353
|
+
event: RealtimeModelInputAudioTranscriptionCompletedEvent | RealtimeItem,
|
|
354
|
+
) -> list[RealtimeItem]:
|
|
355
|
+
# Merge transcript into placeholder input_audio message.
|
|
356
|
+
if isinstance(event, RealtimeModelInputAudioTranscriptionCompletedEvent):
|
|
357
|
+
new_history: list[RealtimeItem] = []
|
|
358
|
+
for item in old_history:
|
|
359
|
+
if item.item_id == event.item_id and item.type == "message" and item.role == "user":
|
|
360
|
+
content: list[InputText | InputAudio] = []
|
|
361
|
+
for entry in item.content:
|
|
362
|
+
if entry.type == "input_audio":
|
|
363
|
+
copied_entry = entry.model_copy(update={"transcript": event.transcript})
|
|
364
|
+
content.append(copied_entry)
|
|
365
|
+
else:
|
|
366
|
+
content.append(entry) # type: ignore
|
|
367
|
+
new_history.append(
|
|
368
|
+
item.model_copy(update={"content": content, "status": "completed"})
|
|
369
|
+
)
|
|
370
|
+
else:
|
|
371
|
+
new_history.append(item)
|
|
372
|
+
return new_history
|
|
373
|
+
|
|
374
|
+
# Otherwise it's just a new item
|
|
375
|
+
# TODO (rm) Add support for audio storage config
|
|
376
|
+
|
|
377
|
+
# If the item already exists, update it
|
|
378
|
+
existing_index = next(
|
|
379
|
+
(i for i, item in enumerate(old_history) if item.item_id == event.item_id), None
|
|
380
|
+
)
|
|
381
|
+
if existing_index is not None:
|
|
382
|
+
new_history = old_history.copy()
|
|
383
|
+
new_history[existing_index] = event
|
|
384
|
+
return new_history
|
|
385
|
+
# Otherwise, insert it after the previous_item_id if that is set
|
|
386
|
+
elif event.previous_item_id:
|
|
387
|
+
# Insert the new item after the previous item
|
|
388
|
+
previous_index = next(
|
|
389
|
+
(i for i, item in enumerate(old_history) if item.item_id == event.previous_item_id),
|
|
390
|
+
None,
|
|
391
|
+
)
|
|
392
|
+
if previous_index is not None:
|
|
393
|
+
new_history = old_history.copy()
|
|
394
|
+
new_history.insert(previous_index + 1, event)
|
|
395
|
+
return new_history
|
|
396
|
+
|
|
397
|
+
# Otherwise, add it to the end
|
|
398
|
+
return old_history + [event]
|
|
399
|
+
|
|
400
|
+
async def _run_output_guardrails(self, text: str) -> bool:
|
|
401
|
+
"""Run output guardrails on the given text. Returns True if any guardrail was triggered."""
|
|
402
|
+
output_guardrails = self._run_config.get("output_guardrails", [])
|
|
403
|
+
if not output_guardrails or self._interrupted_by_guardrail:
|
|
404
|
+
return False
|
|
405
|
+
|
|
406
|
+
triggered_results = []
|
|
407
|
+
|
|
408
|
+
for guardrail in output_guardrails:
|
|
409
|
+
try:
|
|
410
|
+
result = await guardrail.run(
|
|
411
|
+
# TODO (rm) Remove this cast, it's wrong
|
|
412
|
+
self._context_wrapper,
|
|
413
|
+
cast(Agent[Any], self._current_agent),
|
|
414
|
+
text,
|
|
415
|
+
)
|
|
416
|
+
if result.output.tripwire_triggered:
|
|
417
|
+
triggered_results.append(result)
|
|
418
|
+
except Exception:
|
|
419
|
+
# Continue with other guardrails if one fails
|
|
420
|
+
continue
|
|
421
|
+
|
|
422
|
+
if triggered_results:
|
|
423
|
+
# Mark as interrupted to prevent multiple interrupts
|
|
424
|
+
self._interrupted_by_guardrail = True
|
|
425
|
+
|
|
426
|
+
# Emit guardrail tripped event
|
|
427
|
+
await self._put_event(
|
|
428
|
+
RealtimeGuardrailTripped(
|
|
429
|
+
guardrail_results=triggered_results,
|
|
430
|
+
message=text,
|
|
431
|
+
info=self._event_info,
|
|
432
|
+
)
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# Interrupt the model
|
|
436
|
+
await self._model.send_event(RealtimeModelSendInterrupt())
|
|
437
|
+
|
|
438
|
+
# Send guardrail triggered message
|
|
439
|
+
guardrail_names = [result.guardrail.get_name() for result in triggered_results]
|
|
440
|
+
await self._model.send_event(
|
|
441
|
+
RealtimeModelSendUserInput(
|
|
442
|
+
user_input=f"guardrail triggered: {', '.join(guardrail_names)}"
|
|
443
|
+
)
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
return True
|
|
447
|
+
|
|
448
|
+
return False
|
|
449
|
+
|
|
450
|
+
def _enqueue_guardrail_task(self, text: str) -> None:
|
|
451
|
+
# Runs the guardrails in a separate task to avoid blocking the main loop
|
|
452
|
+
|
|
453
|
+
task = asyncio.create_task(self._run_output_guardrails(text))
|
|
454
|
+
self._guardrail_tasks.add(task)
|
|
455
|
+
|
|
456
|
+
# Add callback to remove completed tasks and handle exceptions
|
|
457
|
+
task.add_done_callback(self._on_guardrail_task_done)
|
|
458
|
+
|
|
459
|
+
def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None:
|
|
460
|
+
"""Handle completion of a guardrail task."""
|
|
461
|
+
# Remove from tracking set
|
|
462
|
+
self._guardrail_tasks.discard(task)
|
|
463
|
+
|
|
464
|
+
# Check for exceptions and propagate as events
|
|
465
|
+
if not task.cancelled():
|
|
466
|
+
exception = task.exception()
|
|
467
|
+
if exception:
|
|
468
|
+
# Create an exception event instead of raising
|
|
469
|
+
asyncio.create_task(
|
|
470
|
+
self._put_event(
|
|
471
|
+
RealtimeError(
|
|
472
|
+
info=self._event_info,
|
|
473
|
+
error={"message": f"Guardrail task failed: {str(exception)}"},
|
|
474
|
+
)
|
|
475
|
+
)
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
def _cleanup_guardrail_tasks(self) -> None:
|
|
479
|
+
for task in self._guardrail_tasks:
|
|
480
|
+
if not task.done():
|
|
481
|
+
task.cancel()
|
|
482
|
+
self._guardrail_tasks.clear()
|
|
483
|
+
|
|
484
|
+
async def _cleanup(self) -> None:
|
|
485
|
+
"""Clean up all resources and mark session as closed."""
|
|
486
|
+
# Cancel and cleanup guardrail tasks
|
|
487
|
+
self._cleanup_guardrail_tasks()
|
|
488
|
+
|
|
489
|
+
# Remove ourselves as a listener
|
|
490
|
+
self._model.remove_listener(self)
|
|
491
|
+
|
|
492
|
+
# Close the model connection
|
|
493
|
+
await self._model.close()
|
|
494
|
+
|
|
495
|
+
# Mark as closed
|
|
496
|
+
self._closed = True
|
|
497
|
+
|
|
498
|
+
async def _get__updated_model_settings(
|
|
499
|
+
self, new_agent: RealtimeAgent
|
|
500
|
+
) -> RealtimeSessionModelSettings:
|
|
501
|
+
updated_settings: RealtimeSessionModelSettings = {}
|
|
502
|
+
instructions, tools, handoffs = await asyncio.gather(
|
|
503
|
+
new_agent.get_system_prompt(self._context_wrapper),
|
|
504
|
+
new_agent.get_all_tools(self._context_wrapper),
|
|
505
|
+
self._get_handoffs(new_agent, self._context_wrapper),
|
|
506
|
+
)
|
|
507
|
+
updated_settings["instructions"] = instructions or ""
|
|
508
|
+
updated_settings["tools"] = tools or []
|
|
509
|
+
updated_settings["handoffs"] = handoffs or []
|
|
510
|
+
|
|
511
|
+
return updated_settings
|
|
512
|
+
|
|
513
|
+
@classmethod
|
|
514
|
+
async def _get_handoffs(
|
|
515
|
+
cls, agent: RealtimeAgent[Any], context_wrapper: RunContextWrapper[Any]
|
|
516
|
+
) -> list[Handoff[Any, RealtimeAgent[Any]]]:
|
|
517
|
+
handoffs: list[Handoff[Any, RealtimeAgent[Any]]] = []
|
|
518
|
+
for handoff_item in agent.handoffs:
|
|
519
|
+
if isinstance(handoff_item, Handoff):
|
|
520
|
+
handoffs.append(handoff_item)
|
|
521
|
+
elif isinstance(handoff_item, RealtimeAgent):
|
|
522
|
+
handoffs.append(realtime_handoff(handoff_item))
|
|
523
|
+
|
|
524
|
+
async def _check_handoff_enabled(handoff_obj: Handoff[Any, RealtimeAgent[Any]]) -> bool:
|
|
525
|
+
attr = handoff_obj.is_enabled
|
|
526
|
+
if isinstance(attr, bool):
|
|
527
|
+
return attr
|
|
528
|
+
res = attr(context_wrapper, agent)
|
|
529
|
+
if inspect.isawaitable(res):
|
|
530
|
+
return await res
|
|
531
|
+
return res
|
|
532
|
+
|
|
533
|
+
results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs))
|
|
534
|
+
enabled = [h for h, ok in zip(handoffs, results) if ok]
|
|
535
|
+
return enabled
|