openai-agents 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +5 -1
- agents/_run_impl.py +5 -1
- agents/agent.py +62 -30
- agents/agent_output.py +2 -2
- agents/function_schema.py +11 -1
- agents/guardrail.py +5 -1
- agents/handoffs.py +32 -14
- agents/lifecycle.py +26 -17
- agents/mcp/server.py +82 -11
- agents/mcp/util.py +16 -9
- agents/memory/__init__.py +3 -0
- agents/memory/session.py +369 -0
- agents/model_settings.py +15 -7
- agents/models/chatcmpl_converter.py +20 -3
- agents/models/chatcmpl_stream_handler.py +134 -43
- agents/models/openai_responses.py +12 -5
- agents/realtime/README.md +3 -0
- agents/realtime/__init__.py +177 -0
- agents/realtime/agent.py +89 -0
- agents/realtime/config.py +188 -0
- agents/realtime/events.py +216 -0
- agents/realtime/handoffs.py +165 -0
- agents/realtime/items.py +184 -0
- agents/realtime/model.py +69 -0
- agents/realtime/model_events.py +159 -0
- agents/realtime/model_inputs.py +100 -0
- agents/realtime/openai_realtime.py +670 -0
- agents/realtime/runner.py +118 -0
- agents/realtime/session.py +535 -0
- agents/run.py +106 -4
- agents/tool.py +6 -7
- agents/tool_context.py +16 -3
- agents/voice/models/openai_stt.py +1 -1
- agents/voice/pipeline.py +6 -0
- agents/voice/workflow.py +8 -0
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/METADATA +121 -4
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/RECORD +39 -24
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/WHEEL +0 -0
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,670 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import base64
|
|
5
|
+
import inspect
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Any, Callable, Literal
|
|
10
|
+
|
|
11
|
+
import pydantic
|
|
12
|
+
import websockets
|
|
13
|
+
from openai.types.beta.realtime.conversation_item import (
|
|
14
|
+
ConversationItem,
|
|
15
|
+
ConversationItem as OpenAIConversationItem,
|
|
16
|
+
)
|
|
17
|
+
from openai.types.beta.realtime.conversation_item_content import (
|
|
18
|
+
ConversationItemContent as OpenAIConversationItemContent,
|
|
19
|
+
)
|
|
20
|
+
from openai.types.beta.realtime.conversation_item_create_event import (
|
|
21
|
+
ConversationItemCreateEvent as OpenAIConversationItemCreateEvent,
|
|
22
|
+
)
|
|
23
|
+
from openai.types.beta.realtime.conversation_item_retrieve_event import (
|
|
24
|
+
ConversationItemRetrieveEvent as OpenAIConversationItemRetrieveEvent,
|
|
25
|
+
)
|
|
26
|
+
from openai.types.beta.realtime.conversation_item_truncate_event import (
|
|
27
|
+
ConversationItemTruncateEvent as OpenAIConversationItemTruncateEvent,
|
|
28
|
+
)
|
|
29
|
+
from openai.types.beta.realtime.input_audio_buffer_append_event import (
|
|
30
|
+
InputAudioBufferAppendEvent as OpenAIInputAudioBufferAppendEvent,
|
|
31
|
+
)
|
|
32
|
+
from openai.types.beta.realtime.input_audio_buffer_commit_event import (
|
|
33
|
+
InputAudioBufferCommitEvent as OpenAIInputAudioBufferCommitEvent,
|
|
34
|
+
)
|
|
35
|
+
from openai.types.beta.realtime.realtime_client_event import (
|
|
36
|
+
RealtimeClientEvent as OpenAIRealtimeClientEvent,
|
|
37
|
+
)
|
|
38
|
+
from openai.types.beta.realtime.realtime_server_event import (
|
|
39
|
+
RealtimeServerEvent as OpenAIRealtimeServerEvent,
|
|
40
|
+
)
|
|
41
|
+
from openai.types.beta.realtime.response_audio_delta_event import ResponseAudioDeltaEvent
|
|
42
|
+
from openai.types.beta.realtime.response_cancel_event import (
|
|
43
|
+
ResponseCancelEvent as OpenAIResponseCancelEvent,
|
|
44
|
+
)
|
|
45
|
+
from openai.types.beta.realtime.response_create_event import (
|
|
46
|
+
ResponseCreateEvent as OpenAIResponseCreateEvent,
|
|
47
|
+
)
|
|
48
|
+
from openai.types.beta.realtime.session_update_event import (
|
|
49
|
+
Session as OpenAISessionObject,
|
|
50
|
+
SessionTool as OpenAISessionTool,
|
|
51
|
+
SessionTracing as OpenAISessionTracing,
|
|
52
|
+
SessionTracingTracingConfiguration as OpenAISessionTracingConfiguration,
|
|
53
|
+
SessionUpdateEvent as OpenAISessionUpdateEvent,
|
|
54
|
+
)
|
|
55
|
+
from pydantic import TypeAdapter
|
|
56
|
+
from typing_extensions import assert_never
|
|
57
|
+
from websockets.asyncio.client import ClientConnection
|
|
58
|
+
|
|
59
|
+
from agents.handoffs import Handoff
|
|
60
|
+
from agents.tool import FunctionTool, Tool
|
|
61
|
+
from agents.util._types import MaybeAwaitable
|
|
62
|
+
|
|
63
|
+
from ..exceptions import UserError
|
|
64
|
+
from ..logger import logger
|
|
65
|
+
from .config import (
|
|
66
|
+
RealtimeModelTracingConfig,
|
|
67
|
+
RealtimeSessionModelSettings,
|
|
68
|
+
)
|
|
69
|
+
from .items import RealtimeMessageItem, RealtimeToolCallItem
|
|
70
|
+
from .model import (
|
|
71
|
+
RealtimeModel,
|
|
72
|
+
RealtimeModelConfig,
|
|
73
|
+
RealtimeModelListener,
|
|
74
|
+
)
|
|
75
|
+
from .model_events import (
|
|
76
|
+
RealtimeModelAudioDoneEvent,
|
|
77
|
+
RealtimeModelAudioEvent,
|
|
78
|
+
RealtimeModelAudioInterruptedEvent,
|
|
79
|
+
RealtimeModelErrorEvent,
|
|
80
|
+
RealtimeModelEvent,
|
|
81
|
+
RealtimeModelExceptionEvent,
|
|
82
|
+
RealtimeModelInputAudioTranscriptionCompletedEvent,
|
|
83
|
+
RealtimeModelItemDeletedEvent,
|
|
84
|
+
RealtimeModelItemUpdatedEvent,
|
|
85
|
+
RealtimeModelToolCallEvent,
|
|
86
|
+
RealtimeModelTranscriptDeltaEvent,
|
|
87
|
+
RealtimeModelTurnEndedEvent,
|
|
88
|
+
RealtimeModelTurnStartedEvent,
|
|
89
|
+
)
|
|
90
|
+
from .model_inputs import (
|
|
91
|
+
RealtimeModelSendAudio,
|
|
92
|
+
RealtimeModelSendEvent,
|
|
93
|
+
RealtimeModelSendInterrupt,
|
|
94
|
+
RealtimeModelSendRawMessage,
|
|
95
|
+
RealtimeModelSendSessionUpdate,
|
|
96
|
+
RealtimeModelSendToolOutput,
|
|
97
|
+
RealtimeModelSendUserInput,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
DEFAULT_MODEL_SETTINGS: RealtimeSessionModelSettings = {
|
|
101
|
+
"voice": "ash",
|
|
102
|
+
"modalities": ["text", "audio"],
|
|
103
|
+
"input_audio_format": "pcm16",
|
|
104
|
+
"output_audio_format": "pcm16",
|
|
105
|
+
"input_audio_transcription": {
|
|
106
|
+
"model": "gpt-4o-mini-transcribe",
|
|
107
|
+
},
|
|
108
|
+
"turn_detection": {"type": "semantic_vad"},
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
async def get_api_key(key: str | Callable[[], MaybeAwaitable[str]] | None) -> str | None:
|
|
113
|
+
if isinstance(key, str):
|
|
114
|
+
return key
|
|
115
|
+
elif callable(key):
|
|
116
|
+
result = key()
|
|
117
|
+
if inspect.isawaitable(result):
|
|
118
|
+
return await result
|
|
119
|
+
return result
|
|
120
|
+
|
|
121
|
+
return os.getenv("OPENAI_API_KEY")
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class OpenAIRealtimeWebSocketModel(RealtimeModel):
|
|
125
|
+
"""A model that uses OpenAI's WebSocket API."""
|
|
126
|
+
|
|
127
|
+
def __init__(self) -> None:
|
|
128
|
+
self.model = "gpt-4o-realtime-preview" # Default model
|
|
129
|
+
self._websocket: ClientConnection | None = None
|
|
130
|
+
self._websocket_task: asyncio.Task[None] | None = None
|
|
131
|
+
self._listeners: list[RealtimeModelListener] = []
|
|
132
|
+
self._current_item_id: str | None = None
|
|
133
|
+
self._audio_start_time: datetime | None = None
|
|
134
|
+
self._audio_length_ms: float = 0.0
|
|
135
|
+
self._ongoing_response: bool = False
|
|
136
|
+
self._current_audio_content_index: int | None = None
|
|
137
|
+
self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None
|
|
138
|
+
|
|
139
|
+
async def connect(self, options: RealtimeModelConfig) -> None:
|
|
140
|
+
"""Establish a connection to the model and keep it alive."""
|
|
141
|
+
assert self._websocket is None, "Already connected"
|
|
142
|
+
assert self._websocket_task is None, "Already connected"
|
|
143
|
+
|
|
144
|
+
model_settings: RealtimeSessionModelSettings = options.get("initial_model_settings", {})
|
|
145
|
+
|
|
146
|
+
self.model = model_settings.get("model_name", self.model)
|
|
147
|
+
api_key = await get_api_key(options.get("api_key"))
|
|
148
|
+
|
|
149
|
+
if "tracing" in model_settings:
|
|
150
|
+
self._tracing_config = model_settings["tracing"]
|
|
151
|
+
else:
|
|
152
|
+
self._tracing_config = "auto"
|
|
153
|
+
|
|
154
|
+
if not api_key:
|
|
155
|
+
raise UserError("API key is required but was not provided.")
|
|
156
|
+
|
|
157
|
+
url = options.get("url", f"wss://api.openai.com/v1/realtime?model={self.model}")
|
|
158
|
+
|
|
159
|
+
headers = {
|
|
160
|
+
"Authorization": f"Bearer {api_key}",
|
|
161
|
+
"OpenAI-Beta": "realtime=v1",
|
|
162
|
+
}
|
|
163
|
+
self._websocket = await websockets.connect(url, additional_headers=headers)
|
|
164
|
+
self._websocket_task = asyncio.create_task(self._listen_for_messages())
|
|
165
|
+
await self._update_session_config(model_settings)
|
|
166
|
+
|
|
167
|
+
async def _send_tracing_config(
|
|
168
|
+
self, tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None
|
|
169
|
+
) -> None:
|
|
170
|
+
"""Update tracing configuration via session.update event."""
|
|
171
|
+
if tracing_config is not None:
|
|
172
|
+
converted_tracing_config = _ConversionHelper.convert_tracing_config(tracing_config)
|
|
173
|
+
await self._send_raw_message(
|
|
174
|
+
OpenAISessionUpdateEvent(
|
|
175
|
+
session=OpenAISessionObject(tracing=converted_tracing_config),
|
|
176
|
+
type="session.update",
|
|
177
|
+
)
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def add_listener(self, listener: RealtimeModelListener) -> None:
|
|
181
|
+
"""Add a listener to the model."""
|
|
182
|
+
if listener not in self._listeners:
|
|
183
|
+
self._listeners.append(listener)
|
|
184
|
+
|
|
185
|
+
def remove_listener(self, listener: RealtimeModelListener) -> None:
|
|
186
|
+
"""Remove a listener from the model."""
|
|
187
|
+
if listener in self._listeners:
|
|
188
|
+
self._listeners.remove(listener)
|
|
189
|
+
|
|
190
|
+
async def _emit_event(self, event: RealtimeModelEvent) -> None:
|
|
191
|
+
"""Emit an event to the listeners."""
|
|
192
|
+
for listener in self._listeners:
|
|
193
|
+
await listener.on_event(event)
|
|
194
|
+
|
|
195
|
+
async def _listen_for_messages(self):
|
|
196
|
+
assert self._websocket is not None, "Not connected"
|
|
197
|
+
|
|
198
|
+
try:
|
|
199
|
+
async for message in self._websocket:
|
|
200
|
+
try:
|
|
201
|
+
parsed = json.loads(message)
|
|
202
|
+
await self._handle_ws_event(parsed)
|
|
203
|
+
except json.JSONDecodeError as e:
|
|
204
|
+
await self._emit_event(
|
|
205
|
+
RealtimeModelExceptionEvent(
|
|
206
|
+
exception=e, context="Failed to parse WebSocket message as JSON"
|
|
207
|
+
)
|
|
208
|
+
)
|
|
209
|
+
except Exception as e:
|
|
210
|
+
await self._emit_event(
|
|
211
|
+
RealtimeModelExceptionEvent(
|
|
212
|
+
exception=e, context="Error handling WebSocket event"
|
|
213
|
+
)
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
except websockets.exceptions.ConnectionClosedOK:
|
|
217
|
+
# Normal connection closure - no exception event needed
|
|
218
|
+
logger.info("WebSocket connection closed normally")
|
|
219
|
+
except websockets.exceptions.ConnectionClosed as e:
|
|
220
|
+
await self._emit_event(
|
|
221
|
+
RealtimeModelExceptionEvent(
|
|
222
|
+
exception=e, context="WebSocket connection closed unexpectedly"
|
|
223
|
+
)
|
|
224
|
+
)
|
|
225
|
+
except Exception as e:
|
|
226
|
+
await self._emit_event(
|
|
227
|
+
RealtimeModelExceptionEvent(
|
|
228
|
+
exception=e, context="WebSocket error in message listener"
|
|
229
|
+
)
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
async def send_event(self, event: RealtimeModelSendEvent) -> None:
|
|
233
|
+
"""Send an event to the model."""
|
|
234
|
+
if isinstance(event, RealtimeModelSendRawMessage):
|
|
235
|
+
converted = _ConversionHelper.try_convert_raw_message(event)
|
|
236
|
+
if converted is not None:
|
|
237
|
+
await self._send_raw_message(converted)
|
|
238
|
+
else:
|
|
239
|
+
logger.error(f"Failed to convert raw message: {event}")
|
|
240
|
+
elif isinstance(event, RealtimeModelSendUserInput):
|
|
241
|
+
await self._send_user_input(event)
|
|
242
|
+
elif isinstance(event, RealtimeModelSendAudio):
|
|
243
|
+
await self._send_audio(event)
|
|
244
|
+
elif isinstance(event, RealtimeModelSendToolOutput):
|
|
245
|
+
await self._send_tool_output(event)
|
|
246
|
+
elif isinstance(event, RealtimeModelSendInterrupt):
|
|
247
|
+
await self._send_interrupt(event)
|
|
248
|
+
elif isinstance(event, RealtimeModelSendSessionUpdate):
|
|
249
|
+
await self._send_session_update(event)
|
|
250
|
+
else:
|
|
251
|
+
assert_never(event)
|
|
252
|
+
raise ValueError(f"Unknown event type: {type(event)}")
|
|
253
|
+
|
|
254
|
+
async def _send_raw_message(self, event: OpenAIRealtimeClientEvent) -> None:
|
|
255
|
+
"""Send a raw message to the model."""
|
|
256
|
+
assert self._websocket is not None, "Not connected"
|
|
257
|
+
|
|
258
|
+
await self._websocket.send(event.model_dump_json(exclude_none=True, exclude_unset=True))
|
|
259
|
+
|
|
260
|
+
async def _send_user_input(self, event: RealtimeModelSendUserInput) -> None:
|
|
261
|
+
converted = _ConversionHelper.convert_user_input_to_item_create(event)
|
|
262
|
+
await self._send_raw_message(converted)
|
|
263
|
+
await self._send_raw_message(OpenAIResponseCreateEvent(type="response.create"))
|
|
264
|
+
|
|
265
|
+
async def _send_audio(self, event: RealtimeModelSendAudio) -> None:
|
|
266
|
+
converted = _ConversionHelper.convert_audio_to_input_audio_buffer_append(event)
|
|
267
|
+
await self._send_raw_message(converted)
|
|
268
|
+
if event.commit:
|
|
269
|
+
await self._send_raw_message(
|
|
270
|
+
OpenAIInputAudioBufferCommitEvent(type="input_audio_buffer.commit")
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None:
|
|
274
|
+
converted = _ConversionHelper.convert_tool_output(event)
|
|
275
|
+
await self._send_raw_message(converted)
|
|
276
|
+
|
|
277
|
+
tool_item = RealtimeToolCallItem(
|
|
278
|
+
item_id=event.tool_call.id or "",
|
|
279
|
+
previous_item_id=event.tool_call.previous_item_id,
|
|
280
|
+
call_id=event.tool_call.call_id,
|
|
281
|
+
type="function_call",
|
|
282
|
+
status="completed",
|
|
283
|
+
arguments=event.tool_call.arguments,
|
|
284
|
+
name=event.tool_call.name,
|
|
285
|
+
output=event.output,
|
|
286
|
+
)
|
|
287
|
+
await self._emit_event(RealtimeModelItemUpdatedEvent(item=tool_item))
|
|
288
|
+
|
|
289
|
+
if event.start_response:
|
|
290
|
+
await self._send_raw_message(OpenAIResponseCreateEvent(type="response.create"))
|
|
291
|
+
|
|
292
|
+
async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
|
|
293
|
+
if not self._current_item_id or not self._audio_start_time:
|
|
294
|
+
return
|
|
295
|
+
|
|
296
|
+
await self._cancel_response()
|
|
297
|
+
|
|
298
|
+
elapsed_time_ms = (datetime.now() - self._audio_start_time).total_seconds() * 1000
|
|
299
|
+
if elapsed_time_ms > 0 and elapsed_time_ms < self._audio_length_ms:
|
|
300
|
+
await self._emit_event(RealtimeModelAudioInterruptedEvent())
|
|
301
|
+
converted = _ConversionHelper.convert_interrupt(
|
|
302
|
+
self._current_item_id,
|
|
303
|
+
self._current_audio_content_index or 0,
|
|
304
|
+
int(elapsed_time_ms),
|
|
305
|
+
)
|
|
306
|
+
await self._send_raw_message(converted)
|
|
307
|
+
|
|
308
|
+
self._current_item_id = None
|
|
309
|
+
self._audio_start_time = None
|
|
310
|
+
self._audio_length_ms = 0.0
|
|
311
|
+
self._current_audio_content_index = None
|
|
312
|
+
|
|
313
|
+
async def _send_session_update(self, event: RealtimeModelSendSessionUpdate) -> None:
|
|
314
|
+
"""Send a session update to the model."""
|
|
315
|
+
await self._update_session_config(event.session_settings)
|
|
316
|
+
|
|
317
|
+
async def _handle_audio_delta(self, parsed: ResponseAudioDeltaEvent) -> None:
|
|
318
|
+
"""Handle audio delta events and update audio tracking state."""
|
|
319
|
+
self._current_audio_content_index = parsed.content_index
|
|
320
|
+
self._current_item_id = parsed.item_id
|
|
321
|
+
if self._audio_start_time is None:
|
|
322
|
+
self._audio_start_time = datetime.now()
|
|
323
|
+
self._audio_length_ms = 0.0
|
|
324
|
+
|
|
325
|
+
audio_bytes = base64.b64decode(parsed.delta)
|
|
326
|
+
# Calculate audio length in ms using 24KHz pcm16le
|
|
327
|
+
self._audio_length_ms += self._calculate_audio_length_ms(audio_bytes)
|
|
328
|
+
await self._emit_event(
|
|
329
|
+
RealtimeModelAudioEvent(data=audio_bytes, response_id=parsed.response_id)
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
def _calculate_audio_length_ms(self, audio_bytes: bytes) -> float:
|
|
333
|
+
"""Calculate audio length in milliseconds for 24KHz PCM16LE format."""
|
|
334
|
+
return len(audio_bytes) / 24 / 2
|
|
335
|
+
|
|
336
|
+
async def _handle_output_item(self, item: ConversationItem) -> None:
|
|
337
|
+
"""Handle response output item events (function calls and messages)."""
|
|
338
|
+
if item.type == "function_call" and item.status == "completed":
|
|
339
|
+
tool_call = RealtimeToolCallItem(
|
|
340
|
+
item_id=item.id or "",
|
|
341
|
+
previous_item_id=None,
|
|
342
|
+
call_id=item.call_id,
|
|
343
|
+
type="function_call",
|
|
344
|
+
# We use the same item for tool call and output, so it will be completed by the
|
|
345
|
+
# output being added
|
|
346
|
+
status="in_progress",
|
|
347
|
+
arguments=item.arguments or "",
|
|
348
|
+
name=item.name or "",
|
|
349
|
+
output=None,
|
|
350
|
+
)
|
|
351
|
+
await self._emit_event(RealtimeModelItemUpdatedEvent(item=tool_call))
|
|
352
|
+
await self._emit_event(
|
|
353
|
+
RealtimeModelToolCallEvent(
|
|
354
|
+
call_id=item.call_id or "",
|
|
355
|
+
name=item.name or "",
|
|
356
|
+
arguments=item.arguments or "",
|
|
357
|
+
id=item.id or "",
|
|
358
|
+
)
|
|
359
|
+
)
|
|
360
|
+
elif item.type == "message":
|
|
361
|
+
# Handle message items from output_item events (no previous_item_id)
|
|
362
|
+
message_item: RealtimeMessageItem = TypeAdapter(RealtimeMessageItem).validate_python(
|
|
363
|
+
{
|
|
364
|
+
"item_id": item.id or "",
|
|
365
|
+
"type": item.type,
|
|
366
|
+
"role": item.role,
|
|
367
|
+
"content": (
|
|
368
|
+
[content.model_dump() for content in item.content] if item.content else []
|
|
369
|
+
),
|
|
370
|
+
"status": "in_progress",
|
|
371
|
+
}
|
|
372
|
+
)
|
|
373
|
+
await self._emit_event(RealtimeModelItemUpdatedEvent(item=message_item))
|
|
374
|
+
|
|
375
|
+
async def _handle_conversation_item(
|
|
376
|
+
self, item: ConversationItem, previous_item_id: str | None
|
|
377
|
+
) -> None:
|
|
378
|
+
"""Handle conversation item creation/retrieval events."""
|
|
379
|
+
message_item = _ConversionHelper.conversation_item_to_realtime_message_item(
|
|
380
|
+
item, previous_item_id
|
|
381
|
+
)
|
|
382
|
+
await self._emit_event(RealtimeModelItemUpdatedEvent(item=message_item))
|
|
383
|
+
|
|
384
|
+
async def close(self) -> None:
|
|
385
|
+
"""Close the session."""
|
|
386
|
+
if self._websocket:
|
|
387
|
+
await self._websocket.close()
|
|
388
|
+
self._websocket = None
|
|
389
|
+
if self._websocket_task:
|
|
390
|
+
self._websocket_task.cancel()
|
|
391
|
+
self._websocket_task = None
|
|
392
|
+
|
|
393
|
+
async def _cancel_response(self) -> None:
|
|
394
|
+
if self._ongoing_response:
|
|
395
|
+
await self._send_raw_message(OpenAIResponseCancelEvent(type="response.cancel"))
|
|
396
|
+
self._ongoing_response = False
|
|
397
|
+
|
|
398
|
+
async def _handle_ws_event(self, event: dict[str, Any]):
|
|
399
|
+
try:
|
|
400
|
+
if "previous_item_id" in event and event["previous_item_id"] is None:
|
|
401
|
+
event["previous_item_id"] = "" # TODO (rm) remove
|
|
402
|
+
parsed: OpenAIRealtimeServerEvent = TypeAdapter(
|
|
403
|
+
OpenAIRealtimeServerEvent
|
|
404
|
+
).validate_python(event)
|
|
405
|
+
except pydantic.ValidationError as e:
|
|
406
|
+
logger.error(f"Failed to validate server event: {event}", exc_info=True)
|
|
407
|
+
await self._emit_event(
|
|
408
|
+
RealtimeModelErrorEvent(
|
|
409
|
+
error=e,
|
|
410
|
+
)
|
|
411
|
+
)
|
|
412
|
+
return
|
|
413
|
+
except Exception as e:
|
|
414
|
+
event_type = event.get("type", "unknown") if isinstance(event, dict) else "unknown"
|
|
415
|
+
logger.error(f"Failed to validate server event: {event}", exc_info=True)
|
|
416
|
+
await self._emit_event(
|
|
417
|
+
RealtimeModelExceptionEvent(
|
|
418
|
+
exception=e,
|
|
419
|
+
context=f"Failed to validate server event: {event_type}",
|
|
420
|
+
)
|
|
421
|
+
)
|
|
422
|
+
return
|
|
423
|
+
|
|
424
|
+
if parsed.type == "response.audio.delta":
|
|
425
|
+
await self._handle_audio_delta(parsed)
|
|
426
|
+
elif parsed.type == "response.audio.done":
|
|
427
|
+
await self._emit_event(RealtimeModelAudioDoneEvent())
|
|
428
|
+
elif parsed.type == "input_audio_buffer.speech_started":
|
|
429
|
+
await self._send_interrupt(RealtimeModelSendInterrupt())
|
|
430
|
+
elif parsed.type == "response.created":
|
|
431
|
+
self._ongoing_response = True
|
|
432
|
+
await self._emit_event(RealtimeModelTurnStartedEvent())
|
|
433
|
+
elif parsed.type == "response.done":
|
|
434
|
+
self._ongoing_response = False
|
|
435
|
+
await self._emit_event(RealtimeModelTurnEndedEvent())
|
|
436
|
+
elif parsed.type == "session.created":
|
|
437
|
+
await self._send_tracing_config(self._tracing_config)
|
|
438
|
+
elif parsed.type == "error":
|
|
439
|
+
await self._emit_event(RealtimeModelErrorEvent(error=parsed.error))
|
|
440
|
+
elif parsed.type == "conversation.item.deleted":
|
|
441
|
+
await self._emit_event(RealtimeModelItemDeletedEvent(item_id=parsed.item_id))
|
|
442
|
+
elif (
|
|
443
|
+
parsed.type == "conversation.item.created"
|
|
444
|
+
or parsed.type == "conversation.item.retrieved"
|
|
445
|
+
):
|
|
446
|
+
previous_item_id = (
|
|
447
|
+
parsed.previous_item_id if parsed.type == "conversation.item.created" else None
|
|
448
|
+
)
|
|
449
|
+
if parsed.item.type == "message":
|
|
450
|
+
await self._handle_conversation_item(parsed.item, previous_item_id)
|
|
451
|
+
elif (
|
|
452
|
+
parsed.type == "conversation.item.input_audio_transcription.completed"
|
|
453
|
+
or parsed.type == "conversation.item.truncated"
|
|
454
|
+
):
|
|
455
|
+
if self._current_item_id:
|
|
456
|
+
await self._send_raw_message(
|
|
457
|
+
OpenAIConversationItemRetrieveEvent(
|
|
458
|
+
type="conversation.item.retrieve",
|
|
459
|
+
item_id=self._current_item_id,
|
|
460
|
+
)
|
|
461
|
+
)
|
|
462
|
+
if parsed.type == "conversation.item.input_audio_transcription.completed":
|
|
463
|
+
await self._emit_event(
|
|
464
|
+
RealtimeModelInputAudioTranscriptionCompletedEvent(
|
|
465
|
+
item_id=parsed.item_id, transcript=parsed.transcript
|
|
466
|
+
)
|
|
467
|
+
)
|
|
468
|
+
elif parsed.type == "response.audio_transcript.delta":
|
|
469
|
+
await self._emit_event(
|
|
470
|
+
RealtimeModelTranscriptDeltaEvent(
|
|
471
|
+
item_id=parsed.item_id, delta=parsed.delta, response_id=parsed.response_id
|
|
472
|
+
)
|
|
473
|
+
)
|
|
474
|
+
elif (
|
|
475
|
+
parsed.type == "conversation.item.input_audio_transcription.delta"
|
|
476
|
+
or parsed.type == "response.text.delta"
|
|
477
|
+
or parsed.type == "response.function_call_arguments.delta"
|
|
478
|
+
):
|
|
479
|
+
# No support for partials yet
|
|
480
|
+
pass
|
|
481
|
+
elif (
|
|
482
|
+
parsed.type == "response.output_item.added"
|
|
483
|
+
or parsed.type == "response.output_item.done"
|
|
484
|
+
):
|
|
485
|
+
await self._handle_output_item(parsed.item)
|
|
486
|
+
|
|
487
|
+
async def _update_session_config(self, model_settings: RealtimeSessionModelSettings) -> None:
|
|
488
|
+
session_config = self._get_session_config(model_settings)
|
|
489
|
+
await self._send_raw_message(
|
|
490
|
+
OpenAISessionUpdateEvent(session=session_config, type="session.update")
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
def _get_session_config(
|
|
494
|
+
self, model_settings: RealtimeSessionModelSettings
|
|
495
|
+
) -> OpenAISessionObject:
|
|
496
|
+
"""Get the session config."""
|
|
497
|
+
return OpenAISessionObject(
|
|
498
|
+
instructions=model_settings.get("instructions", None),
|
|
499
|
+
model=(
|
|
500
|
+
model_settings.get("model_name", self.model) # type: ignore
|
|
501
|
+
or DEFAULT_MODEL_SETTINGS.get("model_name")
|
|
502
|
+
),
|
|
503
|
+
voice=model_settings.get("voice", DEFAULT_MODEL_SETTINGS.get("voice")),
|
|
504
|
+
modalities=model_settings.get("modalities", DEFAULT_MODEL_SETTINGS.get("modalities")),
|
|
505
|
+
input_audio_format=model_settings.get(
|
|
506
|
+
"input_audio_format",
|
|
507
|
+
DEFAULT_MODEL_SETTINGS.get("input_audio_format"), # type: ignore
|
|
508
|
+
),
|
|
509
|
+
output_audio_format=model_settings.get(
|
|
510
|
+
"output_audio_format",
|
|
511
|
+
DEFAULT_MODEL_SETTINGS.get("output_audio_format"), # type: ignore
|
|
512
|
+
),
|
|
513
|
+
input_audio_transcription=model_settings.get(
|
|
514
|
+
"input_audio_transcription",
|
|
515
|
+
DEFAULT_MODEL_SETTINGS.get("input_audio_transcription"), # type: ignore
|
|
516
|
+
),
|
|
517
|
+
turn_detection=model_settings.get(
|
|
518
|
+
"turn_detection",
|
|
519
|
+
DEFAULT_MODEL_SETTINGS.get("turn_detection"), # type: ignore
|
|
520
|
+
),
|
|
521
|
+
tool_choice=model_settings.get(
|
|
522
|
+
"tool_choice",
|
|
523
|
+
DEFAULT_MODEL_SETTINGS.get("tool_choice"), # type: ignore
|
|
524
|
+
),
|
|
525
|
+
tools=self._tools_to_session_tools(
|
|
526
|
+
tools=model_settings.get("tools", []), handoffs=model_settings.get("handoffs", [])
|
|
527
|
+
),
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
def _tools_to_session_tools(
|
|
531
|
+
self, tools: list[Tool], handoffs: list[Handoff]
|
|
532
|
+
) -> list[OpenAISessionTool]:
|
|
533
|
+
converted_tools: list[OpenAISessionTool] = []
|
|
534
|
+
for tool in tools:
|
|
535
|
+
if not isinstance(tool, FunctionTool):
|
|
536
|
+
raise UserError(f"Tool {tool.name} is unsupported. Must be a function tool.")
|
|
537
|
+
converted_tools.append(
|
|
538
|
+
OpenAISessionTool(
|
|
539
|
+
name=tool.name,
|
|
540
|
+
description=tool.description,
|
|
541
|
+
parameters=tool.params_json_schema,
|
|
542
|
+
type="function",
|
|
543
|
+
)
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
for handoff in handoffs:
|
|
547
|
+
converted_tools.append(
|
|
548
|
+
OpenAISessionTool(
|
|
549
|
+
name=handoff.tool_name,
|
|
550
|
+
description=handoff.tool_description,
|
|
551
|
+
parameters=handoff.input_json_schema,
|
|
552
|
+
type="function",
|
|
553
|
+
)
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
return converted_tools
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
class _ConversionHelper:
|
|
560
|
+
@classmethod
|
|
561
|
+
def conversation_item_to_realtime_message_item(
|
|
562
|
+
cls, item: ConversationItem, previous_item_id: str | None
|
|
563
|
+
) -> RealtimeMessageItem:
|
|
564
|
+
return TypeAdapter(RealtimeMessageItem).validate_python(
|
|
565
|
+
{
|
|
566
|
+
"item_id": item.id or "",
|
|
567
|
+
"previous_item_id": previous_item_id,
|
|
568
|
+
"type": item.type,
|
|
569
|
+
"role": item.role,
|
|
570
|
+
"content": (
|
|
571
|
+
[content.model_dump() for content in item.content] if item.content else []
|
|
572
|
+
),
|
|
573
|
+
"status": "in_progress",
|
|
574
|
+
},
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
@classmethod
|
|
578
|
+
def try_convert_raw_message(
|
|
579
|
+
cls, message: RealtimeModelSendRawMessage
|
|
580
|
+
) -> OpenAIRealtimeClientEvent | None:
|
|
581
|
+
try:
|
|
582
|
+
data = {}
|
|
583
|
+
data["type"] = message.message["type"]
|
|
584
|
+
data.update(message.message.get("other_data", {}))
|
|
585
|
+
return TypeAdapter(OpenAIRealtimeClientEvent).validate_python(data)
|
|
586
|
+
except Exception:
|
|
587
|
+
return None
|
|
588
|
+
|
|
589
|
+
@classmethod
|
|
590
|
+
def convert_tracing_config(
|
|
591
|
+
cls, tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None
|
|
592
|
+
) -> OpenAISessionTracing | None:
|
|
593
|
+
if tracing_config is None:
|
|
594
|
+
return None
|
|
595
|
+
elif tracing_config == "auto":
|
|
596
|
+
return "auto"
|
|
597
|
+
return OpenAISessionTracingConfiguration(
|
|
598
|
+
group_id=tracing_config.get("group_id"),
|
|
599
|
+
metadata=tracing_config.get("metadata"),
|
|
600
|
+
workflow_name=tracing_config.get("workflow_name"),
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
@classmethod
|
|
604
|
+
def convert_user_input_to_conversation_item(
|
|
605
|
+
cls, event: RealtimeModelSendUserInput
|
|
606
|
+
) -> OpenAIConversationItem:
|
|
607
|
+
user_input = event.user_input
|
|
608
|
+
|
|
609
|
+
if isinstance(user_input, dict):
|
|
610
|
+
return OpenAIConversationItem(
|
|
611
|
+
type="message",
|
|
612
|
+
role="user",
|
|
613
|
+
content=[
|
|
614
|
+
OpenAIConversationItemContent(
|
|
615
|
+
type="input_text",
|
|
616
|
+
text=item.get("text"),
|
|
617
|
+
)
|
|
618
|
+
for item in user_input.get("content", [])
|
|
619
|
+
],
|
|
620
|
+
)
|
|
621
|
+
else:
|
|
622
|
+
return OpenAIConversationItem(
|
|
623
|
+
type="message",
|
|
624
|
+
role="user",
|
|
625
|
+
content=[OpenAIConversationItemContent(type="input_text", text=user_input)],
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
@classmethod
|
|
629
|
+
def convert_user_input_to_item_create(
|
|
630
|
+
cls, event: RealtimeModelSendUserInput
|
|
631
|
+
) -> OpenAIRealtimeClientEvent:
|
|
632
|
+
return OpenAIConversationItemCreateEvent(
|
|
633
|
+
type="conversation.item.create",
|
|
634
|
+
item=cls.convert_user_input_to_conversation_item(event),
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
@classmethod
|
|
638
|
+
def convert_audio_to_input_audio_buffer_append(
|
|
639
|
+
cls, event: RealtimeModelSendAudio
|
|
640
|
+
) -> OpenAIRealtimeClientEvent:
|
|
641
|
+
base64_audio = base64.b64encode(event.audio).decode("utf-8")
|
|
642
|
+
return OpenAIInputAudioBufferAppendEvent(
|
|
643
|
+
type="input_audio_buffer.append",
|
|
644
|
+
audio=base64_audio,
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
@classmethod
|
|
648
|
+
def convert_tool_output(cls, event: RealtimeModelSendToolOutput) -> OpenAIRealtimeClientEvent:
|
|
649
|
+
return OpenAIConversationItemCreateEvent(
|
|
650
|
+
type="conversation.item.create",
|
|
651
|
+
item=OpenAIConversationItem(
|
|
652
|
+
type="function_call_output",
|
|
653
|
+
output=event.output,
|
|
654
|
+
call_id=event.tool_call.call_id,
|
|
655
|
+
),
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
@classmethod
|
|
659
|
+
def convert_interrupt(
|
|
660
|
+
cls,
|
|
661
|
+
current_item_id: str,
|
|
662
|
+
current_audio_content_index: int,
|
|
663
|
+
elapsed_time_ms: int,
|
|
664
|
+
) -> OpenAIRealtimeClientEvent:
|
|
665
|
+
return OpenAIConversationItemTruncateEvent(
|
|
666
|
+
type="conversation.item.truncate",
|
|
667
|
+
item_id=current_item_id,
|
|
668
|
+
content_index=current_audio_content_index,
|
|
669
|
+
audio_end_ms=elapsed_time_ms,
|
|
670
|
+
)
|