sqlsaber 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

Files changed (38) hide show
  1. sqlsaber/agents/__init__.py +2 -2
  2. sqlsaber/agents/base.py +1 -1
  3. sqlsaber/agents/mcp.py +1 -1
  4. sqlsaber/agents/pydantic_ai_agent.py +207 -135
  5. sqlsaber/application/__init__.py +1 -0
  6. sqlsaber/application/auth_setup.py +164 -0
  7. sqlsaber/application/db_setup.py +223 -0
  8. sqlsaber/application/model_selection.py +98 -0
  9. sqlsaber/application/prompts.py +115 -0
  10. sqlsaber/cli/auth.py +22 -50
  11. sqlsaber/cli/commands.py +22 -28
  12. sqlsaber/cli/completers.py +2 -0
  13. sqlsaber/cli/database.py +25 -86
  14. sqlsaber/cli/display.py +29 -9
  15. sqlsaber/cli/interactive.py +150 -127
  16. sqlsaber/cli/models.py +18 -28
  17. sqlsaber/cli/onboarding.py +325 -0
  18. sqlsaber/cli/streaming.py +15 -17
  19. sqlsaber/cli/threads.py +10 -6
  20. sqlsaber/config/api_keys.py +2 -2
  21. sqlsaber/config/settings.py +25 -2
  22. sqlsaber/database/__init__.py +55 -1
  23. sqlsaber/database/base.py +124 -0
  24. sqlsaber/database/csv.py +133 -0
  25. sqlsaber/database/duckdb.py +313 -0
  26. sqlsaber/database/mysql.py +345 -0
  27. sqlsaber/database/postgresql.py +328 -0
  28. sqlsaber/database/schema.py +66 -963
  29. sqlsaber/database/sqlite.py +258 -0
  30. sqlsaber/mcp/mcp.py +1 -1
  31. sqlsaber/tools/sql_tools.py +1 -1
  32. {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/METADATA +43 -9
  33. sqlsaber-0.27.0.dist-info/RECORD +58 -0
  34. sqlsaber/database/connection.py +0 -535
  35. sqlsaber-0.25.0.dist-info/RECORD +0 -47
  36. {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/WHEEL +0 -0
  37. {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/entry_points.txt +0 -0
  38. {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/cli/database.py CHANGED
@@ -81,95 +81,34 @@ def add(
81
81
 
82
82
  if interactive:
83
83
  # Interactive mode - prompt for all required fields
84
- console.print(f"[bold]Adding database connection: {name}[/bold]")
84
+ from sqlsaber.application.db_setup import collect_db_input
85
+ from sqlsaber.application.prompts import AsyncPrompter
85
86
 
86
- # Database type
87
- if not type or type == "postgresql":
88
- type = questionary.select(
89
- "Database type:",
90
- choices=["postgresql", "mysql", "sqlite", "duckdb"],
91
- default="postgresql",
92
- ).ask()
93
-
94
- if type in {"sqlite", "duckdb"}:
95
- # SQLite/DuckDB only need database file path
96
- database = database or questionary.path("Database file path:").ask()
97
- database = str(Path(database).expanduser().resolve())
98
- host = "localhost"
99
- port = 0
100
- username = type
101
- password = ""
102
- else:
103
- # PostgreSQL/MySQL need connection details
104
- host = host or questionary.text("Host:", default="localhost").ask()
87
+ console.print(f"[bold]Adding database connection: {name}[/bold]")
105
88
 
106
- default_port = 5432 if type == "postgresql" else 3306
107
- port = port or int(
108
- questionary.text("Port:", default=str(default_port)).ask()
89
+ async def collect_input():
90
+ prompter = AsyncPrompter()
91
+ return await collect_db_input(
92
+ prompter=prompter, name=name, db_type=type, include_ssl=True
109
93
  )
110
94
 
111
- database = database or questionary.text("Database name:").ask()
112
- username = username or questionary.text("Username:").ask()
113
-
114
- # Ask for password
115
- password = getpass.getpass("Password (stored in your OS keychain): ")
116
-
117
- # Ask for SSL configuration
118
- if questionary.confirm("Configure SSL/TLS settings?", default=False).ask():
119
- if type == "postgresql":
120
- ssl_mode = (
121
- ssl_mode
122
- or questionary.select(
123
- "SSL mode for PostgreSQL:",
124
- choices=[
125
- "disable",
126
- "allow",
127
- "prefer",
128
- "require",
129
- "verify-ca",
130
- "verify-full",
131
- ],
132
- default="prefer",
133
- ).ask()
134
- )
135
- elif type == "mysql":
136
- ssl_mode = (
137
- ssl_mode
138
- or questionary.select(
139
- "SSL mode for MySQL:",
140
- choices=[
141
- "DISABLED",
142
- "PREFERRED",
143
- "REQUIRED",
144
- "VERIFY_CA",
145
- "VERIFY_IDENTITY",
146
- ],
147
- default="PREFERRED",
148
- ).ask()
149
- )
150
-
151
- if ssl_mode and ssl_mode not in ["disable", "DISABLED"]:
152
- if questionary.confirm(
153
- "Specify SSL certificate files?", default=False
154
- ).ask():
155
- ssl_ca = (
156
- ssl_ca or questionary.path("SSL CA certificate file:").ask()
157
- )
158
- if questionary.confirm(
159
- "Specify client certificate?", default=False
160
- ).ask():
161
- ssl_cert = (
162
- ssl_cert
163
- or questionary.path(
164
- "SSL client certificate file:"
165
- ).ask()
166
- )
167
- ssl_key = (
168
- ssl_key
169
- or questionary.path(
170
- "SSL client private key file:"
171
- ).ask()
172
- )
95
+ db_input = asyncio.run(collect_input())
96
+
97
+ if db_input is None:
98
+ console.print("[yellow]Operation cancelled[/yellow]")
99
+ return
100
+
101
+ # Extract values from db_input
102
+ type = db_input.type
103
+ host = db_input.host
104
+ port = db_input.port
105
+ database = db_input.database
106
+ username = db_input.username
107
+ password = db_input.password
108
+ ssl_mode = db_input.ssl_mode
109
+ ssl_ca = db_input.ssl_ca
110
+ ssl_cert = db_input.ssl_cert
111
+ ssl_key = db_input.ssl_key
173
112
  else:
174
113
  # Non-interactive mode - use provided values or defaults
175
114
  if type == "sqlite":
@@ -354,7 +293,7 @@ def test(
354
293
 
355
294
  async def test_connection():
356
295
  # Lazy import to keep CLI startup fast
357
- from sqlsaber.database.connection import DatabaseConnection
296
+ from sqlsaber.database import DatabaseConnection
358
297
 
359
298
  if name:
360
299
  db_config = config_manager.get_database(name)
sqlsaber/cli/display.py CHANGED
@@ -8,7 +8,7 @@ rendered with Live.
8
8
  import json
9
9
  from typing import Sequence, Type
10
10
 
11
- from pydantic_ai.messages import ModelResponsePart, TextPart
11
+ from pydantic_ai.messages import ModelResponsePart, TextPart, ThinkingPart
12
12
  from rich.columns import Columns
13
13
  from rich.console import Console, ConsoleOptions, RenderResult
14
14
  from rich.live import Live
@@ -75,7 +75,7 @@ class LiveMarkdownRenderer:
75
75
  self.end()
76
76
  self.paragraph_break()
77
77
 
78
- self._start()
78
+ self._start(kind)
79
79
  self._current_kind = kind
80
80
 
81
81
  def append(self, text: str | None) -> None:
@@ -87,7 +87,13 @@ class LiveMarkdownRenderer:
87
87
  self.ensure_segment(TextPart)
88
88
 
89
89
  self._buffer += text
90
- self._live.update(Markdown(self._buffer))
90
+
91
+ # Apply dim styling for thinking segments
92
+ if self._current_kind == ThinkingPart:
93
+ content = Markdown(self._buffer, style="dim")
94
+ self._live.update(content)
95
+ else:
96
+ self._live.update(Markdown(self._buffer))
91
97
 
92
98
  def end(self) -> None:
93
99
  """Finalize and stop the current Live segment, if any."""
@@ -95,13 +101,17 @@ class LiveMarkdownRenderer:
95
101
  return
96
102
  # Persist the *final* render exactly once, then shut Live down.
97
103
  buf = self._buffer
104
+ kind = self._current_kind
98
105
  self._live.stop()
99
106
  self._live = None
100
107
  self._buffer = ""
101
108
  self._current_kind = None
102
109
  # Print the complete markdown to scroll-back for permanent reference
103
110
  if buf:
104
- self.console.print(Markdown(buf))
111
+ if kind == ThinkingPart:
112
+ self.console.print(Markdown(buf, style="dim"))
113
+ else:
114
+ self.console.print(Markdown(buf))
105
115
 
106
116
  def end_if_active(self) -> None:
107
117
  self.end()
@@ -153,10 +163,20 @@ class LiveMarkdownRenderer:
153
163
  text = Text(f" {message}", style="yellow")
154
164
  return Columns([spinner, text], expand=False)
155
165
 
156
- def _start(self, initial_markdown: str = "") -> None:
166
+ def _start(
167
+ self, kind: Type[ModelResponsePart] | None = None, initial_markdown: str = ""
168
+ ) -> None:
157
169
  if self._live is not None:
158
170
  self.end()
159
171
  self._buffer = initial_markdown or ""
172
+
173
+ # Add visual styling for thinking segments
174
+ if kind == ThinkingPart:
175
+ if self.console.is_terminal:
176
+ self.console.print("[dim]💭 Thinking...[/dim]")
177
+ else:
178
+ self.console.print("*Thinking...*\n")
179
+
160
180
  # NOTE: Use transient=True so the live widget disappears on exit,
161
181
  # giving a clean transition to the final printed result.
162
182
  live = Live(
@@ -219,7 +239,9 @@ class DisplayManager:
219
239
  if self.console.is_terminal:
220
240
  self.console.print("[dim bold]:gear: Executing SQL:[/dim bold]")
221
241
  self.show_newline()
222
- syntax = Syntax(query, "sql", background_color="default", word_wrap=True)
242
+ syntax = Syntax(
243
+ query, "sql", background_color="default", word_wrap=True
244
+ )
223
245
  self.console.print(syntax)
224
246
  else:
225
247
  self.console.print("**Executing SQL:**\n")
@@ -271,9 +293,7 @@ class DisplayManager:
271
293
  f"[yellow]... and {len(results) - 20} more rows[/yellow]"
272
294
  )
273
295
  else:
274
- self.console.print(
275
- f"*... and {len(results) - 20} more rows*\n"
276
- )
296
+ self.console.print(f"*... and {len(results) - 20} more rows*\n")
277
297
 
278
298
  def show_error(self, error_message: str):
279
299
  """Display error message."""
@@ -3,13 +3,13 @@
3
3
  import asyncio
4
4
  from pathlib import Path
5
5
  from textwrap import dedent
6
+ from typing import TYPE_CHECKING
6
7
 
7
8
  import platformdirs
8
9
  from prompt_toolkit import PromptSession
9
10
  from prompt_toolkit.history import FileHistory
10
11
  from prompt_toolkit.patch_stdout import patch_stdout
11
12
  from prompt_toolkit.styles import Style
12
- from pydantic_ai import Agent
13
13
  from rich.console import Console
14
14
  from rich.markdown import Markdown
15
15
  from rich.panel import Panel
@@ -21,7 +21,7 @@ from sqlsaber.cli.completers import (
21
21
  )
22
22
  from sqlsaber.cli.display import DisplayManager
23
23
  from sqlsaber.cli.streaming import StreamingQueryHandler
24
- from sqlsaber.database.connection import (
24
+ from sqlsaber.database import (
25
25
  CSVConnection,
26
26
  DuckDBConnection,
27
27
  MySQLConnection,
@@ -31,31 +31,20 @@ from sqlsaber.database.connection import (
31
31
  from sqlsaber.database.schema import SchemaManager
32
32
  from sqlsaber.threads import ThreadStorage
33
33
 
34
-
35
- def bottom_toolbar():
36
- return [
37
- (
38
- "class:bottom-toolbar",
39
- " Use 'Esc-Enter' or 'Meta-Enter' to submit.",
40
- )
41
- ]
42
-
43
-
44
- style = Style.from_dict(
45
- {
46
- "frame.border": "#ebbcba",
47
- "bottom-toolbar": "#ebbcba bg:#21202e",
48
- }
49
- )
34
+ if TYPE_CHECKING:
35
+ from sqlsaber.agents.pydantic_ai_agent import SQLSaberAgent
50
36
 
51
37
 
52
38
  class InteractiveSession:
53
39
  """Manages interactive CLI sessions."""
54
40
 
41
+ exit_commands = {"/exit", "/quit", "exit", "quit"}
42
+ resume_command_template = "saber threads resume {thread_id}"
43
+
55
44
  def __init__(
56
45
  self,
57
46
  console: Console,
58
- agent: Agent,
47
+ sqlsaber_agent: "SQLSaberAgent",
59
48
  db_conn,
60
49
  database_name: str,
61
50
  *,
@@ -63,7 +52,7 @@ class InteractiveSession:
63
52
  initial_history: list | None = None,
64
53
  ):
65
54
  self.console = console
66
- self.agent = agent
55
+ self.sqlsaber_agent = sqlsaber_agent
67
56
  self.db_conn = db_conn
68
57
  self.database_name = database_name
69
58
  self.display = DisplayManager(console)
@@ -77,28 +66,33 @@ class InteractiveSession:
77
66
  self._thread_id: str | None = initial_thread_id
78
67
  self.first_message = not self._thread_id
79
68
 
80
- def show_welcome_message(self):
81
- """Display welcome message for interactive mode."""
82
- # Show database information
83
- db_name = self.database_name or "Unknown"
84
- db_type = (
85
- "PostgreSQL"
86
- if isinstance(self.db_conn, PostgreSQLConnection)
87
- else "MySQL"
88
- if isinstance(self.db_conn, MySQLConnection)
89
- else "DuckDB"
90
- if isinstance(self.db_conn, DuckDBConnection)
91
- else "DuckDB"
92
- if isinstance(self.db_conn, CSVConnection)
93
- else "SQLite"
94
- if isinstance(self.db_conn, SQLiteConnection)
95
- else "database"
69
+ def _history_path(self) -> Path:
70
+ """Get the history file path, ensuring directory exists."""
71
+ history_dir = Path(platformdirs.user_config_dir("sqlsaber"))
72
+ history_dir.mkdir(parents=True, exist_ok=True)
73
+ return history_dir / "history"
74
+
75
+ def _prompt_style(self) -> Style:
76
+ """Get the prompt style configuration."""
77
+ return Style.from_dict(
78
+ {
79
+ "frame.border": "gray",
80
+ "bottom-toolbar": "white bg:#21202e",
81
+ }
96
82
  )
97
83
 
98
- if self.first_message:
99
- self.console.print(
100
- Panel.fit(
101
- """
84
+ def _bottom_toolbar(self):
85
+ """Get the bottom toolbar text."""
86
+ return [
87
+ (
88
+ "class:bottom-toolbar",
89
+ " Use 'Esc-Enter' or 'Meta-Enter' to submit.",
90
+ )
91
+ ]
92
+
93
+ def _banner(self) -> str:
94
+ """Get the ASCII banner."""
95
+ return """
102
96
  ███████ ██████ ██ ███████ █████ ██████ ███████ ██████
103
97
  ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
104
98
  ███████ ██ ██ ██ ███████ ███████ ██████ █████ ██████
@@ -106,35 +100,110 @@ class InteractiveSession:
106
100
  ███████ ██████ ███████ ███████ ██ ██ ██████ ███████ ██ ██
107
101
  ▀▀
108
102
  """
109
- )
110
- )
111
- self.console.print(
112
- Markdown(
113
- dedent("""
103
+
104
+ def _instructions(self) -> str:
105
+ """Get the instruction text."""
106
+ return dedent("""
114
107
  - Use `/` for slash commands
115
108
  - Type `@` to get table name completions
116
109
  - Start message with `#` to add something to agent's memory
117
110
  - Use `Ctrl+C` to interrupt and `Ctrl+D` to exit
118
111
  """)
119
- )
120
- )
121
112
 
113
+ def _db_type_name(self) -> str:
114
+ """Get human-readable database type name."""
115
+ mapping = {
116
+ PostgreSQLConnection: "PostgreSQL",
117
+ MySQLConnection: "MySQL",
118
+ DuckDBConnection: "DuckDB",
119
+ CSVConnection: "DuckDB",
120
+ SQLiteConnection: "SQLite",
121
+ }
122
+ for cls, name in mapping.items():
123
+ if isinstance(self.db_conn, cls):
124
+ return name
125
+ return "database"
126
+
127
+ def _resume_hint(self, thread_id: str) -> str:
128
+ """Build resume command hint."""
129
+ return self.resume_command_template.format(thread_id=thread_id)
130
+
131
+ def show_welcome_message(self):
132
+ """Display welcome message for interactive mode."""
133
+ if self.first_message:
134
+ self.console.print(Panel.fit(self._banner()))
135
+ self.console.print(Markdown(self._instructions()))
136
+
137
+ db_name = self.database_name or "Unknown"
122
138
  self.console.print(
123
- f"[bold blue]\n\nConnected to:[/bold blue] {db_name} ({db_type})\n"
139
+ f"[bold blue]\n\nConnected to:[/bold blue] {db_name} ({self._db_type_name()})\n"
124
140
  )
125
- # If resuming a thread, show a notice
141
+
126
142
  if self._thread_id:
127
143
  self.console.print(f"[dim]Resuming thread:[/dim] {self._thread_id}\n")
128
144
 
129
- async def _end_thread_and_display_resume_hint(self):
130
- """End thread and display command to resume thread"""
131
- # Print resume hint if there is an active thread
145
+ async def _end_thread(self):
146
+ """End thread and display resume hint."""
132
147
  if self._thread_id:
133
148
  await self._threads.end_thread(self._thread_id)
134
149
  self.console.print(
135
- f"[dim]You can continue this thread using:[/dim] saber threads resume {self._thread_id}"
150
+ f"[dim]You can continue this thread using:[/dim] {self._resume_hint(self._thread_id)}"
136
151
  )
137
152
 
153
+ async def _handle_memory(self, content: str):
154
+ """Handle memory addition command."""
155
+ if not content:
156
+ self.console.print("[yellow]Empty memory content after '#'[/yellow]\n")
157
+ return
158
+
159
+ try:
160
+ mm = self.sqlsaber_agent.memory_manager
161
+ if mm and self.database_name:
162
+ memory = mm.add_memory(self.database_name, content)
163
+ self.console.print(f"[green]✓ Memory added:[/green] {content}")
164
+ self.console.print(f"[dim]Memory ID: {memory.id}[/dim]\n")
165
+ else:
166
+ self.console.print(
167
+ "[yellow]Could not add memory (no database context)[/yellow]\n"
168
+ )
169
+ except Exception as exc:
170
+ self.console.print(f"[yellow]Could not add memory:[/yellow] {exc}\n")
171
+
172
+ async def _cmd_clear(self):
173
+ """Clear conversation history."""
174
+ self.message_history = []
175
+ try:
176
+ if self._thread_id:
177
+ await self._threads.end_thread(self._thread_id)
178
+ except Exception:
179
+ pass
180
+ self.console.print("[green]Conversation history cleared.[/green]\n")
181
+ self._thread_id = None
182
+ self.first_message = True
183
+
184
+ async def _cmd_thinking_on(self):
185
+ """Enable thinking mode."""
186
+ self.sqlsaber_agent.set_thinking(enabled=True)
187
+ self.console.print("[green]✓ Thinking enabled[/green]\n")
188
+
189
+ async def _cmd_thinking_off(self):
190
+ """Disable thinking mode."""
191
+ self.sqlsaber_agent.set_thinking(enabled=False)
192
+ self.console.print("[green]✓ Thinking disabled[/green]\n")
193
+
194
+ async def _handle_command(self, user_query: str) -> bool:
195
+ """Handle slash commands. Returns True if command was handled."""
196
+ if user_query == "/clear":
197
+ await self._cmd_clear()
198
+ return True
199
+ if user_query == "/thinking on":
200
+ await self._cmd_thinking_on()
201
+ return True
202
+ if user_query == "/thinking off":
203
+ await self._cmd_thinking_off()
204
+ return True
205
+ return False
206
+
138
207
  async def _update_table_cache(self):
139
208
  """Update the table completer cache with fresh data."""
140
209
  try:
@@ -167,6 +236,10 @@ class InteractiveSession:
167
236
  # If there's an error, just use empty cache
168
237
  self.table_completer.update_cache([])
169
238
 
239
+ async def before_prompt_loop(self):
240
+ """Hook to refresh context before prompt loop."""
241
+ await self._update_table_cache()
242
+
170
243
  async def _execute_query_with_cancellation(self, user_query: str):
171
244
  """Execute a query with cancellation support."""
172
245
  # Create cancellation token
@@ -176,7 +249,7 @@ class InteractiveSession:
176
249
  query_task = asyncio.create_task(
177
250
  self.streaming_handler.execute_streaming_query(
178
251
  user_query,
179
- self.agent,
252
+ self.sqlsaber_agent,
180
253
  self.cancellation_token,
181
254
  self.message_history,
182
255
  )
@@ -191,11 +264,6 @@ class InteractiveSession:
191
264
  # Use all_messages() so the system prompt and all prior turns are preserved
192
265
  self.message_history = run_result.all_messages()
193
266
 
194
- # Extract title (first user prompt) and model name
195
- if not self._thread_id:
196
- title = user_query
197
- model_name = self.agent.model.model_name
198
-
199
267
  # Persist snapshot to thread storage (create or overwrite)
200
268
  self._thread_id = await self._threads.save_snapshot(
201
269
  messages_json=run_result.all_messages_json(),
@@ -206,9 +274,10 @@ class InteractiveSession:
206
274
  if self.first_message:
207
275
  await self._threads.save_metadata(
208
276
  thread_id=self._thread_id,
209
- title=title,
210
- model_name=model_name,
277
+ title=user_query,
278
+ model_name=self.sqlsaber_agent.agent.model.model_name,
211
279
  )
280
+ self.first_message = False
212
281
  except Exception:
213
282
  pass
214
283
  finally:
@@ -220,15 +289,9 @@ class InteractiveSession:
220
289
  async def run(self):
221
290
  """Run the interactive session loop."""
222
291
  self.show_welcome_message()
292
+ await self.before_prompt_loop()
223
293
 
224
- # Initialize table cache
225
- await self._update_table_cache()
226
-
227
- session = PromptSession(
228
- history=FileHistory(
229
- Path(platformdirs.user_config_dir("sqlsaber")) / "history"
230
- )
231
- )
294
+ session = PromptSession(history=FileHistory(self._history_path()))
232
295
 
233
296
  while True:
234
297
  try:
@@ -240,72 +303,32 @@ class InteractiveSession:
240
303
  SlashCommandCompleter(), self.table_completer
241
304
  ),
242
305
  show_frame=True,
243
- bottom_toolbar=bottom_toolbar,
244
- style=style,
306
+ bottom_toolbar=self._bottom_toolbar,
307
+ style=self._prompt_style(),
245
308
  )
246
309
 
247
310
  if not user_query:
248
311
  continue
249
312
 
250
- if (
251
- user_query in ["/exit", "/quit", "exit", "quit"]
252
- or user_query.startswith("/exit")
253
- or user_query.startswith("/quit")
313
+ # Handle exit commands
314
+ if user_query in self.exit_commands or any(
315
+ user_query.startswith(cmd) for cmd in self.exit_commands
254
316
  ):
255
- await self._end_thread_and_display_resume_hint()
317
+ await self._end_thread()
256
318
  break
257
319
 
258
- if user_query == "/clear":
259
- # Reset local history (pydantic-ai call will receive empty history on next run)
260
- self.message_history = []
261
- # End current thread (if any) so the next turn creates a fresh one
262
- try:
263
- if self._thread_id:
264
- await self._threads.end_thread(self._thread_id)
265
- except Exception:
266
- pass
267
- self.console.print("[green]Conversation history cleared.[/green]\n")
268
- # Do not print resume hint when clearing; a new thread will be created on next turn
269
- self._thread_id = None
320
+ # Handle slash commands
321
+ if await self._handle_command(user_query):
270
322
  continue
271
323
 
272
- if memory_text := user_query.strip():
273
- # Check if query starts with # for memory addition
274
- if memory_text.startswith("#"):
275
- memory_content = memory_text[1:].strip() # Remove # and trim
276
- if memory_content:
277
- # Add memory via the agent's memory manager
278
- try:
279
- mm = getattr(
280
- self.agent, "_sqlsaber_memory_manager", None
281
- )
282
- if mm and self.database_name:
283
- memory = mm.add_memory(
284
- self.database_name, memory_content
285
- )
286
- self.console.print(
287
- f"[green]✓ Memory added:[/green] {memory_content}"
288
- )
289
- self.console.print(
290
- f"[dim]Memory ID: {memory.id}[/dim]\n"
291
- )
292
- else:
293
- self.console.print(
294
- "[yellow]Could not add memory (no database context)[/yellow]\n"
295
- )
296
- except Exception:
297
- self.console.print(
298
- "[yellow]Could not add memory[/yellow]\n"
299
- )
300
- else:
301
- self.console.print(
302
- "[yellow]Empty memory content after '#'[/yellow]\n"
303
- )
304
- continue
324
+ # Handle memory addition
325
+ if user_query.strip().startswith("#"):
326
+ await self._handle_memory(user_query.strip()[1:].strip())
327
+ continue
305
328
 
306
- # Execute query with cancellation support
307
- await self._execute_query_with_cancellation(user_query)
308
- self.display.show_newline() # Empty line for readability
329
+ # Execute query with cancellation support
330
+ await self._execute_query_with_cancellation(user_query)
331
+ self.display.show_newline()
309
332
 
310
333
  except KeyboardInterrupt:
311
334
  # Handle Ctrl+C - cancel current task if running
@@ -324,7 +347,7 @@ class InteractiveSession:
324
347
  )
325
348
  except EOFError:
326
349
  # Exit when Ctrl+D is pressed
327
- await self._end_thread_and_display_resume_hint()
350
+ await self._end_thread()
328
351
  break
329
- except Exception as e:
330
- self.console.print(f"[bold red]Error:[/bold red] {str(e)}")
352
+ except Exception as exc:
353
+ self.console.print(f"[bold red]Error:[/bold red] {exc}")