sqlsaber 0.14.0__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/anthropic.py +28 -116
- sqlsaber/agents/base.py +17 -219
- sqlsaber/mcp/mcp.py +43 -51
- sqlsaber/tools/__init__.py +25 -0
- sqlsaber/tools/base.py +83 -0
- sqlsaber/tools/enums.py +21 -0
- sqlsaber/tools/instructions.py +251 -0
- sqlsaber/tools/registry.py +130 -0
- sqlsaber/tools/sql_tools.py +275 -0
- sqlsaber/tools/visualization_tools.py +144 -0
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.15.0.dist-info}/METADATA +1 -1
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.15.0.dist-info}/RECORD +15 -8
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.15.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.15.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.14.0.dist-info → sqlsaber-0.15.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/agents/anthropic.py
CHANGED
|
@@ -21,6 +21,8 @@ from sqlsaber.config.settings import Config
|
|
|
21
21
|
from sqlsaber.database.connection import BaseDatabaseConnection
|
|
22
22
|
from sqlsaber.memory.manager import MemoryManager
|
|
23
23
|
from sqlsaber.models.events import StreamEvent
|
|
24
|
+
from sqlsaber.tools import tool_registry
|
|
25
|
+
from sqlsaber.tools.instructions import InstructionBuilder
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
class AnthropicSQLAgent(BaseSQLAgent):
|
|
@@ -51,89 +53,11 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
51
53
|
self._last_results = None
|
|
52
54
|
self._last_query = None
|
|
53
55
|
|
|
54
|
-
#
|
|
55
|
-
self.tools: list[ToolDefinition] =
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
input_schema={
|
|
60
|
-
"type": "object",
|
|
61
|
-
"properties": {},
|
|
62
|
-
"required": [],
|
|
63
|
-
},
|
|
64
|
-
),
|
|
65
|
-
ToolDefinition(
|
|
66
|
-
name="introspect_schema",
|
|
67
|
-
description="Introspect database schema to understand table structures.",
|
|
68
|
-
input_schema={
|
|
69
|
-
"type": "object",
|
|
70
|
-
"properties": {
|
|
71
|
-
"table_pattern": {
|
|
72
|
-
"type": "string",
|
|
73
|
-
"description": "Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')",
|
|
74
|
-
}
|
|
75
|
-
},
|
|
76
|
-
"required": [],
|
|
77
|
-
},
|
|
78
|
-
),
|
|
79
|
-
ToolDefinition(
|
|
80
|
-
name="execute_sql",
|
|
81
|
-
description="Execute a SQL query against the database.",
|
|
82
|
-
input_schema={
|
|
83
|
-
"type": "object",
|
|
84
|
-
"properties": {
|
|
85
|
-
"query": {
|
|
86
|
-
"type": "string",
|
|
87
|
-
"description": "SQL query to execute",
|
|
88
|
-
},
|
|
89
|
-
"limit": {
|
|
90
|
-
"type": "integer",
|
|
91
|
-
"description": f"Maximum number of rows to return (default: {AnthropicSQLAgent.DEFAULT_SQL_LIMIT})",
|
|
92
|
-
"default": AnthropicSQLAgent.DEFAULT_SQL_LIMIT,
|
|
93
|
-
},
|
|
94
|
-
},
|
|
95
|
-
"required": ["query"],
|
|
96
|
-
},
|
|
97
|
-
),
|
|
98
|
-
ToolDefinition(
|
|
99
|
-
name="plot_data",
|
|
100
|
-
description="Create a plot of query results.",
|
|
101
|
-
input_schema={
|
|
102
|
-
"type": "object",
|
|
103
|
-
"properties": {
|
|
104
|
-
"y_values": {
|
|
105
|
-
"type": "array",
|
|
106
|
-
"items": {"type": ["number", "null"]},
|
|
107
|
-
"description": "Y-axis data points (required)",
|
|
108
|
-
},
|
|
109
|
-
"x_values": {
|
|
110
|
-
"type": "array",
|
|
111
|
-
"items": {"type": ["number", "null"]},
|
|
112
|
-
"description": "X-axis data points (optional, will use indices if not provided)",
|
|
113
|
-
},
|
|
114
|
-
"plot_type": {
|
|
115
|
-
"type": "string",
|
|
116
|
-
"enum": ["line", "scatter", "histogram"],
|
|
117
|
-
"description": "Type of plot to create (default: line)",
|
|
118
|
-
"default": "line",
|
|
119
|
-
},
|
|
120
|
-
"title": {
|
|
121
|
-
"type": "string",
|
|
122
|
-
"description": "Title for the plot",
|
|
123
|
-
},
|
|
124
|
-
"x_label": {
|
|
125
|
-
"type": "string",
|
|
126
|
-
"description": "Label for X-axis",
|
|
127
|
-
},
|
|
128
|
-
"y_label": {
|
|
129
|
-
"type": "string",
|
|
130
|
-
"description": "Label for Y-axis",
|
|
131
|
-
},
|
|
132
|
-
},
|
|
133
|
-
"required": ["y_values"],
|
|
134
|
-
},
|
|
135
|
-
),
|
|
136
|
-
]
|
|
56
|
+
# Get tool definitions from registry
|
|
57
|
+
self.tools: list[ToolDefinition] = tool_registry.get_tool_definitions()
|
|
58
|
+
|
|
59
|
+
# Initialize instruction builder
|
|
60
|
+
self.instruction_builder = InstructionBuilder(tool_registry)
|
|
137
61
|
|
|
138
62
|
# Build system prompt with memories if available
|
|
139
63
|
self.system_prompt = self._build_system_prompt()
|
|
@@ -157,31 +81,9 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
157
81
|
def _get_sql_assistant_instructions(self) -> str:
|
|
158
82
|
"""Get the detailed SQL assistant instructions."""
|
|
159
83
|
db_type = self._get_database_type_name()
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
1. Understand user's natural language requests, think and convert them to SQL
|
|
164
|
-
2. Use the provided tools efficiently to explore database schema
|
|
165
|
-
3. Generate appropriate SQL queries
|
|
166
|
-
4. Execute queries safely - queries that modify the database are not allowed
|
|
167
|
-
5. Format and explain results clearly
|
|
168
|
-
6. Create visualizations when requested or when they would be helpful
|
|
169
|
-
|
|
170
|
-
IMPORTANT - Schema Discovery Strategy:
|
|
171
|
-
1. ALWAYS start with 'list_tables' to see available tables and row counts
|
|
172
|
-
2. Based on the user's query, identify which specific tables are relevant
|
|
173
|
-
3. Use 'introspect_schema' with a table_pattern to get details ONLY for relevant tables
|
|
174
|
-
4. Timestamp columns must be converted to text when you write queries
|
|
175
|
-
|
|
176
|
-
Guidelines:
|
|
177
|
-
- Use list_tables first, then introspect_schema for specific tables only
|
|
178
|
-
- Use table patterns like 'sample%' or '%experiment%' to filter related tables
|
|
179
|
-
- Use proper JOIN syntax and avoid cartesian products
|
|
180
|
-
- Include appropriate WHERE clauses to limit results
|
|
181
|
-
- Explain what the query does in simple terms
|
|
182
|
-
- Handle errors gracefully and suggest fixes
|
|
183
|
-
- Be security conscious - use parameterized queries when needed
|
|
184
|
-
"""
|
|
84
|
+
|
|
85
|
+
# Build dynamic instructions from available tools
|
|
86
|
+
instructions = self.instruction_builder.build_instructions(db_type=db_type)
|
|
185
87
|
|
|
186
88
|
# Add memory context if database name is available
|
|
187
89
|
if self.database_name:
|
|
@@ -189,7 +91,7 @@ Guidelines:
|
|
|
189
91
|
self.database_name
|
|
190
92
|
)
|
|
191
93
|
if memory_context.strip():
|
|
192
|
-
instructions += memory_context
|
|
94
|
+
instructions += "\n\n" + memory_context
|
|
193
95
|
|
|
194
96
|
return instructions
|
|
195
97
|
|
|
@@ -199,16 +101,19 @@ Guidelines:
|
|
|
199
101
|
return None
|
|
200
102
|
|
|
201
103
|
memory = self.memory_manager.add_memory(self.database_name, content)
|
|
202
|
-
# Rebuild system prompt with new memory
|
|
104
|
+
# Rebuild system prompt with new memory (includes dynamic instructions)
|
|
203
105
|
self.system_prompt = self._build_system_prompt()
|
|
204
106
|
return memory.id
|
|
205
107
|
|
|
206
|
-
async def
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
108
|
+
async def _execute_sql_with_tracking(
|
|
109
|
+
self, query: str, limit: int | None = None
|
|
110
|
+
) -> str:
|
|
111
|
+
"""Execute SQL and track results for streaming."""
|
|
112
|
+
# Get the execute_sql tool and run it
|
|
113
|
+
tool = tool_registry.get_tool("execute_sql")
|
|
114
|
+
result = await tool.execute(query=query, limit=limit)
|
|
210
115
|
|
|
211
|
-
# Parse result to extract data for streaming
|
|
116
|
+
# Parse result to extract data for streaming
|
|
212
117
|
try:
|
|
213
118
|
result_data = json.loads(result)
|
|
214
119
|
if result_data.get("success") and "results" in result_data:
|
|
@@ -228,7 +133,14 @@ Guidelines:
|
|
|
228
133
|
self, tool_name: str, tool_input: dict[str, Any]
|
|
229
134
|
) -> str:
|
|
230
135
|
"""Process a tool call and return the result."""
|
|
231
|
-
#
|
|
136
|
+
# Special handling for execute_sql to track results
|
|
137
|
+
if tool_name == "execute_sql":
|
|
138
|
+
return await self._execute_sql_with_tracking(
|
|
139
|
+
tool_input.get("query", ""),
|
|
140
|
+
tool_input.get("limit", self.DEFAULT_SQL_LIMIT),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Use parent implementation for all other tools
|
|
232
144
|
return await super().process_tool_call(tool_name, tool_input)
|
|
233
145
|
|
|
234
146
|
def _convert_user_message_to_message(
|
sqlsaber/agents/base.py
CHANGED
|
@@ -5,8 +5,6 @@ import json
|
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from typing import Any, AsyncIterator
|
|
7
7
|
|
|
8
|
-
from uniplot import histogram, plot
|
|
9
|
-
|
|
10
8
|
from sqlsaber.conversation.manager import ConversationManager
|
|
11
9
|
from sqlsaber.database.connection import (
|
|
12
10
|
BaseDatabaseConnection,
|
|
@@ -17,6 +15,7 @@ from sqlsaber.database.connection import (
|
|
|
17
15
|
)
|
|
18
16
|
from sqlsaber.database.schema import SchemaManager
|
|
19
17
|
from sqlsaber.models.events import StreamEvent
|
|
18
|
+
from sqlsaber.tools import SQLTool, tool_registry
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
class BaseSQLAgent(ABC):
|
|
@@ -32,6 +31,9 @@ class BaseSQLAgent(ABC):
|
|
|
32
31
|
self._conversation_id: str | None = None
|
|
33
32
|
self._msg_index: int = 0
|
|
34
33
|
|
|
34
|
+
# Initialize SQL tools with database connection
|
|
35
|
+
self._init_tools()
|
|
36
|
+
|
|
35
37
|
@abstractmethod
|
|
36
38
|
async def query_stream(
|
|
37
39
|
self,
|
|
@@ -69,232 +71,28 @@ class BaseSQLAgent(ABC):
|
|
|
69
71
|
else:
|
|
70
72
|
return "database" # Fallback
|
|
71
73
|
|
|
72
|
-
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
formatted_info = {}
|
|
80
|
-
for table_name, table_info in schema_info.items():
|
|
81
|
-
formatted_info[table_name] = {
|
|
82
|
-
"columns": {
|
|
83
|
-
col_name: {
|
|
84
|
-
"type": col_info["data_type"],
|
|
85
|
-
"nullable": col_info["nullable"],
|
|
86
|
-
"default": col_info["default"],
|
|
87
|
-
}
|
|
88
|
-
for col_name, col_info in table_info["columns"].items()
|
|
89
|
-
},
|
|
90
|
-
"primary_keys": table_info["primary_keys"],
|
|
91
|
-
"foreign_keys": [
|
|
92
|
-
f"{fk['column']} -> {fk['references']['table']}.{fk['references']['column']}"
|
|
93
|
-
for fk in table_info["foreign_keys"]
|
|
94
|
-
],
|
|
95
|
-
}
|
|
96
|
-
|
|
97
|
-
return json.dumps(formatted_info)
|
|
98
|
-
except Exception as e:
|
|
99
|
-
return json.dumps({"error": f"Error introspecting schema: {str(e)}"})
|
|
100
|
-
|
|
101
|
-
async def list_tables(self) -> str:
|
|
102
|
-
"""List all tables in the database with basic information."""
|
|
103
|
-
try:
|
|
104
|
-
tables_info = await self.schema_manager.list_tables()
|
|
105
|
-
return json.dumps(tables_info)
|
|
106
|
-
except Exception as e:
|
|
107
|
-
return json.dumps({"error": f"Error listing tables: {str(e)}"})
|
|
108
|
-
|
|
109
|
-
async def execute_sql(self, query: str, limit: int | None = None) -> str:
|
|
110
|
-
"""Execute a SQL query against the database."""
|
|
111
|
-
try:
|
|
112
|
-
# Security check - only allow SELECT queries unless write is enabled
|
|
113
|
-
write_error = self._validate_write_operation(query)
|
|
114
|
-
if write_error:
|
|
115
|
-
return json.dumps(
|
|
116
|
-
{
|
|
117
|
-
"error": write_error,
|
|
118
|
-
}
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
# Add LIMIT if not present and it's a SELECT query
|
|
122
|
-
query = self._add_limit_to_query(query, limit)
|
|
123
|
-
|
|
124
|
-
# Execute the query (wrapped in a transaction for safety)
|
|
125
|
-
results = await self.db.execute_query(query)
|
|
126
|
-
|
|
127
|
-
# Format results
|
|
128
|
-
actual_limit = limit if limit is not None else len(results)
|
|
129
|
-
|
|
130
|
-
return json.dumps(
|
|
131
|
-
{
|
|
132
|
-
"success": True,
|
|
133
|
-
"row_count": len(results),
|
|
134
|
-
"results": results[:actual_limit], # Extra safety for limit
|
|
135
|
-
"truncated": len(results) > actual_limit,
|
|
136
|
-
}
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
except Exception as e:
|
|
140
|
-
error_msg = str(e)
|
|
141
|
-
|
|
142
|
-
# Provide helpful error messages
|
|
143
|
-
suggestions = []
|
|
144
|
-
if "column" in error_msg.lower() and "does not exist" in error_msg.lower():
|
|
145
|
-
suggestions.append(
|
|
146
|
-
"Check column names using the schema introspection tool"
|
|
147
|
-
)
|
|
148
|
-
elif "table" in error_msg.lower() and "does not exist" in error_msg.lower():
|
|
149
|
-
suggestions.append(
|
|
150
|
-
"Check table names using the schema introspection tool"
|
|
151
|
-
)
|
|
152
|
-
elif "syntax error" in error_msg.lower():
|
|
153
|
-
suggestions.append(
|
|
154
|
-
"Review SQL syntax, especially JOIN conditions and WHERE clauses"
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
return json.dumps({"error": error_msg, "suggestions": suggestions})
|
|
74
|
+
def _init_tools(self) -> None:
|
|
75
|
+
"""Initialize SQL tools with database connection."""
|
|
76
|
+
# Get all SQL tools and set their database connection
|
|
77
|
+
for tool_name in tool_registry.list_tools(category="sql"):
|
|
78
|
+
tool = tool_registry.get_tool(tool_name)
|
|
79
|
+
if isinstance(tool, SQLTool):
|
|
80
|
+
tool.set_connection(self.db)
|
|
158
81
|
|
|
159
82
|
async def process_tool_call(
|
|
160
83
|
self, tool_name: str, tool_input: dict[str, Any]
|
|
161
84
|
) -> str:
|
|
162
85
|
"""Process a tool call and return the result."""
|
|
163
|
-
if tool_name == "list_tables":
|
|
164
|
-
return await self.list_tables()
|
|
165
|
-
elif tool_name == "introspect_schema":
|
|
166
|
-
return await self.introspect_schema(tool_input.get("table_pattern"))
|
|
167
|
-
elif tool_name == "execute_sql":
|
|
168
|
-
return await self.execute_sql(
|
|
169
|
-
tool_input["query"], tool_input.get("limit", 100)
|
|
170
|
-
)
|
|
171
|
-
elif tool_name == "plot_data":
|
|
172
|
-
return await self.plot_data(
|
|
173
|
-
y_values=tool_input["y_values"],
|
|
174
|
-
x_values=tool_input.get("x_values"),
|
|
175
|
-
plot_type=tool_input.get("plot_type", "line"),
|
|
176
|
-
title=tool_input.get("title"),
|
|
177
|
-
x_label=tool_input.get("x_label"),
|
|
178
|
-
y_label=tool_input.get("y_label"),
|
|
179
|
-
)
|
|
180
|
-
else:
|
|
181
|
-
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
|
182
|
-
|
|
183
|
-
def _validate_write_operation(self, query: str) -> str | None:
|
|
184
|
-
"""Validate if a write operation is allowed.
|
|
185
|
-
|
|
186
|
-
Returns:
|
|
187
|
-
None if operation is allowed, error message if not allowed.
|
|
188
|
-
"""
|
|
189
|
-
query_upper = query.strip().upper()
|
|
190
|
-
|
|
191
|
-
# Check for write operations
|
|
192
|
-
write_keywords = [
|
|
193
|
-
"INSERT",
|
|
194
|
-
"UPDATE",
|
|
195
|
-
"DELETE",
|
|
196
|
-
"DROP",
|
|
197
|
-
"CREATE",
|
|
198
|
-
"ALTER",
|
|
199
|
-
"TRUNCATE",
|
|
200
|
-
]
|
|
201
|
-
is_write_query = any(query_upper.startswith(kw) for kw in write_keywords)
|
|
202
|
-
|
|
203
|
-
if is_write_query:
|
|
204
|
-
return (
|
|
205
|
-
"Write operations are not allowed. Only SELECT queries are permitted."
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
return None
|
|
209
|
-
|
|
210
|
-
def _add_limit_to_query(self, query: str, limit: int = 100) -> str:
|
|
211
|
-
"""Add LIMIT clause to SELECT queries if not present."""
|
|
212
|
-
query_upper = query.strip().upper()
|
|
213
|
-
if query_upper.startswith("SELECT") and "LIMIT" not in query_upper:
|
|
214
|
-
return f"{query.rstrip(';')} LIMIT {limit};"
|
|
215
|
-
return query
|
|
216
|
-
|
|
217
|
-
async def plot_data(
|
|
218
|
-
self,
|
|
219
|
-
y_values: list[float],
|
|
220
|
-
x_values: list[float] | None = None,
|
|
221
|
-
plot_type: str = "line",
|
|
222
|
-
title: str | None = None,
|
|
223
|
-
x_label: str | None = None,
|
|
224
|
-
y_label: str | None = None,
|
|
225
|
-
) -> str:
|
|
226
|
-
"""Create a terminal plot using uniplot.
|
|
227
|
-
|
|
228
|
-
Args:
|
|
229
|
-
y_values: Y-axis data points
|
|
230
|
-
x_values: X-axis data points (optional)
|
|
231
|
-
plot_type: Type of plot - "line", "scatter", or "histogram"
|
|
232
|
-
title: Plot title
|
|
233
|
-
x_label: X-axis label
|
|
234
|
-
y_label: Y-axis label
|
|
235
|
-
|
|
236
|
-
Returns:
|
|
237
|
-
JSON string with success status and plot details
|
|
238
|
-
"""
|
|
239
86
|
try:
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
try:
|
|
246
|
-
y_values = [float(v) if v is not None else None for v in y_values]
|
|
247
|
-
if x_values:
|
|
248
|
-
x_values = [float(v) if v is not None else None for v in x_values]
|
|
249
|
-
except (ValueError, TypeError) as e:
|
|
250
|
-
return json.dumps({"error": f"Invalid data format: {str(e)}"})
|
|
251
|
-
|
|
252
|
-
# Create the plot
|
|
253
|
-
if plot_type == "histogram":
|
|
254
|
-
# For histogram, we only need y_values
|
|
255
|
-
histogram(
|
|
256
|
-
y_values,
|
|
257
|
-
title=title,
|
|
258
|
-
bins=min(20, len(set(y_values))), # Adaptive bin count
|
|
259
|
-
)
|
|
260
|
-
plot_info = {
|
|
261
|
-
"type": "histogram",
|
|
262
|
-
"data_points": len(y_values),
|
|
263
|
-
"title": title or "Histogram",
|
|
264
|
-
}
|
|
265
|
-
elif plot_type in ["line", "scatter"]:
|
|
266
|
-
# For line/scatter plots
|
|
267
|
-
plot_kwargs = {
|
|
268
|
-
"ys": y_values,
|
|
269
|
-
"title": title,
|
|
270
|
-
"lines": plot_type == "line",
|
|
271
|
-
}
|
|
272
|
-
|
|
273
|
-
if x_values:
|
|
274
|
-
plot_kwargs["xs"] = x_values
|
|
275
|
-
if x_label:
|
|
276
|
-
plot_kwargs["x_unit"] = x_label
|
|
277
|
-
if y_label:
|
|
278
|
-
plot_kwargs["y_unit"] = y_label
|
|
279
|
-
|
|
280
|
-
plot(**plot_kwargs)
|
|
281
|
-
|
|
282
|
-
plot_info = {
|
|
283
|
-
"type": plot_type,
|
|
284
|
-
"data_points": len(y_values),
|
|
285
|
-
"title": title or f"{plot_type.capitalize()} Plot",
|
|
286
|
-
"has_x_values": x_values is not None,
|
|
287
|
-
}
|
|
288
|
-
else:
|
|
289
|
-
return json.dumps({"error": f"Unsupported plot type: {plot_type}"})
|
|
290
|
-
|
|
87
|
+
tool = tool_registry.get_tool(tool_name)
|
|
88
|
+
return await tool.execute(**tool_input)
|
|
89
|
+
except KeyError:
|
|
90
|
+
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
|
91
|
+
except Exception as e:
|
|
291
92
|
return json.dumps(
|
|
292
|
-
{"
|
|
93
|
+
{"error": f"Error executing tool '{tool_name}': {str(e)}"}
|
|
293
94
|
)
|
|
294
95
|
|
|
295
|
-
except Exception as e:
|
|
296
|
-
return json.dumps({"error": f"Error creating plot: {str(e)}"})
|
|
297
|
-
|
|
298
96
|
# Conversation persistence helpers
|
|
299
97
|
|
|
300
98
|
async def _ensure_conversation(self) -> None:
|
sqlsaber/mcp/mcp.py
CHANGED
|
@@ -7,25 +7,17 @@ from fastmcp import FastMCP
|
|
|
7
7
|
from sqlsaber.agents.mcp import MCPSQLAgent
|
|
8
8
|
from sqlsaber.config.database import DatabaseConfigManager
|
|
9
9
|
from sqlsaber.database.connection import DatabaseConnection
|
|
10
|
+
from sqlsaber.tools import SQLTool, tool_registry
|
|
11
|
+
from sqlsaber.tools.instructions import InstructionBuilder
|
|
10
12
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
+
# Initialize the instruction builder
|
|
14
|
+
instruction_builder = InstructionBuilder(tool_registry)
|
|
13
15
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
- Call `introspect_schema()` to introspect database schema to understand table structures.
|
|
17
|
-
- Call `execute_sql()` to execute SQL queries against the database and retrieve results.
|
|
16
|
+
# Generate dynamic instructions
|
|
17
|
+
DYNAMIC_INSTRUCTIONS = instruction_builder.build_mcp_instructions()
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
- Use table patterns like 'sample%' or '%experiment%' to filter related tables
|
|
22
|
-
- Use proper JOIN syntax and avoid cartesian products
|
|
23
|
-
- Include appropriate WHERE clauses to limit results
|
|
24
|
-
- Handle errors gracefully and suggest fixes
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
# Create the FastMCP server instance
|
|
28
|
-
mcp = FastMCP(name="SQL Assistant", instructions=INSTRUCTIONS)
|
|
19
|
+
# Create the FastMCP server instance with dynamic instructions
|
|
20
|
+
mcp = FastMCP(name="SQL Assistant", instructions=DYNAMIC_INSTRUCTIONS)
|
|
29
21
|
|
|
30
22
|
# Initialize the database config manager
|
|
31
23
|
config_manager = DatabaseConfigManager()
|
|
@@ -70,10 +62,16 @@ def get_databases() -> dict:
|
|
|
70
62
|
return {"databases": databases, "count": len(databases)}
|
|
71
63
|
|
|
72
64
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
65
|
+
async def _execute_with_connection(tool_name: str, database: str, **kwargs) -> str:
|
|
66
|
+
"""Execute a SQL tool with database connection management.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
tool_name: Name of the tool to execute
|
|
70
|
+
database: Database name to connect to
|
|
71
|
+
**kwargs: Tool-specific parameters
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
JSON string with the tool's output
|
|
77
75
|
"""
|
|
78
76
|
try:
|
|
79
77
|
agent = await _create_agent_for_database(database)
|
|
@@ -82,50 +80,44 @@ async def list_tables(database: str) -> str:
|
|
|
82
80
|
{"error": f"Database '{database}' not found or could not connect"}
|
|
83
81
|
)
|
|
84
82
|
|
|
85
|
-
|
|
83
|
+
# Get the tool and set up connection
|
|
84
|
+
tool = tool_registry.get_tool(tool_name)
|
|
85
|
+
if isinstance(tool, SQLTool):
|
|
86
|
+
tool.set_connection(agent.db)
|
|
87
|
+
|
|
88
|
+
# Execute the tool
|
|
89
|
+
result = await tool.execute(**kwargs)
|
|
86
90
|
await agent.db.close()
|
|
87
91
|
return result
|
|
88
92
|
|
|
89
93
|
except Exception as e:
|
|
90
|
-
return json.dumps({"error": f"Error
|
|
94
|
+
return json.dumps({"error": f"Error in {tool_name}: {str(e)}"})
|
|
91
95
|
|
|
92
96
|
|
|
93
|
-
|
|
94
|
-
async def introspect_schema(database: str, table_pattern: str | None = None) -> str:
|
|
95
|
-
"""
|
|
96
|
-
Introspect database schema to understand table structures. Use optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%').
|
|
97
|
-
"""
|
|
98
|
-
try:
|
|
99
|
-
agent = await _create_agent_for_database(database)
|
|
100
|
-
if not agent:
|
|
101
|
-
return json.dumps(
|
|
102
|
-
{"error": f"Database '{database}' not found or could not connect"}
|
|
103
|
-
)
|
|
97
|
+
# SQL Tool Wrappers with explicit signatures
|
|
104
98
|
|
|
105
|
-
result = await agent.introspect_schema(table_pattern)
|
|
106
|
-
await agent.db.close()
|
|
107
|
-
return result
|
|
108
99
|
|
|
109
|
-
|
|
110
|
-
|
|
100
|
+
@mcp.tool
|
|
101
|
+
async def list_tables(database: str) -> str:
|
|
102
|
+
"""Get a list of all tables in the database with row counts. Use this first to discover available tables."""
|
|
103
|
+
return await _execute_with_connection("list_tables", database)
|
|
111
104
|
|
|
112
105
|
|
|
113
106
|
@mcp.tool
|
|
114
|
-
async def
|
|
115
|
-
"""
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
{"error": f"Database '{database}' not found or could not connect"}
|
|
121
|
-
)
|
|
107
|
+
async def introspect_schema(database: str, table_pattern: str = None) -> str:
|
|
108
|
+
"""Introspect database schema to understand table structures."""
|
|
109
|
+
kwargs = {}
|
|
110
|
+
if table_pattern is not None:
|
|
111
|
+
kwargs["table_pattern"] = table_pattern
|
|
112
|
+
return await _execute_with_connection("introspect_schema", database, **kwargs)
|
|
122
113
|
|
|
123
|
-
result = await agent.execute_sql(query, limit)
|
|
124
|
-
await agent.db.close()
|
|
125
|
-
return result
|
|
126
114
|
|
|
127
|
-
|
|
128
|
-
|
|
115
|
+
@mcp.tool
|
|
116
|
+
async def execute_sql(database: str, query: str, limit: int = 100) -> str:
|
|
117
|
+
"""Execute a SQL query against the database."""
|
|
118
|
+
return await _execute_with_connection(
|
|
119
|
+
"execute_sql", database, query=query, limit=limit
|
|
120
|
+
)
|
|
129
121
|
|
|
130
122
|
|
|
131
123
|
def main():
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""SQLSaber tools module."""
|
|
2
|
+
|
|
3
|
+
from .base import Tool
|
|
4
|
+
from .enums import ToolCategory, WorkflowPosition
|
|
5
|
+
from .instructions import InstructionBuilder
|
|
6
|
+
from .registry import ToolRegistry, register_tool, tool_registry
|
|
7
|
+
|
|
8
|
+
# Import concrete tools to register them
|
|
9
|
+
from .sql_tools import ExecuteSQLTool, IntrospectSchemaTool, ListTablesTool, SQLTool
|
|
10
|
+
from .visualization_tools import PlotDataTool
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"Tool",
|
|
14
|
+
"ToolCategory",
|
|
15
|
+
"WorkflowPosition",
|
|
16
|
+
"ToolRegistry",
|
|
17
|
+
"tool_registry",
|
|
18
|
+
"register_tool",
|
|
19
|
+
"InstructionBuilder",
|
|
20
|
+
"SQLTool",
|
|
21
|
+
"ListTablesTool",
|
|
22
|
+
"IntrospectSchemaTool",
|
|
23
|
+
"ExecuteSQLTool",
|
|
24
|
+
"PlotDataTool",
|
|
25
|
+
]
|