sqlsaber 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

@@ -1,7 +1,8 @@
1
1
  """Anthropic-specific SQL agent implementation."""
2
2
 
3
+ import asyncio
3
4
  import json
4
- from typing import Any, AsyncIterator, Dict, List, Optional
5
+ from typing import Any, AsyncIterator, Dict, List
5
6
 
6
7
  from anthropic import AsyncAnthropic
7
8
 
@@ -21,7 +22,7 @@ class AnthropicSQLAgent(BaseSQLAgent):
21
22
  """SQL Agent using Anthropic SDK directly."""
22
23
 
23
24
  def __init__(
24
- self, db_connection: BaseDatabaseConnection, database_name: Optional[str] = None
25
+ self, db_connection: BaseDatabaseConnection, database_name: str | None = None
25
26
  ):
26
27
  super().__init__(db_connection)
27
28
 
@@ -164,7 +165,7 @@ Guidelines:
164
165
 
165
166
  return base_prompt
166
167
 
167
- def add_memory(self, content: str) -> Optional[str]:
168
+ def add_memory(self, content: str) -> str | None:
168
169
  """Add a memory for the current database."""
169
170
  if not self.database_name:
170
171
  return None
@@ -174,7 +175,7 @@ Guidelines:
174
175
  self.system_prompt = self._build_system_prompt()
175
176
  return memory.id
176
177
 
177
- async def execute_sql(self, query: str, limit: Optional[int] = 100) -> str:
178
+ async def execute_sql(self, query: str, limit: int | None = None) -> str:
178
179
  """Execute a SQL query against the database with streaming support."""
179
180
  # Call parent implementation for core functionality
180
181
  result = await super().execute_sql(query, limit)
@@ -203,10 +204,18 @@ Guidelines:
203
204
  return await super().process_tool_call(tool_name, tool_input)
204
205
 
205
206
  async def _process_stream_events(
206
- self, stream, content_blocks: List[Dict], tool_use_blocks: List[Dict]
207
+ self,
208
+ stream,
209
+ content_blocks: List[Dict],
210
+ tool_use_blocks: List[Dict],
211
+ cancellation_token: asyncio.Event | None = None,
207
212
  ) -> AsyncIterator[StreamEvent]:
208
213
  """Process stream events and yield appropriate StreamEvents."""
209
214
  async for event in stream:
215
+ # Only check cancellation if token is provided
216
+ if cancellation_token is not None and cancellation_token.is_set():
217
+ return
218
+
210
219
  if event.type == "content_block_start":
211
220
  if hasattr(event.content_block, "type"):
212
221
  if event.content_block.type == "tool_use":
@@ -253,11 +262,17 @@ Guidelines:
253
262
  return "stop"
254
263
 
255
264
  async def _process_tool_results(
256
- self, response: StreamingResponse
265
+ self,
266
+ response: StreamingResponse,
267
+ cancellation_token: asyncio.Event | None = None,
257
268
  ) -> AsyncIterator[StreamEvent]:
258
269
  """Process tool results and yield appropriate events."""
259
270
  tool_results = []
260
271
  for block in response.content:
272
+ # Only check cancellation if token is provided
273
+ if cancellation_token is not None and cancellation_token.is_set():
274
+ return
275
+
261
276
  if block.get("type") == "tool_use":
