sqlsaber 0.14.0__tar.gz → 0.15.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/CHANGELOG.md +14 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/PKG-INFO +1 -1
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/pyproject.toml +1 -1
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/agents/anthropic.py +28 -116
- sqlsaber-0.15.0/src/sqlsaber/agents/base.py +187 -0
- sqlsaber-0.15.0/src/sqlsaber/mcp/mcp.py +129 -0
- sqlsaber-0.15.0/src/sqlsaber/tools/__init__.py +25 -0
- sqlsaber-0.15.0/src/sqlsaber/tools/base.py +83 -0
- sqlsaber-0.15.0/src/sqlsaber/tools/enums.py +21 -0
- sqlsaber-0.15.0/src/sqlsaber/tools/instructions.py +251 -0
- sqlsaber-0.15.0/src/sqlsaber/tools/registry.py +130 -0
- sqlsaber-0.15.0/src/sqlsaber/tools/sql_tools.py +275 -0
- sqlsaber-0.15.0/src/sqlsaber/tools/visualization_tools.py +144 -0
- sqlsaber-0.15.0/tests/test_tools/__init__.py +1 -0
- sqlsaber-0.15.0/tests/test_tools/test_base.py +63 -0
- sqlsaber-0.15.0/tests/test_tools/test_instructions.py +255 -0
- sqlsaber-0.15.0/tests/test_tools/test_registry.py +189 -0
- sqlsaber-0.15.0/tests/test_tools/test_sql_tools.py +218 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/uv.lock +1 -1
- sqlsaber-0.14.0/src/sqlsaber/agents/base.py +0 -389
- sqlsaber-0.14.0/src/sqlsaber/mcp/mcp.py +0 -137
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/.github/workflows/claude-code-review.yml +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/.github/workflows/claude.yml +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/.github/workflows/publish.yml +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/.github/workflows/test.yml +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/.gitignore +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/.python-version +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/AGENT.md +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/CLAUDE.md +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/LICENSE +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/README.md +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/pytest.ini +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/sqlsaber.svg +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/__init__.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/__main__.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/agents/__init__.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/agents/mcp.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/agents/streaming.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/__init__.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/auth.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/commands.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/completers.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/database.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/display.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/interactive.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/memory.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/models.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/cli/streaming.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/__init__.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/anthropic.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/base.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/exceptions.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/models.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/clients/streaming.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/config/__init__.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/config/api_keys.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/config/auth.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/config/database.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/config/oauth_flow.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/config/oauth_tokens.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/config/settings.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/conversation/__init__.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/conversation/manager.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/conversation/models.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/conversation/storage.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/database/__init__.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/database/connection.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/database/resolver.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/database/schema.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/mcp/__init__.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/memory/__init__.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/memory/manager.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/memory/storage.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/models/__init__.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/models/events.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/src/sqlsaber/models/types.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/tests/__init__.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/tests/conftest.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/tests/test_agents/test_anthropic_oauth.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/tests/test_cli/__init__.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/tests/test_cli/test_commands.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/tests/test_clients/test_anthropic_client.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/tests/test_clients/test_streaming.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/tests/test_config/__init__.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/tests/test_config/test_database.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/tests/test_config/test_oauth.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/tests/test_config/test_settings.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/tests/test_conversation_storage.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/tests/test_database/__init__.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/tests/test_database/test_connection.py +0 -0
- {sqlsaber-0.14.0 → sqlsaber-0.15.0}/tests/test_database_resolver.py +0 -0
|
@@ -4,6 +4,20 @@ All notable changes to SQLSaber will be documented in this file.
|
|
|
4
4
|
|
|
5
5
|
## [Unreleased]
|
|
6
6
|
|
|
7
|
+
## [0.15.0] - 2025-08-18
|
|
8
|
+
|
|
9
|
+
### Added
|
|
10
|
+
|
|
11
|
+
- Tool abstraction system with centralized registry (new `Tool` base class, `ToolRegistry`, decorators)
|
|
12
|
+
- Dynamic instruction generation system (`InstructionBuilder`)
|
|
13
|
+
- Comprehensive test suite for the tools module
|
|
14
|
+
|
|
15
|
+
### Changed
|
|
16
|
+
|
|
17
|
+
- Refactored agents to use centralized tool registry instead of hardcoded tools
|
|
18
|
+
- Enhanced MCP server with dynamic tool registration
|
|
19
|
+
- Moved core SQL functionality to dedicated tool classes
|
|
20
|
+
|
|
7
21
|
## [0.14.0] - 2025-08-01
|
|
8
22
|
|
|
9
23
|
### Added
|
|
@@ -21,6 +21,8 @@ from sqlsaber.config.settings import Config
|
|
|
21
21
|
from sqlsaber.database.connection import BaseDatabaseConnection
|
|
22
22
|
from sqlsaber.memory.manager import MemoryManager
|
|
23
23
|
from sqlsaber.models.events import StreamEvent
|
|
24
|
+
from sqlsaber.tools import tool_registry
|
|
25
|
+
from sqlsaber.tools.instructions import InstructionBuilder
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
class AnthropicSQLAgent(BaseSQLAgent):
|
|
@@ -51,89 +53,11 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
51
53
|
self._last_results = None
|
|
52
54
|
self._last_query = None
|
|
53
55
|
|
|
54
|
-
#
|
|
55
|
-
self.tools: list[ToolDefinition] =
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
input_schema={
|
|
60
|
-
"type": "object",
|
|
61
|
-
"properties": {},
|
|
62
|
-
"required": [],
|
|
63
|
-
},
|
|
64
|
-
),
|
|
65
|
-
ToolDefinition(
|
|
66
|
-
name="introspect_schema",
|
|
67
|
-
description="Introspect database schema to understand table structures.",
|
|
68
|
-
input_schema={
|
|
69
|
-
"type": "object",
|
|
70
|
-
"properties": {
|
|
71
|
-
"table_pattern": {
|
|
72
|
-
"type": "string",
|
|
73
|
-
"description": "Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')",
|
|
74
|
-
}
|
|
75
|
-
},
|
|
76
|
-
"required": [],
|
|
77
|
-
},
|
|
78
|
-
),
|
|
79
|
-
ToolDefinition(
|
|
80
|
-
name="execute_sql",
|
|
81
|
-
description="Execute a SQL query against the database.",
|
|
82
|
-
input_schema={
|
|
83
|
-
"type": "object",
|
|
84
|
-
"properties": {
|
|
85
|
-
"query": {
|
|
86
|
-
"type": "string",
|
|
87
|
-
"description": "SQL query to execute",
|
|
88
|
-
},
|
|
89
|
-
"limit": {
|
|
90
|
-
"type": "integer",
|
|
91
|
-
"description": f"Maximum number of rows to return (default: {AnthropicSQLAgent.DEFAULT_SQL_LIMIT})",
|
|
92
|
-
"default": AnthropicSQLAgent.DEFAULT_SQL_LIMIT,
|
|
93
|
-
},
|
|
94
|
-
},
|
|
95
|
-
"required": ["query"],
|
|
96
|
-
},
|
|
97
|
-
),
|
|
98
|
-
ToolDefinition(
|
|
99
|
-
name="plot_data",
|
|
100
|
-
description="Create a plot of query results.",
|
|
101
|
-
input_schema={
|
|
102
|
-
"type": "object",
|
|
103
|
-
"properties": {
|
|
104
|
-
"y_values": {
|
|
105
|
-
"type": "array",
|
|
106
|
-
"items": {"type": ["number", "null"]},
|
|
107
|
-
"description": "Y-axis data points (required)",
|
|
108
|
-
},
|
|
109
|
-
"x_values": {
|
|
110
|
-
"type": "array",
|
|
111
|
-
"items": {"type": ["number", "null"]},
|
|
112
|
-
"description": "X-axis data points (optional, will use indices if not provided)",
|
|
113
|
-
},
|
|
114
|
-
"plot_type": {
|
|
115
|
-
"type": "string",
|
|
116
|
-
"enum": ["line", "scatter", "histogram"],
|
|
117
|
-
"description": "Type of plot to create (default: line)",
|
|
118
|
-
"default": "line",
|
|
119
|
-
},
|
|
120
|
-
"title": {
|
|
121
|
-
"type": "string",
|
|
122
|
-
"description": "Title for the plot",
|
|
123
|
-
},
|
|
124
|
-
"x_label": {
|
|
125
|
-
"type": "string",
|
|
126
|
-
"description": "Label for X-axis",
|
|
127
|
-
},
|
|
128
|
-
"y_label": {
|
|
129
|
-
"type": "string",
|
|
130
|
-
"description": "Label for Y-axis",
|
|
131
|
-
},
|
|
132
|
-
},
|
|
133
|
-
"required": ["y_values"],
|
|
134
|
-
},
|
|
135
|
-
),
|
|
136
|
-
]
|
|
56
|
+
# Get tool definitions from registry
|
|
57
|
+
self.tools: list[ToolDefinition] = tool_registry.get_tool_definitions()
|
|
58
|
+
|
|
59
|
+
# Initialize instruction builder
|
|
60
|
+
self.instruction_builder = InstructionBuilder(tool_registry)
|
|
137
61
|
|
|
138
62
|
# Build system prompt with memories if available
|
|
139
63
|
self.system_prompt = self._build_system_prompt()
|
|
@@ -157,31 +81,9 @@ class AnthropicSQLAgent(BaseSQLAgent):
|
|
|
157
81
|
def _get_sql_assistant_instructions(self) -> str:
|
|
158
82
|
"""Get the detailed SQL assistant instructions."""
|
|
159
83
|
db_type = self._get_database_type_name()
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
1. Understand user's natural language requests, think and convert them to SQL
|
|
164
|
-
2. Use the provided tools efficiently to explore database schema
|
|
165
|
-
3. Generate appropriate SQL queries
|
|
166
|
-
4. Execute queries safely - queries that modify the database are not allowed
|
|
167
|
-
5. Format and explain results clearly
|
|
168
|
-
6. Create visualizations when requested or when they would be helpful
|
|
169
|
-
|
|
170
|
-
IMPORTANT - Schema Discovery Strategy:
|
|
171
|
-
1. ALWAYS start with 'list_tables' to see available tables and row counts
|
|
172
|
-
2. Based on the user's query, identify which specific tables are relevant
|
|
173
|
-
3. Use 'introspect_schema' with a table_pattern to get details ONLY for relevant tables
|
|
174
|
-
4. Timestamp columns must be converted to text when you write queries
|
|
175
|
-
|
|
176
|
-
Guidelines:
|
|
177
|
-
- Use list_tables first, then introspect_schema for specific tables only
|
|
178
|
-
- Use table patterns like 'sample%' or '%experiment%' to filter related tables
|
|
179
|
-
- Use proper JOIN syntax and avoid cartesian products
|
|
180
|
-
- Include appropriate WHERE clauses to limit results
|
|
181
|
-
- Explain what the query does in simple terms
|
|
182
|
-
- Handle errors gracefully and suggest fixes
|
|
183
|
-
- Be security conscious - use parameterized queries when needed
|
|
184
|
-
"""
|
|
84
|
+
|
|
85
|
+
# Build dynamic instructions from available tools
|
|
86
|
+
instructions = self.instruction_builder.build_instructions(db_type=db_type)
|
|
185
87
|
|
|
186
88
|
# Add memory context if database name is available
|
|
187
89
|
if self.database_name:
|
|
@@ -189,7 +91,7 @@ Guidelines:
|
|
|
189
91
|
self.database_name
|
|
190
92
|
)
|
|
191
93
|
if memory_context.strip():
|
|
192
|
-
instructions += memory_context
|
|
94
|
+
instructions += "\n\n" + memory_context
|
|
193
95
|
|
|
194
96
|
return instructions
|
|
195
97
|
|
|
@@ -199,16 +101,19 @@ Guidelines:
|
|
|
199
101
|
return None
|
|
200
102
|
|
|
201
103
|
memory = self.memory_manager.add_memory(self.database_name, content)
|
|
202
|
-
# Rebuild system prompt with new memory
|
|
104
|
+
# Rebuild system prompt with new memory (includes dynamic instructions)
|
|
203
105
|
self.system_prompt = self._build_system_prompt()
|
|
204
106
|
return memory.id
|
|
205
107
|
|
|
206
|
-
async def
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
108
|
+
async def _execute_sql_with_tracking(
|
|
109
|
+
self, query: str, limit: int | None = None
|
|
110
|
+
) -> str:
|
|
111
|
+
"""Execute SQL and track results for streaming."""
|
|
112
|
+
# Get the execute_sql tool and run it
|
|
113
|
+
tool = tool_registry.get_tool("execute_sql")
|
|
114
|
+
result = await tool.execute(query=query, limit=limit)
|
|
210
115
|
|
|
211
|
-
# Parse result to extract data for streaming
|
|
116
|
+
# Parse result to extract data for streaming
|
|
212
117
|
try:
|
|
213
118
|
result_data = json.loads(result)
|
|
214
119
|
if result_data.get("success") and "results" in result_data:
|
|
@@ -228,7 +133,14 @@ Guidelines:
|
|
|
228
133
|
self, tool_name: str, tool_input: dict[str, Any]
|
|
229
134
|
) -> str:
|
|
230
135
|
"""Process a tool call and return the result."""
|
|
231
|
-
#
|
|
136
|
+
# Special handling for execute_sql to track results
|
|
137
|
+
if tool_name == "execute_sql":
|
|
138
|
+
return await self._execute_sql_with_tracking(
|
|
139
|
+
tool_input.get("query", ""),
|
|
140
|
+
tool_input.get("limit", self.DEFAULT_SQL_LIMIT),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Use parent implementation for all other tools
|
|
232
144
|
return await super().process_tool_call(tool_name, tool_input)
|
|
233
145
|
|
|
234
146
|
def _convert_user_message_to_message(
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""Abstract base class for SQL agents."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Any, AsyncIterator
|
|
7
|
+
|
|
8
|
+
from sqlsaber.conversation.manager import ConversationManager
|
|
9
|
+
from sqlsaber.database.connection import (
|
|
10
|
+
BaseDatabaseConnection,
|
|
11
|
+
CSVConnection,
|
|
12
|
+
MySQLConnection,
|
|
13
|
+
PostgreSQLConnection,
|
|
14
|
+
SQLiteConnection,
|
|
15
|
+
)
|
|
16
|
+
from sqlsaber.database.schema import SchemaManager
|
|
17
|
+
from sqlsaber.models.events import StreamEvent
|
|
18
|
+
from sqlsaber.tools import SQLTool, tool_registry
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BaseSQLAgent(ABC):
|
|
22
|
+
"""Abstract base class for SQL agents."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, db_connection: BaseDatabaseConnection):
|
|
25
|
+
self.db = db_connection
|
|
26
|
+
self.schema_manager = SchemaManager(db_connection)
|
|
27
|
+
self.conversation_history: list[dict[str, Any]] = []
|
|
28
|
+
|
|
29
|
+
# Conversation persistence
|
|
30
|
+
self._conv_manager = ConversationManager()
|
|
31
|
+
self._conversation_id: str | None = None
|
|
32
|
+
self._msg_index: int = 0
|
|
33
|
+
|
|
34
|
+
# Initialize SQL tools with database connection
|
|
35
|
+
self._init_tools()
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
async def query_stream(
|
|
39
|
+
self,
|
|
40
|
+
user_query: str,
|
|
41
|
+
use_history: bool = True,
|
|
42
|
+
cancellation_token: asyncio.Event | None = None,
|
|
43
|
+
) -> AsyncIterator[StreamEvent]:
|
|
44
|
+
"""Process a user query and stream responses.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
user_query: The user's query to process
|
|
48
|
+
use_history: Whether to include conversation history
|
|
49
|
+
cancellation_token: Optional event to signal cancellation
|
|
50
|
+
"""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
async def clear_history(self):
|
|
54
|
+
"""Clear conversation history."""
|
|
55
|
+
# End current conversation in storage
|
|
56
|
+
await self._end_conversation()
|
|
57
|
+
|
|
58
|
+
# Clear in-memory history
|
|
59
|
+
self.conversation_history = []
|
|
60
|
+
|
|
61
|
+
def _get_database_type_name(self) -> str:
|
|
62
|
+
"""Get the human-readable database type name."""
|
|
63
|
+
if isinstance(self.db, PostgreSQLConnection):
|
|
64
|
+
return "PostgreSQL"
|
|
65
|
+
elif isinstance(self.db, MySQLConnection):
|
|
66
|
+
return "MySQL"
|
|
67
|
+
elif isinstance(self.db, SQLiteConnection):
|
|
68
|
+
return "SQLite"
|
|
69
|
+
elif isinstance(self.db, CSVConnection):
|
|
70
|
+
return "SQLite" # we convert csv to in-memory sqlite
|
|
71
|
+
else:
|
|
72
|
+
return "database" # Fallback
|
|
73
|
+
|
|
74
|
+
def _init_tools(self) -> None:
|
|
75
|
+
"""Initialize SQL tools with database connection."""
|
|
76
|
+
# Get all SQL tools and set their database connection
|
|
77
|
+
for tool_name in tool_registry.list_tools(category="sql"):
|
|
78
|
+
tool = tool_registry.get_tool(tool_name)
|
|
79
|
+
if isinstance(tool, SQLTool):
|
|
80
|
+
tool.set_connection(self.db)
|
|
81
|
+
|
|
82
|
+
async def process_tool_call(
|
|
83
|
+
self, tool_name: str, tool_input: dict[str, Any]
|
|
84
|
+
) -> str:
|
|
85
|
+
"""Process a tool call and return the result."""
|
|
86
|
+
try:
|
|
87
|
+
tool = tool_registry.get_tool(tool_name)
|
|
88
|
+
return await tool.execute(**tool_input)
|
|
89
|
+
except KeyError:
|
|
90
|
+
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
|
91
|
+
except Exception as e:
|
|
92
|
+
return json.dumps(
|
|
93
|
+
{"error": f"Error executing tool '{tool_name}': {str(e)}"}
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Conversation persistence helpers
|
|
97
|
+
|
|
98
|
+
async def _ensure_conversation(self) -> None:
|
|
99
|
+
"""Ensure a conversation is active for storing messages."""
|
|
100
|
+
if self._conversation_id is None:
|
|
101
|
+
db_name = getattr(self, "database_name", "unknown")
|
|
102
|
+
self._conversation_id = await self._conv_manager.start_conversation(db_name)
|
|
103
|
+
self._msg_index = 0
|
|
104
|
+
|
|
105
|
+
async def _store_user_message(self, content: str | dict[str, Any]) -> None:
|
|
106
|
+
"""Store a user message in conversation history."""
|
|
107
|
+
if self._conversation_id is None:
|
|
108
|
+
return
|
|
109
|
+
|
|
110
|
+
await self._conv_manager.add_user_message(
|
|
111
|
+
self._conversation_id, content, self._msg_index
|
|
112
|
+
)
|
|
113
|
+
self._msg_index += 1
|
|
114
|
+
|
|
115
|
+
async def _store_assistant_message(
|
|
116
|
+
self, content: list[dict[str, Any]] | dict[str, Any]
|
|
117
|
+
) -> None:
|
|
118
|
+
"""Store an assistant message in conversation history."""
|
|
119
|
+
if self._conversation_id is None:
|
|
120
|
+
return
|
|
121
|
+
|
|
122
|
+
await self._conv_manager.add_assistant_message(
|
|
123
|
+
self._conversation_id, content, self._msg_index
|
|
124
|
+
)
|
|
125
|
+
self._msg_index += 1
|
|
126
|
+
|
|
127
|
+
async def _store_tool_message(
|
|
128
|
+
self, content: list[dict[str, Any]] | dict[str, Any]
|
|
129
|
+
) -> None:
|
|
130
|
+
"""Store a tool/system message in conversation history."""
|
|
131
|
+
if self._conversation_id is None:
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
await self._conv_manager.add_tool_message(
|
|
135
|
+
self._conversation_id, content, self._msg_index
|
|
136
|
+
)
|
|
137
|
+
self._msg_index += 1
|
|
138
|
+
|
|
139
|
+
async def _end_conversation(self) -> None:
|
|
140
|
+
"""End the current conversation."""
|
|
141
|
+
if self._conversation_id:
|
|
142
|
+
await self._conv_manager.end_conversation(self._conversation_id)
|
|
143
|
+
self._conversation_id = None
|
|
144
|
+
self._msg_index = 0
|
|
145
|
+
|
|
146
|
+
async def restore_conversation(self, conversation_id: str) -> bool:
|
|
147
|
+
"""Restore a conversation from storage to in-memory history.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
conversation_id: ID of the conversation to restore
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
True if successfully restored, False otherwise
|
|
154
|
+
"""
|
|
155
|
+
success = await self._conv_manager.restore_conversation_to_agent(
|
|
156
|
+
conversation_id, self.conversation_history
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
if success:
|
|
160
|
+
# Set up for continuing this conversation
|
|
161
|
+
self._conversation_id = conversation_id
|
|
162
|
+
self._msg_index = len(self.conversation_history)
|
|
163
|
+
|
|
164
|
+
return success
|
|
165
|
+
|
|
166
|
+
async def list_conversations(self, limit: int = 50) -> list:
|
|
167
|
+
"""List conversations for this agent's database.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
limit: Maximum number of conversations to return
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
List of conversation data
|
|
174
|
+
"""
|
|
175
|
+
db_name = getattr(self, "database_name", None)
|
|
176
|
+
conversations = await self._conv_manager.list_conversations(db_name, limit)
|
|
177
|
+
|
|
178
|
+
return [
|
|
179
|
+
{
|
|
180
|
+
"id": conv.id,
|
|
181
|
+
"database_name": conv.database_name,
|
|
182
|
+
"started_at": conv.formatted_start_time(),
|
|
183
|
+
"ended_at": conv.formatted_end_time(),
|
|
184
|
+
"duration": conv.duration_seconds(),
|
|
185
|
+
}
|
|
186
|
+
for conv in conversations
|
|
187
|
+
]
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""FastMCP server implementation for SQLSaber."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from fastmcp import FastMCP
|
|
6
|
+
|
|
7
|
+
from sqlsaber.agents.mcp import MCPSQLAgent
|
|
8
|
+
from sqlsaber.config.database import DatabaseConfigManager
|
|
9
|
+
from sqlsaber.database.connection import DatabaseConnection
|
|
10
|
+
from sqlsaber.tools import SQLTool, tool_registry
|
|
11
|
+
from sqlsaber.tools.instructions import InstructionBuilder
|
|
12
|
+
|
|
13
|
+
# Initialize the instruction builder
|
|
14
|
+
instruction_builder = InstructionBuilder(tool_registry)
|
|
15
|
+
|
|
16
|
+
# Generate dynamic instructions
|
|
17
|
+
DYNAMIC_INSTRUCTIONS = instruction_builder.build_mcp_instructions()
|
|
18
|
+
|
|
19
|
+
# Create the FastMCP server instance with dynamic instructions
|
|
20
|
+
mcp = FastMCP(name="SQL Assistant", instructions=DYNAMIC_INSTRUCTIONS)
|
|
21
|
+
|
|
22
|
+
# Initialize the database config manager
|
|
23
|
+
config_manager = DatabaseConfigManager()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
async def _create_agent_for_database(database_name: str) -> MCPSQLAgent | None:
|
|
27
|
+
"""Create a MCPSQLAgent for the specified database."""
|
|
28
|
+
try:
|
|
29
|
+
# Look up configured database connection
|
|
30
|
+
db_config = config_manager.get_database(database_name)
|
|
31
|
+
if not db_config:
|
|
32
|
+
return None
|
|
33
|
+
connection_string = db_config.to_connection_string()
|
|
34
|
+
|
|
35
|
+
# Create database connection
|
|
36
|
+
db_conn = DatabaseConnection(connection_string)
|
|
37
|
+
|
|
38
|
+
# Create and return the agent
|
|
39
|
+
agent = MCPSQLAgent(db_conn)
|
|
40
|
+
return agent
|
|
41
|
+
|
|
42
|
+
except Exception:
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@mcp.tool
|
|
47
|
+
def get_databases() -> dict:
|
|
48
|
+
"""List all configured databases with their types."""
|
|
49
|
+
databases = []
|
|
50
|
+
for db_config in config_manager.list_databases():
|
|
51
|
+
databases.append(
|
|
52
|
+
{
|
|
53
|
+
"name": db_config.name,
|
|
54
|
+
"type": db_config.type,
|
|
55
|
+
"database": db_config.database,
|
|
56
|
+
"host": db_config.host,
|
|
57
|
+
"port": db_config.port,
|
|
58
|
+
"is_default": db_config.name == config_manager.get_default_name(),
|
|
59
|
+
}
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
return {"databases": databases, "count": len(databases)}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
async def _execute_with_connection(tool_name: str, database: str, **kwargs) -> str:
|
|
66
|
+
"""Execute a SQL tool with database connection management.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
tool_name: Name of the tool to execute
|
|
70
|
+
database: Database name to connect to
|
|
71
|
+
**kwargs: Tool-specific parameters
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
JSON string with the tool's output
|
|
75
|
+
"""
|
|
76
|
+
try:
|
|
77
|
+
agent = await _create_agent_for_database(database)
|
|
78
|
+
if not agent:
|
|
79
|
+
return json.dumps(
|
|
80
|
+
{"error": f"Database '{database}' not found or could not connect"}
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Get the tool and set up connection
|
|
84
|
+
tool = tool_registry.get_tool(tool_name)
|
|
85
|
+
if isinstance(tool, SQLTool):
|
|
86
|
+
tool.set_connection(agent.db)
|
|
87
|
+
|
|
88
|
+
# Execute the tool
|
|
89
|
+
result = await tool.execute(**kwargs)
|
|
90
|
+
await agent.db.close()
|
|
91
|
+
return result
|
|
92
|
+
|
|
93
|
+
except Exception as e:
|
|
94
|
+
return json.dumps({"error": f"Error in {tool_name}: {str(e)}"})
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# SQL Tool Wrappers with explicit signatures
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@mcp.tool
|
|
101
|
+
async def list_tables(database: str) -> str:
|
|
102
|
+
"""Get a list of all tables in the database with row counts. Use this first to discover available tables."""
|
|
103
|
+
return await _execute_with_connection("list_tables", database)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@mcp.tool
|
|
107
|
+
async def introspect_schema(database: str, table_pattern: str = None) -> str:
|
|
108
|
+
"""Introspect database schema to understand table structures."""
|
|
109
|
+
kwargs = {}
|
|
110
|
+
if table_pattern is not None:
|
|
111
|
+
kwargs["table_pattern"] = table_pattern
|
|
112
|
+
return await _execute_with_connection("introspect_schema", database, **kwargs)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@mcp.tool
|
|
116
|
+
async def execute_sql(database: str, query: str, limit: int = 100) -> str:
|
|
117
|
+
"""Execute a SQL query against the database."""
|
|
118
|
+
return await _execute_with_connection(
|
|
119
|
+
"execute_sql", database, query=query, limit=limit
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def main():
|
|
124
|
+
"""Entry point for the MCP server console script."""
|
|
125
|
+
mcp.run()
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
if __name__ == "__main__":
|
|
129
|
+
main()
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""SQLSaber tools module."""
|
|
2
|
+
|
|
3
|
+
from .base import Tool
|
|
4
|
+
from .enums import ToolCategory, WorkflowPosition
|
|
5
|
+
from .instructions import InstructionBuilder
|
|
6
|
+
from .registry import ToolRegistry, register_tool, tool_registry
|
|
7
|
+
|
|
8
|
+
# Import concrete tools to register them
|
|
9
|
+
from .sql_tools import ExecuteSQLTool, IntrospectSchemaTool, ListTablesTool, SQLTool
|
|
10
|
+
from .visualization_tools import PlotDataTool
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"Tool",
|
|
14
|
+
"ToolCategory",
|
|
15
|
+
"WorkflowPosition",
|
|
16
|
+
"ToolRegistry",
|
|
17
|
+
"tool_registry",
|
|
18
|
+
"register_tool",
|
|
19
|
+
"InstructionBuilder",
|
|
20
|
+
"SQLTool",
|
|
21
|
+
"ListTablesTool",
|
|
22
|
+
"IntrospectSchemaTool",
|
|
23
|
+
"ExecuteSQLTool",
|
|
24
|
+
"PlotDataTool",
|
|
25
|
+
]
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Base class for SQLSaber tools."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from sqlsaber.clients.models import ToolDefinition
|
|
7
|
+
|
|
8
|
+
from .enums import ToolCategory, WorkflowPosition
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Tool(ABC):
|
|
12
|
+
"""Abstract base class for all tools."""
|
|
13
|
+
|
|
14
|
+
def __init__(self):
|
|
15
|
+
"""Initialize the tool."""
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def name(self) -> str:
|
|
21
|
+
"""Return the tool name."""
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def description(self) -> str:
|
|
27
|
+
"""Return the tool description."""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def input_schema(self) -> dict[str, Any]:
|
|
33
|
+
"""Return the tool's input schema."""
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
async def execute(self, **kwargs) -> str:
|
|
38
|
+
"""Execute the tool with given inputs.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
**kwargs: Tool-specific keyword arguments
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
JSON string with the tool's output
|
|
45
|
+
"""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
def to_definition(self) -> ToolDefinition:
|
|
49
|
+
"""Convert this tool to a ToolDefinition."""
|
|
50
|
+
return ToolDefinition(
|
|
51
|
+
name=self.name,
|
|
52
|
+
description=self.description,
|
|
53
|
+
input_schema=self.input_schema,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def category(self) -> ToolCategory:
|
|
58
|
+
"""Return the tool category. Override to customize."""
|
|
59
|
+
return ToolCategory.GENERAL
|
|
60
|
+
|
|
61
|
+
def get_usage_instructions(self) -> str | None:
|
|
62
|
+
"""Return tool-specific usage instructions for LLM guidance.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Usage instructions string, or None for no specific guidance
|
|
66
|
+
"""
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
def get_priority(self) -> int:
|
|
70
|
+
"""Return priority for tool ordering in instructions.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Priority number (lower = higher priority, default = 100)
|
|
74
|
+
"""
|
|
75
|
+
return 100
|
|
76
|
+
|
|
77
|
+
def get_workflow_position(self) -> WorkflowPosition:
|
|
78
|
+
"""Return the typical workflow position for this tool.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
WorkflowPosition enum value
|
|
82
|
+
"""
|
|
83
|
+
return WorkflowPosition.OTHER
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Enums for tool categories and workflow positions."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ToolCategory(Enum):
|
|
7
|
+
"""Tool categories for organizing and filtering tools."""
|
|
8
|
+
|
|
9
|
+
GENERAL = "general"
|
|
10
|
+
SQL = "sql"
|
|
11
|
+
VISUALIZATION = "visualization"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class WorkflowPosition(Enum):
|
|
15
|
+
"""Workflow positions for organizing tools by usage order."""
|
|
16
|
+
|
|
17
|
+
DISCOVERY = "discovery"
|
|
18
|
+
ANALYSIS = "analysis"
|
|
19
|
+
EXECUTION = "execution"
|
|
20
|
+
VISUALIZATION = "visualization"
|
|
21
|
+
OTHER = "other"
|