sqlsaber 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,6 @@ import platformdirs
9
9
  from prompt_toolkit import PromptSession
10
10
  from prompt_toolkit.history import FileHistory
11
11
  from prompt_toolkit.patch_stdout import patch_stdout
12
- from prompt_toolkit.styles import Style
13
12
  from rich.console import Console
14
13
  from rich.markdown import Markdown
15
14
  from rich.panel import Panel
@@ -29,32 +28,19 @@ from sqlsaber.database import (
29
28
  SQLiteConnection,
30
29
  )
31
30
  from sqlsaber.database.schema import SchemaManager
31
+ from sqlsaber.theme.manager import get_theme_manager
32
32
  from sqlsaber.threads import ThreadStorage
33
33
 
34
34
  if TYPE_CHECKING:
35
35
  from sqlsaber.agents.pydantic_ai_agent import SQLSaberAgent
36
36
 
37
37
 
38
- def bottom_toolbar():
39
- return [
40
- (
41
- "class:bottom-toolbar",
42
- " Use 'Esc-Enter' or 'Meta-Enter' to submit.",
43
- )
44
- ]
45
-
46
-
47
- style = Style.from_dict(
48
- {
49
- "frame.border": "#ebbcba",
50
- "bottom-toolbar": "#ebbcba bg:#21202e",
51
- }
52
- )
53
-
54
-
55
38
  class InteractiveSession:
56
39
  """Manages interactive CLI sessions."""
57
40
 
