sqlsaber 0.4.1__tar.gz → 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

Files changed (50) hide show
  1. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/CHANGELOG.md +30 -0
  2. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/PKG-INFO +2 -1
  3. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/pyproject.toml +2 -1
  4. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/agents/anthropic.py +105 -18
  5. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/agents/base.py +105 -3
  6. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/cli/display.py +41 -3
  7. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/cli/interactive.py +80 -7
  8. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/cli/streaming.py +26 -2
  9. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/models/events.py +1 -1
  10. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/uv.lock +25 -1
  11. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/.github/workflows/publish.yml +0 -0
  12. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/.gitignore +0 -0
  13. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/.python-version +0 -0
  14. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/CLAUDE.md +0 -0
  15. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/LICENSE +0 -0
  16. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/README.md +0 -0
  17. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/pytest.ini +0 -0
  18. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/__init__.py +0 -0
  19. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/__main__.py +0 -0
  20. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/agents/__init__.py +0 -0
  21. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/agents/mcp.py +0 -0
  22. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/agents/streaming.py +0 -0
  23. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/cli/__init__.py +0 -0
  24. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/cli/commands.py +0 -0
  25. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/cli/database.py +0 -0
  26. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/cli/memory.py +0 -0
  27. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/cli/models.py +0 -0
  28. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/config/__init__.py +0 -0
  29. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/config/api_keys.py +0 -0
  30. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/config/database.py +0 -0
  31. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/config/settings.py +0 -0
  32. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/database/__init__.py +0 -0
  33. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/database/connection.py +0 -0
  34. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/database/schema.py +0 -0
  35. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/mcp/__init__.py +0 -0
  36. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/mcp/mcp.py +0 -0
  37. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/memory/__init__.py +0 -0
  38. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/memory/manager.py +0 -0
  39. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/memory/storage.py +0 -0
  40. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/models/__init__.py +0 -0
  41. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/src/sqlsaber/models/types.py +0 -0
  42. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/tests/__init__.py +0 -0
  43. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/tests/conftest.py +0 -0
  44. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/tests/test_cli/__init__.py +0 -0
  45. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/tests/test_cli/test_commands.py +0 -0
  46. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/tests/test_config/__init__.py +0 -0
  47. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/tests/test_config/test_database.py +0 -0
  48. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/tests/test_config/test_settings.py +0 -0
  49. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/tests/test_database/__init__.py +0 -0
  50. {sqlsaber-0.4.1 → sqlsaber-0.6.0}/tests/test_database/test_connection.py +0 -0
@@ -4,6 +4,36 @@ All notable changes to SQLSaber will be documented in this file.
4
4
 
5
5
  ## [Unreleased]
6
6
 
7
+ ## [0.6.0] - 2025-06-30
8
+
9
+ ### Added
10
+
11
+ - Slash command autocomplete in interactive mode
12
+ - Commands now use slash prefix: `/clear`, `/exit`, `/quit`
13
+ - Autocomplete shows when typing `/` at the start of a line
14
+ - Press Tab to select suggestion
15
+ - Query interruption with Ctrl+C in interactive mode
16
+ - Press Ctrl+C during query execution to gracefully cancel ongoing operations
17
+ - Preserves conversation history up to the interruption point
18
+
19
+ ### Changed
20
+
21
+ - Updated table display for better readability: limit to first 15 columns on wide tables
22
+ - Shows warning when columns are truncated
23
+ - Interactive commands now require slash prefix (breaking change)
24
+ - `clear` → `/clear`
25
+ - `exit` → `/exit`
26
+ - `quit` → `/quit`
27
+ - Removed default limit of 100. Now model will decide it.
28
+
29
+ ## [0.5.0] - 2025-06-27
30
+
31
+ ### Added
32
+
33
+ - Added support for plotting data from query results.
34
+ - The agent can decide if plotting will useful and create a plot with query results.
35
+ - Small updates to system prompt
36
+
7
37
  ## [0.4.1] - 2025-06-26
8
38
 
9
39
  ### Added
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sqlsaber
3
- Version: 0.4.1
3
+ Version: 0.6.0
4
4
  Summary: SQLSaber - Agentic SQL assistant like Claude Code
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.12
@@ -16,6 +16,7 @@ Requires-Dist: platformdirs>=4.0.0
16
16
  Requires-Dist: questionary>=2.1.0
17
17
  Requires-Dist: rich>=13.7.0
18
18
  Requires-Dist: typer>=0.16.0
19
+ Requires-Dist: uniplot>=0.21.2
19
20
  Description-Content-Type: text/markdown
