sqlsaber 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/__init__.py +2 -4
- sqlsaber/agents/base.py +18 -221
- sqlsaber/agents/mcp.py +2 -2
- sqlsaber/agents/pydantic_ai_agent.py +170 -0
- sqlsaber/cli/auth.py +146 -79
- sqlsaber/cli/commands.py +22 -7
- sqlsaber/cli/database.py +1 -1
- sqlsaber/cli/interactive.py +65 -30
- sqlsaber/cli/models.py +58 -29
- sqlsaber/cli/streaming.py +114 -77
- sqlsaber/config/api_keys.py +9 -11
- sqlsaber/config/providers.py +116 -0
- sqlsaber/config/settings.py +50 -30
- sqlsaber/database/connection.py +3 -3
- sqlsaber/mcp/mcp.py +43 -51
- sqlsaber/models/__init__.py +0 -3
- sqlsaber/tools/__init__.py +25 -0
- sqlsaber/tools/base.py +85 -0
- sqlsaber/tools/enums.py +21 -0
- sqlsaber/tools/instructions.py +251 -0
- sqlsaber/tools/registry.py +130 -0
- sqlsaber/tools/sql_tools.py +275 -0
- sqlsaber/tools/visualization_tools.py +144 -0
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/METADATA +20 -39
- sqlsaber-0.16.0.dist-info/RECORD +51 -0
- sqlsaber/agents/anthropic.py +0 -579
- sqlsaber/agents/streaming.py +0 -16
- sqlsaber/clients/__init__.py +0 -6
- sqlsaber/clients/anthropic.py +0 -285
- sqlsaber/clients/base.py +0 -31
- sqlsaber/clients/exceptions.py +0 -117
- sqlsaber/clients/models.py +0 -282
- sqlsaber/clients/streaming.py +0 -257
- sqlsaber/models/events.py +0 -28
- sqlsaber-0.14.0.dist-info/RECORD +0 -51
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/agents/__init__.py
CHANGED
sqlsaber/agents/base.py
CHANGED
|
@@ -5,8 +5,6 @@ import json
|
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from typing import Any, AsyncIterator
|
|
7
7
|
|
|
8
|
-
from uniplot import histogram, plot
|
|
9
|
-
|
|
10
8
|
from sqlsaber.conversation.manager import ConversationManager
|
|
11
9
|
from sqlsaber.database.connection import (
|
|
12
10
|
BaseDatabaseConnection,
|
|
@@ -16,7 +14,7 @@ from sqlsaber.database.connection import (
|
|
|
16
14
|
SQLiteConnection,
|
|
17
15
|
)
|
|
18
16
|
from sqlsaber.database.schema import SchemaManager
|
|
19
|
-
from sqlsaber.
|
|
17
|
+
from sqlsaber.tools import SQLTool, tool_registry
|
|
20
18
|
|
|
21
19
|
|
|
22
20
|
class BaseSQLAgent(ABC):
|
|
@@ -32,13 +30,16 @@ class BaseSQLAgent(ABC):
|
|
|
32
30
|
self._conversation_id: str | None = None
|
|
33
31
|
self._msg_index: int = 0
|
|
34
32
|
|
|
33
|
+
# Initialize SQL tools with database connection
|
|
34
|
+
self._init_tools()
|
|
35
|
+
|
|
35
36
|
@abstractmethod
|
|
36
37
|
async def query_stream(
|
|
37
38
|
self,
|
|
38
39
|
user_query: str,
|
|
39
40
|
use_history: bool = True,
|
|
40
41
|
cancellation_token: asyncio.Event | None = None,
|
|
41
|
-
) -> AsyncIterator
|
|
42
|
+
) -> AsyncIterator:
|
|
42
43
|
"""Process a user query and stream responses.
|
|
43
44
|
|
|
44
45
|
Args:
|
|
@@ -69,232 +70,28 @@ class BaseSQLAgent(ABC):
|
|
|
69
70
|
else:
|
|
70
71
|
return "database" # Fallback
|
|
71
72
|
|
|
72
|
-
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
formatted_info = {}
|
|
80
|
-
for table_name, table_info in schema_info.items():
|
|
81
|
-
formatted_info[table_name] = {
|
|
82
|
-
"columns": {
|
|
83
|
-
col_name: {
|
|
84
|
-
"type": col_info["data_type"],
|
|
85
|
-
"nullable": col_info["nullable"],
|
|
86
|
-
"default": col_info["default"],
|
|
87
|
-
}
|
|
88
|
-
for col_name, col_info in table_info["columns"].items()
|
|
89
|
-
},
|
|
90
|
-
"primary_keys": table_info["primary_keys"],
|
|
91
|
-
"foreign_keys": [
|
|
92
|
-
f"{fk['column']} -> {fk['references']['table']}.{fk['references']['column']}"
|
|
93
|
-
for fk in table_info["foreign_keys"]
|
|
94
|
-
],
|
|
95
|
-
}
|
|
96
|
-
|
|
97
|
-
return json.dumps(formatted_info)
|
|
98
|
-
except Exception as e:
|
|
99
|
-
return json.dumps({"error": f"Error introspecting schema: {str(e)}"})
|
|
100
|
-
|
|
101
|
-
async def list_tables(self) -> str:
|
|
102
|
-
"""List all tables in the database with basic information."""
|
|
103
|
-
try:
|
|
104
|
-
tables_info = await self.schema_manager.list_tables()
|
|
105
|
-
return json.dumps(tables_info)
|
|
106
|
-
except Exception as e:
|
|
107
|
-
return json.dumps({"error": f"Error listing tables: {str(e)}"})
|
|
108
|
-
|
|
109
|
-
async def execute_sql(self, query: str, limit: int | None = None) -> str:
|
|
110
|
-
"""Execute a SQL query against the database."""
|
|
111
|
-
try:
|
|
112
|
-
# Security check - only allow SELECT queries unless write is enabled
|
|
113
|
-
write_error = self._validate_write_operation(query)
|
|
114
|
-
if write_error:
|
|
115
|
-
return json.dumps(
|
|
116
|
-
{
|
|
117
|
-
"error": write_error,
|
|
118
|
-
}
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
# Add LIMIT if not present and it's a SELECT query
|
|
122
|
-
query = self._add_limit_to_query(query, limit)
|
|
123
|
-
|
|
124
|
-
# Execute the query (wrapped in a transaction for safety)
|
|
125
|
-
results = await self.db.execute_query(query)
|
|
126
|
-
|
|
127
|
-
# Format results
|
|
128
|
-
actual_limit = limit if limit is not None else len(results)
|
|
129
|
-
|
|
130
|
-
return json.dumps(
|
|
131
|
-
{
|
|
132
|
-
"success": True,
|
|
133
|
-
"row_count": len(results),
|
|
134
|
-
"results": results[:actual_limit], # Extra safety for limit
|
|
135
|
-
"truncated": len(results) > actual_limit,
|
|
136
|
-
}
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
except Exception as e:
|
|
140
|
-
error_msg = str(e)
|
|
141
|
-
|
|
142
|
-
# Provide helpful error messages
|
|
143
|
-
suggestions = []
|
|
144
|
-
if "column" in error_msg.lower() and "does not exist" in error_msg.lower():
|
|
145
|
-
suggestions.append(
|
|
146
|
-
"Check column names using the schema introspection tool"
|
|
147
|
-
)
|
|
148
|
-
elif "table" in error_msg.lower() and "does not exist" in error_msg.lower():
|
|
149
|
-
suggestions.append(
|
|
150
|
-
"Check table names using the schema introspection tool"
|
|
151
|
-
)
|
|
152
|
-
elif "syntax error" in error_msg.lower():
|
|
153
|
-
suggestions.append(
|
|
154
|
-
"Review SQL syntax, especially JOIN conditions and WHERE clauses"
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
return json.dumps({"error": error_msg, "suggestions": suggestions})
|
|
73
|
+
def _init_tools(self) -> None:
|
|
74
|
+
"""Initialize SQL tools with database connection."""
|
|
75
|
+
# Get all SQL tools and set their database connection
|
|
76
|
+
for tool_name in tool_registry.list_tools(category="sql"):
|
|
77
|
+
tool = tool_registry.get_tool(tool_name)
|
|
78
|
+
if isinstance(tool, SQLTool):
|
|
79
|
+
tool.set_connection(self.db)
|
|
158
80
|
|
|
159
81
|
async def process_tool_call(
|
|
160
82
|
self, tool_name: str, tool_input: dict[str, Any]
|
|
161
83
|
) -> str:
|
|
162
84
|
"""Process a tool call and return the result."""
|
|
163
|
-
if tool_name == "list_tables":
|
|
164
|
-
return await self.list_tables()
|
|
165
|
-
elif tool_name == "introspect_schema":
|
|
166
|
-
return await self.introspect_schema(tool_input.get("table_pattern"))
|
|
167
|
-
elif tool_name == "execute_sql":
|
|
168
|
-
return await self.execute_sql(
|
|
169
|
-
tool_input["query"], tool_input.get("limit", 100)
|
|
170
|
-
)
|
|
171
|
-
elif tool_name == "plot_data":
|
|
172
|
-
return await self.plot_data(
|
|
173
|
-
y_values=tool_input["y_values"],
|
|
174
|
-
x_values=tool_input.get("x_values"),
|
|
175
|
-
plot_type=tool_input.get("plot_type", "line"),
|
|
176
|
-
title=tool_input.get("title"),
|
|
177
|
-
x_label=tool_input.get("x_label"),
|
|
178
|
-
y_label=tool_input.get("y_label"),
|
|
179
|
-
)
|
|
180
|
-
else:
|
|
181
|
-
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
|
182
|
-
|
|
183
|
-
def _validate_write_operation(self, query: str) -> str | None:
|
|
184
|
-
"""Validate if a write operation is allowed.
|
|
185
|
-
|
|
186
|
-
Returns:
|
|
187
|
-
None if operation is allowed, error message if not allowed.
|
|
188
|
-
"""
|
|
189
|
-
query_upper = query.strip().upper()
|
|
190
|
-
|
|
191
|
-
# Check for write operations
|
|
192
|
-
write_keywords = [
|
|
193
|
-
"INSERT",
|
|
194
|
-
"UPDATE",
|
|
195
|
-
"DELETE",
|
|
196
|
-
"DROP",
|
|
197
|
-
"CREATE",
|
|
198
|
-
"ALTER",
|
|
199
|
-
"TRUNCATE",
|
|
200
|
-
]
|
|
201
|
-
is_write_query = any(query_upper.startswith(kw) for kw in write_keywords)
|
|
202
|
-
|
|
203
|
-
if is_write_query:
|
|
204
|
-
return (
|
|
205
|
-
"Write operations are not allowed. Only SELECT queries are permitted."
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
return None
|
|
209
|
-
|
|
210
|
-
def _add_limit_to_query(self, query: str, limit: int = 100) -> str:
|
|
211
|
-
"""Add LIMIT clause to SELECT queries if not present."""
|
|
212
|
-
query_upper = query.strip().upper()
|
|
213
|
-
if query_upper.startswith("SELECT") and "LIMIT" not in query_upper:
|
|
214
|
-
return f"{query.rstrip(';')} LIMIT {limit};"
|
|
215
|
-
return query
|
|
216
|
-
|
|
217
|
-
async def plot_data(
|
|
218
|
-
self,
|
|
219
|
-
y_values: list[float],
|
|
220
|
-
x_values: list[float] | None = None,
|
|
221
|
-
plot_type: str = "line",
|
|
222
|
-
title: str | None = None,
|
|
223
|
-
x_label: str | None = None,
|
|
224
|
-
y_label: str | None = None,
|
|
225
|
-
) -> str:
|
|
226
|
-
"""Create a terminal plot using uniplot.
|
|
227
|
-
|
|
228
|
-
Args:
|
|
229
|
-
y_values: Y-axis data points
|
|
230
|
-
x_values: X-axis data points (optional)
|
|
231
|
-
plot_type: Type of plot - "line", "scatter", or "histogram"
|
|
232
|
-
title: Plot title
|
|
233
|
-
x_label: X-axis label
|
|
234
|
-
y_label: Y-axis label
|
|
235
|
-
|
|
236
|
-
Returns:
|
|
237
|
-
JSON string with success status and plot details
|
|
238
|
-
"""
|
|
239
85
|
try:
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
try:
|
|
246
|
-
y_values = [float(v) if v is not None else None for v in y_values]
|
|
247
|
-
if x_values:
|
|
248
|
-
x_values = [float(v) if v is not None else None for v in x_values]
|
|
249
|
-
except (ValueError, TypeError) as e:
|
|
250
|
-
return json.dumps({"error": f"Invalid data format: {str(e)}"})
|
|
251
|
-
|
|
252
|
-
# Create the plot
|
|
253
|
-
if plot_type == "histogram":
|
|
254
|
-
# For histogram, we only need y_values
|
|
255
|
-
histogram(
|
|
256
|
-
y_values,
|
|
257
|
-
title=title,
|
|
258
|
-
bins=min(20, len(set(y_values))), # Adaptive bin count
|
|
259
|
-
)
|
|
260
|
-
plot_info = {
|
|
261
|
-
"type": "histogram",
|
|
262
|
-
"data_points": len(y_values),
|
|
263
|
-
"title": title or "Histogram",
|
|
264
|
-
}
|
|
265
|
-
elif plot_type in ["line", "scatter"]:
|
|
266
|
-
# For line/scatter plots
|
|
267
|
-
plot_kwargs = {
|
|
268
|
-
"ys": y_values,
|
|
269
|
-
"title": title,
|
|
270
|
-
"lines": plot_type == "line",
|
|
271
|
-
}
|
|
272
|
-
|
|
273
|
-
if x_values:
|
|
274
|
-
plot_kwargs["xs"] = x_values
|
|
275
|
-
if x_label:
|
|
276
|
-
plot_kwargs["x_unit"] = x_label
|
|
277
|
-
if y_label:
|
|
278
|
-
plot_kwargs["y_unit"] = y_label
|
|
279
|
-
|
|
280
|
-
plot(**plot_kwargs)
|
|
281
|
-
|
|
282
|
-
plot_info = {
|
|
283
|
-
"type": plot_type,
|
|
284
|
-
"data_points": len(y_values),
|
|
285
|
-
"title": title or f"{plot_type.capitalize()} Plot",
|
|
286
|
-
"has_x_values": x_values is not None,
|
|
287
|
-
}
|
|
288
|
-
else:
|
|
289
|
-
return json.dumps({"error": f"Unsupported plot type: {plot_type}"})
|
|
290
|
-
|
|
86
|
+
tool = tool_registry.get_tool(tool_name)
|
|
87
|
+
return await tool.execute(**tool_input)
|
|
88
|
+
except KeyError:
|
|
89
|
+
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
|
90
|
+
except Exception as e:
|
|
291
91
|
return json.dumps(
|
|
292
|
-
{"
|
|
92
|
+
{"error": f"Error executing tool '{tool_name}': {str(e)}"}
|
|
293
93
|
)
|
|
294
94
|
|
|
295
|
-
except Exception as e:
|
|
296
|
-
return json.dumps({"error": f"Error creating plot: {str(e)}"})
|
|
297
|
-
|
|
298
95
|
# Conversation persistence helpers
|
|
299
96
|
|
|
300
97
|
async def _ensure_conversation(self) -> None:
|
sqlsaber/agents/mcp.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
"""Generic SQL agent implementation for MCP tools."""
|
|
2
2
|
|
|
3
3
|
from typing import AsyncIterator
|
|
4
|
+
|
|
4
5
|
from sqlsaber.agents.base import BaseSQLAgent
|
|
5
6
|
from sqlsaber.database.connection import BaseDatabaseConnection
|
|
6
|
-
from sqlsaber.models.events import StreamEvent
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class MCPSQLAgent(BaseSQLAgent):
|
|
@@ -14,7 +14,7 @@ class MCPSQLAgent(BaseSQLAgent):
|
|
|
14
14
|
|
|
15
15
|
async def query_stream(
|
|
16
16
|
self, user_query: str, use_history: bool = True
|
|
17
|
-
) -> AsyncIterator
|
|
17
|
+
) -> AsyncIterator:
|
|
18
18
|
"""Not implemented for generic agent as it's only used for tool operations."""
|
|
19
19
|
raise NotImplementedError(
|
|
20
20
|
"MCPSQLAgent does not support query streaming. Use specific agent implementations for conversation."
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""Pydantic-AI Agent for SQLSaber.
|
|
2
|
+
|
|
3
|
+
This replaces the custom AnthropicSQLAgent and uses pydantic-ai's Agent,
|
|
4
|
+
function tools, and streaming event types directly.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
from pydantic_ai import Agent, RunContext
|
|
9
|
+
from pydantic_ai.models.anthropic import AnthropicModel
|
|
10
|
+
from pydantic_ai.models.google import GoogleModel
|
|
11
|
+
from pydantic_ai.providers.anthropic import AnthropicProvider
|
|
12
|
+
from pydantic_ai.providers.google import GoogleProvider
|
|
13
|
+
|
|
14
|
+
from sqlsaber.config import providers
|
|
15
|
+
from sqlsaber.config.settings import Config
|
|
16
|
+
from sqlsaber.database.connection import (
|
|
17
|
+
BaseDatabaseConnection,
|
|
18
|
+
CSVConnection,
|
|
19
|
+
MySQLConnection,
|
|
20
|
+
PostgreSQLConnection,
|
|
21
|
+
SQLiteConnection,
|
|
22
|
+
)
|
|
23
|
+
from sqlsaber.memory.manager import MemoryManager
|
|
24
|
+
from sqlsaber.tools.instructions import InstructionBuilder
|
|
25
|
+
from sqlsaber.tools.registry import tool_registry
|
|
26
|
+
from sqlsaber.tools.sql_tools import SQLTool
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def build_sqlsaber_agent(
|
|
30
|
+
db_connection: BaseDatabaseConnection,
|
|
31
|
+
database_name: str | None,
|
|
32
|
+
) -> Agent:
|
|
33
|
+
"""Create and configure a pydantic-ai Agent for SQLSaber.
|
|
34
|
+
|
|
35
|
+
- Registers function tools that delegate to the existing tool registry
|
|
36
|
+
- Attaches dynamic system prompt built from InstructionBuilder + MemoryManager
|
|
37
|
+
- Ensures SQL tools have the active DB connection
|
|
38
|
+
"""
|
|
39
|
+
# Ensure SQL tools receive the active connection
|
|
40
|
+
for tool_name in tool_registry.list_tools(category="sql"):
|
|
41
|
+
tool = tool_registry.get_tool(tool_name)
|
|
42
|
+
if isinstance(tool, SQLTool):
|
|
43
|
+
tool.set_connection(db_connection)
|
|
44
|
+
|
|
45
|
+
cfg = Config()
|
|
46
|
+
# Ensure provider env var is hydrated from keyring for current provider (Config.validate handles it)
|
|
47
|
+
cfg.validate()
|
|
48
|
+
|
|
49
|
+
# Build model/agent. For some providers (e.g., google), construct provider model explicitly to
|
|
50
|
+
# allow arbitrary model IDs even if not in pydantic-ai's KnownModelName.
|
|
51
|
+
model_name_only = (
|
|
52
|
+
cfg.model_name.split(":", 1)[1] if ":" in cfg.model_name else cfg.model_name
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
provider = providers.provider_from_model(cfg.model_name) or ""
|
|
56
|
+
if provider == "google":
|
|
57
|
+
model_obj = GoogleModel(
|
|
58
|
+
model_name_only, provider=GoogleProvider(api_key=cfg.api_key)
|
|
59
|
+
)
|
|
60
|
+
agent = Agent(model_obj, name="sqlsaber")
|
|
61
|
+
elif provider == "anthropic" and bool(getattr(cfg, "oauth_token", None)):
|
|
62
|
+
# Build custom httpx client to inject OAuth headers for Anthropic
|
|
63
|
+
async def add_oauth_headers(request: httpx.Request) -> None: # type: ignore[override]
|
|
64
|
+
# Remove API-key header if present and add OAuth headers
|
|
65
|
+
if "x-api-key" in request.headers:
|
|
66
|
+
del request.headers["x-api-key"]
|
|
67
|
+
request.headers.update(
|
|
68
|
+
{
|
|
69
|
+
"Authorization": f"Bearer {cfg.oauth_token}",
|
|
70
|
+
"anthropic-version": "2023-06-01",
|
|
71
|
+
"anthropic-beta": "oauth-2025-04-20",
|
|
72
|
+
"User-Agent": "ClaudeCode/1.0 (Anthropic Claude Code CLI)",
|
|
73
|
+
"X-Client-Name": "claude-code",
|
|
74
|
+
"X-Client-Version": "1.0.0",
|
|
75
|
+
}
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
http_client = httpx.AsyncClient(event_hooks={"request": [add_oauth_headers]})
|
|
79
|
+
provider_obj = AnthropicProvider(api_key="placeholder", http_client=http_client)
|
|
80
|
+
model_obj = AnthropicModel(model_name_only, provider=provider_obj)
|
|
81
|
+
agent = Agent(model_obj, name="sqlsaber")
|
|
82
|
+
else:
|
|
83
|
+
agent = Agent(cfg.model_name, name="sqlsaber")
|
|
84
|
+
|
|
85
|
+
# Memory + dynamic system prompt
|
|
86
|
+
memory_manager = MemoryManager()
|
|
87
|
+
instruction_builder = InstructionBuilder(tool_registry)
|
|
88
|
+
|
|
89
|
+
is_oauth = provider == "anthropic" and bool(getattr(cfg, "oauth_token", None))
|
|
90
|
+
|
|
91
|
+
if not is_oauth:
|
|
92
|
+
|
|
93
|
+
@agent.system_prompt(dynamic=True)
|
|
94
|
+
async def sqlsaber_system_prompt(ctx: RunContext) -> str:
|
|
95
|
+
db_type = _get_database_type_name(db_connection)
|
|
96
|
+
instructions = instruction_builder.build_instructions(db_type=db_type)
|
|
97
|
+
|
|
98
|
+
# Add memory context if available
|
|
99
|
+
if database_name:
|
|
100
|
+
mem = memory_manager.format_memories_for_prompt(database_name)
|
|
101
|
+
else:
|
|
102
|
+
mem = ""
|
|
103
|
+
|
|
104
|
+
parts = [p for p in (instructions, mem) if p and p.strip()]
|
|
105
|
+
return "\n\n".join(parts) if parts else ""
|
|
106
|
+
else:
|
|
107
|
+
|
|
108
|
+
@agent.system_prompt(dynamic=True)
|
|
109
|
+
async def sqlsaber_system_prompt(ctx: RunContext) -> str:
|
|
110
|
+
# Minimal system prompt in OAuth mode to match Claude Code identity
|
|
111
|
+
return "You are Claude Code, Anthropic's official CLI for Claude."
|
|
112
|
+
|
|
113
|
+
# Expose helpers and context on agent instance
|
|
114
|
+
agent._sqlsaber_memory_manager = memory_manager # type: ignore[attr-defined]
|
|
115
|
+
agent._sqlsaber_database_name = database_name # type: ignore[attr-defined]
|
|
116
|
+
agent._sqlsaber_instruction_builder = instruction_builder # type: ignore[attr-defined]
|
|
117
|
+
agent._sqlsaber_db_type = _get_database_type_name(db_connection) # type: ignore[attr-defined]
|
|
118
|
+
agent._sqlsaber_is_oauth = is_oauth # type: ignore[attr-defined]
|
|
119
|
+
|
|
120
|
+
# Tool wrappers that invoke the registered tools
|
|
121
|
+
@agent.tool(name="list_tables")
|
|
122
|
+
async def list_tables(ctx: RunContext) -> str:
|
|
123
|
+
"""
|
|
124
|
+
Get a list of all tables in the database with row counts.
|
|
125
|
+
Use this first to discover available tables.
|
|
126
|
+
"""
|
|
127
|
+
tool = tool_registry.get_tool("list_tables")
|
|
128
|
+
return await tool.execute()
|
|
129
|
+
|
|
130
|
+
@agent.tool(name="introspect_schema")
|
|
131
|
+
async def introspect_schema(
|
|
132
|
+
ctx: RunContext, table_pattern: str | None = None
|
|
133
|
+
) -> str:
|
|
134
|
+
"""
|
|
135
|
+
Introspect database schema to understand table structures.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
table_pattern: Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')
|
|
139
|
+
"""
|
|
140
|
+
tool = tool_registry.get_tool("introspect_schema")
|
|
141
|
+
return await tool.execute(table_pattern=table_pattern)
|
|
142
|
+
|
|
143
|
+
@agent.tool(name="execute_sql")
|
|
144
|
+
async def execute_sql(ctx: RunContext, query: str, limit: int | None = 100) -> str:
|
|
145
|
+
"""
|
|
146
|
+
Execute a SQL query and return the results.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
query: SQL query to execute
|
|
150
|
+
limit: Maximum number of rows to return (default: 100)
|
|
151
|
+
"""
|
|
152
|
+
tool = tool_registry.get_tool("execute_sql")
|
|
153
|
+
return await tool.execute(query=query, limit=limit)
|
|
154
|
+
|
|
155
|
+
return agent
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _get_database_type_name(db: BaseDatabaseConnection) -> str:
|
|
159
|
+
"""Get the human-readable database type name (mirrors BaseSQLAgent)."""
|
|
160
|
+
|
|
161
|
+
if isinstance(db, PostgreSQLConnection):
|
|
162
|
+
return "PostgreSQL"
|
|
163
|
+
elif isinstance(db, MySQLConnection):
|
|
164
|
+
return "MySQL"
|
|
165
|
+
elif isinstance(db, SQLiteConnection):
|
|
166
|
+
return "SQLite"
|
|
167
|
+
elif isinstance(db, CSVConnection):
|
|
168
|
+
return "SQLite"
|
|
169
|
+
else:
|
|
170
|
+
return "database"
|