sqlsaber 0.13.0__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/anthropic.py +63 -123
- sqlsaber/agents/base.py +111 -210
- sqlsaber/cli/interactive.py +6 -2
- sqlsaber/conversation/__init__.py +12 -0
- sqlsaber/conversation/manager.py +224 -0
- sqlsaber/conversation/models.py +120 -0
- sqlsaber/conversation/storage.py +362 -0
- sqlsaber/database/schema.py +2 -51
- sqlsaber/mcp/mcp.py +43 -51
- sqlsaber/tools/__init__.py +25 -0
- sqlsaber/tools/base.py +83 -0
- sqlsaber/tools/enums.py +21 -0
- sqlsaber/tools/instructions.py +251 -0
- sqlsaber/tools/registry.py +130 -0
- sqlsaber/tools/sql_tools.py +275 -0
- sqlsaber/tools/visualization_tools.py +144 -0
- {sqlsaber-0.13.0.dist-info → sqlsaber-0.15.0.dist-info}/METADATA +1 -1
- {sqlsaber-0.13.0.dist-info → sqlsaber-0.15.0.dist-info}/RECORD +21 -10
- {sqlsaber-0.13.0.dist-info → sqlsaber-0.15.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.13.0.dist-info → sqlsaber-0.15.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.13.0.dist-info → sqlsaber-0.15.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/database/schema.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Database schema introspection utilities."""
|
|
2
2
|
|
|
3
|
-
import time
|
|
4
3
|
from abc import ABC, abstractmethod
|
|
5
4
|
from typing import Any
|
|
6
5
|
|
|
@@ -532,12 +531,10 @@ class SQLiteSchemaIntrospector(BaseSchemaIntrospector):
|
|
|
532
531
|
|
|
533
532
|
|
|
534
533
|
class SchemaManager:
|
|
535
|
-
"""Manages database schema introspection
|
|
534
|
+
"""Manages database schema introspection."""
|
|
536
535
|
|
|
537
|
-
def __init__(self, db_connection: BaseDatabaseConnection
|
|
536
|
+
def __init__(self, db_connection: BaseDatabaseConnection):
|
|
538
537
|
self.db = db_connection
|
|
539
|
-
self.cache_ttl = cache_ttl # Default 15 minutes
|
|
540
|
-
self._schema_cache: dict[str, tuple[float, dict[str, Any]]] = {}
|
|
541
538
|
|
|
542
539
|
# Select appropriate introspector based on connection type
|
|
543
540
|
if isinstance(db_connection, PostgreSQLConnection):
|
|
@@ -551,10 +548,6 @@ class SchemaManager:
|
|
|
551
548
|
f"Unsupported database connection type: {type(db_connection)}"
|
|
552
549
|
)
|
|
553
550
|
|
|
554
|
-
def clear_schema_cache(self):
|
|
555
|
-
"""Clear the schema cache."""
|
|
556
|
-
self._schema_cache.clear()
|
|
557
|
-
|
|
558
551
|
async def get_schema_info(
|
|
559
552
|
self, table_pattern: str | None = None
|
|
560
553
|
) -> dict[str, SchemaInfo]:
|
|
@@ -563,31 +556,6 @@ class SchemaManager:
|
|
|
563
556
|
Args:
|
|
564
557
|
table_pattern: Optional SQL LIKE pattern to filter tables (e.g., 'public.user%')
|
|
565
558
|
"""
|
|
566
|
-
# Check cache first
|
|
567
|
-
cache_key = f"schema:{table_pattern or 'all'}"
|
|
568
|
-
cached_data = self._get_cached_schema(cache_key)
|
|
569
|
-
if cached_data is not None:
|
|
570
|
-
return cached_data
|
|
571
|
-
|
|
572
|
-
# Fetch from database if not cached
|
|
573
|
-
schema_info = await self._fetch_schema_from_db(table_pattern)
|
|
574
|
-
|
|
575
|
-
# Cache the result
|
|
576
|
-
self._schema_cache[cache_key] = (time.time(), schema_info)
|
|
577
|
-
return schema_info
|
|
578
|
-
|
|
579
|
-
def _get_cached_schema(self, cache_key: str) -> dict[str, SchemaInfo] | None:
|
|
580
|
-
"""Get schema from cache if available and not expired."""
|
|
581
|
-
if cache_key in self._schema_cache:
|
|
582
|
-
cached_time, cached_data = self._schema_cache[cache_key]
|
|
583
|
-
if time.time() - cached_time < self.cache_ttl:
|
|
584
|
-
return cached_data
|
|
585
|
-
return None
|
|
586
|
-
|
|
587
|
-
async def _fetch_schema_from_db(
|
|
588
|
-
self, table_pattern: str | None
|
|
589
|
-
) -> dict[str, SchemaInfo]:
|
|
590
|
-
"""Fetch schema information from database."""
|
|
591
559
|
# Get all schema components
|
|
592
560
|
tables = await self.introspector.get_tables_info(self.db, table_pattern)
|
|
593
561
|
columns = await self.introspector.get_columns_info(self.db, tables)
|
|
@@ -672,13 +640,6 @@ class SchemaManager:
|
|
|
672
640
|
|
|
673
641
|
async def list_tables(self) -> dict[str, Any]:
|
|
674
642
|
"""Get a list of all tables with basic information."""
|
|
675
|
-
# Check cache first
|
|
676
|
-
cache_key = "list_tables"
|
|
677
|
-
cached_data = self._get_cached_tables(cache_key)
|
|
678
|
-
if cached_data is not None:
|
|
679
|
-
return cached_data
|
|
680
|
-
|
|
681
|
-
# Fetch from database if not cached
|
|
682
643
|
tables = await self.introspector.list_tables_info(self.db)
|
|
683
644
|
|
|
684
645
|
# Format the result
|
|
@@ -694,14 +655,4 @@ class SchemaManager:
|
|
|
694
655
|
}
|
|
695
656
|
)
|
|
696
657
|
|
|
697
|
-
# Cache the result
|
|
698
|
-
self._schema_cache[cache_key] = (time.time(), result)
|
|
699
658
|
return result
|
|
700
|
-
|
|
701
|
-
def _get_cached_tables(self, cache_key: str) -> dict[str, Any] | None:
|
|
702
|
-
"""Get table list from cache if available and not expired."""
|
|
703
|
-
if cache_key in self._schema_cache:
|
|
704
|
-
cached_time, cached_data = self._schema_cache[cache_key]
|
|
705
|
-
if time.time() - cached_time < self.cache_ttl:
|
|
706
|
-
return cached_data
|
|
707
|
-
return None
|
sqlsaber/mcp/mcp.py
CHANGED
|
@@ -7,25 +7,17 @@ from fastmcp import FastMCP
|
|
|
7
7
|
from sqlsaber.agents.mcp import MCPSQLAgent
|
|
8
8
|
from sqlsaber.config.database import DatabaseConfigManager
|
|
9
9
|
from sqlsaber.database.connection import DatabaseConnection
|
|
10
|
+
from sqlsaber.tools import SQLTool, tool_registry
|
|
11
|
+
from sqlsaber.tools.instructions import InstructionBuilder
|
|
10
12
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
+
# Initialize the instruction builder
|
|
14
|
+
instruction_builder = InstructionBuilder(tool_registry)
|
|
13
15
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
- Call `introspect_schema()` to introspect database schema to understand table structures.
|
|
17
|
-
- Call `execute_sql()` to execute SQL queries against the database and retrieve results.
|
|
16
|
+
# Generate dynamic instructions
|
|
17
|
+
DYNAMIC_INSTRUCTIONS = instruction_builder.build_mcp_instructions()
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
- Use table patterns like 'sample%' or '%experiment%' to filter related tables
|
|
22
|
-
- Use proper JOIN syntax and avoid cartesian products
|
|
23
|
-
- Include appropriate WHERE clauses to limit results
|
|
24
|
-
- Handle errors gracefully and suggest fixes
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
# Create the FastMCP server instance
|
|
28
|
-
mcp = FastMCP(name="SQL Assistant", instructions=INSTRUCTIONS)
|
|
19
|
+
# Create the FastMCP server instance with dynamic instructions
|
|
20
|
+
mcp = FastMCP(name="SQL Assistant", instructions=DYNAMIC_INSTRUCTIONS)
|
|
29
21
|
|
|
30
22
|
# Initialize the database config manager
|
|
31
23
|
config_manager = DatabaseConfigManager()
|
|
@@ -70,10 +62,16 @@ def get_databases() -> dict:
|
|
|
70
62
|
return {"databases": databases, "count": len(databases)}
|
|
71
63
|
|
|
72
64
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
65
|
+
async def _execute_with_connection(tool_name: str, database: str, **kwargs) -> str:
|
|
66
|
+
"""Execute a SQL tool with database connection management.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
tool_name: Name of the tool to execute
|
|
70
|
+
database: Database name to connect to
|
|
71
|
+
**kwargs: Tool-specific parameters
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
JSON string with the tool's output
|
|
77
75
|
"""
|
|
78
76
|
try:
|
|
79
77
|
agent = await _create_agent_for_database(database)
|
|
@@ -82,50 +80,44 @@ async def list_tables(database: str) -> str:
|
|
|
82
80
|
{"error": f"Database '{database}' not found or could not connect"}
|
|
83
81
|
)
|
|
84
82
|
|
|
85
|
-
|
|
83
|
+
# Get the tool and set up connection
|
|
84
|
+
tool = tool_registry.get_tool(tool_name)
|
|
85
|
+
if isinstance(tool, SQLTool):
|
|
86
|
+
tool.set_connection(agent.db)
|
|
87
|
+
|
|
88
|
+
# Execute the tool
|
|
89
|
+
result = await tool.execute(**kwargs)
|
|
86
90
|
await agent.db.close()
|
|
87
91
|
return result
|
|
88
92
|
|
|
89
93
|
except Exception as e:
|
|
90
|
-
return json.dumps({"error": f"Error
|
|
94
|
+
return json.dumps({"error": f"Error in {tool_name}: {str(e)}"})
|
|
91
95
|
|
|
92
96
|
|
|
93
|
-
|
|
94
|
-
async def introspect_schema(database: str, table_pattern: str | None = None) -> str:
|
|
95
|
-
"""
|
|
96
|
-
Introspect database schema to understand table structures. Use optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%').
|
|
97
|
-
"""
|
|
98
|
-
try:
|
|
99
|
-
agent = await _create_agent_for_database(database)
|
|
100
|
-
if not agent:
|
|
101
|
-
return json.dumps(
|
|
102
|
-
{"error": f"Database '{database}' not found or could not connect"}
|
|
103
|
-
)
|
|
97
|
+
# SQL Tool Wrappers with explicit signatures
|
|
104
98
|
|
|
105
|
-
result = await agent.introspect_schema(table_pattern)
|
|
106
|
-
await agent.db.close()
|
|
107
|
-
return result
|
|
108
99
|
|
|
109
|
-
|
|
110
|
-
|
|
100
|
+
@mcp.tool
|
|
101
|
+
async def list_tables(database: str) -> str:
|
|
102
|
+
"""Get a list of all tables in the database with row counts. Use this first to discover available tables."""
|
|
103
|
+
return await _execute_with_connection("list_tables", database)
|
|
111
104
|
|
|
112
105
|
|
|
113
106
|
@mcp.tool
|
|
114
|
-
async def
|
|
115
|
-
"""
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
{"error": f"Database '{database}' not found or could not connect"}
|
|
121
|
-
)
|
|
107
|
+
async def introspect_schema(database: str, table_pattern: str = None) -> str:
|
|
108
|
+
"""Introspect database schema to understand table structures."""
|
|
109
|
+
kwargs = {}
|
|
110
|
+
if table_pattern is not None:
|
|
111
|
+
kwargs["table_pattern"] = table_pattern
|
|
112
|
+
return await _execute_with_connection("introspect_schema", database, **kwargs)
|
|
122
113
|
|
|
123
|
-
result = await agent.execute_sql(query, limit)
|
|
124
|
-
await agent.db.close()
|
|
125
|
-
return result
|
|
126
114
|
|
|
127
|
-
|
|
128
|
-
|
|
115
|
+
@mcp.tool
|
|
116
|
+
async def execute_sql(database: str, query: str, limit: int = 100) -> str:
|
|
117
|
+
"""Execute a SQL query against the database."""
|
|
118
|
+
return await _execute_with_connection(
|
|
119
|
+
"execute_sql", database, query=query, limit=limit
|
|
120
|
+
)
|
|
129
121
|
|
|
130
122
|
|
|
131
123
|
def main():
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""SQLSaber tools module."""
|
|
2
|
+
|
|
3
|
+
from .base import Tool
|
|
4
|
+
from .enums import ToolCategory, WorkflowPosition
|
|
5
|
+
from .instructions import InstructionBuilder
|
|
6
|
+
from .registry import ToolRegistry, register_tool, tool_registry
|
|
7
|
+
|
|
8
|
+
# Import concrete tools to register them
|
|
9
|
+
from .sql_tools import ExecuteSQLTool, IntrospectSchemaTool, ListTablesTool, SQLTool
|
|
10
|
+
from .visualization_tools import PlotDataTool
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"Tool",
|
|
14
|
+
"ToolCategory",
|
|
15
|
+
"WorkflowPosition",
|
|
16
|
+
"ToolRegistry",
|
|
17
|
+
"tool_registry",
|
|
18
|
+
"register_tool",
|
|
19
|
+
"InstructionBuilder",
|
|
20
|
+
"SQLTool",
|
|
21
|
+
"ListTablesTool",
|
|
22
|
+
"IntrospectSchemaTool",
|
|
23
|
+
"ExecuteSQLTool",
|
|
24
|
+
"PlotDataTool",
|
|
25
|
+
]
|
sqlsaber/tools/base.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Base class for SQLSaber tools."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from sqlsaber.clients.models import ToolDefinition
|
|
7
|
+
|
|
8
|
+
from .enums import ToolCategory, WorkflowPosition
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Tool(ABC):
|
|
12
|
+
"""Abstract base class for all tools."""
|
|
13
|
+
|
|
14
|
+
def __init__(self):
|
|
15
|
+
"""Initialize the tool."""
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def name(self) -> str:
|
|
21
|
+
"""Return the tool name."""
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def description(self) -> str:
|
|
27
|
+
"""Return the tool description."""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def input_schema(self) -> dict[str, Any]:
|
|
33
|
+
"""Return the tool's input schema."""
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
async def execute(self, **kwargs) -> str:
|
|
38
|
+
"""Execute the tool with given inputs.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
**kwargs: Tool-specific keyword arguments
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
JSON string with the tool's output
|
|
45
|
+
"""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
def to_definition(self) -> ToolDefinition:
|
|
49
|
+
"""Convert this tool to a ToolDefinition."""
|
|
50
|
+
return ToolDefinition(
|
|
51
|
+
name=self.name,
|
|
52
|
+
description=self.description,
|
|
53
|
+
input_schema=self.input_schema,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def category(self) -> ToolCategory:
|
|
58
|
+
"""Return the tool category. Override to customize."""
|
|
59
|
+
return ToolCategory.GENERAL
|
|
60
|
+
|
|
61
|
+
def get_usage_instructions(self) -> str | None:
|
|
62
|
+
"""Return tool-specific usage instructions for LLM guidance.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Usage instructions string, or None for no specific guidance
|
|
66
|
+
"""
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
def get_priority(self) -> int:
|
|
70
|
+
"""Return priority for tool ordering in instructions.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Priority number (lower = higher priority, default = 100)
|
|
74
|
+
"""
|
|
75
|
+
return 100
|
|
76
|
+
|
|
77
|
+
def get_workflow_position(self) -> WorkflowPosition:
|
|
78
|
+
"""Return the typical workflow position for this tool.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
WorkflowPosition enum value
|
|
82
|
+
"""
|
|
83
|
+
return WorkflowPosition.OTHER
|
sqlsaber/tools/enums.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Enums for tool categories and workflow positions."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ToolCategory(Enum):
|
|
7
|
+
"""Tool categories for organizing and filtering tools."""
|
|
8
|
+
|
|
9
|
+
GENERAL = "general"
|
|
10
|
+
SQL = "sql"
|
|
11
|
+
VISUALIZATION = "visualization"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class WorkflowPosition(Enum):
|
|
15
|
+
"""Workflow positions for organizing tools by usage order."""
|
|
16
|
+
|
|
17
|
+
DISCOVERY = "discovery"
|
|
18
|
+
ANALYSIS = "analysis"
|
|
19
|
+
EXECUTION = "execution"
|
|
20
|
+
VISUALIZATION = "visualization"
|
|
21
|
+
OTHER = "other"
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
"""Dynamic instruction builder for tools."""
|
|
2
|
+
|
|
3
|
+
from .base import Tool
|
|
4
|
+
from .enums import ToolCategory, WorkflowPosition
|
|
5
|
+
from .registry import ToolRegistry
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class InstructionBuilder:
|
|
9
|
+
"""Builds dynamic instructions based on available tools."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, tool_registry: ToolRegistry):
|
|
12
|
+
"""Initialize with a tool registry."""
|
|
13
|
+
self.registry = tool_registry
|
|
14
|
+
|
|
15
|
+
def build_instructions(
|
|
16
|
+
self,
|
|
17
|
+
db_type: str = "database",
|
|
18
|
+
category: str | ToolCategory | None = None,
|
|
19
|
+
include_base_instructions: bool = True,
|
|
20
|
+
) -> str:
|
|
21
|
+
"""Build dynamic instructions from available tools.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
db_type: Type of database (PostgreSQL, MySQL, SQLite, etc.)
|
|
25
|
+
category: Optional category to filter tools by (string or ToolCategory enum)
|
|
26
|
+
include_base_instructions: Whether to include base SQL assistant instructions
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Complete instruction string for LLM
|
|
30
|
+
"""
|
|
31
|
+
# Get available tools
|
|
32
|
+
tools = self.registry.get_all_tools(category)
|
|
33
|
+
|
|
34
|
+
if not tools:
|
|
35
|
+
return self._get_base_instructions(db_type)
|
|
36
|
+
|
|
37
|
+
# Sort tools by priority and workflow position
|
|
38
|
+
sorted_tools = self._sort_tools_by_workflow(tools)
|
|
39
|
+
|
|
40
|
+
# Build instruction components
|
|
41
|
+
instructions_parts = []
|
|
42
|
+
|
|
43
|
+
if include_base_instructions:
|
|
44
|
+
instructions_parts.append(self._get_base_instructions(db_type))
|
|
45
|
+
|
|
46
|
+
# Add tool-specific workflow guidance
|
|
47
|
+
workflow_instructions = self._build_workflow_instructions(sorted_tools)
|
|
48
|
+
if workflow_instructions:
|
|
49
|
+
instructions_parts.append(workflow_instructions)
|
|
50
|
+
|
|
51
|
+
# Add tool descriptions and guidelines
|
|
52
|
+
tool_guidelines = self._build_tool_guidelines(sorted_tools)
|
|
53
|
+
if tool_guidelines:
|
|
54
|
+
instructions_parts.append(tool_guidelines)
|
|
55
|
+
|
|
56
|
+
# Add general guidelines
|
|
57
|
+
general_guidelines = self._build_general_guidelines(sorted_tools)
|
|
58
|
+
if general_guidelines:
|
|
59
|
+
instructions_parts.append(general_guidelines)
|
|
60
|
+
|
|
61
|
+
return "\n\n".join(instructions_parts)
|
|
62
|
+
|
|
63
|
+
def _get_base_instructions(self, db_type: str) -> str:
|
|
64
|
+
"""Get base SQL assistant instructions."""
|
|
65
|
+
return f"""You are also a helpful SQL assistant that helps users query their {db_type} database.
|
|
66
|
+
|
|
67
|
+
Your responsibilities:
|
|
68
|
+
1. Understand user's natural language requests, think and convert them to SQL
|
|
69
|
+
2. Use the provided tools efficiently to explore database schema
|
|
70
|
+
3. Generate appropriate SQL queries
|
|
71
|
+
4. Execute queries safely - queries that modify the database are not allowed
|
|
72
|
+
5. Format and explain results clearly
|
|
73
|
+
6. Create visualizations when requested or when they would be helpful"""
|
|
74
|
+
|
|
75
|
+
def _sort_tools_by_workflow(self, tools: list[Tool]) -> list[Tool]:
|
|
76
|
+
"""Sort tools by priority and workflow position."""
|
|
77
|
+
# Define workflow position ordering
|
|
78
|
+
position_order = {
|
|
79
|
+
WorkflowPosition.DISCOVERY: 1,
|
|
80
|
+
WorkflowPosition.ANALYSIS: 2,
|
|
81
|
+
WorkflowPosition.EXECUTION: 3,
|
|
82
|
+
WorkflowPosition.VISUALIZATION: 4,
|
|
83
|
+
WorkflowPosition.OTHER: 5,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
return sorted(
|
|
87
|
+
tools,
|
|
88
|
+
key=lambda tool: (
|
|
89
|
+
position_order.get(tool.get_workflow_position(), 5),
|
|
90
|
+
tool.get_priority(),
|
|
91
|
+
tool.name,
|
|
92
|
+
),
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def _build_workflow_instructions(self, sorted_tools: list[Tool]) -> str:
|
|
96
|
+
"""Build workflow-based instructions."""
|
|
97
|
+
# Group tools by workflow position
|
|
98
|
+
workflow_groups = {}
|
|
99
|
+
for tool in sorted_tools:
|
|
100
|
+
position = tool.get_workflow_position()
|
|
101
|
+
if position not in workflow_groups:
|
|
102
|
+
workflow_groups[position] = []
|
|
103
|
+
workflow_groups[position].append(tool)
|
|
104
|
+
|
|
105
|
+
# Build workflow instructions
|
|
106
|
+
instructions = ["IMPORTANT - Tool Usage Strategy:"]
|
|
107
|
+
step = 1
|
|
108
|
+
|
|
109
|
+
# Add discovery tools first
|
|
110
|
+
if WorkflowPosition.DISCOVERY in workflow_groups:
|
|
111
|
+
discovery_tools = workflow_groups[WorkflowPosition.DISCOVERY]
|
|
112
|
+
for tool in discovery_tools:
|
|
113
|
+
usage = tool.get_usage_instructions()
|
|
114
|
+
if usage:
|
|
115
|
+
instructions.append(f"{step}. {usage}")
|
|
116
|
+
else:
|
|
117
|
+
instructions.append(
|
|
118
|
+
f"{step}. Use '{tool.name}' to {tool.description.lower()}"
|
|
119
|
+
)
|
|
120
|
+
step += 1
|
|
121
|
+
|
|
122
|
+
# Add analysis tools
|
|
123
|
+
if WorkflowPosition.ANALYSIS in workflow_groups:
|
|
124
|
+
analysis_tools = workflow_groups[WorkflowPosition.ANALYSIS]
|
|
125
|
+
for tool in analysis_tools:
|
|
126
|
+
usage = tool.get_usage_instructions()
|
|
127
|
+
if usage:
|
|
128
|
+
instructions.append(f"{step}. {usage}")
|
|
129
|
+
else:
|
|
130
|
+
instructions.append(
|
|
131
|
+
f"{step}. Use '{tool.name}' to {tool.description.lower()}"
|
|
132
|
+
)
|
|
133
|
+
step += 1
|
|
134
|
+
|
|
135
|
+
# Add execution tools
|
|
136
|
+
if WorkflowPosition.EXECUTION in workflow_groups:
|
|
137
|
+
execution_tools = workflow_groups[WorkflowPosition.EXECUTION]
|
|
138
|
+
for tool in execution_tools:
|
|
139
|
+
usage = tool.get_usage_instructions()
|
|
140
|
+
if usage:
|
|
141
|
+
instructions.append(f"{step}. {usage}")
|
|
142
|
+
else:
|
|
143
|
+
instructions.append(
|
|
144
|
+
f"{step}. Use '{tool.name}' to {tool.description.lower()}"
|
|
145
|
+
)
|
|
146
|
+
step += 1
|
|
147
|
+
|
|
148
|
+
# Add visualization tools
|
|
149
|
+
if WorkflowPosition.VISUALIZATION in workflow_groups:
|
|
150
|
+
viz_tools = workflow_groups[WorkflowPosition.VISUALIZATION]
|
|
151
|
+
for tool in viz_tools:
|
|
152
|
+
usage = tool.get_usage_instructions()
|
|
153
|
+
if usage:
|
|
154
|
+
instructions.append(f"{step}. {usage}")
|
|
155
|
+
else:
|
|
156
|
+
instructions.append(
|
|
157
|
+
f"{step}. Use '{tool.name}' when creating visualizations"
|
|
158
|
+
)
|
|
159
|
+
step += 1
|
|
160
|
+
|
|
161
|
+
return "\n".join(instructions) if len(instructions) > 1 else ""
|
|
162
|
+
|
|
163
|
+
def _build_tool_guidelines(self, sorted_tools: list[Tool]) -> str:
|
|
164
|
+
"""Build tool-specific guidelines."""
|
|
165
|
+
guidelines = []
|
|
166
|
+
|
|
167
|
+
for tool in sorted_tools:
|
|
168
|
+
usage = tool.get_usage_instructions()
|
|
169
|
+
if usage and not self._is_usage_in_workflow(usage):
|
|
170
|
+
guidelines.append(f"- {tool.name}: {usage}")
|
|
171
|
+
|
|
172
|
+
if guidelines:
|
|
173
|
+
return "Tool-Specific Guidelines:\n" + "\n".join(guidelines)
|
|
174
|
+
return ""
|
|
175
|
+
|
|
176
|
+
def _build_general_guidelines(self, sorted_tools: list[Tool]) -> str:
|
|
177
|
+
"""Build general usage guidelines."""
|
|
178
|
+
guidelines = [
|
|
179
|
+
"Guidelines:",
|
|
180
|
+
"- Use proper JOIN syntax and avoid cartesian products",
|
|
181
|
+
"- Include appropriate WHERE clauses to limit results",
|
|
182
|
+
"- Explain what the query does in simple terms",
|
|
183
|
+
"- Handle errors gracefully and suggest fixes",
|
|
184
|
+
"- Be security conscious - use parameterized queries when needed",
|
|
185
|
+
]
|
|
186
|
+
|
|
187
|
+
# Add category-specific guidelines
|
|
188
|
+
categories = {tool.category for tool in sorted_tools}
|
|
189
|
+
|
|
190
|
+
if ToolCategory.SQL in categories:
|
|
191
|
+
guidelines.extend(
|
|
192
|
+
[
|
|
193
|
+
"- Timestamp columns must be converted to text when you write queries",
|
|
194
|
+
"- Use table patterns like 'sample%' or '%experiment%' to filter related tables",
|
|
195
|
+
]
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if ToolCategory.VISUALIZATION in categories:
|
|
199
|
+
guidelines.append(
|
|
200
|
+
"- Create visualizations when they would enhance understanding of the data"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
return "\n".join(guidelines)
|
|
204
|
+
|
|
205
|
+
def _is_usage_in_workflow(self, usage: str) -> bool:
|
|
206
|
+
"""Check if usage instruction is already covered in workflow section."""
|
|
207
|
+
# Simple heuristic - if usage starts with workflow words, it's probably in workflow
|
|
208
|
+
workflow_words = ["always start", "first", "use this", "begin with", "start by"]
|
|
209
|
+
usage_lower = usage.lower()
|
|
210
|
+
return any(word in usage_lower for word in workflow_words)
|
|
211
|
+
|
|
212
|
+
def build_mcp_instructions(self) -> str:
|
|
213
|
+
"""Build instructions specifically for MCP server."""
|
|
214
|
+
instructions = [
|
|
215
|
+
"This server provides helpful resources and tools that will help you address users queries on their database.",
|
|
216
|
+
"",
|
|
217
|
+
]
|
|
218
|
+
|
|
219
|
+
# Add database discovery
|
|
220
|
+
instructions.append("- Get all databases using `get_databases()`")
|
|
221
|
+
|
|
222
|
+
# Add tool-specific instructions
|
|
223
|
+
sql_tools = self.registry.get_all_tools(category=ToolCategory.SQL)
|
|
224
|
+
sorted_tools = self._sort_tools_by_workflow(sql_tools)
|
|
225
|
+
|
|
226
|
+
for tool in sorted_tools:
|
|
227
|
+
instructions.append(f"- Call `{tool.name}()` to {tool.description.lower()}")
|
|
228
|
+
|
|
229
|
+
# Add workflow guidelines
|
|
230
|
+
instructions.extend(["", "Guidelines:"])
|
|
231
|
+
|
|
232
|
+
workflow_instructions = self._build_workflow_instructions(sorted_tools)
|
|
233
|
+
if workflow_instructions:
|
|
234
|
+
# Extract just the numbered steps without the "IMPORTANT" header
|
|
235
|
+
lines = workflow_instructions.split("\n")[1:] # Skip header
|
|
236
|
+
for line in lines:
|
|
237
|
+
if line.strip():
|
|
238
|
+
# Convert numbered steps to bullet points
|
|
239
|
+
if line.strip()[0].isdigit():
|
|
240
|
+
instructions.append(f"- {line.strip()[3:]}") # Remove "X. "
|
|
241
|
+
|
|
242
|
+
# Add general guidelines
|
|
243
|
+
instructions.extend(
|
|
244
|
+
[
|
|
245
|
+
"- Use proper JOIN syntax and avoid cartesian products",
|
|
246
|
+
"- Include appropriate WHERE clauses to limit results",
|
|
247
|
+
"- Handle errors gracefully and suggest fixes",
|
|
248
|
+
]
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
return "\n".join(instructions)
|