sqlsaber 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,325 @@
1
+ """Interactive onboarding flow for first-time SQLSaber users."""
2
+
3
+ import sys
4
+
5
+ from rich.panel import Panel
6
+
7
+ from sqlsaber.cli.models import ModelManager
8
+ from sqlsaber.config.api_keys import APIKeyManager
9
+ from sqlsaber.config.auth import AuthConfigManager
10
+ from sqlsaber.config.database import DatabaseConfigManager
11
+ from sqlsaber.theme.manager import create_console
12
+
13
+ console = create_console()
14
+
15
+
16
+ def needs_onboarding(database_arg: str | None = None) -> bool:
17
+ """Check if user needs onboarding.
18
+
19
+ Onboarding is needed if:
20
+ - No database is configured AND no database connection string provided via CLI
21
+ """
22
+ # If user provided a database argument, skip onboarding
23
+ if database_arg:
24
+ return False
25
+
26
+ # Check if databases are configured
27
+ db_manager = DatabaseConfigManager()
28
+ has_db = db_manager.has_databases()
29
+
30
+ return not has_db
31
+
32
+
33
+ def welcome_screen() -> None:
34
+ """Display welcome screen to new users."""
35
+ banner = """
36
+ ███████ ██████ ██ ███████ █████ ██████ ███████ ██████
37
+ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
38
+ ███████ ██ ██ ██ ███████ ███████ ██████ █████ ██████
39
+ ██ ██ ▄▄ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
40
+ ███████ ██████ ███████ ███████ ██ ██ ██████ ███████ ██ ██
41
+ ▀▀
42
+ """
43
+
44
+ console.print(Panel.fit(banner, style="bold blue"))
45
+ console.print()
46
+
47
+ welcome_message = """
48
+ [bold]Welcome to SQLsaber! 🎉[/bold]
49
+
50
+ SQLsaber is an agentic SQL assistant that lets you query your database using natural language.
51
+
52
+ Let's get you set up in just a few steps.
53
+ """
54
+
55
+ console.print(Panel(welcome_message.strip(), border_style="blue", padding=(1, 2)))
56
+ console.print()
57
+
58
+
59
+ async def setup_database_guided() -> str | None:
60
+ """Guide user through database setup.
61
+
62
+ Returns the name of the configured database or None if cancelled.
63
+ """
64
+ from sqlsaber.application.db_setup import (
65
+ build_config,
66
+ collect_db_input,
67
+ save_database,
68
+ test_connection,
69
+ )
70
+ from sqlsaber.application.prompts import AsyncPrompter
71
+
72
+ console.print("━" * 80, style="dim")
73
+ console.print("[bold cyan]Step 1 of 2: Database Connection[/bold cyan]")
74
+ console.print("━" * 80, style="dim")
75
+ console.print()
76
+
77
+ try:
78
+ # Ask for connection name
79
+ prompter = AsyncPrompter()
80
+ name = await prompter.text(
81
+ "What would you like to name this connection?",
82
+ default="mydb",
83
+ validate=lambda x: bool(x.strip()) or "Name cannot be empty",
84
+ )
85
+
86
+ if name is None:
87
+ return None
88
+
89
+ name = name.strip()
90
+
91
+ # Check if name already exists
92
+ db_manager = DatabaseConfigManager()
93
+ if db_manager.get_database(name):
94
+ console.print(
95
+ f"[yellow]Database connection '{name}' already exists.[/yellow]"
96
+ )
97
+ return name
98
+
99
+ # Collect database input (simplified - no SSL in onboarding)
100
+ db_input = await collect_db_input(
101
+ prompter=prompter, name=name, db_type="postgresql", include_ssl=False
102
+ )
103
+
104
+ if db_input is None:
105
+ return None
106
+
107
+ # Build config
108
+ db_config = build_config(db_input)
109
+
110
+ # Test the connection
111
+ console.print(f"[dim]Testing connection to '{name}'...[/dim]")
112
+ connection_success = await test_connection(db_config, db_input.password)
113
+
114
+ if not connection_success:
115
+ retry = await prompter.confirm(
116
+ "Would you like to try again with different settings?", default=True
117
+ )
118
+ if retry:
119
+ return await setup_database_guided()
120
+ else:
121
+ console.print(
122
+ "[yellow]You can add a database later using 'saber db add'[/yellow]"
123
+ )
124
+ return None
125
+
126
+ # Save the configuration
127
+ try:
128
+ save_database(db_manager, db_config, db_input.password)
129
+ console.print(f"[green]✓ Connection to '{name}' successful![/green]")
130
+ console.print()
131
+ return name
132
+ except Exception as e:
133
+ console.print(f"[bold red]Error saving database:[/bold red] {e}")
134
+ return None
135
+
136
+ except KeyboardInterrupt:
137
+ console.print("\n[yellow]Setup cancelled.[/yellow]")
138
+ return None
139
+ except Exception as e:
140
+ console.print(f"[bold red]Unexpected error:[/bold red] {e}")
141
+ return None
142
+
143
+
144
+ async def select_model_for_provider(provider: str) -> str | None:
145
+ """Fetch and let user select a model for the given provider.
146
+
147
+ Returns the selected model ID or None if cancelled/failed.
148
+ """
149
+ from sqlsaber.application.model_selection import choose_model, fetch_models
150
+ from sqlsaber.application.prompts import AsyncPrompter
151
+
152
+ try:
153
+ console.print()
154
+ console.print(f"[dim]Fetching available {provider.title()} models...[/dim]")
155
+
156
+ model_manager = ModelManager()
157
+ models = await fetch_models(model_manager, providers=[provider])
158
+
159
+ if not models:
160
+ console.print(
161
+ f"[yellow]Could not fetch models for {provider}. Using default.[/yellow]"
162
+ )
163
+ # Use provider-specific default or fallback to Anthropic
164
+ default_model_id = ModelManager.RECOMMENDED_MODELS.get(
165
+ provider, ModelManager.DEFAULT_MODEL
166
+ )
167
+ # Format it properly if we have a recommended model for this provider
168
+ if provider in ModelManager.RECOMMENDED_MODELS:
169
+ return f"{provider}:{ModelManager.RECOMMENDED_MODELS[provider]}"
170
+ return default_model_id
171
+
172
+ prompter = AsyncPrompter()
173
+ console.print()
174
+ selected_model = await choose_model(
175
+ prompter, models, restrict_provider=provider, use_search_filter=True
176
+ )
177
+
178
+ return selected_model
179
+
180
+ except KeyboardInterrupt:
181
+ console.print("\n[yellow]Model selection cancelled.[/yellow]")
182
+ return None
183
+ except Exception as e:
184
+ console.print(f"[yellow]Error selecting model: {e}. Using default.[/yellow]")
185
+ # Fallback to provider default
186
+ if provider in ModelManager.RECOMMENDED_MODELS:
187
+ return f"{provider}:{ModelManager.RECOMMENDED_MODELS[provider]}"
188
+ return ModelManager.DEFAULT_MODEL
189
+
190
+
191
+ async def setup_auth_guided() -> tuple[bool, str | None]:
192
+ """Guide user through auth setup.
193
+
194
+ Returns tuple of (success: bool, selected_model: str | None).
195
+ """
196
+ from sqlsaber.application.auth_setup import setup_auth
197
+ from sqlsaber.application.prompts import AsyncPrompter
198
+
199
+ console.print("━" * 80, style="dim")
200
+ console.print("[bold cyan]Step 2 of 2: Authentication[/bold cyan]")
201
+ console.print("━" * 80, style="dim")
202
+ console.print()
203
+
204
+ try:
205
+ # Run auth setup
206
+ prompter = AsyncPrompter()
207
+ auth_manager = AuthConfigManager()
208
+ api_key_manager = APIKeyManager()
209
+
210
+ success, provider = await setup_auth(
211
+ prompter=prompter,
212
+ auth_manager=auth_manager,
213
+ api_key_manager=api_key_manager,
214
+ allow_oauth=True,
215
+ default_provider="anthropic",
216
+ run_oauth_in_thread=True,
217
+ )
218
+
219
+ if not success:
220
+ console.print(
221
+ "[yellow]You can set it up later using 'saber auth setup'[/yellow]"
222
+ )
223
+ console.print()
224
+ return False, None
225
+
226
+ # If auth configured but we don't know the provider (already configured case)
227
+ if provider is None:
228
+ console.print()
229
+ return True, None
230
+
231
+ # Select model for this provider
232
+ selected_model = await select_model_for_provider(provider)
233
+ if selected_model:
234
+ model_manager = ModelManager()
235
+ model_manager.set_model(selected_model)
236
+ console.print(f"[green]✓ Model set to: {selected_model}[/green]")
237
+ console.print()
238
+ return True, selected_model
239
+
240
+ except KeyboardInterrupt:
241
+ console.print("\n[yellow]Setup cancelled.[/yellow]")
242
+ console.print()
243
+ return False, None
244
+ except Exception as e:
245
+ console.print(f"[bold red]Unexpected error:[/bold red] {e}")
246
+ console.print()
247
+ return False, None
248
+
249
+
250
+ def success_screen(
251
+ database_name: str | None, auth_configured: bool, model_name: str | None = None
252
+ ) -> None:
253
+ """Display success screen after onboarding."""
254
+ console.print("━" * 80, style="dim")
255
+ console.print("[bold green]You're all set! 🚀[/bold green]")
256
+ console.print("━" * 80, style="dim")
257
+ console.print()
258
+
259
+ if database_name and auth_configured:
260
+ console.print(
261
+ f"[green]✓ Database '{database_name}' connected and ready to use[/green]"
262
+ )
263
+ console.print("[green]✓ Authentication configured[/green]")
264
+ if model_name:
265
+ console.print(f"[green]✓ Model: {model_name}[/green]")
266
+ elif database_name:
267
+ console.print(
268
+ f"[green]✓ Database '{database_name}' connected and ready to use[/green]"
269
+ )
270
+ console.print(
271
+ "[yellow]⚠ AI authentication not configured - you'll be prompted when needed[/yellow]"
272
+ )
273
+ elif auth_configured:
274
+ console.print("[green]✓ AI authentication configured[/green]")
275
+ if model_name:
276
+ console.print(f"[green]✓ Model: {model_name}[/green]")
277
+ console.print(
278
+ "[yellow]⚠ No database configured - you'll need to provide one via -d flag[/yellow]"
279
+ )
280
+
281
+ console.print()
282
+ console.print("[dim]Starting interactive session...[/dim]")
283
+ console.print()
284
+
285
+
286
+ async def run_onboarding() -> bool:
287
+ """Run the complete onboarding flow.
288
+
289
+ Returns True if onboarding completed successfully (at least database configured),
290
+ False if user cancelled or onboarding failed.
291
+ """
292
+ try:
293
+ # Welcome screen
294
+ welcome_screen()
295
+
296
+ # Database setup
297
+ database_name = await setup_database_guided()
298
+
299
+ # If user cancelled database setup, exit
300
+ if database_name is None:
301
+ console.print("[yellow]Database setup is required to continue.[/yellow]")
302
+ console.print(
303
+ "[dim]You can also provide a connection string using: saber -d <connection-string>[/dim]"
304
+ )
305
+ return False
306
+
307
+ # Auth setup
308
+ auth_configured, model_name = await setup_auth_guided()
309
+
310
+ # Show success screen
311
+ success_screen(database_name, auth_configured, model_name)
312
+
313
+ return True
314
+
315
+ except KeyboardInterrupt:
316
+ console.print("\n[yellow]Onboarding cancelled.[/yellow]")
317
+ console.print(
318
+ "[dim]You can run setup commands manually:[/dim]\n"
319
+ "[dim] - saber db add <name> # Add database connection[/dim]\n"
320
+ "[dim] - saber auth setup # Configure authentication[/dim]"
321
+ )
322
+ sys.exit(0)
323
+ except Exception as e:
324
+ console.print(f"[bold red]Onboarding failed:[/bold red] {e}")
325
+ return False
sqlsaber/cli/streaming.py CHANGED
@@ -170,7 +170,7 @@ class StreamingQueryHandler:
170
170
  except asyncio.CancelledError:
