toolproxy 0.1.0__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {toolproxy-0.1.0 → toolproxy-0.2.0}/PKG-INFO +1 -1
- {toolproxy-0.1.0 → toolproxy-0.2.0}/pyproject.toml +1 -1
- {toolproxy-0.1.0 → toolproxy-0.2.0}/src/toolproxy/__init__.py +10 -3
- {toolproxy-0.1.0 → toolproxy-0.2.0}/src/toolproxy/agent.py +91 -3
- toolproxy-0.2.0/src/toolproxy/config.py +201 -0
- {toolproxy-0.1.0 → toolproxy-0.2.0}/src/toolproxy/executor.py +79 -4
- {toolproxy-0.1.0 → toolproxy-0.2.0}/src/toolproxy/llm_client.py +332 -49
- toolproxy-0.2.0/src/toolproxy/loop.py +379 -0
- {toolproxy-0.1.0 → toolproxy-0.2.0}/src/toolproxy/planner.py +69 -5
- {toolproxy-0.1.0 → toolproxy-0.2.0}/src/toolproxy/schemas.py +36 -0
- toolproxy-0.2.0/tests/test_async_agent.py +480 -0
- toolproxy-0.1.0/src/toolproxy/config.py +0 -98
- toolproxy-0.1.0/src/toolproxy/loop.py +0 -180
- {toolproxy-0.1.0 → toolproxy-0.2.0}/.gitignore +0 -0
- {toolproxy-0.1.0 → toolproxy-0.2.0}/LICENSE +0 -0
- {toolproxy-0.1.0 → toolproxy-0.2.0}/README.md +0 -0
- {toolproxy-0.1.0 → toolproxy-0.2.0}/examples/basic_chat.py +0 -0
- {toolproxy-0.1.0 → toolproxy-0.2.0}/examples/local_ollama.py +0 -0
- {toolproxy-0.1.0 → toolproxy-0.2.0}/examples/openrouter_tools.py +0 -0
- {toolproxy-0.1.0 → toolproxy-0.2.0}/src/toolproxy/exceptions.py +0 -0
- {toolproxy-0.1.0 → toolproxy-0.2.0}/src/toolproxy/py.typed +0 -0
- {toolproxy-0.1.0 → toolproxy-0.2.0}/src/toolproxy/tools.py +0 -0
- {toolproxy-0.1.0 → toolproxy-0.2.0}/tests/conftest.py +0 -0
- {toolproxy-0.1.0 → toolproxy-0.2.0}/tests/test_agent_basic.py +0 -0
- {toolproxy-0.1.0 → toolproxy-0.2.0}/tests/test_emulated_mode.py +0 -0
- {toolproxy-0.1.0 → toolproxy-0.2.0}/tests/test_error_handling.py +0 -0
- {toolproxy-0.1.0 → toolproxy-0.2.0}/tests/test_native_mode.py +0 -0
- {toolproxy-0.1.0 → toolproxy-0.2.0}/tests/test_tool_registry.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: toolproxy
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: Universal tool-calling wrapper for non-tool-native LLMs — emulates function calling via structured JSON planning
|
|
5
5
|
Project-URL: Homepage, https://github.com/yourusername/toolproxy
|
|
6
6
|
Project-URL: Repository, https://github.com/yourusername/toolproxy
|
|
@@ -5,12 +5,16 @@ Public API:
|
|
|
5
5
|
UniversalAgent — main agent class
|
|
6
6
|
tool — decorator to mark callables as tools
|
|
7
7
|
ToolRegistry — registry for managing tools
|
|
8
|
-
AgentResponse — returned by UniversalAgent.run()
|
|
8
|
+
AgentResponse — returned by UniversalAgent.run() / arun()
|
|
9
9
|
AgentTrace — full debug trace
|
|
10
10
|
Message — conversation message
|
|
11
11
|
ToolCall — a single tool invocation
|
|
12
12
|
ToolResult — outcome of a tool execution
|
|
13
13
|
Action — emulated-mode structured output
|
|
14
|
+
StreamChunk — (v0.2) streaming event from agent.stream()
|
|
15
|
+
|
|
16
|
+
Utilities:
|
|
17
|
+
probe_tool_support — (v0.2) async helper to detect native tool support
|
|
14
18
|
|
|
15
19
|
Exceptions:
|
|
16
20
|
UniversalAgentError
|
|
@@ -22,7 +26,7 @@ Exceptions:
|
|
|
22
26
|
ExecutionPolicyError
|
|
23
27
|
"""
|
|
24
28
|
from .agent import UniversalAgent
|
|
25
|
-
from .config import AgentConfig, ExecutionPolicy
|
|
29
|
+
from .config import AgentConfig, ExecutionPolicy, probe_tool_support
|
|
26
30
|
from .exceptions import (
|
|
27
31
|
ExecutionPolicyError,
|
|
28
32
|
MaxStepsExceededError,
|
|
@@ -46,6 +50,7 @@ from .schemas import (
|
|
|
46
50
|
AgentResponse,
|
|
47
51
|
AgentTrace,
|
|
48
52
|
Message,
|
|
53
|
+
StreamChunk,
|
|
49
54
|
ToolCall,
|
|
50
55
|
ToolResult,
|
|
51
56
|
TraceEntry,
|
|
@@ -61,11 +66,13 @@ __all__ = [
|
|
|
61
66
|
# Config
|
|
62
67
|
"AgentConfig",
|
|
63
68
|
"ExecutionPolicy",
|
|
69
|
+
"probe_tool_support",
|
|
64
70
|
# Schemas
|
|
65
71
|
"Action",
|
|
66
72
|
"AgentResponse",
|
|
67
73
|
"AgentTrace",
|
|
68
74
|
"Message",
|
|
75
|
+
"StreamChunk",
|
|
69
76
|
"ToolCall",
|
|
70
77
|
"ToolResult",
|
|
71
78
|
"TraceEntry",
|
|
@@ -87,4 +94,4 @@ __all__ = [
|
|
|
87
94
|
"ExecutionPolicyError",
|
|
88
95
|
]
|
|
89
96
|
|
|
90
|
-
__version__ = "0.
|
|
97
|
+
__version__ = "0.2.0"
|
|
@@ -15,9 +15,21 @@ Usage::
|
|
|
15
15
|
tools=[get_weather],
|
|
16
16
|
)
|
|
17
17
|
|
|
18
|
+
# Synchronous (unchanged API)
|
|
18
19
|
result = agent.run("What is the weather in Chennai?")
|
|
19
20
|
print(result.content)
|
|
20
21
|
|
|
22
|
+
# Async (new in v0.2)
|
|
23
|
+
result = await agent.arun("What is the weather in Chennai?")
|
|
24
|
+
print(result.content)
|
|
25
|
+
|
|
26
|
+
# Streaming (new in v0.2)
|
|
27
|
+
async for chunk in agent.stream("What is the weather in Chennai?"):
|
|
28
|
+
if chunk.done:
|
|
29
|
+
print()
|
|
30
|
+
else:
|
|
31
|
+
print(chunk.delta, end="", flush=True)
|
|
32
|
+
|
|
21
33
|
# With trace
|
|
22
34
|
result = agent.run("...", return_trace=True)
|
|
23
35
|
for call in result.trace.tool_calls:
|
|
@@ -26,14 +38,14 @@ Usage::
|
|
|
26
38
|
from __future__ import annotations
|
|
27
39
|
|
|
28
40
|
import os
|
|
29
|
-
from typing import Any, Callable, List, Literal, Optional, Union
|
|
41
|
+
from typing import Any, AsyncIterator, Callable, List, Literal, Optional, Union
|
|
30
42
|
|
|
31
43
|
from .config import AgentConfig, ExecutionPolicy
|
|
32
44
|
from .executor import Executor
|
|
33
45
|
from .llm_client import LLMClient, get_client
|
|
34
46
|
from .loop import LoopController
|
|
35
47
|
from .planner import Planner
|
|
36
|
-
from .schemas import AgentResponse, Message
|
|
48
|
+
from .schemas import AgentResponse, Message, StreamChunk
|
|
37
49
|
from .tools import ToolRegistry
|
|
38
50
|
|
|
39
51
|
|
|
@@ -118,7 +130,7 @@ class UniversalAgent:
|
|
|
118
130
|
)
|
|
119
131
|
|
|
120
132
|
# ------------------------------------------------------------------
|
|
121
|
-
#
|
|
133
|
+
# Synchronous API (backwards-compatible)
|
|
122
134
|
# ------------------------------------------------------------------
|
|
123
135
|
|
|
124
136
|
def run(
|
|
@@ -158,6 +170,82 @@ class UniversalAgent:
|
|
|
158
170
|
on_model_output=on_model_output,
|
|
159
171
|
)
|
|
160
172
|
|
|
173
|
+
# ------------------------------------------------------------------
|
|
174
|
+
# Async API (new in v0.2)
|
|
175
|
+
# ------------------------------------------------------------------
|
|
176
|
+
|
|
177
|
+
async def arun(
|
|
178
|
+
self,
|
|
179
|
+
prompt_or_messages: Union[str, List[Message]],
|
|
180
|
+
return_trace: bool = False,
|
|
181
|
+
on_tool_call: Optional[Callable] = None,
|
|
182
|
+
on_tool_result: Optional[Callable] = None,
|
|
183
|
+
on_model_output: Optional[Callable] = None,
|
|
184
|
+
) -> AgentResponse:
|
|
185
|
+
"""
|
|
186
|
+
Async version of ``run``.
|
|
187
|
+
|
|
188
|
+
Uses non-blocking I/O throughout and dispatches multiple tool calls
|
|
189
|
+
in parallel via asyncio.gather when the model requests them.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
prompt_or_messages :
|
|
194
|
+
Either a plain string or a list of Message objects.
|
|
195
|
+
return_trace :
|
|
196
|
+
If True, the response includes a full execution trace.
|
|
197
|
+
on_tool_call / on_tool_result / on_model_output :
|
|
198
|
+
Optional callbacks (called synchronously mid-loop).
|
|
199
|
+
"""
|
|
200
|
+
if isinstance(prompt_or_messages, str):
|
|
201
|
+
messages: List[Message] = [Message(role="user", content=prompt_or_messages)]
|
|
202
|
+
else:
|
|
203
|
+
messages = list(prompt_or_messages)
|
|
204
|
+
|
|
205
|
+
self._loop._return_trace = return_trace
|
|
206
|
+
|
|
207
|
+
return await self._loop.arun(
|
|
208
|
+
messages=messages,
|
|
209
|
+
on_tool_call=on_tool_call,
|
|
210
|
+
on_tool_result=on_tool_result,
|
|
211
|
+
on_model_output=on_model_output,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# ------------------------------------------------------------------
|
|
215
|
+
# Streaming API (new in v0.2)
|
|
216
|
+
# ------------------------------------------------------------------
|
|
217
|
+
|
|
218
|
+
def stream(
|
|
219
|
+
self,
|
|
220
|
+
prompt_or_messages: Union[str, List[Message]],
|
|
221
|
+
) -> AsyncIterator[StreamChunk]:
|
|
222
|
+
"""
|
|
223
|
+
Stream the agent's execution as an async iterator of StreamChunk events.
|
|
224
|
+
|
|
225
|
+
Usage::
|
|
226
|
+
|
|
227
|
+
async for chunk in agent.stream("What's the weather in Chennai?"):
|
|
228
|
+
if chunk.done:
|
|
229
|
+
print() # newline after streaming
|
|
230
|
+
elif chunk.tool_call:
|
|
231
|
+
print(f"\\n[calling {chunk.tool_call.tool_name}]")
|
|
232
|
+
elif chunk.tool_result:
|
|
233
|
+
print(f"\\n[result: {chunk.tool_result.output}]")
|
|
234
|
+
else:
|
|
235
|
+
print(chunk.delta, end="", flush=True)
|
|
236
|
+
|
|
237
|
+
Yields
|
|
238
|
+
------
|
|
239
|
+
StreamChunk
|
|
240
|
+
One chunk per event: text token, tool call, tool result, or done.
|
|
241
|
+
"""
|
|
242
|
+
if isinstance(prompt_or_messages, str):
|
|
243
|
+
messages: List[Message] = [Message(role="user", content=prompt_or_messages)]
|
|
244
|
+
else:
|
|
245
|
+
messages = list(prompt_or_messages)
|
|
246
|
+
|
|
247
|
+
return self._loop.stream(messages)
|
|
248
|
+
|
|
161
249
|
# ------------------------------------------------------------------
|
|
162
250
|
# Convenience properties
|
|
163
251
|
# ------------------------------------------------------------------
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration for toolproxy.
|
|
3
|
+
|
|
4
|
+
Includes AgentConfig, ExecutionPolicy, the MODEL_TOOL_SUPPORT capability map,
|
|
5
|
+
and an async ``probe_tool_support()`` helper that can dynamically detect
|
|
6
|
+
whether a model/endpoint truly supports native tool calling.
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import os
|
|
12
|
+
import time
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any, Dict, Literal, Optional
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel, Field
|
|
17
|
+
|
|
18
|
+
# ---------------------------------------------------------------------------
|
|
19
|
+
# Execution policy constants
|
|
20
|
+
# ---------------------------------------------------------------------------
|
|
21
|
+
POLICY_ALLOW_ALL = "allow_all"
|
|
22
|
+
POLICY_ALLOW_ONLY = "allow_only"
|
|
23
|
+
POLICY_CONFIRM_BEFORE = "confirm_before"
|
|
24
|
+
|
|
25
|
+
# ---------------------------------------------------------------------------
|
|
26
|
+
# Known model → native tool-calling support map
|
|
27
|
+
# Models not listed here default to False (emulated mode).
|
|
28
|
+
# Keys use lowercase and may include partial prefixes.
|
|
29
|
+
# ---------------------------------------------------------------------------
|
|
30
|
+
MODEL_TOOL_SUPPORT: Dict[str, bool] = {
|
|
31
|
+
# OpenAI
|
|
32
|
+
"gpt-4o": True,
|
|
33
|
+
"gpt-4o-mini": True,
|
|
34
|
+
"gpt-4-turbo": True,
|
|
35
|
+
"gpt-4": True,
|
|
36
|
+
"gpt-3.5-turbo": True,
|
|
37
|
+
# Anthropic (via OpenRouter or direct)
|
|
38
|
+
"claude-3-5-sonnet": True,
|
|
39
|
+
"claude-3-5-haiku": True,
|
|
40
|
+
"claude-3-opus": True,
|
|
41
|
+
"claude-3-sonnet": True,
|
|
42
|
+
"claude-3-haiku": True,
|
|
43
|
+
# Google (via OpenRouter)
|
|
44
|
+
"gemini-pro": True,
|
|
45
|
+
"gemini-1.5-pro": True,
|
|
46
|
+
"gemini-1.5-flash": True,
|
|
47
|
+
"gemini-2": True,
|
|
48
|
+
# Meta / Llama (typically no native tool calling through OpenRouter free tier)
|
|
49
|
+
"llama-3": False,
|
|
50
|
+
"llama-2": False,
|
|
51
|
+
"mistral": False,
|
|
52
|
+
"mixtral": False,
|
|
53
|
+
# Qwen / DeepSeek free tier
|
|
54
|
+
"deepseek": False,
|
|
55
|
+
"qwen": False,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def model_supports_native_tools(model: str) -> bool:
|
|
60
|
+
"""
|
|
61
|
+
Check whether *model* supports native tool/function calling.
|
|
62
|
+
|
|
63
|
+
Performs a case-insensitive substring search against MODEL_TOOL_SUPPORT.
|
|
64
|
+
Returns False for unknown models (safe default → emulated mode).
|
|
65
|
+
"""
|
|
66
|
+
lower = model.lower()
|
|
67
|
+
for key, supported in MODEL_TOOL_SUPPORT.items():
|
|
68
|
+
if key in lower:
|
|
69
|
+
return supported
|
|
70
|
+
return False
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# ---------------------------------------------------------------------------
|
|
74
|
+
# Dynamic capability probe
|
|
75
|
+
# ---------------------------------------------------------------------------
|
|
76
|
+
|
|
77
|
+
# Simple JSON cache stored alongside the package data (or the user's home dir)
|
|
78
|
+
_CACHE_PATH = Path(os.environ.get("TOOLPROXY_CACHE_DIR", Path.home() / ".toolproxy")) / "capability_cache.json"
|
|
79
|
+
_CACHE_TTL_SECONDS = 86_400 # 24 hours
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _load_cache() -> Dict[str, Any]:
|
|
83
|
+
"""Load the capability cache from disk, returning {} on any error."""
|
|
84
|
+
try:
|
|
85
|
+
if _CACHE_PATH.exists():
|
|
86
|
+
return json.loads(_CACHE_PATH.read_text())
|
|
87
|
+
except Exception:
|
|
88
|
+
pass
|
|
89
|
+
return {}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _save_cache(cache: Dict[str, Any]) -> None:
|
|
93
|
+
"""Persist *cache* to disk, silently ignoring write errors."""
|
|
94
|
+
try:
|
|
95
|
+
_CACHE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
|
96
|
+
_CACHE_PATH.write_text(json.dumps(cache, indent=2))
|
|
97
|
+
except Exception:
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
async def probe_tool_support(
|
|
102
|
+
client: Any, # LLMClient — circular import avoided via Any
|
|
103
|
+
cache_key: Optional[str] = None,
|
|
104
|
+
force: bool = False,
|
|
105
|
+
) -> bool:
|
|
106
|
+
"""
|
|
107
|
+
Dynamically probe whether *client*'s model supports native tool calling.
|
|
108
|
+
|
|
109
|
+
Strategy
|
|
110
|
+
--------
|
|
111
|
+
1. Check the in-memory / on-disk cache (TTL: 24 h) unless ``force=True``.
|
|
112
|
+
2. Send a minimal tool-schema test request (a dummy ``ping`` tool).
|
|
113
|
+
3. If the response contains tool_calls → native mode supported.
|
|
114
|
+
4. If the response is plain text or the provider returns an error
|
|
115
|
+
indicating unsupported features → emulated mode.
|
|
116
|
+
5. Cache and return the result.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
client:
|
|
121
|
+
Any LLMClient instance.
|
|
122
|
+
cache_key:
|
|
123
|
+
Optional cache key string (defaults to ``client.model``).
|
|
124
|
+
force:
|
|
125
|
+
If True, bypass the cache and always probe live.
|
|
126
|
+
"""
|
|
127
|
+
from .schemas import Message # local import to avoid circular
|
|
128
|
+
|
|
129
|
+
key = cache_key or getattr(client, "model", "unknown")
|
|
130
|
+
now = time.time()
|
|
131
|
+
|
|
132
|
+
if not force:
|
|
133
|
+
cache = _load_cache()
|
|
134
|
+
entry = cache.get(key)
|
|
135
|
+
if entry and (now - entry.get("ts", 0)) < _CACHE_TTL_SECONDS:
|
|
136
|
+
return bool(entry["supported"])
|
|
137
|
+
|
|
138
|
+
# Minimal probe: ask model to call a trivial "ping" tool
|
|
139
|
+
_PING_TOOL = [
|
|
140
|
+
{
|
|
141
|
+
"type": "function",
|
|
142
|
+
"function": {
|
|
143
|
+
"name": "ping",
|
|
144
|
+
"description": "A diagnostic ping.",
|
|
145
|
+
"parameters": {"type": "object", "properties": {}, "required": []},
|
|
146
|
+
},
|
|
147
|
+
}
|
|
148
|
+
]
|
|
149
|
+
probe_messages = [
|
|
150
|
+
Message(role="user", content="ping"),
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
supported = False
|
|
154
|
+
try:
|
|
155
|
+
response = await client.agenerate(probe_messages, tools=_PING_TOOL)
|
|
156
|
+
supported = response.has_tool_calls
|
|
157
|
+
except Exception:
|
|
158
|
+
# Any error → assume no native tool support (safe default)
|
|
159
|
+
supported = False
|
|
160
|
+
|
|
161
|
+
# Persist to cache
|
|
162
|
+
cache = _load_cache()
|
|
163
|
+
cache[key] = {"supported": supported, "ts": now}
|
|
164
|
+
_save_cache(cache)
|
|
165
|
+
|
|
166
|
+
return supported
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# ---------------------------------------------------------------------------
|
|
170
|
+
# Agent configuration
|
|
171
|
+
# ---------------------------------------------------------------------------
|
|
172
|
+
|
|
173
|
+
class ExecutionPolicy(BaseModel):
|
|
174
|
+
"""Defines what tools the executor is allowed to run."""
|
|
175
|
+
|
|
176
|
+
mode: Literal["allow_all", "allow_only", "confirm_before"] = "allow_all"
|
|
177
|
+
allowed_tools: list[str] = Field(default_factory=list)
|
|
178
|
+
confirm_tools: list[str] = Field(default_factory=list)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class AgentConfig(BaseModel):
|
|
182
|
+
"""Full configuration for a UniversalAgent instance."""
|
|
183
|
+
|
|
184
|
+
model: str
|
|
185
|
+
api_key: str | None = None
|
|
186
|
+
base_url: str | None = None
|
|
187
|
+
|
|
188
|
+
# Capability override
|
|
189
|
+
mode: Literal["auto", "native_only", "emulated_only"] = "auto"
|
|
190
|
+
|
|
191
|
+
# Loop settings
|
|
192
|
+
max_steps: int = Field(default=10, ge=1)
|
|
193
|
+
|
|
194
|
+
# Execution policy
|
|
195
|
+
execution_policy: ExecutionPolicy = Field(default_factory=ExecutionPolicy)
|
|
196
|
+
|
|
197
|
+
# Tracing
|
|
198
|
+
return_trace: bool = False
|
|
199
|
+
|
|
200
|
+
# Retry attempts for schema parsing in emulated mode
|
|
201
|
+
parse_retries: int = Field(default=3, ge=1)
|
|
@@ -2,9 +2,14 @@
|
|
|
2
2
|
Executor module for toolproxy.
|
|
3
3
|
|
|
4
4
|
Validates tool arguments and executes tool callables with policy enforcement.
|
|
5
|
+
|
|
6
|
+
Supports both synchronous execution (execute / execute_many) and async
|
|
7
|
+
parallel execution (aexecute / aexecute_many) via asyncio.gather.
|
|
5
8
|
"""
|
|
6
9
|
from __future__ import annotations
|
|
7
10
|
|
|
11
|
+
import asyncio
|
|
12
|
+
import inspect
|
|
8
13
|
from typing import Any, Callable, List, Optional
|
|
9
14
|
|
|
10
15
|
from pydantic import ValidationError
|
|
@@ -24,6 +29,10 @@ class Executor:
|
|
|
24
29
|
- allow_only([...]) — only tools in the allow-list may be called.
|
|
25
30
|
- confirm_before([...]) — prompt user for confirmation before listed tools;
|
|
26
31
|
in non-interactive mode these are blocked.
|
|
32
|
+
|
|
33
|
+
Both sync and async execution are supported:
|
|
34
|
+
- ``execute`` / ``execute_many`` — synchronous
|
|
35
|
+
- ``aexecute`` / ``aexecute_many`` — async, with parallel dispatch
|
|
27
36
|
"""
|
|
28
37
|
|
|
29
38
|
def __init__(
|
|
@@ -49,7 +58,7 @@ class Executor:
|
|
|
49
58
|
self._confirm_callback = confirm_callback
|
|
50
59
|
|
|
51
60
|
# ------------------------------------------------------------------
|
|
52
|
-
#
|
|
61
|
+
# Sync public interface
|
|
53
62
|
# ------------------------------------------------------------------
|
|
54
63
|
|
|
55
64
|
def execute(self, tool_call: ToolCall) -> ToolResult:
|
|
@@ -78,9 +87,14 @@ class Executor:
|
|
|
78
87
|
error=f"Argument validation failed: {exc}",
|
|
79
88
|
)
|
|
80
89
|
|
|
81
|
-
# 4. Execute the callable
|
|
90
|
+
# 4. Execute the callable (handles both sync and async tools)
|
|
82
91
|
try:
|
|
83
|
-
|
|
92
|
+
raw_output = defn.callable(**validated_args.model_dump())
|
|
93
|
+
# If the tool is an async function, run it in the event loop
|
|
94
|
+
if inspect.isawaitable(raw_output):
|
|
95
|
+
loop = asyncio.get_event_loop()
|
|
96
|
+
raw_output = loop.run_until_complete(raw_output)
|
|
97
|
+
output = str(raw_output)
|
|
84
98
|
except Exception as exc:
|
|
85
99
|
return ToolResult(
|
|
86
100
|
tool_name=tool_call.tool_name,
|
|
@@ -91,13 +105,74 @@ class Executor:
|
|
|
91
105
|
return ToolResult(
|
|
92
106
|
tool_name=tool_call.tool_name,
|
|
93
107
|
call_id=tool_call.call_id,
|
|
94
|
-
output=
|
|
108
|
+
output=output,
|
|
95
109
|
)
|
|
96
110
|
|
|
97
111
|
def execute_many(self, tool_calls: List[ToolCall]) -> List[ToolResult]:
|
|
98
112
|
"""Execute a list of tool calls sequentially, collecting all results."""
|
|
99
113
|
return [self.execute(tc) for tc in tool_calls]
|
|
100
114
|
|
|
115
|
+
# ------------------------------------------------------------------
|
|
116
|
+
# Async public interface
|
|
117
|
+
# ------------------------------------------------------------------
|
|
118
|
+
|
|
119
|
+
async def aexecute(self, tool_call: ToolCall) -> ToolResult:
|
|
120
|
+
"""
|
|
121
|
+
Async version of ``execute``.
|
|
122
|
+
|
|
123
|
+
Natively handles ``async def`` tools (awaited directly) and
|
|
124
|
+
sync tools (run in a thread pool via ``asyncio.to_thread``).
|
|
125
|
+
"""
|
|
126
|
+
# 1. Check policy
|
|
127
|
+
self._check_policy(tool_call.tool_name, tool_call.arguments)
|
|
128
|
+
|
|
129
|
+
# 2. Look up tool
|
|
130
|
+
try:
|
|
131
|
+
defn = self._registry.get(tool_call.tool_name)
|
|
132
|
+
except ToolNotFoundError:
|
|
133
|
+
raise
|
|
134
|
+
|
|
135
|
+
# 3. Validate arguments
|
|
136
|
+
try:
|
|
137
|
+
validated_args = defn.args_schema.model_validate(tool_call.arguments)
|
|
138
|
+
except ValidationError as exc:
|
|
139
|
+
return ToolResult(
|
|
140
|
+
tool_name=tool_call.tool_name,
|
|
141
|
+
call_id=tool_call.call_id,
|
|
142
|
+
error=f"Argument validation failed: {exc}",
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# 4. Execute (async-aware)
|
|
146
|
+
try:
|
|
147
|
+
kwargs = validated_args.model_dump()
|
|
148
|
+
if inspect.iscoroutinefunction(defn.callable):
|
|
149
|
+
raw_output = await defn.callable(**kwargs)
|
|
150
|
+
else:
|
|
151
|
+
raw_output = await asyncio.to_thread(defn.callable, **kwargs)
|
|
152
|
+
output = str(raw_output)
|
|
153
|
+
except Exception as exc:
|
|
154
|
+
return ToolResult(
|
|
155
|
+
tool_name=tool_call.tool_name,
|
|
156
|
+
call_id=tool_call.call_id,
|
|
157
|
+
error=f"Tool raised exception: {type(exc).__name__}: {exc}",
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
return ToolResult(
|
|
161
|
+
tool_name=tool_call.tool_name,
|
|
162
|
+
call_id=tool_call.call_id,
|
|
163
|
+
output=output,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
async def aexecute_many(self, tool_calls: List[ToolCall]) -> List[ToolResult]:
|
|
167
|
+
"""
|
|
168
|
+
Execute *tool_calls* in parallel using ``asyncio.gather``.
|
|
169
|
+
|
|
170
|
+
All tool calls in the list are dispatched concurrently and results
|
|
171
|
+
are returned in the same order as the input list.
|
|
172
|
+
"""
|
|
173
|
+
tasks = [self.aexecute(tc) for tc in tool_calls]
|
|
174
|
+
return list(await asyncio.gather(*tasks))
|
|
175
|
+
|
|
101
176
|
# ------------------------------------------------------------------
|
|
102
177
|
# Policy enforcement
|
|
103
178
|
# ------------------------------------------------------------------
|