sqlsaber 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

@@ -0,0 +1,132 @@
1
+ """CLI command definitions and handlers."""
2
+
3
+ import asyncio
4
+ from typing import Optional
5
+
6
+ import typer
7
+ from rich.console import Console
8
+
9
+ from sqlsaber.agents.anthropic import AnthropicSQLAgent
10
+ from sqlsaber.cli.database import create_db_app
11
+ from sqlsaber.cli.interactive import InteractiveSession
12
+ from sqlsaber.cli.memory import create_memory_app
13
+ from sqlsaber.cli.models import create_models_app
14
+ from sqlsaber.cli.streaming import StreamingQueryHandler
15
+ from sqlsaber.config.database import DatabaseConfigManager
16
+ from sqlsaber.database.connection import DatabaseConnection
17
+
18
+ app = typer.Typer(
19
+ name="sqlsaber",
20
+ help="SQLSaber - Use the agent Luke!\n\nSQL assistant for your database",
21
+ add_completion=True,
22
+ )
23
+
24
+
25
+ console = Console()
26
+ config_manager = DatabaseConfigManager()
27
+
28
+
29
+ @app.callback()
30
+ def main_callback(
31
+ database: Optional[str] = typer.Option(
32
+ None,
33
+ "--database",
34
+ "-d",
35
+ help="Database connection name (uses default if not specified)",
36
+ ),
37
+ ):
38
+ """
39
+ Query your database using natural language.
40
+
41
+ Examples:
42
+ sb query # Start interactive mode
43
+ sb query "show me all users" # Run a single query with default database
44
+ sb query -d mydb "show me users" # Run a query with specific database
45
+ """
46
+ pass
47
+
48
+
49
+ @app.command()
50
+ def query(
51
+ query_text: Optional[str] = typer.Argument(
52
+ None,
53
+ help="SQL query in natural language (if not provided, starts interactive mode)",
54
+ ),
55
+ database: Optional[str] = typer.Option(
56
+ None,
57
+ "--database",
58
+ "-d",
59
+ help="Database connection name (uses default if not specified)",
60
+ ),
61
+ ):
62
+ """Run a query against the database or start interactive mode."""
63
+
64
+ async def run_session():
65
+ # Get database configuration
66
+ if database:
67
+ db_config = config_manager.get_database(database)
68
+ if not db_config:
69
+ console.print(
70
+ f"[bold red]Error:[/bold red] Database connection '{database}' not found."
71
+ )
72
+ console.print("Use 'sqlsaber db list' to see available connections.")
73
+ raise typer.Exit(1)
74
+ else:
75
+ db_config = config_manager.get_default_database()
76
+ if not db_config:
77
+ console.print(
78
+ "[bold red]Error:[/bold red] No database connections configured."
79
+ )
80
+ console.print(
81
+ "Use 'sqlsaber db add <name>' to add a database connection."
82
+ )
83
+ raise typer.Exit(1)
84
+
85
+ # Create database connection
86
+ try:
87
+ connection_string = db_config.to_connection_string()
88
+ db_conn = DatabaseConnection(connection_string)
89
+ except Exception as e:
90
+ console.print(
91
+ f"[bold red]Error creating database connection:[/bold red] {e}"
92
+ )
93
+ raise typer.Exit(1)
94
+
95
+ # Create agent instance with database name for memory context
96
+ agent = AnthropicSQLAgent(db_conn, db_config.name)
97
+
98
+ try:
99
+ if query_text:
100
+ # Single query mode with streaming
101
+ streaming_handler = StreamingQueryHandler(console)
102
+ await streaming_handler.execute_streaming_query(query_text, agent)
103
+ else:
104
+ # Interactive mode
105
+ session = InteractiveSession(console, agent)
106
+ await session.run()
107
+
108
+ finally:
109
+ # Clean up
110
+ await db_conn.close()
111
+ console.print("\n[green]Goodbye![/green]")
112
+
113
+ # Run the async function
114
+ asyncio.run(run_session())
115
+
116
+
117
+ # Add database management commands after main callback is defined
118
+ db_app = create_db_app()
119
+ app.add_typer(db_app, name="db")
120
+
121
+ # Add memory management commands
122
+ memory_app = create_memory_app()
123
+ app.add_typer(memory_app, name="memory")
124
+
125
+ # Add model management commands
126
+ models_app = create_models_app()
127
+ app.add_typer(models_app, name="models")
128
+
129
+
130
+ def main():
131
+ """Entry point for the CLI application."""
132
+ app()
@@ -0,0 +1,275 @@
1
+ """Database management CLI commands."""
2
+
3
+ import asyncio
4
+ import getpass
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ import questionary
9
+ import typer
10
+ from rich.console import Console
11
+ from rich.table import Table
12
+
13
+ from sqlsaber.config.database import DatabaseConfig, DatabaseConfigManager
14
+ from sqlsaber.database.connection import DatabaseConnection
15
+
16
+ # Global instances for CLI commands
17
+ console = Console()
18
+ config_manager = DatabaseConfigManager()
19
+
20
+ # Create the database management CLI app
21
+ db_app = typer.Typer(
22
+ name="db",
23
+ help="Manage database connections",
24
+ add_completion=True,
25
+ )
26
+
27
+
28
+ @db_app.command("add")
29
+ def add_database(
30
+ name: str = typer.Argument(..., help="Name for the database connection"),
31
+ type: str = typer.Option(
32
+ "postgresql",
33
+ "--type",
34
+ "-t",
35
+ help="Database type (postgresql, mysql, sqlite)",
36
+ ),
37
+ host: Optional[str] = typer.Option(None, "--host", "-h", help="Database host"),
38
+ port: Optional[int] = typer.Option(None, "--port", "-p", help="Database port"),
39
+ database: Optional[str] = typer.Option(
40
+ None, "--database", "--db", help="Database name"
41
+ ),
42
+ username: Optional[str] = typer.Option(None, "--username", "-u", help="Username"),
43
+ interactive: bool = typer.Option(
44
+ True, "--interactive/--no-interactive", help="Use interactive mode"
45
+ ),
46
+ ):
47
+ """Add a new database connection."""
48
+
49
+ if interactive:
50
+ # Interactive mode - prompt for all required fields
51
+ console.print(f"[bold]Adding database connection: {name}[/bold]")
52
+
53
+ # Database type
54
+ if not type or type == "postgresql":
55
+ type = questionary.select(
56
+ "Database type:",
57
+ choices=["postgresql", "mysql", "sqlite"],
58
+ default="postgresql",
59
+ ).ask()
60
+
61
+ if type == "sqlite":
62
+ # SQLite only needs database path
63
+ database = database or questionary.path("Database file path:").ask()
64
+ database = str(Path(database).expanduser())
65
+ host = "localhost"
66
+ port = 0
67
+ username = "sqlite"
68
+ password = ""
69
+ else:
70
+ # PostgreSQL/MySQL need connection details
71
+ host = host or questionary.text("Host:", default="localhost").ask()
72
+
73
+ default_port = 5432 if type == "postgresql" else 3306
74
+ port = port or int(
75
+ questionary.text("Port:", default=str(default_port)).ask()
76
+ )
77
+
78
+ database = database or questionary.text("Database name:").ask()
79
+ username = username or questionary.text("Username:").ask()
80
+
81
+ # Ask for password
82
+ password = getpass.getpass("Password (stored in your OS keychain): ")
83
+ else:
84
+ # Non-interactive mode - use provided values or defaults
85
+ if type == "sqlite":
86
+ if not database:
87
+ console.print(
88
+ "[bold red]Error:[/bold red] Database file path is required for SQLite"
89
+ )
90
+ raise typer.Exit(1)
91
+ host = "localhost"
92
+ port = 0
93
+ username = "sqlite"
94
+ password = ""
95
+ else:
96
+ if not all([host, database, username]):
97
+ console.print(
98
+ "[bold red]Error:[/bold red] Host, database, and username are required"
99
+ )
100
+ raise typer.Exit(1)
101
+
102
+ if port is None:
103
+ port = 5432 if type == "postgresql" else 3306
104
+
105
+ password = (
106
+ getpass.getpass("Password (stored in your OS keychain): ")
107
+ if questionary.confirm("Enter password?").ask()
108
+ else ""
109
+ )
110
+
111
+ # Create database config
112
+ # At this point, all required values should be set
113
+ assert database is not None, "Database should be set by now"
114
+ if type != "sqlite":
115
+ assert host is not None, "Host should be set by now"
116
+ assert port is not None, "Port should be set by now"
117
+ assert username is not None, "Username should be set by now"
118
+
119
+ db_config = DatabaseConfig(
120
+ name=name,
121
+ type=type,
122
+ host=host,
123
+ port=port,
124
+ database=database,
125
+ username=username,
126
+ )
127
+
128
+ try:
129
+ # Add the configuration
130
+ config_manager.add_database(db_config, password if password else None)
131
+ console.print(f"[green]Successfully added database connection '{name}'[/green]")
132
+
133
+ # Set as default if it's the first one
134
+ if len(config_manager.list_databases()) == 1:
135
+ console.print(f"[blue]Set '{name}' as default database[/blue]")
136
+
137
+ except Exception as e:
138
+ console.print(f"[bold red]Error adding database:[/bold red] {e}")
139
+ raise typer.Exit(1)
140
+
141
+
142
+ @db_app.command("list")
143
+ def list_databases():
144
+ """List all configured database connections."""
145
+ databases = config_manager.list_databases()
146
+ default_name = config_manager.get_default_name()
147
+
148
+ if not databases:
149
+ console.print("[yellow]No database connections configured[/yellow]")
150
+ console.print("Use 'sqlsaber db add <name>' to add a database connection")
151
+ return
152
+
153
+ table = Table(title="Database Connections")
154
+ table.add_column("Name", style="cyan")
155
+ table.add_column("Type", style="magenta")
156
+ table.add_column("Host", style="green")
157
+ table.add_column("Port", style="yellow")
158
+ table.add_column("Database", style="blue")
159
+ table.add_column("Username", style="white")
160
+ table.add_column("Default", style="bold red")
161
+
162
+ for db in databases:
163
+ is_default = "✓" if db.name == default_name else ""
164
+ table.add_row(
165
+ db.name,
166
+ db.type,
167
+ db.host,
168
+ str(db.port) if db.port else "",
169
+ db.database,
170
+ db.username,
171
+ is_default,
172
+ )
173
+
174
+ console.print(table)
175
+
176
+
177
+ @db_app.command("remove")
178
+ def remove_database(
179
+ name: str = typer.Argument(..., help="Name of the database connection to remove"),
180
+ ):
181
+ """Remove a database connection."""
182
+ if not config_manager.get_database(name):
183
+ console.print(
184
+ f"[bold red]Error:[/bold red] Database connection '{name}' not found"
185
+ )
186
+ raise typer.Exit(1)
187
+
188
+ if questionary.confirm(
189
+ f"Are you sure you want to remove database connection '{name}'?"
190
+ ).ask():
191
+ if config_manager.remove_database(name):
192
+ console.print(
193
+ f"[green]Successfully removed database connection '{name}'[/green]"
194
+ )
195
+ else:
196
+ console.print(
197
+ f"[bold red]Error:[/bold red] Failed to remove database connection '{name}'"
198
+ )
199
+ raise typer.Exit(1)
200
+ else:
201
+ console.print("Operation cancelled")
202
+
203
+
204
+ @db_app.command("set-default")
205
+ def set_default_database(
206
+ name: str = typer.Argument(
207
+ ..., help="Name of the database connection to set as default"
208
+ ),
209
+ ):
210
+ """Set the default database connection."""
211
+ if not config_manager.get_database(name):
212
+ console.print(
213
+ f"[bold red]Error:[/bold red] Database connection '{name}' not found"
214
+ )
215
+ raise typer.Exit(1)
216
+
217
+ if config_manager.set_default_database(name):
218
+ console.print(f"[green]Successfully set '{name}' as default database[/green]")
219
+ else:
220
+ console.print(f"[bold red]Error:[/bold red] Failed to set '{name}' as default")
221
+ raise typer.Exit(1)
222
+
223
+
224
+ @db_app.command("test")
225
+ def test_database(
226
+ name: Optional[str] = typer.Argument(
227
+ None,
228
+ help="Name of the database connection to test (uses default if not specified)",
229
+ ),
230
+ ):
231
+ """Test a database connection."""
232
+
233
+ async def test_connection():
234
+ if name:
235
+ db_config = config_manager.get_database(name)
236
+ if not db_config:
237
+ console.print(
238
+ f"[bold red]Error:[/bold red] Database connection '{name}' not found"
239
+ )
240
+ raise typer.Exit(1)
241
+ else:
242
+ db_config = config_manager.get_default_database()
243
+ if not db_config:
244
+ console.print(
245
+ "[bold red]Error:[/bold red] No default database configured"
246
+ )
247
+ console.print(
248
+ "Use 'sqlsaber db add <name>' to add a database connection"
249
+ )
250
+ raise typer.Exit(1)
251
+
252
+ console.print(f"[blue]Testing connection to '{db_config.name}'...[/blue]")
253
+
254
+ try:
255
+ connection_string = db_config.to_connection_string()
256
+ db_conn = DatabaseConnection(connection_string)
257
+
258
+ # Try to connect and run a simple query
259
+ await db_conn.execute_query("SELECT 1 as test")
260
+ await db_conn.close()
261
+
262
+ console.print(
263
+ f"[green]✓ Connection to '{db_config.name}' successful[/green]"
264
+ )
265
+
266
+ except Exception as e:
267
+ console.print(f"[bold red]✗ Connection failed:[/bold red] {e}")
268
+ raise typer.Exit(1)
269
+
270
+ asyncio.run(test_connection())
271
+
272
+
273
+ def create_db_app() -> typer.Typer:
274
+ """Return the database management CLI app."""
275
+ return db_app
@@ -0,0 +1,207 @@
1
+ """Display utilities for the CLI interface."""
2
+
3
+ import json
4
+ from typing import Optional
5
+
6
+ from rich.console import Console
7
+ from rich.syntax import Syntax
8
+ from rich.table import Table
9
+
10
+
11
+ class DisplayManager:
12
+ """Manages display formatting and output for the CLI."""
13
+
14
+ def __init__(self, console: Console):
15
+ self.console = console
16
+
17
+ def _create_table(
18
+ self,
19
+ columns: list,
20
+ header_style: str = "bold blue",
21
+ title: Optional[str] = None,
22
+ ) -> Table:
23
+ """Create a Rich table with specified columns."""
24
+ table = Table(show_header=True, header_style=header_style, title=title)
25
+ for col in columns:
26
+ if isinstance(col, dict):
27
+ table.add_column(
28
+ col["name"], style=col.get("style"), justify=col.get("justify")
29
+ )
30
+ else:
31
+ table.add_column(col)
32
+ return table
33
+
34
+ def show_tool_started(self, tool_name: str):
35
+ """Display tool started message."""
36
+ self.console.print(f"\n[yellow]🔧 Using tool: {tool_name}[/yellow]")
37
+
38
+ def show_tool_executing(self, tool_name: str, tool_input: dict):
39
+ """Display tool execution details."""
40
+ if tool_name == "list_tables":
41
+ self.console.print("[dim] → Discovering available tables[/dim]")
42
+ elif tool_name == "introspect_schema":
43
+ pattern = tool_input.get("table_pattern", "all tables")
44
+ self.console.print(f"[dim] → Examining schema for: {pattern}[/dim]")
45
+ elif tool_name == "execute_sql":
46
+ query = tool_input.get("query", "")
47
+ self.console.print("\n[bold green]Executing SQL:[/bold green]")
48
+ syntax = Syntax(query, "sql")
49
+ self.console.print(syntax)
50
+
51
+ def show_text_stream(self, text: str):
52
+ """Display streaming text."""
53
+ self.console.print(text, end="", markup=False)
54
+
55
+ def show_query_results(self, results: list):
56
+ """Display query results in a formatted table."""
57
+ if not results:
58
+ return
59
+
60
+ self.console.print(
61
+ f"\n[bold magenta]Results ({len(results)} rows):[/bold magenta]"
62
+ )
63
+
64
+ # Create table with columns from first result
65
+ columns = list(results[0].keys())
66
+ table = self._create_table(columns)
67
+
68
+ # Add rows (show first 20 rows)
69
+ for row in results[:20]:
70
+ table.add_row(*[str(row[key]) for key in columns])
71
+
72
+ self.console.print(table)
73
+
74
+ if len(results) > 20:
75
+ self.console.print(
76
+ f"[yellow]... and {len(results) - 20} more rows[/yellow]"
77
+ )
78
+
79
+ def show_error(self, error_message: str):
80
+ """Display error message."""
81
+ self.console.print(f"\n[bold red]Error:[/bold red] {error_message}")
82
+
83
+ def show_processing(self, message: str):
84
+ """Display processing message."""
85
+ self.console.print() # Add newline
86
+ return self.console.status(
87
+ f"[yellow]{message}[/yellow]", spinner="bouncingBall"
88
+ )
89
+
90
+ def show_newline(self):
91
+ """Display a newline for spacing."""
92
+ self.console.print()
93
+
94
+ def show_table_list(self, tables_data: str):
95
+ """Display the results from list_tables tool."""
96
+ try:
97
+ data = json.loads(tables_data)
98
+
99
+ # Handle error case
100
+ if "error" in data:
101
+ self.show_error(data["error"])
102
+ return
103
+
104
+ tables = data.get("tables", [])
105
+ total_tables = data.get("total_tables", 0)
106
+
107
+ if not tables:
108
+ self.console.print("[yellow]No tables found in the database.[/yellow]")
109
+ return
110
+
111
+ self.console.print(
112
+ f"\n[bold green]Database Tables ({total_tables} total):[/bold green]"
113
+ )
114
+
115
+ # Create a rich table for displaying table information
116
+ columns = [
117
+ {"name": "Schema", "style": "cyan"},
118
+ {"name": "Table Name", "style": "white"},
119
+ {"name": "Type", "style": "yellow"},
120
+ {"name": "Row Count", "justify": "right", "style": "magenta"},
121
+ ]
122
+ table = self._create_table(columns)
123
+
124
+ # Add rows
125
+ for table_info in tables:
126
+ schema = table_info.get("schema", "")
127
+ name = table_info.get("name", "")
128
+ table_type = table_info.get("type", "")
129
+ row_count = table_info.get("row_count", 0)
130
+
131
+ # Format row count with commas for readability
132
+ formatted_count = f"{row_count:,}" if row_count else "0"
133
+
134
+ table.add_row(schema, name, table_type, formatted_count)
135
+
136
+ self.console.print(table)
137
+
138
+ except json.JSONDecodeError:
139
+ self.show_error("Failed to parse table list data")
140
+ except Exception as e:
141
+ self.show_error(f"Error displaying table list: {str(e)}")
142
+
143
+ def show_schema_info(self, schema_data: str):
144
+ """Display the results from introspect_schema tool."""
145
+ try:
146
+ data = json.loads(schema_data)
147
+
148
+ # Handle error case
149
+ if "error" in data:
150
+ self.show_error(data["error"])
151
+ return
152
+
153
+ if not data:
154
+ self.console.print("[yellow]No schema information found.[/yellow]")
155
+ return
156
+
157
+ self.console.print(
158
+ f"\n[bold green]Schema Information ({len(data)} tables):[/bold green]"
159
+ )
160
+
161
+ # Display each table's schema
162
+ for table_name, table_info in data.items():
163
+ self.console.print(f"\n[bold cyan]Table: {table_name}[/bold cyan]")
164
+
165
+ # Show columns
166
+ table_columns = table_info.get("columns", {})
167
+ if table_columns:
168
+ # Create a table for columns
169
+ columns = [
170
+ {"name": "Column Name", "style": "white"},
171
+ {"name": "Type", "style": "yellow"},
172
+ {"name": "Nullable", "style": "cyan"},
173
+ {"name": "Default", "style": "dim"},
174
+ ]
175
+ col_table = self._create_table(columns, title="Columns")
176
+
177
+ for col_name, col_info in table_columns.items():
178
+ nullable = "✓" if col_info.get("nullable", False) else "✗"
179
+ default = (
180
+ str(col_info.get("default", ""))
181
+ if col_info.get("default")
182
+ else ""
183
+ )
184
+ col_table.add_row(
185
+ col_name, col_info.get("type", ""), nullable, default
186
+ )
187
+
188
+ self.console.print(col_table)
189
+
190
+ # Show primary keys
191
+ primary_keys = table_info.get("primary_keys", [])
192
+ if primary_keys:
193
+ self.console.print(
194
+ f"[bold yellow]Primary Keys:[/bold yellow] {', '.join(primary_keys)}"
195
+ )
196
+
197
+ # Show foreign keys
198
+ foreign_keys = table_info.get("foreign_keys", [])
199
+ if foreign_keys:
200
+ self.console.print("[bold magenta]Foreign Keys:[/bold magenta]")
201
+ for fk in foreign_keys:
202
+ self.console.print(f" • {fk}")
203
+
204
+ except json.JSONDecodeError:
205
+ self.show_error("Failed to parse schema data")
206
+ except Exception as e:
207
+ self.show_error(f"Error displaying schema information: {str(e)}")