sqlsaber 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/__init__.py +2 -2
- sqlsaber/agents/base.py +1 -1
- sqlsaber/agents/mcp.py +1 -1
- sqlsaber/agents/pydantic_ai_agent.py +207 -135
- sqlsaber/application/__init__.py +1 -0
- sqlsaber/application/auth_setup.py +164 -0
- sqlsaber/application/db_setup.py +223 -0
- sqlsaber/application/model_selection.py +98 -0
- sqlsaber/application/prompts.py +115 -0
- sqlsaber/cli/auth.py +22 -50
- sqlsaber/cli/commands.py +22 -28
- sqlsaber/cli/completers.py +2 -0
- sqlsaber/cli/database.py +25 -86
- sqlsaber/cli/display.py +29 -9
- sqlsaber/cli/interactive.py +150 -127
- sqlsaber/cli/models.py +18 -28
- sqlsaber/cli/onboarding.py +325 -0
- sqlsaber/cli/streaming.py +15 -17
- sqlsaber/cli/threads.py +10 -6
- sqlsaber/config/api_keys.py +2 -2
- sqlsaber/config/settings.py +25 -2
- sqlsaber/database/__init__.py +55 -1
- sqlsaber/database/base.py +124 -0
- sqlsaber/database/csv.py +133 -0
- sqlsaber/database/duckdb.py +313 -0
- sqlsaber/database/mysql.py +345 -0
- sqlsaber/database/postgresql.py +328 -0
- sqlsaber/database/schema.py +66 -963
- sqlsaber/database/sqlite.py +258 -0
- sqlsaber/mcp/mcp.py +1 -1
- sqlsaber/tools/sql_tools.py +1 -1
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/METADATA +43 -9
- sqlsaber-0.27.0.dist-info/RECORD +58 -0
- sqlsaber/database/connection.py +0 -535
- sqlsaber-0.25.0.dist-info/RECORD +0 -47
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
"""Shared database setup logic for onboarding and CLI."""
|
|
2
|
+
|
|
3
|
+
import getpass
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from rich.console import Console
|
|
8
|
+
|
|
9
|
+
from sqlsaber.application.prompts import Prompter
|
|
10
|
+
from sqlsaber.config.database import DatabaseConfig, DatabaseConfigManager
|
|
11
|
+
|
|
12
|
+
console = Console()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class DatabaseInput:
|
|
17
|
+
"""Input data for database configuration."""
|
|
18
|
+
|
|
19
|
+
name: str
|
|
20
|
+
type: str
|
|
21
|
+
host: str
|
|
22
|
+
port: int
|
|
23
|
+
database: str
|
|
24
|
+
username: str
|
|
25
|
+
password: str | None
|
|
26
|
+
ssl_mode: str | None = None
|
|
27
|
+
ssl_ca: str | None = None
|
|
28
|
+
ssl_cert: str | None = None
|
|
29
|
+
ssl_key: str | None = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
async def collect_db_input(
|
|
33
|
+
prompter: Prompter,
|
|
34
|
+
name: str,
|
|
35
|
+
db_type: str = "postgresql",
|
|
36
|
+
include_ssl: bool = True,
|
|
37
|
+
) -> DatabaseInput | None:
|
|
38
|
+
"""Collect database connection details interactively.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
prompter: Prompter instance for interaction
|
|
42
|
+
name: Database connection name
|
|
43
|
+
db_type: Initial database type (can be changed via prompt)
|
|
44
|
+
include_ssl: Whether to prompt for SSL configuration
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
DatabaseInput with collected values or None if cancelled
|
|
48
|
+
"""
|
|
49
|
+
# Ask for database type
|
|
50
|
+
db_type = await prompter.select(
|
|
51
|
+
"Database type:",
|
|
52
|
+
choices=["postgresql", "mysql", "sqlite", "duckdb"],
|
|
53
|
+
default=db_type,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if db_type is None:
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
# Handle file-based databases
|
|
60
|
+
if db_type in {"sqlite", "duckdb"}:
|
|
61
|
+
database_path = await prompter.path(
|
|
62
|
+
f"{db_type.upper()} file path:", only_directories=False
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
if database_path is None:
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
database = str(Path(database_path).expanduser().resolve())
|
|
69
|
+
host = "localhost"
|
|
70
|
+
port = 0
|
|
71
|
+
username = db_type
|
|
72
|
+
password = ""
|
|
73
|
+
ssl_mode = None
|
|
74
|
+
ssl_ca = None
|
|
75
|
+
ssl_cert = None
|
|
76
|
+
ssl_key = None
|
|
77
|
+
|
|
78
|
+
else:
|
|
79
|
+
# PostgreSQL/MySQL need connection details
|
|
80
|
+
host = await prompter.text("Host:", default="localhost")
|
|
81
|
+
if host is None:
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
default_port = 5432 if db_type == "postgresql" else 3306
|
|
85
|
+
port_str = await prompter.text("Port:", default=str(default_port))
|
|
86
|
+
if port_str is None:
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
port = int(port_str)
|
|
91
|
+
except ValueError:
|
|
92
|
+
console.print("[red]Invalid port number. Using default.[/red]")
|
|
93
|
+
port = default_port
|
|
94
|
+
|
|
95
|
+
database = await prompter.text("Database name:")
|
|
96
|
+
if database is None:
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
username = await prompter.text("Username:")
|
|
100
|
+
if username is None:
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
password = getpass.getpass("Password (stored in your OS keychain): ")
|
|
104
|
+
|
|
105
|
+
ssl_mode = None
|
|
106
|
+
ssl_ca = None
|
|
107
|
+
ssl_cert = None
|
|
108
|
+
ssl_key = None
|
|
109
|
+
|
|
110
|
+
# Ask for SSL configuration if enabled
|
|
111
|
+
if include_ssl:
|
|
112
|
+
configure_ssl = await prompter.confirm(
|
|
113
|
+
"Configure SSL/TLS settings?", default=False
|
|
114
|
+
)
|
|
115
|
+
if configure_ssl:
|
|
116
|
+
if db_type == "postgresql":
|
|
117
|
+
ssl_mode = await prompter.select(
|
|
118
|
+
"SSL mode for PostgreSQL:",
|
|
119
|
+
choices=[
|
|
120
|
+
"disable",
|
|
121
|
+
"allow",
|
|
122
|
+
"prefer",
|
|
123
|
+
"require",
|
|
124
|
+
"verify-ca",
|
|
125
|
+
"verify-full",
|
|
126
|
+
],
|
|
127
|
+
default="prefer",
|
|
128
|
+
)
|
|
129
|
+
elif db_type == "mysql":
|
|
130
|
+
ssl_mode = await prompter.select(
|
|
131
|
+
"SSL mode for MySQL:",
|
|
132
|
+
choices=[
|
|
133
|
+
"DISABLED",
|
|
134
|
+
"PREFERRED",
|
|
135
|
+
"REQUIRED",
|
|
136
|
+
"VERIFY_CA",
|
|
137
|
+
"VERIFY_IDENTITY",
|
|
138
|
+
],
|
|
139
|
+
default="PREFERRED",
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
if ssl_mode and ssl_mode not in ["disable", "DISABLED"]:
|
|
143
|
+
specify_certs = await prompter.confirm(
|
|
144
|
+
"Specify SSL certificate files?", default=False
|
|
145
|
+
)
|
|
146
|
+
if specify_certs:
|
|
147
|
+
ssl_ca = await prompter.path("SSL CA certificate file:")
|
|
148
|
+
specify_client = await prompter.confirm(
|
|
149
|
+
"Specify client certificate?", default=False
|
|
150
|
+
)
|
|
151
|
+
if specify_client:
|
|
152
|
+
ssl_cert = await prompter.path(
|
|
153
|
+
"SSL client certificate file:"
|
|
154
|
+
)
|
|
155
|
+
ssl_key = await prompter.path(
|
|
156
|
+
"SSL client private key file:"
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
return DatabaseInput(
|
|
160
|
+
name=name,
|
|
161
|
+
type=db_type,
|
|
162
|
+
host=host,
|
|
163
|
+
port=port,
|
|
164
|
+
database=database,
|
|
165
|
+
username=username,
|
|
166
|
+
password=password,
|
|
167
|
+
ssl_mode=ssl_mode,
|
|
168
|
+
ssl_ca=ssl_ca,
|
|
169
|
+
ssl_cert=ssl_cert,
|
|
170
|
+
ssl_key=ssl_key,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def build_config(db_input: DatabaseInput) -> DatabaseConfig:
|
|
175
|
+
"""Build DatabaseConfig from DatabaseInput."""
|
|
176
|
+
return DatabaseConfig(
|
|
177
|
+
name=db_input.name,
|
|
178
|
+
type=db_input.type,
|
|
179
|
+
host=db_input.host,
|
|
180
|
+
port=db_input.port,
|
|
181
|
+
database=db_input.database,
|
|
182
|
+
username=db_input.username,
|
|
183
|
+
ssl_mode=db_input.ssl_mode,
|
|
184
|
+
ssl_ca=db_input.ssl_ca,
|
|
185
|
+
ssl_cert=db_input.ssl_cert,
|
|
186
|
+
ssl_key=db_input.ssl_key,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
async def test_connection(config: DatabaseConfig, password: str | None) -> bool:
|
|
191
|
+
"""Test database connection.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
config: DatabaseConfig to test
|
|
195
|
+
password: Password for connection (not stored in config yet)
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
True if connection successful, False otherwise
|
|
199
|
+
"""
|
|
200
|
+
from sqlsaber.database import DatabaseConnection
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
connection_string = config.to_connection_string()
|
|
204
|
+
db_conn = DatabaseConnection(connection_string)
|
|
205
|
+
await db_conn.execute_query("SELECT 1 as test")
|
|
206
|
+
await db_conn.close()
|
|
207
|
+
return True
|
|
208
|
+
except Exception as e:
|
|
209
|
+
console.print(f"[bold red]Connection failed:[/bold red] {e}", style="red")
|
|
210
|
+
return False
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def save_database(
|
|
214
|
+
config_manager: DatabaseConfigManager, config: DatabaseConfig, password: str | None
|
|
215
|
+
) -> None:
|
|
216
|
+
"""Save database configuration.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
config_manager: DatabaseConfigManager instance
|
|
220
|
+
config: DatabaseConfig to save
|
|
221
|
+
password: Password to store in keyring (if provided)
|
|
222
|
+
"""
|
|
223
|
+
config_manager.add_database(config, password if password else None)
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Shared model selection logic for onboarding and CLI."""
|
|
2
|
+
|
|
3
|
+
from questionary import Choice
|
|
4
|
+
from rich.console import Console
|
|
5
|
+
|
|
6
|
+
from sqlsaber.application.prompts import Prompter
|
|
7
|
+
from sqlsaber.cli.models import ModelManager
|
|
8
|
+
|
|
9
|
+
console = Console()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
async def fetch_models(
|
|
13
|
+
model_manager: ModelManager, providers: list[str] | None = None
|
|
14
|
+
) -> list[dict]:
|
|
15
|
+
"""Fetch available models from models.dev API."""
|
|
16
|
+
return await model_manager.fetch_available_models(providers=providers)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
async def choose_model(
|
|
20
|
+
prompter: Prompter,
|
|
21
|
+
models: list[dict],
|
|
22
|
+
restrict_provider: str | None = None,
|
|
23
|
+
use_search_filter: bool = True,
|
|
24
|
+
) -> str | None:
|
|
25
|
+
"""Interactive model selection with recommended models prioritized.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
prompter: Prompter instance for interaction
|
|
29
|
+
models: List of model dicts from fetch_models
|
|
30
|
+
restrict_provider: If set, only show models from this provider and use provider-specific recommendation
|
|
31
|
+
use_search_filter: Enable search filter for large lists
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Selected model ID (provider:model_id) or None if cancelled
|
|
35
|
+
"""
|
|
36
|
+
if not models:
|
|
37
|
+
console.print("[yellow]No models available[/yellow]")
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
# Filter by provider if restricted
|
|
41
|
+
if restrict_provider:
|
|
42
|
+
models = [m for m in models if m.get("provider") == restrict_provider]
|
|
43
|
+
if not models:
|
|
44
|
+
console.print(
|
|
45
|
+
f"[yellow]No models available for {restrict_provider}[/yellow]"
|
|
46
|
+
)
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
# Get recommended model for the provider
|
|
50
|
+
recommended_id = None
|
|
51
|
+
if restrict_provider and restrict_provider in ModelManager.RECOMMENDED_MODELS:
|
|
52
|
+
recommended_id = ModelManager.RECOMMENDED_MODELS[restrict_provider]
|
|
53
|
+
|
|
54
|
+
# Build choices
|
|
55
|
+
choices = []
|
|
56
|
+
recommended_index = 0
|
|
57
|
+
|
|
58
|
+
for i, model in enumerate(models):
|
|
59
|
+
model_id_without_provider = model["id"].split(":", 1)[1]
|
|
60
|
+
is_recommended = recommended_id == model_id_without_provider
|
|
61
|
+
|
|
62
|
+
choice_text = model["name"]
|
|
63
|
+
if is_recommended:
|
|
64
|
+
choice_text += " (Recommended)"
|
|
65
|
+
recommended_index = i
|
|
66
|
+
elif model["description"]:
|
|
67
|
+
desc_short = model["description"][:40]
|
|
68
|
+
choice_text += (
|
|
69
|
+
f" ({desc_short}...)"
|
|
70
|
+
if len(model["description"]) > 40
|
|
71
|
+
else f" ({desc_short})"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
choices.append(Choice(choice_text, value=model["id"]))
|
|
75
|
+
|
|
76
|
+
# Move recommended model to top if it exists
|
|
77
|
+
if recommended_index > 0:
|
|
78
|
+
choices.insert(0, choices.pop(recommended_index))
|
|
79
|
+
|
|
80
|
+
# Prompt user
|
|
81
|
+
selected_model = await prompter.select(
|
|
82
|
+
"Select a model:",
|
|
83
|
+
choices=choices,
|
|
84
|
+
use_search_filter=use_search_filter,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if selected_model:
|
|
88
|
+
return selected_model
|
|
89
|
+
|
|
90
|
+
# User cancelled, return recommended or first available
|
|
91
|
+
if recommended_id and restrict_provider:
|
|
92
|
+
return f"{restrict_provider}:{recommended_id}"
|
|
93
|
+
return models[0]["id"] if models else None
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def set_model(model_manager: ModelManager, model_id: str) -> bool:
|
|
97
|
+
"""Set the current model."""
|
|
98
|
+
return model_manager.set_model(model_id)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Prompter abstraction for sync/async questionary interactions."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, Callable
|
|
5
|
+
|
|
6
|
+
import questionary
|
|
7
|
+
from questionary import Choice
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Prompter(ABC):
|
|
11
|
+
"""Abstract base class for interactive prompting."""
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
async def text(
|
|
15
|
+
self,
|
|
16
|
+
message: str,
|
|
17
|
+
default: str = "",
|
|
18
|
+
validate: Callable[[str], bool | str] | None = None,
|
|
19
|
+
) -> str | None:
|
|
20
|
+
"""Prompt for text input."""
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
async def select(
|
|
25
|
+
self,
|
|
26
|
+
message: str,
|
|
27
|
+
choices: list[str] | list[Choice] | list[dict],
|
|
28
|
+
default: Any = None,
|
|
29
|
+
use_search_filter: bool = False,
|
|
30
|
+
use_jk_keys: bool = True,
|
|
31
|
+
) -> Any:
|
|
32
|
+
"""Prompt for selection from choices."""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
async def confirm(self, message: str, default: bool = False) -> bool | None:
|
|
37
|
+
"""Prompt for yes/no confirmation."""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
async def path(self, message: str, only_directories: bool = False) -> str | None:
|
|
42
|
+
"""Prompt for file/directory path."""
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class AsyncPrompter(Prompter):
|
|
47
|
+
"""Async prompter using questionary.ask_async() for onboarding."""
|
|
48
|
+
|
|
49
|
+
async def text(
|
|
50
|
+
self,
|
|
51
|
+
message: str,
|
|
52
|
+
default: str = "",
|
|
53
|
+
validate: Callable[[str], bool | str] | None = None,
|
|
54
|
+
) -> str | None:
|
|
55
|
+
return await questionary.text(
|
|
56
|
+
message, default=default, validate=validate
|
|
57
|
+
).ask_async()
|
|
58
|
+
|
|
59
|
+
async def select(
|
|
60
|
+
self,
|
|
61
|
+
message: str,
|
|
62
|
+
choices: list[str] | list[Choice] | list[dict],
|
|
63
|
+
default: Any = None,
|
|
64
|
+
use_search_filter: bool = True,
|
|
65
|
+
use_jk_keys: bool = False,
|
|
66
|
+
) -> Any:
|
|
67
|
+
return await questionary.select(
|
|
68
|
+
message,
|
|
69
|
+
choices=choices,
|
|
70
|
+
default=default,
|
|
71
|
+
use_search_filter=use_search_filter,
|
|
72
|
+
use_jk_keys=use_jk_keys,
|
|
73
|
+
).ask_async()
|
|
74
|
+
|
|
75
|
+
async def confirm(self, message: str, default: bool = False) -> bool | None:
|
|
76
|
+
return await questionary.confirm(message, default=default).ask_async()
|
|
77
|
+
|
|
78
|
+
async def path(self, message: str, only_directories: bool = False) -> str | None:
|
|
79
|
+
return await questionary.path(
|
|
80
|
+
message, only_directories=only_directories
|
|
81
|
+
).ask_async()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class SyncPrompter(Prompter):
|
|
85
|
+
"""Sync prompter using questionary.ask() for CLI commands."""
|
|
86
|
+
|
|
87
|
+
async def text(
|
|
88
|
+
self,
|
|
89
|
+
message: str,
|
|
90
|
+
default: str = "",
|
|
91
|
+
validate: Callable[[str], bool | str] | None = None,
|
|
92
|
+
) -> str | None:
|
|
93
|
+
return questionary.text(message, default=default, validate=validate).ask()
|
|
94
|
+
|
|
95
|
+
async def select(
|
|
96
|
+
self,
|
|
97
|
+
message: str,
|
|
98
|
+
choices: list[str] | list[Choice] | list[dict],
|
|
99
|
+
default: Any = None,
|
|
100
|
+
use_search_filter: bool = True,
|
|
101
|
+
use_jk_keys: bool = False,
|
|
102
|
+
) -> Any:
|
|
103
|
+
return questionary.select(
|
|
104
|
+
message,
|
|
105
|
+
choices=choices,
|
|
106
|
+
default=default,
|
|
107
|
+
use_search_filter=use_search_filter,
|
|
108
|
+
use_jk_keys=use_jk_keys,
|
|
109
|
+
).ask()
|
|
110
|
+
|
|
111
|
+
async def confirm(self, message: str, default: bool = False) -> bool | None:
|
|
112
|
+
return questionary.confirm(message, default=default).ask()
|
|
113
|
+
|
|
114
|
+
async def path(self, message: str, only_directories: bool = False) -> str | None:
|
|
115
|
+
return questionary.path(message, only_directories=only_directories).ask()
|
sqlsaber/cli/auth.py
CHANGED
|
@@ -10,7 +10,6 @@ from rich.console import Console
|
|
|
10
10
|
from sqlsaber.config import providers
|
|
11
11
|
from sqlsaber.config.api_keys import APIKeyManager
|
|
12
12
|
from sqlsaber.config.auth import AuthConfigManager, AuthMethod
|
|
13
|
-
from sqlsaber.config.oauth_flow import AnthropicOAuthFlow
|
|
14
13
|
from sqlsaber.config.oauth_tokens import OAuthTokenManager
|
|
15
14
|
|
|
16
15
|
# Global instances for CLI commands
|
|
@@ -27,60 +26,33 @@ auth_app = cyclopts.App(
|
|
|
27
26
|
@auth_app.command
|
|
28
27
|
def setup():
|
|
29
28
|
"""Configure authentication for SQLsaber (API keys and Anthropic OAuth)."""
|
|
30
|
-
|
|
29
|
+
import asyncio
|
|
31
30
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
choices=providers.all_keys(),
|
|
35
|
-
).ask()
|
|
31
|
+
from sqlsaber.application.auth_setup import setup_auth
|
|
32
|
+
from sqlsaber.application.prompts import AsyncPrompter
|
|
36
33
|
|
|
37
|
-
|
|
38
|
-
console.print("[yellow]Setup cancelled.[/yellow]")
|
|
39
|
-
return
|
|
34
|
+
console.print("\n[bold]SQLsaber Authentication Setup[/bold]\n")
|
|
40
35
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
if method_choice == AuthMethod.CLAUDE_PRO:
|
|
52
|
-
flow = AnthropicOAuthFlow()
|
|
53
|
-
if flow.authenticate():
|
|
54
|
-
config_manager.set_auth_method(AuthMethod.CLAUDE_PRO)
|
|
55
|
-
console.print(
|
|
56
|
-
"\n[bold green]✓ Anthropic OAuth configured successfully![/bold green]"
|
|
57
|
-
)
|
|
58
|
-
else:
|
|
59
|
-
console.print("\n[red]✗ Anthropic OAuth setup failed.[/red]")
|
|
60
|
-
console.print(
|
|
61
|
-
"You can change this anytime by running [cyan]saber auth setup[/cyan] again."
|
|
62
|
-
)
|
|
63
|
-
return
|
|
64
|
-
|
|
65
|
-
# API key flow (all providers + Anthropic when selected above)
|
|
66
|
-
api_key_manager = APIKeyManager()
|
|
67
|
-
env_var = api_key_manager._get_env_var_name(provider)
|
|
68
|
-
console.print("\nTo configure your API key, you can either:")
|
|
69
|
-
console.print(f"• Set the {env_var} environment variable")
|
|
70
|
-
console.print("• Let SQLsaber prompt you for the key when needed (stored securely)")
|
|
71
|
-
|
|
72
|
-
# Fetch/store key (cascades env -> keyring -> prompt)
|
|
73
|
-
api_key = api_key_manager.get_api_key(provider)
|
|
74
|
-
if api_key:
|
|
75
|
-
config_manager.set_auth_method(AuthMethod.API_KEY)
|
|
76
|
-
console.print(
|
|
77
|
-
f"\n[bold green]✓ {provider.title()} API key configured successfully![/bold green]"
|
|
36
|
+
async def run_setup():
|
|
37
|
+
prompter = AsyncPrompter()
|
|
38
|
+
api_key_manager = APIKeyManager()
|
|
39
|
+
success, provider = await setup_auth(
|
|
40
|
+
prompter=prompter,
|
|
41
|
+
auth_manager=config_manager,
|
|
42
|
+
api_key_manager=api_key_manager,
|
|
43
|
+
allow_oauth=True,
|
|
44
|
+
default_provider="anthropic",
|
|
45
|
+
run_oauth_in_thread=False,
|
|
78
46
|
)
|
|
79
|
-
|
|
80
|
-
|
|
47
|
+
return success, provider
|
|
48
|
+
|
|
49
|
+
success, _ = asyncio.run(run_setup())
|
|
50
|
+
|
|
51
|
+
if not success:
|
|
52
|
+
console.print("\n[yellow]No authentication configured.[/yellow]")
|
|
81
53
|
|
|
82
54
|
console.print(
|
|
83
|
-
"
|
|
55
|
+
"\nYou can change this anytime by running [cyan]saber auth setup[/cyan] again."
|
|
84
56
|
)
|
|
85
57
|
|
|
86
58
|
|
|
@@ -109,7 +81,7 @@ def status():
|
|
|
109
81
|
# Include OAuth status
|
|
110
82
|
if OAuthTokenManager().has_oauth_token("anthropic"):
|
|
111
83
|
console.print("> anthropic (oauth): [green]configured[/green]")
|
|
112
|
-
env_var = api_key_manager.
|
|
84
|
+
env_var = api_key_manager.get_env_var_name(provider)
|
|
113
85
|
service = api_key_manager._get_service_name(provider)
|
|
114
86
|
from_env = bool(os.getenv(env_var))
|
|
115
87
|
from_keyring = bool(keyring.get_password(service, provider))
|
sqlsaber/cli/commands.py
CHANGED
|
@@ -11,6 +11,7 @@ from sqlsaber.cli.auth import create_auth_app
|
|
|
11
11
|
from sqlsaber.cli.database import create_db_app
|
|
12
12
|
from sqlsaber.cli.memory import create_memory_app
|
|
13
13
|
from sqlsaber.cli.models import create_models_app
|
|
14
|
+
from sqlsaber.cli.onboarding import needs_onboarding, run_onboarding
|
|
14
15
|
from sqlsaber.cli.threads import create_threads_app
|
|
15
16
|
|
|
16
17
|
# Lazy imports - only import what's needed for CLI parsing
|
|
@@ -75,7 +76,7 @@ def query(
|
|
|
75
76
|
query_text: Annotated[
|
|
76
77
|
str | None,
|
|
77
78
|
cyclopts.Parameter(
|
|
78
|
-
help="
|
|
79
|
+
help="Question in natural language (if not provided, reads from stdin or starts interactive mode)",
|
|
79
80
|
),
|
|
80
81
|
] = None,
|
|
81
82
|
database: Annotated[
|
|
@@ -85,6 +86,7 @@ def query(
|
|
|
85
86
|
help="Database connection name, file path (CSV/SQLite/DuckDB), or connection string (postgresql://, mysql://, duckdb://) (uses default if not specified)",
|
|
86
87
|
),
|
|
87
88
|
] = None,
|
|
89
|
+
thinking: bool = False,
|
|
88
90
|
):
|
|
89
91
|
"""Run a query against the database or start interactive mode.
|
|
90
92
|
|
|
@@ -109,16 +111,11 @@ def query(
|
|
|
109
111
|
async def run_session():
|
|
110
112
|
# Import heavy dependencies only when actually running a query
|
|
111
113
|
# This is only done to speed up startup time
|
|
112
|
-
from sqlsaber.agents import
|
|
114
|
+
from sqlsaber.agents import SQLSaberAgent
|
|
113
115
|
from sqlsaber.cli.interactive import InteractiveSession
|
|
114
116
|
from sqlsaber.cli.streaming import StreamingQueryHandler
|
|
115
|
-
from sqlsaber.database
|
|
116
|
-
CSVConnection,
|
|
117
|
+
from sqlsaber.database import (
|
|
117
118
|
DatabaseConnection,
|
|
118
|
-
DuckDBConnection,
|
|
119
|
-
MySQLConnection,
|
|
120
|
-
PostgreSQLConnection,
|
|
121
|
-
SQLiteConnection,
|
|
122
119
|
)
|
|
123
120
|
from sqlsaber.database.resolver import DatabaseResolutionError, resolve_database
|
|
124
121
|
from sqlsaber.threads import ThreadStorage
|
|
@@ -132,6 +129,16 @@ def query(
|
|
|
132
129
|
# If stdin was empty, fall back to interactive mode
|
|
133
130
|
actual_query = None
|
|
134
131
|
|
|
132
|
+
# Check if onboarding is needed (only for interactive mode or when no database is configured)
|
|
133
|
+
if needs_onboarding(database):
|
|
134
|
+
# Run onboarding flow
|
|
135
|
+
onboarding_success = await run_onboarding()
|
|
136
|
+
if not onboarding_success:
|
|
137
|
+
# User cancelled or onboarding failed
|
|
138
|
+
raise CLIError(
|
|
139
|
+
"Setup incomplete. Please configure your database and try again."
|
|
140
|
+
)
|
|
141
|
+
|
|
135
142
|
# Resolve database from CLI input
|
|
136
143
|
try:
|
|
137
144
|
resolved = resolve_database(database, config_manager)
|
|
@@ -147,45 +154,32 @@ def query(
|
|
|
147
154
|
raise CLIError(f"Error creating database connection: {e}")
|
|
148
155
|
|
|
149
156
|
# Create pydantic-ai agent instance with database name for memory context
|
|
150
|
-
|
|
157
|
+
sqlsaber_agent = SQLSaberAgent(db_conn, db_name, thinking_enabled=thinking)
|
|
151
158
|
|
|
152
159
|
try:
|
|
153
160
|
if actual_query:
|
|
154
161
|
# Single query mode with streaming
|
|
155
162
|
streaming_handler = StreamingQueryHandler(console)
|
|
156
|
-
|
|
157
|
-
if isinstance(db_conn, PostgreSQLConnection):
|
|
158
|
-
db_type = "PostgreSQL"
|
|
159
|
-
elif isinstance(db_conn, MySQLConnection):
|
|
160
|
-
db_type = "MySQL"
|
|
161
|
-
elif isinstance(db_conn, DuckDBConnection):
|
|
162
|
-
db_type = "DuckDB"
|
|
163
|
-
elif isinstance(db_conn, SQLiteConnection):
|
|
164
|
-
db_type = "SQLite"
|
|
165
|
-
elif isinstance(db_conn, CSVConnection):
|
|
166
|
-
db_type = "DuckDB"
|
|
167
|
-
else:
|
|
168
|
-
db_type = "database"
|
|
163
|
+
db_type = sqlsaber_agent.db_type
|
|
169
164
|
console.print(
|
|
170
165
|
f"[bold blue]Connected to:[/bold blue] {db_name} ({db_type})\n"
|
|
171
166
|
)
|
|
172
167
|
run = await streaming_handler.execute_streaming_query(
|
|
173
|
-
actual_query,
|
|
168
|
+
actual_query, sqlsaber_agent
|
|
174
169
|
)
|
|
175
170
|
# Persist non-interactive run as a thread snapshot so it can be resumed later
|
|
176
171
|
try:
|
|
177
172
|
if run is not None:
|
|
178
173
|
threads = ThreadStorage()
|
|
179
|
-
# Extract title and model name
|
|
180
|
-
title = actual_query
|
|
181
|
-
model_name: str | None = agent.model.model_name
|
|
182
174
|
|
|
183
175
|
thread_id = await threads.save_snapshot(
|
|
184
176
|
messages_json=run.all_messages_json(),
|
|
185
177
|
database_name=db_name,
|
|
186
178
|
)
|
|
187
179
|
await threads.save_metadata(
|
|
188
|
-
thread_id=thread_id,
|
|
180
|
+
thread_id=thread_id,
|
|
181
|
+
title=actual_query,
|
|
182
|
+
model_name=sqlsaber_agent.agent.model.model_name,
|
|
189
183
|
)
|
|
190
184
|
await threads.end_thread(thread_id)
|
|
191
185
|
console.print(
|
|
@@ -198,7 +192,7 @@ def query(
|
|
|
198
192
|
await threads.prune_threads()
|
|
199
193
|
else:
|
|
200
194
|
# Interactive mode
|
|
201
|
-
session = InteractiveSession(console,
|
|
195
|
+
session = InteractiveSession(console, sqlsaber_agent, db_conn, db_name)
|
|
202
196
|
await session.run()
|
|
203
197
|
|
|
204
198
|
finally:
|
sqlsaber/cli/completers.py
CHANGED
|
@@ -19,6 +19,8 @@ class SlashCommandCompleter(Completer):
|
|
|
19
19
|
("clear", "Clear conversation history"),
|
|
20
20
|
("exit", "Exit the interactive session"),
|
|
21
21
|
("quit", "Exit the interactive session"),
|
|
22
|
+
("thinking on", "Enable extended thinking/reasoning"),
|
|
23
|
+
("thinking off", "Disable extended thinking/reasoning"),
|
|
22
24
|
]
|
|
23
25
|
|
|
24
26
|
# Yield completions that match the partial command
|