openai-agents 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

@@ -0,0 +1,502 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from collections.abc import AsyncIterator
5
+ from typing import Any, cast
6
+
7
+ from typing_extensions import assert_never
8
+
9
+ from ..agent import Agent
10
+ from ..exceptions import ModelBehaviorError, UserError
11
+ from ..handoffs import Handoff
12
+ from ..run_context import RunContextWrapper, TContext
13
+ from ..tool import FunctionTool
14
+ from ..tool_context import ToolContext
15
+ from .agent import RealtimeAgent
16
+ from .config import RealtimeRunConfig, RealtimeSessionModelSettings, RealtimeUserInput
17
+ from .events import (
18
+ RealtimeAgentEndEvent,
19
+ RealtimeAgentStartEvent,
20
+ RealtimeAudio,
21
+ RealtimeAudioEnd,
22
+ RealtimeAudioInterrupted,
23
+ RealtimeError,
24
+ RealtimeEventInfo,
25
+ RealtimeGuardrailTripped,
26
+ RealtimeHandoffEvent,
27
+ RealtimeHistoryAdded,
28
+ RealtimeHistoryUpdated,
29
+ RealtimeRawModelEvent,
30
+ RealtimeSessionEvent,
31
+ RealtimeToolEnd,
32
+ RealtimeToolStart,
33
+ )
34
+ from .items import InputAudio, InputText, RealtimeItem
35
+ from .model import RealtimeModel, RealtimeModelConfig, RealtimeModelListener
36
+ from .model_events import (
37
+ RealtimeModelEvent,
38
+ RealtimeModelInputAudioTranscriptionCompletedEvent,
39
+ RealtimeModelToolCallEvent,
40
+ )
41
+ from .model_inputs import (
42
+ RealtimeModelSendAudio,
43
+ RealtimeModelSendInterrupt,
44
+ RealtimeModelSendSessionUpdate,
45
+ RealtimeModelSendToolOutput,
46
+ RealtimeModelSendUserInput,
47
+ )
48
+
49
+
50
+ class RealtimeSession(RealtimeModelListener):
51
+ """A connection to a realtime model. It streams events from the model to you, and allows you to
52
+ send messages and audio to the model.
53
+
54
+ Example:
55
+ ```python
56
+ runner = RealtimeRunner(agent)
57
+ async with await runner.run() as session:
58
+ # Send messages
59
+ await session.send_message("Hello")
60
+ await session.send_audio(audio_bytes)
61
+
62
+ # Stream events
63
+ async for event in session:
64
+ if event.type == "audio":
65
+ # Handle audio event
66
+ pass
67
+ ```
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ model: RealtimeModel,
73
+ agent: RealtimeAgent,
74
+ context: TContext | None,
75
+ model_config: RealtimeModelConfig | None = None,
76
+ run_config: RealtimeRunConfig | None = None,
77
+ ) -> None:
78
+ """Initialize the session.
79
+
80
+ Args:
81
+ model: The model to use.
82
+ agent: The current agent.
83
+ context: The context object.
84
+ model_config: Model configuration.
85
+ run_config: Runtime configuration including guardrails.
86
+ """
87
+ self._model = model
88
+ self._current_agent = agent
89
+ self._context_wrapper = RunContextWrapper(context)
90
+ self._event_info = RealtimeEventInfo(context=self._context_wrapper)
91
+ self._history: list[RealtimeItem] = []
92
+ self._model_config = model_config or {}
93
+ self._run_config = run_config or {}
94
+ self._event_queue: asyncio.Queue[RealtimeSessionEvent] = asyncio.Queue()
95
+ self._closed = False
96
+ self._stored_exception: Exception | None = None
97
+
98
+ # Guardrails state tracking
99
+ self._interrupted_by_guardrail = False
100
+ self._item_transcripts: dict[str, str] = {} # item_id -> accumulated transcript
101
+ self._item_guardrail_run_counts: dict[str, int] = {} # item_id -> run count
102
+ self._debounce_text_length = self._run_config.get("guardrails_settings", {}).get(
103
+ "debounce_text_length", 100
104
+ )
105
+
106
+ self._guardrail_tasks: set[asyncio.Task[Any]] = set()
107
+
108
+ async def __aenter__(self) -> RealtimeSession:
109
+ """Start the session by connecting to the model. After this, you will be able to stream
110
+ events from the model and send messages and audio to the model.
111
+ """
112
+ # Add ourselves as a listener
113
+ self._model.add_listener(self)
114
+
115
+ # Connect to the model
116
+ await self._model.connect(self._model_config)
117
+
118
+ # Emit initial history update
119
+ await self._put_event(
120
+ RealtimeHistoryUpdated(
121
+ history=self._history,
122
+ info=self._event_info,
123
+ )
124
+ )
125
+
126
+ return self
127
+
128
+ async def enter(self) -> RealtimeSession:
129
+ """Enter the async context manager. We strongly recommend using the async context manager
130
+ pattern instead of this method. If you use this, you need to manually call `close()` when
131
+ you are done.
132
+ """
133
+ return await self.__aenter__()
134
+
135
+ async def __aexit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None:
136
+ """End the session."""
137
+ await self.close()
138
+
139
+ async def __aiter__(self) -> AsyncIterator[RealtimeSessionEvent]:
140
+ """Iterate over events from the session."""
141
+ while not self._closed:
142
+ try:
143
+ # Check if there's a stored exception to raise
144
+ if self._stored_exception is not None:
145
+ # Clean up resources before raising
146
+ await self._cleanup()
147
+ raise self._stored_exception
148
+
149
+ event = await self._event_queue.get()
150
+ yield event
151
+ except asyncio.CancelledError:
152
+ break
153
+
154
+ async def close(self) -> None:
155
+ """Close the session."""
156
+ await self._cleanup()
157
+
158
+ async def send_message(self, message: RealtimeUserInput) -> None:
159
+ """Send a message to the model."""
160
+ await self._model.send_event(RealtimeModelSendUserInput(user_input=message))
161
+
162
+ async def send_audio(self, audio: bytes, *, commit: bool = False) -> None:
163
+ """Send a raw audio chunk to the model."""
164
+ await self._model.send_event(RealtimeModelSendAudio(audio=audio, commit=commit))
165
+
166
+ async def interrupt(self) -> None:
167
+ """Interrupt the model."""
168
+ await self._model.send_event(RealtimeModelSendInterrupt())
169
+
170
+ async def on_event(self, event: RealtimeModelEvent) -> None:
171
+ await self._put_event(RealtimeRawModelEvent(data=event, info=self._event_info))
172
+
173
+ if event.type == "error":
174
+ await self._put_event(RealtimeError(info=self._event_info, error=event.error))
175
+ elif event.type == "function_call":
176
+ await self._handle_tool_call(event)
177
+ elif event.type == "audio":
178
+ await self._put_event(RealtimeAudio(info=self._event_info, audio=event))
179
+ elif event.type == "audio_interrupted":
180
+ await self._put_event(RealtimeAudioInterrupted(info=self._event_info))
181
+ elif event.type == "audio_done":
182
+ await self._put_event(RealtimeAudioEnd(info=self._event_info))
183
+ elif event.type == "input_audio_transcription_completed":
184
+ self._history = RealtimeSession._get_new_history(self._history, event)
185
+ await self._put_event(
186
+ RealtimeHistoryUpdated(info=self._event_info, history=self._history)
187
+ )
188
+ elif event.type == "transcript_delta":
189
+ # Accumulate transcript text for guardrail debouncing per item_id
190
+ item_id = event.item_id
191
+ if item_id not in self._item_transcripts:
192
+ self._item_transcripts[item_id] = ""
193
+ self._item_guardrail_run_counts[item_id] = 0
194
+
195
+ self._item_transcripts[item_id] += event.delta
196
+
197
+ # Check if we should run guardrails based on debounce threshold
198
+ current_length = len(self._item_transcripts[item_id])
199
+ threshold = self._debounce_text_length
200
+ next_run_threshold = (self._item_guardrail_run_counts[item_id] + 1) * threshold
201
+
202
+ if current_length >= next_run_threshold:
203
+ self._item_guardrail_run_counts[item_id] += 1
204
+ self._enqueue_guardrail_task(self._item_transcripts[item_id])
205
+ elif event.type == "item_updated":
206
+ is_new = not any(item.item_id == event.item.item_id for item in self._history)
207
+ self._history = self._get_new_history(self._history, event.item)
208
+ if is_new:
209
+ new_item = next(
210
+ item for item in self._history if item.item_id == event.item.item_id
211
+ )
212
+ await self._put_event(RealtimeHistoryAdded(info=self._event_info, item=new_item))
213
+ else:
214
+ await self._put_event(
215
+ RealtimeHistoryUpdated(info=self._event_info, history=self._history)
216
+ )
217
+ elif event.type == "item_deleted":
218
+ deleted_id = event.item_id
219
+ self._history = [item for item in self._history if item.item_id != deleted_id]
220
+ await self._put_event(
221
+ RealtimeHistoryUpdated(info=self._event_info, history=self._history)
222
+ )
223
+ elif event.type == "connection_status":
224
+ pass
225
+ elif event.type == "turn_started":
226
+ await self._put_event(
227
+ RealtimeAgentStartEvent(
228
+ agent=self._current_agent,
229
+ info=self._event_info,
230
+ )
231
+ )
232
+ elif event.type == "turn_ended":
233
+ # Clear guardrail state for next turn
234
+ self._item_transcripts.clear()
235
+ self._item_guardrail_run_counts.clear()
236
+ self._interrupted_by_guardrail = False
237
+
238
+ await self._put_event(
239
+ RealtimeAgentEndEvent(
240
+ agent=self._current_agent,
241
+ info=self._event_info,
242
+ )
243
+ )
244
+ elif event.type == "exception":
245
+ # Store the exception to be raised in __aiter__
246
+ self._stored_exception = event.exception
247
+ elif event.type == "other":
248
+ pass
249
+ else:
250
+ assert_never(event)
251
+
252
+ async def _put_event(self, event: RealtimeSessionEvent) -> None:
253
+ """Put an event into the queue."""
254
+ await self._event_queue.put(event)
255
+
256
+ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
257
+ """Handle a tool call event."""
258
+ all_tools = await self._current_agent.get_all_tools(self._context_wrapper)
259
+ function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
260
+ handoff_map = {tool.name: tool for tool in all_tools if isinstance(tool, Handoff)}
261
+
262
+ if event.name in function_map:
263
+ await self._put_event(
264
+ RealtimeToolStart(
265
+ info=self._event_info,
266
+ tool=function_map[event.name],
267
+ agent=self._current_agent,
268
+ )
269
+ )
270
+
271
+ func_tool = function_map[event.name]
272
+ tool_context = ToolContext(
273
+ context=self._context_wrapper.context,
274
+ usage=self._context_wrapper.usage,
275
+ tool_name=event.name,
276
+ tool_call_id=event.call_id,
277
+ )
278
+ result = await func_tool.on_invoke_tool(tool_context, event.arguments)
279
+
280
+ await self._model.send_event(
281
+ RealtimeModelSendToolOutput(
282
+ tool_call=event, output=str(result), start_response=True
283
+ )
284
+ )
285
+
286
+ await self._put_event(
287
+ RealtimeToolEnd(
288
+ info=self._event_info,
289
+ tool=func_tool,
290
+ output=result,
291
+ agent=self._current_agent,
292
+ )
293
+ )
294
+ elif event.name in handoff_map:
295
+ handoff = handoff_map[event.name]
296
+ tool_context = ToolContext(
297
+ context=self._context_wrapper.context,
298
+ usage=self._context_wrapper.usage,
299
+ tool_name=event.name,
300
+ tool_call_id=event.call_id,
301
+ )
302
+
303
+ # Execute the handoff to get the new agent
304
+ result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments)
305
+ if not isinstance(result, RealtimeAgent):
306
+ raise UserError(f"Handoff {handoff.name} returned invalid result: {type(result)}")
307
+
308
+ # Store previous agent for event
309
+ previous_agent = self._current_agent
310
+
311
+ # Update current agent
312
+ self._current_agent = result
313
+
314
+ # Get updated model settings from new agent
315
+ updated_settings = await self._get__updated_model_settings(self._current_agent)
316
+
317
+ # Send handoff event
318
+ await self._put_event(
319
+ RealtimeHandoffEvent(
320
+ from_agent=previous_agent,
321
+ to_agent=self._current_agent,
322
+ info=self._event_info,
323
+ )
324
+ )
325
+
326
+ # Send tool output to complete the handoff
327
+ await self._model.send_event(
328
+ RealtimeModelSendToolOutput(
329
+ tool_call=event,
330
+ output=f"Handed off to {self._current_agent.name}",
331
+ start_response=True,
332
+ )
333
+ )
334
+
335
+ # Send session update to model
336
+ await self._model.send_event(
337
+ RealtimeModelSendSessionUpdate(session_settings=updated_settings)
338
+ )
339
+ else:
340
+ raise ModelBehaviorError(f"Tool {event.name} not found")
341
+
342
+ @classmethod
343
+ def _get_new_history(
344
+ cls,
345
+ old_history: list[RealtimeItem],
346
+ event: RealtimeModelInputAudioTranscriptionCompletedEvent | RealtimeItem,
347
+ ) -> list[RealtimeItem]:
348
+ # Merge transcript into placeholder input_audio message.
349
+ if isinstance(event, RealtimeModelInputAudioTranscriptionCompletedEvent):
350
+ new_history: list[RealtimeItem] = []
351
+ for item in old_history:
352
+ if item.item_id == event.item_id and item.type == "message" and item.role == "user":
353
+ content: list[InputText | InputAudio] = []
354
+ for entry in item.content:
355
+ if entry.type == "input_audio":
356
+ copied_entry = entry.model_copy(update={"transcript": event.transcript})
357
+ content.append(copied_entry)
358
+ else:
359
+ content.append(entry) # type: ignore
360
+ new_history.append(
361
+ item.model_copy(update={"content": content, "status": "completed"})
362
+ )
363
+ else:
364
+ new_history.append(item)
365
+ return new_history
366
+
367
+ # Otherwise it's just a new item
368
+ # TODO (rm) Add support for audio storage config
369
+
370
+ # If the item already exists, update it
371
+ existing_index = next(
372
+ (i for i, item in enumerate(old_history) if item.item_id == event.item_id), None
373
+ )
374
+ if existing_index is not None:
375
+ new_history = old_history.copy()
376
+ new_history[existing_index] = event
377
+ return new_history
378
+ # Otherwise, insert it after the previous_item_id if that is set
379
+ elif event.previous_item_id:
380
+ # Insert the new item after the previous item
381
+ previous_index = next(
382
+ (i for i, item in enumerate(old_history) if item.item_id == event.previous_item_id),
383
+ None,
384
+ )
385
+ if previous_index is not None:
386
+ new_history = old_history.copy()
387
+ new_history.insert(previous_index + 1, event)
388
+ return new_history
389
+
390
+ # Otherwise, add it to the end
391
+ return old_history + [event]
392
+
393
+ async def _run_output_guardrails(self, text: str) -> bool:
394
+ """Run output guardrails on the given text. Returns True if any guardrail was triggered."""
395
+ output_guardrails = self._run_config.get("output_guardrails", [])
396
+ if not output_guardrails or self._interrupted_by_guardrail:
397
+ return False
398
+
399
+ triggered_results = []
400
+
401
+ for guardrail in output_guardrails:
402
+ try:
403
+ result = await guardrail.run(
404
+ # TODO (rm) Remove this cast, it's wrong
405
+ self._context_wrapper,
406
+ cast(Agent[Any], self._current_agent),
407
+ text,
408
+ )
409
+ if result.output.tripwire_triggered:
410
+ triggered_results.append(result)
411
+ except Exception:
412
+ # Continue with other guardrails if one fails
413
+ continue
414
+
415
+ if triggered_results:
416
+ # Mark as interrupted to prevent multiple interrupts
417
+ self._interrupted_by_guardrail = True
418
+
419
+ # Emit guardrail tripped event
420
+ await self._put_event(
421
+ RealtimeGuardrailTripped(
422
+ guardrail_results=triggered_results,
423
+ message=text,
424
+ info=self._event_info,
425
+ )
426
+ )
427
+
428
+ # Interrupt the model
429
+ await self._model.send_event(RealtimeModelSendInterrupt())
430
+
431
+ # Send guardrail triggered message
432
+ guardrail_names = [result.guardrail.get_name() for result in triggered_results]
433
+ await self._model.send_event(
434
+ RealtimeModelSendUserInput(
435
+ user_input=f"guardrail triggered: {', '.join(guardrail_names)}"
436
+ )
437
+ )
438
+
439
+ return True
440
+
441
+ return False
442
+
443
+ def _enqueue_guardrail_task(self, text: str) -> None:
444
+ # Runs the guardrails in a separate task to avoid blocking the main loop
445
+
446
+ task = asyncio.create_task(self._run_output_guardrails(text))
447
+ self._guardrail_tasks.add(task)
448
+
449
+ # Add callback to remove completed tasks and handle exceptions
450
+ task.add_done_callback(self._on_guardrail_task_done)
451
+
452
+ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None:
453
+ """Handle completion of a guardrail task."""
454
+ # Remove from tracking set
455
+ self._guardrail_tasks.discard(task)
456
+
457
+ # Check for exceptions and propagate as events
458
+ if not task.cancelled():
459
+ exception = task.exception()
460
+ if exception:
461
+ # Create an exception event instead of raising
462
+ asyncio.create_task(
463
+ self._put_event(
464
+ RealtimeError(
465
+ info=self._event_info,
466
+ error={"message": f"Guardrail task failed: {str(exception)}"},
467
+ )
468
+ )
469
+ )
470
+
471
+ def _cleanup_guardrail_tasks(self) -> None:
472
+ for task in self._guardrail_tasks:
473
+ if not task.done():
474
+ task.cancel()
475
+ self._guardrail_tasks.clear()
476
+
477
+ async def _cleanup(self) -> None:
478
+ """Clean up all resources and mark session as closed."""
479
+ # Cancel and cleanup guardrail tasks
480
+ self._cleanup_guardrail_tasks()
481
+
482
+ # Remove ourselves as a listener
483
+ self._model.remove_listener(self)
484
+
485
+ # Close the model connection
486
+ await self._model.close()
487
+
488
+ # Mark as closed
489
+ self._closed = True
490
+
491
+ async def _get__updated_model_settings(
492
+ self, new_agent: RealtimeAgent
493
+ ) -> RealtimeSessionModelSettings:
494
+ updated_settings: RealtimeSessionModelSettings = {}
495
+ instructions, tools = await asyncio.gather(
496
+ new_agent.get_system_prompt(self._context_wrapper),
497
+ new_agent.get_all_tools(self._context_wrapper),
498
+ )
499
+ updated_settings["instructions"] = instructions or ""
500
+ updated_settings["tools"] = tools or []
501
+
502
+ return updated_settings