sqlsaber 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlsaber/application/__init__.py +1 -0
- sqlsaber/application/auth_setup.py +164 -0
- sqlsaber/application/db_setup.py +222 -0
- sqlsaber/application/model_selection.py +98 -0
- sqlsaber/application/prompts.py +115 -0
- sqlsaber/cli/auth.py +24 -52
- sqlsaber/cli/commands.py +13 -2
- sqlsaber/cli/database.py +26 -87
- sqlsaber/cli/display.py +59 -40
- sqlsaber/cli/interactive.py +138 -131
- sqlsaber/cli/memory.py +2 -2
- sqlsaber/cli/models.py +20 -30
- sqlsaber/cli/onboarding.py +325 -0
- sqlsaber/cli/streaming.py +1 -1
- sqlsaber/cli/threads.py +35 -16
- sqlsaber/config/api_keys.py +4 -4
- sqlsaber/config/oauth_flow.py +3 -2
- sqlsaber/config/oauth_tokens.py +3 -5
- sqlsaber/database/base.py +6 -0
- sqlsaber/database/csv.py +5 -0
- sqlsaber/database/duckdb.py +5 -0
- sqlsaber/database/mysql.py +5 -0
- sqlsaber/database/postgresql.py +5 -0
- sqlsaber/database/sqlite.py +5 -0
- sqlsaber/theme/__init__.py +5 -0
- sqlsaber/theme/manager.py +219 -0
- sqlsaber/tools/sql_guard.py +225 -0
- sqlsaber/tools/sql_tools.py +10 -35
- {sqlsaber-0.26.0.dist-info → sqlsaber-0.28.0.dist-info}/METADATA +2 -1
- sqlsaber-0.28.0.dist-info/RECORD +61 -0
- sqlsaber-0.26.0.dist-info/RECORD +0 -52
- {sqlsaber-0.26.0.dist-info → sqlsaber-0.28.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.26.0.dist-info → sqlsaber-0.28.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.26.0.dist-info → sqlsaber-0.28.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/cli/interactive.py
CHANGED
|
@@ -9,7 +9,6 @@ import platformdirs
|
|
|
9
9
|
from prompt_toolkit import PromptSession
|
|
10
10
|
from prompt_toolkit.history import FileHistory
|
|
11
11
|
from prompt_toolkit.patch_stdout import patch_stdout
|
|
12
|
-
from prompt_toolkit.styles import Style
|
|
13
12
|
from rich.console import Console
|
|
14
13
|
from rich.markdown import Markdown
|
|
15
14
|
from rich.panel import Panel
|
|
@@ -29,32 +28,19 @@ from sqlsaber.database import (
|
|
|
29
28
|
SQLiteConnection,
|
|
30
29
|
)
|
|
31
30
|
from sqlsaber.database.schema import SchemaManager
|
|
31
|
+
from sqlsaber.theme.manager import get_theme_manager
|
|
32
32
|
from sqlsaber.threads import ThreadStorage
|
|
33
33
|
|
|
34
34
|
if TYPE_CHECKING:
|
|
35
35
|
from sqlsaber.agents.pydantic_ai_agent import SQLSaberAgent
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
def bottom_toolbar():
|
|
39
|
-
return [
|
|
40
|
-
(
|
|
41
|
-
"class:bottom-toolbar",
|
|
42
|
-
" Use 'Esc-Enter' or 'Meta-Enter' to submit.",
|
|
43
|
-
)
|
|
44
|
-
]
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
style = Style.from_dict(
|
|
48
|
-
{
|
|
49
|
-
"frame.border": "#ebbcba",
|
|
50
|
-
"bottom-toolbar": "#ebbcba bg:#21202e",
|
|
51
|
-
}
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
|
|
55
38
|
class InteractiveSession:
|
|
56
39
|
"""Manages interactive CLI sessions."""
|
|
57
40
|
|
|
41
|
+
exit_commands = {"/exit", "/quit", "exit", "quit"}
|
|
42
|
+
resume_command_template = "saber threads resume {thread_id}"
|
|
43
|
+
|
|
58
44
|
def __init__(
|
|
59
45
|
self,
|
|
60
46
|
console: Console,
|
|
@@ -75,33 +61,30 @@ class InteractiveSession:
|
|
|
75
61
|
self.cancellation_token: asyncio.Event | None = None
|
|
76
62
|
self.table_completer = TableNameCompleter()
|
|
77
63
|
self.message_history: list | None = initial_history or []
|
|
64
|
+
self.tm = get_theme_manager()
|
|
78
65
|
# Conversation Thread persistence
|
|
79
66
|
self._threads = ThreadStorage()
|
|
80
67
|
self._thread_id: str | None = initial_thread_id
|
|
81
68
|
self.first_message = not self._thread_id
|
|
82
69
|
|
|
83
|
-
def
|
|
84
|
-
"""
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
if isinstance(self.db_conn, SQLiteConnection)
|
|
98
|
-
else "database"
|
|
99
|
-
)
|
|
70
|
+
def _history_path(self) -> Path:
|
|
71
|
+
"""Get the history file path, ensuring directory exists."""
|
|
72
|
+
history_dir = Path(platformdirs.user_config_dir("sqlsaber"))
|
|
73
|
+
history_dir.mkdir(parents=True, exist_ok=True)
|
|
74
|
+
return history_dir / "history"
|
|
75
|
+
|
|
76
|
+
def _bottom_toolbar(self):
|
|
77
|
+
"""Get the bottom toolbar text."""
|
|
78
|
+
return [
|
|
79
|
+
(
|
|
80
|
+
"class:bottom-toolbar",
|
|
81
|
+
" Use 'Esc-Enter' or 'Meta-Enter' to submit.",
|
|
82
|
+
)
|
|
83
|
+
]
|
|
100
84
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
"""
|
|
85
|
+
def _banner(self) -> str:
|
|
86
|
+
"""Get the ASCII banner."""
|
|
87
|
+
return """
|
|
105
88
|
███████ ██████ ██ ███████ █████ ██████ ███████ ██████
|
|
106
89
|
██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
107
90
|
███████ ██ ██ ██ ███████ ███████ ██████ █████ ██████
|
|
@@ -109,35 +92,110 @@ class InteractiveSession:
|
|
|
109
92
|
███████ ██████ ███████ ███████ ██ ██ ██████ ███████ ██ ██
|
|
110
93
|
▀▀
|
|
111
94
|
"""
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
dedent("""
|
|
95
|
+
|
|
96
|
+
def _instructions(self) -> str:
|
|
97
|
+
"""Get the instruction text."""
|
|
98
|
+
return dedent("""
|
|
117
99
|
- Use `/` for slash commands
|
|
118
100
|
- Type `@` to get table name completions
|
|
119
101
|
- Start message with `#` to add something to agent's memory
|
|
120
102
|
- Use `Ctrl+C` to interrupt and `Ctrl+D` to exit
|
|
121
103
|
""")
|
|
122
|
-
)
|
|
123
|
-
)
|
|
124
104
|
|
|
105
|
+
def _db_type_name(self) -> str:
|
|
106
|
+
"""Get human-readable database type name."""
|
|
107
|
+
mapping = {
|
|
108
|
+
PostgreSQLConnection: "PostgreSQL",
|
|
109
|
+
MySQLConnection: "MySQL",
|
|
110
|
+
DuckDBConnection: "DuckDB",
|
|
111
|
+
CSVConnection: "DuckDB",
|
|
112
|
+
SQLiteConnection: "SQLite",
|
|
113
|
+
}
|
|
114
|
+
for cls, name in mapping.items():
|
|
115
|
+
if isinstance(self.db_conn, cls):
|
|
116
|
+
return name
|
|
117
|
+
return "database"
|
|
118
|
+
|
|
119
|
+
def _resume_hint(self, thread_id: str) -> str:
|
|
120
|
+
"""Build resume command hint."""
|
|
121
|
+
return self.resume_command_template.format(thread_id=thread_id)
|
|
122
|
+
|
|
123
|
+
def show_welcome_message(self):
|
|
124
|
+
"""Display welcome message for interactive mode."""
|
|
125
|
+
if self.first_message:
|
|
126
|
+
self.console.print(Panel.fit(self._banner()))
|
|
127
|
+
self.console.print(Markdown(self._instructions()))
|
|
128
|
+
|
|
129
|
+
db_name = self.database_name or "Unknown"
|
|
125
130
|
self.console.print(
|
|
126
|
-
f"[
|
|
131
|
+
f"[heading]\n\nConnected to:[/heading] {db_name} ({self._db_type_name()})\n"
|
|
127
132
|
)
|
|
128
|
-
|
|
133
|
+
|
|
129
134
|
if self._thread_id:
|
|
130
|
-
self.console.print(f"[
|
|
135
|
+
self.console.print(f"[muted]Resuming thread:[/muted] {self._thread_id}\n")
|
|
131
136
|
|
|
132
|
-
async def
|
|
133
|
-
"""End thread and display
|
|
134
|
-
# Print resume hint if there is an active thread
|
|
137
|
+
async def _end_thread(self):
|
|
138
|
+
"""End thread and display resume hint."""
|
|
135
139
|
if self._thread_id:
|
|
136
140
|
await self._threads.end_thread(self._thread_id)
|
|
137
141
|
self.console.print(
|
|
138
|
-
f"[
|
|
142
|
+
f"[muted]You can continue this thread using:[/muted] {self._resume_hint(self._thread_id)}"
|
|
139
143
|
)
|
|
140
144
|
|
|
145
|
+
async def _handle_memory(self, content: str):
|
|
146
|
+
"""Handle memory addition command."""
|
|
147
|
+
if not content:
|
|
148
|
+
self.console.print("[warning]Empty memory content after '#'[/warning]\n")
|
|
149
|
+
return
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
mm = self.sqlsaber_agent.memory_manager
|
|
153
|
+
if mm and self.database_name:
|
|
154
|
+
memory = mm.add_memory(self.database_name, content)
|
|
155
|
+
self.console.print(f"[success]✓ Memory added:[/success] {content}")
|
|
156
|
+
self.console.print(f"[muted]Memory ID: {memory.id}[/muted]\n")
|
|
157
|
+
else:
|
|
158
|
+
self.console.print(
|
|
159
|
+
"[warning]Could not add memory (no database context)[/warning]\n"
|
|
160
|
+
)
|
|
161
|
+
except Exception as exc:
|
|
162
|
+
self.console.print(f"[warning]Could not add memory:[/warning] {exc}\n")
|
|
163
|
+
|
|
164
|
+
async def _cmd_clear(self):
|
|
165
|
+
"""Clear conversation history."""
|
|
166
|
+
self.message_history = []
|
|
167
|
+
try:
|
|
168
|
+
if self._thread_id:
|
|
169
|
+
await self._threads.end_thread(self._thread_id)
|
|
170
|
+
except Exception:
|
|
171
|
+
pass
|
|
172
|
+
self.console.print("[success]Conversation history cleared.[/success]\n")
|
|
173
|
+
self._thread_id = None
|
|
174
|
+
self.first_message = True
|
|
175
|
+
|
|
176
|
+
async def _cmd_thinking_on(self):
|
|
177
|
+
"""Enable thinking mode."""
|
|
178
|
+
self.sqlsaber_agent.set_thinking(enabled=True)
|
|
179
|
+
self.console.print("[success]✓ Thinking enabled[/success]\n")
|
|
180
|
+
|
|
181
|
+
async def _cmd_thinking_off(self):
|
|
182
|
+
"""Disable thinking mode."""
|
|
183
|
+
self.sqlsaber_agent.set_thinking(enabled=False)
|
|
184
|
+
self.console.print("[success]✓ Thinking disabled[/success]\n")
|
|
185
|
+
|
|
186
|
+
async def _handle_command(self, user_query: str) -> bool:
|
|
187
|
+
"""Handle slash commands. Returns True if command was handled."""
|
|
188
|
+
if user_query == "/clear":
|
|
189
|
+
await self._cmd_clear()
|
|
190
|
+
return True
|
|
191
|
+
if user_query == "/thinking on":
|
|
192
|
+
await self._cmd_thinking_on()
|
|
193
|
+
return True
|
|
194
|
+
if user_query == "/thinking off":
|
|
195
|
+
await self._cmd_thinking_off()
|
|
196
|
+
return True
|
|
197
|
+
return False
|
|
198
|
+
|
|
141
199
|
async def _update_table_cache(self):
|
|
142
200
|
"""Update the table completer cache with fresh data."""
|
|
143
201
|
try:
|
|
@@ -170,6 +228,10 @@ class InteractiveSession:
|
|
|
170
228
|
# If there's an error, just use empty cache
|
|
171
229
|
self.table_completer.update_cache([])
|
|
172
230
|
|
|
231
|
+
async def before_prompt_loop(self):
|
|
232
|
+
"""Hook to refresh context before prompt loop."""
|
|
233
|
+
await self._update_table_cache()
|
|
234
|
+
|
|
173
235
|
async def _execute_query_with_cancellation(self, user_query: str):
|
|
174
236
|
"""Execute a query with cancellation support."""
|
|
175
237
|
# Create cancellation token
|
|
@@ -207,6 +269,7 @@ class InteractiveSession:
|
|
|
207
269
|
title=user_query,
|
|
208
270
|
model_name=self.sqlsaber_agent.agent.model.model_name,
|
|
209
271
|
)
|
|
272
|
+
self.first_message = False
|
|
210
273
|
except Exception:
|
|
211
274
|
pass
|
|
212
275
|
finally:
|
|
@@ -218,101 +281,45 @@ class InteractiveSession:
|
|
|
218
281
|
async def run(self):
|
|
219
282
|
"""Run the interactive session loop."""
|
|
220
283
|
self.show_welcome_message()
|
|
284
|
+
await self.before_prompt_loop()
|
|
221
285
|
|
|
222
|
-
|
|
223
|
-
await self._update_table_cache()
|
|
224
|
-
|
|
225
|
-
session = PromptSession(
|
|
226
|
-
history=FileHistory(
|
|
227
|
-
Path(platformdirs.user_config_dir("sqlsaber")) / "history"
|
|
228
|
-
)
|
|
229
|
-
)
|
|
286
|
+
session = PromptSession(history=FileHistory(self._history_path()))
|
|
230
287
|
|
|
231
288
|
while True:
|
|
232
289
|
try:
|
|
233
290
|
with patch_stdout():
|
|
234
291
|
user_query = await session.prompt_async(
|
|
235
|
-
"",
|
|
292
|
+
"> ",
|
|
236
293
|
multiline=True,
|
|
237
294
|
completer=CompositeCompleter(
|
|
238
295
|
SlashCommandCompleter(), self.table_completer
|
|
239
296
|
),
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
style=style,
|
|
297
|
+
bottom_toolbar=self._bottom_toolbar,
|
|
298
|
+
style=self.tm.pt_style(),
|
|
243
299
|
)
|
|
244
300
|
|
|
245
301
|
if not user_query:
|
|
246
302
|
continue
|
|
247
303
|
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
or user_query.startswith("/quit")
|
|
304
|
+
# Handle exit commands
|
|
305
|
+
if user_query in self.exit_commands or any(
|
|
306
|
+
user_query.startswith(cmd) for cmd in self.exit_commands
|
|
252
307
|
):
|
|
253
|
-
await self.
|
|
308
|
+
await self._end_thread()
|
|
254
309
|
break
|
|
255
310
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
self.message_history = []
|
|
259
|
-
# End current thread (if any) so the next turn creates a fresh one
|
|
260
|
-
try:
|
|
261
|
-
if self._thread_id:
|
|
262
|
-
await self._threads.end_thread(self._thread_id)
|
|
263
|
-
except Exception:
|
|
264
|
-
pass
|
|
265
|
-
self.console.print("[green]Conversation history cleared.[/green]\n")
|
|
266
|
-
# Do not print resume hint when clearing; a new thread will be created on next turn
|
|
267
|
-
self._thread_id = None
|
|
268
|
-
continue
|
|
269
|
-
|
|
270
|
-
# Thinking commands
|
|
271
|
-
if user_query == "/thinking on":
|
|
272
|
-
self.sqlsaber_agent.set_thinking(enabled=True)
|
|
273
|
-
self.console.print("[green]✓ Thinking enabled[/green]\n")
|
|
311
|
+
# Handle slash commands
|
|
312
|
+
if await self._handle_command(user_query):
|
|
274
313
|
continue
|
|
275
314
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
self.
|
|
315
|
+
# Handle memory addition
|
|
316
|
+
if user_query.strip().startswith("#"):
|
|
317
|
+
await self._handle_memory(user_query.strip()[1:].strip())
|
|
279
318
|
continue
|
|
280
319
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
memory_content = memory_text[1:].strip() # Remove # and trim
|
|
285
|
-
if memory_content:
|
|
286
|
-
# Add memory via the agent's memory manager
|
|
287
|
-
try:
|
|
288
|
-
mm = self.sqlsaber_agent.memory_manager
|
|
289
|
-
if mm and self.database_name:
|
|
290
|
-
memory = mm.add_memory(
|
|
291
|
-
self.database_name, memory_content
|
|
292
|
-
)
|
|
293
|
-
self.console.print(
|
|
294
|
-
f"[green]✓ Memory added:[/green] {memory_content}"
|
|
295
|
-
)
|
|
296
|
-
self.console.print(
|
|
297
|
-
f"[dim]Memory ID: {memory.id}[/dim]\n"
|
|
298
|
-
)
|
|
299
|
-
else:
|
|
300
|
-
self.console.print(
|
|
301
|
-
"[yellow]Could not add memory (no database context)[/yellow]\n"
|
|
302
|
-
)
|
|
303
|
-
except Exception:
|
|
304
|
-
self.console.print(
|
|
305
|
-
"[yellow]Could not add memory[/yellow]\n"
|
|
306
|
-
)
|
|
307
|
-
else:
|
|
308
|
-
self.console.print(
|
|
309
|
-
"[yellow]Empty memory content after '#'[/yellow]\n"
|
|
310
|
-
)
|
|
311
|
-
continue
|
|
312
|
-
|
|
313
|
-
# Execute query with cancellation support
|
|
314
|
-
await self._execute_query_with_cancellation(user_query)
|
|
315
|
-
self.display.show_newline() # Empty line for readability
|
|
320
|
+
# Execute query with cancellation support
|
|
321
|
+
await self._execute_query_with_cancellation(user_query)
|
|
322
|
+
self.display.show_newline()
|
|
316
323
|
|
|
317
324
|
except KeyboardInterrupt:
|
|
318
325
|
# Handle Ctrl+C - cancel current task if running
|
|
@@ -324,14 +331,14 @@ class InteractiveSession:
|
|
|
324
331
|
await self.current_task
|
|
325
332
|
except asyncio.CancelledError:
|
|
326
333
|
pass
|
|
327
|
-
self.console.print("\n[
|
|
334
|
+
self.console.print("\n[warning]Query interrupted[/warning]")
|
|
328
335
|
else:
|
|
329
336
|
self.console.print(
|
|
330
|
-
"\n[
|
|
337
|
+
"\n[warning]Press Ctrl+D to exit. Or use '/exit' or '/quit' slash command.[/warning]"
|
|
331
338
|
)
|
|
332
339
|
except EOFError:
|
|
333
340
|
# Exit when Ctrl+D is pressed
|
|
334
|
-
await self.
|
|
341
|
+
await self._end_thread()
|
|
335
342
|
break
|
|
336
|
-
except Exception as
|
|
337
|
-
self.console.print(f"[
|
|
343
|
+
except Exception as exc:
|
|
344
|
+
self.console.print(f"[error]Error:[/error] {exc}")
|
sqlsaber/cli/memory.py
CHANGED
|
@@ -5,14 +5,14 @@ from typing import Annotated
|
|
|
5
5
|
|
|
6
6
|
import cyclopts
|
|
7
7
|
import questionary
|
|
8
|
-
from rich.console import Console
|
|
9
8
|
from rich.table import Table
|
|
10
9
|
|
|
11
10
|
from sqlsaber.config.database import DatabaseConfigManager
|
|
12
11
|
from sqlsaber.memory.manager import MemoryManager
|
|
12
|
+
from sqlsaber.theme.manager import create_console
|
|
13
13
|
|
|
14
14
|
# Global instances for CLI commands
|
|
15
|
-
console =
|
|
15
|
+
console = create_console()
|
|
16
16
|
config_manager = DatabaseConfigManager()
|
|
17
17
|
memory_manager = MemoryManager()
|
|
18
18
|
|
sqlsaber/cli/models.py
CHANGED
|
@@ -6,14 +6,14 @@ import sys
|
|
|
6
6
|
import cyclopts
|
|
7
7
|
import httpx
|
|
8
8
|
import questionary
|
|
9
|
-
from rich.console import Console
|
|
10
9
|
from rich.table import Table
|
|
11
10
|
|
|
12
11
|
from sqlsaber.config import providers
|
|
13
12
|
from sqlsaber.config.settings import Config
|
|
13
|
+
from sqlsaber.theme.manager import create_console
|
|
14
14
|
|
|
15
15
|
# Global instances for CLI commands
|
|
16
|
-
console =
|
|
16
|
+
console = create_console()
|
|
17
17
|
|
|
18
18
|
# Create the model management CLI app
|
|
19
19
|
models_app = cyclopts.App(
|
|
@@ -25,11 +25,20 @@ models_app = cyclopts.App(
|
|
|
25
25
|
class ModelManager:
|
|
26
26
|
"""Manages AI model configuration and fetching."""
|
|
27
27
|
|
|
28
|
-
DEFAULT_MODEL = "anthropic:claude-sonnet-4-
|
|
28
|
+
DEFAULT_MODEL = "anthropic:claude-sonnet-4-5-20250929"
|
|
29
29
|
MODELS_API_URL = "https://models.dev/api.json"
|
|
30
30
|
# Providers come from central registry
|
|
31
31
|
SUPPORTED_PROVIDERS = providers.all_keys()
|
|
32
32
|
|
|
33
|
+
RECOMMENDED_MODELS = {
|
|
34
|
+
"anthropic": "claude-sonnet-4-5-20250929",
|
|
35
|
+
"openai": "gpt-5",
|
|
36
|
+
"google": "gemini-2.5-pro",
|
|
37
|
+
"groq": "llama-3-3-70b-versatile",
|
|
38
|
+
"mistral": "mistral-large-latest",
|
|
39
|
+
"cohere": "command-r-plus",
|
|
40
|
+
}
|
|
41
|
+
|
|
33
42
|
async def fetch_available_models(
|
|
34
43
|
self, providers: list[str] | None = None
|
|
35
44
|
) -> list[dict]:
|
|
@@ -180,39 +189,20 @@ def set():
|
|
|
180
189
|
"""Set the AI model to use."""
|
|
181
190
|
|
|
182
191
|
async def interactive_set():
|
|
192
|
+
from sqlsaber.application.model_selection import choose_model, fetch_models
|
|
193
|
+
from sqlsaber.application.prompts import AsyncPrompter
|
|
194
|
+
|
|
183
195
|
console.print("[blue]Fetching available models...[/blue]")
|
|
184
|
-
models = await model_manager
|
|
196
|
+
models = await fetch_models(model_manager)
|
|
185
197
|
|
|
186
198
|
if not models:
|
|
187
199
|
console.print("[red]Failed to fetch models. Cannot set model.[/red]")
|
|
188
200
|
sys.exit(1)
|
|
189
201
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
prov = model.get("provider", "?")
|
|
195
|
-
choice_text = f"[{prov}] {model['id']} - {model['name']}"
|
|
196
|
-
if model["description"]:
|
|
197
|
-
choice_text += f" ({model['description'][:50]}{'...' if len(model['description']) > 50 else ''})"
|
|
198
|
-
|
|
199
|
-
choices.append({"name": choice_text, "value": model["id"]})
|
|
200
|
-
|
|
201
|
-
# Get current model to set as default
|
|
202
|
-
current_model = model_manager.get_current_model()
|
|
203
|
-
default_index = 0
|
|
204
|
-
for i, choice in enumerate(choices):
|
|
205
|
-
if choice["value"] == current_model:
|
|
206
|
-
default_index = i
|
|
207
|
-
break
|
|
208
|
-
|
|
209
|
-
selected_model = await questionary.select(
|
|
210
|
-
"Select a model:",
|
|
211
|
-
choices=choices,
|
|
212
|
-
use_search_filter=True,
|
|
213
|
-
use_jk_keys=False, # Disable j/k keys when using search filter
|
|
214
|
-
default=choices[default_index] if choices else None,
|
|
215
|
-
).ask_async()
|
|
202
|
+
prompter = AsyncPrompter()
|
|
203
|
+
selected_model = await choose_model(
|
|
204
|
+
prompter, models, restrict_provider=None, use_search_filter=True
|
|
205
|
+
)
|
|
216
206
|
|
|
217
207
|
if selected_model:
|
|
218
208
|
if model_manager.set_model(selected_model):
|