chuk-tool-processor 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chuk-tool-processor might be problematic. Click here for more details.
- chuk_tool_processor/__init__.py +1 -0
- chuk_tool_processor/core/__init__.py +1 -0
- chuk_tool_processor/core/exceptions.py +45 -0
- chuk_tool_processor/core/processor.py +268 -0
- chuk_tool_processor/execution/__init__.py +0 -0
- chuk_tool_processor/execution/strategies/__init__.py +0 -0
- chuk_tool_processor/execution/strategies/inprocess_strategy.py +206 -0
- chuk_tool_processor/execution/strategies/subprocess_strategy.py +103 -0
- chuk_tool_processor/execution/tool_executor.py +46 -0
- chuk_tool_processor/execution/wrappers/__init__.py +0 -0
- chuk_tool_processor/execution/wrappers/caching.py +234 -0
- chuk_tool_processor/execution/wrappers/rate_limiting.py +149 -0
- chuk_tool_processor/execution/wrappers/retry.py +176 -0
- chuk_tool_processor/models/__init__.py +1 -0
- chuk_tool_processor/models/execution_strategy.py +19 -0
- chuk_tool_processor/models/tool_call.py +7 -0
- chuk_tool_processor/models/tool_result.py +49 -0
- chuk_tool_processor/plugins/__init__.py +1 -0
- chuk_tool_processor/plugins/discovery.py +205 -0
- chuk_tool_processor/plugins/parsers/__init__.py +1 -0
- chuk_tool_processor/plugins/parsers/function_call_tool.py +105 -0
- chuk_tool_processor/plugins/parsers/json_tool.py +17 -0
- chuk_tool_processor/plugins/parsers/xml_tool.py +41 -0
- chuk_tool_processor/registry/__init__.py +20 -0
- chuk_tool_processor/registry/decorators.py +42 -0
- chuk_tool_processor/registry/interface.py +79 -0
- chuk_tool_processor/registry/metadata.py +36 -0
- chuk_tool_processor/registry/provider.py +44 -0
- chuk_tool_processor/registry/providers/__init__.py +41 -0
- chuk_tool_processor/registry/providers/memory.py +165 -0
- chuk_tool_processor/utils/__init__.py +0 -0
- chuk_tool_processor/utils/logging.py +260 -0
- chuk_tool_processor/utils/validation.py +192 -0
- chuk_tool_processor-0.1.0.dist-info/METADATA +293 -0
- chuk_tool_processor-0.1.0.dist-info/RECORD +37 -0
- chuk_tool_processor-0.1.0.dist-info/WHEEL +5 -0
- chuk_tool_processor-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# chuk_tool_processor/registry/providers/memory.py
|
|
2
|
+
"""
|
|
3
|
+
In-memory implementation of the tool registry.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import inspect
|
|
7
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
from chuk_tool_processor.core.exceptions import ToolNotFoundError
|
|
10
|
+
from chuk_tool_processor.registry.interface import ToolRegistryInterface
|
|
11
|
+
from chuk_tool_processor.registry.metadata import ToolMetadata
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class InMemoryToolRegistry(ToolRegistryInterface):
|
|
15
|
+
"""
|
|
16
|
+
In-memory implementation of ToolRegistryInterface with namespace support.
|
|
17
|
+
|
|
18
|
+
This implementation stores tools and their metadata in memory,
|
|
19
|
+
organized by namespace. It's suitable for single-process applications
|
|
20
|
+
or for testing, but doesn't provide persistence or sharing across
|
|
21
|
+
multiple processes.
|
|
22
|
+
"""
|
|
23
|
+
def __init__(self):
|
|
24
|
+
"""Initialize the in-memory registry."""
|
|
25
|
+
# Store tools as {namespace: {name: tool}}
|
|
26
|
+
self._tools: Dict[str, Dict[str, Any]] = {}
|
|
27
|
+
# Store metadata as {namespace: {name: metadata}}
|
|
28
|
+
self._metadata: Dict[str, Dict[str, ToolMetadata]] = {}
|
|
29
|
+
|
|
30
|
+
def register_tool(
|
|
31
|
+
self,
|
|
32
|
+
tool: Any,
|
|
33
|
+
name: Optional[str] = None,
|
|
34
|
+
namespace: str = "default",
|
|
35
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
36
|
+
) -> None:
|
|
37
|
+
"""
|
|
38
|
+
Register a tool implementation.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
tool: The tool class or instance with an `execute` method.
|
|
42
|
+
name: Optional explicit name; if omitted, uses tool.__name__.
|
|
43
|
+
namespace: Namespace for the tool (default: "default").
|
|
44
|
+
metadata: Optional additional metadata for the tool.
|
|
45
|
+
"""
|
|
46
|
+
# Ensure the namespace exists
|
|
47
|
+
if namespace not in self._tools:
|
|
48
|
+
self._tools[namespace] = {}
|
|
49
|
+
self._metadata[namespace] = {}
|
|
50
|
+
|
|
51
|
+
# Determine tool name
|
|
52
|
+
key = name or getattr(tool, "__name__", None) or repr(tool)
|
|
53
|
+
|
|
54
|
+
# Register the tool
|
|
55
|
+
self._tools[namespace][key] = tool
|
|
56
|
+
|
|
57
|
+
# Create and store metadata
|
|
58
|
+
is_async = inspect.iscoroutinefunction(getattr(tool, "execute", None))
|
|
59
|
+
|
|
60
|
+
# Get description from docstring if available
|
|
61
|
+
description = None
|
|
62
|
+
if hasattr(tool, "__doc__") and tool.__doc__:
|
|
63
|
+
description = inspect.getdoc(tool)
|
|
64
|
+
|
|
65
|
+
# Create metadata object
|
|
66
|
+
meta_dict = {
|
|
67
|
+
"name": key,
|
|
68
|
+
"namespace": namespace,
|
|
69
|
+
"is_async": is_async
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
# Add description if available (but don't override metadata if provided)
|
|
73
|
+
if description and not (metadata and "description" in metadata):
|
|
74
|
+
meta_dict["description"] = description
|
|
75
|
+
|
|
76
|
+
# Add any additional metadata
|
|
77
|
+
if metadata:
|
|
78
|
+
meta_dict.update(metadata)
|
|
79
|
+
|
|
80
|
+
tool_metadata = ToolMetadata(**meta_dict)
|
|
81
|
+
|
|
82
|
+
self._metadata[namespace][key] = tool_metadata
|
|
83
|
+
|
|
84
|
+
def get_tool(self, name: str, namespace: str = "default") -> Optional[Any]:
|
|
85
|
+
"""
|
|
86
|
+
Retrieve a registered tool by name and namespace.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
name: The name of the tool.
|
|
90
|
+
namespace: The namespace of the tool (default: "default").
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
The tool implementation or None if not found.
|
|
94
|
+
"""
|
|
95
|
+
if namespace not in self._tools:
|
|
96
|
+
return None
|
|
97
|
+
return self._tools[namespace].get(name)
|
|
98
|
+
|
|
99
|
+
def get_tool_strict(self, name: str, namespace: str = "default") -> Any:
|
|
100
|
+
"""
|
|
101
|
+
Retrieve a registered tool by name and namespace, raising an exception if not found.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
name: The name of the tool.
|
|
105
|
+
namespace: The namespace of the tool (default: "default").
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
The tool implementation.
|
|
109
|
+
|
|
110
|
+
Raises:
|
|
111
|
+
ToolNotFoundError: If the tool is not found.
|
|
112
|
+
"""
|
|
113
|
+
tool = self.get_tool(name, namespace)
|
|
114
|
+
if tool is None:
|
|
115
|
+
raise ToolNotFoundError(f"{namespace}.{name}")
|
|
116
|
+
return tool
|
|
117
|
+
|
|
118
|
+
def get_metadata(self, name: str, namespace: str = "default") -> Optional[ToolMetadata]:
|
|
119
|
+
"""
|
|
120
|
+
Retrieve metadata for a registered tool.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
name: The name of the tool.
|
|
124
|
+
namespace: The namespace of the tool (default: "default").
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
ToolMetadata if found, None otherwise.
|
|
128
|
+
"""
|
|
129
|
+
if namespace not in self._metadata:
|
|
130
|
+
return None
|
|
131
|
+
return self._metadata[namespace].get(name)
|
|
132
|
+
|
|
133
|
+
def list_tools(self, namespace: Optional[str] = None) -> List[Tuple[str, str]]:
|
|
134
|
+
"""
|
|
135
|
+
List all registered tool names, optionally filtered by namespace.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
namespace: Optional namespace filter.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
List of (namespace, name) tuples.
|
|
142
|
+
"""
|
|
143
|
+
result = []
|
|
144
|
+
|
|
145
|
+
if namespace:
|
|
146
|
+
# List tools in specific namespace
|
|
147
|
+
if namespace in self._tools:
|
|
148
|
+
for name in self._tools[namespace].keys():
|
|
149
|
+
result.append((namespace, name))
|
|
150
|
+
else:
|
|
151
|
+
# List all tools
|
|
152
|
+
for ns, tools in self._tools.items():
|
|
153
|
+
for name in tools.keys():
|
|
154
|
+
result.append((ns, name))
|
|
155
|
+
|
|
156
|
+
return result
|
|
157
|
+
|
|
158
|
+
def list_namespaces(self) -> List[str]:
|
|
159
|
+
"""
|
|
160
|
+
List all registered namespaces.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
List of namespace names.
|
|
164
|
+
"""
|
|
165
|
+
return list(self._tools.keys())
|
|
File without changes
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
# chuk_tool_processor/logging.py
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
import time
|
|
6
|
+
import uuid
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
from datetime import datetime, timezone
|
|
9
|
+
from typing import Any, Dict, Optional, Union
|
|
10
|
+
|
|
11
|
+
# Configure the root logger
|
|
12
|
+
root_logger = logging.getLogger("chuk_tool_processor")
|
|
13
|
+
root_logger.setLevel(logging.INFO)
|
|
14
|
+
|
|
15
|
+
# Create a handler for stderr
|
|
16
|
+
handler = logging.StreamHandler(sys.stderr)
|
|
17
|
+
handler.setLevel(logging.INFO)
|
|
18
|
+
|
|
19
|
+
# Create a formatter for structured logging
|
|
20
|
+
class StructuredFormatter(logging.Formatter):
|
|
21
|
+
"""
|
|
22
|
+
Custom formatter for structured JSON logging.
|
|
23
|
+
"""
|
|
24
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
25
|
+
"""
|
|
26
|
+
Format log record as JSON.
|
|
27
|
+
"""
|
|
28
|
+
# Basic log data
|
|
29
|
+
log_data = {
|
|
30
|
+
"timestamp": datetime.fromtimestamp(record.created, timezone.utc)
|
|
31
|
+
.isoformat().replace("+00:00", "Z"),
|
|
32
|
+
"level": record.levelname,
|
|
33
|
+
"message": record.getMessage(),
|
|
34
|
+
"logger": record.name,
|
|
35
|
+
"pid": record.process,
|
|
36
|
+
"thread": record.thread,
|
|
37
|
+
"file": record.filename,
|
|
38
|
+
"line": record.lineno,
|
|
39
|
+
"function": record.funcName,
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
# Add traceback if present
|
|
43
|
+
if record.exc_info:
|
|
44
|
+
log_data["traceback"] = self.formatException(record.exc_info)
|
|
45
|
+
|
|
46
|
+
# Add extra fields
|
|
47
|
+
if hasattr(record, "extra"):
|
|
48
|
+
log_data.update(record.extra)
|
|
49
|
+
|
|
50
|
+
# Add structured logging context if present
|
|
51
|
+
if hasattr(record, "context"):
|
|
52
|
+
log_data["context"] = record.context
|
|
53
|
+
|
|
54
|
+
return json.dumps(log_data)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# Configure the formatter
|
|
58
|
+
formatter = StructuredFormatter()
|
|
59
|
+
handler.setFormatter(formatter)
|
|
60
|
+
|
|
61
|
+
# Add the handler to the root logger
|
|
62
|
+
root_logger.addHandler(handler)
|
|
63
|
+
|
|
64
|
+
# Thread-local context storage
|
|
65
|
+
class LogContext:
|
|
66
|
+
"""
|
|
67
|
+
Thread-local storage for log context.
|
|
68
|
+
"""
|
|
69
|
+
def __init__(self):
|
|
70
|
+
self.context = {}
|
|
71
|
+
self.request_id = None
|
|
72
|
+
|
|
73
|
+
def set(self, key: str, value: Any) -> None:
|
|
74
|
+
self.context[key] = value
|
|
75
|
+
|
|
76
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
77
|
+
return self.context.get(key, default)
|
|
78
|
+
|
|
79
|
+
def update(self, values: Dict[str, Any]) -> None:
|
|
80
|
+
self.context.update(values)
|
|
81
|
+
|
|
82
|
+
def clear(self) -> None:
|
|
83
|
+
self.context = {}
|
|
84
|
+
self.request_id = None
|
|
85
|
+
|
|
86
|
+
def start_request(self, request_id: Optional[str] = None) -> str:
|
|
87
|
+
self.request_id = request_id or str(uuid.uuid4())
|
|
88
|
+
self.context["request_id"] = self.request_id
|
|
89
|
+
return self.request_id
|
|
90
|
+
|
|
91
|
+
def end_request(self) -> None:
|
|
92
|
+
self.clear()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# Create global log context
|
|
96
|
+
log_context = LogContext()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class StructuredAdapter(logging.LoggerAdapter):
|
|
100
|
+
"""
|
|
101
|
+
Adapter to add structured context to log messages.
|
|
102
|
+
"""
|
|
103
|
+
def process(self, msg: str, kwargs: Dict[str, Any]) -> tuple:
|
|
104
|
+
kwargs = kwargs.copy() if kwargs else {}
|
|
105
|
+
extra = kwargs.get("extra", {})
|
|
106
|
+
if log_context.context:
|
|
107
|
+
context_copy = log_context.context.copy()
|
|
108
|
+
if "context" in extra:
|
|
109
|
+
extra["context"].update(context_copy)
|
|
110
|
+
else:
|
|
111
|
+
extra["context"] = context_copy
|
|
112
|
+
kwargs["extra"] = extra
|
|
113
|
+
return msg, kwargs
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def get_logger(name: str) -> StructuredAdapter:
|
|
117
|
+
logger = logging.getLogger(name)
|
|
118
|
+
return StructuredAdapter(logger, {})
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@contextmanager
|
|
122
|
+
def log_context_span(
|
|
123
|
+
operation: str,
|
|
124
|
+
extra: Optional[Dict[str, Any]] = None,
|
|
125
|
+
log_duration: bool = True,
|
|
126
|
+
level: int = logging.INFO
|
|
127
|
+
):
|
|
128
|
+
logger = get_logger(f"chuk_tool_processor.span.{operation}")
|
|
129
|
+
start_time = time.time()
|
|
130
|
+
span_id = str(uuid.uuid4())
|
|
131
|
+
span_context = {
|
|
132
|
+
"span_id": span_id,
|
|
133
|
+
"operation": operation,
|
|
134
|
+
"start_time": datetime.fromtimestamp(start_time, timezone.utc)
|
|
135
|
+
.isoformat().replace("+00:00", "Z"),
|
|
136
|
+
}
|
|
137
|
+
if extra:
|
|
138
|
+
span_context.update(extra)
|
|
139
|
+
previous_context = log_context.context.copy() if log_context.context else {}
|
|
140
|
+
log_context.update(span_context)
|
|
141
|
+
logger.log(level, f"Starting {operation}")
|
|
142
|
+
try:
|
|
143
|
+
yield
|
|
144
|
+
if log_duration:
|
|
145
|
+
duration = time.time() - start_time
|
|
146
|
+
logger.log(level, f"Completed {operation}", extra={"context": {"duration": duration}})
|
|
147
|
+
else:
|
|
148
|
+
logger.log(level, f"Completed {operation}")
|
|
149
|
+
except Exception as e:
|
|
150
|
+
duration = time.time() - start_time
|
|
151
|
+
logger.exception(
|
|
152
|
+
f"Error in {operation}: {str(e)}",
|
|
153
|
+
extra={"context": {"duration": duration, "error": str(e)}}
|
|
154
|
+
)
|
|
155
|
+
raise
|
|
156
|
+
finally:
|
|
157
|
+
log_context.clear()
|
|
158
|
+
if previous_context:
|
|
159
|
+
log_context.update(previous_context)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@contextmanager
|
|
163
|
+
def request_logging(request_id: Optional[str] = None):
|
|
164
|
+
logger = get_logger("chuk_tool_processor.request")
|
|
165
|
+
request_id = log_context.start_request(request_id)
|
|
166
|
+
start_time = time.time()
|
|
167
|
+
logger.info(f"Starting request {request_id}")
|
|
168
|
+
try:
|
|
169
|
+
yield request_id
|
|
170
|
+
duration = time.time() - start_time
|
|
171
|
+
logger.info(
|
|
172
|
+
f"Completed request {request_id}",
|
|
173
|
+
extra={"context": {"duration": duration}}
|
|
174
|
+
)
|
|
175
|
+
except Exception as e:
|
|
176
|
+
duration = time.time() - start_time
|
|
177
|
+
logger.exception(
|
|
178
|
+
f"Error in request {request_id}: {str(e)}",
|
|
179
|
+
extra={"context": {"duration": duration, "error": str(e)}}
|
|
180
|
+
)
|
|
181
|
+
raise
|
|
182
|
+
finally:
|
|
183
|
+
log_context.end_request()
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def log_tool_call(tool_call, tool_result):
|
|
187
|
+
logger = get_logger("chuk_tool_processor.tool_call")
|
|
188
|
+
duration = (tool_result.end_time - tool_result.start_time).total_seconds()
|
|
189
|
+
context = {
|
|
190
|
+
"tool": tool_call.tool,
|
|
191
|
+
"arguments": tool_call.arguments,
|
|
192
|
+
"result": tool_result.result,
|
|
193
|
+
"error": tool_result.error,
|
|
194
|
+
"duration": duration,
|
|
195
|
+
"machine": tool_result.machine,
|
|
196
|
+
"pid": tool_result.pid,
|
|
197
|
+
}
|
|
198
|
+
if hasattr(tool_result, "cached") and tool_result.cached:
|
|
199
|
+
context["cached"] = True
|
|
200
|
+
if hasattr(tool_result, "attempts") and tool_result.attempts:
|
|
201
|
+
context["attempts"] = tool_result.attempts
|
|
202
|
+
if tool_result.error:
|
|
203
|
+
logger.error(
|
|
204
|
+
f"Tool {tool_call.tool} failed: {tool_result.error}",
|
|
205
|
+
extra={"context": context}
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
logger.info(
|
|
209
|
+
f"Tool {tool_call.tool} succeeded in {duration:.3f}s",
|
|
210
|
+
extra={"context": context}
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class MetricsLogger:
|
|
215
|
+
def __init__(self):
|
|
216
|
+
self.logger = get_logger("chuk_tool_processor.metrics")
|
|
217
|
+
def log_tool_execution(
|
|
218
|
+
self,
|
|
219
|
+
tool: str,
|
|
220
|
+
success: bool,
|
|
221
|
+
duration: float,
|
|
222
|
+
error: Optional[str] = None,
|
|
223
|
+
cached: bool = False,
|
|
224
|
+
attempts: int = 1
|
|
225
|
+
):
|
|
226
|
+
self.logger.info(
|
|
227
|
+
f"Tool execution metric: {tool}",
|
|
228
|
+
extra={
|
|
229
|
+
"context": {
|
|
230
|
+
"metric_type": "tool_execution",
|
|
231
|
+
"tool": tool,
|
|
232
|
+
"success": success,
|
|
233
|
+
"duration": duration,
|
|
234
|
+
"error": error,
|
|
235
|
+
"cached": cached,
|
|
236
|
+
"attempts": attempts,
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
)
|
|
240
|
+
def log_parser_metric(
|
|
241
|
+
self,
|
|
242
|
+
parser: str,
|
|
243
|
+
success: bool,
|
|
244
|
+
duration: float,
|
|
245
|
+
num_calls: int
|
|
246
|
+
):
|
|
247
|
+
self.logger.info(
|
|
248
|
+
f"Parser metric: {parser}",
|
|
249
|
+
extra={
|
|
250
|
+
"context": {
|
|
251
|
+
"metric_type": "parser",
|
|
252
|
+
"parser": parser,
|
|
253
|
+
"success": success,
|
|
254
|
+
"duration": duration,
|
|
255
|
+
"num_calls": num_calls,
|
|
256
|
+
}
|
|
257
|
+
}
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
metrics = MetricsLogger()
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
# chuk_tool_processor/utils/validation.py
|
|
2
|
+
from typing import Any, Dict, Optional, Type, get_type_hints, Union, List, Callable
|
|
3
|
+
from pydantic import BaseModel, ValidationError, create_model
|
|
4
|
+
import inspect
|
|
5
|
+
from functools import wraps
|
|
6
|
+
|
|
7
|
+
from chuk_tool_processor.core.exceptions import ToolValidationError
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def validate_arguments(tool_name: str, tool_func: Callable, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
11
|
+
"""
|
|
12
|
+
Validate tool arguments against function signature.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
tool_name: Name of the tool for error reporting.
|
|
16
|
+
tool_func: Tool function to validate against.
|
|
17
|
+
args: Arguments to validate.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Validated arguments dict.
|
|
21
|
+
|
|
22
|
+
Raises:
|
|
23
|
+
ToolValidationError: If validation fails.
|
|
24
|
+
"""
|
|
25
|
+
try:
|
|
26
|
+
# Get type hints from function
|
|
27
|
+
type_hints = get_type_hints(tool_func)
|
|
28
|
+
|
|
29
|
+
# Remove return type hint if present
|
|
30
|
+
if 'return' in type_hints:
|
|
31
|
+
type_hints.pop('return')
|
|
32
|
+
|
|
33
|
+
# Create dynamic Pydantic model for validation
|
|
34
|
+
field_definitions = {
|
|
35
|
+
name: (type_hint, ...) for name, type_hint in type_hints.items()
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
# Add optional fields based on default values
|
|
39
|
+
sig = inspect.signature(tool_func)
|
|
40
|
+
for param_name, param in sig.parameters.items():
|
|
41
|
+
if param.default is not inspect.Parameter.empty:
|
|
42
|
+
if param_name in field_definitions:
|
|
43
|
+
field_type, _ = field_definitions[param_name]
|
|
44
|
+
field_definitions[param_name] = (field_type, param.default)
|
|
45
|
+
|
|
46
|
+
# Create model
|
|
47
|
+
model = create_model(f"{tool_name}Args", **field_definitions)
|
|
48
|
+
|
|
49
|
+
# Validate args
|
|
50
|
+
validated = model(**args)
|
|
51
|
+
return validated.dict()
|
|
52
|
+
|
|
53
|
+
except ValidationError as e:
|
|
54
|
+
raise ToolValidationError(tool_name, e.errors())
|
|
55
|
+
except Exception as e:
|
|
56
|
+
raise ToolValidationError(tool_name, {"general": str(e)})
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def validate_result(tool_name: str, tool_func: Callable, result: Any) -> Any:
|
|
60
|
+
"""
|
|
61
|
+
Validate tool result against function return type.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
tool_name: Name of the tool for error reporting.
|
|
65
|
+
tool_func: Tool function to validate against.
|
|
66
|
+
result: Result to validate.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Validated result.
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
ToolValidationError: If validation fails.
|
|
73
|
+
"""
|
|
74
|
+
try:
|
|
75
|
+
# Get return type hint
|
|
76
|
+
type_hints = get_type_hints(tool_func)
|
|
77
|
+
return_type = type_hints.get('return')
|
|
78
|
+
|
|
79
|
+
if return_type is None:
|
|
80
|
+
# No return type to validate against
|
|
81
|
+
return result
|
|
82
|
+
|
|
83
|
+
# Create dynamic Pydantic model for validation
|
|
84
|
+
model = create_model(
|
|
85
|
+
f"{tool_name}Result",
|
|
86
|
+
result=(return_type, ...)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Validate result
|
|
90
|
+
validated = model(result=result)
|
|
91
|
+
return validated.result
|
|
92
|
+
|
|
93
|
+
except ValidationError as e:
|
|
94
|
+
raise ToolValidationError(tool_name, e.errors())
|
|
95
|
+
except Exception as e:
|
|
96
|
+
raise ToolValidationError(tool_name, {"general": str(e)})
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def with_validation(cls):
|
|
100
|
+
"""
|
|
101
|
+
Class decorator to add type validation to tool classes.
|
|
102
|
+
|
|
103
|
+
Example:
|
|
104
|
+
@with_validation
|
|
105
|
+
class MyTool:
|
|
106
|
+
def execute(self, x: int, y: str) -> float:
|
|
107
|
+
return float(x) + float(y)
|
|
108
|
+
"""
|
|
109
|
+
original_execute = cls.execute
|
|
110
|
+
|
|
111
|
+
@wraps(original_execute)
|
|
112
|
+
def execute_with_validation(self, **kwargs):
|
|
113
|
+
# Get tool name
|
|
114
|
+
tool_name = getattr(cls, "__name__", repr(cls))
|
|
115
|
+
|
|
116
|
+
# Validate arguments
|
|
117
|
+
validated_args = validate_arguments(tool_name, original_execute, kwargs)
|
|
118
|
+
|
|
119
|
+
# Execute the tool
|
|
120
|
+
result = original_execute(self, **validated_args)
|
|
121
|
+
|
|
122
|
+
# Validate result
|
|
123
|
+
return validate_result(tool_name, original_execute, result)
|
|
124
|
+
|
|
125
|
+
cls.execute = execute_with_validation
|
|
126
|
+
return cls
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class ValidatedTool(BaseModel):
|
|
130
|
+
"""
|
|
131
|
+
Base class for tools with built-in validation.
|
|
132
|
+
|
|
133
|
+
Example:
|
|
134
|
+
class AddTool(ValidatedTool):
|
|
135
|
+
class Arguments(BaseModel):
|
|
136
|
+
x: int
|
|
137
|
+
y: int
|
|
138
|
+
|
|
139
|
+
class Result(BaseModel):
|
|
140
|
+
sum: int
|
|
141
|
+
|
|
142
|
+
def execute(self, x: int, y: int) -> Result:
|
|
143
|
+
return self.Result(sum=x + y)
|
|
144
|
+
"""
|
|
145
|
+
class Arguments(BaseModel):
|
|
146
|
+
"""Base arguments model to be overridden by subclasses."""
|
|
147
|
+
pass
|
|
148
|
+
|
|
149
|
+
class Result(BaseModel):
|
|
150
|
+
"""Base result model to be overridden by subclasses."""
|
|
151
|
+
pass
|
|
152
|
+
|
|
153
|
+
def execute(self, **kwargs) -> Any:
|
|
154
|
+
"""
|
|
155
|
+
Execute the tool with validated arguments.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
**kwargs: Arguments to validate against Arguments model.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Validated result according to Result model.
|
|
162
|
+
|
|
163
|
+
Raises:
|
|
164
|
+
ToolValidationError: If validation fails.
|
|
165
|
+
"""
|
|
166
|
+
try:
|
|
167
|
+
# Validate arguments
|
|
168
|
+
validated_args = self.Arguments(**kwargs)
|
|
169
|
+
|
|
170
|
+
# Execute implementation
|
|
171
|
+
result = self._execute(**validated_args.dict())
|
|
172
|
+
|
|
173
|
+
# Validate result if it's not already a Result instance
|
|
174
|
+
if not isinstance(result, self.Result):
|
|
175
|
+
result = self.Result(**result if isinstance(result, dict) else {"value": result})
|
|
176
|
+
|
|
177
|
+
return result
|
|
178
|
+
|
|
179
|
+
except ValidationError as e:
|
|
180
|
+
raise ToolValidationError(self.__class__.__name__, e.errors())
|
|
181
|
+
|
|
182
|
+
def _execute(self, **kwargs) -> Any:
|
|
183
|
+
"""
|
|
184
|
+
Implementation method to be overridden by subclasses.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
**kwargs: Validated arguments.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Result that will be validated against Result model.
|
|
191
|
+
"""
|
|
192
|
+
raise NotImplementedError("Subclasses must implement _execute")
|