262
277
  yield StreamEvent(
263
278
  "tool_use",
@@ -304,7 +319,10 @@ Guidelines:
304
319
  yield StreamEvent("tool_result_data", tool_results)
305
320
 
306
321
  async def query_stream(
307
- self, user_query: str, use_history: bool = True
322
+ self,
323
+ user_query: str,
324
+ use_history: bool = True,
325
+ cancellation_token: asyncio.Event | None = None,
308
326
  ) -> AsyncIterator[StreamEvent]:
309
327
  """Process a user query and stream responses."""
310
328
  # Initialize for tracking state
@@ -322,7 +340,11 @@ Guidelines:
322
340
  try:
323
341
  # Create initial stream and get response
324
342
  response = None
325
- async for event in self._create_and_process_stream(messages):
343
+ async for event in self._create_and_process_stream(
344
+ messages, cancellation_token
345
+ ):
346
+ if cancellation_token is not None and cancellation_token.is_set():
347
+ return
326
348
  if event.type == "response_ready":
327
349
  response = event.data
328
350
  else:
@@ -332,14 +354,21 @@ Guidelines:
332
354
 
333
355
  # Process tool calls if needed
334
356
  while response is not None and response.stop_reason == "tool_use":
357
+ # Check for cancellation at the start of tool cycle
358
+ if cancellation_token is not None and cancellation_token.is_set():
359
+ return
360
+
335
361
  # Add assistant's response to conversation
336
362
  collected_content.append(
337
363
  {"role": "assistant", "content": response.content}
338
364
  )
339
365
 
340
- # Process tool results
366
+ # Process tool results - DO NOT check cancellation during tool execution
367
+ # as this would break the tool_use -> tool_result API contract
341
368
  tool_results = []
342
- async for event in self._process_tool_results(response):
369
+ async for event in self._process_tool_results(
370
+ response, None
371
+ ): # Pass None to disable cancellation checks
343
372
  if event.type == "tool_result_data":
344
373
  tool_results = event.data
345
374
  else:
@@ -347,6 +376,12 @@ Guidelines:
347
376
 
348
377
  # Continue conversation with tool results
349
378
  collected_content.append({"role": "user", "content": tool_results})
379
+ if use_history:
380
+ self.conversation_history.extend(collected_content)
381
+
382
+ # Check for cancellation AFTER tool results are complete
383
+ if cancellation_token is not None and cancellation_token.is_set():
384
+ return
350
385
 
351
386
  # Signal that we're processing the tool results
352
387
  yield StreamEvent("processing", "Analyzing results...")
@@ -354,8 +389,10 @@ Guidelines:
354
389
  # Get next response
355
390
  response = None
356
391
  async for event in self._create_and_process_stream(
357
- messages + collected_content
392
+ messages + collected_content, cancellation_token
358
393
  ):
394
+ if cancellation_token is not None and cancellation_token.is_set():
395
+ return
359
396
  if event.type == "response_ready":
360
397
  response = event.data
361
398
  else:
@@ -363,21 +400,19 @@ Guidelines:
363
400
 
364
401
  # Update conversation history if using history
365
402
  if use_history:
366
- self.conversation_history.append(
367
- {"role": "user", "content": user_query}
368
- )
369
- self.conversation_history.extend(collected_content)
370
403
  # Add final assistant response
371
404
  if response is not None:
372
405
  self.conversation_history.append(
373
406
  {"role": "assistant", "content": response.content}
374
407
  )
375
408
 
409
+ except asyncio.CancelledError:
410
+ return
376
411
  except Exception as e:
377
412
  yield StreamEvent("error", str(e))
378
413
 
379
414
  async def _create_and_process_stream(
380
- self, messages: List[Dict]
415
+ self, messages: List[Dict], cancellation_token: asyncio.Event | None = None
381
416
  ) -> AsyncIterator[StreamEvent]:
382
417
  """Create a stream and yield events while building response."""
383
418
  stream = await self.client.messages.create(
@@ -393,8 +428,11 @@ Guidelines:
393
428
  tool_use_blocks = []
394
429
 
395
430
  async for event in self._process_stream_events(
396
- stream, content_blocks, tool_use_blocks
431
+ stream, content_blocks, tool_use_blocks, cancellation_token
397
432
  ):
433
+ # Only check cancellation if token is provided
434
+ if cancellation_token is not None and cancellation_token.is_set():
435
+ return
398
436
  yield event
399
437
 
400
438
  # Finalize tool blocks and create response
sqlsaber/agents/base.py CHANGED
@@ -1,5 +1,6 @@
1
1
  """Abstract base class for SQL agents."""
2
2
 
3
+ import asyncio
3
4
  import json
4
5
  from abc import ABC, abstractmethod
5
6
  from typing import Any, AsyncIterator, Dict, List, Optional
@@ -27,9 +28,18 @@ class BaseSQLAgent(ABC):
27
28
 
28
29
  @abstractmethod
29
30
  async def query_stream(
30
- self, user_query: str, use_history: bool = True
31
+ self,
32
+ user_query: str,
33
+ use_history: bool = True,
34
+ cancellation_token: asyncio.Event | None = None,
31
35
  ) -> AsyncIterator[StreamEvent]:
32
- """Process a user query and stream responses."""
36
+ """Process a user query and stream responses.
37
+
38
+ Args:
39
+ user_query: The user's query to process
40
+ use_history: Whether to include conversation history
41
+ cancellation_token: Optional event to signal cancellation
42
+ """
33
43
  pass
34
44
 
35
45
  def clear_history(self):
@@ -86,7 +96,7 @@ class BaseSQLAgent(ABC):
86
96
  except Exception as e:
87
97
  return json.dumps({"error": f"Error listing tables: {str(e)}"})
88
98
 
89
- async def execute_sql(self, query: str, limit: Optional[int] = 100) -> str:
99
+ async def execute_sql(self, query: str, limit: Optional[int] = None) -> str:
90
100
  """Execute a SQL query against the database."""
91
101
  try:
92
102
  # Security check - only allow SELECT queries unless write is enabled
@@ -0,0 +1,172 @@
1
+ """Command line completers for the CLI interface."""
2
+
3
+ from typing import List, Tuple
4
+
5
+ from prompt_toolkit.completion import Completer, Completion
6
+
7
+
8
+ class SlashCommandCompleter(Completer):
9
+ """Custom completer for slash commands."""
10
+
11
+ def get_completions(self, document, complete_event):
12
+ """Get completions for slash commands."""
13
+ # Only provide completions if the line starts with "/"
14
+ text = document.text
15
+ if text.startswith("/"):
16
+ # Get the partial command after the slash
17
+ partial_cmd = text[1:]
18
+
19
+ # Define available commands with descriptions
20
+ commands = [
21
+ ("clear", "Clear conversation history"),
22
+ ("exit", "Exit the interactive session"),
23
+ ("quit", "Exit the interactive session"),
24
+ ]
25
+
26
+ # Yield completions that match the partial command
27
+ for cmd, description in commands:
28
+ if cmd.startswith(partial_cmd):
29
+ yield Completion(
30
+ cmd,
31
+ start_position=-len(partial_cmd),
32
+ display_meta=description,
33
+ )
34
+
35
+
36
+ class TableNameCompleter(Completer):
37
+ """Custom completer for table names."""
38
+
39
+ def __init__(self):
40
+ self._table_cache: List[Tuple[str, str]] = []
41
+
42
+ def update_cache(self, tables_data: List[Tuple[str, str]]):
43
+ """Update the cache with fresh table data."""
44
+ self._table_cache = tables_data
45
+
46
+ def _get_table_names(self) -> List[Tuple[str, str]]:
47
+ """Get table names from cache."""
48
+ return self._table_cache
49
+
50
+ def get_completions(self, document, complete_event):
51
+ """Get completions for table names with fuzzy matching."""
52
+ text = document.text
53
+ cursor_position = document.cursor_position
54
+
55
+ # Find the last "@" before the cursor position
56
+ at_pos = text.rfind("@", 0, cursor_position)
57
+
58
+ if at_pos >= 0:
59
+ # Extract text after the "@" up to the cursor
60
+ partial_table = text[at_pos + 1 : cursor_position].lower()
61
+
62
+ # Check if this looks like a valid table reference context
63
+ # (not inside quotes, and followed by word characters or end of input)
64
+ if self._is_valid_table_context(text, at_pos, cursor_position):
65
+ # Get table names
66
+ tables = self._get_table_names()
67
+
68
+ # Collect matches with scores for ranking
69
+ matches = []
70
+
71
+ for table_name, description in tables:
72
+ table_lower = table_name.lower()
73
+ score = self._calculate_match_score(
74
+ partial_table, table_name, table_lower
75
+ )
76
+
77
+ if score > 0:
78
+ matches.append((score, table_name, description))
79
+
80
+ # Sort by score (higher is better) and yield completions
81
+ matches.sort(key=lambda x: x[0], reverse=True)
82
+
83
+ for score, table_name, description in matches:
84
+ yield Completion(
85
+ table_name,
86
+ start_position=at_pos
87
+ + 1
88
+ - cursor_position, # Start from after the @
89
+ display_meta=description if description else None,
90
+ )
91
+
92
+ def _is_valid_table_context(self, text: str, at_pos: int, cursor_pos: int) -> bool:
93
+ """Check if the @ is in a valid context for table completion."""
94
+ # Simple heuristic: avoid completion inside quoted strings
95
+
96
+ # Count quotes before the @ position
97
+ single_quotes = text[:at_pos].count("'") - text[:at_pos].count("\\'")
98
+ double_quotes = text[:at_pos].count('"') - text[:at_pos].count('\\"')
99
+
100
+ # If we're inside quotes, don't complete
101
+ if single_quotes % 2 == 1 or double_quotes % 2 == 1:
102
+ return False
103
+
104
+ # Check if the character after the cursor (if any) is part of a word
105
+ # This helps avoid breaking existing words
106
+ if cursor_pos < len(text):
107
+ next_char = text[cursor_pos]
108
+ if next_char.isalnum() or next_char == "_":
109
+ # We're in the middle of a word, check if it looks like a table name
110
+ partial = (
111
+ text[at_pos + 1 :].split()[0] if text[at_pos + 1 :].split() else ""
112
+ )
113
+ if not any(c in partial for c in [".", "_"]):
114
+ return False
115
+
116
+ return True
117
+
118
+ def _calculate_match_score(
119
+ self, partial: str, table_name: str, table_lower: str
120
+ ) -> int:
121
+ """Calculate match score for fuzzy matching (higher is better)."""
122
+ if not partial:
123
+ return 1 # Empty search matches everything with low score
124
+
125
+ # Score 100: Exact full name prefix match
126
+ if table_lower.startswith(partial):
127
+ return 100
128
+
129
+ # Score 90: Table name (after schema) prefix match
130
+ if "." in table_name:
131
+ table_part = table_name.split(".")[-1].lower()
132
+ if table_part.startswith(partial):
133
+ return 90
134
+
135
+ # Score 80: Exact table name match (for short names)
136
+ if "." in table_name:
137
+ table_part = table_name.split(".")[-1].lower()
138
+ if table_part == partial:
139
+ return 80
140
+
141
+ # Score 70: Word boundary matches (e.g., "user" matches "user_accounts")
142
+ if "." in table_name:
143
+ table_part = table_name.split(".")[-1].lower()
144
+ if table_part.startswith(partial + "_") or table_part.startswith(
145
+ partial + "-"
146
+ ):
147
+ return 70
148
+
149
+ # Score 50: Substring match in table name part
150
+ if "." in table_name:
151
+ table_part = table_name.split(".")[-1].lower()
152
+ if partial in table_part:
153
+ return 50
154
+
155
+ # Score 30: Substring match in full name
156
+ if partial in table_lower:
157
+ return 30
158
+
159
+ # Score 0: No match
160
+ return 0
161
+
162
+
163
+ class CompositeCompleter(Completer):
164
+ """Combines multiple completers."""
165
+
166
+ def __init__(self, *completers: Completer):
167
+ self.completers = completers
168
+
169
+ def get_completions(self, document, complete_event):
170
+ """Get completions from all registered completers."""
171
+ for completer in self.completers:
172
+ yield from completer.get_completions(document, complete_event)
sqlsaber/cli/display.py CHANGED
@@ -4,6 +4,7 @@ import json
4
4
  from typing import Optional
5
5
 
6
6
  from rich.console import Console
7
+ from rich.markdown import Markdown
7
8
  from rich.syntax import Syntax
8
9
  from rich.table import Table
9
10
 
@@ -62,12 +63,20 @@ class DisplayManager:
62
63
  )
63
64
 
64
65
  # Create table with columns from first result
65
- columns = list(results[0].keys())
66
- table = self._create_table(columns)
66
+ all_columns = list(results[0].keys())
67
+ display_columns = all_columns[:15] # Limit to first 15 columns
68
+
69
+ # Show warning if columns were truncated
70
+ if len(all_columns) > 15:
71
+ self.console.print(
72
+ f"[yellow]Note: Showing first 15 of {len(all_columns)} columns[/yellow]"
73
+ )
74
+
75
+ table = self._create_table(display_columns)
67
76
 
68
77
  # Add rows (show first 20 rows)
69
78
  for row in results[:20]:
70
- table.add_row(*[str(row[key]) for key in columns])
79
+ table.add_row(*[str(row[key]) for key in display_columns])
71
80
 
72
81
  self.console.print(table)
73
82
 
@@ -235,3 +244,24 @@ class DisplayManager:
235
244
  self.show_error("Failed to parse plot result")
236
245
  except Exception as e:
237
246
  self.show_error(f"Error displaying plot: {str(e)}")
247
+
248
+ def show_markdown_response(self, content: list):
249
+ """Display the assistant's response as rich markdown."""
250
+ if not content:
251
+ return
252
+
253
+ # Extract text from content blocks
254
+ text_parts = []
255
+ for block in content:
256
+ if isinstance(block, dict) and block.get("type") == "text":
257
+ text = block.get("text", "")
258
+ if text:
259
+ text_parts.append(text)
260
+
261
+ # Join all text parts and display as markdown
262
+ full_text = "".join(text_parts).strip()
263
+ if full_text:
264
+ self.console.print() # Add spacing before markdown
265
+ markdown = Markdown(full_text)
266
+ self.console.print(markdown)
267
+ self.console.print() # Add spacing after markdown
@@ -1,10 +1,18 @@
1
1
  """Interactive mode handling for the CLI."""
2
2
 
3
+ import asyncio
4
+ from typing import Optional
5
+
3
6
  import questionary
4
7
  from rich.console import Console
5
8
  from rich.panel import Panel
6
9
 
7
10
  from sqlsaber.agents.base import BaseSQLAgent
11
+ from sqlsaber.cli.completers import (
12
+ CompositeCompleter,
13
+ SlashCommandCompleter,
14
+ TableNameCompleter,
15
+ )
8
16
  from sqlsaber.cli.display import DisplayManager
9
17
  from sqlsaber.cli.streaming import StreamingQueryHandler
10
18
 
@@ -17,6 +25,9 @@ class InteractiveSession:
17
25
  self.agent = agent
18
26
  self.display = DisplayManager(console)
19
27
  self.streaming_handler = StreamingQueryHandler(console)
28
+ self.current_task: Optional[asyncio.Task] = None
29
+ self.cancellation_token: Optional[asyncio.Event] = None
30
+ self.table_completer = TableNameCompleter()
20
31
 
21
32
  def show_welcome_message(self):
22
33
  """Display welcome message for interactive mode."""
@@ -28,8 +39,9 @@ class InteractiveSession:
28
39
  Panel.fit(
29
40
  "[bold green]SQLSaber - Use the agent Luke![/bold green]\n\n"
30
41
  "[bold]Your agentic SQL assistant.[/bold]\n\n\n"
31
- "[dim]Use 'clear' to reset conversation, 'exit' or 'quit' to leave.[/dim]\n\n"
32
- "[dim]Start a message with '#' to add something to agent's memory for this database.[/dim]",
42
+ "[dim]Use '/clear' to reset conversation, '/exit' or '/quit' to leave.[/dim]\n\n"
43
+ "[dim]Start a message with '#' to add something to agent's memory for this database.[/dim]\n\n"
44
+ "[dim]Type '@' to get table name completions.[/dim]",
33
45
  border_style="green",
34
46
  )
35
47
  )
