sqlsaber 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlsaber/application/__init__.py +1 -0
- sqlsaber/application/auth_setup.py +164 -0
- sqlsaber/application/db_setup.py +222 -0
- sqlsaber/application/model_selection.py +98 -0
- sqlsaber/application/prompts.py +115 -0
- sqlsaber/cli/auth.py +24 -52
- sqlsaber/cli/commands.py +13 -2
- sqlsaber/cli/database.py +26 -87
- sqlsaber/cli/display.py +59 -40
- sqlsaber/cli/interactive.py +138 -131
- sqlsaber/cli/memory.py +2 -2
- sqlsaber/cli/models.py +20 -30
- sqlsaber/cli/onboarding.py +325 -0
- sqlsaber/cli/streaming.py +1 -1
- sqlsaber/cli/threads.py +35 -16
- sqlsaber/config/api_keys.py +4 -4
- sqlsaber/config/oauth_flow.py +3 -2
- sqlsaber/config/oauth_tokens.py +3 -5
- sqlsaber/database/base.py +6 -0
- sqlsaber/database/csv.py +5 -0
- sqlsaber/database/duckdb.py +5 -0
- sqlsaber/database/mysql.py +5 -0
- sqlsaber/database/postgresql.py +5 -0
- sqlsaber/database/sqlite.py +5 -0
- sqlsaber/theme/__init__.py +5 -0
- sqlsaber/theme/manager.py +219 -0
- sqlsaber/tools/sql_guard.py +225 -0
- sqlsaber/tools/sql_tools.py +10 -35
- {sqlsaber-0.26.0.dist-info → sqlsaber-0.28.0.dist-info}/METADATA +2 -1
- sqlsaber-0.28.0.dist-info/RECORD +61 -0
- sqlsaber-0.26.0.dist-info/RECORD +0 -52
- {sqlsaber-0.26.0.dist-info → sqlsaber-0.28.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.26.0.dist-info → sqlsaber-0.28.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.26.0.dist-info → sqlsaber-0.28.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
"""Interactive onboarding flow for first-time SQLSaber users."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
from rich.panel import Panel
|
|
6
|
+
|
|
7
|
+
from sqlsaber.cli.models import ModelManager
|
|
8
|
+
from sqlsaber.config.api_keys import APIKeyManager
|
|
9
|
+
from sqlsaber.config.auth import AuthConfigManager
|
|
10
|
+
from sqlsaber.config.database import DatabaseConfigManager
|
|
11
|
+
from sqlsaber.theme.manager import create_console
|
|
12
|
+
|
|
13
|
+
console = create_console()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def needs_onboarding(database_arg: str | None = None) -> bool:
|
|
17
|
+
"""Check if user needs onboarding.
|
|
18
|
+
|
|
19
|
+
Onboarding is needed if:
|
|
20
|
+
- No database is configured AND no database connection string provided via CLI
|
|
21
|
+
"""
|
|
22
|
+
# If user provided a database argument, skip onboarding
|
|
23
|
+
if database_arg:
|
|
24
|
+
return False
|
|
25
|
+
|
|
26
|
+
# Check if databases are configured
|
|
27
|
+
db_manager = DatabaseConfigManager()
|
|
28
|
+
has_db = db_manager.has_databases()
|
|
29
|
+
|
|
30
|
+
return not has_db
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def welcome_screen() -> None:
|
|
34
|
+
"""Display welcome screen to new users."""
|
|
35
|
+
banner = """
|
|
36
|
+
███████ ██████ ██ ███████ █████ ██████ ███████ ██████
|
|
37
|
+
██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
38
|
+
███████ ██ ██ ██ ███████ ███████ ██████ █████ ██████
|
|
39
|
+
██ ██ ▄▄ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
40
|
+
███████ ██████ ███████ ███████ ██ ██ ██████ ███████ ██ ██
|
|
41
|
+
▀▀
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
console.print(Panel.fit(banner, style="bold blue"))
|
|
45
|
+
console.print()
|
|
46
|
+
|
|
47
|
+
welcome_message = """
|
|
48
|
+
[bold]Welcome to SQLsaber! 🎉[/bold]
|
|
49
|
+
|
|
50
|
+
SQLsaber is an agentic SQL assistant that lets you query your database using natural language.
|
|
51
|
+
|
|
52
|
+
Let's get you set up in just a few steps.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
console.print(Panel(welcome_message.strip(), border_style="blue", padding=(1, 2)))
|
|
56
|
+
console.print()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
async def setup_database_guided() -> str | None:
|
|
60
|
+
"""Guide user through database setup.
|
|
61
|
+
|
|
62
|
+
Returns the name of the configured database or None if cancelled.
|
|
63
|
+
"""
|
|
64
|
+
from sqlsaber.application.db_setup import (
|
|
65
|
+
build_config,
|
|
66
|
+
collect_db_input,
|
|
67
|
+
save_database,
|
|
68
|
+
test_connection,
|
|
69
|
+
)
|
|
70
|
+
from sqlsaber.application.prompts import AsyncPrompter
|
|
71
|
+
|
|
72
|
+
console.print("━" * 80, style="dim")
|
|
73
|
+
console.print("[bold cyan]Step 1 of 2: Database Connection[/bold cyan]")
|
|
74
|
+
console.print("━" * 80, style="dim")
|
|
75
|
+
console.print()
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
# Ask for connection name
|
|
79
|
+
prompter = AsyncPrompter()
|
|
80
|
+
name = await prompter.text(
|
|
81
|
+
"What would you like to name this connection?",
|
|
82
|
+
default="mydb",
|
|
83
|
+
validate=lambda x: bool(x.strip()) or "Name cannot be empty",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if name is None:
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
name = name.strip()
|
|
90
|
+
|
|
91
|
+
# Check if name already exists
|
|
92
|
+
db_manager = DatabaseConfigManager()
|
|
93
|
+
if db_manager.get_database(name):
|
|
94
|
+
console.print(
|
|
95
|
+
f"[yellow]Database connection '{name}' already exists.[/yellow]"
|
|
96
|
+
)
|
|
97
|
+
return name
|
|
98
|
+
|
|
99
|
+
# Collect database input (simplified - no SSL in onboarding)
|
|
100
|
+
db_input = await collect_db_input(
|
|
101
|
+
prompter=prompter, name=name, db_type="postgresql", include_ssl=False
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
if db_input is None:
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
# Build config
|
|
108
|
+
db_config = build_config(db_input)
|
|
109
|
+
|
|
110
|
+
# Test the connection
|
|
111
|
+
console.print(f"[dim]Testing connection to '{name}'...[/dim]")
|
|
112
|
+
connection_success = await test_connection(db_config, db_input.password)
|
|
113
|
+
|
|
114
|
+
if not connection_success:
|
|
115
|
+
retry = await prompter.confirm(
|
|
116
|
+
"Would you like to try again with different settings?", default=True
|
|
117
|
+
)
|
|
118
|
+
if retry:
|
|
119
|
+
return await setup_database_guided()
|
|
120
|
+
else:
|
|
121
|
+
console.print(
|
|
122
|
+
"[yellow]You can add a database later using 'saber db add'[/yellow]"
|
|
123
|
+
)
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
# Save the configuration
|
|
127
|
+
try:
|
|
128
|
+
save_database(db_manager, db_config, db_input.password)
|
|
129
|
+
console.print(f"[green]✓ Connection to '{name}' successful![/green]")
|
|
130
|
+
console.print()
|
|
131
|
+
return name
|
|
132
|
+
except Exception as e:
|
|
133
|
+
console.print(f"[bold red]Error saving database:[/bold red] {e}")
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
except KeyboardInterrupt:
|
|
137
|
+
console.print("\n[yellow]Setup cancelled.[/yellow]")
|
|
138
|
+
return None
|
|
139
|
+
except Exception as e:
|
|
140
|
+
console.print(f"[bold red]Unexpected error:[/bold red] {e}")
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
async def select_model_for_provider(provider: str) -> str | None:
|
|
145
|
+
"""Fetch and let user select a model for the given provider.
|
|
146
|
+
|
|
147
|
+
Returns the selected model ID or None if cancelled/failed.
|
|
148
|
+
"""
|
|
149
|
+
from sqlsaber.application.model_selection import choose_model, fetch_models
|
|
150
|
+
from sqlsaber.application.prompts import AsyncPrompter
|
|
151
|
+
|
|
152
|
+
try:
|
|
153
|
+
console.print()
|
|
154
|
+
console.print(f"[dim]Fetching available {provider.title()} models...[/dim]")
|
|
155
|
+
|
|
156
|
+
model_manager = ModelManager()
|
|
157
|
+
models = await fetch_models(model_manager, providers=[provider])
|
|
158
|
+
|
|
159
|
+
if not models:
|
|
160
|
+
console.print(
|
|
161
|
+
f"[yellow]Could not fetch models for {provider}. Using default.[/yellow]"
|
|
162
|
+
)
|
|
163
|
+
# Use provider-specific default or fallback to Anthropic
|
|
164
|
+
default_model_id = ModelManager.RECOMMENDED_MODELS.get(
|
|
165
|
+
provider, ModelManager.DEFAULT_MODEL
|
|
166
|
+
)
|
|
167
|
+
# Format it properly if we have a recommended model for this provider
|
|
168
|
+
if provider in ModelManager.RECOMMENDED_MODELS:
|
|
169
|
+
return f"{provider}:{ModelManager.RECOMMENDED_MODELS[provider]}"
|
|
170
|
+
return default_model_id
|
|
171
|
+
|
|
172
|
+
prompter = AsyncPrompter()
|
|
173
|
+
console.print()
|
|
174
|
+
selected_model = await choose_model(
|
|
175
|
+
prompter, models, restrict_provider=provider, use_search_filter=True
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
return selected_model
|
|
179
|
+
|
|
180
|
+
except KeyboardInterrupt:
|
|
181
|
+
console.print("\n[yellow]Model selection cancelled.[/yellow]")
|
|
182
|
+
return None
|
|
183
|
+
except Exception as e:
|
|
184
|
+
console.print(f"[yellow]Error selecting model: {e}. Using default.[/yellow]")
|
|
185
|
+
# Fallback to provider default
|
|
186
|
+
if provider in ModelManager.RECOMMENDED_MODELS:
|
|
187
|
+
return f"{provider}:{ModelManager.RECOMMENDED_MODELS[provider]}"
|
|
188
|
+
return ModelManager.DEFAULT_MODEL
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
async def setup_auth_guided() -> tuple[bool, str | None]:
|
|
192
|
+
"""Guide user through auth setup.
|
|
193
|
+
|
|
194
|
+
Returns tuple of (success: bool, selected_model: str | None).
|
|
195
|
+
"""
|
|
196
|
+
from sqlsaber.application.auth_setup import setup_auth
|
|
197
|
+
from sqlsaber.application.prompts import AsyncPrompter
|
|
198
|
+
|
|
199
|
+
console.print("━" * 80, style="dim")
|
|
200
|
+
console.print("[bold cyan]Step 2 of 2: Authentication[/bold cyan]")
|
|
201
|
+
console.print("━" * 80, style="dim")
|
|
202
|
+
console.print()
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
# Run auth setup
|
|
206
|
+
prompter = AsyncPrompter()
|
|
207
|
+
auth_manager = AuthConfigManager()
|
|
208
|
+
api_key_manager = APIKeyManager()
|
|
209
|
+
|
|
210
|
+
success, provider = await setup_auth(
|
|
211
|
+
prompter=prompter,
|
|
212
|
+
auth_manager=auth_manager,
|
|
213
|
+
api_key_manager=api_key_manager,
|
|
214
|
+
allow_oauth=True,
|
|
215
|
+
default_provider="anthropic",
|
|
216
|
+
run_oauth_in_thread=True,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
if not success:
|
|
220
|
+
console.print(
|
|
221
|
+
"[yellow]You can set it up later using 'saber auth setup'[/yellow]"
|
|
222
|
+
)
|
|
223
|
+
console.print()
|
|
224
|
+
return False, None
|
|
225
|
+
|
|
226
|
+
# If auth configured but we don't know the provider (already configured case)
|
|
227
|
+
if provider is None:
|
|
228
|
+
console.print()
|
|
229
|
+
return True, None
|
|
230
|
+
|
|
231
|
+
# Select model for this provider
|
|
232
|
+
selected_model = await select_model_for_provider(provider)
|
|
233
|
+
if selected_model:
|
|
234
|
+
model_manager = ModelManager()
|
|
235
|
+
model_manager.set_model(selected_model)
|
|
236
|
+
console.print(f"[green]✓ Model set to: {selected_model}[/green]")
|
|
237
|
+
console.print()
|
|
238
|
+
return True, selected_model
|
|
239
|
+
|
|
240
|
+
except KeyboardInterrupt:
|
|
241
|
+
console.print("\n[yellow]Setup cancelled.[/yellow]")
|
|
242
|
+
console.print()
|
|
243
|
+
return False, None
|
|
244
|
+
except Exception as e:
|
|
245
|
+
console.print(f"[bold red]Unexpected error:[/bold red] {e}")
|
|
246
|
+
console.print()
|
|
247
|
+
return False, None
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def success_screen(
|
|
251
|
+
database_name: str | None, auth_configured: bool, model_name: str | None = None
|
|
252
|
+
) -> None:
|
|
253
|
+
"""Display success screen after onboarding."""
|
|
254
|
+
console.print("━" * 80, style="dim")
|
|
255
|
+
console.print("[bold green]You're all set! 🚀[/bold green]")
|
|
256
|
+
console.print("━" * 80, style="dim")
|
|
257
|
+
console.print()
|
|
258
|
+
|
|
259
|
+
if database_name and auth_configured:
|
|
260
|
+
console.print(
|
|
261
|
+
f"[green]✓ Database '{database_name}' connected and ready to use[/green]"
|
|
262
|
+
)
|
|
263
|
+
console.print("[green]✓ Authentication configured[/green]")
|
|
264
|
+
if model_name:
|
|
265
|
+
console.print(f"[green]✓ Model: {model_name}[/green]")
|
|
266
|
+
elif database_name:
|
|
267
|
+
console.print(
|
|
268
|
+
f"[green]✓ Database '{database_name}' connected and ready to use[/green]"
|
|
269
|
+
)
|
|
270
|
+
console.print(
|
|
271
|
+
"[yellow]⚠ AI authentication not configured - you'll be prompted when needed[/yellow]"
|
|
272
|
+
)
|
|
273
|
+
elif auth_configured:
|
|
274
|
+
console.print("[green]✓ AI authentication configured[/green]")
|
|
275
|
+
if model_name:
|
|
276
|
+
console.print(f"[green]✓ Model: {model_name}[/green]")
|
|
277
|
+
console.print(
|
|
278
|
+
"[yellow]⚠ No database configured - you'll need to provide one via -d flag[/yellow]"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
console.print()
|
|
282
|
+
console.print("[dim]Starting interactive session...[/dim]")
|
|
283
|
+
console.print()
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
async def run_onboarding() -> bool:
|
|
287
|
+
"""Run the complete onboarding flow.
|
|
288
|
+
|
|
289
|
+
Returns True if onboarding completed successfully (at least database configured),
|
|
290
|
+
False if user cancelled or onboarding failed.
|
|
291
|
+
"""
|
|
292
|
+
try:
|
|
293
|
+
# Welcome screen
|
|
294
|
+
welcome_screen()
|
|
295
|
+
|
|
296
|
+
# Database setup
|
|
297
|
+
database_name = await setup_database_guided()
|
|
298
|
+
|
|
299
|
+
# If user cancelled database setup, exit
|
|
300
|
+
if database_name is None:
|
|
301
|
+
console.print("[yellow]Database setup is required to continue.[/yellow]")
|
|
302
|
+
console.print(
|
|
303
|
+
"[dim]You can also provide a connection string using: saber -d <connection-string>[/dim]"
|
|
304
|
+
)
|
|
305
|
+
return False
|
|
306
|
+
|
|
307
|
+
# Auth setup
|
|
308
|
+
auth_configured, model_name = await setup_auth_guided()
|
|
309
|
+
|
|
310
|
+
# Show success screen
|
|
311
|
+
success_screen(database_name, auth_configured, model_name)
|
|
312
|
+
|
|
313
|
+
return True
|
|
314
|
+
|
|
315
|
+
except KeyboardInterrupt:
|
|
316
|
+
console.print("\n[yellow]Onboarding cancelled.[/yellow]")
|
|
317
|
+
console.print(
|
|
318
|
+
"[dim]You can run setup commands manually:[/dim]\n"
|
|
319
|
+
"[dim] - saber db add <name> # Add database connection[/dim]\n"
|
|
320
|
+
"[dim] - saber auth setup # Configure authentication[/dim]"
|
|
321
|
+
)
|
|
322
|
+
sys.exit(0)
|
|
323
|
+
except Exception as e:
|
|
324
|
+
console.print(f"[bold red]Onboarding failed:[/bold red] {e}")
|
|
325
|
+
return False
|
sqlsaber/cli/streaming.py
CHANGED
|
@@ -170,7 +170,7 @@ class StreamingQueryHandler:
|
|
|
170
170
|
except asyncio.CancelledError:
|
|
171
171
|
# Show interruption message outside of Live
|
|
172
172
|
self.display.show_newline()
|
|
173
|
-
self.console.print("[
|
|
173
|
+
self.console.print("[warning]Query interrupted[/warning]")
|
|
174
174
|
return None
|
|
175
175
|
finally:
|
|
176
176
|
# End any active status and live markdown segments
|
sqlsaber/cli/threads.py
CHANGED
|
@@ -12,10 +12,12 @@ from rich.markdown import Markdown
|
|
|
12
12
|
from rich.panel import Panel
|
|
13
13
|
from rich.table import Table
|
|
14
14
|
|
|
15
|
+
from sqlsaber.theme.manager import create_console, get_theme_manager
|
|
15
16
|
from sqlsaber.threads import ThreadStorage
|
|
16
17
|
|
|
17
18
|
# Globals consistent with other CLI modules
|
|
18
|
-
console =
|
|
19
|
+
console = create_console()
|
|
20
|
+
tm = get_theme_manager()
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
threads_app = cyclopts.App(
|
|
@@ -84,13 +86,23 @@ def _render_transcript(
|
|
|
84
86
|
console.print(f"**User:**\n\n{text}\n")
|
|
85
87
|
else:
|
|
86
88
|
console.print(
|
|
87
|
-
Panel.fit(
|
|
89
|
+
Panel.fit(
|
|
90
|
+
Markdown(text, code_theme=tm.pygments_style_name),
|
|
91
|
+
title="User",
|
|
92
|
+
border_style=tm.style("panel.border.user"),
|
|
93
|
+
)
|
|
88
94
|
)
|
|
89
95
|
return
|
|
90
96
|
if is_redirected:
|
|
91
97
|
console.print("**User:** (no content)\n")
|
|
92
98
|
else:
|
|
93
|
-
console.print(
|
|
99
|
+
console.print(
|
|
100
|
+
Panel.fit(
|
|
101
|
+
"(no content)",
|
|
102
|
+
title="User",
|
|
103
|
+
border_style=tm.style("panel.border.user"),
|
|
104
|
+
)
|
|
105
|
+
)
|
|
94
106
|
|
|
95
107
|
def _render_response(message: ModelMessage) -> None:
|
|
96
108
|
for part in getattr(message, "parts", []):
|
|
@@ -103,7 +115,9 @@ def _render_transcript(
|
|
|
103
115
|
else:
|
|
104
116
|
console.print(
|
|
105
117
|
Panel.fit(
|
|
106
|
-
Markdown(text
|
|
118
|
+
Markdown(text, code_theme=tm.pygments_style_name),
|
|
119
|
+
title="Assistant",
|
|
120
|
+
border_style=tm.style("panel.border.assistant"),
|
|
107
121
|
)
|
|
108
122
|
)
|
|
109
123
|
elif kind in ("tool-call", "builtin-tool-call"):
|
|
@@ -211,11 +225,11 @@ def list_threads(
|
|
|
211
225
|
console.print("No threads found.")
|
|
212
226
|
return
|
|
213
227
|
table = Table(title="Threads")
|
|
214
|
-
table.add_column("ID", style="
|
|
215
|
-
table.add_column("Database", style="
|
|
216
|
-
table.add_column("Title", style="
|
|
217
|
-
table.add_column("Last Activity", style="
|
|
218
|
-
table.add_column("Model", style="
|
|
228
|
+
table.add_column("ID", style=tm.style("info"))
|
|
229
|
+
table.add_column("Database", style=tm.style("accent"))
|
|
230
|
+
table.add_column("Title", style=tm.style("success"))
|
|
231
|
+
table.add_column("Last Activity", style=tm.style("muted"))
|
|
232
|
+
table.add_column("Model", style=tm.style("warning"))
|
|
219
233
|
for t in threads:
|
|
220
234
|
table.add_row(
|
|
221
235
|
t.id,
|
|
@@ -235,7 +249,7 @@ def show(
|
|
|
235
249
|
store = ThreadStorage()
|
|
236
250
|
thread = asyncio.run(store.get_thread(thread_id))
|
|
237
251
|
if not thread:
|
|
238
|
-
console.print(f"[
|
|
252
|
+
console.print(f"[error]Thread not found:[/error] {thread_id}")
|
|
239
253
|
return
|
|
240
254
|
msgs = asyncio.run(store.get_thread_messages(thread_id))
|
|
241
255
|
console.print(f"[bold]Thread: {thread.id}[/bold]")
|
|
@@ -273,12 +287,12 @@ def resume(
|
|
|
273
287
|
|
|
274
288
|
thread = await store.get_thread(thread_id)
|
|
275
289
|
if not thread:
|
|
276
|
-
console.print(f"[
|
|
290
|
+
console.print(f"[error]Thread not found:[/error] {thread_id}")
|
|
277
291
|
return
|
|
278
292
|
db_selector = database or thread.database_name
|
|
279
293
|
if not db_selector:
|
|
280
294
|
console.print(
|
|
281
|
-
"[
|
|
295
|
+
"[error]No database specified or stored with this thread.[/error]"
|
|
282
296
|
)
|
|
283
297
|
return
|
|
284
298
|
try:
|
|
@@ -287,7 +301,7 @@ def resume(
|
|
|
287
301
|
connection_string = resolved.connection_string
|
|
288
302
|
db_name = resolved.name
|
|
289
303
|
except DatabaseResolutionError as e:
|
|
290
|
-
console.print(f"[
|
|
304
|
+
console.print(f"[error]Database resolution error:[/error] {e}")
|
|
291
305
|
return
|
|
292
306
|
|
|
293
307
|
db_conn = DatabaseConnection(connection_string)
|
|
@@ -295,7 +309,12 @@ def resume(
|
|
|
295
309
|
sqlsaber_agent = SQLSaberAgent(db_conn, db_name)
|
|
296
310
|
history = await store.get_thread_messages(thread_id)
|
|
297
311
|
if console.is_terminal:
|
|
298
|
-
console.print(
|
|
312
|
+
console.print(
|
|
313
|
+
Panel.fit(
|
|
314
|
+
f"Thread: {thread.id}",
|
|
315
|
+
border_style=tm.style("panel.border.thread"),
|
|
316
|
+
)
|
|
317
|
+
)
|
|
299
318
|
else:
|
|
300
319
|
console.print(f"# Thread: {thread.id}\n")
|
|
301
320
|
_render_transcript(console, history, None)
|
|
@@ -310,7 +329,7 @@ def resume(
|
|
|
310
329
|
await session.run()
|
|
311
330
|
finally:
|
|
312
331
|
await db_conn.close()
|
|
313
|
-
console.print("\n[
|
|
332
|
+
console.print("\n[success]Goodbye![/success]")
|
|
314
333
|
|
|
315
334
|
asyncio.run(_run())
|
|
316
335
|
|
|
@@ -329,7 +348,7 @@ def prune(
|
|
|
329
348
|
|
|
330
349
|
async def _run() -> None:
|
|
331
350
|
deleted = await store.prune_threads(older_than_days=days)
|
|
332
|
-
console.print(f"[
|
|
351
|
+
console.print(f"[success]✓ Pruned {deleted} thread(s).[/success]")
|
|
333
352
|
|
|
334
353
|
asyncio.run(_run())
|
|
335
354
|
|
sqlsaber/config/api_keys.py
CHANGED
|
@@ -4,11 +4,11 @@ import getpass
|
|
|
4
4
|
import os
|
|
5
5
|
|
|
6
6
|
import keyring
|
|
7
|
-
from rich.console import Console
|
|
8
7
|
|
|
9
8
|
from sqlsaber.config import providers
|
|
9
|
+
from sqlsaber.theme.manager import create_console
|
|
10
10
|
|
|
11
|
-
console =
|
|
11
|
+
console = create_console()
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class APIKeyManager:
|
|
@@ -19,7 +19,7 @@ class APIKeyManager:
|
|
|
19
19
|
|
|
20
20
|
def get_api_key(self, provider: str) -> str | None:
|
|
21
21
|
"""Get API key for the specified provider using cascading logic."""
|
|
22
|
-
env_var_name = self.
|
|
22
|
+
env_var_name = self.get_env_var_name(provider)
|
|
23
23
|
service_name = self._get_service_name(provider)
|
|
24
24
|
|
|
25
25
|
# 1. Check environment variable first
|
|
@@ -41,7 +41,7 @@ class APIKeyManager:
|
|
|
41
41
|
# 3. Prompt user for API key
|
|
42
42
|
return self._prompt_and_store_key(provider, env_var_name, service_name)
|
|
43
43
|
|
|
44
|
-
def
|
|
44
|
+
def get_env_var_name(self, provider: str) -> str:
|
|
45
45
|
"""Get the expected environment variable name for a provider."""
|
|
46
46
|
# Normalize aliases to canonical provider keys
|
|
47
47
|
key = providers.canonical(provider) or provider
|
sqlsaber/config/oauth_flow.py
CHANGED
|
@@ -10,12 +10,13 @@ from datetime import datetime, timezone
|
|
|
10
10
|
|
|
11
11
|
import httpx
|
|
12
12
|
import questionary
|
|
13
|
-
from rich.console import Console
|
|
14
13
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
15
14
|
|
|
15
|
+
from sqlsaber.theme.manager import create_console
|
|
16
|
+
|
|
16
17
|
from .oauth_tokens import OAuthToken, OAuthTokenManager
|
|
17
18
|
|
|
18
|
-
console =
|
|
19
|
+
console = create_console()
|
|
19
20
|
logger = logging.getLogger(__name__)
|
|
20
21
|
|
|
21
22
|
|
sqlsaber/config/oauth_tokens.py
CHANGED
|
@@ -6,9 +6,10 @@ from datetime import datetime, timedelta, timezone
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
8
|
import keyring
|
|
9
|
-
from rich.console import Console
|
|
10
9
|
|
|
11
|
-
|
|
10
|
+
from sqlsaber.theme.manager import create_console
|
|
11
|
+
|
|
12
|
+
console = create_console()
|
|
12
13
|
logger = logging.getLogger(__name__)
|
|
13
14
|
|
|
14
15
|
|
|
@@ -158,9 +159,6 @@ class OAuthTokenManager:
|
|
|
158
159
|
keyring.delete_password(service_name, provider)
|
|
159
160
|
console.print(f"OAuth token for {provider} removed", style="green")
|
|
160
161
|
return True
|
|
161
|
-
except keyring.errors.PasswordDeleteError:
|
|
162
|
-
# Token doesn't exist
|
|
163
|
-
return True
|
|
164
162
|
except Exception as e:
|
|
165
163
|
logger.error(f"Failed to remove OAuth token for {provider}: {e}")
|
|
166
164
|
console.print(f"Warning: Could not remove OAuth token: {e}", style="yellow")
|
sqlsaber/database/base.py
CHANGED
|
@@ -61,6 +61,12 @@ class BaseDatabaseConnection(ABC):
|
|
|
61
61
|
self.connection_string = connection_string
|
|
62
62
|
self._pool = None
|
|
63
63
|
|
|
64
|
+
@property
|
|
65
|
+
@abstractmethod
|
|
66
|
+
def sqlglot_dialect(self) -> str:
|
|
67
|
+
"""Return the sqlglot dialect name for this database."""
|
|
68
|
+
pass
|
|
69
|
+
|
|
64
70
|
@abstractmethod
|
|
65
71
|
async def get_pool(self):
|
|
66
72
|
"""Get or create connection pool."""
|
sqlsaber/database/csv.py
CHANGED
|
@@ -58,6 +58,11 @@ class CSVConnection(BaseDatabaseConnection):
|
|
|
58
58
|
|
|
59
59
|
self.table_name = Path(self.csv_path).stem or "csv_table"
|
|
60
60
|
|
|
61
|
+
@property
|
|
62
|
+
def sqlglot_dialect(self) -> str:
|
|
63
|
+
"""Return the sqlglot dialect name."""
|
|
64
|
+
return "duckdb"
|
|
65
|
+
|
|
61
66
|
async def get_pool(self):
|
|
62
67
|
"""CSV connections do not maintain a pool."""
|
|
63
68
|
return None
|
sqlsaber/database/duckdb.py
CHANGED
|
@@ -52,6 +52,11 @@ class DuckDBConnection(BaseDatabaseConnection):
|
|
|
52
52
|
|
|
53
53
|
self.database_path = db_path or ":memory:"
|
|
54
54
|
|
|
55
|
+
@property
|
|
56
|
+
def sqlglot_dialect(self) -> str:
|
|
57
|
+
"""Return the sqlglot dialect name."""
|
|
58
|
+
return "duckdb"
|
|
59
|
+
|
|
55
60
|
async def get_pool(self):
|
|
56
61
|
"""DuckDB creates connections per query, return database path."""
|
|
57
62
|
return self.database_path
|
sqlsaber/database/mysql.py
CHANGED
|
@@ -23,6 +23,11 @@ class MySQLConnection(BaseDatabaseConnection):
|
|
|
23
23
|
self._pool: aiomysql.Pool | None = None
|
|
24
24
|
self._parse_connection_string()
|
|
25
25
|
|
|
26
|
+
@property
|
|
27
|
+
def sqlglot_dialect(self) -> str:
|
|
28
|
+
"""Return the sqlglot dialect name."""
|
|
29
|
+
return "mysql"
|
|
30
|
+
|
|
26
31
|
def _parse_connection_string(self):
|
|
27
32
|
"""Parse MySQL connection string into components."""
|
|
28
33
|
parsed = urlparse(self.connection_string)
|
sqlsaber/database/postgresql.py
CHANGED
|
@@ -23,6 +23,11 @@ class PostgreSQLConnection(BaseDatabaseConnection):
|
|
|
23
23
|
self._pool: asyncpg.Pool | None = None
|
|
24
24
|
self._ssl_context = self._create_ssl_context()
|
|
25
25
|
|
|
26
|
+
@property
|
|
27
|
+
def sqlglot_dialect(self) -> str:
|
|
28
|
+
"""Return the sqlglot dialect name."""
|
|
29
|
+
return "postgres"
|
|
30
|
+
|
|
26
31
|
def _create_ssl_context(self) -> ssl.SSLContext | None:
|
|
27
32
|
"""Create SSL context from connection string parameters."""
|
|
28
33
|
parsed = urlparse(self.connection_string)
|
sqlsaber/database/sqlite.py
CHANGED
|
@@ -21,6 +21,11 @@ class SQLiteConnection(BaseDatabaseConnection):
|
|
|
21
21
|
# Extract database path from sqlite:///path format
|
|
22
22
|
self.database_path = connection_string.replace("sqlite:///", "")
|
|
23
23
|
|
|
24
|
+
@property
|
|
25
|
+
def sqlglot_dialect(self) -> str:
|
|
26
|
+
"""Return the sqlglot dialect name."""
|
|
27
|
+
return "sqlite"
|
|
28
|
+
|
|
24
29
|
async def get_pool(self):
|
|
25
30
|
"""SQLite doesn't use connection pooling, return database path."""
|
|
26
31
|
return self.database_path
|