sqlsaber 0.14.0__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

@@ -21,6 +21,8 @@ from sqlsaber.config.settings import Config
21
21
  from sqlsaber.database.connection import BaseDatabaseConnection
22
22
  from sqlsaber.memory.manager import MemoryManager
23
23
  from sqlsaber.models.events import StreamEvent
24
+ from sqlsaber.tools import tool_registry
25
+ from sqlsaber.tools.instructions import InstructionBuilder
24
26
 
25
27
 
26
28
  class AnthropicSQLAgent(BaseSQLAgent):
@@ -51,89 +53,11 @@ class AnthropicSQLAgent(BaseSQLAgent):
51
53
  self._last_results = None
52
54
  self._last_query = None
53
55
 
54
- # Define tools in the new format
55
- self.tools: list[ToolDefinition] = [
56
- ToolDefinition(
57
- name="list_tables",
58
- description="Get a list of all tables in the database with row counts. Use this first to discover available tables.",
59
- input_schema={
60
- "type": "object",
61
- "properties": {},
62
- "required": [],
63
- },
64
- ),
65
- ToolDefinition(
66
- name="introspect_schema",
67
- description="Introspect database schema to understand table structures.",
68
- input_schema={
69
- "type": "object",
70
- "properties": {
71
- "table_pattern": {
72
- "type": "string",
73
- "description": "Optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%')",
74
- }
75
- },
76
- "required": [],
77
- },
78
- ),
79
- ToolDefinition(
80
- name="execute_sql",
81
- description="Execute a SQL query against the database.",
82
- input_schema={
83
- "type": "object",
84
- "properties": {
85
- "query": {
86
- "type": "string",
87
- "description": "SQL query to execute",
88
- },
89
- "limit": {
90
- "type": "integer",
91
- "description": f"Maximum number of rows to return (default: {AnthropicSQLAgent.DEFAULT_SQL_LIMIT})",
92
- "default": AnthropicSQLAgent.DEFAULT_SQL_LIMIT,
93
- },
94
- },
95
- "required": ["query"],
96
- },
97
- ),
98
- ToolDefinition(
99
- name="plot_data",
100
- description="Create a plot of query results.",
101
- input_schema={
102
- "type": "object",
103
- "properties": {
104
- "y_values": {
105
- "type": "array",
106
- "items": {"type": ["number", "null"]},
107
- "description": "Y-axis data points (required)",
108
- },
109
- "x_values": {
110
- "type": "array",
111
- "items": {"type": ["number", "null"]},
112
- "description": "X-axis data points (optional, will use indices if not provided)",
113
- },
114
- "plot_type": {
115
- "type": "string",
116
- "enum": ["line", "scatter", "histogram"],
117
- "description": "Type of plot to create (default: line)",
118
- "default": "line",
119
- },
120
- "title": {
121
- "type": "string",
122
- "description": "Title for the plot",
123
- },
124
- "x_label": {
125
- "type": "string",
126
- "description": "Label for X-axis",
127
- },
128
- "y_label": {
129
- "type": "string",
130
- "description": "Label for Y-axis",
131
- },
132
- },
133
- "required": ["y_values"],
134
- },
135
- ),
136
- ]
56
+ # Get tool definitions from registry
57
+ self.tools: list[ToolDefinition] = tool_registry.get_tool_definitions()
58
+
59
+ # Initialize instruction builder
60
+ self.instruction_builder = InstructionBuilder(tool_registry)
137
61
 
138
62
  # Build system prompt with memories if available
139
63
  self.system_prompt = self._build_system_prompt()
@@ -157,31 +81,9 @@ class AnthropicSQLAgent(BaseSQLAgent):
157
81
  def _get_sql_assistant_instructions(self) -> str:
158
82
  """Get the detailed SQL assistant instructions."""
159
83
  db_type = self._get_database_type_name()