@@ -38,12 +50,71 @@ class InteractiveSession:
38
50
  )
39
51
  self.console.print(
40
52
  "[dim]Press Esc-Enter or Meta-Enter to submit your query.[/dim]\n"
53
+ "[dim]Press Ctrl+C during query execution to interrupt and return to prompt.[/dim]\n"
41
54
  )
42
55
 
56
+ async def _update_table_cache(self):
57
+ """Update the table completer cache with fresh data."""
58
+ try:
59
+ # Use the schema manager directly which has built-in caching
60
+ tables_data = await self.agent.schema_manager.list_tables()
61
+
62
+ # Parse the table information
63
+ table_list = []
64
+ if isinstance(tables_data, dict) and "tables" in tables_data:
65
+ for table in tables_data["tables"]:
66
+ if isinstance(table, dict):
67
+ name = table.get("name", "")
68
+ schema = table.get("schema", "")
69
+ full_name = table.get("full_name", "")
70
+
71
+ # Use full_name if available, otherwise construct it
72
+ if full_name:
73
+ table_name = full_name
74
+ elif schema and schema != "main":
75
+ table_name = f"{schema}.{name}"
76
+ else:
77
+ table_name = name
78
+
79
+ # No description needed - cleaner completions
80
+ table_list.append((table_name, ""))
81
+
82
+ # Update the completer cache
83
+ self.table_completer.update_cache(table_list)
84
+
85
+ except Exception:
86
+ # If there's an error, just use empty cache
87
+ self.table_completer.update_cache([])
88
+
89
+ async def _execute_query_with_cancellation(self, user_query: str):
90
+ """Execute a query with cancellation support."""
91
+ # Create cancellation token
92
+ self.cancellation_token = asyncio.Event()
93
+
94
+ # Create the query task
95
+ query_task = asyncio.create_task(
96
+ self.streaming_handler.execute_streaming_query(
97
+ user_query, self.agent, self.cancellation_token
98
+ )
99
+ )
100
+ self.current_task = query_task
101
+
102
+ try:
103
+ # Simply await the query task
104
+ # Ctrl+C will be handled by the KeyboardInterrupt exception in run()
105
+ await query_task
106
+
107
+ finally:
108
+ self.current_task = None
109
+ self.cancellation_token = None
110
+
43
111
  async def run(self):
