openai-agents 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

Files changed (39) hide show
  1. agents/__init__.py +5 -1
  2. agents/_run_impl.py +5 -1
  3. agents/agent.py +62 -30
  4. agents/agent_output.py +2 -2
  5. agents/function_schema.py +11 -1
  6. agents/guardrail.py +5 -1
  7. agents/handoffs.py +32 -14
  8. agents/lifecycle.py +26 -17
  9. agents/mcp/server.py +82 -11
  10. agents/mcp/util.py +16 -9
  11. agents/memory/__init__.py +3 -0
  12. agents/memory/session.py +369 -0
  13. agents/model_settings.py +15 -7
  14. agents/models/chatcmpl_converter.py +20 -3
  15. agents/models/chatcmpl_stream_handler.py +134 -43
  16. agents/models/openai_responses.py +12 -5
  17. agents/realtime/README.md +3 -0
  18. agents/realtime/__init__.py +177 -0
  19. agents/realtime/agent.py +89 -0
  20. agents/realtime/config.py +188 -0
  21. agents/realtime/events.py +216 -0
  22. agents/realtime/handoffs.py +165 -0
  23. agents/realtime/items.py +184 -0
  24. agents/realtime/model.py +69 -0
  25. agents/realtime/model_events.py +159 -0
  26. agents/realtime/model_inputs.py +100 -0
  27. agents/realtime/openai_realtime.py +670 -0
  28. agents/realtime/runner.py +118 -0
  29. agents/realtime/session.py +535 -0
  30. agents/run.py +106 -4
  31. agents/tool.py +6 -7
  32. agents/tool_context.py +16 -3
  33. agents/voice/models/openai_stt.py +1 -1
  34. agents/voice/pipeline.py +6 -0
  35. agents/voice/workflow.py +8 -0
  36. {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/METADATA +121 -4
  37. {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/RECORD +39 -24
  38. {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/WHEEL +0 -0
  39. {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,118 @@
1
+ """Minimal realtime session implementation for voice agents."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+
7
+ from ..run_context import RunContextWrapper, TContext
8
+ from .agent import RealtimeAgent
9
+ from .config import (
10
+ RealtimeRunConfig,
11
+ RealtimeSessionModelSettings,
12
+ )
13
+ from .model import (
14
+ RealtimeModel,
15
+ RealtimeModelConfig,
16
+ )
17
+ from .openai_realtime import OpenAIRealtimeWebSocketModel
18
+ from .session import RealtimeSession
19
+
20
+
21
+ class RealtimeRunner:
22
+ """A `RealtimeRunner` is the equivalent of `Runner` for realtime agents. It automatically
23
+ handles multiple turns by maintaining a persistent connection with the underlying model
24
+ layer.
25
+
26
+ The session manages the local history copy, executes tools, runs guardrails and facilitates
27
+ handoffs between agents.
28
+
29
+ Since this code runs on your server, it uses WebSockets by default. You can optionally create
30
+ your own custom model layer by implementing the `RealtimeModel` interface.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ starting_agent: RealtimeAgent,
36
+ *,
37
+ model: RealtimeModel | None = None,
38
+ config: RealtimeRunConfig | None = None,
39
+ ) -> None:
40
+ """Initialize the realtime runner.
41
+
42
+ Args:
43
+ starting_agent: The agent to start the session with.
44
+ context: The context to use for the session.
45
+ model: The model to use. If not provided, will use a default OpenAI realtime model.
46
+ config: Override parameters to use for the entire run.
47
+ """
48
+ self._starting_agent = starting_agent
49
+ self._config = config
50
+ self._model = model or OpenAIRealtimeWebSocketModel()
51
+
52
+ async def run(
53
+ self, *, context: TContext | None = None, model_config: RealtimeModelConfig | None = None
54
+ ) -> RealtimeSession:
55
+ """Start and returns a realtime session.
56
+
57
+ Returns:
58
+ RealtimeSession: A session object that allows bidirectional communication with the
59
+ realtime model.
60
+
61
+ Example:
62
+ ```python
63
+ runner = RealtimeRunner(agent)
64
+ async with await runner.run() as session:
65
+ await session.send_message("Hello")
66
+ async for event in session:
67
+ print(event)
68
+ ```
69
+ """
70
+ model_settings = await self._get_model_settings(
71
+ agent=self._starting_agent,
72
+ disable_tracing=self._config.get("tracing_disabled", False) if self._config else False,
73
+ initial_settings=model_config.get("initial_model_settings") if model_config else None,
74
+ overrides=self._config.get("model_settings") if self._config else None,
75
+ )
76
+
77
+ model_config = model_config.copy() if model_config else {}
78
+ model_config["initial_model_settings"] = model_settings
79
+
80
+ # Create and return the connection
81
+ session = RealtimeSession(
82
+ model=self._model,
83
+ agent=self._starting_agent,
84
+ context=context,
85
+ model_config=model_config,
86
+ run_config=self._config,
87
+ )
88
+
89
+ return session
90
+
91
+ async def _get_model_settings(
92
+ self,
93
+ agent: RealtimeAgent,
94
+ disable_tracing: bool,
95
+ context: TContext | None = None,
96
+ initial_settings: RealtimeSessionModelSettings | None = None,
97
+ overrides: RealtimeSessionModelSettings | None = None,
98
+ ) -> RealtimeSessionModelSettings:
99
+ context_wrapper = RunContextWrapper(context)
100
+ model_settings = initial_settings.copy() if initial_settings else {}
101
+
102
+ instructions, tools = await asyncio.gather(
103
+ agent.get_system_prompt(context_wrapper),
104
+ agent.get_all_tools(context_wrapper),
105
+ )
106
+
107
+ if instructions is not None:
108
+ model_settings["instructions"] = instructions
109
+ if tools is not None:
110
+ model_settings["tools"] = tools
111
+
112
+ if overrides:
113
+ model_settings.update(overrides)
114
+
115
+ if disable_tracing:
116
+ model_settings["tracing"] = None
117
+
118
+ return model_settings
@@ -0,0 +1,535 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import inspect
5
+ from collections.abc import AsyncIterator
6
+ from typing import Any, cast
7
+
8
+ from typing_extensions import assert_never
9
+
10
+ from ..agent import Agent
11
+ from ..exceptions import ModelBehaviorError, UserError
12
+ from ..handoffs import Handoff
13
+ from ..run_context import RunContextWrapper, TContext
14
+ from ..tool import FunctionTool
15
+ from ..tool_context import ToolContext
16
+ from .agent import RealtimeAgent
17
+ from .config import RealtimeRunConfig, RealtimeSessionModelSettings, RealtimeUserInput
18
+ from .events import (
19
+ RealtimeAgentEndEvent,
20
+ RealtimeAgentStartEvent,
21
+ RealtimeAudio,
22
+ RealtimeAudioEnd,
23
+ RealtimeAudioInterrupted,
24
+ RealtimeError,
25
+ RealtimeEventInfo,
26
+ RealtimeGuardrailTripped,
27
+ RealtimeHandoffEvent,
28
+ RealtimeHistoryAdded,
29
+ RealtimeHistoryUpdated,
30
+ RealtimeRawModelEvent,
31
+ RealtimeSessionEvent,
32
+ RealtimeToolEnd,
33
+ RealtimeToolStart,
34
+ )
35
+ from .handoffs import realtime_handoff
36
+ from .items import InputAudio, InputText, RealtimeItem
37
+ from .model import RealtimeModel, RealtimeModelConfig, RealtimeModelListener
38
+ from .model_events import (
39
+ RealtimeModelEvent,
40
+ RealtimeModelInputAudioTranscriptionCompletedEvent,
41
+ RealtimeModelToolCallEvent,
42
+ )
43
+ from .model_inputs import (
44
+ RealtimeModelSendAudio,
45
+ RealtimeModelSendInterrupt,
46
+ RealtimeModelSendSessionUpdate,
47
+ RealtimeModelSendToolOutput,
48
+ RealtimeModelSendUserInput,
49
+ )
50
+
51
+
52
+ class RealtimeSession(RealtimeModelListener):
53
+ """A connection to a realtime model. It streams events from the model to you, and allows you to
54
+ send messages and audio to the model.
55
+
56
+ Example:
57
+ ```python
58
+ runner = RealtimeRunner(agent)
59
+ async with await runner.run() as session:
60
+ # Send messages
61
+ await session.send_message("Hello")
62
+ await session.send_audio(audio_bytes)
63
+
64
+ # Stream events
65
+ async for event in session:
66
+ if event.type == "audio":
67
+ # Handle audio event
68
+ pass
69
+ ```
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ model: RealtimeModel,
75
+ agent: RealtimeAgent,
76
+ context: TContext | None,
77
+ model_config: RealtimeModelConfig | None = None,
78
+ run_config: RealtimeRunConfig | None = None,
79
+ ) -> None:
80
+ """Initialize the session.
81
+
82
+ Args:
83
+ model: The model to use.
84
+ agent: The current agent.
85
+ context: The context object.
86
+ model_config: Model configuration.
87
+ run_config: Runtime configuration including guardrails.
88
+ """
89
+ self._model = model
90
+ self._current_agent = agent
91
+ self._context_wrapper = RunContextWrapper(context)
92
+ self._event_info = RealtimeEventInfo(context=self._context_wrapper)
93
+ self._history: list[RealtimeItem] = []
94
+ self._model_config = model_config or {}
95
+ self._run_config = run_config or {}
96
+ self._event_queue: asyncio.Queue[RealtimeSessionEvent] = asyncio.Queue()
97
+ self._closed = False
98
+ self._stored_exception: Exception | None = None
99
+
100
+ # Guardrails state tracking
101
+ self._interrupted_by_guardrail = False
102
+ self._item_transcripts: dict[str, str] = {} # item_id -> accumulated transcript
103
+ self._item_guardrail_run_counts: dict[str, int] = {} # item_id -> run count
104
+ self._debounce_text_length = self._run_config.get("guardrails_settings", {}).get(
105
+ "debounce_text_length", 100
106
+ )
107
+
108
+ self._guardrail_tasks: set[asyncio.Task[Any]] = set()
109
+
110
+ async def __aenter__(self) -> RealtimeSession:
111
+ """Start the session by connecting to the model. After this, you will be able to stream
112
+ events from the model and send messages and audio to the model.
113
+ """
114
+ # Add ourselves as a listener
115
+ self._model.add_listener(self)
116
+
117
+ # Connect to the model
118
+ await self._model.connect(self._model_config)
119
+
120
+ # Emit initial history update
121
+ await self._put_event(
122
+ RealtimeHistoryUpdated(
123
+ history=self._history,
124
+ info=self._event_info,
125
+ )
126
+ )
127
+
128
+ return self
129
+
130
+ async def enter(self) -> RealtimeSession:
131
+ """Enter the async context manager. We strongly recommend using the async context manager
132
+ pattern instead of this method. If you use this, you need to manually call `close()` when
133
+ you are done.
134
+ """
135
+ return await self.__aenter__()
136
+
137
+ async def __aexit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None:
138
+ """End the session."""
139
+ await self.close()
140
+
141
+ async def __aiter__(self) -> AsyncIterator[RealtimeSessionEvent]:
142
+ """Iterate over events from the session."""
143
+ while not self._closed:
144
+ try:
145
+ # Check if there's a stored exception to raise
146
+ if self._stored_exception is not None:
147
+ # Clean up resources before raising
148
+ await self._cleanup()
149
+ raise self._stored_exception
150
+
151
+ event = await self._event_queue.get()
152
+ yield event
153
+ except asyncio.CancelledError:
154
+ break
155
+
156
+ async def close(self) -> None:
157
+ """Close the session."""
158
+ await self._cleanup()
159
+
160
+ async def send_message(self, message: RealtimeUserInput) -> None:
161
+ """Send a message to the model."""
162
+ await self._model.send_event(RealtimeModelSendUserInput(user_input=message))
163
+
164
+ async def send_audio(self, audio: bytes, *, commit: bool = False) -> None:
165
+ """Send a raw audio chunk to the model."""
166
+ await self._model.send_event(RealtimeModelSendAudio(audio=audio, commit=commit))
167
+
168
+ async def interrupt(self) -> None:
169
+ """Interrupt the model."""
170
+ await self._model.send_event(RealtimeModelSendInterrupt())
171
+
172
+ async def on_event(self, event: RealtimeModelEvent) -> None:
173
+ await self._put_event(RealtimeRawModelEvent(data=event, info=self._event_info))
174
+
175
+ if event.type == "error":
176
+ await self._put_event(RealtimeError(info=self._event_info, error=event.error))
177
+ elif event.type == "function_call":
178
+ await self._handle_tool_call(event)
179
+ elif event.type == "audio":
180
+ await self._put_event(RealtimeAudio(info=self._event_info, audio=event))
181
+ elif event.type == "audio_interrupted":
182
+ await self._put_event(RealtimeAudioInterrupted(info=self._event_info))
183
+ elif event.type == "audio_done":
184
+ await self._put_event(RealtimeAudioEnd(info=self._event_info))
185
+ elif event.type == "input_audio_transcription_completed":
186
+ self._history = RealtimeSession._get_new_history(self._history, event)
187
+ await self._put_event(
188
+ RealtimeHistoryUpdated(info=self._event_info, history=self._history)
189
+ )
190
+ elif event.type == "transcript_delta":
191
+ # Accumulate transcript text for guardrail debouncing per item_id
192
+ item_id = event.item_id
193
+ if item_id not in self._item_transcripts:
194
+ self._item_transcripts[item_id] = ""
195
+ self._item_guardrail_run_counts[item_id] = 0
196
+
197
+ self._item_transcripts[item_id] += event.delta
198
+
199
+ # Check if we should run guardrails based on debounce threshold
200
+ current_length = len(self._item_transcripts[item_id])
201
+ threshold = self._debounce_text_length
202
+ next_run_threshold = (self._item_guardrail_run_counts[item_id] + 1) * threshold
203
+
204
+ if current_length >= next_run_threshold:
205
+ self._item_guardrail_run_counts[item_id] += 1
206
+ self._enqueue_guardrail_task(self._item_transcripts[item_id])
207
+ elif event.type == "item_updated":
208
+ is_new = not any(item.item_id == event.item.item_id for item in self._history)
209
+ self._history = self._get_new_history(self._history, event.item)
210
+ if is_new:
211
+ new_item = next(
212
+ item for item in self._history if item.item_id == event.item.item_id
213
+ )
214
+ await self._put_event(RealtimeHistoryAdded(info=self._event_info, item=new_item))
215
+ else:
216
+ await self._put_event(
217
+ RealtimeHistoryUpdated(info=self._event_info, history=self._history)
218
+ )
219
+ elif event.type == "item_deleted":
220
+ deleted_id = event.item_id
221
+ self._history = [item for item in self._history if item.item_id != deleted_id]
222
+ await self._put_event(
223
+ RealtimeHistoryUpdated(info=self._event_info, history=self._history)
224
+ )
225
+ elif event.type == "connection_status":
226
+ pass
227
+ elif event.type == "turn_started":
228
+ await self._put_event(
229
+ RealtimeAgentStartEvent(
230
+ agent=self._current_agent,
231
+ info=self._event_info,
232
+ )
233
+ )
234
+ elif event.type == "turn_ended":
235
+ # Clear guardrail state for next turn
236
+ self._item_transcripts.clear()
237
+ self._item_guardrail_run_counts.clear()
238
+ self._interrupted_by_guardrail = False
239
+
240
+ await self._put_event(
241
+ RealtimeAgentEndEvent(
242
+ agent=self._current_agent,
243
+ info=self._event_info,
244
+ )
245
+ )
246
+ elif event.type == "exception":
247
+ # Store the exception to be raised in __aiter__
248
+ self._stored_exception = event.exception
249
+ elif event.type == "other":
250
+ pass
251
+ else:
252
+ assert_never(event)
253
+
254
+ async def _put_event(self, event: RealtimeSessionEvent) -> None:
255
+ """Put an event into the queue."""
256
+ await self._event_queue.put(event)
257
+
258
+ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
259
+ """Handle a tool call event."""
260
+ tools, handoffs = await asyncio.gather(
261
+ self._current_agent.get_all_tools(self._context_wrapper),
262
+ self._get_handoffs(self._current_agent, self._context_wrapper),
263
+ )
264
+ function_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
265
+ handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
266
+
267
+ if event.name in function_map:
268
+ await self._put_event(
269
+ RealtimeToolStart(
270
+ info=self._event_info,
271
+ tool=function_map[event.name],
272
+ agent=self._current_agent,
273
+ )
274
+ )
275
+
276
+ func_tool = function_map[event.name]
277
+ tool_context = ToolContext(
278
+ context=self._context_wrapper.context,
279
+ usage=self._context_wrapper.usage,
280
+ tool_name=event.name,
281
+ tool_call_id=event.call_id,
282
+ )
283
+ result = await func_tool.on_invoke_tool(tool_context, event.arguments)
284
+
285
+ await self._model.send_event(
286
+ RealtimeModelSendToolOutput(
287
+ tool_call=event, output=str(result), start_response=True
288
+ )
289
+ )
290
+
291
+ await self._put_event(
292
+ RealtimeToolEnd(
293
+ info=self._event_info,
294
+ tool=func_tool,
295
+ output=result,
296
+ agent=self._current_agent,
297
+ )
298
+ )
299
+ elif event.name in handoff_map:
300
+ handoff = handoff_map[event.name]
301
+ tool_context = ToolContext(
302
+ context=self._context_wrapper.context,
303
+ usage=self._context_wrapper.usage,
304
+ tool_name=event.name,
305
+ tool_call_id=event.call_id,
306
+ )
307
+
308
+ # Execute the handoff to get the new agent
309
+ result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments)
310
+ if not isinstance(result, RealtimeAgent):
311
+ raise UserError(
312
+ f"Handoff {handoff.tool_name} returned invalid result: {type(result)}"
313
+ )
314
+
315
+ # Store previous agent for event
316
+ previous_agent = self._current_agent
317
+
318
+ # Update current agent
319
+ self._current_agent = result
320
+
321
+ # Get updated model settings from new agent
322
+ updated_settings = await self._get__updated_model_settings(self._current_agent)
323
+
324
+ # Send handoff event
325
+ await self._put_event(
326
+ RealtimeHandoffEvent(
327
+ from_agent=previous_agent,
328
+ to_agent=self._current_agent,
329
+ info=self._event_info,
330
+ )
331
+ )
332
+
333
+ # Send tool output to complete the handoff
334
+ await self._model.send_event(
335
+ RealtimeModelSendToolOutput(
336
+ tool_call=event,
337
+ output=f"Handed off to {self._current_agent.name}",
338
+ start_response=True,
339
+ )
340
+ )
341
+
342
+ # Send session update to model
343
+ await self._model.send_event(
344
+ RealtimeModelSendSessionUpdate(session_settings=updated_settings)
345
+ )
346
+ else:
347
+ raise ModelBehaviorError(f"Tool {event.name} not found")
348
+
349
+ @classmethod
350
+ def _get_new_history(
351
+ cls,
352
+ old_history: list[RealtimeItem],
353
+ event: RealtimeModelInputAudioTranscriptionCompletedEvent | RealtimeItem,
354
+ ) -> list[RealtimeItem]:
355
+ # Merge transcript into placeholder input_audio message.
356
+ if isinstance(event, RealtimeModelInputAudioTranscriptionCompletedEvent):
357
+ new_history: list[RealtimeItem] = []
358
+ for item in old_history:
359
+ if item.item_id == event.item_id and item.type == "message" and item.role == "user":
360
+ content: list[InputText | InputAudio] = []
361
+ for entry in item.content:
362
+ if entry.type == "input_audio":
363
+ copied_entry = entry.model_copy(update={"transcript": event.transcript})
364
+ content.append(copied_entry)
365
+ else:
366
+ content.append(entry) # type: ignore
367
+ new_history.append(
368
+ item.model_copy(update={"content": content, "status": "completed"})
369
+ )
370
+ else:
371
+ new_history.append(item)
372
+ return new_history
373
+
374
+ # Otherwise it's just a new item
375
+ # TODO (rm) Add support for audio storage config
376
+
377
+ # If the item already exists, update it
378
+ existing_index = next(
379
+ (i for i, item in enumerate(old_history) if item.item_id == event.item_id), None
380
+ )
381
+ if existing_index is not None:
382
+ new_history = old_history.copy()
383
+ new_history[existing_index] = event
384
+ return new_history
385
+ # Otherwise, insert it after the previous_item_id if that is set
386
+ elif event.previous_item_id:
387
+ # Insert the new item after the previous item
388
+ previous_index = next(
389
+ (i for i, item in enumerate(old_history) if item.item_id == event.previous_item_id),
390
+ None,
391
+ )
392
+ if previous_index is not None:
393
+ new_history = old_history.copy()
394
+ new_history.insert(previous_index + 1, event)
395
+ return new_history
396
+
397
+ # Otherwise, add it to the end
398
+ return old_history + [event]
399
+
400
+ async def _run_output_guardrails(self, text: str) -> bool:
401
+ """Run output guardrails on the given text. Returns True if any guardrail was triggered."""
402
+ output_guardrails = self._run_config.get("output_guardrails", [])
403
+ if not output_guardrails or self._interrupted_by_guardrail:
404
+ return False
405
+
406
+ triggered_results = []
407
+
408
+ for guardrail in output_guardrails:
409
+ try:
410
+ result = await guardrail.run(
411
+ # TODO (rm) Remove this cast, it's wrong
412
+ self._context_wrapper,
413
+ cast(Agent[Any], self._current_agent),
414
+ text,
415
+ )
416
+ if result.output.tripwire_triggered:
417
+ triggered_results.append(result)
418
+ except Exception:
419
+ # Continue with other guardrails if one fails
420
+ continue
421
+
422
+ if triggered_results:
423
+ # Mark as interrupted to prevent multiple interrupts
424
+ self._interrupted_by_guardrail = True
425
+
426
+ # Emit guardrail tripped event
427
+ await self._put_event(
428
+ RealtimeGuardrailTripped(
429
+ guardrail_results=triggered_results,
430
+ message=text,
431
+ info=self._event_info,
432
+ )
433
+ )
434
+
435
+ # Interrupt the model
436
+ await self._model.send_event(RealtimeModelSendInterrupt())
437
+
438
+ # Send guardrail triggered message
439
+ guardrail_names = [result.guardrail.get_name() for result in triggered_results]
440
+ await self._model.send_event(
441
+ RealtimeModelSendUserInput(
442
+ user_input=f"guardrail triggered: {', '.join(guardrail_names)}"
443
+ )
444
+ )
445
+
446
+ return True
447
+
448
+ return False
449
+
450
+ def _enqueue_guardrail_task(self, text: str) -> None:
451
+ # Runs the guardrails in a separate task to avoid blocking the main loop
452
+
453
+ task = asyncio.create_task(self._run_output_guardrails(text))
454
+ self._guardrail_tasks.add(task)
455
+
456
+ # Add callback to remove completed tasks and handle exceptions
457
+ task.add_done_callback(self._on_guardrail_task_done)
458
+
459
+ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None:
460
+ """Handle completion of a guardrail task."""
461
+ # Remove from tracking set
462
+ self._guardrail_tasks.discard(task)
463
+
464
+ # Check for exceptions and propagate as events
465
+ if not task.cancelled():
466
+ exception = task.exception()
467
+ if exception:
468
+ # Create an exception event instead of raising
469
+ asyncio.create_task(
470
+ self._put_event(
471
+ RealtimeError(
472
+ info=self._event_info,
473
+ error={"message": f"Guardrail task failed: {str(exception)}"},
474
+ )
475
+ )
476
+ )
477
+
478
+ def _cleanup_guardrail_tasks(self) -> None:
479
+ for task in self._guardrail_tasks:
480
+ if not task.done():
481
+ task.cancel()
482
+ self._guardrail_tasks.clear()
483
+
484
+ async def _cleanup(self) -> None:
485
+ """Clean up all resources and mark session as closed."""
486
+ # Cancel and cleanup guardrail tasks
487
+ self._cleanup_guardrail_tasks()
488
+
489
+ # Remove ourselves as a listener
490
+ self._model.remove_listener(self)
491
+
492
+ # Close the model connection
493
+ await self._model.close()
494
+
495
+ # Mark as closed
496
+ self._closed = True
497
+
498
+ async def _get__updated_model_settings(
499
+ self, new_agent: RealtimeAgent
500
+ ) -> RealtimeSessionModelSettings:
501
+ updated_settings: RealtimeSessionModelSettings = {}
502
+ instructions, tools, handoffs = await asyncio.gather(
503
+ new_agent.get_system_prompt(self._context_wrapper),
504
+ new_agent.get_all_tools(self._context_wrapper),
505
+ self._get_handoffs(new_agent, self._context_wrapper),
506
+ )
507
+ updated_settings["instructions"] = instructions or ""
508
+ updated_settings["tools"] = tools or []
509
+ updated_settings["handoffs"] = handoffs or []
510
+
511
+ return updated_settings
512
+
513
+ @classmethod
514
+ async def _get_handoffs(
515
+ cls, agent: RealtimeAgent[Any], context_wrapper: RunContextWrapper[Any]
516
+ ) -> list[Handoff[Any, RealtimeAgent[Any]]]:
517
+ handoffs: list[Handoff[Any, RealtimeAgent[Any]]] = []
518
+ for handoff_item in agent.handoffs:
519
+ if isinstance(handoff_item, Handoff):
520
+ handoffs.append(handoff_item)
521
+ elif isinstance(handoff_item, RealtimeAgent):
522
+ handoffs.append(realtime_handoff(handoff_item))
523
+
524
+ async def _check_handoff_enabled(handoff_obj: Handoff[Any, RealtimeAgent[Any]]) -> bool:
525
+ attr = handoff_obj.is_enabled
526
+ if isinstance(attr, bool):
527
+ return attr
528
+ res = attr(context_wrapper, agent)
529
+ if inspect.isawaitable(res):
530
+ return await res
531
+ return res
532
+
533
+ results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs))
534
+ enabled = [h for h, ok in zip(handoffs, results) if ok]
535
+ return enabled