sqlsaber 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/anthropic.py +19 -113
- sqlsaber/agents/base.py +120 -3
- sqlsaber/agents/mcp.py +21 -0
- sqlsaber/agents/streaming.py +0 -10
- sqlsaber/cli/commands.py +28 -10
- sqlsaber/cli/database.py +1 -1
- sqlsaber/config/database.py +25 -3
- sqlsaber/database/connection.py +129 -0
- sqlsaber/database/schema.py +92 -68
- sqlsaber/mcp/__init__.py +5 -0
- sqlsaber/mcp/mcp.py +138 -0
- {sqlsaber-0.2.0.dist-info → sqlsaber-0.4.0.dist-info}/METADATA +41 -1
- {sqlsaber-0.2.0.dist-info → sqlsaber-0.4.0.dist-info}/RECORD +16 -13
- {sqlsaber-0.2.0.dist-info → sqlsaber-0.4.0.dist-info}/entry_points.txt +2 -0
- {sqlsaber-0.2.0.dist-info → sqlsaber-0.4.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.2.0.dist-info → sqlsaber-0.4.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/agents/anthropic.py
CHANGED
|
@@ -11,13 +11,7 @@ from sqlsaber.agents.streaming import (
|
|
|
11
11
|
build_tool_result_block,
|
|
12
12
|
)
|
|
13
13
|
from sqlsaber.config.settings import Config
|
|
14
|
-
from sqlsaber.database.connection import
|
|
15
|
-
BaseDatabaseConnection,
|
|
16
|
-
MySQLConnection,
|
|
17
|
-
PostgreSQLConnection,
|
|
18
|
-
SQLiteConnection,
|
|
19
|
-
)
|
|
20
|
-
from sqlsaber.database.schema import SchemaManager
|
|
14
|
+
from sqlsaber.database.connection import BaseDatabaseConnection
|
|
21
15
|
from sqlsaber.memory.manager import MemoryManager
|
|
22
16
|
from sqlsaber.models.events import StreamEvent
|
|
23
17
|
from sqlsaber.models.types import ToolDefinition
|
|
@@ -36,7 +30,6 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
36
30
|
|
|
37
31
|
self.client = AsyncAnthropic(api_key=config.api_key)
|
|
38
32
|
self.model = config.model_name.replace("anthropic:", "")
|
|
39
|
-
self.schema_manager = SchemaManager(db_connection)
|
|
40
33
|
|
|
41
34
|
self.database_name = database_name
|
|
42
35
|
self.memory_manager = MemoryManager()
|
|
@@ -94,17 +87,6 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
94
87
|
# Build system prompt with memories if available
|
|
95
88
|
self.system_prompt = self._build_system_prompt()
|
|
96
89
|
|
|
97
|
-
def _get_database_type_name(self) -> str:
|
|
98
|
-
"""Get the human-readable database type name."""
|
|
99
|
-
if isinstance(self.db, PostgreSQLConnection):
|
|
100
|
-
return "PostgreSQL"
|
|
101
|
-
elif isinstance(self.db, MySQLConnection):
|
|
102
|
-
return "MySQL"
|
|
103
|
-
elif isinstance(self.db, SQLiteConnection):
|
|
104
|
-
return "SQLite"
|
|
105
|
-
else:
|
|
106
|
-
return "database" # Fallback
|
|
107
|
-
|
|
108
90
|
def _build_system_prompt(self) -> str:
|
|
109
91
|
"""Build system prompt with optional memory context."""
|
|
110
92
|
db_type = self._get_database_type_name()
|
|
@@ -152,109 +134,33 @@ Guidelines:
|
|
|
152
134
|
self.system_prompt = self._build_system_prompt()
|
|
153
135
|
return memory.id
|
|
154
136
|
|
|
155
|
-
async def introspect_schema(self, table_pattern: Optional[str] = None) -> str:
|
|
156
|
-
"""Introspect database schema to understand table structures."""
|
|
157
|
-
try:
|
|
158
|
-
# Pass table_pattern to get_schema_info for efficient filtering at DB level
|
|
159
|
-
schema_info = await self.schema_manager.get_schema_info(table_pattern)
|
|
160
|
-
|
|
161
|
-
# Format the schema information
|
|
162
|
-
formatted_info = {}
|
|
163
|
-
for table_name, table_info in schema_info.items():
|
|
164
|
-
formatted_info[table_name] = {
|
|
165
|
-
"columns": {
|
|
166
|
-
col_name: {
|
|
167
|
-
"type": col_info["data_type"],
|
|
168
|
-
"nullable": col_info["nullable"],
|
|
169
|
-
"default": col_info["default"],
|
|
170
|
-
}
|
|
171
|
-
for col_name, col_info in table_info["columns"].items()
|
|
172
|
-
},
|
|
173
|
-
"primary_keys": table_info["primary_keys"],
|
|
174
|
-
"foreign_keys": [
|
|
175
|
-
f"{fk['column']} -> {fk['references']['table']}.{fk['references']['column']}"
|
|
176
|
-
for fk in table_info["foreign_keys"]
|
|
177
|
-
],
|
|
178
|
-
}
|
|
179
|
-
|
|
180
|
-
return json.dumps(formatted_info)
|
|
181
|
-
except Exception as e:
|
|
182
|
-
return json.dumps({"error": f"Error introspecting schema: {str(e)}"})
|
|
183
|
-
|
|
184
|
-
async def list_tables(self) -> str:
|
|
185
|
-
"""List all tables in the database with basic information."""
|
|
186
|
-
try:
|
|
187
|
-
tables_info = await self.schema_manager.list_tables()
|
|
188
|
-
return json.dumps(tables_info)
|
|
189
|
-
except Exception as e:
|
|
190
|
-
return json.dumps({"error": f"Error listing tables: {str(e)}"})
|
|
191
|
-
|
|
192
137
|
async def execute_sql(self, query: str, limit: Optional[int] = 100) -> str:
|
|
193
|
-
"""Execute a SQL query against the database."""
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
write_error = self._validate_write_operation(query)
|
|
197
|
-
if write_error:
|
|
198
|
-
return json.dumps(
|
|
199
|
-
{
|
|
200
|
-
"error": write_error,
|
|
201
|
-
}
|
|
202
|
-
)
|
|
203
|
-
|
|
204
|
-
# Add LIMIT if not present and it's a SELECT query
|
|
205
|
-
query = self._add_limit_to_query(query, limit)
|
|
206
|
-
|
|
207
|
-
# Execute the query (wrapped in a transaction for safety)
|
|
208
|
-
results = await self.db.execute_query(query)
|
|
209
|
-
|
|
210
|
-
# Format results - but also store the actual data
|
|
211
|
-
actual_limit = limit if limit is not None else len(results)
|
|
212
|
-
self._last_results = results[:actual_limit]
|
|
213
|
-
self._last_query = query
|
|
138
|
+
"""Execute a SQL query against the database with streaming support."""
|
|
139
|
+
# Call parent implementation for core functionality
|
|
140
|
+
result = await super().execute_sql(query, limit)
|
|
214
141
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
except Exception as e:
|
|
225
|
-
error_msg = str(e)
|
|
226
|
-
|
|
227
|
-
# Provide helpful error messages
|
|
228
|
-
suggestions = []
|
|
229
|
-
if "column" in error_msg.lower() and "does not exist" in error_msg.lower():
|
|
230
|
-
suggestions.append(
|
|
231
|
-
"Check column names using the schema introspection tool"
|
|
232
|
-
)
|
|
233
|
-
elif "table" in error_msg.lower() and "does not exist" in error_msg.lower():
|
|
234
|
-
suggestions.append(
|
|
235
|
-
"Check table names using the schema introspection tool"
|
|
236
|
-
)
|
|
237
|
-
elif "syntax error" in error_msg.lower():
|
|
238
|
-
suggestions.append(
|
|
239
|
-
"Review SQL syntax, especially JOIN conditions and WHERE clauses"
|
|
142
|
+
# Parse result to extract data for streaming (AnthropicSQLAgent specific)
|
|
143
|
+
try:
|
|
144
|
+
result_data = json.loads(result)
|
|
145
|
+
if result_data.get("success") and "results" in result_data:
|
|
146
|
+
# Store results for streaming
|
|
147
|
+
actual_limit = (
|
|
148
|
+
limit if limit is not None else len(result_data["results"])
|
|
240
149
|
)
|
|
150
|
+
self._last_results = result_data["results"][:actual_limit]
|
|
151
|
+
self._last_query = query
|
|
152
|
+
except (json.JSONDecodeError, KeyError):
|
|
153
|
+
# If we can't parse the result, just continue without storing
|
|
154
|
+
pass
|
|
241
155
|
|
|
242
|
-
|
|
156
|
+
return result
|
|
243
157
|
|
|
244
158
|
async def process_tool_call(
|
|
245
159
|
self, tool_name: str, tool_input: Dict[str, Any]
|
|
246
160
|
) -> str:
|
|
247
161
|
"""Process a tool call and return the result."""
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
elif tool_name == "introspect_schema":
|
|
251
|
-
return await self.introspect_schema(tool_input.get("table_pattern"))
|
|
252
|
-
elif tool_name == "execute_sql":
|
|
253
|
-
return await self.execute_sql(
|
|
254
|
-
tool_input["query"], tool_input.get("limit", 100)
|
|
255
|
-
)
|
|
256
|
-
else:
|
|
257
|
-
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
|
162
|
+
# Use parent implementation for core tools
|
|
163
|
+
return await super().process_tool_call(tool_name, tool_input)
|
|
258
164
|
|
|
259
165
|
async def _process_stream_events(
|
|
260
166
|
self, stream, content_blocks: List[Dict], tool_use_blocks: List[Dict]
|
sqlsaber/agents/base.py
CHANGED
|
@@ -1,9 +1,17 @@
|
|
|
1
1
|
"""Abstract base class for SQL agents."""
|
|
2
2
|
|
|
3
|
+
import json
|
|
3
4
|
from abc import ABC, abstractmethod
|
|
4
5
|
from typing import Any, AsyncIterator, Dict, List, Optional
|
|
5
6
|
|
|
6
|
-
from sqlsaber.database.connection import
|
|
7
|
+
from sqlsaber.database.connection import (
|
|
8
|
+
BaseDatabaseConnection,
|
|
9
|
+
CSVConnection,
|
|
10
|
+
MySQLConnection,
|
|
11
|
+
PostgreSQLConnection,
|
|
12
|
+
SQLiteConnection,
|
|
13
|
+
)
|
|
14
|
+
from sqlsaber.database.schema import SchemaManager
|
|
7
15
|
from sqlsaber.models.events import StreamEvent
|
|
8
16
|
|
|
9
17
|
|
|
@@ -12,6 +20,7 @@ class BaseSQLAgent(ABC):
|
|
|
12
20
|
|
|
13
21
|
def __init__(self, db_connection: BaseDatabaseConnection):
|
|
14
22
|
self.db = db_connection
|
|
23
|
+
self.schema_manager = SchemaManager(db_connection)
|
|
15
24
|
self.conversation_history: List[Dict[str, Any]] = []
|
|
16
25
|
|
|
17
26
|
@abstractmethod
|
|
@@ -25,12 +34,120 @@ class BaseSQLAgent(ABC):
|
|
|
25
34
|
"""Clear conversation history."""
|
|
26
35
|
self.conversation_history = []
|
|
27
36
|
|
|
28
|
-
|
|
37
|
+
def _get_database_type_name(self) -> str:
|
|
38
|
+
"""Get the human-readable database type name."""
|
|
39
|
+
if isinstance(self.db, PostgreSQLConnection):
|
|
40
|
+
return "PostgreSQL"
|
|
41
|
+
elif isinstance(self.db, MySQLConnection):
|
|
42
|
+
return "MySQL"
|
|
43
|
+
elif isinstance(self.db, SQLiteConnection):
|
|
44
|
+
return "SQLite"
|
|
45
|
+
elif isinstance(self.db, CSVConnection):
|
|
46
|
+
return "SQLite" # we convert csv to in-memory sqlite
|
|
47
|
+
else:
|
|
48
|
+
return "database" # Fallback
|
|
49
|
+
|
|
50
|
+
async def introspect_schema(self, table_pattern: Optional[str] = None) -> str:
|
|
51
|
+
"""Introspect database schema to understand table structures."""
|
|
52
|
+
try:
|
|
53
|
+
# Pass table_pattern to get_schema_info for efficient filtering at DB level
|
|
54
|
+
schema_info = await self.schema_manager.get_schema_info(table_pattern)
|
|
55
|
+
|
|
56
|
+
# Format the schema information
|
|
57
|
+
formatted_info = {}
|
|
58
|
+
for table_name, table_info in schema_info.items():
|
|
59
|
+
formatted_info[table_name] = {
|
|
60
|
+
"columns": {
|
|
61
|
+
col_name: {
|
|
62
|
+
"type": col_info["data_type"],
|
|
63
|
+
"nullable": col_info["nullable"],
|
|
64
|
+
"default": col_info["default"],
|
|
65
|
+
}
|
|
66
|
+
for col_name, col_info in table_info["columns"].items()
|
|
67
|
+
},
|
|
68
|
+
"primary_keys": table_info["primary_keys"],
|
|
69
|
+
"foreign_keys": [
|
|
70
|
+
f"{fk['column']} -> {fk['references']['table']}.{fk['references']['column']}"
|
|
71
|
+
for fk in table_info["foreign_keys"]
|
|
72
|
+
],
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
return json.dumps(formatted_info)
|
|
76
|
+
except Exception as e:
|
|
77
|
+
return json.dumps({"error": f"Error introspecting schema: {str(e)}"})
|
|
78
|
+
|
|
79
|
+
async def list_tables(self) -> str:
|
|
80
|
+
"""List all tables in the database with basic information."""
|
|
81
|
+
try:
|
|
82
|
+
tables_info = await self.schema_manager.list_tables()
|
|
83
|
+
return json.dumps(tables_info)
|
|
84
|
+
except Exception as e:
|
|
85
|
+
return json.dumps({"error": f"Error listing tables: {str(e)}"})
|
|
86
|
+
|
|
87
|
+
async def execute_sql(self, query: str, limit: Optional[int] = 100) -> str:
|
|
88
|
+
"""Execute a SQL query against the database."""
|
|
89
|
+
try:
|
|
90
|
+
# Security check - only allow SELECT queries unless write is enabled
|
|
91
|
+
write_error = self._validate_write_operation(query)
|
|
92
|
+
if write_error:
|
|
93
|
+
return json.dumps(
|
|
94
|
+
{
|
|
95
|
+
"error": write_error,
|
|
96
|
+
}
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Add LIMIT if not present and it's a SELECT query
|
|
100
|
+
query = self._add_limit_to_query(query, limit)
|
|
101
|
+
|
|
102
|
+
# Execute the query (wrapped in a transaction for safety)
|
|
103
|
+
results = await self.db.execute_query(query)
|
|
104
|
+
|
|
105
|
+
# Format results
|
|
106
|
+
actual_limit = limit if limit is not None else len(results)
|
|
107
|
+
|
|
108
|
+
return json.dumps(
|
|
109
|
+
{
|
|
110
|
+
"success": True,
|
|
111
|
+
"row_count": len(results),
|
|
112
|
+
"results": results[:actual_limit], # Extra safety for limit
|
|
113
|
+
"truncated": len(results) > actual_limit,
|
|
114
|
+
}
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
except Exception as e:
|
|
118
|
+
error_msg = str(e)
|
|
119
|
+
|
|
120
|
+
# Provide helpful error messages
|
|
121
|
+
suggestions = []
|
|
122
|
+
if "column" in error_msg.lower() and "does not exist" in error_msg.lower():
|
|
123
|
+
suggestions.append(
|
|
124
|
+
"Check column names using the schema introspection tool"
|
|
125
|
+
)
|
|
126
|
+
elif "table" in error_msg.lower() and "does not exist" in error_msg.lower():
|
|
127
|
+
suggestions.append(
|
|
128
|
+
"Check table names using the schema introspection tool"
|
|
129
|
+
)
|
|
130
|
+
elif "syntax error" in error_msg.lower():
|
|
131
|
+
suggestions.append(
|
|
132
|
+
"Review SQL syntax, especially JOIN conditions and WHERE clauses"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return json.dumps({"error": error_msg, "suggestions": suggestions})
|
|
136
|
+
|
|
29
137
|
async def process_tool_call(
|
|
30
138
|
self, tool_name: str, tool_input: Dict[str, Any]
|
|
31
139
|
) -> str:
|
|
32
140
|
"""Process a tool call and return the result."""
|
|
33
|
-
|
|
141
|
+
if tool_name == "list_tables":
|
|
142
|
+
return await self.list_tables()
|
|
143
|
+
elif tool_name == "introspect_schema":
|
|
144
|
+
return await self.introspect_schema(tool_input.get("table_pattern"))
|
|
145
|
+
elif tool_name == "execute_sql":
|
|
146
|
+
return await self.execute_sql(
|
|
147
|
+
tool_input["query"], tool_input.get("limit", 100)
|
|
148
|
+
)
|
|
149
|
+
else:
|
|
150
|
+
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
|
34
151
|
|
|
35
152
|
def _validate_write_operation(self, query: str) -> Optional[str]:
|
|
36
153
|
"""Validate if a write operation is allowed.
|
sqlsaber/agents/mcp.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Generic SQL agent implementation for MCP tools."""
|
|
2
|
+
|
|
3
|
+
from typing import AsyncIterator
|
|
4
|
+
from sqlsaber.agents.base import BaseSQLAgent
|
|
5
|
+
from sqlsaber.database.connection import BaseDatabaseConnection
|
|
6
|
+
from sqlsaber.models.events import StreamEvent
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MCPSQLAgent(BaseSQLAgent):
|
|
10
|
+
"""MCP SQL Agent for MCP tool operations without LLM-specific logic."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, db_connection: BaseDatabaseConnection):
|
|
13
|
+
super().__init__(db_connection)
|
|
14
|
+
|
|
15
|
+
async def query_stream(
|
|
16
|
+
self, user_query: str, use_history: bool = True
|
|
17
|
+
) -> AsyncIterator[StreamEvent]:
|
|
18
|
+
"""Not implemented for generic agent as it's only used for tool operations."""
|
|
19
|
+
raise NotImplementedError(
|
|
20
|
+
"MCPSQLAgent does not support query streaming. Use specific agent implementations for conversation."
|
|
21
|
+
)
|
sqlsaber/agents/streaming.py
CHANGED
|
@@ -14,13 +14,3 @@ class StreamingResponse:
|
|
|
14
14
|
def build_tool_result_block(tool_use_id: str, content: str) -> Dict[str, Any]:
|
|
15
15
|
"""Build a tool result block for the conversation."""
|
|
16
16
|
return {"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def extract_sql_from_text(text: str) -> str:
|
|
20
|
-
"""Extract SQL query from markdown-formatted text."""
|
|
21
|
-
if "```sql" in text:
|
|
22
|
-
sql_start = text.find("```sql") + 6
|
|
23
|
-
sql_end = text.find("```", sql_start)
|
|
24
|
-
if sql_end > sql_start:
|
|
25
|
-
return text[sql_start:sql_end].strip()
|
|
26
|
-
return ""
|
sqlsaber/cli/commands.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""CLI command definitions and handlers."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
from pathlib import Path
|
|
4
5
|
from typing import Optional
|
|
5
6
|
|
|
6
7
|
import typer
|
|
@@ -62,15 +63,31 @@ def query(
|
|
|
62
63
|
"""Run a query against the database or start interactive mode."""
|
|
63
64
|
|
|
64
65
|
async def run_session():
|
|
65
|
-
# Get database configuration
|
|
66
|
+
# Get database configuration or handle direct CSV file
|
|
66
67
|
if database:
|
|
67
|
-
|
|
68
|
-
if
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
68
|
+
# Check if this is a direct CSV file path
|
|
69
|
+
if database.endswith(".csv"):
|
|
70
|
+
csv_path = Path(database).expanduser().resolve()
|
|
71
|
+
if not csv_path.exists():
|
|
72
|
+
console.print(
|
|
73
|
+
f"[bold red]Error:[/bold red] CSV file '{database}' not found."
|
|
74
|
+
)
|
|
75
|
+
raise typer.Exit(1)
|
|
76
|
+
connection_string = f"csv:///{csv_path}"
|
|
77
|
+
db_name = csv_path.stem
|
|
78
|
+
else:
|
|
79
|
+
# Look up configured database connection
|
|
80
|
+
db_config = config_manager.get_database(database)
|
|
81
|
+
if not db_config:
|
|
82
|
+
console.print(
|
|
83
|
+
f"[bold red]Error:[/bold red] Database connection '{database}' not found."
|
|
84
|
+
)
|
|
85
|
+
console.print(
|
|
86
|
+
"Use 'sqlsaber db list' to see available connections."
|
|
87
|
+
)
|
|
88
|
+
raise typer.Exit(1)
|
|
89
|
+
connection_string = db_config.to_connection_string()
|
|
90
|
+
db_name = db_config.name
|
|
74
91
|
else:
|
|
75
92
|
db_config = config_manager.get_default_database()
|
|
76
93
|
if not db_config:
|
|
@@ -81,10 +98,11 @@ def query(
|
|
|
81
98
|
"Use 'sqlsaber db add <name>' to add a database connection."
|
|
82
99
|
)
|
|
83
100
|
raise typer.Exit(1)
|
|
101
|
+
connection_string = db_config.to_connection_string()
|
|
102
|
+
db_name = db_config.name
|
|
84
103
|
|
|
85
104
|
# Create database connection
|
|
86
105
|
try:
|
|
87
|
-
connection_string = db_config.to_connection_string()
|
|
88
106
|
db_conn = DatabaseConnection(connection_string)
|
|
89
107
|
except Exception as e:
|
|
90
108
|
console.print(
|
|
@@ -93,7 +111,7 @@ def query(
|
|
|
93
111
|
raise typer.Exit(1)
|
|
94
112
|
|
|
95
113
|
# Create agent instance with database name for memory context
|
|
96
|
-
agent = AnthropicSQLAgent(db_conn,
|
|
114
|
+
agent = AnthropicSQLAgent(db_conn, db_name)
|
|
97
115
|
|
|
98
116
|
try:
|
|
99
117
|
if query_text:
|
sqlsaber/cli/database.py
CHANGED
|
@@ -75,7 +75,7 @@ def add_database(
|
|
|
75
75
|
if type == "sqlite":
|
|
76
76
|
# SQLite only needs database path
|
|
77
77
|
database = database or questionary.path("Database file path:").ask()
|
|
78
|
-
database = str(Path(database).expanduser())
|
|
78
|
+
database = str(Path(database).expanduser().resolve())
|
|
79
79
|
host = "localhost"
|
|
80
80
|
port = 0
|
|
81
81
|
username = "sqlite"
|
sqlsaber/config/database.py
CHANGED
|
@@ -4,12 +4,12 @@ import json
|
|
|
4
4
|
import os
|
|
5
5
|
import platform
|
|
6
6
|
import stat
|
|
7
|
-
import keyring
|
|
8
7
|
from dataclasses import dataclass
|
|
9
8
|
from pathlib import Path
|
|
10
|
-
from typing import Dict, List, Optional
|
|
9
|
+
from typing import Any, Dict, List, Optional
|
|
11
10
|
from urllib.parse import quote_plus
|
|
12
11
|
|
|
12
|
+
import keyring
|
|
13
13
|
import platformdirs
|
|
14
14
|
|
|
15
15
|
|
|
@@ -18,7 +18,7 @@ class DatabaseConfig:
|
|
|
18
18
|
"""Database connection configuration."""
|
|
19
19
|
|
|
20
20
|
name: str
|
|
21
|
-
type: str # postgresql, mysql, sqlite
|
|
21
|
+
type: str # postgresql, mysql, sqlite, csv
|
|
22
22
|
host: Optional[str]
|
|
23
23
|
port: Optional[int]
|
|
24
24
|
database: str
|
|
@@ -90,6 +90,28 @@ class DatabaseConfig:
|
|
|
90
90
|
|
|
91
91
|
elif self.type == "sqlite":
|
|
92
92
|
return f"sqlite:///{self.database}"
|
|
93
|
+
elif self.type == "csv":
|
|
94
|
+
# For CSV files, database field contains the file path
|
|
95
|
+
base_url = f"csv:///{self.database}"
|
|
96
|
+
|
|
97
|
+
# Add CSV-specific parameters if they exist in schema field
|
|
98
|
+
if self.schema:
|
|
99
|
+
# Schema field can contain CSV options in JSON format
|
|
100
|
+
try:
|
|
101
|
+
csv_options = json.loads(self.schema)
|
|
102
|
+
params = []
|
|
103
|
+
if "delimiter" in csv_options:
|
|
104
|
+
params.append(f"delimiter={csv_options['delimiter']}")
|
|
105
|
+
if "encoding" in csv_options:
|
|
106
|
+
params.append(f"encoding={csv_options['encoding']}")
|
|
107
|
+
if "header" in csv_options:
|
|
108
|
+
params.append(f"header={str(csv_options['header']).lower()}")
|
|
109
|
+
|
|
110
|
+
if params:
|
|
111
|
+
return f"{base_url}?{'&'.join(params)}"
|
|
112
|
+
except (json.JSONDecodeError, KeyError):
|
|
113
|
+
pass
|
|
114
|
+
return base_url
|
|
93
115
|
else:
|
|
94
116
|
raise ValueError(f"Unsupported database type: {self.type}")
|
|
95
117
|
|
sqlsaber/database/connection.py
CHANGED
|
@@ -4,10 +4,12 @@ from abc import ABC, abstractmethod
|
|
|
4
4
|
from typing import Any, Dict, List, Optional
|
|
5
5
|
from urllib.parse import urlparse, parse_qs
|
|
6
6
|
import ssl
|
|
7
|
+
from pathlib import Path
|
|
7
8
|
|
|
8
9
|
import aiomysql
|
|
9
10
|
import aiosqlite
|
|
10
11
|
import asyncpg
|
|
12
|
+
import pandas as pd
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
class BaseDatabaseConnection(ABC):
|
|
@@ -272,6 +274,131 @@ class SQLiteConnection(BaseDatabaseConnection):
|
|
|
272
274
|
await conn.rollback()
|
|
273
275
|
|
|
274
276
|
|
|
277
|
+
class CSVConnection(BaseDatabaseConnection):
|
|
278
|
+
"""CSV file connection using in-memory SQLite database."""
|
|
279
|
+
|
|
280
|
+
def __init__(self, connection_string: str):
|
|
281
|
+
super().__init__(connection_string)
|
|
282
|
+
|
|
283
|
+
# Parse CSV file path from connection string
|
|
284
|
+
self.csv_path = connection_string.replace("csv:///", "")
|
|
285
|
+
|
|
286
|
+
# CSV parsing options
|
|
287
|
+
self.delimiter = ","
|
|
288
|
+
self.encoding = "utf-8"
|
|
289
|
+
self.has_header = True
|
|
290
|
+
|
|
291
|
+
# Parse additional options from connection string
|
|
292
|
+
parsed = urlparse(connection_string)
|
|
293
|
+
if parsed.query:
|
|
294
|
+
params = parse_qs(parsed.query)
|
|
295
|
+
self.delimiter = params.get("delimiter", [","])[0]
|
|
296
|
+
self.encoding = params.get("encoding", ["utf-8"])[0]
|
|
297
|
+
self.has_header = params.get("header", ["true"])[0].lower() == "true"
|
|
298
|
+
|
|
299
|
+
# Table name derived from filename
|
|
300
|
+
self.table_name = Path(self.csv_path).stem
|
|
301
|
+
|
|
302
|
+
# Initialize connection and flag to track if CSV is loaded
|
|
303
|
+
self._conn = None
|
|
304
|
+
self._csv_loaded = False
|
|
305
|
+
|
|
306
|
+
async def get_pool(self):
|
|
307
|
+
"""Get or create the in-memory database connection."""
|
|
308
|
+
if self._conn is None:
|
|
309
|
+
self._conn = await aiosqlite.connect(":memory:")
|
|
310
|
+
self._conn.row_factory = aiosqlite.Row
|
|
311
|
+
await self._load_csv_data()
|
|
312
|
+
return self._conn
|
|
313
|
+
|
|
314
|
+
async def close(self):
|
|
315
|
+
"""Close the database connection."""
|
|
316
|
+
if self._conn:
|
|
317
|
+
await self._conn.close()
|
|
318
|
+
self._conn = None
|
|
319
|
+
self._csv_loaded = False
|
|
320
|
+
|
|
321
|
+
async def _load_csv_data(self):
|
|
322
|
+
"""Load CSV data into the in-memory SQLite database."""
|
|
323
|
+
if self._csv_loaded or not self._conn:
|
|
324
|
+
return
|
|
325
|
+
|
|
326
|
+
try:
|
|
327
|
+
# Read CSV file using pandas
|
|
328
|
+
df = pd.read_csv(
|
|
329
|
+
self.csv_path,
|
|
330
|
+
delimiter=self.delimiter,
|
|
331
|
+
encoding=self.encoding,
|
|
332
|
+
header=0 if self.has_header else None,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# If no header, create column names
|
|
336
|
+
if not self.has_header:
|
|
337
|
+
df.columns = [f"column_{i}" for i in range(len(df.columns))]
|
|
338
|
+
|
|
339
|
+
# Create table with proper column types
|
|
340
|
+
columns_sql = []
|
|
341
|
+
for col in df.columns:
|
|
342
|
+
# Infer SQLite type from pandas dtype
|
|
343
|
+
dtype = df[col].dtype
|
|
344
|
+
if pd.api.types.is_integer_dtype(dtype):
|
|
345
|
+
sql_type = "INTEGER"
|
|
346
|
+
elif pd.api.types.is_float_dtype(dtype):
|
|
347
|
+
sql_type = "REAL"
|
|
348
|
+
elif pd.api.types.is_bool_dtype(dtype):
|
|
349
|
+
sql_type = "INTEGER" # SQLite doesn't have BOOLEAN
|
|
350
|
+
else:
|
|
351
|
+
sql_type = "TEXT"
|
|
352
|
+
|
|
353
|
+
columns_sql.append(f'"{col}" {sql_type}')
|
|
354
|
+
|
|
355
|
+
create_table_sql = (
|
|
356
|
+
f'CREATE TABLE "{self.table_name}" ({", ".join(columns_sql)})'
|
|
357
|
+
)
|
|
358
|
+
await self._conn.execute(create_table_sql)
|
|
359
|
+
|
|
360
|
+
# Insert data row by row
|
|
361
|
+
placeholders = ", ".join(["?" for _ in df.columns])
|
|
362
|
+
insert_sql = f'INSERT INTO "{self.table_name}" VALUES ({placeholders})'
|
|
363
|
+
|
|
364
|
+
for _, row in df.iterrows():
|
|
365
|
+
# Convert pandas values to Python native types
|
|
366
|
+
values = []
|
|
367
|
+
for val in row:
|
|
368
|
+
if pd.isna(val):
|
|
369
|
+
values.append(None)
|
|
370
|
+
elif isinstance(val, (pd.Timestamp, pd.Timedelta)):
|
|
371
|
+
values.append(str(val))
|
|
372
|
+
else:
|
|
373
|
+
values.append(val)
|
|
374
|
+
|
|
375
|
+
await self._conn.execute(insert_sql, values)
|
|
376
|
+
|
|
377
|
+
await self._conn.commit()
|
|
378
|
+
self._csv_loaded = True
|
|
379
|
+
|
|
380
|
+
except Exception as e:
|
|
381
|
+
raise ValueError(f"Error loading CSV file '{self.csv_path}': {str(e)}")
|
|
382
|
+
|
|
383
|
+
async def execute_query(self, query: str, *args) -> List[Dict[str, Any]]:
|
|
384
|
+
"""Execute a query and return results as list of dicts.
|
|
385
|
+
|
|
386
|
+
All queries run in a transaction that is rolled back at the end,
|
|
387
|
+
ensuring no changes are persisted to the database.
|
|
388
|
+
"""
|
|
389
|
+
conn = await self.get_pool()
|
|
390
|
+
|
|
391
|
+
# Start transaction
|
|
392
|
+
await conn.execute("BEGIN")
|
|
393
|
+
try:
|
|
394
|
+
cursor = await conn.execute(query, args if args else ())
|
|
395
|
+
rows = await cursor.fetchall()
|
|
396
|
+
return [dict(row) for row in rows]
|
|
397
|
+
finally:
|
|
398
|
+
# Always rollback to ensure no changes are committed
|
|
399
|
+
await conn.rollback()
|
|
400
|
+
|
|
401
|
+
|
|
275
402
|
def DatabaseConnection(connection_string: str) -> BaseDatabaseConnection:
|
|
276
403
|
"""Factory function to create appropriate database connection based on connection string."""
|
|
277
404
|
if connection_string.startswith("postgresql://"):
|
|
@@ -280,6 +407,8 @@ def DatabaseConnection(connection_string: str) -> BaseDatabaseConnection:
|
|
|
280
407
|
return MySQLConnection(connection_string)
|
|
281
408
|
elif connection_string.startswith("sqlite:///"):
|
|
282
409
|
return SQLiteConnection(connection_string)
|
|
410
|
+
elif connection_string.startswith("csv:///"):
|
|
411
|
+
return CSVConnection(connection_string)
|
|
283
412
|
else:
|
|
284
413
|
raise ValueError(
|
|
285
414
|
f"Unsupported database type in connection string: {connection_string}"
|
sqlsaber/database/schema.py
CHANGED
|
@@ -8,6 +8,7 @@ import aiosqlite
|
|
|
8
8
|
|
|
9
9
|
from sqlsaber.database.connection import (
|
|
10
10
|
BaseDatabaseConnection,
|
|
11
|
+
CSVConnection,
|
|
11
12
|
MySQLConnection,
|
|
12
13
|
PostgreSQLConnection,
|
|
13
14
|
SQLiteConnection,
|
|
@@ -200,8 +201,8 @@ class PostgreSQLSchemaIntrospector(BaseSchemaIntrospector):
|
|
|
200
201
|
t.table_type,
|
|
201
202
|
COALESCE(ts.approximate_row_count, 0) as row_count
|
|
202
203
|
FROM information_schema.tables t
|
|
203
|
-
LEFT JOIN table_stats ts
|
|
204
|
-
ON t.table_schema = ts.schemaname
|
|
204
|
+
LEFT JOIN table_stats ts
|
|
205
|
+
ON t.table_schema = ts.schemaname
|
|
205
206
|
AND t.table_name = ts.tablename
|
|
206
207
|
WHERE t.table_schema NOT IN ('pg_catalog', 'information_schema')
|
|
207
208
|
ORDER BY t.table_schema, t.table_name;
|
|
@@ -375,15 +376,30 @@ class MySQLSchemaIntrospector(BaseSchemaIntrospector):
|
|
|
375
376
|
class SQLiteSchemaIntrospector(BaseSchemaIntrospector):
|
|
376
377
|
"""SQLite-specific schema introspection."""
|
|
377
378
|
|
|
379
|
+
async def _execute_query(self, connection, query: str, params=()) -> list:
|
|
380
|
+
"""Helper method to execute queries on both SQLite and CSV connections."""
|
|
381
|
+
# Handle both SQLite and CSV connections
|
|
382
|
+
if hasattr(connection, "database_path"):
|
|
383
|
+
# Regular SQLite connection
|
|
384
|
+
async with aiosqlite.connect(connection.database_path) as conn:
|
|
385
|
+
conn.row_factory = aiosqlite.Row
|
|
386
|
+
cursor = await conn.execute(query, params)
|
|
387
|
+
return await cursor.fetchall()
|
|
388
|
+
else:
|
|
389
|
+
# CSV connection - use the existing connection
|
|
390
|
+
conn = await connection.get_pool()
|
|
391
|
+
cursor = await conn.execute(query, params)
|
|
392
|
+
return await cursor.fetchall()
|
|
393
|
+
|
|
378
394
|
async def get_tables_info(
|
|
379
395
|
self, connection, table_pattern: Optional[str] = None
|
|
380
396
|
) -> Dict[str, Any]:
|
|
381
397
|
"""Get tables information for SQLite."""
|
|
382
|
-
|
|
398
|
+
where_conditions = ["type IN ('table', 'view')", "name NOT LIKE 'sqlite_%'"]
|
|
383
399
|
params = ()
|
|
384
400
|
|
|
385
401
|
if table_pattern:
|
|
386
|
-
|
|
402
|
+
where_conditions.append("name LIKE ?")
|
|
387
403
|
params = (table_pattern,)
|
|
388
404
|
|
|
389
405
|
query = f"""
|
|
@@ -392,16 +408,11 @@ class SQLiteSchemaIntrospector(BaseSchemaIntrospector):
|
|
|
392
408
|
name as table_name,
|
|
393
409
|
type as table_type
|
|
394
410
|
FROM sqlite_master
|
|
395
|
-
WHERE
|
|
396
|
-
AND name NOT LIKE 'sqlite_%'
|
|
397
|
-
{where_clause}
|
|
411
|
+
WHERE {" AND ".join(where_conditions)}
|
|
398
412
|
ORDER BY name;
|
|
399
413
|
"""
|
|
400
414
|
|
|
401
|
-
|
|
402
|
-
conn.row_factory = aiosqlite.Row
|
|
403
|
-
cursor = await conn.execute(query, params)
|
|
404
|
-
return await cursor.fetchall()
|
|
415
|
+
return await self._execute_query(connection, query, params)
|
|
405
416
|
|
|
406
417
|
async def get_columns_info(self, connection, tables: list) -> list:
|
|
407
418
|
"""Get columns information for SQLite."""
|
|
@@ -414,26 +425,22 @@ class SQLiteSchemaIntrospector(BaseSchemaIntrospector):
|
|
|
414
425
|
|
|
415
426
|
# Get table info using PRAGMA
|
|
416
427
|
pragma_query = f"PRAGMA table_info({table_name})"
|
|
428
|
+
table_columns = await self._execute_query(connection, pragma_query)
|
|
417
429
|
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
"character_maximum_length": None,
|
|
433
|
-
"numeric_precision": None,
|
|
434
|
-
"numeric_scale": None,
|
|
435
|
-
}
|
|
436
|
-
)
|
|
430
|
+
for col in table_columns:
|
|
431
|
+
columns.append(
|
|
432
|
+
{
|
|
433
|
+
"table_schema": "main",
|
|
434
|
+
"table_name": table_name,
|
|
435
|
+
"column_name": col["name"],
|
|
436
|
+
"data_type": col["type"],
|
|
437
|
+
"is_nullable": "YES" if not col["notnull"] else "NO",
|
|
438
|
+
"column_default": col["dflt_value"],
|
|
439
|
+
"character_maximum_length": None,
|
|
440
|
+
"numeric_precision": None,
|
|
441
|
+
"numeric_scale": None,
|
|
442
|
+
}
|
|
443
|
+
)
|
|
437
444
|
|
|
438
445
|
return columns
|
|
439
446
|
|
|
@@ -448,23 +455,19 @@ class SQLiteSchemaIntrospector(BaseSchemaIntrospector):
|
|
|
448
455
|
|
|
449
456
|
# Get foreign key info using PRAGMA
|
|
450
457
|
pragma_query = f"PRAGMA foreign_key_list({table_name})"
|
|
458
|
+
table_fks = await self._execute_query(connection, pragma_query)
|
|
451
459
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
"foreign_table_schema": "main",
|
|
464
|
-
"foreign_table_name": fk["table"],
|
|
465
|
-
"foreign_column_name": fk["to"],
|
|
466
|
-
}
|
|
467
|
-
)
|
|
460
|
+
for fk in table_fks:
|
|
461
|
+
foreign_keys.append(
|
|
462
|
+
{
|
|
463
|
+
"table_schema": "main",
|
|
464
|
+
"table_name": table_name,
|
|
465
|
+
"column_name": fk["from"],
|
|
466
|
+
"foreign_table_schema": "main",
|
|
467
|
+
"foreign_table_name": fk["table"],
|
|
468
|
+
"foreign_column_name": fk["to"],
|
|
469
|
+
}
|
|
470
|
+
)
|
|
468
471
|
|
|
469
472
|
return foreign_keys
|
|
470
473
|
|
|
@@ -479,43 +482,64 @@ class SQLiteSchemaIntrospector(BaseSchemaIntrospector):
|
|
|
479
482
|
|
|
480
483
|
# Get table info using PRAGMA to find primary keys
|
|
481
484
|
pragma_query = f"PRAGMA table_info({table_name})"
|
|
485
|
+
table_columns = await self._execute_query(connection, pragma_query)
|
|
482
486
|
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
"table_schema": "main",
|
|
493
|
-
"table_name": table_name,
|
|
494
|
-
"column_name": col["name"],
|
|
495
|
-
}
|
|
496
|
-
)
|
|
487
|
+
for col in table_columns:
|
|
488
|
+
if col["pk"]: # Primary key indicator
|
|
489
|
+
primary_keys.append(
|
|
490
|
+
{
|
|
491
|
+
"table_schema": "main",
|
|
492
|
+
"table_name": table_name,
|
|
493
|
+
"column_name": col["name"],
|
|
494
|
+
}
|
|
495
|
+
)
|
|
497
496
|
|
|
498
497
|
return primary_keys
|
|
499
498
|
|
|
500
499
|
async def list_tables_info(self, connection) -> Dict[str, Any]:
|
|
501
500
|
"""Get list of tables with basic information for SQLite."""
|
|
502
|
-
#
|
|
501
|
+
# First get the table names
|
|
503
502
|
tables_query = """
|
|
504
503
|
SELECT
|
|
505
504
|
'main' as table_schema,
|
|
506
505
|
name as table_name,
|
|
507
|
-
type as table_type
|
|
508
|
-
0 as row_count
|
|
506
|
+
type as table_type
|
|
509
507
|
FROM sqlite_master
|
|
510
508
|
WHERE type IN ('table', 'view')
|
|
511
509
|
AND name NOT LIKE 'sqlite_%'
|
|
512
510
|
ORDER BY name;
|
|
513
511
|
"""
|
|
514
512
|
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
513
|
+
tables = await self._execute_query(connection, tables_query)
|
|
514
|
+
|
|
515
|
+
# Now get row counts for each table
|
|
516
|
+
result = []
|
|
517
|
+
for table in tables:
|
|
518
|
+
table_name = table["table_name"]
|
|
519
|
+
table_type = table["table_type"]
|
|
520
|
+
|
|
521
|
+
# Only count rows for tables, not views
|
|
522
|
+
if table_type.lower() == "table":
|
|
523
|
+
try:
|
|
524
|
+
count_query = f"SELECT COUNT(*) as count FROM [{table_name}]"
|
|
525
|
+
count_result = await self._execute_query(connection, count_query)
|
|
526
|
+
row_count = count_result[0]["count"] if count_result else 0
|
|
527
|
+
except Exception:
|
|
528
|
+
# If count fails (e.g., table locked), default to 0
|
|
529
|
+
row_count = 0
|
|
530
|
+
else:
|
|
531
|
+
# For views, we don't count rows as it could be expensive
|
|
532
|
+
row_count = 0
|
|
533
|
+
|
|
534
|
+
result.append(
|
|
535
|
+
{
|
|
536
|
+
"table_schema": table["table_schema"],
|
|
537
|
+
"table_name": table_name,
|
|
538
|
+
"table_type": table_type,
|
|
539
|
+
"row_count": row_count,
|
|
540
|
+
}
|
|
541
|
+
)
|
|
542
|
+
return result
|
|
519
543
|
|
|
520
544
|
|
|
521
545
|
class SchemaManager:
|
|
@@ -531,7 +555,7 @@ class SchemaManager:
|
|
|
531
555
|
self.introspector = PostgreSQLSchemaIntrospector()
|
|
532
556
|
elif isinstance(db_connection, MySQLConnection):
|
|
533
557
|
self.introspector = MySQLSchemaIntrospector()
|
|
534
|
-
elif isinstance(db_connection, SQLiteConnection):
|
|
558
|
+
elif isinstance(db_connection, (SQLiteConnection, CSVConnection)):
|
|
535
559
|
self.introspector = SQLiteSchemaIntrospector()
|
|
536
560
|
else:
|
|
537
561
|
raise ValueError(
|
sqlsaber/mcp/__init__.py
ADDED
sqlsaber/mcp/mcp.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""FastMCP server implementation for SQLSaber."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from fastmcp import FastMCP
|
|
7
|
+
|
|
8
|
+
from sqlsaber.agents.mcp import MCPSQLAgent
|
|
9
|
+
from sqlsaber.config.database import DatabaseConfigManager
|
|
10
|
+
from sqlsaber.database.connection import DatabaseConnection
|
|
11
|
+
|
|
12
|
+
INSTRUCTIONS = """
|
|
13
|
+
This server provides helpful resources and tools that will help you address users queries on their database.
|
|
14
|
+
|
|
15
|
+
- Get all databases using `get_databases()`
|
|
16
|
+
- Call `list_tables()` to get a list of all tables in the database with row counts. Use this first to discover available tables.
|
|
17
|
+
- Call `introspect_schema()` to introspect database schema to understand table structures.
|
|
18
|
+
- Call `execute_sql()` to execute SQL queries against the database and retrieve results.
|
|
19
|
+
|
|
20
|
+
Guidelines:
|
|
21
|
+
- Use list_tables first, then introspect_schema for specific tables only
|
|
22
|
+
- Use table patterns like 'sample%' or '%experiment%' to filter related tables
|
|
23
|
+
- Use proper JOIN syntax and avoid cartesian products
|
|
24
|
+
- Include appropriate WHERE clauses to limit results
|
|
25
|
+
- Handle errors gracefully and suggest fixes
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
# Create the FastMCP server instance
|
|
29
|
+
mcp = FastMCP(name="SQL Assistant", instructions=INSTRUCTIONS)
|
|
30
|
+
|
|
31
|
+
# Initialize the database config manager
|
|
32
|
+
config_manager = DatabaseConfigManager()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
async def _create_agent_for_database(database_name: str) -> Optional[MCPSQLAgent]:
|
|
36
|
+
"""Create a MCPSQLAgent for the specified database."""
|
|
37
|
+
try:
|
|
38
|
+
# Look up configured database connection
|
|
39
|
+
db_config = config_manager.get_database(database_name)
|
|
40
|
+
if not db_config:
|
|
41
|
+
return None
|
|
42
|
+
connection_string = db_config.to_connection_string()
|
|
43
|
+
|
|
44
|
+
# Create database connection
|
|
45
|
+
db_conn = DatabaseConnection(connection_string)
|
|
46
|
+
|
|
47
|
+
# Create and return the agent
|
|
48
|
+
agent = MCPSQLAgent(db_conn)
|
|
49
|
+
return agent
|
|
50
|
+
|
|
51
|
+
except Exception:
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@mcp.tool
|
|
56
|
+
def get_databases() -> dict:
|
|
57
|
+
"""List all configured databases with their types."""
|
|
58
|
+
databases = []
|
|
59
|
+
for db_config in config_manager.list_databases():
|
|
60
|
+
databases.append(
|
|
61
|
+
{
|
|
62
|
+
"name": db_config.name,
|
|
63
|
+
"type": db_config.type,
|
|
64
|
+
"database": db_config.database,
|
|
65
|
+
"host": db_config.host,
|
|
66
|
+
"port": db_config.port,
|
|
67
|
+
"is_default": db_config.name == config_manager.get_default_name(),
|
|
68
|
+
}
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
return {"databases": databases, "count": len(databases)}
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@mcp.tool
|
|
75
|
+
async def list_tables(database: str) -> str:
|
|
76
|
+
"""
|
|
77
|
+
Get a list of all tables in the database with row counts. Use this first to discover available tables.
|
|
78
|
+
"""
|
|
79
|
+
try:
|
|
80
|
+
agent = await _create_agent_for_database(database)
|
|
81
|
+
if not agent:
|
|
82
|
+
return json.dumps(
|
|
83
|
+
{"error": f"Database '{database}' not found or could not connect"}
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
result = await agent.list_tables()
|
|
87
|
+
await agent.db.close()
|
|
88
|
+
return result
|
|
89
|
+
|
|
90
|
+
except Exception as e:
|
|
91
|
+
return json.dumps({"error": f"Error listing tables: {str(e)}"})
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@mcp.tool
|
|
95
|
+
async def introspect_schema(database: str, table_pattern: Optional[str] = None) -> str:
|
|
96
|
+
"""
|
|
97
|
+
Introspect database schema to understand table structures. Use optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%').
|
|
98
|
+
"""
|
|
99
|
+
try:
|
|
100
|
+
agent = await _create_agent_for_database(database)
|
|
101
|
+
if not agent:
|
|
102
|
+
return json.dumps(
|
|
103
|
+
{"error": f"Database '{database}' not found or could not connect"}
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
result = await agent.introspect_schema(table_pattern)
|
|
107
|
+
await agent.db.close()
|
|
108
|
+
return result
|
|
109
|
+
|
|
110
|
+
except Exception as e:
|
|
111
|
+
return json.dumps({"error": f"Error introspecting schema: {str(e)}"})
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@mcp.tool
|
|
115
|
+
async def execute_sql(database: str, query: str, limit: Optional[int] = 100) -> str:
|
|
116
|
+
"""Execute a SQL query against the specified database."""
|
|
117
|
+
try:
|
|
118
|
+
agent = await _create_agent_for_database(database)
|
|
119
|
+
if not agent:
|
|
120
|
+
return json.dumps(
|
|
121
|
+
{"error": f"Database '{database}' not found or could not connect"}
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
result = await agent.execute_sql(query, limit)
|
|
125
|
+
await agent.db.close()
|
|
126
|
+
return result
|
|
127
|
+
|
|
128
|
+
except Exception as e:
|
|
129
|
+
return json.dumps({"error": f"Error executing SQL: {str(e)}"})
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def main():
|
|
133
|
+
"""Entry point for the MCP server console script."""
|
|
134
|
+
mcp.run()
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
if __name__ == "__main__":
|
|
138
|
+
main()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sqlsaber
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: SQLSaber - Agentic SQL assistant like Claude Code
|
|
5
5
|
License-File: LICENSE
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -8,8 +8,10 @@ Requires-Dist: aiomysql>=0.2.0
|
|
|
8
8
|
Requires-Dist: aiosqlite>=0.21.0
|
|
9
9
|
Requires-Dist: anthropic>=0.54.0
|
|
10
10
|
Requires-Dist: asyncpg>=0.30.0
|
|
11
|
+
Requires-Dist: fastmcp>=2.9.0
|
|
11
12
|
Requires-Dist: httpx>=0.28.1
|
|
12
13
|
Requires-Dist: keyring>=25.6.0
|
|
14
|
+
Requires-Dist: pandas>=2.0.0
|
|
13
15
|
Requires-Dist: platformdirs>=4.0.0
|
|
14
16
|
Requires-Dist: questionary>=2.1.0
|
|
15
17
|
Requires-Dist: rich>=13.7.0
|
|
@@ -33,6 +35,7 @@ Ask your questions in natural language and it will gather the right context and
|
|
|
33
35
|
- 💬 Interactive REPL mode
|
|
34
36
|
- 🎨 Beautiful formatted output with syntax highlighting
|
|
35
37
|
- 🗄️ Support for PostgreSQL, SQLite, and MySQL
|
|
38
|
+
- 🔌 MCP (Model Context Protocol) server support
|
|
36
39
|
|
|
37
40
|
## Installation
|
|
38
41
|
|
|
@@ -139,6 +142,43 @@ saber query "show me the distribution of customer ages"
|
|
|
139
142
|
saber query "which products had the highest sales growth last quarter?"
|
|
140
143
|
```
|
|
141
144
|
|
|
145
|
+
## MCP Server Integration
|
|
146
|
+
|
|
147
|
+
SQLSaber includes an MCP (Model Context Protocol) server that allows AI agents like Claude Code to directly leverage tools available in SQLSaber.
|
|
148
|
+
|
|
149
|
+
### Starting the MCP Server
|
|
150
|
+
|
|
151
|
+
Run the MCP server using uvx:
|
|
152
|
+
|
|
153
|
+
```bash
|
|
154
|
+
uvx saber-mcp
|
|
155
|
+
```
|
|
156
|
+
|
|
157
|
+
### Configuring MCP Clients
|
|
158
|
+
|
|
159
|
+
#### Claude Code
|
|
160
|
+
|
|
161
|
+
Add SQLSaber as an MCP server in Claude Code:
|
|
162
|
+
|
|
163
|
+
```bash
|
|
164
|
+
claude mcp add -- uvx saber-mcp
|
|
165
|
+
```
|
|
166
|
+
|
|
167
|
+
#### Other MCP Clients
|
|
168
|
+
|
|
169
|
+
For other MCP clients, configure them to run the command: `uvx saber-mcp`
|
|
170
|
+
|
|
171
|
+
### Available MCP Tools
|
|
172
|
+
|
|
173
|
+
Once connected, the MCP client will have access to these tools:
|
|
174
|
+
|
|
175
|
+
- `get_databases()` - Lists all configured databases
|
|
176
|
+
- `list_tables(database)` - Get all tables in a database with row counts
|
|
177
|
+
- `introspect_schema(database, table_pattern?)` - Get detailed schema information
|
|
178
|
+
- `execute_sql(database, query, limit?)` - Execute SQL queries (read-only)
|
|
179
|
+
|
|
180
|
+
The MCP server uses your existing SQLSaber database configurations, so make sure to set up your databases using `saber db add` first.
|
|
181
|
+
|
|
142
182
|
## How It Works
|
|
143
183
|
|
|
144
184
|
SQLSaber uses an intelligent three-step process optimized for minimal token usage:
|
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
sqlsaber/__init__.py,sha256=QCFi8xTVMohelfi7zOV1-6oLCcGoiXoOcKQY-HNBCk8,66
|
|
2
2
|
sqlsaber/__main__.py,sha256=RIHxWeWh2QvLfah-2OkhI5IJxojWfy4fXpMnVEJYvxw,78
|
|
3
3
|
sqlsaber/agents/__init__.py,sha256=LWeSeEUE4BhkyAYFF3TE-fx8TtLud3oyEtyB8ojFJgo,167
|
|
4
|
-
sqlsaber/agents/anthropic.py,sha256=
|
|
5
|
-
sqlsaber/agents/base.py,sha256=
|
|
6
|
-
sqlsaber/agents/
|
|
4
|
+
sqlsaber/agents/anthropic.py,sha256=xAjKeQSnaut-P5VBeBISbQeqdP41epDjX6MJb2ZUXWg,14060
|
|
5
|
+
sqlsaber/agents/base.py,sha256=IuVyCaA7VsA92odfQS2_lYNzwIZwPxK55mL_xRewgwQ,6943
|
|
6
|
+
sqlsaber/agents/mcp.py,sha256=FKtXgDrPZ2-xqUYCw2baI5JzrWekXaC5fjkYW1_Mg50,827
|
|
7
|
+
sqlsaber/agents/streaming.py,sha256=_EO390-FHUrL1fRCNfibtE9QuJz3LGQygbwG3CB2ViY,533
|
|
7
8
|
sqlsaber/cli/__init__.py,sha256=qVSLVJLLJYzoC6aj6y9MFrzZvAwc4_OgxU9DlkQnZ4M,86
|
|
8
|
-
sqlsaber/cli/commands.py,sha256=
|
|
9
|
-
sqlsaber/cli/database.py,sha256=
|
|
9
|
+
sqlsaber/cli/commands.py,sha256=h418lgh_Xp7XEQ1xvjcDyplC2JON0-y98QMaDm6o29k,4919
|
|
10
|
+
sqlsaber/cli/database.py,sha256=DUfyvNBDp47oFM_VAC_hXHQy_qyE7JbXtowflJpwwH8,12643
|
|
10
11
|
sqlsaber/cli/display.py,sha256=5J4AgJADmMwKi9Aq5u6_MKRO1TA6unS4F4RUfml_sfU,7651
|
|
11
12
|
sqlsaber/cli/interactive.py,sha256=y92rdoM49SOSwEctm9ZcrEN220fhJ_DMHPSd_7KsORg,3701
|
|
12
13
|
sqlsaber/cli/memory.py,sha256=LW4ZF2V6Gw6hviUFGZ4ym9ostFCwucgBTIMZ3EANO-I,7671
|
|
@@ -14,19 +15,21 @@ sqlsaber/cli/models.py,sha256=3IcXeeU15IQvemSv-V-RQzVytJ3wuQ4YmWk89nTDcSE,7813
|
|
|
14
15
|
sqlsaber/cli/streaming.py,sha256=5QGAYTAvg9mzQLxDEVtdDH-TIbGfYYzMOLoOYPrHPu0,3788
|
|
15
16
|
sqlsaber/config/__init__.py,sha256=olwC45k8Nc61yK0WmPUk7XHdbsZH9HuUAbwnmKe3IgA,100
|
|
16
17
|
sqlsaber/config/api_keys.py,sha256=kLdoExF_My9ojmdhO5Ca7-ZeowsO0v1GVa_QT5jjUPo,3658
|
|
17
|
-
sqlsaber/config/database.py,sha256=
|
|
18
|
+
sqlsaber/config/database.py,sha256=vKFOxPjVakjQhj1uoLcfzhS9ZFr6Z2F5b4MmYALQZoA,11421
|
|
18
19
|
sqlsaber/config/settings.py,sha256=zjQ7nS3ybcCb88Ea0tmwJox5-q0ettChZw89ZqRVpX8,3975
|
|
19
20
|
sqlsaber/database/__init__.py,sha256=a_gtKRJnZVO8-fEZI7g3Z8YnGa6Nio-5Y50PgVp07ss,176
|
|
20
|
-
sqlsaber/database/connection.py,sha256=
|
|
21
|
-
sqlsaber/database/schema.py,sha256=
|
|
21
|
+
sqlsaber/database/connection.py,sha256=s8GSFZebB8be8sVUr-N0x88-20YfkfljJFRyfoB1gH0,15154
|
|
22
|
+
sqlsaber/database/schema.py,sha256=9QoH-gADzWlepq-tGz3nPU3miSUU0koWmpDaoWvz8Q0,27951
|
|
23
|
+
sqlsaber/mcp/__init__.py,sha256=COdWq7wauPBp5Ew8tfZItFzbcLDSEkHBJSMhxzy8C9c,112
|
|
24
|
+
sqlsaber/mcp/mcp.py,sha256=ACm1P1TnicjOptQgeLNhXg5xgZf4MYq2kqdfVdj6wh0,4477
|
|
22
25
|
sqlsaber/memory/__init__.py,sha256=GiWkU6f6YYVV0EvvXDmFWe_CxarmDCql05t70MkTEWs,63
|
|
23
26
|
sqlsaber/memory/manager.py,sha256=ML2NEO5Z4Aw36sEI9eOvWVnjl-qT2VOTojViJAj7Seo,2777
|
|
24
27
|
sqlsaber/memory/storage.py,sha256=DvZBsSPaAfk_DqrNEn86uMD-TQsWUI6rQLfNw6PSCB8,5788
|
|
25
28
|
sqlsaber/models/__init__.py,sha256=RJ7p3WtuSwwpFQ1Iw4_DHV2zzCtHqIzsjJzxv8kUjUE,287
|
|
26
29
|
sqlsaber/models/events.py,sha256=55m41tDwMsFxnKKA5_VLJz8iV-V4Sq3LDfta4VoutJI,737
|
|
27
30
|
sqlsaber/models/types.py,sha256=3U_30n91EB3IglBTHipwiW4MqmmaA2qfshfraMZyPps,896
|
|
28
|
-
sqlsaber-0.
|
|
29
|
-
sqlsaber-0.
|
|
30
|
-
sqlsaber-0.
|
|
31
|
-
sqlsaber-0.
|
|
32
|
-
sqlsaber-0.
|
|
31
|
+
sqlsaber-0.4.0.dist-info/METADATA,sha256=CL1mNjOLrc6VDJqE2dSrCXO5OJz9gTMxYNoYq6jtzYE,5071
|
|
32
|
+
sqlsaber-0.4.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
33
|
+
sqlsaber-0.4.0.dist-info/entry_points.txt,sha256=jmFo96Ylm0zIKXJBwhv_P5wQ7SXP9qdaBcnTp8iCEe8,195
|
|
34
|
+
sqlsaber-0.4.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
35
|
+
sqlsaber-0.4.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|