sqlsaber 0.13.0__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

@@ -1,6 +1,5 @@
1
1
  """Database schema introspection utilities."""
2
2
 
3
- import time
4
3
  from abc import ABC, abstractmethod
5
4
  from typing import Any
6
5
 
@@ -532,12 +531,10 @@ class SQLiteSchemaIntrospector(BaseSchemaIntrospector):
532
531
 
533
532
 
534
533
  class SchemaManager:
535
- """Manages database schema introspection with caching."""
534
+ """Manages database schema introspection."""
536
535
 
537
- def __init__(self, db_connection: BaseDatabaseConnection, cache_ttl: int = 900):
536
+ def __init__(self, db_connection: BaseDatabaseConnection):
538
537
  self.db = db_connection
539
- self.cache_ttl = cache_ttl # Default 15 minutes
540
- self._schema_cache: dict[str, tuple[float, dict[str, Any]]] = {}
541
538
 
542
539
  # Select appropriate introspector based on connection type
543
540
  if isinstance(db_connection, PostgreSQLConnection):
@@ -551,10 +548,6 @@ class SchemaManager:
551
548
  f"Unsupported database connection type: {type(db_connection)}"
552
549
  )
553
550
 
554
- def clear_schema_cache(self):
555
- """Clear the schema cache."""
556
- self._schema_cache.clear()
557
-
558
551
  async def get_schema_info(
559
552
  self, table_pattern: str | None = None
560
553
  ) -> dict[str, SchemaInfo]:
@@ -563,31 +556,6 @@ class SchemaManager:
563
556
  Args:
564
557
  table_pattern: Optional SQL LIKE pattern to filter tables (e.g., 'public.user%')
565
558
  """
566
- # Check cache first
567
- cache_key = f"schema:{table_pattern or 'all'}"
568
- cached_data = self._get_cached_schema(cache_key)
569
- if cached_data is not None:
570
- return cached_data
571
-
572
- # Fetch from database if not cached
573
- schema_info = await self._fetch_schema_from_db(table_pattern)
574
-
575
- # Cache the result
576
- self._schema_cache[cache_key] = (time.time(), schema_info)
577
- return schema_info
578
-
579
- def _get_cached_schema(self, cache_key: str) -> dict[str, SchemaInfo] | None:
580
- """Get schema from cache if available and not expired."""
581
- if cache_key in self._schema_cache:
582
- cached_time, cached_data = self._schema_cache[cache_key]
583
- if time.time() - cached_time < self.cache_ttl:
584
- return cached_data
585
- return None
586
-
587
- async def _fetch_schema_from_db(
588
- self, table_pattern: str | None
589
- ) -> dict[str, SchemaInfo]:
590
- """Fetch schema information from database."""
591
559
  # Get all schema components
592
560
  tables = await self.introspector.get_tables_info(self.db, table_pattern)
593
561
  columns = await self.introspector.get_columns_info(self.db, tables)
@@ -672,13 +640,6 @@ class SchemaManager:
672
640
 
673
641
  async def list_tables(self) -> dict[str, Any]:
674
642
  """Get a list of all tables with basic information."""
675
- # Check cache first
676
- cache_key = "list_tables"
677
- cached_data = self._get_cached_tables(cache_key)
678
- if cached_data is not None:
679
- return cached_data
680
-
681
- # Fetch from database if not cached
682
643
  tables = await self.introspector.list_tables_info(self.db)
683
644
 
684
645
  # Format the result
@@ -694,14 +655,4 @@ class SchemaManager:
694
655
  }
695
656
  )
696
657
 
697
- # Cache the result
698
- self._schema_cache[cache_key] = (time.time(), result)
699
658
  return result
700
-
701
- def _get_cached_tables(self, cache_key: str) -> dict[str, Any] | None:
702
- """Get table list from cache if available and not expired."""
703
- if cache_key in self._schema_cache:
704
- cached_time, cached_data = self._schema_cache[cache_key]
705
- if time.time() - cached_time < self.cache_ttl:
706
- return cached_data
707
- return None
sqlsaber/mcp/mcp.py CHANGED
@@ -7,25 +7,17 @@ from fastmcp import FastMCP
7
7
  from sqlsaber.agents.mcp import MCPSQLAgent
8
8
  from sqlsaber.config.database import DatabaseConfigManager
9
9
  from sqlsaber.database.connection import DatabaseConnection
10
+ from sqlsaber.tools import SQLTool, tool_registry
11
+ from sqlsaber.tools.instructions import InstructionBuilder
10
12
 
11
- INSTRUCTIONS = """
12
- This server provides helpful resources and tools that will help you address users queries on their database.
13
+ # Initialize the instruction builder
14
+ instruction_builder = InstructionBuilder(tool_registry)
13
15
 