20
21
 
21
22
  # SQLSaber
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "sqlsaber"
3
- version = "0.4.1"
3
+ version = "0.6.0"
4
4
  description = "SQLSaber - Agentic SQL assistant like Claude Code"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -17,6 +17,7 @@ dependencies = [
17
17
  "aiosqlite>=0.21.0",
18
18
  "pandas>=2.0.0",
19
19
  "fastmcp>=2.9.0",
20
+ "uniplot>=0.21.2",
20
21
  ]
21
22
 
22
23
  [tool.uv]
@@ -1,7 +1,8 @@
1
1
  """Anthropic-specific SQL agent implementation."""
2
2
 
3
+ import asyncio
3
4
  import json
4
- from typing import Any, AsyncIterator, Dict, List, Optional
5
+ from typing import Any, AsyncIterator, Dict, List
5
6
 
6
7
  from anthropic import AsyncAnthropic
7
8
 
@@ -21,7 +22,7 @@ class AnthropicSQLAgent(BaseSQLAgent):
21
22
  """SQL Agent using Anthropic SDK directly."""
22
23
 
23
24
  def __init__(
24
- self, db_connection: BaseDatabaseConnection, database_name: Optional[str] = None
25
+ self, db_connection: BaseDatabaseConnection, database_name: str | None = None
25
26
  ):
26
27
  super().__init__(db_connection)
27
28
 
@@ -82,6 +83,44 @@ class AnthropicSQLAgent(BaseSQLAgent):
82
83
  "required": ["query"],
83
84
  },
84
85
  },
86
+ {
87
+ "name": "plot_data",
88
+ "description": "Create a plot of query results.",
89
+ "input_schema": {
90
+ "type": "object",
91
+ "properties": {
92
+ "y_values": {
93
+ "type": "array",
94
+ "items": {"type": ["number", "null"]},
95
+ "description": "Y-axis data points (required)",
96
+ },
97
+ "x_values": {
98
+ "type": "array",
99
+ "items": {"type": ["number", "null"]},
100
+ "description": "X-axis data points (optional, will use indices if not provided)",
101
+ },
102
+ "plot_type": {
103
+ "type": "string",
104
+ "enum": ["line", "scatter", "histogram"],
105
+ "description": "Type of plot to create (default: line)",
106
+ "default": "line",
107
+ },
108
+ "title": {
109
+ "type": "string",
110
+ "description": "Title for the plot",
111
+ },
112
+ "x_label": {
113
+ "type": "string",
114
+ "description": "Label for X-axis",
115
+ },
116
+ "y_label": {
117
+ "type": "string",
118
+ "description": "Label for Y-axis",
119
+ },
120
+ },
121
+ "required": ["y_values"],
122
+ },
123
+ },
85
124
  ]
86
125
 
87
126
  # Build system prompt with memories if available
@@ -96,13 +135,15 @@ Your responsibilities:
96
135
  1. Understand user's natural language requests, think and convert them to SQL
97
136
  2. Use the provided tools efficiently to explore database schema
98
137
  3. Generate appropriate SQL queries
99
- 4. Execute queries safely (only SELECT queries unless explicitly allowed)
138
+ 4. Execute queries safely - queries that modify the database are not allowed
100
139
  5. Format and explain results clearly
140
+ 6. Create visualizations when requested or when they would be helpful
101
141
 
102
142
  IMPORTANT - Schema Discovery Strategy:
103
143
  1. ALWAYS start with 'list_tables' to see available tables and row counts
104
144
  2. Based on the user's query, identify which specific tables are relevant
105
145
  3. Use 'introspect_schema' with a table_pattern to get details ONLY for relevant tables
146
+ 4. Timestamp columns must be converted to text when you write queries
106
147
 
107
148
  Guidelines:
108
149
  - Use list_tables first, then introspect_schema for specific tables only
@@ -124,7 +165,7 @@ Guidelines:
124
165
 
125
166
  return base_prompt
126
167
 
127
- def add_memory(self, content: str) -> Optional[str]:
168
+ def add_memory(self, content: str) -> str | None:
128
169
  """Add a memory for the current database."""
129
170
  if not self.database_name:
130
171
  return None
@@ -134,7 +175,7 @@ Guidelines:
134
175
  self.system_prompt = self._build_system_prompt()
135
176
  return memory.id
136
177
 
137
- async def execute_sql(self, query: str, limit: Optional[int] = 100) -> str:
178
+ async def execute_sql(self, query: str, limit: int | None = None) -> str:
138
179
  """Execute a SQL query against the database with streaming support."""
