openai-agents 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +5 -1
- agents/_run_impl.py +5 -1
- agents/agent.py +61 -29
- agents/function_schema.py +11 -1
- agents/guardrail.py +5 -1
- agents/lifecycle.py +26 -17
- agents/mcp/server.py +43 -11
- agents/mcp/util.py +5 -6
- agents/memory/__init__.py +3 -0
- agents/memory/session.py +369 -0
- agents/model_settings.py +15 -7
- agents/models/chatcmpl_converter.py +19 -2
- agents/models/chatcmpl_stream_handler.py +1 -1
- agents/models/openai_responses.py +11 -4
- agents/realtime/README.md +3 -0
- agents/realtime/__init__.py +174 -0
- agents/realtime/agent.py +80 -0
- agents/realtime/config.py +128 -0
- agents/realtime/events.py +216 -0
- agents/realtime/items.py +91 -0
- agents/realtime/model.py +69 -0
- agents/realtime/model_events.py +159 -0
- agents/realtime/model_inputs.py +100 -0
- agents/realtime/openai_realtime.py +584 -0
- agents/realtime/runner.py +118 -0
- agents/realtime/session.py +502 -0
- agents/run.py +106 -4
- agents/tool.py +6 -7
- agents/tool_context.py +16 -3
- agents/voice/models/openai_stt.py +1 -1
- agents/voice/pipeline.py +6 -0
- agents/voice/workflow.py +8 -0
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.0.dist-info}/METADATA +120 -3
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.0.dist-info}/RECORD +36 -22
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.0.dist-info}/WHEEL +0 -0
- {openai_agents-0.1.0.dist-info → openai_agents-0.2.0.dist-info}/licenses/LICENSE +0 -0
agents/memory/session.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import sqlite3
|
|
6
|
+
import threading
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from ..items import TResponseInputItem
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@runtime_checkable
|
|
16
|
+
class Session(Protocol):
|
|
17
|
+
"""Protocol for session implementations.
|
|
18
|
+
|
|
19
|
+
Session stores conversation history for a specific session, allowing
|
|
20
|
+
agents to maintain context without requiring explicit manual memory management.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
session_id: str
|
|
24
|
+
|
|
25
|
+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
|
|
26
|
+
"""Retrieve the conversation history for this session.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
limit: Maximum number of items to retrieve. If None, retrieves all items.
|
|
30
|
+
When specified, returns the latest N items in chronological order.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
List of input items representing the conversation history
|
|
34
|
+
"""
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
|
38
|
+
"""Add new items to the conversation history.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
items: List of input items to add to the history
|
|
42
|
+
"""
|
|
43
|
+
...
|
|
44
|
+
|
|
45
|
+
async def pop_item(self) -> TResponseInputItem | None:
|
|
46
|
+
"""Remove and return the most recent item from the session.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
The most recent item if it exists, None if the session is empty
|
|
50
|
+
"""
|
|
51
|
+
...
|
|
52
|
+
|
|
53
|
+
async def clear_session(self) -> None:
|
|
54
|
+
"""Clear all items for this session."""
|
|
55
|
+
...
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class SessionABC(ABC):
|
|
59
|
+
"""Abstract base class for session implementations.
|
|
60
|
+
|
|
61
|
+
Session stores conversation history for a specific session, allowing
|
|
62
|
+
agents to maintain context without requiring explicit manual memory management.
|
|
63
|
+
|
|
64
|
+
This ABC is intended for internal use and as a base class for concrete implementations.
|
|
65
|
+
Third-party libraries should implement the Session protocol instead.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
session_id: str
|
|
69
|
+
|
|
70
|
+
@abstractmethod
|
|
71
|
+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
|
|
72
|
+
"""Retrieve the conversation history for this session.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
limit: Maximum number of items to retrieve. If None, retrieves all items.
|
|
76
|
+
When specified, returns the latest N items in chronological order.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
List of input items representing the conversation history
|
|
80
|
+
"""
|
|
81
|
+
...
|
|
82
|
+
|
|
83
|
+
@abstractmethod
|
|
84
|
+
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
|
85
|
+
"""Add new items to the conversation history.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
items: List of input items to add to the history
|
|
89
|
+
"""
|
|
90
|
+
...
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
async def pop_item(self) -> TResponseInputItem | None:
|
|
94
|
+
"""Remove and return the most recent item from the session.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
The most recent item if it exists, None if the session is empty
|
|
98
|
+
"""
|
|
99
|
+
...
|
|
100
|
+
|
|
101
|
+
@abstractmethod
|
|
102
|
+
async def clear_session(self) -> None:
|
|
103
|
+
"""Clear all items for this session."""
|
|
104
|
+
...
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class SQLiteSession(SessionABC):
|
|
108
|
+
"""SQLite-based implementation of session storage.
|
|
109
|
+
|
|
110
|
+
This implementation stores conversation history in a SQLite database.
|
|
111
|
+
By default, uses an in-memory database that is lost when the process ends.
|
|
112
|
+
For persistent storage, provide a file path.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
session_id: str,
|
|
118
|
+
db_path: str | Path = ":memory:",
|
|
119
|
+
sessions_table: str = "agent_sessions",
|
|
120
|
+
messages_table: str = "agent_messages",
|
|
121
|
+
):
|
|
122
|
+
"""Initialize the SQLite session.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
session_id: Unique identifier for the conversation session
|
|
126
|
+
db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database)
|
|
127
|
+
sessions_table: Name of the table to store session metadata. Defaults to
|
|
128
|
+
'agent_sessions'
|
|
129
|
+
messages_table: Name of the table to store message data. Defaults to 'agent_messages'
|
|
130
|
+
"""
|
|
131
|
+
self.session_id = session_id
|
|
132
|
+
self.db_path = db_path
|
|
133
|
+
self.sessions_table = sessions_table
|
|
134
|
+
self.messages_table = messages_table
|
|
135
|
+
self._local = threading.local()
|
|
136
|
+
self._lock = threading.Lock()
|
|
137
|
+
|
|
138
|
+
# For in-memory databases, we need a shared connection to avoid thread isolation
|
|
139
|
+
# For file databases, we use thread-local connections for better concurrency
|
|
140
|
+
self._is_memory_db = str(db_path) == ":memory:"
|
|
141
|
+
if self._is_memory_db:
|
|
142
|
+
self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False)
|
|
143
|
+
self._shared_connection.execute("PRAGMA journal_mode=WAL")
|
|
144
|
+
self._init_db_for_connection(self._shared_connection)
|
|
145
|
+
else:
|
|
146
|
+
# For file databases, initialize the schema once since it persists
|
|
147
|
+
init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
|
148
|
+
init_conn.execute("PRAGMA journal_mode=WAL")
|
|
149
|
+
self._init_db_for_connection(init_conn)
|
|
150
|
+
init_conn.close()
|
|
151
|
+
|
|
152
|
+
def _get_connection(self) -> sqlite3.Connection:
|
|
153
|
+
"""Get a database connection."""
|
|
154
|
+
if self._is_memory_db:
|
|
155
|
+
# Use shared connection for in-memory database to avoid thread isolation
|
|
156
|
+
return self._shared_connection
|
|
157
|
+
else:
|
|
158
|
+
# Use thread-local connections for file databases
|
|
159
|
+
if not hasattr(self._local, "connection"):
|
|
160
|
+
self._local.connection = sqlite3.connect(
|
|
161
|
+
str(self.db_path),
|
|
162
|
+
check_same_thread=False,
|
|
163
|
+
)
|
|
164
|
+
self._local.connection.execute("PRAGMA journal_mode=WAL")
|
|
165
|
+
assert isinstance(self._local.connection, sqlite3.Connection), (
|
|
166
|
+
f"Expected sqlite3.Connection, got {type(self._local.connection)}"
|
|
167
|
+
)
|
|
168
|
+
return self._local.connection
|
|
169
|
+
|
|
170
|
+
def _init_db_for_connection(self, conn: sqlite3.Connection) -> None:
|
|
171
|
+
"""Initialize the database schema for a specific connection."""
|
|
172
|
+
conn.execute(
|
|
173
|
+
f"""
|
|
174
|
+
CREATE TABLE IF NOT EXISTS {self.sessions_table} (
|
|
175
|
+
session_id TEXT PRIMARY KEY,
|
|
176
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
177
|
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
178
|
+
)
|
|
179
|
+
"""
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
conn.execute(
|
|
183
|
+
f"""
|
|
184
|
+
CREATE TABLE IF NOT EXISTS {self.messages_table} (
|
|
185
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
186
|
+
session_id TEXT NOT NULL,
|
|
187
|
+
message_data TEXT NOT NULL,
|
|
188
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
189
|
+
FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id)
|
|
190
|
+
ON DELETE CASCADE
|
|
191
|
+
)
|
|
192
|
+
"""
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
conn.execute(
|
|
196
|
+
f"""
|
|
197
|
+
CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id
|
|
198
|
+
ON {self.messages_table} (session_id, created_at)
|
|
199
|
+
"""
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
conn.commit()
|
|
203
|
+
|
|
204
|
+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
|
|
205
|
+
"""Retrieve the conversation history for this session.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
limit: Maximum number of items to retrieve. If None, retrieves all items.
|
|
209
|
+
When specified, returns the latest N items in chronological order.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
List of input items representing the conversation history
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
def _get_items_sync():
|
|
216
|
+
conn = self._get_connection()
|
|
217
|
+
with self._lock if self._is_memory_db else threading.Lock():
|
|
218
|
+
if limit is None:
|
|
219
|
+
# Fetch all items in chronological order
|
|
220
|
+
cursor = conn.execute(
|
|
221
|
+
f"""
|
|
222
|
+
SELECT message_data FROM {self.messages_table}
|
|
223
|
+
WHERE session_id = ?
|
|
224
|
+
ORDER BY created_at ASC
|
|
225
|
+
""",
|
|
226
|
+
(self.session_id,),
|
|
227
|
+
)
|
|
228
|
+
else:
|
|
229
|
+
# Fetch the latest N items in chronological order
|
|
230
|
+
cursor = conn.execute(
|
|
231
|
+
f"""
|
|
232
|
+
SELECT message_data FROM {self.messages_table}
|
|
233
|
+
WHERE session_id = ?
|
|
234
|
+
ORDER BY created_at DESC
|
|
235
|
+
LIMIT ?
|
|
236
|
+
""",
|
|
237
|
+
(self.session_id, limit),
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
rows = cursor.fetchall()
|
|
241
|
+
|
|
242
|
+
# Reverse to get chronological order when using DESC
|
|
243
|
+
if limit is not None:
|
|
244
|
+
rows = list(reversed(rows))
|
|
245
|
+
|
|
246
|
+
items = []
|
|
247
|
+
for (message_data,) in rows:
|
|
248
|
+
try:
|
|
249
|
+
item = json.loads(message_data)
|
|
250
|
+
items.append(item)
|
|
251
|
+
except json.JSONDecodeError:
|
|
252
|
+
# Skip invalid JSON entries
|
|
253
|
+
continue
|
|
254
|
+
|
|
255
|
+
return items
|
|
256
|
+
|
|
257
|
+
return await asyncio.to_thread(_get_items_sync)
|
|
258
|
+
|
|
259
|
+
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
|
260
|
+
"""Add new items to the conversation history.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
items: List of input items to add to the history
|
|
264
|
+
"""
|
|
265
|
+
if not items:
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
def _add_items_sync():
|
|
269
|
+
conn = self._get_connection()
|
|
270
|
+
|
|
271
|
+
with self._lock if self._is_memory_db else threading.Lock():
|
|
272
|
+
# Ensure session exists
|
|
273
|
+
conn.execute(
|
|
274
|
+
f"""
|
|
275
|
+
INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?)
|
|
276
|
+
""",
|
|
277
|
+
(self.session_id,),
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Add items
|
|
281
|
+
message_data = [(self.session_id, json.dumps(item)) for item in items]
|
|
282
|
+
conn.executemany(
|
|
283
|
+
f"""
|
|
284
|
+
INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?)
|
|
285
|
+
""",
|
|
286
|
+
message_data,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Update session timestamp
|
|
290
|
+
conn.execute(
|
|
291
|
+
f"""
|
|
292
|
+
UPDATE {self.sessions_table}
|
|
293
|
+
SET updated_at = CURRENT_TIMESTAMP
|
|
294
|
+
WHERE session_id = ?
|
|
295
|
+
""",
|
|
296
|
+
(self.session_id,),
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
conn.commit()
|
|
300
|
+
|
|
301
|
+
await asyncio.to_thread(_add_items_sync)
|
|
302
|
+
|
|
303
|
+
async def pop_item(self) -> TResponseInputItem | None:
|
|
304
|
+
"""Remove and return the most recent item from the session.
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
The most recent item if it exists, None if the session is empty
|
|
308
|
+
"""
|
|
309
|
+
|
|
310
|
+
def _pop_item_sync():
|
|
311
|
+
conn = self._get_connection()
|
|
312
|
+
with self._lock if self._is_memory_db else threading.Lock():
|
|
313
|
+
# Use DELETE with RETURNING to atomically delete and return the most recent item
|
|
314
|
+
cursor = conn.execute(
|
|
315
|
+
f"""
|
|
316
|
+
DELETE FROM {self.messages_table}
|
|
317
|
+
WHERE id = (
|
|
318
|
+
SELECT id FROM {self.messages_table}
|
|
319
|
+
WHERE session_id = ?
|
|
320
|
+
ORDER BY created_at DESC
|
|
321
|
+
LIMIT 1
|
|
322
|
+
)
|
|
323
|
+
RETURNING message_data
|
|
324
|
+
""",
|
|
325
|
+
(self.session_id,),
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
result = cursor.fetchone()
|
|
329
|
+
conn.commit()
|
|
330
|
+
|
|
331
|
+
if result:
|
|
332
|
+
message_data = result[0]
|
|
333
|
+
try:
|
|
334
|
+
item = json.loads(message_data)
|
|
335
|
+
return item
|
|
336
|
+
except json.JSONDecodeError:
|
|
337
|
+
# Return None for corrupted JSON entries (already deleted)
|
|
338
|
+
return None
|
|
339
|
+
|
|
340
|
+
return None
|
|
341
|
+
|
|
342
|
+
return await asyncio.to_thread(_pop_item_sync)
|
|
343
|
+
|
|
344
|
+
async def clear_session(self) -> None:
|
|
345
|
+
"""Clear all items for this session."""
|
|
346
|
+
|
|
347
|
+
def _clear_session_sync():
|
|
348
|
+
conn = self._get_connection()
|
|
349
|
+
with self._lock if self._is_memory_db else threading.Lock():
|
|
350
|
+
conn.execute(
|
|
351
|
+
f"DELETE FROM {self.messages_table} WHERE session_id = ?",
|
|
352
|
+
(self.session_id,),
|
|
353
|
+
)
|
|
354
|
+
conn.execute(
|
|
355
|
+
f"DELETE FROM {self.sessions_table} WHERE session_id = ?",
|
|
356
|
+
(self.session_id,),
|
|
357
|
+
)
|
|
358
|
+
conn.commit()
|
|
359
|
+
|
|
360
|
+
await asyncio.to_thread(_clear_session_sync)
|
|
361
|
+
|
|
362
|
+
def close(self) -> None:
|
|
363
|
+
"""Close the database connection."""
|
|
364
|
+
if self._is_memory_db:
|
|
365
|
+
if hasattr(self, "_shared_connection"):
|
|
366
|
+
self._shared_connection.close()
|
|
367
|
+
else:
|
|
368
|
+
if hasattr(self._local, "connection"):
|
|
369
|
+
self._local.connection.close()
|
agents/model_settings.py
CHANGED
|
@@ -17,9 +17,9 @@ from typing_extensions import TypeAlias
|
|
|
17
17
|
class _OmitTypeAnnotation:
|
|
18
18
|
@classmethod
|
|
19
19
|
def __get_pydantic_core_schema__(
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
20
|
+
cls,
|
|
21
|
+
_source_type: Any,
|
|
22
|
+
_handler: GetCoreSchemaHandler,
|
|
23
23
|
) -> core_schema.CoreSchema:
|
|
24
24
|
def validate_from_none(value: None) -> _Omit:
|
|
25
25
|
return _Omit()
|
|
@@ -39,12 +39,20 @@ class _OmitTypeAnnotation:
|
|
|
39
39
|
from_none_schema,
|
|
40
40
|
]
|
|
41
41
|
),
|
|
42
|
-
serialization=core_schema.plain_serializer_function_ser_schema(
|
|
43
|
-
lambda instance: None
|
|
44
|
-
),
|
|
42
|
+
serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None),
|
|
45
43
|
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class MCPToolChoice:
|
|
48
|
+
server_label: str
|
|
49
|
+
name: str
|
|
50
|
+
|
|
51
|
+
|
|
46
52
|
Omit = Annotated[_Omit, _OmitTypeAnnotation]
|
|
47
53
|
Headers: TypeAlias = Mapping[str, Union[str, Omit]]
|
|
54
|
+
ToolChoice: TypeAlias = Union[Literal["auto", "required", "none"], str, MCPToolChoice, None]
|
|
55
|
+
|
|
48
56
|
|
|
49
57
|
@dataclass
|
|
50
58
|
class ModelSettings:
|
|
@@ -69,7 +77,7 @@ class ModelSettings:
|
|
|
69
77
|
presence_penalty: float | None = None
|
|
70
78
|
"""The presence penalty to use when calling the model."""
|
|
71
79
|
|
|
72
|
-
tool_choice:
|
|
80
|
+
tool_choice: ToolChoice | None = None
|
|
73
81
|
"""The tool choice to use when calling the model."""
|
|
74
82
|
|
|
75
83
|
parallel_tool_calls: bool | None = None
|
|
@@ -19,6 +19,7 @@ from openai.types.chat import (
|
|
|
19
19
|
ChatCompletionToolMessageParam,
|
|
20
20
|
ChatCompletionUserMessageParam,
|
|
21
21
|
)
|
|
22
|
+
from openai.types.chat.chat_completion_content_part_param import File, FileFile
|
|
22
23
|
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
|
|
23
24
|
from openai.types.chat.completion_create_params import ResponseFormat
|
|
24
25
|
from openai.types.responses import (
|
|
@@ -27,6 +28,7 @@ from openai.types.responses import (
|
|
|
27
28
|
ResponseFunctionToolCall,
|
|
28
29
|
ResponseFunctionToolCallParam,
|
|
29
30
|
ResponseInputContentParam,
|
|
31
|
+
ResponseInputFileParam,
|
|
30
32
|
ResponseInputImageParam,
|
|
31
33
|
ResponseInputTextParam,
|
|
32
34
|
ResponseOutputMessage,
|
|
@@ -42,6 +44,7 @@ from ..agent_output import AgentOutputSchemaBase
|
|
|
42
44
|
from ..exceptions import AgentsException, UserError
|
|
43
45
|
from ..handoffs import Handoff
|
|
44
46
|
from ..items import TResponseInputItem, TResponseOutputItem
|
|
47
|
+
from ..model_settings import MCPToolChoice
|
|
45
48
|
from ..tool import FunctionTool, Tool
|
|
46
49
|
from .fake_id import FAKE_RESPONSES_ID
|
|
47
50
|
|
|
@@ -49,10 +52,12 @@ from .fake_id import FAKE_RESPONSES_ID
|
|
|
49
52
|
class Converter:
|
|
50
53
|
@classmethod
|
|
51
54
|
def convert_tool_choice(
|
|
52
|
-
cls, tool_choice: Literal["auto", "required", "none"] | str | None
|
|
55
|
+
cls, tool_choice: Literal["auto", "required", "none"] | str | MCPToolChoice | None
|
|
53
56
|
) -> ChatCompletionToolChoiceOptionParam | NotGiven:
|
|
54
57
|
if tool_choice is None:
|
|
55
58
|
return NOT_GIVEN
|
|
59
|
+
elif isinstance(tool_choice, MCPToolChoice):
|
|
60
|
+
raise UserError("MCPToolChoice is not supported for Chat Completions models")
|
|
56
61
|
elif tool_choice == "auto":
|
|
57
62
|
return "auto"
|
|
58
63
|
elif tool_choice == "required":
|
|
@@ -251,7 +256,19 @@ class Converter:
|
|
|
251
256
|
)
|
|
252
257
|
)
|
|
253
258
|
elif isinstance(c, dict) and c.get("type") == "input_file":
|
|
254
|
-
|
|
259
|
+
casted_file_param = cast(ResponseInputFileParam, c)
|
|
260
|
+
if "file_data" not in casted_file_param or not casted_file_param["file_data"]:
|
|
261
|
+
raise UserError(
|
|
262
|
+
f"Only file_data is supported for input_file {casted_file_param}"
|
|
263
|
+
)
|
|
264
|
+
out.append(
|
|
265
|
+
File(
|
|
266
|
+
type="file",
|
|
267
|
+
file=FileFile(
|
|
268
|
+
file_data=casted_file_param["file_data"],
|
|
269
|
+
),
|
|
270
|
+
)
|
|
271
|
+
)
|
|
255
272
|
else:
|
|
256
273
|
raise UserError(f"Unknown content: {c}")
|
|
257
274
|
return out
|
|
@@ -276,7 +276,7 @@ class ChatCmplStreamHandler:
|
|
|
276
276
|
state.function_calls[tc_delta.index].name += (
|
|
277
277
|
tc_function.name if tc_function else ""
|
|
278
278
|
) or ""
|
|
279
|
-
state.function_calls[tc_delta.index].call_id
|
|
279
|
+
state.function_calls[tc_delta.index].call_id = tc_delta.id or ""
|
|
280
280
|
|
|
281
281
|
if state.reasoning_content_index_and_output:
|
|
282
282
|
yield ResponseReasoningSummaryPartDoneEvent(
|
|
@@ -25,6 +25,7 @@ from ..exceptions import UserError
|
|
|
25
25
|
from ..handoffs import Handoff
|
|
26
26
|
from ..items import ItemHelpers, ModelResponse, TResponseInputItem
|
|
27
27
|
from ..logger import logger
|
|
28
|
+
from ..model_settings import MCPToolChoice
|
|
28
29
|
from ..tool import (
|
|
29
30
|
CodeInterpreterTool,
|
|
30
31
|
ComputerTool,
|
|
@@ -303,10 +304,16 @@ class ConvertedTools:
|
|
|
303
304
|
class Converter:
|
|
304
305
|
@classmethod
|
|
305
306
|
def convert_tool_choice(
|
|
306
|
-
cls, tool_choice: Literal["auto", "required", "none"] | str | None
|
|
307
|
+
cls, tool_choice: Literal["auto", "required", "none"] | str | MCPToolChoice | None
|
|
307
308
|
) -> response_create_params.ToolChoice | NotGiven:
|
|
308
309
|
if tool_choice is None:
|
|
309
310
|
return NOT_GIVEN
|
|
311
|
+
elif isinstance(tool_choice, MCPToolChoice):
|
|
312
|
+
return {
|
|
313
|
+
"server_label": tool_choice.server_label,
|
|
314
|
+
"type": "mcp",
|
|
315
|
+
"name": tool_choice.name,
|
|
316
|
+
}
|
|
310
317
|
elif tool_choice == "required":
|
|
311
318
|
return "required"
|
|
312
319
|
elif tool_choice == "auto":
|
|
@@ -334,9 +341,9 @@ class Converter:
|
|
|
334
341
|
"type": "code_interpreter",
|
|
335
342
|
}
|
|
336
343
|
elif tool_choice == "mcp":
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
}
|
|
344
|
+
# Note that this is still here for backwards compatibility,
|
|
345
|
+
# but migrating to MCPToolChoice is recommended.
|
|
346
|
+
return {"type": "mcp"} # type: ignore [typeddict-item]
|
|
340
347
|
else:
|
|
341
348
|
return {
|
|
342
349
|
"type": "function",
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
from .agent import RealtimeAgent, RealtimeAgentHooks, RealtimeRunHooks
|
|
2
|
+
from .config import (
|
|
3
|
+
RealtimeAudioFormat,
|
|
4
|
+
RealtimeClientMessage,
|
|
5
|
+
RealtimeGuardrailsSettings,
|
|
6
|
+
RealtimeInputAudioTranscriptionConfig,
|
|
7
|
+
RealtimeModelName,
|
|
8
|
+
RealtimeModelTracingConfig,
|
|
9
|
+
RealtimeRunConfig,
|
|
10
|
+
RealtimeSessionModelSettings,
|
|
11
|
+
RealtimeTurnDetectionConfig,
|
|
12
|
+
RealtimeUserInput,
|
|
13
|
+
RealtimeUserInputMessage,
|
|
14
|
+
RealtimeUserInputText,
|
|
15
|
+
)
|
|
16
|
+
from .events import (
|
|
17
|
+
RealtimeAgentEndEvent,
|
|
18
|
+
RealtimeAgentStartEvent,
|
|
19
|
+
RealtimeAudio,
|
|
20
|
+
RealtimeAudioEnd,
|
|
21
|
+
RealtimeAudioInterrupted,
|
|
22
|
+
RealtimeError,
|
|
23
|
+
RealtimeEventInfo,
|
|
24
|
+
RealtimeGuardrailTripped,
|
|
25
|
+
RealtimeHandoffEvent,
|
|
26
|
+
RealtimeHistoryAdded,
|
|
27
|
+
RealtimeHistoryUpdated,
|
|
28
|
+
RealtimeRawModelEvent,
|
|
29
|
+
RealtimeSessionEvent,
|
|
30
|
+
RealtimeToolEnd,
|
|
31
|
+
RealtimeToolStart,
|
|
32
|
+
)
|
|
33
|
+
from .items import (
|
|
34
|
+
AssistantMessageItem,
|
|
35
|
+
AssistantText,
|
|
36
|
+
InputAudio,
|
|
37
|
+
InputText,
|
|
38
|
+
RealtimeItem,
|
|
39
|
+
RealtimeMessageItem,
|
|
40
|
+
RealtimeResponse,
|
|
41
|
+
RealtimeToolCallItem,
|
|
42
|
+
SystemMessageItem,
|
|
43
|
+
UserMessageItem,
|
|
44
|
+
)
|
|
45
|
+
from .model import (
|
|
46
|
+
RealtimeModel,
|
|
47
|
+
RealtimeModelConfig,
|
|
48
|
+
RealtimeModelListener,
|
|
49
|
+
)
|
|
50
|
+
from .model_events import (
|
|
51
|
+
RealtimeConnectionStatus,
|
|
52
|
+
RealtimeModelAudioDoneEvent,
|
|
53
|
+
RealtimeModelAudioEvent,
|
|
54
|
+
RealtimeModelAudioInterruptedEvent,
|
|
55
|
+
RealtimeModelConnectionStatusEvent,
|
|
56
|
+
RealtimeModelErrorEvent,
|
|
57
|
+
RealtimeModelEvent,
|
|
58
|
+
RealtimeModelExceptionEvent,
|
|
59
|
+
RealtimeModelInputAudioTranscriptionCompletedEvent,
|
|
60
|
+
RealtimeModelItemDeletedEvent,
|
|
61
|
+
RealtimeModelItemUpdatedEvent,
|
|
62
|
+
RealtimeModelOtherEvent,
|
|
63
|
+
RealtimeModelToolCallEvent,
|
|
64
|
+
RealtimeModelTranscriptDeltaEvent,
|
|
65
|
+
RealtimeModelTurnEndedEvent,
|
|
66
|
+
RealtimeModelTurnStartedEvent,
|
|
67
|
+
)
|
|
68
|
+
from .model_inputs import (
|
|
69
|
+
RealtimeModelInputTextContent,
|
|
70
|
+
RealtimeModelRawClientMessage,
|
|
71
|
+
RealtimeModelSendAudio,
|
|
72
|
+
RealtimeModelSendEvent,
|
|
73
|
+
RealtimeModelSendInterrupt,
|
|
74
|
+
RealtimeModelSendRawMessage,
|
|
75
|
+
RealtimeModelSendSessionUpdate,
|
|
76
|
+
RealtimeModelSendToolOutput,
|
|
77
|
+
RealtimeModelSendUserInput,
|
|
78
|
+
RealtimeModelUserInput,
|
|
79
|
+
RealtimeModelUserInputMessage,
|
|
80
|
+
)
|
|
81
|
+
from .openai_realtime import (
|
|
82
|
+
DEFAULT_MODEL_SETTINGS,
|
|
83
|
+
OpenAIRealtimeWebSocketModel,
|
|
84
|
+
get_api_key,
|
|
85
|
+
)
|
|
86
|
+
from .runner import RealtimeRunner
|
|
87
|
+
from .session import RealtimeSession
|
|
88
|
+
|
|
89
|
+
__all__ = [
|
|
90
|
+
# Agent
|
|
91
|
+
"RealtimeAgent",
|
|
92
|
+
"RealtimeAgentHooks",
|
|
93
|
+
"RealtimeRunHooks",
|
|
94
|
+
"RealtimeRunner",
|
|
95
|
+
# Config
|
|
96
|
+
"RealtimeAudioFormat",
|
|
97
|
+
"RealtimeClientMessage",
|
|
98
|
+
"RealtimeGuardrailsSettings",
|
|
99
|
+
"RealtimeInputAudioTranscriptionConfig",
|
|
100
|
+
"RealtimeModelName",
|
|
101
|
+
"RealtimeModelTracingConfig",
|
|
102
|
+
"RealtimeRunConfig",
|
|
103
|
+
"RealtimeSessionModelSettings",
|
|
104
|
+
"RealtimeTurnDetectionConfig",
|
|
105
|
+
"RealtimeUserInput",
|
|
106
|
+
"RealtimeUserInputMessage",
|
|
107
|
+
"RealtimeUserInputText",
|
|
108
|
+
# Events
|
|
109
|
+
"RealtimeAgentEndEvent",
|
|
110
|
+
"RealtimeAgentStartEvent",
|
|
111
|
+
"RealtimeAudio",
|
|
112
|
+
"RealtimeAudioEnd",
|
|
113
|
+
"RealtimeAudioInterrupted",
|
|
114
|
+
"RealtimeError",
|
|
115
|
+
"RealtimeEventInfo",
|
|
116
|
+
"RealtimeGuardrailTripped",
|
|
117
|
+
"RealtimeHandoffEvent",
|
|
118
|
+
"RealtimeHistoryAdded",
|
|
119
|
+
"RealtimeHistoryUpdated",
|
|
120
|
+
"RealtimeRawModelEvent",
|
|
121
|
+
"RealtimeSessionEvent",
|
|
122
|
+
"RealtimeToolEnd",
|
|
123
|
+
"RealtimeToolStart",
|
|
124
|
+
# Items
|
|
125
|
+
"AssistantMessageItem",
|
|
126
|
+
"AssistantText",
|
|
127
|
+
"InputAudio",
|
|
128
|
+
"InputText",
|
|
129
|
+
"RealtimeItem",
|
|
130
|
+
"RealtimeMessageItem",
|
|
131
|
+
"RealtimeResponse",
|
|
132
|
+
"RealtimeToolCallItem",
|
|
133
|
+
"SystemMessageItem",
|
|
134
|
+
"UserMessageItem",
|
|
135
|
+
# Model
|
|
136
|
+
"RealtimeModel",
|
|
137
|
+
"RealtimeModelConfig",
|
|
138
|
+
"RealtimeModelListener",
|
|
139
|
+
# Model Events
|
|
140
|
+
"RealtimeConnectionStatus",
|
|
141
|
+
"RealtimeModelAudioDoneEvent",
|
|
142
|
+
"RealtimeModelAudioEvent",
|
|
143
|
+
"RealtimeModelAudioInterruptedEvent",
|
|
144
|
+
"RealtimeModelConnectionStatusEvent",
|
|
145
|
+
"RealtimeModelErrorEvent",
|
|
146
|
+
"RealtimeModelEvent",
|
|
147
|
+
"RealtimeModelExceptionEvent",
|
|
148
|
+
"RealtimeModelInputAudioTranscriptionCompletedEvent",
|
|
149
|
+
"RealtimeModelItemDeletedEvent",
|
|
150
|
+
"RealtimeModelItemUpdatedEvent",
|
|
151
|
+
"RealtimeModelOtherEvent",
|
|
152
|
+
"RealtimeModelToolCallEvent",
|
|
153
|
+
"RealtimeModelTranscriptDeltaEvent",
|
|
154
|
+
"RealtimeModelTurnEndedEvent",
|
|
155
|
+
"RealtimeModelTurnStartedEvent",
|
|
156
|
+
# Model Inputs
|
|
157
|
+
"RealtimeModelInputTextContent",
|
|
158
|
+
"RealtimeModelRawClientMessage",
|
|
159
|
+
"RealtimeModelSendAudio",
|
|
160
|
+
"RealtimeModelSendEvent",
|
|
161
|
+
"RealtimeModelSendInterrupt",
|
|
162
|
+
"RealtimeModelSendRawMessage",
|
|
163
|
+
"RealtimeModelSendSessionUpdate",
|
|
164
|
+
"RealtimeModelSendToolOutput",
|
|
165
|
+
"RealtimeModelSendUserInput",
|
|
166
|
+
"RealtimeModelUserInput",
|
|
167
|
+
"RealtimeModelUserInputMessage",
|
|
168
|
+
# OpenAI Realtime
|
|
169
|
+
"DEFAULT_MODEL_SETTINGS",
|
|
170
|
+
"OpenAIRealtimeWebSocketModel",
|
|
171
|
+
"get_api_key",
|
|
172
|
+
# Session
|
|
173
|
+
"RealtimeSession",
|
|
174
|
+
]
|