44
112
  """Run the interactive session loop."""
45
113
  self.show_welcome_message()
46
114
 
115
+ # Initialize table cache
116
+ await self._update_table_cache()
117
+
47
118
  while True:
48
119
  try:
49
120
  user_query = await questionary.text(
@@ -51,12 +122,18 @@ class InteractiveSession:
51
122
  qmark="",
52
123
  multiline=True,
53
124
  instruction="",
125
+ completer=CompositeCompleter(
126
+ SlashCommandCompleter(), self.table_completer
127
+ ),
54
128
  ).ask_async()
55
129
 
56
- if user_query.lower() in ["exit", "quit", "q"]:
130
+ if not user_query:
131
+ continue
132
+
133
+ if user_query in ["/exit", "/quit"]:
57
134
  break
58
135
 
59
- if user_query.lower() == "clear":
136
+ if user_query == "/clear":
60
137
  self.agent.clear_history()
61
138
  self.console.print("[green]Conversation history cleared.[/green]\n")
62
139
  continue
@@ -85,12 +162,24 @@ class InteractiveSession:
85
162
  )
86
163
  continue
87
164
 
88
- await self.streaming_handler.execute_streaming_query(
89
- user_query, self.agent
90
- )
165
+ # Execute query with cancellation support
166
+ await self._execute_query_with_cancellation(user_query)
91
167
  self.display.show_newline() # Empty line for readability
