sqlsaber 0.27.0__py3-none-any.whl → 0.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

sqlsaber/cli/theme.py ADDED
@@ -0,0 +1,146 @@
1
+ """Theme management CLI commands."""
2
+
3
+ import asyncio
4
+ import json
5
+ import os
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import cyclopts
10
+ import questionary
11
+ from platformdirs import user_config_dir
12
+ from pygments.styles import get_all_styles
13
+
14
+ from sqlsaber.theme.manager import DEFAULT_THEME_NAME, create_console
15
+
16
+ console = create_console()
17
+
18
+ # Create the theme management CLI app
19
+ theme_app = cyclopts.App(
20
+ name="theme",
21
+ help="Manage theme settings",
22
+ )
23
+
24
+
25
+ class ThemeManager:
26
+ """Manages theme configuration persistence."""
27
+
28
+ def __init__(self):
29
+ self.config_dir = Path(user_config_dir("sqlsaber"))
30
+ self.config_file = self.config_dir / "theme.json"
31
+
32
+ def _ensure_config_dir(self) -> None:
33
+ """Ensure config directory exists."""
34
+ self.config_dir.mkdir(parents=True, exist_ok=True)
35
+
36
+ def _load_config(self) -> dict:
37
+ """Load theme configuration from file."""
38
+ if not self.config_file.exists():
39
+ return {}
40
+
41
+ try:
42
+ with open(self.config_file, "r") as f:
43
+ return json.load(f)
44
+ except Exception:
45
+ return {}
46
+
47
+ def _save_config(self, config: dict) -> None:
48
+ """Save theme configuration to file."""
49
+ self._ensure_config_dir()
50
+
51
+ with open(self.config_file, "w") as f:
52
+ json.dump(config, f, indent=2)
53
+
54
+ def get_current_theme(self) -> str:
55
+ """Get the currently configured theme."""
56
+ config = self._load_config()
57
+ env_theme = os.getenv("SQLSABER_THEME")
58
+ if env_theme:
59
+ return env_theme
60
+ return config.get("theme", {}).get("pygments_style") or DEFAULT_THEME_NAME
61
+
62
+ def set_theme(self, theme_name: str) -> bool:
63
+ """Set the current theme."""
64
+ try:
65
+ config = self._load_config()
66
+ if "theme" not in config:
67
+ config["theme"] = {}
68
+ config["theme"]["name"] = theme_name
69
+ config["theme"]["pygments_style"] = theme_name
70
+ self._save_config(config)
71
+ return True
72
+ except Exception as e:
73
+ console.print(f"[error]Error setting theme: {e}[/error]")
74
+ return False
75
+
76
+ def reset_theme(self) -> bool:
77
+ """Reset to default theme."""
78
+ try:
79
+ if self.config_file.exists():
80
+ self.config_file.unlink()
81
+ return True
82
+ except Exception as e:
83
+ console.print(f"[error]Error resetting theme: {e}[/error]")
84
+ return False
85
+
86
+ def get_available_themes(self) -> list[str]:
87
+ """Get list of available Pygments themes."""
88
+ return sorted(get_all_styles())
89
+
90
+
91
+ theme_manager = ThemeManager()
92
+
93
+
94
+ @theme_app.command
95
+ def set():
96
+ """Set the theme to use for syntax highlighting."""
97
+
98
+ async def interactive_set():
99
+ themes = theme_manager.get_available_themes()
100
+ current_theme = theme_manager.get_current_theme()
101
+
102
+ # Create choices with current theme highlighted
103
+ choices = [
104
+ questionary.Choice(
105
+ title=f"{theme} (current)" if theme == current_theme else theme,
106
+ value=theme,
107
+ )
108
+ for theme in themes
109
+ ]
110
+
111
+ selected_theme = await questionary.select(
112
+ "Select a theme:",
113
+ choices=choices,
114
+ default=current_theme,
115
+ use_search_filter=True,
116
+ use_jk_keys=False,
117
+ ).ask_async()
118
+
119
+ if selected_theme:
120
+ if theme_manager.set_theme(selected_theme):
121
+ console.print(f"[success]✓ Theme set to: {selected_theme}[/success]")
122
+ else:
123
+ console.print("[error]✗ Failed to set theme[/error]")
124
+ sys.exit(1)
125
+ else:
126
+ console.print("[warning]Operation cancelled[/warning]")
127
+
128
+ asyncio.run(interactive_set())
129
+
130
+
131
+ @theme_app.command
132
+ def reset():
133
+ """Reset to the default theme."""
134
+
135
+ if theme_manager.reset_theme():
136
+ console.print(
137
+ f"[success]✓ Theme reset to default: {DEFAULT_THEME_NAME}[/success]"
138
+ )
139
+ else:
140
+ console.print("[error]✗ Failed to reset theme[/error]")
141
+ sys.exit(1)
142
+
143
+
144
+ def create_theme_app() -> cyclopts.App:
145
+ """Return the theme management CLI app."""
146
+ return theme_app
sqlsaber/cli/threads.py CHANGED
@@ -12,10 +12,12 @@ from rich.markdown import Markdown
12
12
  from rich.panel import Panel
