naive-knowledge-base 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agents/__init__.py +7 -0
- agents/common/__init__.py +28 -0
- agents/common/base.py +187 -0
- agents/common/exceptions.py +62 -0
- agents/common/logging.py +60 -0
- agents/common/schema_converter.py +112 -0
- agents/common/tools.py +195 -0
- agents/common/utils.py +161 -0
- agents/dependency_graph/__init__.py +7 -0
- agents/dependency_graph/agent.py +205 -0
- agents/dependency_graph/model.py +38 -0
- api_models/__init__.py +3 -0
- api_models/flow_api_model.py +390 -0
- cli.py +203 -0
- main.py +68 -0
- naive_knowledge_base-0.1.0.dist-info/METADATA +215 -0
- naive_knowledge_base-0.1.0.dist-info/RECORD +24 -0
- naive_knowledge_base-0.1.0.dist-info/WHEEL +5 -0
- naive_knowledge_base-0.1.0.dist-info/entry_points.txt +2 -0
- naive_knowledge_base-0.1.0.dist-info/licenses/LICENSE +21 -0
- naive_knowledge_base-0.1.0.dist-info/top_level.txt +5 -0
- tools/__init__.py +11 -0
- tools/io.py +59 -0
- tools/tree.py +42 -0
agents/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from exceptions import (
|
|
2
|
+
CodeRefactorError,
|
|
3
|
+
ConfigurationError,
|
|
4
|
+
DependencyGraphError,
|
|
5
|
+
SonarQubeError,
|
|
6
|
+
RefactoringError,
|
|
7
|
+
FileOperationError,
|
|
8
|
+
MissingDependencyError,
|
|
9
|
+
wrap_exceptions
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from .logging import configure_logging, get_logger
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
# Exceptions
|
|
16
|
+
"CodeRefactorError",
|
|
17
|
+
"ConfigurationError",
|
|
18
|
+
"DependencyGraphError",
|
|
19
|
+
"SonarQubeError",
|
|
20
|
+
"RefactoringError",
|
|
21
|
+
"FileOperationError",
|
|
22
|
+
"MissingDependencyError",
|
|
23
|
+
"wrap_exceptions",
|
|
24
|
+
|
|
25
|
+
# Logging
|
|
26
|
+
"configure_logging",
|
|
27
|
+
"get_logger",
|
|
28
|
+
]
|
agents/common/base.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List, Dict, Any, Optional, Union
|
|
3
|
+
from smolagents import Tool, CodeAgent, ToolCallingAgent
|
|
4
|
+
|
|
5
|
+
from .logging import get_logger
|
|
6
|
+
|
|
7
|
+
logger = get_logger(__name__)
|
|
8
|
+
|
|
9
|
+
class BaseAgent(ABC):
|
|
10
|
+
"""
|
|
11
|
+
Abstract base class for all agents in the refactor system.
|
|
12
|
+
|
|
13
|
+
This class defines the common interface that all agents must implement.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def run(self, input_data: Any) -> Any:
|
|
18
|
+
"""
|
|
19
|
+
Run the agent with the provided input data.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
input_data: Input data for the agent
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Agent output
|
|
26
|
+
"""
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def name(self) -> str:
|
|
32
|
+
"""Get the agent name."""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def description(self) -> str:
|
|
38
|
+
"""Get the agent description."""
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
class AnalysisAgent(BaseAgent):
|
|
42
|
+
"""
|
|
43
|
+
Base class for agents that analyze code.
|
|
44
|
+
|
|
45
|
+
These agents are responsible for analyzing code and providing information about it.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def analyze(self, file_path: str, **kwargs) -> Dict[str, Any]:
|
|
50
|
+
"""
|
|
51
|
+
Analyze the code in the given file.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
file_path: Path to the file to analyze
|
|
55
|
+
**kwargs: Additional arguments for the analysis
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Analysis results
|
|
59
|
+
"""
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
class RefactoringAgent(BaseAgent):
|
|
63
|
+
"""
|
|
64
|
+
Base class for agents that refactor code.
|
|
65
|
+
|
|
66
|
+
These agents are responsible for refactoring code based on analysis results.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
def generate_plan(self, file_path: str, analysis_results: Dict[str, Any]) -> str:
|
|
71
|
+
"""
|
|
72
|
+
Generate a refactoring plan for the given file based on analysis results.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
file_path: Path to the file to refactor
|
|
76
|
+
analysis_results: Results from analysis agents
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Refactoring plan
|
|
80
|
+
"""
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
@abstractmethod
|
|
84
|
+
def execute_plan(self, file_path: str, plan: str) -> bool:
|
|
85
|
+
"""
|
|
86
|
+
Execute a refactoring plan on the given file.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
file_path: Path to the file to refactor
|
|
90
|
+
plan: Refactoring plan
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
True if successful, False otherwise
|
|
94
|
+
"""
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
class SmolagentsAdapter(BaseAgent):
|
|
98
|
+
"""
|
|
99
|
+
Adapter class for smolagents agents.
|
|
100
|
+
|
|
101
|
+
This class adapts smolagents agents to our BaseAgent interface.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(self, agent: Union[CodeAgent, ToolCallingAgent]):
|
|
105
|
+
"""
|
|
106
|
+
Initialize the adapter.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
agent: The smolagents agent to adapt
|
|
110
|
+
"""
|
|
111
|
+
self._agent = agent
|
|
112
|
+
|
|
113
|
+
def run(self, input_data: Any) -> Any:
|
|
114
|
+
"""
|
|
115
|
+
Run the smolagents agent.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
input_data: Input data for the agent
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Agent output
|
|
122
|
+
"""
|
|
123
|
+
return self._agent.run(input_data)
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def name(self) -> str:
|
|
127
|
+
"""Get the agent name."""
|
|
128
|
+
return self._agent.name
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def description(self) -> str:
|
|
132
|
+
"""Get the agent description."""
|
|
133
|
+
return self._agent.description
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def agent(self) -> Union[CodeAgent, ToolCallingAgent]:
|
|
137
|
+
"""Get the underlying smolagents agent."""
|
|
138
|
+
return self._agent
|
|
139
|
+
|
|
140
|
+
class AgentFactory:
|
|
141
|
+
"""
|
|
142
|
+
Factory class for creating agents.
|
|
143
|
+
|
|
144
|
+
This class provides methods for creating different types of agents.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def create_analysis_agent(agent_type: str, **kwargs) -> AnalysisAgent:
|
|
149
|
+
"""
|
|
150
|
+
Create an analysis agent of the specified type.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
agent_type: Type of analysis agent to create
|
|
154
|
+
**kwargs: Additional arguments for the agent
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
An instance of the requested analysis agent
|
|
158
|
+
|
|
159
|
+
Raises:
|
|
160
|
+
ValueError: If the agent type is not supported
|
|
161
|
+
"""
|
|
162
|
+
# Import here to avoid circular imports
|
|
163
|
+
from agents.sonar import create_sonar_agent
|
|
164
|
+
from agents.dependency_graph import create_dependency_graph_agent
|
|
165
|
+
|
|
166
|
+
if agent_type == "sonar":
|
|
167
|
+
return create_sonar_agent(**kwargs)
|
|
168
|
+
elif agent_type == "dependency_graph":
|
|
169
|
+
return create_dependency_graph_agent(**kwargs)
|
|
170
|
+
else:
|
|
171
|
+
raise ValueError(f"Unsupported analysis agent type: {agent_type}")
|
|
172
|
+
|
|
173
|
+
@staticmethod
|
|
174
|
+
def create_refactoring_agent(**kwargs) -> RefactoringAgent:
|
|
175
|
+
"""
|
|
176
|
+
Create a refactoring agent.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
**kwargs: Additional arguments for the agent
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
An instance of the refactoring agent
|
|
183
|
+
"""
|
|
184
|
+
# Import here to avoid circular imports
|
|
185
|
+
from agents.coder import create_coder_agent
|
|
186
|
+
|
|
187
|
+
return create_coder_agent(**kwargs)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from typing import Dict, Any, Optional
|
|
2
|
+
|
|
3
|
+
class CodeRefactorError(Exception):
|
|
4
|
+
"""Base exception for all code refactor errors."""
|
|
5
|
+
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
|
6
|
+
self.message = message
|
|
7
|
+
self.details = details or {}
|
|
8
|
+
super().__init__(self.message)
|
|
9
|
+
|
|
10
|
+
class ConfigurationError(CodeRefactorError):
|
|
11
|
+
"""Raised when there's an issue with configuration."""
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
class DependencyGraphError(CodeRefactorError):
|
|
15
|
+
"""Raised when there's an issue with dependency graph generation or processing."""
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
class SonarQubeError(CodeRefactorError):
|
|
19
|
+
"""Raised when there's an issue with SonarQube integration."""
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
class RefactoringError(CodeRefactorError):
|
|
23
|
+
"""Raised when there's an issue with the refactoring process."""
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
class FileOperationError(CodeRefactorError):
|
|
27
|
+
"""Raised when there's an issue with file operations."""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
class MissingDependencyError(CodeRefactorError):
|
|
31
|
+
"""Raised when a required dependency is missing."""
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
class TestingError(CodeRefactorError):
|
|
35
|
+
"""Raised when there's an issue with the testing process."""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def wrap_exceptions(func):
|
|
39
|
+
"""
|
|
40
|
+
Decorator to wrap exceptions with our custom exception types.
|
|
41
|
+
Example usage:
|
|
42
|
+
|
|
43
|
+
@wrap_exceptions
|
|
44
|
+
def function_that_might_fail():
|
|
45
|
+
# function body
|
|
46
|
+
"""
|
|
47
|
+
def wrapper(*args, **kwargs):
|
|
48
|
+
try:
|
|
49
|
+
return func(*args, **kwargs)
|
|
50
|
+
except FileNotFoundError as e:
|
|
51
|
+
raise FileOperationError(f"File not found: {str(e)}", {"path": str(e)})
|
|
52
|
+
except PermissionError as e:
|
|
53
|
+
raise FileOperationError(f"Permission denied: {str(e)}", {"path": str(e)})
|
|
54
|
+
except ImportError as e:
|
|
55
|
+
raise MissingDependencyError(f"Missing dependency: {str(e)}")
|
|
56
|
+
except Exception as e:
|
|
57
|
+
# If it's already one of our exceptions, re-raise it
|
|
58
|
+
if isinstance(e, CodeRefactorError):
|
|
59
|
+
raise
|
|
60
|
+
# Otherwise wrap it
|
|
61
|
+
raise CodeRefactorError(f"Unexpected error: {str(e)}", {"original_error": str(e)})
|
|
62
|
+
return wrapper
|
agents/common/logging.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
# Configure the root logger
|
|
7
|
+
def configure_logging(
|
|
8
|
+
log_level: str = "INFO",
|
|
9
|
+
log_file: Optional[str] = None,
|
|
10
|
+
log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
11
|
+
):
|
|
12
|
+
"""
|
|
13
|
+
Configure the logging system.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
log_level: The log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
|
17
|
+
log_file: Path to the log file, if None, logs to console only
|
|
18
|
+
log_format: The format of the log messages
|
|
19
|
+
"""
|
|
20
|
+
# Convert string log level to logging constant
|
|
21
|
+
numeric_level = getattr(logging, log_level.upper(), None)
|
|
22
|
+
if not isinstance(numeric_level, int):
|
|
23
|
+
raise ValueError(f"Invalid log level: {log_level}")
|
|
24
|
+
|
|
25
|
+
# Configure root logger
|
|
26
|
+
root_logger = logging.getLogger()
|
|
27
|
+
root_logger.setLevel(numeric_level)
|
|
28
|
+
|
|
29
|
+
# Clear existing handlers to avoid duplicates
|
|
30
|
+
root_logger.handlers = []
|
|
31
|
+
|
|
32
|
+
# Create console handler
|
|
33
|
+
console_handler = logging.StreamHandler(sys.stdout)
|
|
34
|
+
console_handler.setLevel(numeric_level)
|
|
35
|
+
console_handler.setFormatter(logging.Formatter(log_format))
|
|
36
|
+
root_logger.addHandler(console_handler)
|
|
37
|
+
|
|
38
|
+
# Create file handler if log_file is specified
|
|
39
|
+
if log_file:
|
|
40
|
+
# Ensure directory exists
|
|
41
|
+
log_dir = os.path.dirname(log_file)
|
|
42
|
+
if log_dir and not os.path.exists(log_dir):
|
|
43
|
+
os.makedirs(log_dir)
|
|
44
|
+
|
|
45
|
+
file_handler = logging.FileHandler(log_file)
|
|
46
|
+
file_handler.setLevel(numeric_level)
|
|
47
|
+
file_handler.setFormatter(logging.Formatter(log_format))
|
|
48
|
+
root_logger.addHandler(file_handler)
|
|
49
|
+
|
|
50
|
+
def get_logger(name: str) -> logging.Logger:
|
|
51
|
+
"""
|
|
52
|
+
Get a logger with the specified name.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
name: The name of the logger
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Logger instance
|
|
59
|
+
"""
|
|
60
|
+
return logging.getLogger(name)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from typing import Dict, Any, List, Optional, Union, Callable
|
|
2
|
+
import json
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
class SchemaConverter:
|
|
6
|
+
"""
|
|
7
|
+
Converts between different schema formats for tool functions.
|
|
8
|
+
|
|
9
|
+
This class handles conversion of parameters and return types between
|
|
10
|
+
different schema formats used by different libraries.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
@staticmethod
|
|
14
|
+
def convert_pydantic_to_dict(model_instance: BaseModel) -> Dict[str, Any]:
|
|
15
|
+
"""
|
|
16
|
+
Convert a Pydantic model instance to a dictionary.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model_instance: The Pydantic model instance to convert
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
The model as a dictionary
|
|
23
|
+
"""
|
|
24
|
+
return model_instance.dict()
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def convert_dict_to_pydantic(data: Dict[str, Any], model_class: type) -> BaseModel:
|
|
28
|
+
"""
|
|
29
|
+
Convert a dictionary to a Pydantic model instance.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
data: The dictionary data to convert
|
|
33
|
+
model_class: The Pydantic model class to convert to
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
An instance of the specified Pydantic model
|
|
37
|
+
"""
|
|
38
|
+
return model_class(**data)
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def convert_json_to_pydantic(json_str: str, model_class: type) -> BaseModel:
|
|
42
|
+
"""
|
|
43
|
+
Convert a JSON string to a Pydantic model instance.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
json_str: The JSON string to convert
|
|
47
|
+
model_class: The Pydantic model class to convert to
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
An instance of the specified Pydantic model
|
|
51
|
+
"""
|
|
52
|
+
data = json.loads(json_str)
|
|
53
|
+
return SchemaConverter.convert_dict_to_pydantic(data, model_class)
|
|
54
|
+
|
|
55
|
+
@staticmethod
|
|
56
|
+
def wrap_function_for_schema_conversion(
|
|
57
|
+
func: Callable,
|
|
58
|
+
input_conversions: Dict[str, type] = None,
|
|
59
|
+
output_conversion: Optional[type] = None
|
|
60
|
+
) -> Callable:
|
|
61
|
+
"""
|
|
62
|
+
Wraps a function to handle schema conversions automatically.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
func: The function to wrap
|
|
66
|
+
input_conversions: Mapping of parameter names to Pydantic model classes for conversion
|
|
67
|
+
output_conversion: Pydantic model class for converting the function output
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
A wrapped function that handles schema conversions
|
|
71
|
+
"""
|
|
72
|
+
input_conversions = input_conversions or {}
|
|
73
|
+
|
|
74
|
+
def wrapper(*args, **kwargs):
|
|
75
|
+
# Convert input parameters based on specified conversions
|
|
76
|
+
converted_kwargs = kwargs.copy()
|
|
77
|
+
|
|
78
|
+
for param_name, model_class in input_conversions.items():
|
|
79
|
+
if param_name in kwargs:
|
|
80
|
+
# Special case for JSON strings
|
|
81
|
+
if isinstance(kwargs[param_name], str) and kwargs[param_name].startswith('{'):
|
|
82
|
+
try:
|
|
83
|
+
converted_kwargs[param_name] = SchemaConverter.convert_json_to_pydantic(
|
|
84
|
+
kwargs[param_name], model_class
|
|
85
|
+
)
|
|
86
|
+
except (json.JSONDecodeError, ValueError):
|
|
87
|
+
# Not a valid JSON string, keep original
|
|
88
|
+
pass
|
|
89
|
+
# Dict to model conversion
|
|
90
|
+
elif isinstance(kwargs[param_name], dict):
|
|
91
|
+
converted_kwargs[param_name] = SchemaConverter.convert_dict_to_pydantic(
|
|
92
|
+
kwargs[param_name], model_class
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Call the original function with converted parameters
|
|
96
|
+
result = func(*args, **converted_kwargs)
|
|
97
|
+
|
|
98
|
+
# Convert output if needed
|
|
99
|
+
if output_conversion and isinstance(result, BaseModel):
|
|
100
|
+
return SchemaConverter.convert_pydantic_to_dict(result)
|
|
101
|
+
elif output_conversion and isinstance(result, dict):
|
|
102
|
+
# Create a model instance and then convert back to dict to ensure schema compliance
|
|
103
|
+
model_instance = SchemaConverter.convert_dict_to_pydantic(result, output_conversion)
|
|
104
|
+
return SchemaConverter.convert_pydantic_to_dict(model_instance)
|
|
105
|
+
|
|
106
|
+
return result
|
|
107
|
+
|
|
108
|
+
# Preserve the original function's metadata
|
|
109
|
+
wrapper.__name__ = func.__name__
|
|
110
|
+
wrapper.__doc__ = func.__doc__
|
|
111
|
+
|
|
112
|
+
return wrapper
|
agents/common/tools.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
from typing import Callable, Dict, Any, TypeVar, Protocol, Optional, List, Union
|
|
2
|
+
from functools import wraps
|
|
3
|
+
import inspect
|
|
4
|
+
from smolagents import tool as smolagents_tool
|
|
5
|
+
from pydantic import create_model, BaseModel, Field
|
|
6
|
+
|
|
7
|
+
from logging import get_logger
|
|
8
|
+
from exceptions import CodeRefactorError
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
T = TypeVar('T')
|
|
13
|
+
|
|
14
|
+
class Tool(Protocol):
|
|
15
|
+
"""Protocol for tool functions."""
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def name(self) -> str:
|
|
19
|
+
"""Get the tool name."""
|
|
20
|
+
...
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def description(self) -> str:
|
|
24
|
+
"""Get the tool description."""
|
|
25
|
+
...
|
|
26
|
+
|
|
27
|
+
def __call__(self, *args, **kwargs) -> Any:
|
|
28
|
+
"""Call the tool."""
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
def tool(func: Callable) -> Tool:
|
|
32
|
+
"""
|
|
33
|
+
Decorator for tool functions.
|
|
34
|
+
|
|
35
|
+
This decorator wraps a function to provide standardized error handling and logging.
|
|
36
|
+
It also validates that the function has proper type hints and docstrings.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
func: The function to wrap
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
The wrapped function
|
|
43
|
+
"""
|
|
44
|
+
# Get function signature
|
|
45
|
+
sig = inspect.signature(func)
|
|
46
|
+
|
|
47
|
+
# Extract parameter types from type hints
|
|
48
|
+
param_types = {}
|
|
49
|
+
for name, param in sig.parameters.items():
|
|
50
|
+
if param.annotation is inspect.Parameter.empty:
|
|
51
|
+
logger.warning(f"Parameter {name} in {func.__name__} missing type hint")
|
|
52
|
+
continue
|
|
53
|
+
param_types[name] = param.annotation
|
|
54
|
+
|
|
55
|
+
# Create pydantic model for parameters
|
|
56
|
+
fields = {}
|
|
57
|
+
for name, type_ in param_types.items():
|
|
58
|
+
# Extract description from docstring if available
|
|
59
|
+
param_desc = ""
|
|
60
|
+
if func.__doc__:
|
|
61
|
+
for line in func.__doc__.split('\n'):
|
|
62
|
+
if f"{name}:" in line or f"{name} (" in line:
|
|
63
|
+
parts = line.split(':', 1)
|
|
64
|
+
if len(parts) > 1:
|
|
65
|
+
param_desc = parts[1].strip()
|
|
66
|
+
break
|
|
67
|
+
|
|
68
|
+
fields[name] = (type_, Field(..., description=param_desc))
|
|
69
|
+
|
|
70
|
+
# Create model class
|
|
71
|
+
model_name = f"{func.__name__.title()}Model"
|
|
72
|
+
model = create_model(model_name, **fields)
|
|
73
|
+
|
|
74
|
+
@wraps(func)
|
|
75
|
+
def wrapper(*args, **kwargs):
|
|
76
|
+
try:
|
|
77
|
+
# Validate parameters
|
|
78
|
+
params = model(**kwargs)
|
|
79
|
+
|
|
80
|
+
# Execute function
|
|
81
|
+
logger.debug(f"Executing tool {func.__name__}")
|
|
82
|
+
result = func(*args, **{k: v for k, v in params.dict().items() if k in kwargs})
|
|
83
|
+
|
|
84
|
+
# Log success
|
|
85
|
+
logger.debug(f"Tool {func.__name__} executed successfully")
|
|
86
|
+
|
|
87
|
+
return result
|
|
88
|
+
except Exception as e:
|
|
89
|
+
# Log error
|
|
90
|
+
logger.error(f"Error executing tool {func.__name__}: {str(e)}")
|
|
91
|
+
|
|
92
|
+
# Wrap exception
|
|
93
|
+
if not isinstance(e, CodeRefactorError):
|
|
94
|
+
e = CodeRefactorError(f"Tool {func.__name__} failed: {str(e)}")
|
|
95
|
+
|
|
96
|
+
raise e
|
|
97
|
+
|
|
98
|
+
# Add tool name and description
|
|
99
|
+
wrapper.name = func.__name__
|
|
100
|
+
wrapper.description = func.__doc__ or ""
|
|
101
|
+
|
|
102
|
+
# Decorate with smolagents tool decorator
|
|
103
|
+
decorated_func = smolagents_tool(func)
|
|
104
|
+
|
|
105
|
+
# Copy attributes from smolagents_tool
|
|
106
|
+
for attr in dir(decorated_func):
|
|
107
|
+
if not attr.startswith('__'):
|
|
108
|
+
setattr(wrapper, attr, getattr(decorated_func, attr))
|
|
109
|
+
|
|
110
|
+
return wrapper
|
|
111
|
+
|
|
112
|
+
def create_tool_registry():
|
|
113
|
+
"""
|
|
114
|
+
Create a registry for tools.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
A dictionary for registering tools
|
|
118
|
+
"""
|
|
119
|
+
return {'tools': {}}
|
|
120
|
+
|
|
121
|
+
def register_tool(registry: Dict[str, Dict[str, Tool]], tool_func: Tool) -> Tool:
|
|
122
|
+
"""
|
|
123
|
+
Register a tool function in the registry.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
registry: The tool registry
|
|
127
|
+
tool_func: The tool function to register
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
The registered tool function
|
|
131
|
+
"""
|
|
132
|
+
registry['tools'][tool_func.name] = tool_func
|
|
133
|
+
return tool_func
|
|
134
|
+
|
|
135
|
+
def get_tool(registry: Dict[str, Dict[str, Tool]], name: str) -> Optional[Tool]:
|
|
136
|
+
"""
|
|
137
|
+
Get a tool from the registry by name.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
registry: The tool registry
|
|
141
|
+
name: The name of the tool
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
The tool function or None if not found
|
|
145
|
+
"""
|
|
146
|
+
return registry['tools'].get(name)
|
|
147
|
+
|
|
148
|
+
def get_all_tools(registry: Dict[str, Dict[str, Tool]]) -> List[Tool]:
|
|
149
|
+
"""
|
|
150
|
+
Get all tools from the registry.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
registry: The tool registry
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
List of all tool functions
|
|
157
|
+
"""
|
|
158
|
+
return list(registry['tools'].values())
|
|
159
|
+
|
|
160
|
+
# Create global tool registry
|
|
161
|
+
tool_registry = create_tool_registry()
|
|
162
|
+
|
|
163
|
+
def register(func: Callable) -> Tool:
|
|
164
|
+
"""
|
|
165
|
+
Register a tool function in the global registry.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
func: The function to register
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
The registered tool function
|
|
172
|
+
"""
|
|
173
|
+
tool_func = tool(func)
|
|
174
|
+
return register_tool(tool_registry, tool_func)
|
|
175
|
+
|
|
176
|
+
def get(name: str) -> Optional[Tool]:
|
|
177
|
+
"""
|
|
178
|
+
Get a tool from the global registry by name.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
name: The name of the tool
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
The tool function or None if not found
|
|
185
|
+
"""
|
|
186
|
+
return get_tool(tool_registry, name)
|
|
187
|
+
|
|
188
|
+
def get_all() -> List[Tool]:
|
|
189
|
+
"""
|
|
190
|
+
Get all tools from the global registry.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
List of all tool functions
|
|
194
|
+
"""
|
|
195
|
+
return get_all_tools(tool_registry)
|