92
168
 
93
169
  except KeyboardInterrupt:
94
- self.console.print("\n[yellow]Use 'exit' or 'quit' to leave.[/yellow]")
170
+ # Handle Ctrl+C - cancel current task if running
171
+ if self.current_task and not self.current_task.done():
172
+ if self.cancellation_token is not None:
173
+ self.cancellation_token.set()
174
+ self.current_task.cancel()
175
+ try:
176
+ await self.current_task
177
+ except asyncio.CancelledError:
178
+ pass
179
+ self.console.print("\n[yellow]Query interrupted[/yellow]")
180
+ else:
181
+ self.console.print(
182
+ "\n[yellow]Use '/exit' or '/quit' to leave.[/yellow]"
183
+ )
95
184
  except Exception as e:
96
185
  self.console.print(f"[bold red]Error:[/bold red] {str(e)}")
sqlsaber/cli/streaming.py CHANGED
@@ -1,5 +1,7 @@
1
1
  """Streaming query handling for the CLI."""
2
2
 
3
+ import asyncio
4
+
3
5
  from rich.console import Console
4
6
 
5
7
  from sqlsaber.agents.base import BaseSQLAgent
@@ -13,7 +15,12 @@ class StreamingQueryHandler:
13
15
  self.console = console
14
16
  self.display = DisplayManager(console)
