sqlsaber 0.27.0__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlsaber/application/auth_setup.py +2 -2
- sqlsaber/application/db_setup.py +2 -3
- sqlsaber/application/model_selection.py +2 -2
- sqlsaber/cli/auth.py +2 -2
- sqlsaber/cli/commands.py +2 -2
- sqlsaber/cli/database.py +2 -2
- sqlsaber/cli/display.py +59 -40
- sqlsaber/cli/interactive.py +18 -27
- sqlsaber/cli/memory.py +2 -2
- sqlsaber/cli/models.py +2 -2
- sqlsaber/cli/onboarding.py +2 -2
- sqlsaber/cli/streaming.py +1 -1
- sqlsaber/cli/threads.py +35 -16
- sqlsaber/config/api_keys.py +2 -2
- sqlsaber/config/oauth_flow.py +3 -2
- sqlsaber/config/oauth_tokens.py +3 -5
- sqlsaber/database/base.py +6 -0
- sqlsaber/database/csv.py +5 -0
- sqlsaber/database/duckdb.py +5 -0
- sqlsaber/database/mysql.py +5 -0
- sqlsaber/database/postgresql.py +5 -0
- sqlsaber/database/sqlite.py +5 -0
- sqlsaber/theme/__init__.py +5 -0
- sqlsaber/theme/manager.py +219 -0
- sqlsaber/tools/sql_guard.py +225 -0
- sqlsaber/tools/sql_tools.py +10 -35
- {sqlsaber-0.27.0.dist-info → sqlsaber-0.28.0.dist-info}/METADATA +2 -1
- {sqlsaber-0.27.0.dist-info → sqlsaber-0.28.0.dist-info}/RECORD +31 -28
- {sqlsaber-0.27.0.dist-info → sqlsaber-0.28.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.27.0.dist-info → sqlsaber-0.28.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.27.0.dist-info → sqlsaber-0.28.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/cli/threads.py
CHANGED
|
@@ -12,10 +12,12 @@ from rich.markdown import Markdown
|
|
|
12
12
|
from rich.panel import Panel
|
|
13
13
|
from rich.table import Table
|
|
14
14
|
|
|
15
|
+
from sqlsaber.theme.manager import create_console, get_theme_manager
|
|
15
16
|
from sqlsaber.threads import ThreadStorage
|
|
16
17
|
|
|
17
18
|
# Globals consistent with other CLI modules
|
|
18
|
-
console =
|
|
19
|
+
console = create_console()
|
|
20
|
+
tm = get_theme_manager()
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
threads_app = cyclopts.App(
|
|
@@ -84,13 +86,23 @@ def _render_transcript(
|
|
|
84
86
|
console.print(f"**User:**\n\n{text}\n")
|
|
85
87
|
else:
|
|
86
88
|
console.print(
|
|
87
|
-
Panel.fit(
|
|
89
|
+
Panel.fit(
|
|
90
|
+
Markdown(text, code_theme=tm.pygments_style_name),
|
|
91
|
+
title="User",
|
|
92
|
+
border_style=tm.style("panel.border.user"),
|
|
93
|
+
)
|
|
88
94
|
)
|
|
89
95
|
return
|
|
90
96
|
if is_redirected:
|
|
91
97
|
console.print("**User:** (no content)\n")
|
|
92
98
|
else:
|
|
93
|
-
console.print(
|
|
99
|
+
console.print(
|
|
100
|
+
Panel.fit(
|
|
101
|
+
"(no content)",
|
|
102
|
+
title="User",
|
|
103
|
+
border_style=tm.style("panel.border.user"),
|
|
104
|
+
)
|
|
105
|
+
)
|
|
94
106
|
|
|
95
107
|
def _render_response(message: ModelMessage) -> None:
|
|
96
108
|
for part in getattr(message, "parts", []):
|
|
@@ -103,7 +115,9 @@ def _render_transcript(
|
|
|
103
115
|
else:
|
|
104
116
|
console.print(
|
|
105
117
|
Panel.fit(
|
|
106
|
-
Markdown(text
|
|
118
|
+
Markdown(text, code_theme=tm.pygments_style_name),
|
|
119
|
+
title="Assistant",
|
|
120
|
+
border_style=tm.style("panel.border.assistant"),
|
|
107
121
|
)
|
|
108
122
|
)
|
|
109
123
|
elif kind in ("tool-call", "builtin-tool-call"):
|
|
@@ -211,11 +225,11 @@ def list_threads(
|
|
|
211
225
|
console.print("No threads found.")
|
|
212
226
|
return
|
|
213
227
|
table = Table(title="Threads")
|
|
214
|
-
table.add_column("ID", style="
|
|
215
|
-
table.add_column("Database", style="
|
|
216
|
-
table.add_column("Title", style="
|
|
217
|
-
table.add_column("Last Activity", style="
|
|
218
|
-
table.add_column("Model", style="
|
|
228
|
+
table.add_column("ID", style=tm.style("info"))
|
|
229
|
+
table.add_column("Database", style=tm.style("accent"))
|
|
230
|
+
table.add_column("Title", style=tm.style("success"))
|
|
231
|
+
table.add_column("Last Activity", style=tm.style("muted"))
|
|
232
|
+
table.add_column("Model", style=tm.style("warning"))
|
|
219
233
|
for t in threads:
|
|
220
234
|
table.add_row(
|
|
221
235
|
t.id,
|
|
@@ -235,7 +249,7 @@ def show(
|
|
|
235
249
|
store = ThreadStorage()
|
|
236
250
|
thread = asyncio.run(store.get_thread(thread_id))
|
|
237
251
|
if not thread:
|
|
238
|
-
console.print(f"[
|
|
252
|
+
console.print(f"[error]Thread not found:[/error] {thread_id}")
|
|
239
253
|
return
|
|
240
254
|
msgs = asyncio.run(store.get_thread_messages(thread_id))
|
|
241
255
|
console.print(f"[bold]Thread: {thread.id}[/bold]")
|
|
@@ -273,12 +287,12 @@ def resume(
|
|
|
273
287
|
|
|
274
288
|
thread = await store.get_thread(thread_id)
|
|
275
289
|
if not thread:
|
|
276
|
-
console.print(f"[
|
|
290
|
+
console.print(f"[error]Thread not found:[/error] {thread_id}")
|
|
277
291
|
return
|
|
278
292
|
db_selector = database or thread.database_name
|
|
279
293
|
if not db_selector:
|
|
280
294
|
console.print(
|
|
281
|
-
"[
|
|
295
|
+
"[error]No database specified or stored with this thread.[/error]"
|
|
282
296
|
)
|
|
283
297
|
return
|
|
284
298
|
try:
|
|
@@ -287,7 +301,7 @@ def resume(
|
|
|
287
301
|
connection_string = resolved.connection_string
|
|
288
302
|
db_name = resolved.name
|
|
289
303
|
except DatabaseResolutionError as e:
|
|
290
|
-
console.print(f"[
|
|
304
|
+
console.print(f"[error]Database resolution error:[/error] {e}")
|
|
291
305
|
return
|
|
292
306
|
|
|
293
307
|
db_conn = DatabaseConnection(connection_string)
|
|
@@ -295,7 +309,12 @@ def resume(
|
|
|
295
309
|
sqlsaber_agent = SQLSaberAgent(db_conn, db_name)
|
|
296
310
|
history = await store.get_thread_messages(thread_id)
|
|
297
311
|
if console.is_terminal:
|
|
298
|
-
console.print(
|
|
312
|
+
console.print(
|
|
313
|
+
Panel.fit(
|
|
314
|
+
f"Thread: {thread.id}",
|
|
315
|
+
border_style=tm.style("panel.border.thread"),
|
|
316
|
+
)
|
|
317
|
+
)
|
|
299
318
|
else:
|
|
300
319
|
console.print(f"# Thread: {thread.id}\n")
|
|
301
320
|
_render_transcript(console, history, None)
|
|
@@ -310,7 +329,7 @@ def resume(
|
|
|
310
329
|
await session.run()
|
|
311
330
|
finally:
|
|
312
331
|
await db_conn.close()
|
|
313
|
-
console.print("\n[
|
|
332
|
+
console.print("\n[success]Goodbye![/success]")
|
|
314
333
|
|
|
315
334
|
asyncio.run(_run())
|
|
316
335
|
|
|
@@ -329,7 +348,7 @@ def prune(
|
|
|
329
348
|
|
|
330
349
|
async def _run() -> None:
|
|
331
350
|
deleted = await store.prune_threads(older_than_days=days)
|
|
332
|
-
console.print(f"[
|
|
351
|
+
console.print(f"[success]✓ Pruned {deleted} thread(s).[/success]")
|
|
333
352
|
|
|
334
353
|
asyncio.run(_run())
|
|
335
354
|
|
sqlsaber/config/api_keys.py
CHANGED
|
@@ -4,11 +4,11 @@ import getpass
|
|
|
4
4
|
import os
|
|
5
5
|
|
|
6
6
|
import keyring
|
|
7
|
-
from rich.console import Console
|
|
8
7
|
|
|
9
8
|
from sqlsaber.config import providers
|
|
9
|
+
from sqlsaber.theme.manager import create_console
|
|
10
10
|
|
|
11
|
-
console =
|
|
11
|
+
console = create_console()
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class APIKeyManager:
|
sqlsaber/config/oauth_flow.py
CHANGED
|
@@ -10,12 +10,13 @@ from datetime import datetime, timezone
|
|
|
10
10
|
|
|
11
11
|
import httpx
|
|
12
12
|
import questionary
|
|
13
|
-
from rich.console import Console
|
|
14
13
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
15
14
|
|
|
15
|
+
from sqlsaber.theme.manager import create_console
|
|
16
|
+
|
|
16
17
|
from .oauth_tokens import OAuthToken, OAuthTokenManager
|
|
17
18
|
|
|
18
|
-
console =
|
|
19
|
+
console = create_console()
|
|
19
20
|
logger = logging.getLogger(__name__)
|
|
20
21
|
|
|
21
22
|
|
sqlsaber/config/oauth_tokens.py
CHANGED
|
@@ -6,9 +6,10 @@ from datetime import datetime, timedelta, timezone
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
8
|
import keyring
|
|
9
|
-
from rich.console import Console
|
|
10
9
|
|
|
11
|
-
|
|
10
|
+
from sqlsaber.theme.manager import create_console
|
|
11
|
+
|
|
12
|
+
console = create_console()
|
|
12
13
|
logger = logging.getLogger(__name__)
|
|
13
14
|
|
|
14
15
|
|
|
@@ -158,9 +159,6 @@ class OAuthTokenManager:
|
|
|
158
159
|
keyring.delete_password(service_name, provider)
|
|
159
160
|
console.print(f"OAuth token for {provider} removed", style="green")
|
|
160
161
|
return True
|
|
161
|
-
except keyring.errors.PasswordDeleteError:
|
|
162
|
-
# Token doesn't exist
|
|
163
|
-
return True
|
|
164
162
|
except Exception as e:
|
|
165
163
|
logger.error(f"Failed to remove OAuth token for {provider}: {e}")
|
|
166
164
|
console.print(f"Warning: Could not remove OAuth token: {e}", style="yellow")
|
sqlsaber/database/base.py
CHANGED
|
@@ -61,6 +61,12 @@ class BaseDatabaseConnection(ABC):
|
|
|
61
61
|
self.connection_string = connection_string
|
|
62
62
|
self._pool = None
|
|
63
63
|
|
|
64
|
+
@property
|
|
65
|
+
@abstractmethod
|
|
66
|
+
def sqlglot_dialect(self) -> str:
|
|
67
|
+
"""Return the sqlglot dialect name for this database."""
|
|
68
|
+
pass
|
|
69
|
+
|
|
64
70
|
@abstractmethod
|
|
65
71
|
async def get_pool(self):
|
|
66
72
|
"""Get or create connection pool."""
|
sqlsaber/database/csv.py
CHANGED
|
@@ -58,6 +58,11 @@ class CSVConnection(BaseDatabaseConnection):
|
|
|
58
58
|
|
|
59
59
|
self.table_name = Path(self.csv_path).stem or "csv_table"
|
|
60
60
|
|
|
61
|
+
@property
|
|
62
|
+
def sqlglot_dialect(self) -> str:
|
|
63
|
+
"""Return the sqlglot dialect name."""
|
|
64
|
+
return "duckdb"
|
|
65
|
+
|
|
61
66
|
async def get_pool(self):
|
|
62
67
|
"""CSV connections do not maintain a pool."""
|
|
63
68
|
return None
|
sqlsaber/database/duckdb.py
CHANGED
|
@@ -52,6 +52,11 @@ class DuckDBConnection(BaseDatabaseConnection):
|
|
|
52
52
|
|
|
53
53
|
self.database_path = db_path or ":memory:"
|
|
54
54
|
|
|
55
|
+
@property
|
|
56
|
+
def sqlglot_dialect(self) -> str:
|
|
57
|
+
"""Return the sqlglot dialect name."""
|
|
58
|
+
return "duckdb"
|
|
59
|
+
|
|
55
60
|
async def get_pool(self):
|
|
56
61
|
"""DuckDB creates connections per query, return database path."""
|
|
57
62
|
return self.database_path
|
sqlsaber/database/mysql.py
CHANGED
|
@@ -23,6 +23,11 @@ class MySQLConnection(BaseDatabaseConnection):
|
|
|
23
23
|
self._pool: aiomysql.Pool | None = None
|
|
24
24
|
self._parse_connection_string()
|
|
25
25
|
|
|
26
|
+
@property
|
|
27
|
+
def sqlglot_dialect(self) -> str:
|
|
28
|
+
"""Return the sqlglot dialect name."""
|
|
29
|
+
return "mysql"
|
|
30
|
+
|
|
26
31
|
def _parse_connection_string(self):
|
|
27
32
|
"""Parse MySQL connection string into components."""
|
|
28
33
|
parsed = urlparse(self.connection_string)
|
sqlsaber/database/postgresql.py
CHANGED
|
@@ -23,6 +23,11 @@ class PostgreSQLConnection(BaseDatabaseConnection):
|
|
|
23
23
|
self._pool: asyncpg.Pool | None = None
|
|
24
24
|
self._ssl_context = self._create_ssl_context()
|
|
25
25
|
|
|
26
|
+
@property
|
|
27
|
+
def sqlglot_dialect(self) -> str:
|
|
28
|
+
"""Return the sqlglot dialect name."""
|
|
29
|
+
return "postgres"
|
|
30
|
+
|
|
26
31
|
def _create_ssl_context(self) -> ssl.SSLContext | None:
|
|
27
32
|
"""Create SSL context from connection string parameters."""
|
|
28
33
|
parsed = urlparse(self.connection_string)
|
sqlsaber/database/sqlite.py
CHANGED
|
@@ -21,6 +21,11 @@ class SQLiteConnection(BaseDatabaseConnection):
|
|
|
21
21
|
# Extract database path from sqlite:///path format
|
|
22
22
|
self.database_path = connection_string.replace("sqlite:///", "")
|
|
23
23
|
|
|
24
|
+
@property
|
|
25
|
+
def sqlglot_dialect(self) -> str:
|
|
26
|
+
"""Return the sqlglot dialect name."""
|
|
27
|
+
return "sqlite"
|
|
28
|
+
|
|
24
29
|
async def get_pool(self):
|
|
25
30
|
"""SQLite doesn't use connection pooling, return database path."""
|
|
26
31
|
return self.database_path
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""Theme management for unified theming across Rich and prompt_toolkit."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import tomllib
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from functools import lru_cache
|
|
7
|
+
from typing import Dict
|
|
8
|
+
|
|
9
|
+
from platformdirs import user_config_dir
|
|
10
|
+
from prompt_toolkit.styles import Style as PTStyle
|
|
11
|
+
from prompt_toolkit.styles.pygments import style_from_pygments_cls
|
|
12
|
+
from pygments.styles import get_style_by_name
|
|
13
|
+
from rich.console import Console
|
|
14
|
+
from rich.theme import Theme
|
|
15
|
+
|
|
16
|
+
DEFAULT_THEME_NAME = "nord"
|
|
17
|
+
|
|
18
|
+
DEFAULT_ROLE_PALETTE = {
|
|
19
|
+
# base roles
|
|
20
|
+
"primary": "cyan",
|
|
21
|
+
"accent": "magenta",
|
|
22
|
+
"success": "green",
|
|
23
|
+
"warning": "yellow",
|
|
24
|
+
"error": "red",
|
|
25
|
+
"info": "cyan",
|
|
26
|
+
"muted": "dim",
|
|
27
|
+
# components
|
|
28
|
+
"table.header": "bold $primary",
|
|
29
|
+
"panel.border.user": "$info",
|
|
30
|
+
"panel.border.assistant": "$success",
|
|
31
|
+
"panel.border.thread": "$primary",
|
|
32
|
+
"spinner": "$warning",
|
|
33
|
+
"status": "$warning",
|
|
34
|
+
# domain-specific
|
|
35
|
+
"key.primary": "bold $warning",
|
|
36
|
+
"key.foreign": "bold $accent",
|
|
37
|
+
"key.index": "bold $primary",
|
|
38
|
+
"column.schema": "$info",
|
|
39
|
+
"column.name": "white",
|
|
40
|
+
"column.type": "$warning",
|
|
41
|
+
"heading": "bold $primary",
|
|
42
|
+
"section": "bold $accent",
|
|
43
|
+
"title": "bold $success",
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
# Theme presets using exact Pygments colors
|
|
47
|
+
THEME_PRESETS = {
|
|
48
|
+
# Nord - exact colors from pygments nord theme
|
|
49
|
+
"nord": {
|
|
50
|
+
"primary": "#81a1c1", # Keyword (frost)
|
|
51
|
+
"accent": "#b48ead", # Number (aurora purple)
|
|
52
|
+
"success": "#a3be8c", # String (aurora green)
|
|
53
|
+
"warning": "#ebcb8b", # String.Escape (aurora yellow)
|
|
54
|
+
"error": "#bf616a", # Error/Generic.Error (aurora red)
|
|
55
|
+
"info": "#88c0d0", # Name.Function (frost cyan)
|
|
56
|
+
"muted": "dim",
|
|
57
|
+
},
|
|
58
|
+
# Dracula - exact colors from pygments dracula theme
|
|
59
|
+
"dracula": {
|
|
60
|
+
"primary": "#bd93f9", # purple
|
|
61
|
+
"accent": "#ff79c6", # pink
|
|
62
|
+
"success": "#50fa7b", # green
|
|
63
|
+
"warning": "#f1fa8c", # yellow
|
|
64
|
+
"error": "#ff5555", # red
|
|
65
|
+
"info": "#8be9fd", # cyan
|
|
66
|
+
"muted": "dim",
|
|
67
|
+
},
|
|
68
|
+
# Solarized Light - exact colors from pygments solarized-light theme
|
|
69
|
+
"solarized-light": {
|
|
70
|
+
"primary": "#268bd2", # blue
|
|
71
|
+
"accent": "#d33682", # magenta
|
|
72
|
+
"success": "#859900", # green
|
|
73
|
+
"warning": "#b58900", # yellow
|
|
74
|
+
"error": "#dc322f", # red
|
|
75
|
+
"info": "#2aa198", # cyan
|
|
76
|
+
"muted": "dim",
|
|
77
|
+
},
|
|
78
|
+
# VS (Visual Studio Light) - exact colors from pygments vs theme
|
|
79
|
+
"vs": {
|
|
80
|
+
"primary": "#0000ff", # Keyword (blue)
|
|
81
|
+
"accent": "#2b91af", # Keyword.Type/Name.Class
|
|
82
|
+
"success": "#008000", # Comment (green)
|
|
83
|
+
"warning": "#b58900", # (using solarized yellow as fallback)
|
|
84
|
+
"error": "#dc322f", # (using solarized red as fallback)
|
|
85
|
+
"info": "#2aa198", # (using solarized cyan as fallback)
|
|
86
|
+
"muted": "dim",
|
|
87
|
+
},
|
|
88
|
+
# Material (approximation based on material design colors)
|
|
89
|
+
"material": {
|
|
90
|
+
"primary": "#89ddff", # cyan
|
|
91
|
+
"accent": "#f07178", # pink/red
|
|
92
|
+
"success": "#c3e88d", # green
|
|
93
|
+
"warning": "#ffcb6b", # yellow
|
|
94
|
+
"error": "#ff5370", # red
|
|
95
|
+
"info": "#82aaff", # blue
|
|
96
|
+
"muted": "dim",
|
|
97
|
+
},
|
|
98
|
+
# One Dark - exact colors from pygments one-dark theme
|
|
99
|
+
"one-dark": {
|
|
100
|
+
"primary": "#c678dd", # Keyword (purple)
|
|
101
|
+
"accent": "#e06c75", # Name (red)
|
|
102
|
+
"success": "#98c379", # String (green)
|
|
103
|
+
"warning": "#e5c07b", # Keyword.Type (yellow)
|
|
104
|
+
"error": "#e06c75", # Name (red, used for errors)
|
|
105
|
+
"info": "#61afef", # Name.Function (blue)
|
|
106
|
+
"muted": "dim",
|
|
107
|
+
},
|
|
108
|
+
# Lightbulb - exact colors from pygments lightbulb theme (minimal dark)
|
|
109
|
+
"lightbulb": {
|
|
110
|
+
"primary": "#73d0ff", # Keyword.Type/Name.Class (blue_1)
|
|
111
|
+
"accent": "#dfbfff", # Number (magenta_1)
|
|
112
|
+
"success": "#d5ff80", # String (green_1)
|
|
113
|
+
"warning": "#ffd173", # Name.Function (yellow_1)
|
|
114
|
+
"error": "#f88f7f", # Error (red_1)
|
|
115
|
+
"info": "#95e6cb", # Name.Entity (cyan_1)
|
|
116
|
+
"muted": "dim",
|
|
117
|
+
},
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _load_user_theme_config() -> dict:
|
|
122
|
+
"""Load theme configuration from user config directory."""
|
|
123
|
+
cfg_dir = user_config_dir("sqlsaber")
|
|
124
|
+
path = os.path.join(cfg_dir, "theme.toml")
|
|
125
|
+
if not os.path.exists(path):
|
|
126
|
+
return {}
|
|
127
|
+
with open(path, "rb") as f:
|
|
128
|
+
return tomllib.load(f)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _resolve_refs(palette: dict[str, str]) -> dict[str, str]:
|
|
132
|
+
"""Resolve $var references in palette values."""
|
|
133
|
+
out = {}
|
|
134
|
+
for k, v in palette.items():
|
|
135
|
+
if isinstance(v, str) and "$" in v:
|
|
136
|
+
parts = v.split()
|
|
137
|
+
resolved = []
|
|
138
|
+
for part in parts:
|
|
139
|
+
if part.startswith("$"):
|
|
140
|
+
ref = part[1:]
|
|
141
|
+
resolved.append(palette.get(ref, ""))
|
|
142
|
+
else:
|
|
143
|
+
resolved.append(part)
|
|
144
|
+
out[k] = " ".join(p for p in resolved if p)
|
|
145
|
+
else:
|
|
146
|
+
out[k] = v
|
|
147
|
+
return out
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@dataclass(frozen=True)
|
|
151
|
+
class ThemeConfig:
|
|
152
|
+
"""Theme configuration."""
|
|
153
|
+
|
|
154
|
+
name: str
|
|
155
|
+
pygments_style: str
|
|
156
|
+
roles: Dict[str, str]
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class ThemeManager:
|
|
160
|
+
"""Manages theme configuration and provides themed components."""
|
|
161
|
+
|
|
162
|
+
def __init__(self, cfg: ThemeConfig):
|
|
163
|
+
self._cfg = cfg
|
|
164
|
+
self._roles = _resolve_refs({**DEFAULT_ROLE_PALETTE, **cfg.roles})
|
|
165
|
+
self._rich_theme = Theme(self._roles)
|
|
166
|
+
self._pt_style = None
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def rich_theme(self) -> Theme:
|
|
170
|
+
"""Get Rich theme with semantic role mappings."""
|
|
171
|
+
return self._rich_theme
|
|
172
|
+
|
|
173
|
+
@property
|
|
174
|
+
def pygments_style_name(self) -> str:
|
|
175
|
+
"""Get pygments style name for syntax highlighting."""
|
|
176
|
+
return self._cfg.pygments_style
|
|
177
|
+
|
|
178
|
+
def pt_style(self) -> PTStyle:
|
|
179
|
+
"""Get prompt_toolkit style derived from Pygments theme."""
|
|
180
|
+
if self._pt_style is None:
|
|
181
|
+
try:
|
|
182
|
+
# Try to use Pygments style directly
|
|
183
|
+
pygments_style = get_style_by_name(self._cfg.pygments_style)
|
|
184
|
+
self._pt_style = style_from_pygments_cls(pygments_style)
|
|
185
|
+
except Exception:
|
|
186
|
+
# Fallback to basic style if Pygments theme not found
|
|
187
|
+
self._pt_style = PTStyle.from_dict({})
|
|
188
|
+
return self._pt_style
|
|
189
|
+
|
|
190
|
+
def style(self, role: str) -> str:
|
|
191
|
+
"""Get style string for a semantic role."""
|
|
192
|
+
return self._roles.get(role, "")
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@lru_cache(maxsize=1)
|
|
196
|
+
def get_theme_manager() -> ThemeManager:
|
|
197
|
+
"""Get the global theme manager instance."""
|
|
198
|
+
user_cfg = _load_user_theme_config()
|
|
199
|
+
env_name = os.getenv("SQLSABER_THEME")
|
|
200
|
+
|
|
201
|
+
name = (
|
|
202
|
+
env_name or user_cfg.get("theme", {}).get("name") or DEFAULT_THEME_NAME
|
|
203
|
+
).lower()
|
|
204
|
+
pygments_style = user_cfg.get("theme", {}).get("pygments_style") or name
|
|
205
|
+
|
|
206
|
+
roles = dict(DEFAULT_ROLE_PALETTE)
|
|
207
|
+
roles.update(THEME_PRESETS.get(name, {}))
|
|
208
|
+
roles.update(user_cfg.get("roles", {}))
|
|
209
|
+
|
|
210
|
+
cfg = ThemeConfig(name=name, pygments_style=pygments_style, roles=roles)
|
|
211
|
+
return ThemeManager(cfg)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def create_console(**kwargs):
|
|
215
|
+
"""Create a Rich Console with theme applied."""
|
|
216
|
+
# from rich.console import Console
|
|
217
|
+
|
|
218
|
+
tm = get_theme_manager()
|
|
219
|
+
return Console(theme=tm.rich_theme, **kwargs)
|