sqlsaber 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/__init__.py +2 -4
- sqlsaber/agents/base.py +18 -221
- sqlsaber/agents/mcp.py +2 -2
- sqlsaber/agents/pydantic_ai_agent.py +170 -0
- sqlsaber/cli/auth.py +146 -79
- sqlsaber/cli/commands.py +22 -7
- sqlsaber/cli/database.py +1 -1
- sqlsaber/cli/interactive.py +65 -30
- sqlsaber/cli/models.py +58 -29
- sqlsaber/cli/streaming.py +114 -77
- sqlsaber/config/api_keys.py +9 -11
- sqlsaber/config/providers.py +116 -0
- sqlsaber/config/settings.py +50 -30
- sqlsaber/database/connection.py +3 -3
- sqlsaber/mcp/mcp.py +43 -51
- sqlsaber/models/__init__.py +0 -3
- sqlsaber/tools/__init__.py +25 -0
- sqlsaber/tools/base.py +85 -0
- sqlsaber/tools/enums.py +21 -0
- sqlsaber/tools/instructions.py +251 -0
- sqlsaber/tools/registry.py +130 -0
- sqlsaber/tools/sql_tools.py +275 -0
- sqlsaber/tools/visualization_tools.py +144 -0
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/METADATA +20 -39
- sqlsaber-0.16.0.dist-info/RECORD +51 -0
- sqlsaber/agents/anthropic.py +0 -579
- sqlsaber/agents/streaming.py +0 -16
- sqlsaber/clients/__init__.py +0 -6
- sqlsaber/clients/anthropic.py +0 -285
- sqlsaber/clients/base.py +0 -31
- sqlsaber/clients/exceptions.py +0 -117
- sqlsaber/clients/models.py +0 -282
- sqlsaber/clients/streaming.py +0 -257
- sqlsaber/models/events.py +0 -28
- sqlsaber-0.14.0.dist-info/RECORD +0 -51
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/cli/models.py
CHANGED
|
@@ -3,12 +3,13 @@
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import sys
|
|
5
5
|
|
|
6
|
+
import cyclopts
|
|
6
7
|
import httpx
|
|
7
8
|
import questionary
|
|
8
|
-
import cyclopts
|
|
9
9
|
from rich.console import Console
|
|
10
10
|
from rich.table import Table
|
|
11
11
|
|
|
12
|
+
from sqlsaber.config import providers
|
|
12
13
|
from sqlsaber.config.settings import Config
|
|
13
14
|
|
|
14
15
|
# Global instances for CLI commands
|
|
@@ -26,49 +27,75 @@ class ModelManager:
|
|
|
26
27
|
|
|
27
28
|
DEFAULT_MODEL = "anthropic:claude-sonnet-4-20250514"
|
|
28
29
|
MODELS_API_URL = "https://models.dev/api.json"
|
|
30
|
+
# Providers come from central registry
|
|
31
|
+
SUPPORTED_PROVIDERS = providers.all_keys()
|
|
32
|
+
|
|
33
|
+
async def fetch_available_models(
|
|
34
|
+
self, providers: list[str] | None = None
|
|
35
|
+
) -> list[dict]:
|
|
36
|
+
"""Fetch available models across providers from models.dev API.
|
|
29
37
|
|
|
30
|
-
|
|
31
|
-
"""
|
|
38
|
+
Returns list of dicts with keys: id (provider:model_id), provider, name, description, context_length, knowledge.
|
|
39
|
+
"""
|
|
32
40
|
try:
|
|
33
41
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
34
42
|
response = await client.get(self.MODELS_API_URL)
|
|
35
43
|
response.raise_for_status()
|
|
36
44
|
data = response.json()
|
|
37
45
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
46
|
+
providers = providers or self.SUPPORTED_PROVIDERS
|
|
47
|
+
results: list[dict] = []
|
|
48
|
+
|
|
49
|
+
for provider in providers:
|
|
50
|
+
prov_data = data.get(provider, {})
|
|
51
|
+
models_obj = (
|
|
52
|
+
prov_data.get("models") or prov_data.get("Models") or {}
|
|
53
|
+
)
|
|
54
|
+
if not isinstance(models_obj, dict):
|
|
55
|
+
continue
|
|
56
|
+
for model_id, model_info in models_obj.items():
|
|
57
|
+
formatted_id = f"{provider}:{model_id}"
|
|
58
|
+
# cost
|
|
59
|
+
cost_info = (
|
|
60
|
+
model_info.get("cost", {})
|
|
61
|
+
if isinstance(model_info, dict)
|
|
62
|
+
else {}
|
|
63
|
+
)
|
|
49
64
|
cost_display = ""
|
|
50
|
-
if cost_info:
|
|
65
|
+
if isinstance(cost_info, dict) and cost_info:
|
|
51
66
|
input_cost = cost_info.get("input", 0)
|
|
52
67
|
output_cost = cost_info.get("output", 0)
|
|
53
68
|
cost_display = f"${input_cost}/{output_cost} per 1M tokens"
|
|
69
|
+
# context
|
|
70
|
+
limit_info = (
|
|
71
|
+
model_info.get("limit", {})
|
|
72
|
+
if isinstance(model_info, dict)
|
|
73
|
+
else {}
|
|
74
|
+
)
|
|
75
|
+
context_length = (
|
|
76
|
+
limit_info.get("context", 0)
|
|
77
|
+
if isinstance(limit_info, dict)
|
|
78
|
+
else 0
|
|
79
|
+
)
|
|
54
80
|
|
|
55
|
-
|
|
56
|
-
limit_info = model_info.get("limit", {})
|
|
57
|
-
context_length = limit_info.get("context", 0)
|
|
58
|
-
|
|
59
|
-
anthropic_models.append(
|
|
81
|
+
results.append(
|
|
60
82
|
{
|
|
61
83
|
"id": formatted_id,
|
|
62
|
-
"
|
|
84
|
+
"provider": provider,
|
|
85
|
+
"name": model_info.get("name", model_id)
|
|
86
|
+
if isinstance(model_info, dict)
|
|
87
|
+
else model_id,
|
|
63
88
|
"description": cost_display,
|
|
64
89
|
"context_length": context_length,
|
|
65
|
-
"knowledge": model_info.get("knowledge", "")
|
|
90
|
+
"knowledge": model_info.get("knowledge", "")
|
|
91
|
+
if isinstance(model_info, dict)
|
|
92
|
+
else "",
|
|
66
93
|
}
|
|
67
94
|
)
|
|
68
95
|
|
|
69
|
-
# Sort by
|
|
70
|
-
|
|
71
|
-
return
|
|
96
|
+
# Sort by provider then by name
|
|
97
|
+
results.sort(key=lambda x: (x["provider"], x["name"]))
|
|
98
|
+
return results
|
|
72
99
|
except Exception as e:
|
|
73
100
|
console.print(f"[red]Error fetching models: {e}[/red]")
|
|
74
101
|
return []
|
|
@@ -110,7 +137,8 @@ def list():
|
|
|
110
137
|
)
|
|
111
138
|
return
|
|
112
139
|
|
|
113
|
-
table = Table(title="Available
|
|
140
|
+
table = Table(title="Available Models")
|
|
141
|
+
table.add_column("Provider", style="magenta")
|
|
114
142
|
table.add_column("ID", style="cyan")
|
|
115
143
|
table.add_column("Name", style="green")
|
|
116
144
|
table.add_column("Description", style="white")
|
|
@@ -133,6 +161,7 @@ def list():
|
|
|
133
161
|
)
|
|
134
162
|
|
|
135
163
|
table.add_row(
|
|
164
|
+
model.get("provider", "-"),
|
|
136
165
|
model["id"],
|
|
137
166
|
model["name"],
|
|
138
167
|
description,
|
|
@@ -161,8 +190,9 @@ def set():
|
|
|
161
190
|
# Create choices for questionary
|
|
162
191
|
choices = []
|
|
163
192
|
for model in models:
|
|
164
|
-
# Format: "ID - Name (Description)"
|
|
165
|
-
|
|
193
|
+
# Format: "[provider] ID - Name (Description)"
|
|
194
|
+
prov = model.get("provider", "?")
|
|
195
|
+
choice_text = f"[{prov}] {model['id']} - {model['name']}"
|
|
166
196
|
if model["description"]:
|
|
167
197
|
choice_text += f" ({model['description'][:50]}{'...' if len(model['description']) > 50 else ''})"
|
|
168
198
|
|
|
@@ -179,7 +209,6 @@ def set():
|
|
|
179
209
|
selected_model = await questionary.select(
|
|
180
210
|
"Select a model:",
|
|
181
211
|
choices=choices,
|
|
182
|
-
use_shortcuts=True,
|
|
183
212
|
use_search_filter=True,
|
|
184
213
|
use_jk_keys=False, # Disable j/k keys when using search filter
|
|
185
214
|
default=choices[default_index] if choices else None,
|
sqlsaber/cli/streaming.py
CHANGED
|
@@ -1,100 +1,137 @@
|
|
|
1
|
-
"""Streaming query handling for the CLI."""
|
|
1
|
+
"""Streaming query handling for the CLI (pydantic-ai based)."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
-
|
|
4
|
+
import json
|
|
5
|
+
from typing import AsyncIterable
|
|
6
|
+
|
|
7
|
+
from pydantic_ai import Agent, RunContext
|
|
8
|
+
from pydantic_ai.messages import (
|
|
9
|
+
AgentStreamEvent,
|
|
10
|
+
FunctionToolCallEvent,
|
|
11
|
+
FunctionToolResultEvent,
|
|
12
|
+
PartDeltaEvent,
|
|
13
|
+
PartStartEvent,
|
|
14
|
+
TextPart,
|
|
15
|
+
TextPartDelta,
|
|
16
|
+
ThinkingPart,
|
|
17
|
+
ThinkingPartDelta,
|
|
18
|
+
)
|
|
5
19
|
from rich.console import Console
|
|
6
20
|
|
|
7
|
-
from sqlsaber.agents.base import BaseSQLAgent
|
|
8
21
|
from sqlsaber.cli.display import DisplayManager
|
|
9
22
|
|
|
10
23
|
|
|
11
24
|
class StreamingQueryHandler:
|
|
12
|
-
"""Handles streaming query execution and display."""
|
|
25
|
+
"""Handles streaming query execution and display using pydantic-ai events."""
|
|
13
26
|
|
|
14
27
|
def __init__(self, console: Console):
|
|
15
28
|
self.console = console
|
|
16
29
|
self.display = DisplayManager(console)
|
|
17
30
|
|
|
31
|
+
self.status = self.console.status(
|
|
32
|
+
"[yellow]Crunching data...[/yellow]", spinner="bouncingBall"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
async def _event_stream_handler(
|
|
36
|
+
self, ctx: RunContext, event_stream: AsyncIterable[AgentStreamEvent]
|
|
37
|
+
) -> None:
|
|
38
|
+
async for event in event_stream:
|
|
39
|
+
if isinstance(event, PartStartEvent):
|
|
40
|
+
if isinstance(event.part, (TextPart, ThinkingPart)):
|
|
41
|
+
self.status.stop()
|
|
42
|
+
self.display.show_text_stream(event.part.content)
|
|
43
|
+
|
|
44
|
+
elif isinstance(event, PartDeltaEvent):
|
|
45
|
+
if isinstance(event.delta, (TextPartDelta, ThinkingPartDelta)):
|
|
46
|
+
delta = event.delta.content_delta or ""
|
|
47
|
+
if delta:
|
|
48
|
+
self.status.stop()
|
|
49
|
+
self.display.show_text_stream(delta)
|
|
50
|
+
|
|
51
|
+
elif isinstance(event, FunctionToolCallEvent):
|
|
52
|
+
# Show tool execution start
|
|
53
|
+
self.status.stop()
|
|
54
|
+
args = event.part.args_as_dict()
|
|
55
|
+
self.display.show_newline()
|
|
56
|
+
self.display.show_tool_executing(event.part.tool_name, args)
|
|
57
|
+
|
|
58
|
+
elif isinstance(event, FunctionToolResultEvent):
|
|
59
|
+
self.status.stop()
|
|
60
|
+
# Route tool result to appropriate display
|
|
61
|
+
tool_name = event.result.tool_name
|
|
62
|
+
content = event.result.content
|
|
63
|
+
if tool_name == "list_tables":
|
|
64
|
+
self.display.show_table_list(content)
|
|
65
|
+
elif tool_name == "introspect_schema":
|
|
66
|
+
self.display.show_schema_info(content)
|
|
67
|
+
elif tool_name == "execute_sql":
|
|
68
|
+
try:
|
|
69
|
+
data = json.loads(content)
|
|
70
|
+
if data.get("success") and data.get("results"):
|
|
71
|
+
self.display.show_query_results(data["results"]) # type: ignore[arg-type]
|
|
72
|
+
except json.JSONDecodeError:
|
|
73
|
+
# If not JSON, ignore here
|
|
74
|
+
pass
|
|
75
|
+
elif tool_name == "plot_data":
|
|
76
|
+
self.display.show_plot(
|
|
77
|
+
{"tool_name": tool_name, "result": content, "input": {}}
|
|
78
|
+
)
|
|
79
|
+
|
|
18
80
|
async def execute_streaming_query(
|
|
19
81
|
self,
|
|
20
82
|
user_query: str,
|
|
21
|
-
agent:
|
|
83
|
+
agent: Agent,
|
|
22
84
|
cancellation_token: asyncio.Event | None = None,
|
|
85
|
+
message_history: list | None = None,
|
|
23
86
|
):
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
status = self.console.status(
|
|
27
|
-
"[yellow]Crunching data...[/yellow]", spinner="bouncingBall"
|
|
28
|
-
)
|
|
29
|
-
status.start()
|
|
30
|
-
|
|
87
|
+
self.status.start()
|
|
31
88
|
try:
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
self.display.
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
elif event.type == "error":
|
|
74
|
-
self._stop_status(status)
|
|
75
|
-
self.display.show_error(event.data)
|
|
76
|
-
|
|
89
|
+
# If Anthropic OAuth, inject SQLsaber instructions before the first user prompt
|
|
90
|
+
prepared_prompt: str | list[str] = user_query
|
|
91
|
+
is_oauth = bool(getattr(agent, "_sqlsaber_is_oauth", False))
|
|
92
|
+
no_history = not message_history
|
|
93
|
+
if is_oauth and no_history:
|
|
94
|
+
ib = getattr(agent, "_sqlsaber_instruction_builder", None)
|
|
95
|
+
mm = getattr(agent, "_sqlsaber_memory_manager", None)
|
|
96
|
+
db_type = getattr(agent, "_sqlsaber_db_type", "database")
|
|
97
|
+
db_name = getattr(agent, "_sqlsaber_database_name", None)
|
|
98
|
+
instructions = (
|
|
99
|
+
ib.build_instructions(db_type=db_type) if ib is not None else ""
|
|
100
|
+
)
|
|
101
|
+
mem = (
|
|
102
|
+
mm.format_memories_for_prompt(db_name)
|
|
103
|
+
if (mm is not None and db_name)
|
|
104
|
+
else ""
|
|
105
|
+
)
|
|
106
|
+
parts = [p for p in (instructions, mem) if p and str(p).strip()]
|
|
107
|
+
if parts:
|
|
108
|
+
injected = "\n\n".join(parts)
|
|
109
|
+
prepared_prompt = [injected, user_query]
|
|
110
|
+
|
|
111
|
+
# Run the agent with our event stream handler
|
|
112
|
+
run = await agent.run(
|
|
113
|
+
prepared_prompt,
|
|
114
|
+
message_history=message_history,
|
|
115
|
+
event_stream_handler=self._event_stream_handler,
|
|
116
|
+
)
|
|
117
|
+
# After the run completes, show the assistant's final text as markdown if available
|
|
118
|
+
try:
|
|
119
|
+
output = run.output
|
|
120
|
+
if isinstance(output, str) and output.strip():
|
|
121
|
+
self.display.show_newline()
|
|
122
|
+
self.display.show_markdown_response(
|
|
123
|
+
[{"type": "text", "text": output}]
|
|
124
|
+
)
|
|
125
|
+
except Exception as e:
|
|
126
|
+
self.display.show_error(str(e))
|
|
127
|
+
self.display.show_newline()
|
|
128
|
+
return run
|
|
77
129
|
except asyncio.CancelledError:
|
|
78
|
-
# Handle cancellation gracefully
|
|
79
|
-
self._stop_status(status)
|
|
80
130
|
self.display.show_newline()
|
|
81
131
|
self.console.print("[yellow]Query interrupted[/yellow]")
|
|
82
|
-
return
|
|
132
|
+
return None
|
|
83
133
|
finally:
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
if hasattr(agent, "conversation_history") and agent.conversation_history:
|
|
89
|
-
last_message = agent.conversation_history[-1]
|
|
90
|
-
if last_message.get("role") == "assistant" and last_message.get(
|
|
91
|
-
"content"
|
|
92
|
-
):
|
|
93
|
-
self.display.show_markdown_response(last_message["content"])
|
|
94
|
-
|
|
95
|
-
def _stop_status(self, status):
|
|
96
|
-
"""Safely stop a status spinner."""
|
|
97
|
-
try:
|
|
98
|
-
status.stop()
|
|
99
|
-
except Exception:
|
|
100
|
-
pass # Status might already be stopped
|
|
134
|
+
try:
|
|
135
|
+
self.status.stop()
|
|
136
|
+
except Exception:
|
|
137
|
+
pass
|
sqlsaber/config/api_keys.py
CHANGED
|
@@ -6,6 +6,8 @@ import os
|
|
|
6
6
|
import keyring
|
|
7
7
|
from rich.console import Console
|
|
8
8
|
|
|
9
|
+
from sqlsaber.config import providers
|
|
10
|
+
|
|
9
11
|
console = Console()
|
|
10
12
|
|
|
11
13
|
|
|
@@ -30,9 +32,7 @@ class APIKeyManager:
|
|
|
30
32
|
try:
|
|
31
33
|
api_key = keyring.get_password(service_name, provider)
|
|
32
34
|
if api_key:
|
|
33
|
-
console.print(
|
|
34
|
-
f"Using stored {provider} API key from keyring", style="dim"
|
|
35
|
-
)
|
|
35
|
+
console.print(f"Using stored {provider} API key", style="dim")
|
|
36
36
|
return api_key
|
|
37
37
|
except Exception as e:
|
|
38
38
|
# Keyring access failed, continue to prompt
|
|
@@ -43,12 +43,9 @@ class APIKeyManager:
|
|
|
43
43
|
|
|
44
44
|
def _get_env_var_name(self, provider: str) -> str:
|
|
45
45
|
"""Get the expected environment variable name for a provider."""
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
return "ANTHROPIC_API_KEY"
|
|
50
|
-
else:
|
|
51
|
-
return "AI_API_KEY"
|
|
46
|
+
# Normalize aliases to canonical provider keys
|
|
47
|
+
key = providers.canonical(provider) or provider
|
|
48
|
+
return providers.env_var_name(key)
|
|
52
49
|
|
|
53
50
|
def _get_service_name(self, provider: str) -> str:
|
|
54
51
|
"""Get the keyring service name for a provider."""
|
|
@@ -60,7 +57,7 @@ class APIKeyManager:
|
|
|
60
57
|
"""Prompt user for API key and store it in keyring."""
|
|
61
58
|
try:
|
|
62
59
|
console.print(
|
|
63
|
-
f"\n{provider.title()} API key not found in environment or
|
|
60
|
+
f"\n{provider.title()} API key not found in environment or your OS's credentials store."
|
|
64
61
|
)
|
|
65
62
|
console.print("You can either:")
|
|
66
63
|
console.print(f" 1. Set the {env_var_name} environment variable")
|
|
@@ -85,7 +82,8 @@ class APIKeyManager:
|
|
|
85
82
|
console.print("API key stored securely for future use", style="green")
|
|
86
83
|
except Exception as e:
|
|
87
84
|
console.print(
|
|
88
|
-
f"Warning: Could not store API key in
|
|
85
|
+
f"Warning: Could not store API key in your operating system's credentials store: {e}",
|
|
86
|
+
style="yellow",
|
|
89
87
|
)
|
|
90
88
|
console.print(
|
|
91
89
|
"You may need to enter it again next time", style="yellow"
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Central registry for supported AI providers.
|
|
2
|
+
|
|
3
|
+
This module defines a single source of truth for providers used across the
|
|
4
|
+
codebase (CLI, config, agents). Update this file to add or modify providers.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import Dict, Iterable, List, Optional
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class ProviderSpec:
|
|
15
|
+
"""Specification for a provider."""
|
|
16
|
+
|
|
17
|
+
key: str
|
|
18
|
+
env_var: str
|
|
19
|
+
supports_oauth: bool = False
|
|
20
|
+
aliases: tuple[str, ...] = ()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Ordered definition -> used for CLI display order
|
|
24
|
+
_PROVIDERS: List[ProviderSpec] = [
|
|
25
|
+
ProviderSpec(
|
|
26
|
+
key="anthropic",
|
|
27
|
+
env_var="ANTHROPIC_API_KEY",
|
|
28
|
+
supports_oauth=True,
|
|
29
|
+
aliases=(),
|
|
30
|
+
),
|
|
31
|
+
ProviderSpec(
|
|
32
|
+
key="openai",
|
|
33
|
+
env_var="OPENAI_API_KEY",
|
|
34
|
+
aliases=(),
|
|
35
|
+
),
|
|
36
|
+
ProviderSpec(
|
|
37
|
+
key="google",
|
|
38
|
+
env_var="GOOGLE_API_KEY",
|
|
39
|
+
# Historically some model IDs start with "google-gla"; treat as alias
|
|
40
|
+
aliases=("google-gla",),
|
|
41
|
+
),
|
|
42
|
+
ProviderSpec(
|
|
43
|
+
key="groq",
|
|
44
|
+
env_var="GROQ_API_KEY",
|
|
45
|
+
aliases=(),
|
|
46
|
+
),
|
|
47
|
+
ProviderSpec(
|
|
48
|
+
key="mistral",
|
|
49
|
+
env_var="MISTRAL_API_KEY",
|
|
50
|
+
aliases=(),
|
|
51
|
+
),
|
|
52
|
+
ProviderSpec(
|
|
53
|
+
key="cohere",
|
|
54
|
+
env_var="COHERE_API_KEY",
|
|
55
|
+
aliases=(),
|
|
56
|
+
),
|
|
57
|
+
ProviderSpec(
|
|
58
|
+
key="huggingface",
|
|
59
|
+
env_var="HUGGINGFACE_API_KEY",
|
|
60
|
+
aliases=(),
|
|
61
|
+
),
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# Fast lookup maps
|
|
66
|
+
_BY_KEY: Dict[str, ProviderSpec] = {p.key: p for p in _PROVIDERS}
|
|
67
|
+
_ALIAS_TO_KEY: Dict[str, str] = {
|
|
68
|
+
alias: p.key for p in _PROVIDERS for alias in p.aliases
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def all_keys() -> List[str]:
|
|
73
|
+
"""Return provider keys in display order."""
|
|
74
|
+
return [p.key for p in _PROVIDERS]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def env_var_name(key: str) -> str:
|
|
78
|
+
"""Return the expected environment variable for a provider.
|
|
79
|
+
|
|
80
|
+
Falls back to a generic name if the provider is unknown.
|
|
81
|
+
"""
|
|
82
|
+
spec = _BY_KEY.get(key)
|
|
83
|
+
return spec.env_var if spec else "AI_API_KEY"
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def supports_oauth(key: str) -> bool:
|
|
87
|
+
"""Return True if the provider supports OAuth in SQLsaber."""
|
|
88
|
+
spec = _BY_KEY.get(key)
|
|
89
|
+
return bool(spec and spec.supports_oauth)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def canonical(key_or_alias: str) -> Optional[str]:
|
|
93
|
+
"""Return the canonical provider key for a provider or alias.
|
|
94
|
+
|
|
95
|
+
Returns None if not recognized.
|
|
96
|
+
"""
|
|
97
|
+
if key_or_alias in _BY_KEY:
|
|
98
|
+
return key_or_alias
|
|
99
|
+
return _ALIAS_TO_KEY.get(key_or_alias)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def provider_from_model(model_name: str) -> Optional[str]:
|
|
103
|
+
"""Infer the canonical provider key from a model identifier.
|
|
104
|
+
|
|
105
|
+
Accepts either "provider:model_id" or a bare provider string. Aliases are
|
|
106
|
+
normalized to their canonical provider key.
|
|
107
|
+
"""
|
|
108
|
+
if not model_name:
|
|
109
|
+
return None
|
|
110
|
+
provider_raw = model_name.split(":", 1)[0]
|
|
111
|
+
return canonical(provider_raw)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def specs() -> Iterable[ProviderSpec]:
|
|
115
|
+
"""Iterate provider specifications (in display order)."""
|
|
116
|
+
return tuple(_PROVIDERS)
|
sqlsaber/config/settings.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Any
|
|
|
9
9
|
|
|
10
10
|
import platformdirs
|
|
11
11
|
|
|
12
|
+
from sqlsaber.config import providers
|
|
12
13
|
from sqlsaber.config.api_keys import APIKeyManager
|
|
13
14
|
from sqlsaber.config.auth import AuthConfigManager, AuthMethod
|
|
14
15
|
from sqlsaber.config.oauth_flow import AnthropicOAuthFlow
|
|
@@ -84,47 +85,66 @@ class Config:
|
|
|
84
85
|
self.model_name = self.model_config_manager.get_model()
|
|
85
86
|
self.api_key_manager = APIKeyManager()
|
|
86
87
|
self.auth_config_manager = AuthConfigManager()
|
|
87
|
-
self.oauth_flow = AnthropicOAuthFlow()
|
|
88
88
|
|
|
89
|
-
#
|
|
89
|
+
# Authentication method (API key or Anthropic OAuth)
|
|
90
90
|
self.auth_method = self.auth_config_manager.get_auth_method()
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
token = self.oauth_flow.refresh_token_if_needed()
|
|
98
|
-
if token:
|
|
99
|
-
self.oauth_token = token.access_token
|
|
100
|
-
except Exception:
|
|
101
|
-
# OAuth token unavailable, will need to re-authenticate
|
|
102
|
-
pass
|
|
91
|
+
|
|
92
|
+
# Optional Anthropic OAuth access token (only relevant for provider=='anthropic')
|
|
93
|
+
if self.auth_method == AuthMethod.CLAUDE_PRO and self.model_name.startswith(
|
|
94
|
+
"anthropic"
|
|
95
|
+
):
|
|
96
|
+
self.oauth_token = self.get_oauth_access_token()
|
|
103
97
|
else:
|
|
104
|
-
# Use API key authentication (default or explicitly configured)
|
|
105
98
|
self.api_key = self._get_api_key()
|
|
99
|
+
# self.oauth_token = None
|
|
106
100
|
|
|
107
101
|
def _get_api_key(self) -> str | None:
|
|
108
102
|
"""Get API key for the model provider using cascading logic."""
|
|
109
|
-
model = self.model_name
|
|
110
|
-
|
|
111
|
-
|
|
103
|
+
model = self.model_name or ""
|
|
104
|
+
prov = providers.provider_from_model(model)
|
|
105
|
+
if prov in set(providers.all_keys()):
|
|
106
|
+
return self.api_key_manager.get_api_key(prov) # type: ignore[arg-type]
|
|
107
|
+
return None
|
|
112
108
|
|
|
113
109
|
def set_model(self, model: str) -> None:
|
|
114
110
|
"""Set the model and update configuration."""
|
|
115
111
|
self.model_config_manager.set_model(model)
|
|
116
112
|
self.model_name = model
|
|
117
113
|
|
|
114
|
+
def get_oauth_access_token(self) -> str | None:
|
|
115
|
+
"""Return a valid Anthropic OAuth access token if configured, else None.
|
|
116
|
+
|
|
117
|
+
Uses the stored refresh token (if present) to refresh as needed.
|
|
118
|
+
Only relevant when provider is 'anthropic'.
|
|
119
|
+
"""
|
|
120
|
+
if not self.model_name.startswith("anthropic"):
|
|
121
|
+
return None
|
|
122
|
+
try:
|
|
123
|
+
flow = AnthropicOAuthFlow()
|
|
124
|
+
token = flow.refresh_token_if_needed()
|
|
125
|
+
return token.access_token if token else None
|
|
126
|
+
except Exception:
|
|
127
|
+
return None
|
|
128
|
+
|
|
118
129
|
def validate(self):
|
|
119
|
-
"""Validate that necessary configuration is present.
|
|
120
|
-
|
|
121
|
-
if
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
130
|
+
"""Validate that necessary configuration is present.
|
|
131
|
+
|
|
132
|
+
Also ensure provider env var is set from keyring if needed for API-key flows.
|
|
133
|
+
"""
|
|
134
|
+
model = self.model_name or ""
|
|
135
|
+
provider_key = providers.provider_from_model(model)
|
|
136
|
+
env_var = providers.env_var_name(provider_key or "") if provider_key else None
|
|
137
|
+
if env_var:
|
|
138
|
+
# Anthropic special-case: allow OAuth in lieu of API key only when explicitly configured
|
|
139
|
+
if (
|
|
140
|
+
provider_key == "anthropic"
|
|
141
|
+
and self.auth_method == AuthMethod.CLAUDE_PRO
|
|
142
|
+
and self.oauth_token
|
|
143
|
+
):
|
|
144
|
+
return
|
|
145
|
+
# If we don't have a key resolved from env/keyring, raise
|
|
146
|
+
if not self.api_key:
|
|
147
|
+
raise ValueError(f"{provider_key.capitalize()} API key not found.")
|
|
148
|
+
# Hydrate env var for downstream SDKs if missing
|
|
149
|
+
if not os.getenv(env_var):
|
|
150
|
+
os.environ[env_var] = self.api_key
|
sqlsaber/database/connection.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""Database connection management."""
|
|
2
2
|
|
|
3
|
-
from abc import ABC, abstractmethod
|
|
4
|
-
from typing import Any
|
|
5
|
-
from urllib.parse import urlparse, parse_qs
|
|
6
3
|
import ssl
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
7
5
|
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
from urllib.parse import parse_qs, urlparse
|
|
8
8
|
|
|
9
9
|
import aiomysql
|
|
10
10
|
import aiosqlite
|