14
- - Get all databases using `get_databases()`
15
- - Call `list_tables()` to get a list of all tables in the database with row counts. Use this first to discover available tables.
16
- - Call `introspect_schema()` to introspect database schema to understand table structures.
17
- - Call `execute_sql()` to execute SQL queries against the database and retrieve results.
16
+ # Generate dynamic instructions
17
+ DYNAMIC_INSTRUCTIONS = instruction_builder.build_mcp_instructions()
18
18
 
19
- Guidelines:
20
- - Use list_tables first, then introspect_schema for specific tables only
21
- - Use table patterns like 'sample%' or '%experiment%' to filter related tables
22
- - Use proper JOIN syntax and avoid cartesian products
23
- - Include appropriate WHERE clauses to limit results
24
- - Handle errors gracefully and suggest fixes
25
- """
26
-
27
- # Create the FastMCP server instance
28
- mcp = FastMCP(name="SQL Assistant", instructions=INSTRUCTIONS)
19
+ # Create the FastMCP server instance with dynamic instructions
20
+ mcp = FastMCP(name="SQL Assistant", instructions=DYNAMIC_INSTRUCTIONS)
29
21
 
30
22
  # Initialize the database config manager
31
23
  config_manager = DatabaseConfigManager()
@@ -70,10 +62,16 @@ def get_databases() -> dict:
70
62
  return {"databases": databases, "count": len(databases)}
71
63
 
72
64
 
73
- @mcp.tool
74
- async def list_tables(database: str) -> str:
75
- """
76
- Get a list of all tables in the database with row counts. Use this first to discover available tables.
65
+ async def _execute_with_connection(tool_name: str, database: str, **kwargs) -> str:
66
+ """Execute a SQL tool with database connection management.
67
+
68
+ Args:
69
+ tool_name: Name of the tool to execute
70
+ database: Database name to connect to
71
+ **kwargs: Tool-specific parameters
72
+
73
+ Returns:
74
+ JSON string with the tool's output
77
75
  """
78
76
  try:
79
77
  agent = await _create_agent_for_database(database)
@@ -82,50 +80,44 @@ async def list_tables(database: str) -> str:
82
80
  {"error": f"Database '{database}' not found or could not connect"}
83
81
  )
84
82
 
85
- result = await agent.list_tables()
83
+ # Get the tool and set up connection
84
+ tool = tool_registry.get_tool(tool_name)
85
+ if isinstance(tool, SQLTool):
86
+ tool.set_connection(agent.db)
87
+
88
+ # Execute the tool
89
+ result = await tool.execute(**kwargs)
86
90
  await agent.db.close()
87
91
  return result
88
92
 
89
93
  except Exception as e:
90
- return json.dumps({"error": f"Error listing tables: {str(e)}"})
94
+ return json.dumps({"error": f"Error in {tool_name}: {str(e)}"})
91
95
 
92
96
 
93
- @mcp.tool
94
- async def introspect_schema(database: str, table_pattern: str | None = None) -> str:
95
- """
96
- Introspect database schema to understand table structures. Use optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%').
97
- """
98
- try:
99
- agent = await _create_agent_for_database(database)
100
- if not agent:
101
- return json.dumps(
102
- {"error": f"Database '{database}' not found or could not connect"}
103
- )
97
+ # SQL Tool Wrappers with explicit signatures
104
98
 
105
- result = await agent.introspect_schema(table_pattern)
106
- await agent.db.close()
107
- return result
108
99
 
