sqlsaber 0.27.0__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlsaber/application/auth_setup.py +2 -2
- sqlsaber/application/db_setup.py +2 -3
- sqlsaber/application/model_selection.py +2 -2
- sqlsaber/cli/auth.py +2 -2
- sqlsaber/cli/commands.py +2 -2
- sqlsaber/cli/database.py +2 -2
- sqlsaber/cli/display.py +59 -40
- sqlsaber/cli/interactive.py +18 -27
- sqlsaber/cli/memory.py +2 -2
- sqlsaber/cli/models.py +2 -2
- sqlsaber/cli/onboarding.py +2 -2
- sqlsaber/cli/streaming.py +1 -1
- sqlsaber/cli/threads.py +35 -16
- sqlsaber/config/api_keys.py +2 -2
- sqlsaber/config/oauth_flow.py +3 -2
- sqlsaber/config/oauth_tokens.py +3 -5
- sqlsaber/database/base.py +6 -0
- sqlsaber/database/csv.py +5 -0
- sqlsaber/database/duckdb.py +5 -0
- sqlsaber/database/mysql.py +5 -0
- sqlsaber/database/postgresql.py +5 -0
- sqlsaber/database/sqlite.py +5 -0
- sqlsaber/theme/__init__.py +5 -0
- sqlsaber/theme/manager.py +219 -0
- sqlsaber/tools/sql_guard.py +225 -0
- sqlsaber/tools/sql_tools.py +10 -35
- {sqlsaber-0.27.0.dist-info → sqlsaber-0.28.0.dist-info}/METADATA +2 -1
- {sqlsaber-0.27.0.dist-info → sqlsaber-0.28.0.dist-info}/RECORD +31 -28
- {sqlsaber-0.27.0.dist-info → sqlsaber-0.28.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.27.0.dist-info → sqlsaber-0.28.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.27.0.dist-info → sqlsaber-0.28.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,15 +3,15 @@
|
|
|
3
3
|
import asyncio
|
|
4
4
|
|
|
5
5
|
from questionary import Choice
|
|
6
|
-
from rich.console import Console
|
|
7
6
|
|
|
8
7
|
from sqlsaber.application.prompts import Prompter
|
|
9
8
|
from sqlsaber.config import providers
|
|
10
9
|
from sqlsaber.config.api_keys import APIKeyManager
|
|
11
10
|
from sqlsaber.config.auth import AuthConfigManager, AuthMethod
|
|
12
11
|
from sqlsaber.config.oauth_flow import AnthropicOAuthFlow
|
|
12
|
+
from sqlsaber.theme.manager import create_console
|
|
13
13
|
|
|
14
|
-
console =
|
|
14
|
+
console = create_console()
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
async def select_provider(prompter: Prompter, default: str = "anthropic") -> str | None:
|
sqlsaber/application/db_setup.py
CHANGED
|
@@ -4,12 +4,11 @@ import getpass
|
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
|
|
7
|
-
from rich.console import Console
|
|
8
|
-
|
|
9
7
|
from sqlsaber.application.prompts import Prompter
|
|
10
8
|
from sqlsaber.config.database import DatabaseConfig, DatabaseConfigManager
|
|
9
|
+
from sqlsaber.theme.manager import create_console
|
|
11
10
|
|
|
12
|
-
console =
|
|
11
|
+
console = create_console()
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
@dataclass
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
"""Shared model selection logic for onboarding and CLI."""
|
|
2
2
|
|
|
3
3
|
from questionary import Choice
|
|
4
|
-
from rich.console import Console
|
|
5
4
|
|
|
6
5
|
from sqlsaber.application.prompts import Prompter
|
|
7
6
|
from sqlsaber.cli.models import ModelManager
|
|
7
|
+
from sqlsaber.theme.manager import create_console
|
|
8
8
|
|
|
9
|
-
console =
|
|
9
|
+
console = create_console()
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
async def fetch_models(
|
sqlsaber/cli/auth.py
CHANGED
|
@@ -5,15 +5,15 @@ import os
|
|
|
5
5
|
import cyclopts
|
|
6
6
|
import keyring
|
|
7
7
|
import questionary
|
|
8
|
-
from rich.console import Console
|
|
9
8
|
|
|
10
9
|
from sqlsaber.config import providers
|
|
11
10
|
from sqlsaber.config.api_keys import APIKeyManager
|
|
12
11
|
from sqlsaber.config.auth import AuthConfigManager, AuthMethod
|
|
13
12
|
from sqlsaber.config.oauth_tokens import OAuthTokenManager
|
|
13
|
+
from sqlsaber.theme.manager import create_console
|
|
14
14
|
|
|
15
15
|
# Global instances for CLI commands
|
|
16
|
-
console =
|
|
16
|
+
console = create_console()
|
|
17
17
|
config_manager = AuthConfigManager()
|
|
18
18
|
|
|
19
19
|
# Create the authentication management CLI app
|
sqlsaber/cli/commands.py
CHANGED
|
@@ -5,7 +5,6 @@ import sys
|
|
|
5
5
|
from typing import Annotated
|
|
6
6
|
|
|
7
7
|
import cyclopts
|
|
8
|
-
from rich.console import Console
|
|
9
8
|
|
|
10
9
|
from sqlsaber.cli.auth import create_auth_app
|
|
11
10
|
from sqlsaber.cli.database import create_db_app
|
|
@@ -16,6 +15,7 @@ from sqlsaber.cli.threads import create_threads_app
|
|
|
16
15
|
|
|
17
16
|
# Lazy imports - only import what's needed for CLI parsing
|
|
18
17
|
from sqlsaber.config.database import DatabaseConfigManager
|
|
18
|
+
from sqlsaber.theme.manager import create_console
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class CLIError(Exception):
|
|
@@ -37,7 +37,7 @@ app.command(create_memory_app(), name="memory")
|
|
|
37
37
|
app.command(create_models_app(), name="models")
|
|
38
38
|
app.command(create_threads_app(), name="threads")
|
|
39
39
|
|
|
40
|
-
console =
|
|
40
|
+
console = create_console()
|
|
41
41
|
config_manager = DatabaseConfigManager()
|
|
42
42
|
|
|
43
43
|
|
sqlsaber/cli/database.py
CHANGED
|
@@ -8,13 +8,13 @@ from typing import Annotated
|
|
|
8
8
|
|
|
9
9
|
import cyclopts
|
|
10
10
|
import questionary
|
|
11
|
-
from rich.console import Console
|
|
12
11
|
from rich.table import Table
|
|
13
12
|
|
|
14
13
|
from sqlsaber.config.database import DatabaseConfig, DatabaseConfigManager
|
|
14
|
+
from sqlsaber.theme.manager import create_console
|
|
15
15
|
|
|
16
16
|
# Global instances for CLI commands
|
|
17
|
-
console =
|
|
17
|
+
console = create_console()
|
|
18
18
|
config_manager = DatabaseConfigManager()
|
|
19
19
|
|
|
20
20
|
# Create the database management CLI app
|
sqlsaber/cli/display.py
CHANGED
|
@@ -19,6 +19,8 @@ from rich.syntax import Syntax
|
|
|
19
19
|
from rich.table import Table
|
|
20
20
|
from rich.text import Text
|
|
21
21
|
|
|
22
|
+
from sqlsaber.theme.manager import get_theme_manager
|
|
23
|
+
|
|
22
24
|
|
|
23
25
|
class _SimpleCodeBlock(CodeBlock):
|
|
24
26
|
def __rich_console__(
|
|
@@ -46,6 +48,7 @@ class LiveMarkdownRenderer:
|
|
|
46
48
|
|
|
47
49
|
def __init__(self, console: Console):
|
|
48
50
|
self.console = console
|
|
51
|
+
self.tm = get_theme_manager()
|
|
49
52
|
self._live: Live | None = None
|
|
50
53
|
self._status_live: Live | None = None
|
|
51
54
|
self._buffer: str = ""
|
|
@@ -90,10 +93,14 @@ class LiveMarkdownRenderer:
|
|
|
90
93
|
|
|
91
94
|
# Apply dim styling for thinking segments
|
|
92
95
|
if self._current_kind == ThinkingPart:
|
|
93
|
-
content = Markdown(
|
|
96
|
+
content = Markdown(
|
|
97
|
+
self._buffer, style="muted", code_theme=self.tm.pygments_style_name
|
|
98
|
+
)
|
|
94
99
|
self._live.update(content)
|
|
95
100
|
else:
|
|
96
|
-
self._live.update(
|
|
101
|
+
self._live.update(
|
|
102
|
+
Markdown(self._buffer, code_theme=self.tm.pygments_style_name)
|
|
103
|
+
)
|
|
97
104
|
|
|
98
105
|
def end(self) -> None:
|
|
99
106
|
"""Finalize and stop the current Live segment, if any."""
|
|
@@ -109,9 +116,13 @@ class LiveMarkdownRenderer:
|
|
|
109
116
|
# Print the complete markdown to scroll-back for permanent reference
|
|
110
117
|
if buf:
|
|
111
118
|
if kind == ThinkingPart:
|
|
112
|
-
self.console.print(
|
|
119
|
+
self.console.print(
|
|
120
|
+
Markdown(buf, style="muted", code_theme=self.tm.pygments_style_name)
|
|
121
|
+
)
|
|
113
122
|
else:
|
|
114
|
-
self.console.print(
|
|
123
|
+
self.console.print(
|
|
124
|
+
Markdown(buf, code_theme=self.tm.pygments_style_name)
|
|
125
|
+
)
|
|
115
126
|
|
|
116
127
|
def end_if_active(self) -> None:
|
|
117
128
|
self.end()
|
|
@@ -129,7 +140,7 @@ class LiveMarkdownRenderer:
|
|
|
129
140
|
self._buffer = f"```sql\n{sql}\n```"
|
|
130
141
|
# Use context manager to auto-stop and persist final render
|
|
131
142
|
with Live(
|
|
132
|
-
Markdown(self._buffer),
|
|
143
|
+
Markdown(self._buffer, code_theme=self.tm.pygments_style_name),
|
|
133
144
|
console=self.console,
|
|
134
145
|
vertical_overflow="visible",
|
|
135
146
|
refresh_per_second=12,
|
|
@@ -159,8 +170,8 @@ class LiveMarkdownRenderer:
|
|
|
159
170
|
self._status_live = None
|
|
160
171
|
|
|
161
172
|
def _status_renderable(self, message: str):
|
|
162
|
-
spinner = Spinner("dots", style="
|
|
163
|
-
text = Text(f" {message}", style="
|
|
173
|
+
spinner = Spinner("dots", style=self.tm.style("spinner"))
|
|
174
|
+
text = Text(f" {message}", style=self.tm.style("status"))
|
|
164
175
|
return Columns([spinner, text], expand=False)
|
|
165
176
|
|
|
166
177
|
def _start(
|
|
@@ -173,14 +184,14 @@ class LiveMarkdownRenderer:
|
|
|
173
184
|
# Add visual styling for thinking segments
|
|
174
185
|
if kind == ThinkingPart:
|
|
175
186
|
if self.console.is_terminal:
|
|
176
|
-
self.console.print("[
|
|
187
|
+
self.console.print("[muted]💭 Thinking...[/muted]")
|
|
177
188
|
else:
|
|
178
189
|
self.console.print("*Thinking...*\n")
|
|
179
190
|
|
|
180
191
|
# NOTE: Use transient=True so the live widget disappears on exit,
|
|
181
192
|
# giving a clean transition to the final printed result.
|
|
182
193
|
live = Live(
|
|
183
|
-
Markdown(self._buffer),
|
|
194
|
+
Markdown(self._buffer, code_theme=self.tm.pygments_style_name),
|
|
184
195
|
console=self.console,
|
|
185
196
|
transient=True,
|
|
186
197
|
refresh_per_second=12,
|
|
@@ -195,14 +206,16 @@ class DisplayManager:
|
|
|
195
206
|
def __init__(self, console: Console):
|
|
196
207
|
self.console = console
|
|
197
208
|
self.live = LiveMarkdownRenderer(console)
|
|
209
|
+
self.tm = get_theme_manager()
|
|
198
210
|
|
|
199
211
|
def _create_table(
|
|
200
212
|
self,
|
|
201
213
|
columns: Sequence[str | dict[str, str]],
|
|
202
|
-
header_style: str =
|
|
214
|
+
header_style: str | None = None,
|
|
203
215
|
title: str | None = None,
|
|
204
216
|
) -> Table:
|
|
205
217
|
"""Create a Rich table with specified columns."""
|
|
218
|
+
header_style = header_style or self.tm.style("table.header")
|
|
206
219
|
table = Table(show_header=True, header_style=header_style, title=title)
|
|
207
220
|
for col in columns:
|
|
208
221
|
if isinstance(col, dict):
|
|
@@ -220,7 +233,7 @@ class DisplayManager:
|
|
|
220
233
|
if tool_name == "list_tables":
|
|
221
234
|
if self.console.is_terminal:
|
|
222
235
|
self.console.print(
|
|
223
|
-
"[
|
|
236
|
+
"[muted bold]:gear: Discovering available tables[/muted bold]"
|
|
224
237
|
)
|
|
225
238
|
else:
|
|
226
239
|
self.console.print("**Discovering available tables**\n")
|
|
@@ -228,7 +241,7 @@ class DisplayManager:
|
|
|
228
241
|
pattern = tool_input.get("table_pattern", "all tables")
|
|
229
242
|
if self.console.is_terminal:
|
|
230
243
|
self.console.print(
|
|
231
|
-
f"[
|
|
244
|
+
f"[muted bold]:gear: Examining schema for: {pattern}[/muted bold]"
|
|
232
245
|
)
|
|
233
246
|
else:
|
|
234
247
|
self.console.print(f"**Examining schema for:** {pattern}\n")
|
|
@@ -237,10 +250,14 @@ class DisplayManager:
|
|
|
237
250
|
# rendering for threads show/resume. Controlled by include_sql flag.
|
|
238
251
|
query = tool_input.get("query", "")
|
|
239
252
|
if self.console.is_terminal:
|
|
240
|
-
self.console.print("[
|
|
253
|
+
self.console.print("[muted bold]:gear: Executing SQL:[/muted bold]")
|
|
241
254
|
self.show_newline()
|
|
242
255
|
syntax = Syntax(
|
|
243
|
-
query,
|
|
256
|
+
query,
|
|
257
|
+
"sql",
|
|
258
|
+
theme=self.tm.pygments_style_name,
|
|
259
|
+
background_color="default",
|
|
260
|
+
word_wrap=True,
|
|
244
261
|
)
|
|
245
262
|
self.console.print(syntax)
|
|
246
263
|
else:
|
|
@@ -258,9 +275,7 @@ class DisplayManager:
|
|
|
258
275
|
return
|
|
259
276
|
|
|
260
277
|
if self.console.is_terminal:
|
|
261
|
-
self.console.print(
|
|
262
|
-
f"\n[bold magenta]Results ({len(results)} rows):[/bold magenta]"
|
|
263
|
-
)
|
|
278
|
+
self.console.print(f"\n[section]Results ({len(results)} rows):[/section]")
|
|
264
279
|
else:
|
|
265
280
|
self.console.print(f"\n**Results ({len(results)} rows):**\n")
|
|
266
281
|
|
|
@@ -272,7 +287,7 @@ class DisplayManager:
|
|
|
272
287
|
if len(all_columns) > 15:
|
|
273
288
|
if self.console.is_terminal:
|
|
274
289
|
self.console.print(
|
|
275
|
-
f"[
|
|
290
|
+
f"[warning]Note: Showing first 15 of {len(all_columns)} columns[/warning]"
|
|
276
291
|
)
|
|
277
292
|
else:
|
|
278
293
|
self.console.print(
|
|
@@ -290,21 +305,21 @@ class DisplayManager:
|
|
|
290
305
|
if len(results) > 20:
|
|
291
306
|
if self.console.is_terminal:
|
|
292
307
|
self.console.print(
|
|
293
|
-
f"[
|
|
308
|
+
f"[warning]... and {len(results) - 20} more rows[/warning]"
|
|
294
309
|
)
|
|
295
310
|
else:
|
|
296
311
|
self.console.print(f"*... and {len(results) - 20} more rows*\n")
|
|
297
312
|
|
|
298
313
|
def show_error(self, error_message: str):
|
|
299
314
|
"""Display error message."""
|
|
300
|
-
self.console.print(f"\n[
|
|
315
|
+
self.console.print(f"\n[error]Error:[/error] {error_message}")
|
|
301
316
|
|
|
302
317
|
def show_sql_error(self, error_message: str, suggestions: list[str] | None = None):
|
|
303
318
|
"""Display SQL-specific error with optional suggestions."""
|
|
304
319
|
self.show_newline()
|
|
305
|
-
self.console.print(f"[
|
|
320
|
+
self.console.print(f"[error]SQL error:[/error] {error_message}")
|
|
306
321
|
if suggestions:
|
|
307
|
-
self.console.print("[
|
|
322
|
+
self.console.print("[warning]Hints:[/warning]")
|
|
308
323
|
for suggestion in suggestions:
|
|
309
324
|
self.console.print(f" • {suggestion}")
|
|
310
325
|
|
|
@@ -312,7 +327,7 @@ class DisplayManager:
|
|
|
312
327
|
"""Display processing message."""
|
|
313
328
|
self.console.print() # Add newline
|
|
314
329
|
return self.console.status(
|
|
315
|
-
f"[
|
|
330
|
+
f"[status]{message}[/status]", spinner="bouncingBall"
|
|
316
331
|
)
|
|
317
332
|
|
|
318
333
|
def show_newline(self):
|
|
@@ -335,18 +350,20 @@ class DisplayManager:
|
|
|
335
350
|
total_tables = data.get("total_tables", 0)
|
|
336
351
|
|
|
337
352
|
if not tables:
|
|
338
|
-
self.console.print(
|
|
353
|
+
self.console.print(
|
|
354
|
+
"[warning]No tables found in the database.[/warning]"
|
|
355
|
+
)
|
|
339
356
|
return
|
|
340
357
|
|
|
341
358
|
self.console.print(
|
|
342
|
-
f"\n[
|
|
359
|
+
f"\n[title]Database Tables ({total_tables} total):[/title]"
|
|
343
360
|
)
|
|
344
361
|
|
|
345
362
|
# Create a rich table for displaying table information
|
|
346
363
|
columns = [
|
|
347
|
-
{"name": "Schema", "style": "
|
|
348
|
-
{"name": "Table Name", "style": "
|
|
349
|
-
{"name": "Type", "style": "
|
|
364
|
+
{"name": "Schema", "style": "column.schema"},
|
|
365
|
+
{"name": "Table Name", "style": "column.name"},
|
|
366
|
+
{"name": "Type", "style": "column.type"},
|
|
350
367
|
]
|
|
351
368
|
table = self._create_table(columns)
|
|
352
369
|
|
|
@@ -378,26 +395,26 @@ class DisplayManager:
|
|
|
378
395
|
return
|
|
379
396
|
|
|
380
397
|
if not data:
|
|
381
|
-
self.console.print("[
|
|
398
|
+
self.console.print("[warning]No schema information found.[/warning]")
|
|
382
399
|
return
|
|
383
400
|
|
|
384
401
|
self.console.print(
|
|
385
|
-
f"\n[
|
|
402
|
+
f"\n[title]Schema Information ({len(data)} tables):[/title]"
|
|
386
403
|
)
|
|
387
404
|
|
|
388
405
|
# Display each table's schema
|
|
389
406
|
for table_name, table_info in data.items():
|
|
390
|
-
self.console.print(f"\n[
|
|
407
|
+
self.console.print(f"\n[heading]Table: {table_name}[/heading]")
|
|
391
408
|
|
|
392
409
|
# Show columns
|
|
393
410
|
table_columns = table_info.get("columns", {})
|
|
394
411
|
if table_columns:
|
|
395
412
|
# Create a table for columns
|
|
396
413
|
columns = [
|
|
397
|
-
{"name": "Column Name", "style": "
|
|
398
|
-
{"name": "Type", "style": "
|
|
399
|
-
{"name": "Nullable", "style": "
|
|
400
|
-
{"name": "Default", "style": "
|
|
414
|
+
{"name": "Column Name", "style": "column.name"},
|
|
415
|
+
{"name": "Type", "style": "column.type"},
|
|
416
|
+
{"name": "Nullable", "style": "info"},
|
|
417
|
+
{"name": "Default", "style": "muted"},
|
|
401
418
|
]
|
|
402
419
|
col_table = self._create_table(columns, title="Columns")
|
|
403
420
|
|
|
@@ -418,20 +435,20 @@ class DisplayManager:
|
|
|
418
435
|
primary_keys = table_info.get("primary_keys", [])
|
|
419
436
|
if primary_keys:
|
|
420
437
|
self.console.print(
|
|
421
|
-
f"[
|
|
438
|
+
f"[key.primary]Primary Keys:[/key.primary] {', '.join(primary_keys)}"
|
|
422
439
|
)
|
|
423
440
|
|
|
424
441
|
# Show foreign keys
|
|
425
442
|
foreign_keys = table_info.get("foreign_keys", [])
|
|
426
443
|
if foreign_keys:
|
|
427
|
-
self.console.print("[
|
|
444
|
+
self.console.print("[key.foreign]Foreign Keys:[/key.foreign]")
|
|
428
445
|
for fk in foreign_keys:
|
|
429
446
|
self.console.print(f" • {fk}")
|
|
430
447
|
|
|
431
448
|
# Show indexes
|
|
432
449
|
indexes = table_info.get("indexes", [])
|
|
433
450
|
if indexes:
|
|
434
|
-
self.console.print("[
|
|
451
|
+
self.console.print("[key.index]Indexes:[/key.index]")
|
|
435
452
|
for idx in indexes:
|
|
436
453
|
self.console.print(f" • {idx}")
|
|
437
454
|
|
|
@@ -457,7 +474,9 @@ class DisplayManager:
|
|
|
457
474
|
full_text = "".join(text_parts).strip()
|
|
458
475
|
if full_text:
|
|
459
476
|
self.console.print() # Add spacing before panel
|
|
460
|
-
markdown = Markdown(full_text)
|
|
461
|
-
panel = Panel.fit(
|
|
477
|
+
markdown = Markdown(full_text, code_theme=self.tm.pygments_style_name)
|
|
478
|
+
panel = Panel.fit(
|
|
479
|
+
markdown, border_style=self.tm.style("panel.border.assistant")
|
|
480
|
+
)
|
|
462
481
|
self.console.print(panel)
|
|
463
482
|
self.console.print() # Add spacing after panel
|
sqlsaber/cli/interactive.py
CHANGED
|
@@ -9,7 +9,6 @@ import platformdirs
|
|
|
9
9
|
from prompt_toolkit import PromptSession
|
|
10
10
|
from prompt_toolkit.history import FileHistory
|
|
11
11
|
from prompt_toolkit.patch_stdout import patch_stdout
|
|
12
|
-
from prompt_toolkit.styles import Style
|
|
13
12
|
from rich.console import Console
|
|
14
13
|
from rich.markdown import Markdown
|
|
15
14
|
from rich.panel import Panel
|
|
@@ -29,6 +28,7 @@ from sqlsaber.database import (
|
|
|
29
28
|
SQLiteConnection,
|
|
30
29
|
)
|
|
31
30
|
from sqlsaber.database.schema import SchemaManager
|
|
31
|
+
from sqlsaber.theme.manager import get_theme_manager
|
|
32
32
|
from sqlsaber.threads import ThreadStorage
|
|
33
33
|
|
|
34
34
|
if TYPE_CHECKING:
|
|
@@ -61,6 +61,7 @@ class InteractiveSession:
|
|
|
61
61
|
self.cancellation_token: asyncio.Event | None = None
|
|
62
62
|
self.table_completer = TableNameCompleter()
|
|
63
63
|
self.message_history: list | None = initial_history or []
|
|
64
|
+
self.tm = get_theme_manager()
|
|
64
65
|
# Conversation Thread persistence
|
|
65
66
|
self._threads = ThreadStorage()
|
|
66
67
|
self._thread_id: str | None = initial_thread_id
|
|
@@ -72,15 +73,6 @@ class InteractiveSession:
|
|
|
72
73
|
history_dir.mkdir(parents=True, exist_ok=True)
|
|
73
74
|
return history_dir / "history"
|
|
74
75
|
|
|
75
|
-
def _prompt_style(self) -> Style:
|
|
76
|
-
"""Get the prompt style configuration."""
|
|
77
|
-
return Style.from_dict(
|
|
78
|
-
{
|
|
79
|
-
"frame.border": "gray",
|
|
80
|
-
"bottom-toolbar": "white bg:#21202e",
|
|
81
|
-
}
|
|
82
|
-
)
|
|
83
|
-
|
|
84
76
|
def _bottom_toolbar(self):
|
|
85
77
|
"""Get the bottom toolbar text."""
|
|
86
78
|
return [
|
|
@@ -136,38 +128,38 @@ class InteractiveSession:
|
|
|
136
128
|
|
|
137
129
|
db_name = self.database_name or "Unknown"
|
|
138
130
|
self.console.print(
|
|
139
|
-
f"[
|
|
131
|
+
f"[heading]\n\nConnected to:[/heading] {db_name} ({self._db_type_name()})\n"
|
|
140
132
|
)
|
|
141
133
|
|
|
142
134
|
if self._thread_id:
|
|
143
|
-
self.console.print(f"[
|
|
135
|
+
self.console.print(f"[muted]Resuming thread:[/muted] {self._thread_id}\n")
|
|
144
136
|
|
|
145
137
|
async def _end_thread(self):
|
|
146
138
|
"""End thread and display resume hint."""
|
|
147
139
|
if self._thread_id:
|
|
148
140
|
await self._threads.end_thread(self._thread_id)
|
|
149
141
|
self.console.print(
|
|
150
|
-
f"[
|
|
142
|
+
f"[muted]You can continue this thread using:[/muted] {self._resume_hint(self._thread_id)}"
|
|
151
143
|
)
|
|
152
144
|
|
|
153
145
|
async def _handle_memory(self, content: str):
|
|
154
146
|
"""Handle memory addition command."""
|
|
155
147
|
if not content:
|
|
156
|
-
self.console.print("[
|
|
148
|
+
self.console.print("[warning]Empty memory content after '#'[/warning]\n")
|
|
157
149
|
return
|
|
158
150
|
|
|
159
151
|
try:
|
|
160
152
|
mm = self.sqlsaber_agent.memory_manager
|
|
161
153
|
if mm and self.database_name:
|
|
162
154
|
memory = mm.add_memory(self.database_name, content)
|
|
163
|
-
self.console.print(f"[
|
|
164
|
-
self.console.print(f"[
|
|
155
|
+
self.console.print(f"[success]✓ Memory added:[/success] {content}")
|
|
156
|
+
self.console.print(f"[muted]Memory ID: {memory.id}[/muted]\n")
|
|
165
157
|
else:
|
|
166
158
|
self.console.print(
|
|
167
|
-
"[
|
|
159
|
+
"[warning]Could not add memory (no database context)[/warning]\n"
|
|
168
160
|
)
|
|
169
161
|
except Exception as exc:
|
|
170
|
-
self.console.print(f"[
|
|
162
|
+
self.console.print(f"[warning]Could not add memory:[/warning] {exc}\n")
|
|
171
163
|
|
|
172
164
|
async def _cmd_clear(self):
|
|
173
165
|
"""Clear conversation history."""
|
|
@@ -177,19 +169,19 @@ class InteractiveSession:
|
|
|
177
169
|
await self._threads.end_thread(self._thread_id)
|
|
178
170
|
except Exception:
|
|
179
171
|
pass
|
|
180
|
-
self.console.print("[
|
|
172
|
+
self.console.print("[success]Conversation history cleared.[/success]\n")
|
|
181
173
|
self._thread_id = None
|
|
182
174
|
self.first_message = True
|
|
183
175
|
|
|
184
176
|
async def _cmd_thinking_on(self):
|
|
185
177
|
"""Enable thinking mode."""
|
|
186
178
|
self.sqlsaber_agent.set_thinking(enabled=True)
|
|
187
|
-
self.console.print("[
|
|
179
|
+
self.console.print("[success]✓ Thinking enabled[/success]\n")
|
|
188
180
|
|
|
189
181
|
async def _cmd_thinking_off(self):
|
|
190
182
|
"""Disable thinking mode."""
|
|
191
183
|
self.sqlsaber_agent.set_thinking(enabled=False)
|
|
192
|
-
self.console.print("[
|
|
184
|
+
self.console.print("[success]✓ Thinking disabled[/success]\n")
|
|
193
185
|
|
|
194
186
|
async def _handle_command(self, user_query: str) -> bool:
|
|
195
187
|
"""Handle slash commands. Returns True if command was handled."""
|
|
@@ -297,14 +289,13 @@ class InteractiveSession:
|
|
|
297
289
|
try:
|
|
298
290
|
with patch_stdout():
|
|
299
291
|
user_query = await session.prompt_async(
|
|
300
|
-
"",
|
|
292
|
+
"> ",
|
|
301
293
|
multiline=True,
|
|
302
294
|
completer=CompositeCompleter(
|
|
303
295
|
SlashCommandCompleter(), self.table_completer
|
|
304
296
|
),
|
|
305
|
-
show_frame=True,
|
|
306
297
|
bottom_toolbar=self._bottom_toolbar,
|
|
307
|
-
style=self.
|
|
298
|
+
style=self.tm.pt_style(),
|
|
308
299
|
)
|
|
309
300
|
|
|
310
301
|
if not user_query:
|
|
@@ -340,14 +331,14 @@ class InteractiveSession:
|
|
|
340
331
|
await self.current_task
|
|
341
332
|
except asyncio.CancelledError:
|
|
342
333
|
pass
|
|
343
|
-
self.console.print("\n[
|
|
334
|
+
self.console.print("\n[warning]Query interrupted[/warning]")
|
|
344
335
|
else:
|
|
345
336
|
self.console.print(
|
|
346
|
-
"\n[
|
|
337
|
+
"\n[warning]Press Ctrl+D to exit. Or use '/exit' or '/quit' slash command.[/warning]"
|
|
347
338
|
)
|
|
348
339
|
except EOFError:
|
|
349
340
|
# Exit when Ctrl+D is pressed
|
|
350
341
|
await self._end_thread()
|
|
351
342
|
break
|
|
352
343
|
except Exception as exc:
|
|
353
|
-
self.console.print(f"[
|
|
344
|
+
self.console.print(f"[error]Error:[/error] {exc}")
|
sqlsaber/cli/memory.py
CHANGED
|
@@ -5,14 +5,14 @@ from typing import Annotated
|
|
|
5
5
|
|
|
6
6
|
import cyclopts
|
|
7
7
|
import questionary
|
|
8
|
-
from rich.console import Console
|
|
9
8
|
from rich.table import Table
|
|
10
9
|
|
|
11
10
|
from sqlsaber.config.database import DatabaseConfigManager
|
|
12
11
|
from sqlsaber.memory.manager import MemoryManager
|
|
12
|
+
from sqlsaber.theme.manager import create_console
|
|
13
13
|
|
|
14
14
|
# Global instances for CLI commands
|
|
15
|
-
console =
|
|
15
|
+
console = create_console()
|
|
16
16
|
config_manager = DatabaseConfigManager()
|
|
17
17
|
memory_manager = MemoryManager()
|
|
18
18
|
|
sqlsaber/cli/models.py
CHANGED
|
@@ -6,14 +6,14 @@ import sys
|
|
|
6
6
|
import cyclopts
|
|
7
7
|
import httpx
|
|
8
8
|
import questionary
|
|
9
|
-
from rich.console import Console
|
|
10
9
|
from rich.table import Table
|
|
11
10
|
|
|
12
11
|
from sqlsaber.config import providers
|
|
13
12
|
from sqlsaber.config.settings import Config
|
|
13
|
+
from sqlsaber.theme.manager import create_console
|
|
14
14
|
|
|
15
15
|
# Global instances for CLI commands
|
|
16
|
-
console =
|
|
16
|
+
console = create_console()
|
|
17
17
|
|
|
18
18
|
# Create the model management CLI app
|
|
19
19
|
models_app = cyclopts.App(
|
sqlsaber/cli/onboarding.py
CHANGED
|
@@ -2,15 +2,15 @@
|
|
|
2
2
|
|
|
3
3
|
import sys
|
|
4
4
|
|
|
5
|
-
from rich.console import Console
|
|
6
5
|
from rich.panel import Panel
|
|
7
6
|
|
|
8
7
|
from sqlsaber.cli.models import ModelManager
|
|
9
8
|
from sqlsaber.config.api_keys import APIKeyManager
|
|
10
9
|
from sqlsaber.config.auth import AuthConfigManager
|
|
11
10
|
from sqlsaber.config.database import DatabaseConfigManager
|
|
11
|
+
from sqlsaber.theme.manager import create_console
|
|
12
12
|
|
|
13
|
-
console =
|
|
13
|
+
console = create_console()
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def needs_onboarding(database_arg: str | None = None) -> bool:
|
sqlsaber/cli/streaming.py
CHANGED
|
@@ -170,7 +170,7 @@ class StreamingQueryHandler:
|
|
|
170
170
|
except asyncio.CancelledError:
|
|
171
171
|
# Show interruption message outside of Live
|
|
172
172
|
self.display.show_newline()
|
|
173
|
-
self.console.print("[
|
|
173
|
+
self.console.print("[warning]Query interrupted[/warning]")
|
|
174
174
|
return None
|
|
175
175
|
finally:
|
|
176
176
|
# End any active status and live markdown segments
|