sqlsaber 0.4.1__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/anthropic.py +105 -18
- sqlsaber/agents/base.py +105 -3
- sqlsaber/cli/display.py +41 -3
- sqlsaber/cli/interactive.py +80 -7
- sqlsaber/cli/streaming.py +26 -2
- sqlsaber/models/events.py +1 -1
- {sqlsaber-0.4.1.dist-info → sqlsaber-0.6.0.dist-info}/METADATA +2 -1
- {sqlsaber-0.4.1.dist-info → sqlsaber-0.6.0.dist-info}/RECORD +11 -11
- {sqlsaber-0.4.1.dist-info → sqlsaber-0.6.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.4.1.dist-info → sqlsaber-0.6.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.4.1.dist-info → sqlsaber-0.6.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/agents/anthropic.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Anthropic-specific SQL agent implementation."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import json
|
|
4
|
-
from typing import Any, AsyncIterator, Dict, List
|
|
5
|
+
from typing import Any, AsyncIterator, Dict, List
|
|
5
6
|
|
|
6
7
|
from anthropic import AsyncAnthropic
|
|
7
8
|
|
|
@@ -21,7 +22,7 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
21
22
|
"""SQL Agent using Anthropic SDK directly."""
|
|
22
23
|
|
|
23
24
|
def __init__(
|
|
24
|
-
self, db_connection: BaseDatabaseConnection, database_name:
|
|
25
|
+
self, db_connection: BaseDatabaseConnection, database_name: str | None = None
|
|
25
26
|
):
|
|
26
27
|
super().__init__(db_connection)
|
|
27
28
|
|
|
@@ -82,6 +83,44 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
82
83
|
"required": ["query"],
|
|
83
84
|
},
|
|
84
85
|
},
|
|
86
|
+
{
|
|
87
|
+
"name": "plot_data",
|
|
88
|
+
"description": "Create a plot of query results.",
|
|
89
|
+
"input_schema": {
|
|
90
|
+
"type": "object",
|
|
91
|
+
"properties": {
|
|
92
|
+
"y_values": {
|
|
93
|
+
"type": "array",
|
|
94
|
+
"items": {"type": ["number", "null"]},
|
|
95
|
+
"description": "Y-axis data points (required)",
|
|
96
|
+
},
|
|
97
|
+
"x_values": {
|
|
98
|
+
"type": "array",
|
|
99
|
+
"items": {"type": ["number", "null"]},
|
|
100
|
+
"description": "X-axis data points (optional, will use indices if not provided)",
|
|
101
|
+
},
|
|
102
|
+
"plot_type": {
|
|
103
|
+
"type": "string",
|
|
104
|
+
"enum": ["line", "scatter", "histogram"],
|
|
105
|
+
"description": "Type of plot to create (default: line)",
|
|
106
|
+
"default": "line",
|
|
107
|
+
},
|
|
108
|
+
"title": {
|
|
109
|
+
"type": "string",
|
|
110
|
+
"description": "Title for the plot",
|
|
111
|
+
},
|
|
112
|
+
"x_label": {
|
|
113
|
+
"type": "string",
|
|
114
|
+
"description": "Label for X-axis",
|
|
115
|
+
},
|
|
116
|
+
"y_label": {
|
|
117
|
+
"type": "string",
|
|
118
|
+
"description": "Label for Y-axis",
|
|
119
|
+
},
|
|
120
|
+
},
|
|
121
|
+
"required": ["y_values"],
|
|
122
|
+
},
|
|
123
|
+
},
|
|
85
124
|
]
|
|
86
125
|
|
|
87
126
|
# Build system prompt with memories if available
|
|
@@ -96,13 +135,15 @@ Your responsibilities:
|
|
|
96
135
|
1. Understand user's natural language requests, think and convert them to SQL
|
|
97
136
|
2. Use the provided tools efficiently to explore database schema
|
|
98
137
|
3. Generate appropriate SQL queries
|
|
99
|
-
4. Execute queries safely
|
|
138
|
+
4. Execute queries safely - queries that modify the database are not allowed
|
|
100
139
|
5. Format and explain results clearly
|
|
140
|
+
6. Create visualizations when requested or when they would be helpful
|
|
101
141
|
|
|
102
142
|
IMPORTANT - Schema Discovery Strategy:
|
|
103
143
|
1. ALWAYS start with 'list_tables' to see available tables and row counts
|
|
104
144
|
2. Based on the user's query, identify which specific tables are relevant
|
|
105
145
|
3. Use 'introspect_schema' with a table_pattern to get details ONLY for relevant tables
|
|
146
|
+
4. Timestamp columns must be converted to text when you write queries
|
|
106
147
|
|
|
107
148
|
Guidelines:
|
|
108
149
|
- Use list_tables first, then introspect_schema for specific tables only
|
|
@@ -124,7 +165,7 @@ Guidelines:
|
|
|
124
165
|
|
|
125
166
|
return base_prompt
|
|
126
167
|
|
|
127
|
-
def add_memory(self, content: str) ->
|
|
168
|
+
def add_memory(self, content: str) -> str | None:
|
|
128
169
|
"""Add a memory for the current database."""
|
|
129
170
|
if not self.database_name:
|
|
130
171
|
return None
|
|
@@ -134,7 +175,7 @@ Guidelines:
|
|
|
134
175
|
self.system_prompt = self._build_system_prompt()
|
|
135
176
|
return memory.id
|
|
136
177
|
|
|
137
|
-
async def execute_sql(self, query: str, limit:
|
|
178
|
+
async def execute_sql(self, query: str, limit: int | None = None) -> str:
|
|
138
179
|
"""Execute a SQL query against the database with streaming support."""
|
|
139
180
|
# Call parent implementation for core functionality
|
|
140
181
|
result = await super().execute_sql(query, limit)
|
|
@@ -163,10 +204,18 @@ Guidelines:
|
|
|
163
204
|
return await super().process_tool_call(tool_name, tool_input)
|
|
164
205
|
|
|
165
206
|
async def _process_stream_events(
|
|
166
|
-
self,
|
|
207
|
+
self,
|
|
208
|
+
stream,
|
|
209
|
+
content_blocks: List[Dict],
|
|
210
|
+
tool_use_blocks: List[Dict],
|
|
211
|
+
cancellation_token: asyncio.Event | None = None,
|
|
167
212
|
) -> AsyncIterator[StreamEvent]:
|
|
168
213
|
"""Process stream events and yield appropriate StreamEvents."""
|
|
169
214
|
async for event in stream:
|
|
215
|
+
# Only check cancellation if token is provided
|
|
216
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
217
|
+
return
|
|
218
|
+
|
|
170
219
|
if event.type == "content_block_start":
|
|
171
220
|
if hasattr(event.content_block, "type"):
|
|
172
221
|
if event.content_block.type == "tool_use":
|
|
@@ -213,11 +262,17 @@ Guidelines:
|
|
|
213
262
|
return "stop"
|
|
214
263
|
|
|
215
264
|
async def _process_tool_results(
|
|
216
|
-
self,
|
|
265
|
+
self,
|
|
266
|
+
response: StreamingResponse,
|
|
267
|
+
cancellation_token: asyncio.Event | None = None,
|
|
217
268
|
) -> AsyncIterator[StreamEvent]:
|
|
218
269
|
"""Process tool results and yield appropriate events."""
|
|
219
270
|
tool_results = []
|
|
220
271
|
for block in response.content:
|
|
272
|
+
# Only check cancellation if token is provided
|
|
273
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
274
|
+
return
|
|
275
|
+
|
|
221
276
|
if block.get("type") == "tool_use":
|
|
222
277
|
yield StreamEvent(
|
|
223
278
|
"tool_use",
|
|
@@ -249,13 +304,25 @@ Guidelines:
|
|
|
249
304
|
"result": tool_result,
|
|
250
305
|
},
|
|
251
306
|
)
|
|
307
|
+
elif block["name"] == "plot_data":
|
|
308
|
+
yield StreamEvent(
|
|
309
|
+
"plot_result",
|
|
310
|
+
{
|
|
311
|
+
"tool_name": block["name"],
|
|
312
|
+
"input": block["input"],
|
|
313
|
+
"result": tool_result,
|
|
314
|
+
},
|
|
315
|
+
)
|
|
252
316
|
|
|
253
317
|
tool_results.append(build_tool_result_block(block["id"], tool_result))
|
|
254
318
|
|
|
255
319
|
yield StreamEvent("tool_result_data", tool_results)
|
|
256
320
|
|
|
257
321
|
async def query_stream(
|
|
258
|
-
self,
|
|
322
|
+
self,
|
|
323
|
+
user_query: str,
|
|
324
|
+
use_history: bool = True,
|
|
325
|
+
cancellation_token: asyncio.Event | None = None,
|
|
259
326
|
) -> AsyncIterator[StreamEvent]:
|
|
260
327
|
"""Process a user query and stream responses."""
|
|
261
328
|
# Initialize for tracking state
|
|
@@ -273,7 +340,11 @@ Guidelines:
|
|
|
273
340
|
try:
|
|
274
341
|
# Create initial stream and get response
|
|
275
342
|
response = None
|
|
276
|
-
async for event in self._create_and_process_stream(
|
|
343
|
+
async for event in self._create_and_process_stream(
|
|
344
|
+
messages, cancellation_token
|
|
345
|
+
):
|
|
346
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
347
|
+
return
|
|
277
348
|
if event.type == "response_ready":
|
|
278
349
|
response = event.data
|
|
279
350
|
else:
|
|
@@ -283,14 +354,21 @@ Guidelines:
|
|
|
283
354
|
|
|
284
355
|
# Process tool calls if needed
|
|
285
356
|
while response is not None and response.stop_reason == "tool_use":
|
|
357
|
+
# Check for cancellation at the start of tool cycle
|
|
358
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
359
|
+
return
|
|
360
|
+
|
|
286
361
|
# Add assistant's response to conversation
|
|
287
362
|
collected_content.append(
|
|
288
363
|
{"role": "assistant", "content": response.content}
|
|
289
364
|
)
|
|
290
365
|
|
|
291
|
-
# Process tool results
|
|
366
|
+
# Process tool results - DO NOT check cancellation during tool execution
|
|
367
|
+
# as this would break the tool_use -> tool_result API contract
|
|
292
368
|
tool_results = []
|
|
293
|
-
async for event in self._process_tool_results(
|
|
369
|
+
async for event in self._process_tool_results(
|
|
370
|
+
response, None
|
|
371
|
+
): # Pass None to disable cancellation checks
|
|
294
372
|
if event.type == "tool_result_data":
|
|
295
373
|
tool_results = event.data
|
|
296
374
|
else:
|
|
@@ -298,6 +376,12 @@ Guidelines:
|
|
|
298
376
|
|
|
299
377
|
# Continue conversation with tool results
|
|
300
378
|
collected_content.append({"role": "user", "content": tool_results})
|
|
379
|
+
if use_history:
|
|
380
|
+
self.conversation_history.extend(collected_content)
|
|
381
|
+
|
|
382
|
+
# Check for cancellation AFTER tool results are complete
|
|
383
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
384
|
+
return
|
|
301
385
|
|
|
302
386
|
# Signal that we're processing the tool results
|
|
303
387
|
yield StreamEvent("processing", "Analyzing results...")
|
|
@@ -305,8 +389,10 @@ Guidelines:
|
|
|
305
389
|
# Get next response
|
|
306
390
|
response = None
|
|
307
391
|
async for event in self._create_and_process_stream(
|
|
308
|
-
messages + collected_content
|
|
392
|
+
messages + collected_content, cancellation_token
|
|
309
393
|
):
|
|
394
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
395
|
+
return
|
|
310
396
|
if event.type == "response_ready":
|
|
311
397
|
response = event.data
|
|
312
398
|
else:
|
|
@@ -314,21 +400,19 @@ Guidelines:
|
|
|
314
400
|
|
|
315
401
|
# Update conversation history if using history
|
|
316
402
|
if use_history:
|
|
317
|
-
self.conversation_history.append(
|
|
318
|
-
{"role": "user", "content": user_query}
|
|
319
|
-
)
|
|
320
|
-
self.conversation_history.extend(collected_content)
|
|
321
403
|
# Add final assistant response
|
|
322
404
|
if response is not None:
|
|
323
405
|
self.conversation_history.append(
|
|
324
406
|
{"role": "assistant", "content": response.content}
|
|
325
407
|
)
|
|
326
408
|
|
|
409
|
+
except asyncio.CancelledError:
|
|
410
|
+
return
|
|
327
411
|
except Exception as e:
|
|
328
412
|
yield StreamEvent("error", str(e))
|
|
329
413
|
|
|
330
414
|
async def _create_and_process_stream(
|
|
331
|
-
self, messages: List[Dict]
|
|
415
|
+
self, messages: List[Dict], cancellation_token: asyncio.Event | None = None
|
|
332
416
|
) -> AsyncIterator[StreamEvent]:
|
|
333
417
|
"""Create a stream and yield events while building response."""
|
|
334
418
|
stream = await self.client.messages.create(
|
|
@@ -344,8 +428,11 @@ Guidelines:
|
|
|
344
428
|
tool_use_blocks = []
|
|
345
429
|
|
|
346
430
|
async for event in self._process_stream_events(
|
|
347
|
-
stream, content_blocks, tool_use_blocks
|
|
431
|
+
stream, content_blocks, tool_use_blocks, cancellation_token
|
|
348
432
|
):
|
|
433
|
+
# Only check cancellation if token is provided
|
|
434
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
435
|
+
return
|
|
349
436
|
yield event
|
|
350
437
|
|
|
351
438
|
# Finalize tool blocks and create response
|
sqlsaber/agents/base.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
1
1
|
"""Abstract base class for SQL agents."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import json
|
|
4
5
|
from abc import ABC, abstractmethod
|
|
5
6
|
from typing import Any, AsyncIterator, Dict, List, Optional
|
|
6
7
|
|
|
8
|
+
from uniplot import histogram, plot
|
|
9
|
+
|
|
7
10
|
from sqlsaber.database.connection import (
|
|
8
11
|
BaseDatabaseConnection,
|
|
9
12
|
CSVConnection,
|
|
@@ -25,9 +28,18 @@ class BaseSQLAgent(ABC):
|
|
|
25
28
|
|
|
26
29
|
@abstractmethod
|
|
27
30
|
async def query_stream(
|
|
28
|
-
self,
|
|
31
|
+
self,
|
|
32
|
+
user_query: str,
|
|
33
|
+
use_history: bool = True,
|
|
34
|
+
cancellation_token: asyncio.Event | None = None,
|
|
29
35
|
) -> AsyncIterator[StreamEvent]:
|
|
30
|
-
"""Process a user query and stream responses.
|
|
36
|
+
"""Process a user query and stream responses.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
user_query: The user's query to process
|
|
40
|
+
use_history: Whether to include conversation history
|
|
41
|
+
cancellation_token: Optional event to signal cancellation
|
|
42
|
+
"""
|
|
31
43
|
pass
|
|
32
44
|
|
|
33
45
|
def clear_history(self):
|
|
@@ -84,7 +96,7 @@ class BaseSQLAgent(ABC):
|
|
|
84
96
|
except Exception as e:
|
|
85
97
|
return json.dumps({"error": f"Error listing tables: {str(e)}"})
|
|
86
98
|
|
|
87
|
-
async def execute_sql(self, query: str, limit: Optional[int] =
|
|
99
|
+
async def execute_sql(self, query: str, limit: Optional[int] = None) -> str:
|
|
88
100
|
"""Execute a SQL query against the database."""
|
|
89
101
|
try:
|
|
90
102
|
# Security check - only allow SELECT queries unless write is enabled
|
|
@@ -146,6 +158,15 @@ class BaseSQLAgent(ABC):
|
|
|
146
158
|
return await self.execute_sql(
|
|
147
159
|
tool_input["query"], tool_input.get("limit", 100)
|
|
148
160
|
)
|
|
161
|
+
elif tool_name == "plot_data":
|
|
162
|
+
return await self.plot_data(
|
|
163
|
+
y_values=tool_input["y_values"],
|
|
164
|
+
x_values=tool_input.get("x_values"),
|
|
165
|
+
plot_type=tool_input.get("plot_type", "line"),
|
|
166
|
+
title=tool_input.get("title"),
|
|
167
|
+
x_label=tool_input.get("x_label"),
|
|
168
|
+
y_label=tool_input.get("y_label"),
|
|
169
|
+
)
|
|
149
170
|
else:
|
|
150
171
|
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
|
151
172
|
|
|
@@ -182,3 +203,84 @@ class BaseSQLAgent(ABC):
|
|
|
182
203
|
if query_upper.startswith("SELECT") and "LIMIT" not in query_upper:
|
|
183
204
|
return f"{query.rstrip(';')} LIMIT {limit};"
|
|
184
205
|
return query
|
|
206
|
+
|
|
207
|
+
async def plot_data(
|
|
208
|
+
self,
|
|
209
|
+
y_values: List[float],
|
|
210
|
+
x_values: Optional[List[float]] = None,
|
|
211
|
+
plot_type: str = "line",
|
|
212
|
+
title: Optional[str] = None,
|
|
213
|
+
x_label: Optional[str] = None,
|
|
214
|
+
y_label: Optional[str] = None,
|
|
215
|
+
) -> str:
|
|
216
|
+
"""Create a terminal plot using uniplot.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
y_values: Y-axis data points
|
|
220
|
+
x_values: X-axis data points (optional)
|
|
221
|
+
plot_type: Type of plot - "line", "scatter", or "histogram"
|
|
222
|
+
title: Plot title
|
|
223
|
+
x_label: X-axis label
|
|
224
|
+
y_label: Y-axis label
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
JSON string with success status and plot details
|
|
228
|
+
"""
|
|
229
|
+
try:
|
|
230
|
+
# Validate inputs
|
|
231
|
+
if not y_values:
|
|
232
|
+
return json.dumps({"error": "No data provided for plotting"})
|
|
233
|
+
|
|
234
|
+
# Convert to floats if needed
|
|
235
|
+
try:
|
|
236
|
+
y_values = [float(v) if v is not None else None for v in y_values]
|
|
237
|
+
if x_values:
|
|
238
|
+
x_values = [float(v) if v is not None else None for v in x_values]
|
|
239
|
+
except (ValueError, TypeError) as e:
|
|
240
|
+
return json.dumps({"error": f"Invalid data format: {str(e)}"})
|
|
241
|
+
|
|
242
|
+
# Create the plot
|
|
243
|
+
if plot_type == "histogram":
|
|
244
|
+
# For histogram, we only need y_values
|
|
245
|
+
histogram(
|
|
246
|
+
y_values,
|
|
247
|
+
title=title,
|
|
248
|
+
bins=min(20, len(set(y_values))), # Adaptive bin count
|
|
249
|
+
)
|
|
250
|
+
plot_info = {
|
|
251
|
+
"type": "histogram",
|
|
252
|
+
"data_points": len(y_values),
|
|
253
|
+
"title": title or "Histogram",
|
|
254
|
+
}
|
|
255
|
+
elif plot_type in ["line", "scatter"]:
|
|
256
|
+
# For line/scatter plots
|
|
257
|
+
plot_kwargs = {
|
|
258
|
+
"ys": y_values,
|
|
259
|
+
"title": title,
|
|
260
|
+
"lines": plot_type == "line",
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
if x_values:
|
|
264
|
+
plot_kwargs["xs"] = x_values
|
|
265
|
+
if x_label:
|
|
266
|
+
plot_kwargs["x_unit"] = x_label
|
|
267
|
+
if y_label:
|
|
268
|
+
plot_kwargs["y_unit"] = y_label
|
|
269
|
+
|
|
270
|
+
plot(**plot_kwargs)
|
|
271
|
+
|
|
272
|
+
plot_info = {
|
|
273
|
+
"type": plot_type,
|
|
274
|
+
"data_points": len(y_values),
|
|
275
|
+
"title": title or f"{plot_type.capitalize()} Plot",
|
|
276
|
+
"has_x_values": x_values is not None,
|
|
277
|
+
}
|
|
278
|
+
else:
|
|
279
|
+
return json.dumps({"error": f"Unsupported plot type: {plot_type}"})
|
|
280
|
+
|
|
281
|
+
return json.dumps(
|
|
282
|
+
{"success": True, "plot_rendered": True, "plot_info": plot_info}
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
except Exception as e:
|
|
286
|
+
return json.dumps({"error": f"Error creating plot: {str(e)}"})
|
sqlsaber/cli/display.py
CHANGED
|
@@ -62,12 +62,20 @@ class DisplayManager:
|
|
|
62
62
|
)
|
|
63
63
|
|
|
64
64
|
# Create table with columns from first result
|
|
65
|
-
|
|
66
|
-
|
|
65
|
+
all_columns = list(results[0].keys())
|
|
66
|
+
display_columns = all_columns[:15] # Limit to first 15 columns
|
|
67
|
+
|
|
68
|
+
# Show warning if columns were truncated
|
|
69
|
+
if len(all_columns) > 15:
|
|
70
|
+
self.console.print(
|
|
71
|
+
f"[yellow]Note: Showing first 15 of {len(all_columns)} columns[/yellow]"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
table = self._create_table(display_columns)
|
|
67
75
|
|
|
68
76
|
# Add rows (show first 20 rows)
|
|
69
77
|
for row in results[:20]:
|
|
70
|
-
table.add_row(*[str(row[key]) for key in
|
|
78
|
+
table.add_row(*[str(row[key]) for key in display_columns])
|
|
71
79
|
|
|
72
80
|
self.console.print(table)
|
|
73
81
|
|
|
@@ -205,3 +213,33 @@ class DisplayManager:
|
|
|
205
213
|
self.show_error("Failed to parse schema data")
|
|
206
214
|
except Exception as e:
|
|
207
215
|
self.show_error(f"Error displaying schema information: {str(e)}")
|
|
216
|
+
|
|
217
|
+
def show_plot(self, plot_data: dict):
|
|
218
|
+
"""Display plot information and status."""
|
|
219
|
+
try:
|
|
220
|
+
# Parse the result if it's a string
|
|
221
|
+
if isinstance(plot_data.get("result"), str):
|
|
222
|
+
result = json.loads(plot_data["result"])
|
|
223
|
+
else:
|
|
224
|
+
result = plot_data.get("result", {})
|
|
225
|
+
|
|
226
|
+
# Check if there was an error
|
|
227
|
+
if "error" in result:
|
|
228
|
+
self.show_error(f"Plot error: {result['error']}")
|
|
229
|
+
return
|
|
230
|
+
|
|
231
|
+
# If plot was successful, show plot info
|
|
232
|
+
if result.get("success") and result.get("plot_rendered"):
|
|
233
|
+
plot_info = result.get("plot_info", {})
|
|
234
|
+
self.console.print(
|
|
235
|
+
f"\n[bold green]✓ Plot rendered:[/bold green] {plot_info.get('title', 'Plot')}"
|
|
236
|
+
)
|
|
237
|
+
self.console.print(
|
|
238
|
+
f"[dim] Type: {plot_info.get('type', 'unknown')}, "
|
|
239
|
+
f"Data points: {plot_info.get('data_points', 0)}[/dim]"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
except json.JSONDecodeError:
|
|
243
|
+
self.show_error("Failed to parse plot result")
|
|
244
|
+
except Exception as e:
|
|
245
|
+
self.show_error(f"Error displaying plot: {str(e)}")
|
sqlsaber/cli/interactive.py
CHANGED
|
@@ -1,6 +1,10 @@
|
|
|
1
1
|
"""Interactive mode handling for the CLI."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
3
6
|
import questionary
|
|
7
|
+
from prompt_toolkit.completion import Completer, Completion
|
|
4
8
|
from rich.console import Console
|
|
5
9
|
from rich.panel import Panel
|
|
6
10
|
|
|
@@ -9,6 +13,34 @@ from sqlsaber.cli.display import DisplayManager
|
|
|
9
13
|
from sqlsaber.cli.streaming import StreamingQueryHandler
|
|
10
14
|
|
|
11
15
|
|
|
16
|
+
class SlashCommandCompleter(Completer):
|
|
17
|
+
"""Custom completer for slash commands."""
|
|
18
|
+
|
|
19
|
+
def get_completions(self, document, complete_event):
|
|
20
|
+
"""Get completions for slash commands."""
|
|
21
|
+
# Only provide completions if the line starts with "/"
|
|
22
|
+
text = document.text
|
|
23
|
+
if text.startswith("/"):
|
|
24
|
+
# Get the partial command after the slash
|
|
25
|
+
partial_cmd = text[1:]
|
|
26
|
+
|
|
27
|
+
# Define available commands with descriptions
|
|
28
|
+
commands = [
|
|
29
|
+
("clear", "Clear conversation history"),
|
|
30
|
+
("exit", "Exit the interactive session"),
|
|
31
|
+
("quit", "Exit the interactive session"),
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
# Yield completions that match the partial command
|
|
35
|
+
for cmd, description in commands:
|
|
36
|
+
if cmd.startswith(partial_cmd):
|
|
37
|
+
yield Completion(
|
|
38
|
+
cmd,
|
|
39
|
+
start_position=-len(partial_cmd),
|
|
40
|
+
display_meta=description,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
12
44
|
class InteractiveSession:
|
|
13
45
|
"""Manages interactive CLI sessions."""
|
|
14
46
|
|
|
@@ -17,6 +49,8 @@ class InteractiveSession:
|
|
|
17
49
|
self.agent = agent
|
|
18
50
|
self.display = DisplayManager(console)
|
|
19
51
|
self.streaming_handler = StreamingQueryHandler(console)
|
|
52
|
+
self.current_task: Optional[asyncio.Task] = None
|
|
53
|
+
self.cancellation_token: Optional[asyncio.Event] = None
|
|
20
54
|
|
|
21
55
|
def show_welcome_message(self):
|
|
22
56
|
"""Display welcome message for interactive mode."""
|
|
@@ -28,7 +62,7 @@ class InteractiveSession:
|
|
|
28
62
|
Panel.fit(
|
|
29
63
|
"[bold green]SQLSaber - Use the agent Luke![/bold green]\n\n"
|
|
30
64
|
"[bold]Your agentic SQL assistant.[/bold]\n\n\n"
|
|
31
|
-
"[dim]Use 'clear' to reset conversation, 'exit' or 'quit' to leave.[/dim]\n\n"
|
|
65
|
+
"[dim]Use '/clear' to reset conversation, '/exit' or '/quit' to leave.[/dim]\n\n"
|
|
32
66
|
"[dim]Start a message with '#' to add something to agent's memory for this database.[/dim]",
|
|
33
67
|
border_style="green",
|
|
34
68
|
)
|
|
@@ -38,8 +72,31 @@ class InteractiveSession:
|
|
|
38
72
|
)
|
|
39
73
|
self.console.print(
|
|
40
74
|
"[dim]Press Esc-Enter or Meta-Enter to submit your query.[/dim]\n"
|
|
75
|
+
"[dim]Press Ctrl+C during query execution to interrupt and return to prompt.[/dim]\n"
|
|
41
76
|
)
|
|
42
77
|
|
|
78
|
+
async def _execute_query_with_cancellation(self, user_query: str):
|
|
79
|
+
"""Execute a query with cancellation support."""
|
|
80
|
+
# Create cancellation token
|
|
81
|
+
self.cancellation_token = asyncio.Event()
|
|
82
|
+
|
|
83
|
+
# Create the query task
|
|
84
|
+
query_task = asyncio.create_task(
|
|
85
|
+
self.streaming_handler.execute_streaming_query(
|
|
86
|
+
user_query, self.agent, self.cancellation_token
|
|
87
|
+
)
|
|
88
|
+
)
|
|
89
|
+
self.current_task = query_task
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
# Simply await the query task
|
|
93
|
+
# Ctrl+C will be handled by the KeyboardInterrupt exception in run()
|
|
94
|
+
await query_task
|
|
95
|
+
|
|
96
|
+
finally:
|
|
97
|
+
self.current_task = None
|
|
98
|
+
self.cancellation_token = None
|
|
99
|
+
|
|
43
100
|
async def run(self):
|
|
44
101
|
"""Run the interactive session loop."""
|
|
45
102
|
self.show_welcome_message()
|
|
@@ -51,12 +108,16 @@ class InteractiveSession:
|
|
|
51
108
|
qmark="",
|
|
52
109
|
multiline=True,
|
|
53
110
|
instruction="",
|
|
111
|
+
completer=SlashCommandCompleter(),
|
|
54
112
|
).ask_async()
|
|
55
113
|
|
|
56
|
-
if user_query
|
|
114
|
+
if not user_query:
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
if user_query in ["/exit", "/quit"]:
|
|
57
118
|
break
|
|
58
119
|
|
|
59
|
-
if user_query
|
|
120
|
+
if user_query == "/clear":
|
|
60
121
|
self.agent.clear_history()
|
|
61
122
|
self.console.print("[green]Conversation history cleared.[/green]\n")
|
|
62
123
|
continue
|
|
@@ -85,12 +146,24 @@ class InteractiveSession:
|
|
|
85
146
|
)
|
|
86
147
|
continue
|
|
87
148
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
)
|
|
149
|
+
# Execute query with cancellation support
|
|
150
|
+
await self._execute_query_with_cancellation(user_query)
|
|
91
151
|
self.display.show_newline() # Empty line for readability
|
|
92
152
|
|
|
93
153
|
except KeyboardInterrupt:
|
|
94
|
-
|
|
154
|
+
# Handle Ctrl+C - cancel current task if running
|
|
155
|
+
if self.current_task and not self.current_task.done():
|
|
156
|
+
if self.cancellation_token is not None:
|
|
157
|
+
self.cancellation_token.set()
|
|
158
|
+
self.current_task.cancel()
|
|
159
|
+
try:
|
|
160
|
+
await self.current_task
|
|
161
|
+
except asyncio.CancelledError:
|
|
162
|
+
pass
|
|
163
|
+
self.console.print("\n[yellow]Query interrupted[/yellow]")
|
|
164
|
+
else:
|
|
165
|
+
self.console.print(
|
|
166
|
+
"\n[yellow]Use '/exit' or '/quit' to leave.[/yellow]"
|
|
167
|
+
)
|
|
95
168
|
except Exception as e:
|
|
96
169
|
self.console.print(f"[bold red]Error:[/bold red] {str(e)}")
|
sqlsaber/cli/streaming.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Streaming query handling for the CLI."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
4
|
+
|
|
3
5
|
from rich.console import Console
|
|
4
6
|
|
|
5
7
|
from sqlsaber.agents.base import BaseSQLAgent
|
|
@@ -13,7 +15,12 @@ class StreamingQueryHandler:
|
|
|
13
15
|
self.console = console
|
|
14
16
|
self.display = DisplayManager(console)
|
|
15
17
|
|
|
16
|
-
async def execute_streaming_query(
|
|
18
|
+
async def execute_streaming_query(
|
|
19
|
+
self,
|
|
20
|
+
user_query: str,
|
|
21
|
+
agent: BaseSQLAgent,
|
|
22
|
+
cancellation_token: asyncio.Event | None = None,
|
|
23
|
+
):
|
|
17
24
|
"""Execute a query with streaming display."""
|
|
18
25
|
|
|
19
26
|
has_content = False
|
|
@@ -24,7 +31,12 @@ class StreamingQueryHandler:
|
|
|
24
31
|
status.start()
|
|
25
32
|
|
|
26
33
|
try:
|
|
27
|
-
async for event in agent.query_stream(
|
|
34
|
+
async for event in agent.query_stream(
|
|
35
|
+
user_query, cancellation_token=cancellation_token
|
|
36
|
+
):
|
|
37
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
38
|
+
break
|
|
39
|
+
|
|
28
40
|
if event.type == "tool_use":
|
|
29
41
|
# Stop any ongoing status, but don't mark has_content yet
|
|
30
42
|
self._stop_status(status)
|
|
@@ -63,6 +75,11 @@ class StreamingQueryHandler:
|
|
|
63
75
|
self.display.show_schema_info(event.data["result"])
|
|
64
76
|
has_content = True
|
|
65
77
|
|
|
78
|
+
elif event.type == "plot_result":
|
|
79
|
+
# Handle plot results
|
|
80
|
+
self.display.show_plot(event.data)
|
|
81
|
+
has_content = True
|
|
82
|
+
|
|
66
83
|
elif event.type == "processing":
|
|
67
84
|
# Show status when processing tool results
|
|
68
85
|
if explanation_started:
|
|
@@ -78,6 +95,13 @@ class StreamingQueryHandler:
|
|
|
78
95
|
has_content = True
|
|
79
96
|
self.display.show_error(event.data)
|
|
80
97
|
|
|
98
|
+
except asyncio.CancelledError:
|
|
99
|
+
# Handle cancellation gracefully
|
|
100
|
+
self._stop_status(status)
|
|
101
|
+
if explanation_started:
|
|
102
|
+
self.display.show_newline()
|
|
103
|
+
self.console.print("[yellow]Query interrupted[/yellow]")
|
|
104
|
+
return
|
|
81
105
|
finally:
|
|
82
106
|
# Make sure status is stopped
|
|
83
107
|
self._stop_status(status)
|
sqlsaber/models/events.py
CHANGED
|
@@ -7,7 +7,7 @@ class StreamEvent:
|
|
|
7
7
|
"""Event emitted during streaming processing."""
|
|
8
8
|
|
|
9
9
|
def __init__(self, event_type: str, data: Any = None):
|
|
10
|
-
# 'tool_use', 'text', 'query_result', 'error', 'processing'
|
|
10
|
+
# 'tool_use', 'text', 'query_result', 'plot_result', 'error', 'processing'
|
|
11
11
|
self.type = event_type
|
|
12
12
|
self.data = data
|
|
13
13
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sqlsaber
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.0
|
|
4
4
|
Summary: SQLSaber - Agentic SQL assistant like Claude Code
|
|
5
5
|
License-File: LICENSE
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -16,6 +16,7 @@ Requires-Dist: platformdirs>=4.0.0
|
|
|
16
16
|
Requires-Dist: questionary>=2.1.0
|
|
17
17
|
Requires-Dist: rich>=13.7.0
|
|
18
18
|
Requires-Dist: typer>=0.16.0
|
|
19
|
+
Requires-Dist: uniplot>=0.21.2
|
|
19
20
|
Description-Content-Type: text/markdown
|
|
20
21
|
|
|
21
22
|
# SQLSaber
|
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
sqlsaber/__init__.py,sha256=QCFi8xTVMohelfi7zOV1-6oLCcGoiXoOcKQY-HNBCk8,66
|
|
2
2
|
sqlsaber/__main__.py,sha256=RIHxWeWh2QvLfah-2OkhI5IJxojWfy4fXpMnVEJYvxw,78
|
|
3
3
|
sqlsaber/agents/__init__.py,sha256=LWeSeEUE4BhkyAYFF3TE-fx8TtLud3oyEtyB8ojFJgo,167
|
|
4
|
-
sqlsaber/agents/anthropic.py,sha256=
|
|
5
|
-
sqlsaber/agents/base.py,sha256=
|
|
4
|
+
sqlsaber/agents/anthropic.py,sha256=FLVET2HvFmsEuFln9Hu4SaBs-Tnk-GestOgnDnUp3ps,17885
|
|
5
|
+
sqlsaber/agents/base.py,sha256=DAnezHl5RLYoef8XQ-n3KA9PowdrMbQrkjdGKPPnFsI,10570
|
|
6
6
|
sqlsaber/agents/mcp.py,sha256=FKtXgDrPZ2-xqUYCw2baI5JzrWekXaC5fjkYW1_Mg50,827
|
|
7
7
|
sqlsaber/agents/streaming.py,sha256=_EO390-FHUrL1fRCNfibtE9QuJz3LGQygbwG3CB2ViY,533
|
|
8
8
|
sqlsaber/cli/__init__.py,sha256=qVSLVJLLJYzoC6aj6y9MFrzZvAwc4_OgxU9DlkQnZ4M,86
|
|
9
9
|
sqlsaber/cli/commands.py,sha256=Dw24W0jij-8t1lpk99C4PBTgzFSag6vU-FZcjAYGG54,5074
|
|
10
10
|
sqlsaber/cli/database.py,sha256=DUfyvNBDp47oFM_VAC_hXHQy_qyE7JbXtowflJpwwH8,12643
|
|
11
|
-
sqlsaber/cli/display.py,sha256=
|
|
12
|
-
sqlsaber/cli/interactive.py,sha256=
|
|
11
|
+
sqlsaber/cli/display.py,sha256=lZW7BI2LusU5lJhov7u9kKWwfsqcXGSfjrw-kNO3FQA,9200
|
|
12
|
+
sqlsaber/cli/interactive.py,sha256=RNEWyCM1gLLUsXaXaCFzXu0PFDVuAK8NBOOYYFTHTUE,6716
|
|
13
13
|
sqlsaber/cli/memory.py,sha256=LW4ZF2V6Gw6hviUFGZ4ym9ostFCwucgBTIMZ3EANO-I,7671
|
|
14
14
|
sqlsaber/cli/models.py,sha256=3IcXeeU15IQvemSv-V-RQzVytJ3wuQ4YmWk89nTDcSE,7813
|
|
15
|
-
sqlsaber/cli/streaming.py,sha256=
|
|
15
|
+
sqlsaber/cli/streaming.py,sha256=2vLCYqqziQTO52erfgvnEk_hM3BoDM1TMBAXgT7KKfo,4548
|
|
16
16
|
sqlsaber/config/__init__.py,sha256=olwC45k8Nc61yK0WmPUk7XHdbsZH9HuUAbwnmKe3IgA,100
|
|
17
17
|
sqlsaber/config/api_keys.py,sha256=kLdoExF_My9ojmdhO5Ca7-ZeowsO0v1GVa_QT5jjUPo,3658
|
|
18
18
|
sqlsaber/config/database.py,sha256=vKFOxPjVakjQhj1uoLcfzhS9ZFr6Z2F5b4MmYALQZoA,11421
|
|
@@ -26,10 +26,10 @@ sqlsaber/memory/__init__.py,sha256=GiWkU6f6YYVV0EvvXDmFWe_CxarmDCql05t70MkTEWs,6
|
|
|
26
26
|
sqlsaber/memory/manager.py,sha256=ML2NEO5Z4Aw36sEI9eOvWVnjl-qT2VOTojViJAj7Seo,2777
|
|
27
27
|
sqlsaber/memory/storage.py,sha256=DvZBsSPaAfk_DqrNEn86uMD-TQsWUI6rQLfNw6PSCB8,5788
|
|
28
28
|
sqlsaber/models/__init__.py,sha256=RJ7p3WtuSwwpFQ1Iw4_DHV2zzCtHqIzsjJzxv8kUjUE,287
|
|
29
|
-
sqlsaber/models/events.py,sha256=
|
|
29
|
+
sqlsaber/models/events.py,sha256=q2FackB60J9-7vegYIjzElLwKebIh7nxnV5AFoZc67c,752
|
|
30
30
|
sqlsaber/models/types.py,sha256=3U_30n91EB3IglBTHipwiW4MqmmaA2qfshfraMZyPps,896
|
|
31
|
-
sqlsaber-0.
|
|
32
|
-
sqlsaber-0.
|
|
33
|
-
sqlsaber-0.
|
|
34
|
-
sqlsaber-0.
|
|
35
|
-
sqlsaber-0.
|
|
31
|
+
sqlsaber-0.6.0.dist-info/METADATA,sha256=Mvou6xXxA8T2Cfwq4y1bd0DYW46zBs4Qw7oXWNfihfE,5969
|
|
32
|
+
sqlsaber-0.6.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
33
|
+
sqlsaber-0.6.0.dist-info/entry_points.txt,sha256=jmFo96Ylm0zIKXJBwhv_P5wQ7SXP9qdaBcnTp8iCEe8,195
|
|
34
|
+
sqlsaber-0.6.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
35
|
+
sqlsaber-0.6.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|