109
- except Exception as e:
110
- return json.dumps({"error": f"Error introspecting schema: {str(e)}"})
100
+ @mcp.tool
101
+ async def list_tables(database: str) -> str:
102
+ """Get a list of all tables in the database with row counts. Use this first to discover available tables."""
103
+ return await _execute_with_connection("list_tables", database)
111
104
 
112
105
 
113
106
  @mcp.tool
114
- async def execute_sql(database: str, query: str, limit: int | None = 100) -> str:
115
- """Execute a SQL query against the specified database."""
116
- try:
117
- agent = await _create_agent_for_database(database)
118
- if not agent:
119
- return json.dumps(
120
- {"error": f"Database '{database}' not found or could not connect"}
121
- )
107
+ async def introspect_schema(database: str, table_pattern: str = None) -> str:
108
+ """Introspect database schema to understand table structures."""
109
+ kwargs = {}
110
+ if table_pattern is not None:
111
+ kwargs["table_pattern"] = table_pattern
112
+ return await _execute_with_connection("introspect_schema", database, **kwargs)
122
113
 
123
- result = await agent.execute_sql(query, limit)
124
- await agent.db.close()
125
- return result
126
114
 
127
- except Exception as e:
128
- return json.dumps({"error": f"Error executing SQL: {str(e)}"})
115
+ @mcp.tool
116
+ async def execute_sql(database: str, query: str, limit: int = 100) -> str:
117
+ """Execute a SQL query against the database."""
118
+ return await _execute_with_connection(
119
+ "execute_sql", database, query=query, limit=limit
120
+ )
129
121
 
130
122
 
131
123
  def main():