139
180
  # Call parent implementation for core functionality
140
181
  result = await super().execute_sql(query, limit)
@@ -163,10 +204,18 @@ Guidelines:
163
204
  return await super().process_tool_call(tool_name, tool_input)
164
205
 
165
206
  async def _process_stream_events(
166
- self, stream, content_blocks: List[Dict], tool_use_blocks: List[Dict]
207
+ self,
208
+ stream,
209
+ content_blocks: List[Dict],
210
+ tool_use_blocks: List[Dict],
211
+ cancellation_token: asyncio.Event | None = None,
167
212
  ) -> AsyncIterator[StreamEvent]:
168
213
  """Process stream events and yield appropriate StreamEvents."""
169
214
  async for event in stream:
215
+ # Only check cancellation if token is provided
216
+ if cancellation_token is not None and cancellation_token.is_set():
217
+ return
218
+
170
219
  if event.type == "content_block_start":
171
220
  if hasattr(event.content_block, "type"):
172
221
  if event.content_block.type == "tool_use":
@@ -213,11 +262,17 @@ Guidelines:
213
262
  return "stop"
214
263
 
215
264
  async def _process_tool_results(
216
- self, response: StreamingResponse
265
+ self,
266
+ response: StreamingResponse,
267
+ cancellation_token: asyncio.Event | None = None,
217
268
  ) -> AsyncIterator[StreamEvent]:
218
269
  """Process tool results and yield appropriate events."""
219
270
  tool_results = []
220
271
  for block in response.content:
272
+ # Only check cancellation if token is provided
273
+ if cancellation_token is not None and cancellation_token.is_set():
274
+ return
275
+
221
276
  if block.get("type") == "tool_use":
222
277
  yield StreamEvent(
223
278
  "tool_use",
@@ -249,13 +304,25 @@ Guidelines:
249
304
  "result": tool_result,
250
305
  },
251
306
  )
307
+ elif block["name"] == "plot_data":
308
+ yield StreamEvent(
309
+ "plot_result",
310
+ {
311
+ "tool_name": block["name"],
312
+ "input": block["input"],
313
+ "result": tool_result,
314
+ },
315
+ )
252
316
 
253
317
  tool_results.append(build_tool_result_block(block["id"], tool_result))
254
318
 
255
319
  yield StreamEvent("tool_result_data", tool_results)
256
320
 
257
321
  async def query_stream(
258
- self, user_query: str, use_history: bool = True
322
+ self,
323
+ user_query: str,
324
+ use_history: bool = True,
325
+ cancellation_token: asyncio.Event | None = None,
259
326
  ) -> AsyncIterator[StreamEvent]:
260
327
  """Process a user query and stream responses."""
261
328
  # Initialize for tracking state
@@ -273,7 +340,11 @@ Guidelines:
273
340
  try:
274
341
  # Create initial stream and get response
275
342
  response = None
276
- async for event in self._create_and_process_stream(messages):
343
+ async for event in self._create_and_process_stream(
344
+ messages, cancellation_token
345
+ ):
346
+ if cancellation_token is not None and cancellation_token.is_set():
347
+ return
277
348
  if event.type == "response_ready":
278
349
  response = event.data
279
350
  else:
@@ -283,14 +354,21 @@ Guidelines:
283
354
 
284
355
  # Process tool calls if needed
285
356
  while response is not None and response.stop_reason == "tool_use":
357
+ # Check for cancellation at the start of tool cycle
358
+ if cancellation_token is not None and cancellation_token.is_set():
359
+ return
360
+
286
361
  # Add assistant's response to conversation
287
362
  collected_content.append(
288
363
  {"role": "assistant", "content": response.content}
289
364
  )
290
365
 
291
- # Process tool results
366
+ # Process tool results - DO NOT check cancellation during tool execution
367
+ # as this would break the tool_use -> tool_result API contract
292
368
  tool_results = []
293
- async for event in self._process_tool_results(response):
369
+ async for event in self._process_tool_results(
370
+ response, None
371
+ ): # Pass None to disable cancellation checks
294
372
  if event.type == "tool_result_data":
295
373
  tool_results = event.data
296
374
  else:
@@ -298,6 +376,12 @@ Guidelines:
298
376
 
299
377
  # Continue conversation with tool results
300
378
  collected_content.append({"role": "user", "content": tool_results})
379
+ if use_history:
380
+ self.conversation_history.extend(collected_content)
381
+
382
+ # Check for cancellation AFTER tool results are complete
383
+ if cancellation_token is not None and cancellation_token.is_set():
384
+ return
301
385
 
