sqlsaber 0.15.0__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/__init__.py +2 -4
- sqlsaber/agents/base.py +1 -2
- sqlsaber/agents/mcp.py +2 -2
- sqlsaber/agents/pydantic_ai_agent.py +170 -0
- sqlsaber/cli/auth.py +146 -79
- sqlsaber/cli/commands.py +22 -7
- sqlsaber/cli/database.py +1 -1
- sqlsaber/cli/interactive.py +65 -30
- sqlsaber/cli/models.py +58 -29
- sqlsaber/cli/streaming.py +114 -77
- sqlsaber/config/api_keys.py +9 -11
- sqlsaber/config/providers.py +116 -0
- sqlsaber/config/settings.py +50 -30
- sqlsaber/database/connection.py +3 -3
- sqlsaber/models/__init__.py +0 -3
- sqlsaber/tools/base.py +7 -5
- {sqlsaber-0.15.0.dist-info → sqlsaber-0.16.0.dist-info}/METADATA +20 -39
- {sqlsaber-0.15.0.dist-info → sqlsaber-0.16.0.dist-info}/RECORD +21 -28
- sqlsaber/agents/anthropic.py +0 -491
- sqlsaber/agents/streaming.py +0 -16
- sqlsaber/clients/__init__.py +0 -6
- sqlsaber/clients/anthropic.py +0 -285
- sqlsaber/clients/base.py +0 -31
- sqlsaber/clients/exceptions.py +0 -117
- sqlsaber/clients/models.py +0 -282
- sqlsaber/clients/streaming.py +0 -257
- sqlsaber/models/events.py +0 -28
- {sqlsaber-0.15.0.dist-info → sqlsaber-0.16.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.15.0.dist-info → sqlsaber-0.16.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.15.0.dist-info → sqlsaber-0.16.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/cli/interactive.py
CHANGED
|
@@ -3,10 +3,10 @@
|
|
|
3
3
|
import asyncio
|
|
4
4
|
|
|
5
5
|
import questionary
|
|
6
|
+
from pydantic_ai import Agent
|
|
6
7
|
from rich.console import Console
|
|
7
8
|
from rich.panel import Panel
|
|
8
9
|
|
|
9
|
-
from sqlsaber.agents.base import BaseSQLAgent
|
|
10
10
|
from sqlsaber.cli.completers import (
|
|
11
11
|
CompositeCompleter,
|
|
12
12
|
SlashCommandCompleter,
|
|
@@ -14,25 +14,44 @@ from sqlsaber.cli.completers import (
|
|
|
14
14
|
)
|
|
15
15
|
from sqlsaber.cli.display import DisplayManager
|
|
16
16
|
from sqlsaber.cli.streaming import StreamingQueryHandler
|
|
17
|
+
from sqlsaber.database.schema import SchemaManager
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class InteractiveSession:
|
|
20
21
|
"""Manages interactive CLI sessions."""
|
|
21
22
|
|
|
22
|
-
def __init__(self, console: Console, agent:
|
|
23
|
+
def __init__(self, console: Console, agent: Agent, db_conn, database_name: str):
|
|
23
24
|
self.console = console
|
|
24
25
|
self.agent = agent
|
|
26
|
+
self.db_conn = db_conn
|
|
27
|
+
self.database_name = database_name
|
|
25
28
|
self.display = DisplayManager(console)
|
|
26
29
|
self.streaming_handler = StreamingQueryHandler(console)
|
|
27
30
|
self.current_task: asyncio.Task | None = None
|
|
28
31
|
self.cancellation_token: asyncio.Event | None = None
|
|
29
32
|
self.table_completer = TableNameCompleter()
|
|
33
|
+
self.message_history: list | None = []
|
|
30
34
|
|
|
31
35
|
def show_welcome_message(self):
|
|
32
36
|
"""Display welcome message for interactive mode."""
|
|
33
37
|
# Show database information
|
|
34
|
-
db_name =
|
|
35
|
-
|
|
38
|
+
db_name = self.database_name or "Unknown"
|
|
39
|
+
from sqlsaber.database.connection import (
|
|
40
|
+
CSVConnection,
|
|
41
|
+
MySQLConnection,
|
|
42
|
+
PostgreSQLConnection,
|
|
43
|
+
SQLiteConnection,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
db_type = (
|
|
47
|
+
"PostgreSQL"
|
|
48
|
+
if isinstance(self.db_conn, PostgreSQLConnection)
|
|
49
|
+
else "MySQL"
|
|
50
|
+
if isinstance(self.db_conn, MySQLConnection)
|
|
51
|
+
else "SQLite"
|
|
52
|
+
if isinstance(self.db_conn, (SQLiteConnection, CSVConnection))
|
|
53
|
+
else "database"
|
|
54
|
+
)
|
|
36
55
|
|
|
37
56
|
self.console.print(
|
|
38
57
|
Panel.fit(
|
|
@@ -44,26 +63,27 @@ class InteractiveSession:
|
|
|
44
63
|
███████ ██████ ███████ ███████ ██ ██ ██████ ███████ ██ ██
|
|
45
64
|
▀▀
|
|
46
65
|
"""
|
|
47
|
-
"\n\n"
|
|
48
|
-
"[dim]Use '/clear' to reset conversation, '/exit' or '/quit' to leave.[/dim]\n\n"
|
|
49
|
-
"[dim]Start a message with '#' to add something to agent's memory for this database.[/dim]\n\n"
|
|
50
|
-
"[dim]Type '@' to get table name completions.[/dim]",
|
|
51
|
-
border_style="green",
|
|
52
66
|
)
|
|
53
67
|
)
|
|
54
68
|
self.console.print(
|
|
55
|
-
|
|
69
|
+
"\n",
|
|
70
|
+
"[dim] ≥ Use '/clear' to reset conversation",
|
|
71
|
+
"[dim] ≥ Use '/exit' or '/quit' to leave[/dim]",
|
|
72
|
+
"[dim] ≥ Use 'Ctrl+C' to interrupt and return to prompt\n\n",
|
|
73
|
+
"[dim] ≥ Start message with '#' to add something to agent's memory for this database",
|
|
74
|
+
"[dim] ≥ Type '@' to get table name completions",
|
|
75
|
+
"[dim] ≥ Press 'Esc-Enter' or 'Meta-Enter' to submit your question",
|
|
76
|
+
sep="\n",
|
|
56
77
|
)
|
|
78
|
+
|
|
57
79
|
self.console.print(
|
|
58
|
-
"[
|
|
59
|
-
"[dim]Press Ctrl+C during query execution to interrupt and return to prompt.[/dim]\n"
|
|
80
|
+
f"[bold blue]\n\nConnected to:[/bold blue] {db_name} ({db_type})\n"
|
|
60
81
|
)
|
|
61
82
|
|
|
62
83
|
async def _update_table_cache(self):
|
|
63
84
|
"""Update the table completer cache with fresh data."""
|
|
64
85
|
try:
|
|
65
|
-
|
|
66
|
-
tables_data = await self.agent.schema_manager.list_tables()
|
|
86
|
+
tables_data = await SchemaManager(self.db_conn).list_tables()
|
|
67
87
|
|
|
68
88
|
# Parse the table information
|
|
69
89
|
table_list = []
|
|
@@ -100,16 +120,20 @@ class InteractiveSession:
|
|
|
100
120
|
# Create the query task
|
|
101
121
|
query_task = asyncio.create_task(
|
|
102
122
|
self.streaming_handler.execute_streaming_query(
|
|
103
|
-
user_query, self.agent, self.cancellation_token
|
|
123
|
+
user_query, self.agent, self.cancellation_token, self.message_history
|
|
104
124
|
)
|
|
105
125
|
)
|
|
106
126
|
self.current_task = query_task
|
|
107
127
|
|
|
108
128
|
try:
|
|
109
|
-
|
|
110
|
-
#
|
|
111
|
-
|
|
112
|
-
|
|
129
|
+
run_result = await query_task
|
|
130
|
+
# Persist message history from this run using pydantic-ai API
|
|
131
|
+
if run_result is not None:
|
|
132
|
+
try:
|
|
133
|
+
# Use all_messages() so the system prompt and all prior turns are preserved
|
|
134
|
+
self.message_history = run_result.all_messages()
|
|
135
|
+
except Exception:
|
|
136
|
+
pass
|
|
113
137
|
finally:
|
|
114
138
|
self.current_task = None
|
|
115
139
|
self.cancellation_token = None
|
|
@@ -144,7 +168,8 @@ class InteractiveSession:
|
|
|
144
168
|
break
|
|
145
169
|
|
|
146
170
|
if user_query == "/clear":
|
|
147
|
-
|
|
171
|
+
# Reset local history (pydantic-ai call will receive empty history on next run)
|
|
172
|
+
self.message_history = []
|
|
148
173
|
self.console.print("[green]Conversation history cleared.[/green]\n")
|
|
149
174
|
continue
|
|
150
175
|
|
|
@@ -153,18 +178,28 @@ class InteractiveSession:
|
|
|
153
178
|
if memory_text.startswith("#"):
|
|
154
179
|
memory_content = memory_text[1:].strip() # Remove # and trim
|
|
155
180
|
if memory_content:
|
|
156
|
-
# Add memory
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
f"[green]✓ Memory added:[/green] {memory_content}"
|
|
161
|
-
)
|
|
162
|
-
self.console.print(
|
|
163
|
-
f"[dim]Memory ID: {memory_id}[/dim]\n"
|
|
181
|
+
# Add memory via the agent's memory manager
|
|
182
|
+
try:
|
|
183
|
+
mm = getattr(
|
|
184
|
+
self.agent, "_sqlsaber_memory_manager", None
|
|
164
185
|
)
|
|
165
|
-
|
|
186
|
+
if mm and self.database_name:
|
|
187
|
+
memory = mm.add_memory(
|
|
188
|
+
self.database_name, memory_content
|
|
189
|
+
)
|
|
190
|
+
self.console.print(
|
|
191
|
+
f"[green]✓ Memory added:[/green] {memory_content}"
|
|
192
|
+
)
|
|
193
|
+
self.console.print(
|
|
194
|
+
f"[dim]Memory ID: {memory.id}[/dim]\n"
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
self.console.print(
|
|
198
|
+
"[yellow]Could not add memory (no database context)[/yellow]\n"
|
|
199
|
+
)
|
|
200
|
+
except Exception:
|
|
166
201
|
self.console.print(
|
|
167
|
-
"[yellow]Could not add memory
|
|
202
|
+
"[yellow]Could not add memory[/yellow]\n"
|
|
168
203
|
)
|
|
169
204
|
else:
|
|
170
205
|
self.console.print(
|
sqlsaber/cli/models.py
CHANGED
|
@@ -3,12 +3,13 @@
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import sys
|
|
5
5
|
|
|
6
|
+
import cyclopts
|
|
6
7
|
import httpx
|
|
7
8
|
import questionary
|
|
8
|
-
import cyclopts
|
|
9
9
|
from rich.console import Console
|
|
10
10
|
from rich.table import Table
|
|
11
11
|
|
|
12
|
+
from sqlsaber.config import providers
|
|
12
13
|
from sqlsaber.config.settings import Config
|
|
13
14
|
|
|
14
15
|
# Global instances for CLI commands
|
|
@@ -26,49 +27,75 @@ class ModelManager:
|
|
|
26
27
|
|
|
27
28
|
DEFAULT_MODEL = "anthropic:claude-sonnet-4-20250514"
|
|
28
29
|
MODELS_API_URL = "https://models.dev/api.json"
|
|
30
|
+
# Providers come from central registry
|
|
31
|
+
SUPPORTED_PROVIDERS = providers.all_keys()
|
|
32
|
+
|
|
33
|
+
async def fetch_available_models(
|
|
34
|
+
self, providers: list[str] | None = None
|
|
35
|
+
) -> list[dict]:
|
|
36
|
+
"""Fetch available models across providers from models.dev API.
|
|
29
37
|
|
|
30
|
-
|
|
31
|
-
"""
|
|
38
|
+
Returns list of dicts with keys: id (provider:model_id), provider, name, description, context_length, knowledge.
|
|
39
|
+
"""
|
|
32
40
|
try:
|
|
33
41
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
34
42
|
response = await client.get(self.MODELS_API_URL)
|
|
35
43
|
response.raise_for_status()
|
|
36
44
|
data = response.json()
|
|
37
45
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
46
|
+
providers = providers or self.SUPPORTED_PROVIDERS
|
|
47
|
+
results: list[dict] = []
|
|
48
|
+
|
|
49
|
+
for provider in providers:
|
|
50
|
+
prov_data = data.get(provider, {})
|
|
51
|
+
models_obj = (
|
|
52
|
+
prov_data.get("models") or prov_data.get("Models") or {}
|
|
53
|
+
)
|
|
54
|
+
if not isinstance(models_obj, dict):
|
|
55
|
+
continue
|
|
56
|
+
for model_id, model_info in models_obj.items():
|
|
57
|
+
formatted_id = f"{provider}:{model_id}"
|
|
58
|
+
# cost
|
|
59
|
+
cost_info = (
|
|
60
|
+
model_info.get("cost", {})
|
|
61
|
+
if isinstance(model_info, dict)
|
|
62
|
+
else {}
|
|
63
|
+
)
|
|
49
64
|
cost_display = ""
|
|
50
|
-
if cost_info:
|
|
65
|
+
if isinstance(cost_info, dict) and cost_info:
|
|
51
66
|
input_cost = cost_info.get("input", 0)
|
|
52
67
|
output_cost = cost_info.get("output", 0)
|
|
53
68
|
cost_display = f"${input_cost}/{output_cost} per 1M tokens"
|
|
69
|
+
# context
|
|
70
|
+
limit_info = (
|
|
71
|
+
model_info.get("limit", {})
|
|
72
|
+
if isinstance(model_info, dict)
|
|
73
|
+
else {}
|
|
74
|
+
)
|
|
75
|
+
context_length = (
|
|
76
|
+
limit_info.get("context", 0)
|
|
77
|
+
if isinstance(limit_info, dict)
|
|
78
|
+
else 0
|
|
79
|
+
)
|
|
54
80
|
|
|
55
|
-
|
|
56
|
-
limit_info = model_info.get("limit", {})
|
|
57
|
-
context_length = limit_info.get("context", 0)
|
|
58
|
-
|
|
59
|
-
anthropic_models.append(
|
|
81
|
+
results.append(
|
|
60
82
|
{
|
|
61
83
|
"id": formatted_id,
|
|
62
|
-
"
|
|
84
|
+
"provider": provider,
|
|
85
|
+
"name": model_info.get("name", model_id)
|
|
86
|
+
if isinstance(model_info, dict)
|
|
87
|
+
else model_id,
|
|
63
88
|
"description": cost_display,
|
|
64
89
|
"context_length": context_length,
|
|
65
|
-
"knowledge": model_info.get("knowledge", "")
|
|
90
|
+
"knowledge": model_info.get("knowledge", "")
|
|
91
|
+
if isinstance(model_info, dict)
|
|
92
|
+
else "",
|
|
66
93
|
}
|
|
67
94
|
)
|
|
68
95
|
|
|
69
|
-
# Sort by
|
|
70
|
-
|
|
71
|
-
return
|
|
96
|
+
# Sort by provider then by name
|
|
97
|
+
results.sort(key=lambda x: (x["provider"], x["name"]))
|
|
98
|
+
return results
|
|
72
99
|
except Exception as e:
|
|
73
100
|
console.print(f"[red]Error fetching models: {e}[/red]")
|
|
74
101
|
return []
|
|
@@ -110,7 +137,8 @@ def list():
|
|
|
110
137
|
)
|
|
111
138
|
return
|
|
112
139
|
|
|
113
|
-
table = Table(title="Available
|
|
140
|
+
table = Table(title="Available Models")
|
|
141
|
+
table.add_column("Provider", style="magenta")
|
|
114
142
|
table.add_column("ID", style="cyan")
|
|
115
143
|
table.add_column("Name", style="green")
|
|
116
144
|
table.add_column("Description", style="white")
|
|
@@ -133,6 +161,7 @@ def list():
|
|
|
133
161
|
)
|
|
134
162
|
|
|
135
163
|
table.add_row(
|
|
164
|
+
model.get("provider", "-"),
|
|
136
165
|
model["id"],
|
|
137
166
|
model["name"],
|
|
138
167
|
description,
|
|
@@ -161,8 +190,9 @@ def set():
|
|
|
161
190
|
# Create choices for questionary
|
|
162
191
|
choices = []
|
|
163
192
|
for model in models:
|
|
164
|
-
# Format: "ID - Name (Description)"
|
|
165
|
-
|
|
193
|
+
# Format: "[provider] ID - Name (Description)"
|
|
194
|
+
prov = model.get("provider", "?")
|
|
195
|
+
choice_text = f"[{prov}] {model['id']} - {model['name']}"
|
|
166
196
|
if model["description"]:
|
|
167
197
|
choice_text += f" ({model['description'][:50]}{'...' if len(model['description']) > 50 else ''})"
|
|
168
198
|
|
|
@@ -179,7 +209,6 @@ def set():
|
|
|
179
209
|
selected_model = await questionary.select(
|
|
180
210
|
"Select a model:",
|
|
181
211
|
choices=choices,
|
|
182
|
-
use_shortcuts=True,
|
|
183
212
|
use_search_filter=True,
|
|
184
213
|
use_jk_keys=False, # Disable j/k keys when using search filter
|
|
185
214
|
default=choices[default_index] if choices else None,
|
sqlsaber/cli/streaming.py
CHANGED
|
@@ -1,100 +1,137 @@
|
|
|
1
|
-
"""Streaming query handling for the CLI."""
|
|
1
|
+
"""Streaming query handling for the CLI (pydantic-ai based)."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
-
|
|
4
|
+
import json
|
|
5
|
+
from typing import AsyncIterable
|
|
6
|
+
|
|
7
|
+
from pydantic_ai import Agent, RunContext
|
|
8
|
+
from pydantic_ai.messages import (
|
|
9
|
+
AgentStreamEvent,
|
|
10
|
+
FunctionToolCallEvent,
|
|
11
|
+
FunctionToolResultEvent,
|
|
12
|
+
PartDeltaEvent,
|
|
13
|
+
PartStartEvent,
|
|
14
|
+
TextPart,
|
|
15
|
+
TextPartDelta,
|
|
16
|
+
ThinkingPart,
|
|
17
|
+
ThinkingPartDelta,
|
|
18
|
+
)
|
|
5
19
|
from rich.console import Console
|
|
6
20
|
|
|
7
|
-
from sqlsaber.agents.base import BaseSQLAgent
|
|
8
21
|
from sqlsaber.cli.display import DisplayManager
|
|
9
22
|
|
|
10
23
|
|
|
11
24
|
class StreamingQueryHandler:
|
|
12
|
-
"""Handles streaming query execution and display."""
|
|
25
|
+
"""Handles streaming query execution and display using pydantic-ai events."""
|
|
13
26
|
|
|
14
27
|
def __init__(self, console: Console):
|
|
15
28
|
self.console = console
|
|
16
29
|
self.display = DisplayManager(console)
|
|
17
30
|
|
|
31
|
+
self.status = self.console.status(
|
|
32
|
+
"[yellow]Crunching data...[/yellow]", spinner="bouncingBall"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
async def _event_stream_handler(
|
|
36
|
+
self, ctx: RunContext, event_stream: AsyncIterable[AgentStreamEvent]
|
|
37
|
+
) -> None:
|
|
38
|
+
async for event in event_stream:
|
|
39
|
+
if isinstance(event, PartStartEvent):
|
|
40
|
+
if isinstance(event.part, (TextPart, ThinkingPart)):
|
|
41
|
+
self.status.stop()
|
|
42
|
+
self.display.show_text_stream(event.part.content)
|
|
43
|
+
|
|
44
|
+
elif isinstance(event, PartDeltaEvent):
|
|
45
|
+
if isinstance(event.delta, (TextPartDelta, ThinkingPartDelta)):
|
|
46
|
+
delta = event.delta.content_delta or ""
|
|
47
|
+
if delta:
|
|
48
|
+
self.status.stop()
|
|
49
|
+
self.display.show_text_stream(delta)
|
|
50
|
+
|
|
51
|
+
elif isinstance(event, FunctionToolCallEvent):
|
|
52
|
+
# Show tool execution start
|
|
53
|
+
self.status.stop()
|
|
54
|
+
args = event.part.args_as_dict()
|
|
55
|
+
self.display.show_newline()
|
|
56
|
+
self.display.show_tool_executing(event.part.tool_name, args)
|
|
57
|
+
|
|
58
|
+
elif isinstance(event, FunctionToolResultEvent):
|
|
59
|
+
self.status.stop()
|
|
60
|
+
# Route tool result to appropriate display
|
|
61
|
+
tool_name = event.result.tool_name
|
|
62
|
+
content = event.result.content
|
|
63
|
+
if tool_name == "list_tables":
|
|
64
|
+
self.display.show_table_list(content)
|
|
65
|
+
elif tool_name == "introspect_schema":
|
|
66
|
+
self.display.show_schema_info(content)
|
|
67
|
+
elif tool_name == "execute_sql":
|
|
68
|
+
try:
|
|
69
|
+
data = json.loads(content)
|
|
70
|
+
if data.get("success") and data.get("results"):
|
|
71
|
+
self.display.show_query_results(data["results"]) # type: ignore[arg-type]
|
|
72
|
+
except json.JSONDecodeError:
|
|
73
|
+
# If not JSON, ignore here
|
|
74
|
+
pass
|
|
75
|
+
elif tool_name == "plot_data":
|
|
76
|
+
self.display.show_plot(
|
|
77
|
+
{"tool_name": tool_name, "result": content, "input": {}}
|
|
78
|
+
)
|
|
79
|
+
|
|
18
80
|
async def execute_streaming_query(
|
|
19
81
|
self,
|
|
20
82
|
user_query: str,
|
|
21
|
-
agent:
|
|
83
|
+
agent: Agent,
|
|
22
84
|
cancellation_token: asyncio.Event | None = None,
|
|
85
|
+
message_history: list | None = None,
|
|
23
86
|
):
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
status = self.console.status(
|
|
27
|
-
"[yellow]Crunching data...[/yellow]", spinner="bouncingBall"
|
|
28
|
-
)
|
|
29
|
-
status.start()
|
|
30
|
-
|
|
87
|
+
self.status.start()
|
|
31
88
|
try:
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
self.display.
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
elif event.type == "error":
|
|
74
|
-
self._stop_status(status)
|
|
75
|
-
self.display.show_error(event.data)
|
|
76
|
-
|
|
89
|
+
# If Anthropic OAuth, inject SQLsaber instructions before the first user prompt
|
|
90
|
+
prepared_prompt: str | list[str] = user_query
|
|
91
|
+
is_oauth = bool(getattr(agent, "_sqlsaber_is_oauth", False))
|
|
92
|
+
no_history = not message_history
|
|
93
|
+
if is_oauth and no_history:
|
|
94
|
+
ib = getattr(agent, "_sqlsaber_instruction_builder", None)
|
|
95
|
+
mm = getattr(agent, "_sqlsaber_memory_manager", None)
|
|
96
|
+
db_type = getattr(agent, "_sqlsaber_db_type", "database")
|
|
97
|
+
db_name = getattr(agent, "_sqlsaber_database_name", None)
|
|
98
|
+
instructions = (
|
|
99
|
+
ib.build_instructions(db_type=db_type) if ib is not None else ""
|
|
100
|
+
)
|
|
101
|
+
mem = (
|
|
102
|
+
mm.format_memories_for_prompt(db_name)
|
|
103
|
+
if (mm is not None and db_name)
|
|
104
|
+
else ""
|
|
105
|
+
)
|
|
106
|
+
parts = [p for p in (instructions, mem) if p and str(p).strip()]
|
|
107
|
+
if parts:
|
|
108
|
+
injected = "\n\n".join(parts)
|
|
109
|
+
prepared_prompt = [injected, user_query]
|
|
110
|
+
|
|
111
|
+
# Run the agent with our event stream handler
|
|
112
|
+
run = await agent.run(
|
|
113
|
+
prepared_prompt,
|
|
114
|
+
message_history=message_history,
|
|
115
|
+
event_stream_handler=self._event_stream_handler,
|
|
116
|
+
)
|
|
117
|
+
# After the run completes, show the assistant's final text as markdown if available
|
|
118
|
+
try:
|
|
119
|
+
output = run.output
|
|
120
|
+
if isinstance(output, str) and output.strip():
|
|
121
|
+
self.display.show_newline()
|
|
122
|
+
self.display.show_markdown_response(
|
|
123
|
+
[{"type": "text", "text": output}]
|
|
124
|
+
)
|
|
125
|
+
except Exception as e:
|
|
126
|
+
self.display.show_error(str(e))
|
|
127
|
+
self.display.show_newline()
|
|
128
|
+
return run
|
|
77
129
|
except asyncio.CancelledError:
|
|
78
|
-
# Handle cancellation gracefully
|
|
79
|
-
self._stop_status(status)
|
|
80
130
|
self.display.show_newline()
|
|
81
131
|
self.console.print("[yellow]Query interrupted[/yellow]")
|
|
82
|
-
return
|
|
132
|
+
return None
|
|
83
133
|
finally:
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
if hasattr(agent, "conversation_history") and agent.conversation_history:
|
|
89
|
-
last_message = agent.conversation_history[-1]
|
|
90
|
-
if last_message.get("role") == "assistant" and last_message.get(
|
|
91
|
-
"content"
|
|
92
|
-
):
|
|
93
|
-
self.display.show_markdown_response(last_message["content"])
|
|
94
|
-
|
|
95
|
-
def _stop_status(self, status):
|
|
96
|
-
"""Safely stop a status spinner."""
|
|
97
|
-
try:
|
|
98
|
-
status.stop()
|
|
99
|
-
except Exception:
|
|
100
|
-
pass # Status might already be stopped
|
|
134
|
+
try:
|
|
135
|
+
self.status.stop()
|
|
136
|
+
except Exception:
|
|
137
|
+
pass
|
sqlsaber/config/api_keys.py
CHANGED
|
@@ -6,6 +6,8 @@ import os
|
|
|
6
6
|
import keyring
|
|
7
7
|
from rich.console import Console
|
|
8
8
|
|
|
9
|
+
from sqlsaber.config import providers
|
|
10
|
+
|
|
9
11
|
console = Console()
|
|
10
12
|
|
|
11
13
|
|
|
@@ -30,9 +32,7 @@ class APIKeyManager:
|
|
|
30
32
|
try:
|
|
31
33
|
api_key = keyring.get_password(service_name, provider)
|
|
32
34
|
if api_key:
|
|
33
|
-
console.print(
|
|
34
|
-
f"Using stored {provider} API key from keyring", style="dim"
|
|
35
|
-
)
|
|
35
|
+
console.print(f"Using stored {provider} API key", style="dim")
|
|
36
36
|
return api_key
|
|
37
37
|
except Exception as e:
|
|
38
38
|
# Keyring access failed, continue to prompt
|
|
@@ -43,12 +43,9 @@ class APIKeyManager:
|
|
|
43
43
|
|
|
44
44
|
def _get_env_var_name(self, provider: str) -> str:
|
|
45
45
|
"""Get the expected environment variable name for a provider."""
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
return "ANTHROPIC_API_KEY"
|
|
50
|
-
else:
|
|
51
|
-
return "AI_API_KEY"
|
|
46
|
+
# Normalize aliases to canonical provider keys
|
|
47
|
+
key = providers.canonical(provider) or provider
|
|
48
|
+
return providers.env_var_name(key)
|
|
52
49
|
|
|
53
50
|
def _get_service_name(self, provider: str) -> str:
|
|
54
51
|
"""Get the keyring service name for a provider."""
|
|
@@ -60,7 +57,7 @@ class APIKeyManager:
|
|
|
60
57
|
"""Prompt user for API key and store it in keyring."""
|
|
61
58
|
try:
|
|
62
59
|
console.print(
|
|
63
|
-
f"\n{provider.title()} API key not found in environment or
|
|
60
|
+
f"\n{provider.title()} API key not found in environment or your OS's credentials store."
|
|
64
61
|
)
|
|
65
62
|
console.print("You can either:")
|
|
66
63
|
console.print(f" 1. Set the {env_var_name} environment variable")
|
|
@@ -85,7 +82,8 @@ class APIKeyManager:
|
|
|
85
82
|
console.print("API key stored securely for future use", style="green")
|
|
86
83
|
except Exception as e:
|
|
87
84
|
console.print(
|
|
88
|
-
f"Warning: Could not store API key in
|
|
85
|
+
f"Warning: Could not store API key in your operating system's credentials store: {e}",
|
|
86
|
+
style="yellow",
|
|
89
87
|
)
|
|
90
88
|
console.print(
|
|
91
89
|
"You may need to enter it again next time", style="yellow"
|