@@ -0,0 +1,25 @@
1
+ """SQLSaber tools module."""
2
+
3
+ from .base import Tool
4
+ from .enums import ToolCategory, WorkflowPosition
5
+ from .instructions import InstructionBuilder
6
+ from .registry import ToolRegistry, register_tool, tool_registry
7
+
8
+ # Import concrete tools to register them
9
+ from .sql_tools import ExecuteSQLTool, IntrospectSchemaTool, ListTablesTool, SQLTool
10
+ from .visualization_tools import PlotDataTool
11
+
12
+ __all__ = [
13
+ "Tool",
14
+ "ToolCategory",
15
+ "WorkflowPosition",
16
+ "ToolRegistry",
17
+ "tool_registry",
18
+ "register_tool",
19
+ "InstructionBuilder",
20
+ "SQLTool",
21
+ "ListTablesTool",
22
+ "IntrospectSchemaTool",
23
+ "ExecuteSQLTool",
24
+ "PlotDataTool",
25
+ ]
sqlsaber/tools/base.py ADDED
@@ -0,0 +1,83 @@
1
+ """Base class for SQLSaber tools."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any
5
+
6
+ from sqlsaber.clients.models import ToolDefinition
7
+
8
+ from .enums import ToolCategory, WorkflowPosition
9
+
10
+
11
+ class Tool(ABC):
12
+ """Abstract base class for all tools."""
13
+
14
+ def __init__(self):
15
+ """Initialize the tool."""
16
+ pass
17
+
18
+ @property
19
+ @abstractmethod
20
+ def name(self) -> str:
21
+ """Return the tool name."""
22
+ pass
23
+
24
+ @property
25
+ @abstractmethod
26
+ def description(self) -> str:
27
+ """Return the tool description."""
28
+ pass
29
+
30
+ @property
31
+ @abstractmethod
32
+ def input_schema(self) -> dict[str, Any]:
33
+ """Return the tool's input schema."""
34
+ pass
35
+
36
+ @abstractmethod
37
+ async def execute(self, **kwargs) -> str:
38
+ """Execute the tool with given inputs.
39
+
40
+ Args:
41
+ **kwargs: Tool-specific keyword arguments
42
+
43
+ Returns:
44
+ JSON string with the tool's output
45
+ """
46
+ pass
47
+
48
+ def to_definition(self) -> ToolDefinition:
49
+ """Convert this tool to a ToolDefinition."""
50
+ return ToolDefinition(
51
+ name=self.name,
52
+ description=self.description,
53
+ input_schema=self.input_schema,
54
+ )
55
+
56
+ @property
57
+ def category(self) -> ToolCategory:
58
+ """Return the tool category. Override to customize."""
59
+ return ToolCategory.GENERAL
60
+
61
+ def get_usage_instructions(self) -> str | None:
62
+ """Return tool-specific usage instructions for LLM guidance.
63
+
64
+ Returns:
65
+ Usage instructions string, or None for no specific guidance
66
+ """
67
+ return None
68
+
69
+ def get_priority(self) -> int:
70
+ """Return priority for tool ordering in instructions.
71
+
72
+ Returns:
73
+ Priority number (lower = higher priority, default = 100)
74
+ """
75
+ return 100
76
+
77
+ def get_workflow_position(self) -> WorkflowPosition:
78
+ """Return the typical workflow position for this tool.
79
+
80
+ Returns:
81
+ WorkflowPosition enum value
82
+ """
83
+ return WorkflowPosition.OTHER
@@ -0,0 +1,21 @@
1
+ """Enums for tool categories and workflow positions."""
2
+
3
+ from enum import Enum
4
+
5
+
6
+ class ToolCategory(Enum):
7
+ """Tool categories for organizing and filtering tools."""
8
+
9
+ GENERAL = "general"
10
+ SQL = "sql"
11
+ VISUALIZATION = "visualization"
12
+
13
+
14
+ class WorkflowPosition(Enum):
15
+ """Workflow positions for organizing tools by usage order."""
16
+
17
+ DISCOVERY = "discovery"
18
+ ANALYSIS = "analysis"
19
+ EXECUTION = "execution"
20
+ VISUALIZATION = "visualization"
21
+ OTHER = "other"
@@ -0,0 +1,251 @@
1
+ """Dynamic instruction builder for tools."""
2
+
3
+ from .base import Tool
4
+ from .enums import ToolCategory, WorkflowPosition
5
+ from .registry import ToolRegistry
6
+
7
+
8
+ class InstructionBuilder:
9
+ """Builds dynamic instructions based on available tools."""
10
+
11
+ def __init__(self, tool_registry: ToolRegistry):
12
+ """Initialize with a tool registry."""
13
+ self.registry = tool_registry
14
+
15
+ def build_instructions(
16
+ self,
17
+ db_type: str = "database",
18
+ category: str | ToolCategory | None = None,
19
+ include_base_instructions: bool = True,
20
+ ) -> str:
21
+ """Build dynamic instructions from available tools.
22
+
23
+ Args:
24
+ db_type: Type of database (PostgreSQL, MySQL, SQLite, etc.)
25
+ category: Optional category to filter tools by (string or ToolCategory enum)
26
+ include_base_instructions: Whether to include base SQL assistant instructions
27
+
28
+ Returns:
29
+ Complete instruction string for LLM
30
+ """
31
+ # Get available tools
32
+ tools = self.registry.get_all_tools(category)
33
+
34
+ if not tools:
35
+ return self._get_base_instructions(db_type)
36
+
37
+ # Sort tools by priority and workflow position
38
+ sorted_tools = self._sort_tools_by_workflow(tools)
39
+
40
+ # Build instruction components
41
+ instructions_parts = []
42
+
43
+ if include_base_instructions:
44
+ instructions_parts.append(self._get_base_instructions(db_type))
45
+
46
+ # Add tool-specific workflow guidance
47
+ workflow_instructions = self._build_workflow_instructions(sorted_tools)
48
+ if workflow_instructions:
49
+ instructions_parts.append(workflow_instructions)
50
+
51
+ # Add tool descriptions and guidelines
52
+ tool_guidelines = self._build_tool_guidelines(sorted_tools)
53
+ if tool_guidelines:
54
+ instructions_parts.append(tool_guidelines)
55
+
56
+ # Add general guidelines
57
+ general_guidelines = self._build_general_guidelines(sorted_tools)
58
+ if general_guidelines:
59
+ instructions_parts.append(general_guidelines)
60
+
61
+ return "\n\n".join(instructions_parts)
62
+
63
+ def _get_base_instructions(self, db_type: str) -> str:
64
+ """Get base SQL assistant instructions."""
65
+ return f"""You are also a helpful SQL assistant that helps users query their {db_type} database.
66
+
67
+ Your responsibilities:
68
+ 1. Understand user's natural language requests, think and convert them to SQL
69
+ 2. Use the provided tools efficiently to explore database schema
70
+ 3. Generate appropriate SQL queries
71
+ 4. Execute queries safely - queries that modify the database are not allowed
72
+ 5. Format and explain results clearly
73
+ 6. Create visualizations when requested or when they would be helpful"""
74
+
75
+ def _sort_tools_by_workflow(self, tools: list[Tool]) -> list[Tool]:
76
+ """Sort tools by priority and workflow position."""
77
+ # Define workflow position ordering
78
+ position_order = {
79
+ WorkflowPosition.DISCOVERY: 1,
80
+ WorkflowPosition.ANALYSIS: 2,
81
+ WorkflowPosition.EXECUTION: 3,
82
+ WorkflowPosition.VISUALIZATION: 4,
83
+ WorkflowPosition.OTHER: 5,
84
+ }
85
+
86
+ return sorted(
87
+ tools,
88
+ key=lambda tool: (
89
+ position_order.get(tool.get_workflow_position(), 5),
90
+ tool.get_priority(),
91
+ tool.name,
92
+ ),
93
+ )
94
+
95
+ def _build_workflow_instructions(self, sorted_tools: list[Tool]) -> str:
96
+ """Build workflow-based instructions."""
97
+ # Group tools by workflow position
98
+ workflow_groups = {}
99
+ for tool in sorted_tools:
100
+ position = tool.get_workflow_position()
101
+ if position not in workflow_groups:
102
+ workflow_groups[position] = []
103
+ workflow_groups[position].append(tool)
104
+
105
+ # Build workflow instructions
106
+ instructions = ["IMPORTANT - Tool Usage Strategy:"]
107
+ step = 1
108
+
109
+ # Add discovery tools first
110
+ if WorkflowPosition.DISCOVERY in workflow_groups:
111
+ discovery_tools = workflow_groups[WorkflowPosition.DISCOVERY]
112
+ for tool in discovery_tools:
113
+ usage = tool.get_usage_instructions()
114
+ if usage:
115
+ instructions.append(f"{step}. {usage}")
116
+ else:
117
+ instructions.append(
118
+ f"{step}. Use '{tool.name}' to {tool.description.lower()}"
119
+ )
120
+ step += 1
121
+
122
+ # Add analysis tools
123
+ if WorkflowPosition.ANALYSIS in workflow_groups:
124
+ analysis_tools = workflow_groups[WorkflowPosition.ANALYSIS]
125
+ for tool in analysis_tools:
126
+ usage = tool.get_usage_instructions()
127
+ if usage:
128
+ instructions.append(f"{step}. {usage}")
129
+ else:
130
+ instructions.append(
131
+ f"{step}. Use '{tool.name}' to {tool.description.lower()}"
132
+ )
133
+ step += 1
134
+
135
+ # Add execution tools
136
+ if WorkflowPosition.EXECUTION in workflow_groups:
137
+ execution_tools = workflow_groups[WorkflowPosition.EXECUTION]
138
+ for tool in execution_tools:
139
+ usage = tool.get_usage_instructions()
140
+ if usage:
141
+ instructions.append(f"{step}. {usage}")
142
+ else:
143
+ instructions.append(
144
+ f"{step}. Use '{tool.name}' to {tool.description.lower()}"
145
+ )
146
+ step += 1
147
+
148
+ # Add visualization tools
149
+ if WorkflowPosition.VISUALIZATION in workflow_groups:
150
+ viz_tools = workflow_groups[WorkflowPosition.VISUALIZATION]
151
+ for tool in viz_tools:
152
+ usage = tool.get_usage_instructions()
153
+ if usage:
154
+ instructions.append(f"{step}. {usage}")
155
+ else:
156
+ instructions.append(
157
+ f"{step}. Use '{tool.name}' when creating visualizations"
158
+ )
159
+ step += 1
160
+
161
+ return "\n".join(instructions) if len(instructions) > 1 else ""
162
+
163
+ def _build_tool_guidelines(self, sorted_tools: list[Tool]) -> str:
164
+ """Build tool-specific guidelines."""
165
+ guidelines = []
166
+
167
+ for tool in sorted_tools:
168
+ usage = tool.get_usage_instructions()
169
+ if usage and not self._is_usage_in_workflow(usage):
170
+ guidelines.append(f"- {tool.name}: {usage}")
171
+
172
+ if guidelines:
173
+ return "Tool-Specific Guidelines:\n" + "\n".join(guidelines)
174
+ return ""
175
+
176
+ def _build_general_guidelines(self, sorted_tools: list[Tool]) -> str:
177
+ """Build general usage guidelines."""
178
+ guidelines = [
179
+ "Guidelines:",
180
+ "- Use proper JOIN syntax and avoid cartesian products",
181
+ "- Include appropriate WHERE clauses to limit results",
182
+ "- Explain what the query does in simple terms",
183
+ "- Handle errors gracefully and suggest fixes",
184
+ "- Be security conscious - use parameterized queries when needed",
185
+ ]
186
+
187
+ # Add category-specific guidelines
188
+ categories = {tool.category for tool in sorted_tools}
189
+
190
+ if ToolCategory.SQL in categories:
191
+ guidelines.extend(
192
+ [
193
+ "- Timestamp columns must be converted to text when you write queries",
194
+ "- Use table patterns like 'sample%' or '%experiment%' to filter related tables",
195
+ ]
196
+ )
197
+
198
+ if ToolCategory.VISUALIZATION in categories:
199
+ guidelines.append(
200
+ "- Create visualizations when they would enhance understanding of the data"
201
+ )
202
+
203
+ return "\n".join(guidelines)
204
+
205
+ def _is_usage_in_workflow(self, usage: str) -> bool:
206
+ """Check if usage instruction is already covered in workflow section."""
207
+ # Simple heuristic - if usage starts with workflow words, it's probably in workflow
208
+ workflow_words = ["always start", "first", "use this", "begin with", "start by"]
209
+ usage_lower = usage.lower()
210
+ return any(word in usage_lower for word in workflow_words)
211
+
212
+ def build_mcp_instructions(self) -> str:
213
+ """Build instructions specifically for MCP server."""
214
+ instructions = [
215
+ "This server provides helpful resources and tools that will help you address users queries on their database.",
216
+ "",
217
+ ]
218
+
219
+ # Add database discovery
220
+ instructions.append("- Get all databases using `get_databases()`")
221
+
222
+ # Add tool-specific instructions
223
+ sql_tools = self.registry.get_all_tools(category=ToolCategory.SQL)
224
+ sorted_tools = self._sort_tools_by_workflow(sql_tools)
225
+
226
+ for tool in sorted_tools:
227
+ instructions.append(f"- Call `{tool.name}()` to {tool.description.lower()}")
228
+
229
+ # Add workflow guidelines
230
+ instructions.extend(["", "Guidelines:"])
231
+
232
+ workflow_instructions = self._build_workflow_instructions(sorted_tools)
233
+ if workflow_instructions:
234
+ # Extract just the numbered steps without the "IMPORTANT" header
235
+ lines = workflow_instructions.split("\n")[1:] # Skip header
236
+ for line in lines:
237
+ if line.strip():
238
+ # Convert numbered steps to bullet points
239
+ if line.strip()[0].isdigit():
240
+ instructions.append(f"- {line.strip()[3:]}") # Remove "X. "
241
+
242
+ # Add general guidelines
243
+ instructions.extend(
244
+ [
245
+ "- Use proper JOIN syntax and avoid cartesian products",
246
+ "- Include appropriate WHERE clauses to limit results",
247
+ "- Handle errors gracefully and suggest fixes",
248
+ ]
249
+ )
250
+
251
+ return "\n".join(instructions)