sqlsaber 0.26.0__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

sqlsaber/cli/commands.py CHANGED
@@ -11,6 +11,7 @@ from sqlsaber.cli.auth import create_auth_app
11
11
  from sqlsaber.cli.database import create_db_app
12
12
  from sqlsaber.cli.memory import create_memory_app
13
13
  from sqlsaber.cli.models import create_models_app
14
+ from sqlsaber.cli.onboarding import needs_onboarding, run_onboarding
14
15
  from sqlsaber.cli.threads import create_threads_app
15
16
 
16
17
  # Lazy imports - only import what's needed for CLI parsing
@@ -128,6 +129,16 @@ def query(
128
129
  # If stdin was empty, fall back to interactive mode
129
130
  actual_query = None
130
131
 
132
+ # Check if onboarding is needed (only for interactive mode or when no database is configured)
133
+ if needs_onboarding(database):
134
+ # Run onboarding flow
135
+ onboarding_success = await run_onboarding()
136
+ if not onboarding_success:
137
+ # User cancelled or onboarding failed
138
+ raise CLIError(
139
+ "Setup incomplete. Please configure your database and try again."
140
+ )
141
+
131
142
  # Resolve database from CLI input
132
143
  try:
133
144
  resolved = resolve_database(database, config_manager)
sqlsaber/cli/database.py CHANGED
@@ -81,95 +81,34 @@ def add(
81
81
 
82
82
  if interactive:
83
83
  # Interactive mode - prompt for all required fields
84
- console.print(f"[bold]Adding database connection: {name}[/bold]")
84
+ from sqlsaber.application.db_setup import collect_db_input
85
+ from sqlsaber.application.prompts import AsyncPrompter
85
86
 
86
- # Database type
87
- if not type or type == "postgresql":
88
- type = questionary.select(
89
- "Database type:",
90
- choices=["postgresql", "mysql", "sqlite", "duckdb"],
91
- default="postgresql",
92
- ).ask()
93
-
94
- if type in {"sqlite", "duckdb"}:
95
- # SQLite/DuckDB only need database file path
96
- database = database or questionary.path("Database file path:").ask()
97
- database = str(Path(database).expanduser().resolve())
98
- host = "localhost"
99
- port = 0
100
- username = type
101
- password = ""
102
- else:
103
- # PostgreSQL/MySQL need connection details
104
- host = host or questionary.text("Host:", default="localhost").ask()
87
+ console.print(f"[bold]Adding database connection: {name}[/bold]")
105
88
 
106
- default_port = 5432 if type == "postgresql" else 3306
107
- port = port or int(
108
- questionary.text("Port:", default=str(default_port)).ask()
89
+ async def collect_input():
90
+ prompter = AsyncPrompter()
91
+ return await collect_db_input(
92
+ prompter=prompter, name=name, db_type=type, include_ssl=True
109
93
  )
110
94
 
111
- database = database or questionary.text("Database name:").ask()
112
- username = username or questionary.text("Username:").ask()
113
-
114
- # Ask for password
115
- password = getpass.getpass("Password (stored in your OS keychain): ")
116
-
117
- # Ask for SSL configuration
118
- if questionary.confirm("Configure SSL/TLS settings?", default=False).ask():
119
- if type == "postgresql":
120
- ssl_mode = (
121
- ssl_mode
122
- or questionary.select(
123
- "SSL mode for PostgreSQL:",
124
- choices=[
125
- "disable",
126
- "allow",
127
- "prefer",
128
- "require",
129
- "verify-ca",
130
- "verify-full",
131
- ],
132
- default="prefer",
133
- ).ask()
134
- )
135
- elif type == "mysql":
136
- ssl_mode = (
137
- ssl_mode
138
- or questionary.select(
139
- "SSL mode for MySQL:",
140
- choices=[
141
- "DISABLED",
142
- "PREFERRED",
143
- "REQUIRED",
144
- "VERIFY_CA",
145
- "VERIFY_IDENTITY",
146
- ],
147
- default="PREFERRED",
148
- ).ask()
149
- )
150
-
151
- if ssl_mode and ssl_mode not in ["disable", "DISABLED"]:
152
- if questionary.confirm(
153
- "Specify SSL certificate files?", default=False
154
- ).ask():
155
- ssl_ca = (
156
- ssl_ca or questionary.path("SSL CA certificate file:").ask()
157
- )
158
- if questionary.confirm(
159
- "Specify client certificate?", default=False
160
- ).ask():
161
- ssl_cert = (
162
- ssl_cert
163
- or questionary.path(
164
- "SSL client certificate file:"
165
- ).ask()
166
- )
167
- ssl_key = (
168
- ssl_key
169
- or questionary.path(
170
- "SSL client private key file:"
171
- ).ask()
172
- )
95
+ db_input = asyncio.run(collect_input())
96
+
97
+ if db_input is None:
98
+ console.print("[yellow]Operation cancelled[/yellow]")
99
+ return
100
+
101
+ # Extract values from db_input
102
+ type = db_input.type
103
+ host = db_input.host
104
+ port = db_input.port
105
+ database = db_input.database
106
+ username = db_input.username
107
+ password = db_input.password
108
+ ssl_mode = db_input.ssl_mode
109
+ ssl_ca = db_input.ssl_ca
110
+ ssl_cert = db_input.ssl_cert
111
+ ssl_key = db_input.ssl_key
173
112
  else:
174
113
  # Non-interactive mode - use provided values or defaults
175
114
  if type == "sqlite":
sqlsaber/cli/display.py CHANGED
@@ -109,7 +109,7 @@ class LiveMarkdownRenderer:
109
109
  # Print the complete markdown to scroll-back for permanent reference
110
110
  if buf:
111
111
  if kind == ThinkingPart:
112
- self.console.print(Text(buf, style="dim"))
112
+ self.console.print(Markdown(buf, style="dim"))
113
113
  else:
114
114
  self.console.print(Markdown(buf))
115
115
 
@@ -35,26 +35,12 @@ if TYPE_CHECKING:
35
35
  from sqlsaber.agents.pydantic_ai_agent import SQLSaberAgent
36
36
 
37
37
 
38
- def bottom_toolbar():
39
- return [
40
- (
41
- "class:bottom-toolbar",
42
- " Use 'Esc-Enter' or 'Meta-Enter' to submit.",
43
- )
44
- ]
45
-
46
-
47
- style = Style.from_dict(
48
- {
49
- "frame.border": "#ebbcba",
50
- "bottom-toolbar": "#ebbcba bg:#21202e",
51
- }
52
- )
53
-
54
-
55
38
  class InteractiveSession:
56
39
  """Manages interactive CLI sessions."""
57
40
 
41
+ exit_commands = {"/exit", "/quit", "exit", "quit"}
42
+ resume_command_template = "saber threads resume {thread_id}"
43
+
58
44
  def __init__(
59
45
  self,
60
46
  console: Console,
@@ -80,28 +66,33 @@ class InteractiveSession:
80
66
  self._thread_id: str | None = initial_thread_id
81
67
  self.first_message = not self._thread_id
82
68
 
83
- def show_welcome_message(self):
84
- """Display welcome message for interactive mode."""
85
- # Show database information
86
- db_name = self.database_name or "Unknown"
87
- db_type = (
88
- "PostgreSQL"
89
- if isinstance(self.db_conn, PostgreSQLConnection)
90
- else "MySQL"
91
- if isinstance(self.db_conn, MySQLConnection)
92
- else "DuckDB"
93
- if isinstance(self.db_conn, DuckDBConnection)
94
- else "DuckDB"
95
- if isinstance(self.db_conn, CSVConnection)
96
- else "SQLite"
97
- if isinstance(self.db_conn, SQLiteConnection)
98
- else "database"
69
+ def _history_path(self) -> Path:
70
+ """Get the history file path, ensuring directory exists."""
71
+ history_dir = Path(platformdirs.user_config_dir("sqlsaber"))
72
+ history_dir.mkdir(parents=True, exist_ok=True)
73
+ return history_dir / "history"
74
+
75
+ def _prompt_style(self) -> Style:
76
+ """Get the prompt style configuration."""
77
+ return Style.from_dict(
78
+ {
79
+ "frame.border": "gray",
80
+ "bottom-toolbar": "white bg:#21202e",
81
+ }
99
82
  )
100
83
 
101
- if self.first_message:
102
- self.console.print(
103
- Panel.fit(
104
- """
84
+ def _bottom_toolbar(self):
85
+ """Get the bottom toolbar text."""
86
+ return [
87
+ (
88
+ "class:bottom-toolbar",
89
+ " Use 'Esc-Enter' or 'Meta-Enter' to submit.",
90
+ )
91
+ ]
92
+
93
+ def _banner(self) -> str:
94
+ """Get the ASCII banner."""
95
+ return """
105
96
  ███████ ██████ ██ ███████ █████ ██████ ███████ ██████
106
97
  ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
107
98
  ███████ ██ ██ ██ ███████ ███████ ██████ █████ ██████
@@ -109,35 +100,110 @@ class InteractiveSession:
109
100
  ███████ ██████ ███████ ███████ ██ ██ ██████ ███████ ██ ██
110
101
  ▀▀
111
102
  """
112
- )
113
- )
114
- self.console.print(
115
- Markdown(
116
- dedent("""
103
+
104
+ def _instructions(self) -> str:
105
+ """Get the instruction text."""
106
+ return dedent("""
117
107
  - Use `/` for slash commands
118
108
  - Type `@` to get table name completions
119
109
  - Start message with `#` to add something to agent's memory
120
110
  - Use `Ctrl+C` to interrupt and `Ctrl+D` to exit
121
111
  """)
122
- )
123
- )
124
112
 
113
+ def _db_type_name(self) -> str:
114
+ """Get human-readable database type name."""
115
+ mapping = {
116
+ PostgreSQLConnection: "PostgreSQL",
117
+ MySQLConnection: "MySQL",
118
+ DuckDBConnection: "DuckDB",
119
+ CSVConnection: "DuckDB",
120
+ SQLiteConnection: "SQLite",
121
+ }
122
+ for cls, name in mapping.items():
123
+ if isinstance(self.db_conn, cls):
124
+ return name
125
+ return "database"
126
+
127
+ def _resume_hint(self, thread_id: str) -> str:
128
+ """Build resume command hint."""
129
+ return self.resume_command_template.format(thread_id=thread_id)
130
+
131
+ def show_welcome_message(self):
132
+ """Display welcome message for interactive mode."""
133
+ if self.first_message:
134
+ self.console.print(Panel.fit(self._banner()))
135
+ self.console.print(Markdown(self._instructions()))
136
+
137
+ db_name = self.database_name or "Unknown"
125
138
  self.console.print(
126
- f"[bold blue]\n\nConnected to:[/bold blue] {db_name} ({db_type})\n"
139
+ f"[bold blue]\n\nConnected to:[/bold blue] {db_name} ({self._db_type_name()})\n"
127
140
  )
128
- # If resuming a thread, show a notice
141
+
129
142
  if self._thread_id:
130
143
  self.console.print(f"[dim]Resuming thread:[/dim] {self._thread_id}\n")
131
144
 
132
- async def _end_thread_and_display_resume_hint(self):
133
- """End thread and display command to resume thread"""
134
- # Print resume hint if there is an active thread
145
+ async def _end_thread(self):
146
+ """End thread and display resume hint."""
135
147
  if self._thread_id:
136
148
  await self._threads.end_thread(self._thread_id)
137
149
  self.console.print(
138
- f"[dim]You can continue this thread using:[/dim] saber threads resume {self._thread_id}"
150
+ f"[dim]You can continue this thread using:[/dim] {self._resume_hint(self._thread_id)}"
139
151
  )
140
152
 
153
+ async def _handle_memory(self, content: str):
154
+ """Handle memory addition command."""
155
+ if not content:
156
+ self.console.print("[yellow]Empty memory content after '#'[/yellow]\n")
157
+ return
158
+
159
+ try:
160
+ mm = self.sqlsaber_agent.memory_manager
161
+ if mm and self.database_name:
162
+ memory = mm.add_memory(self.database_name, content)
163
+ self.console.print(f"[green]✓ Memory added:[/green] {content}")
164
+ self.console.print(f"[dim]Memory ID: {memory.id}[/dim]\n")
165
+ else:
166
+ self.console.print(
167
+ "[yellow]Could not add memory (no database context)[/yellow]\n"
168
+ )
169
+ except Exception as exc:
170
+ self.console.print(f"[yellow]Could not add memory:[/yellow] {exc}\n")
171
+
172
+ async def _cmd_clear(self):
173
+ """Clear conversation history."""
174
+ self.message_history = []
175
+ try:
176
+ if self._thread_id:
177
+ await self._threads.end_thread(self._thread_id)
178
+ except Exception:
179
+ pass
180
+ self.console.print("[green]Conversation history cleared.[/green]\n")
181
+ self._thread_id = None
182
+ self.first_message = True
183
+
184
+ async def _cmd_thinking_on(self):
185
+ """Enable thinking mode."""
186
+ self.sqlsaber_agent.set_thinking(enabled=True)
187
+ self.console.print("[green]✓ Thinking enabled[/green]\n")
188
+
189
+ async def _cmd_thinking_off(self):
190
+ """Disable thinking mode."""
191
+ self.sqlsaber_agent.set_thinking(enabled=False)
192
+ self.console.print("[green]✓ Thinking disabled[/green]\n")
193
+
194
+ async def _handle_command(self, user_query: str) -> bool:
195
+ """Handle slash commands. Returns True if command was handled."""
196
+ if user_query == "/clear":
197
+ await self._cmd_clear()
198
+ return True
199
+ if user_query == "/thinking on":
200
+ await self._cmd_thinking_on()
201
+ return True
202
+ if user_query == "/thinking off":
203
+ await self._cmd_thinking_off()
204
+ return True
205
+ return False
206
+
141
207
  async def _update_table_cache(self):
142
208
  """Update the table completer cache with fresh data."""
143
209
  try:
@@ -170,6 +236,10 @@ class InteractiveSession:
170
236
  # If there's an error, just use empty cache
171
237
  self.table_completer.update_cache([])
172
238
 
239
+ async def before_prompt_loop(self):
240
+ """Hook to refresh context before prompt loop."""
241
+ await self._update_table_cache()
242
+
173
243
  async def _execute_query_with_cancellation(self, user_query: str):
174
244
  """Execute a query with cancellation support."""
175
245
  # Create cancellation token
@@ -207,6 +277,7 @@ class InteractiveSession:
207
277
  title=user_query,
208
278
  model_name=self.sqlsaber_agent.agent.model.model_name,
209
279
  )
280
+ self.first_message = False
210
281
  except Exception:
211
282
  pass
212
283
  finally:
@@ -218,15 +289,9 @@ class InteractiveSession:
218
289
  async def run(self):
219
290
  """Run the interactive session loop."""
220
291
  self.show_welcome_message()
292
+ await self.before_prompt_loop()
221
293
 
222
- # Initialize table cache
223
- await self._update_table_cache()
224
-
225
- session = PromptSession(
226
- history=FileHistory(
227
- Path(platformdirs.user_config_dir("sqlsaber")) / "history"
228
- )
229
- )
294
+ session = PromptSession(history=FileHistory(self._history_path()))
230
295
 
231
296
  while True:
232
297
  try:
@@ -238,81 +303,32 @@ class InteractiveSession:
238
303
  SlashCommandCompleter(), self.table_completer
239
304
  ),
240
305
  show_frame=True,
241
- bottom_toolbar=bottom_toolbar,
242
- style=style,
306
+ bottom_toolbar=self._bottom_toolbar,
307
+ style=self._prompt_style(),
243
308
  )
244
309
 
245
310
  if not user_query:
246
311
  continue
247
312
 
248
- if (
249
- user_query in ["/exit", "/quit", "exit", "quit"]
250
- or user_query.startswith("/exit")
251
- or user_query.startswith("/quit")
313
+ # Handle exit commands
314
+ if user_query in self.exit_commands or any(
315
+ user_query.startswith(cmd) for cmd in self.exit_commands
252
316
  ):
253
- await self._end_thread_and_display_resume_hint()
317
+ await self._end_thread()
254
318
  break
255
319
 
256
- if user_query == "/clear":
257
- # Reset local history (pydantic-ai call will receive empty history on next run)
258
- self.message_history = []
259
- # End current thread (if any) so the next turn creates a fresh one
260
- try:
261
- if self._thread_id:
262
- await self._threads.end_thread(self._thread_id)
263
- except Exception:
264
- pass
265
- self.console.print("[green]Conversation history cleared.[/green]\n")
266
- # Do not print resume hint when clearing; a new thread will be created on next turn
267
- self._thread_id = None
268
- continue
269
-
270
- # Thinking commands
271
- if user_query == "/thinking on":
272
- self.sqlsaber_agent.set_thinking(enabled=True)
273
- self.console.print("[green]✓ Thinking enabled[/green]\n")
320
+ # Handle slash commands
321
+ if await self._handle_command(user_query):
274
322
  continue
275
323
 
276
- if user_query == "/thinking off":
277
- self.sqlsaber_agent.set_thinking(enabled=False)
278
- self.console.print("[green]✓ Thinking disabled[/green]\n")
324
+ # Handle memory addition
325
+ if user_query.strip().startswith("#"):
326
+ await self._handle_memory(user_query.strip()[1:].strip())
279
327
  continue
280
328
 
281
- if memory_text := user_query.strip():
282
- # Check if query starts with # for memory addition
283
- if memory_text.startswith("#"):
284
- memory_content = memory_text[1:].strip() # Remove # and trim
285
- if memory_content:
286
- # Add memory via the agent's memory manager
287
- try:
288
- mm = self.sqlsaber_agent.memory_manager
289
- if mm and self.database_name:
290
- memory = mm.add_memory(
291
- self.database_name, memory_content
292
- )
293
- self.console.print(
294
- f"[green]✓ Memory added:[/green] {memory_content}"
295
- )
296
- self.console.print(
297
- f"[dim]Memory ID: {memory.id}[/dim]\n"
298
- )
299
- else:
300
- self.console.print(
301
- "[yellow]Could not add memory (no database context)[/yellow]\n"
302
- )
303
- except Exception:
304
- self.console.print(
305
- "[yellow]Could not add memory[/yellow]\n"
306
- )
307
- else:
308
- self.console.print(
309
- "[yellow]Empty memory content after '#'[/yellow]\n"
310
- )
311
- continue
312
-
313
- # Execute query with cancellation support
314
- await self._execute_query_with_cancellation(user_query)
315
- self.display.show_newline() # Empty line for readability
329
+ # Execute query with cancellation support
330
+ await self._execute_query_with_cancellation(user_query)
331
+ self.display.show_newline()
316
332
 
317
333
  except KeyboardInterrupt:
318
334
  # Handle Ctrl+C - cancel current task if running
@@ -331,7 +347,7 @@ class InteractiveSession:
331
347
  )
332
348
  except EOFError:
333
349
  # Exit when Ctrl+D is pressed
334
- await self._end_thread_and_display_resume_hint()
350
+ await self._end_thread()
335
351
  break
336
- except Exception as e:
337
- self.console.print(f"[bold red]Error:[/bold red] {str(e)}")
352
+ except Exception as exc:
353
+ self.console.print(f"[bold red]Error:[/bold red] {exc}")
sqlsaber/cli/models.py CHANGED
@@ -25,11 +25,20 @@ models_app = cyclopts.App(
25
25
  class ModelManager:
26
26
  """Manages AI model configuration and fetching."""
27
27
 
28
- DEFAULT_MODEL = "anthropic:claude-sonnet-4-20250514"
28
+ DEFAULT_MODEL = "anthropic:claude-sonnet-4-5-20250929"
29
29
  MODELS_API_URL = "https://models.dev/api.json"
30
30
  # Providers come from central registry
31
31
  SUPPORTED_PROVIDERS = providers.all_keys()
32
32
 
33
+ RECOMMENDED_MODELS = {
34
+ "anthropic": "claude-sonnet-4-5-20250929",
35
+ "openai": "gpt-5",
36
+ "google": "gemini-2.5-pro",
37
+ "groq": "llama-3-3-70b-versatile",
38
+ "mistral": "mistral-large-latest",
39
+ "cohere": "command-r-plus",
40
+ }
41
+
33
42
  async def fetch_available_models(
34
43
  self, providers: list[str] | None = None
35
44
  ) -> list[dict]:
@@ -180,39 +189,20 @@ def set():
180
189
  """Set the AI model to use."""
181
190
 
182
191
  async def interactive_set():
192
+ from sqlsaber.application.model_selection import choose_model, fetch_models
193
+ from sqlsaber.application.prompts import AsyncPrompter
194
+
183
195
  console.print("[blue]Fetching available models...[/blue]")
184
- models = await model_manager.fetch_available_models()
196
+ models = await fetch_models(model_manager)
185
197
 
186
198
  if not models:
187
199
  console.print("[red]Failed to fetch models. Cannot set model.[/red]")
188
200
  sys.exit(1)
189
201
 
190
- # Create choices for questionary
191
- choices = []
192
- for model in models:
193
- # Format: "[provider] ID - Name (Description)"
194
- prov = model.get("provider", "?")
195
- choice_text = f"[{prov}] {model['id']} - {model['name']}"
196
- if model["description"]:
197
- choice_text += f" ({model['description'][:50]}{'...' if len(model['description']) > 50 else ''})"
198
-
199
- choices.append({"name": choice_text, "value": model["id"]})
200
-
201
- # Get current model to set as default
202
- current_model = model_manager.get_current_model()
203
- default_index = 0
204
- for i, choice in enumerate(choices):
205
- if choice["value"] == current_model:
206
- default_index = i
207
- break
208
-
209
- selected_model = await questionary.select(
210
- "Select a model:",
211
- choices=choices,
212
- use_search_filter=True,
213
- use_jk_keys=False, # Disable j/k keys when using search filter
214
- default=choices[default_index] if choices else None,
215
- ).ask_async()
202
+ prompter = AsyncPrompter()
203
+ selected_model = await choose_model(
204
+ prompter, models, restrict_provider=None, use_search_filter=True
205
+ )
216
206
 
217
207
  if selected_model:
218
208
  if model_manager.set_model(selected_model):