openai-agents 0.2.8__py3-none-any.whl → 0.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agents/__init__.py +105 -4
- agents/_debug.py +15 -4
- agents/_run_impl.py +1203 -96
- agents/agent.py +164 -19
- agents/apply_diff.py +329 -0
- agents/editor.py +47 -0
- agents/exceptions.py +35 -0
- agents/extensions/experimental/__init__.py +6 -0
- agents/extensions/experimental/codex/__init__.py +92 -0
- agents/extensions/experimental/codex/codex.py +89 -0
- agents/extensions/experimental/codex/codex_options.py +35 -0
- agents/extensions/experimental/codex/codex_tool.py +1142 -0
- agents/extensions/experimental/codex/events.py +162 -0
- agents/extensions/experimental/codex/exec.py +263 -0
- agents/extensions/experimental/codex/items.py +245 -0
- agents/extensions/experimental/codex/output_schema_file.py +50 -0
- agents/extensions/experimental/codex/payloads.py +31 -0
- agents/extensions/experimental/codex/thread.py +214 -0
- agents/extensions/experimental/codex/thread_options.py +54 -0
- agents/extensions/experimental/codex/turn_options.py +36 -0
- agents/extensions/handoff_filters.py +13 -1
- agents/extensions/memory/__init__.py +120 -0
- agents/extensions/memory/advanced_sqlite_session.py +1285 -0
- agents/extensions/memory/async_sqlite_session.py +239 -0
- agents/extensions/memory/dapr_session.py +423 -0
- agents/extensions/memory/encrypt_session.py +185 -0
- agents/extensions/memory/redis_session.py +261 -0
- agents/extensions/memory/sqlalchemy_session.py +334 -0
- agents/extensions/models/litellm_model.py +449 -36
- agents/extensions/models/litellm_provider.py +3 -1
- agents/function_schema.py +47 -5
- agents/guardrail.py +16 -2
- agents/{handoffs.py → handoffs/__init__.py} +89 -47
- agents/handoffs/history.py +268 -0
- agents/items.py +237 -11
- agents/lifecycle.py +75 -14
- agents/mcp/server.py +280 -37
- agents/mcp/util.py +24 -3
- agents/memory/__init__.py +22 -2
- agents/memory/openai_conversations_session.py +91 -0
- agents/memory/openai_responses_compaction_session.py +249 -0
- agents/memory/session.py +19 -261
- agents/memory/sqlite_session.py +275 -0
- agents/memory/util.py +20 -0
- agents/model_settings.py +14 -3
- agents/models/__init__.py +13 -0
- agents/models/chatcmpl_converter.py +303 -50
- agents/models/chatcmpl_helpers.py +63 -0
- agents/models/chatcmpl_stream_handler.py +290 -68
- agents/models/default_models.py +58 -0
- agents/models/interface.py +4 -0
- agents/models/openai_chatcompletions.py +103 -49
- agents/models/openai_provider.py +10 -4
- agents/models/openai_responses.py +162 -46
- agents/realtime/__init__.py +4 -0
- agents/realtime/_util.py +14 -3
- agents/realtime/agent.py +7 -0
- agents/realtime/audio_formats.py +53 -0
- agents/realtime/config.py +78 -10
- agents/realtime/events.py +18 -0
- agents/realtime/handoffs.py +2 -2
- agents/realtime/items.py +17 -1
- agents/realtime/model.py +13 -0
- agents/realtime/model_events.py +12 -0
- agents/realtime/model_inputs.py +18 -1
- agents/realtime/openai_realtime.py +696 -150
- agents/realtime/session.py +243 -23
- agents/repl.py +7 -3
- agents/result.py +197 -38
- agents/run.py +949 -168
- agents/run_context.py +13 -2
- agents/stream_events.py +1 -0
- agents/strict_schema.py +14 -0
- agents/tool.py +413 -15
- agents/tool_context.py +22 -1
- agents/tool_guardrails.py +279 -0
- agents/tracing/__init__.py +2 -0
- agents/tracing/config.py +9 -0
- agents/tracing/create.py +4 -0
- agents/tracing/processor_interface.py +84 -11
- agents/tracing/processors.py +65 -54
- agents/tracing/provider.py +64 -7
- agents/tracing/spans.py +105 -0
- agents/tracing/traces.py +116 -16
- agents/usage.py +134 -12
- agents/util/_json.py +19 -1
- agents/util/_transforms.py +12 -2
- agents/voice/input.py +5 -4
- agents/voice/models/openai_stt.py +17 -9
- agents/voice/pipeline.py +2 -0
- agents/voice/pipeline_config.py +4 -0
- {openai_agents-0.2.8.dist-info → openai_agents-0.6.8.dist-info}/METADATA +44 -19
- openai_agents-0.6.8.dist-info/RECORD +134 -0
- {openai_agents-0.2.8.dist-info → openai_agents-0.6.8.dist-info}/WHEEL +1 -1
- openai_agents-0.2.8.dist-info/RECORD +0 -103
- {openai_agents-0.2.8.dist-info → openai_agents-0.6.8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1285 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import threading
|
|
7
|
+
from contextlib import closing
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Union, cast
|
|
10
|
+
|
|
11
|
+
from agents.result import RunResult
|
|
12
|
+
from agents.usage import Usage
|
|
13
|
+
|
|
14
|
+
from ...items import TResponseInputItem
|
|
15
|
+
from ...memory import SQLiteSession
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AdvancedSQLiteSession(SQLiteSession):
|
|
19
|
+
"""Enhanced SQLite session with conversation branching and usage analytics."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
*,
|
|
24
|
+
session_id: str,
|
|
25
|
+
db_path: str | Path = ":memory:",
|
|
26
|
+
create_tables: bool = False,
|
|
27
|
+
logger: logging.Logger | None = None,
|
|
28
|
+
**kwargs,
|
|
29
|
+
):
|
|
30
|
+
"""Initialize the AdvancedSQLiteSession.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
session_id: The ID of the session
|
|
34
|
+
db_path: The path to the SQLite database file. Defaults to `:memory:` for in-memory storage
|
|
35
|
+
create_tables: Whether to create the structure tables
|
|
36
|
+
logger: The logger to use. Defaults to the module logger
|
|
37
|
+
**kwargs: Additional keyword arguments to pass to the superclass
|
|
38
|
+
""" # noqa: E501
|
|
39
|
+
super().__init__(session_id, db_path, **kwargs)
|
|
40
|
+
if create_tables:
|
|
41
|
+
self._init_structure_tables()
|
|
42
|
+
self._current_branch_id = "main"
|
|
43
|
+
self._logger = logger or logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
def _init_structure_tables(self):
|
|
46
|
+
"""Add structure and usage tracking tables.
|
|
47
|
+
|
|
48
|
+
Creates the message_structure and turn_usage tables with appropriate
|
|
49
|
+
indexes for conversation branching and usage analytics.
|
|
50
|
+
"""
|
|
51
|
+
conn = self._get_connection()
|
|
52
|
+
|
|
53
|
+
# Message structure with branch support
|
|
54
|
+
conn.execute("""
|
|
55
|
+
CREATE TABLE IF NOT EXISTS message_structure (
|
|
56
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
57
|
+
session_id TEXT NOT NULL,
|
|
58
|
+
message_id INTEGER NOT NULL,
|
|
59
|
+
branch_id TEXT NOT NULL DEFAULT 'main',
|
|
60
|
+
message_type TEXT NOT NULL,
|
|
61
|
+
sequence_number INTEGER NOT NULL,
|
|
62
|
+
user_turn_number INTEGER,
|
|
63
|
+
branch_turn_number INTEGER,
|
|
64
|
+
tool_name TEXT,
|
|
65
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
66
|
+
FOREIGN KEY (session_id) REFERENCES agent_sessions(session_id) ON DELETE CASCADE,
|
|
67
|
+
FOREIGN KEY (message_id) REFERENCES agent_messages(id) ON DELETE CASCADE
|
|
68
|
+
)
|
|
69
|
+
""")
|
|
70
|
+
|
|
71
|
+
# Turn-level usage tracking with branch support and full JSON details
|
|
72
|
+
conn.execute("""
|
|
73
|
+
CREATE TABLE IF NOT EXISTS turn_usage (
|
|
74
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
75
|
+
session_id TEXT NOT NULL,
|
|
76
|
+
branch_id TEXT NOT NULL DEFAULT 'main',
|
|
77
|
+
user_turn_number INTEGER NOT NULL,
|
|
78
|
+
requests INTEGER DEFAULT 0,
|
|
79
|
+
input_tokens INTEGER DEFAULT 0,
|
|
80
|
+
output_tokens INTEGER DEFAULT 0,
|
|
81
|
+
total_tokens INTEGER DEFAULT 0,
|
|
82
|
+
input_tokens_details JSON,
|
|
83
|
+
output_tokens_details JSON,
|
|
84
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
85
|
+
FOREIGN KEY (session_id) REFERENCES agent_sessions(session_id) ON DELETE CASCADE,
|
|
86
|
+
UNIQUE(session_id, branch_id, user_turn_number)
|
|
87
|
+
)
|
|
88
|
+
""")
|
|
89
|
+
|
|
90
|
+
# Indexes
|
|
91
|
+
conn.execute("""
|
|
92
|
+
CREATE INDEX IF NOT EXISTS idx_structure_session_seq
|
|
93
|
+
ON message_structure(session_id, sequence_number)
|
|
94
|
+
""")
|
|
95
|
+
conn.execute("""
|
|
96
|
+
CREATE INDEX IF NOT EXISTS idx_structure_branch
|
|
97
|
+
ON message_structure(session_id, branch_id)
|
|
98
|
+
""")
|
|
99
|
+
conn.execute("""
|
|
100
|
+
CREATE INDEX IF NOT EXISTS idx_structure_turn
|
|
101
|
+
ON message_structure(session_id, branch_id, user_turn_number)
|
|
102
|
+
""")
|
|
103
|
+
conn.execute("""
|
|
104
|
+
CREATE INDEX IF NOT EXISTS idx_structure_branch_seq
|
|
105
|
+
ON message_structure(session_id, branch_id, sequence_number)
|
|
106
|
+
""")
|
|
107
|
+
conn.execute("""
|
|
108
|
+
CREATE INDEX IF NOT EXISTS idx_turn_usage_session_turn
|
|
109
|
+
ON turn_usage(session_id, branch_id, user_turn_number)
|
|
110
|
+
""")
|
|
111
|
+
|
|
112
|
+
conn.commit()
|
|
113
|
+
|
|
114
|
+
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
|
115
|
+
"""Add items to the session.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
items: The items to add to the session
|
|
119
|
+
"""
|
|
120
|
+
# Add to base table first
|
|
121
|
+
await super().add_items(items)
|
|
122
|
+
|
|
123
|
+
# Extract structure metadata with precise sequencing
|
|
124
|
+
if items:
|
|
125
|
+
await self._add_structure_metadata(items)
|
|
126
|
+
|
|
127
|
+
async def get_items(
|
|
128
|
+
self,
|
|
129
|
+
limit: int | None = None,
|
|
130
|
+
branch_id: str | None = None,
|
|
131
|
+
) -> list[TResponseInputItem]:
|
|
132
|
+
"""Get items from current or specified branch.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
limit: Maximum number of items to return. If None, returns all items.
|
|
136
|
+
branch_id: Branch to get items from. If None, uses current branch.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
List of conversation items from the specified branch.
|
|
140
|
+
"""
|
|
141
|
+
if branch_id is None:
|
|
142
|
+
branch_id = self._current_branch_id
|
|
143
|
+
|
|
144
|
+
# Get all items for this branch
|
|
145
|
+
def _get_all_items_sync():
|
|
146
|
+
"""Synchronous helper to get all items for a branch."""
|
|
147
|
+
conn = self._get_connection()
|
|
148
|
+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
|
|
149
|
+
with self._lock if self._is_memory_db else threading.Lock():
|
|
150
|
+
with closing(conn.cursor()) as cursor:
|
|
151
|
+
if limit is None:
|
|
152
|
+
cursor.execute(
|
|
153
|
+
"""
|
|
154
|
+
SELECT m.message_data
|
|
155
|
+
FROM agent_messages m
|
|
156
|
+
JOIN message_structure s ON m.id = s.message_id
|
|
157
|
+
WHERE m.session_id = ? AND s.branch_id = ?
|
|
158
|
+
ORDER BY s.sequence_number ASC
|
|
159
|
+
""",
|
|
160
|
+
(self.session_id, branch_id),
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
cursor.execute(
|
|
164
|
+
"""
|
|
165
|
+
SELECT m.message_data
|
|
166
|
+
FROM agent_messages m
|
|
167
|
+
JOIN message_structure s ON m.id = s.message_id
|
|
168
|
+
WHERE m.session_id = ? AND s.branch_id = ?
|
|
169
|
+
ORDER BY s.sequence_number DESC
|
|
170
|
+
LIMIT ?
|
|
171
|
+
""",
|
|
172
|
+
(self.session_id, branch_id, limit),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
rows = cursor.fetchall()
|
|
176
|
+
if limit is not None:
|
|
177
|
+
rows = list(reversed(rows))
|
|
178
|
+
|
|
179
|
+
items = []
|
|
180
|
+
for (message_data,) in rows:
|
|
181
|
+
try:
|
|
182
|
+
item = json.loads(message_data)
|
|
183
|
+
items.append(item)
|
|
184
|
+
except json.JSONDecodeError:
|
|
185
|
+
continue
|
|
186
|
+
return items
|
|
187
|
+
|
|
188
|
+
return await asyncio.to_thread(_get_all_items_sync)
|
|
189
|
+
|
|
190
|
+
def _get_items_sync():
|
|
191
|
+
"""Synchronous helper to get items for a specific branch."""
|
|
192
|
+
conn = self._get_connection()
|
|
193
|
+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
|
|
194
|
+
with self._lock if self._is_memory_db else threading.Lock():
|
|
195
|
+
with closing(conn.cursor()) as cursor:
|
|
196
|
+
# Get message IDs in correct order for this branch
|
|
197
|
+
if limit is None:
|
|
198
|
+
cursor.execute(
|
|
199
|
+
"""
|
|
200
|
+
SELECT m.message_data
|
|
201
|
+
FROM agent_messages m
|
|
202
|
+
JOIN message_structure s ON m.id = s.message_id
|
|
203
|
+
WHERE m.session_id = ? AND s.branch_id = ?
|
|
204
|
+
ORDER BY s.sequence_number ASC
|
|
205
|
+
""",
|
|
206
|
+
(self.session_id, branch_id),
|
|
207
|
+
)
|
|
208
|
+
else:
|
|
209
|
+
cursor.execute(
|
|
210
|
+
"""
|
|
211
|
+
SELECT m.message_data
|
|
212
|
+
FROM agent_messages m
|
|
213
|
+
JOIN message_structure s ON m.id = s.message_id
|
|
214
|
+
WHERE m.session_id = ? AND s.branch_id = ?
|
|
215
|
+
ORDER BY s.sequence_number DESC
|
|
216
|
+
LIMIT ?
|
|
217
|
+
""",
|
|
218
|
+
(self.session_id, branch_id, limit),
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
rows = cursor.fetchall()
|
|
222
|
+
if limit is not None:
|
|
223
|
+
rows = list(reversed(rows))
|
|
224
|
+
|
|
225
|
+
items = []
|
|
226
|
+
for (message_data,) in rows:
|
|
227
|
+
try:
|
|
228
|
+
item = json.loads(message_data)
|
|
229
|
+
items.append(item)
|
|
230
|
+
except json.JSONDecodeError:
|
|
231
|
+
continue
|
|
232
|
+
return items
|
|
233
|
+
|
|
234
|
+
return await asyncio.to_thread(_get_items_sync)
|
|
235
|
+
|
|
236
|
+
async def store_run_usage(self, result: RunResult) -> None:
|
|
237
|
+
"""Store usage data for the current conversation turn.
|
|
238
|
+
|
|
239
|
+
This is designed to be called after `Runner.run()` completes.
|
|
240
|
+
Session-level usage can be aggregated from turn data when needed.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
result: The result from the run
|
|
244
|
+
"""
|
|
245
|
+
try:
|
|
246
|
+
if result.context_wrapper.usage is not None:
|
|
247
|
+
# Get the current turn number for this branch
|
|
248
|
+
current_turn = self._get_current_turn_number()
|
|
249
|
+
# Only update turn-level usage - session usage is aggregated on demand
|
|
250
|
+
await self._update_turn_usage_internal(current_turn, result.context_wrapper.usage)
|
|
251
|
+
except Exception as e:
|
|
252
|
+
self._logger.error(f"Failed to store usage for session {self.session_id}: {e}")
|
|
253
|
+
|
|
254
|
+
def _get_next_turn_number(self, branch_id: str) -> int:
|
|
255
|
+
"""Get the next turn number for a specific branch.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
branch_id: The branch ID to get the next turn number for.
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
The next available turn number for the specified branch.
|
|
262
|
+
"""
|
|
263
|
+
conn = self._get_connection()
|
|
264
|
+
with closing(conn.cursor()) as cursor:
|
|
265
|
+
cursor.execute(
|
|
266
|
+
"""
|
|
267
|
+
SELECT COALESCE(MAX(user_turn_number), 0)
|
|
268
|
+
FROM message_structure
|
|
269
|
+
WHERE session_id = ? AND branch_id = ?
|
|
270
|
+
""",
|
|
271
|
+
(self.session_id, branch_id),
|
|
272
|
+
)
|
|
273
|
+
result = cursor.fetchone()
|
|
274
|
+
max_turn = result[0] if result else 0
|
|
275
|
+
return max_turn + 1
|
|
276
|
+
|
|
277
|
+
def _get_next_branch_turn_number(self, branch_id: str) -> int:
|
|
278
|
+
"""Get the next branch turn number for a specific branch.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
branch_id: The branch ID to get the next branch turn number for.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
The next available branch turn number for the specified branch.
|
|
285
|
+
"""
|
|
286
|
+
conn = self._get_connection()
|
|
287
|
+
with closing(conn.cursor()) as cursor:
|
|
288
|
+
cursor.execute(
|
|
289
|
+
"""
|
|
290
|
+
SELECT COALESCE(MAX(branch_turn_number), 0)
|
|
291
|
+
FROM message_structure
|
|
292
|
+
WHERE session_id = ? AND branch_id = ?
|
|
293
|
+
""",
|
|
294
|
+
(self.session_id, branch_id),
|
|
295
|
+
)
|
|
296
|
+
result = cursor.fetchone()
|
|
297
|
+
max_turn = result[0] if result else 0
|
|
298
|
+
return max_turn + 1
|
|
299
|
+
|
|
300
|
+
def _get_current_turn_number(self) -> int:
|
|
301
|
+
"""Get the current turn number for the current branch.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
The current turn number for the active branch.
|
|
305
|
+
"""
|
|
306
|
+
conn = self._get_connection()
|
|
307
|
+
with closing(conn.cursor()) as cursor:
|
|
308
|
+
cursor.execute(
|
|
309
|
+
"""
|
|
310
|
+
SELECT COALESCE(MAX(user_turn_number), 0)
|
|
311
|
+
FROM message_structure
|
|
312
|
+
WHERE session_id = ? AND branch_id = ?
|
|
313
|
+
""",
|
|
314
|
+
(self.session_id, self._current_branch_id),
|
|
315
|
+
)
|
|
316
|
+
result = cursor.fetchone()
|
|
317
|
+
return result[0] if result else 0
|
|
318
|
+
|
|
319
|
+
async def _add_structure_metadata(self, items: list[TResponseInputItem]) -> None:
|
|
320
|
+
"""Extract structure metadata with branch-aware turn tracking.
|
|
321
|
+
|
|
322
|
+
This method:
|
|
323
|
+
- Assigns turn numbers per branch (not globally)
|
|
324
|
+
- Assigns explicit sequence numbers for precise ordering
|
|
325
|
+
- Links messages to their database IDs for structure tracking
|
|
326
|
+
- Handles multiple user messages in a single batch correctly
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
items: The items to add to the session
|
|
330
|
+
"""
|
|
331
|
+
|
|
332
|
+
def _add_structure_sync():
|
|
333
|
+
"""Synchronous helper to add structure metadata to database."""
|
|
334
|
+
conn = self._get_connection()
|
|
335
|
+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
|
|
336
|
+
with self._lock if self._is_memory_db else threading.Lock():
|
|
337
|
+
# Get the IDs of messages we just inserted, in order
|
|
338
|
+
with closing(conn.cursor()) as cursor:
|
|
339
|
+
cursor.execute(
|
|
340
|
+
f"SELECT id FROM {self.messages_table} "
|
|
341
|
+
f"WHERE session_id = ? ORDER BY id DESC LIMIT ?",
|
|
342
|
+
(self.session_id, len(items)),
|
|
343
|
+
)
|
|
344
|
+
message_ids = [row[0] for row in cursor.fetchall()]
|
|
345
|
+
message_ids.reverse() # Match order of items
|
|
346
|
+
|
|
347
|
+
# Get current max sequence number (global)
|
|
348
|
+
with closing(conn.cursor()) as cursor:
|
|
349
|
+
cursor.execute(
|
|
350
|
+
"""
|
|
351
|
+
SELECT COALESCE(MAX(sequence_number), 0)
|
|
352
|
+
FROM message_structure
|
|
353
|
+
WHERE session_id = ?
|
|
354
|
+
""",
|
|
355
|
+
(self.session_id,),
|
|
356
|
+
)
|
|
357
|
+
seq_start = cursor.fetchone()[0]
|
|
358
|
+
|
|
359
|
+
# Get current turn numbers atomically with a single query
|
|
360
|
+
with closing(conn.cursor()) as cursor:
|
|
361
|
+
cursor.execute(
|
|
362
|
+
"""
|
|
363
|
+
SELECT
|
|
364
|
+
COALESCE(MAX(user_turn_number), 0) as max_global_turn,
|
|
365
|
+
COALESCE(MAX(branch_turn_number), 0) as max_branch_turn
|
|
366
|
+
FROM message_structure
|
|
367
|
+
WHERE session_id = ? AND branch_id = ?
|
|
368
|
+
""",
|
|
369
|
+
(self.session_id, self._current_branch_id),
|
|
370
|
+
)
|
|
371
|
+
result = cursor.fetchone()
|
|
372
|
+
current_turn = result[0] if result else 0
|
|
373
|
+
current_branch_turn = result[1] if result else 0
|
|
374
|
+
|
|
375
|
+
# Process items and assign turn numbers correctly
|
|
376
|
+
structure_data = []
|
|
377
|
+
user_message_count = 0
|
|
378
|
+
|
|
379
|
+
for i, (item, msg_id) in enumerate(zip(items, message_ids)):
|
|
380
|
+
msg_type = self._classify_message_type(item)
|
|
381
|
+
tool_name = self._extract_tool_name(item)
|
|
382
|
+
|
|
383
|
+
# If this is a user message, increment turn counters
|
|
384
|
+
if self._is_user_message(item):
|
|
385
|
+
user_message_count += 1
|
|
386
|
+
item_turn = current_turn + user_message_count
|
|
387
|
+
item_branch_turn = current_branch_turn + user_message_count
|
|
388
|
+
else:
|
|
389
|
+
# Non-user messages inherit the turn number of the most recent user message
|
|
390
|
+
item_turn = current_turn + user_message_count
|
|
391
|
+
item_branch_turn = current_branch_turn + user_message_count
|
|
392
|
+
|
|
393
|
+
structure_data.append(
|
|
394
|
+
(
|
|
395
|
+
self.session_id,
|
|
396
|
+
msg_id,
|
|
397
|
+
self._current_branch_id,
|
|
398
|
+
msg_type,
|
|
399
|
+
seq_start + i + 1, # Global sequence
|
|
400
|
+
item_turn, # Global turn number
|
|
401
|
+
item_branch_turn, # Branch-specific turn number
|
|
402
|
+
tool_name,
|
|
403
|
+
)
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
with closing(conn.cursor()) as cursor:
|
|
407
|
+
cursor.executemany(
|
|
408
|
+
"""
|
|
409
|
+
INSERT INTO message_structure
|
|
410
|
+
(session_id, message_id, branch_id, message_type, sequence_number,
|
|
411
|
+
user_turn_number, branch_turn_number, tool_name)
|
|
412
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
413
|
+
""",
|
|
414
|
+
structure_data,
|
|
415
|
+
)
|
|
416
|
+
conn.commit()
|
|
417
|
+
|
|
418
|
+
try:
|
|
419
|
+
await asyncio.to_thread(_add_structure_sync)
|
|
420
|
+
except Exception as e:
|
|
421
|
+
self._logger.error(
|
|
422
|
+
f"Failed to add structure metadata for session {self.session_id}: {e}"
|
|
423
|
+
)
|
|
424
|
+
# Try to clean up any orphaned messages to maintain consistency
|
|
425
|
+
try:
|
|
426
|
+
await self._cleanup_orphaned_messages()
|
|
427
|
+
except Exception as cleanup_error:
|
|
428
|
+
self._logger.error(f"Failed to cleanup orphaned messages: {cleanup_error}")
|
|
429
|
+
# Don't re-raise - structure metadata is supplementary
|
|
430
|
+
|
|
431
|
+
async def _cleanup_orphaned_messages(self) -> None:
|
|
432
|
+
"""Remove messages that exist in agent_messages but not in message_structure.
|
|
433
|
+
|
|
434
|
+
This can happen if _add_structure_metadata fails after super().add_items() succeeds.
|
|
435
|
+
Used for maintaining data consistency.
|
|
436
|
+
"""
|
|
437
|
+
|
|
438
|
+
def _cleanup_sync():
|
|
439
|
+
"""Synchronous helper to cleanup orphaned messages."""
|
|
440
|
+
conn = self._get_connection()
|
|
441
|
+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
|
|
442
|
+
with self._lock if self._is_memory_db else threading.Lock():
|
|
443
|
+
with closing(conn.cursor()) as cursor:
|
|
444
|
+
# Find messages without structure metadata
|
|
445
|
+
cursor.execute(
|
|
446
|
+
"""
|
|
447
|
+
SELECT am.id
|
|
448
|
+
FROM agent_messages am
|
|
449
|
+
LEFT JOIN message_structure ms ON am.id = ms.message_id
|
|
450
|
+
WHERE am.session_id = ? AND ms.message_id IS NULL
|
|
451
|
+
""",
|
|
452
|
+
(self.session_id,),
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
orphaned_ids = [row[0] for row in cursor.fetchall()]
|
|
456
|
+
|
|
457
|
+
if orphaned_ids:
|
|
458
|
+
# Delete orphaned messages
|
|
459
|
+
placeholders = ",".join("?" * len(orphaned_ids))
|
|
460
|
+
cursor.execute(
|
|
461
|
+
f"DELETE FROM agent_messages WHERE id IN ({placeholders})", orphaned_ids
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
deleted_count = cursor.rowcount
|
|
465
|
+
conn.commit()
|
|
466
|
+
|
|
467
|
+
self._logger.info(f"Cleaned up {deleted_count} orphaned messages")
|
|
468
|
+
return deleted_count
|
|
469
|
+
|
|
470
|
+
return 0
|
|
471
|
+
|
|
472
|
+
return await asyncio.to_thread(_cleanup_sync)
|
|
473
|
+
|
|
474
|
+
def _classify_message_type(self, item: TResponseInputItem) -> str:
|
|
475
|
+
"""Classify the type of a message item.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
item: The message item to classify.
|
|
479
|
+
|
|
480
|
+
Returns:
|
|
481
|
+
String representing the message type (user, assistant, etc.).
|
|
482
|
+
"""
|
|
483
|
+
if isinstance(item, dict):
|
|
484
|
+
if item.get("role") == "user":
|
|
485
|
+
return "user"
|
|
486
|
+
elif item.get("role") == "assistant":
|
|
487
|
+
return "assistant"
|
|
488
|
+
elif item.get("type"):
|
|
489
|
+
return str(item.get("type"))
|
|
490
|
+
return "other"
|
|
491
|
+
|
|
492
|
+
def _extract_tool_name(self, item: TResponseInputItem) -> str | None:
|
|
493
|
+
"""Extract tool name if this is a tool call/output.
|
|
494
|
+
|
|
495
|
+
Args:
|
|
496
|
+
item: The message item to extract tool name from.
|
|
497
|
+
|
|
498
|
+
Returns:
|
|
499
|
+
Tool name if item is a tool call, None otherwise.
|
|
500
|
+
"""
|
|
501
|
+
if isinstance(item, dict):
|
|
502
|
+
item_type = item.get("type")
|
|
503
|
+
|
|
504
|
+
# For MCP tools, try to extract from server_label if available
|
|
505
|
+
if item_type in {"mcp_call", "mcp_approval_request"} and "server_label" in item:
|
|
506
|
+
server_label = item.get("server_label")
|
|
507
|
+
tool_name = item.get("name")
|
|
508
|
+
if tool_name and server_label:
|
|
509
|
+
return f"{server_label}.{tool_name}"
|
|
510
|
+
elif server_label:
|
|
511
|
+
return str(server_label)
|
|
512
|
+
elif tool_name:
|
|
513
|
+
return str(tool_name)
|
|
514
|
+
|
|
515
|
+
# For tool types without a 'name' field, derive from the type
|
|
516
|
+
elif item_type in {
|
|
517
|
+
"computer_call",
|
|
518
|
+
"file_search_call",
|
|
519
|
+
"web_search_call",
|
|
520
|
+
"code_interpreter_call",
|
|
521
|
+
}:
|
|
522
|
+
return item_type
|
|
523
|
+
|
|
524
|
+
# Most other tool calls have a 'name' field
|
|
525
|
+
elif "name" in item:
|
|
526
|
+
name = item.get("name")
|
|
527
|
+
return str(name) if name is not None else None
|
|
528
|
+
|
|
529
|
+
return None
|
|
530
|
+
|
|
531
|
+
def _is_user_message(self, item: TResponseInputItem) -> bool:
|
|
532
|
+
"""Check if this is a user message.
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
item: The message item to check.
|
|
536
|
+
|
|
537
|
+
Returns:
|
|
538
|
+
True if the item is a user message, False otherwise.
|
|
539
|
+
"""
|
|
540
|
+
return isinstance(item, dict) and item.get("role") == "user"
|
|
541
|
+
|
|
542
|
+
async def create_branch_from_turn(
|
|
543
|
+
self, turn_number: int, branch_name: str | None = None
|
|
544
|
+
) -> str:
|
|
545
|
+
"""Create a new branch starting from a specific user message turn.
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
turn_number: The branch turn number of the user message to branch from
|
|
549
|
+
branch_name: Optional name for the branch (auto-generated if None)
|
|
550
|
+
|
|
551
|
+
Returns:
|
|
552
|
+
The branch_id of the newly created branch
|
|
553
|
+
|
|
554
|
+
Raises:
|
|
555
|
+
ValueError: If turn doesn't exist or doesn't contain a user message
|
|
556
|
+
"""
|
|
557
|
+
import time
|
|
558
|
+
|
|
559
|
+
# Validate the turn exists and contains a user message
|
|
560
|
+
def _validate_turn():
|
|
561
|
+
"""Synchronous helper to validate turn exists and contains user message."""
|
|
562
|
+
conn = self._get_connection()
|
|
563
|
+
with closing(conn.cursor()) as cursor:
|
|
564
|
+
cursor.execute(
|
|
565
|
+
"""
|
|
566
|
+
SELECT am.message_data
|
|
567
|
+
FROM message_structure ms
|
|
568
|
+
JOIN agent_messages am ON ms.message_id = am.id
|
|
569
|
+
WHERE ms.session_id = ? AND ms.branch_id = ?
|
|
570
|
+
AND ms.branch_turn_number = ? AND ms.message_type = 'user'
|
|
571
|
+
""",
|
|
572
|
+
(self.session_id, self._current_branch_id, turn_number),
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
result = cursor.fetchone()
|
|
576
|
+
if not result:
|
|
577
|
+
raise ValueError(
|
|
578
|
+
f"Turn {turn_number} does not contain a user message "
|
|
579
|
+
f"in branch '{self._current_branch_id}'"
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
message_data = result[0]
|
|
583
|
+
try:
|
|
584
|
+
content = json.loads(message_data).get("content", "")
|
|
585
|
+
return content[:50] + "..." if len(content) > 50 else content
|
|
586
|
+
except Exception:
|
|
587
|
+
return "Unable to parse content"
|
|
588
|
+
|
|
589
|
+
turn_content = await asyncio.to_thread(_validate_turn)
|
|
590
|
+
|
|
591
|
+
# Generate branch name if not provided
|
|
592
|
+
if branch_name is None:
|
|
593
|
+
timestamp = int(time.time())
|
|
594
|
+
branch_name = f"branch_from_turn_{turn_number}_{timestamp}"
|
|
595
|
+
|
|
596
|
+
# Copy messages before the branch point to the new branch
|
|
597
|
+
await self._copy_messages_to_new_branch(branch_name, turn_number)
|
|
598
|
+
|
|
599
|
+
# Switch to new branch
|
|
600
|
+
old_branch = self._current_branch_id
|
|
601
|
+
self._current_branch_id = branch_name
|
|
602
|
+
|
|
603
|
+
self._logger.debug(
|
|
604
|
+
f"Created branch '{branch_name}' from turn {turn_number} ('{turn_content}') in '{old_branch}'" # noqa: E501
|
|
605
|
+
)
|
|
606
|
+
return branch_name
|
|
607
|
+
|
|
608
|
+
async def create_branch_from_content(
|
|
609
|
+
self, search_term: str, branch_name: str | None = None
|
|
610
|
+
) -> str:
|
|
611
|
+
"""Create branch from the first user turn matching the search term.
|
|
612
|
+
|
|
613
|
+
Args:
|
|
614
|
+
search_term: Text to search for in user messages.
|
|
615
|
+
branch_name: Optional name for the branch (auto-generated if None).
|
|
616
|
+
|
|
617
|
+
Returns:
|
|
618
|
+
The branch_id of the newly created branch.
|
|
619
|
+
|
|
620
|
+
Raises:
|
|
621
|
+
ValueError: If no matching turns are found.
|
|
622
|
+
"""
|
|
623
|
+
matching_turns = await self.find_turns_by_content(search_term)
|
|
624
|
+
if not matching_turns:
|
|
625
|
+
raise ValueError(f"No user turns found containing '{search_term}'")
|
|
626
|
+
|
|
627
|
+
# Use the first (earliest) match
|
|
628
|
+
turn_number = matching_turns[0]["turn"]
|
|
629
|
+
return await self.create_branch_from_turn(turn_number, branch_name)
|
|
630
|
+
|
|
631
|
+
async def switch_to_branch(self, branch_id: str) -> None:
|
|
632
|
+
"""Switch to a different branch.
|
|
633
|
+
|
|
634
|
+
Args:
|
|
635
|
+
branch_id: The branch to switch to.
|
|
636
|
+
|
|
637
|
+
Raises:
|
|
638
|
+
ValueError: If the branch doesn't exist.
|
|
639
|
+
"""
|
|
640
|
+
|
|
641
|
+
# Validate branch exists
|
|
642
|
+
def _validate_branch():
|
|
643
|
+
"""Synchronous helper to validate branch exists."""
|
|
644
|
+
conn = self._get_connection()
|
|
645
|
+
with closing(conn.cursor()) as cursor:
|
|
646
|
+
cursor.execute(
|
|
647
|
+
"""
|
|
648
|
+
SELECT COUNT(*) FROM message_structure
|
|
649
|
+
WHERE session_id = ? AND branch_id = ?
|
|
650
|
+
""",
|
|
651
|
+
(self.session_id, branch_id),
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
count = cursor.fetchone()[0]
|
|
655
|
+
if count == 0:
|
|
656
|
+
raise ValueError(f"Branch '{branch_id}' does not exist")
|
|
657
|
+
|
|
658
|
+
await asyncio.to_thread(_validate_branch)
|
|
659
|
+
|
|
660
|
+
old_branch = self._current_branch_id
|
|
661
|
+
self._current_branch_id = branch_id
|
|
662
|
+
self._logger.info(f"Switched from branch '{old_branch}' to '{branch_id}'")
|
|
663
|
+
|
|
664
|
+
async def delete_branch(self, branch_id: str, force: bool = False) -> None:
|
|
665
|
+
"""Delete a branch and all its associated data.
|
|
666
|
+
|
|
667
|
+
Args:
|
|
668
|
+
branch_id: The branch to delete.
|
|
669
|
+
force: If True, allows deleting the current branch (will switch to 'main').
|
|
670
|
+
|
|
671
|
+
Raises:
|
|
672
|
+
ValueError: If branch doesn't exist, is 'main', or is current branch without force.
|
|
673
|
+
"""
|
|
674
|
+
if not branch_id or not branch_id.strip():
|
|
675
|
+
raise ValueError("Branch ID cannot be empty")
|
|
676
|
+
|
|
677
|
+
branch_id = branch_id.strip()
|
|
678
|
+
|
|
679
|
+
# Protect main branch
|
|
680
|
+
if branch_id == "main":
|
|
681
|
+
raise ValueError("Cannot delete the 'main' branch")
|
|
682
|
+
|
|
683
|
+
# Check if trying to delete current branch
|
|
684
|
+
if branch_id == self._current_branch_id:
|
|
685
|
+
if not force:
|
|
686
|
+
raise ValueError(
|
|
687
|
+
f"Cannot delete current branch '{branch_id}'. Use force=True or switch branches first" # noqa: E501
|
|
688
|
+
)
|
|
689
|
+
else:
|
|
690
|
+
# Switch to main before deleting
|
|
691
|
+
await self.switch_to_branch("main")
|
|
692
|
+
|
|
693
|
+
def _delete_sync():
|
|
694
|
+
"""Synchronous helper to delete branch and associated data."""
|
|
695
|
+
conn = self._get_connection()
|
|
696
|
+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
|
|
697
|
+
with self._lock if self._is_memory_db else threading.Lock():
|
|
698
|
+
with closing(conn.cursor()) as cursor:
|
|
699
|
+
# First verify the branch exists
|
|
700
|
+
cursor.execute(
|
|
701
|
+
"""
|
|
702
|
+
SELECT COUNT(*) FROM message_structure
|
|
703
|
+
WHERE session_id = ? AND branch_id = ?
|
|
704
|
+
""",
|
|
705
|
+
(self.session_id, branch_id),
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
count = cursor.fetchone()[0]
|
|
709
|
+
if count == 0:
|
|
710
|
+
raise ValueError(f"Branch '{branch_id}' does not exist")
|
|
711
|
+
|
|
712
|
+
# Delete from turn_usage first (foreign key constraint)
|
|
713
|
+
cursor.execute(
|
|
714
|
+
"""
|
|
715
|
+
DELETE FROM turn_usage
|
|
716
|
+
WHERE session_id = ? AND branch_id = ?
|
|
717
|
+
""",
|
|
718
|
+
(self.session_id, branch_id),
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
usage_deleted = cursor.rowcount
|
|
722
|
+
|
|
723
|
+
# Delete from message_structure
|
|
724
|
+
cursor.execute(
|
|
725
|
+
"""
|
|
726
|
+
DELETE FROM message_structure
|
|
727
|
+
WHERE session_id = ? AND branch_id = ?
|
|
728
|
+
""",
|
|
729
|
+
(self.session_id, branch_id),
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
structure_deleted = cursor.rowcount
|
|
733
|
+
|
|
734
|
+
conn.commit()
|
|
735
|
+
|
|
736
|
+
return usage_deleted, structure_deleted
|
|
737
|
+
|
|
738
|
+
usage_deleted, structure_deleted = await asyncio.to_thread(_delete_sync)
|
|
739
|
+
|
|
740
|
+
self._logger.info(
|
|
741
|
+
f"Deleted branch '{branch_id}': {structure_deleted} message entries, {usage_deleted} usage entries" # noqa: E501
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
async def list_branches(self) -> list[dict[str, Any]]:
|
|
745
|
+
"""List all branches in this session.
|
|
746
|
+
|
|
747
|
+
Returns:
|
|
748
|
+
List of dicts with branch info containing:
|
|
749
|
+
- 'branch_id': Branch identifier
|
|
750
|
+
- 'message_count': Number of messages in branch
|
|
751
|
+
- 'user_turns': Number of user turns in branch
|
|
752
|
+
- 'is_current': Whether this is the current branch
|
|
753
|
+
- 'created_at': When the branch was first created
|
|
754
|
+
"""
|
|
755
|
+
|
|
756
|
+
def _list_branches_sync():
|
|
757
|
+
"""Synchronous helper to list all branches."""
|
|
758
|
+
conn = self._get_connection()
|
|
759
|
+
with closing(conn.cursor()) as cursor:
|
|
760
|
+
cursor.execute(
|
|
761
|
+
"""
|
|
762
|
+
SELECT
|
|
763
|
+
ms.branch_id,
|
|
764
|
+
COUNT(*) as message_count,
|
|
765
|
+
COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns,
|
|
766
|
+
MIN(ms.created_at) as created_at
|
|
767
|
+
FROM message_structure ms
|
|
768
|
+
WHERE ms.session_id = ?
|
|
769
|
+
GROUP BY ms.branch_id
|
|
770
|
+
ORDER BY created_at
|
|
771
|
+
""",
|
|
772
|
+
(self.session_id,),
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
branches = []
|
|
776
|
+
for row in cursor.fetchall():
|
|
777
|
+
branch_id, msg_count, user_turns, created_at = row
|
|
778
|
+
branches.append(
|
|
779
|
+
{
|
|
780
|
+
"branch_id": branch_id,
|
|
781
|
+
"message_count": msg_count,
|
|
782
|
+
"user_turns": user_turns,
|
|
783
|
+
"is_current": branch_id == self._current_branch_id,
|
|
784
|
+
"created_at": created_at,
|
|
785
|
+
}
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
return branches
|
|
789
|
+
|
|
790
|
+
return await asyncio.to_thread(_list_branches_sync)
|
|
791
|
+
|
|
792
|
+
async def _copy_messages_to_new_branch(self, new_branch_id: str, from_turn_number: int) -> None:
|
|
793
|
+
"""Copy messages before the branch point to the new branch.
|
|
794
|
+
|
|
795
|
+
Args:
|
|
796
|
+
new_branch_id: The ID of the new branch to copy messages to.
|
|
797
|
+
from_turn_number: The turn number to copy messages up to (exclusive).
|
|
798
|
+
"""
|
|
799
|
+
|
|
800
|
+
def _copy_sync():
|
|
801
|
+
"""Synchronous helper to copy messages to new branch."""
|
|
802
|
+
conn = self._get_connection()
|
|
803
|
+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
|
|
804
|
+
with self._lock if self._is_memory_db else threading.Lock():
|
|
805
|
+
with closing(conn.cursor()) as cursor:
|
|
806
|
+
# Get all messages before the branch point
|
|
807
|
+
cursor.execute(
|
|
808
|
+
"""
|
|
809
|
+
SELECT
|
|
810
|
+
ms.message_id,
|
|
811
|
+
ms.message_type,
|
|
812
|
+
ms.sequence_number,
|
|
813
|
+
ms.user_turn_number,
|
|
814
|
+
ms.branch_turn_number,
|
|
815
|
+
ms.tool_name
|
|
816
|
+
FROM message_structure ms
|
|
817
|
+
WHERE ms.session_id = ? AND ms.branch_id = ?
|
|
818
|
+
AND ms.branch_turn_number < ?
|
|
819
|
+
ORDER BY ms.sequence_number
|
|
820
|
+
""",
|
|
821
|
+
(self.session_id, self._current_branch_id, from_turn_number),
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
messages_to_copy = cursor.fetchall()
|
|
825
|
+
|
|
826
|
+
if messages_to_copy:
|
|
827
|
+
# Get the max sequence number for the new inserts
|
|
828
|
+
cursor.execute(
|
|
829
|
+
"""
|
|
830
|
+
SELECT COALESCE(MAX(sequence_number), 0)
|
|
831
|
+
FROM message_structure
|
|
832
|
+
WHERE session_id = ?
|
|
833
|
+
""",
|
|
834
|
+
(self.session_id,),
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
seq_start = cursor.fetchone()[0]
|
|
838
|
+
|
|
839
|
+
# Insert copied messages with new branch_id
|
|
840
|
+
new_structure_data = []
|
|
841
|
+
for i, (
|
|
842
|
+
msg_id,
|
|
843
|
+
msg_type,
|
|
844
|
+
_,
|
|
845
|
+
user_turn,
|
|
846
|
+
branch_turn,
|
|
847
|
+
tool_name,
|
|
848
|
+
) in enumerate(messages_to_copy):
|
|
849
|
+
new_structure_data.append(
|
|
850
|
+
(
|
|
851
|
+
self.session_id,
|
|
852
|
+
msg_id, # Same message_id (sharing the actual message data)
|
|
853
|
+
new_branch_id,
|
|
854
|
+
msg_type,
|
|
855
|
+
seq_start + i + 1, # New sequence number
|
|
856
|
+
user_turn, # Keep same global turn number
|
|
857
|
+
branch_turn, # Keep same branch turn number
|
|
858
|
+
tool_name,
|
|
859
|
+
)
|
|
860
|
+
)
|
|
861
|
+
|
|
862
|
+
cursor.executemany(
|
|
863
|
+
"""
|
|
864
|
+
INSERT INTO message_structure
|
|
865
|
+
(session_id, message_id, branch_id, message_type, sequence_number,
|
|
866
|
+
user_turn_number, branch_turn_number, tool_name)
|
|
867
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
868
|
+
""",
|
|
869
|
+
new_structure_data,
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
conn.commit()
|
|
873
|
+
|
|
874
|
+
await asyncio.to_thread(_copy_sync)
|
|
875
|
+
|
|
876
|
+
async def get_conversation_turns(self, branch_id: str | None = None) -> list[dict[str, Any]]:
|
|
877
|
+
"""Get user turns with content for easy browsing and branching decisions.
|
|
878
|
+
|
|
879
|
+
Args:
|
|
880
|
+
branch_id: Branch to get turns from (current branch if None).
|
|
881
|
+
|
|
882
|
+
Returns:
|
|
883
|
+
List of dicts with turn info containing:
|
|
884
|
+
- 'turn': Branch turn number
|
|
885
|
+
- 'content': User message content (truncated)
|
|
886
|
+
- 'full_content': Full user message content
|
|
887
|
+
- 'timestamp': When the turn was created
|
|
888
|
+
- 'can_branch': Always True (all user messages can branch)
|
|
889
|
+
"""
|
|
890
|
+
if branch_id is None:
|
|
891
|
+
branch_id = self._current_branch_id
|
|
892
|
+
|
|
893
|
+
def _get_turns_sync():
|
|
894
|
+
"""Synchronous helper to get conversation turns."""
|
|
895
|
+
conn = self._get_connection()
|
|
896
|
+
with closing(conn.cursor()) as cursor:
|
|
897
|
+
cursor.execute(
|
|
898
|
+
"""
|
|
899
|
+
SELECT
|
|
900
|
+
ms.branch_turn_number,
|
|
901
|
+
am.message_data,
|
|
902
|
+
ms.created_at
|
|
903
|
+
FROM message_structure ms
|
|
904
|
+
JOIN agent_messages am ON ms.message_id = am.id
|
|
905
|
+
WHERE ms.session_id = ? AND ms.branch_id = ?
|
|
906
|
+
AND ms.message_type = 'user'
|
|
907
|
+
ORDER BY ms.branch_turn_number
|
|
908
|
+
""",
|
|
909
|
+
(self.session_id, branch_id),
|
|
910
|
+
)
|
|
911
|
+
|
|
912
|
+
turns = []
|
|
913
|
+
for row in cursor.fetchall():
|
|
914
|
+
turn_num, message_data, created_at = row
|
|
915
|
+
try:
|
|
916
|
+
content = json.loads(message_data).get("content", "")
|
|
917
|
+
turns.append(
|
|
918
|
+
{
|
|
919
|
+
"turn": turn_num,
|
|
920
|
+
"content": content[:100] + "..." if len(content) > 100 else content,
|
|
921
|
+
"full_content": content,
|
|
922
|
+
"timestamp": created_at,
|
|
923
|
+
"can_branch": True,
|
|
924
|
+
}
|
|
925
|
+
)
|
|
926
|
+
except (json.JSONDecodeError, AttributeError):
|
|
927
|
+
continue
|
|
928
|
+
|
|
929
|
+
return turns
|
|
930
|
+
|
|
931
|
+
return await asyncio.to_thread(_get_turns_sync)
|
|
932
|
+
|
|
933
|
+
async def find_turns_by_content(
|
|
934
|
+
self, search_term: str, branch_id: str | None = None
|
|
935
|
+
) -> list[dict[str, Any]]:
|
|
936
|
+
"""Find user turns containing specific content.
|
|
937
|
+
|
|
938
|
+
Args:
|
|
939
|
+
search_term: Text to search for in user messages.
|
|
940
|
+
branch_id: Branch to search in (current branch if None).
|
|
941
|
+
|
|
942
|
+
Returns:
|
|
943
|
+
List of matching turns with same format as get_conversation_turns().
|
|
944
|
+
"""
|
|
945
|
+
if branch_id is None:
|
|
946
|
+
branch_id = self._current_branch_id
|
|
947
|
+
|
|
948
|
+
def _search_sync():
|
|
949
|
+
"""Synchronous helper to search turns by content."""
|
|
950
|
+
conn = self._get_connection()
|
|
951
|
+
with closing(conn.cursor()) as cursor:
|
|
952
|
+
cursor.execute(
|
|
953
|
+
"""
|
|
954
|
+
SELECT
|
|
955
|
+
ms.branch_turn_number,
|
|
956
|
+
am.message_data,
|
|
957
|
+
ms.created_at
|
|
958
|
+
FROM message_structure ms
|
|
959
|
+
JOIN agent_messages am ON ms.message_id = am.id
|
|
960
|
+
WHERE ms.session_id = ? AND ms.branch_id = ?
|
|
961
|
+
AND ms.message_type = 'user'
|
|
962
|
+
AND am.message_data LIKE ?
|
|
963
|
+
ORDER BY ms.branch_turn_number
|
|
964
|
+
""",
|
|
965
|
+
(self.session_id, branch_id, f"%{search_term}%"),
|
|
966
|
+
)
|
|
967
|
+
|
|
968
|
+
matches = []
|
|
969
|
+
for row in cursor.fetchall():
|
|
970
|
+
turn_num, message_data, created_at = row
|
|
971
|
+
try:
|
|
972
|
+
content = json.loads(message_data).get("content", "")
|
|
973
|
+
matches.append(
|
|
974
|
+
{
|
|
975
|
+
"turn": turn_num,
|
|
976
|
+
"content": content,
|
|
977
|
+
"full_content": content,
|
|
978
|
+
"timestamp": created_at,
|
|
979
|
+
"can_branch": True,
|
|
980
|
+
}
|
|
981
|
+
)
|
|
982
|
+
except (json.JSONDecodeError, AttributeError):
|
|
983
|
+
continue
|
|
984
|
+
|
|
985
|
+
return matches
|
|
986
|
+
|
|
987
|
+
return await asyncio.to_thread(_search_sync)
|
|
988
|
+
|
|
989
|
+
async def get_conversation_by_turns(
|
|
990
|
+
self, branch_id: str | None = None
|
|
991
|
+
) -> dict[int, list[dict[str, str | None]]]:
|
|
992
|
+
"""Get conversation grouped by user turns for specified branch.
|
|
993
|
+
|
|
994
|
+
Args:
|
|
995
|
+
branch_id: Branch to get conversation from (current branch if None).
|
|
996
|
+
|
|
997
|
+
Returns:
|
|
998
|
+
Dictionary mapping turn numbers to lists of message metadata.
|
|
999
|
+
"""
|
|
1000
|
+
if branch_id is None:
|
|
1001
|
+
branch_id = self._current_branch_id
|
|
1002
|
+
|
|
1003
|
+
def _get_conversation_sync():
|
|
1004
|
+
"""Synchronous helper to get conversation by turns."""
|
|
1005
|
+
conn = self._get_connection()
|
|
1006
|
+
with closing(conn.cursor()) as cursor:
|
|
1007
|
+
cursor.execute(
|
|
1008
|
+
"""
|
|
1009
|
+
SELECT user_turn_number, message_type, tool_name
|
|
1010
|
+
FROM message_structure
|
|
1011
|
+
WHERE session_id = ? AND branch_id = ?
|
|
1012
|
+
ORDER BY sequence_number
|
|
1013
|
+
""",
|
|
1014
|
+
(self.session_id, branch_id),
|
|
1015
|
+
)
|
|
1016
|
+
|
|
1017
|
+
turns: dict[int, list[dict[str, str | None]]] = {}
|
|
1018
|
+
for row in cursor.fetchall():
|
|
1019
|
+
turn_num, msg_type, tool_name = row
|
|
1020
|
+
if turn_num not in turns:
|
|
1021
|
+
turns[turn_num] = []
|
|
1022
|
+
turns[turn_num].append({"type": msg_type, "tool_name": tool_name})
|
|
1023
|
+
return turns
|
|
1024
|
+
|
|
1025
|
+
return await asyncio.to_thread(_get_conversation_sync)
|
|
1026
|
+
|
|
1027
|
+
async def get_tool_usage(self, branch_id: str | None = None) -> list[tuple[str, int, int]]:
|
|
1028
|
+
"""Get all tool usage by turn for specified branch.
|
|
1029
|
+
|
|
1030
|
+
Args:
|
|
1031
|
+
branch_id: Branch to get tool usage from (current branch if None).
|
|
1032
|
+
|
|
1033
|
+
Returns:
|
|
1034
|
+
List of tuples containing (tool_name, usage_count, turn_number).
|
|
1035
|
+
"""
|
|
1036
|
+
if branch_id is None:
|
|
1037
|
+
branch_id = self._current_branch_id
|
|
1038
|
+
|
|
1039
|
+
def _get_tool_usage_sync():
|
|
1040
|
+
"""Synchronous helper to get tool usage statistics."""
|
|
1041
|
+
conn = self._get_connection()
|
|
1042
|
+
with closing(conn.cursor()) as cursor:
|
|
1043
|
+
cursor.execute(
|
|
1044
|
+
"""
|
|
1045
|
+
SELECT tool_name, COUNT(*), user_turn_number
|
|
1046
|
+
FROM message_structure
|
|
1047
|
+
WHERE session_id = ? AND branch_id = ? AND message_type IN (
|
|
1048
|
+
'tool_call', 'function_call', 'computer_call', 'file_search_call',
|
|
1049
|
+
'web_search_call', 'code_interpreter_call', 'custom_tool_call',
|
|
1050
|
+
'mcp_call', 'mcp_approval_request'
|
|
1051
|
+
)
|
|
1052
|
+
GROUP BY tool_name, user_turn_number
|
|
1053
|
+
ORDER BY user_turn_number
|
|
1054
|
+
""",
|
|
1055
|
+
(self.session_id, branch_id),
|
|
1056
|
+
)
|
|
1057
|
+
return cursor.fetchall()
|
|
1058
|
+
|
|
1059
|
+
return await asyncio.to_thread(_get_tool_usage_sync)
|
|
1060
|
+
|
|
1061
|
+
async def get_session_usage(self, branch_id: str | None = None) -> dict[str, int] | None:
|
|
1062
|
+
"""Get cumulative usage for session or specific branch.
|
|
1063
|
+
|
|
1064
|
+
Args:
|
|
1065
|
+
branch_id: If provided, only get usage for that branch. If None, get all branches.
|
|
1066
|
+
|
|
1067
|
+
Returns:
|
|
1068
|
+
Dictionary with usage statistics or None if no usage data found.
|
|
1069
|
+
"""
|
|
1070
|
+
|
|
1071
|
+
def _get_usage_sync():
|
|
1072
|
+
"""Synchronous helper to get session usage data."""
|
|
1073
|
+
conn = self._get_connection()
|
|
1074
|
+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
|
|
1075
|
+
with self._lock if self._is_memory_db else threading.Lock():
|
|
1076
|
+
if branch_id:
|
|
1077
|
+
# Branch-specific usage
|
|
1078
|
+
query = """
|
|
1079
|
+
SELECT
|
|
1080
|
+
SUM(requests) as total_requests,
|
|
1081
|
+
SUM(input_tokens) as total_input_tokens,
|
|
1082
|
+
SUM(output_tokens) as total_output_tokens,
|
|
1083
|
+
SUM(total_tokens) as total_total_tokens,
|
|
1084
|
+
COUNT(*) as total_turns
|
|
1085
|
+
FROM turn_usage
|
|
1086
|
+
WHERE session_id = ? AND branch_id = ?
|
|
1087
|
+
"""
|
|
1088
|
+
params: tuple[str, ...] = (self.session_id, branch_id)
|
|
1089
|
+
else:
|
|
1090
|
+
# All branches
|
|
1091
|
+
query = """
|
|
1092
|
+
SELECT
|
|
1093
|
+
SUM(requests) as total_requests,
|
|
1094
|
+
SUM(input_tokens) as total_input_tokens,
|
|
1095
|
+
SUM(output_tokens) as total_output_tokens,
|
|
1096
|
+
SUM(total_tokens) as total_total_tokens,
|
|
1097
|
+
COUNT(*) as total_turns
|
|
1098
|
+
FROM turn_usage
|
|
1099
|
+
WHERE session_id = ?
|
|
1100
|
+
"""
|
|
1101
|
+
params = (self.session_id,)
|
|
1102
|
+
|
|
1103
|
+
with closing(conn.cursor()) as cursor:
|
|
1104
|
+
cursor.execute(query, params)
|
|
1105
|
+
row = cursor.fetchone()
|
|
1106
|
+
|
|
1107
|
+
if row and row[0] is not None:
|
|
1108
|
+
return {
|
|
1109
|
+
"requests": row[0] or 0,
|
|
1110
|
+
"input_tokens": row[1] or 0,
|
|
1111
|
+
"output_tokens": row[2] or 0,
|
|
1112
|
+
"total_tokens": row[3] or 0,
|
|
1113
|
+
"total_turns": row[4] or 0,
|
|
1114
|
+
}
|
|
1115
|
+
return None
|
|
1116
|
+
|
|
1117
|
+
result = await asyncio.to_thread(_get_usage_sync)
|
|
1118
|
+
|
|
1119
|
+
return cast(Union[dict[str, int], None], result)
|
|
1120
|
+
|
|
1121
|
+
async def get_turn_usage(
|
|
1122
|
+
self,
|
|
1123
|
+
user_turn_number: int | None = None,
|
|
1124
|
+
branch_id: str | None = None,
|
|
1125
|
+
) -> list[dict[str, Any]] | dict[str, Any]:
|
|
1126
|
+
"""Get usage statistics by turn with full JSON token details.
|
|
1127
|
+
|
|
1128
|
+
Args:
|
|
1129
|
+
user_turn_number: Specific turn to get usage for. If None, returns all turns.
|
|
1130
|
+
branch_id: Branch to get usage from (current branch if None).
|
|
1131
|
+
|
|
1132
|
+
Returns:
|
|
1133
|
+
Dictionary with usage data for specific turn, or list of dictionaries for all turns.
|
|
1134
|
+
"""
|
|
1135
|
+
|
|
1136
|
+
if branch_id is None:
|
|
1137
|
+
branch_id = self._current_branch_id
|
|
1138
|
+
|
|
1139
|
+
def _get_turn_usage_sync():
|
|
1140
|
+
"""Synchronous helper to get turn usage statistics."""
|
|
1141
|
+
conn = self._get_connection()
|
|
1142
|
+
|
|
1143
|
+
if user_turn_number is not None:
|
|
1144
|
+
query = """
|
|
1145
|
+
SELECT requests, input_tokens, output_tokens, total_tokens,
|
|
1146
|
+
input_tokens_details, output_tokens_details
|
|
1147
|
+
FROM turn_usage
|
|
1148
|
+
WHERE session_id = ? AND branch_id = ? AND user_turn_number = ?
|
|
1149
|
+
"""
|
|
1150
|
+
|
|
1151
|
+
with closing(conn.cursor()) as cursor:
|
|
1152
|
+
cursor.execute(query, (self.session_id, branch_id, user_turn_number))
|
|
1153
|
+
row = cursor.fetchone()
|
|
1154
|
+
|
|
1155
|
+
if row:
|
|
1156
|
+
# Parse JSON details if present
|
|
1157
|
+
input_details = None
|
|
1158
|
+
output_details = None
|
|
1159
|
+
|
|
1160
|
+
if row[4]: # input_tokens_details
|
|
1161
|
+
try:
|
|
1162
|
+
input_details = json.loads(row[4])
|
|
1163
|
+
except json.JSONDecodeError:
|
|
1164
|
+
pass
|
|
1165
|
+
|
|
1166
|
+
if row[5]: # output_tokens_details
|
|
1167
|
+
try:
|
|
1168
|
+
output_details = json.loads(row[5])
|
|
1169
|
+
except json.JSONDecodeError:
|
|
1170
|
+
pass
|
|
1171
|
+
|
|
1172
|
+
return {
|
|
1173
|
+
"requests": row[0],
|
|
1174
|
+
"input_tokens": row[1],
|
|
1175
|
+
"output_tokens": row[2],
|
|
1176
|
+
"total_tokens": row[3],
|
|
1177
|
+
"input_tokens_details": input_details,
|
|
1178
|
+
"output_tokens_details": output_details,
|
|
1179
|
+
}
|
|
1180
|
+
return {}
|
|
1181
|
+
else:
|
|
1182
|
+
query = """
|
|
1183
|
+
SELECT user_turn_number, requests, input_tokens, output_tokens,
|
|
1184
|
+
total_tokens, input_tokens_details, output_tokens_details
|
|
1185
|
+
FROM turn_usage
|
|
1186
|
+
WHERE session_id = ? AND branch_id = ?
|
|
1187
|
+
ORDER BY user_turn_number
|
|
1188
|
+
"""
|
|
1189
|
+
|
|
1190
|
+
with closing(conn.cursor()) as cursor:
|
|
1191
|
+
cursor.execute(query, (self.session_id, branch_id))
|
|
1192
|
+
results = []
|
|
1193
|
+
for row in cursor.fetchall():
|
|
1194
|
+
# Parse JSON details if present
|
|
1195
|
+
input_details = None
|
|
1196
|
+
output_details = None
|
|
1197
|
+
|
|
1198
|
+
if row[5]: # input_tokens_details
|
|
1199
|
+
try:
|
|
1200
|
+
input_details = json.loads(row[5])
|
|
1201
|
+
except json.JSONDecodeError:
|
|
1202
|
+
pass
|
|
1203
|
+
|
|
1204
|
+
if row[6]: # output_tokens_details
|
|
1205
|
+
try:
|
|
1206
|
+
output_details = json.loads(row[6])
|
|
1207
|
+
except json.JSONDecodeError:
|
|
1208
|
+
pass
|
|
1209
|
+
|
|
1210
|
+
results.append(
|
|
1211
|
+
{
|
|
1212
|
+
"user_turn_number": row[0],
|
|
1213
|
+
"requests": row[1],
|
|
1214
|
+
"input_tokens": row[2],
|
|
1215
|
+
"output_tokens": row[3],
|
|
1216
|
+
"total_tokens": row[4],
|
|
1217
|
+
"input_tokens_details": input_details,
|
|
1218
|
+
"output_tokens_details": output_details,
|
|
1219
|
+
}
|
|
1220
|
+
)
|
|
1221
|
+
return results
|
|
1222
|
+
|
|
1223
|
+
result = await asyncio.to_thread(_get_turn_usage_sync)
|
|
1224
|
+
|
|
1225
|
+
return cast(Union[list[dict[str, Any]], dict[str, Any]], result)
|
|
1226
|
+
|
|
1227
|
+
async def _update_turn_usage_internal(self, user_turn_number: int, usage_data: Usage) -> None:
|
|
1228
|
+
"""Internal method to update usage for a specific turn with full JSON details.
|
|
1229
|
+
|
|
1230
|
+
Args:
|
|
1231
|
+
user_turn_number: The turn number to update usage for.
|
|
1232
|
+
usage_data: The usage data to store.
|
|
1233
|
+
"""
|
|
1234
|
+
|
|
1235
|
+
def _update_sync():
|
|
1236
|
+
"""Synchronous helper to update turn usage data."""
|
|
1237
|
+
conn = self._get_connection()
|
|
1238
|
+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
|
|
1239
|
+
with self._lock if self._is_memory_db else threading.Lock():
|
|
1240
|
+
# Serialize token details as JSON
|
|
1241
|
+
input_details_json = None
|
|
1242
|
+
output_details_json = None
|
|
1243
|
+
|
|
1244
|
+
if hasattr(usage_data, "input_tokens_details") and usage_data.input_tokens_details:
|
|
1245
|
+
try:
|
|
1246
|
+
input_details_json = json.dumps(usage_data.input_tokens_details.__dict__)
|
|
1247
|
+
except (TypeError, ValueError) as e:
|
|
1248
|
+
self._logger.warning(f"Failed to serialize input tokens details: {e}")
|
|
1249
|
+
input_details_json = None
|
|
1250
|
+
|
|
1251
|
+
if (
|
|
1252
|
+
hasattr(usage_data, "output_tokens_details")
|
|
1253
|
+
and usage_data.output_tokens_details
|
|
1254
|
+
):
|
|
1255
|
+
try:
|
|
1256
|
+
output_details_json = json.dumps(
|
|
1257
|
+
usage_data.output_tokens_details.__dict__
|
|
1258
|
+
)
|
|
1259
|
+
except (TypeError, ValueError) as e:
|
|
1260
|
+
self._logger.warning(f"Failed to serialize output tokens details: {e}")
|
|
1261
|
+
output_details_json = None
|
|
1262
|
+
|
|
1263
|
+
with closing(conn.cursor()) as cursor:
|
|
1264
|
+
cursor.execute(
|
|
1265
|
+
"""
|
|
1266
|
+
INSERT OR REPLACE INTO turn_usage
|
|
1267
|
+
(session_id, branch_id, user_turn_number, requests, input_tokens, output_tokens,
|
|
1268
|
+
total_tokens, input_tokens_details, output_tokens_details)
|
|
1269
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
1270
|
+
""", # noqa: E501
|
|
1271
|
+
(
|
|
1272
|
+
self.session_id,
|
|
1273
|
+
self._current_branch_id,
|
|
1274
|
+
user_turn_number,
|
|
1275
|
+
usage_data.requests or 0,
|
|
1276
|
+
usage_data.input_tokens or 0,
|
|
1277
|
+
usage_data.output_tokens or 0,
|
|
1278
|
+
usage_data.total_tokens or 0,
|
|
1279
|
+
input_details_json,
|
|
1280
|
+
output_details_json,
|
|
1281
|
+
),
|
|
1282
|
+
)
|
|
1283
|
+
conn.commit()
|
|
1284
|
+
|
|
1285
|
+
await asyncio.to_thread(_update_sync)
|