sqlsaber 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/__init__.py +2 -2
- sqlsaber/agents/base.py +1 -1
- sqlsaber/agents/mcp.py +1 -1
- sqlsaber/agents/pydantic_ai_agent.py +207 -135
- sqlsaber/application/__init__.py +1 -0
- sqlsaber/application/auth_setup.py +164 -0
- sqlsaber/application/db_setup.py +223 -0
- sqlsaber/application/model_selection.py +98 -0
- sqlsaber/application/prompts.py +115 -0
- sqlsaber/cli/auth.py +22 -50
- sqlsaber/cli/commands.py +22 -28
- sqlsaber/cli/completers.py +2 -0
- sqlsaber/cli/database.py +25 -86
- sqlsaber/cli/display.py +29 -9
- sqlsaber/cli/interactive.py +150 -127
- sqlsaber/cli/models.py +18 -28
- sqlsaber/cli/onboarding.py +325 -0
- sqlsaber/cli/streaming.py +15 -17
- sqlsaber/cli/threads.py +10 -6
- sqlsaber/config/api_keys.py +2 -2
- sqlsaber/config/settings.py +25 -2
- sqlsaber/database/__init__.py +55 -1
- sqlsaber/database/base.py +124 -0
- sqlsaber/database/csv.py +133 -0
- sqlsaber/database/duckdb.py +313 -0
- sqlsaber/database/mysql.py +345 -0
- sqlsaber/database/postgresql.py +328 -0
- sqlsaber/database/schema.py +66 -963
- sqlsaber/database/sqlite.py +258 -0
- sqlsaber/mcp/mcp.py +1 -1
- sqlsaber/tools/sql_tools.py +1 -1
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/METADATA +43 -9
- sqlsaber-0.27.0.dist-info/RECORD +58 -0
- sqlsaber/database/connection.py +0 -535
- sqlsaber-0.25.0.dist-info/RECORD +0 -47
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/cli/models.py
CHANGED
|
@@ -25,11 +25,20 @@ models_app = cyclopts.App(
|
|
|
25
25
|
class ModelManager:
|
|
26
26
|
"""Manages AI model configuration and fetching."""
|
|
27
27
|
|
|
28
|
-
DEFAULT_MODEL = "anthropic:claude-sonnet-4-
|
|
28
|
+
DEFAULT_MODEL = "anthropic:claude-sonnet-4-5-20250929"
|
|
29
29
|
MODELS_API_URL = "https://models.dev/api.json"
|
|
30
30
|
# Providers come from central registry
|
|
31
31
|
SUPPORTED_PROVIDERS = providers.all_keys()
|
|
32
32
|
|
|
33
|
+
RECOMMENDED_MODELS = {
|
|
34
|
+
"anthropic": "claude-sonnet-4-5-20250929",
|
|
35
|
+
"openai": "gpt-5",
|
|
36
|
+
"google": "gemini-2.5-pro",
|
|
37
|
+
"groq": "llama-3-3-70b-versatile",
|
|
38
|
+
"mistral": "mistral-large-latest",
|
|
39
|
+
"cohere": "command-r-plus",
|
|
40
|
+
}
|
|
41
|
+
|
|
33
42
|
async def fetch_available_models(
|
|
34
43
|
self, providers: list[str] | None = None
|
|
35
44
|
) -> list[dict]:
|
|
@@ -180,39 +189,20 @@ def set():
|
|
|
180
189
|
"""Set the AI model to use."""
|
|
181
190
|
|
|
182
191
|
async def interactive_set():
|
|
192
|
+
from sqlsaber.application.model_selection import choose_model, fetch_models
|
|
193
|
+
from sqlsaber.application.prompts import AsyncPrompter
|
|
194
|
+
|
|
183
195
|
console.print("[blue]Fetching available models...[/blue]")
|
|
184
|
-
models = await model_manager
|
|
196
|
+
models = await fetch_models(model_manager)
|
|
185
197
|
|
|
186
198
|
if not models:
|
|
187
199
|
console.print("[red]Failed to fetch models. Cannot set model.[/red]")
|
|
188
200
|
sys.exit(1)
|
|
189
201
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
prov = model.get("provider", "?")
|
|
195
|
-
choice_text = f"[{prov}] {model['id']} - {model['name']}"
|
|
196
|
-
if model["description"]:
|
|
197
|
-
choice_text += f" ({model['description'][:50]}{'...' if len(model['description']) > 50 else ''})"
|
|
198
|
-
|
|
199
|
-
choices.append({"name": choice_text, "value": model["id"]})
|
|
200
|
-
|
|
201
|
-
# Get current model to set as default
|
|
202
|
-
current_model = model_manager.get_current_model()
|
|
203
|
-
default_index = 0
|
|
204
|
-
for i, choice in enumerate(choices):
|
|
205
|
-
if choice["value"] == current_model:
|
|
206
|
-
default_index = i
|
|
207
|
-
break
|
|
208
|
-
|
|
209
|
-
selected_model = await questionary.select(
|
|
210
|
-
"Select a model:",
|
|
211
|
-
choices=choices,
|
|
212
|
-
use_search_filter=True,
|
|
213
|
-
use_jk_keys=False, # Disable j/k keys when using search filter
|
|
214
|
-
default=choices[default_index] if choices else None,
|
|
215
|
-
).ask_async()
|
|
202
|
+
prompter = AsyncPrompter()
|
|
203
|
+
selected_model = await choose_model(
|
|
204
|
+
prompter, models, restrict_provider=None, use_search_filter=True
|
|
205
|
+
)
|
|
216
206
|
|
|
217
207
|
if selected_model:
|
|
218
208
|
if model_manager.set_model(selected_model):
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
"""Interactive onboarding flow for first-time SQLSaber users."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
from rich.console import Console
|
|
6
|
+
from rich.panel import Panel
|
|
7
|
+
|
|
8
|
+
from sqlsaber.cli.models import ModelManager
|
|
9
|
+
from sqlsaber.config.api_keys import APIKeyManager
|
|
10
|
+
from sqlsaber.config.auth import AuthConfigManager
|
|
11
|
+
from sqlsaber.config.database import DatabaseConfigManager
|
|
12
|
+
|
|
13
|
+
console = Console()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def needs_onboarding(database_arg: str | None = None) -> bool:
|
|
17
|
+
"""Check if user needs onboarding.
|
|
18
|
+
|
|
19
|
+
Onboarding is needed if:
|
|
20
|
+
- No database is configured AND no database connection string provided via CLI
|
|
21
|
+
"""
|
|
22
|
+
# If user provided a database argument, skip onboarding
|
|
23
|
+
if database_arg:
|
|
24
|
+
return False
|
|
25
|
+
|
|
26
|
+
# Check if databases are configured
|
|
27
|
+
db_manager = DatabaseConfigManager()
|
|
28
|
+
has_db = db_manager.has_databases()
|
|
29
|
+
|
|
30
|
+
return not has_db
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def welcome_screen() -> None:
|
|
34
|
+
"""Display welcome screen to new users."""
|
|
35
|
+
banner = """
|
|
36
|
+
███████ ██████ ██ ███████ █████ ██████ ███████ ██████
|
|
37
|
+
██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
38
|
+
███████ ██ ██ ██ ███████ ███████ ██████ █████ ██████
|
|
39
|
+
██ ██ ▄▄ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
40
|
+
███████ ██████ ███████ ███████ ██ ██ ██████ ███████ ██ ██
|
|
41
|
+
▀▀
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
console.print(Panel.fit(banner, style="bold blue"))
|
|
45
|
+
console.print()
|
|
46
|
+
|
|
47
|
+
welcome_message = """
|
|
48
|
+
[bold]Welcome to SQLsaber! 🎉[/bold]
|
|
49
|
+
|
|
50
|
+
SQLsaber is an agentic SQL assistant that lets you query your database using natural language.
|
|
51
|
+
|
|
52
|
+
Let's get you set up in just a few steps.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
console.print(Panel(welcome_message.strip(), border_style="blue", padding=(1, 2)))
|
|
56
|
+
console.print()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
async def setup_database_guided() -> str | None:
|
|
60
|
+
"""Guide user through database setup.
|
|
61
|
+
|
|
62
|
+
Returns the name of the configured database or None if cancelled.
|
|
63
|
+
"""
|
|
64
|
+
from sqlsaber.application.db_setup import (
|
|
65
|
+
build_config,
|
|
66
|
+
collect_db_input,
|
|
67
|
+
save_database,
|
|
68
|
+
test_connection,
|
|
69
|
+
)
|
|
70
|
+
from sqlsaber.application.prompts import AsyncPrompter
|
|
71
|
+
|
|
72
|
+
console.print("━" * 80, style="dim")
|
|
73
|
+
console.print("[bold cyan]Step 1 of 2: Database Connection[/bold cyan]")
|
|
74
|
+
console.print("━" * 80, style="dim")
|
|
75
|
+
console.print()
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
# Ask for connection name
|
|
79
|
+
prompter = AsyncPrompter()
|
|
80
|
+
name = await prompter.text(
|
|
81
|
+
"What would you like to name this connection?",
|
|
82
|
+
default="mydb",
|
|
83
|
+
validate=lambda x: bool(x.strip()) or "Name cannot be empty",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if name is None:
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
name = name.strip()
|
|
90
|
+
|
|
91
|
+
# Check if name already exists
|
|
92
|
+
db_manager = DatabaseConfigManager()
|
|
93
|
+
if db_manager.get_database(name):
|
|
94
|
+
console.print(
|
|
95
|
+
f"[yellow]Database connection '{name}' already exists.[/yellow]"
|
|
96
|
+
)
|
|
97
|
+
return name
|
|
98
|
+
|
|
99
|
+
# Collect database input (simplified - no SSL in onboarding)
|
|
100
|
+
db_input = await collect_db_input(
|
|
101
|
+
prompter=prompter, name=name, db_type="postgresql", include_ssl=False
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
if db_input is None:
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
# Build config
|
|
108
|
+
db_config = build_config(db_input)
|
|
109
|
+
|
|
110
|
+
# Test the connection
|
|
111
|
+
console.print(f"[dim]Testing connection to '{name}'...[/dim]")
|
|
112
|
+
connection_success = await test_connection(db_config, db_input.password)
|
|
113
|
+
|
|
114
|
+
if not connection_success:
|
|
115
|
+
retry = await prompter.confirm(
|
|
116
|
+
"Would you like to try again with different settings?", default=True
|
|
117
|
+
)
|
|
118
|
+
if retry:
|
|
119
|
+
return await setup_database_guided()
|
|
120
|
+
else:
|
|
121
|
+
console.print(
|
|
122
|
+
"[yellow]You can add a database later using 'saber db add'[/yellow]"
|
|
123
|
+
)
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
# Save the configuration
|
|
127
|
+
try:
|
|
128
|
+
save_database(db_manager, db_config, db_input.password)
|
|
129
|
+
console.print(f"[green]✓ Connection to '{name}' successful![/green]")
|
|
130
|
+
console.print()
|
|
131
|
+
return name
|
|
132
|
+
except Exception as e:
|
|
133
|
+
console.print(f"[bold red]Error saving database:[/bold red] {e}")
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
except KeyboardInterrupt:
|
|
137
|
+
console.print("\n[yellow]Setup cancelled.[/yellow]")
|
|
138
|
+
return None
|
|
139
|
+
except Exception as e:
|
|
140
|
+
console.print(f"[bold red]Unexpected error:[/bold red] {e}")
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
async def select_model_for_provider(provider: str) -> str | None:
|
|
145
|
+
"""Fetch and let user select a model for the given provider.
|
|
146
|
+
|
|
147
|
+
Returns the selected model ID or None if cancelled/failed.
|
|
148
|
+
"""
|
|
149
|
+
from sqlsaber.application.model_selection import choose_model, fetch_models
|
|
150
|
+
from sqlsaber.application.prompts import AsyncPrompter
|
|
151
|
+
|
|
152
|
+
try:
|
|
153
|
+
console.print()
|
|
154
|
+
console.print(f"[dim]Fetching available {provider.title()} models...[/dim]")
|
|
155
|
+
|
|
156
|
+
model_manager = ModelManager()
|
|
157
|
+
models = await fetch_models(model_manager, providers=[provider])
|
|
158
|
+
|
|
159
|
+
if not models:
|
|
160
|
+
console.print(
|
|
161
|
+
f"[yellow]Could not fetch models for {provider}. Using default.[/yellow]"
|
|
162
|
+
)
|
|
163
|
+
# Use provider-specific default or fallback to Anthropic
|
|
164
|
+
default_model_id = ModelManager.RECOMMENDED_MODELS.get(
|
|
165
|
+
provider, ModelManager.DEFAULT_MODEL
|
|
166
|
+
)
|
|
167
|
+
# Format it properly if we have a recommended model for this provider
|
|
168
|
+
if provider in ModelManager.RECOMMENDED_MODELS:
|
|
169
|
+
return f"{provider}:{ModelManager.RECOMMENDED_MODELS[provider]}"
|
|
170
|
+
return default_model_id
|
|
171
|
+
|
|
172
|
+
prompter = AsyncPrompter()
|
|
173
|
+
console.print()
|
|
174
|
+
selected_model = await choose_model(
|
|
175
|
+
prompter, models, restrict_provider=provider, use_search_filter=True
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
return selected_model
|
|
179
|
+
|
|
180
|
+
except KeyboardInterrupt:
|
|
181
|
+
console.print("\n[yellow]Model selection cancelled.[/yellow]")
|
|
182
|
+
return None
|
|
183
|
+
except Exception as e:
|
|
184
|
+
console.print(f"[yellow]Error selecting model: {e}. Using default.[/yellow]")
|
|
185
|
+
# Fallback to provider default
|
|
186
|
+
if provider in ModelManager.RECOMMENDED_MODELS:
|
|
187
|
+
return f"{provider}:{ModelManager.RECOMMENDED_MODELS[provider]}"
|
|
188
|
+
return ModelManager.DEFAULT_MODEL
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
async def setup_auth_guided() -> tuple[bool, str | None]:
|
|
192
|
+
"""Guide user through auth setup.
|
|
193
|
+
|
|
194
|
+
Returns tuple of (success: bool, selected_model: str | None).
|
|
195
|
+
"""
|
|
196
|
+
from sqlsaber.application.auth_setup import setup_auth
|
|
197
|
+
from sqlsaber.application.prompts import AsyncPrompter
|
|
198
|
+
|
|
199
|
+
console.print("━" * 80, style="dim")
|
|
200
|
+
console.print("[bold cyan]Step 2 of 2: Authentication[/bold cyan]")
|
|
201
|
+
console.print("━" * 80, style="dim")
|
|
202
|
+
console.print()
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
# Run auth setup
|
|
206
|
+
prompter = AsyncPrompter()
|
|
207
|
+
auth_manager = AuthConfigManager()
|
|
208
|
+
api_key_manager = APIKeyManager()
|
|
209
|
+
|
|
210
|
+
success, provider = await setup_auth(
|
|
211
|
+
prompter=prompter,
|
|
212
|
+
auth_manager=auth_manager,
|
|
213
|
+
api_key_manager=api_key_manager,
|
|
214
|
+
allow_oauth=True,
|
|
215
|
+
default_provider="anthropic",
|
|
216
|
+
run_oauth_in_thread=True,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
if not success:
|
|
220
|
+
console.print(
|
|
221
|
+
"[yellow]You can set it up later using 'saber auth setup'[/yellow]"
|
|
222
|
+
)
|
|
223
|
+
console.print()
|
|
224
|
+
return False, None
|
|
225
|
+
|
|
226
|
+
# If auth configured but we don't know the provider (already configured case)
|
|
227
|
+
if provider is None:
|
|
228
|
+
console.print()
|
|
229
|
+
return True, None
|
|
230
|
+
|
|
231
|
+
# Select model for this provider
|
|
232
|
+
selected_model = await select_model_for_provider(provider)
|
|
233
|
+
if selected_model:
|
|
234
|
+
model_manager = ModelManager()
|
|
235
|
+
model_manager.set_model(selected_model)
|
|
236
|
+
console.print(f"[green]✓ Model set to: {selected_model}[/green]")
|
|
237
|
+
console.print()
|
|
238
|
+
return True, selected_model
|
|
239
|
+
|
|
240
|
+
except KeyboardInterrupt:
|
|
241
|
+
console.print("\n[yellow]Setup cancelled.[/yellow]")
|
|
242
|
+
console.print()
|
|
243
|
+
return False, None
|
|
244
|
+
except Exception as e:
|
|
245
|
+
console.print(f"[bold red]Unexpected error:[/bold red] {e}")
|
|
246
|
+
console.print()
|
|
247
|
+
return False, None
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def success_screen(
|
|
251
|
+
database_name: str | None, auth_configured: bool, model_name: str | None = None
|
|
252
|
+
) -> None:
|
|
253
|
+
"""Display success screen after onboarding."""
|
|
254
|
+
console.print("━" * 80, style="dim")
|
|
255
|
+
console.print("[bold green]You're all set! 🚀[/bold green]")
|
|
256
|
+
console.print("━" * 80, style="dim")
|
|
257
|
+
console.print()
|
|
258
|
+
|
|
259
|
+
if database_name and auth_configured:
|
|
260
|
+
console.print(
|
|
261
|
+
f"[green]✓ Database '{database_name}' connected and ready to use[/green]"
|
|
262
|
+
)
|
|
263
|
+
console.print("[green]✓ Authentication configured[/green]")
|
|
264
|
+
if model_name:
|
|
265
|
+
console.print(f"[green]✓ Model: {model_name}[/green]")
|
|
266
|
+
elif database_name:
|
|
267
|
+
console.print(
|
|
268
|
+
f"[green]✓ Database '{database_name}' connected and ready to use[/green]"
|
|
269
|
+
)
|
|
270
|
+
console.print(
|
|
271
|
+
"[yellow]⚠ AI authentication not configured - you'll be prompted when needed[/yellow]"
|
|
272
|
+
)
|
|
273
|
+
elif auth_configured:
|
|
274
|
+
console.print("[green]✓ AI authentication configured[/green]")
|
|
275
|
+
if model_name:
|
|
276
|
+
console.print(f"[green]✓ Model: {model_name}[/green]")
|
|
277
|
+
console.print(
|
|
278
|
+
"[yellow]⚠ No database configured - you'll need to provide one via -d flag[/yellow]"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
console.print()
|
|
282
|
+
console.print("[dim]Starting interactive session...[/dim]")
|
|
283
|
+
console.print()
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
async def run_onboarding() -> bool:
|
|
287
|
+
"""Run the complete onboarding flow.
|
|
288
|
+
|
|
289
|
+
Returns True if onboarding completed successfully (at least database configured),
|
|
290
|
+
False if user cancelled or onboarding failed.
|
|
291
|
+
"""
|
|
292
|
+
try:
|
|
293
|
+
# Welcome screen
|
|
294
|
+
welcome_screen()
|
|
295
|
+
|
|
296
|
+
# Database setup
|
|
297
|
+
database_name = await setup_database_guided()
|
|
298
|
+
|
|
299
|
+
# If user cancelled database setup, exit
|
|
300
|
+
if database_name is None:
|
|
301
|
+
console.print("[yellow]Database setup is required to continue.[/yellow]")
|
|
302
|
+
console.print(
|
|
303
|
+
"[dim]You can also provide a connection string using: saber -d <connection-string>[/dim]"
|
|
304
|
+
)
|
|
305
|
+
return False
|
|
306
|
+
|
|
307
|
+
# Auth setup
|
|
308
|
+
auth_configured, model_name = await setup_auth_guided()
|
|
309
|
+
|
|
310
|
+
# Show success screen
|
|
311
|
+
success_screen(database_name, auth_configured, model_name)
|
|
312
|
+
|
|
313
|
+
return True
|
|
314
|
+
|
|
315
|
+
except KeyboardInterrupt:
|
|
316
|
+
console.print("\n[yellow]Onboarding cancelled.[/yellow]")
|
|
317
|
+
console.print(
|
|
318
|
+
"[dim]You can run setup commands manually:[/dim]\n"
|
|
319
|
+
"[dim] - saber db add <name> # Add database connection[/dim]\n"
|
|
320
|
+
"[dim] - saber auth setup # Configure authentication[/dim]"
|
|
321
|
+
)
|
|
322
|
+
sys.exit(0)
|
|
323
|
+
except Exception as e:
|
|
324
|
+
console.print(f"[bold red]Onboarding failed:[/bold red] {e}")
|
|
325
|
+
return False
|
sqlsaber/cli/streaming.py
CHANGED
|
@@ -8,9 +8,9 @@ rendered via DisplayManager helpers.
|
|
|
8
8
|
import asyncio
|
|
9
9
|
import json
|
|
10
10
|
from functools import singledispatchmethod
|
|
11
|
-
from typing import AsyncIterable
|
|
11
|
+
from typing import TYPE_CHECKING, AsyncIterable
|
|
12
12
|
|
|
13
|
-
from pydantic_ai import
|
|
13
|
+
from pydantic_ai import RunContext
|
|
14
14
|
from pydantic_ai.messages import (
|
|
15
15
|
AgentStreamEvent,
|
|
16
16
|
FunctionToolCallEvent,
|
|
@@ -26,6 +26,9 @@ from rich.console import Console
|
|
|
26
26
|
|
|
27
27
|
from sqlsaber.cli.display import DisplayManager
|
|
28
28
|
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from sqlsaber.agents.pydantic_ai_agent import SQLSaberAgent
|
|
31
|
+
|
|
29
32
|
|
|
30
33
|
class StreamingQueryHandler:
|
|
31
34
|
"""
|
|
@@ -130,7 +133,7 @@ class StreamingQueryHandler:
|
|
|
130
133
|
async def execute_streaming_query(
|
|
131
134
|
self,
|
|
132
135
|
user_query: str,
|
|
133
|
-
|
|
136
|
+
sqlsaber_agent: "SQLSaberAgent",
|
|
134
137
|
cancellation_token: asyncio.Event | None = None,
|
|
135
138
|
message_history: list | None = None,
|
|
136
139
|
):
|
|
@@ -139,21 +142,16 @@ class StreamingQueryHandler:
|
|
|
139
142
|
try:
|
|
140
143
|
# If Anthropic OAuth, inject SQLsaber instructions before the first user prompt
|
|
141
144
|
prepared_prompt: str | list[str] = user_query
|
|
142
|
-
is_oauth = bool(getattr(agent, "_sqlsaber_is_oauth", False))
|
|
143
145
|
no_history = not message_history
|
|
144
|
-
if is_oauth and no_history:
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
db_type = getattr(agent, "_sqlsaber_db_type", "database")
|
|
148
|
-
db_name = getattr(agent, "_sqlsaber_database_name", None)
|
|
149
|
-
instructions = (
|
|
150
|
-
ib.build_instructions(db_type=db_type) if ib is not None else ""
|
|
151
|
-
)
|
|
152
|
-
mem = (
|
|
153
|
-
mm.format_memories_for_prompt(db_name)
|
|
154
|
-
if (mm is not None and db_name)
|
|
155
|
-
else ""
|
|
146
|
+
if sqlsaber_agent.is_oauth and no_history:
|
|
147
|
+
instructions = sqlsaber_agent.instruction_builder.build_instructions(
|
|
148
|
+
db_type=sqlsaber_agent.db_type
|
|
156
149
|
)
|
|
150
|
+
mem = ""
|
|
151
|
+
if sqlsaber_agent.database_name:
|
|
152
|
+
mem = sqlsaber_agent.memory_manager.format_memories_for_prompt(
|
|
153
|
+
sqlsaber_agent.database_name
|
|
154
|
+
)
|
|
157
155
|
parts = [p for p in (instructions, mem) if p and str(p).strip()]
|
|
158
156
|
if parts:
|
|
159
157
|
injected = "\n\n".join(parts)
|
|
@@ -163,7 +161,7 @@ class StreamingQueryHandler:
|
|
|
163
161
|
self.display.live.start_status("Crunching data...")
|
|
164
162
|
|
|
165
163
|
# Run the agent with our event stream handler
|
|
166
|
-
run = await agent.run(
|
|
164
|
+
run = await sqlsaber_agent.agent.run(
|
|
167
165
|
prepared_prompt,
|
|
168
166
|
message_history=message_history,
|
|
169
167
|
event_stream_handler=self._event_stream_handler,
|
sqlsaber/cli/threads.py
CHANGED
|
@@ -148,7 +148,9 @@ def _render_transcript(
|
|
|
148
148
|
)
|
|
149
149
|
else:
|
|
150
150
|
if is_redirected:
|
|
151
|
-
console.print(
|
|
151
|
+
console.print(
|
|
152
|
+
f"**Tool result ({name}):**\n\n{content_str}\n"
|
|
153
|
+
)
|
|
152
154
|
else:
|
|
153
155
|
console.print(
|
|
154
156
|
Panel.fit(
|
|
@@ -159,7 +161,9 @@ def _render_transcript(
|
|
|
159
161
|
)
|
|
160
162
|
except Exception:
|
|
161
163
|
if is_redirected:
|
|
162
|
-
console.print(
|
|
164
|
+
console.print(
|
|
165
|
+
f"**Tool result ({name}):**\n\n{content_str}\n"
|
|
166
|
+
)
|
|
163
167
|
else:
|
|
164
168
|
console.print(
|
|
165
169
|
Panel.fit(
|
|
@@ -258,10 +262,10 @@ def resume(
|
|
|
258
262
|
|
|
259
263
|
async def _run() -> None:
|
|
260
264
|
# Lazy imports to avoid heavy modules at CLI startup
|
|
261
|
-
from sqlsaber.agents import
|
|
265
|
+
from sqlsaber.agents import SQLSaberAgent
|
|
262
266
|
from sqlsaber.cli.interactive import InteractiveSession
|
|
263
267
|
from sqlsaber.config.database import DatabaseConfigManager
|
|
264
|
-
from sqlsaber.database
|
|
268
|
+
from sqlsaber.database import DatabaseConnection
|
|
265
269
|
from sqlsaber.database.resolver import (
|
|
266
270
|
DatabaseResolutionError,
|
|
267
271
|
resolve_database,
|
|
@@ -288,7 +292,7 @@ def resume(
|
|
|
288
292
|
|
|
289
293
|
db_conn = DatabaseConnection(connection_string)
|
|
290
294
|
try:
|
|
291
|
-
|
|
295
|
+
sqlsaber_agent = SQLSaberAgent(db_conn, db_name)
|
|
292
296
|
history = await store.get_thread_messages(thread_id)
|
|
293
297
|
if console.is_terminal:
|
|
294
298
|
console.print(Panel.fit(f"Thread: {thread.id}", border_style="blue"))
|
|
@@ -297,7 +301,7 @@ def resume(
|
|
|
297
301
|
_render_transcript(console, history, None)
|
|
298
302
|
session = InteractiveSession(
|
|
299
303
|
console=console,
|
|
300
|
-
|
|
304
|
+
sqlsaber_agent=sqlsaber_agent,
|
|
301
305
|
db_conn=db_conn,
|
|
302
306
|
database_name=db_name,
|
|
303
307
|
initial_thread_id=thread_id,
|
sqlsaber/config/api_keys.py
CHANGED
|
@@ -19,7 +19,7 @@ class APIKeyManager:
|
|
|
19
19
|
|
|
20
20
|
def get_api_key(self, provider: str) -> str | None:
|
|
21
21
|
"""Get API key for the specified provider using cascading logic."""
|
|
22
|
-
env_var_name = self.
|
|
22
|
+
env_var_name = self.get_env_var_name(provider)
|
|
23
23
|
service_name = self._get_service_name(provider)
|
|
24
24
|
|
|
25
25
|
# 1. Check environment variable first
|
|
@@ -41,7 +41,7 @@ class APIKeyManager:
|
|
|
41
41
|
# 3. Prompt user for API key
|
|
42
42
|
return self._prompt_and_store_key(provider, env_var_name, service_name)
|
|
43
43
|
|
|
44
|
-
def
|
|
44
|
+
def get_env_var_name(self, provider: str) -> str:
|
|
45
45
|
"""Get the expected environment variable name for a provider."""
|
|
46
46
|
# Normalize aliases to canonical provider keys
|
|
47
47
|
key = providers.canonical(provider) or provider
|
sqlsaber/config/settings.py
CHANGED
|
@@ -46,7 +46,10 @@ class ModelConfigManager:
|
|
|
46
46
|
def _load_config(self) -> dict[str, Any]:
|
|
47
47
|
"""Load configuration from file."""
|
|
48
48
|
if not self.config_file.exists():
|
|
49
|
-
return {
|
|
49
|
+
return {
|
|
50
|
+
"model": self.DEFAULT_MODEL,
|
|
51
|
+
"thinking_enabled": False,
|
|
52
|
+
}
|
|
50
53
|
|
|
51
54
|
try:
|
|
52
55
|
with open(self.config_file, "r") as f:
|
|
@@ -54,9 +57,15 @@ class ModelConfigManager:
|
|
|
54
57
|
# Ensure we have a model set
|
|
55
58
|
if "model" not in config:
|
|
56
59
|
config["model"] = self.DEFAULT_MODEL
|
|
60
|
+
# Set defaults for thinking if not present
|
|
61
|
+
if "thinking_enabled" not in config:
|
|
62
|
+
config["thinking_enabled"] = False
|
|
57
63
|
return config
|
|
58
64
|
except (json.JSONDecodeError, IOError):
|
|
59
|
-
return {
|
|
65
|
+
return {
|
|
66
|
+
"model": self.DEFAULT_MODEL,
|
|
67
|
+
"thinking_enabled": False,
|
|
68
|
+
}
|
|
60
69
|
|
|
61
70
|
def _save_config(self, config: dict[str, Any]) -> None:
|
|
62
71
|
"""Save configuration to file."""
|
|
@@ -76,6 +85,17 @@ class ModelConfigManager:
|
|
|
76
85
|
config["model"] = model
|
|
77
86
|
self._save_config(config)
|
|
78
87
|
|
|
88
|
+
def get_thinking_enabled(self) -> bool:
|
|
89
|
+
"""Get whether thinking is enabled."""
|
|
90
|
+
config = self._load_config()
|
|
91
|
+
return config.get("thinking_enabled", False)
|
|
92
|
+
|
|
93
|
+
def set_thinking_enabled(self, enabled: bool) -> None:
|
|
94
|
+
"""Set whether thinking is enabled."""
|
|
95
|
+
config = self._load_config()
|
|
96
|
+
config["thinking_enabled"] = enabled
|
|
97
|
+
self._save_config(config)
|
|
98
|
+
|
|
79
99
|
|
|
80
100
|
class Config:
|
|
81
101
|
"""Configuration class for SQLSaber."""
|
|
@@ -86,6 +106,9 @@ class Config:
|
|
|
86
106
|
self.api_key_manager = APIKeyManager()
|
|
87
107
|
self.auth_config_manager = AuthConfigManager()
|
|
88
108
|
|
|
109
|
+
# Thinking configuration
|
|
110
|
+
self.thinking_enabled = self.model_config_manager.get_thinking_enabled()
|
|
111
|
+
|
|
89
112
|
# Authentication method (API key or Anthropic OAuth)
|
|
90
113
|
self.auth_method = self.auth_config_manager.get_auth_method()
|
|
91
114
|
|
sqlsaber/database/__init__.py
CHANGED
|
@@ -1,9 +1,63 @@
|
|
|
1
1
|
"""Database module for SQLSaber."""
|
|
2
2
|
|
|
3
|
-
from .
|
|
3
|
+
from .base import (
|
|
4
|
+
DEFAULT_QUERY_TIMEOUT,
|
|
5
|
+
BaseDatabaseConnection,
|
|
6
|
+
BaseSchemaIntrospector,
|
|
7
|
+
ColumnInfo,
|
|
8
|
+
ForeignKeyInfo,
|
|
9
|
+
IndexInfo,
|
|
10
|
+
QueryTimeoutError,
|
|
11
|
+
SchemaInfo,
|
|
12
|
+
)
|
|
13
|
+
from .csv import CSVConnection, CSVSchemaIntrospector
|
|
14
|
+
from .duckdb import DuckDBConnection, DuckDBSchemaIntrospector
|
|
15
|
+
from .mysql import MySQLConnection, MySQLSchemaIntrospector
|
|
16
|
+
from .postgresql import PostgreSQLConnection, PostgreSQLSchemaIntrospector
|
|
4
17
|
from .schema import SchemaManager
|
|
18
|
+
from .sqlite import SQLiteConnection, SQLiteSchemaIntrospector
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def DatabaseConnection(connection_string: str) -> BaseDatabaseConnection:
|
|
22
|
+
"""Factory function to create appropriate database connection based on connection string."""
|
|
23
|
+
if connection_string.startswith("postgresql://"):
|
|
24
|
+
return PostgreSQLConnection(connection_string)
|
|
25
|
+
elif connection_string.startswith("mysql://"):
|
|
26
|
+
return MySQLConnection(connection_string)
|
|
27
|
+
elif connection_string.startswith("sqlite:///"):
|
|
28
|
+
return SQLiteConnection(connection_string)
|
|
29
|
+
elif connection_string.startswith("duckdb://"):
|
|
30
|
+
return DuckDBConnection(connection_string)
|
|
31
|
+
elif connection_string.startswith("csv:///"):
|
|
32
|
+
return CSVConnection(connection_string)
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Unsupported database type in connection string: {connection_string}"
|
|
36
|
+
)
|
|
37
|
+
|
|
5
38
|
|
|
6
39
|
__all__ = [
|
|
40
|
+
# Base classes and types
|
|
41
|
+
"BaseDatabaseConnection",
|
|
42
|
+
"BaseSchemaIntrospector",
|
|
43
|
+
"ColumnInfo",
|
|
44
|
+
"DEFAULT_QUERY_TIMEOUT",
|
|
45
|
+
"ForeignKeyInfo",
|
|
46
|
+
"IndexInfo",
|
|
47
|
+
"QueryTimeoutError",
|
|
48
|
+
"SchemaInfo",
|
|
49
|
+
# Concrete implementations
|
|
50
|
+
"PostgreSQLConnection",
|
|
51
|
+
"MySQLConnection",
|
|
52
|
+
"SQLiteConnection",
|
|
53
|
+
"DuckDBConnection",
|
|
54
|
+
"CSVConnection",
|
|
55
|
+
"PostgreSQLSchemaIntrospector",
|
|
56
|
+
"MySQLSchemaIntrospector",
|
|
57
|
+
"SQLiteSchemaIntrospector",
|
|
58
|
+
"DuckDBSchemaIntrospector",
|
|
59
|
+
"CSVSchemaIntrospector",
|
|
60
|
+
# Factory function and manager
|
|
7
61
|
"DatabaseConnection",
|
|
8
62
|
"SchemaManager",
|
|
9
63
|
]
|