tollgate 1.0.5__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,458 @@
1
+ """SQLite-backed persistent stores for Tollgate.
2
+
3
+ Zero additional dependencies — uses Python's built-in ``sqlite3`` module.
4
+ Suitable for single-process deployments. For multi-process or multi-host
5
+ deployments, use the Redis backends instead.
6
+
7
+ Usage:
8
+
9
+ from tollgate.backends import SQLiteGrantStore, SQLiteApprovalStore
10
+
11
+ grant_store = SQLiteGrantStore("tollgate.db")
12
+ approval_store = SQLiteApprovalStore("tollgate.db")
13
+
14
+ tower = ControlTower(
15
+ ...,
16
+ grant_store=grant_store,
17
+ )
18
+ """
19
+
20
+ import asyncio
21
+ import json
22
+ import sqlite3
23
+ import time
24
+ import uuid
25
+ from pathlib import Path
26
+ from typing import Any
27
+
28
+ from ..types import AgentContext, ApprovalOutcome, Effect, Grant, ToolRequest
29
+
30
+
31
+ class SQLiteGrantStore:
32
+ """SQLite-backed GrantStore implementation.
33
+
34
+ Satisfies the ``GrantStore`` protocol. Uses WAL mode for concurrent
35
+ reads. All operations run in a thread executor to avoid blocking the
36
+ event loop.
37
+
38
+ Args:
39
+ db_path: Path to the SQLite database file. Use ``:memory:`` for
40
+ testing (non-persistent).
41
+ table_name: Name of the grants table (default ``tollgate_grants``).
42
+ """
43
+
44
+ def __init__(self, db_path: str | Path = "tollgate_grants.db", *, table_name: str = "tollgate_grants"):
45
+ self._db_path = str(db_path)
46
+ self._table = table_name
47
+ self._conn = sqlite3.connect(self._db_path, check_same_thread=False)
48
+ self._conn.execute("PRAGMA journal_mode=WAL")
49
+ self._conn.execute("PRAGMA foreign_keys=ON")
50
+ self._create_table()
51
+
52
+ def _create_table(self):
53
+ self._conn.execute(f"""
54
+ CREATE TABLE IF NOT EXISTS {self._table} (
55
+ id TEXT PRIMARY KEY,
56
+ agent_id TEXT,
57
+ effect TEXT,
58
+ tool TEXT,
59
+ action TEXT,
60
+ resource_type TEXT,
61
+ expires_at REAL NOT NULL,
62
+ granted_by TEXT NOT NULL,
63
+ created_at REAL NOT NULL,
64
+ reason TEXT,
65
+ usage_count INTEGER DEFAULT 0,
66
+ revoked INTEGER DEFAULT 0
67
+ )
68
+ """)
69
+ # Index for fast expiry cleanup and matching
70
+ self._conn.execute(f"""
71
+ CREATE INDEX IF NOT EXISTS idx_{self._table}_expires
72
+ ON {self._table} (expires_at, revoked)
73
+ """)
74
+ self._conn.execute(f"""
75
+ CREATE INDEX IF NOT EXISTS idx_{self._table}_agent
76
+ ON {self._table} (agent_id, revoked)
77
+ """)
78
+ self._conn.commit()
79
+
80
+ def _grant_to_row(self, grant: Grant) -> dict[str, Any]:
81
+ return {
82
+ "id": grant.id,
83
+ "agent_id": grant.agent_id,
84
+ "effect": grant.effect.value if grant.effect else None,
85
+ "tool": grant.tool,
86
+ "action": grant.action,
87
+ "resource_type": grant.resource_type,
88
+ "expires_at": grant.expires_at,
89
+ "granted_by": grant.granted_by,
90
+ "created_at": grant.created_at,
91
+ "reason": grant.reason,
92
+ }
93
+
94
+ def _row_to_grant(self, row: sqlite3.Row | tuple) -> Grant:
95
+ # sqlite3.Row or tuple access
96
+ if isinstance(row, sqlite3.Row):
97
+ d = dict(row)
98
+ else:
99
+ cols = [
100
+ "id", "agent_id", "effect", "tool", "action",
101
+ "resource_type", "expires_at", "granted_by",
102
+ "created_at", "reason", "usage_count", "revoked",
103
+ ]
104
+ d = dict(zip(cols, row))
105
+
106
+ return Grant(
107
+ id=d["id"],
108
+ agent_id=d["agent_id"],
109
+ effect=Effect(d["effect"]) if d["effect"] else None,
110
+ tool=d["tool"],
111
+ action=d["action"],
112
+ resource_type=d["resource_type"],
113
+ expires_at=d["expires_at"],
114
+ granted_by=d["granted_by"],
115
+ created_at=d["created_at"],
116
+ reason=d["reason"],
117
+ )
118
+
119
+ async def create_grant(self, grant: Grant) -> str:
120
+ row = self._grant_to_row(grant)
121
+ loop = asyncio.get_event_loop()
122
+ await loop.run_in_executor(None, self._insert_grant, row)
123
+ return grant.id
124
+
125
+ def _insert_grant(self, row: dict[str, Any]):
126
+ self._conn.execute(
127
+ f"""INSERT OR REPLACE INTO {self._table}
128
+ (id, agent_id, effect, tool, action, resource_type,
129
+ expires_at, granted_by, created_at, reason)
130
+ VALUES (:id, :agent_id, :effect, :tool, :action, :resource_type,
131
+ :expires_at, :granted_by, :created_at, :reason)""",
132
+ row,
133
+ )
134
+ self._conn.commit()
135
+
136
+ async def find_matching_grant(
137
+ self, agent_ctx: AgentContext, tool_request: ToolRequest
138
+ ) -> Grant | None:
139
+ loop = asyncio.get_event_loop()
140
+ return await loop.run_in_executor(
141
+ None, self._find_matching_grant_sync, agent_ctx, tool_request
142
+ )
143
+
144
+ def _find_matching_grant_sync(
145
+ self, agent_ctx: AgentContext, tool_request: ToolRequest
146
+ ) -> Grant | None:
147
+ now = time.time()
148
+ cursor = self._conn.execute(
149
+ f"""SELECT * FROM {self._table}
150
+ WHERE expires_at > ? AND revoked = 0
151
+ ORDER BY created_at DESC""",
152
+ (now,),
153
+ )
154
+ cursor.row_factory = sqlite3.Row
155
+
156
+ for row in cursor:
157
+ d = dict(row)
158
+
159
+ # Match agent_id
160
+ if d["agent_id"] is not None and d["agent_id"] != agent_ctx.agent_id:
161
+ continue
162
+
163
+ # Match effect
164
+ if d["effect"] is not None and d["effect"] != tool_request.effect.value:
165
+ continue
166
+
167
+ # Match tool (exact or prefix with *)
168
+ if d["tool"] is not None:
169
+ if d["tool"].endswith("*"):
170
+ prefix = d["tool"][:-1]
171
+ if not tool_request.tool.startswith(prefix):
172
+ continue
173
+ elif d["tool"] != tool_request.tool:
174
+ continue
175
+
176
+ # Match action
177
+ if d["action"] is not None and d["action"] != tool_request.action:
178
+ continue
179
+
180
+ # Match resource_type
181
+ if (
182
+ d["resource_type"] is not None
183
+ and d["resource_type"] != tool_request.resource_type
184
+ ):
185
+ continue
186
+
187
+ # Match found — increment usage
188
+ self._conn.execute(
189
+ f"UPDATE {self._table} SET usage_count = usage_count + 1 WHERE id = ?",
190
+ (d["id"],),
191
+ )
192
+ self._conn.commit()
193
+ return self._row_to_grant(row)
194
+
195
+ return None
196
+
197
+ async def revoke_grant(self, grant_id: str) -> bool:
198
+ loop = asyncio.get_event_loop()
199
+ return await loop.run_in_executor(None, self._revoke_sync, grant_id)
200
+
201
+ def _revoke_sync(self, grant_id: str) -> bool:
202
+ cursor = self._conn.execute(
203
+ f"UPDATE {self._table} SET revoked = 1 WHERE id = ? AND revoked = 0",
204
+ (grant_id,),
205
+ )
206
+ self._conn.commit()
207
+ return cursor.rowcount > 0
208
+
209
+ async def list_active_grants(self, agent_id: str | None = None) -> list[Grant]:
210
+ loop = asyncio.get_event_loop()
211
+ return await loop.run_in_executor(
212
+ None, self._list_active_sync, agent_id
213
+ )
214
+
215
+ def _list_active_sync(self, agent_id: str | None) -> list[Grant]:
216
+ now = time.time()
217
+ if agent_id is not None:
218
+ cursor = self._conn.execute(
219
+ f"""SELECT * FROM {self._table}
220
+ WHERE expires_at > ? AND revoked = 0 AND agent_id = ?""",
221
+ (now, agent_id),
222
+ )
223
+ else:
224
+ cursor = self._conn.execute(
225
+ f"SELECT * FROM {self._table} WHERE expires_at > ? AND revoked = 0",
226
+ (now,),
227
+ )
228
+ cursor.row_factory = sqlite3.Row
229
+ return [self._row_to_grant(row) for row in cursor]
230
+
231
+ async def cleanup_expired(self) -> int:
232
+ loop = asyncio.get_event_loop()
233
+ return await loop.run_in_executor(None, self._cleanup_sync)
234
+
235
+ def _cleanup_sync(self) -> int:
236
+ now = time.time()
237
+ cursor = self._conn.execute(
238
+ f"DELETE FROM {self._table} WHERE expires_at <= ?", (now,)
239
+ )
240
+ self._conn.commit()
241
+ return cursor.rowcount
242
+
243
+ async def get_usage_count(self, grant_id: str) -> int:
244
+ loop = asyncio.get_event_loop()
245
+ return await loop.run_in_executor(None, self._get_usage_sync, grant_id)
246
+
247
+ def _get_usage_sync(self, grant_id: str) -> int:
248
+ cursor = self._conn.execute(
249
+ f"SELECT usage_count FROM {self._table} WHERE id = ?", (grant_id,)
250
+ )
251
+ row = cursor.fetchone()
252
+ return row[0] if row else 0
253
+
254
+ def close(self):
255
+ """Close the database connection."""
256
+ self._conn.close()
257
+
258
+
259
+ class SQLiteApprovalStore:
260
+ """SQLite-backed ApprovalStore implementation.
261
+
262
+ Satisfies the ``ApprovalStore`` ABC. Uses polling-based wait_for_decision
263
+ since SQLite doesn't support notifications.
264
+
265
+ Args:
266
+ db_path: Path to the SQLite database file.
267
+ table_name: Name of the approvals table.
268
+ poll_interval: Seconds between polls when waiting for a decision.
269
+ """
270
+
271
+ def __init__(
272
+ self,
273
+ db_path: str | Path = "tollgate_approvals.db",
274
+ *,
275
+ table_name: str = "tollgate_approvals",
276
+ poll_interval: float = 0.5,
277
+ ):
278
+ self._db_path = str(db_path)
279
+ self._table = table_name
280
+ self._poll_interval = poll_interval
281
+ self._conn = sqlite3.connect(self._db_path, check_same_thread=False)
282
+ self._conn.execute("PRAGMA journal_mode=WAL")
283
+ self._create_table()
284
+
285
+ def _create_table(self):
286
+ self._conn.execute(f"""
287
+ CREATE TABLE IF NOT EXISTS {self._table} (
288
+ id TEXT PRIMARY KEY,
289
+ agent_json TEXT NOT NULL,
290
+ intent_json TEXT NOT NULL,
291
+ tool_request_json TEXT NOT NULL,
292
+ request_hash TEXT NOT NULL,
293
+ reason TEXT NOT NULL,
294
+ expiry REAL NOT NULL,
295
+ outcome TEXT NOT NULL DEFAULT 'deferred',
296
+ decided_by TEXT,
297
+ decided_at REAL
298
+ )
299
+ """)
300
+ self._conn.execute(f"""
301
+ CREATE INDEX IF NOT EXISTS idx_{self._table}_hash
302
+ ON {self._table} (request_hash)
303
+ """)
304
+ self._conn.commit()
305
+
306
+ async def create_request(
307
+ self,
308
+ agent_ctx: AgentContext,
309
+ intent: Any,
310
+ tool_request: ToolRequest,
311
+ request_hash: str,
312
+ reason: str,
313
+ expiry: float,
314
+ ) -> str:
315
+ approval_id = str(uuid.uuid4())
316
+ loop = asyncio.get_event_loop()
317
+ await loop.run_in_executor(
318
+ None,
319
+ self._insert_request,
320
+ approval_id,
321
+ json.dumps(agent_ctx.to_dict()),
322
+ json.dumps(intent.to_dict()),
323
+ json.dumps(tool_request.to_dict()),
324
+ request_hash,
325
+ reason,
326
+ expiry,
327
+ )
328
+ return approval_id
329
+
330
+ def _insert_request(
331
+ self,
332
+ approval_id: str,
333
+ agent_json: str,
334
+ intent_json: str,
335
+ tool_request_json: str,
336
+ request_hash: str,
337
+ reason: str,
338
+ expiry: float,
339
+ ):
340
+ self._conn.execute(
341
+ f"""INSERT INTO {self._table}
342
+ (id, agent_json, intent_json, tool_request_json,
343
+ request_hash, reason, expiry)
344
+ VALUES (?, ?, ?, ?, ?, ?, ?)""",
345
+ (
346
+ approval_id,
347
+ agent_json,
348
+ intent_json,
349
+ tool_request_json,
350
+ request_hash,
351
+ reason,
352
+ expiry,
353
+ ),
354
+ )
355
+ self._conn.commit()
356
+
357
+ async def set_decision(
358
+ self,
359
+ approval_id: str,
360
+ outcome: ApprovalOutcome,
361
+ decided_by: str,
362
+ decided_at: float,
363
+ request_hash: str,
364
+ ) -> None:
365
+ loop = asyncio.get_event_loop()
366
+ await loop.run_in_executor(
367
+ None,
368
+ self._set_decision_sync,
369
+ approval_id,
370
+ outcome,
371
+ decided_by,
372
+ decided_at,
373
+ request_hash,
374
+ )
375
+
376
+ def _set_decision_sync(
377
+ self,
378
+ approval_id: str,
379
+ outcome: ApprovalOutcome,
380
+ decided_by: str,
381
+ decided_at: float,
382
+ request_hash: str,
383
+ ):
384
+ cursor = self._conn.execute(
385
+ f"SELECT request_hash FROM {self._table} WHERE id = ?",
386
+ (approval_id,),
387
+ )
388
+ row = cursor.fetchone()
389
+ if row is None:
390
+ return
391
+
392
+ stored_hash = row[0]
393
+ if stored_hash != request_hash:
394
+ raise ValueError(
395
+ "Request hash mismatch. Approval bound to a different request."
396
+ )
397
+
398
+ self._conn.execute(
399
+ f"""UPDATE {self._table}
400
+ SET outcome = ?, decided_by = ?, decided_at = ?
401
+ WHERE id = ?""",
402
+ (outcome.value, decided_by, decided_at, approval_id),
403
+ )
404
+ self._conn.commit()
405
+
406
+ async def get_request(self, approval_id: str) -> dict[str, Any] | None:
407
+ loop = asyncio.get_event_loop()
408
+ return await loop.run_in_executor(
409
+ None, self._get_request_sync, approval_id
410
+ )
411
+
412
+ def _get_request_sync(self, approval_id: str) -> dict[str, Any] | None:
413
+ cursor = self._conn.execute(
414
+ f"SELECT * FROM {self._table} WHERE id = ?", (approval_id,)
415
+ )
416
+ cursor.row_factory = sqlite3.Row
417
+ row = cursor.fetchone()
418
+ if row is None:
419
+ return None
420
+
421
+ d = dict(row)
422
+ return {
423
+ "id": d["id"],
424
+ "agent": json.loads(d["agent_json"]),
425
+ "intent": json.loads(d["intent_json"]),
426
+ "tool_request": json.loads(d["tool_request_json"]),
427
+ "request_hash": d["request_hash"],
428
+ "reason": d["reason"],
429
+ "expiry": d["expiry"],
430
+ "outcome": ApprovalOutcome(d["outcome"]),
431
+ "decided_by": d.get("decided_by"),
432
+ "decided_at": d.get("decided_at"),
433
+ }
434
+
435
+ async def wait_for_decision(
436
+ self, approval_id: str, timeout: float
437
+ ) -> ApprovalOutcome:
438
+ """Poll the database for a decision, with timeout."""
439
+ deadline = time.time() + timeout
440
+
441
+ while time.time() < deadline:
442
+ req = await self.get_request(approval_id)
443
+ if req is None:
444
+ return ApprovalOutcome.TIMEOUT
445
+
446
+ if req["expiry"] < time.time():
447
+ return ApprovalOutcome.TIMEOUT
448
+
449
+ if req["outcome"] != ApprovalOutcome.DEFERRED:
450
+ return req["outcome"]
451
+
452
+ await asyncio.sleep(self._poll_interval)
453
+
454
+ return ApprovalOutcome.TIMEOUT
455
+
456
+ def close(self):
457
+ """Close the database connection."""
458
+ self._conn.close()
@@ -0,0 +1,206 @@
1
+ """Circuit breaker for AI agent tool calls.
2
+
3
+ Tracks consecutive failures per (tool, action) pair. After a configured
4
+ threshold, the circuit "opens" and all subsequent calls are auto-denied
5
+ for a cooldown period.
6
+
7
+ States:
8
+ CLOSED → normal operation, failures counted
9
+ OPEN → blocking all calls, waiting for cooldown
10
+ HALF_OPEN → cooldown expired, next call is a probe
11
+ - if probe succeeds → CLOSED
12
+ - if probe fails → OPEN (cooldown resets)
13
+ """
14
+
15
+ import asyncio
16
+ import time
17
+ from enum import Enum
18
+ from typing import Any, Protocol, runtime_checkable
19
+
20
+
21
+ class CircuitState(str, Enum):
22
+ CLOSED = "closed"
23
+ OPEN = "open"
24
+ HALF_OPEN = "half_open"
25
+
26
+
27
+ class _CircuitEntry:
28
+ """Internal state for a single circuit."""
29
+
30
+ __slots__ = ("state", "failure_count", "last_failure_at", "opened_at")
31
+
32
+ def __init__(self):
33
+ self.state: CircuitState = CircuitState.CLOSED
34
+ self.failure_count: int = 0
35
+ self.last_failure_at: float = 0.0
36
+ self.opened_at: float = 0.0
37
+
38
+
39
+ @runtime_checkable
40
+ class CircuitBreaker(Protocol):
41
+ """Protocol for circuit breaker backends."""
42
+
43
+ async def before_call(self, tool: str, action: str) -> tuple[bool, str | None]:
44
+ """Check if a call is allowed.
45
+
46
+ Returns (allowed, reason). If not allowed, reason explains why.
47
+ """
48
+ ...
49
+
50
+ async def record_success(self, tool: str, action: str) -> None:
51
+ """Record a successful tool execution."""
52
+ ...
53
+
54
+ async def record_failure(self, tool: str, action: str) -> None:
55
+ """Record a failed tool execution."""
56
+ ...
57
+
58
+
59
+ class InMemoryCircuitBreaker:
60
+ """In-memory circuit breaker with configurable thresholds.
61
+
62
+ Args:
63
+ failure_threshold: Number of consecutive failures before opening.
64
+ cooldown_seconds: Seconds to wait before attempting a probe (HALF_OPEN).
65
+ half_open_max_calls: Number of successful probes needed to close.
66
+
67
+ Usage:
68
+ breaker = InMemoryCircuitBreaker(failure_threshold=5, cooldown_seconds=60)
69
+
70
+ tower = ControlTower(
71
+ ...,
72
+ circuit_breaker=breaker,
73
+ )
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ *,
79
+ failure_threshold: int = 5,
80
+ cooldown_seconds: float = 60.0,
81
+ half_open_max_calls: int = 1,
82
+ ):
83
+ if failure_threshold < 1:
84
+ raise ValueError("failure_threshold must be >= 1")
85
+ if cooldown_seconds <= 0:
86
+ raise ValueError("cooldown_seconds must be > 0")
87
+
88
+ self.failure_threshold = failure_threshold
89
+ self.cooldown_seconds = cooldown_seconds
90
+ self.half_open_max_calls = half_open_max_calls
91
+ self._circuits: dict[str, _CircuitEntry] = {}
92
+ self._lock = asyncio.Lock()
93
+
94
+ @staticmethod
95
+ def _key(tool: str, action: str) -> str:
96
+ return f"{tool}:{action}"
97
+
98
+ async def before_call(
99
+ self, tool: str, action: str
100
+ ) -> tuple[bool, str | None]:
101
+ """Check if the circuit allows a call through."""
102
+ now = time.time()
103
+ key = self._key(tool, action)
104
+
105
+ async with self._lock:
106
+ entry = self._circuits.get(key)
107
+ if entry is None:
108
+ return True, None
109
+
110
+ if entry.state == CircuitState.CLOSED:
111
+ return True, None
112
+
113
+ if entry.state == CircuitState.OPEN:
114
+ # Check if cooldown has elapsed
115
+ elapsed = now - entry.opened_at
116
+ if elapsed >= self.cooldown_seconds:
117
+ # Transition to HALF_OPEN — allow a probe
118
+ entry.state = CircuitState.HALF_OPEN
119
+ entry.failure_count = 0
120
+ return True, None
121
+ else:
122
+ remaining = self.cooldown_seconds - elapsed
123
+ return False, (
124
+ f"Circuit OPEN for {tool}.{action}: "
125
+ f"{remaining:.1f}s remaining in cooldown "
126
+ f"(opened after {self.failure_threshold} consecutive failures)"
127
+ )
128
+
129
+ if entry.state == CircuitState.HALF_OPEN:
130
+ # Allow the probe call through
131
+ return True, None
132
+
133
+ return True, None
134
+
135
+ async def record_success(self, tool: str, action: str) -> None:
136
+ """Record a success — close the circuit if in HALF_OPEN."""
137
+ key = self._key(tool, action)
138
+
139
+ async with self._lock:
140
+ entry = self._circuits.get(key)
141
+ if entry is None:
142
+ return
143
+
144
+ if entry.state == CircuitState.HALF_OPEN:
145
+ # Probe succeeded — close the circuit
146
+ entry.state = CircuitState.CLOSED
147
+ entry.failure_count = 0
148
+
149
+ elif entry.state == CircuitState.CLOSED:
150
+ # Reset failure count on success
151
+ entry.failure_count = 0
152
+
153
+ async def record_failure(self, tool: str, action: str) -> None:
154
+ """Record a failure — may open the circuit."""
155
+ now = time.time()
156
+ key = self._key(tool, action)
157
+
158
+ async with self._lock:
159
+ entry = self._circuits.get(key)
160
+ if entry is None:
161
+ entry = _CircuitEntry()
162
+ self._circuits[key] = entry
163
+
164
+ entry.failure_count += 1
165
+ entry.last_failure_at = now
166
+
167
+ if entry.state == CircuitState.HALF_OPEN:
168
+ # Probe failed — back to OPEN
169
+ entry.state = CircuitState.OPEN
170
+ entry.opened_at = now
171
+
172
+ elif entry.state == CircuitState.CLOSED:
173
+ if entry.failure_count >= self.failure_threshold:
174
+ entry.state = CircuitState.OPEN
175
+ entry.opened_at = now
176
+
177
+ async def get_state(self, tool: str, action: str) -> CircuitState:
178
+ """Get the current circuit state (for monitoring/testing)."""
179
+ key = self._key(tool, action)
180
+ async with self._lock:
181
+ entry = self._circuits.get(key)
182
+ if entry is None:
183
+ return CircuitState.CLOSED
184
+ return entry.state
185
+
186
+ async def reset(self, tool: str | None = None, action: str | None = None) -> None:
187
+ """Reset circuit state. If tool/action given, reset only that circuit."""
188
+ async with self._lock:
189
+ if tool is not None and action is not None:
190
+ key = self._key(tool, action)
191
+ self._circuits.pop(key, None)
192
+ else:
193
+ self._circuits.clear()
194
+
195
+ async def get_all_states(self) -> dict[str, dict[str, Any]]:
196
+ """Get all circuit states (for monitoring dashboards)."""
197
+ async with self._lock:
198
+ return {
199
+ key: {
200
+ "state": entry.state.value,
201
+ "failure_count": entry.failure_count,
202
+ "last_failure_at": entry.last_failure_at,
203
+ "opened_at": entry.opened_at,
204
+ }
205
+ for key, entry in self._circuits.items()
206
+ }