302
386
  # Signal that we're processing the tool results
303
387
  yield StreamEvent("processing", "Analyzing results...")
@@ -305,8 +389,10 @@ Guidelines:
305
389
  # Get next response
306
390
  response = None
307
391
  async for event in self._create_and_process_stream(
308
- messages + collected_content
392
+ messages + collected_content, cancellation_token
309
393
  ):
394
+ if cancellation_token is not None and cancellation_token.is_set():
395
+ return
310
396
  if event.type == "response_ready":
311
397
  response = event.data
312
398
  else:
@@ -314,21 +400,19 @@ Guidelines:
314
400
 
315
401
  # Update conversation history if using history
316
402
  if use_history:
317
- self.conversation_history.append(
318
- {"role": "user", "content": user_query}
319
- )
320
- self.conversation_history.extend(collected_content)
321
403
  # Add final assistant response
322
404
  if response is not None:
323
405
  self.conversation_history.append(
324
406
  {"role": "assistant", "content": response.content}
325
407
  )
326
408
 
409
+ except asyncio.CancelledError:
410
+ return
327
411
  except Exception as e:
328
412
  yield StreamEvent("error", str(e))
329
413
 
330
414
  async def _create_and_process_stream(
331
- self, messages: List[Dict]
415
+ self, messages: List[Dict], cancellation_token: asyncio.Event | None = None
332
416
  ) -> AsyncIterator[StreamEvent]:
333
417
  """Create a stream and yield events while building response."""
