sqlsaber 0.26.0__py3-none-any.whl → 0.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/application/__init__.py +1 -0
- sqlsaber/application/auth_setup.py +164 -0
- sqlsaber/application/db_setup.py +223 -0
- sqlsaber/application/model_selection.py +98 -0
- sqlsaber/application/prompts.py +115 -0
- sqlsaber/cli/auth.py +22 -50
- sqlsaber/cli/commands.py +11 -0
- sqlsaber/cli/database.py +24 -85
- sqlsaber/cli/display.py +1 -1
- sqlsaber/cli/interactive.py +140 -124
- sqlsaber/cli/models.py +18 -28
- sqlsaber/cli/onboarding.py +325 -0
- sqlsaber/config/api_keys.py +2 -2
- {sqlsaber-0.26.0.dist-info → sqlsaber-0.27.0.dist-info}/METADATA +1 -1
- {sqlsaber-0.26.0.dist-info → sqlsaber-0.27.0.dist-info}/RECORD +18 -12
- {sqlsaber-0.26.0.dist-info → sqlsaber-0.27.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.26.0.dist-info → sqlsaber-0.27.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.26.0.dist-info → sqlsaber-0.27.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/cli/commands.py
CHANGED
|
@@ -11,6 +11,7 @@ from sqlsaber.cli.auth import create_auth_app
|
|
|
11
11
|
from sqlsaber.cli.database import create_db_app
|
|
12
12
|
from sqlsaber.cli.memory import create_memory_app
|
|
13
13
|
from sqlsaber.cli.models import create_models_app
|
|
14
|
+
from sqlsaber.cli.onboarding import needs_onboarding, run_onboarding
|
|
14
15
|
from sqlsaber.cli.threads import create_threads_app
|
|
15
16
|
|
|
16
17
|
# Lazy imports - only import what's needed for CLI parsing
|
|
@@ -128,6 +129,16 @@ def query(
|
|
|
128
129
|
# If stdin was empty, fall back to interactive mode
|
|
129
130
|
actual_query = None
|
|
130
131
|
|
|
132
|
+
# Check if onboarding is needed (only for interactive mode or when no database is configured)
|
|
133
|
+
if needs_onboarding(database):
|
|
134
|
+
# Run onboarding flow
|
|
135
|
+
onboarding_success = await run_onboarding()
|
|
136
|
+
if not onboarding_success:
|
|
137
|
+
# User cancelled or onboarding failed
|
|
138
|
+
raise CLIError(
|
|
139
|
+
"Setup incomplete. Please configure your database and try again."
|
|
140
|
+
)
|
|
141
|
+
|
|
131
142
|
# Resolve database from CLI input
|
|
132
143
|
try:
|
|
133
144
|
resolved = resolve_database(database, config_manager)
|
sqlsaber/cli/database.py
CHANGED
|
@@ -81,95 +81,34 @@ def add(
|
|
|
81
81
|
|
|
82
82
|
if interactive:
|
|
83
83
|
# Interactive mode - prompt for all required fields
|
|
84
|
-
|
|
84
|
+
from sqlsaber.application.db_setup import collect_db_input
|
|
85
|
+
from sqlsaber.application.prompts import AsyncPrompter
|
|
85
86
|
|
|
86
|
-
|
|
87
|
-
if not type or type == "postgresql":
|
|
88
|
-
type = questionary.select(
|
|
89
|
-
"Database type:",
|
|
90
|
-
choices=["postgresql", "mysql", "sqlite", "duckdb"],
|
|
91
|
-
default="postgresql",
|
|
92
|
-
).ask()
|
|
93
|
-
|
|
94
|
-
if type in {"sqlite", "duckdb"}:
|
|
95
|
-
# SQLite/DuckDB only need database file path
|
|
96
|
-
database = database or questionary.path("Database file path:").ask()
|
|
97
|
-
database = str(Path(database).expanduser().resolve())
|
|
98
|
-
host = "localhost"
|
|
99
|
-
port = 0
|
|
100
|
-
username = type
|
|
101
|
-
password = ""
|
|
102
|
-
else:
|
|
103
|
-
# PostgreSQL/MySQL need connection details
|
|
104
|
-
host = host or questionary.text("Host:", default="localhost").ask()
|
|
87
|
+
console.print(f"[bold]Adding database connection: {name}[/bold]")
|
|
105
88
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
89
|
+
async def collect_input():
|
|
90
|
+
prompter = AsyncPrompter()
|
|
91
|
+
return await collect_db_input(
|
|
92
|
+
prompter=prompter, name=name, db_type=type, include_ssl=True
|
|
109
93
|
)
|
|
110
94
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
"require",
|
|
129
|
-
"verify-ca",
|
|
130
|
-
"verify-full",
|
|
131
|
-
],
|
|
132
|
-
default="prefer",
|
|
133
|
-
).ask()
|
|
134
|
-
)
|
|
135
|
-
elif type == "mysql":
|
|
136
|
-
ssl_mode = (
|
|
137
|
-
ssl_mode
|
|
138
|
-
or questionary.select(
|
|
139
|
-
"SSL mode for MySQL:",
|
|
140
|
-
choices=[
|
|
141
|
-
"DISABLED",
|
|
142
|
-
"PREFERRED",
|
|
143
|
-
"REQUIRED",
|
|
144
|
-
"VERIFY_CA",
|
|
145
|
-
"VERIFY_IDENTITY",
|
|
146
|
-
],
|
|
147
|
-
default="PREFERRED",
|
|
148
|
-
).ask()
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
if ssl_mode and ssl_mode not in ["disable", "DISABLED"]:
|
|
152
|
-
if questionary.confirm(
|
|
153
|
-
"Specify SSL certificate files?", default=False
|
|
154
|
-
).ask():
|
|
155
|
-
ssl_ca = (
|
|
156
|
-
ssl_ca or questionary.path("SSL CA certificate file:").ask()
|
|
157
|
-
)
|
|
158
|
-
if questionary.confirm(
|
|
159
|
-
"Specify client certificate?", default=False
|
|
160
|
-
).ask():
|
|
161
|
-
ssl_cert = (
|
|
162
|
-
ssl_cert
|
|
163
|
-
or questionary.path(
|
|
164
|
-
"SSL client certificate file:"
|
|
165
|
-
).ask()
|
|
166
|
-
)
|
|
167
|
-
ssl_key = (
|
|
168
|
-
ssl_key
|
|
169
|
-
or questionary.path(
|
|
170
|
-
"SSL client private key file:"
|
|
171
|
-
).ask()
|
|
172
|
-
)
|
|
95
|
+
db_input = asyncio.run(collect_input())
|
|
96
|
+
|
|
97
|
+
if db_input is None:
|
|
98
|
+
console.print("[yellow]Operation cancelled[/yellow]")
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
# Extract values from db_input
|
|
102
|
+
type = db_input.type
|
|
103
|
+
host = db_input.host
|
|
104
|
+
port = db_input.port
|
|
105
|
+
database = db_input.database
|
|
106
|
+
username = db_input.username
|
|
107
|
+
password = db_input.password
|
|
108
|
+
ssl_mode = db_input.ssl_mode
|
|
109
|
+
ssl_ca = db_input.ssl_ca
|
|
110
|
+
ssl_cert = db_input.ssl_cert
|
|
111
|
+
ssl_key = db_input.ssl_key
|
|
173
112
|
else:
|
|
174
113
|
# Non-interactive mode - use provided values or defaults
|
|
175
114
|
if type == "sqlite":
|
sqlsaber/cli/display.py
CHANGED
|
@@ -109,7 +109,7 @@ class LiveMarkdownRenderer:
|
|
|
109
109
|
# Print the complete markdown to scroll-back for permanent reference
|
|
110
110
|
if buf:
|
|
111
111
|
if kind == ThinkingPart:
|
|
112
|
-
self.console.print(
|
|
112
|
+
self.console.print(Markdown(buf, style="dim"))
|
|
113
113
|
else:
|
|
114
114
|
self.console.print(Markdown(buf))
|
|
115
115
|
|
sqlsaber/cli/interactive.py
CHANGED
|
@@ -35,26 +35,12 @@ if TYPE_CHECKING:
|
|
|
35
35
|
from sqlsaber.agents.pydantic_ai_agent import SQLSaberAgent
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
def bottom_toolbar():
|
|
39
|
-
return [
|
|
40
|
-
(
|
|
41
|
-
"class:bottom-toolbar",
|
|
42
|
-
" Use 'Esc-Enter' or 'Meta-Enter' to submit.",
|
|
43
|
-
)
|
|
44
|
-
]
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
style = Style.from_dict(
|
|
48
|
-
{
|
|
49
|
-
"frame.border": "#ebbcba",
|
|
50
|
-
"bottom-toolbar": "#ebbcba bg:#21202e",
|
|
51
|
-
}
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
|
|
55
38
|
class InteractiveSession:
|
|
56
39
|
"""Manages interactive CLI sessions."""
|
|
57
40
|
|
|
41
|
+
exit_commands = {"/exit", "/quit", "exit", "quit"}
|
|
42
|
+
resume_command_template = "saber threads resume {thread_id}"
|
|
43
|
+
|
|
58
44
|
def __init__(
|
|
59
45
|
self,
|
|
60
46
|
console: Console,
|
|
@@ -80,28 +66,33 @@ class InteractiveSession:
|
|
|
80
66
|
self._thread_id: str | None = initial_thread_id
|
|
81
67
|
self.first_message = not self._thread_id
|
|
82
68
|
|
|
83
|
-
def
|
|
84
|
-
"""
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
else "SQLite"
|
|
97
|
-
if isinstance(self.db_conn, SQLiteConnection)
|
|
98
|
-
else "database"
|
|
69
|
+
def _history_path(self) -> Path:
|
|
70
|
+
"""Get the history file path, ensuring directory exists."""
|
|
71
|
+
history_dir = Path(platformdirs.user_config_dir("sqlsaber"))
|
|
72
|
+
history_dir.mkdir(parents=True, exist_ok=True)
|
|
73
|
+
return history_dir / "history"
|
|
74
|
+
|
|
75
|
+
def _prompt_style(self) -> Style:
|
|
76
|
+
"""Get the prompt style configuration."""
|
|
77
|
+
return Style.from_dict(
|
|
78
|
+
{
|
|
79
|
+
"frame.border": "gray",
|
|
80
|
+
"bottom-toolbar": "white bg:#21202e",
|
|
81
|
+
}
|
|
99
82
|
)
|
|
100
83
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
84
|
+
def _bottom_toolbar(self):
|
|
85
|
+
"""Get the bottom toolbar text."""
|
|
86
|
+
return [
|
|
87
|
+
(
|
|
88
|
+
"class:bottom-toolbar",
|
|
89
|
+
" Use 'Esc-Enter' or 'Meta-Enter' to submit.",
|
|
90
|
+
)
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
def _banner(self) -> str:
|
|
94
|
+
"""Get the ASCII banner."""
|
|
95
|
+
return """
|
|
105
96
|
███████ ██████ ██ ███████ █████ ██████ ███████ ██████
|
|
106
97
|
██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
107
98
|
███████ ██ ██ ██ ███████ ███████ ██████ █████ ██████
|
|
@@ -109,35 +100,110 @@ class InteractiveSession:
|
|
|
109
100
|
███████ ██████ ███████ ███████ ██ ██ ██████ ███████ ██ ██
|
|
110
101
|
▀▀
|
|
111
102
|
"""
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
dedent("""
|
|
103
|
+
|
|
104
|
+
def _instructions(self) -> str:
|
|
105
|
+
"""Get the instruction text."""
|
|
106
|
+
return dedent("""
|
|
117
107
|
- Use `/` for slash commands
|
|
118
108
|
- Type `@` to get table name completions
|
|
119
109
|
- Start message with `#` to add something to agent's memory
|
|
120
110
|
- Use `Ctrl+C` to interrupt and `Ctrl+D` to exit
|
|
121
111
|
""")
|
|
122
|
-
)
|
|
123
|
-
)
|
|
124
112
|
|
|
113
|
+
def _db_type_name(self) -> str:
|
|
114
|
+
"""Get human-readable database type name."""
|
|
115
|
+
mapping = {
|
|
116
|
+
PostgreSQLConnection: "PostgreSQL",
|
|
117
|
+
MySQLConnection: "MySQL",
|
|
118
|
+
DuckDBConnection: "DuckDB",
|
|
119
|
+
CSVConnection: "DuckDB",
|
|
120
|
+
SQLiteConnection: "SQLite",
|
|
121
|
+
}
|
|
122
|
+
for cls, name in mapping.items():
|
|
123
|
+
if isinstance(self.db_conn, cls):
|
|
124
|
+
return name
|
|
125
|
+
return "database"
|
|
126
|
+
|
|
127
|
+
def _resume_hint(self, thread_id: str) -> str:
|
|
128
|
+
"""Build resume command hint."""
|
|
129
|
+
return self.resume_command_template.format(thread_id=thread_id)
|
|
130
|
+
|
|
131
|
+
def show_welcome_message(self):
|
|
132
|
+
"""Display welcome message for interactive mode."""
|
|
133
|
+
if self.first_message:
|
|
134
|
+
self.console.print(Panel.fit(self._banner()))
|
|
135
|
+
self.console.print(Markdown(self._instructions()))
|
|
136
|
+
|
|
137
|
+
db_name = self.database_name or "Unknown"
|
|
125
138
|
self.console.print(
|
|
126
|
-
f"[bold blue]\n\nConnected to:[/bold blue] {db_name} ({
|
|
139
|
+
f"[bold blue]\n\nConnected to:[/bold blue] {db_name} ({self._db_type_name()})\n"
|
|
127
140
|
)
|
|
128
|
-
|
|
141
|
+
|
|
129
142
|
if self._thread_id:
|
|
130
143
|
self.console.print(f"[dim]Resuming thread:[/dim] {self._thread_id}\n")
|
|
131
144
|
|
|
132
|
-
async def
|
|
133
|
-
"""End thread and display
|
|
134
|
-
# Print resume hint if there is an active thread
|
|
145
|
+
async def _end_thread(self):
|
|
146
|
+
"""End thread and display resume hint."""
|
|
135
147
|
if self._thread_id:
|
|
136
148
|
await self._threads.end_thread(self._thread_id)
|
|
137
149
|
self.console.print(
|
|
138
|
-
f"[dim]You can continue this thread using:[/dim]
|
|
150
|
+
f"[dim]You can continue this thread using:[/dim] {self._resume_hint(self._thread_id)}"
|
|
139
151
|
)
|
|
140
152
|
|
|
153
|
+
async def _handle_memory(self, content: str):
|
|
154
|
+
"""Handle memory addition command."""
|
|
155
|
+
if not content:
|
|
156
|
+
self.console.print("[yellow]Empty memory content after '#'[/yellow]\n")
|
|
157
|
+
return
|
|
158
|
+
|
|
159
|
+
try:
|
|
160
|
+
mm = self.sqlsaber_agent.memory_manager
|
|
161
|
+
if mm and self.database_name:
|
|
162
|
+
memory = mm.add_memory(self.database_name, content)
|
|
163
|
+
self.console.print(f"[green]✓ Memory added:[/green] {content}")
|
|
164
|
+
self.console.print(f"[dim]Memory ID: {memory.id}[/dim]\n")
|
|
165
|
+
else:
|
|
166
|
+
self.console.print(
|
|
167
|
+
"[yellow]Could not add memory (no database context)[/yellow]\n"
|
|
168
|
+
)
|
|
169
|
+
except Exception as exc:
|
|
170
|
+
self.console.print(f"[yellow]Could not add memory:[/yellow] {exc}\n")
|
|
171
|
+
|
|
172
|
+
async def _cmd_clear(self):
|
|
173
|
+
"""Clear conversation history."""
|
|
174
|
+
self.message_history = []
|
|
175
|
+
try:
|
|
176
|
+
if self._thread_id:
|
|
177
|
+
await self._threads.end_thread(self._thread_id)
|
|
178
|
+
except Exception:
|
|
179
|
+
pass
|
|
180
|
+
self.console.print("[green]Conversation history cleared.[/green]\n")
|
|
181
|
+
self._thread_id = None
|
|
182
|
+
self.first_message = True
|
|
183
|
+
|
|
184
|
+
async def _cmd_thinking_on(self):
|
|
185
|
+
"""Enable thinking mode."""
|
|
186
|
+
self.sqlsaber_agent.set_thinking(enabled=True)
|
|
187
|
+
self.console.print("[green]✓ Thinking enabled[/green]\n")
|
|
188
|
+
|
|
189
|
+
async def _cmd_thinking_off(self):
|
|
190
|
+
"""Disable thinking mode."""
|
|
191
|
+
self.sqlsaber_agent.set_thinking(enabled=False)
|
|
192
|
+
self.console.print("[green]✓ Thinking disabled[/green]\n")
|
|
193
|
+
|
|
194
|
+
async def _handle_command(self, user_query: str) -> bool:
|
|
195
|
+
"""Handle slash commands. Returns True if command was handled."""
|
|
196
|
+
if user_query == "/clear":
|
|
197
|
+
await self._cmd_clear()
|
|
198
|
+
return True
|
|
199
|
+
if user_query == "/thinking on":
|
|
200
|
+
await self._cmd_thinking_on()
|
|
201
|
+
return True
|
|
202
|
+
if user_query == "/thinking off":
|
|
203
|
+
await self._cmd_thinking_off()
|
|
204
|
+
return True
|
|
205
|
+
return False
|
|
206
|
+
|
|
141
207
|
async def _update_table_cache(self):
|
|
142
208
|
"""Update the table completer cache with fresh data."""
|
|
143
209
|
try:
|
|
@@ -170,6 +236,10 @@ class InteractiveSession:
|
|
|
170
236
|
# If there's an error, just use empty cache
|
|
171
237
|
self.table_completer.update_cache([])
|
|
172
238
|
|
|
239
|
+
async def before_prompt_loop(self):
|
|
240
|
+
"""Hook to refresh context before prompt loop."""
|
|
241
|
+
await self._update_table_cache()
|
|
242
|
+
|
|
173
243
|
async def _execute_query_with_cancellation(self, user_query: str):
|
|
174
244
|
"""Execute a query with cancellation support."""
|
|
175
245
|
# Create cancellation token
|
|
@@ -207,6 +277,7 @@ class InteractiveSession:
|
|
|
207
277
|
title=user_query,
|
|
208
278
|
model_name=self.sqlsaber_agent.agent.model.model_name,
|
|
209
279
|
)
|
|
280
|
+
self.first_message = False
|
|
210
281
|
except Exception:
|
|
211
282
|
pass
|
|
212
283
|
finally:
|
|
@@ -218,15 +289,9 @@ class InteractiveSession:
|
|
|
218
289
|
async def run(self):
|
|
219
290
|
"""Run the interactive session loop."""
|
|
220
291
|
self.show_welcome_message()
|
|
292
|
+
await self.before_prompt_loop()
|
|
221
293
|
|
|
222
|
-
|
|
223
|
-
await self._update_table_cache()
|
|
224
|
-
|
|
225
|
-
session = PromptSession(
|
|
226
|
-
history=FileHistory(
|
|
227
|
-
Path(platformdirs.user_config_dir("sqlsaber")) / "history"
|
|
228
|
-
)
|
|
229
|
-
)
|
|
294
|
+
session = PromptSession(history=FileHistory(self._history_path()))
|
|
230
295
|
|
|
231
296
|
while True:
|
|
232
297
|
try:
|
|
@@ -238,81 +303,32 @@ class InteractiveSession:
|
|
|
238
303
|
SlashCommandCompleter(), self.table_completer
|
|
239
304
|
),
|
|
240
305
|
show_frame=True,
|
|
241
|
-
bottom_toolbar=
|
|
242
|
-
style=
|
|
306
|
+
bottom_toolbar=self._bottom_toolbar,
|
|
307
|
+
style=self._prompt_style(),
|
|
243
308
|
)
|
|
244
309
|
|
|
245
310
|
if not user_query:
|
|
246
311
|
continue
|
|
247
312
|
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
or user_query.startswith("/quit")
|
|
313
|
+
# Handle exit commands
|
|
314
|
+
if user_query in self.exit_commands or any(
|
|
315
|
+
user_query.startswith(cmd) for cmd in self.exit_commands
|
|
252
316
|
):
|
|
253
|
-
await self.
|
|
317
|
+
await self._end_thread()
|
|
254
318
|
break
|
|
255
319
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
self.message_history = []
|
|
259
|
-
# End current thread (if any) so the next turn creates a fresh one
|
|
260
|
-
try:
|
|
261
|
-
if self._thread_id:
|
|
262
|
-
await self._threads.end_thread(self._thread_id)
|
|
263
|
-
except Exception:
|
|
264
|
-
pass
|
|
265
|
-
self.console.print("[green]Conversation history cleared.[/green]\n")
|
|
266
|
-
# Do not print resume hint when clearing; a new thread will be created on next turn
|
|
267
|
-
self._thread_id = None
|
|
268
|
-
continue
|
|
269
|
-
|
|
270
|
-
# Thinking commands
|
|
271
|
-
if user_query == "/thinking on":
|
|
272
|
-
self.sqlsaber_agent.set_thinking(enabled=True)
|
|
273
|
-
self.console.print("[green]✓ Thinking enabled[/green]\n")
|
|
320
|
+
# Handle slash commands
|
|
321
|
+
if await self._handle_command(user_query):
|
|
274
322
|
continue
|
|
275
323
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
self.
|
|
324
|
+
# Handle memory addition
|
|
325
|
+
if user_query.strip().startswith("#"):
|
|
326
|
+
await self._handle_memory(user_query.strip()[1:].strip())
|
|
279
327
|
continue
|
|
280
328
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
memory_content = memory_text[1:].strip() # Remove # and trim
|
|
285
|
-
if memory_content:
|
|
286
|
-
# Add memory via the agent's memory manager
|
|
287
|
-
try:
|
|
288
|
-
mm = self.sqlsaber_agent.memory_manager
|
|
289
|
-
if mm and self.database_name:
|
|
290
|
-
memory = mm.add_memory(
|
|
291
|
-
self.database_name, memory_content
|
|
292
|
-
)
|
|
293
|
-
self.console.print(
|
|
294
|
-
f"[green]✓ Memory added:[/green] {memory_content}"
|
|
295
|
-
)
|
|
296
|
-
self.console.print(
|
|
297
|
-
f"[dim]Memory ID: {memory.id}[/dim]\n"
|
|
298
|
-
)
|
|
299
|
-
else:
|
|
300
|
-
self.console.print(
|
|
301
|
-
"[yellow]Could not add memory (no database context)[/yellow]\n"
|
|
302
|
-
)
|
|
303
|
-
except Exception:
|
|
304
|
-
self.console.print(
|
|
305
|
-
"[yellow]Could not add memory[/yellow]\n"
|
|
306
|
-
)
|
|
307
|
-
else:
|
|
308
|
-
self.console.print(
|
|
309
|
-
"[yellow]Empty memory content after '#'[/yellow]\n"
|
|
310
|
-
)
|
|
311
|
-
continue
|
|
312
|
-
|
|
313
|
-
# Execute query with cancellation support
|
|
314
|
-
await self._execute_query_with_cancellation(user_query)
|
|
315
|
-
self.display.show_newline() # Empty line for readability
|
|
329
|
+
# Execute query with cancellation support
|
|
330
|
+
await self._execute_query_with_cancellation(user_query)
|
|
331
|
+
self.display.show_newline()
|
|
316
332
|
|
|
317
333
|
except KeyboardInterrupt:
|
|
318
334
|
# Handle Ctrl+C - cancel current task if running
|
|
@@ -331,7 +347,7 @@ class InteractiveSession:
|
|
|
331
347
|
)
|
|
332
348
|
except EOFError:
|
|
333
349
|
# Exit when Ctrl+D is pressed
|
|
334
|
-
await self.
|
|
350
|
+
await self._end_thread()
|
|
335
351
|
break
|
|
336
|
-
except Exception as
|
|
337
|
-
self.console.print(f"[bold red]Error:[/bold red] {
|
|
352
|
+
except Exception as exc:
|
|
353
|
+
self.console.print(f"[bold red]Error:[/bold red] {exc}")
|
sqlsaber/cli/models.py
CHANGED
|
@@ -25,11 +25,20 @@ models_app = cyclopts.App(
|
|
|
25
25
|
class ModelManager:
|
|
26
26
|
"""Manages AI model configuration and fetching."""
|
|
27
27
|
|
|
28
|
-
DEFAULT_MODEL = "anthropic:claude-sonnet-4-
|
|
28
|
+
DEFAULT_MODEL = "anthropic:claude-sonnet-4-5-20250929"
|
|
29
29
|
MODELS_API_URL = "https://models.dev/api.json"
|
|
30
30
|
# Providers come from central registry
|
|
31
31
|
SUPPORTED_PROVIDERS = providers.all_keys()
|
|
32
32
|
|
|
33
|
+
RECOMMENDED_MODELS = {
|
|
34
|
+
"anthropic": "claude-sonnet-4-5-20250929",
|
|
35
|
+
"openai": "gpt-5",
|
|
36
|
+
"google": "gemini-2.5-pro",
|
|
37
|
+
"groq": "llama-3-3-70b-versatile",
|
|
38
|
+
"mistral": "mistral-large-latest",
|
|
39
|
+
"cohere": "command-r-plus",
|
|
40
|
+
}
|
|
41
|
+
|
|
33
42
|
async def fetch_available_models(
|
|
34
43
|
self, providers: list[str] | None = None
|
|
35
44
|
) -> list[dict]:
|
|
@@ -180,39 +189,20 @@ def set():
|
|
|
180
189
|
"""Set the AI model to use."""
|
|
181
190
|
|
|
182
191
|
async def interactive_set():
|
|
192
|
+
from sqlsaber.application.model_selection import choose_model, fetch_models
|
|
193
|
+
from sqlsaber.application.prompts import AsyncPrompter
|
|
194
|
+
|
|
183
195
|
console.print("[blue]Fetching available models...[/blue]")
|
|
184
|
-
models = await model_manager
|
|
196
|
+
models = await fetch_models(model_manager)
|
|
185
197
|
|
|
186
198
|
if not models:
|
|
187
199
|
console.print("[red]Failed to fetch models. Cannot set model.[/red]")
|
|
188
200
|
sys.exit(1)
|
|
189
201
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
prov = model.get("provider", "?")
|
|
195
|
-
choice_text = f"[{prov}] {model['id']} - {model['name']}"
|
|
196
|
-
if model["description"]:
|
|
197
|
-
choice_text += f" ({model['description'][:50]}{'...' if len(model['description']) > 50 else ''})"
|
|
198
|
-
|
|
199
|
-
choices.append({"name": choice_text, "value": model["id"]})
|
|
200
|
-
|
|
201
|
-
# Get current model to set as default
|
|
202
|
-
current_model = model_manager.get_current_model()
|
|
203
|
-
default_index = 0
|
|
204
|
-
for i, choice in enumerate(choices):
|
|
205
|
-
if choice["value"] == current_model:
|
|
206
|
-
default_index = i
|
|
207
|
-
break
|
|
208
|
-
|
|
209
|
-
selected_model = await questionary.select(
|
|
210
|
-
"Select a model:",
|
|
211
|
-
choices=choices,
|
|
212
|
-
use_search_filter=True,
|
|
213
|
-
use_jk_keys=False, # Disable j/k keys when using search filter
|
|
214
|
-
default=choices[default_index] if choices else None,
|
|
215
|
-
).ask_async()
|
|
202
|
+
prompter = AsyncPrompter()
|
|
203
|
+
selected_model = await choose_model(
|
|
204
|
+
prompter, models, restrict_provider=None, use_search_filter=True
|
|
205
|
+
)
|
|
216
206
|
|
|
217
207
|
if selected_model:
|
|
218
208
|
if model_manager.set_model(selected_model):
|