41
+ exit_commands = {"/exit", "/quit", "exit", "quit"}
42
+ resume_command_template = "saber threads resume {thread_id}"
43
+
58
44
  def __init__(
59
45
  self,
60
46
  console: Console,
@@ -75,33 +61,30 @@ class InteractiveSession:
75
61
  self.cancellation_token: asyncio.Event | None = None
76
62
  self.table_completer = TableNameCompleter()
77
63
  self.message_history: list | None = initial_history or []
64
+ self.tm = get_theme_manager()
78
65
  # Conversation Thread persistence
79
66
  self._threads = ThreadStorage()
80
67
  self._thread_id: str | None = initial_thread_id
81
68
  self.first_message = not self._thread_id
82
69
 
83
- def show_welcome_message(self):
84
- """Display welcome message for interactive mode."""
85
- # Show database information
86
- db_name = self.database_name or "Unknown"
87
- db_type = (
88
- "PostgreSQL"
89
- if isinstance(self.db_conn, PostgreSQLConnection)
90
- else "MySQL"
91
- if isinstance(self.db_conn, MySQLConnection)
92
- else "DuckDB"
93
- if isinstance(self.db_conn, DuckDBConnection)
94
- else "DuckDB"
95
- if isinstance(self.db_conn, CSVConnection)
96
- else "SQLite"
97
- if isinstance(self.db_conn, SQLiteConnection)
98
- else "database"
99
- )
70
+ def _history_path(self) -> Path:
71
+ """Get the history file path, ensuring directory exists."""
72
+ history_dir = Path(platformdirs.user_config_dir("sqlsaber"))
73
+ history_dir.mkdir(parents=True, exist_ok=True)
74
+ return history_dir / "history"
75
+
76
+ def _bottom_toolbar(self):
77
+ """Get the bottom toolbar text."""
78
+ return [
79
+ (
80
+ "class:bottom-toolbar",
81
+ " Use 'Esc-Enter' or 'Meta-Enter' to submit.",
82
+ )
83
+ ]
100
84
 
101
- if self.first_message:
102
- self.console.print(
103
- Panel.fit(
104
- """
85
+ def _banner(self) -> str:
86
+ """Get the ASCII banner."""
87
+ return """
105
88
  ███████ ██████ ██ ███████ █████ ██████ ███████ ██████
106
89
  ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
107
90
  ███████ ██ ██ ██ ███████ ███████ ██████ █████ ██████
@@ -109,35 +92,110 @@ class InteractiveSession:
109
92
  ███████ ██████ ███████ ███████ ██ ██ ██████ ███████ ██ ██
110
93
  ▀▀
111
94
  """
112
- )
113
- )
114
- self.console.print(
115
- Markdown(
116
- dedent("""
95
+
96
+ def _instructions(self) -> str:
97
+ """Get the instruction text."""
98
+ return dedent("""
117
99
  - Use `/` for slash commands
118
100
  - Type `@` to get table name completions
119
101
  - Start message with `#` to add something to agent's memory
120
102
  - Use `Ctrl+C` to interrupt and `Ctrl+D` to exit
121
103
  """)
122
- )
123
- )
124
104
 
105
+ def _db_type_name(self) -> str:
106
+ """Get human-readable database type name."""
107
+ mapping = {
108
+ PostgreSQLConnection: "PostgreSQL",
109
+ MySQLConnection: "MySQL",
110
+ DuckDBConnection: "DuckDB",
111
+ CSVConnection: "DuckDB",
112
+ SQLiteConnection: "SQLite",
113
+ }
114
+ for cls, name in mapping.items():
115
+ if isinstance(self.db_conn, cls):
116
+ return name
117
+ return "database"
118
+
119
+ def _resume_hint(self, thread_id: str) -> str:
120
+ """Build resume command hint."""
121
+ return self.resume_command_template.format(thread_id=thread_id)
122
+
123
+ def show_welcome_message(self):
124
+ """Display welcome message for interactive mode."""
125
+ if self.first_message:
126
+ self.console.print(Panel.fit(self._banner()))
127
+ self.console.print(Markdown(self._instructions()))
128
+
129
+ db_name = self.database_name or "Unknown"
125
130
  self.console.print(
126
- f"[bold blue]\n\nConnected to:[/bold blue] {db_name} ({db_type})\n"
131
+ f"[heading]\n\nConnected to:[/heading] {db_name} ({self._db_type_name()})\n"
127
132
  )
128
- # If resuming a thread, show a notice
133
+
129
134
  if self._thread_id:
130
- self.console.print(f"[dim]Resuming thread:[/dim] {self._thread_id}\n")
135
+ self.console.print(f"[muted]Resuming thread:[/muted] {self._thread_id}\n")
131
136
 
132
- async def _end_thread_and_display_resume_hint(self):
133
- """End thread and display command to resume thread"""
134
- # Print resume hint if there is an active thread
137
+ async def _end_thread(self):
138
+ """End thread and display resume hint."""
135
139
  if self._thread_id:
136
140
  await self._threads.end_thread(self._thread_id)
137
141
  self.console.print(
138
- f"[dim]You can continue this thread using:[/dim] saber threads resume {self._thread_id}"
142
+ f"[muted]You can continue this thread using:[/muted] {self._resume_hint(self._thread_id)}"
139
143
  )
140
144
 
145
+ async def _handle_memory(self, content: str):
146
+ """Handle memory addition command."""
147
+ if not content:
148
+ self.console.print("[warning]Empty memory content after '#'[/warning]\n")
149
+ return
150
+
151
+ try:
152
+ mm = self.sqlsaber_agent.memory_manager
153
+ if mm and self.database_name:
154
+ memory = mm.add_memory(self.database_name, content)
155
+ self.console.print(f"[success]✓ Memory added:[/success] {content}")
156
+ self.console.print(f"[muted]Memory ID: {memory.id}[/muted]\n")
157
+ else:
158
+ self.console.print(
159
+ "[warning]Could not add memory (no database context)[/warning]\n"
160
+ )
161
+ except Exception as exc:
162
+ self.console.print(f"[warning]Could not add memory:[/warning] {exc}\n")
163
+
164
+ async def _cmd_clear(self):
165
+ """Clear conversation history."""
166
+ self.message_history = []
167
+ try:
168
+ if self._thread_id:
169
+ await self._threads.end_thread(self._thread_id)
170
+ except Exception:
171
+ pass
172
+ self.console.print("[success]Conversation history cleared.[/success]\n")
173
+ self._thread_id = None
174
+ self.first_message = True
175
+
176
+ async def _cmd_thinking_on(self):
177
+ """Enable thinking mode."""
178
+ self.sqlsaber_agent.set_thinking(enabled=True)
179
+ self.console.print("[success]✓ Thinking enabled[/success]\n")
180
+
181
+ async def _cmd_thinking_off(self):
182
+ """Disable thinking mode."""
183
+ self.sqlsaber_agent.set_thinking(enabled=False)
184
+ self.console.print("[success]✓ Thinking disabled[/success]\n")
185
+
186
+ async def _handle_command(self, user_query: str) -> bool:
187
+ """Handle slash commands. Returns True if command was handled."""
188
+ if user_query == "/clear":
189
+ await self._cmd_clear()
190
+ return True
191
+ if user_query == "/thinking on":
192
+ await self._cmd_thinking_on()
193
+ return True
194
+ if user_query == "/thinking off":
195
+ await self._cmd_thinking_off()
196
+ return True
197
+ return False
198
+
141
199
  async def _update_table_cache(self):
142
200
  """Update the table completer cache with fresh data."""
143
201
  try:
@@ -170,6 +228,10 @@ class InteractiveSession:
170
228
  # If there's an error, just use empty cache
171
229
  self.table_completer.update_cache([])
172
230
 
231
+ async def before_prompt_loop(self):
232
+ """Hook to refresh context before prompt loop."""
233
+ await self._update_table_cache()
234
+
173
235
  async def _execute_query_with_cancellation(self, user_query: str):
174
236
  """Execute a query with cancellation support."""
175
237
  # Create cancellation token
@@ -207,6 +269,7 @@ class InteractiveSession:
207
269
  title=user_query,
208
270
  model_name=self.sqlsaber_agent.agent.model.model_name,
209
271
  )