334
418
  stream = await self.client.messages.create(
@@ -344,8 +428,11 @@ Guidelines:
344
428
  tool_use_blocks = []
345
429
 
346
430
  async for event in self._process_stream_events(
347
- stream, content_blocks, tool_use_blocks
431
+ stream, content_blocks, tool_use_blocks, cancellation_token
348
432
  ):
433
+ # Only check cancellation if token is provided
434
+ if cancellation_token is not None and cancellation_token.is_set():
435
+ return
349
436
  yield event
350
437
 
351
438
  # Finalize tool blocks and create response
@@ -1,9 +1,12 @@
1
1
  """Abstract base class for SQL agents."""
2
2
 
3
+ import asyncio
3
4
  import json
4
5
  from abc import ABC, abstractmethod
5
6
  from typing import Any, AsyncIterator, Dict, List, Optional
6
7
 
8
+ from uniplot import histogram, plot
9
+
7
10
  from sqlsaber.database.connection import (
8
11
  BaseDatabaseConnection,
9
12
  CSVConnection,
@@ -25,9 +28,18 @@ class BaseSQLAgent(ABC):
25
28
 
26
29
  @abstractmethod
27
30
  async def query_stream(
28
- self, user_query: str, use_history: bool = True
31
+ self,
32
+ user_query: str,
33
+ use_history: bool = True,
34
+ cancellation_token: asyncio.Event | None = None,
29
35
  ) -> AsyncIterator[StreamEvent]:
30
- """Process a user query and stream responses."""
36
+ """Process a user query and stream responses.
37
+
38
+ Args:
39
+ user_query: The user's query to process
40
+ use_history: Whether to include conversation history
41
+ cancellation_token: Optional event to signal cancellation
42
+ """
31
43
  pass
32
44
 
33
45
  def clear_history(self):
@@ -84,7 +96,7 @@ class BaseSQLAgent(ABC):
84
96
  except Exception as e:
85
97
  return json.dumps({"error": f"Error listing tables: {str(e)}"})
86
98
 
87
- async def execute_sql(self, query: str, limit: Optional[int] = 100) -> str:
99
+ async def execute_sql(self, query: str, limit: Optional[int] = None) -> str:
88
100
  """Execute a SQL query against the database."""
89
101
  try:
90
102
  # Security check - only allow SELECT queries unless write is enabled
@@ -146,6 +158,15 @@ class BaseSQLAgent(ABC):
146
158
  return await self.execute_sql(
147
159
  tool_input["query"], tool_input.get("limit", 100)
148
160
  )
161
+ elif tool_name == "plot_data":
162
+ return await self.plot_data(
163
+ y_values=tool_input["y_values"],
164
+ x_values=tool_input.get("x_values"),
165
+ plot_type=tool_input.get("plot_type", "line"),
166
+ title=tool_input.get("title"),
167
+ x_label=tool_input.get("x_label"),
168
+ y_label=tool_input.get("y_label"),
169
+ )
149
170
  else:
150
171
  return json.dumps({"error": f"Unknown tool: {tool_name}"})
151
172
 
@@ -182,3 +203,84 @@ class BaseSQLAgent(ABC):
182
203
  if query_upper.startswith("SELECT") and "LIMIT" not in query_upper:
183
204
  return f"{query.rstrip(';')} LIMIT {limit};"
184
205
  return query
206
+
207
+ async def plot_data(
208
+ self,
209
+ y_values: List[float],
210
+ x_values: Optional[List[float]] = None,
211
+ plot_type: str = "line",
212
+ title: Optional[str] = None,
213
+ x_label: Optional[str] = None,
214
+ y_label: Optional[str] = None,
215
+ ) -> str:
216
+ """Create a terminal plot using uniplot.
217
+
218
+ Args:
219
+ y_values: Y-axis data points
220
+ x_values: X-axis data points (optional)
221
+ plot_type: Type of plot - "line", "scatter", or "histogram"
222
+ title: Plot title
223
+ x_label: X-axis label
224
+ y_label: Y-axis label
225
+
226
+ Returns:
227
+ JSON string with success status and plot details
228
+ """
229
+ try:
230
+ # Validate inputs
231
+ if not y_values:
232
+ return json.dumps({"error": "No data provided for plotting"})
233
+
234
+ # Convert to floats if needed
235
+ try:
236
+ y_values = [float(v) if v is not None else None for v in y_values]
237
+ if x_values:
238
+ x_values = [float(v) if v is not None else None for v in x_values]
239
+ except (ValueError, TypeError) as e:
240
+ return json.dumps({"error": f"Invalid data format: {str(e)}"})
241
+
242
+ # Create the plot
243
+ if plot_type == "histogram":
244
+ # For histogram, we only need y_values
245
+ histogram(
246
+ y_values,
247
+ title=title,
248
+ bins=min(20, len(set(y_values))), # Adaptive bin count
249
+ )
250
+ plot_info = {
251
+ "type": "histogram",
252
+ "data_points": len(y_values),
253
+ "title": title or "Histogram",
254
+ }
255
+ elif plot_type in ["line", "scatter"]:
256
+ # For line/scatter plots
257
+ plot_kwargs = {
258
+ "ys": y_values,
259
+ "title": title,
260
+ "lines": plot_type == "line",
261
+ }
262
+
263
+ if x_values:
264
+ plot_kwargs["xs"] = x_values
265
+ if x_label:
266
+ plot_kwargs["x_unit"] = x_label
267
+ if y_label:
268
+ plot_kwargs["y_unit"] = y_label
269
+
270
+ plot(**plot_kwargs)
271
+
272
+ plot_info = {
273
+ "type": plot_type,
274
+ "data_points": len(y_values),
275
+ "title": title or f"{plot_type.capitalize()} Plot",
276
+ "has_x_values": x_values is not None,
277
+ }
278
+ else:
279
+ return json.dumps({"error": f"Unsupported plot type: {plot_type}"})
280
+
281
+ return json.dumps(
282
+ {"success": True, "plot_rendered": True, "plot_info": plot_info}
283
+ )
284
+
285
+ except Exception as e:
286
+ return json.dumps({"error": f"Error creating plot: {str(e)}"})
@@ -62,12 +62,20 @@ class DisplayManager:
62
62
  )
63
63
 
64
64
  # Create table with columns from first result
65
- columns = list(results[0].keys())
66
- table = self._create_table(columns)
65
+ all_columns = list(results[0].keys())
66
+ display_columns = all_columns[:15] # Limit to first 15 columns
67
+
68
+ # Show warning if columns were truncated
69
+ if len(all_columns) > 15:
70
+ self.console.print(
71
+ f"[yellow]Note: Showing first 15 of {len(all_columns)} columns[/yellow]"
72
+ )
73
+
74
+ table = self._create_table(display_columns)
67
75
 
68
76
  # Add rows (show first 20 rows)
69
77
  for row in results[:20]:
70
- table.add_row(*[str(row[key]) for key in columns])
78
+ table.add_row(*[str(row[key]) for key in display_columns])
71
79
 
72
80
  self.console.print(table)
73
81
 
@@ -205,3 +213,33 @@ class DisplayManager:
205
213
  self.show_error("Failed to parse schema data")
206
214
  except Exception as e:
207
215
  self.show_error(f"Error displaying schema information: {str(e)}")
216
+
217
+ def show_plot(self, plot_data: dict):
218
+ """Display plot information and status."""
219
+ try:
220
+ # Parse the result if it's a string
221
+ if isinstance(plot_data.get("result"), str):
222
+ result = json.loads(plot_data["result"])
223
+ else:
224
+ result = plot_data.get("result", {})
225
+
226
+ # Check if there was an error
227
+ if "error" in result:
228
+ self.show_error(f"Plot error: {result['error']}")
229
+ return
230
+
231
+ # If plot was successful, show plot info
232
+ if result.get("success") and result.get("plot_rendered"):
233
+ plot_info = result.get("plot_info", {})
234
+ self.console.print(
235
+ f"\n[bold green]✓ Plot rendered:[/bold green] {plot_info.get('title', 'Plot')}"
236
+ )
237
+ self.console.print(
238
+ f"[dim] Type: {plot_info.get('type', 'unknown')}, "
239
+ f"Data points: {plot_info.get('data_points', 0)}[/dim]"
240
+ )
241
+
242
+ except json.JSONDecodeError:
243
+ self.show_error("Failed to parse plot result")
244
+ except Exception as e:
245
+ self.show_error(f"Error displaying plot: {str(e)}")
@@ -1,6 +1,10 @@
1
1
  """Interactive mode handling for the CLI."""
2
2
 
3
+ import asyncio
4
+ from typing import Optional
5
+
3
6
  import questionary
7
+ from prompt_toolkit.completion import Completer, Completion
4
8
  from rich.console import Console
5
9
  from rich.panel import Panel
6
10
 
@@ -9,6 +13,34 @@ from sqlsaber.cli.display import DisplayManager
9
13
  from sqlsaber.cli.streaming import StreamingQueryHandler
10
14
 
11
15
 
16
+ class SlashCommandCompleter(Completer):
17
+ """Custom completer for slash commands."""
18
+
19
+ def get_completions(self, document, complete_event):
20
+ """Get completions for slash commands."""
21
+ # Only provide completions if the line starts with "/"
22
+ text = document.text
23
+ if text.startswith("/"):
24
+ # Get the partial command after the slash
25
+ partial_cmd = text[1:]
26
+
27
+ # Define available commands with descriptions
28
+ commands = [
29
+ ("clear", "Clear conversation history"),
30
+ ("exit", "Exit the interactive session"),
31
+ ("quit", "Exit the interactive session"),
32
+ ]
33
+
34
+ # Yield completions that match the partial command
35
+ for cmd, description in commands:
36
+ if cmd.startswith(partial_cmd):
37
+ yield Completion(
38
+ cmd,
39
+ start_position=-len(partial_cmd),
40
+ display_meta=description,
41
+ )
42
+
43
+
12
44
  class InteractiveSession:
13
45
  """Manages interactive CLI sessions."""
14
46
 
@@ -17,6 +49,8 @@ class InteractiveSession:
17
49
  self.agent = agent
18
50
  self.display = DisplayManager(console)
19
51
  self.streaming_handler = StreamingQueryHandler(console)
52
+ self.current_task: Optional[asyncio.Task] = None
53
+ self.cancellation_token: Optional[asyncio.Event] = None
20
54
 
21
55
  def show_welcome_message(self):
22
56
  """Display welcome message for interactive mode."""
@@ -28,7 +62,7 @@ class InteractiveSession:
28
62
  Panel.fit(
29
63
  "[bold green]SQLSaber - Use the agent Luke![/bold green]\n\n"
30
64
  "[bold]Your agentic SQL assistant.[/bold]\n\n\n"
31
- "[dim]Use 'clear' to reset conversation, 'exit' or 'quit' to leave.[/dim]\n\n"
65
+ "[dim]Use '/clear' to reset conversation, '/exit' or '/quit' to leave.[/dim]\n\n"
32
66
  "[dim]Start a message with '#' to add something to agent's memory for this database.[/dim]",
33
67
  border_style="green",
34
68
  )
@@ -38,8 +72,31 @@ class InteractiveSession:
38
72
  )
