sqlsaber 0.16.1__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

sqlsaber/agents/base.py CHANGED
@@ -5,7 +5,6 @@ import json
5
5
  from abc import ABC, abstractmethod
6
6
  from typing import Any, AsyncIterator
7
7
 
8
- from sqlsaber.conversation.manager import ConversationManager
9
8
  from sqlsaber.database.connection import (
10
9
  BaseDatabaseConnection,
11
10
  CSVConnection,
@@ -23,12 +22,6 @@ class BaseSQLAgent(ABC):
23
22
  def __init__(self, db_connection: BaseDatabaseConnection):
24
23
  self.db = db_connection
25
24
  self.schema_manager = SchemaManager(db_connection)
26
- self.conversation_history: list[dict[str, Any]] = []
27
-
28
- # Conversation persistence
29
- self._conv_manager = ConversationManager()
30
- self._conversation_id: str | None = None
31
- self._msg_index: int = 0
32
25
 
33
26
  # Initialize SQL tools with database connection
34
27
  self._init_tools()
@@ -49,14 +42,6 @@ class BaseSQLAgent(ABC):
49
42
  """
50
43
  pass
51
44
 
52
- async def clear_history(self):
53
- """Clear conversation history."""
54
- # End current conversation in storage
55
- await self._end_conversation()
56
-
57
- # Clear in-memory history
58
- self.conversation_history = []
59
-
60
45
  def _get_database_type_name(self) -> str:
61
46
  """Get the human-readable database type name."""
62
47
  if isinstance(self.db, PostgreSQLConnection):
@@ -91,96 +76,3 @@ class BaseSQLAgent(ABC):
91
76
  return json.dumps(
92
77
  {"error": f"Error executing tool '{tool_name}': {str(e)}"}
93
78
  )
94
-
95
- # Conversation persistence helpers
96
-
97
- async def _ensure_conversation(self) -> None:
98
- """Ensure a conversation is active for storing messages."""
99
- if self._conversation_id is None:
100
- db_name = getattr(self, "database_name", "unknown")
101
- self._conversation_id = await self._conv_manager.start_conversation(db_name)
102
- self._msg_index = 0
103
-
104
- async def _store_user_message(self, content: str | dict[str, Any]) -> None:
105
- """Store a user message in conversation history."""
106
- if self._conversation_id is None:
107
- return
108
-
109
- await self._conv_manager.add_user_message(
110
- self._conversation_id, content, self._msg_index
111
- )
112
- self._msg_index += 1
113
-
114
- async def _store_assistant_message(
115
- self, content: list[dict[str, Any]] | dict[str, Any]
116
- ) -> None:
117
- """Store an assistant message in conversation history."""
118
- if self._conversation_id is None:
119
- return
120
-
121
- await self._conv_manager.add_assistant_message(
122
- self._conversation_id, content, self._msg_index
123
- )
124
- self._msg_index += 1
125
-
126
- async def _store_tool_message(
127
- self, content: list[dict[str, Any]] | dict[str, Any]
128
- ) -> None:
129
- """Store a tool/system message in conversation history."""
130
- if self._conversation_id is None:
131
- return
132
-
133
- await self._conv_manager.add_tool_message(
134
- self._conversation_id, content, self._msg_index
135
- )
136
- self._msg_index += 1
137
-
138
- async def _end_conversation(self) -> None:
139
- """End the current conversation."""
140
- if self._conversation_id:
141
- await self._conv_manager.end_conversation(self._conversation_id)
142
- self._conversation_id = None
143
- self._msg_index = 0
144
-
145
- async def restore_conversation(self, conversation_id: str) -> bool:
146
- """Restore a conversation from storage to in-memory history.
147
-
148
- Args:
149
- conversation_id: ID of the conversation to restore
150
-
151
- Returns:
152
- True if successfully restored, False otherwise
153
- """
154
- success = await self._conv_manager.restore_conversation_to_agent(
155
- conversation_id, self.conversation_history
156
- )
157
-
158
- if success:
159
- # Set up for continuing this conversation
160
- self._conversation_id = conversation_id
161
- self._msg_index = len(self.conversation_history)
162
-
163
- return success
164
-
165
- async def list_conversations(self, limit: int = 50) -> list:
166
- """List conversations for this agent's database.
167
-
168
- Args:
169
- limit: Maximum number of conversations to return
170
-
171
- Returns:
172
- List of conversation data
173
- """
174
- db_name = getattr(self, "database_name", None)
175
- conversations = await self._conv_manager.list_conversations(db_name, limit)
176
-
177
- return [
178
- {
179
- "id": conv.id,
180
- "database_name": conv.database_name,
181
- "started_at": conv.formatted_start_time(),
182
- "ended_at": conv.formatted_end_time(),
183
- "duration": conv.duration_seconds(),
184
- }
185
- for conv in conversations
186
- ]
sqlsaber/cli/commands.py CHANGED
@@ -7,22 +7,8 @@ from typing import Annotated
7
7
  import cyclopts
8
8
  from rich.console import Console
9
9
 
10
- from sqlsaber.agents import build_sqlsaber_agent
11
- from sqlsaber.cli.auth import create_auth_app
12
- from sqlsaber.cli.database import create_db_app
13
- from sqlsaber.cli.interactive import InteractiveSession
14
- from sqlsaber.cli.memory import create_memory_app
15
- from sqlsaber.cli.models import create_models_app
16
- from sqlsaber.cli.streaming import StreamingQueryHandler
10
+ # Lazy imports - only import what's needed for CLI parsing
17
11
  from sqlsaber.config.database import DatabaseConfigManager
18
- from sqlsaber.database.connection import (
19
- CSVConnection,
20
- DatabaseConnection,
21
- MySQLConnection,
22
- PostgreSQLConnection,
23
- SQLiteConnection,
24
- )
25
- from sqlsaber.database.resolver import DatabaseResolutionError, resolve_database
26
12
 
27
13
 
28
14
  class CLIError(Exception):
@@ -106,6 +92,21 @@ def query(
106
92
  """
107
93
 
108
94
  async def run_session():
95
+ # Import heavy dependencies only when actually running a query
96
+ # This is only done to speed up startup time
97
+ from sqlsaber.agents import build_sqlsaber_agent
98
+ from sqlsaber.cli.interactive import InteractiveSession
99
+ from sqlsaber.cli.streaming import StreamingQueryHandler
100
+ from sqlsaber.database.connection import (
101
+ CSVConnection,
102
+ DatabaseConnection,
103
+ MySQLConnection,
104
+ PostgreSQLConnection,
105
+ SQLiteConnection,
106
+ )
107
+ from sqlsaber.database.resolver import DatabaseResolutionError, resolve_database
108
+ from sqlsaber.threads import ThreadStorage
109
+
109
110
  # Check if query_text is None and stdin has data
110
111
  actual_query = query_text
111
112
  if query_text is None and not sys.stdin.isatty():
@@ -149,7 +150,33 @@ def query(
149
150
  console.print(
150
151
  f"[bold blue]Connected to:[/bold blue] {db_name} ({db_type})\n"
151
152
  )
152
- await streaming_handler.execute_streaming_query(actual_query, agent)
153
+ run = await streaming_handler.execute_streaming_query(
154
+ actual_query, agent
155
+ )
156
+ # Persist non-interactive run as a thread snapshot so it can be resumed later
157
+ try:
158
+ if run is not None:
159
+ threads = ThreadStorage()
160
+ # Extract title and model name
161
+ title = actual_query
162
+ model_name: str | None = agent.model.model_name
163
+
164
+ thread_id = await threads.save_snapshot(
165
+ messages_json=run.all_messages_json(),
166
+ database_name=db_name,
167
+ )
168
+ await threads.save_metadata(
169
+ thread_id=thread_id, title=title, model_name=model_name
170
+ )
171
+ await threads.end_thread(thread_id)
172
+ console.print(
173
+ f"[dim]You can continue this thread using:[/dim] saber threads resume {thread_id}"
174
+ )
175
+ except Exception:
176
+ # best-effort persistence; don't fail the CLI on storage errors
177
+ pass
178
+ finally:
179
+ await threads.prune_threads()
153
180
  else:
154
181
  # Interactive mode
155
182
  session = InteractiveSession(console, agent, db_conn, db_name)
@@ -168,21 +195,45 @@ def query(
168
195
  sys.exit(e.exit_code)
169
196
 
170
197
 
171
- # Add authentication management commands
172
- auth_app = create_auth_app()
173
- app.command(auth_app, name="auth")
198
+ # Use lazy imports for fast CLI startup time
199
+ @app.command(name="auth")
200
+ def auth(*args, **kwargs):
201
+ """Manage authentication configuration."""
202
+ from sqlsaber.cli.auth import create_auth_app
203
+
204
+ return create_auth_app()(*args, **kwargs)
205
+
206
+
207
+ @app.command(name="db")
208
+ def db(*args, **kwargs):
209
+ """Manage database connections."""
210
+ from sqlsaber.cli.database import create_db_app
211
+
212
+ return create_db_app()(*args, **kwargs)
213
+
214
+
215
+ @app.command(name="memory")
216
+ def memory(*args, **kwargs):
217
+ """Manage database-specific memories."""
218
+ from sqlsaber.cli.memory import create_memory_app
219
+
220
+ return create_memory_app()(*args, **kwargs)
221
+
222
+
223
+ @app.command(name="models")
224
+ def models(*args, **kwargs):
225
+ """Select and manage models."""
226
+ from sqlsaber.cli.models import create_models_app
227
+
228
+ return create_models_app()(*args, **kwargs)
174
229
 
175
- # Add database management commands after main callback is defined
176
- db_app = create_db_app()
177
- app.command(db_app, name="db")
178
230
 
179
- # Add memory management commands
180
- memory_app = create_memory_app()
181
- app.command(memory_app, name="memory")
231
+ @app.command(name="threads")
232
+ def threads(*args, **kwargs):
233
+ """Manage SQLsaber threads."""
234
+ from sqlsaber.cli.threads import create_threads_app
182
235
 
183
- # Add model management commands
184
- models_app = create_models_app()
185
- app.command(models_app, name="models")
236
+ return create_threads_app()(*args, **kwargs)
186
237
 
187
238
 
188
239
  def main():
sqlsaber/cli/display.py CHANGED
@@ -209,36 +209,6 @@ class DisplayManager:
209
209
  except Exception as e:
210
210
  self.show_error(f"Error displaying schema information: {str(e)}")
211
211
 
212
- def show_plot(self, plot_data: dict):
213
- """Display plot information and status."""
214
- try:
215
- # Parse the result if it's a string
216
- if isinstance(plot_data.get("result"), str):
217
- result = json.loads(plot_data["result"])
218
- else:
219
- result = plot_data.get("result", {})
220
-
221
- # Check if there was an error
222
- if "error" in result:
223
- self.show_error(f"Plot error: {result['error']}")
224
- return
225
-
226
- # If plot was successful, show plot info
227
- if result.get("success") and result.get("plot_rendered"):
228
- plot_info = result.get("plot_info", {})
229
- self.console.print(
230
- f"\n[bold green]✓ Plot rendered:[/bold green] {plot_info.get('title', 'Plot')}"
231
- )
232
- self.console.print(
233
- f"[dim] Type: {plot_info.get('type', 'unknown')}, "
234
- f"Data points: {plot_info.get('data_points', 0)}[/dim]"
235
- )
236
-
237
- except json.JSONDecodeError:
238
- self.show_error("Failed to parse plot result")
239
- except Exception as e:
240
- self.show_error(f"Error displaying plot: {str(e)}")
241
-
242
212
  def show_markdown_response(self, content: list):
243
213
  """Display the assistant's response as rich markdown in a panel."""
244
214
  if not content:
@@ -14,13 +14,29 @@ from sqlsaber.cli.completers import (
14
14
  )
15
15
  from sqlsaber.cli.display import DisplayManager
16
16
  from sqlsaber.cli.streaming import StreamingQueryHandler
17
+ from sqlsaber.database.connection import (
18
+ CSVConnection,
19
+ MySQLConnection,
20
+ PostgreSQLConnection,
21
+ SQLiteConnection,
22
+ )
17
23
  from sqlsaber.database.schema import SchemaManager
24
+ from sqlsaber.threads import ThreadStorage
18
25
 
19
26
 
20
27
  class InteractiveSession:
21
28
  """Manages interactive CLI sessions."""
22
29
 
23
- def __init__(self, console: Console, agent: Agent, db_conn, database_name: str):
30
+ def __init__(
31
+ self,
32
+ console: Console,
33
+ agent: Agent,
34
+ db_conn,
35
+ database_name: str,
36
+ *,
37
+ initial_thread_id: str | None = None,
38
+ initial_history: list | None = None,
39
+ ):
24
40
  self.console = console
25
41
  self.agent = agent
26
42
  self.db_conn = db_conn
@@ -30,19 +46,16 @@ class InteractiveSession:
30
46
  self.current_task: asyncio.Task | None = None
31
47
  self.cancellation_token: asyncio.Event | None = None
32
48
  self.table_completer = TableNameCompleter()
33
- self.message_history: list | None = []
49
+ self.message_history: list | None = initial_history or []
50
+ # Conversation Thread persistence
51
+ self._threads = ThreadStorage()
52
+ self._thread_id: str | None = initial_thread_id
53
+ self.first_message = not self._thread_id
34
54
 
35
55
  def show_welcome_message(self):
36
56
  """Display welcome message for interactive mode."""
37
57
  # Show database information
38
58
  db_name = self.database_name or "Unknown"
39
- from sqlsaber.database.connection import (
40
- CSVConnection,
41
- MySQLConnection,
42
- PostgreSQLConnection,
43
- SQLiteConnection,
44
- )
45
-
46
59
  db_type = (
47
60
  "PostgreSQL"
48
61
  if isinstance(self.db_conn, PostgreSQLConnection)
@@ -53,32 +66,36 @@ class InteractiveSession:
53
66
  else "database"
54
67
  )
55
68
 
56
- self.console.print(
57
- Panel.fit(
58
- """
69
+ if self.first_message:
70
+ self.console.print(
71
+ Panel.fit(
72
+ """
59
73
  ███████ ██████ ██ ███████ █████ ██████ ███████ ██████
60
74
  ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
61
75
  ███████ ██ ██ ██ ███████ ███████ ██████ █████ ██████
62
76
  ██ ██ ▄▄ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
63
77
  ███████ ██████ ███████ ███████ ██ ██ ██████ ███████ ██ ██
64
78
  ▀▀
65
- """
79
+ """
80
+ )
81
+ )
82
+ self.console.print(
83
+ "\n",
84
+ "[dim] > Use '/clear' to reset conversation",
85
+ "[dim] > Use '/exit' or '/quit' to leave[/dim]",
86
+ "[dim] > Use 'Ctrl+C' to interrupt and return to prompt\n\n",
87
+ "[dim] > Start message with '#' to add something to agent's memory for this database",
88
+ "[dim] > Type '@' to get table name completions",
89
+ "[dim] > Press 'Esc-Enter' or 'Meta-Enter' to submit your question",
90
+ sep="\n",
66
91
  )
67
- )
68
- self.console.print(
69
- "\n",
70
- "[dim] ≥ Use '/clear' to reset conversation",
71
- "[dim] ≥ Use '/exit' or '/quit' to leave[/dim]",
72
- "[dim] ≥ Use 'Ctrl+C' to interrupt and return to prompt\n\n",
73
- "[dim] ≥ Start message with '#' to add something to agent's memory for this database",
74
- "[dim] ≥ Type '@' to get table name completions",
75
- "[dim] ≥ Press 'Esc-Enter' or 'Meta-Enter' to submit your question",
76
- sep="\n",
77
- )
78
92
 
79
93
  self.console.print(
80
94
  f"[bold blue]\n\nConnected to:[/bold blue] {db_name} ({db_type})\n"
81
95
  )
96
+ # If resuming a thread, show a notice
97
+ if self._thread_id:
98
+ self.console.print(f"[dim]Resuming thread:[/dim] {self._thread_id}\n")
82
99
 
83
100
  async def _update_table_cache(self):
84
101
  """Update the table completer cache with fresh data."""
@@ -132,8 +149,29 @@ class InteractiveSession:
132
149
  try:
133
150
  # Use all_messages() so the system prompt and all prior turns are preserved
134
151
  self.message_history = run_result.all_messages()
152
+
153
+ # Extract title (first user prompt) and model name
154
+ if not self._thread_id:
155
+ title = user_query
156
+ model_name = self.agent.model.model_name
157
+
158
+ # Persist snapshot to thread storage (create or overwrite)
159
+ self._thread_id = await self._threads.save_snapshot(
160
+ messages_json=run_result.all_messages_json(),
161
+ database_name=self.database_name,
162
+ thread_id=self._thread_id,
163
+ )
164
+ # Save metadata separately (only if its the first message)
165
+ if self.first_message:
166
+ await self._threads.save_metadata(
167
+ thread_id=self._thread_id,
168
+ title=title,
169
+ model_name=model_name,
170
+ )
135
171
  except Exception:
136
172
  pass
173
+ finally:
174
+ await self._threads.prune_threads()
137
175
  finally:
138
176
  self.current_task = None
139
177
  self.cancellation_token = None
@@ -165,12 +203,26 @@ class InteractiveSession:
165
203
  or user_query.startswith("/exit")
166
204
  or user_query.startswith("/quit")
167
205
  ):
206
+ # Print resume hint if there is an active thread
207
+ if self._thread_id:
208
+ await self._threads.end_thread(self._thread_id)
209
+ self.console.print(
210
+ f"[dim]You can continue this thread using:[/dim] saber threads resume {self._thread_id}"
211
+ )
168
212
  break
169
213
 
170
214
  if user_query == "/clear":
171
215
  # Reset local history (pydantic-ai call will receive empty history on next run)
172
216
  self.message_history = []
217
+ # End current thread (if any) so the next turn creates a fresh one
218
+ try:
219
+ if self._thread_id:
220
+ await self._threads.end_thread(self._thread_id)
221
+ except Exception:
222
+ pass
173
223
  self.console.print("[green]Conversation history cleared.[/green]\n")
224
+ # Do not print resume hint when clearing; a new thread will be created on next turn
225
+ self._thread_id = None
174
226
  continue
175
227
 
176
228
  if memory_text := user_query.strip():
sqlsaber/cli/streaming.py CHANGED
@@ -72,10 +72,6 @@ class StreamingQueryHandler:
72
72
  except json.JSONDecodeError:
73
73
  # If not JSON, ignore here
74
74
  pass
75
- elif tool_name == "plot_data":
76
- self.display.show_plot(
77
- {"tool_name": tool_name, "result": content, "input": {}}
78
- )
79
75
 
80
76
  async def execute_streaming_query(
81
77
  self,