160
- instructions = f"""You are also a helpful SQL assistant that helps users query their {db_type} database.
161
-
162
- Your responsibilities:
163
- 1. Understand user's natural language requests, think and convert them to SQL
164
- 2. Use the provided tools efficiently to explore database schema
165
- 3. Generate appropriate SQL queries
166
- 4. Execute queries safely - queries that modify the database are not allowed
167
- 5. Format and explain results clearly
168
- 6. Create visualizations when requested or when they would be helpful
169
-
170
- IMPORTANT - Schema Discovery Strategy:
171
- 1. ALWAYS start with 'list_tables' to see available tables and row counts
172
- 2. Based on the user's query, identify which specific tables are relevant
173
- 3. Use 'introspect_schema' with a table_pattern to get details ONLY for relevant tables
174
- 4. Timestamp columns must be converted to text when you write queries
175
-
176
- Guidelines:
177
- - Use list_tables first, then introspect_schema for specific tables only
178
- - Use table patterns like 'sample%' or '%experiment%' to filter related tables
179
- - Use proper JOIN syntax and avoid cartesian products
180
- - Include appropriate WHERE clauses to limit results
181
- - Explain what the query does in simple terms
182
- - Handle errors gracefully and suggest fixes
183
- - Be security conscious - use parameterized queries when needed
184
- """
84
+
85
+ # Build dynamic instructions from available tools
86
+ instructions = self.instruction_builder.build_instructions(db_type=db_type)
185
87
 
186
88
  # Add memory context if database name is available
187
89
  if self.database_name:
@@ -189,7 +91,7 @@ Guidelines:
189
91
  self.database_name
190
92
  )
191
93
  if memory_context.strip():
192
- instructions += memory_context
94
+ instructions += "\n\n" + memory_context
193
95
 
194
96
  return instructions
195
97
 
@@ -199,16 +101,19 @@ Guidelines:
199
101
  return None
200
102
 
201
103
  memory = self.memory_manager.add_memory(self.database_name, content)
202
- # Rebuild system prompt with new memory
104
+ # Rebuild system prompt with new memory (includes dynamic instructions)
203
105
  self.system_prompt = self._build_system_prompt()
204
106
  return memory.id
205
107
 
206
- async def execute_sql(self, query: str, limit: int | None = None) -> str:
207
- """Execute a SQL query against the database with streaming support."""
208
- # Call parent implementation for core functionality
209
- result = await super().execute_sql(query, limit)
108
+ async def _execute_sql_with_tracking(
109
+ self, query: str, limit: int | None = None
110
+ ) -> str:
111
+ """Execute SQL and track results for streaming."""
112
+ # Get the execute_sql tool and run it
113
+ tool = tool_registry.get_tool("execute_sql")
114
+ result = await tool.execute(query=query, limit=limit)
210
115
 
211
- # Parse result to extract data for streaming (AnthropicSQLAgent specific)
116
+ # Parse result to extract data for streaming
212
117
  try:
213
118
  result_data = json.loads(result)
214
119
  if result_data.get("success") and "results" in result_data:
@@ -228,7 +133,14 @@ Guidelines:
228
133
  self, tool_name: str, tool_input: dict[str, Any]
229
134
  ) -> str:
230
135
  """Process a tool call and return the result."""
231
- # Use parent implementation for core tools
136
+ # Special handling for execute_sql to track results
137
+ if tool_name == "execute_sql":
138
+ return await self._execute_sql_with_tracking(
139
+ tool_input.get("query", ""),
140
+ tool_input.get("limit", self.DEFAULT_SQL_LIMIT),
141
+ )
142
+
143
+ # Use parent implementation for all other tools
232
144
  return await super().process_tool_call(tool_name, tool_input)
233
145
 
