openai-agents 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

@@ -0,0 +1,1285 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import threading
7
+ from contextlib import closing
8
+ from pathlib import Path
9
+ from typing import Any, Union, cast
10
+
11
+ from agents.result import RunResult
12
+ from agents.usage import Usage
13
+
14
+ from ...items import TResponseInputItem
15
+ from ...memory import SQLiteSession
16
+
17
+
18
+ class AdvancedSQLiteSession(SQLiteSession):
19
+ """Enhanced SQLite session with conversation branching and usage analytics."""
20
+
21
+ def __init__(
22
+ self,
23
+ *,
24
+ session_id: str,
25
+ db_path: str | Path = ":memory:",
26
+ create_tables: bool = False,
27
+ logger: logging.Logger | None = None,
28
+ **kwargs,
29
+ ):
30
+ """Initialize the AdvancedSQLiteSession.
31
+
32
+ Args:
33
+ session_id: The ID of the session
34
+ db_path: The path to the SQLite database file. Defaults to `:memory:` for in-memory storage
35
+ create_tables: Whether to create the structure tables
36
+ logger: The logger to use. Defaults to the module logger
37
+ **kwargs: Additional keyword arguments to pass to the superclass
38
+ """ # noqa: E501
39
+ super().__init__(session_id, db_path, **kwargs)
40
+ if create_tables:
41
+ self._init_structure_tables()
42
+ self._current_branch_id = "main"
43
+ self._logger = logger or logging.getLogger(__name__)
44
+
45
+ def _init_structure_tables(self):
46
+ """Add structure and usage tracking tables.
47
+
48
+ Creates the message_structure and turn_usage tables with appropriate
49
+ indexes for conversation branching and usage analytics.
50
+ """
51
+ conn = self._get_connection()
52
+
53
+ # Message structure with branch support
54
+ conn.execute("""
55
+ CREATE TABLE IF NOT EXISTS message_structure (
56
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
57
+ session_id TEXT NOT NULL,
58
+ message_id INTEGER NOT NULL,
59
+ branch_id TEXT NOT NULL DEFAULT 'main',
60
+ message_type TEXT NOT NULL,
61
+ sequence_number INTEGER NOT NULL,
62
+ user_turn_number INTEGER,
63
+ branch_turn_number INTEGER,
64
+ tool_name TEXT,
65
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
66
+ FOREIGN KEY (session_id) REFERENCES agent_sessions(session_id) ON DELETE CASCADE,
67
+ FOREIGN KEY (message_id) REFERENCES agent_messages(id) ON DELETE CASCADE
68
+ )
69
+ """)
70
+
71
+ # Turn-level usage tracking with branch support and full JSON details
72
+ conn.execute("""
73
+ CREATE TABLE IF NOT EXISTS turn_usage (
74
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
75
+ session_id TEXT NOT NULL,
76
+ branch_id TEXT NOT NULL DEFAULT 'main',
77
+ user_turn_number INTEGER NOT NULL,
78
+ requests INTEGER DEFAULT 0,
79
+ input_tokens INTEGER DEFAULT 0,
80
+ output_tokens INTEGER DEFAULT 0,
81
+ total_tokens INTEGER DEFAULT 0,
82
+ input_tokens_details JSON,
83
+ output_tokens_details JSON,
84
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
85
+ FOREIGN KEY (session_id) REFERENCES agent_sessions(session_id) ON DELETE CASCADE,
86
+ UNIQUE(session_id, branch_id, user_turn_number)
87
+ )
88
+ """)
89
+
90
+ # Indexes
91
+ conn.execute("""
92
+ CREATE INDEX IF NOT EXISTS idx_structure_session_seq
93
+ ON message_structure(session_id, sequence_number)
94
+ """)
95
+ conn.execute("""
96
+ CREATE INDEX IF NOT EXISTS idx_structure_branch
97
+ ON message_structure(session_id, branch_id)
98
+ """)
99
+ conn.execute("""
100
+ CREATE INDEX IF NOT EXISTS idx_structure_turn
101
+ ON message_structure(session_id, branch_id, user_turn_number)
102
+ """)
103
+ conn.execute("""
104
+ CREATE INDEX IF NOT EXISTS idx_structure_branch_seq
105
+ ON message_structure(session_id, branch_id, sequence_number)
106
+ """)
107
+ conn.execute("""
108
+ CREATE INDEX IF NOT EXISTS idx_turn_usage_session_turn
109
+ ON turn_usage(session_id, branch_id, user_turn_number)
110
+ """)
111
+
112
+ conn.commit()
113
+
114
+ async def add_items(self, items: list[TResponseInputItem]) -> None:
115
+ """Add items to the session.
116
+
117
+ Args:
118
+ items: The items to add to the session
119
+ """
120
+ # Add to base table first
121
+ await super().add_items(items)
122
+
123
+ # Extract structure metadata with precise sequencing
124
+ if items:
125
+ await self._add_structure_metadata(items)
126
+
127
+ async def get_items(
128
+ self,
129
+ limit: int | None = None,
130
+ branch_id: str | None = None,
131
+ ) -> list[TResponseInputItem]:
132
+ """Get items from current or specified branch.
133
+
134
+ Args:
135
+ limit: Maximum number of items to return. If None, returns all items.
136
+ branch_id: Branch to get items from. If None, uses current branch.
137
+
138
+ Returns:
139
+ List of conversation items from the specified branch.
140
+ """
141
+ if branch_id is None:
142
+ branch_id = self._current_branch_id
143
+
144
+ # Get all items for this branch
145
+ def _get_all_items_sync():
146
+ """Synchronous helper to get all items for a branch."""
147
+ conn = self._get_connection()
148
+ # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
149
+ with self._lock if self._is_memory_db else threading.Lock():
150
+ with closing(conn.cursor()) as cursor:
151
+ if limit is None:
152
+ cursor.execute(
153
+ """
154
+ SELECT m.message_data
155
+ FROM agent_messages m
156
+ JOIN message_structure s ON m.id = s.message_id
157
+ WHERE m.session_id = ? AND s.branch_id = ?
158
+ ORDER BY s.sequence_number ASC
159
+ """,
160
+ (self.session_id, branch_id),
161
+ )
162
+ else:
163
+ cursor.execute(
164
+ """
165
+ SELECT m.message_data
166
+ FROM agent_messages m
167
+ JOIN message_structure s ON m.id = s.message_id
168
+ WHERE m.session_id = ? AND s.branch_id = ?
169
+ ORDER BY s.sequence_number DESC
170
+ LIMIT ?
171
+ """,
172
+ (self.session_id, branch_id, limit),
173
+ )
174
+
175
+ rows = cursor.fetchall()
176
+ if limit is not None:
177
+ rows = list(reversed(rows))
178
+
179
+ items = []
180
+ for (message_data,) in rows:
181
+ try:
182
+ item = json.loads(message_data)
183
+ items.append(item)
184
+ except json.JSONDecodeError:
185
+ continue
186
+ return items
187
+
188
+ return await asyncio.to_thread(_get_all_items_sync)
189
+
190
+ def _get_items_sync():
191
+ """Synchronous helper to get items for a specific branch."""
192
+ conn = self._get_connection()
193
+ # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
194
+ with self._lock if self._is_memory_db else threading.Lock():
195
+ with closing(conn.cursor()) as cursor:
196
+ # Get message IDs in correct order for this branch
197
+ if limit is None:
198
+ cursor.execute(
199
+ """
200
+ SELECT m.message_data
201
+ FROM agent_messages m
202
+ JOIN message_structure s ON m.id = s.message_id
203
+ WHERE m.session_id = ? AND s.branch_id = ?
204
+ ORDER BY s.sequence_number ASC
205
+ """,
206
+ (self.session_id, branch_id),
207
+ )
208
+ else:
209
+ cursor.execute(
210
+ """
211
+ SELECT m.message_data
212
+ FROM agent_messages m
213
+ JOIN message_structure s ON m.id = s.message_id
214
+ WHERE m.session_id = ? AND s.branch_id = ?
215
+ ORDER BY s.sequence_number DESC
216
+ LIMIT ?
217
+ """,
218
+ (self.session_id, branch_id, limit),
219
+ )
220
+
221
+ rows = cursor.fetchall()
222
+ if limit is not None:
223
+ rows = list(reversed(rows))
224
+
225
+ items = []
226
+ for (message_data,) in rows:
227
+ try:
228
+ item = json.loads(message_data)
229
+ items.append(item)
230
+ except json.JSONDecodeError:
231
+ continue
232
+ return items
233
+
234
+ return await asyncio.to_thread(_get_items_sync)
235
+
236
+ async def store_run_usage(self, result: RunResult) -> None:
237
+ """Store usage data for the current conversation turn.
238
+
239
+ This is designed to be called after `Runner.run()` completes.
240
+ Session-level usage can be aggregated from turn data when needed.
241
+
242
+ Args:
243
+ result: The result from the run
244
+ """
245
+ try:
246
+ if result.context_wrapper.usage is not None:
247
+ # Get the current turn number for this branch
248
+ current_turn = self._get_current_turn_number()
249
+ # Only update turn-level usage - session usage is aggregated on demand
250
+ await self._update_turn_usage_internal(current_turn, result.context_wrapper.usage)
251
+ except Exception as e:
252
+ self._logger.error(f"Failed to store usage for session {self.session_id}: {e}")
253
+
254
+ def _get_next_turn_number(self, branch_id: str) -> int:
255
+ """Get the next turn number for a specific branch.
256
+
257
+ Args:
258
+ branch_id: The branch ID to get the next turn number for.
259
+
260
+ Returns:
261
+ The next available turn number for the specified branch.
262
+ """
263
+ conn = self._get_connection()
264
+ with closing(conn.cursor()) as cursor:
265
+ cursor.execute(
266
+ """
267
+ SELECT COALESCE(MAX(user_turn_number), 0)
268
+ FROM message_structure
269
+ WHERE session_id = ? AND branch_id = ?
270
+ """,
271
+ (self.session_id, branch_id),
272
+ )
273
+ result = cursor.fetchone()
274
+ max_turn = result[0] if result else 0
275
+ return max_turn + 1
276
+
277
+ def _get_next_branch_turn_number(self, branch_id: str) -> int:
278
+ """Get the next branch turn number for a specific branch.
279
+
280
+ Args:
281
+ branch_id: The branch ID to get the next branch turn number for.
282
+
283
+ Returns:
284
+ The next available branch turn number for the specified branch.
285
+ """
286
+ conn = self._get_connection()
287
+ with closing(conn.cursor()) as cursor:
288
+ cursor.execute(
289
+ """
290
+ SELECT COALESCE(MAX(branch_turn_number), 0)
291
+ FROM message_structure
292
+ WHERE session_id = ? AND branch_id = ?
293
+ """,
294
+ (self.session_id, branch_id),
295
+ )
296
+ result = cursor.fetchone()
297
+ max_turn = result[0] if result else 0
298
+ return max_turn + 1
299
+
300
+ def _get_current_turn_number(self) -> int:
301
+ """Get the current turn number for the current branch.
302
+
303
+ Returns:
304
+ The current turn number for the active branch.
305
+ """
306
+ conn = self._get_connection()
307
+ with closing(conn.cursor()) as cursor:
308
+ cursor.execute(
309
+ """
310
+ SELECT COALESCE(MAX(user_turn_number), 0)
311
+ FROM message_structure
312
+ WHERE session_id = ? AND branch_id = ?
313
+ """,
314
+ (self.session_id, self._current_branch_id),
315
+ )
316
+ result = cursor.fetchone()
317
+ return result[0] if result else 0
318
+
319
+ async def _add_structure_metadata(self, items: list[TResponseInputItem]) -> None:
320
+ """Extract structure metadata with branch-aware turn tracking.
321
+
322
+ This method:
323
+ - Assigns turn numbers per branch (not globally)
324
+ - Assigns explicit sequence numbers for precise ordering
325
+ - Links messages to their database IDs for structure tracking
326
+ - Handles multiple user messages in a single batch correctly
327
+
328
+ Args:
329
+ items: The items to add to the session
330
+ """
331
+
332
+ def _add_structure_sync():
333
+ """Synchronous helper to add structure metadata to database."""
334
+ conn = self._get_connection()
335
+ # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
336
+ with self._lock if self._is_memory_db else threading.Lock():
337
+ # Get the IDs of messages we just inserted, in order
338
+ with closing(conn.cursor()) as cursor:
339
+ cursor.execute(
340
+ f"SELECT id FROM {self.messages_table} "
341
+ f"WHERE session_id = ? ORDER BY id DESC LIMIT ?",
342
+ (self.session_id, len(items)),
343
+ )
344
+ message_ids = [row[0] for row in cursor.fetchall()]
345
+ message_ids.reverse() # Match order of items
346
+
347
+ # Get current max sequence number (global)
348
+ with closing(conn.cursor()) as cursor:
349
+ cursor.execute(
350
+ """
351
+ SELECT COALESCE(MAX(sequence_number), 0)
352
+ FROM message_structure
353
+ WHERE session_id = ?
354
+ """,
355
+ (self.session_id,),
356
+ )
357
+ seq_start = cursor.fetchone()[0]
358
+
359
+ # Get current turn numbers atomically with a single query
360
+ with closing(conn.cursor()) as cursor:
361
+ cursor.execute(
362
+ """
363
+ SELECT
364
+ COALESCE(MAX(user_turn_number), 0) as max_global_turn,
365
+ COALESCE(MAX(branch_turn_number), 0) as max_branch_turn
366
+ FROM message_structure
367
+ WHERE session_id = ? AND branch_id = ?
368
+ """,
369
+ (self.session_id, self._current_branch_id),
370
+ )
371
+ result = cursor.fetchone()
372
+ current_turn = result[0] if result else 0
373
+ current_branch_turn = result[1] if result else 0
374
+
375
+ # Process items and assign turn numbers correctly
376
+ structure_data = []
377
+ user_message_count = 0
378
+
379
+ for i, (item, msg_id) in enumerate(zip(items, message_ids)):
380
+ msg_type = self._classify_message_type(item)
381
+ tool_name = self._extract_tool_name(item)
382
+
383
+ # If this is a user message, increment turn counters
384
+ if self._is_user_message(item):
385
+ user_message_count += 1
386
+ item_turn = current_turn + user_message_count
387
+ item_branch_turn = current_branch_turn + user_message_count
388
+ else:
389
+ # Non-user messages inherit the turn number of the most recent user message
390
+ item_turn = current_turn + user_message_count
391
+ item_branch_turn = current_branch_turn + user_message_count
392
+
393
+ structure_data.append(
394
+ (
395
+ self.session_id,
396
+ msg_id,
397
+ self._current_branch_id,
398
+ msg_type,
399
+ seq_start + i + 1, # Global sequence
400
+ item_turn, # Global turn number
401
+ item_branch_turn, # Branch-specific turn number
402
+ tool_name,
403
+ )
404
+ )
405
+
406
+ with closing(conn.cursor()) as cursor:
407
+ cursor.executemany(
408
+ """
409
+ INSERT INTO message_structure
410
+ (session_id, message_id, branch_id, message_type, sequence_number,
411
+ user_turn_number, branch_turn_number, tool_name)
412
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
413
+ """,
414
+ structure_data,
415
+ )
416
+ conn.commit()
417
+
418
+ try:
419
+ await asyncio.to_thread(_add_structure_sync)
420
+ except Exception as e:
421
+ self._logger.error(
422
+ f"Failed to add structure metadata for session {self.session_id}: {e}"
423
+ )
424
+ # Try to clean up any orphaned messages to maintain consistency
425
+ try:
426
+ await self._cleanup_orphaned_messages()
427
+ except Exception as cleanup_error:
428
+ self._logger.error(f"Failed to cleanup orphaned messages: {cleanup_error}")
429
+ # Don't re-raise - structure metadata is supplementary
430
+
431
+ async def _cleanup_orphaned_messages(self) -> None:
432
+ """Remove messages that exist in agent_messages but not in message_structure.
433
+
434
+ This can happen if _add_structure_metadata fails after super().add_items() succeeds.
435
+ Used for maintaining data consistency.
436
+ """
437
+
438
+ def _cleanup_sync():
439
+ """Synchronous helper to cleanup orphaned messages."""
440
+ conn = self._get_connection()
441
+ # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
442
+ with self._lock if self._is_memory_db else threading.Lock():
443
+ with closing(conn.cursor()) as cursor:
444
+ # Find messages without structure metadata
445
+ cursor.execute(
446
+ """
447
+ SELECT am.id
448
+ FROM agent_messages am
449
+ LEFT JOIN message_structure ms ON am.id = ms.message_id
450
+ WHERE am.session_id = ? AND ms.message_id IS NULL
451
+ """,
452
+ (self.session_id,),
453
+ )
454
+
455
+ orphaned_ids = [row[0] for row in cursor.fetchall()]
456
+
457
+ if orphaned_ids:
458
+ # Delete orphaned messages
459
+ placeholders = ",".join("?" * len(orphaned_ids))
460
+ cursor.execute(
461
+ f"DELETE FROM agent_messages WHERE id IN ({placeholders})", orphaned_ids
462
+ )
463
+
464
+ deleted_count = cursor.rowcount
465
+ conn.commit()
466
+
467
+ self._logger.info(f"Cleaned up {deleted_count} orphaned messages")
468
+ return deleted_count
469
+
470
+ return 0
471
+
472
+ return await asyncio.to_thread(_cleanup_sync)
473
+
474
+ def _classify_message_type(self, item: TResponseInputItem) -> str:
475
+ """Classify the type of a message item.
476
+
477
+ Args:
478
+ item: The message item to classify.
479
+
480
+ Returns:
481
+ String representing the message type (user, assistant, etc.).
482
+ """
483
+ if isinstance(item, dict):
484
+ if item.get("role") == "user":
485
+ return "user"
486
+ elif item.get("role") == "assistant":
487
+ return "assistant"
488
+ elif item.get("type"):
489
+ return str(item.get("type"))
490
+ return "other"
491
+
492
+ def _extract_tool_name(self, item: TResponseInputItem) -> str | None:
493
+ """Extract tool name if this is a tool call/output.
494
+
495
+ Args:
496
+ item: The message item to extract tool name from.
497
+
498
+ Returns:
499
+ Tool name if item is a tool call, None otherwise.
500
+ """
501
+ if isinstance(item, dict):
502
+ item_type = item.get("type")
503
+
504
+ # For MCP tools, try to extract from server_label if available
505
+ if item_type in {"mcp_call", "mcp_approval_request"} and "server_label" in item:
506
+ server_label = item.get("server_label")
507
+ tool_name = item.get("name")
508
+ if tool_name and server_label:
509
+ return f"{server_label}.{tool_name}"
510
+ elif server_label:
511
+ return str(server_label)
512
+ elif tool_name:
513
+ return str(tool_name)
514
+
515
+ # For tool types without a 'name' field, derive from the type
516
+ elif item_type in {
517
+ "computer_call",
518
+ "file_search_call",
519
+ "web_search_call",
520
+ "code_interpreter_call",
521
+ }:
522
+ return item_type
523
+
524
+ # Most other tool calls have a 'name' field
525
+ elif "name" in item:
526
+ name = item.get("name")
527
+ return str(name) if name is not None else None
528
+
529
+ return None
530
+
531
+ def _is_user_message(self, item: TResponseInputItem) -> bool:
532
+ """Check if this is a user message.
533
+
534
+ Args:
535
+ item: The message item to check.
536
+
537
+ Returns:
538
+ True if the item is a user message, False otherwise.
539
+ """
540
+ return isinstance(item, dict) and item.get("role") == "user"
541
+
542
+ async def create_branch_from_turn(
543
+ self, turn_number: int, branch_name: str | None = None
544
+ ) -> str:
545
+ """Create a new branch starting from a specific user message turn.
546
+
547
+ Args:
548
+ turn_number: The branch turn number of the user message to branch from
549
+ branch_name: Optional name for the branch (auto-generated if None)
550
+
551
+ Returns:
552
+ The branch_id of the newly created branch
553
+
554
+ Raises:
555
+ ValueError: If turn doesn't exist or doesn't contain a user message
556
+ """
557
+ import time
558
+
559
+ # Validate the turn exists and contains a user message
560
+ def _validate_turn():
561
+ """Synchronous helper to validate turn exists and contains user message."""
562
+ conn = self._get_connection()
563
+ with closing(conn.cursor()) as cursor:
564
+ cursor.execute(
565
+ """
566
+ SELECT am.message_data
567
+ FROM message_structure ms
568
+ JOIN agent_messages am ON ms.message_id = am.id
569
+ WHERE ms.session_id = ? AND ms.branch_id = ?
570
+ AND ms.branch_turn_number = ? AND ms.message_type = 'user'
571
+ """,
572
+ (self.session_id, self._current_branch_id, turn_number),
573
+ )
574
+
575
+ result = cursor.fetchone()
576
+ if not result:
577
+ raise ValueError(
578
+ f"Turn {turn_number} does not contain a user message "
579
+ f"in branch '{self._current_branch_id}'"
580
+ )
581
+
582
+ message_data = result[0]
583
+ try:
584
+ content = json.loads(message_data).get("content", "")
585
+ return content[:50] + "..." if len(content) > 50 else content
586
+ except Exception:
587
+ return "Unable to parse content"
588
+
589
+ turn_content = await asyncio.to_thread(_validate_turn)
590
+
591
+ # Generate branch name if not provided
592
+ if branch_name is None:
593
+ timestamp = int(time.time())
594
+ branch_name = f"branch_from_turn_{turn_number}_{timestamp}"
595
+
596
+ # Copy messages before the branch point to the new branch
597
+ await self._copy_messages_to_new_branch(branch_name, turn_number)
598
+
599
+ # Switch to new branch
600
+ old_branch = self._current_branch_id
601
+ self._current_branch_id = branch_name
602
+
603
+ self._logger.debug(
604
+ f"Created branch '{branch_name}' from turn {turn_number} ('{turn_content}') in '{old_branch}'" # noqa: E501
605
+ )
606
+ return branch_name
607
+
608
+ async def create_branch_from_content(
609
+ self, search_term: str, branch_name: str | None = None
610
+ ) -> str:
611
+ """Create branch from the first user turn matching the search term.
612
+
613
+ Args:
614
+ search_term: Text to search for in user messages.
615
+ branch_name: Optional name for the branch (auto-generated if None).
616
+
617
+ Returns:
618
+ The branch_id of the newly created branch.
619
+
620
+ Raises:
621
+ ValueError: If no matching turns are found.
622
+ """
623
+ matching_turns = await self.find_turns_by_content(search_term)
624
+ if not matching_turns:
625
+ raise ValueError(f"No user turns found containing '{search_term}'")
626
+
627
+ # Use the first (earliest) match
628
+ turn_number = matching_turns[0]["turn"]
629
+ return await self.create_branch_from_turn(turn_number, branch_name)
630
+
631
+ async def switch_to_branch(self, branch_id: str) -> None:
632
+ """Switch to a different branch.
633
+
634
+ Args:
635
+ branch_id: The branch to switch to.
636
+
637
+ Raises:
638
+ ValueError: If the branch doesn't exist.
639
+ """
640
+
641
+ # Validate branch exists
642
+ def _validate_branch():
643
+ """Synchronous helper to validate branch exists."""
644
+ conn = self._get_connection()
645
+ with closing(conn.cursor()) as cursor:
646
+ cursor.execute(
647
+ """
648
+ SELECT COUNT(*) FROM message_structure
649
+ WHERE session_id = ? AND branch_id = ?
650
+ """,
651
+ (self.session_id, branch_id),
652
+ )
653
+
654
+ count = cursor.fetchone()[0]
655
+ if count == 0:
656
+ raise ValueError(f"Branch '{branch_id}' does not exist")
657
+
658
+ await asyncio.to_thread(_validate_branch)
659
+
660
+ old_branch = self._current_branch_id
661
+ self._current_branch_id = branch_id
662
+ self._logger.info(f"Switched from branch '{old_branch}' to '{branch_id}'")
663
+
664
+ async def delete_branch(self, branch_id: str, force: bool = False) -> None:
665
+ """Delete a branch and all its associated data.
666
+
667
+ Args:
668
+ branch_id: The branch to delete.
669
+ force: If True, allows deleting the current branch (will switch to 'main').
670
+
671
+ Raises:
672
+ ValueError: If branch doesn't exist, is 'main', or is current branch without force.
673
+ """
674
+ if not branch_id or not branch_id.strip():
675
+ raise ValueError("Branch ID cannot be empty")
676
+
677
+ branch_id = branch_id.strip()
678
+
679
+ # Protect main branch
680
+ if branch_id == "main":
681
+ raise ValueError("Cannot delete the 'main' branch")
682
+
683
+ # Check if trying to delete current branch
684
+ if branch_id == self._current_branch_id:
685
+ if not force:
686
+ raise ValueError(
687
+ f"Cannot delete current branch '{branch_id}'. Use force=True or switch branches first" # noqa: E501
688
+ )
689
+ else:
690
+ # Switch to main before deleting
691
+ await self.switch_to_branch("main")
692
+
693
+ def _delete_sync():
694
+ """Synchronous helper to delete branch and associated data."""
695
+ conn = self._get_connection()
696
+ # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
697
+ with self._lock if self._is_memory_db else threading.Lock():
698
+ with closing(conn.cursor()) as cursor:
699
+ # First verify the branch exists
700
+ cursor.execute(
701
+ """
702
+ SELECT COUNT(*) FROM message_structure
703
+ WHERE session_id = ? AND branch_id = ?
704
+ """,
705
+ (self.session_id, branch_id),
706
+ )
707
+
708
+ count = cursor.fetchone()[0]
709
+ if count == 0:
710
+ raise ValueError(f"Branch '{branch_id}' does not exist")
711
+
712
+ # Delete from turn_usage first (foreign key constraint)
713
+ cursor.execute(
714
+ """
715
+ DELETE FROM turn_usage
716
+ WHERE session_id = ? AND branch_id = ?
717
+ """,
718
+ (self.session_id, branch_id),
719
+ )
720
+
721
+ usage_deleted = cursor.rowcount
722
+
723
+ # Delete from message_structure
724
+ cursor.execute(
725
+ """
726
+ DELETE FROM message_structure
727
+ WHERE session_id = ? AND branch_id = ?
728
+ """,
729
+ (self.session_id, branch_id),
730
+ )
731
+
732
+ structure_deleted = cursor.rowcount
733
+
734
+ conn.commit()
735
+
736
+ return usage_deleted, structure_deleted
737
+
738
+ usage_deleted, structure_deleted = await asyncio.to_thread(_delete_sync)
739
+
740
+ self._logger.info(
741
+ f"Deleted branch '{branch_id}': {structure_deleted} message entries, {usage_deleted} usage entries" # noqa: E501
742
+ )
743
+
744
+ async def list_branches(self) -> list[dict[str, Any]]:
745
+ """List all branches in this session.
746
+
747
+ Returns:
748
+ List of dicts with branch info containing:
749
+ - 'branch_id': Branch identifier
750
+ - 'message_count': Number of messages in branch
751
+ - 'user_turns': Number of user turns in branch
752
+ - 'is_current': Whether this is the current branch
753
+ - 'created_at': When the branch was first created
754
+ """
755
+
756
+ def _list_branches_sync():
757
+ """Synchronous helper to list all branches."""
758
+ conn = self._get_connection()
759
+ with closing(conn.cursor()) as cursor:
760
+ cursor.execute(
761
+ """
762
+ SELECT
763
+ ms.branch_id,
764
+ COUNT(*) as message_count,
765
+ COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns,
766
+ MIN(ms.created_at) as created_at
767
+ FROM message_structure ms
768
+ WHERE ms.session_id = ?
769
+ GROUP BY ms.branch_id
770
+ ORDER BY created_at
771
+ """,
772
+ (self.session_id,),
773
+ )
774
+
775
+ branches = []
776
+ for row in cursor.fetchall():
777
+ branch_id, msg_count, user_turns, created_at = row
778
+ branches.append(
779
+ {
780
+ "branch_id": branch_id,
781
+ "message_count": msg_count,
782
+ "user_turns": user_turns,
783
+ "is_current": branch_id == self._current_branch_id,
784
+ "created_at": created_at,
785
+ }
786
+ )
787
+
788
+ return branches
789
+
790
+ return await asyncio.to_thread(_list_branches_sync)
791
+
792
+ async def _copy_messages_to_new_branch(self, new_branch_id: str, from_turn_number: int) -> None:
793
+ """Copy messages before the branch point to the new branch.
794
+
795
+ Args:
796
+ new_branch_id: The ID of the new branch to copy messages to.
797
+ from_turn_number: The turn number to copy messages up to (exclusive).
798
+ """
799
+
800
+ def _copy_sync():
801
+ """Synchronous helper to copy messages to new branch."""
802
+ conn = self._get_connection()
803
+ # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
804
+ with self._lock if self._is_memory_db else threading.Lock():
805
+ with closing(conn.cursor()) as cursor:
806
+ # Get all messages before the branch point
807
+ cursor.execute(
808
+ """
809
+ SELECT
810
+ ms.message_id,
811
+ ms.message_type,
812
+ ms.sequence_number,
813
+ ms.user_turn_number,
814
+ ms.branch_turn_number,
815
+ ms.tool_name
816
+ FROM message_structure ms
817
+ WHERE ms.session_id = ? AND ms.branch_id = ?
818
+ AND ms.branch_turn_number < ?
819
+ ORDER BY ms.sequence_number
820
+ """,
821
+ (self.session_id, self._current_branch_id, from_turn_number),
822
+ )
823
+
824
+ messages_to_copy = cursor.fetchall()
825
+
826
+ if messages_to_copy:
827
+ # Get the max sequence number for the new inserts
828
+ cursor.execute(
829
+ """
830
+ SELECT COALESCE(MAX(sequence_number), 0)
831
+ FROM message_structure
832
+ WHERE session_id = ?
833
+ """,
834
+ (self.session_id,),
835
+ )
836
+
837
+ seq_start = cursor.fetchone()[0]
838
+
839
+ # Insert copied messages with new branch_id
840
+ new_structure_data = []
841
+ for i, (
842
+ msg_id,
843
+ msg_type,
844
+ _,
845
+ user_turn,
846
+ branch_turn,
847
+ tool_name,
848
+ ) in enumerate(messages_to_copy):
849
+ new_structure_data.append(
850
+ (
851
+ self.session_id,
852
+ msg_id, # Same message_id (sharing the actual message data)
853
+ new_branch_id,
854
+ msg_type,
855
+ seq_start + i + 1, # New sequence number
856
+ user_turn, # Keep same global turn number
857
+ branch_turn, # Keep same branch turn number
858
+ tool_name,
859
+ )
860
+ )
861
+
862
+ cursor.executemany(
863
+ """
864
+ INSERT INTO message_structure
865
+ (session_id, message_id, branch_id, message_type, sequence_number,
866
+ user_turn_number, branch_turn_number, tool_name)
867
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
868
+ """,
869
+ new_structure_data,
870
+ )
871
+
872
+ conn.commit()
873
+
874
+ await asyncio.to_thread(_copy_sync)
875
+
876
+ async def get_conversation_turns(self, branch_id: str | None = None) -> list[dict[str, Any]]:
877
+ """Get user turns with content for easy browsing and branching decisions.
878
+
879
+ Args:
880
+ branch_id: Branch to get turns from (current branch if None).
881
+
882
+ Returns:
883
+ List of dicts with turn info containing:
884
+ - 'turn': Branch turn number
885
+ - 'content': User message content (truncated)
886
+ - 'full_content': Full user message content
887
+ - 'timestamp': When the turn was created
888
+ - 'can_branch': Always True (all user messages can branch)
889
+ """
890
+ if branch_id is None:
891
+ branch_id = self._current_branch_id
892
+
893
+ def _get_turns_sync():
894
+ """Synchronous helper to get conversation turns."""
895
+ conn = self._get_connection()
896
+ with closing(conn.cursor()) as cursor:
897
+ cursor.execute(
898
+ """
899
+ SELECT
900
+ ms.branch_turn_number,
901
+ am.message_data,
902
+ ms.created_at
903
+ FROM message_structure ms
904
+ JOIN agent_messages am ON ms.message_id = am.id
905
+ WHERE ms.session_id = ? AND ms.branch_id = ?
906
+ AND ms.message_type = 'user'
907
+ ORDER BY ms.branch_turn_number
908
+ """,
909
+ (self.session_id, branch_id),
910
+ )
911
+
912
+ turns = []
913
+ for row in cursor.fetchall():
914
+ turn_num, message_data, created_at = row
915
+ try:
916
+ content = json.loads(message_data).get("content", "")
917
+ turns.append(
918
+ {
919
+ "turn": turn_num,
920
+ "content": content[:100] + "..." if len(content) > 100 else content,
921
+ "full_content": content,
922
+ "timestamp": created_at,
923
+ "can_branch": True,
924
+ }
925
+ )
926
+ except (json.JSONDecodeError, AttributeError):
927
+ continue
928
+
929
+ return turns
930
+
931
+ return await asyncio.to_thread(_get_turns_sync)
932
+
933
+ async def find_turns_by_content(
934
+ self, search_term: str, branch_id: str | None = None
935
+ ) -> list[dict[str, Any]]:
936
+ """Find user turns containing specific content.
937
+
938
+ Args:
939
+ search_term: Text to search for in user messages.
940
+ branch_id: Branch to search in (current branch if None).
941
+
942
+ Returns:
943
+ List of matching turns with same format as get_conversation_turns().
944
+ """
945
+ if branch_id is None:
946
+ branch_id = self._current_branch_id
947
+
948
+ def _search_sync():
949
+ """Synchronous helper to search turns by content."""
950
+ conn = self._get_connection()
951
+ with closing(conn.cursor()) as cursor:
952
+ cursor.execute(
953
+ """
954
+ SELECT
955
+ ms.branch_turn_number,
956
+ am.message_data,
957
+ ms.created_at
958
+ FROM message_structure ms
959
+ JOIN agent_messages am ON ms.message_id = am.id
960
+ WHERE ms.session_id = ? AND ms.branch_id = ?
961
+ AND ms.message_type = 'user'
962
+ AND am.message_data LIKE ?
963
+ ORDER BY ms.branch_turn_number
964
+ """,
965
+ (self.session_id, branch_id, f"%{search_term}%"),
966
+ )
967
+
968
+ matches = []
969
+ for row in cursor.fetchall():
970
+ turn_num, message_data, created_at = row
971
+ try:
972
+ content = json.loads(message_data).get("content", "")
973
+ matches.append(
974
+ {
975
+ "turn": turn_num,
976
+ "content": content,
977
+ "full_content": content,
978
+ "timestamp": created_at,
979
+ "can_branch": True,
980
+ }
981
+ )
982
+ except (json.JSONDecodeError, AttributeError):
983
+ continue
984
+
985
+ return matches
986
+
987
+ return await asyncio.to_thread(_search_sync)
988
+
989
+ async def get_conversation_by_turns(
990
+ self, branch_id: str | None = None
991
+ ) -> dict[int, list[dict[str, str | None]]]:
992
+ """Get conversation grouped by user turns for specified branch.
993
+
994
+ Args:
995
+ branch_id: Branch to get conversation from (current branch if None).
996
+
997
+ Returns:
998
+ Dictionary mapping turn numbers to lists of message metadata.
999
+ """
1000
+ if branch_id is None:
1001
+ branch_id = self._current_branch_id
1002
+
1003
+ def _get_conversation_sync():
1004
+ """Synchronous helper to get conversation by turns."""
1005
+ conn = self._get_connection()
1006
+ with closing(conn.cursor()) as cursor:
1007
+ cursor.execute(
1008
+ """
1009
+ SELECT user_turn_number, message_type, tool_name
1010
+ FROM message_structure
1011
+ WHERE session_id = ? AND branch_id = ?
1012
+ ORDER BY sequence_number
1013
+ """,
1014
+ (self.session_id, branch_id),
1015
+ )
1016
+
1017
+ turns: dict[int, list[dict[str, str | None]]] = {}
1018
+ for row in cursor.fetchall():
1019
+ turn_num, msg_type, tool_name = row
1020
+ if turn_num not in turns:
1021
+ turns[turn_num] = []
1022
+ turns[turn_num].append({"type": msg_type, "tool_name": tool_name})
1023
+ return turns
1024
+
1025
+ return await asyncio.to_thread(_get_conversation_sync)
1026
+
1027
+ async def get_tool_usage(self, branch_id: str | None = None) -> list[tuple[str, int, int]]:
1028
+ """Get all tool usage by turn for specified branch.
1029
+
1030
+ Args:
1031
+ branch_id: Branch to get tool usage from (current branch if None).
1032
+
1033
+ Returns:
1034
+ List of tuples containing (tool_name, usage_count, turn_number).
1035
+ """
1036
+ if branch_id is None:
1037
+ branch_id = self._current_branch_id
1038
+
1039
+ def _get_tool_usage_sync():
1040
+ """Synchronous helper to get tool usage statistics."""
1041
+ conn = self._get_connection()
1042
+ with closing(conn.cursor()) as cursor:
1043
+ cursor.execute(
1044
+ """
1045
+ SELECT tool_name, COUNT(*), user_turn_number
1046
+ FROM message_structure
1047
+ WHERE session_id = ? AND branch_id = ? AND message_type IN (
1048
+ 'tool_call', 'function_call', 'computer_call', 'file_search_call',
1049
+ 'web_search_call', 'code_interpreter_call', 'custom_tool_call',
1050
+ 'mcp_call', 'mcp_approval_request'
1051
+ )
1052
+ GROUP BY tool_name, user_turn_number
1053
+ ORDER BY user_turn_number
1054
+ """,
1055
+ (self.session_id, branch_id),
1056
+ )
1057
+ return cursor.fetchall()
1058
+
1059
+ return await asyncio.to_thread(_get_tool_usage_sync)
1060
+
1061
+ async def get_session_usage(self, branch_id: str | None = None) -> dict[str, int] | None:
1062
+ """Get cumulative usage for session or specific branch.
1063
+
1064
+ Args:
1065
+ branch_id: If provided, only get usage for that branch. If None, get all branches.
1066
+
1067
+ Returns:
1068
+ Dictionary with usage statistics or None if no usage data found.
1069
+ """
1070
+
1071
+ def _get_usage_sync():
1072
+ """Synchronous helper to get session usage data."""
1073
+ conn = self._get_connection()
1074
+ # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
1075
+ with self._lock if self._is_memory_db else threading.Lock():
1076
+ if branch_id:
1077
+ # Branch-specific usage
1078
+ query = """
1079
+ SELECT
1080
+ SUM(requests) as total_requests,
1081
+ SUM(input_tokens) as total_input_tokens,
1082
+ SUM(output_tokens) as total_output_tokens,
1083
+ SUM(total_tokens) as total_total_tokens,
1084
+ COUNT(*) as total_turns
1085
+ FROM turn_usage
1086
+ WHERE session_id = ? AND branch_id = ?
1087
+ """
1088
+ params: tuple[str, ...] = (self.session_id, branch_id)
1089
+ else:
1090
+ # All branches
1091
+ query = """
1092
+ SELECT
1093
+ SUM(requests) as total_requests,
1094
+ SUM(input_tokens) as total_input_tokens,
1095
+ SUM(output_tokens) as total_output_tokens,
1096
+ SUM(total_tokens) as total_total_tokens,
1097
+ COUNT(*) as total_turns
1098
+ FROM turn_usage
1099
+ WHERE session_id = ?
1100
+ """
1101
+ params = (self.session_id,)
1102
+
1103
+ with closing(conn.cursor()) as cursor:
1104
+ cursor.execute(query, params)
1105
+ row = cursor.fetchone()
1106
+
1107
+ if row and row[0] is not None:
1108
+ return {
1109
+ "requests": row[0] or 0,
1110
+ "input_tokens": row[1] or 0,
1111
+ "output_tokens": row[2] or 0,
1112
+ "total_tokens": row[3] or 0,
1113
+ "total_turns": row[4] or 0,
1114
+ }
1115
+ return None
1116
+
1117
+ result = await asyncio.to_thread(_get_usage_sync)
1118
+
1119
+ return cast(Union[dict[str, int], None], result)
1120
+
1121
+ async def get_turn_usage(
1122
+ self,
1123
+ user_turn_number: int | None = None,
1124
+ branch_id: str | None = None,
1125
+ ) -> list[dict[str, Any]] | dict[str, Any]:
1126
+ """Get usage statistics by turn with full JSON token details.
1127
+
1128
+ Args:
1129
+ user_turn_number: Specific turn to get usage for. If None, returns all turns.
1130
+ branch_id: Branch to get usage from (current branch if None).
1131
+
1132
+ Returns:
1133
+ Dictionary with usage data for specific turn, or list of dictionaries for all turns.
1134
+ """
1135
+
1136
+ if branch_id is None:
1137
+ branch_id = self._current_branch_id
1138
+
1139
+ def _get_turn_usage_sync():
1140
+ """Synchronous helper to get turn usage statistics."""
1141
+ conn = self._get_connection()
1142
+
1143
+ if user_turn_number is not None:
1144
+ query = """
1145
+ SELECT requests, input_tokens, output_tokens, total_tokens,
1146
+ input_tokens_details, output_tokens_details
1147
+ FROM turn_usage
1148
+ WHERE session_id = ? AND branch_id = ? AND user_turn_number = ?
1149
+ """
1150
+
1151
+ with closing(conn.cursor()) as cursor:
1152
+ cursor.execute(query, (self.session_id, branch_id, user_turn_number))
1153
+ row = cursor.fetchone()
1154
+
1155
+ if row:
1156
+ # Parse JSON details if present
1157
+ input_details = None
1158
+ output_details = None
1159
+
1160
+ if row[4]: # input_tokens_details
1161
+ try:
1162
+ input_details = json.loads(row[4])
1163
+ except json.JSONDecodeError:
1164
+ pass
1165
+
1166
+ if row[5]: # output_tokens_details
1167
+ try:
1168
+ output_details = json.loads(row[5])
1169
+ except json.JSONDecodeError:
1170
+ pass
1171
+
1172
+ return {
1173
+ "requests": row[0],
1174
+ "input_tokens": row[1],
1175
+ "output_tokens": row[2],
1176
+ "total_tokens": row[3],
1177
+ "input_tokens_details": input_details,
1178
+ "output_tokens_details": output_details,
1179
+ }
1180
+ return {}
1181
+ else:
1182
+ query = """
1183
+ SELECT user_turn_number, requests, input_tokens, output_tokens,
1184
+ total_tokens, input_tokens_details, output_tokens_details
1185
+ FROM turn_usage
1186
+ WHERE session_id = ? AND branch_id = ?
1187
+ ORDER BY user_turn_number
1188
+ """
1189
+
1190
+ with closing(conn.cursor()) as cursor:
1191
+ cursor.execute(query, (self.session_id, branch_id))
1192
+ results = []
1193
+ for row in cursor.fetchall():
1194
+ # Parse JSON details if present
1195
+ input_details = None
1196
+ output_details = None
1197
+
1198
+ if row[5]: # input_tokens_details
1199
+ try:
1200
+ input_details = json.loads(row[5])
1201
+ except json.JSONDecodeError:
1202
+ pass
1203
+
1204
+ if row[6]: # output_tokens_details
1205
+ try:
1206
+ output_details = json.loads(row[6])
1207
+ except json.JSONDecodeError:
1208
+ pass
1209
+
1210
+ results.append(
1211
+ {
1212
+ "user_turn_number": row[0],
1213
+ "requests": row[1],
1214
+ "input_tokens": row[2],
1215
+ "output_tokens": row[3],
1216
+ "total_tokens": row[4],
1217
+ "input_tokens_details": input_details,
1218
+ "output_tokens_details": output_details,
1219
+ }
1220
+ )
1221
+ return results
1222
+
1223
+ result = await asyncio.to_thread(_get_turn_usage_sync)
1224
+
1225
+ return cast(Union[list[dict[str, Any]], dict[str, Any]], result)
1226
+
1227
+ async def _update_turn_usage_internal(self, user_turn_number: int, usage_data: Usage) -> None:
1228
+ """Internal method to update usage for a specific turn with full JSON details.
1229
+
1230
+ Args:
1231
+ user_turn_number: The turn number to update usage for.
1232
+ usage_data: The usage data to store.
1233
+ """
1234
+
1235
+ def _update_sync():
1236
+ """Synchronous helper to update turn usage data."""
1237
+ conn = self._get_connection()
1238
+ # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
1239
+ with self._lock if self._is_memory_db else threading.Lock():
1240
+ # Serialize token details as JSON
1241
+ input_details_json = None
1242
+ output_details_json = None
1243
+
1244
+ if hasattr(usage_data, "input_tokens_details") and usage_data.input_tokens_details:
1245
+ try:
1246
+ input_details_json = json.dumps(usage_data.input_tokens_details.__dict__)
1247
+ except (TypeError, ValueError) as e:
1248
+ self._logger.warning(f"Failed to serialize input tokens details: {e}")
1249
+ input_details_json = None
1250
+
1251
+ if (
1252
+ hasattr(usage_data, "output_tokens_details")
1253
+ and usage_data.output_tokens_details
1254
+ ):
1255
+ try:
1256
+ output_details_json = json.dumps(
1257
+ usage_data.output_tokens_details.__dict__
1258
+ )
1259
+ except (TypeError, ValueError) as e:
1260
+ self._logger.warning(f"Failed to serialize output tokens details: {e}")
1261
+ output_details_json = None
1262
+
1263
+ with closing(conn.cursor()) as cursor:
1264
+ cursor.execute(
1265
+ """
1266
+ INSERT OR REPLACE INTO turn_usage
1267
+ (session_id, branch_id, user_turn_number, requests, input_tokens, output_tokens,
1268
+ total_tokens, input_tokens_details, output_tokens_details)
1269
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
1270
+ """, # noqa: E501
1271
+ (
1272
+ self.session_id,
1273
+ self._current_branch_id,
1274
+ user_turn_number,
1275
+ usage_data.requests or 0,
1276
+ usage_data.input_tokens or 0,
1277
+ usage_data.output_tokens or 0,
1278
+ usage_data.total_tokens or 0,
1279
+ input_details_json,
1280
+ output_details_json,
1281
+ ),
1282
+ )
1283
+ conn.commit()
1284
+
1285
+ await asyncio.to_thread(_update_sync)