272
+ self.first_message = False
210
273
  except Exception:
211
274
  pass
212
275
  finally:
@@ -218,101 +281,45 @@ class InteractiveSession:
218
281
  async def run(self):
219
282
  """Run the interactive session loop."""
220
283
  self.show_welcome_message()
284
+ await self.before_prompt_loop()
221
285
 
222
- # Initialize table cache
223
- await self._update_table_cache()
224
-
225
- session = PromptSession(
226
- history=FileHistory(
227
- Path(platformdirs.user_config_dir("sqlsaber")) / "history"
228
- )
229
- )
286
+ session = PromptSession(history=FileHistory(self._history_path()))
230
287
 
231
288
  while True:
232
289
  try:
233
290
  with patch_stdout():
234
291
  user_query = await session.prompt_async(
235
- "",
292
+ "> ",
236
293
  multiline=True,
237
294
  completer=CompositeCompleter(
238
295
  SlashCommandCompleter(), self.table_completer
239
296
  ),
240
- show_frame=True,
241
- bottom_toolbar=bottom_toolbar,
242
- style=style,
297
+ bottom_toolbar=self._bottom_toolbar,
298
+ style=self.tm.pt_style(),
243
299
  )
244
300
 
245
301
  if not user_query:
246
302
  continue
247
303
 
