sqlsaber 0.29.0__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/base.py +1 -1
- sqlsaber/agents/pydantic_ai_agent.py +39 -17
- sqlsaber/application/auth_setup.py +3 -3
- sqlsaber/application/db_setup.py +2 -2
- sqlsaber/application/model_selection.py +2 -2
- sqlsaber/cli/auth.py +10 -8
- sqlsaber/cli/commands.py +3 -3
- sqlsaber/cli/database.py +22 -20
- sqlsaber/cli/interactive.py +11 -5
- sqlsaber/cli/memory.py +10 -10
- sqlsaber/cli/models.py +12 -12
- sqlsaber/cli/onboarding.py +41 -44
- sqlsaber/cli/streaming.py +2 -11
- sqlsaber/cli/threads.py +3 -3
- sqlsaber/config/api_keys.py +5 -5
- sqlsaber/config/oauth_flow.py +11 -10
- sqlsaber/config/oauth_tokens.py +7 -5
- sqlsaber/database/schema.py +1 -1
- sqlsaber/theme/manager.py +4 -9
- sqlsaber/tools/__init__.py +0 -5
- sqlsaber/tools/base.py +0 -31
- sqlsaber/tools/registry.py +6 -39
- sqlsaber/tools/sql_tools.py +0 -42
- {sqlsaber-0.29.0.dist-info → sqlsaber-0.30.0.dist-info}/METADATA +3 -44
- sqlsaber-0.30.0.dist-info/RECORD +57 -0
- {sqlsaber-0.29.0.dist-info → sqlsaber-0.30.0.dist-info}/entry_points.txt +0 -2
- sqlsaber/agents/mcp.py +0 -21
- sqlsaber/mcp/__init__.py +0 -5
- sqlsaber/mcp/mcp.py +0 -129
- sqlsaber/tools/enums.py +0 -19
- sqlsaber/tools/instructions.py +0 -231
- sqlsaber-0.29.0.dist-info/RECORD +0 -62
- {sqlsaber-0.29.0.dist-info → sqlsaber-0.30.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.29.0.dist-info → sqlsaber-0.30.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/cli/onboarding.py
CHANGED
|
@@ -32,16 +32,16 @@ def needs_onboarding(database_arg: str | None = None) -> bool:
|
|
|
32
32
|
|
|
33
33
|
def welcome_screen() -> None:
|
|
34
34
|
"""Display welcome screen to new users."""
|
|
35
|
-
banner = """
|
|
35
|
+
banner = """[primary]
|
|
36
36
|
███████ ██████ ██ ███████ █████ ██████ ███████ ██████
|
|
37
37
|
██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
38
38
|
███████ ██ ██ ██ ███████ ███████ ██████ █████ ██████
|
|
39
39
|
██ ██ ▄▄ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
|
40
40
|
███████ ██████ ███████ ███████ ██ ██ ██████ ███████ ██ ██
|
|
41
41
|
▀▀
|
|
42
|
-
"""
|
|
42
|
+
[/primary]"""
|
|
43
43
|
|
|
44
|
-
console.print(Panel.fit(banner, style="
|
|
44
|
+
console.print(Panel.fit(banner, style="primary"))
|
|
45
45
|
console.print()
|
|
46
46
|
|
|
47
47
|
welcome_message = """
|
|
@@ -52,7 +52,9 @@ SQLsaber is an agentic SQL assistant that lets you query your database using nat
|
|
|
52
52
|
Let's get you set up in just a few steps.
|
|
53
53
|
"""
|
|
54
54
|
|
|
55
|
-
console.print(
|
|
55
|
+
console.print(
|
|
56
|
+
Panel.fit(welcome_message.strip(), border_style="primary", padding=(1, 2))
|
|
57
|
+
)
|
|
56
58
|
console.print()
|
|
57
59
|
|
|
58
60
|
|
|
@@ -69,9 +71,7 @@ async def setup_database_guided() -> str | None:
|
|
|
69
71
|
)
|
|
70
72
|
from sqlsaber.application.prompts import AsyncPrompter
|
|
71
73
|
|
|
72
|
-
console.print("
|
|
73
|
-
console.print("[bold cyan]Step 1 of 2: Database Connection[/bold cyan]")
|
|
74
|
-
console.print("━" * 80, style="dim")
|
|
74
|
+
console.print("[heading]Step 1 of 2: Database Connection[/heading]")
|
|
75
75
|
console.print()
|
|
76
76
|
|
|
77
77
|
try:
|
|
@@ -92,7 +92,7 @@ async def setup_database_guided() -> str | None:
|
|
|
92
92
|
db_manager = DatabaseConfigManager()
|
|
93
93
|
if db_manager.get_database(name):
|
|
94
94
|
console.print(
|
|
95
|
-
f"[
|
|
95
|
+
f"[warning]Database connection '{name}' already exists.[/warning]"
|
|
96
96
|
)
|
|
97
97
|
return name
|
|
98
98
|
|
|
@@ -108,7 +108,7 @@ async def setup_database_guided() -> str | None:
|
|
|
108
108
|
db_config = build_config(db_input)
|
|
109
109
|
|
|
110
110
|
# Test the connection
|
|
111
|
-
console.print(f"[
|
|
111
|
+
console.print(f"[muted]Testing connection to '{name}'...[/muted]")
|
|
112
112
|
connection_success = await test_connection(db_config, db_input.password)
|
|
113
113
|
|
|
114
114
|
if not connection_success:
|
|
@@ -119,25 +119,25 @@ async def setup_database_guided() -> str | None:
|
|
|
119
119
|
return await setup_database_guided()
|
|
120
120
|
else:
|
|
121
121
|
console.print(
|
|
122
|
-
"[
|
|
122
|
+
"[warning]You can add a database later using 'saber db add'[/warning]"
|
|
123
123
|
)
|
|
124
124
|
return None
|
|
125
125
|
|
|
126
126
|
# Save the configuration
|
|
127
127
|
try:
|
|
128
128
|
save_database(db_manager, db_config, db_input.password)
|
|
129
|
-
console.print(f"[
|
|
129
|
+
console.print(f"[success]✓ Connection to '{name}' successful![/success]")
|
|
130
130
|
console.print()
|
|
131
131
|
return name
|
|
132
132
|
except Exception as e:
|
|
133
|
-
console.print(f"[
|
|
133
|
+
console.print(f"[error]Error saving database:[/error] {e}")
|
|
134
134
|
return None
|
|
135
135
|
|
|
136
136
|
except KeyboardInterrupt:
|
|
137
|
-
console.print("\n[
|
|
137
|
+
console.print("\n[warning]Setup cancelled.[/warning]")
|
|
138
138
|
return None
|
|
139
139
|
except Exception as e:
|
|
140
|
-
console.print(f"[
|
|
140
|
+
console.print(f"[error]Unexpected error:[/error] {e}")
|
|
141
141
|
return None
|
|
142
142
|
|
|
143
143
|
|
|
@@ -151,14 +151,14 @@ async def select_model_for_provider(provider: str) -> str | None:
|
|
|
151
151
|
|
|
152
152
|
try:
|
|
153
153
|
console.print()
|
|
154
|
-
console.print(f"[
|
|
154
|
+
console.print(f"[muted]Fetching available {provider.title()} models...[/muted]")
|
|
155
155
|
|
|
156
156
|
model_manager = ModelManager()
|
|
157
157
|
models = await fetch_models(model_manager, providers=[provider])
|
|
158
158
|
|
|
159
159
|
if not models:
|
|
160
160
|
console.print(
|
|
161
|
-
f"[
|
|
161
|
+
f"[warning]Could not fetch models for {provider}. Using default.[/warning]"
|
|
162
162
|
)
|
|
163
163
|
# Use provider-specific default or fallback to Anthropic
|
|
164
164
|
default_model_id = ModelManager.RECOMMENDED_MODELS.get(
|
|
@@ -178,10 +178,10 @@ async def select_model_for_provider(provider: str) -> str | None:
|
|
|
178
178
|
return selected_model
|
|
179
179
|
|
|
180
180
|
except KeyboardInterrupt:
|
|
181
|
-
console.print("\n[
|
|
181
|
+
console.print("\n[warning]Model selection cancelled.[/warning]")
|
|
182
182
|
return None
|
|
183
183
|
except Exception as e:
|
|
184
|
-
console.print(f"[
|
|
184
|
+
console.print(f"[warning]Error selecting model: {e}. Using default.[/warning]")
|
|
185
185
|
# Fallback to provider default
|
|
186
186
|
if provider in ModelManager.RECOMMENDED_MODELS:
|
|
187
187
|
return f"{provider}:{ModelManager.RECOMMENDED_MODELS[provider]}"
|
|
@@ -196,9 +196,7 @@ async def setup_auth_guided() -> tuple[bool, str | None]:
|
|
|
196
196
|
from sqlsaber.application.auth_setup import setup_auth
|
|
197
197
|
from sqlsaber.application.prompts import AsyncPrompter
|
|
198
198
|
|
|
199
|
-
console.print("
|
|
200
|
-
console.print("[bold cyan]Step 2 of 2: Authentication[/bold cyan]")
|
|
201
|
-
console.print("━" * 80, style="dim")
|
|
199
|
+
console.print("[primary]Step 2 of 2: Authentication[/primary]")
|
|
202
200
|
console.print()
|
|
203
201
|
|
|
204
202
|
try:
|
|
@@ -218,7 +216,7 @@ async def setup_auth_guided() -> tuple[bool, str | None]:
|
|
|
218
216
|
|
|
219
217
|
if not success:
|
|
220
218
|
console.print(
|
|
221
|
-
"[
|
|
219
|
+
"[warning]You can set it up later using 'saber auth setup'[/warning]"
|
|
222
220
|
)
|
|
223
221
|
console.print()
|
|
224
222
|
return False, None
|
|
@@ -233,16 +231,16 @@ async def setup_auth_guided() -> tuple[bool, str | None]:
|
|
|
233
231
|
if selected_model:
|
|
234
232
|
model_manager = ModelManager()
|
|
235
233
|
model_manager.set_model(selected_model)
|
|
236
|
-
console.print(f"[
|
|
234
|
+
console.print(f"[success]✓ Model set to: {selected_model}[/success]")
|
|
237
235
|
console.print()
|
|
238
236
|
return True, selected_model
|
|
239
237
|
|
|
240
238
|
except KeyboardInterrupt:
|
|
241
|
-
console.print("\n[
|
|
239
|
+
console.print("\n[warning]Setup cancelled.[/warning]")
|
|
242
240
|
console.print()
|
|
243
241
|
return False, None
|
|
244
242
|
except Exception as e:
|
|
245
|
-
console.print(f"[
|
|
243
|
+
console.print(f"[error]Unexpected error:[/error] {e}")
|
|
246
244
|
console.print()
|
|
247
245
|
return False, None
|
|
248
246
|
|
|
@@ -251,35 +249,34 @@ def success_screen(
|
|
|
251
249
|
database_name: str | None, auth_configured: bool, model_name: str | None = None
|
|
252
250
|
) -> None:
|
|
253
251
|
"""Display success screen after onboarding."""
|
|
254
|
-
|
|
255
|
-
console.print("[
|
|
256
|
-
console.print("━" * 80, style="dim")
|
|
252
|
+
|
|
253
|
+
console.print("[success]You're all set! 🚀[/success]")
|
|
257
254
|
console.print()
|
|
258
255
|
|
|
259
256
|
if database_name and auth_configured:
|
|
260
257
|
console.print(
|
|
261
|
-
f"[
|
|
258
|
+
f"[success]✓ Database '{database_name}' connected and ready to use[/success]"
|
|
262
259
|
)
|
|
263
|
-
console.print("[
|
|
260
|
+
console.print("[success]✓ Authentication configured[/success]")
|
|
264
261
|
if model_name:
|
|
265
|
-
console.print(f"[
|
|
262
|
+
console.print(f"[success]✓ Model: {model_name}[/success]")
|
|
266
263
|
elif database_name:
|
|
267
264
|
console.print(
|
|
268
|
-
f"[
|
|
265
|
+
f"[success]✓ Database '{database_name}' connected and ready to use[/success]"
|
|
269
266
|
)
|
|
270
267
|
console.print(
|
|
271
|
-
"[
|
|
268
|
+
"[warning]⚠ AI authentication not configured - you'll be prompted when needed[/warning]"
|
|
272
269
|
)
|
|
273
270
|
elif auth_configured:
|
|
274
|
-
console.print("[
|
|
271
|
+
console.print("[success]✓ AI authentication configured[/success]")
|
|
275
272
|
if model_name:
|
|
276
|
-
console.print(f"[
|
|
273
|
+
console.print(f"[success]✓ Model: {model_name}[/success]")
|
|
277
274
|
console.print(
|
|
278
|
-
"[
|
|
275
|
+
"[warning]⚠ No database configured - you'll need to provide one via -d flag[/warning]"
|
|
279
276
|
)
|
|
280
277
|
|
|
281
278
|
console.print()
|
|
282
|
-
console.print("[
|
|
279
|
+
console.print("[muted]Starting interactive session...[/muted]")
|
|
283
280
|
console.print()
|
|
284
281
|
|
|
285
282
|
|
|
@@ -298,9 +295,9 @@ async def run_onboarding() -> bool:
|
|
|
298
295
|
|
|
299
296
|
# If user cancelled database setup, exit
|
|
300
297
|
if database_name is None:
|
|
301
|
-
console.print("[
|
|
298
|
+
console.print("[warning]Database setup is required to continue.[/warning]")
|
|
302
299
|
console.print(
|
|
303
|
-
"[
|
|
300
|
+
"[muted]You can also provide a connection string using: saber -d <connection-string>[/muted]"
|
|
304
301
|
)
|
|
305
302
|
return False
|
|
306
303
|
|
|
@@ -313,13 +310,13 @@ async def run_onboarding() -> bool:
|
|
|
313
310
|
return True
|
|
314
311
|
|
|
315
312
|
except KeyboardInterrupt:
|
|
316
|
-
console.print("\n[
|
|
313
|
+
console.print("\n[warning]Onboarding cancelled.[/warning]")
|
|
317
314
|
console.print(
|
|
318
|
-
"[
|
|
319
|
-
"[
|
|
320
|
-
"[
|
|
315
|
+
"[muted]You can run setup commands manually:[/muted]\n"
|
|
316
|
+
"[muted] - saber db add <name> # Add database connection[/muted]\n"
|
|
317
|
+
"[muted] - saber auth setup # Configure authentication[/muted]"
|
|
321
318
|
)
|
|
322
319
|
sys.exit(0)
|
|
323
320
|
except Exception as e:
|
|
324
|
-
console.print(f"[
|
|
321
|
+
console.print(f"[error]Onboarding failed:[/error] {e}")
|
|
325
322
|
return False
|
sqlsaber/cli/streaming.py
CHANGED
|
@@ -144,17 +144,8 @@ class StreamingQueryHandler:
|
|
|
144
144
|
prepared_prompt: str | list[str] = user_query
|
|
145
145
|
no_history = not message_history
|
|
146
146
|
if sqlsaber_agent.is_oauth and no_history:
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
)
|
|
150
|
-
mem = ""
|
|
151
|
-
if sqlsaber_agent.database_name:
|
|
152
|
-
mem = sqlsaber_agent.memory_manager.format_memories_for_prompt(
|
|
153
|
-
sqlsaber_agent.database_name
|
|
154
|
-
)
|
|
155
|
-
parts = [p for p in (instructions, mem) if p and str(p).strip()]
|
|
156
|
-
if parts:
|
|
157
|
-
injected = "\n\n".join(parts)
|
|
147
|
+
injected = sqlsaber_agent.system_prompt_text(include_memory=True)
|
|
148
|
+
if injected and str(injected).strip():
|
|
158
149
|
prepared_prompt = [injected, user_query]
|
|
159
150
|
|
|
160
151
|
# Show a transient status until events start streaming
|
sqlsaber/cli/threads.py
CHANGED
|
@@ -170,7 +170,7 @@ def _render_transcript(
|
|
|
170
170
|
Panel.fit(
|
|
171
171
|
content_str,
|
|
172
172
|
title=f"Tool result: {name}",
|
|
173
|
-
border_style="
|
|
173
|
+
border_style="warning",
|
|
174
174
|
)
|
|
175
175
|
)
|
|
176
176
|
except Exception:
|
|
@@ -183,7 +183,7 @@ def _render_transcript(
|
|
|
183
183
|
Panel.fit(
|
|
184
184
|
content_str,
|
|
185
185
|
title=f"Tool result: {name}",
|
|
186
|
-
border_style="
|
|
186
|
+
border_style="warning",
|
|
187
187
|
)
|
|
188
188
|
)
|
|
189
189
|
else:
|
|
@@ -194,7 +194,7 @@ def _render_transcript(
|
|
|
194
194
|
Panel.fit(
|
|
195
195
|
content_str,
|
|
196
196
|
title=f"Tool result: {name}",
|
|
197
|
-
border_style="
|
|
197
|
+
border_style="warning",
|
|
198
198
|
)
|
|
199
199
|
)
|
|
200
200
|
# Thinking parts omitted
|
sqlsaber/config/api_keys.py
CHANGED
|
@@ -36,7 +36,7 @@ class APIKeyManager:
|
|
|
36
36
|
return api_key
|
|
37
37
|
except Exception as e:
|
|
38
38
|
# Keyring access failed, continue to prompt
|
|
39
|
-
console.print(f"Keyring access failed: {e}", style="
|
|
39
|
+
console.print(f"Keyring access failed: {e}", style="muted warning")
|
|
40
40
|
|
|
41
41
|
# 3. Prompt user for API key
|
|
42
42
|
return self._prompt_and_store_key(provider, env_var_name, service_name)
|
|
@@ -72,7 +72,7 @@ class APIKeyManager:
|
|
|
72
72
|
if not api_key.strip():
|
|
73
73
|
console.print(
|
|
74
74
|
"No API key provided. Some functionality may not work.",
|
|
75
|
-
style="
|
|
75
|
+
style="warning",
|
|
76
76
|
)
|
|
77
77
|
return None
|
|
78
78
|
|
|
@@ -83,16 +83,16 @@ class APIKeyManager:
|
|
|
83
83
|
except Exception as e:
|
|
84
84
|
console.print(
|
|
85
85
|
f"Warning: Could not store API key in your operating system's credentials store: {e}",
|
|
86
|
-
style="
|
|
86
|
+
style="warning",
|
|
87
87
|
)
|
|
88
88
|
console.print(
|
|
89
|
-
"You may need to enter it again next time", style="
|
|
89
|
+
"You may need to enter it again next time", style="warning"
|
|
90
90
|
)
|
|
91
91
|
|
|
92
92
|
return api_key.strip()
|
|
93
93
|
|
|
94
94
|
except KeyboardInterrupt:
|
|
95
|
-
console.print("\nOperation cancelled", style="
|
|
95
|
+
console.print("\nOperation cancelled", style="warning")
|
|
96
96
|
return None
|
|
97
97
|
except Exception as e:
|
|
98
98
|
console.print(f"Error prompting for API key: {e}", style="red")
|
sqlsaber/config/oauth_flow.py
CHANGED
|
@@ -131,7 +131,7 @@ class AnthropicOAuthFlow:
|
|
|
131
131
|
if not questionary.confirm(
|
|
132
132
|
"Continue with browser-based authentication?", default=True
|
|
133
133
|
).ask():
|
|
134
|
-
console.print("[
|
|
134
|
+
console.print("[warning]Authentication cancelled.[/warning]")
|
|
135
135
|
return False
|
|
136
136
|
|
|
137
137
|
try:
|
|
@@ -168,7 +168,7 @@ class AnthropicOAuthFlow:
|
|
|
168
168
|
).ask()
|
|
169
169
|
|
|
170
170
|
if not auth_code:
|
|
171
|
-
console.print("[
|
|
171
|
+
console.print("[warning]Authentication cancelled.[/warning]")
|
|
172
172
|
return False
|
|
173
173
|
|
|
174
174
|
# Step 2: Exchange code for tokens
|
|
@@ -198,23 +198,23 @@ class AnthropicOAuthFlow:
|
|
|
198
198
|
)
|
|
199
199
|
|
|
200
200
|
if self.token_manager.store_oauth_token("anthropic", oauth_token):
|
|
201
|
-
console.print(
|
|
202
|
-
"\n[bold green]✓ Authentication successful![/bold green]"
|
|
203
|
-
)
|
|
201
|
+
console.print("\n[success]✓ Authentication successful![/success]")
|
|
204
202
|
console.print(
|
|
205
203
|
"Your Claude Pro/Max subscription is now configured for SQLSaber."
|
|
206
204
|
)
|
|
207
205
|
return True
|
|
208
206
|
else:
|
|
209
|
-
console.print(
|
|
207
|
+
console.print(
|
|
208
|
+
"[error]✗ Failed to store authentication tokens.[/error]"
|
|
209
|
+
)
|
|
210
210
|
return False
|
|
211
211
|
|
|
212
212
|
except KeyboardInterrupt:
|
|
213
|
-
console.print("\n[
|
|
213
|
+
console.print("\n[warning]Authentication cancelled by user.[/warning]")
|
|
214
214
|
return False
|
|
215
215
|
except Exception as e:
|
|
216
216
|
logger.error(f"OAuth authentication failed: {e}")
|
|
217
|
-
console.print(f"[
|
|
217
|
+
console.print(f"[error]✗ Authentication failed: {str(e)}[/error]")
|
|
218
218
|
return False
|
|
219
219
|
|
|
220
220
|
def refresh_token_if_needed(self) -> OAuthToken | None:
|
|
@@ -255,13 +255,14 @@ class AnthropicOAuthFlow:
|
|
|
255
255
|
console.print("OAuth token refreshed successfully", style="green")
|
|
256
256
|
return refreshed_token
|
|
257
257
|
else:
|
|
258
|
-
console.print("Failed to store refreshed token", style="
|
|
258
|
+
console.print("Failed to store refreshed token", style="warning")
|
|
259
259
|
return current_token
|
|
260
260
|
|
|
261
261
|
except Exception as e:
|
|
262
262
|
logger.warning(f"Token refresh failed: {e}")
|
|
263
263
|
console.print(
|
|
264
|
-
"Token refresh failed. You may need to re-authenticate.",
|
|
264
|
+
"Token refresh failed. You may need to re-authenticate.",
|
|
265
|
+
style="warning",
|
|
265
266
|
)
|
|
266
267
|
return current_token
|
|
267
268
|
|
sqlsaber/config/oauth_tokens.py
CHANGED
|
@@ -97,14 +97,14 @@ class OAuthTokenManager:
|
|
|
97
97
|
if token.is_expired():
|
|
98
98
|
console.print(
|
|
99
99
|
f"OAuth token for {provider} has expired and needs refresh",
|
|
100
|
-
style="
|
|
100
|
+
style="muted warning",
|
|
101
101
|
)
|
|
102
102
|
return token # Return anyway for refresh attempt
|
|
103
103
|
|
|
104
104
|
if token.expires_soon():
|
|
105
105
|
console.print(
|
|
106
106
|
f"OAuth token for {provider} expires soon, consider refreshing",
|
|
107
|
-
style="
|
|
107
|
+
style="muted warning",
|
|
108
108
|
)
|
|
109
109
|
|
|
110
110
|
return token
|
|
@@ -126,7 +126,7 @@ class OAuthTokenManager:
|
|
|
126
126
|
logger.error(f"Failed to store OAuth token for {provider}: {e}")
|
|
127
127
|
console.print(
|
|
128
128
|
f"Warning: Could not store OAuth token in keyring: {e}",
|
|
129
|
-
style="
|
|
129
|
+
style="warning",
|
|
130
130
|
)
|
|
131
131
|
return False
|
|
132
132
|
|
|
@@ -137,7 +137,7 @@ class OAuthTokenManager:
|
|
|
137
137
|
existing_token = self.get_oauth_token(provider)
|
|
138
138
|
if not existing_token:
|
|
139
139
|
console.print(
|
|
140
|
-
f"No existing OAuth token found for {provider}", style="
|
|
140
|
+
f"No existing OAuth token found for {provider}", style="warning"
|
|
141
141
|
)
|
|
142
142
|
return False
|
|
143
143
|
|
|
@@ -161,7 +161,9 @@ class OAuthTokenManager:
|
|
|
161
161
|
return True
|
|
162
162
|
except Exception as e:
|
|
163
163
|
logger.error(f"Failed to remove OAuth token for {provider}: {e}")
|
|
164
|
-
console.print(
|
|
164
|
+
console.print(
|
|
165
|
+
f"Warning: Could not remove OAuth token: {e}", style="warning"
|
|
166
|
+
)
|
|
165
167
|
return False
|
|
166
168
|
|
|
167
169
|
def has_oauth_token(self, provider: str) -> bool:
|
sqlsaber/database/schema.py
CHANGED
|
@@ -158,7 +158,7 @@ class SchemaManager:
|
|
|
158
158
|
table["schema"] = table["table_schema"]
|
|
159
159
|
table["type"] = table["table_type"] # Map table_type to type for display
|
|
160
160
|
|
|
161
|
-
return {"tables": tables_list}
|
|
161
|
+
return {"tables": tables_list, "total_tables": len(tables_list)}
|
|
162
162
|
|
|
163
163
|
async def close(self):
|
|
164
164
|
"""Close database connection."""
|
sqlsaber/theme/manager.py
CHANGED
|
@@ -9,7 +9,7 @@ from typing import Dict
|
|
|
9
9
|
from platformdirs import user_config_dir
|
|
10
10
|
from prompt_toolkit.styles import Style as PTStyle
|
|
11
11
|
from prompt_toolkit.styles.pygments import style_from_pygments_cls
|
|
12
|
-
from pygments.styles import get_style_by_name
|
|
12
|
+
from pygments.styles import get_all_styles, get_style_by_name
|
|
13
13
|
from pygments.token import Token
|
|
14
14
|
from pygments.util import ClassNotFound
|
|
15
15
|
from rich.console import Console
|
|
@@ -18,14 +18,6 @@ from rich.theme import Theme
|
|
|
18
18
|
DEFAULT_THEME_NAME = "nord"
|
|
19
19
|
|
|
20
20
|
DEFAULT_ROLE_PALETTE = {
|
|
21
|
-
# base roles
|
|
22
|
-
"primary": "cyan",
|
|
23
|
-
"accent": "magenta",
|
|
24
|
-
"success": "green",
|
|
25
|
-
"warning": "yellow",
|
|
26
|
-
"error": "red",
|
|
27
|
-
"info": "cyan",
|
|
28
|
-
"muted": "dim",
|
|
29
21
|
# components
|
|
30
22
|
"table.header": "bold $primary",
|
|
31
23
|
"panel.border.user": "$info",
|
|
@@ -208,6 +200,9 @@ def get_theme_manager() -> ThemeManager:
|
|
|
208
200
|
user_cfg = _load_user_theme_config()
|
|
209
201
|
env_name = os.getenv("SQLSABER_THEME")
|
|
210
202
|
|
|
203
|
+
if env_name and env_name.lower() not in get_all_styles():
|
|
204
|
+
env_name = None
|
|
205
|
+
|
|
211
206
|
name = (
|
|
212
207
|
env_name or user_cfg.get("theme", {}).get("name") or DEFAULT_THEME_NAME
|
|
213
208
|
).lower()
|
sqlsaber/tools/__init__.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
"""SQLSaber tools module."""
|
|
2
2
|
|
|
3
3
|
from .base import Tool
|
|
4
|
-
from .enums import ToolCategory, WorkflowPosition
|
|
5
|
-
from .instructions import InstructionBuilder
|
|
6
4
|
from .registry import ToolRegistry, register_tool, tool_registry
|
|
7
5
|
|
|
8
6
|
# Import concrete tools to register them
|
|
@@ -10,12 +8,9 @@ from .sql_tools import ExecuteSQLTool, IntrospectSchemaTool, ListTablesTool, SQL
|
|
|
10
8
|
|
|
11
9
|
__all__ = [
|
|
12
10
|
"Tool",
|
|
13
|
-
"ToolCategory",
|
|
14
|
-
"WorkflowPosition",
|
|
15
11
|
"ToolRegistry",
|
|
16
12
|
"tool_registry",
|
|
17
13
|
"register_tool",
|
|
18
|
-
"InstructionBuilder",
|
|
19
14
|
"SQLTool",
|
|
20
15
|
"ListTablesTool",
|
|
21
16
|
"IntrospectSchemaTool",
|
sqlsaber/tools/base.py
CHANGED
|
@@ -3,8 +3,6 @@
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from typing import Any
|
|
5
5
|
|
|
6
|
-
from .enums import ToolCategory, WorkflowPosition
|
|
7
|
-
|
|
8
6
|
|
|
9
7
|
class Tool(ABC):
|
|
10
8
|
"""Abstract base class for all tools."""
|
|
@@ -42,32 +40,3 @@ class Tool(ABC):
|
|
|
42
40
|
JSON string with the tool's output
|
|
43
41
|
"""
|
|
44
42
|
pass
|
|
45
|
-
|
|
46
|
-
@property
|
|
47
|
-
def category(self) -> ToolCategory:
|
|
48
|
-
"""Return the tool category. Override to customize."""
|
|
49
|
-
return ToolCategory.GENERAL
|
|
50
|
-
|
|
51
|
-
def get_usage_instructions(self) -> str | None:
|
|
52
|
-
"""Return tool-specific usage instructions for LLM guidance.
|
|
53
|
-
|
|
54
|
-
Returns:
|
|
55
|
-
Usage instructions string, or None for no specific guidance
|
|
56
|
-
"""
|
|
57
|
-
return None
|
|
58
|
-
|
|
59
|
-
def get_priority(self) -> int:
|
|
60
|
-
"""Return priority for tool ordering in instructions.
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
Priority number (lower = higher priority, default = 100)
|
|
64
|
-
"""
|
|
65
|
-
return 100
|
|
66
|
-
|
|
67
|
-
def get_workflow_position(self) -> WorkflowPosition:
|
|
68
|
-
"""Return the typical workflow position for this tool.
|
|
69
|
-
|
|
70
|
-
Returns:
|
|
71
|
-
WorkflowPosition enum value
|
|
72
|
-
"""
|
|
73
|
-
return WorkflowPosition.OTHER
|
sqlsaber/tools/registry.py
CHANGED
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
from typing import Type
|
|
4
4
|
|
|
5
5
|
from .base import Tool
|
|
6
|
-
from .enums import ToolCategory
|
|
7
6
|
|
|
8
7
|
|
|
9
8
|
class ToolRegistry:
|
|
@@ -61,45 +60,13 @@ class ToolRegistry:
|
|
|
61
60
|
|
|
62
61
|
return self._instances[name]
|
|
63
62
|
|
|
64
|
-
def list_tools(self
|
|
65
|
-
"""List all registered tool names.
|
|
63
|
+
def list_tools(self) -> list[str]:
|
|
64
|
+
"""List all registered tool names."""
|
|
65
|
+
return list(self._tools.keys())
|
|
66
66
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
Returns:
|
|
71
|
-
List of tool names
|
|
72
|
-
"""
|
|
73
|
-
if category is None:
|
|
74
|
-
return list(self._tools.keys())
|
|
75
|
-
|
|
76
|
-
# Convert string to enum
|
|
77
|
-
if isinstance(category, str):
|
|
78
|
-
try:
|
|
79
|
-
category = ToolCategory(category)
|
|
80
|
-
except ValueError:
|
|
81
|
-
# If string doesn't match any enum, return empty list
|
|
82
|
-
return []
|
|
83
|
-
|
|
84
|
-
# Filter by category
|
|
85
|
-
result = []
|
|
86
|
-
for name, tool_class in self._tools.items():
|
|
87
|
-
tool = self.get_tool(name)
|
|
88
|
-
if tool.category == category:
|
|
89
|
-
result.append(name)
|
|
90
|
-
return result
|
|
91
|
-
|
|
92
|
-
def get_all_tools(self, category: str | ToolCategory | None = None) -> list[Tool]:
|
|
93
|
-
"""Get all tool instances.
|
|
94
|
-
|
|
95
|
-
Args:
|
|
96
|
-
category: Optional category to filter by (string or ToolCategory enum)
|
|
97
|
-
|
|
98
|
-
Returns:
|
|
99
|
-
List of tool instances
|
|
100
|
-
"""
|
|
101
|
-
names = self.list_tools(category)
|
|
102
|
-
return [self.get_tool(name) for name in names]
|
|
67
|
+
def get_all_tools(self) -> list[Tool]:
|
|
68
|
+
"""Get all tool instances."""
|
|
69
|
+
return [self.get_tool(name) for name in self.list_tools()]
|
|
103
70
|
|
|
104
71
|
|
|
105
72
|
# Global registry instance
|