171
171
  # Show interruption message outside of Live
172
172
  self.display.show_newline()
173
- self.console.print("[yellow]Query interrupted[/yellow]")
173
+ self.console.print("[warning]Query interrupted[/warning]")
174
174
  return None
175
175
  finally:
176
176
  # End any active status and live markdown segments
sqlsaber/cli/threads.py CHANGED
@@ -12,10 +12,12 @@ from rich.markdown import Markdown
12
12
  from rich.panel import Panel
13
13
  from rich.table import Table
14
14
 
15
+ from sqlsaber.theme.manager import create_console, get_theme_manager
15
16
  from sqlsaber.threads import ThreadStorage
16
17
 
17
18
  # Globals consistent with other CLI modules
18
- console = Console()
19
+ console = create_console()
20
+ tm = get_theme_manager()
19
21
 
20
22
 
21
23
  threads_app = cyclopts.App(
@@ -84,13 +86,23 @@ def _render_transcript(
84
86
  console.print(f"**User:**\n\n{text}\n")
85
87
  else:
86
88
  console.print(
87
- Panel.fit(Markdown(text), title="User", border_style="cyan")
89
+ Panel.fit(
90
+ Markdown(text, code_theme=tm.pygments_style_name),
91
+ title="User",
92
+ border_style=tm.style("panel.border.user"),
93
+ )
88
94
  )
89
95
  return
90
96
  if is_redirected:
91
97
  console.print("**User:** (no content)\n")
92
98
  else:
93
- console.print(Panel.fit("(no content)", title="User", border_style="cyan"))
99
+ console.print(
100
+ Panel.fit(
101
+ "(no content)",
102
+ title="User",
103
+ border_style=tm.style("panel.border.user"),
104
+ )
105
+ )
94
106
 
95
107
  def _render_response(message: ModelMessage) -> None:
96
108
  for part in getattr(message, "parts", []):
@@ -103,7 +115,9 @@ def _render_transcript(
103
115
  else:
104
116
  console.print(
105
117
  Panel.fit(
106
- Markdown(text), title="Assistant", border_style="green"
118
+ Markdown(text, code_theme=tm.pygments_style_name),
119
+ title="Assistant",
120
+ border_style=tm.style("panel.border.assistant"),
107
121
  )
108
122
  )
109
123
  elif kind in ("tool-call", "builtin-tool-call"):
@@ -211,11 +225,11 @@ def list_threads(
211
225
  console.print("No threads found.")
212
226
  return
213
227
  table = Table(title="Threads")
214
- table.add_column("ID", style="cyan")
215
- table.add_column("Database", style="magenta")
216
- table.add_column("Title", style="green")
217
- table.add_column("Last Activity", style="dim")
218
- table.add_column("Model", style="yellow")
228
+ table.add_column("ID", style=tm.style("info"))
229
+ table.add_column("Database", style=tm.style("accent"))
230
+ table.add_column("Title", style=tm.style("success"))
231
+ table.add_column("Last Activity", style=tm.style("muted"))
232
+ table.add_column("Model", style=tm.style("warning"))
219
233
  for t in threads:
220
234
  table.add_row(
221
235
  t.id,
@@ -235,7 +249,7 @@ def show(
235
249
  store = ThreadStorage()
236
250
  thread = asyncio.run(store.get_thread(thread_id))
237
251
  if not thread:
238
- console.print(f"[red]Thread not found:[/red] {thread_id}")
252
+ console.print(f"[error]Thread not found:[/error] {thread_id}")
239
253
  return
240
254
  msgs = asyncio.run(store.get_thread_messages(thread_id))
241
255
  console.print(f"[bold]Thread: {thread.id}[/bold]")
@@ -273,12 +287,12 @@ def resume(
273
287
 
274
288
  thread = await store.get_thread(thread_id)
275
289
  if not thread:
276
- console.print(f"[red]Thread not found:[/red] {thread_id}")
290
+ console.print(f"[error]Thread not found:[/error] {thread_id}")
277
291
  return
278
292
  db_selector = database or thread.database_name
279
293
  if not db_selector:
280
294
  console.print(
281
- "[red]No database specified or stored with this thread.[/red]"
295
+ "[error]No database specified or stored with this thread.[/error]"
282
296
  )
283
297
  return
284
298
  try:
@@ -287,7 +301,7 @@ def resume(
287
301
  connection_string = resolved.connection_string
288
302
  db_name = resolved.name
289
303
  except DatabaseResolutionError as e:
290
- console.print(f"[red]Database resolution error:[/red] {e}")
304
+ console.print(f"[error]Database resolution error:[/error] {e}")
291
305
  return
292
306
 
293
307
  db_conn = DatabaseConnection(connection_string)
@@ -295,7 +309,12 @@ def resume(
295
309
  sqlsaber_agent = SQLSaberAgent(db_conn, db_name)
296
310
  history = await store.get_thread_messages(thread_id)
297
311
  if console.is_terminal:
298
- console.print(Panel.fit(f"Thread: {thread.id}", border_style="blue"))
312
+ console.print(
313
+ Panel.fit(
314
+ f"Thread: {thread.id}",
315
+ border_style=tm.style("panel.border.thread"),
316
+ )
317
+ )
299
318
  else:
300
319
  console.print(f"# Thread: {thread.id}\n")
301
320
  _render_transcript(console, history, None)
@@ -310,7 +329,7 @@ def resume(
310
329
  await session.run()
311
330
  finally:
312
331
  await db_conn.close()
313
- console.print("\n[green]Goodbye![/green]")
332
+ console.print("\n[success]Goodbye![/success]")
314
333
 
315
334
  asyncio.run(_run())
316
335
 
@@ -329,7 +348,7 @@ def prune(
329
348
 
330
349
  async def _run() -> None:
331
350
  deleted = await store.prune_threads(older_than_days=days)
332
- console.print(f"[green]✓ Pruned {deleted} thread(s).[/green]")
351
+ console.print(f"[success]✓ Pruned {deleted} thread(s).[/success]")
333
352
 
334
353
  asyncio.run(_run())
335
354
 
@@ -4,11 +4,11 @@ import getpass
4
4
  import os
5
5
 
6
6
  import keyring
7
- from rich.console import Console
8
7
 
9
8
  from sqlsaber.config import providers
9
+ from sqlsaber.theme.manager import create_console
10
10
 
11
- console = Console()
11
+ console = create_console()
12
12
 
13
13
 
14
14
  class APIKeyManager:
@@ -19,7 +19,7 @@ class APIKeyManager:
19
19
 
20
20
  def get_api_key(self, provider: str) -> str | None:
21
21
  """Get API key for the specified provider using cascading logic."""
22
- env_var_name = self._get_env_var_name(provider)
22
+ env_var_name = self.get_env_var_name(provider)
23
23
  service_name = self._get_service_name(provider)
24
24
 
25
25
  # 1. Check environment variable first
@@ -41,7 +41,7 @@ class APIKeyManager:
41
41
  # 3. Prompt user for API key
42
42
  return self._prompt_and_store_key(provider, env_var_name, service_name)
43
43
 
44
- def _get_env_var_name(self, provider: str) -> str:
44
+ def get_env_var_name(self, provider: str) -> str:
45
45
  """Get the expected environment variable name for a provider."""
46
46
  # Normalize aliases to canonical provider keys
47
47
  key = providers.canonical(provider) or provider
@@ -10,12 +10,13 @@ from datetime import datetime, timezone
10
10
 
11
11
  import httpx
12
12
  import questionary
13
- from rich.console import Console
14
13
  from rich.progress import Progress, SpinnerColumn, TextColumn
15
14
 
15
+ from sqlsaber.theme.manager import create_console
16
+
16
17
  from .oauth_tokens import OAuthToken, OAuthTokenManager
17
18
 
18
- console = Console()
19
+ console = create_console()
19
20
  logger = logging.getLogger(__name__)
20
21
 
21
22
 
@@ -6,9 +6,10 @@ from datetime import datetime, timedelta, timezone
6
6
  from typing import Any
7
7
 
8
8
  import keyring
9
- from rich.console import Console
10
9
 
11
- console = Console()
10
+ from sqlsaber.theme.manager import create_console
11
+
12
+ console = create_console()
12
13
  logger = logging.getLogger(__name__)
13
14
 
14
15
 
@@ -158,9 +159,6 @@ class OAuthTokenManager:
158
159
  keyring.delete_password(service_name, provider)
159
160
  console.print(f"OAuth token for {provider} removed", style="green")
160
161
  return True
161
- except keyring.errors.PasswordDeleteError:
162
- # Token doesn't exist
163
- return True
164
162
  except Exception as e:
165
163
  logger.error(f"Failed to remove OAuth token for {provider}: {e}")
166
164
  console.print(f"Warning: Could not remove OAuth token: {e}", style="yellow")
sqlsaber/database/base.py CHANGED
@@ -61,6 +61,12 @@ class BaseDatabaseConnection(ABC):
61
61
  self.connection_string = connection_string
62
62
  self._pool = None
63
63
 
64
+ @property
65
+ @abstractmethod
66
+ def sqlglot_dialect(self) -> str:
67
+ """Return the sqlglot dialect name for this database."""
68
+ pass
69
+
64
70
  @abstractmethod
65
71
  async def get_pool(self):
66
72
  """Get or create connection pool."""
sqlsaber/database/csv.py CHANGED
@@ -58,6 +58,11 @@ class CSVConnection(BaseDatabaseConnection):
58
58
 
59
59
  self.table_name = Path(self.csv_path).stem or "csv_table"
60
60
 
61
+ @property
62
+ def sqlglot_dialect(self) -> str:
63
+ """Return the sqlglot dialect name."""
64
+ return "duckdb"
65
+
61
66
  async def get_pool(self):
62
67
  """CSV connections do not maintain a pool."""
63
68
  return None
@@ -52,6 +52,11 @@ class DuckDBConnection(BaseDatabaseConnection):
52
52
 
53
53
  self.database_path = db_path or ":memory:"
54
54
 
55
+ @property
56
+ def sqlglot_dialect(self) -> str:
57
+ """Return the sqlglot dialect name."""
58
+ return "duckdb"
59
+
55
60
  async def get_pool(self):
56
61
  """DuckDB creates connections per query, return database path."""
57
62
  return self.database_path
@@ -23,6 +23,11 @@ class MySQLConnection(BaseDatabaseConnection):
23
23
  self._pool: aiomysql.Pool | None = None
24
24
  self._parse_connection_string()
25
25
 
26
+ @property
27
+ def sqlglot_dialect(self) -> str:
28
+ """Return the sqlglot dialect name."""
29
+ return "mysql"
30
+
26
31
  def _parse_connection_string(self):
27
32
  """Parse MySQL connection string into components."""
28
33
  parsed = urlparse(self.connection_string)
@@ -23,6 +23,11 @@ class PostgreSQLConnection(BaseDatabaseConnection):
23
23
  self._pool: asyncpg.Pool | None = None
24
24
  self._ssl_context = self._create_ssl_context()
25
25
 
26
+ @property
27
+ def sqlglot_dialect(self) -> str:
28
+ """Return the sqlglot dialect name."""
29
+ return "postgres"
30
+
26
31
  def _create_ssl_context(self) -> ssl.SSLContext | None:
27
32
  """Create SSL context from connection string parameters."""
28
33
  parsed = urlparse(self.connection_string)
@@ -21,6 +21,11 @@ class SQLiteConnection(BaseDatabaseConnection):
21
21
  # Extract database path from sqlite:///path format
22
22
  self.database_path = connection_string.replace("sqlite:///", "")
23
23
 
24
+ @property
25
+ def sqlglot_dialect(self) -> str:
26
+ """Return the sqlglot dialect name."""
27
+ return "sqlite"
28
+
24
29
  async def get_pool(self):
25
30
  """SQLite doesn't use connection pooling, return database path."""
26
31
  return self.database_path
@@ -0,0 +1,5 @@
1
+ """Theme management for SQLSaber."""
2
+
3
+ from sqlsaber.theme.manager import ThemeManager, create_console, get_theme_manager
4
+
5
+ __all__ = ["ThemeManager", "create_console", "get_theme_manager"]