13
13
  from rich.table import Table
14
14
 
15
+ from sqlsaber.theme.manager import create_console, get_theme_manager
15
16
  from sqlsaber.threads import ThreadStorage
16
17
 
17
18
  # Globals consistent with other CLI modules
18
- console = Console()
19
+ console = create_console()
20
+ tm = get_theme_manager()
19
21
 
20
22
 
21
23
  threads_app = cyclopts.App(
@@ -84,13 +86,23 @@ def _render_transcript(
84
86
  console.print(f"**User:**\n\n{text}\n")
85
87
  else:
86
88
  console.print(
87
- Panel.fit(Markdown(text), title="User", border_style="cyan")
89
+ Panel.fit(
90
+ Markdown(text, code_theme=tm.pygments_style_name),
91
+ title="User",
92
+ border_style=tm.style("panel.border.user"),
93
+ )
88
94
  )
89
95
  return
90
96
  if is_redirected:
91
97
  console.print("**User:** (no content)\n")
92
98
  else:
93
- console.print(Panel.fit("(no content)", title="User", border_style="cyan"))
99
+ console.print(
100
+ Panel.fit(
101
+ "(no content)",
102
+ title="User",
103
+ border_style=tm.style("panel.border.user"),
104
+ )
105
+ )
94
106
 
95
107
  def _render_response(message: ModelMessage) -> None:
96
108
  for part in getattr(message, "parts", []):
@@ -103,7 +115,9 @@ def _render_transcript(
103
115
  else:
104
116
  console.print(
105
117
  Panel.fit(
106
- Markdown(text), title="Assistant", border_style="green"
118
+ Markdown(text, code_theme=tm.pygments_style_name),
119
+ title="Assistant",
120
+ border_style=tm.style("panel.border.assistant"),
107
121
  )
108
122
  )
109
123
  elif kind in ("tool-call", "builtin-tool-call"):
@@ -211,11 +225,11 @@ def list_threads(
211
225
  console.print("No threads found.")
212
226
  return
213
227
  table = Table(title="Threads")
214
- table.add_column("ID", style="cyan")
215
- table.add_column("Database", style="magenta")
216
- table.add_column("Title", style="green")
217
- table.add_column("Last Activity", style="dim")
218
- table.add_column("Model", style="yellow")
228
+ table.add_column("ID", style=tm.style("info"))
229
+ table.add_column("Database", style=tm.style("accent"))
230
+ table.add_column("Title", style=tm.style("success"))
231
+ table.add_column("Last Activity", style=tm.style("muted"))
232
+ table.add_column("Model", style=tm.style("warning"))
219
233
  for t in threads:
220
234
  table.add_row(
221
235
  t.id,
@@ -235,7 +249,7 @@ def show(
235
249
  store = ThreadStorage()
236
250
  thread = asyncio.run(store.get_thread(thread_id))
237
251
  if not thread:
238
- console.print(f"[red]Thread not found:[/red] {thread_id}")
252
+ console.print(f"[error]Thread not found:[/error] {thread_id}")
239
253
  return
240
254
  msgs = asyncio.run(store.get_thread_messages(thread_id))
241
255
  console.print(f"[bold]Thread: {thread.id}[/bold]")
@@ -273,12 +287,12 @@ def resume(
273
287
 
274
288
  thread = await store.get_thread(thread_id)
275
289
  if not thread:
276
- console.print(f"[red]Thread not found:[/red] {thread_id}")
290
+ console.print(f"[error]Thread not found:[/error] {thread_id}")
277
291
  return
278
292
  db_selector = database or thread.database_name
279
293
  if not db_selector:
280
294
  console.print(
281
- "[red]No database specified or stored with this thread.[/red]"
295
+ "[error]No database specified or stored with this thread.[/error]"
282
296
  )
283
297
  return
284
298
  try:
@@ -287,7 +301,7 @@ def resume(
287
301
  connection_string = resolved.connection_string
288
302
  db_name = resolved.name
289
303
  except DatabaseResolutionError as e:
290
- console.print(f"[red]Database resolution error:[/red] {e}")
304
+ console.print(f"[error]Database resolution error:[/error] {e}")
291
305
  return
292
306
 
293
307
  db_conn = DatabaseConnection(connection_string)
@@ -295,7 +309,12 @@ def resume(
295
309
  sqlsaber_agent = SQLSaberAgent(db_conn, db_name)
296
310
  history = await store.get_thread_messages(thread_id)
297
311
  if console.is_terminal:
298
- console.print(Panel.fit(f"Thread: {thread.id}", border_style="blue"))
312
+ console.print(
313
+ Panel.fit(
314
+ f"Thread: {thread.id}",
315
+ border_style=tm.style("panel.border.thread"),
316
+ )
317
+ )
299
318
  else:
300
319
  console.print(f"# Thread: {thread.id}\n")
301
320
  _render_transcript(console, history, None)
@@ -310,7 +329,7 @@ def resume(
310
329
  await session.run()
311
330
  finally:
312
331
  await db_conn.close()
313
- console.print("\n[green]Goodbye![/green]")
332
+ console.print("\n[success]Goodbye![/success]")
314
333
 
315
334
  asyncio.run(_run())
316
335
 
@@ -329,7 +348,7 @@ def prune(
329
348
 
330
349
  async def _run() -> None:
331
350
  deleted = await store.prune_threads(older_than_days=days)
332
- console.print(f"[green]✓ Pruned {deleted} thread(s).[/green]")
351
+ console.print(f"[success]✓ Pruned {deleted} thread(s).[/success]")
333
352
 
334
353
  asyncio.run(_run())
335
354
 
@@ -4,11 +4,11 @@ import getpass
4
4
  import os
5
5
 
6
6
  import keyring
7
- from rich.console import Console
8
7
 
9
8
  from sqlsaber.config import providers
9
+ from sqlsaber.theme.manager import create_console
10
10
 
11
- console = Console()
11
+ console = create_console()
12
12
 
13
13
 
14
14
  class APIKeyManager:
@@ -10,12 +10,13 @@ from datetime import datetime, timezone
10
10
 
11
11
  import httpx
12
12
  import questionary
13
- from rich.console import Console
14
13
  from rich.progress import Progress, SpinnerColumn, TextColumn
15
14
 
15
+ from sqlsaber.theme.manager import create_console
16
+
16
17
  from .oauth_tokens import OAuthToken, OAuthTokenManager
17
18
 
18
- console = Console()
19
+ console = create_console()
19
20
  logger = logging.getLogger(__name__)
20
21
 
21
22
 
@@ -6,9 +6,10 @@ from datetime import datetime, timedelta, timezone
6
6
  from typing import Any
7
7
 
8
8
  import keyring
9
- from rich.console import Console
10
9
 
11
- console = Console()
10
+ from sqlsaber.theme.manager import create_console
11
+
12
+ console = create_console()
12
13
  logger = logging.getLogger(__name__)
13
14
 
14
15
 
@@ -158,9 +159,6 @@ class OAuthTokenManager:
158
159
  keyring.delete_password(service_name, provider)
159
160
  console.print(f"OAuth token for {provider} removed", style="green")
160
161
  return True
161
- except keyring.errors.PasswordDeleteError:
162
- # Token doesn't exist
163
- return True
164
162
  except Exception as e:
165
163
  logger.error(f"Failed to remove OAuth token for {provider}: {e}")
166
164
  console.print(f"Warning: Could not remove OAuth token: {e}", style="yellow")
sqlsaber/database/base.py CHANGED
@@ -61,6 +61,12 @@ class BaseDatabaseConnection(ABC):
61
61
  self.connection_string = connection_string
62
62
  self._pool = None
63
63
 
64
+ @property
65
+ @abstractmethod
66
+ def sqlglot_dialect(self) -> str:
67
+ """Return the sqlglot dialect name for this database."""
68
+ pass
69
+
64
70
  @abstractmethod
65
71
  async def get_pool(self):
66
72
  """Get or create connection pool."""
sqlsaber/database/csv.py CHANGED
@@ -58,6 +58,11 @@ class CSVConnection(BaseDatabaseConnection):
58
58
 
59
59
  self.table_name = Path(self.csv_path).stem or "csv_table"
60
60
 
61
+ @property
62
+ def sqlglot_dialect(self) -> str:
63
+ """Return the sqlglot dialect name."""
64
+ return "duckdb"
65
+
61
66
  async def get_pool(self):
62
67
  """CSV connections do not maintain a pool."""
63
68
  return None
@@ -52,6 +52,11 @@ class DuckDBConnection(BaseDatabaseConnection):
52
52
 
53
53
  self.database_path = db_path or ":memory:"
54
54
 
55
+ @property
56
+ def sqlglot_dialect(self) -> str:
57
+ """Return the sqlglot dialect name."""
58
+ return "duckdb"
59
+
55
60
  async def get_pool(self):
56
61
  """DuckDB creates connections per query, return database path."""
57
62
  return self.database_path
@@ -23,6 +23,11 @@ class MySQLConnection(BaseDatabaseConnection):
23
23
  self._pool: aiomysql.Pool | None = None
24
24
  self._parse_connection_string()
25
25
 
26
+ @property
27
+ def sqlglot_dialect(self) -> str:
28
+ """Return the sqlglot dialect name."""
29
+ return "mysql"
30
+
26
31
  def _parse_connection_string(self):
27
32
  """Parse MySQL connection string into components."""
28
33
  parsed = urlparse(self.connection_string)
@@ -23,6 +23,11 @@ class PostgreSQLConnection(BaseDatabaseConnection):
23
23
  self._pool: asyncpg.Pool | None = None
24
24
  self._ssl_context = self._create_ssl_context()
25
25
 
26
+ @property
27
+ def sqlglot_dialect(self) -> str:
28
+ """Return the sqlglot dialect name."""
29
+ return "postgres"
30
+
26
31
  def _create_ssl_context(self) -> ssl.SSLContext | None:
27
32
  """Create SSL context from connection string parameters."""
28
33
  parsed = urlparse(self.connection_string)
@@ -21,6 +21,11 @@ class SQLiteConnection(BaseDatabaseConnection):
21
21
  # Extract database path from sqlite:///path format
22
22
  self.database_path = connection_string.replace("sqlite:///", "")
23
23
 
24
+ @property
25
+ def sqlglot_dialect(self) -> str:
26
+ """Return the sqlglot dialect name."""
27
+ return "sqlite"
28
+
24
29
  async def get_pool(self):
25
30
  """SQLite doesn't use connection pooling, return database path."""
26
31
  return self.database_path
@@ -0,0 +1,5 @@
1
+ """Theme management for SQLSaber."""
2
+
3
+ from sqlsaber.theme.manager import ThemeManager, create_console, get_theme_manager
4
+
5
+ __all__ = ["ThemeManager", "create_console", "get_theme_manager"]
@@ -0,0 +1,229 @@
1
+ """Theme management for unified theming across Rich and prompt_toolkit."""
2
+
3
+ import json
4
+ import os
5
+ from dataclasses import dataclass
6
+ from functools import lru_cache
7
+ from typing import Dict
8
+
9
+ from platformdirs import user_config_dir
10
+ from prompt_toolkit.styles import Style as PTStyle
11
+ from prompt_toolkit.styles.pygments import style_from_pygments_cls
12
+ from pygments.styles import get_style_by_name
13
+ from pygments.token import Token
14
+ from pygments.util import ClassNotFound
15
+ from rich.console import Console
16
+ from rich.theme import Theme
17
+
18
+ DEFAULT_THEME_NAME = "nord"
19
+
20
+ DEFAULT_ROLE_PALETTE = {
21
+ # base roles
22
+ "primary": "cyan",
23
+ "accent": "magenta",
24
+ "success": "green",
25
+ "warning": "yellow",
26
+ "error": "red",
27
+ "info": "cyan",
28
+ "muted": "dim",
29
+ # components
30
+ "table.header": "bold $primary",
31
+ "panel.border.user": "$info",
32
+ "panel.border.assistant": "$success",
33
+ "panel.border.thread": "$primary",
34
+ "spinner": "$warning",
35
+ "status": "$warning",
36
+ # domain-specific
37
+ "key.primary": "bold $warning",
38
+ "key.foreign": "bold $accent",
39
+ "key.index": "bold $primary",
40
+ "column.schema": "$info",
41
+ "column.name": "white",
42
+ "column.type": "$warning",
43
+ "heading": "bold $primary",
44
+ "section": "bold $accent",
45
+ "title": "bold $success",
46
+ }
47
+
48
+ ROLE_TOKEN_PREFERENCES: dict[str, tuple] = {
49
+ "primary": (
50
+ Token.Keyword,
51
+ Token.Keyword.Namespace,
52
+ Token.Name.Tag,
53
+ ),
54
+ "accent": (
55
+ Token.Name.Tag,
56
+ Token.Keyword.Type,
57
+ Token.Literal.Number,
58
+ Token.Operator.Word,
59
+ ),
60
+ "success": (
61
+ Token.Literal.String,
62
+ Token.Generic.Inserted,
63
+ Token.Name.Attribute,
64
+ ),
65
+ "warning": (
66
+ Token.Literal.String.Escape,
67
+ Token.Name.Constant,
68
+ Token.Generic.Emph,
69
+ ),
70
+ "error": (
71
+ Token.Error,
72
+ Token.Generic.Error,
73
+ Token.Generic.Deleted,
74
+ Token.Name.Exception,
75
+ ),
76
+ "info": (
77
+ Token.Name.Function,
78
+ Token.Name.Builtin,
79
+ Token.Keyword.Type,
80
+ ),
81
+ "muted": (
82
+ Token.Comment,
83
+ Token.Generic.Subheading,
84
+ Token.Text,
85
+ ),
86
+ }
87
+
88
+
89
+ def _normalize_hex(color: str | None) -> str | None:
90
+ if not color:
91
+ return None
92
+ color = color.strip()
93
+ if not color:
94
+ return None
95
+ if color.startswith("#"):
96
+ color = color[1:]
97
+ if len(color) == 3:
98
+ color = "".join(ch * 2 for ch in color)
99
+ if len(color) != 6:
100
+ return None
101
+ return f"#{color.lower()}"
102
+
103
+
104
+ def _build_role_palette_from_style(style_name: str) -> dict[str, str]:
105
+ try:
106
+ style_cls = get_style_by_name(style_name)
107
+ except ClassNotFound:
108
+ return {}
109
+
110
+ palette: dict[str, str] = {}
111
+ try:
112
+ base_color = _normalize_hex(style_cls.style_for_token(Token.Text).get("color"))
113
+ except KeyError:
114
+ base_color = None
115
+ for role, tokens in ROLE_TOKEN_PREFERENCES.items():
116
+ for token in tokens:
117
+ try:
118
+ style_def = style_cls.style_for_token(token)
119
+ except KeyError:
120
+ continue
121
+ color = _normalize_hex(style_def.get("color"))
122
+ if not color or color == base_color:
123
+ continue
124
+ if role == "accent" and color == palette.get("primary"):
125
+ continue
126
+ palette[role] = color
127
+ break
128
+ return palette
129
+
130
+
131
+ def _load_user_theme_config() -> dict:
132
+ """Load theme configuration from user config directory."""
133
+ cfg_dir = user_config_dir("sqlsaber")
134
+ path = os.path.join(cfg_dir, "theme.json")
135
+ if not os.path.exists(path):
136
+ return {}
137
+ with open(path, "r") as f:
138
+ return json.load(f)
139
+
140
+
141
+ def _resolve_refs(palette: dict[str, str]) -> dict[str, str]:
142
+ """Resolve $var references in palette values."""
143
+ out = {}
144
+ for k, v in palette.items():
145
+ if isinstance(v, str) and "$" in v:
146
+ parts = v.split()
147
+ resolved = []
148
+ for part in parts:
149
+ if part.startswith("$"):
150
+ ref = part[1:]
151
+ resolved.append(palette.get(ref, ""))
152
+ else:
153
+ resolved.append(part)
154
+ out[k] = " ".join(p for p in resolved if p)
155
+ else:
156
+ out[k] = v
157
+ return out
158
+
159
+
160
+ @dataclass(frozen=True)
161
+ class ThemeConfig:
162
+ """Theme configuration."""
163
+
164
+ name: str
165
+ pygments_style: str
166
+ roles: Dict[str, str]
167
+
168
+
169
+ class ThemeManager:
170
+ """Manages theme configuration and provides themed components."""
171
+
172
+ def __init__(self, cfg: ThemeConfig):
173
+ self._cfg = cfg
174
+ self._roles = _resolve_refs({**DEFAULT_ROLE_PALETTE, **cfg.roles})
175
+ self._rich_theme = Theme(self._roles)
176
+ self._pt_style = None
177
+
178
+ @property
179
+ def rich_theme(self) -> Theme:
180
+ """Get Rich theme with semantic role mappings."""
181
+ return self._rich_theme
182
+
183
+ @property
184
+ def pygments_style_name(self) -> str:
185
+ """Get pygments style name for syntax highlighting."""
186
+ return self._cfg.pygments_style
187
+
188
+ def pt_style(self) -> PTStyle:
189
+ """Get prompt_toolkit style derived from Pygments theme."""
190
+ if self._pt_style is None:
191
+ try:
192
+ # Try to use Pygments style directly
193
+ pygments_style = get_style_by_name(self._cfg.pygments_style)
194
+ self._pt_style = style_from_pygments_cls(pygments_style)
195
+ except Exception:
196
+ # Fallback to basic style if Pygments theme not found
197
+ self._pt_style = PTStyle.from_dict({})
198
+ return self._pt_style
199
+
200
+ def style(self, role: str) -> str:
201
+ """Get style string for a semantic role."""
202
+ return self._roles.get(role, "")
203
+
204
+
205
+ @lru_cache(maxsize=1)
206
+ def get_theme_manager() -> ThemeManager:
207
+ """Get the global theme manager instance."""
208
+ user_cfg = _load_user_theme_config()
209
+ env_name = os.getenv("SQLSABER_THEME")
210
+
211
+ name = (
212
+ env_name or user_cfg.get("theme", {}).get("name") or DEFAULT_THEME_NAME
213
+ ).lower()
214
+ pygments_style = user_cfg.get("theme", {}).get("pygments_style") or name
215
+
216
+ roles = dict(DEFAULT_ROLE_PALETTE)
217
+ roles.update(_build_role_palette_from_style(pygments_style))
218
+ roles.update(user_cfg.get("roles", {}))
219
+
220
+ cfg = ThemeConfig(name=name, pygments_style=pygments_style, roles=roles)
221
+ return ThemeManager(cfg)
222
+
223
+
224
+ def create_console(**kwargs):
225
+ """Create a Rich Console with theme applied."""
226
+ # from rich.console import Console
227
+
228
+ tm = get_theme_manager()
229
+ return Console(theme=tm.rich_theme, **kwargs)