39
73
  self.console.print(
40
74
  "[dim]Press Esc-Enter or Meta-Enter to submit your query.[/dim]\n"
75
+ "[dim]Press Ctrl+C during query execution to interrupt and return to prompt.[/dim]\n"
41
76
  )
42
77
 
78
+ async def _execute_query_with_cancellation(self, user_query: str):
79
+ """Execute a query with cancellation support."""
80
+ # Create cancellation token
81
+ self.cancellation_token = asyncio.Event()
82
+
83
+ # Create the query task
84
+ query_task = asyncio.create_task(
85
+ self.streaming_handler.execute_streaming_query(
86
+ user_query, self.agent, self.cancellation_token
87
+ )
88
+ )
89
+ self.current_task = query_task
90
+
91
+ try:
92
+ # Simply await the query task
93
+ # Ctrl+C will be handled by the KeyboardInterrupt exception in run()
94
+ await query_task
95
+
96
+ finally:
97
+ self.current_task = None
98
+ self.cancellation_token = None
99
+
43
100
  async def run(self):
44
101
  """Run the interactive session loop."""
45
102
  self.show_welcome_message()
@@ -51,12 +108,16 @@ class InteractiveSession:
51
108
  qmark="",
52
109
  multiline=True,
53
110
  instruction="",
