toolproxy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- toolproxy/__init__.py +90 -0
- toolproxy/agent.py +184 -0
- toolproxy/config.py +98 -0
- toolproxy/exceptions.py +58 -0
- toolproxy/executor.py +123 -0
- toolproxy/llm_client.py +409 -0
- toolproxy/loop.py +180 -0
- toolproxy/planner.py +241 -0
- toolproxy/py.typed +0 -0
- toolproxy/schemas.py +125 -0
- toolproxy/tools.py +221 -0
- toolproxy-0.1.0.dist-info/METADATA +243 -0
- toolproxy-0.1.0.dist-info/RECORD +15 -0
- toolproxy-0.1.0.dist-info/WHEEL +4 -0
- toolproxy-0.1.0.dist-info/licenses/LICENSE +21 -0
toolproxy/__init__.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""
|
|
2
|
+
toolproxy — Universal Tool-Calling Wrapper for Non-Tool-Native LLMs.
|
|
3
|
+
|
|
4
|
+
Public API:
|
|
5
|
+
UniversalAgent — main agent class
|
|
6
|
+
tool — decorator to mark callables as tools
|
|
7
|
+
ToolRegistry — registry for managing tools
|
|
8
|
+
AgentResponse — returned by UniversalAgent.run()
|
|
9
|
+
AgentTrace — full debug trace
|
|
10
|
+
Message — conversation message
|
|
11
|
+
ToolCall — a single tool invocation
|
|
12
|
+
ToolResult — outcome of a tool execution
|
|
13
|
+
Action — emulated-mode structured output
|
|
14
|
+
|
|
15
|
+
Exceptions:
|
|
16
|
+
UniversalAgentError
|
|
17
|
+
ToolNotFoundError
|
|
18
|
+
ToolExecutionError
|
|
19
|
+
MaxStepsExceededError
|
|
20
|
+
SchemaValidationError
|
|
21
|
+
ModelError
|
|
22
|
+
ExecutionPolicyError
|
|
23
|
+
"""
|
|
24
|
+
from .agent import UniversalAgent
|
|
25
|
+
from .config import AgentConfig, ExecutionPolicy
|
|
26
|
+
from .exceptions import (
|
|
27
|
+
ExecutionPolicyError,
|
|
28
|
+
MaxStepsExceededError,
|
|
29
|
+
ModelError,
|
|
30
|
+
SchemaValidationError,
|
|
31
|
+
ToolExecutionError,
|
|
32
|
+
ToolNotFoundError,
|
|
33
|
+
UniversalAgentError,
|
|
34
|
+
)
|
|
35
|
+
from .llm_client import (
|
|
36
|
+
LLMClient,
|
|
37
|
+
MockClient,
|
|
38
|
+
ModelResponse,
|
|
39
|
+
OllamaClient,
|
|
40
|
+
OpenAIClient,
|
|
41
|
+
OpenRouterClient,
|
|
42
|
+
get_client,
|
|
43
|
+
)
|
|
44
|
+
from .schemas import (
|
|
45
|
+
Action,
|
|
46
|
+
AgentResponse,
|
|
47
|
+
AgentTrace,
|
|
48
|
+
Message,
|
|
49
|
+
ToolCall,
|
|
50
|
+
ToolResult,
|
|
51
|
+
TraceEntry,
|
|
52
|
+
)
|
|
53
|
+
from .tools import ToolDefinition, ToolRegistry, tool
|
|
54
|
+
|
|
55
|
+
__all__ = [
|
|
56
|
+
# Main API
|
|
57
|
+
"UniversalAgent",
|
|
58
|
+
"tool",
|
|
59
|
+
"ToolRegistry",
|
|
60
|
+
"ToolDefinition",
|
|
61
|
+
# Config
|
|
62
|
+
"AgentConfig",
|
|
63
|
+
"ExecutionPolicy",
|
|
64
|
+
# Schemas
|
|
65
|
+
"Action",
|
|
66
|
+
"AgentResponse",
|
|
67
|
+
"AgentTrace",
|
|
68
|
+
"Message",
|
|
69
|
+
"ToolCall",
|
|
70
|
+
"ToolResult",
|
|
71
|
+
"TraceEntry",
|
|
72
|
+
# Clients
|
|
73
|
+
"LLMClient",
|
|
74
|
+
"ModelResponse",
|
|
75
|
+
"OpenRouterClient",
|
|
76
|
+
"OpenAIClient",
|
|
77
|
+
"OllamaClient",
|
|
78
|
+
"MockClient",
|
|
79
|
+
"get_client",
|
|
80
|
+
# Exceptions
|
|
81
|
+
"UniversalAgentError",
|
|
82
|
+
"ToolNotFoundError",
|
|
83
|
+
"ToolExecutionError",
|
|
84
|
+
"MaxStepsExceededError",
|
|
85
|
+
"SchemaValidationError",
|
|
86
|
+
"ModelError",
|
|
87
|
+
"ExecutionPolicyError",
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
__version__ = "0.1.0"
|
toolproxy/agent.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""
|
|
2
|
+
UniversalAgent — the main developer-facing API for toolproxy.
|
|
3
|
+
|
|
4
|
+
Usage::
|
|
5
|
+
|
|
6
|
+
from toolproxy import UniversalAgent, tool
|
|
7
|
+
|
|
8
|
+
@tool
|
|
9
|
+
def get_weather(city: str) -> str:
|
|
10
|
+
\"\"\"Get the current weather for a city.\"\"\"
|
|
11
|
+
return f"Sunny, 25°C in {city}"
|
|
12
|
+
|
|
13
|
+
agent = UniversalAgent(
|
|
14
|
+
model="openrouter/mistralai/mistral-7b-instruct",
|
|
15
|
+
tools=[get_weather],
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
result = agent.run("What is the weather in Chennai?")
|
|
19
|
+
print(result.content)
|
|
20
|
+
|
|
21
|
+
# With trace
|
|
22
|
+
result = agent.run("...", return_trace=True)
|
|
23
|
+
for call in result.trace.tool_calls:
|
|
24
|
+
print(call.tool_name, call.arguments)
|
|
25
|
+
"""
|
|
26
|
+
from __future__ import annotations
|
|
27
|
+
|
|
28
|
+
import os
|
|
29
|
+
from typing import Any, Callable, List, Literal, Optional, Union
|
|
30
|
+
|
|
31
|
+
from .config import AgentConfig, ExecutionPolicy
|
|
32
|
+
from .executor import Executor
|
|
33
|
+
from .llm_client import LLMClient, get_client
|
|
34
|
+
from .loop import LoopController
|
|
35
|
+
from .planner import Planner
|
|
36
|
+
from .schemas import AgentResponse, Message
|
|
37
|
+
from .tools import ToolRegistry
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class UniversalAgent:
|
|
41
|
+
"""
|
|
42
|
+
Universal tool-calling wrapper that works with any LLM provider.
|
|
43
|
+
|
|
44
|
+
Automatically detects whether the underlying model supports native tool
|
|
45
|
+
calling and falls back to structured-output emulation when it does not.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
model : str
|
|
50
|
+
Model identifier. Prefix with:
|
|
51
|
+
- ``openrouter/`` for OpenRouter models.
|
|
52
|
+
- ``ollama/`` for local Ollama models.
|
|
53
|
+
- ``mock/`` for unit tests.
|
|
54
|
+
- No prefix for direct OpenAI-compatible endpoints.
|
|
55
|
+
tools : list
|
|
56
|
+
List of @tool-decorated callables to register.
|
|
57
|
+
mode : "auto" | "native_only" | "emulated_only"
|
|
58
|
+
Override capability detection. Defaults to "auto".
|
|
59
|
+
max_steps : int
|
|
60
|
+
Maximum loop iterations before raising MaxStepsExceededError.
|
|
61
|
+
api_key : str, optional
|
|
62
|
+
API key. Falls back to OPENROUTER_API_KEY / OPENAI_API_KEY env vars.
|
|
63
|
+
base_url : str, optional
|
|
64
|
+
Override the provider base URL.
|
|
65
|
+
execution_policy : ExecutionPolicy, optional
|
|
66
|
+
Controls which tools may be called.
|
|
67
|
+
client : LLMClient, optional
|
|
68
|
+
Inject a pre-built LLMClient (useful for testing).
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
model: str,
|
|
74
|
+
tools: Optional[List[Any]] = None,
|
|
75
|
+
mode: Literal["auto", "native_only", "emulated_only"] = "auto",
|
|
76
|
+
max_steps: int = 10,
|
|
77
|
+
api_key: Optional[str] = None,
|
|
78
|
+
base_url: Optional[str] = None,
|
|
79
|
+
execution_policy: Optional[ExecutionPolicy] = None,
|
|
80
|
+
client: Optional[LLMClient] = None,
|
|
81
|
+
) -> None:
|
|
82
|
+
self._config = AgentConfig(
|
|
83
|
+
model=model,
|
|
84
|
+
api_key=api_key,
|
|
85
|
+
base_url=base_url,
|
|
86
|
+
mode=mode,
|
|
87
|
+
max_steps=max_steps,
|
|
88
|
+
execution_policy=execution_policy or ExecutionPolicy(),
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Build or use provided LLM client
|
|
92
|
+
self._client: LLMClient = client or get_client(
|
|
93
|
+
model=model,
|
|
94
|
+
api_key=api_key,
|
|
95
|
+
base_url=base_url,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# Build tool registry
|
|
99
|
+
self._registry = ToolRegistry()
|
|
100
|
+
for t in (tools or []):
|
|
101
|
+
self._registry.register(t)
|
|
102
|
+
|
|
103
|
+
# Build planner, executor, loop controller
|
|
104
|
+
self._planner = Planner(
|
|
105
|
+
client=self._client,
|
|
106
|
+
registry=self._registry,
|
|
107
|
+
mode=mode,
|
|
108
|
+
parse_retries=self._config.parse_retries,
|
|
109
|
+
)
|
|
110
|
+
self._executor = Executor(
|
|
111
|
+
registry=self._registry,
|
|
112
|
+
policy=execution_policy,
|
|
113
|
+
)
|
|
114
|
+
self._loop = LoopController(
|
|
115
|
+
planner=self._planner,
|
|
116
|
+
executor=self._executor,
|
|
117
|
+
max_steps=max_steps,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# ------------------------------------------------------------------
|
|
121
|
+
# Main API
|
|
122
|
+
# ------------------------------------------------------------------
|
|
123
|
+
|
|
124
|
+
def run(
|
|
125
|
+
self,
|
|
126
|
+
prompt_or_messages: Union[str, List[Message]],
|
|
127
|
+
return_trace: bool = False,
|
|
128
|
+
on_tool_call: Optional[Callable] = None,
|
|
129
|
+
on_tool_result: Optional[Callable] = None,
|
|
130
|
+
on_model_output: Optional[Callable] = None,
|
|
131
|
+
) -> AgentResponse:
|
|
132
|
+
"""
|
|
133
|
+
Run the agent on *prompt_or_messages* and return an AgentResponse.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
prompt_or_messages :
|
|
138
|
+
Either a plain string (treated as a user message) or a list of
|
|
139
|
+
Message objects representing an existing conversation.
|
|
140
|
+
return_trace :
|
|
141
|
+
If True, the returned AgentResponse will include a full trace
|
|
142
|
+
of all tool calls and results.
|
|
143
|
+
on_tool_call / on_tool_result / on_model_output :
|
|
144
|
+
Optional streaming callbacks.
|
|
145
|
+
"""
|
|
146
|
+
if isinstance(prompt_or_messages, str):
|
|
147
|
+
messages: List[Message] = [Message(role="user", content=prompt_or_messages)]
|
|
148
|
+
else:
|
|
149
|
+
messages = list(prompt_or_messages)
|
|
150
|
+
|
|
151
|
+
# Update loop controller with current return_trace setting
|
|
152
|
+
self._loop._return_trace = return_trace
|
|
153
|
+
|
|
154
|
+
return self._loop.run(
|
|
155
|
+
messages=messages,
|
|
156
|
+
on_tool_call=on_tool_call,
|
|
157
|
+
on_tool_result=on_tool_result,
|
|
158
|
+
on_model_output=on_model_output,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# ------------------------------------------------------------------
|
|
162
|
+
# Convenience properties
|
|
163
|
+
# ------------------------------------------------------------------
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def model(self) -> str:
|
|
167
|
+
return self._config.model
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def uses_native_tools(self) -> bool:
|
|
171
|
+
"""True if the planner is running in native tool-calling mode."""
|
|
172
|
+
return self._planner.use_native
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def registry(self) -> ToolRegistry:
|
|
176
|
+
"""Access the internal ToolRegistry (read-only by convention)."""
|
|
177
|
+
return self._registry
|
|
178
|
+
|
|
179
|
+
def __repr__(self) -> str:
|
|
180
|
+
mode_str = "native" if self.uses_native_tools else "emulated"
|
|
181
|
+
return (
|
|
182
|
+
f"UniversalAgent(model={self.model!r}, "
|
|
183
|
+
f"tools={len(self._registry)}, mode={mode_str})"
|
|
184
|
+
)
|
toolproxy/config.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration for toolproxy.
|
|
3
|
+
|
|
4
|
+
Includes AgentConfig and the MODEL_TOOL_SUPPORT capability map.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Literal
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, Field
|
|
11
|
+
|
|
12
|
+
# ---------------------------------------------------------------------------
|
|
13
|
+
# Execution policy constants
|
|
14
|
+
# ---------------------------------------------------------------------------
|
|
15
|
+
POLICY_ALLOW_ALL = "allow_all"
|
|
16
|
+
POLICY_ALLOW_ONLY = "allow_only"
|
|
17
|
+
POLICY_CONFIRM_BEFORE = "confirm_before"
|
|
18
|
+
|
|
19
|
+
# ---------------------------------------------------------------------------
|
|
20
|
+
# Known model → native tool-calling support map
|
|
21
|
+
# Models not listed here default to False (emulated mode).
|
|
22
|
+
# Keys use lowercase and may include partial prefixes.
|
|
23
|
+
# ---------------------------------------------------------------------------
|
|
24
|
+
MODEL_TOOL_SUPPORT: dict[str, bool] = {
|
|
25
|
+
# OpenAI
|
|
26
|
+
"gpt-4o": True,
|
|
27
|
+
"gpt-4o-mini": True,
|
|
28
|
+
"gpt-4-turbo": True,
|
|
29
|
+
"gpt-4": True,
|
|
30
|
+
"gpt-3.5-turbo": True,
|
|
31
|
+
# Anthropic (via OpenRouter or direct)
|
|
32
|
+
"claude-3-5-sonnet": True,
|
|
33
|
+
"claude-3-5-haiku": True,
|
|
34
|
+
"claude-3-opus": True,
|
|
35
|
+
"claude-3-sonnet": True,
|
|
36
|
+
"claude-3-haiku": True,
|
|
37
|
+
# Google (via OpenRouter)
|
|
38
|
+
"gemini-pro": True,
|
|
39
|
+
"gemini-1.5-pro": True,
|
|
40
|
+
"gemini-1.5-flash": True,
|
|
41
|
+
# Meta / Llama (typically no native tool calling through OpenRouter free tier)
|
|
42
|
+
"llama-3": False,
|
|
43
|
+
"llama-2": False,
|
|
44
|
+
"mistral": False,
|
|
45
|
+
"mixtral": False,
|
|
46
|
+
# Qwen / DeepSeek free tier
|
|
47
|
+
"deepseek": False,
|
|
48
|
+
"qwen": False,
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def model_supports_native_tools(model: str) -> bool:
|
|
53
|
+
"""
|
|
54
|
+
Check whether *model* supports native tool/function calling.
|
|
55
|
+
|
|
56
|
+
Performs a case-insensitive substring search against MODEL_TOOL_SUPPORT.
|
|
57
|
+
Returns False for unknown models (safe default → emulated mode).
|
|
58
|
+
"""
|
|
59
|
+
lower = model.lower()
|
|
60
|
+
for key, supported in MODEL_TOOL_SUPPORT.items():
|
|
61
|
+
if key in lower:
|
|
62
|
+
return supported
|
|
63
|
+
return False
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# ---------------------------------------------------------------------------
|
|
67
|
+
# Agent configuration
|
|
68
|
+
# ---------------------------------------------------------------------------
|
|
69
|
+
|
|
70
|
+
class ExecutionPolicy(BaseModel):
|
|
71
|
+
"""Defines what tools the executor is allowed to run."""
|
|
72
|
+
|
|
73
|
+
mode: Literal["allow_all", "allow_only", "confirm_before"] = "allow_all"
|
|
74
|
+
allowed_tools: list[str] = Field(default_factory=list)
|
|
75
|
+
confirm_tools: list[str] = Field(default_factory=list)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class AgentConfig(BaseModel):
|
|
79
|
+
"""Full configuration for a UniversalAgent instance."""
|
|
80
|
+
|
|
81
|
+
model: str
|
|
82
|
+
api_key: str | None = None
|
|
83
|
+
base_url: str | None = None
|
|
84
|
+
|
|
85
|
+
# Capability override
|
|
86
|
+
mode: Literal["auto", "native_only", "emulated_only"] = "auto"
|
|
87
|
+
|
|
88
|
+
# Loop settings
|
|
89
|
+
max_steps: int = Field(default=10, ge=1)
|
|
90
|
+
|
|
91
|
+
# Execution policy
|
|
92
|
+
execution_policy: ExecutionPolicy = Field(default_factory=ExecutionPolicy)
|
|
93
|
+
|
|
94
|
+
# Tracing
|
|
95
|
+
return_trace: bool = False
|
|
96
|
+
|
|
97
|
+
# Retry attempts for schema parsing in emulated mode
|
|
98
|
+
parse_retries: int = Field(default=3, ge=1)
|
toolproxy/exceptions.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Custom exceptions for toolproxy.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class UniversalAgentError(Exception):
|
|
7
|
+
"""Base exception for all toolproxy errors."""
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ToolNotFoundError(UniversalAgentError):
|
|
11
|
+
"""Raised when a requested tool is not in the registry."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, tool_name: str) -> None:
|
|
14
|
+
super().__init__(f"Tool '{tool_name}' not found in registry.")
|
|
15
|
+
self.tool_name = tool_name
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ToolExecutionError(UniversalAgentError):
|
|
19
|
+
"""Raised when a tool callable raises an exception during execution."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, tool_name: str, cause: Exception) -> None:
|
|
22
|
+
super().__init__(f"Tool '{tool_name}' raised an error: {cause}")
|
|
23
|
+
self.tool_name = tool_name
|
|
24
|
+
self.cause = cause
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MaxStepsExceededError(UniversalAgentError):
|
|
28
|
+
"""Raised when the agent loop exceeds max_steps without a final answer."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, max_steps: int) -> None:
|
|
31
|
+
super().__init__(
|
|
32
|
+
f"Agent did not produce a final answer within {max_steps} step(s)."
|
|
33
|
+
)
|
|
34
|
+
self.max_steps = max_steps
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class SchemaValidationError(UniversalAgentError):
|
|
38
|
+
"""Raised when the model's output cannot be validated against the Action schema."""
|
|
39
|
+
|
|
40
|
+
def __init__(self, message: str) -> None:
|
|
41
|
+
super().__init__(f"Schema validation failed: {message}")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ModelError(UniversalAgentError):
|
|
45
|
+
"""Raised when the underlying LLM returns an error or unexpected response."""
|
|
46
|
+
|
|
47
|
+
def __init__(self, message: str) -> None:
|
|
48
|
+
super().__init__(f"Model error: {message}")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ExecutionPolicyError(UniversalAgentError):
|
|
52
|
+
"""Raised when a tool is blocked by the execution policy."""
|
|
53
|
+
|
|
54
|
+
def __init__(self, tool_name: str) -> None:
|
|
55
|
+
super().__init__(
|
|
56
|
+
f"Tool '{tool_name}' is blocked by the current execution policy."
|
|
57
|
+
)
|
|
58
|
+
self.tool_name = tool_name
|
toolproxy/executor.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Executor module for toolproxy.
|
|
3
|
+
|
|
4
|
+
Validates tool arguments and executes tool callables with policy enforcement.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Any, Callable, List, Optional
|
|
9
|
+
|
|
10
|
+
from pydantic import ValidationError
|
|
11
|
+
|
|
12
|
+
from .config import ExecutionPolicy, POLICY_ALLOW_ALL, POLICY_ALLOW_ONLY, POLICY_CONFIRM_BEFORE
|
|
13
|
+
from .exceptions import ExecutionPolicyError, ToolExecutionError, ToolNotFoundError
|
|
14
|
+
from .schemas import ToolCall, ToolResult
|
|
15
|
+
from .tools import ToolRegistry
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Executor:
|
|
19
|
+
"""
|
|
20
|
+
Executes ToolCall objects against the registry, with policy enforcement.
|
|
21
|
+
|
|
22
|
+
Execution policies:
|
|
23
|
+
- allow_all — any tool may be called.
|
|
24
|
+
- allow_only([...]) — only tools in the allow-list may be called.
|
|
25
|
+
- confirm_before([...]) — prompt user for confirmation before listed tools;
|
|
26
|
+
in non-interactive mode these are blocked.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
registry: ToolRegistry,
|
|
32
|
+
policy: Optional[ExecutionPolicy] = None,
|
|
33
|
+
confirm_callback: Optional[Callable[[str, dict], bool]] = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
"""
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
registry:
|
|
39
|
+
The ToolRegistry to look up tools from.
|
|
40
|
+
policy:
|
|
41
|
+
Execution policy (defaults to allow_all).
|
|
42
|
+
confirm_callback:
|
|
43
|
+
Called for confirm_before tools. Receives (tool_name, arguments).
|
|
44
|
+
Should return True to allow, False to block.
|
|
45
|
+
If None, confirm_before tools are blocked (safe default).
|
|
46
|
+
"""
|
|
47
|
+
self._registry = registry
|
|
48
|
+
self._policy = policy or ExecutionPolicy()
|
|
49
|
+
self._confirm_callback = confirm_callback
|
|
50
|
+
|
|
51
|
+
# ------------------------------------------------------------------
|
|
52
|
+
# Public interface
|
|
53
|
+
# ------------------------------------------------------------------
|
|
54
|
+
|
|
55
|
+
def execute(self, tool_call: ToolCall) -> ToolResult:
|
|
56
|
+
"""
|
|
57
|
+
Validate and execute *tool_call*.
|
|
58
|
+
|
|
59
|
+
Returns a ToolResult (success or error); never raises on tool errors.
|
|
60
|
+
Raises ExecutionPolicyError / ToolNotFoundError before execution.
|
|
61
|
+
"""
|
|
62
|
+
# 1. Check policy
|
|
63
|
+
self._check_policy(tool_call.tool_name, tool_call.arguments)
|
|
64
|
+
|
|
65
|
+
# 2. Look up tool
|
|
66
|
+
try:
|
|
67
|
+
defn = self._registry.get(tool_call.tool_name)
|
|
68
|
+
except ToolNotFoundError:
|
|
69
|
+
raise
|
|
70
|
+
|
|
71
|
+
# 3. Validate arguments
|
|
72
|
+
try:
|
|
73
|
+
validated_args = defn.args_schema.model_validate(tool_call.arguments)
|
|
74
|
+
except ValidationError as exc:
|
|
75
|
+
return ToolResult(
|
|
76
|
+
tool_name=tool_call.tool_name,
|
|
77
|
+
call_id=tool_call.call_id,
|
|
78
|
+
error=f"Argument validation failed: {exc}",
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# 4. Execute the callable
|
|
82
|
+
try:
|
|
83
|
+
output = defn.callable(**validated_args.model_dump())
|
|
84
|
+
except Exception as exc:
|
|
85
|
+
return ToolResult(
|
|
86
|
+
tool_name=tool_call.tool_name,
|
|
87
|
+
call_id=tool_call.call_id,
|
|
88
|
+
error=f"Tool raised exception: {type(exc).__name__}: {exc}",
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
return ToolResult(
|
|
92
|
+
tool_name=tool_call.tool_name,
|
|
93
|
+
call_id=tool_call.call_id,
|
|
94
|
+
output=str(output),
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def execute_many(self, tool_calls: List[ToolCall]) -> List[ToolResult]:
|
|
98
|
+
"""Execute a list of tool calls sequentially, collecting all results."""
|
|
99
|
+
return [self.execute(tc) for tc in tool_calls]
|
|
100
|
+
|
|
101
|
+
# ------------------------------------------------------------------
|
|
102
|
+
# Policy enforcement
|
|
103
|
+
# ------------------------------------------------------------------
|
|
104
|
+
|
|
105
|
+
def _check_policy(self, tool_name: str, arguments: dict) -> None:
|
|
106
|
+
mode = self._policy.mode
|
|
107
|
+
|
|
108
|
+
if mode == POLICY_ALLOW_ALL:
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
if mode == POLICY_ALLOW_ONLY:
|
|
112
|
+
if tool_name not in self._policy.allowed_tools:
|
|
113
|
+
raise ExecutionPolicyError(tool_name)
|
|
114
|
+
return
|
|
115
|
+
|
|
116
|
+
if mode == POLICY_CONFIRM_BEFORE:
|
|
117
|
+
if tool_name in self._policy.confirm_tools:
|
|
118
|
+
if self._confirm_callback is None:
|
|
119
|
+
raise ExecutionPolicyError(tool_name)
|
|
120
|
+
allowed = self._confirm_callback(tool_name, arguments)
|
|
121
|
+
if not allowed:
|
|
122
|
+
raise ExecutionPolicyError(tool_name)
|
|
123
|
+
return
|