sqlsaber 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/anthropic.py +55 -17
- sqlsaber/agents/base.py +13 -3
- sqlsaber/cli/completers.py +172 -0
- sqlsaber/cli/display.py +33 -3
- sqlsaber/cli/interactive.py +97 -8
- sqlsaber/cli/streaming.py +29 -2
- sqlsaber/database/schema.py +17 -0
- {sqlsaber-0.5.0.dist-info → sqlsaber-0.7.0.dist-info}/METADATA +9 -8
- {sqlsaber-0.5.0.dist-info → sqlsaber-0.7.0.dist-info}/RECORD +12 -11
- {sqlsaber-0.5.0.dist-info → sqlsaber-0.7.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.5.0.dist-info → sqlsaber-0.7.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.5.0.dist-info → sqlsaber-0.7.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/agents/anthropic.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Anthropic-specific SQL agent implementation."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import json
|
|
4
|
-
from typing import Any, AsyncIterator, Dict, List
|
|
5
|
+
from typing import Any, AsyncIterator, Dict, List
|
|
5
6
|
|
|
6
7
|
from anthropic import AsyncAnthropic
|
|
7
8
|
|
|
@@ -21,7 +22,7 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
21
22
|
"""SQL Agent using Anthropic SDK directly."""
|
|
22
23
|
|
|
23
24
|
def __init__(
|
|
24
|
-
self, db_connection: BaseDatabaseConnection, database_name:
|
|
25
|
+
self, db_connection: BaseDatabaseConnection, database_name: str | None = None
|
|
25
26
|
):
|
|
26
27
|
super().__init__(db_connection)
|
|
27
28
|
|
|
@@ -164,7 +165,7 @@ Guidelines:
|
|
|
164
165
|
|
|
165
166
|
return base_prompt
|
|
166
167
|
|
|
167
|
-
def add_memory(self, content: str) ->
|
|
168
|
+
def add_memory(self, content: str) -> str | None:
|
|
168
169
|
"""Add a memory for the current database."""
|
|
169
170
|
if not self.database_name:
|
|
170
171
|
return None
|
|
@@ -174,7 +175,7 @@ Guidelines:
|
|
|
174
175
|
self.system_prompt = self._build_system_prompt()
|
|
175
176
|
return memory.id
|
|
176
177
|
|
|
177
|
-
async def execute_sql(self, query: str, limit:
|
|
178
|
+
async def execute_sql(self, query: str, limit: int | None = None) -> str:
|
|
178
179
|
"""Execute a SQL query against the database with streaming support."""
|
|
179
180
|
# Call parent implementation for core functionality
|
|
180
181
|
result = await super().execute_sql(query, limit)
|
|
@@ -203,10 +204,18 @@ Guidelines:
|
|
|
203
204
|
return await super().process_tool_call(tool_name, tool_input)
|
|
204
205
|
|
|
205
206
|
async def _process_stream_events(
|
|
206
|
-
self,
|
|
207
|
+
self,
|
|
208
|
+
stream,
|
|
209
|
+
content_blocks: List[Dict],
|
|
210
|
+
tool_use_blocks: List[Dict],
|
|
211
|
+
cancellation_token: asyncio.Event | None = None,
|
|
207
212
|
) -> AsyncIterator[StreamEvent]:
|
|
208
213
|
"""Process stream events and yield appropriate StreamEvents."""
|
|
209
214
|
async for event in stream:
|
|
215
|
+
# Only check cancellation if token is provided
|
|
216
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
217
|
+
return
|
|
218
|
+
|
|
210
219
|
if event.type == "content_block_start":
|
|
211
220
|
if hasattr(event.content_block, "type"):
|
|
212
221
|
if event.content_block.type == "tool_use":
|
|
@@ -253,11 +262,17 @@ Guidelines:
|
|
|
253
262
|
return "stop"
|
|
254
263
|
|
|
255
264
|
async def _process_tool_results(
|
|
256
|
-
self,
|
|
265
|
+
self,
|
|
266
|
+
response: StreamingResponse,
|
|
267
|
+
cancellation_token: asyncio.Event | None = None,
|
|
257
268
|
) -> AsyncIterator[StreamEvent]:
|
|
258
269
|
"""Process tool results and yield appropriate events."""
|
|
259
270
|
tool_results = []
|
|
260
271
|
for block in response.content:
|
|
272
|
+
# Only check cancellation if token is provided
|
|
273
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
274
|
+
return
|
|
275
|
+
|
|
261
276
|
if block.get("type") == "tool_use":
|
|
262
277
|
yield StreamEvent(
|
|
263
278
|
"tool_use",
|
|
@@ -304,7 +319,10 @@ Guidelines:
|
|
|
304
319
|
yield StreamEvent("tool_result_data", tool_results)
|
|
305
320
|
|
|
306
321
|
async def query_stream(
|
|
307
|
-
self,
|
|
322
|
+
self,
|
|
323
|
+
user_query: str,
|
|
324
|
+
use_history: bool = True,
|
|
325
|
+
cancellation_token: asyncio.Event | None = None,
|
|
308
326
|
) -> AsyncIterator[StreamEvent]:
|
|
309
327
|
"""Process a user query and stream responses."""
|
|
310
328
|
# Initialize for tracking state
|
|
@@ -322,7 +340,11 @@ Guidelines:
|
|
|
322
340
|
try:
|
|
323
341
|
# Create initial stream and get response
|
|
324
342
|
response = None
|
|
325
|
-
async for event in self._create_and_process_stream(
|
|
343
|
+
async for event in self._create_and_process_stream(
|
|
344
|
+
messages, cancellation_token
|
|
345
|
+
):
|
|
346
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
347
|
+
return
|
|
326
348
|
if event.type == "response_ready":
|
|
327
349
|
response = event.data
|
|
328
350
|
else:
|
|
@@ -332,14 +354,21 @@ Guidelines:
|
|
|
332
354
|
|
|
333
355
|
# Process tool calls if needed
|
|
334
356
|
while response is not None and response.stop_reason == "tool_use":
|
|
357
|
+
# Check for cancellation at the start of tool cycle
|
|
358
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
359
|
+
return
|
|
360
|
+
|
|
335
361
|
# Add assistant's response to conversation
|
|
336
362
|
collected_content.append(
|
|
337
363
|
{"role": "assistant", "content": response.content}
|
|
338
364
|
)
|
|
339
365
|
|
|
340
|
-
# Process tool results
|
|
366
|
+
# Process tool results - DO NOT check cancellation during tool execution
|
|
367
|
+
# as this would break the tool_use -> tool_result API contract
|
|
341
368
|
tool_results = []
|
|
342
|
-
async for event in self._process_tool_results(
|
|
369
|
+
async for event in self._process_tool_results(
|
|
370
|
+
response, None
|
|
371
|
+
): # Pass None to disable cancellation checks
|
|
343
372
|
if event.type == "tool_result_data":
|
|
344
373
|
tool_results = event.data
|
|
345
374
|
else:
|
|
@@ -347,6 +376,12 @@ Guidelines:
|
|
|
347
376
|
|
|
348
377
|
# Continue conversation with tool results
|
|
349
378
|
collected_content.append({"role": "user", "content": tool_results})
|
|
379
|
+
if use_history:
|
|
380
|
+
self.conversation_history.extend(collected_content)
|
|
381
|
+
|
|
382
|
+
# Check for cancellation AFTER tool results are complete
|
|
383
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
384
|
+
return
|
|
350
385
|
|
|
351
386
|
# Signal that we're processing the tool results
|
|
352
387
|
yield StreamEvent("processing", "Analyzing results...")
|
|
@@ -354,8 +389,10 @@ Guidelines:
|
|
|
354
389
|
# Get next response
|
|
355
390
|
response = None
|
|
356
391
|
async for event in self._create_and_process_stream(
|
|
357
|
-
messages + collected_content
|
|
392
|
+
messages + collected_content, cancellation_token
|
|
358
393
|
):
|
|
394
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
395
|
+
return
|
|
359
396
|
if event.type == "response_ready":
|
|
360
397
|
response = event.data
|
|
361
398
|
else:
|
|
@@ -363,21 +400,19 @@ Guidelines:
|
|
|
363
400
|
|
|
364
401
|
# Update conversation history if using history
|
|
365
402
|
if use_history:
|
|
366
|
-
self.conversation_history.append(
|
|
367
|
-
{"role": "user", "content": user_query}
|
|
368
|
-
)
|
|
369
|
-
self.conversation_history.extend(collected_content)
|
|
370
403
|
# Add final assistant response
|
|
371
404
|
if response is not None:
|
|
372
405
|
self.conversation_history.append(
|
|
373
406
|
{"role": "assistant", "content": response.content}
|
|
374
407
|
)
|
|
375
408
|
|
|
409
|
+
except asyncio.CancelledError:
|
|
410
|
+
return
|
|
376
411
|
except Exception as e:
|
|
377
412
|
yield StreamEvent("error", str(e))
|
|
378
413
|
|
|
379
414
|
async def _create_and_process_stream(
|
|
380
|
-
self, messages: List[Dict]
|
|
415
|
+
self, messages: List[Dict], cancellation_token: asyncio.Event | None = None
|
|
381
416
|
) -> AsyncIterator[StreamEvent]:
|
|
382
417
|
"""Create a stream and yield events while building response."""
|
|
383
418
|
stream = await self.client.messages.create(
|
|
@@ -393,8 +428,11 @@ Guidelines:
|
|
|
393
428
|
tool_use_blocks = []
|
|
394
429
|
|
|
395
430
|
async for event in self._process_stream_events(
|
|
396
|
-
stream, content_blocks, tool_use_blocks
|
|
431
|
+
stream, content_blocks, tool_use_blocks, cancellation_token
|
|
397
432
|
):
|
|
433
|
+
# Only check cancellation if token is provided
|
|
434
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
435
|
+
return
|
|
398
436
|
yield event
|
|
399
437
|
|
|
400
438
|
# Finalize tool blocks and create response
|
sqlsaber/agents/base.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Abstract base class for SQL agents."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import json
|
|
4
5
|
from abc import ABC, abstractmethod
|
|
5
6
|
from typing import Any, AsyncIterator, Dict, List, Optional
|
|
@@ -27,9 +28,18 @@ class BaseSQLAgent(ABC):
|
|
|
27
28
|
|
|
28
29
|
@abstractmethod
|
|
29
30
|
async def query_stream(
|
|
30
|
-
self,
|
|
31
|
+
self,
|
|
32
|
+
user_query: str,
|
|
33
|
+
use_history: bool = True,
|
|
34
|
+
cancellation_token: asyncio.Event | None = None,
|
|
31
35
|
) -> AsyncIterator[StreamEvent]:
|
|
32
|
-
"""Process a user query and stream responses.
|
|
36
|
+
"""Process a user query and stream responses.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
user_query: The user's query to process
|
|
40
|
+
use_history: Whether to include conversation history
|
|
41
|
+
cancellation_token: Optional event to signal cancellation
|
|
42
|
+
"""
|
|
33
43
|
pass
|
|
34
44
|
|
|
35
45
|
def clear_history(self):
|
|
@@ -86,7 +96,7 @@ class BaseSQLAgent(ABC):
|
|
|
86
96
|
except Exception as e:
|
|
87
97
|
return json.dumps({"error": f"Error listing tables: {str(e)}"})
|
|
88
98
|
|
|
89
|
-
async def execute_sql(self, query: str, limit: Optional[int] =
|
|
99
|
+
async def execute_sql(self, query: str, limit: Optional[int] = None) -> str:
|
|
90
100
|
"""Execute a SQL query against the database."""
|
|
91
101
|
try:
|
|
92
102
|
# Security check - only allow SELECT queries unless write is enabled
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""Command line completers for the CLI interface."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
from prompt_toolkit.completion import Completer, Completion
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SlashCommandCompleter(Completer):
|
|
9
|
+
"""Custom completer for slash commands."""
|
|
10
|
+
|
|
11
|
+
def get_completions(self, document, complete_event):
|
|
12
|
+
"""Get completions for slash commands."""
|
|
13
|
+
# Only provide completions if the line starts with "/"
|
|
14
|
+
text = document.text
|
|
15
|
+
if text.startswith("/"):
|
|
16
|
+
# Get the partial command after the slash
|
|
17
|
+
partial_cmd = text[1:]
|
|
18
|
+
|
|
19
|
+
# Define available commands with descriptions
|
|
20
|
+
commands = [
|
|
21
|
+
("clear", "Clear conversation history"),
|
|
22
|
+
("exit", "Exit the interactive session"),
|
|
23
|
+
("quit", "Exit the interactive session"),
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
# Yield completions that match the partial command
|
|
27
|
+
for cmd, description in commands:
|
|
28
|
+
if cmd.startswith(partial_cmd):
|
|
29
|
+
yield Completion(
|
|
30
|
+
cmd,
|
|
31
|
+
start_position=-len(partial_cmd),
|
|
32
|
+
display_meta=description,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class TableNameCompleter(Completer):
|
|
37
|
+
"""Custom completer for table names."""
|
|
38
|
+
|
|
39
|
+
def __init__(self):
|
|
40
|
+
self._table_cache: List[Tuple[str, str]] = []
|
|
41
|
+
|
|
42
|
+
def update_cache(self, tables_data: List[Tuple[str, str]]):
|
|
43
|
+
"""Update the cache with fresh table data."""
|
|
44
|
+
self._table_cache = tables_data
|
|
45
|
+
|
|
46
|
+
def _get_table_names(self) -> List[Tuple[str, str]]:
|
|
47
|
+
"""Get table names from cache."""
|
|
48
|
+
return self._table_cache
|
|
49
|
+
|
|
50
|
+
def get_completions(self, document, complete_event):
|
|
51
|
+
"""Get completions for table names with fuzzy matching."""
|
|
52
|
+
text = document.text
|
|
53
|
+
cursor_position = document.cursor_position
|
|
54
|
+
|
|
55
|
+
# Find the last "@" before the cursor position
|
|
56
|
+
at_pos = text.rfind("@", 0, cursor_position)
|
|
57
|
+
|
|
58
|
+
if at_pos >= 0:
|
|
59
|
+
# Extract text after the "@" up to the cursor
|
|
60
|
+
partial_table = text[at_pos + 1 : cursor_position].lower()
|
|
61
|
+
|
|
62
|
+
# Check if this looks like a valid table reference context
|
|
63
|
+
# (not inside quotes, and followed by word characters or end of input)
|
|
64
|
+
if self._is_valid_table_context(text, at_pos, cursor_position):
|
|
65
|
+
# Get table names
|
|
66
|
+
tables = self._get_table_names()
|
|
67
|
+
|
|
68
|
+
# Collect matches with scores for ranking
|
|
69
|
+
matches = []
|
|
70
|
+
|
|
71
|
+
for table_name, description in tables:
|
|
72
|
+
table_lower = table_name.lower()
|
|
73
|
+
score = self._calculate_match_score(
|
|
74
|
+
partial_table, table_name, table_lower
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if score > 0:
|
|
78
|
+
matches.append((score, table_name, description))
|
|
79
|
+
|
|
80
|
+
# Sort by score (higher is better) and yield completions
|
|
81
|
+
matches.sort(key=lambda x: x[0], reverse=True)
|
|
82
|
+
|
|
83
|
+
for score, table_name, description in matches:
|
|
84
|
+
yield Completion(
|
|
85
|
+
table_name,
|
|
86
|
+
start_position=at_pos
|
|
87
|
+
+ 1
|
|
88
|
+
- cursor_position, # Start from after the @
|
|
89
|
+
display_meta=description if description else None,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def _is_valid_table_context(self, text: str, at_pos: int, cursor_pos: int) -> bool:
|
|
93
|
+
"""Check if the @ is in a valid context for table completion."""
|
|
94
|
+
# Simple heuristic: avoid completion inside quoted strings
|
|
95
|
+
|
|
96
|
+
# Count quotes before the @ position
|
|
97
|
+
single_quotes = text[:at_pos].count("'") - text[:at_pos].count("\\'")
|
|
98
|
+
double_quotes = text[:at_pos].count('"') - text[:at_pos].count('\\"')
|
|
99
|
+
|
|
100
|
+
# If we're inside quotes, don't complete
|
|
101
|
+
if single_quotes % 2 == 1 or double_quotes % 2 == 1:
|
|
102
|
+
return False
|
|
103
|
+
|
|
104
|
+
# Check if the character after the cursor (if any) is part of a word
|
|
105
|
+
# This helps avoid breaking existing words
|
|
106
|
+
if cursor_pos < len(text):
|
|
107
|
+
next_char = text[cursor_pos]
|
|
108
|
+
if next_char.isalnum() or next_char == "_":
|
|
109
|
+
# We're in the middle of a word, check if it looks like a table name
|
|
110
|
+
partial = (
|
|
111
|
+
text[at_pos + 1 :].split()[0] if text[at_pos + 1 :].split() else ""
|
|
112
|
+
)
|
|
113
|
+
if not any(c in partial for c in [".", "_"]):
|
|
114
|
+
return False
|
|
115
|
+
|
|
116
|
+
return True
|
|
117
|
+
|
|
118
|
+
def _calculate_match_score(
|
|
119
|
+
self, partial: str, table_name: str, table_lower: str
|
|
120
|
+
) -> int:
|
|
121
|
+
"""Calculate match score for fuzzy matching (higher is better)."""
|
|
122
|
+
if not partial:
|
|
123
|
+
return 1 # Empty search matches everything with low score
|
|
124
|
+
|
|
125
|
+
# Score 100: Exact full name prefix match
|
|
126
|
+
if table_lower.startswith(partial):
|
|
127
|
+
return 100
|
|
128
|
+
|
|
129
|
+
# Score 90: Table name (after schema) prefix match
|
|
130
|
+
if "." in table_name:
|
|
131
|
+
table_part = table_name.split(".")[-1].lower()
|
|
132
|
+
if table_part.startswith(partial):
|
|
133
|
+
return 90
|
|
134
|
+
|
|
135
|
+
# Score 80: Exact table name match (for short names)
|
|
136
|
+
if "." in table_name:
|
|
137
|
+
table_part = table_name.split(".")[-1].lower()
|
|
138
|
+
if table_part == partial:
|
|
139
|
+
return 80
|
|
140
|
+
|
|
141
|
+
# Score 70: Word boundary matches (e.g., "user" matches "user_accounts")
|
|
142
|
+
if "." in table_name:
|
|
143
|
+
table_part = table_name.split(".")[-1].lower()
|
|
144
|
+
if table_part.startswith(partial + "_") or table_part.startswith(
|
|
145
|
+
partial + "-"
|
|
146
|
+
):
|
|
147
|
+
return 70
|
|
148
|
+
|
|
149
|
+
# Score 50: Substring match in table name part
|
|
150
|
+
if "." in table_name:
|
|
151
|
+
table_part = table_name.split(".")[-1].lower()
|
|
152
|
+
if partial in table_part:
|
|
153
|
+
return 50
|
|
154
|
+
|
|
155
|
+
# Score 30: Substring match in full name
|
|
156
|
+
if partial in table_lower:
|
|
157
|
+
return 30
|
|
158
|
+
|
|
159
|
+
# Score 0: No match
|
|
160
|
+
return 0
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class CompositeCompleter(Completer):
|
|
164
|
+
"""Combines multiple completers."""
|
|
165
|
+
|
|
166
|
+
def __init__(self, *completers: Completer):
|
|
167
|
+
self.completers = completers
|
|
168
|
+
|
|
169
|
+
def get_completions(self, document, complete_event):
|
|
170
|
+
"""Get completions from all registered completers."""
|
|
171
|
+
for completer in self.completers:
|
|
172
|
+
yield from completer.get_completions(document, complete_event)
|
sqlsaber/cli/display.py
CHANGED
|
@@ -4,6 +4,7 @@ import json
|
|
|
4
4
|
from typing import Optional
|
|
5
5
|
|
|
6
6
|
from rich.console import Console
|
|
7
|
+
from rich.markdown import Markdown
|
|
7
8
|
from rich.syntax import Syntax
|
|
8
9
|
from rich.table import Table
|
|
9
10
|
|
|
@@ -62,12 +63,20 @@ class DisplayManager:
|
|
|
62
63
|
)
|
|
63
64
|
|
|
64
65
|
# Create table with columns from first result
|
|
65
|
-
|
|
66
|
-
|
|
66
|
+
all_columns = list(results[0].keys())
|
|
67
|
+
display_columns = all_columns[:15] # Limit to first 15 columns
|
|
68
|
+
|
|
69
|
+
# Show warning if columns were truncated
|
|
70
|
+
if len(all_columns) > 15:
|
|
71
|
+
self.console.print(
|
|
72
|
+
f"[yellow]Note: Showing first 15 of {len(all_columns)} columns[/yellow]"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
table = self._create_table(display_columns)
|
|
67
76
|
|
|
68
77
|
# Add rows (show first 20 rows)
|
|
69
78
|
for row in results[:20]:
|
|
70
|
-
table.add_row(*[str(row[key]) for key in
|
|
79
|
+
table.add_row(*[str(row[key]) for key in display_columns])
|
|
71
80
|
|
|
72
81
|
self.console.print(table)
|
|
73
82
|
|
|
@@ -235,3 +244,24 @@ class DisplayManager:
|
|
|
235
244
|
self.show_error("Failed to parse plot result")
|
|
236
245
|
except Exception as e:
|
|
237
246
|
self.show_error(f"Error displaying plot: {str(e)}")
|
|
247
|
+
|
|
248
|
+
def show_markdown_response(self, content: list):
|
|
249
|
+
"""Display the assistant's response as rich markdown."""
|
|
250
|
+
if not content:
|
|
251
|
+
return
|
|
252
|
+
|
|
253
|
+
# Extract text from content blocks
|
|
254
|
+
text_parts = []
|
|
255
|
+
for block in content:
|
|
256
|
+
if isinstance(block, dict) and block.get("type") == "text":
|
|
257
|
+
text = block.get("text", "")
|
|
258
|
+
if text:
|
|
259
|
+
text_parts.append(text)
|
|
260
|
+
|
|
261
|
+
# Join all text parts and display as markdown
|
|
262
|
+
full_text = "".join(text_parts).strip()
|
|
263
|
+
if full_text:
|
|
264
|
+
self.console.print() # Add spacing before markdown
|
|
265
|
+
markdown = Markdown(full_text)
|
|
266
|
+
self.console.print(markdown)
|
|
267
|
+
self.console.print() # Add spacing after markdown
|
sqlsaber/cli/interactive.py
CHANGED
|
@@ -1,10 +1,18 @@
|
|
|
1
1
|
"""Interactive mode handling for the CLI."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
3
6
|
import questionary
|
|
4
7
|
from rich.console import Console
|
|
5
8
|
from rich.panel import Panel
|
|
6
9
|
|
|
7
10
|
from sqlsaber.agents.base import BaseSQLAgent
|
|
11
|
+
from sqlsaber.cli.completers import (
|
|
12
|
+
CompositeCompleter,
|
|
13
|
+
SlashCommandCompleter,
|
|
14
|
+
TableNameCompleter,
|
|
15
|
+
)
|
|
8
16
|
from sqlsaber.cli.display import DisplayManager
|
|
9
17
|
from sqlsaber.cli.streaming import StreamingQueryHandler
|
|
10
18
|
|
|
@@ -17,6 +25,9 @@ class InteractiveSession:
|
|
|
17
25
|
self.agent = agent
|
|
18
26
|
self.display = DisplayManager(console)
|
|
19
27
|
self.streaming_handler = StreamingQueryHandler(console)
|
|
28
|
+
self.current_task: Optional[asyncio.Task] = None
|
|
29
|
+
self.cancellation_token: Optional[asyncio.Event] = None
|
|
30
|
+
self.table_completer = TableNameCompleter()
|
|
20
31
|
|
|
21
32
|
def show_welcome_message(self):
|
|
22
33
|
"""Display welcome message for interactive mode."""
|
|
@@ -28,8 +39,9 @@ class InteractiveSession:
|
|
|
28
39
|
Panel.fit(
|
|
29
40
|
"[bold green]SQLSaber - Use the agent Luke![/bold green]\n\n"
|
|
30
41
|
"[bold]Your agentic SQL assistant.[/bold]\n\n\n"
|
|
31
|
-
"[dim]Use 'clear' to reset conversation, 'exit' or 'quit' to leave.[/dim]\n\n"
|
|
32
|
-
"[dim]Start a message with '#' to add something to agent's memory for this database.[/dim]"
|
|
42
|
+
"[dim]Use '/clear' to reset conversation, '/exit' or '/quit' to leave.[/dim]\n\n"
|
|
43
|
+
"[dim]Start a message with '#' to add something to agent's memory for this database.[/dim]\n\n"
|
|
44
|
+
"[dim]Type '@' to get table name completions.[/dim]",
|
|
33
45
|
border_style="green",
|
|
34
46
|
)
|
|
35
47
|
)
|
|
@@ -38,12 +50,71 @@ class InteractiveSession:
|
|
|
38
50
|
)
|
|
39
51
|
self.console.print(
|
|
40
52
|
"[dim]Press Esc-Enter or Meta-Enter to submit your query.[/dim]\n"
|
|
53
|
+
"[dim]Press Ctrl+C during query execution to interrupt and return to prompt.[/dim]\n"
|
|
41
54
|
)
|
|
42
55
|
|
|
56
|
+
async def _update_table_cache(self):
|
|
57
|
+
"""Update the table completer cache with fresh data."""
|
|
58
|
+
try:
|
|
59
|
+
# Use the schema manager directly which has built-in caching
|
|
60
|
+
tables_data = await self.agent.schema_manager.list_tables()
|
|
61
|
+
|
|
62
|
+
# Parse the table information
|
|
63
|
+
table_list = []
|
|
64
|
+
if isinstance(tables_data, dict) and "tables" in tables_data:
|
|
65
|
+
for table in tables_data["tables"]:
|
|
66
|
+
if isinstance(table, dict):
|
|
67
|
+
name = table.get("name", "")
|
|
68
|
+
schema = table.get("schema", "")
|
|
69
|
+
full_name = table.get("full_name", "")
|
|
70
|
+
|
|
71
|
+
# Use full_name if available, otherwise construct it
|
|
72
|
+
if full_name:
|
|
73
|
+
table_name = full_name
|
|
74
|
+
elif schema and schema != "main":
|
|
75
|
+
table_name = f"{schema}.{name}"
|
|
76
|
+
else:
|
|
77
|
+
table_name = name
|
|
78
|
+
|
|
79
|
+
# No description needed - cleaner completions
|
|
80
|
+
table_list.append((table_name, ""))
|
|
81
|
+
|
|
82
|
+
# Update the completer cache
|
|
83
|
+
self.table_completer.update_cache(table_list)
|
|
84
|
+
|
|
85
|
+
except Exception:
|
|
86
|
+
# If there's an error, just use empty cache
|
|
87
|
+
self.table_completer.update_cache([])
|
|
88
|
+
|
|
89
|
+
async def _execute_query_with_cancellation(self, user_query: str):
|
|
90
|
+
"""Execute a query with cancellation support."""
|
|
91
|
+
# Create cancellation token
|
|
92
|
+
self.cancellation_token = asyncio.Event()
|
|
93
|
+
|
|
94
|
+
# Create the query task
|
|
95
|
+
query_task = asyncio.create_task(
|
|
96
|
+
self.streaming_handler.execute_streaming_query(
|
|
97
|
+
user_query, self.agent, self.cancellation_token
|
|
98
|
+
)
|
|
99
|
+
)
|
|
100
|
+
self.current_task = query_task
|
|
101
|
+
|
|
102
|
+
try:
|
|
103
|
+
# Simply await the query task
|
|
104
|
+
# Ctrl+C will be handled by the KeyboardInterrupt exception in run()
|
|
105
|
+
await query_task
|
|
106
|
+
|
|
107
|
+
finally:
|
|
108
|
+
self.current_task = None
|
|
109
|
+
self.cancellation_token = None
|
|
110
|
+
|
|
43
111
|
async def run(self):
|
|
44
112
|
"""Run the interactive session loop."""
|
|
45
113
|
self.show_welcome_message()
|
|
46
114
|
|
|
115
|
+
# Initialize table cache
|
|
116
|
+
await self._update_table_cache()
|
|
117
|
+
|
|
47
118
|
while True:
|
|
48
119
|
try:
|
|
49
120
|
user_query = await questionary.text(
|
|
@@ -51,12 +122,18 @@ class InteractiveSession:
|
|
|
51
122
|
qmark="",
|
|
52
123
|
multiline=True,
|
|
53
124
|
instruction="",
|
|
125
|
+
completer=CompositeCompleter(
|
|
126
|
+
SlashCommandCompleter(), self.table_completer
|
|
127
|
+
),
|
|
54
128
|
).ask_async()
|
|
55
129
|
|
|
56
|
-
if user_query
|
|
130
|
+
if not user_query:
|
|
131
|
+
continue
|
|
132
|
+
|
|
133
|
+
if user_query in ["/exit", "/quit"]:
|
|
57
134
|
break
|
|
58
135
|
|
|
59
|
-
if user_query
|
|
136
|
+
if user_query == "/clear":
|
|
60
137
|
self.agent.clear_history()
|
|
61
138
|
self.console.print("[green]Conversation history cleared.[/green]\n")
|
|
62
139
|
continue
|
|
@@ -85,12 +162,24 @@ class InteractiveSession:
|
|
|
85
162
|
)
|
|
86
163
|
continue
|
|
87
164
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
)
|
|
165
|
+
# Execute query with cancellation support
|
|
166
|
+
await self._execute_query_with_cancellation(user_query)
|
|
91
167
|
self.display.show_newline() # Empty line for readability
|
|
92
168
|
|
|
93
169
|
except KeyboardInterrupt:
|
|
94
|
-
|
|
170
|
+
# Handle Ctrl+C - cancel current task if running
|
|
171
|
+
if self.current_task and not self.current_task.done():
|
|
172
|
+
if self.cancellation_token is not None:
|
|
173
|
+
self.cancellation_token.set()
|
|
174
|
+
self.current_task.cancel()
|
|
175
|
+
try:
|
|
176
|
+
await self.current_task
|
|
177
|
+
except asyncio.CancelledError:
|
|
178
|
+
pass
|
|
179
|
+
self.console.print("\n[yellow]Query interrupted[/yellow]")
|
|
180
|
+
else:
|
|
181
|
+
self.console.print(
|
|
182
|
+
"\n[yellow]Use '/exit' or '/quit' to leave.[/yellow]"
|
|
183
|
+
)
|
|
95
184
|
except Exception as e:
|
|
96
185
|
self.console.print(f"[bold red]Error:[/bold red] {str(e)}")
|
sqlsaber/cli/streaming.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Streaming query handling for the CLI."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
4
|
+
|
|
3
5
|
from rich.console import Console
|
|
4
6
|
|
|
5
7
|
from sqlsaber.agents.base import BaseSQLAgent
|
|
@@ -13,7 +15,12 @@ class StreamingQueryHandler:
|
|
|
13
15
|
self.console = console
|
|
14
16
|
self.display = DisplayManager(console)
|
|
15
17
|
|
|
16
|
-
async def execute_streaming_query(
|
|
18
|
+
async def execute_streaming_query(
|
|
19
|
+
self,
|
|
20
|
+
user_query: str,
|
|
21
|
+
agent: BaseSQLAgent,
|
|
22
|
+
cancellation_token: asyncio.Event | None = None,
|
|
23
|
+
):
|
|
17
24
|
"""Execute a query with streaming display."""
|
|
18
25
|
|
|
19
26
|
has_content = False
|
|
@@ -24,7 +31,12 @@ class StreamingQueryHandler:
|
|
|
24
31
|
status.start()
|
|
25
32
|
|
|
26
33
|
try:
|
|
27
|
-
async for event in agent.query_stream(
|
|
34
|
+
async for event in agent.query_stream(
|
|
35
|
+
user_query, cancellation_token=cancellation_token
|
|
36
|
+
):
|
|
37
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
38
|
+
break
|
|
39
|
+
|
|
28
40
|
if event.type == "tool_use":
|
|
29
41
|
# Stop any ongoing status, but don't mark has_content yet
|
|
30
42
|
self._stop_status(status)
|
|
@@ -83,6 +95,13 @@ class StreamingQueryHandler:
|
|
|
83
95
|
has_content = True
|
|
84
96
|
self.display.show_error(event.data)
|
|
85
97
|
|
|
98
|
+
except asyncio.CancelledError:
|
|
99
|
+
# Handle cancellation gracefully
|
|
100
|
+
self._stop_status(status)
|
|
101
|
+
if explanation_started:
|
|
102
|
+
self.display.show_newline()
|
|
103
|
+
self.console.print("[yellow]Query interrupted[/yellow]")
|
|
104
|
+
return
|
|
86
105
|
finally:
|
|
87
106
|
# Make sure status is stopped
|
|
88
107
|
self._stop_status(status)
|
|
@@ -91,6 +110,14 @@ class StreamingQueryHandler:
|
|
|
91
110
|
if explanation_started:
|
|
92
111
|
self.display.show_newline() # Empty line for better readability
|
|
93
112
|
|
|
113
|
+
# Display the last assistant response as markdown
|
|
114
|
+
if hasattr(agent, "conversation_history") and agent.conversation_history:
|
|
115
|
+
last_message = agent.conversation_history[-1]
|
|
116
|
+
if last_message.get("role") == "assistant" and last_message.get(
|
|
117
|
+
"content"
|
|
118
|
+
):
|
|
119
|
+
self.display.show_markdown_response(last_message["content"])
|
|
120
|
+
|
|
94
121
|
def _stop_status(self, status):
|
|
95
122
|
"""Safely stop a status spinner."""
|
|
96
123
|
try:
|
sqlsaber/database/schema.py
CHANGED
|
@@ -683,6 +683,13 @@ class SchemaManager:
|
|
|
683
683
|
|
|
684
684
|
async def list_tables(self) -> Dict[str, Any]:
|
|
685
685
|
"""Get a list of all tables with basic information like row counts."""
|
|
686
|
+
# Check cache first
|
|
687
|
+
cache_key = "list_tables"
|
|
688
|
+
cached_data = self._get_cached_tables(cache_key)
|
|
689
|
+
if cached_data is not None:
|
|
690
|
+
return cached_data
|
|
691
|
+
|
|
692
|
+
# Fetch from database if not cached
|
|
686
693
|
tables = await self.introspector.list_tables_info(self.db)
|
|
687
694
|
|
|
688
695
|
# Format the result
|
|
@@ -699,4 +706,14 @@ class SchemaManager:
|
|
|
699
706
|
}
|
|
700
707
|
)
|
|
701
708
|
|
|
709
|
+
# Cache the result
|
|
710
|
+
self._schema_cache[cache_key] = (time.time(), result)
|
|
702
711
|
return result
|
|
712
|
+
|
|
713
|
+
def _get_cached_tables(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
|
714
|
+
"""Get table list from cache if available and not expired."""
|
|
715
|
+
if cache_key in self._schema_cache:
|
|
716
|
+
cached_time, cached_data = self._schema_cache[cache_key]
|
|
717
|
+
if time.time() - cached_time < self.cache_ttl:
|
|
718
|
+
return cached_data
|
|
719
|
+
return None
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sqlsaber
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.7.0
|
|
4
4
|
Summary: SQLSaber - Agentic SQL assistant like Claude Code
|
|
5
5
|
License-File: LICENSE
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -212,23 +212,24 @@ The MCP server uses your existing SQLSaber database configurations, so make sure
|
|
|
212
212
|
|
|
213
213
|
## How It Works
|
|
214
214
|
|
|
215
|
-
SQLSaber uses
|
|
215
|
+
SQLSaber uses a multi-step process to gather the right context, provide it to the model, and execute SQL queries to get the right answers:
|
|
216
|
+
|
|
217
|
+

|
|
216
218
|
|
|
217
219
|
### 🔍 Discovery Phase
|
|
218
220
|
|
|
219
221
|
1. **List Tables Tool**: Quickly discovers available tables with row counts
|
|
220
|
-
2. **Pattern Matching**: Identifies relevant tables based on your query
|
|
222
|
+
2. **Pattern Matching**: Identifies relevant tables based on your query
|
|
221
223
|
|
|
222
224
|
### 📋 Schema Analysis
|
|
223
225
|
|
|
224
|
-
3. **Smart Introspection**: Analyzes only the specific table structures needed for your query
|
|
225
|
-
4. **Selective Loading**: Fetches schema information only for relevant tables
|
|
226
|
+
3. **Smart Schema Introspection**: Analyzes only the specific table structures needed for your query
|
|
226
227
|
|
|
227
228
|
### ⚡ Execution Phase
|
|
228
229
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
230
|
+
4. **SQL Generation**: Creates optimized SQL queries based on natural language input
|
|
231
|
+
5. **Safe Execution**: Runs read-only queries with built-in protections against destructive operations
|
|
232
|
+
6. **Result Formatting**: Presents results with explanations in tables and optionally, visualizes using plots
|
|
232
233
|
|
|
233
234
|
## Contributing
|
|
234
235
|
|
|
@@ -1,25 +1,26 @@
|
|
|
1
1
|
sqlsaber/__init__.py,sha256=QCFi8xTVMohelfi7zOV1-6oLCcGoiXoOcKQY-HNBCk8,66
|
|
2
2
|
sqlsaber/__main__.py,sha256=RIHxWeWh2QvLfah-2OkhI5IJxojWfy4fXpMnVEJYvxw,78
|
|
3
3
|
sqlsaber/agents/__init__.py,sha256=LWeSeEUE4BhkyAYFF3TE-fx8TtLud3oyEtyB8ojFJgo,167
|
|
4
|
-
sqlsaber/agents/anthropic.py,sha256=
|
|
5
|
-
sqlsaber/agents/base.py,sha256=
|
|
4
|
+
sqlsaber/agents/anthropic.py,sha256=FLVET2HvFmsEuFln9Hu4SaBs-Tnk-GestOgnDnUp3ps,17885
|
|
5
|
+
sqlsaber/agents/base.py,sha256=DAnezHl5RLYoef8XQ-n3KA9PowdrMbQrkjdGKPPnFsI,10570
|
|
6
6
|
sqlsaber/agents/mcp.py,sha256=FKtXgDrPZ2-xqUYCw2baI5JzrWekXaC5fjkYW1_Mg50,827
|
|
7
7
|
sqlsaber/agents/streaming.py,sha256=_EO390-FHUrL1fRCNfibtE9QuJz3LGQygbwG3CB2ViY,533
|
|
8
8
|
sqlsaber/cli/__init__.py,sha256=qVSLVJLLJYzoC6aj6y9MFrzZvAwc4_OgxU9DlkQnZ4M,86
|
|
9
9
|
sqlsaber/cli/commands.py,sha256=Dw24W0jij-8t1lpk99C4PBTgzFSag6vU-FZcjAYGG54,5074
|
|
10
|
+
sqlsaber/cli/completers.py,sha256=JWOCKAm0Prpy_O2QJsf_VbPWfy2lQQh6KutyG8FU4us,6462
|
|
10
11
|
sqlsaber/cli/database.py,sha256=DUfyvNBDp47oFM_VAC_hXHQy_qyE7JbXtowflJpwwH8,12643
|
|
11
|
-
sqlsaber/cli/display.py,sha256=
|
|
12
|
-
sqlsaber/cli/interactive.py,sha256=
|
|
12
|
+
sqlsaber/cli/display.py,sha256=NIBWHUrX_8ZhDu6iW9v4fzx0zncnXa5WdQ9wfTrjKIM,10017
|
|
13
|
+
sqlsaber/cli/interactive.py,sha256=FvgtT45U-yblhbRImKqJ4jgBRNs0u7NhE2PcgoVUaVA,7429
|
|
13
14
|
sqlsaber/cli/memory.py,sha256=LW4ZF2V6Gw6hviUFGZ4ym9ostFCwucgBTIMZ3EANO-I,7671
|
|
14
15
|
sqlsaber/cli/models.py,sha256=3IcXeeU15IQvemSv-V-RQzVytJ3wuQ4YmWk89nTDcSE,7813
|
|
15
|
-
sqlsaber/cli/streaming.py,sha256=
|
|
16
|
+
sqlsaber/cli/streaming.py,sha256=DfwygmjEzAh9hZGKjrW9kS1A7MG5W9Ky_kCTzxziODQ,4970
|
|
16
17
|
sqlsaber/config/__init__.py,sha256=olwC45k8Nc61yK0WmPUk7XHdbsZH9HuUAbwnmKe3IgA,100
|
|
17
18
|
sqlsaber/config/api_keys.py,sha256=kLdoExF_My9ojmdhO5Ca7-ZeowsO0v1GVa_QT5jjUPo,3658
|
|
18
19
|
sqlsaber/config/database.py,sha256=vKFOxPjVakjQhj1uoLcfzhS9ZFr6Z2F5b4MmYALQZoA,11421
|
|
19
20
|
sqlsaber/config/settings.py,sha256=zjQ7nS3ybcCb88Ea0tmwJox5-q0ettChZw89ZqRVpX8,3975
|
|
20
21
|
sqlsaber/database/__init__.py,sha256=a_gtKRJnZVO8-fEZI7g3Z8YnGa6Nio-5Y50PgVp07ss,176
|
|
21
22
|
sqlsaber/database/connection.py,sha256=s8GSFZebB8be8sVUr-N0x88-20YfkfljJFRyfoB1gH0,15154
|
|
22
|
-
sqlsaber/database/schema.py,sha256=
|
|
23
|
+
sqlsaber/database/schema.py,sha256=3CfkyhxgD6SmiUoz7MQPlQLrrA007HOQLnGCvvsdJx0,28647
|
|
23
24
|
sqlsaber/mcp/__init__.py,sha256=COdWq7wauPBp5Ew8tfZItFzbcLDSEkHBJSMhxzy8C9c,112
|
|
24
25
|
sqlsaber/mcp/mcp.py,sha256=ACm1P1TnicjOptQgeLNhXg5xgZf4MYq2kqdfVdj6wh0,4477
|
|
25
26
|
sqlsaber/memory/__init__.py,sha256=GiWkU6f6YYVV0EvvXDmFWe_CxarmDCql05t70MkTEWs,63
|
|
@@ -28,8 +29,8 @@ sqlsaber/memory/storage.py,sha256=DvZBsSPaAfk_DqrNEn86uMD-TQsWUI6rQLfNw6PSCB8,57
|
|
|
28
29
|
sqlsaber/models/__init__.py,sha256=RJ7p3WtuSwwpFQ1Iw4_DHV2zzCtHqIzsjJzxv8kUjUE,287
|
|
29
30
|
sqlsaber/models/events.py,sha256=q2FackB60J9-7vegYIjzElLwKebIh7nxnV5AFoZc67c,752
|
|
30
31
|
sqlsaber/models/types.py,sha256=3U_30n91EB3IglBTHipwiW4MqmmaA2qfshfraMZyPps,896
|
|
31
|
-
sqlsaber-0.
|
|
32
|
-
sqlsaber-0.
|
|
33
|
-
sqlsaber-0.
|
|
34
|
-
sqlsaber-0.
|
|
35
|
-
sqlsaber-0.
|
|
32
|
+
sqlsaber-0.7.0.dist-info/METADATA,sha256=tUV3WHkVZEXissVrKAaOooaZyn7e_NmMV_e-nNaoLVE,5986
|
|
33
|
+
sqlsaber-0.7.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
34
|
+
sqlsaber-0.7.0.dist-info/entry_points.txt,sha256=jmFo96Ylm0zIKXJBwhv_P5wQ7SXP9qdaBcnTp8iCEe8,195
|
|
35
|
+
sqlsaber-0.7.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
36
|
+
sqlsaber-0.7.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|