sqlsaber 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

Files changed (38) hide show
  1. sqlsaber/agents/__init__.py +2 -4
  2. sqlsaber/agents/base.py +18 -221
  3. sqlsaber/agents/mcp.py +2 -2
  4. sqlsaber/agents/pydantic_ai_agent.py +170 -0
  5. sqlsaber/cli/auth.py +146 -79
  6. sqlsaber/cli/commands.py +22 -7
  7. sqlsaber/cli/database.py +1 -1
  8. sqlsaber/cli/interactive.py +65 -30
  9. sqlsaber/cli/models.py +58 -29
  10. sqlsaber/cli/streaming.py +114 -77
  11. sqlsaber/config/api_keys.py +9 -11
  12. sqlsaber/config/providers.py +116 -0
  13. sqlsaber/config/settings.py +50 -30
  14. sqlsaber/database/connection.py +3 -3
  15. sqlsaber/mcp/mcp.py +43 -51
  16. sqlsaber/models/__init__.py +0 -3
  17. sqlsaber/tools/__init__.py +25 -0
  18. sqlsaber/tools/base.py +85 -0
  19. sqlsaber/tools/enums.py +21 -0
  20. sqlsaber/tools/instructions.py +251 -0
  21. sqlsaber/tools/registry.py +130 -0
  22. sqlsaber/tools/sql_tools.py +275 -0
  23. sqlsaber/tools/visualization_tools.py +144 -0
  24. {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/METADATA +20 -39
  25. sqlsaber-0.16.0.dist-info/RECORD +51 -0
  26. sqlsaber/agents/anthropic.py +0 -579
  27. sqlsaber/agents/streaming.py +0 -16
  28. sqlsaber/clients/__init__.py +0 -6
  29. sqlsaber/clients/anthropic.py +0 -285
  30. sqlsaber/clients/base.py +0 -31
  31. sqlsaber/clients/exceptions.py +0 -117
  32. sqlsaber/clients/models.py +0 -282
  33. sqlsaber/clients/streaming.py +0 -257
  34. sqlsaber/models/events.py +0 -28
  35. sqlsaber-0.14.0.dist-info/RECORD +0 -51
  36. {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/WHEEL +0 -0
  37. {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/entry_points.txt +0 -0
  38. {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/mcp/mcp.py CHANGED
@@ -7,25 +7,17 @@ from fastmcp import FastMCP
7
7
  from sqlsaber.agents.mcp import MCPSQLAgent
8
8
  from sqlsaber.config.database import DatabaseConfigManager
9
9
  from sqlsaber.database.connection import DatabaseConnection
10
+ from sqlsaber.tools import SQLTool, tool_registry
11
+ from sqlsaber.tools.instructions import InstructionBuilder
10
12
 
11
- INSTRUCTIONS = """
12
- This server provides helpful resources and tools that will help you address users queries on their database.
13
+ # Initialize the instruction builder
14
+ instruction_builder = InstructionBuilder(tool_registry)
13
15
 
14
- - Get all databases using `get_databases()`
15
- - Call `list_tables()` to get a list of all tables in the database with row counts. Use this first to discover available tables.
16
- - Call `introspect_schema()` to introspect database schema to understand table structures.
17
- - Call `execute_sql()` to execute SQL queries against the database and retrieve results.
16
+ # Generate dynamic instructions
17
+ DYNAMIC_INSTRUCTIONS = instruction_builder.build_mcp_instructions()
18
18
 
19
- Guidelines:
20
- - Use list_tables first, then introspect_schema for specific tables only
21
- - Use table patterns like 'sample%' or '%experiment%' to filter related tables
22
- - Use proper JOIN syntax and avoid cartesian products
23
- - Include appropriate WHERE clauses to limit results
24
- - Handle errors gracefully and suggest fixes
25
- """
26
-
27
- # Create the FastMCP server instance
28
- mcp = FastMCP(name="SQL Assistant", instructions=INSTRUCTIONS)
19
+ # Create the FastMCP server instance with dynamic instructions
20
+ mcp = FastMCP(name="SQL Assistant", instructions=DYNAMIC_INSTRUCTIONS)
29
21
 
30
22
  # Initialize the database config manager
31
23
  config_manager = DatabaseConfigManager()
@@ -70,10 +62,16 @@ def get_databases() -> dict:
70
62
  return {"databases": databases, "count": len(databases)}
71
63
 
72
64
 
73
- @mcp.tool
74
- async def list_tables(database: str) -> str:
75
- """
76
- Get a list of all tables in the database with row counts. Use this first to discover available tables.
65
+ async def _execute_with_connection(tool_name: str, database: str, **kwargs) -> str:
66
+ """Execute a SQL tool with database connection management.
67
+
68
+ Args:
69
+ tool_name: Name of the tool to execute
70
+ database: Database name to connect to
71
+ **kwargs: Tool-specific parameters
72
+
73
+ Returns:
74
+ JSON string with the tool's output
77
75
  """
78
76
  try:
79
77
  agent = await _create_agent_for_database(database)
@@ -82,50 +80,44 @@ async def list_tables(database: str) -> str:
82
80
  {"error": f"Database '{database}' not found or could not connect"}
83
81
  )
84
82
 
85
- result = await agent.list_tables()
83
+ # Get the tool and set up connection
84
+ tool = tool_registry.get_tool(tool_name)
85
+ if isinstance(tool, SQLTool):
86
+ tool.set_connection(agent.db)
87
+
88
+ # Execute the tool
89
+ result = await tool.execute(**kwargs)
86
90
  await agent.db.close()
87
91
  return result
88
92
 
89
93
  except Exception as e:
90
- return json.dumps({"error": f"Error listing tables: {str(e)}"})
94
+ return json.dumps({"error": f"Error in {tool_name}: {str(e)}"})
91
95
 
92
96
 
93
- @mcp.tool
94
- async def introspect_schema(database: str, table_pattern: str | None = None) -> str:
95
- """
96
- Introspect database schema to understand table structures. Use optional pattern to filter tables (e.g., 'public.users', 'user%', '%order%').
97
- """
98
- try:
99
- agent = await _create_agent_for_database(database)
100
- if not agent:
101
- return json.dumps(
102
- {"error": f"Database '{database}' not found or could not connect"}
103
- )
97
+ # SQL Tool Wrappers with explicit signatures
104
98
 
105
- result = await agent.introspect_schema(table_pattern)
106
- await agent.db.close()
107
- return result
108
99
 
109
- except Exception as e:
110
- return json.dumps({"error": f"Error introspecting schema: {str(e)}"})
100
+ @mcp.tool
101
+ async def list_tables(database: str) -> str:
102
+ """Get a list of all tables in the database with row counts. Use this first to discover available tables."""
103
+ return await _execute_with_connection("list_tables", database)
111
104
 
112
105
 
113
106
  @mcp.tool
114
- async def execute_sql(database: str, query: str, limit: int | None = 100) -> str:
115
- """Execute a SQL query against the specified database."""
116
- try:
117
- agent = await _create_agent_for_database(database)
118
- if not agent:
119
- return json.dumps(
120
- {"error": f"Database '{database}' not found or could not connect"}
121
- )
107
+ async def introspect_schema(database: str, table_pattern: str = None) -> str:
108
+ """Introspect database schema to understand table structures."""
109
+ kwargs = {}
110
+ if table_pattern is not None:
111
+ kwargs["table_pattern"] = table_pattern
112
+ return await _execute_with_connection("introspect_schema", database, **kwargs)
122
113
 
123
- result = await agent.execute_sql(query, limit)
124
- await agent.db.close()
125
- return result
126
114
 
127
- except Exception as e:
128
- return json.dumps({"error": f"Error executing SQL: {str(e)}"})
115
+ @mcp.tool
116
+ async def execute_sql(database: str, query: str, limit: int = 100) -> str:
117
+ """Execute a SQL query against the database."""
118
+ return await _execute_with_connection(
119
+ "execute_sql", database, query=query, limit=limit
120
+ )
129
121
 
130
122
 
131
123
  def main():
@@ -1,11 +1,8 @@
1
1
  """Models module for SQLSaber."""
2
2
 
3
- from .events import StreamEvent, SQLResponse
4
3
  from .types import ColumnInfo, ForeignKeyInfo, SchemaInfo, ToolDefinition
5
4
 
6
5
  __all__ = [
7
- "StreamEvent",
8
- "SQLResponse",
9
6
  "ColumnInfo",
10
7
  "ForeignKeyInfo",
11
8
  "SchemaInfo",
@@ -0,0 +1,25 @@
1
+ """SQLSaber tools module."""
2
+
3
+ from .base import Tool
4
+ from .enums import ToolCategory, WorkflowPosition
5
+ from .instructions import InstructionBuilder
6
+ from .registry import ToolRegistry, register_tool, tool_registry
7
+
8
+ # Import concrete tools to register them
9
+ from .sql_tools import ExecuteSQLTool, IntrospectSchemaTool, ListTablesTool, SQLTool
10
+ from .visualization_tools import PlotDataTool
11
+
12
+ __all__ = [
13
+ "Tool",
14
+ "ToolCategory",
15
+ "WorkflowPosition",
16
+ "ToolRegistry",
17
+ "tool_registry",
18
+ "register_tool",
19
+ "InstructionBuilder",
20
+ "SQLTool",
21
+ "ListTablesTool",
22
+ "IntrospectSchemaTool",
23
+ "ExecuteSQLTool",
24
+ "PlotDataTool",
25
+ ]
sqlsaber/tools/base.py ADDED
@@ -0,0 +1,85 @@
1
+ """Base class for SQLSaber tools."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from types import SimpleNamespace
5
+ from typing import Any
6
+
7
+ from .enums import ToolCategory, WorkflowPosition
8
+
9
+
10
+ class Tool(ABC):
11
+ """Abstract base class for all tools."""
12
+
13
+ def __init__(self):
14
+ """Initialize the tool."""
15
+ pass
16
+
17
+ @property
18
+ @abstractmethod
19
+ def name(self) -> str:
20
+ """Return the tool name."""
21
+ pass
22
+
23
+ @property
24
+ @abstractmethod
25
+ def description(self) -> str:
26
+ """Return the tool description."""
27
+ pass
28
+
29
+ @property
30
+ @abstractmethod
31
+ def input_schema(self) -> dict[str, Any]:
32
+ """Return the tool's input schema."""
33
+ pass
34
+
35
+ @abstractmethod
36
+ async def execute(self, **kwargs) -> str:
37
+ """Execute the tool with given inputs.
38
+
39
+ Args:
40
+ **kwargs: Tool-specific keyword arguments
41
+
42
+ Returns:
43
+ JSON string with the tool's output
44
+ """
45
+ pass
46
+
47
+ def to_definition(self):
48
+ """Convert this tool to a ToolDefinition-like object with attributes.
49
+
50
+ Tests expect attribute access (definition.name), so return a SimpleNamespace.
51
+ """
52
+ return SimpleNamespace(
53
+ name=self.name,
54
+ description=self.description,
55
+ input_schema=self.input_schema,
56
+ )
57
+
58
+ @property
59
+ def category(self) -> ToolCategory:
60
+ """Return the tool category. Override to customize."""
61
+ return ToolCategory.GENERAL
62
+
63
+ def get_usage_instructions(self) -> str | None:
64
+ """Return tool-specific usage instructions for LLM guidance.
65
+
66
+ Returns:
67
+ Usage instructions string, or None for no specific guidance
68
+ """
69
+ return None
70
+
71
+ def get_priority(self) -> int:
72
+ """Return priority for tool ordering in instructions.
73
+
74
+ Returns:
75
+ Priority number (lower = higher priority, default = 100)
76
+ """
77
+ return 100
78
+
79
+ def get_workflow_position(self) -> WorkflowPosition:
80
+ """Return the typical workflow position for this tool.
81
+
82
+ Returns:
83
+ WorkflowPosition enum value
84
+ """
85
+ return WorkflowPosition.OTHER
@@ -0,0 +1,21 @@
1
+ """Enums for tool categories and workflow positions."""
2
+
3
+ from enum import Enum
4
+
5
+
6
+ class ToolCategory(Enum):
7
+ """Tool categories for organizing and filtering tools."""
8
+
9
+ GENERAL = "general"
10
+ SQL = "sql"
11
+ VISUALIZATION = "visualization"
12
+
13
+
14
+ class WorkflowPosition(Enum):
15
+ """Workflow positions for organizing tools by usage order."""
16
+
17
+ DISCOVERY = "discovery"
18
+ ANALYSIS = "analysis"
19
+ EXECUTION = "execution"
20
+ VISUALIZATION = "visualization"
21
+ OTHER = "other"
@@ -0,0 +1,251 @@
1
+ """Dynamic instruction builder for tools."""
2
+
3
+ from .base import Tool
4
+ from .enums import ToolCategory, WorkflowPosition
5
+ from .registry import ToolRegistry
6
+
7
+
8
+ class InstructionBuilder:
9
+ """Builds dynamic instructions based on available tools."""
10
+
11
+ def __init__(self, tool_registry: ToolRegistry):
12
+ """Initialize with a tool registry."""
13
+ self.registry = tool_registry
14
+
15
+ def build_instructions(
16
+ self,
17
+ db_type: str = "database",
18
+ category: str | ToolCategory | None = None,
19
+ include_base_instructions: bool = True,
20
+ ) -> str:
21
+ """Build dynamic instructions from available tools.
22
+
23
+ Args:
24
+ db_type: Type of database (PostgreSQL, MySQL, SQLite, etc.)
25
+ category: Optional category to filter tools by (string or ToolCategory enum)
26
+ include_base_instructions: Whether to include base SQL assistant instructions
27
+
28
+ Returns:
29
+ Complete instruction string for LLM
30
+ """
31
+ # Get available tools
32
+ tools = self.registry.get_all_tools(category)
33
+
34
+ if not tools:
35
+ return self._get_base_instructions(db_type)
36
+
37
+ # Sort tools by priority and workflow position
38
+ sorted_tools = self._sort_tools_by_workflow(tools)
39
+
40
+ # Build instruction components
41
+ instructions_parts = []
42
+
43
+ if include_base_instructions:
44
+ instructions_parts.append(self._get_base_instructions(db_type))
45
+
46
+ # Add tool-specific workflow guidance
47
+ workflow_instructions = self._build_workflow_instructions(sorted_tools)
48
+ if workflow_instructions:
49
+ instructions_parts.append(workflow_instructions)
50
+
51
+ # Add tool descriptions and guidelines
52
+ tool_guidelines = self._build_tool_guidelines(sorted_tools)
53
+ if tool_guidelines:
54
+ instructions_parts.append(tool_guidelines)
55
+
56
+ # Add general guidelines
57
+ general_guidelines = self._build_general_guidelines(sorted_tools)
58
+ if general_guidelines:
59
+ instructions_parts.append(general_guidelines)
60
+
61
+ return "\n\n".join(instructions_parts)
62
+
63
+ def _get_base_instructions(self, db_type: str) -> str:
64
+ """Get base SQL assistant instructions."""
65
+ return f"""You are also a helpful SQL assistant that helps users query their {db_type} database.
66
+
67
+ Your responsibilities:
68
+ 1. Understand user's natural language requests, think and convert them to SQL
69
+ 2. Use the provided tools efficiently to explore database schema
70
+ 3. Generate appropriate SQL queries
71
+ 4. Execute queries safely - queries that modify the database are not allowed
72
+ 5. Format and explain results clearly
73
+ 6. Create visualizations when requested or when they would be helpful"""
74
+
75
+ def _sort_tools_by_workflow(self, tools: list[Tool]) -> list[Tool]:
76
+ """Sort tools by priority and workflow position."""
77
+ # Define workflow position ordering
78
+ position_order = {
79
+ WorkflowPosition.DISCOVERY: 1,
80
+ WorkflowPosition.ANALYSIS: 2,
81
+ WorkflowPosition.EXECUTION: 3,
82
+ WorkflowPosition.VISUALIZATION: 4,
83
+ WorkflowPosition.OTHER: 5,
84
+ }
85
+
86
+ return sorted(
87
+ tools,
88
+ key=lambda tool: (
89
+ position_order.get(tool.get_workflow_position(), 5),
90
+ tool.get_priority(),
91
+ tool.name,
92
+ ),
93
+ )
94
+
95
+ def _build_workflow_instructions(self, sorted_tools: list[Tool]) -> str:
96
+ """Build workflow-based instructions."""
97
+ # Group tools by workflow position
98
+ workflow_groups = {}
99
+ for tool in sorted_tools:
100
+ position = tool.get_workflow_position()
101
+ if position not in workflow_groups:
102
+ workflow_groups[position] = []
103
+ workflow_groups[position].append(tool)
104
+
105
+ # Build workflow instructions
106
+ instructions = ["IMPORTANT - Tool Usage Strategy:"]
107
+ step = 1
108
+
109
+ # Add discovery tools first
110
+ if WorkflowPosition.DISCOVERY in workflow_groups:
111
+ discovery_tools = workflow_groups[WorkflowPosition.DISCOVERY]
112
+ for tool in discovery_tools:
113
+ usage = tool.get_usage_instructions()
114
+ if usage:
115
+ instructions.append(f"{step}. {usage}")
116
+ else:
117
+ instructions.append(
118
+ f"{step}. Use '{tool.name}' to {tool.description.lower()}"
119
+ )
120
+ step += 1
121
+
122
+ # Add analysis tools
123
+ if WorkflowPosition.ANALYSIS in workflow_groups:
124
+ analysis_tools = workflow_groups[WorkflowPosition.ANALYSIS]
125
+ for tool in analysis_tools:
126
+ usage = tool.get_usage_instructions()
127
+ if usage:
128
+ instructions.append(f"{step}. {usage}")
129
+ else:
130
+ instructions.append(
131
+ f"{step}. Use '{tool.name}' to {tool.description.lower()}"
132
+ )
133
+ step += 1
134
+
135
+ # Add execution tools
136
+ if WorkflowPosition.EXECUTION in workflow_groups:
137
+ execution_tools = workflow_groups[WorkflowPosition.EXECUTION]
138
+ for tool in execution_tools:
139
+ usage = tool.get_usage_instructions()
140
+ if usage:
141
+ instructions.append(f"{step}. {usage}")
142
+ else:
143
+ instructions.append(
144
+ f"{step}. Use '{tool.name}' to {tool.description.lower()}"
145
+ )
146
+ step += 1
147
+
148
+ # Add visualization tools
149
+ if WorkflowPosition.VISUALIZATION in workflow_groups:
150
+ viz_tools = workflow_groups[WorkflowPosition.VISUALIZATION]
151
+ for tool in viz_tools:
152
+ usage = tool.get_usage_instructions()
153
+ if usage:
154
+ instructions.append(f"{step}. {usage}")
155
+ else:
156
+ instructions.append(
157
+ f"{step}. Use '{tool.name}' when creating visualizations"
158
+ )
159
+ step += 1
160
+
161
+ return "\n".join(instructions) if len(instructions) > 1 else ""
162
+
163
+ def _build_tool_guidelines(self, sorted_tools: list[Tool]) -> str:
164
+ """Build tool-specific guidelines."""
165
+ guidelines = []
166
+
167
+ for tool in sorted_tools:
168
+ usage = tool.get_usage_instructions()
169
+ if usage and not self._is_usage_in_workflow(usage):
170
+ guidelines.append(f"- {tool.name}: {usage}")
171
+
172
+ if guidelines:
173
+ return "Tool-Specific Guidelines:\n" + "\n".join(guidelines)
174
+ return ""
175
+
176
+ def _build_general_guidelines(self, sorted_tools: list[Tool]) -> str:
177
+ """Build general usage guidelines."""
178
+ guidelines = [
179
+ "Guidelines:",
180
+ "- Use proper JOIN syntax and avoid cartesian products",
181
+ "- Include appropriate WHERE clauses to limit results",
182
+ "- Explain what the query does in simple terms",
183
+ "- Handle errors gracefully and suggest fixes",
184
+ "- Be security conscious - use parameterized queries when needed",
185
+ ]
186
+
187
+ # Add category-specific guidelines
188
+ categories = {tool.category for tool in sorted_tools}
189
+
190
+ if ToolCategory.SQL in categories:
191
+ guidelines.extend(
192
+ [
193
+ "- Timestamp columns must be converted to text when you write queries",
194
+ "- Use table patterns like 'sample%' or '%experiment%' to filter related tables",
195
+ ]
196
+ )
197
+
198
+ if ToolCategory.VISUALIZATION in categories:
199
+ guidelines.append(
200
+ "- Create visualizations when they would enhance understanding of the data"
201
+ )
202
+
203
+ return "\n".join(guidelines)
204
+
205
+ def _is_usage_in_workflow(self, usage: str) -> bool:
206
+ """Check if usage instruction is already covered in workflow section."""
207
+ # Simple heuristic - if usage starts with workflow words, it's probably in workflow
208
+ workflow_words = ["always start", "first", "use this", "begin with", "start by"]
209
+ usage_lower = usage.lower()
210
+ return any(word in usage_lower for word in workflow_words)
211
+
212
+ def build_mcp_instructions(self) -> str:
213
+ """Build instructions specifically for MCP server."""
214
+ instructions = [
215
+ "This server provides helpful resources and tools that will help you address users queries on their database.",
216
+ "",
217
+ ]
218
+
219
+ # Add database discovery
220
+ instructions.append("- Get all databases using `get_databases()`")
221
+
222
+ # Add tool-specific instructions
223
+ sql_tools = self.registry.get_all_tools(category=ToolCategory.SQL)
224
+ sorted_tools = self._sort_tools_by_workflow(sql_tools)
225
+
226
+ for tool in sorted_tools:
227
+ instructions.append(f"- Call `{tool.name}()` to {tool.description.lower()}")
228
+
229
+ # Add workflow guidelines
230
+ instructions.extend(["", "Guidelines:"])
231
+
232
+ workflow_instructions = self._build_workflow_instructions(sorted_tools)
233
+ if workflow_instructions:
234
+ # Extract just the numbered steps without the "IMPORTANT" header
235
+ lines = workflow_instructions.split("\n")[1:] # Skip header
236
+ for line in lines:
237
+ if line.strip():
238
+ # Convert numbered steps to bullet points
239
+ if line.strip()[0].isdigit():
240
+ instructions.append(f"- {line.strip()[3:]}") # Remove "X. "
241
+
242
+ # Add general guidelines
243
+ instructions.extend(
244
+ [
245
+ "- Use proper JOIN syntax and avoid cartesian products",
246
+ "- Include appropriate WHERE clauses to limit results",
247
+ "- Handle errors gracefully and suggest fixes",
248
+ ]
249
+ )
250
+
251
+ return "\n".join(instructions)
@@ -0,0 +1,130 @@
1
+ """Tool registry for managing available tools."""
2
+
3
+ from typing import Type
4
+
5
+ from .base import Tool
6
+ from .enums import ToolCategory
7
+
8
+
9
+ class ToolRegistry:
10
+ """Registry for managing and discovering tools."""
11
+
12
+ def __init__(self):
13
+ """Initialize the registry."""
14
+ self._tools: dict[str, Type[Tool]] = {}
15
+ self._instances: dict[str, Tool] = {}
16
+
17
+ def register(self, tool_class: Type[Tool]) -> None:
18
+ """Register a tool class.
19
+
20
+ Args:
21
+ tool_class: The tool class to register
22
+ """
23
+ # Create a temporary instance to get the name
24
+ temp_instance = tool_class()
25
+ name = temp_instance.name
26
+
27
+ if name in self._tools:
28
+ raise ValueError(f"Tool '{name}' is already registered")
29
+
30
+ self._tools[name] = tool_class
31
+
32
+ def unregister(self, name: str) -> None:
33
+ """Unregister a tool.
34
+
35
+ Args:
36
+ name: Name of the tool to unregister
37
+ """
38
+ if name in self._tools:
39
+ del self._tools[name]
40
+ if name in self._instances:
41
+ del self._instances[name]
42
+
43
+ def get_tool(self, name: str) -> Tool:
44
+ """Get a tool instance by name.
45
+
46
+ Args:
47
+ name: Name of the tool
48
+
49
+ Returns:
50
+ Tool instance
51
+
52
+ Raises:
53
+ KeyError: If tool is not found
54
+ """
55
+ if name not in self._tools:
56
+ raise KeyError(f"Tool '{name}' not found in registry")
57
+
58
+ # Create instance if not already created (singleton pattern)
59
+ if name not in self._instances:
60
+ self._instances[name] = self._tools[name]()
61
+
62
+ return self._instances[name]
63
+
64
+ def list_tools(self, category: str | ToolCategory | None = None) -> list[str]:
65
+ """List all registered tool names.
66
+
67
+ Args:
68
+ category: Optional category to filter by (string or ToolCategory enum)
69
+
70
+ Returns:
71
+ List of tool names
72
+ """
73
+ if category is None:
74
+ return list(self._tools.keys())
75
+
76
+ # Convert string to enum
77
+ if isinstance(category, str):
78
+ try:
79
+ category = ToolCategory(category)
80
+ except ValueError:
81
+ # If string doesn't match any enum, return empty list
82
+ return []
83
+
84
+ # Filter by category
85
+ result = []
86
+ for name, tool_class in self._tools.items():
87
+ tool = self.get_tool(name)
88
+ if tool.category == category:
89
+ result.append(name)
90
+ return result
91
+
92
+ def get_all_tools(self, category: str | ToolCategory | None = None) -> list[Tool]:
93
+ """Get all tool instances.
94
+
95
+ Args:
96
+ category: Optional category to filter by (string or ToolCategory enum)
97
+
98
+ Returns:
99
+ List of tool instances
100
+ """
101
+ names = self.list_tools(category)
102
+ return [self.get_tool(name) for name in names]
103
+
104
+ def get_tool_definitions(self, category: str | ToolCategory | None = None) -> list:
105
+ """Get tool definitions for all tools.
106
+
107
+ Args:
108
+ category: Optional category to filter by (string or ToolCategory enum)
109
+
110
+ Returns:
111
+ List of ToolDefinition objects
112
+ """
113
+ tools = self.get_all_tools(category)
114
+ return [tool.to_definition() for tool in tools]
115
+
116
+
117
+ # Global registry instance
118
+ tool_registry = ToolRegistry()
119
+
120
+
121
+ def register_tool(tool_class: Type[Tool]) -> Type[Tool]:
122
+ """Decorator to register a tool class.
123
+
124
+ Usage:
125
+ @register_tool
126
+ class MyTool(Tool):
127
+ ...
128
+ """
129
+ tool_registry.register(tool_class)
130
+ return tool_class