sqlsaber 0.27.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlsaber/cli/threads.py CHANGED
@@ -12,10 +12,12 @@ from rich.markdown import Markdown
12
12
  from rich.panel import Panel
13
13
  from rich.table import Table
14
14
 
15
+ from sqlsaber.theme.manager import create_console, get_theme_manager
15
16
  from sqlsaber.threads import ThreadStorage
16
17
 
17
18
  # Globals consistent with other CLI modules
18
- console = Console()
19
+ console = create_console()
20
+ tm = get_theme_manager()
19
21
 
20
22
 
21
23
  threads_app = cyclopts.App(
@@ -84,13 +86,23 @@ def _render_transcript(
84
86
  console.print(f"**User:**\n\n{text}\n")
85
87
  else:
86
88
  console.print(
87
- Panel.fit(Markdown(text), title="User", border_style="cyan")
89
+ Panel.fit(
90
+ Markdown(text, code_theme=tm.pygments_style_name),
91
+ title="User",
92
+ border_style=tm.style("panel.border.user"),
93
+ )
88
94
  )
89
95
  return
90
96
  if is_redirected:
91
97
  console.print("**User:** (no content)\n")
92
98
  else:
93
- console.print(Panel.fit("(no content)", title="User", border_style="cyan"))
99
+ console.print(
100
+ Panel.fit(
101
+ "(no content)",
102
+ title="User",
103
+ border_style=tm.style("panel.border.user"),
104
+ )
105
+ )
94
106
 
95
107
  def _render_response(message: ModelMessage) -> None:
96
108
  for part in getattr(message, "parts", []):
@@ -103,7 +115,9 @@ def _render_transcript(
103
115
  else:
104
116
  console.print(
105
117
  Panel.fit(
106
- Markdown(text), title="Assistant", border_style="green"
118
+ Markdown(text, code_theme=tm.pygments_style_name),
119
+ title="Assistant",
120
+ border_style=tm.style("panel.border.assistant"),
107
121
  )
108
122
  )
109
123
  elif kind in ("tool-call", "builtin-tool-call"):
@@ -211,11 +225,11 @@ def list_threads(
211
225
  console.print("No threads found.")
212
226
  return
213
227
  table = Table(title="Threads")
214
- table.add_column("ID", style="cyan")
215
- table.add_column("Database", style="magenta")
216
- table.add_column("Title", style="green")
217
- table.add_column("Last Activity", style="dim")
218
- table.add_column("Model", style="yellow")
228
+ table.add_column("ID", style=tm.style("info"))
229
+ table.add_column("Database", style=tm.style("accent"))
230
+ table.add_column("Title", style=tm.style("success"))
231
+ table.add_column("Last Activity", style=tm.style("muted"))
232
+ table.add_column("Model", style=tm.style("warning"))
219
233
  for t in threads:
220
234
  table.add_row(
221
235
  t.id,
@@ -235,7 +249,7 @@ def show(
235
249
  store = ThreadStorage()
236
250
  thread = asyncio.run(store.get_thread(thread_id))
237
251
  if not thread:
238
- console.print(f"[red]Thread not found:[/red] {thread_id}")
252
+ console.print(f"[error]Thread not found:[/error] {thread_id}")
239
253
  return
240
254
  msgs = asyncio.run(store.get_thread_messages(thread_id))
241
255
  console.print(f"[bold]Thread: {thread.id}[/bold]")
@@ -273,12 +287,12 @@ def resume(
273
287
 
274
288
  thread = await store.get_thread(thread_id)
275
289
  if not thread:
276
- console.print(f"[red]Thread not found:[/red] {thread_id}")
290
+ console.print(f"[error]Thread not found:[/error] {thread_id}")
277
291
  return
278
292
  db_selector = database or thread.database_name
279
293
  if not db_selector:
280
294
  console.print(
281
- "[red]No database specified or stored with this thread.[/red]"
295
+ "[error]No database specified or stored with this thread.[/error]"
282
296
  )
283
297
  return
284
298
  try:
@@ -287,7 +301,7 @@ def resume(
287
301
  connection_string = resolved.connection_string
288
302
  db_name = resolved.name
289
303
  except DatabaseResolutionError as e:
290
- console.print(f"[red]Database resolution error:[/red] {e}")
304
+ console.print(f"[error]Database resolution error:[/error] {e}")
291
305
  return
292
306
 
293
307
  db_conn = DatabaseConnection(connection_string)
@@ -295,7 +309,12 @@ def resume(
295
309
  sqlsaber_agent = SQLSaberAgent(db_conn, db_name)
296
310
  history = await store.get_thread_messages(thread_id)
297
311
  if console.is_terminal:
298
- console.print(Panel.fit(f"Thread: {thread.id}", border_style="blue"))
312
+ console.print(
313
+ Panel.fit(
314
+ f"Thread: {thread.id}",
315
+ border_style=tm.style("panel.border.thread"),
316
+ )
317
+ )
299
318
  else:
300
319
  console.print(f"# Thread: {thread.id}\n")
301
320
  _render_transcript(console, history, None)
@@ -310,7 +329,7 @@ def resume(
310
329
  await session.run()
311
330
  finally:
312
331
  await db_conn.close()
313
- console.print("\n[green]Goodbye![/green]")
332
+ console.print("\n[success]Goodbye![/success]")
314
333
 
315
334
  asyncio.run(_run())
316
335
 
@@ -329,7 +348,7 @@ def prune(
329
348
 
330
349
  async def _run() -> None:
331
350
  deleted = await store.prune_threads(older_than_days=days)
332
- console.print(f"[green]✓ Pruned {deleted} thread(s).[/green]")
351
+ console.print(f"[success]✓ Pruned {deleted} thread(s).[/success]")
333
352
 
334
353
  asyncio.run(_run())
335
354
 
@@ -4,11 +4,11 @@ import getpass
4
4
  import os
5
5
 
6
6
  import keyring
7
- from rich.console import Console
8
7
 
9
8
  from sqlsaber.config import providers
9
+ from sqlsaber.theme.manager import create_console
10
10
 
11
- console = Console()
11
+ console = create_console()
12
12
 
13
13
 
14
14
  class APIKeyManager:
@@ -10,12 +10,13 @@ from datetime import datetime, timezone
10
10
 
11
11
  import httpx
12
12
  import questionary
13
- from rich.console import Console
14
13
  from rich.progress import Progress, SpinnerColumn, TextColumn
15
14
 
15
+ from sqlsaber.theme.manager import create_console
16
+
16
17
  from .oauth_tokens import OAuthToken, OAuthTokenManager
17
18
 
18
- console = Console()
19
+ console = create_console()
19
20
  logger = logging.getLogger(__name__)
20
21
 
21
22
 
@@ -6,9 +6,10 @@ from datetime import datetime, timedelta, timezone
6
6
  from typing import Any
7
7
 
8
8
  import keyring
9
- from rich.console import Console
10
9
 
11
- console = Console()
10
+ from sqlsaber.theme.manager import create_console
11
+
12
+ console = create_console()
12
13
  logger = logging.getLogger(__name__)
13
14
 
14
15
 
@@ -158,9 +159,6 @@ class OAuthTokenManager:
158
159
  keyring.delete_password(service_name, provider)
159
160
  console.print(f"OAuth token for {provider} removed", style="green")
160
161
  return True
161
- except keyring.errors.PasswordDeleteError:
162
- # Token doesn't exist
163
- return True
164
162
  except Exception as e:
165
163
  logger.error(f"Failed to remove OAuth token for {provider}: {e}")
166
164
  console.print(f"Warning: Could not remove OAuth token: {e}", style="yellow")
sqlsaber/database/base.py CHANGED
@@ -61,6 +61,12 @@ class BaseDatabaseConnection(ABC):
61
61
  self.connection_string = connection_string
62
62
  self._pool = None
63
63
 
64
+ @property
65
+ @abstractmethod
66
+ def sqlglot_dialect(self) -> str:
67
+ """Return the sqlglot dialect name for this database."""
68
+ pass
69
+
64
70
  @abstractmethod
65
71
  async def get_pool(self):
66
72
  """Get or create connection pool."""
sqlsaber/database/csv.py CHANGED
@@ -58,6 +58,11 @@ class CSVConnection(BaseDatabaseConnection):
58
58
 
59
59
  self.table_name = Path(self.csv_path).stem or "csv_table"
60
60
 
61
+ @property
62
+ def sqlglot_dialect(self) -> str:
63
+ """Return the sqlglot dialect name."""
64
+ return "duckdb"
65
+
61
66
  async def get_pool(self):
62
67
  """CSV connections do not maintain a pool."""
63
68
  return None
@@ -52,6 +52,11 @@ class DuckDBConnection(BaseDatabaseConnection):
52
52
 
53
53
  self.database_path = db_path or ":memory:"
54
54
 
55
+ @property
56
+ def sqlglot_dialect(self) -> str:
57
+ """Return the sqlglot dialect name."""
58
+ return "duckdb"
59
+
55
60
  async def get_pool(self):
56
61
  """DuckDB creates connections per query, return database path."""
57
62
  return self.database_path
@@ -23,6 +23,11 @@ class MySQLConnection(BaseDatabaseConnection):
23
23
  self._pool: aiomysql.Pool | None = None
24
24
  self._parse_connection_string()
25
25
 
26
+ @property
27
+ def sqlglot_dialect(self) -> str:
28
+ """Return the sqlglot dialect name."""
29
+ return "mysql"
30
+
26
31
  def _parse_connection_string(self):
27
32
  """Parse MySQL connection string into components."""
28
33
  parsed = urlparse(self.connection_string)
@@ -23,6 +23,11 @@ class PostgreSQLConnection(BaseDatabaseConnection):
23
23
  self._pool: asyncpg.Pool | None = None
24
24
  self._ssl_context = self._create_ssl_context()
25
25
 
26
+ @property
27
+ def sqlglot_dialect(self) -> str:
28
+ """Return the sqlglot dialect name."""
29
+ return "postgres"
30
+
26
31
  def _create_ssl_context(self) -> ssl.SSLContext | None:
27
32
  """Create SSL context from connection string parameters."""
28
33
  parsed = urlparse(self.connection_string)
@@ -21,6 +21,11 @@ class SQLiteConnection(BaseDatabaseConnection):
21
21
  # Extract database path from sqlite:///path format
22
22
  self.database_path = connection_string.replace("sqlite:///", "")
23
23
 
24
+ @property
25
+ def sqlglot_dialect(self) -> str:
26
+ """Return the sqlglot dialect name."""
27
+ return "sqlite"
28
+
24
29
  async def get_pool(self):
25
30
  """SQLite doesn't use connection pooling, return database path."""
26
31
  return self.database_path
@@ -0,0 +1,5 @@
1
+ """Theme management for SQLSaber."""
2
+
3
+ from sqlsaber.theme.manager import ThemeManager, create_console, get_theme_manager
4
+
5
+ __all__ = ["ThemeManager", "create_console", "get_theme_manager"]
@@ -0,0 +1,219 @@
1
+ """Theme management for unified theming across Rich and prompt_toolkit."""
2
+
3
+ import os
4
+ import tomllib
5
+ from dataclasses import dataclass
6
+ from functools import lru_cache
7
+ from typing import Dict
8
+
9
+ from platformdirs import user_config_dir
10
+ from prompt_toolkit.styles import Style as PTStyle
11
+ from prompt_toolkit.styles.pygments import style_from_pygments_cls
12
+ from pygments.styles import get_style_by_name
13
+ from rich.console import Console
14
+ from rich.theme import Theme
15
+
16
+ DEFAULT_THEME_NAME = "nord"
17
+
18
+ DEFAULT_ROLE_PALETTE = {
19
+ # base roles
20
+ "primary": "cyan",
21
+ "accent": "magenta",
22
+ "success": "green",
23
+ "warning": "yellow",
24
+ "error": "red",
25
+ "info": "cyan",
26
+ "muted": "dim",
27
+ # components
28
+ "table.header": "bold $primary",
29
+ "panel.border.user": "$info",
30
+ "panel.border.assistant": "$success",
31
+ "panel.border.thread": "$primary",
32
+ "spinner": "$warning",
33
+ "status": "$warning",
34
+ # domain-specific
35
+ "key.primary": "bold $warning",
36
+ "key.foreign": "bold $accent",
37
+ "key.index": "bold $primary",
38
+ "column.schema": "$info",
39
+ "column.name": "white",
40
+ "column.type": "$warning",
41
+ "heading": "bold $primary",
42
+ "section": "bold $accent",
43
+ "title": "bold $success",
44
+ }
45
+
46
+ # Theme presets using exact Pygments colors
47
+ THEME_PRESETS = {
48
+ # Nord - exact colors from pygments nord theme
49
+ "nord": {
50
+ "primary": "#81a1c1", # Keyword (frost)
51
+ "accent": "#b48ead", # Number (aurora purple)
52
+ "success": "#a3be8c", # String (aurora green)
53
+ "warning": "#ebcb8b", # String.Escape (aurora yellow)
54
+ "error": "#bf616a", # Error/Generic.Error (aurora red)
55
+ "info": "#88c0d0", # Name.Function (frost cyan)
56
+ "muted": "dim",
57
+ },
58
+ # Dracula - exact colors from pygments dracula theme
59
+ "dracula": {
60
+ "primary": "#bd93f9", # purple
61
+ "accent": "#ff79c6", # pink
62
+ "success": "#50fa7b", # green
63
+ "warning": "#f1fa8c", # yellow
64
+ "error": "#ff5555", # red
65
+ "info": "#8be9fd", # cyan
66
+ "muted": "dim",
67
+ },
68
+ # Solarized Light - exact colors from pygments solarized-light theme
69
+ "solarized-light": {
70
+ "primary": "#268bd2", # blue
71
+ "accent": "#d33682", # magenta
72
+ "success": "#859900", # green
73
+ "warning": "#b58900", # yellow
74
+ "error": "#dc322f", # red
75
+ "info": "#2aa198", # cyan
76
+ "muted": "dim",
77
+ },
78
+ # VS (Visual Studio Light) - exact colors from pygments vs theme
79
+ "vs": {
80
+ "primary": "#0000ff", # Keyword (blue)
81
+ "accent": "#2b91af", # Keyword.Type/Name.Class
82
+ "success": "#008000", # Comment (green)
83
+ "warning": "#b58900", # (using solarized yellow as fallback)
84
+ "error": "#dc322f", # (using solarized red as fallback)
85
+ "info": "#2aa198", # (using solarized cyan as fallback)
86
+ "muted": "dim",
87
+ },
88
+ # Material (approximation based on material design colors)
89
+ "material": {
90
+ "primary": "#89ddff", # cyan
91
+ "accent": "#f07178", # pink/red
92
+ "success": "#c3e88d", # green
93
+ "warning": "#ffcb6b", # yellow
94
+ "error": "#ff5370", # red
95
+ "info": "#82aaff", # blue
96
+ "muted": "dim",
97
+ },
98
+ # One Dark - exact colors from pygments one-dark theme
99
+ "one-dark": {
100
+ "primary": "#c678dd", # Keyword (purple)
101
+ "accent": "#e06c75", # Name (red)
102
+ "success": "#98c379", # String (green)
103
+ "warning": "#e5c07b", # Keyword.Type (yellow)
104
+ "error": "#e06c75", # Name (red, used for errors)
105
+ "info": "#61afef", # Name.Function (blue)
106
+ "muted": "dim",
107
+ },
108
+ # Lightbulb - exact colors from pygments lightbulb theme (minimal dark)
109
+ "lightbulb": {
110
+ "primary": "#73d0ff", # Keyword.Type/Name.Class (blue_1)
111
+ "accent": "#dfbfff", # Number (magenta_1)
112
+ "success": "#d5ff80", # String (green_1)
113
+ "warning": "#ffd173", # Name.Function (yellow_1)
114
+ "error": "#f88f7f", # Error (red_1)
115
+ "info": "#95e6cb", # Name.Entity (cyan_1)
116
+ "muted": "dim",
117
+ },
118
+ }
119
+
120
+
121
+ def _load_user_theme_config() -> dict:
122
+ """Load theme configuration from user config directory."""
123
+ cfg_dir = user_config_dir("sqlsaber")
124
+ path = os.path.join(cfg_dir, "theme.toml")
125
+ if not os.path.exists(path):
126
+ return {}
127
+ with open(path, "rb") as f:
128
+ return tomllib.load(f)
129
+
130
+
131
+ def _resolve_refs(palette: dict[str, str]) -> dict[str, str]:
132
+ """Resolve $var references in palette values."""
133
+ out = {}
134
+ for k, v in palette.items():
135
+ if isinstance(v, str) and "$" in v:
136
+ parts = v.split()
137
+ resolved = []
138
+ for part in parts:
139
+ if part.startswith("$"):
140
+ ref = part[1:]
141
+ resolved.append(palette.get(ref, ""))
142
+ else:
143
+ resolved.append(part)
144
+ out[k] = " ".join(p for p in resolved if p)
145
+ else:
146
+ out[k] = v
147
+ return out
148
+
149
+
150
+ @dataclass(frozen=True)
151
+ class ThemeConfig:
152
+ """Theme configuration."""
153
+
154
+ name: str
155
+ pygments_style: str
156
+ roles: Dict[str, str]
157
+
158
+
159
+ class ThemeManager:
160
+ """Manages theme configuration and provides themed components."""
161
+
162
+ def __init__(self, cfg: ThemeConfig):
163
+ self._cfg = cfg
164
+ self._roles = _resolve_refs({**DEFAULT_ROLE_PALETTE, **cfg.roles})
165
+ self._rich_theme = Theme(self._roles)
166
+ self._pt_style = None
167
+
168
+ @property
169
+ def rich_theme(self) -> Theme:
170
+ """Get Rich theme with semantic role mappings."""
171
+ return self._rich_theme
172
+
173
+ @property
174
+ def pygments_style_name(self) -> str:
175
+ """Get pygments style name for syntax highlighting."""
176
+ return self._cfg.pygments_style
177
+
178
+ def pt_style(self) -> PTStyle:
179
+ """Get prompt_toolkit style derived from Pygments theme."""
180
+ if self._pt_style is None:
181
+ try:
182
+ # Try to use Pygments style directly
183
+ pygments_style = get_style_by_name(self._cfg.pygments_style)
184
+ self._pt_style = style_from_pygments_cls(pygments_style)
185
+ except Exception:
186
+ # Fallback to basic style if Pygments theme not found
187
+ self._pt_style = PTStyle.from_dict({})
188
+ return self._pt_style
189
+
190
+ def style(self, role: str) -> str:
191
+ """Get style string for a semantic role."""
192
+ return self._roles.get(role, "")
193
+
194
+
195
+ @lru_cache(maxsize=1)
196
+ def get_theme_manager() -> ThemeManager:
197
+ """Get the global theme manager instance."""
198
+ user_cfg = _load_user_theme_config()
199
+ env_name = os.getenv("SQLSABER_THEME")
200
+
201
+ name = (
202
+ env_name or user_cfg.get("theme", {}).get("name") or DEFAULT_THEME_NAME
203
+ ).lower()
204
+ pygments_style = user_cfg.get("theme", {}).get("pygments_style") or name
205
+
206
+ roles = dict(DEFAULT_ROLE_PALETTE)
207
+ roles.update(THEME_PRESETS.get(name, {}))
208
+ roles.update(user_cfg.get("roles", {}))
209
+
210
+ cfg = ThemeConfig(name=name, pygments_style=pygments_style, roles=roles)
211
+ return ThemeManager(cfg)
212
+
213
+
214
+ def create_console(**kwargs):
215
+ """Create a Rich Console with theme applied."""
216
+ # from rich.console import Console
217
+
218
+ tm = get_theme_manager()
219
+ return Console(theme=tm.rich_theme, **kwargs)