111
+ completer=SlashCommandCompleter(),
54
112
  ).ask_async()
55
113
 
56
- if user_query.lower() in ["exit", "quit", "q"]:
114
+ if not user_query:
115
+ continue
116
+
117
+ if user_query in ["/exit", "/quit"]:
57
118
  break
58
119
 
59
- if user_query.lower() == "clear":
120
+ if user_query == "/clear":
60
121
  self.agent.clear_history()
61
122
  self.console.print("[green]Conversation history cleared.[/green]\n")
62
123
  continue
@@ -85,12 +146,24 @@ class InteractiveSession:
85
146
  )
86
147
  continue
87
148
 
88
- await self.streaming_handler.execute_streaming_query(
89
- user_query, self.agent
90
- )
149
+ # Execute query with cancellation support
150
+ await self._execute_query_with_cancellation(user_query)
91
151
  self.display.show_newline() # Empty line for readability
92
152
 
93
153
  except KeyboardInterrupt:
94
- self.console.print("\n[yellow]Use 'exit' or 'quit' to leave.[/yellow]")
154
+ # Handle Ctrl+C - cancel current task if running
155
+ if self.current_task and not self.current_task.done():
156
+ if self.cancellation_token is not None:
157
+ self.cancellation_token.set()
158
+ self.current_task.cancel()
159
+ try:
160
+ await self.current_task
161
+ except asyncio.CancelledError:
162
+ pass
163
+ self.console.print("\n[yellow]Query interrupted[/yellow]")
164
+ else:
165
+ self.console.print(
166
+ "\n[yellow]Use '/exit' or '/quit' to leave.[/yellow]"
167
+ )
95
168
  except Exception as e:
96
169
  self.console.print(f"[bold red]Error:[/bold red] {str(e)}")
@@ -1,5 +1,7 @@
1
1
  """Streaming query handling for the CLI."""
2
2
 
3
+ import asyncio
4
+
3
5
  from rich.console import Console
4
6
 
5
7
  from sqlsaber.agents.base import BaseSQLAgent
@@ -13,7 +15,12 @@ class StreamingQueryHandler:
13
15
  self.console = console
14
16
  self.display = DisplayManager(console)
15
17
 
16
- async def execute_streaming_query(self, user_query: str, agent: BaseSQLAgent):
18
+ async def execute_streaming_query(
19
+ self,
20
+ user_query: str,
21
+ agent: BaseSQLAgent,
22
+ cancellation_token: asyncio.Event | None = None,
23
+ ):
17
24
  """Execute a query with streaming display."""
18
25
 
19
26
  has_content = False
@@ -24,7 +31,12 @@ class StreamingQueryHandler:
24
31
  status.start()
25
32
 
26
33
  try:
27
- async for event in agent.query_stream(user_query):
34
+ async for event in agent.query_stream(
35
+ user_query, cancellation_token=cancellation_token
36
+ ):
37
+ if cancellation_token is not None and cancellation_token.is_set():
38
+ break
39
+
28
40
  if event.type == "tool_use":
29
41
  # Stop any ongoing status, but don't mark has_content yet
30
42
  self._stop_status(status)
@@ -63,6 +75,11 @@ class StreamingQueryHandler:
63
75
  self.display.show_schema_info(event.data["result"])
64
76
  has_content = True
65
77
 
78
+ elif event.type == "plot_result":
79
+ # Handle plot results
80
+ self.display.show_plot(event.data)
81
+ has_content = True
82
+
66
83
  elif event.type == "processing":
67
84
  # Show status when processing tool results
68
85
  if explanation_started:
@@ -78,6 +95,13 @@ class StreamingQueryHandler:
78
95
  has_content = True
79
96
  self.display.show_error(event.data)
80
97
 
