tollgate 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tollgate/__init__.py +66 -0
- tollgate/approvals.py +227 -0
- tollgate/audit.py +47 -0
- tollgate/exceptions.py +28 -0
- tollgate/helpers.py +72 -0
- tollgate/integrations/__init__.py +0 -0
- tollgate/integrations/mcp.py +58 -0
- tollgate/integrations/strands.py +89 -0
- tollgate/interceptors/__init__.py +12 -0
- tollgate/interceptors/base.py +41 -0
- tollgate/interceptors/langchain.py +91 -0
- tollgate/interceptors/openai.py +87 -0
- tollgate/policy.py +152 -0
- tollgate/registry.py +58 -0
- tollgate/tower.py +224 -0
- tollgate/types.py +124 -0
- tollgate-1.0.0.dist-info/METADATA +98 -0
- tollgate-1.0.0.dist-info/RECORD +20 -0
- tollgate-1.0.0.dist-info/WHEEL +4 -0
- tollgate-1.0.0.dist-info/licenses/LICENSE +176 -0
tollgate/__init__.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from .approvals import (
|
|
2
|
+
ApprovalOutcome,
|
|
3
|
+
ApprovalStore,
|
|
4
|
+
Approver,
|
|
5
|
+
AsyncQueueApprover,
|
|
6
|
+
AutoApprover,
|
|
7
|
+
CliApprover,
|
|
8
|
+
InMemoryApprovalStore,
|
|
9
|
+
compute_request_hash,
|
|
10
|
+
)
|
|
11
|
+
from .audit import AuditSink, JsonlAuditSink
|
|
12
|
+
from .exceptions import (
|
|
13
|
+
TollgateApprovalDenied,
|
|
14
|
+
TollgateDeferred,
|
|
15
|
+
TollgateDenied,
|
|
16
|
+
TollgateError,
|
|
17
|
+
)
|
|
18
|
+
from .helpers import guard, wrap_tool
|
|
19
|
+
from .policy import PolicyEvaluator, YamlPolicyEvaluator
|
|
20
|
+
from .registry import ToolRegistry
|
|
21
|
+
from .tower import ControlTower
|
|
22
|
+
from .types import (
|
|
23
|
+
AgentContext,
|
|
24
|
+
AuditEvent,
|
|
25
|
+
Decision,
|
|
26
|
+
DecisionType,
|
|
27
|
+
Effect,
|
|
28
|
+
Intent,
|
|
29
|
+
NormalizedToolCall,
|
|
30
|
+
Outcome,
|
|
31
|
+
ToolRequest,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
__version__ = "1.0.0"
|
|
35
|
+
|
|
36
|
+
__all__ = [
|
|
37
|
+
"ControlTower",
|
|
38
|
+
"AgentContext",
|
|
39
|
+
"Intent",
|
|
40
|
+
"ToolRequest",
|
|
41
|
+
"NormalizedToolCall",
|
|
42
|
+
"Decision",
|
|
43
|
+
"DecisionType",
|
|
44
|
+
"Effect",
|
|
45
|
+
"AuditEvent",
|
|
46
|
+
"Outcome",
|
|
47
|
+
"ApprovalOutcome",
|
|
48
|
+
"ApprovalStore",
|
|
49
|
+
"Approver",
|
|
50
|
+
"InMemoryApprovalStore",
|
|
51
|
+
"AsyncQueueApprover",
|
|
52
|
+
"AutoApprover",
|
|
53
|
+
"CliApprover",
|
|
54
|
+
"compute_request_hash",
|
|
55
|
+
"AuditSink",
|
|
56
|
+
"JsonlAuditSink",
|
|
57
|
+
"ToolRegistry",
|
|
58
|
+
"PolicyEvaluator",
|
|
59
|
+
"YamlPolicyEvaluator",
|
|
60
|
+
"TollgateError",
|
|
61
|
+
"TollgateDenied",
|
|
62
|
+
"TollgateApprovalDenied",
|
|
63
|
+
"TollgateDeferred",
|
|
64
|
+
"wrap_tool",
|
|
65
|
+
"guard",
|
|
66
|
+
]
|
tollgate/approvals.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import hashlib
|
|
3
|
+
import json
|
|
4
|
+
import time
|
|
5
|
+
import uuid
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import Any, Protocol
|
|
8
|
+
|
|
9
|
+
from .types import AgentContext, ApprovalOutcome, Effect, Intent, ToolRequest
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ApprovalStore(ABC):
|
|
13
|
+
"""Interface for persistent storage of approval requests."""
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
async def create_request(
|
|
17
|
+
self,
|
|
18
|
+
agent_ctx: AgentContext,
|
|
19
|
+
intent: Intent,
|
|
20
|
+
tool_request: ToolRequest,
|
|
21
|
+
request_hash: str,
|
|
22
|
+
reason: str,
|
|
23
|
+
expiry: float,
|
|
24
|
+
) -> str:
|
|
25
|
+
"""Create an approval request and return its ID."""
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
async def set_decision(
|
|
30
|
+
self,
|
|
31
|
+
approval_id: str,
|
|
32
|
+
outcome: ApprovalOutcome,
|
|
33
|
+
decided_by: str,
|
|
34
|
+
decided_at: float,
|
|
35
|
+
request_hash: str,
|
|
36
|
+
) -> None:
|
|
37
|
+
"""Record a decision for an approval request."""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
async def get_request(self, approval_id: str) -> dict[str, Any] | None:
|
|
42
|
+
"""Load an approval request by ID."""
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
async def wait_for_decision(
|
|
47
|
+
self, approval_id: str, timeout: float
|
|
48
|
+
) -> ApprovalOutcome:
|
|
49
|
+
"""Wait for a decision on an approval request."""
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class InMemoryApprovalStore(ApprovalStore):
|
|
54
|
+
"""In-memory approval store with replay protection and expiry."""
|
|
55
|
+
|
|
56
|
+
def __init__(self):
|
|
57
|
+
self._requests: dict[str, dict[str, Any]] = {}
|
|
58
|
+
self._events: dict[str, asyncio.Event] = {}
|
|
59
|
+
|
|
60
|
+
async def create_request(
|
|
61
|
+
self, agent_ctx, intent, tool_request, request_hash, reason, expiry
|
|
62
|
+
) -> str:
|
|
63
|
+
approval_id = str(uuid.uuid4())
|
|
64
|
+
self._requests[approval_id] = {
|
|
65
|
+
"id": approval_id,
|
|
66
|
+
"agent": agent_ctx.to_dict(),
|
|
67
|
+
"intent": intent.to_dict(),
|
|
68
|
+
"tool_request": tool_request.to_dict(),
|
|
69
|
+
"request_hash": request_hash,
|
|
70
|
+
"reason": reason,
|
|
71
|
+
"expiry": expiry,
|
|
72
|
+
"outcome": ApprovalOutcome.DEFERRED,
|
|
73
|
+
}
|
|
74
|
+
self._events[approval_id] = asyncio.Event()
|
|
75
|
+
return approval_id
|
|
76
|
+
|
|
77
|
+
async def set_decision(
|
|
78
|
+
self, approval_id, outcome, decided_by, decided_at, request_hash
|
|
79
|
+
):
|
|
80
|
+
if approval_id in self._requests:
|
|
81
|
+
req = self._requests[approval_id]
|
|
82
|
+
# Replay protection: hash must match
|
|
83
|
+
if req["request_hash"] != request_hash:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
"Request hash mismatch. Approval bound to a different request."
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
req["outcome"] = outcome
|
|
89
|
+
req["decided_by"] = decided_by
|
|
90
|
+
req["decided_at"] = decided_at
|
|
91
|
+
if approval_id in self._events:
|
|
92
|
+
self._events[approval_id].set()
|
|
93
|
+
|
|
94
|
+
async def get_request(self, approval_id):
|
|
95
|
+
return self._requests.get(approval_id)
|
|
96
|
+
|
|
97
|
+
async def wait_for_decision(self, approval_id, timeout):
|
|
98
|
+
event = self._events.get(approval_id)
|
|
99
|
+
if not event:
|
|
100
|
+
return ApprovalOutcome.TIMEOUT
|
|
101
|
+
|
|
102
|
+
req = self._requests.get(approval_id)
|
|
103
|
+
if req and req["expiry"] < time.time():
|
|
104
|
+
return ApprovalOutcome.TIMEOUT
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
await asyncio.wait_for(event.wait(), timeout=timeout)
|
|
108
|
+
return self._requests[approval_id]["outcome"]
|
|
109
|
+
except asyncio.TimeoutError:
|
|
110
|
+
return ApprovalOutcome.TIMEOUT
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class Approver(Protocol):
|
|
114
|
+
"""Async-first approver protocol."""
|
|
115
|
+
|
|
116
|
+
async def request_approval_async(
|
|
117
|
+
self,
|
|
118
|
+
agent_ctx: AgentContext,
|
|
119
|
+
intent: Intent,
|
|
120
|
+
tool_request: ToolRequest,
|
|
121
|
+
request_hash: str,
|
|
122
|
+
reason: str,
|
|
123
|
+
) -> ApprovalOutcome:
|
|
124
|
+
"""Request approval from a human or another system."""
|
|
125
|
+
...
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class AsyncQueueApprover:
|
|
129
|
+
"""An approver that queues requests in a store and waits for a decision."""
|
|
130
|
+
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
store: ApprovalStore,
|
|
134
|
+
timeout: float = 3600.0,
|
|
135
|
+
default_outcome: ApprovalOutcome = ApprovalOutcome.DENIED,
|
|
136
|
+
):
|
|
137
|
+
self.store = store
|
|
138
|
+
self.timeout = timeout
|
|
139
|
+
self.default_outcome = default_outcome
|
|
140
|
+
|
|
141
|
+
async def request_approval_async(
|
|
142
|
+
self, agent_ctx, intent, tool_request, request_hash, reason
|
|
143
|
+
) -> ApprovalOutcome:
|
|
144
|
+
expiry = time.time() + self.timeout
|
|
145
|
+
approval_id = await self.store.create_request(
|
|
146
|
+
agent_ctx, intent, tool_request, request_hash, reason, expiry
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
outcome = await self.store.wait_for_decision(approval_id, self.timeout)
|
|
150
|
+
if outcome == ApprovalOutcome.TIMEOUT:
|
|
151
|
+
return self.default_outcome
|
|
152
|
+
return outcome
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class AutoApprover:
|
|
156
|
+
"""Non-interactive approver for tests and examples."""
|
|
157
|
+
|
|
158
|
+
async def request_approval_async(
|
|
159
|
+
self,
|
|
160
|
+
_agent_ctx: AgentContext,
|
|
161
|
+
_intent: Intent,
|
|
162
|
+
tool_request: ToolRequest,
|
|
163
|
+
_request_hash: str,
|
|
164
|
+
_reason: str,
|
|
165
|
+
) -> ApprovalOutcome:
|
|
166
|
+
# Decision: approve ASK only when tool_request.effect == READ
|
|
167
|
+
if tool_request.effect == Effect.READ:
|
|
168
|
+
return ApprovalOutcome.APPROVED
|
|
169
|
+
return ApprovalOutcome.DENIED
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class CliApprover:
|
|
173
|
+
"""Async-wrapped CLI approver for development."""
|
|
174
|
+
|
|
175
|
+
def __init__(self, show_emojis: bool = True):
|
|
176
|
+
self.show_emojis = show_emojis
|
|
177
|
+
|
|
178
|
+
async def request_approval_async(
|
|
179
|
+
self, agent_ctx, intent, tool_request, _hash, reason
|
|
180
|
+
) -> ApprovalOutcome:
|
|
181
|
+
loop = asyncio.get_event_loop()
|
|
182
|
+
return await loop.run_in_executor(
|
|
183
|
+
None,
|
|
184
|
+
self._sync_request,
|
|
185
|
+
agent_ctx,
|
|
186
|
+
intent,
|
|
187
|
+
tool_request,
|
|
188
|
+
reason,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
def _sync_request(self, agent_ctx, intent, tool_request, reason) -> ApprovalOutcome:
|
|
192
|
+
prefix = "🚦 " if self.show_emojis else ""
|
|
193
|
+
print("\n" + "=" * 40)
|
|
194
|
+
print(f"{prefix}TOLLGATE APPROVAL REQUESTED")
|
|
195
|
+
print("=" * 40)
|
|
196
|
+
print(f"Reason: {reason}")
|
|
197
|
+
print(f"Agent: {agent_ctx.agent_id} (v{agent_ctx.version})")
|
|
198
|
+
print(f"Intent: {intent.action} - {intent.reason}")
|
|
199
|
+
print(f"Tool: {tool_request.tool}.{tool_request.action}")
|
|
200
|
+
print(f"Params: {tool_request.params}")
|
|
201
|
+
print("-" * 40)
|
|
202
|
+
choice = input("Approve this tool call? (y/N): ").strip().lower()
|
|
203
|
+
return ApprovalOutcome.APPROVED if choice == "y" else ApprovalOutcome.DENIED
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def compute_request_hash(
|
|
207
|
+
agent_ctx: AgentContext, intent: Intent, tool_request: ToolRequest
|
|
208
|
+
) -> str:
|
|
209
|
+
"""Compute a deterministic hash for a tool request."""
|
|
210
|
+
|
|
211
|
+
def canonicalize(d: dict[str, Any]) -> str:
|
|
212
|
+
return json.dumps(d, sort_keys=True)
|
|
213
|
+
|
|
214
|
+
payload = "|".join(
|
|
215
|
+
[
|
|
216
|
+
agent_ctx.agent_id,
|
|
217
|
+
agent_ctx.version,
|
|
218
|
+
intent.action,
|
|
219
|
+
tool_request.tool,
|
|
220
|
+
tool_request.action,
|
|
221
|
+
tool_request.effect.value,
|
|
222
|
+
tool_request.resource_type,
|
|
223
|
+
canonicalize(tool_request.params),
|
|
224
|
+
canonicalize(tool_request.metadata),
|
|
225
|
+
]
|
|
226
|
+
)
|
|
227
|
+
return hashlib.sha256(payload.encode()).hexdigest()
|
tollgate/audit.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Protocol
|
|
4
|
+
|
|
5
|
+
from .types import AuditEvent
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AuditSink(Protocol):
|
|
9
|
+
"""Protocol for auditing tool execution results."""
|
|
10
|
+
|
|
11
|
+
def emit(self, event: AuditEvent) -> None:
|
|
12
|
+
"""Emit an audit event."""
|
|
13
|
+
...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class JsonlAuditSink:
|
|
17
|
+
"""Audit sink that writes to a JSONL file with buffering."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, log_path: str | Path):
|
|
20
|
+
"""Initialize the sink and ensure the log directory exists."""
|
|
21
|
+
self.log_path = Path(log_path)
|
|
22
|
+
self.log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
23
|
+
self._f = None
|
|
24
|
+
|
|
25
|
+
def _get_file(self):
|
|
26
|
+
if self._f is None or self._f.closed:
|
|
27
|
+
self._f = self.log_path.open("a", encoding="utf-8", buffering=1)
|
|
28
|
+
return self._f
|
|
29
|
+
|
|
30
|
+
def emit(self, event: AuditEvent) -> None:
|
|
31
|
+
"""Append an audit event to the JSONL file."""
|
|
32
|
+
f = self._get_file()
|
|
33
|
+
f.write(json.dumps(event.to_dict(), ensure_ascii=False) + "\n")
|
|
34
|
+
|
|
35
|
+
def close(self):
|
|
36
|
+
"""Close the file handle."""
|
|
37
|
+
if self._f and not self._f.closed:
|
|
38
|
+
self._f.close()
|
|
39
|
+
|
|
40
|
+
def __enter__(self):
|
|
41
|
+
return self
|
|
42
|
+
|
|
43
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
44
|
+
self.close()
|
|
45
|
+
|
|
46
|
+
def __del__(self):
|
|
47
|
+
self.close()
|
tollgate/exceptions.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
class TollgateError(Exception):
|
|
2
|
+
"""Base exception for all Tollgate errors."""
|
|
3
|
+
|
|
4
|
+
pass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TollgateDenied(TollgateError): # noqa: N818
|
|
8
|
+
"""Raised when a tool call is explicitly denied by policy."""
|
|
9
|
+
|
|
10
|
+
def __init__(self, reason: str):
|
|
11
|
+
self.reason = reason
|
|
12
|
+
super().__init__(f"Tool call denied: {reason}")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TollgateApprovalDenied(TollgateError): # noqa: N818
|
|
16
|
+
"""Raised when a human-in-the-loop approval is denied."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, reason: str = "Approval denied by human."):
|
|
19
|
+
self.reason = reason
|
|
20
|
+
super().__init__(reason)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TollgateDeferred(TollgateError): # noqa: N818
|
|
24
|
+
"""Raised when a tool call is deferred (e.g., waiting for async approval)."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, approval_id: str):
|
|
27
|
+
self.approval_id = approval_id
|
|
28
|
+
super().__init__(f"Tool call deferred. Approval ID: {approval_id}")
|
tollgate/helpers.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from functools import wraps
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from .tower import ControlTower
|
|
7
|
+
from .types import AgentContext, Effect, Intent, ToolRequest
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def wrap_tool(
|
|
11
|
+
tower: ControlTower,
|
|
12
|
+
tool_callable: Callable,
|
|
13
|
+
*,
|
|
14
|
+
tool: str,
|
|
15
|
+
action: str,
|
|
16
|
+
resource_type: str,
|
|
17
|
+
effect: Effect,
|
|
18
|
+
):
|
|
19
|
+
"""
|
|
20
|
+
Wraps a tool callable to be executed through a ControlTower.
|
|
21
|
+
Maintained for backward compatibility with v0.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
@wraps(tool_callable)
|
|
25
|
+
def wrapper(
|
|
26
|
+
agent_ctx: AgentContext,
|
|
27
|
+
intent: Intent,
|
|
28
|
+
metadata: dict[str, Any] | None = None,
|
|
29
|
+
**params,
|
|
30
|
+
) -> Any:
|
|
31
|
+
req = ToolRequest(
|
|
32
|
+
tool=tool,
|
|
33
|
+
action=action,
|
|
34
|
+
resource_type=resource_type,
|
|
35
|
+
effect=effect,
|
|
36
|
+
params=params,
|
|
37
|
+
metadata=metadata or {},
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
if asyncio.iscoroutinefunction(tool_callable):
|
|
41
|
+
|
|
42
|
+
async def _exec():
|
|
43
|
+
return await tool_callable(**params)
|
|
44
|
+
|
|
45
|
+
return asyncio.run(tower.execute_async(agent_ctx, intent, req, _exec))
|
|
46
|
+
|
|
47
|
+
return tower.execute(agent_ctx, intent, req, lambda: tool_callable(**params))
|
|
48
|
+
|
|
49
|
+
return wrapper
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def guard(
|
|
53
|
+
tower: ControlTower,
|
|
54
|
+
*,
|
|
55
|
+
tool: str,
|
|
56
|
+
action: str,
|
|
57
|
+
resource_type: str,
|
|
58
|
+
effect: Effect,
|
|
59
|
+
):
|
|
60
|
+
"""Decorator to guard a tool function with a ControlTower."""
|
|
61
|
+
|
|
62
|
+
def decorator(func: Callable):
|
|
63
|
+
return wrap_tool(
|
|
64
|
+
tower,
|
|
65
|
+
func,
|
|
66
|
+
tool=tool,
|
|
67
|
+
action=action,
|
|
68
|
+
resource_type=resource_type,
|
|
69
|
+
effect=effect,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return decorator
|
|
File without changes
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from ..registry import ToolRegistry
|
|
4
|
+
from ..tower import ControlTower
|
|
5
|
+
from ..types import AgentContext, Intent, ToolRequest
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TollgateMCPClient:
|
|
9
|
+
"""A wrapper for an MCP client that gates tool calls through Tollgate."""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
client: Any,
|
|
14
|
+
server_name: str,
|
|
15
|
+
tower: ControlTower,
|
|
16
|
+
registry: ToolRegistry,
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Initialize the TollgateMCPClient.
|
|
20
|
+
|
|
21
|
+
:param client: The underlying MCP client (must have a call_tool).
|
|
22
|
+
:param server_name: The explicit name of the MCP server.
|
|
23
|
+
:param tower: The Tollgate ControlTower instance.
|
|
24
|
+
:param registry: The ToolRegistry instance.
|
|
25
|
+
"""
|
|
26
|
+
self.client = client
|
|
27
|
+
self.server_name = server_name
|
|
28
|
+
self.tower = tower
|
|
29
|
+
self.registry = registry
|
|
30
|
+
|
|
31
|
+
async def call_tool(
|
|
32
|
+
self,
|
|
33
|
+
tool_name: str,
|
|
34
|
+
arguments: dict[str, Any],
|
|
35
|
+
agent_ctx: AgentContext,
|
|
36
|
+
intent: Intent,
|
|
37
|
+
metadata: dict[str, Any] | None = None,
|
|
38
|
+
) -> Any:
|
|
39
|
+
"""
|
|
40
|
+
Call an MCP tool, intercepted by Tollgate.
|
|
41
|
+
"""
|
|
42
|
+
tool_key = f"mcp:{self.server_name}.{tool_name}"
|
|
43
|
+
effect, resource_type, manifest_version = self.registry.resolve_tool(tool_key)
|
|
44
|
+
|
|
45
|
+
request = ToolRequest(
|
|
46
|
+
tool="mcp",
|
|
47
|
+
action=tool_name,
|
|
48
|
+
resource_type=resource_type,
|
|
49
|
+
effect=effect,
|
|
50
|
+
params=arguments,
|
|
51
|
+
metadata=metadata or {},
|
|
52
|
+
manifest_version=manifest_version,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
async def _exec_async():
|
|
56
|
+
return await self.client.call_tool(tool_name, arguments)
|
|
57
|
+
|
|
58
|
+
return await self.tower.execute_async(agent_ctx, intent, request, _exec_async)
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from ..registry import ToolRegistry
|
|
5
|
+
from ..tower import ControlTower
|
|
6
|
+
from ..types import AgentContext, Intent, ToolRequest
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class GuardedStrandsTool:
|
|
10
|
+
"""A wrapper for Strands tools that enforces Tollgate gating."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, tool: Any, tower: ControlTower, registry: ToolRegistry):
|
|
13
|
+
self.tool = tool
|
|
14
|
+
self.tower = tower
|
|
15
|
+
self.registry = registry
|
|
16
|
+
|
|
17
|
+
# Resolve tool name
|
|
18
|
+
if callable(tool) and hasattr(tool, "__name__"):
|
|
19
|
+
self.name = tool.__name__
|
|
20
|
+
elif hasattr(tool, "name"):
|
|
21
|
+
self.name = tool.name
|
|
22
|
+
else:
|
|
23
|
+
self.name = tool.__class__.__name__
|
|
24
|
+
|
|
25
|
+
self.description = getattr(tool, "description", f"Strands tool: {self.name}")
|
|
26
|
+
|
|
27
|
+
async def __call__(
|
|
28
|
+
self,
|
|
29
|
+
tool_input: Any,
|
|
30
|
+
agent_ctx: AgentContext,
|
|
31
|
+
intent: Intent,
|
|
32
|
+
metadata: dict[str, Any] | None = None,
|
|
33
|
+
**kwargs,
|
|
34
|
+
) -> Any:
|
|
35
|
+
return await self.run_async(tool_input, agent_ctx, intent, metadata, **kwargs)
|
|
36
|
+
|
|
37
|
+
async def run_async(
|
|
38
|
+
self,
|
|
39
|
+
tool_input: Any,
|
|
40
|
+
agent_ctx: AgentContext,
|
|
41
|
+
intent: Intent,
|
|
42
|
+
metadata: dict[str, Any] | None = None,
|
|
43
|
+
**kwargs,
|
|
44
|
+
) -> Any:
|
|
45
|
+
tool_key = f"strands:{self.name}"
|
|
46
|
+
effect, resource_type, manifest_version = self.registry.resolve_tool(tool_key)
|
|
47
|
+
|
|
48
|
+
params = tool_input if isinstance(tool_input, dict) else {"input": tool_input}
|
|
49
|
+
if kwargs:
|
|
50
|
+
params.update(kwargs)
|
|
51
|
+
|
|
52
|
+
request = ToolRequest(
|
|
53
|
+
tool="strands",
|
|
54
|
+
action=self.name,
|
|
55
|
+
resource_type=resource_type,
|
|
56
|
+
effect=effect,
|
|
57
|
+
params=params,
|
|
58
|
+
metadata=metadata or {},
|
|
59
|
+
manifest_version=manifest_version,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
async def _exec_async():
|
|
63
|
+
# Support various calling conventions
|
|
64
|
+
if hasattr(self.tool, "ainvoke"):
|
|
65
|
+
return await self.tool.ainvoke(tool_input, **kwargs)
|
|
66
|
+
if hasattr(self.tool, "arun"):
|
|
67
|
+
return await self.tool.arun(tool_input, **kwargs)
|
|
68
|
+
if asyncio.iscoroutinefunction(self.tool):
|
|
69
|
+
return await self.tool(tool_input, **kwargs)
|
|
70
|
+
|
|
71
|
+
# Sync fallback
|
|
72
|
+
def _sync_call():
|
|
73
|
+
if hasattr(self.tool, "invoke"):
|
|
74
|
+
return self.tool.invoke(tool_input, **kwargs)
|
|
75
|
+
if hasattr(self.tool, "run"):
|
|
76
|
+
return self.tool.run(tool_input, **kwargs)
|
|
77
|
+
return self.tool(tool_input, **kwargs)
|
|
78
|
+
|
|
79
|
+
loop = asyncio.get_event_loop()
|
|
80
|
+
return await loop.run_in_executor(None, _sync_call)
|
|
81
|
+
|
|
82
|
+
return await self.tower.execute_async(agent_ctx, intent, request, _exec_async)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def guard_tools(
|
|
86
|
+
tools: list[Any], tower: ControlTower, registry: ToolRegistry
|
|
87
|
+
) -> list[GuardedStrandsTool]:
|
|
88
|
+
"""Wrap a list of Strands tools with Tollgate."""
|
|
89
|
+
return [GuardedStrandsTool(t, tower, registry) for t in tools]
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from .base import TollgateInterceptor, ToolAdapter
|
|
2
|
+
from .langchain import LangChainAdapter, guard_tools
|
|
3
|
+
from .openai import OpenAIAdapter, OpenAIToolRunner
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"ToolAdapter",
|
|
7
|
+
"TollgateInterceptor",
|
|
8
|
+
"LangChainAdapter",
|
|
9
|
+
"guard_tools",
|
|
10
|
+
"OpenAIAdapter",
|
|
11
|
+
"OpenAIToolRunner",
|
|
12
|
+
]
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import Any, Protocol
|
|
3
|
+
|
|
4
|
+
from ..tower import ControlTower
|
|
5
|
+
from ..types import AgentContext, Intent, NormalizedToolCall
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ToolAdapter(Protocol):
|
|
9
|
+
"""Protocol for framework-specific tool adapters."""
|
|
10
|
+
|
|
11
|
+
def normalize(self, tool_call: Any) -> NormalizedToolCall:
|
|
12
|
+
"""Normalize a framework-specific tool call."""
|
|
13
|
+
...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TollgateInterceptor:
|
|
17
|
+
"""Core interceptor for gating tool calls."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, tower: ControlTower, adapter: ToolAdapter):
|
|
20
|
+
self.tower = tower
|
|
21
|
+
self.adapter = adapter
|
|
22
|
+
|
|
23
|
+
async def intercept_async(
|
|
24
|
+
self, agent_ctx: AgentContext, intent: Intent, tool_call: Any
|
|
25
|
+
) -> Any:
|
|
26
|
+
"""Intercept and gate a tool call asynchronously."""
|
|
27
|
+
normalized = self.adapter.normalize(tool_call)
|
|
28
|
+
return await self.tower.execute_async(
|
|
29
|
+
agent_ctx, intent, normalized.request, normalized.exec_async
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def intercept(self, agent_ctx: AgentContext, intent: Intent, tool_call: Any) -> Any:
|
|
33
|
+
"""Intercept and gate a tool call synchronously."""
|
|
34
|
+
normalized = self.adapter.normalize(tool_call)
|
|
35
|
+
# For sync, we use tower.execute which handles the loop safety check
|
|
36
|
+
return self.tower.execute(
|
|
37
|
+
agent_ctx,
|
|
38
|
+
intent,
|
|
39
|
+
normalized.request,
|
|
40
|
+
lambda: asyncio.run(normalized.exec_async()),
|
|
41
|
+
)
|