sqlsaber 0.16.1__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/base.py +0 -108
- sqlsaber/cli/commands.py +33 -1
- sqlsaber/cli/display.py +0 -30
- sqlsaber/cli/interactive.py +76 -24
- sqlsaber/cli/streaming.py +0 -4
- sqlsaber/cli/threads.py +301 -0
- sqlsaber/database/schema.py +30 -2
- sqlsaber/threads/__init__.py +5 -0
- sqlsaber/threads/storage.py +303 -0
- sqlsaber/tools/__init__.py +0 -2
- sqlsaber/tools/base.py +0 -12
- sqlsaber/tools/enums.py +0 -2
- sqlsaber/tools/instructions.py +3 -23
- sqlsaber/tools/registry.py +0 -12
- {sqlsaber-0.16.1.dist-info → sqlsaber-0.17.0.dist-info}/METADATA +12 -3
- {sqlsaber-0.16.1.dist-info → sqlsaber-0.17.0.dist-info}/RECORD +19 -23
- sqlsaber/conversation/__init__.py +0 -12
- sqlsaber/conversation/manager.py +0 -224
- sqlsaber/conversation/models.py +0 -120
- sqlsaber/conversation/storage.py +0 -362
- sqlsaber/models/__init__.py +0 -10
- sqlsaber/models/types.py +0 -40
- sqlsaber/tools/visualization_tools.py +0 -144
- {sqlsaber-0.16.1.dist-info → sqlsaber-0.17.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.16.1.dist-info → sqlsaber-0.17.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.16.1.dist-info → sqlsaber-0.17.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/agents/base.py
CHANGED
|
@@ -5,7 +5,6 @@ import json
|
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from typing import Any, AsyncIterator
|
|
7
7
|
|
|
8
|
-
from sqlsaber.conversation.manager import ConversationManager
|
|
9
8
|
from sqlsaber.database.connection import (
|
|
10
9
|
BaseDatabaseConnection,
|
|
11
10
|
CSVConnection,
|
|
@@ -23,12 +22,6 @@ class BaseSQLAgent(ABC):
|
|
|
23
22
|
def __init__(self, db_connection: BaseDatabaseConnection):
|
|
24
23
|
self.db = db_connection
|
|
25
24
|
self.schema_manager = SchemaManager(db_connection)
|
|
26
|
-
self.conversation_history: list[dict[str, Any]] = []
|
|
27
|
-
|
|
28
|
-
# Conversation persistence
|
|
29
|
-
self._conv_manager = ConversationManager()
|
|
30
|
-
self._conversation_id: str | None = None
|
|
31
|
-
self._msg_index: int = 0
|
|
32
25
|
|
|
33
26
|
# Initialize SQL tools with database connection
|
|
34
27
|
self._init_tools()
|
|
@@ -49,14 +42,6 @@ class BaseSQLAgent(ABC):
|
|
|
49
42
|
"""
|
|
50
43
|
pass
|
|
51
44
|
|
|
52
|
-
async def clear_history(self):
|
|
53
|
-
"""Clear conversation history."""
|
|
54
|
-
# End current conversation in storage
|
|
55
|
-
await self._end_conversation()
|
|
56
|
-
|
|
57
|
-
# Clear in-memory history
|
|
58
|
-
self.conversation_history = []
|
|
59
|
-
|
|
60
45
|
def _get_database_type_name(self) -> str:
|
|
61
46
|
"""Get the human-readable database type name."""
|
|
62
47
|
if isinstance(self.db, PostgreSQLConnection):
|
|
@@ -91,96 +76,3 @@ class BaseSQLAgent(ABC):
|
|
|
91
76
|
return json.dumps(
|
|
92
77
|
{"error": f"Error executing tool '{tool_name}': {str(e)}"}
|
|
93
78
|
)
|
|
94
|
-
|
|
95
|
-
# Conversation persistence helpers
|
|
96
|
-
|
|
97
|
-
async def _ensure_conversation(self) -> None:
|
|
98
|
-
"""Ensure a conversation is active for storing messages."""
|
|
99
|
-
if self._conversation_id is None:
|
|
100
|
-
db_name = getattr(self, "database_name", "unknown")
|
|
101
|
-
self._conversation_id = await self._conv_manager.start_conversation(db_name)
|
|
102
|
-
self._msg_index = 0
|
|
103
|
-
|
|
104
|
-
async def _store_user_message(self, content: str | dict[str, Any]) -> None:
|
|
105
|
-
"""Store a user message in conversation history."""
|
|
106
|
-
if self._conversation_id is None:
|
|
107
|
-
return
|
|
108
|
-
|
|
109
|
-
await self._conv_manager.add_user_message(
|
|
110
|
-
self._conversation_id, content, self._msg_index
|
|
111
|
-
)
|
|
112
|
-
self._msg_index += 1
|
|
113
|
-
|
|
114
|
-
async def _store_assistant_message(
|
|
115
|
-
self, content: list[dict[str, Any]] | dict[str, Any]
|
|
116
|
-
) -> None:
|
|
117
|
-
"""Store an assistant message in conversation history."""
|
|
118
|
-
if self._conversation_id is None:
|
|
119
|
-
return
|
|
120
|
-
|
|
121
|
-
await self._conv_manager.add_assistant_message(
|
|
122
|
-
self._conversation_id, content, self._msg_index
|
|
123
|
-
)
|
|
124
|
-
self._msg_index += 1
|
|
125
|
-
|
|
126
|
-
async def _store_tool_message(
|
|
127
|
-
self, content: list[dict[str, Any]] | dict[str, Any]
|
|
128
|
-
) -> None:
|
|
129
|
-
"""Store a tool/system message in conversation history."""
|
|
130
|
-
if self._conversation_id is None:
|
|
131
|
-
return
|
|
132
|
-
|
|
133
|
-
await self._conv_manager.add_tool_message(
|
|
134
|
-
self._conversation_id, content, self._msg_index
|
|
135
|
-
)
|
|
136
|
-
self._msg_index += 1
|
|
137
|
-
|
|
138
|
-
async def _end_conversation(self) -> None:
|
|
139
|
-
"""End the current conversation."""
|
|
140
|
-
if self._conversation_id:
|
|
141
|
-
await self._conv_manager.end_conversation(self._conversation_id)
|
|
142
|
-
self._conversation_id = None
|
|
143
|
-
self._msg_index = 0
|
|
144
|
-
|
|
145
|
-
async def restore_conversation(self, conversation_id: str) -> bool:
|
|
146
|
-
"""Restore a conversation from storage to in-memory history.
|
|
147
|
-
|
|
148
|
-
Args:
|
|
149
|
-
conversation_id: ID of the conversation to restore
|
|
150
|
-
|
|
151
|
-
Returns:
|
|
152
|
-
True if successfully restored, False otherwise
|
|
153
|
-
"""
|
|
154
|
-
success = await self._conv_manager.restore_conversation_to_agent(
|
|
155
|
-
conversation_id, self.conversation_history
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
if success:
|
|
159
|
-
# Set up for continuing this conversation
|
|
160
|
-
self._conversation_id = conversation_id
|
|
161
|
-
self._msg_index = len(self.conversation_history)
|
|
162
|
-
|
|
163
|
-
return success
|
|
164
|
-
|
|
165
|
-
async def list_conversations(self, limit: int = 50) -> list:
|
|
166
|
-
"""List conversations for this agent's database.
|
|
167
|
-
|
|
168
|
-
Args:
|
|
169
|
-
limit: Maximum number of conversations to return
|
|
170
|
-
|
|
171
|
-
Returns:
|
|
172
|
-
List of conversation data
|
|
173
|
-
"""
|
|
174
|
-
db_name = getattr(self, "database_name", None)
|
|
175
|
-
conversations = await self._conv_manager.list_conversations(db_name, limit)
|
|
176
|
-
|
|
177
|
-
return [
|
|
178
|
-
{
|
|
179
|
-
"id": conv.id,
|
|
180
|
-
"database_name": conv.database_name,
|
|
181
|
-
"started_at": conv.formatted_start_time(),
|
|
182
|
-
"ended_at": conv.formatted_end_time(),
|
|
183
|
-
"duration": conv.duration_seconds(),
|
|
184
|
-
}
|
|
185
|
-
for conv in conversations
|
|
186
|
-
]
|
sqlsaber/cli/commands.py
CHANGED
|
@@ -14,6 +14,7 @@ from sqlsaber.cli.interactive import InteractiveSession
|
|
|
14
14
|
from sqlsaber.cli.memory import create_memory_app
|
|
15
15
|
from sqlsaber.cli.models import create_models_app
|
|
16
16
|
from sqlsaber.cli.streaming import StreamingQueryHandler
|
|
17
|
+
from sqlsaber.cli.threads import create_threads_app
|
|
17
18
|
from sqlsaber.config.database import DatabaseConfigManager
|
|
18
19
|
from sqlsaber.database.connection import (
|
|
19
20
|
CSVConnection,
|
|
@@ -23,6 +24,7 @@ from sqlsaber.database.connection import (
|
|
|
23
24
|
SQLiteConnection,
|
|
24
25
|
)
|
|
25
26
|
from sqlsaber.database.resolver import DatabaseResolutionError, resolve_database
|
|
27
|
+
from sqlsaber.threads import ThreadStorage
|
|
26
28
|
|
|
27
29
|
|
|
28
30
|
class CLIError(Exception):
|
|
@@ -149,7 +151,33 @@ def query(
|
|
|
149
151
|
console.print(
|
|
150
152
|
f"[bold blue]Connected to:[/bold blue] {db_name} ({db_type})\n"
|
|
151
153
|
)
|
|
152
|
-
await streaming_handler.execute_streaming_query(
|
|
154
|
+
run = await streaming_handler.execute_streaming_query(
|
|
155
|
+
actual_query, agent
|
|
156
|
+
)
|
|
157
|
+
# Persist non-interactive run as a thread snapshot so it can be resumed later
|
|
158
|
+
try:
|
|
159
|
+
if run is not None:
|
|
160
|
+
threads = ThreadStorage()
|
|
161
|
+
# Extract title and model name
|
|
162
|
+
title = actual_query
|
|
163
|
+
model_name: str | None = agent.model.model_name
|
|
164
|
+
|
|
165
|
+
thread_id = await threads.save_snapshot(
|
|
166
|
+
messages_json=run.all_messages_json(),
|
|
167
|
+
database_name=db_name,
|
|
168
|
+
)
|
|
169
|
+
await threads.save_metadata(
|
|
170
|
+
thread_id=thread_id, title=title, model_name=model_name
|
|
171
|
+
)
|
|
172
|
+
await threads.end_thread(thread_id)
|
|
173
|
+
console.print(
|
|
174
|
+
f"[dim]You can continue this thread using:[/dim] saber threads resume {thread_id}"
|
|
175
|
+
)
|
|
176
|
+
except Exception:
|
|
177
|
+
# best-effort persistence; don't fail the CLI on storage errors
|
|
178
|
+
pass
|
|
179
|
+
finally:
|
|
180
|
+
await threads.prune_threads()
|
|
153
181
|
else:
|
|
154
182
|
# Interactive mode
|
|
155
183
|
session = InteractiveSession(console, agent, db_conn, db_name)
|
|
@@ -184,6 +212,10 @@ app.command(memory_app, name="memory")
|
|
|
184
212
|
models_app = create_models_app()
|
|
185
213
|
app.command(models_app, name="models")
|
|
186
214
|
|
|
215
|
+
# Add threads management commands
|
|
216
|
+
threads_app = create_threads_app()
|
|
217
|
+
app.command(threads_app, name="threads")
|
|
218
|
+
|
|
187
219
|
|
|
188
220
|
def main():
|
|
189
221
|
"""Entry point for the CLI application."""
|
sqlsaber/cli/display.py
CHANGED
|
@@ -209,36 +209,6 @@ class DisplayManager:
|
|
|
209
209
|
except Exception as e:
|
|
210
210
|
self.show_error(f"Error displaying schema information: {str(e)}")
|
|
211
211
|
|
|
212
|
-
def show_plot(self, plot_data: dict):
|
|
213
|
-
"""Display plot information and status."""
|
|
214
|
-
try:
|
|
215
|
-
# Parse the result if it's a string
|
|
216
|
-
if isinstance(plot_data.get("result"), str):
|
|
217
|
-
result = json.loads(plot_data["result"])
|
|
218
|
-
else:
|
|
219
|
-
result = plot_data.get("result", {})
|
|
220
|
-
|
|
221
|
-
# Check if there was an error
|
|
222
|
-
if "error" in result:
|
|
223
|
-
self.show_error(f"Plot error: {result['error']}")
|
|
224
|
-
return
|
|
225
|
-
|
|
226
|
-
# If plot was successful, show plot info
|
|
227
|
-
if result.get("success") and result.get("plot_rendered"):
|
|
228
|
-
plot_info = result.get("plot_info", {})
|
|
229
|
-
self.console.print(
|
|
230
|
-
f"\n[bold green]✓ Plot rendered:[/bold green] {plot_info.get('title', 'Plot')}"
|
|
231
|
-
)
|
|
232
|
-
self.console.print(
|
|
233
|
-
f"[dim] Type: {plot_info.get('type', 'unknown')}, "
|
|
234
|
-
f"Data points: {plot_info.get('data_points', 0)}[/dim]"
|
|
235
|
-
)
|
|
236
|
-
|
|
237
|
-
except json.JSONDecodeError:
|
|
238
|
-
self.show_error("Failed to parse plot result")
|
|
239
|
-
except Exception as e:
|
|
240
|
-
self.show_error(f"Error displaying plot: {str(e)}")
|
|
241
|
-
|
|
242
212
|
def show_markdown_response(self, content: list):
|
|
243
213
|
"""Display the assistant's response as rich markdown in a panel."""
|
|
244
214
|
if not content:
|
sqlsaber/cli/interactive.py
CHANGED
|
@@ -14,13 +14,29 @@ from sqlsaber.cli.completers import (
|
|
|
14
14
|
)
|
|
15
15
|
from sqlsaber.cli.display import DisplayManager
|
|
16
16
|
from sqlsaber.cli.streaming import StreamingQueryHandler
|
|
17
|
+
from sqlsaber.database.connection import (
|
|
18
|
+
CSVConnection,
|
|
19
|
+
MySQLConnection,
|
|
20
|
+
PostgreSQLConnection,
|
|
21
|
+
SQLiteConnection,
|
|
22
|
+
)
|
|
17
23
|
from sqlsaber.database.schema import SchemaManager
|
|
24
|
+
from sqlsaber.threads import ThreadStorage
|
|
18
25
|
|
|
19
26
|
|
|
20
27
|
class InteractiveSession:
|
|
21
28
|
"""Manages interactive CLI sessions."""
|
|
22
29
|
|
|
23
|
-
def __init__(
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
console: Console,
|
|
33
|
+
agent: Agent,
|
|
34
|
+
db_conn,
|
|
35
|
+
database_name: str,
|
|
36
|
+
*,
|
|
37
|
+
initial_thread_id: str | None = None,
|
|
38
|
+
initial_history: list | None = None,
|
|
39
|
+
):
|
|
24
40
|
self.console = console
|
|
25
41
|
self.agent = agent
|
|
26
42
|
self.db_conn = db_conn
|
|
@@ -30,19 +46,16 @@ class InteractiveSession:
|
|
|
30
46
|
self.current_task: asyncio.Task | None = None
|
|
31
47
|
self.cancellation_token: asyncio.Event | None = None
|
|
32
48
|
self.table_completer = TableNameCompleter()
|
|
33
|
-
self.message_history: list | None = []
|
|
49
|
+
self.message_history: list | None = initial_history or []
|
|
50
|
+
# Conversation Thread persistence
|
|
51
|
+
self._threads = ThreadStorage()
|
|
52
|
+
self._thread_id: str | None = initial_thread_id
|
|
53
|
+
self.first_message = not self._thread_id
|
|
34
54
|
|
|
35
55
|
def show_welcome_message(self):
|
|
36
56
|
"""Display welcome message for interactive mode."""
|
|
37
57
|
# Show database information
|
|
38
58
|
db_name = self.database_name or "Unknown"
|
|
39
|
-
from sqlsaber.database.connection import (
|
|
40
|
-
CSVConnection,
|
|
41
|
-
MySQLConnection,
|
|
42
|
-
PostgreSQLConnection,
|
|
43
|
-
SQLiteConnection,
|
|
44
|
-
)
|
|
45
|
-
|
|
46
59
|
db_type = (
|
|
47
60
|
"PostgreSQL"
|
|
48
61
|
if isinstance(self.db_conn, PostgreSQLConnection)
|
|
@@ -53,32 +66,36 @@ class InteractiveSession:
|
|
|
53
66
|
else "database"
|
|
54
67
|
)
|
|
55
68
|
|
|
56
|
-
self.
|
|
57
|
-
|
|
58
|
-
|
|
69
|
+
if self.first_message:
|
|
70
|
+
self.console.print(
|
|
71
|
+
Panel.fit(
|
|
72
|
+
"""
|
|
59
73
|
███████ ██████ ██ ███████ █████ ██████ ███████ ██████
|
|
60
74
|
██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
61
75
|
███████ ██ ██ ██ ███████ ███████ ██████ █████ ██████
|
|
62
76
|
██ ██ ▄▄ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
63
77
|
███████ ██████ ███████ ███████ ██ ██ ██████ ███████ ██ ██
|
|
64
78
|
▀▀
|
|
65
|
-
"""
|
|
79
|
+
"""
|
|
80
|
+
)
|
|
81
|
+
)
|
|
82
|
+
self.console.print(
|
|
83
|
+
"\n",
|
|
84
|
+
"[dim] > Use '/clear' to reset conversation",
|
|
85
|
+
"[dim] > Use '/exit' or '/quit' to leave[/dim]",
|
|
86
|
+
"[dim] > Use 'Ctrl+C' to interrupt and return to prompt\n\n",
|
|
87
|
+
"[dim] > Start message with '#' to add something to agent's memory for this database",
|
|
88
|
+
"[dim] > Type '@' to get table name completions",
|
|
89
|
+
"[dim] > Press 'Esc-Enter' or 'Meta-Enter' to submit your question",
|
|
90
|
+
sep="\n",
|
|
66
91
|
)
|
|
67
|
-
)
|
|
68
|
-
self.console.print(
|
|
69
|
-
"\n",
|
|
70
|
-
"[dim] ≥ Use '/clear' to reset conversation",
|
|
71
|
-
"[dim] ≥ Use '/exit' or '/quit' to leave[/dim]",
|
|
72
|
-
"[dim] ≥ Use 'Ctrl+C' to interrupt and return to prompt\n\n",
|
|
73
|
-
"[dim] ≥ Start message with '#' to add something to agent's memory for this database",
|
|
74
|
-
"[dim] ≥ Type '@' to get table name completions",
|
|
75
|
-
"[dim] ≥ Press 'Esc-Enter' or 'Meta-Enter' to submit your question",
|
|
76
|
-
sep="\n",
|
|
77
|
-
)
|
|
78
92
|
|
|
79
93
|
self.console.print(
|
|
80
94
|
f"[bold blue]\n\nConnected to:[/bold blue] {db_name} ({db_type})\n"
|
|
81
95
|
)
|
|
96
|
+
# If resuming a thread, show a notice
|
|
97
|
+
if self._thread_id:
|
|
98
|
+
self.console.print(f"[dim]Resuming thread:[/dim] {self._thread_id}\n")
|
|
82
99
|
|
|
83
100
|
async def _update_table_cache(self):
|
|
84
101
|
"""Update the table completer cache with fresh data."""
|
|
@@ -132,8 +149,29 @@ class InteractiveSession:
|
|
|
132
149
|
try:
|
|
133
150
|
# Use all_messages() so the system prompt and all prior turns are preserved
|
|
134
151
|
self.message_history = run_result.all_messages()
|
|
152
|
+
|
|
153
|
+
# Extract title (first user prompt) and model name
|
|
154
|
+
if not self._thread_id:
|
|
155
|
+
title = user_query
|
|
156
|
+
model_name = self.agent.model.model_name
|
|
157
|
+
|
|
158
|
+
# Persist snapshot to thread storage (create or overwrite)
|
|
159
|
+
self._thread_id = await self._threads.save_snapshot(
|
|
160
|
+
messages_json=run_result.all_messages_json(),
|
|
161
|
+
database_name=self.database_name,
|
|
162
|
+
thread_id=self._thread_id,
|
|
163
|
+
)
|
|
164
|
+
# Save metadata separately (only if its the first message)
|
|
165
|
+
if self.first_message:
|
|
166
|
+
await self._threads.save_metadata(
|
|
167
|
+
thread_id=self._thread_id,
|
|
168
|
+
title=title,
|
|
169
|
+
model_name=model_name,
|
|
170
|
+
)
|
|
135
171
|
except Exception:
|
|
136
172
|
pass
|
|
173
|
+
finally:
|
|
174
|
+
await self._threads.prune_threads()
|
|
137
175
|
finally:
|
|
138
176
|
self.current_task = None
|
|
139
177
|
self.cancellation_token = None
|
|
@@ -165,12 +203,26 @@ class InteractiveSession:
|
|
|
165
203
|
or user_query.startswith("/exit")
|
|
166
204
|
or user_query.startswith("/quit")
|
|
167
205
|
):
|
|
206
|
+
# Print resume hint if there is an active thread
|
|
207
|
+
if self._thread_id:
|
|
208
|
+
await self._threads.end_thread(self._thread_id)
|
|
209
|
+
self.console.print(
|
|
210
|
+
f"[dim]You can continue this thread using:[/dim] saber threads resume {self._thread_id}"
|
|
211
|
+
)
|
|
168
212
|
break
|
|
169
213
|
|
|
170
214
|
if user_query == "/clear":
|
|
171
215
|
# Reset local history (pydantic-ai call will receive empty history on next run)
|
|
172
216
|
self.message_history = []
|
|
217
|
+
# End current thread (if any) so the next turn creates a fresh one
|
|
218
|
+
try:
|
|
219
|
+
if self._thread_id:
|
|
220
|
+
await self._threads.end_thread(self._thread_id)
|
|
221
|
+
except Exception:
|
|
222
|
+
pass
|
|
173
223
|
self.console.print("[green]Conversation history cleared.[/green]\n")
|
|
224
|
+
# Do not print resume hint when clearing; a new thread will be created on next turn
|
|
225
|
+
self._thread_id = None
|
|
174
226
|
continue
|
|
175
227
|
|
|
176
228
|
if memory_text := user_query.strip():
|
sqlsaber/cli/streaming.py
CHANGED
|
@@ -72,10 +72,6 @@ class StreamingQueryHandler:
|
|
|
72
72
|
except json.JSONDecodeError:
|
|
73
73
|
# If not JSON, ignore here
|
|
74
74
|
pass
|
|
75
|
-
elif tool_name == "plot_data":
|
|
76
|
-
self.display.show_plot(
|
|
77
|
-
{"tool_name": tool_name, "result": content, "input": {}}
|
|
78
|
-
)
|
|
79
75
|
|
|
80
76
|
async def execute_streaming_query(
|
|
81
77
|
self,
|