sqlsaber 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlsaber/cli/auth.py CHANGED
@@ -5,16 +5,15 @@ import os
5
5
  import cyclopts
6
6
  import keyring
7
7
  import questionary
8
- from rich.console import Console
9
8
 
10
9
  from sqlsaber.config import providers
11
10
  from sqlsaber.config.api_keys import APIKeyManager
12
11
  from sqlsaber.config.auth import AuthConfigManager, AuthMethod
13
- from sqlsaber.config.oauth_flow import AnthropicOAuthFlow
14
12
  from sqlsaber.config.oauth_tokens import OAuthTokenManager
13
+ from sqlsaber.theme.manager import create_console
15
14
 
16
15
  # Global instances for CLI commands
17
- console = Console()
16
+ console = create_console()
18
17
  config_manager = AuthConfigManager()
19
18
 
20
19
  # Create the authentication management CLI app
@@ -27,60 +26,33 @@ auth_app = cyclopts.App(
27
26
  @auth_app.command
28
27
  def setup():
29
28
  """Configure authentication for SQLsaber (API keys and Anthropic OAuth)."""
30
- console.print("\n[bold]SQLsaber Authentication Setup[/bold]\n")
29
+ import asyncio
31
30
 
32
- provider = questionary.select(
33
- "Select provider to configure:",
34
- choices=providers.all_keys(),
35
- ).ask()
31
+ from sqlsaber.application.auth_setup import setup_auth
32
+ from sqlsaber.application.prompts import AsyncPrompter
36
33
 
37
- if provider is None:
38
- console.print("[yellow]Setup cancelled.[/yellow]")
39
- return
34
+ console.print("\n[bold]SQLsaber Authentication Setup[/bold]\n")
40
35
 
41
- if provider == "anthropic":
42
- # Let user choose API key or OAuth
43
- method_choice = questionary.select(
44
- "Select Anthropic authentication method:",
45
- choices=[
46
- {"name": "API key", "value": AuthMethod.API_KEY},
47
- {"name": "Claude Pro/Max (OAuth)", "value": AuthMethod.CLAUDE_PRO},
48
- ],
49
- ).ask()
50
-
51
- if method_choice == AuthMethod.CLAUDE_PRO:
52
- flow = AnthropicOAuthFlow()
53
- if flow.authenticate():
54
- config_manager.set_auth_method(AuthMethod.CLAUDE_PRO)
55
- console.print(
56
- "\n[bold green]✓ Anthropic OAuth configured successfully![/bold green]"
57
- )
58
- else:
59
- console.print("\n[red]✗ Anthropic OAuth setup failed.[/red]")
60
- console.print(
61
- "You can change this anytime by running [cyan]saber auth setup[/cyan] again."
62
- )
63
- return
64
-
65
- # API key flow (all providers + Anthropic when selected above)
66
- api_key_manager = APIKeyManager()
67
- env_var = api_key_manager._get_env_var_name(provider)
68
- console.print("\nTo configure your API key, you can either:")
69
- console.print(f"• Set the {env_var} environment variable")
70
- console.print("• Let SQLsaber prompt you for the key when needed (stored securely)")
71
-
72
- # Fetch/store key (cascades env -> keyring -> prompt)
73
- api_key = api_key_manager.get_api_key(provider)
74
- if api_key:
75
- config_manager.set_auth_method(AuthMethod.API_KEY)
76
- console.print(
77
- f"\n[bold green]✓ {provider.title()} API key configured successfully![/bold green]"
36
+ async def run_setup():
37
+ prompter = AsyncPrompter()
38
+ api_key_manager = APIKeyManager()
39
+ success, provider = await setup_auth(
40
+ prompter=prompter,
41
+ auth_manager=config_manager,
42
+ api_key_manager=api_key_manager,
43
+ allow_oauth=True,
44
+ default_provider="anthropic",
45
+ run_oauth_in_thread=False,
78
46
  )
79
- else:
80
- console.print("\n[yellow]No API key configured.[/yellow]")
47
+ return success, provider
48
+
49
+ success, _ = asyncio.run(run_setup())
50
+
51
+ if not success:
52
+ console.print("\n[yellow]No authentication configured.[/yellow]")
81
53
 
82
54
  console.print(
83
- "You can change this anytime by running [cyan]saber auth setup[/cyan] again."
55
+ "\nYou can change this anytime by running [cyan]saber auth setup[/cyan] again."
84
56
  )
85
57
 
86
58
 
@@ -109,7 +81,7 @@ def status():
109
81
  # Include OAuth status
110
82
  if OAuthTokenManager().has_oauth_token("anthropic"):
111
83
  console.print("> anthropic (oauth): [green]configured[/green]")
112
- env_var = api_key_manager._get_env_var_name(provider)
84
+ env_var = api_key_manager.get_env_var_name(provider)
113
85
  service = api_key_manager._get_service_name(provider)
114
86
  from_env = bool(os.getenv(env_var))
115
87
  from_keyring = bool(keyring.get_password(service, provider))
sqlsaber/cli/commands.py CHANGED
@@ -5,16 +5,17 @@ import sys
5
5
  from typing import Annotated
6
6
 
7
7
  import cyclopts
8
- from rich.console import Console
9
8
 
10
9
  from sqlsaber.cli.auth import create_auth_app
11
10
  from sqlsaber.cli.database import create_db_app
12
11
  from sqlsaber.cli.memory import create_memory_app
13
12
  from sqlsaber.cli.models import create_models_app
13
+ from sqlsaber.cli.onboarding import needs_onboarding, run_onboarding
14
14
  from sqlsaber.cli.threads import create_threads_app
15
15
 
16
16
  # Lazy imports - only import what's needed for CLI parsing
17
17
  from sqlsaber.config.database import DatabaseConfigManager
18
+ from sqlsaber.theme.manager import create_console
18
19
 
19
20
 
20
21
  class CLIError(Exception):
@@ -36,7 +37,7 @@ app.command(create_memory_app(), name="memory")
36
37
  app.command(create_models_app(), name="models")
37
38
  app.command(create_threads_app(), name="threads")
38
39
 
39
- console = Console()
40
+ console = create_console()
40
41
  config_manager = DatabaseConfigManager()
41
42
 
42
43
 
@@ -128,6 +129,16 @@ def query(
128
129
  # If stdin was empty, fall back to interactive mode
129
130
  actual_query = None
130
131
 
132
+ # Check if onboarding is needed (only for interactive mode or when no database is configured)
133
+ if needs_onboarding(database):
134
+ # Run onboarding flow
135
+ onboarding_success = await run_onboarding()
136
+ if not onboarding_success:
137
+ # User cancelled or onboarding failed
138
+ raise CLIError(
139
+ "Setup incomplete. Please configure your database and try again."
140
+ )
141
+
131
142
  # Resolve database from CLI input
132
143
  try:
133
144
  resolved = resolve_database(database, config_manager)
sqlsaber/cli/database.py CHANGED
@@ -8,13 +8,13 @@ from typing import Annotated
8
8
 
9
9
  import cyclopts
10
10
  import questionary
11
- from rich.console import Console
12
11
  from rich.table import Table
13
12
 
14
13
  from sqlsaber.config.database import DatabaseConfig, DatabaseConfigManager
14
+ from sqlsaber.theme.manager import create_console
15
15
 
16
16
  # Global instances for CLI commands
17
- console = Console()
17
+ console = create_console()
18
18
  config_manager = DatabaseConfigManager()
19
19
 
20
20
  # Create the database management CLI app
@@ -81,95 +81,34 @@ def add(
81
81
 
82
82
  if interactive:
83
83
  # Interactive mode - prompt for all required fields
84
- console.print(f"[bold]Adding database connection: {name}[/bold]")
84
+ from sqlsaber.application.db_setup import collect_db_input
85
+ from sqlsaber.application.prompts import AsyncPrompter
85
86
 
86
- # Database type
87
- if not type or type == "postgresql":
88
- type = questionary.select(
89
- "Database type:",
90
- choices=["postgresql", "mysql", "sqlite", "duckdb"],
91
- default="postgresql",
92
- ).ask()
93
-
94
- if type in {"sqlite", "duckdb"}:
95
- # SQLite/DuckDB only need database file path
96
- database = database or questionary.path("Database file path:").ask()
97
- database = str(Path(database).expanduser().resolve())
98
- host = "localhost"
99
- port = 0
100
- username = type
101
- password = ""
102
- else:
103
- # PostgreSQL/MySQL need connection details
104
- host = host or questionary.text("Host:", default="localhost").ask()
87
+ console.print(f"[bold]Adding database connection: {name}[/bold]")
105
88
 
106
- default_port = 5432 if type == "postgresql" else 3306
107
- port = port or int(
108
- questionary.text("Port:", default=str(default_port)).ask()
89
+ async def collect_input():
90
+ prompter = AsyncPrompter()
91
+ return await collect_db_input(
92
+ prompter=prompter, name=name, db_type=type, include_ssl=True
109
93
  )
110
94
 
111
- database = database or questionary.text("Database name:").ask()
112
- username = username or questionary.text("Username:").ask()
113
-
114
- # Ask for password
115
- password = getpass.getpass("Password (stored in your OS keychain): ")
116
-
117
- # Ask for SSL configuration
118
- if questionary.confirm("Configure SSL/TLS settings?", default=False).ask():
119
- if type == "postgresql":
120
- ssl_mode = (
121
- ssl_mode
122
- or questionary.select(
123
- "SSL mode for PostgreSQL:",
124
- choices=[
125
- "disable",
126
- "allow",
127
- "prefer",
128
- "require",
129
- "verify-ca",
130
- "verify-full",
131
- ],
132
- default="prefer",
133
- ).ask()
134
- )
135
- elif type == "mysql":
136
- ssl_mode = (
137
- ssl_mode
138
- or questionary.select(
139
- "SSL mode for MySQL:",
140
- choices=[
141
- "DISABLED",
142
- "PREFERRED",
143
- "REQUIRED",
144
- "VERIFY_CA",
145
- "VERIFY_IDENTITY",
146
- ],
147
- default="PREFERRED",
148
- ).ask()
149
- )
150
-
151
- if ssl_mode and ssl_mode not in ["disable", "DISABLED"]:
152
- if questionary.confirm(
153
- "Specify SSL certificate files?", default=False
154
- ).ask():
155
- ssl_ca = (
156
- ssl_ca or questionary.path("SSL CA certificate file:").ask()
157
- )
158
- if questionary.confirm(
159
- "Specify client certificate?", default=False
160
- ).ask():
161
- ssl_cert = (
162
- ssl_cert
163
- or questionary.path(
164
- "SSL client certificate file:"
165
- ).ask()
166
- )
167
- ssl_key = (
168
- ssl_key
169
- or questionary.path(
170
- "SSL client private key file:"
171
- ).ask()
172
- )
95
+ db_input = asyncio.run(collect_input())
96
+
97
+ if db_input is None:
98
+ console.print("[yellow]Operation cancelled[/yellow]")
99
+ return
100
+
101
+ # Extract values from db_input
102
+ type = db_input.type
103
+ host = db_input.host
104
+ port = db_input.port
105
+ database = db_input.database
106
+ username = db_input.username
107
+ password = db_input.password
108
+ ssl_mode = db_input.ssl_mode
109
+ ssl_ca = db_input.ssl_ca
110
+ ssl_cert = db_input.ssl_cert
111
+ ssl_key = db_input.ssl_key
173
112
  else:
174
113
  # Non-interactive mode - use provided values or defaults
175
114
  if type == "sqlite":
sqlsaber/cli/display.py CHANGED
@@ -19,6 +19,8 @@ from rich.syntax import Syntax
19
19
  from rich.table import Table
20
20
  from rich.text import Text
21
21
 
22
+ from sqlsaber.theme.manager import get_theme_manager
23
+
22
24
 
23
25
  class _SimpleCodeBlock(CodeBlock):
24
26
  def __rich_console__(
@@ -46,6 +48,7 @@ class LiveMarkdownRenderer:
46
48
 
47
49
  def __init__(self, console: Console):
48
50
  self.console = console
51
+ self.tm = get_theme_manager()
49
52
  self._live: Live | None = None
50
53
  self._status_live: Live | None = None
51
54
  self._buffer: str = ""
@@ -90,10 +93,14 @@ class LiveMarkdownRenderer:
90
93
 
91
94
  # Apply dim styling for thinking segments
92
95
  if self._current_kind == ThinkingPart:
93
- content = Markdown(self._buffer, style="dim")
96
+ content = Markdown(
97
+ self._buffer, style="muted", code_theme=self.tm.pygments_style_name
98
+ )
94
99
  self._live.update(content)
95
100
  else:
96
- self._live.update(Markdown(self._buffer))
101
+ self._live.update(
102
+ Markdown(self._buffer, code_theme=self.tm.pygments_style_name)
103
+ )
97
104
 
98
105
  def end(self) -> None:
99
106
  """Finalize and stop the current Live segment, if any."""
@@ -109,9 +116,13 @@ class LiveMarkdownRenderer:
109
116
  # Print the complete markdown to scroll-back for permanent reference
110
117
  if buf:
111
118
  if kind == ThinkingPart:
112
- self.console.print(Text(buf, style="dim"))
119
+ self.console.print(
120
+ Markdown(buf, style="muted", code_theme=self.tm.pygments_style_name)
121
+ )
113
122
  else:
114
- self.console.print(Markdown(buf))
123
+ self.console.print(
124
+ Markdown(buf, code_theme=self.tm.pygments_style_name)
125
+ )
115
126
 
116
127
  def end_if_active(self) -> None:
117
128
  self.end()
@@ -129,7 +140,7 @@ class LiveMarkdownRenderer:
129
140
  self._buffer = f"```sql\n{sql}\n```"
130
141
  # Use context manager to auto-stop and persist final render
131
142
  with Live(
132
- Markdown(self._buffer),
143
+ Markdown(self._buffer, code_theme=self.tm.pygments_style_name),
133
144
  console=self.console,
134
145
  vertical_overflow="visible",
135
146
  refresh_per_second=12,
@@ -159,8 +170,8 @@ class LiveMarkdownRenderer:
159
170
  self._status_live = None
160
171
 
161
172
  def _status_renderable(self, message: str):
162
- spinner = Spinner("dots", style="yellow")
163
- text = Text(f" {message}", style="yellow")
173
+ spinner = Spinner("dots", style=self.tm.style("spinner"))
174
+ text = Text(f" {message}", style=self.tm.style("status"))
164
175
  return Columns([spinner, text], expand=False)
165
176
 
166
177
  def _start(
@@ -173,14 +184,14 @@ class LiveMarkdownRenderer:
173
184
  # Add visual styling for thinking segments
174
185
  if kind == ThinkingPart:
175
186
  if self.console.is_terminal:
176
- self.console.print("[dim]💭 Thinking...[/dim]")
187
+ self.console.print("[muted]💭 Thinking...[/muted]")
177
188
  else:
178
189
  self.console.print("*Thinking...*\n")
179
190
 
180
191
  # NOTE: Use transient=True so the live widget disappears on exit,
181
192
  # giving a clean transition to the final printed result.
182
193
  live = Live(
183
- Markdown(self._buffer),
194
+ Markdown(self._buffer, code_theme=self.tm.pygments_style_name),
184
195
  console=self.console,
185
196
  transient=True,
186
197
  refresh_per_second=12,
@@ -195,14 +206,16 @@ class DisplayManager:
195
206
  def __init__(self, console: Console):
196
207
  self.console = console
197
208
  self.live = LiveMarkdownRenderer(console)
209
+ self.tm = get_theme_manager()
198
210
 
199
211
  def _create_table(
200
212
  self,
201
213
  columns: Sequence[str | dict[str, str]],
202
- header_style: str = "bold blue",
214
+ header_style: str | None = None,
203
215
  title: str | None = None,
204
216
  ) -> Table:
205
217
  """Create a Rich table with specified columns."""
218
+ header_style = header_style or self.tm.style("table.header")
206
219
  table = Table(show_header=True, header_style=header_style, title=title)
207
220
  for col in columns:
208
221
  if isinstance(col, dict):
@@ -220,7 +233,7 @@ class DisplayManager:
220
233
  if tool_name == "list_tables":
221
234
  if self.console.is_terminal:
222
235
  self.console.print(
223
- "[dim bold]:gear: Discovering available tables[/dim bold]"
236
+ "[muted bold]:gear: Discovering available tables[/muted bold]"
224
237
  )
225
238
  else:
226
239
  self.console.print("**Discovering available tables**\n")
@@ -228,7 +241,7 @@ class DisplayManager:
228
241
  pattern = tool_input.get("table_pattern", "all tables")
229
242
  if self.console.is_terminal:
230
243
  self.console.print(
231
- f"[dim bold]:gear: Examining schema for: {pattern}[/dim bold]"
244
+ f"[muted bold]:gear: Examining schema for: {pattern}[/muted bold]"
232
245
  )
233
246
  else:
234
247
  self.console.print(f"**Examining schema for:** {pattern}\n")
@@ -237,10 +250,14 @@ class DisplayManager:
237
250
  # rendering for threads show/resume. Controlled by include_sql flag.
238
251
  query = tool_input.get("query", "")
239
252
  if self.console.is_terminal:
240
- self.console.print("[dim bold]:gear: Executing SQL:[/dim bold]")
253
+ self.console.print("[muted bold]:gear: Executing SQL:[/muted bold]")
241
254
  self.show_newline()
242
255
  syntax = Syntax(
243
- query, "sql", background_color="default", word_wrap=True
256
+ query,
257
+ "sql",
258
+ theme=self.tm.pygments_style_name,
259
+ background_color="default",
260
+ word_wrap=True,
244
261
  )
245
262
  self.console.print(syntax)
246
263
  else:
@@ -258,9 +275,7 @@ class DisplayManager:
258
275
  return
259
276
 
260
277
  if self.console.is_terminal:
261
- self.console.print(
262
- f"\n[bold magenta]Results ({len(results)} rows):[/bold magenta]"
263
- )
278
+ self.console.print(f"\n[section]Results ({len(results)} rows):[/section]")
264
279
  else:
265
280
  self.console.print(f"\n**Results ({len(results)} rows):**\n")
266
281
 
@@ -272,7 +287,7 @@ class DisplayManager:
272
287
  if len(all_columns) > 15:
273
288
  if self.console.is_terminal:
274
289
  self.console.print(
275
- f"[yellow]Note: Showing first 15 of {len(all_columns)} columns[/yellow]"
290
+ f"[warning]Note: Showing first 15 of {len(all_columns)} columns[/warning]"
276
291
  )
277
292
  else:
278
293
  self.console.print(
@@ -290,21 +305,21 @@ class DisplayManager:
290
305
  if len(results) > 20:
291
306
  if self.console.is_terminal:
292
307
  self.console.print(
293
- f"[yellow]... and {len(results) - 20} more rows[/yellow]"
308
+ f"[warning]... and {len(results) - 20} more rows[/warning]"
294
309
  )
295
310
  else:
296
311
  self.console.print(f"*... and {len(results) - 20} more rows*\n")
297
312
 
298
313
  def show_error(self, error_message: str):
299
314
  """Display error message."""
300
- self.console.print(f"\n[bold red]Error:[/bold red] {error_message}")
315
+ self.console.print(f"\n[error]Error:[/error] {error_message}")
301
316
 
302
317
  def show_sql_error(self, error_message: str, suggestions: list[str] | None = None):
303
318
  """Display SQL-specific error with optional suggestions."""
304
319
  self.show_newline()
305
- self.console.print(f"[bold red]SQL error:[/bold red] {error_message}")
320
+ self.console.print(f"[error]SQL error:[/error] {error_message}")
306
321
  if suggestions:
307
- self.console.print("[yellow]Hints:[/yellow]")
322
+ self.console.print("[warning]Hints:[/warning]")
308
323
  for suggestion in suggestions:
309
324
  self.console.print(f" • {suggestion}")
310
325
 
@@ -312,7 +327,7 @@ class DisplayManager:
312
327
  """Display processing message."""
313
328
  self.console.print() # Add newline
314
329
  return self.console.status(
315
- f"[yellow]{message}[/yellow]", spinner="bouncingBall"
330
+ f"[status]{message}[/status]", spinner="bouncingBall"
316
331
  )
317
332
 
318
333
  def show_newline(self):
@@ -335,18 +350,20 @@ class DisplayManager:
335
350
  total_tables = data.get("total_tables", 0)
336
351
 
337
352
  if not tables:
338
- self.console.print("[yellow]No tables found in the database.[/yellow]")
353
+ self.console.print(
354
+ "[warning]No tables found in the database.[/warning]"
355
+ )
339
356
  return
340
357
 
341
358
  self.console.print(
342
- f"\n[bold green]Database Tables ({total_tables} total):[/bold green]"
359
+ f"\n[title]Database Tables ({total_tables} total):[/title]"
343
360
  )
344
361
 
345
362
  # Create a rich table for displaying table information
346
363
  columns = [
347
- {"name": "Schema", "style": "cyan"},
348
- {"name": "Table Name", "style": "white"},
349
- {"name": "Type", "style": "yellow"},
364
+ {"name": "Schema", "style": "column.schema"},
365
+ {"name": "Table Name", "style": "column.name"},
366
+ {"name": "Type", "style": "column.type"},
350
367
  ]
351
368
  table = self._create_table(columns)
352
369
 
@@ -378,26 +395,26 @@ class DisplayManager:
378
395
  return
379
396
 
380
397
  if not data:
381
- self.console.print("[yellow]No schema information found.[/yellow]")
398
+ self.console.print("[warning]No schema information found.[/warning]")
382
399
  return
383
400
 
384
401
  self.console.print(
385
- f"\n[bold green]Schema Information ({len(data)} tables):[/bold green]"
402
+ f"\n[title]Schema Information ({len(data)} tables):[/title]"
386
403
  )
387
404
 
388
405
  # Display each table's schema
389
406
  for table_name, table_info in data.items():
390
- self.console.print(f"\n[bold cyan]Table: {table_name}[/bold cyan]")
407
+ self.console.print(f"\n[heading]Table: {table_name}[/heading]")
391
408
 
392
409
  # Show columns
393
410
  table_columns = table_info.get("columns", {})
394
411
  if table_columns:
395
412
  # Create a table for columns
396
413
  columns = [
397
- {"name": "Column Name", "style": "white"},
398
- {"name": "Type", "style": "yellow"},
399
- {"name": "Nullable", "style": "cyan"},
400
- {"name": "Default", "style": "dim"},
414
+ {"name": "Column Name", "style": "column.name"},
415
+ {"name": "Type", "style": "column.type"},
416
+ {"name": "Nullable", "style": "info"},
417
+ {"name": "Default", "style": "muted"},
401
418
  ]
402
419
  col_table = self._create_table(columns, title="Columns")
403
420
 
@@ -418,20 +435,20 @@ class DisplayManager:
418
435
  primary_keys = table_info.get("primary_keys", [])
419
436
  if primary_keys:
420
437
  self.console.print(
421
- f"[bold yellow]Primary Keys:[/bold yellow] {', '.join(primary_keys)}"
438
+ f"[key.primary]Primary Keys:[/key.primary] {', '.join(primary_keys)}"
422
439
  )
423
440
 
424
441
  # Show foreign keys
425
442
  foreign_keys = table_info.get("foreign_keys", [])
426
443
  if foreign_keys:
427
- self.console.print("[bold magenta]Foreign Keys:[/bold magenta]")
444
+ self.console.print("[key.foreign]Foreign Keys:[/key.foreign]")
428
445
  for fk in foreign_keys:
429
446
  self.console.print(f" • {fk}")
430
447
 
431
448
  # Show indexes
432
449
  indexes = table_info.get("indexes", [])
433
450
  if indexes:
434
- self.console.print("[bold blue]Indexes:[/bold blue]")
451
+ self.console.print("[key.index]Indexes:[/key.index]")
435
452
  for idx in indexes:
436
453
  self.console.print(f" • {idx}")
437
454
 
@@ -457,7 +474,9 @@ class DisplayManager:
457
474
  full_text = "".join(text_parts).strip()
458
475
  if full_text:
459
476
  self.console.print() # Add spacing before panel
460
- markdown = Markdown(full_text)
461
- panel = Panel.fit(markdown, border_style="green")
477
+ markdown = Markdown(full_text, code_theme=self.tm.pygments_style_name)
478
+ panel = Panel.fit(
479
+ markdown, border_style=self.tm.style("panel.border.assistant")
480
+ )
462
481
  self.console.print(panel)
463
482
  self.console.print() # Add spacing after panel