tollgate 1.0.4__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tollgate/__init__.py +36 -3
- tollgate/anomaly_detector.py +396 -0
- tollgate/audit.py +90 -1
- tollgate/backends/__init__.py +37 -0
- tollgate/backends/redis_store.py +411 -0
- tollgate/backends/sqlite_store.py +458 -0
- tollgate/circuit_breaker.py +206 -0
- tollgate/context_monitor.py +292 -0
- tollgate/exceptions.py +20 -0
- tollgate/grants.py +46 -0
- tollgate/manifest_signing.py +90 -0
- tollgate/network_guard.py +114 -0
- tollgate/policy.py +37 -0
- tollgate/policy_testing.py +360 -0
- tollgate/rate_limiter.py +162 -0
- tollgate/registry.py +225 -2
- tollgate/tower.py +184 -12
- tollgate/types.py +21 -1
- tollgate/verification.py +81 -0
- tollgate-1.4.0.dist-info/METADATA +393 -0
- tollgate-1.4.0.dist-info/RECORD +33 -0
- tollgate-1.4.0.dist-info/entry_points.txt +2 -0
- tollgate-1.0.4.dist-info/METADATA +0 -144
- tollgate-1.0.4.dist-info/RECORD +0 -21
- {tollgate-1.0.4.dist-info → tollgate-1.4.0.dist-info}/WHEEL +0 -0
- {tollgate-1.0.4.dist-info → tollgate-1.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,458 @@
|
|
|
1
|
+
"""SQLite-backed persistent stores for Tollgate.
|
|
2
|
+
|
|
3
|
+
Zero additional dependencies — uses Python's built-in ``sqlite3`` module.
|
|
4
|
+
Suitable for single-process deployments. For multi-process or multi-host
|
|
5
|
+
deployments, use the Redis backends instead.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
|
|
9
|
+
from tollgate.backends import SQLiteGrantStore, SQLiteApprovalStore
|
|
10
|
+
|
|
11
|
+
grant_store = SQLiteGrantStore("tollgate.db")
|
|
12
|
+
approval_store = SQLiteApprovalStore("tollgate.db")
|
|
13
|
+
|
|
14
|
+
tower = ControlTower(
|
|
15
|
+
...,
|
|
16
|
+
grant_store=grant_store,
|
|
17
|
+
)
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import asyncio
|
|
21
|
+
import json
|
|
22
|
+
import sqlite3
|
|
23
|
+
import time
|
|
24
|
+
import uuid
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from typing import Any
|
|
27
|
+
|
|
28
|
+
from ..types import AgentContext, ApprovalOutcome, Effect, Grant, ToolRequest
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SQLiteGrantStore:
|
|
32
|
+
"""SQLite-backed GrantStore implementation.
|
|
33
|
+
|
|
34
|
+
Satisfies the ``GrantStore`` protocol. Uses WAL mode for concurrent
|
|
35
|
+
reads. All operations run in a thread executor to avoid blocking the
|
|
36
|
+
event loop.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
db_path: Path to the SQLite database file. Use ``:memory:`` for
|
|
40
|
+
testing (non-persistent).
|
|
41
|
+
table_name: Name of the grants table (default ``tollgate_grants``).
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, db_path: str | Path = "tollgate_grants.db", *, table_name: str = "tollgate_grants"):
|
|
45
|
+
self._db_path = str(db_path)
|
|
46
|
+
self._table = table_name
|
|
47
|
+
self._conn = sqlite3.connect(self._db_path, check_same_thread=False)
|
|
48
|
+
self._conn.execute("PRAGMA journal_mode=WAL")
|
|
49
|
+
self._conn.execute("PRAGMA foreign_keys=ON")
|
|
50
|
+
self._create_table()
|
|
51
|
+
|
|
52
|
+
def _create_table(self):
|
|
53
|
+
self._conn.execute(f"""
|
|
54
|
+
CREATE TABLE IF NOT EXISTS {self._table} (
|
|
55
|
+
id TEXT PRIMARY KEY,
|
|
56
|
+
agent_id TEXT,
|
|
57
|
+
effect TEXT,
|
|
58
|
+
tool TEXT,
|
|
59
|
+
action TEXT,
|
|
60
|
+
resource_type TEXT,
|
|
61
|
+
expires_at REAL NOT NULL,
|
|
62
|
+
granted_by TEXT NOT NULL,
|
|
63
|
+
created_at REAL NOT NULL,
|
|
64
|
+
reason TEXT,
|
|
65
|
+
usage_count INTEGER DEFAULT 0,
|
|
66
|
+
revoked INTEGER DEFAULT 0
|
|
67
|
+
)
|
|
68
|
+
""")
|
|
69
|
+
# Index for fast expiry cleanup and matching
|
|
70
|
+
self._conn.execute(f"""
|
|
71
|
+
CREATE INDEX IF NOT EXISTS idx_{self._table}_expires
|
|
72
|
+
ON {self._table} (expires_at, revoked)
|
|
73
|
+
""")
|
|
74
|
+
self._conn.execute(f"""
|
|
75
|
+
CREATE INDEX IF NOT EXISTS idx_{self._table}_agent
|
|
76
|
+
ON {self._table} (agent_id, revoked)
|
|
77
|
+
""")
|
|
78
|
+
self._conn.commit()
|
|
79
|
+
|
|
80
|
+
def _grant_to_row(self, grant: Grant) -> dict[str, Any]:
|
|
81
|
+
return {
|
|
82
|
+
"id": grant.id,
|
|
83
|
+
"agent_id": grant.agent_id,
|
|
84
|
+
"effect": grant.effect.value if grant.effect else None,
|
|
85
|
+
"tool": grant.tool,
|
|
86
|
+
"action": grant.action,
|
|
87
|
+
"resource_type": grant.resource_type,
|
|
88
|
+
"expires_at": grant.expires_at,
|
|
89
|
+
"granted_by": grant.granted_by,
|
|
90
|
+
"created_at": grant.created_at,
|
|
91
|
+
"reason": grant.reason,
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
def _row_to_grant(self, row: sqlite3.Row | tuple) -> Grant:
|
|
95
|
+
# sqlite3.Row or tuple access
|
|
96
|
+
if isinstance(row, sqlite3.Row):
|
|
97
|
+
d = dict(row)
|
|
98
|
+
else:
|
|
99
|
+
cols = [
|
|
100
|
+
"id", "agent_id", "effect", "tool", "action",
|
|
101
|
+
"resource_type", "expires_at", "granted_by",
|
|
102
|
+
"created_at", "reason", "usage_count", "revoked",
|
|
103
|
+
]
|
|
104
|
+
d = dict(zip(cols, row))
|
|
105
|
+
|
|
106
|
+
return Grant(
|
|
107
|
+
id=d["id"],
|
|
108
|
+
agent_id=d["agent_id"],
|
|
109
|
+
effect=Effect(d["effect"]) if d["effect"] else None,
|
|
110
|
+
tool=d["tool"],
|
|
111
|
+
action=d["action"],
|
|
112
|
+
resource_type=d["resource_type"],
|
|
113
|
+
expires_at=d["expires_at"],
|
|
114
|
+
granted_by=d["granted_by"],
|
|
115
|
+
created_at=d["created_at"],
|
|
116
|
+
reason=d["reason"],
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
async def create_grant(self, grant: Grant) -> str:
|
|
120
|
+
row = self._grant_to_row(grant)
|
|
121
|
+
loop = asyncio.get_event_loop()
|
|
122
|
+
await loop.run_in_executor(None, self._insert_grant, row)
|
|
123
|
+
return grant.id
|
|
124
|
+
|
|
125
|
+
def _insert_grant(self, row: dict[str, Any]):
|
|
126
|
+
self._conn.execute(
|
|
127
|
+
f"""INSERT OR REPLACE INTO {self._table}
|
|
128
|
+
(id, agent_id, effect, tool, action, resource_type,
|
|
129
|
+
expires_at, granted_by, created_at, reason)
|
|
130
|
+
VALUES (:id, :agent_id, :effect, :tool, :action, :resource_type,
|
|
131
|
+
:expires_at, :granted_by, :created_at, :reason)""",
|
|
132
|
+
row,
|
|
133
|
+
)
|
|
134
|
+
self._conn.commit()
|
|
135
|
+
|
|
136
|
+
async def find_matching_grant(
|
|
137
|
+
self, agent_ctx: AgentContext, tool_request: ToolRequest
|
|
138
|
+
) -> Grant | None:
|
|
139
|
+
loop = asyncio.get_event_loop()
|
|
140
|
+
return await loop.run_in_executor(
|
|
141
|
+
None, self._find_matching_grant_sync, agent_ctx, tool_request
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def _find_matching_grant_sync(
|
|
145
|
+
self, agent_ctx: AgentContext, tool_request: ToolRequest
|
|
146
|
+
) -> Grant | None:
|
|
147
|
+
now = time.time()
|
|
148
|
+
cursor = self._conn.execute(
|
|
149
|
+
f"""SELECT * FROM {self._table}
|
|
150
|
+
WHERE expires_at > ? AND revoked = 0
|
|
151
|
+
ORDER BY created_at DESC""",
|
|
152
|
+
(now,),
|
|
153
|
+
)
|
|
154
|
+
cursor.row_factory = sqlite3.Row
|
|
155
|
+
|
|
156
|
+
for row in cursor:
|
|
157
|
+
d = dict(row)
|
|
158
|
+
|
|
159
|
+
# Match agent_id
|
|
160
|
+
if d["agent_id"] is not None and d["agent_id"] != agent_ctx.agent_id:
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
# Match effect
|
|
164
|
+
if d["effect"] is not None and d["effect"] != tool_request.effect.value:
|
|
165
|
+
continue
|
|
166
|
+
|
|
167
|
+
# Match tool (exact or prefix with *)
|
|
168
|
+
if d["tool"] is not None:
|
|
169
|
+
if d["tool"].endswith("*"):
|
|
170
|
+
prefix = d["tool"][:-1]
|
|
171
|
+
if not tool_request.tool.startswith(prefix):
|
|
172
|
+
continue
|
|
173
|
+
elif d["tool"] != tool_request.tool:
|
|
174
|
+
continue
|
|
175
|
+
|
|
176
|
+
# Match action
|
|
177
|
+
if d["action"] is not None and d["action"] != tool_request.action:
|
|
178
|
+
continue
|
|
179
|
+
|
|
180
|
+
# Match resource_type
|
|
181
|
+
if (
|
|
182
|
+
d["resource_type"] is not None
|
|
183
|
+
and d["resource_type"] != tool_request.resource_type
|
|
184
|
+
):
|
|
185
|
+
continue
|
|
186
|
+
|
|
187
|
+
# Match found — increment usage
|
|
188
|
+
self._conn.execute(
|
|
189
|
+
f"UPDATE {self._table} SET usage_count = usage_count + 1 WHERE id = ?",
|
|
190
|
+
(d["id"],),
|
|
191
|
+
)
|
|
192
|
+
self._conn.commit()
|
|
193
|
+
return self._row_to_grant(row)
|
|
194
|
+
|
|
195
|
+
return None
|
|
196
|
+
|
|
197
|
+
async def revoke_grant(self, grant_id: str) -> bool:
|
|
198
|
+
loop = asyncio.get_event_loop()
|
|
199
|
+
return await loop.run_in_executor(None, self._revoke_sync, grant_id)
|
|
200
|
+
|
|
201
|
+
def _revoke_sync(self, grant_id: str) -> bool:
|
|
202
|
+
cursor = self._conn.execute(
|
|
203
|
+
f"UPDATE {self._table} SET revoked = 1 WHERE id = ? AND revoked = 0",
|
|
204
|
+
(grant_id,),
|
|
205
|
+
)
|
|
206
|
+
self._conn.commit()
|
|
207
|
+
return cursor.rowcount > 0
|
|
208
|
+
|
|
209
|
+
async def list_active_grants(self, agent_id: str | None = None) -> list[Grant]:
|
|
210
|
+
loop = asyncio.get_event_loop()
|
|
211
|
+
return await loop.run_in_executor(
|
|
212
|
+
None, self._list_active_sync, agent_id
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
def _list_active_sync(self, agent_id: str | None) -> list[Grant]:
|
|
216
|
+
now = time.time()
|
|
217
|
+
if agent_id is not None:
|
|
218
|
+
cursor = self._conn.execute(
|
|
219
|
+
f"""SELECT * FROM {self._table}
|
|
220
|
+
WHERE expires_at > ? AND revoked = 0 AND agent_id = ?""",
|
|
221
|
+
(now, agent_id),
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
cursor = self._conn.execute(
|
|
225
|
+
f"SELECT * FROM {self._table} WHERE expires_at > ? AND revoked = 0",
|
|
226
|
+
(now,),
|
|
227
|
+
)
|
|
228
|
+
cursor.row_factory = sqlite3.Row
|
|
229
|
+
return [self._row_to_grant(row) for row in cursor]
|
|
230
|
+
|
|
231
|
+
async def cleanup_expired(self) -> int:
|
|
232
|
+
loop = asyncio.get_event_loop()
|
|
233
|
+
return await loop.run_in_executor(None, self._cleanup_sync)
|
|
234
|
+
|
|
235
|
+
def _cleanup_sync(self) -> int:
|
|
236
|
+
now = time.time()
|
|
237
|
+
cursor = self._conn.execute(
|
|
238
|
+
f"DELETE FROM {self._table} WHERE expires_at <= ?", (now,)
|
|
239
|
+
)
|
|
240
|
+
self._conn.commit()
|
|
241
|
+
return cursor.rowcount
|
|
242
|
+
|
|
243
|
+
async def get_usage_count(self, grant_id: str) -> int:
|
|
244
|
+
loop = asyncio.get_event_loop()
|
|
245
|
+
return await loop.run_in_executor(None, self._get_usage_sync, grant_id)
|
|
246
|
+
|
|
247
|
+
def _get_usage_sync(self, grant_id: str) -> int:
|
|
248
|
+
cursor = self._conn.execute(
|
|
249
|
+
f"SELECT usage_count FROM {self._table} WHERE id = ?", (grant_id,)
|
|
250
|
+
)
|
|
251
|
+
row = cursor.fetchone()
|
|
252
|
+
return row[0] if row else 0
|
|
253
|
+
|
|
254
|
+
def close(self):
|
|
255
|
+
"""Close the database connection."""
|
|
256
|
+
self._conn.close()
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
class SQLiteApprovalStore:
|
|
260
|
+
"""SQLite-backed ApprovalStore implementation.
|
|
261
|
+
|
|
262
|
+
Satisfies the ``ApprovalStore`` ABC. Uses polling-based wait_for_decision
|
|
263
|
+
since SQLite doesn't support notifications.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
db_path: Path to the SQLite database file.
|
|
267
|
+
table_name: Name of the approvals table.
|
|
268
|
+
poll_interval: Seconds between polls when waiting for a decision.
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
def __init__(
|
|
272
|
+
self,
|
|
273
|
+
db_path: str | Path = "tollgate_approvals.db",
|
|
274
|
+
*,
|
|
275
|
+
table_name: str = "tollgate_approvals",
|
|
276
|
+
poll_interval: float = 0.5,
|
|
277
|
+
):
|
|
278
|
+
self._db_path = str(db_path)
|
|
279
|
+
self._table = table_name
|
|
280
|
+
self._poll_interval = poll_interval
|
|
281
|
+
self._conn = sqlite3.connect(self._db_path, check_same_thread=False)
|
|
282
|
+
self._conn.execute("PRAGMA journal_mode=WAL")
|
|
283
|
+
self._create_table()
|
|
284
|
+
|
|
285
|
+
def _create_table(self):
|
|
286
|
+
self._conn.execute(f"""
|
|
287
|
+
CREATE TABLE IF NOT EXISTS {self._table} (
|
|
288
|
+
id TEXT PRIMARY KEY,
|
|
289
|
+
agent_json TEXT NOT NULL,
|
|
290
|
+
intent_json TEXT NOT NULL,
|
|
291
|
+
tool_request_json TEXT NOT NULL,
|
|
292
|
+
request_hash TEXT NOT NULL,
|
|
293
|
+
reason TEXT NOT NULL,
|
|
294
|
+
expiry REAL NOT NULL,
|
|
295
|
+
outcome TEXT NOT NULL DEFAULT 'deferred',
|
|
296
|
+
decided_by TEXT,
|
|
297
|
+
decided_at REAL
|
|
298
|
+
)
|
|
299
|
+
""")
|
|
300
|
+
self._conn.execute(f"""
|
|
301
|
+
CREATE INDEX IF NOT EXISTS idx_{self._table}_hash
|
|
302
|
+
ON {self._table} (request_hash)
|
|
303
|
+
""")
|
|
304
|
+
self._conn.commit()
|
|
305
|
+
|
|
306
|
+
async def create_request(
|
|
307
|
+
self,
|
|
308
|
+
agent_ctx: AgentContext,
|
|
309
|
+
intent: Any,
|
|
310
|
+
tool_request: ToolRequest,
|
|
311
|
+
request_hash: str,
|
|
312
|
+
reason: str,
|
|
313
|
+
expiry: float,
|
|
314
|
+
) -> str:
|
|
315
|
+
approval_id = str(uuid.uuid4())
|
|
316
|
+
loop = asyncio.get_event_loop()
|
|
317
|
+
await loop.run_in_executor(
|
|
318
|
+
None,
|
|
319
|
+
self._insert_request,
|
|
320
|
+
approval_id,
|
|
321
|
+
json.dumps(agent_ctx.to_dict()),
|
|
322
|
+
json.dumps(intent.to_dict()),
|
|
323
|
+
json.dumps(tool_request.to_dict()),
|
|
324
|
+
request_hash,
|
|
325
|
+
reason,
|
|
326
|
+
expiry,
|
|
327
|
+
)
|
|
328
|
+
return approval_id
|
|
329
|
+
|
|
330
|
+
def _insert_request(
|
|
331
|
+
self,
|
|
332
|
+
approval_id: str,
|
|
333
|
+
agent_json: str,
|
|
334
|
+
intent_json: str,
|
|
335
|
+
tool_request_json: str,
|
|
336
|
+
request_hash: str,
|
|
337
|
+
reason: str,
|
|
338
|
+
expiry: float,
|
|
339
|
+
):
|
|
340
|
+
self._conn.execute(
|
|
341
|
+
f"""INSERT INTO {self._table}
|
|
342
|
+
(id, agent_json, intent_json, tool_request_json,
|
|
343
|
+
request_hash, reason, expiry)
|
|
344
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
|
345
|
+
(
|
|
346
|
+
approval_id,
|
|
347
|
+
agent_json,
|
|
348
|
+
intent_json,
|
|
349
|
+
tool_request_json,
|
|
350
|
+
request_hash,
|
|
351
|
+
reason,
|
|
352
|
+
expiry,
|
|
353
|
+
),
|
|
354
|
+
)
|
|
355
|
+
self._conn.commit()
|
|
356
|
+
|
|
357
|
+
async def set_decision(
|
|
358
|
+
self,
|
|
359
|
+
approval_id: str,
|
|
360
|
+
outcome: ApprovalOutcome,
|
|
361
|
+
decided_by: str,
|
|
362
|
+
decided_at: float,
|
|
363
|
+
request_hash: str,
|
|
364
|
+
) -> None:
|
|
365
|
+
loop = asyncio.get_event_loop()
|
|
366
|
+
await loop.run_in_executor(
|
|
367
|
+
None,
|
|
368
|
+
self._set_decision_sync,
|
|
369
|
+
approval_id,
|
|
370
|
+
outcome,
|
|
371
|
+
decided_by,
|
|
372
|
+
decided_at,
|
|
373
|
+
request_hash,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
def _set_decision_sync(
|
|
377
|
+
self,
|
|
378
|
+
approval_id: str,
|
|
379
|
+
outcome: ApprovalOutcome,
|
|
380
|
+
decided_by: str,
|
|
381
|
+
decided_at: float,
|
|
382
|
+
request_hash: str,
|
|
383
|
+
):
|
|
384
|
+
cursor = self._conn.execute(
|
|
385
|
+
f"SELECT request_hash FROM {self._table} WHERE id = ?",
|
|
386
|
+
(approval_id,),
|
|
387
|
+
)
|
|
388
|
+
row = cursor.fetchone()
|
|
389
|
+
if row is None:
|
|
390
|
+
return
|
|
391
|
+
|
|
392
|
+
stored_hash = row[0]
|
|
393
|
+
if stored_hash != request_hash:
|
|
394
|
+
raise ValueError(
|
|
395
|
+
"Request hash mismatch. Approval bound to a different request."
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
self._conn.execute(
|
|
399
|
+
f"""UPDATE {self._table}
|
|
400
|
+
SET outcome = ?, decided_by = ?, decided_at = ?
|
|
401
|
+
WHERE id = ?""",
|
|
402
|
+
(outcome.value, decided_by, decided_at, approval_id),
|
|
403
|
+
)
|
|
404
|
+
self._conn.commit()
|
|
405
|
+
|
|
406
|
+
async def get_request(self, approval_id: str) -> dict[str, Any] | None:
|
|
407
|
+
loop = asyncio.get_event_loop()
|
|
408
|
+
return await loop.run_in_executor(
|
|
409
|
+
None, self._get_request_sync, approval_id
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
def _get_request_sync(self, approval_id: str) -> dict[str, Any] | None:
|
|
413
|
+
cursor = self._conn.execute(
|
|
414
|
+
f"SELECT * FROM {self._table} WHERE id = ?", (approval_id,)
|
|
415
|
+
)
|
|
416
|
+
cursor.row_factory = sqlite3.Row
|
|
417
|
+
row = cursor.fetchone()
|
|
418
|
+
if row is None:
|
|
419
|
+
return None
|
|
420
|
+
|
|
421
|
+
d = dict(row)
|
|
422
|
+
return {
|
|
423
|
+
"id": d["id"],
|
|
424
|
+
"agent": json.loads(d["agent_json"]),
|
|
425
|
+
"intent": json.loads(d["intent_json"]),
|
|
426
|
+
"tool_request": json.loads(d["tool_request_json"]),
|
|
427
|
+
"request_hash": d["request_hash"],
|
|
428
|
+
"reason": d["reason"],
|
|
429
|
+
"expiry": d["expiry"],
|
|
430
|
+
"outcome": ApprovalOutcome(d["outcome"]),
|
|
431
|
+
"decided_by": d.get("decided_by"),
|
|
432
|
+
"decided_at": d.get("decided_at"),
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
async def wait_for_decision(
|
|
436
|
+
self, approval_id: str, timeout: float
|
|
437
|
+
) -> ApprovalOutcome:
|
|
438
|
+
"""Poll the database for a decision, with timeout."""
|
|
439
|
+
deadline = time.time() + timeout
|
|
440
|
+
|
|
441
|
+
while time.time() < deadline:
|
|
442
|
+
req = await self.get_request(approval_id)
|
|
443
|
+
if req is None:
|
|
444
|
+
return ApprovalOutcome.TIMEOUT
|
|
445
|
+
|
|
446
|
+
if req["expiry"] < time.time():
|
|
447
|
+
return ApprovalOutcome.TIMEOUT
|
|
448
|
+
|
|
449
|
+
if req["outcome"] != ApprovalOutcome.DEFERRED:
|
|
450
|
+
return req["outcome"]
|
|
451
|
+
|
|
452
|
+
await asyncio.sleep(self._poll_interval)
|
|
453
|
+
|
|
454
|
+
return ApprovalOutcome.TIMEOUT
|
|
455
|
+
|
|
456
|
+
def close(self):
|
|
457
|
+
"""Close the database connection."""
|
|
458
|
+
self._conn.close()
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""Circuit breaker for AI agent tool calls.
|
|
2
|
+
|
|
3
|
+
Tracks consecutive failures per (tool, action) pair. After a configured
|
|
4
|
+
threshold, the circuit "opens" and all subsequent calls are auto-denied
|
|
5
|
+
for a cooldown period.
|
|
6
|
+
|
|
7
|
+
States:
|
|
8
|
+
CLOSED → normal operation, failures counted
|
|
9
|
+
OPEN → blocking all calls, waiting for cooldown
|
|
10
|
+
HALF_OPEN → cooldown expired, next call is a probe
|
|
11
|
+
- if probe succeeds → CLOSED
|
|
12
|
+
- if probe fails → OPEN (cooldown resets)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import time
|
|
17
|
+
from enum import Enum
|
|
18
|
+
from typing import Any, Protocol, runtime_checkable
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CircuitState(str, Enum):
|
|
22
|
+
CLOSED = "closed"
|
|
23
|
+
OPEN = "open"
|
|
24
|
+
HALF_OPEN = "half_open"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class _CircuitEntry:
|
|
28
|
+
"""Internal state for a single circuit."""
|
|
29
|
+
|
|
30
|
+
__slots__ = ("state", "failure_count", "last_failure_at", "opened_at")
|
|
31
|
+
|
|
32
|
+
def __init__(self):
|
|
33
|
+
self.state: CircuitState = CircuitState.CLOSED
|
|
34
|
+
self.failure_count: int = 0
|
|
35
|
+
self.last_failure_at: float = 0.0
|
|
36
|
+
self.opened_at: float = 0.0
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@runtime_checkable
|
|
40
|
+
class CircuitBreaker(Protocol):
|
|
41
|
+
"""Protocol for circuit breaker backends."""
|
|
42
|
+
|
|
43
|
+
async def before_call(self, tool: str, action: str) -> tuple[bool, str | None]:
|
|
44
|
+
"""Check if a call is allowed.
|
|
45
|
+
|
|
46
|
+
Returns (allowed, reason). If not allowed, reason explains why.
|
|
47
|
+
"""
|
|
48
|
+
...
|
|
49
|
+
|
|
50
|
+
async def record_success(self, tool: str, action: str) -> None:
|
|
51
|
+
"""Record a successful tool execution."""
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
async def record_failure(self, tool: str, action: str) -> None:
|
|
55
|
+
"""Record a failed tool execution."""
|
|
56
|
+
...
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class InMemoryCircuitBreaker:
|
|
60
|
+
"""In-memory circuit breaker with configurable thresholds.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
failure_threshold: Number of consecutive failures before opening.
|
|
64
|
+
cooldown_seconds: Seconds to wait before attempting a probe (HALF_OPEN).
|
|
65
|
+
half_open_max_calls: Number of successful probes needed to close.
|
|
66
|
+
|
|
67
|
+
Usage:
|
|
68
|
+
breaker = InMemoryCircuitBreaker(failure_threshold=5, cooldown_seconds=60)
|
|
69
|
+
|
|
70
|
+
tower = ControlTower(
|
|
71
|
+
...,
|
|
72
|
+
circuit_breaker=breaker,
|
|
73
|
+
)
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
*,
|
|
79
|
+
failure_threshold: int = 5,
|
|
80
|
+
cooldown_seconds: float = 60.0,
|
|
81
|
+
half_open_max_calls: int = 1,
|
|
82
|
+
):
|
|
83
|
+
if failure_threshold < 1:
|
|
84
|
+
raise ValueError("failure_threshold must be >= 1")
|
|
85
|
+
if cooldown_seconds <= 0:
|
|
86
|
+
raise ValueError("cooldown_seconds must be > 0")
|
|
87
|
+
|
|
88
|
+
self.failure_threshold = failure_threshold
|
|
89
|
+
self.cooldown_seconds = cooldown_seconds
|
|
90
|
+
self.half_open_max_calls = half_open_max_calls
|
|
91
|
+
self._circuits: dict[str, _CircuitEntry] = {}
|
|
92
|
+
self._lock = asyncio.Lock()
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def _key(tool: str, action: str) -> str:
|
|
96
|
+
return f"{tool}:{action}"
|
|
97
|
+
|
|
98
|
+
async def before_call(
|
|
99
|
+
self, tool: str, action: str
|
|
100
|
+
) -> tuple[bool, str | None]:
|
|
101
|
+
"""Check if the circuit allows a call through."""
|
|
102
|
+
now = time.time()
|
|
103
|
+
key = self._key(tool, action)
|
|
104
|
+
|
|
105
|
+
async with self._lock:
|
|
106
|
+
entry = self._circuits.get(key)
|
|
107
|
+
if entry is None:
|
|
108
|
+
return True, None
|
|
109
|
+
|
|
110
|
+
if entry.state == CircuitState.CLOSED:
|
|
111
|
+
return True, None
|
|
112
|
+
|
|
113
|
+
if entry.state == CircuitState.OPEN:
|
|
114
|
+
# Check if cooldown has elapsed
|
|
115
|
+
elapsed = now - entry.opened_at
|
|
116
|
+
if elapsed >= self.cooldown_seconds:
|
|
117
|
+
# Transition to HALF_OPEN — allow a probe
|
|
118
|
+
entry.state = CircuitState.HALF_OPEN
|
|
119
|
+
entry.failure_count = 0
|
|
120
|
+
return True, None
|
|
121
|
+
else:
|
|
122
|
+
remaining = self.cooldown_seconds - elapsed
|
|
123
|
+
return False, (
|
|
124
|
+
f"Circuit OPEN for {tool}.{action}: "
|
|
125
|
+
f"{remaining:.1f}s remaining in cooldown "
|
|
126
|
+
f"(opened after {self.failure_threshold} consecutive failures)"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if entry.state == CircuitState.HALF_OPEN:
|
|
130
|
+
# Allow the probe call through
|
|
131
|
+
return True, None
|
|
132
|
+
|
|
133
|
+
return True, None
|
|
134
|
+
|
|
135
|
+
async def record_success(self, tool: str, action: str) -> None:
|
|
136
|
+
"""Record a success — close the circuit if in HALF_OPEN."""
|
|
137
|
+
key = self._key(tool, action)
|
|
138
|
+
|
|
139
|
+
async with self._lock:
|
|
140
|
+
entry = self._circuits.get(key)
|
|
141
|
+
if entry is None:
|
|
142
|
+
return
|
|
143
|
+
|
|
144
|
+
if entry.state == CircuitState.HALF_OPEN:
|
|
145
|
+
# Probe succeeded — close the circuit
|
|
146
|
+
entry.state = CircuitState.CLOSED
|
|
147
|
+
entry.failure_count = 0
|
|
148
|
+
|
|
149
|
+
elif entry.state == CircuitState.CLOSED:
|
|
150
|
+
# Reset failure count on success
|
|
151
|
+
entry.failure_count = 0
|
|
152
|
+
|
|
153
|
+
async def record_failure(self, tool: str, action: str) -> None:
|
|
154
|
+
"""Record a failure — may open the circuit."""
|
|
155
|
+
now = time.time()
|
|
156
|
+
key = self._key(tool, action)
|
|
157
|
+
|
|
158
|
+
async with self._lock:
|
|
159
|
+
entry = self._circuits.get(key)
|
|
160
|
+
if entry is None:
|
|
161
|
+
entry = _CircuitEntry()
|
|
162
|
+
self._circuits[key] = entry
|
|
163
|
+
|
|
164
|
+
entry.failure_count += 1
|
|
165
|
+
entry.last_failure_at = now
|
|
166
|
+
|
|
167
|
+
if entry.state == CircuitState.HALF_OPEN:
|
|
168
|
+
# Probe failed — back to OPEN
|
|
169
|
+
entry.state = CircuitState.OPEN
|
|
170
|
+
entry.opened_at = now
|
|
171
|
+
|
|
172
|
+
elif entry.state == CircuitState.CLOSED:
|
|
173
|
+
if entry.failure_count >= self.failure_threshold:
|
|
174
|
+
entry.state = CircuitState.OPEN
|
|
175
|
+
entry.opened_at = now
|
|
176
|
+
|
|
177
|
+
async def get_state(self, tool: str, action: str) -> CircuitState:
|
|
178
|
+
"""Get the current circuit state (for monitoring/testing)."""
|
|
179
|
+
key = self._key(tool, action)
|
|
180
|
+
async with self._lock:
|
|
181
|
+
entry = self._circuits.get(key)
|
|
182
|
+
if entry is None:
|
|
183
|
+
return CircuitState.CLOSED
|
|
184
|
+
return entry.state
|
|
185
|
+
|
|
186
|
+
async def reset(self, tool: str | None = None, action: str | None = None) -> None:
|
|
187
|
+
"""Reset circuit state. If tool/action given, reset only that circuit."""
|
|
188
|
+
async with self._lock:
|
|
189
|
+
if tool is not None and action is not None:
|
|
190
|
+
key = self._key(tool, action)
|
|
191
|
+
self._circuits.pop(key, None)
|
|
192
|
+
else:
|
|
193
|
+
self._circuits.clear()
|
|
194
|
+
|
|
195
|
+
async def get_all_states(self) -> dict[str, dict[str, Any]]:
|
|
196
|
+
"""Get all circuit states (for monitoring dashboards)."""
|
|
197
|
+
async with self._lock:
|
|
198
|
+
return {
|
|
199
|
+
key: {
|
|
200
|
+
"state": entry.state.value,
|
|
201
|
+
"failure_count": entry.failure_count,
|
|
202
|
+
"last_failure_at": entry.last_failure_at,
|
|
203
|
+
"opened_at": entry.opened_at,
|
|
204
|
+
}
|
|
205
|
+
for key, entry in self._circuits.items()
|
|
206
|
+
}
|