248
- if (
249
- user_query in ["/exit", "/quit", "exit", "quit"]
250
- or user_query.startswith("/exit")
251
- or user_query.startswith("/quit")
304
+ # Handle exit commands
305
+ if user_query in self.exit_commands or any(
306
+ user_query.startswith(cmd) for cmd in self.exit_commands
252
307
  ):
253
- await self._end_thread_and_display_resume_hint()
308
+ await self._end_thread()
254
309
  break
255
310
 
256
- if user_query == "/clear":
257
- # Reset local history (pydantic-ai call will receive empty history on next run)
258
- self.message_history = []
259
- # End current thread (if any) so the next turn creates a fresh one
260
- try:
261
- if self._thread_id:
262
- await self._threads.end_thread(self._thread_id)
263
- except Exception:
264
- pass
265
- self.console.print("[green]Conversation history cleared.[/green]\n")
266
- # Do not print resume hint when clearing; a new thread will be created on next turn
267
- self._thread_id = None
268
- continue
269
-
270
- # Thinking commands
271
- if user_query == "/thinking on":
272
- self.sqlsaber_agent.set_thinking(enabled=True)
273
- self.console.print("[green]✓ Thinking enabled[/green]\n")
311
+ # Handle slash commands
312
+ if await self._handle_command(user_query):
274
313
  continue
275
314
 
276
- if user_query == "/thinking off":
277
- self.sqlsaber_agent.set_thinking(enabled=False)
278
- self.console.print("[green]✓ Thinking disabled[/green]\n")
315
+ # Handle memory addition
316
+ if user_query.strip().startswith("#"):
317
+ await self._handle_memory(user_query.strip()[1:].strip())
279
318
  continue
280
319
 
281
- if memory_text := user_query.strip():
282
- # Check if query starts with # for memory addition
283
- if memory_text.startswith("#"):
284
- memory_content = memory_text[1:].strip() # Remove # and trim
285
- if memory_content:
286
- # Add memory via the agent's memory manager
287
- try:
288
- mm = self.sqlsaber_agent.memory_manager
289
- if mm and self.database_name:
290
- memory = mm.add_memory(
291
- self.database_name, memory_content
292
- )
293
- self.console.print(
294
- f"[green]✓ Memory added:[/green] {memory_content}"
295
- )
296
- self.console.print(
297
- f"[dim]Memory ID: {memory.id}[/dim]\n"
298
- )
299
- else:
300
- self.console.print(
301
- "[yellow]Could not add memory (no database context)[/yellow]\n"
302
- )
303
- except Exception:
304
- self.console.print(
305
- "[yellow]Could not add memory[/yellow]\n"
306
- )
307
- else:
308
- self.console.print(
309
- "[yellow]Empty memory content after '#'[/yellow]\n"
310
- )
311
- continue
312
-
313
- # Execute query with cancellation support
314
- await self._execute_query_with_cancellation(user_query)
315
- self.display.show_newline() # Empty line for readability
320
+ # Execute query with cancellation support
321
+ await self._execute_query_with_cancellation(user_query)
322
+ self.display.show_newline()
316
323
 
317
324
  except KeyboardInterrupt:
318
325
  # Handle Ctrl+C - cancel current task if running
@@ -324,14 +331,14 @@ class InteractiveSession:
324
331
  await self.current_task
325
332
  except asyncio.CancelledError:
326
333
  pass
327
- self.console.print("\n[yellow]Query interrupted[/yellow]")
334
+ self.console.print("\n[warning]Query interrupted[/warning]")
328
335
  else:
329
336
  self.console.print(
330
- "\n[yellow]Press Ctrl+D to exit. Or use '/exit' or '/quit' slash command.[/yellow]"
337
+ "\n[warning]Press Ctrl+D to exit. Or use '/exit' or '/quit' slash command.[/warning]"
331
338
  )
332
339
  except EOFError:
333
340
  # Exit when Ctrl+D is pressed
334
- await self._end_thread_and_display_resume_hint()
341
+ await self._end_thread()
335
342
  break
336
- except Exception as e:
337
- self.console.print(f"[bold red]Error:[/bold red] {str(e)}")
343
+ except Exception as exc:
344
+ self.console.print(f"[error]Error:[/error] {exc}")
sqlsaber/cli/memory.py CHANGED
@@ -5,14 +5,14 @@ from typing import Annotated
5
5
 
6
6
  import cyclopts
7
7
  import questionary
8
- from rich.console import Console
9
8
  from rich.table import Table
10
9
 
11
10
  from sqlsaber.config.database import DatabaseConfigManager
12
11
  from sqlsaber.memory.manager import MemoryManager
12
+ from sqlsaber.theme.manager import create_console
13
13
 
14
14
  # Global instances for CLI commands
15
- console = Console()
15
+ console = create_console()
16
16
  config_manager = DatabaseConfigManager()
17
17
  memory_manager = MemoryManager()
18
18
 
sqlsaber/cli/models.py CHANGED
@@ -6,14 +6,14 @@ import sys
6
6
  import cyclopts
7
7
  import httpx
8
8
  import questionary
9
- from rich.console import Console
10
9
  from rich.table import Table
11
10
 
12
11
  from sqlsaber.config import providers
13
12
  from sqlsaber.config.settings import Config
13
+ from sqlsaber.theme.manager import create_console
14
14
 
15
15
  # Global instances for CLI commands
16
- console = Console()
16
+ console = create_console()
17
17
 
18
18
  # Create the model management CLI app
19
19
  models_app = cyclopts.App(
@@ -25,11 +25,20 @@ models_app = cyclopts.App(
25
25
  class ModelManager:
26
26
  """Manages AI model configuration and fetching."""
27
27
 
28
- DEFAULT_MODEL = "anthropic:claude-sonnet-4-20250514"
28
+ DEFAULT_MODEL = "anthropic:claude-sonnet-4-5-20250929"
29
29
  MODELS_API_URL = "https://models.dev/api.json"
30
30
  # Providers come from central registry
31
31
  SUPPORTED_PROVIDERS = providers.all_keys()
32
32
 
33
+ RECOMMENDED_MODELS = {
34
+ "anthropic": "claude-sonnet-4-5-20250929",
35
+ "openai": "gpt-5",
36
+ "google": "gemini-2.5-pro",
37
+ "groq": "llama-3-3-70b-versatile",
38
+ "mistral": "mistral-large-latest",
39
+ "cohere": "command-r-plus",
40
+ }
41
+
33
42
  async def fetch_available_models(
34
43
  self, providers: list[str] | None = None
35
44
  ) -> list[dict]:
@@ -180,39 +189,20 @@ def set():
180
189
  """Set the AI model to use."""
181
190
 
182
191
  async def interactive_set():
192
+ from sqlsaber.application.model_selection import choose_model, fetch_models
193
+ from sqlsaber.application.prompts import AsyncPrompter
194
+
183
195
  console.print("[blue]Fetching available models...[/blue]")
184
- models = await model_manager.fetch_available_models()
196
+ models = await fetch_models(model_manager)
185
197
 
186
198
  if not models:
187
199
  console.print("[red]Failed to fetch models. Cannot set model.[/red]")
188
200
  sys.exit(1)
189
201
 
190
- # Create choices for questionary
191
- choices = []
192
- for model in models:
193
- # Format: "[provider] ID - Name (Description)"
194
- prov = model.get("provider", "?")
195
- choice_text = f"[{prov}] {model['id']} - {model['name']}"
196
- if model["description"]:
197
- choice_text += f" ({model['description'][:50]}{'...' if len(model['description']) > 50 else ''})"
198
-
199
- choices.append({"name": choice_text, "value": model["id"]})
200
-
201
- # Get current model to set as default
202
- current_model = model_manager.get_current_model()
203
- default_index = 0
204
- for i, choice in enumerate(choices):
205
- if choice["value"] == current_model:
206
- default_index = i
207
- break
208
-
209
- selected_model = await questionary.select(
210
- "Select a model:",
211
- choices=choices,
212
- use_search_filter=True,
213
- use_jk_keys=False, # Disable j/k keys when using search filter
214
- default=choices[default_index] if choices else None,
215
- ).ask_async()
202
+ prompter = AsyncPrompter()
203
+ selected_model = await choose_model(
204
+ prompter, models, restrict_provider=None, use_search_filter=True
205
+ )
216
206
 
217
207
  if selected_model:
218
208
  if model_manager.set_model(selected_model):