15
17
 
16
- async def execute_streaming_query(self, user_query: str, agent: BaseSQLAgent):
18
+ async def execute_streaming_query(
19
+ self,
20
+ user_query: str,
21
+ agent: BaseSQLAgent,
22
+ cancellation_token: asyncio.Event | None = None,
23
+ ):
17
24
  """Execute a query with streaming display."""
18
25
 
19
26
  has_content = False
@@ -24,7 +31,12 @@ class StreamingQueryHandler:
24
31
  status.start()
25
32
 
26
33
  try:
27
- async for event in agent.query_stream(user_query):
34
+ async for event in agent.query_stream(
35
+ user_query, cancellation_token=cancellation_token
36
+ ):
37
+ if cancellation_token is not None and cancellation_token.is_set():
38
+ break
39
+
28
40
  if event.type == "tool_use":
29
41
  # Stop any ongoing status, but don't mark has_content yet
30
42
  self._stop_status(status)
@@ -83,6 +95,13 @@ class StreamingQueryHandler:
83
95
  has_content = True
84
96
  self.display.show_error(event.data)
85
97
 
98
+ except asyncio.CancelledError:
99
+ # Handle cancellation gracefully
100
+ self._stop_status(status)
101
+ if explanation_started:
102
+ self.display.show_newline()
103
+ self.console.print("[yellow]Query interrupted[/yellow]")
104
+ return
86
105
  finally:
87
106
  # Make sure status is stopped
88
107
  self._stop_status(status)
@@ -91,6 +110,14 @@ class StreamingQueryHandler:
91
110
  if explanation_started:
92
111
  self.display.show_newline() # Empty line for better readability
93
112
 
113
+ # Display the last assistant response as markdown
114
+ if hasattr(agent, "conversation_history") and agent.conversation_history:
115
+ last_message = agent.conversation_history[-1]
116
+ if last_message.get("role") == "assistant" and last_message.get(
117
+ "content"
118
+ ):
119
+ self.display.show_markdown_response(last_message["content"])
120
+
94
121
  def _stop_status(self, status):
95
122
  """Safely stop a status spinner."""
96
123
  try:
@@ -683,6 +683,13 @@ class SchemaManager:
683
683
 
684
684
  async def list_tables(self) -> Dict[str, Any]:
685
685
  """Get a list of all tables with basic information like row counts."""
686
+ # Check cache first
687
+ cache_key = "list_tables"
688
+ cached_data = self._get_cached_tables(cache_key)
689
+ if cached_data is not None:
690
+ return cached_data
691
+
692
+ # Fetch from database if not cached
686
693
  tables = await self.introspector.list_tables_info(self.db)
687
694
 
688
695
  # Format the result
@@ -699,4 +706,14 @@ class SchemaManager:
699
706
  }
700
707
  )
701
708
 
709
+ # Cache the result
710
+ self._schema_cache[cache_key] = (time.time(), result)
702
711
  return result
