openai-agents 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

@@ -57,11 +57,13 @@ from typing_extensions import assert_never
57
57
  from websockets.asyncio.client import ClientConnection
58
58
 
59
59
  from agents.handoffs import Handoff
60
+ from agents.realtime._default_tracker import ModelAudioTracker
60
61
  from agents.tool import FunctionTool, Tool
61
62
  from agents.util._types import MaybeAwaitable
62
63
 
63
64
  from ..exceptions import UserError
64
65
  from ..logger import logger
66
+ from ..version import __version__
65
67
  from .config import (
66
68
  RealtimeModelTracingConfig,
67
69
  RealtimeSessionModelSettings,
@@ -71,6 +73,8 @@ from .model import (
71
73
  RealtimeModel,
72
74
  RealtimeModelConfig,
73
75
  RealtimeModelListener,
76
+ RealtimePlaybackState,
77
+ RealtimePlaybackTracker,
74
78
  )
75
79
  from .model_events import (
76
80
  RealtimeModelAudioDoneEvent,
@@ -82,6 +86,7 @@ from .model_events import (
82
86
  RealtimeModelInputAudioTranscriptionCompletedEvent,
83
87
  RealtimeModelItemDeletedEvent,
84
88
  RealtimeModelItemUpdatedEvent,
89
+ RealtimeModelRawServerEvent,
85
90
  RealtimeModelToolCallEvent,
86
91
  RealtimeModelTranscriptDeltaEvent,
87
92
  RealtimeModelTurnEndedEvent,
@@ -97,6 +102,8 @@ from .model_inputs import (
97
102
  RealtimeModelSendUserInput,
98
103
  )
99
104
 
105
+ _USER_AGENT = f"Agents/Python {__version__}"
106
+
100
107
  DEFAULT_MODEL_SETTINGS: RealtimeSessionModelSettings = {
101
108
  "voice": "ash",
102
109
  "modalities": ["text", "audio"],
@@ -130,11 +137,11 @@ class OpenAIRealtimeWebSocketModel(RealtimeModel):
130
137
  self._websocket_task: asyncio.Task[None] | None = None
131
138
  self._listeners: list[RealtimeModelListener] = []
132
139
  self._current_item_id: str | None = None
133
- self._audio_start_time: datetime | None = None
134
- self._audio_length_ms: float = 0.0
140
+ self._audio_state_tracker: ModelAudioTracker = ModelAudioTracker()
135
141
  self._ongoing_response: bool = False
136
- self._current_audio_content_index: int | None = None
137
142
  self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None
143
+ self._playback_tracker: RealtimePlaybackTracker | None = None
144
+ self._created_session: OpenAISessionObject | None = None
138
145
 
139
146
  async def connect(self, options: RealtimeModelConfig) -> None:
140
147
  """Establish a connection to the model and keep it alive."""
@@ -143,6 +150,8 @@ class OpenAIRealtimeWebSocketModel(RealtimeModel):
143
150
 
144
151
  model_settings: RealtimeSessionModelSettings = options.get("initial_model_settings", {})
145
152
 
153
+ self._playback_tracker = options.get("playback_tracker", RealtimePlaybackTracker())
154
+
146
155
  self.model = model_settings.get("model_name", self.model)
147
156
  api_key = await get_api_key(options.get("api_key"))
148
157
 
@@ -160,7 +169,9 @@ class OpenAIRealtimeWebSocketModel(RealtimeModel):
160
169
  "Authorization": f"Bearer {api_key}",
161
170
  "OpenAI-Beta": "realtime=v1",
162
171
  }
163
- self._websocket = await websockets.connect(url, additional_headers=headers)
172
+ self._websocket = await websockets.connect(
173
+ url, user_agent_header=_USER_AGENT, additional_headers=headers
174
+ )
164
175
  self._websocket_task = asyncio.create_task(self._listen_for_messages())
165
176
  await self._update_session_config(model_settings)
166
177
 
@@ -289,26 +300,69 @@ class OpenAIRealtimeWebSocketModel(RealtimeModel):
289
300
  if event.start_response:
290
301
  await self._send_raw_message(OpenAIResponseCreateEvent(type="response.create"))
291
302
 
303
+ def _get_playback_state(self) -> RealtimePlaybackState:
304
+ if self._playback_tracker:
305
+ return self._playback_tracker.get_state()
306
+
307
+ if last_audio_item_id := self._audio_state_tracker.get_last_audio_item():
308
+ item_id, item_content_index = last_audio_item_id
309
+ audio_state = self._audio_state_tracker.get_state(item_id, item_content_index)
310
+ if audio_state:
311
+ elapsed_ms = (
312
+ datetime.now() - audio_state.initial_received_time
313
+ ).total_seconds() * 1000
314
+ return {
315
+ "current_item_id": item_id,
316
+ "current_item_content_index": item_content_index,
317
+ "elapsed_ms": elapsed_ms,
318
+ }
319
+
320
+ return {
321
+ "current_item_id": None,
322
+ "current_item_content_index": None,
323
+ "elapsed_ms": None,
324
+ }
325
+
292
326
  async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
293
- if not self._current_item_id or not self._audio_start_time:
327
+ playback_state = self._get_playback_state()
328
+ current_item_id = playback_state.get("current_item_id")
329
+ current_item_content_index = playback_state.get("current_item_content_index")
330
+ elapsed_ms = playback_state.get("elapsed_ms")
331
+ if current_item_id is None or elapsed_ms is None:
332
+ logger.info(
333
+ "Skipping interrupt. "
334
+ f"Item id: {current_item_id}, "
335
+ f"elapsed ms: {elapsed_ms}, "
336
+ f"content index: {current_item_content_index}"
337
+ )
294
338
  return
295
339
 
296
- await self._cancel_response()
297
-
298
- elapsed_time_ms = (datetime.now() - self._audio_start_time).total_seconds() * 1000
299
- if elapsed_time_ms > 0 and elapsed_time_ms < self._audio_length_ms:
300
- await self._emit_event(RealtimeModelAudioInterruptedEvent())
340
+ current_item_content_index = current_item_content_index or 0
341
+ if elapsed_ms > 0:
342
+ await self._emit_event(
343
+ RealtimeModelAudioInterruptedEvent(
344
+ item_id=current_item_id,
345
+ content_index=current_item_content_index,
346
+ )
347
+ )
301
348
  converted = _ConversionHelper.convert_interrupt(
302
- self._current_item_id,
303
- self._current_audio_content_index or 0,
304
- int(elapsed_time_ms),
349
+ current_item_id,
350
+ current_item_content_index,
351
+ int(elapsed_ms),
305
352
  )
306
353
  await self._send_raw_message(converted)
307
354
 
308
- self._current_item_id = None
309
- self._audio_start_time = None
310
- self._audio_length_ms = 0.0
311
- self._current_audio_content_index = None
355
+ automatic_response_cancellation_enabled = (
356
+ self._created_session
357
+ and self._created_session.turn_detection
358
+ and self._created_session.turn_detection.interrupt_response
359
+ )
360
+ if not automatic_response_cancellation_enabled:
361
+ await self._cancel_response()
362
+
363
+ self._audio_state_tracker.on_interrupted()
364
+ if self._playback_tracker:
365
+ self._playback_tracker.on_interrupted()
312
366
 
313
367
  async def _send_session_update(self, event: RealtimeModelSendSessionUpdate) -> None:
314
368
  """Send a session update to the model."""
@@ -316,23 +370,21 @@ class OpenAIRealtimeWebSocketModel(RealtimeModel):
316
370
 
317
371
  async def _handle_audio_delta(self, parsed: ResponseAudioDeltaEvent) -> None:
318
372
  """Handle audio delta events and update audio tracking state."""
319
- self._current_audio_content_index = parsed.content_index
320
373
  self._current_item_id = parsed.item_id
321
- if self._audio_start_time is None:
322
- self._audio_start_time = datetime.now()
323
- self._audio_length_ms = 0.0
324
374
 
325
375
  audio_bytes = base64.b64decode(parsed.delta)
326
- # Calculate audio length in ms using 24KHz pcm16le
327
- self._audio_length_ms += self._calculate_audio_length_ms(audio_bytes)
376
+
377
+ self._audio_state_tracker.on_audio_delta(parsed.item_id, parsed.content_index, audio_bytes)
378
+
328
379
  await self._emit_event(
329
- RealtimeModelAudioEvent(data=audio_bytes, response_id=parsed.response_id)
380
+ RealtimeModelAudioEvent(
381
+ data=audio_bytes,
382
+ response_id=parsed.response_id,
383
+ item_id=parsed.item_id,
384
+ content_index=parsed.content_index,
385
+ )
330
386
  )
331
387
 
332
- def _calculate_audio_length_ms(self, audio_bytes: bytes) -> float:
333
- """Calculate audio length in milliseconds for 24KHz PCM16LE format."""
334
- return len(audio_bytes) / 24 / 2
335
-
336
388
  async def _handle_output_item(self, item: ConversationItem) -> None:
337
389
  """Handle response output item events (function calls and messages)."""
338
390
  if item.type == "function_call" and item.status == "completed":
@@ -396,6 +448,7 @@ class OpenAIRealtimeWebSocketModel(RealtimeModel):
396
448
  self._ongoing_response = False
397
449
 
398
450
  async def _handle_ws_event(self, event: dict[str, Any]):
451
+ await self._emit_event(RealtimeModelRawServerEvent(data=event))
399
452
  try:
400
453
  if "previous_item_id" in event and event["previous_item_id"] is None:
401
454
  event["previous_item_id"] = "" # TODO (rm) remove
@@ -424,7 +477,12 @@ class OpenAIRealtimeWebSocketModel(RealtimeModel):
424
477
  if parsed.type == "response.audio.delta":
425
478
  await self._handle_audio_delta(parsed)
426
479
  elif parsed.type == "response.audio.done":
427
- await self._emit_event(RealtimeModelAudioDoneEvent())
480
+ await self._emit_event(
481
+ RealtimeModelAudioDoneEvent(
482
+ item_id=parsed.item_id,
483
+ content_index=parsed.content_index,
484
+ )
485
+ )
428
486
  elif parsed.type == "input_audio_buffer.speech_started":
429
487
  await self._send_interrupt(RealtimeModelSendInterrupt())
430
488
  elif parsed.type == "response.created":
@@ -435,6 +493,9 @@ class OpenAIRealtimeWebSocketModel(RealtimeModel):
435
493
  await self._emit_event(RealtimeModelTurnEndedEvent())
436
494
  elif parsed.type == "session.created":
437
495
  await self._send_tracing_config(self._tracing_config)
496
+ self._update_created_session(parsed.session) # type: ignore
497
+ elif parsed.type == "session.updated":
498
+ self._update_created_session(parsed.session) # type: ignore
438
499
  elif parsed.type == "error":
439
500
  await self._emit_event(RealtimeModelErrorEvent(error=parsed.error))
440
501
  elif parsed.type == "conversation.item.deleted":
@@ -484,6 +545,13 @@ class OpenAIRealtimeWebSocketModel(RealtimeModel):
484
545
  ):
485
546
  await self._handle_output_item(parsed.item)
486
547
 
548
+ def _update_created_session(self, session: OpenAISessionObject) -> None:
549
+ self._created_session = session
550
+ if session.output_audio_format:
551
+ self._audio_state_tracker.set_audio_format(session.output_audio_format)
552
+ if self._playback_tracker:
553
+ self._playback_tracker.set_audio_format(session.output_audio_format)
554
+
487
555
  async def _update_session_config(self, model_settings: RealtimeSessionModelSettings) -> None:
488
556
  session_config = self._get_session_config(model_settings)
489
557
  await self._send_raw_message(
@@ -107,6 +107,11 @@ class RealtimeSession(RealtimeModelListener):
107
107
 
108
108
  self._guardrail_tasks: set[asyncio.Task[Any]] = set()
109
109
 
110
+ @property
111
+ def model(self) -> RealtimeModel:
112
+ """Access the underlying model for adding listeners or other direct interaction."""
113
+ return self._model
114
+
110
115
  async def __aenter__(self) -> RealtimeSession:
111
116
  """Start the session by connecting to the model. After this, you will be able to stream
112
117
  events from the model and send messages and audio to the model.
@@ -116,7 +121,8 @@ class RealtimeSession(RealtimeModelListener):
116
121
 
117
122
  model_config = self._model_config.copy()
118
123
  model_config["initial_model_settings"] = await self._get_updated_model_settings_from_agent(
119
- self._current_agent
124
+ starting_settings=self._model_config.get("initial_model_settings", None),
125
+ agent=self._current_agent,
120
126
  )
121
127
 
122
128
  # Connect to the model
@@ -182,11 +188,26 @@ class RealtimeSession(RealtimeModelListener):
182
188
  elif event.type == "function_call":
183
189
  await self._handle_tool_call(event)
184
190
  elif event.type == "audio":
185
- await self._put_event(RealtimeAudio(info=self._event_info, audio=event))
191
+ await self._put_event(
192
+ RealtimeAudio(
193
+ info=self._event_info,
194
+ audio=event,
195
+ item_id=event.item_id,
196
+ content_index=event.content_index,
197
+ )
198
+ )
186
199
  elif event.type == "audio_interrupted":
187
- await self._put_event(RealtimeAudioInterrupted(info=self._event_info))
200
+ await self._put_event(
201
+ RealtimeAudioInterrupted(
202
+ info=self._event_info, item_id=event.item_id, content_index=event.content_index
203
+ )
204
+ )
188
205
  elif event.type == "audio_done":
189
- await self._put_event(RealtimeAudioEnd(info=self._event_info))
206
+ await self._put_event(
207
+ RealtimeAudioEnd(
208
+ info=self._event_info, item_id=event.item_id, content_index=event.content_index
209
+ )
210
+ )
190
211
  elif event.type == "input_audio_transcription_completed":
191
212
  self._history = RealtimeSession._get_new_history(self._history, event)
192
213
  await self._put_event(
@@ -253,6 +274,8 @@ class RealtimeSession(RealtimeModelListener):
253
274
  self._stored_exception = event.exception
254
275
  elif event.type == "other":
255
276
  pass
277
+ elif event.type == "raw_server_event":
278
+ pass
256
279
  else:
257
280
  assert_never(event)
258
281
 
@@ -325,7 +348,8 @@ class RealtimeSession(RealtimeModelListener):
325
348
 
326
349
  # Get updated model settings from new agent
327
350
  updated_settings = await self._get_updated_model_settings_from_agent(
328
- self._current_agent
351
+ starting_settings=None,
352
+ agent=self._current_agent,
329
353
  )
330
354
 
331
355
  # Send handoff event
@@ -504,9 +528,16 @@ class RealtimeSession(RealtimeModelListener):
504
528
 
505
529
  async def _get_updated_model_settings_from_agent(
506
530
  self,
531
+ starting_settings: RealtimeSessionModelSettings | None,
507
532
  agent: RealtimeAgent,
508
533
  ) -> RealtimeSessionModelSettings:
509
- updated_settings: RealtimeSessionModelSettings = {}
534
+ # Start with run config model settings as base
535
+ run_config_settings = self._run_config.get("model_settings", {})
536
+ updated_settings: RealtimeSessionModelSettings = run_config_settings.copy()
537
+ # Apply starting settings (from model config) next
538
+ if starting_settings:
539
+ updated_settings.update(starting_settings)
540
+
510
541
  instructions, tools, handoffs = await asyncio.gather(
511
542
  agent.get_system_prompt(self._context_wrapper),
512
543
  agent.get_all_tools(self._context_wrapper),
@@ -516,10 +547,6 @@ class RealtimeSession(RealtimeModelListener):
516
547
  updated_settings["tools"] = tools or []
517
548
  updated_settings["handoffs"] = handoffs or []
518
549
 
519
- # Override with initial settings
520
- initial_settings = self._model_config.get("initial_model_settings", {})
521
- updated_settings.update(initial_settings)
522
-
523
550
  disable_tracing = self._run_config.get("tracing_disabled", False)
524
551
  if disable_tracing:
525
552
  updated_settings["tracing"] = None
agents/tool.py CHANGED
@@ -24,6 +24,7 @@ from .function_schema import DocstringStyle, function_schema
24
24
  from .items import RunItem
25
25
  from .logger import logger
26
26
  from .run_context import RunContextWrapper
27
+ from .strict_schema import ensure_strict_json_schema
27
28
  from .tool_context import ToolContext
28
29
  from .tracing import SpanError
29
30
  from .util import _error_tracing
@@ -92,6 +93,10 @@ class FunctionTool:
92
93
  and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
93
94
  based on your context/state."""
94
95
 
96
+ def __post_init__(self):
97
+ if self.strict_json_schema:
98
+ self.params_json_schema = ensure_strict_json_schema(self.params_json_schema)
99
+
95
100
 
96
101
  @dataclass
97
102
  class FileSearchTool:
agents/tracing/create.py CHANGED
@@ -50,8 +50,7 @@ def trace(
50
50
  group_id: Optional grouping identifier to link multiple traces from the same conversation
51
51
  or process. For instance, you might use a chat thread ID.
52
52
  metadata: Optional dictionary of additional metadata to attach to the trace.
53
- disabled: If True, we will return a Trace but the Trace will not be recorded. This will
54
- not be checked if there's an existing trace and `even_if_trace_running` is True.
53
+ disabled: If True, we will return a Trace but the Trace will not be recorded.
55
54
 
56
55
  Returns:
57
56
  The newly created trace object.
@@ -22,7 +22,7 @@ class ConsoleSpanExporter(TracingExporter):
22
22
  def export(self, items: list[Trace | Span[Any]]) -> None:
23
23
  for item in items:
24
24
  if isinstance(item, Trace):
25
- print(f"[Exporter] Export trace_id={item.trace_id}, name={item.name}, ")
25
+ print(f"[Exporter] Export trace_id={item.trace_id}, name={item.name}")
26
26
  else:
27
27
  print(f"[Exporter] Export span: {item.export()}")
28
28
 
@@ -121,7 +121,7 @@ class BackendSpanExporter(TracingExporter):
121
121
  logger.debug(f"Exported {len(items)} items")
122
122
  return
123
123
 
124
- # If the response is a client error (4xx), we wont retry
124
+ # If the response is a client error (4xx), we won't retry
125
125
  if 400 <= response.status_code < 500:
126
126
  logger.error(
127
127
  f"[non-fatal] Tracing client error {response.status_code}: {response.text}"
@@ -183,7 +183,7 @@ class BatchTraceProcessor(TracingProcessor):
183
183
  self._shutdown_event = threading.Event()
184
184
 
185
185
  # The queue size threshold at which we export immediately.
186
- self._export_trigger_size = int(max_queue_size * export_trigger_ratio)
186
+ self._export_trigger_size = max(1, int(max_queue_size * export_trigger_ratio))
187
187
 
188
188
  # Track when we next *must* perform a scheduled export
189
189
  self._next_export_time = time.time() + self._schedule_delay
@@ -269,8 +269,7 @@ class BatchTraceProcessor(TracingProcessor):
269
269
 
270
270
  def _export_batches(self, force: bool = False):
271
271
  """Drains the queue and exports in batches. If force=True, export everything.
272
- Otherwise, export up to `max_batch_size` repeatedly until the queue is empty or below a
273
- certain threshold.
272
+ Otherwise, export up to `max_batch_size` repeatedly until the queue is completely empty.
274
273
  """
275
274
  while True:
276
275
  items_to_export: list[Span[Any] | Trace] = []
agents/tracing/traces.py CHANGED
@@ -10,7 +10,7 @@ from .processor_interface import TracingProcessor
10
10
  from .scope import Scope
11
11
 
12
12
 
13
- class Trace:
13
+ class Trace(abc.ABC):
14
14
  """
15
15
  A trace is the root level object that tracing creates. It represents a logical "workflow".
16
16
  """
agents/usage.py CHANGED
@@ -1,6 +1,7 @@
1
- from dataclasses import dataclass, field
1
+ from dataclasses import field
2
2
 
3
3
  from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
4
+ from pydantic.dataclasses import dataclass
4
5
 
5
6
 
6
7
  @dataclass