234
146
  def _convert_user_message_to_message(
sqlsaber/agents/base.py CHANGED
@@ -5,8 +5,6 @@ import json
5
5
  from abc import ABC, abstractmethod
6
6
  from typing import Any, AsyncIterator
7
7
 
8
- from uniplot import histogram, plot
9
-
10
8
  from sqlsaber.conversation.manager import ConversationManager
11
9
  from sqlsaber.database.connection import (
12
10
  BaseDatabaseConnection,
@@ -17,6 +15,7 @@ from sqlsaber.database.connection import (
17
15
  )
18
16
  from sqlsaber.database.schema import SchemaManager
19
17
  from sqlsaber.models.events import StreamEvent
18
+ from sqlsaber.tools import SQLTool, tool_registry
20
19
 
21
20
 
22
21
  class BaseSQLAgent(ABC):
@@ -32,6 +31,9 @@ class BaseSQLAgent(ABC):
32
31
  self._conversation_id: str | None = None
33
32
  self._msg_index: int = 0
34
33
 
34
+ # Initialize SQL tools with database connection
35
+ self._init_tools()
36
+
35
37
  @abstractmethod
36
38
  async def query_stream(
37
39
  self,
@@ -69,232 +71,28 @@ class BaseSQLAgent(ABC):
69
71
  else:
70
72
  return "database" # Fallback
71
73
 
72
- async def introspect_schema(self, table_pattern: str | None = None) -> str:
73
- """Introspect database schema to understand table structures."""
74
- try:
75
- # Pass table_pattern to get_schema_info for efficient filtering at DB level
76
- schema_info = await self.schema_manager.get_schema_info(table_pattern)
77
-
78
- # Format the schema information
79
- formatted_info = {}
80
- for table_name, table_info in schema_info.items():
81
- formatted_info[table_name] = {
82
- "columns": {
83
- col_name: {
84
- "type": col_info["data_type"],
85
- "nullable": col_info["nullable"],
86
- "default": col_info["default"],
87
- }
88
- for col_name, col_info in table_info["columns"].items()
89
- },
90
- "primary_keys": table_info["primary_keys"],
91
- "foreign_keys": [
92
- f"{fk['column']} -> {fk['references']['table']}.{fk['references']['column']}"
93
- for fk in table_info["foreign_keys"]
94
- ],
95
- }
96
-
97
- return json.dumps(formatted_info)
98
- except Exception as e:
99
- return json.dumps({"error": f"Error introspecting schema: {str(e)}"})
100
-
101
- async def list_tables(self) -> str:
102
- """List all tables in the database with basic information."""
103
- try:
104
- tables_info = await self.schema_manager.list_tables()
105
- return json.dumps(tables_info)
106
- except Exception as e:
107
- return json.dumps({"error": f"Error listing tables: {str(e)}"})
108
-
109
- async def execute_sql(self, query: str, limit: int | None = None) -> str:
110
- """Execute a SQL query against the database."""
111
- try:
112
- # Security check - only allow SELECT queries unless write is enabled
113
- write_error = self._validate_write_operation(query)
114
- if write_error:
115
- return json.dumps(
116
- {
117
- "error": write_error,
118
- }
119
- )
120
-
121
- # Add LIMIT if not present and it's a SELECT query
122
- query = self._add_limit_to_query(query, limit)
123
-
124
- # Execute the query (wrapped in a transaction for safety)
125
- results = await self.db.execute_query(query)
126
-
127
- # Format results
128
- actual_limit = limit if limit is not None else len(results)
129
-
130
- return json.dumps(
131
- {
132
- "success": True,
133
- "row_count": len(results),
134
- "results": results[:actual_limit], # Extra safety for limit
135
- "truncated": len(results) > actual_limit,
136
- }
137
- )
138
-
139
- except Exception as e:
140
- error_msg = str(e)
141
-
142
- # Provide helpful error messages
143
- suggestions = []
144
- if "column" in error_msg.lower() and "does not exist" in error_msg.lower():
145
- suggestions.append(
146
- "Check column names using the schema introspection tool"
147
- )
148
- elif "table" in error_msg.lower() and "does not exist" in error_msg.lower():
149
- suggestions.append(
150
- "Check table names using the schema introspection tool"
151
- )
152
- elif "syntax error" in error_msg.lower():
153
- suggestions.append(
154
- "Review SQL syntax, especially JOIN conditions and WHERE clauses"
155
- )
156
-
157
- return json.dumps({"error": error_msg, "suggestions": suggestions})
74
+ def _init_tools(self) -> None:
75
+ """Initialize SQL tools with database connection."""
76
+ # Get all SQL tools and set their database connection
77
+ for tool_name in tool_registry.list_tools(category="sql"):
78
+ tool = tool_registry.get_tool(tool_name)
79
+ if isinstance(tool, SQLTool):
80
+ tool.set_connection(self.db)
158
81
 
159
82
  async def process_tool_call(
160
83
  self, tool_name: str, tool_input: dict[str, Any]
161
84
  ) -> str:
162
85
  """Process a tool call and return the result."""
163
- if tool_name == "list_tables":
164
- return await self.list_tables()
165
- elif tool_name == "introspect_schema":
166
- return await self.introspect_schema(tool_input.get("table_pattern"))
167
- elif tool_name == "execute_sql":
168
- return await self.execute_sql(
169
- tool_input["query"], tool_input.get("limit", 100)
170
- )
171
- elif tool_name == "plot_data":
172
- return await self.plot_data(
173
- y_values=tool_input["y_values"],
174
- x_values=tool_input.get("x_values"),
175
- plot_type=tool_input.get("plot_type", "line"),
176
- title=tool_input.get("title"),
177
- x_label=tool_input.get("x_label"),
178
- y_label=tool_input.get("y_label"),
179
- )
180
- else:
181
- return json.dumps({"error": f"Unknown tool: {tool_name}"})
182
-
183
- def _validate_write_operation(self, query: str) -> str | None:
184
- """Validate if a write operation is allowed.
185
-
186
- Returns:
187
- None if operation is allowed, error message if not allowed.
188
- """
189
- query_upper = query.strip().upper()
190
-
191
- # Check for write operations
192
- write_keywords = [
193
- "INSERT",
194
- "UPDATE",
195
- "DELETE",
196
- "DROP",
197
- "CREATE",
198
- "ALTER",
199
- "TRUNCATE",
200
- ]
201
- is_write_query = any(query_upper.startswith(kw) for kw in write_keywords)
202
-
203
- if is_write_query:
204
- return (
205
- "Write operations are not allowed. Only SELECT queries are permitted."
206
- )
207
-
208
- return None
209
-
210
- def _add_limit_to_query(self, query: str, limit: int = 100) -> str:
211
- """Add LIMIT clause to SELECT queries if not present."""
212
- query_upper = query.strip().upper()
213
- if query_upper.startswith("SELECT") and "LIMIT" not in query_upper:
214
- return f"{query.rstrip(';')} LIMIT {limit};"
215
- return query
216
-
217
- async def plot_data(
218
- self,
219
- y_values: list[float],
220
- x_values: list[float] | None = None,
221
- plot_type: str = "line",
222
- title: str | None = None,
223
- x_label: str | None = None,
224
- y_label: str | None = None,
225
- ) -> str:
226
- """Create a terminal plot using uniplot.
227
-
228
- Args:
229
- y_values: Y-axis data points
230
- x_values: X-axis data points (optional)
231
- plot_type: Type of plot - "line", "scatter", or "histogram"
232
- title: Plot title
233
- x_label: X-axis label
234
- y_label: Y-axis label
235
-
236
- Returns:
237
- JSON string with success status and plot details
238
- """
239
86
  try:
240
- # Validate inputs
241
- if not y_values:
242
- return json.dumps({"error": "No data provided for plotting"})
243
-
244
- # Convert to floats if needed
245
- try:
246
- y_values = [float(v) if v is not None else None for v in y_values]
247
- if x_values:
248
- x_values = [float(v) if v is not None else None for v in x_values]
249
- except (ValueError, TypeError) as e:
250
- return json.dumps({"error": f"Invalid data format: {str(e)}"})
251
-
252
- # Create the plot
253
- if plot_type == "histogram":
254
- # For histogram, we only need y_values
255
- histogram(
256
- y_values,
257
- title=title,
258
- bins=min(20, len(set(y_values))), # Adaptive bin count
259
- )
260
- plot_info = {
261
- "type": "histogram",
262
- "data_points": len(y_values),
263
- "title": title or "Histogram",
264
- }
265
- elif plot_type in ["line", "scatter"]:
266
- # For line/scatter plots
267
- plot_kwargs = {
268
- "ys": y_values,
269
- "title": title,
270
- "lines": plot_type == "line",
271
- }
272
-
273
- if x_values:
274
- plot_kwargs["xs"] = x_values
275
- if x_label:
276
- plot_kwargs["x_unit"] = x_label
277
- if y_label:
278
- plot_kwargs["y_unit"] = y_label
279
-
280
- plot(**plot_kwargs)
281
-
282
- plot_info = {
283
- "type": plot_type,
284
- "data_points": len(y_values),
285
- "title": title or f"{plot_type.capitalize()} Plot",
286
- "has_x_values": x_values is not None,
287
- }
288
- else:
289
- return json.dumps({"error": f"Unsupported plot type: {plot_type}"})
290
-
87
+ tool = tool_registry.get_tool(tool_name)
88
+ return await tool.execute(**tool_input)
89
+ except KeyError:
90
+ return json.dumps({"error": f"Unknown tool: {tool_name}"})
91
+ except Exception as e:
291
92
  return json.dumps(
292
- {"success": True, "plot_rendered": True, "plot_info": plot_info}
93
+ {"error": f"Error executing tool '{tool_name}': {str(e)}"}
293
94
  )
294
95
 
295
- except Exception as e:
296
- return json.dumps({"error": f"Error creating plot: {str(e)}"})
297
-
298
96
  # Conversation persistence helpers
299
97
 
300
98
  async def _ensure_conversation(self) -> None:
sqlsaber/mcp/mcp.py CHANGED
@@ -7,25 +7,17 @@ from fastmcp import FastMCP
7
7
  from sqlsaber.agents.mcp import MCPSQLAgent
8
8
  from sqlsaber.config.database import DatabaseConfigManager
9
9
  from sqlsaber.database.connection import DatabaseConnection
10
+ from sqlsaber.tools import SQLTool, tool_registry
11
+ from sqlsaber.tools.instructions import InstructionBuilder
10
12
 
11
- INSTRUCTIONS = """
12
- This server provides helpful resources and tools that will help you address users queries on their database.
13
+ # Initialize the instruction builder
14
+ instruction_builder = InstructionBuilder(tool_registry)
13
15
 
14
- - Get all databases using `get_databases()`
15
- - Call `list_tables()` to get a list of all tables in the database with row counts. Use this first to discover available tables.
16
- - Call `introspect_schema()` to introspect database schema to understand table structures.
17
- - Call `execute_sql()` to execute SQL queries against the database and retrieve results.
16
+ # Generate dynamic instructions
17
+ DYNAMIC_INSTRUCTIONS = instruction_builder.build_mcp_instructions()
18
18
 
19
- Guidelines:
20
- - Use list_tables first, then introspect_schema for specific tables only
21
- - Use table patterns like 'sample%' or '%experiment%' to filter related tables
22
- - Use proper JOIN syntax and avoid cartesian products
23
- - Include appropriate WHERE clauses to limit results
24
- - Handle errors gracefully and suggest fixes
25
- """
26
-
27
- # Create the FastMCP server instance
28
- mcp = FastMCP(name="SQL Assistant", instructions=INSTRUCTIONS)
19
+ # Create the FastMCP server instance with dynamic instructions
20
+ mcp = FastMCP(name="SQL Assistant", instructions=DYNAMIC_INSTRUCTIONS)
29
21
 
30
22
  # Initialize the database config manager
31
23
  config_manager = DatabaseConfigManager()
@@ -70,10 +62,16 @@ def get_databases() -> dict:
70
62
  return {"databases": databases, "count": len(databases)}
71
63
 
72
64
 
73
- @mcp.tool
74
- async def list_tables(database: str) -> str:
75
- """
76
- Get a list of all tables in the database with row counts. Use this first to discover available tables.
65
+ async def _execute_with_connection(tool_name: str, database: str, **kwargs) -> str:
66
+ """Execute a SQL tool with database connection management.
67
+
68
+ Args:
69
+ tool_name: Name of the tool to execute
70
+ database: Database name to connect to
71
+ **kwargs: Tool-specific parameters
72
+
73
+ Returns:
74
+ JSON string with the tool's output
77
75
  """
78
76
  try:
79
77
  agent = await _create_agent_for_database(database)
@@ -82,50 +80,44 @@ async def list_tables(database: str) -> str:
82
80
  {"error": f"Database '{database}' not found or could not connect"}
83
81
  )
84
82
 
85
- result = await agent.list_tables()
83
+ # Get the tool and set up connection
84
+ tool = tool_registry.get_tool(tool_name)
85
+ if isinstance(tool, SQLTool):
86
+ tool.set_connection(agent.db)
87
+
88
+ # Execute the tool
89
+ result = await tool.execute(**kwargs)
86
90
  await agent.db.close()
87
91
  return result
88
92
 
89
93
  except Exception as e:
90
- return json.dumps({"error": f"Error listing tables: {str(e)}"})
94
+ return json.dumps({"error": f"Error in {tool_name}: {str(e)}"})
91
95
 
92
96
 
93
- @mcp.tool
94
- async def introspect_schema(database: str, table_pattern: str | None = None) -> str:
95
- """
96
- Introspect database schema to understand table structures. Use optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%').
97
- """
98
- try:
99
- agent = await _create_agent_for_database(database)
100
- if not agent:
101
- return json.dumps(
102
- {"error": f"Database '{database}' not found or could not connect"}
103
- )
97
+ # SQL Tool Wrappers with explicit signatures
104
98
 
105
- result = await agent.introspect_schema(table_pattern)
106
- await agent.db.close()
107
- return result
108
99
 
109
- except Exception as e:
110
- return json.dumps({"error": f"Error introspecting schema: {str(e)}"})
100
+ @mcp.tool
101
+ async def list_tables(database: str) -> str:
102
+ """Get a list of all tables in the database with row counts. Use this first to discover available tables."""
103
+ return await _execute_with_connection("list_tables", database)
111
104
 
112
105
 
113
106
  @mcp.tool
114
- async def execute_sql(database: str, query: str, limit: int | None = 100) -> str:
115
- """Execute a SQL query against the specified database."""
116
- try:
117
- agent = await _create_agent_for_database(database)
118
- if not agent:
119
- return json.dumps(
120
- {"error": f"Database '{database}' not found or could not connect"}
121
- )
107
+ async def introspect_schema(database: str, table_pattern: str = None) -> str:
108
+ """Introspect database schema to understand table structures."""
109
+ kwargs = {}
110
+ if table_pattern is not None:
111
+ kwargs["table_pattern"] = table_pattern
112
+ return await _execute_with_connection("introspect_schema", database, **kwargs)
122
113
 
123
- result = await agent.execute_sql(query, limit)
124
- await agent.db.close()
125
- return result
126
114
 
127
- except Exception as e:
128
- return json.dumps({"error": f"Error executing SQL: {str(e)}"})
115
+ @mcp.tool
116
+ async def execute_sql(database: str, query: str, limit: int = 100) -> str:
117
+ """Execute a SQL query against the database."""
118
+ return await _execute_with_connection(
119
+ "execute_sql", database, query=query, limit=limit
120
+ )
129
121
 
130
122
 
131
123
  def main():
@@ -0,0 +1,25 @@
1
+ """SQLSaber tools module."""
2
+
3
+ from .base import Tool
4
+ from .enums import ToolCategory, WorkflowPosition
5
+ from .instructions import InstructionBuilder
6
+ from .registry import ToolRegistry, register_tool, tool_registry
7
+
8
+ # Import concrete tools to register them
9
+ from .sql_tools import ExecuteSQLTool, IntrospectSchemaTool, ListTablesTool, SQLTool
10
+ from .visualization_tools import PlotDataTool
11
+
12
+ __all__ = [
13
+ "Tool",
14
+ "ToolCategory",
15
+ "WorkflowPosition",
16
+ "ToolRegistry",
17
+ "tool_registry",
18
+ "register_tool",
19
+ "InstructionBuilder",
20
+ "SQLTool",
21
+ "ListTablesTool",
22
+ "IntrospectSchemaTool",
23
+ "ExecuteSQLTool",
24
+ "PlotDataTool",
25
+ ]