712
+
713
+ def _get_cached_tables(self, cache_key: str) -> Optional[Dict[str, Any]]:
714
+ """Get table list from cache if available and not expired."""
715
+ if cache_key in self._schema_cache:
716
+ cached_time, cached_data = self._schema_cache[cache_key]
717
+ if time.time() - cached_time < self.cache_ttl:
718
+ return cached_data
719
+ return None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sqlsaber
3
- Version: 0.5.0
3
+ Version: 0.7.0
4
4
  Summary: SQLSaber - Agentic SQL assistant like Claude Code
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.12
@@ -212,23 +212,24 @@ The MCP server uses your existing SQLSaber database configurations, so make sure
212
212
 
213
213
  ## How It Works
214
214
 
215
- SQLSaber uses an intelligent three-step process optimized for minimal token usage:
215
+ SQLSaber uses a multi-step process to gather the right context, provide it to the model, and execute SQL queries to get the right answers:
216
+
217
+ ![](./sqlsaber.svg)
216
218
 
217
219
  ### 🔍 Discovery Phase
218
220
 
219
221
  1. **List Tables Tool**: Quickly discovers available tables with row counts
220
- 2. **Pattern Matching**: Identifies relevant tables based on your query using SQL LIKE patterns
222
+ 2. **Pattern Matching**: Identifies relevant tables based on your query
221
223
 
222
224
  ### 📋 Schema Analysis
223
225
 
224
- 3. **Smart Introspection**: Analyzes only the specific table structures needed for your query
225
- 4. **Selective Loading**: Fetches schema information only for relevant tables
226
+ 3. **Smart Schema Introspection**: Analyzes only the specific table structures needed for your query
226
227
 
227
228
  ### ⚡ Execution Phase
228
229
 
229
- 5. **SQL Generation**: Creates optimized SQL queries based on natural language input
230
- 6. **Safe Execution**: Runs queries with built-in protections against destructive operations
231
- 7. **Result Formatting**: Presents results with syntax highlighting and explanations
230
+ 4. **SQL Generation**: Creates optimized SQL queries based on natural language input
231
+ 5. **Safe Execution**: Runs read-only queries with built-in protections against destructive operations
232
+ 6. **Result Formatting**: Presents results with explanations in tables and optionally, visualizes using plots
232
233
 
233
234
  ## Contributing
234
235
 
@@ -1,25 +1,26 @@
1
1
  sqlsaber/__init__.py,sha256=QCFi8xTVMohelfi7zOV1-6oLCcGoiXoOcKQY-HNBCk8,66
2
2
  sqlsaber/__main__.py,sha256=RIHxWeWh2QvLfah-2OkhI5IJxojWfy4fXpMnVEJYvxw,78
3
3
  sqlsaber/agents/__init__.py,sha256=LWeSeEUE4BhkyAYFF3TE-fx8TtLud3oyEtyB8ojFJgo,167
4
- sqlsaber/agents/anthropic.py,sha256=HxpMV5uiOmxPAUR-ZZw8ncK7lM0aCUvj8mLiyFaUFwE,16268
5
- sqlsaber/agents/base.py,sha256=IYVYYdQgSOQV6o1_3bzOizRuVMDw_aXfiPdFMNHwpg0,10269
4
+ sqlsaber/agents/anthropic.py,sha256=FLVET2HvFmsEuFln9Hu4SaBs-Tnk-GestOgnDnUp3ps,17885
5
+ sqlsaber/agents/base.py,sha256=DAnezHl5RLYoef8XQ-n3KA9PowdrMbQrkjdGKPPnFsI,10570
6
6
  sqlsaber/agents/mcp.py,sha256=FKtXgDrPZ2-xqUYCw2baI5JzrWekXaC5fjkYW1_Mg50,827
7
7
  sqlsaber/agents/streaming.py,sha256=_EO390-FHUrL1fRCNfibtE9QuJz3LGQygbwG3CB2ViY,533
8
8
  sqlsaber/cli/__init__.py,sha256=qVSLVJLLJYzoC6aj6y9MFrzZvAwc4_OgxU9DlkQnZ4M,86
9
9
  sqlsaber/cli/commands.py,sha256=Dw24W0jij-8t1lpk99C4PBTgzFSag6vU-FZcjAYGG54,5074
10
+ sqlsaber/cli/completers.py,sha256=JWOCKAm0Prpy_O2QJsf_VbPWfy2lQQh6KutyG8FU4us,6462
10
11
  sqlsaber/cli/database.py,sha256=DUfyvNBDp47oFM_VAC_hXHQy_qyE7JbXtowflJpwwH8,12643
11
- sqlsaber/cli/display.py,sha256=LhsUSAFbiPBQRtW2JFf8PnpDnF2_kYdVTsB9HYgvxT4,8888
12
- sqlsaber/cli/interactive.py,sha256=Kqe7kN9mhUiY_5z1Ki6apZ9ahs8uzhHp3xMZGiyTXpY,3912
12
+ sqlsaber/cli/display.py,sha256=NIBWHUrX_8ZhDu6iW9v4fzx0zncnXa5WdQ9wfTrjKIM,10017
13
+ sqlsaber/cli/interactive.py,sha256=FvgtT45U-yblhbRImKqJ4jgBRNs0u7NhE2PcgoVUaVA,7429
13
14
  sqlsaber/cli/memory.py,sha256=LW4ZF2V6Gw6hviUFGZ4ym9ostFCwucgBTIMZ3EANO-I,7671
14
15
  sqlsaber/cli/models.py,sha256=3IcXeeU15IQvemSv-V-RQzVytJ3wuQ4YmWk89nTDcSE,7813
15
- sqlsaber/cli/streaming.py,sha256=EpltnkdokN42bczULbP9u_t8zduwhGyV-TWm1h8H-jc,3975
16
+ sqlsaber/cli/streaming.py,sha256=DfwygmjEzAh9hZGKjrW9kS1A7MG5W9Ky_kCTzxziODQ,4970
16
17
  sqlsaber/config/__init__.py,sha256=olwC45k8Nc61yK0WmPUk7XHdbsZH9HuUAbwnmKe3IgA,100
17
18
  sqlsaber/config/api_keys.py,sha256=kLdoExF_My9ojmdhO5Ca7-ZeowsO0v1GVa_QT5jjUPo,3658
18
19
  sqlsaber/config/database.py,sha256=vKFOxPjVakjQhj1uoLcfzhS9ZFr6Z2F5b4MmYALQZoA,11421
19
20
  sqlsaber/config/settings.py,sha256=zjQ7nS3ybcCb88Ea0tmwJox5-q0ettChZw89ZqRVpX8,3975
20
21
  sqlsaber/database/__init__.py,sha256=a_gtKRJnZVO8-fEZI7g3Z8YnGa6Nio-5Y50PgVp07ss,176
21
22
  sqlsaber/database/connection.py,sha256=s8GSFZebB8be8sVUr-N0x88-20YfkfljJFRyfoB1gH0,15154
22
- sqlsaber/database/schema.py,sha256=9QoH-gADzWlepq-tGz3nPU3miSUU0koWmpDaoWvz8Q0,27951
23
+ sqlsaber/database/schema.py,sha256=3CfkyhxgD6SmiUoz7MQPlQLrrA007HOQLnGCvvsdJx0,28647
23
24
  sqlsaber/mcp/__init__.py,sha256=COdWq7wauPBp5Ew8tfZItFzbcLDSEkHBJSMhxzy8C9c,112
24
25
  sqlsaber/mcp/mcp.py,sha256=ACm1P1TnicjOptQgeLNhXg5xgZf4MYq2kqdfVdj6wh0,4477
25
26
  sqlsaber/memory/__init__.py,sha256=GiWkU6f6YYVV0EvvXDmFWe_CxarmDCql05t70MkTEWs,63
@@ -28,8 +29,8 @@ sqlsaber/memory/storage.py,sha256=DvZBsSPaAfk_DqrNEn86uMD-TQsWUI6rQLfNw6PSCB8,57
28
29
  sqlsaber/models/__init__.py,sha256=RJ7p3WtuSwwpFQ1Iw4_DHV2zzCtHqIzsjJzxv8kUjUE,287
29
30
  sqlsaber/models/events.py,sha256=q2FackB60J9-7vegYIjzElLwKebIh7nxnV5AFoZc67c,752
30
31
  sqlsaber/models/types.py,sha256=3U_30n91EB3IglBTHipwiW4MqmmaA2qfshfraMZyPps,896
31
- sqlsaber-0.5.0.dist-info/METADATA,sha256=npxser6DO4GaHpYKsdcHvosseRr7RU0pZ0c81h2t2KA,5969
32
- sqlsaber-0.5.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
33
- sqlsaber-0.5.0.dist-info/entry_points.txt,sha256=jmFo96Ylm0zIKXJBwhv_P5wQ7SXP9qdaBcnTp8iCEe8,195
34
- sqlsaber-0.5.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
35
- sqlsaber-0.5.0.dist-info/RECORD,,
32
+ sqlsaber-0.7.0.dist-info/METADATA,sha256=tUV3WHkVZEXissVrKAaOooaZyn7e_NmMV_e-nNaoLVE,5986
33
+ sqlsaber-0.7.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
34
+ sqlsaber-0.7.0.dist-info/entry_points.txt,sha256=jmFo96Ylm0zIKXJBwhv_P5wQ7SXP9qdaBcnTp8iCEe8,195
35
+ sqlsaber-0.7.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
36
+ sqlsaber-0.7.0.dist-info/RECORD,,