sqlsaber 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

sqlsaber/agents/base.py CHANGED
@@ -3,7 +3,7 @@
3
3
  import asyncio
4
4
  import json
5
5
  from abc import ABC, abstractmethod
6
- from typing import Any, AsyncIterator, Dict, List, Optional
6
+ from typing import Any, AsyncIterator
7
7
 
8
8
  from uniplot import histogram, plot
9
9
 
@@ -24,7 +24,7 @@ class BaseSQLAgent(ABC):
24
24
  def __init__(self, db_connection: BaseDatabaseConnection):
25
25
  self.db = db_connection
26
26
  self.schema_manager = SchemaManager(db_connection)
27
- self.conversation_history: List[Dict[str, Any]] = []
27
+ self.conversation_history: list[dict[str, Any]] = []
28
28
 
29
29
  @abstractmethod
30
30
  async def query_stream(
@@ -59,7 +59,7 @@ class BaseSQLAgent(ABC):
59
59
  else:
60
60
  return "database" # Fallback
61
61
 
62
- async def introspect_schema(self, table_pattern: Optional[str] = None) -> str:
62
+ async def introspect_schema(self, table_pattern: str | None = None) -> str:
63
63
  """Introspect database schema to understand table structures."""
64
64
  try:
65
65
  # Pass table_pattern to get_schema_info for efficient filtering at DB level
@@ -96,7 +96,7 @@ class BaseSQLAgent(ABC):
96
96
  except Exception as e:
97
97
  return json.dumps({"error": f"Error listing tables: {str(e)}"})
98
98
 
99
- async def execute_sql(self, query: str, limit: Optional[int] = None) -> str:
99
+ async def execute_sql(self, query: str, limit: int | None = None) -> str:
100
100
  """Execute a SQL query against the database."""
101
101
  try:
102
102
  # Security check - only allow SELECT queries unless write is enabled
@@ -147,7 +147,7 @@ class BaseSQLAgent(ABC):
147
147
  return json.dumps({"error": error_msg, "suggestions": suggestions})
148
148
 
149
149
  async def process_tool_call(
150
- self, tool_name: str, tool_input: Dict[str, Any]
150
+ self, tool_name: str, tool_input: dict[str, Any]
151
151
  ) -> str:
152
152
  """Process a tool call and return the result."""
153
153
  if tool_name == "list_tables":
@@ -170,7 +170,7 @@ class BaseSQLAgent(ABC):
170
170
  else:
171
171
  return json.dumps({"error": f"Unknown tool: {tool_name}"})
172
172
 
173
- def _validate_write_operation(self, query: str) -> Optional[str]:
173
+ def _validate_write_operation(self, query: str) -> str | None:
174
174
  """Validate if a write operation is allowed.
175
175
 
176
176
  Returns:
@@ -206,12 +206,12 @@ class BaseSQLAgent(ABC):
206
206
 
207
207
  async def plot_data(
208
208
  self,
209
- y_values: List[float],
210
- x_values: Optional[List[float]] = None,
209
+ y_values: list[float],
210
+ x_values: list[float] | None = None,
211
211
  plot_type: str = "line",
212
- title: Optional[str] = None,
213
- x_label: Optional[str] = None,
214
- y_label: Optional[str] = None,
212
+ title: str | None = None,
213
+ x_label: str | None = None,
214
+ y_label: str | None = None,
215
215
  ) -> str:
216
216
  """Create a terminal plot using uniplot.
217
217
 
@@ -1,16 +1,16 @@
1
1
  """Streaming utilities for agents."""
2
2
 
3
- from typing import Any, Dict, List
3
+ from typing import Any
4
4
 
5
5
 
6
6
  class StreamingResponse:
7
7
  """Helper class to manage streaming response construction."""
8
8
 
9
- def __init__(self, content: List[Dict[str, Any]], stop_reason: str):
9
+ def __init__(self, content: list[dict[str, Any]], stop_reason: str):
10
10
  self.content = content
11
11
  self.stop_reason = stop_reason
12
12
 
13
13
 
14
- def build_tool_result_block(tool_use_id: str, content: str) -> Dict[str, Any]:
14
+ def build_tool_result_block(tool_use_id: str, content: str) -> dict[str, Any]:
15
15
  """Build a tool result block for the conversation."""
16
16
  return {"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
sqlsaber/cli/auth.py ADDED
@@ -0,0 +1,142 @@
1
+ """Authentication CLI commands."""
2
+
3
+ import questionary
4
+ import typer
5
+ from rich.console import Console
6
+
7
+ from sqlsaber.config.auth import AuthConfigManager, AuthMethod
8
+ from sqlsaber.config.oauth_flow import AnthropicOAuthFlow
9
+
10
+ # Global instances for CLI commands
11
+ console = Console()
12
+ config_manager = AuthConfigManager()
13
+
14
+ # Create the authentication management CLI app
15
+ auth_app = typer.Typer(
16
+ name="auth",
17
+ help="Manage authentication configuration",
18
+ add_completion=True,
19
+ )
20
+
21
+
22
+ @auth_app.command("setup")
23
+ def setup_auth():
24
+ """Configure authentication method for SQLSaber."""
25
+ console.print("\n[bold]SQLSaber Authentication Setup[/bold]\n")
26
+
27
+ # Use questionary for selection
28
+ auth_choice = questionary.select(
29
+ "Choose your authentication method:",
30
+ choices=[
31
+ questionary.Choice(
32
+ title="Anthropic API Key",
33
+ value=AuthMethod.API_KEY,
34
+ description="You can create one by visiting https://console.anthropic.com",
35
+ ),
36
+ questionary.Choice(
37
+ title="Claude Pro or Max Subscription",
38
+ value=AuthMethod.CLAUDE_PRO,
39
+ description="This does not require creating an API Key, but requires a subscription at https://claude.ai",
40
+ ),
41
+ ],
42
+ ).ask()
43
+
44
+ if auth_choice is None:
45
+ console.print("[yellow]Setup cancelled.[/yellow]")
46
+ return
47
+
48
+ # Handle auth method setup
49
+ if auth_choice == AuthMethod.API_KEY:
50
+ console.print("\nTo configure your API key, you can either:")
51
+ console.print("• Set the ANTHROPIC_API_KEY environment variable")
52
+ console.print(
53
+ "• Let SQLsaber prompt you for the key when needed (stored securely)"
54
+ )
55
+
56
+ config_manager.set_auth_method(auth_choice)
57
+ console.print("\n[bold green]Authentication method saved![/bold green]")
58
+
59
+ elif auth_choice == AuthMethod.CLAUDE_PRO:
60
+ oauth_flow = AnthropicOAuthFlow()
61
+ try:
62
+ success = oauth_flow.authenticate()
63
+ if success:
64
+ config_manager.set_auth_method(auth_choice)
65
+ console.print(
66
+ "\n[bold green]Authentication setup complete![/bold green]"
67
+ )
68
+ else:
69
+ console.print(
70
+ "\n[yellow]OAuth authentication failed. Please try again.[/yellow]"
71
+ )
72
+ return
73
+ except Exception as e:
74
+ console.print(f"\n[red]Authentication setup failed: {str(e)}[/red]")
75
+ return
76
+
77
+ console.print(
78
+ "You can change this anytime by running [cyan]saber auth setup[/cyan] again."
79
+ )
80
+
81
+
82
+ @auth_app.command("status")
83
+ def show_auth_status():
84
+ """Show current authentication configuration."""
85
+ auth_method = config_manager.get_auth_method()
86
+
87
+ console.print("\n[bold blue]Authentication Status[/bold blue]")
88
+
89
+ if auth_method is None:
90
+ console.print("[yellow]No authentication method configured[/yellow]")
91
+ console.print("Run [cyan]saber auth setup[/cyan] to configure authentication.")
92
+ else:
93
+ if auth_method == AuthMethod.API_KEY:
94
+ console.print("[green]✓ API Key authentication configured[/green]")
95
+ console.print("Using Anthropic API key for authentication")
96
+ elif auth_method == AuthMethod.CLAUDE_PRO:
97
+ console.print("[green]✓ Claude Pro/Max subscription configured[/green]")
98
+
99
+ # Check OAuth token status
100
+ oauth_flow = AnthropicOAuthFlow()
101
+ if oauth_flow.has_valid_authentication():
102
+ console.print("OAuth token is valid and ready to use")
103
+ else:
104
+ console.print("[yellow]OAuth token missing or expired[/yellow]")
105
+
106
+
107
+ @auth_app.command("reset")
108
+ def reset_auth():
109
+ """Reset authentication configuration."""
110
+ if not config_manager.has_auth_configured():
111
+ console.print("[yellow]No authentication configuration to reset.[/yellow]")
112
+ return
113
+
114
+ current_method = config_manager.get_auth_method()
115
+ method_name = (
116
+ "API Key" if current_method == AuthMethod.API_KEY else "Claude Pro/Max"
117
+ )
118
+
119
+ if questionary.confirm(
120
+ f"Are you sure you want to reset the current authentication method ({method_name})?",
121
+ default=False,
122
+ ).ask():
123
+ # If Claude Pro, also remove OAuth tokens
124
+ if current_method == AuthMethod.CLAUDE_PRO:
125
+ oauth_flow = AnthropicOAuthFlow()
126
+ oauth_flow.remove_authentication()
127
+
128
+ # Clear the auth config by setting it to None
129
+ config = config_manager._load_config()
130
+ config["auth_method"] = None
131
+ config_manager._save_config(config)
132
+ console.print("[green]Authentication configuration reset.[/green]")
133
+ console.print(
134
+ "Run [cyan]saber auth setup[/cyan] to configure authentication again."
135
+ )
136
+ else:
137
+ console.print("Reset cancelled.")
138
+
139
+
140
+ def create_auth_app() -> typer.Typer:
141
+ """Return the authentication management CLI app."""
142
+ return auth_app
sqlsaber/cli/commands.py CHANGED
@@ -2,12 +2,12 @@
2
2
 
3
3
  import asyncio
4
4
  from pathlib import Path
5
- from typing import Optional
6
5
 
7
6
  import typer
8
7
  from rich.console import Console
9
8
 
10
9
  from sqlsaber.agents.anthropic import AnthropicSQLAgent
10
+ from sqlsaber.cli.auth import create_auth_app
11
11
  from sqlsaber.cli.database import create_db_app
12
12
  from sqlsaber.cli.interactive import InteractiveSession
13
13
  from sqlsaber.cli.memory import create_memory_app
@@ -29,7 +29,7 @@ config_manager = DatabaseConfigManager()
29
29
 
30
30
  @app.callback()
31
31
  def main_callback(
32
- database: Optional[str] = typer.Option(
32
+ database: str | None = typer.Option(
33
33
  None,
34
34
  "--database",
35
35
  "-d",
@@ -49,11 +49,11 @@ def main_callback(
49
49
 
50
50
  @app.command()
51
51
  def query(
52
- query_text: Optional[str] = typer.Argument(
52
+ query_text: str | None = typer.Argument(
53
53
  None,
54
54
  help="SQL query in natural language (if not provided, starts interactive mode)",
55
55
  ),
56
- database: Optional[str] = typer.Option(
56
+ database: str | None = typer.Option(
57
57
  None,
58
58
  "--database",
59
59
  "-d",
@@ -128,6 +128,7 @@ def query(
128
128
 
129
129
  finally:
130
130
  # Clean up
131
+ await agent.close() # Close the agent's HTTP client
131
132
  await db_conn.close()
132
133
  console.print("\n[green]Goodbye![/green]")
133
134
 
@@ -135,6 +136,10 @@ def query(
135
136
  asyncio.run(run_session())
136
137
 
137
138
 
139
+ # Add authentication management commands
140
+ auth_app = create_auth_app()
141
+ app.add_typer(auth_app, name="auth")
142
+
138
143
  # Add database management commands after main callback is defined
139
144
  db_app = create_db_app()
140
145
  app.add_typer(db_app, name="db")
@@ -0,0 +1,170 @@
1
+ """Command line completers for the CLI interface."""
2
+
3
+ from prompt_toolkit.completion import Completer, Completion
4
+
5
+
6
+ class SlashCommandCompleter(Completer):
7
+ """Custom completer for slash commands."""
8
+
9
+ def get_completions(self, document, complete_event):
10
+ """Get completions for slash commands."""
11
+ # Only provide completions if the line starts with "/"
12
+ text = document.text
13
+ if text.startswith("/"):
14
+ # Get the partial command after the slash
15
+ partial_cmd = text[1:]
16
+
17
+ # Define available commands with descriptions
18
+ commands = [
19
+ ("clear", "Clear conversation history"),
20
+ ("exit", "Exit the interactive session"),
21
+ ("quit", "Exit the interactive session"),
22
+ ]
23
+
24
+ # Yield completions that match the partial command
25
+ for cmd, description in commands:
26
+ if cmd.startswith(partial_cmd):
27
+ yield Completion(
28
+ cmd,
29
+ start_position=-len(partial_cmd),
30
+ display_meta=description,
31
+ )
32
+
33
+
34
+ class TableNameCompleter(Completer):
35
+ """Custom completer for table names."""
36
+
37
+ def __init__(self):
38
+ self._table_cache: list[tuple[str, str]] = []
39
+
40
+ def update_cache(self, tables_data: list[tuple[str, str]]):
41
+ """Update the cache with fresh table data."""
42
+ self._table_cache = tables_data
43
+
44
+ def _get_table_names(self) -> list[tuple[str, str]]:
45
+ """Get table names from cache."""
46
+ return self._table_cache
47
+
48
+ def get_completions(self, document, complete_event):
49
+ """Get completions for table names with fuzzy matching."""
50
+ text = document.text
51
+ cursor_position = document.cursor_position
52
+
53
+ # Find the last "@" before the cursor position
54
+ at_pos = text.rfind("@", 0, cursor_position)
55
+
56
+ if at_pos >= 0:
57
+ # Extract text after the "@" up to the cursor
58
+ partial_table = text[at_pos + 1 : cursor_position].lower()
59
+
60
+ # Check if this looks like a valid table reference context
61
+ # (not inside quotes, and followed by word characters or end of input)
62
+ if self._is_valid_table_context(text, at_pos, cursor_position):
63
+ # Get table names
64
+ tables = self._get_table_names()
65
+
66
+ # Collect matches with scores for ranking
67
+ matches = []
68
+
69
+ for table_name, description in tables:
70
+ table_lower = table_name.lower()
71
+ score = self._calculate_match_score(
72
+ partial_table, table_name, table_lower
73
+ )
74
+
75
+ if score > 0:
76
+ matches.append((score, table_name, description))
77
+
78
+ # Sort by score (higher is better) and yield completions
79
+ matches.sort(key=lambda x: x[0], reverse=True)
80
+
81
+ for score, table_name, description in matches:
82
+ yield Completion(
83
+ table_name,
84
+ start_position=at_pos
85
+ + 1
86
+ - cursor_position, # Start from after the @
87
+ display_meta=description if description else None,
88
+ )
89
+
90
+ def _is_valid_table_context(self, text: str, at_pos: int, cursor_pos: int) -> bool:
91
+ """Check if the @ is in a valid context for table completion."""
92
+ # Simple heuristic: avoid completion inside quoted strings
93
+
94
+ # Count quotes before the @ position
95
+ single_quotes = text[:at_pos].count("'") - text[:at_pos].count("\\'")
96
+ double_quotes = text[:at_pos].count('"') - text[:at_pos].count('\\"')
97
+
98
+ # If we're inside quotes, don't complete
99
+ if single_quotes % 2 == 1 or double_quotes % 2 == 1:
100
+ return False
101
+
102
+ # Check if the character after the cursor (if any) is part of a word
103
+ # This helps avoid breaking existing words
104
+ if cursor_pos < len(text):
105
+ next_char = text[cursor_pos]
106
+ if next_char.isalnum() or next_char == "_":
107
+ # We're in the middle of a word, check if it looks like a table name
108
+ partial = (
109
+ text[at_pos + 1 :].split()[0] if text[at_pos + 1 :].split() else ""
110
+ )
111
+ if not any(c in partial for c in [".", "_"]):
112
+ return False
113
+
114
+ return True
115
+
116
+ def _calculate_match_score(
117
+ self, partial: str, table_name: str, table_lower: str
118
+ ) -> int:
119
+ """Calculate match score for fuzzy matching (higher is better)."""
120
+ if not partial:
121
+ return 1 # Empty search matches everything with low score
122
+
123
+ # Score 100: Exact full name prefix match
124
+ if table_lower.startswith(partial):
125
+ return 100
126
+
127
+ # Score 90: Table name (after schema) prefix match
128
+ if "." in table_name:
129
+ table_part = table_name.split(".")[-1].lower()
130
+ if table_part.startswith(partial):
131
+ return 90
132
+
133
+ # Score 80: Exact table name match (for short names)
134
+ if "." in table_name:
135
+ table_part = table_name.split(".")[-1].lower()
136
+ if table_part == partial:
137
+ return 80
138
+
139
+ # Score 70: Word boundary matches (e.g., "user" matches "user_accounts")
140
+ if "." in table_name:
141
+ table_part = table_name.split(".")[-1].lower()
142
+ if table_part.startswith(partial + "_") or table_part.startswith(
143
+ partial + "-"
144
+ ):
145
+ return 70
146
+
147
+ # Score 50: Substring match in table name part
148
+ if "." in table_name:
149
+ table_part = table_name.split(".")[-1].lower()
150
+ if partial in table_part:
151
+ return 50
152
+
153
+ # Score 30: Substring match in full name
154
+ if partial in table_lower:
155
+ return 30
156
+
157
+ # Score 0: No match
158
+ return 0
159
+
160
+
161
+ class CompositeCompleter(Completer):
162
+ """Combines multiple completers."""
163
+
164
+ def __init__(self, *completers: Completer):
165
+ self.completers = completers
166
+
167
+ def get_completions(self, document, complete_event):
168
+ """Get completions from all registered completers."""
169
+ for completer in self.completers:
170
+ yield from completer.get_completions(document, complete_event)
sqlsaber/cli/database.py CHANGED
@@ -3,7 +3,6 @@
3
3
  import asyncio
4
4
  import getpass
5
5
  from pathlib import Path
6
- from typing import Optional
7
6
 
8
7
  import questionary
9
8
  import typer
@@ -34,24 +33,24 @@ def add_database(
34
33
  "-t",
35
34
  help="Database type (postgresql, mysql, sqlite)",
36
35
  ),
37
- host: Optional[str] = typer.Option(None, "--host", "-h", help="Database host"),
38
- port: Optional[int] = typer.Option(None, "--port", "-p", help="Database port"),
39
- database: Optional[str] = typer.Option(
36
+ host: str | None = typer.Option(None, "--host", "-h", help="Database host"),
37
+ port: int | None = typer.Option(None, "--port", "-p", help="Database port"),
38
+ database: str | None = typer.Option(
40
39
  None, "--database", "--db", help="Database name"
41
40
  ),
42
- username: Optional[str] = typer.Option(None, "--username", "-u", help="Username"),
43
- ssl_mode: Optional[str] = typer.Option(
41
+ username: str | None = typer.Option(None, "--username", "-u", help="Username"),
42
+ ssl_mode: str | None = typer.Option(
44
43
  None,
45
44
  "--ssl-mode",
46
45
  help="SSL mode (disable, allow, prefer, require, verify-ca, verify-full for PostgreSQL; DISABLED, PREFERRED, REQUIRED, VERIFY_CA, VERIFY_IDENTITY for MySQL)",
47
46
  ),
48
- ssl_ca: Optional[str] = typer.Option(
47
+ ssl_ca: str | None = typer.Option(
49
48
  None, "--ssl-ca", help="SSL CA certificate file path"
50
49
  ),
51
- ssl_cert: Optional[str] = typer.Option(
50
+ ssl_cert: str | None = typer.Option(
52
51
  None, "--ssl-cert", help="SSL client certificate file path"
53
52
  ),
54
- ssl_key: Optional[str] = typer.Option(
53
+ ssl_key: str | None = typer.Option(
55
54
  None, "--ssl-key", help="SSL client private key file path"
56
55
  ),
57
56
  interactive: bool = typer.Option(
@@ -310,7 +309,7 @@ def set_default_database(
310
309
 
311
310
  @db_app.command("test")
312
311
  def test_database(
313
- name: Optional[str] = typer.Argument(
312
+ name: str | None = typer.Argument(
314
313
  None,
315
314
  help="Name of the database connection to test (uses default if not specified)",
316
315
  ),
sqlsaber/cli/display.py CHANGED
@@ -1,9 +1,9 @@
1
1
  """Display utilities for the CLI interface."""
2
2
 
3
3
  import json
4
- from typing import Optional
5
4
 
6
5
  from rich.console import Console
6
+ from rich.markdown import Markdown
7
7
  from rich.syntax import Syntax
8
8
  from rich.table import Table
9
9
 
@@ -18,7 +18,7 @@ class DisplayManager:
18
18
  self,
19
19
  columns: list,
20
20
  header_style: str = "bold blue",
21
- title: Optional[str] = None,
21
+ title: str | None = None,
22
22
  ) -> Table:
23
23
  """Create a Rich table with specified columns."""
24
24
  table = Table(show_header=True, header_style=header_style, title=title)
@@ -31,12 +31,9 @@ class DisplayManager:
31
31
  table.add_column(col)
32
32
  return table
33
33
 
34
- def show_tool_started(self, tool_name: str):
35
- """Display tool started message."""
36
- self.console.print(f"\n[yellow]🔧 Using tool: {tool_name}[/yellow]")
37
-
38
34
  def show_tool_executing(self, tool_name: str, tool_input: dict):
39
35
  """Display tool execution details."""
36
+ self.console.print(f"\n[yellow]🔧 Using tool: {tool_name}[/yellow]")
40
37
  if tool_name == "list_tables":
41
38
  self.console.print("[dim] → Discovering available tables[/dim]")
42
39
  elif tool_name == "introspect_schema":
@@ -45,12 +42,14 @@ class DisplayManager:
45
42
  elif tool_name == "execute_sql":
46
43
  query = tool_input.get("query", "")
47
44
  self.console.print("\n[bold green]Executing SQL:[/bold green]")
45
+ self.show_newline()
48
46
  syntax = Syntax(query, "sql")
49
47
  self.console.print(syntax)
50
48
 
51
49
  def show_text_stream(self, text: str):
52
50
  """Display streaming text."""
53
- self.console.print(text, end="", markup=False)
51
+ if text is not None: # Extra safety check
52
+ self.console.print(text, end="", markup=False)
54
53
 
55
54
  def show_query_results(self, results: list):
56
55
  """Display query results in a formatted table."""
@@ -243,3 +242,24 @@ class DisplayManager:
243
242
  self.show_error("Failed to parse plot result")
244
243
  except Exception as e:
245
244
  self.show_error(f"Error displaying plot: {str(e)}")
245
+
246
+ def show_markdown_response(self, content: list):
247
+ """Display the assistant's response as rich markdown."""
248
+ if not content:
249
+ return
250
+
251
+ # Extract text from content blocks
252
+ text_parts = []
253
+ for block in content:
254
+ if isinstance(block, dict) and block.get("type") == "text":
255
+ text = block.get("text", "")
256
+ if text:
257
+ text_parts.append(text)
258
+
259
+ # Join all text parts and display as markdown
260
+ full_text = "".join(text_parts).strip()
261
+ if full_text:
262
+ self.console.print() # Add spacing before markdown
263
+ markdown = Markdown(full_text)
264
+ self.console.print(markdown)
265
+ self.console.print() # Add spacing after markdown