sqlsaber 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/__init__.py +2 -2
- sqlsaber/agents/base.py +1 -1
- sqlsaber/agents/mcp.py +1 -1
- sqlsaber/agents/pydantic_ai_agent.py +207 -135
- sqlsaber/application/__init__.py +1 -0
- sqlsaber/application/auth_setup.py +164 -0
- sqlsaber/application/db_setup.py +223 -0
- sqlsaber/application/model_selection.py +98 -0
- sqlsaber/application/prompts.py +115 -0
- sqlsaber/cli/auth.py +22 -50
- sqlsaber/cli/commands.py +22 -28
- sqlsaber/cli/completers.py +2 -0
- sqlsaber/cli/database.py +25 -86
- sqlsaber/cli/display.py +29 -9
- sqlsaber/cli/interactive.py +150 -127
- sqlsaber/cli/models.py +18 -28
- sqlsaber/cli/onboarding.py +325 -0
- sqlsaber/cli/streaming.py +15 -17
- sqlsaber/cli/threads.py +10 -6
- sqlsaber/config/api_keys.py +2 -2
- sqlsaber/config/settings.py +25 -2
- sqlsaber/database/__init__.py +55 -1
- sqlsaber/database/base.py +124 -0
- sqlsaber/database/csv.py +133 -0
- sqlsaber/database/duckdb.py +313 -0
- sqlsaber/database/mysql.py +345 -0
- sqlsaber/database/postgresql.py +328 -0
- sqlsaber/database/schema.py +66 -963
- sqlsaber/database/sqlite.py +258 -0
- sqlsaber/mcp/mcp.py +1 -1
- sqlsaber/tools/sql_tools.py +1 -1
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/METADATA +43 -9
- sqlsaber-0.27.0.dist-info/RECORD +58 -0
- sqlsaber/database/connection.py +0 -535
- sqlsaber-0.25.0.dist-info/RECORD +0 -47
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.25.0.dist-info → sqlsaber-0.27.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/cli/database.py
CHANGED
|
@@ -81,95 +81,34 @@ def add(
|
|
|
81
81
|
|
|
82
82
|
if interactive:
|
|
83
83
|
# Interactive mode - prompt for all required fields
|
|
84
|
-
|
|
84
|
+
from sqlsaber.application.db_setup import collect_db_input
|
|
85
|
+
from sqlsaber.application.prompts import AsyncPrompter
|
|
85
86
|
|
|
86
|
-
|
|
87
|
-
if not type or type == "postgresql":
|
|
88
|
-
type = questionary.select(
|
|
89
|
-
"Database type:",
|
|
90
|
-
choices=["postgresql", "mysql", "sqlite", "duckdb"],
|
|
91
|
-
default="postgresql",
|
|
92
|
-
).ask()
|
|
93
|
-
|
|
94
|
-
if type in {"sqlite", "duckdb"}:
|
|
95
|
-
# SQLite/DuckDB only need database file path
|
|
96
|
-
database = database or questionary.path("Database file path:").ask()
|
|
97
|
-
database = str(Path(database).expanduser().resolve())
|
|
98
|
-
host = "localhost"
|
|
99
|
-
port = 0
|
|
100
|
-
username = type
|
|
101
|
-
password = ""
|
|
102
|
-
else:
|
|
103
|
-
# PostgreSQL/MySQL need connection details
|
|
104
|
-
host = host or questionary.text("Host:", default="localhost").ask()
|
|
87
|
+
console.print(f"[bold]Adding database connection: {name}[/bold]")
|
|
105
88
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
89
|
+
async def collect_input():
|
|
90
|
+
prompter = AsyncPrompter()
|
|
91
|
+
return await collect_db_input(
|
|
92
|
+
prompter=prompter, name=name, db_type=type, include_ssl=True
|
|
109
93
|
)
|
|
110
94
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
"require",
|
|
129
|
-
"verify-ca",
|
|
130
|
-
"verify-full",
|
|
131
|
-
],
|
|
132
|
-
default="prefer",
|
|
133
|
-
).ask()
|
|
134
|
-
)
|
|
135
|
-
elif type == "mysql":
|
|
136
|
-
ssl_mode = (
|
|
137
|
-
ssl_mode
|
|
138
|
-
or questionary.select(
|
|
139
|
-
"SSL mode for MySQL:",
|
|
140
|
-
choices=[
|
|
141
|
-
"DISABLED",
|
|
142
|
-
"PREFERRED",
|
|
143
|
-
"REQUIRED",
|
|
144
|
-
"VERIFY_CA",
|
|
145
|
-
"VERIFY_IDENTITY",
|
|
146
|
-
],
|
|
147
|
-
default="PREFERRED",
|
|
148
|
-
).ask()
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
if ssl_mode and ssl_mode not in ["disable", "DISABLED"]:
|
|
152
|
-
if questionary.confirm(
|
|
153
|
-
"Specify SSL certificate files?", default=False
|
|
154
|
-
).ask():
|
|
155
|
-
ssl_ca = (
|
|
156
|
-
ssl_ca or questionary.path("SSL CA certificate file:").ask()
|
|
157
|
-
)
|
|
158
|
-
if questionary.confirm(
|
|
159
|
-
"Specify client certificate?", default=False
|
|
160
|
-
).ask():
|
|
161
|
-
ssl_cert = (
|
|
162
|
-
ssl_cert
|
|
163
|
-
or questionary.path(
|
|
164
|
-
"SSL client certificate file:"
|
|
165
|
-
).ask()
|
|
166
|
-
)
|
|
167
|
-
ssl_key = (
|
|
168
|
-
ssl_key
|
|
169
|
-
or questionary.path(
|
|
170
|
-
"SSL client private key file:"
|
|
171
|
-
).ask()
|
|
172
|
-
)
|
|
95
|
+
db_input = asyncio.run(collect_input())
|
|
96
|
+
|
|
97
|
+
if db_input is None:
|
|
98
|
+
console.print("[yellow]Operation cancelled[/yellow]")
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
# Extract values from db_input
|
|
102
|
+
type = db_input.type
|
|
103
|
+
host = db_input.host
|
|
104
|
+
port = db_input.port
|
|
105
|
+
database = db_input.database
|
|
106
|
+
username = db_input.username
|
|
107
|
+
password = db_input.password
|
|
108
|
+
ssl_mode = db_input.ssl_mode
|
|
109
|
+
ssl_ca = db_input.ssl_ca
|
|
110
|
+
ssl_cert = db_input.ssl_cert
|
|
111
|
+
ssl_key = db_input.ssl_key
|
|
173
112
|
else:
|
|
174
113
|
# Non-interactive mode - use provided values or defaults
|
|
175
114
|
if type == "sqlite":
|
|
@@ -354,7 +293,7 @@ def test(
|
|
|
354
293
|
|
|
355
294
|
async def test_connection():
|
|
356
295
|
# Lazy import to keep CLI startup fast
|
|
357
|
-
from sqlsaber.database
|
|
296
|
+
from sqlsaber.database import DatabaseConnection
|
|
358
297
|
|
|
359
298
|
if name:
|
|
360
299
|
db_config = config_manager.get_database(name)
|
sqlsaber/cli/display.py
CHANGED
|
@@ -8,7 +8,7 @@ rendered with Live.
|
|
|
8
8
|
import json
|
|
9
9
|
from typing import Sequence, Type
|
|
10
10
|
|
|
11
|
-
from pydantic_ai.messages import ModelResponsePart, TextPart
|
|
11
|
+
from pydantic_ai.messages import ModelResponsePart, TextPart, ThinkingPart
|
|
12
12
|
from rich.columns import Columns
|
|
13
13
|
from rich.console import Console, ConsoleOptions, RenderResult
|
|
14
14
|
from rich.live import Live
|
|
@@ -75,7 +75,7 @@ class LiveMarkdownRenderer:
|
|
|
75
75
|
self.end()
|
|
76
76
|
self.paragraph_break()
|
|
77
77
|
|
|
78
|
-
self._start()
|
|
78
|
+
self._start(kind)
|
|
79
79
|
self._current_kind = kind
|
|
80
80
|
|
|
81
81
|
def append(self, text: str | None) -> None:
|
|
@@ -87,7 +87,13 @@ class LiveMarkdownRenderer:
|
|
|
87
87
|
self.ensure_segment(TextPart)
|
|
88
88
|
|
|
89
89
|
self._buffer += text
|
|
90
|
-
|
|
90
|
+
|
|
91
|
+
# Apply dim styling for thinking segments
|
|
92
|
+
if self._current_kind == ThinkingPart:
|
|
93
|
+
content = Markdown(self._buffer, style="dim")
|
|
94
|
+
self._live.update(content)
|
|
95
|
+
else:
|
|
96
|
+
self._live.update(Markdown(self._buffer))
|
|
91
97
|
|
|
92
98
|
def end(self) -> None:
|
|
93
99
|
"""Finalize and stop the current Live segment, if any."""
|
|
@@ -95,13 +101,17 @@ class LiveMarkdownRenderer:
|
|
|
95
101
|
return
|
|
96
102
|
# Persist the *final* render exactly once, then shut Live down.
|
|
97
103
|
buf = self._buffer
|
|
104
|
+
kind = self._current_kind
|
|
98
105
|
self._live.stop()
|
|
99
106
|
self._live = None
|
|
100
107
|
self._buffer = ""
|
|
101
108
|
self._current_kind = None
|
|
102
109
|
# Print the complete markdown to scroll-back for permanent reference
|
|
103
110
|
if buf:
|
|
104
|
-
|
|
111
|
+
if kind == ThinkingPart:
|
|
112
|
+
self.console.print(Markdown(buf, style="dim"))
|
|
113
|
+
else:
|
|
114
|
+
self.console.print(Markdown(buf))
|
|
105
115
|
|
|
106
116
|
def end_if_active(self) -> None:
|
|
107
117
|
self.end()
|
|
@@ -153,10 +163,20 @@ class LiveMarkdownRenderer:
|
|
|
153
163
|
text = Text(f" {message}", style="yellow")
|
|
154
164
|
return Columns([spinner, text], expand=False)
|
|
155
165
|
|
|
156
|
-
def _start(
|
|
166
|
+
def _start(
|
|
167
|
+
self, kind: Type[ModelResponsePart] | None = None, initial_markdown: str = ""
|
|
168
|
+
) -> None:
|
|
157
169
|
if self._live is not None:
|
|
158
170
|
self.end()
|
|
159
171
|
self._buffer = initial_markdown or ""
|
|
172
|
+
|
|
173
|
+
# Add visual styling for thinking segments
|
|
174
|
+
if kind == ThinkingPart:
|
|
175
|
+
if self.console.is_terminal:
|
|
176
|
+
self.console.print("[dim]💭 Thinking...[/dim]")
|
|
177
|
+
else:
|
|
178
|
+
self.console.print("*Thinking...*\n")
|
|
179
|
+
|
|
160
180
|
# NOTE: Use transient=True so the live widget disappears on exit,
|
|
161
181
|
# giving a clean transition to the final printed result.
|
|
162
182
|
live = Live(
|
|
@@ -219,7 +239,9 @@ class DisplayManager:
|
|
|
219
239
|
if self.console.is_terminal:
|
|
220
240
|
self.console.print("[dim bold]:gear: Executing SQL:[/dim bold]")
|
|
221
241
|
self.show_newline()
|
|
222
|
-
syntax = Syntax(
|
|
242
|
+
syntax = Syntax(
|
|
243
|
+
query, "sql", background_color="default", word_wrap=True
|
|
244
|
+
)
|
|
223
245
|
self.console.print(syntax)
|
|
224
246
|
else:
|
|
225
247
|
self.console.print("**Executing SQL:**\n")
|
|
@@ -271,9 +293,7 @@ class DisplayManager:
|
|
|
271
293
|
f"[yellow]... and {len(results) - 20} more rows[/yellow]"
|
|
272
294
|
)
|
|
273
295
|
else:
|
|
274
|
-
self.console.print(
|
|
275
|
-
f"*... and {len(results) - 20} more rows*\n"
|
|
276
|
-
)
|
|
296
|
+
self.console.print(f"*... and {len(results) - 20} more rows*\n")
|
|
277
297
|
|
|
278
298
|
def show_error(self, error_message: str):
|
|
279
299
|
"""Display error message."""
|
sqlsaber/cli/interactive.py
CHANGED
|
@@ -3,13 +3,13 @@
|
|
|
3
3
|
import asyncio
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from textwrap import dedent
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
6
7
|
|
|
7
8
|
import platformdirs
|
|
8
9
|
from prompt_toolkit import PromptSession
|
|
9
10
|
from prompt_toolkit.history import FileHistory
|
|
10
11
|
from prompt_toolkit.patch_stdout import patch_stdout
|
|
11
12
|
from prompt_toolkit.styles import Style
|
|
12
|
-
from pydantic_ai import Agent
|
|
13
13
|
from rich.console import Console
|
|
14
14
|
from rich.markdown import Markdown
|
|
15
15
|
from rich.panel import Panel
|
|
@@ -21,7 +21,7 @@ from sqlsaber.cli.completers import (
|
|
|
21
21
|
)
|
|
22
22
|
from sqlsaber.cli.display import DisplayManager
|
|
23
23
|
from sqlsaber.cli.streaming import StreamingQueryHandler
|
|
24
|
-
from sqlsaber.database
|
|
24
|
+
from sqlsaber.database import (
|
|
25
25
|
CSVConnection,
|
|
26
26
|
DuckDBConnection,
|
|
27
27
|
MySQLConnection,
|
|
@@ -31,31 +31,20 @@ from sqlsaber.database.connection import (
|
|
|
31
31
|
from sqlsaber.database.schema import SchemaManager
|
|
32
32
|
from sqlsaber.threads import ThreadStorage
|
|
33
33
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
return [
|
|
37
|
-
(
|
|
38
|
-
"class:bottom-toolbar",
|
|
39
|
-
" Use 'Esc-Enter' or 'Meta-Enter' to submit.",
|
|
40
|
-
)
|
|
41
|
-
]
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
style = Style.from_dict(
|
|
45
|
-
{
|
|
46
|
-
"frame.border": "#ebbcba",
|
|
47
|
-
"bottom-toolbar": "#ebbcba bg:#21202e",
|
|
48
|
-
}
|
|
49
|
-
)
|
|
34
|
+
if TYPE_CHECKING:
|
|
35
|
+
from sqlsaber.agents.pydantic_ai_agent import SQLSaberAgent
|
|
50
36
|
|
|
51
37
|
|
|
52
38
|
class InteractiveSession:
|
|
53
39
|
"""Manages interactive CLI sessions."""
|
|
54
40
|
|
|
41
|
+
exit_commands = {"/exit", "/quit", "exit", "quit"}
|
|
42
|
+
resume_command_template = "saber threads resume {thread_id}"
|
|
43
|
+
|
|
55
44
|
def __init__(
|
|
56
45
|
self,
|
|
57
46
|
console: Console,
|
|
58
|
-
|
|
47
|
+
sqlsaber_agent: "SQLSaberAgent",
|
|
59
48
|
db_conn,
|
|
60
49
|
database_name: str,
|
|
61
50
|
*,
|
|
@@ -63,7 +52,7 @@ class InteractiveSession:
|
|
|
63
52
|
initial_history: list | None = None,
|
|
64
53
|
):
|
|
65
54
|
self.console = console
|
|
66
|
-
self.
|
|
55
|
+
self.sqlsaber_agent = sqlsaber_agent
|
|
67
56
|
self.db_conn = db_conn
|
|
68
57
|
self.database_name = database_name
|
|
69
58
|
self.display = DisplayManager(console)
|
|
@@ -77,28 +66,33 @@ class InteractiveSession:
|
|
|
77
66
|
self._thread_id: str | None = initial_thread_id
|
|
78
67
|
self.first_message = not self._thread_id
|
|
79
68
|
|
|
80
|
-
def
|
|
81
|
-
"""
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
else "SQLite"
|
|
94
|
-
if isinstance(self.db_conn, SQLiteConnection)
|
|
95
|
-
else "database"
|
|
69
|
+
def _history_path(self) -> Path:
|
|
70
|
+
"""Get the history file path, ensuring directory exists."""
|
|
71
|
+
history_dir = Path(platformdirs.user_config_dir("sqlsaber"))
|
|
72
|
+
history_dir.mkdir(parents=True, exist_ok=True)
|
|
73
|
+
return history_dir / "history"
|
|
74
|
+
|
|
75
|
+
def _prompt_style(self) -> Style:
|
|
76
|
+
"""Get the prompt style configuration."""
|
|
77
|
+
return Style.from_dict(
|
|
78
|
+
{
|
|
79
|
+
"frame.border": "gray",
|
|
80
|
+
"bottom-toolbar": "white bg:#21202e",
|
|
81
|
+
}
|
|
96
82
|
)
|
|
97
83
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
84
|
+
def _bottom_toolbar(self):
|
|
85
|
+
"""Get the bottom toolbar text."""
|
|
86
|
+
return [
|
|
87
|
+
(
|
|
88
|
+
"class:bottom-toolbar",
|
|
89
|
+
" Use 'Esc-Enter' or 'Meta-Enter' to submit.",
|
|
90
|
+
)
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
def _banner(self) -> str:
|
|
94
|
+
"""Get the ASCII banner."""
|
|
95
|
+
return """
|
|
102
96
|
███████ ██████ ██ ███████ █████ ██████ ███████ ██████
|
|
103
97
|
██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
104
98
|
███████ ██ ██ ██ ███████ ███████ ██████ █████ ██████
|
|
@@ -106,35 +100,110 @@ class InteractiveSession:
|
|
|
106
100
|
███████ ██████ ███████ ███████ ██ ██ ██████ ███████ ██ ██
|
|
107
101
|
▀▀
|
|
108
102
|
"""
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
dedent("""
|
|
103
|
+
|
|
104
|
+
def _instructions(self) -> str:
|
|
105
|
+
"""Get the instruction text."""
|
|
106
|
+
return dedent("""
|
|
114
107
|
- Use `/` for slash commands
|
|
115
108
|
- Type `@` to get table name completions
|
|
116
109
|
- Start message with `#` to add something to agent's memory
|
|
117
110
|
- Use `Ctrl+C` to interrupt and `Ctrl+D` to exit
|
|
118
111
|
""")
|
|
119
|
-
)
|
|
120
|
-
)
|
|
121
112
|
|
|
113
|
+
def _db_type_name(self) -> str:
|
|
114
|
+
"""Get human-readable database type name."""
|
|
115
|
+
mapping = {
|
|
116
|
+
PostgreSQLConnection: "PostgreSQL",
|
|
117
|
+
MySQLConnection: "MySQL",
|
|
118
|
+
DuckDBConnection: "DuckDB",
|
|
119
|
+
CSVConnection: "DuckDB",
|
|
120
|
+
SQLiteConnection: "SQLite",
|
|
121
|
+
}
|
|
122
|
+
for cls, name in mapping.items():
|
|
123
|
+
if isinstance(self.db_conn, cls):
|
|
124
|
+
return name
|
|
125
|
+
return "database"
|
|
126
|
+
|
|
127
|
+
def _resume_hint(self, thread_id: str) -> str:
|
|
128
|
+
"""Build resume command hint."""
|
|
129
|
+
return self.resume_command_template.format(thread_id=thread_id)
|
|
130
|
+
|
|
131
|
+
def show_welcome_message(self):
|
|
132
|
+
"""Display welcome message for interactive mode."""
|
|
133
|
+
if self.first_message:
|
|
134
|
+
self.console.print(Panel.fit(self._banner()))
|
|
135
|
+
self.console.print(Markdown(self._instructions()))
|
|
136
|
+
|
|
137
|
+
db_name = self.database_name or "Unknown"
|
|
122
138
|
self.console.print(
|
|
123
|
-
f"[bold blue]\n\nConnected to:[/bold blue] {db_name} ({
|
|
139
|
+
f"[bold blue]\n\nConnected to:[/bold blue] {db_name} ({self._db_type_name()})\n"
|
|
124
140
|
)
|
|
125
|
-
|
|
141
|
+
|
|
126
142
|
if self._thread_id:
|
|
127
143
|
self.console.print(f"[dim]Resuming thread:[/dim] {self._thread_id}\n")
|
|
128
144
|
|
|
129
|
-
async def
|
|
130
|
-
"""End thread and display
|
|
131
|
-
# Print resume hint if there is an active thread
|
|
145
|
+
async def _end_thread(self):
|
|
146
|
+
"""End thread and display resume hint."""
|
|
132
147
|
if self._thread_id:
|
|
133
148
|
await self._threads.end_thread(self._thread_id)
|
|
134
149
|
self.console.print(
|
|
135
|
-
f"[dim]You can continue this thread using:[/dim]
|
|
150
|
+
f"[dim]You can continue this thread using:[/dim] {self._resume_hint(self._thread_id)}"
|
|
136
151
|
)
|
|
137
152
|
|
|
153
|
+
async def _handle_memory(self, content: str):
|
|
154
|
+
"""Handle memory addition command."""
|
|
155
|
+
if not content:
|
|
156
|
+
self.console.print("[yellow]Empty memory content after '#'[/yellow]\n")
|
|
157
|
+
return
|
|
158
|
+
|
|
159
|
+
try:
|
|
160
|
+
mm = self.sqlsaber_agent.memory_manager
|
|
161
|
+
if mm and self.database_name:
|
|
162
|
+
memory = mm.add_memory(self.database_name, content)
|
|
163
|
+
self.console.print(f"[green]✓ Memory added:[/green] {content}")
|
|
164
|
+
self.console.print(f"[dim]Memory ID: {memory.id}[/dim]\n")
|
|
165
|
+
else:
|
|
166
|
+
self.console.print(
|
|
167
|
+
"[yellow]Could not add memory (no database context)[/yellow]\n"
|
|
168
|
+
)
|
|
169
|
+
except Exception as exc:
|
|
170
|
+
self.console.print(f"[yellow]Could not add memory:[/yellow] {exc}\n")
|
|
171
|
+
|
|
172
|
+
async def _cmd_clear(self):
|
|
173
|
+
"""Clear conversation history."""
|
|
174
|
+
self.message_history = []
|
|
175
|
+
try:
|
|
176
|
+
if self._thread_id:
|
|
177
|
+
await self._threads.end_thread(self._thread_id)
|
|
178
|
+
except Exception:
|
|
179
|
+
pass
|
|
180
|
+
self.console.print("[green]Conversation history cleared.[/green]\n")
|
|
181
|
+
self._thread_id = None
|
|
182
|
+
self.first_message = True
|
|
183
|
+
|
|
184
|
+
async def _cmd_thinking_on(self):
|
|
185
|
+
"""Enable thinking mode."""
|
|
186
|
+
self.sqlsaber_agent.set_thinking(enabled=True)
|
|
187
|
+
self.console.print("[green]✓ Thinking enabled[/green]\n")
|
|
188
|
+
|
|
189
|
+
async def _cmd_thinking_off(self):
|
|
190
|
+
"""Disable thinking mode."""
|
|
191
|
+
self.sqlsaber_agent.set_thinking(enabled=False)
|
|
192
|
+
self.console.print("[green]✓ Thinking disabled[/green]\n")
|
|
193
|
+
|
|
194
|
+
async def _handle_command(self, user_query: str) -> bool:
|
|
195
|
+
"""Handle slash commands. Returns True if command was handled."""
|
|
196
|
+
if user_query == "/clear":
|
|
197
|
+
await self._cmd_clear()
|
|
198
|
+
return True
|
|
199
|
+
if user_query == "/thinking on":
|
|
200
|
+
await self._cmd_thinking_on()
|
|
201
|
+
return True
|
|
202
|
+
if user_query == "/thinking off":
|
|
203
|
+
await self._cmd_thinking_off()
|
|
204
|
+
return True
|
|
205
|
+
return False
|
|
206
|
+
|
|
138
207
|
async def _update_table_cache(self):
|
|
139
208
|
"""Update the table completer cache with fresh data."""
|
|
140
209
|
try:
|
|
@@ -167,6 +236,10 @@ class InteractiveSession:
|
|
|
167
236
|
# If there's an error, just use empty cache
|
|
168
237
|
self.table_completer.update_cache([])
|
|
169
238
|
|
|
239
|
+
async def before_prompt_loop(self):
|
|
240
|
+
"""Hook to refresh context before prompt loop."""
|
|
241
|
+
await self._update_table_cache()
|
|
242
|
+
|
|
170
243
|
async def _execute_query_with_cancellation(self, user_query: str):
|
|
171
244
|
"""Execute a query with cancellation support."""
|
|
172
245
|
# Create cancellation token
|
|
@@ -176,7 +249,7 @@ class InteractiveSession:
|
|
|
176
249
|
query_task = asyncio.create_task(
|
|
177
250
|
self.streaming_handler.execute_streaming_query(
|
|
178
251
|
user_query,
|
|
179
|
-
self.
|
|
252
|
+
self.sqlsaber_agent,
|
|
180
253
|
self.cancellation_token,
|
|
181
254
|
self.message_history,
|
|
182
255
|
)
|
|
@@ -191,11 +264,6 @@ class InteractiveSession:
|
|
|
191
264
|
# Use all_messages() so the system prompt and all prior turns are preserved
|
|
192
265
|
self.message_history = run_result.all_messages()
|
|
193
266
|
|
|
194
|
-
# Extract title (first user prompt) and model name
|
|
195
|
-
if not self._thread_id:
|
|
196
|
-
title = user_query
|
|
197
|
-
model_name = self.agent.model.model_name
|
|
198
|
-
|
|
199
267
|
# Persist snapshot to thread storage (create or overwrite)
|
|
200
268
|
self._thread_id = await self._threads.save_snapshot(
|
|
201
269
|
messages_json=run_result.all_messages_json(),
|
|
@@ -206,9 +274,10 @@ class InteractiveSession:
|
|
|
206
274
|
if self.first_message:
|
|
207
275
|
await self._threads.save_metadata(
|
|
208
276
|
thread_id=self._thread_id,
|
|
209
|
-
title=
|
|
210
|
-
model_name=model_name,
|
|
277
|
+
title=user_query,
|
|
278
|
+
model_name=self.sqlsaber_agent.agent.model.model_name,
|
|
211
279
|
)
|
|
280
|
+
self.first_message = False
|
|
212
281
|
except Exception:
|
|
213
282
|
pass
|
|
214
283
|
finally:
|
|
@@ -220,15 +289,9 @@ class InteractiveSession:
|
|
|
220
289
|
async def run(self):
|
|
221
290
|
"""Run the interactive session loop."""
|
|
222
291
|
self.show_welcome_message()
|
|
292
|
+
await self.before_prompt_loop()
|
|
223
293
|
|
|
224
|
-
|
|
225
|
-
await self._update_table_cache()
|
|
226
|
-
|
|
227
|
-
session = PromptSession(
|
|
228
|
-
history=FileHistory(
|
|
229
|
-
Path(platformdirs.user_config_dir("sqlsaber")) / "history"
|
|
230
|
-
)
|
|
231
|
-
)
|
|
294
|
+
session = PromptSession(history=FileHistory(self._history_path()))
|
|
232
295
|
|
|
233
296
|
while True:
|
|
234
297
|
try:
|
|
@@ -240,72 +303,32 @@ class InteractiveSession:
|
|
|
240
303
|
SlashCommandCompleter(), self.table_completer
|
|
241
304
|
),
|
|
242
305
|
show_frame=True,
|
|
243
|
-
bottom_toolbar=
|
|
244
|
-
style=
|
|
306
|
+
bottom_toolbar=self._bottom_toolbar,
|
|
307
|
+
style=self._prompt_style(),
|
|
245
308
|
)
|
|
246
309
|
|
|
247
310
|
if not user_query:
|
|
248
311
|
continue
|
|
249
312
|
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
or user_query.startswith("/quit")
|
|
313
|
+
# Handle exit commands
|
|
314
|
+
if user_query in self.exit_commands or any(
|
|
315
|
+
user_query.startswith(cmd) for cmd in self.exit_commands
|
|
254
316
|
):
|
|
255
|
-
await self.
|
|
317
|
+
await self._end_thread()
|
|
256
318
|
break
|
|
257
319
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
self.message_history = []
|
|
261
|
-
# End current thread (if any) so the next turn creates a fresh one
|
|
262
|
-
try:
|
|
263
|
-
if self._thread_id:
|
|
264
|
-
await self._threads.end_thread(self._thread_id)
|
|
265
|
-
except Exception:
|
|
266
|
-
pass
|
|
267
|
-
self.console.print("[green]Conversation history cleared.[/green]\n")
|
|
268
|
-
# Do not print resume hint when clearing; a new thread will be created on next turn
|
|
269
|
-
self._thread_id = None
|
|
320
|
+
# Handle slash commands
|
|
321
|
+
if await self._handle_command(user_query):
|
|
270
322
|
continue
|
|
271
323
|
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
if memory_content:
|
|
277
|
-
# Add memory via the agent's memory manager
|
|
278
|
-
try:
|
|
279
|
-
mm = getattr(
|
|
280
|
-
self.agent, "_sqlsaber_memory_manager", None
|
|
281
|
-
)
|
|
282
|
-
if mm and self.database_name:
|
|
283
|
-
memory = mm.add_memory(
|
|
284
|
-
self.database_name, memory_content
|
|
285
|
-
)
|
|
286
|
-
self.console.print(
|
|
287
|
-
f"[green]✓ Memory added:[/green] {memory_content}"
|
|
288
|
-
)
|
|
289
|
-
self.console.print(
|
|
290
|
-
f"[dim]Memory ID: {memory.id}[/dim]\n"
|
|
291
|
-
)
|
|
292
|
-
else:
|
|
293
|
-
self.console.print(
|
|
294
|
-
"[yellow]Could not add memory (no database context)[/yellow]\n"
|
|
295
|
-
)
|
|
296
|
-
except Exception:
|
|
297
|
-
self.console.print(
|
|
298
|
-
"[yellow]Could not add memory[/yellow]\n"
|
|
299
|
-
)
|
|
300
|
-
else:
|
|
301
|
-
self.console.print(
|
|
302
|
-
"[yellow]Empty memory content after '#'[/yellow]\n"
|
|
303
|
-
)
|
|
304
|
-
continue
|
|
324
|
+
# Handle memory addition
|
|
325
|
+
if user_query.strip().startswith("#"):
|
|
326
|
+
await self._handle_memory(user_query.strip()[1:].strip())
|
|
327
|
+
continue
|
|
305
328
|
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
329
|
+
# Execute query with cancellation support
|
|
330
|
+
await self._execute_query_with_cancellation(user_query)
|
|
331
|
+
self.display.show_newline()
|
|
309
332
|
|
|
310
333
|
except KeyboardInterrupt:
|
|
311
334
|
# Handle Ctrl+C - cancel current task if running
|
|
@@ -324,7 +347,7 @@ class InteractiveSession:
|
|
|
324
347
|
)
|
|
325
348
|
except EOFError:
|
|
326
349
|
# Exit when Ctrl+D is pressed
|
|
327
|
-
await self.
|
|
350
|
+
await self._end_thread()
|
|
328
351
|
break
|
|
329
|
-
except Exception as
|
|
330
|
-
self.console.print(f"[bold red]Error:[/bold red] {
|
|
352
|
+
except Exception as exc:
|
|
353
|
+
self.console.print(f"[bold red]Error:[/bold red] {exc}")
|