98
+ except asyncio.CancelledError:
99
+ # Handle cancellation gracefully
100
+ self._stop_status(status)
101
+ if explanation_started:
102
+ self.display.show_newline()
103
+ self.console.print("[yellow]Query interrupted[/yellow]")
104
+ return
81
105
  finally:
82
106
  # Make sure status is stopped
83
107
  self._stop_status(status)
@@ -7,7 +7,7 @@ class StreamEvent:
7
7
  """Event emitted during streaming processing."""
8
8
 
9
9
  def __init__(self, event_type: str, data: Any = None):
10
- # 'tool_use', 'text', 'query_result', 'error', 'processing'
10
+ # 'tool_use', 'text', 'query_result', 'plot_result', 'error', 'processing'
11
11
  self.type = event_type
12
12
  self.data = data
13
13
 
@@ -774,6 +774,15 @@ wheels = [
774
774
  { url = "https://files.pythonhosted.org/packages/ad/3f/11dd4cd4f39e05128bfd20138faea57bec56f9ffba6185d276e3107ba5b2/questionary-2.1.0-py3-none-any.whl", hash = "sha256:44174d237b68bc828e4878c763a9ad6790ee61990e0ae72927694ead57bab8ec", size = 36747 },
775
775
  ]
776
776
 
777
+ [[package]]
778
+ name = "readchar"
779
+ version = "4.2.1"
780
+ source = { registry = "https://pypi.org/simple" }
781
+ sdist = { url = "https://files.pythonhosted.org/packages/dd/f8/8657b8cbb4ebeabfbdf991ac40eca8a1d1bd012011bd44ad1ed10f5cb494/readchar-4.2.1.tar.gz", hash = "sha256:91ce3faf07688de14d800592951e5575e9c7a3213738ed01d394dcc949b79adb", size = 9685 }
782
+ wheels = [
783
+ { url = "https://files.pythonhosted.org/packages/a9/10/e4b1e0e5b6b6745c8098c275b69bc9d73e9542d5c7da4f137542b499ed44/readchar-4.2.1-py3-none-any.whl", hash = "sha256:a769305cd3994bb5fa2764aa4073452dc105a4ec39068ffe6efd3c20c60acc77", size = 9350 },
784
+ ]
785
+
777
786
  [[package]]
778
787
  name = "rich"
779
788
  version = "14.0.0"
@@ -854,7 +863,7 @@ wheels = [
854
863
 
855
864
  [[package]]
856
865
  name = "sqlsaber"
857
- version = "0.4.1"
866
+ version = "0.6.0"
858
867
  source = { editable = "." }
859
868
  dependencies = [
860
869
  { name = "aiomysql" },
@@ -869,6 +878,7 @@ dependencies = [
869
878
  { name = "questionary" },
870
879
  { name = "rich" },
871
880
  { name = "typer" },
881
+ { name = "uniplot" },
872
882
  ]
873
883
 
874
884
  [package.dev-dependencies]
@@ -892,6 +902,7 @@ requires-dist = [
892
902
  { name = "questionary", specifier = ">=2.1.0" },
893
903
  { name = "rich", specifier = ">=13.7.0" },
894
904
  { name = "typer", specifier = ">=0.16.0" },
905
+ { name = "uniplot", specifier = ">=0.21.2" },
895
906
  ]
896
907
 
897
908
  [package.metadata.requires-dev]
@@ -971,6 +982,19 @@ wheels = [
971
982
  { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839 },
972
983
  ]
973
984
 
985
+ [[package]]
986
+ name = "uniplot"
987
+ version = "0.21.2"
988
+ source = { registry = "https://pypi.org/simple" }
989
+ dependencies = [
990
+ { name = "numpy" },
991
+ { name = "readchar" },
992
+ ]
993
+ sdist = { url = "https://files.pythonhosted.org/packages/87/65/b9db385152a5283c88f955710123c6539a7c79436d2de377b3449995b041/uniplot-0.21.2.tar.gz", hash = "sha256:fc350d6e0f2352822747a3426fef7f521d1b3973585ad2e2967c702dfc6e8440", size = 33412 }
994
+ wheels = [
995
+ { url = "https://files.pythonhosted.org/packages/3a/0e/0b2e41841eb18017e7e125bc8294180d2597a4ca049641068f55355bcc69/uniplot-0.21.2-py3-none-any.whl", hash = "sha256:cae5875eac0d06fd75cbb7076ea3fa49565ef1d71140f0b9a39f7be96085536b", size = 36419 },
996
+ ]
997
+
974
998
  [[package]]
975
999
  name = "uvicorn"
976
1000
  version = "0.34.3"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes