tollgate 1.0.4__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,411 @@
1
+ """Redis-backed persistent stores for Tollgate.
2
+
3
+ Requires the ``redis`` package: ``pip install redis[hiredis]``
4
+
5
+ Suitable for multi-process and multi-host deployments. Uses Redis hashes
6
+ for grant storage and pub/sub for real-time approval notifications.
7
+
8
+ Usage:
9
+
10
+ from tollgate.backends import RedisGrantStore, RedisApprovalStore
11
+
12
+ grant_store = RedisGrantStore(redis_url="redis://localhost:6379/0")
13
+ approval_store = RedisApprovalStore(redis_url="redis://localhost:6379/0")
14
+
15
+ tower = ControlTower(
16
+ ...,
17
+ grant_store=grant_store,
18
+ )
19
+ """
20
+
21
+ import asyncio
22
+ import json
23
+ import time
24
+ import uuid
25
+ from typing import Any
26
+
27
+ try:
28
+ import redis.asyncio as aioredis
29
+ except ImportError:
30
+ raise ImportError(
31
+ "Redis backend requires the 'redis' package. "
32
+ "Install it with: pip install redis[hiredis]"
33
+ )
34
+
35
+ from ..types import AgentContext, ApprovalOutcome, Effect, Grant, ToolRequest
36
+
37
+
38
+ class RedisGrantStore:
39
+ """Redis-backed GrantStore implementation.
40
+
41
+ Uses Redis hashes for grant storage with automatic TTL-based expiry.
42
+ Satisfies the ``GrantStore`` protocol.
43
+
44
+ Args:
45
+ redis_url: Redis connection URL (e.g., ``redis://localhost:6379/0``).
46
+ redis_client: Pre-configured async Redis client (alternative to URL).
47
+ key_prefix: Prefix for all Redis keys (default ``tollgate:grant:``).
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ redis_url: str | None = None,
53
+ *,
54
+ redis_client: Any | None = None,
55
+ key_prefix: str = "tollgate:grant:",
56
+ ):
57
+ if redis_client is not None:
58
+ self._redis = redis_client
59
+ elif redis_url is not None:
60
+ self._redis = aioredis.from_url(redis_url, decode_responses=True)
61
+ else:
62
+ raise ValueError("Either redis_url or redis_client must be provided")
63
+
64
+ self._prefix = key_prefix
65
+ self._index_key = f"{key_prefix}__index__"
66
+
67
+ def _grant_key(self, grant_id: str) -> str:
68
+ return f"{self._prefix}{grant_id}"
69
+
70
+ def _grant_to_dict(self, grant: Grant) -> dict[str, str]:
71
+ """Serialize a Grant to a flat dict for Redis HSET."""
72
+ return {
73
+ "id": grant.id,
74
+ "agent_id": grant.agent_id or "",
75
+ "effect": grant.effect.value if grant.effect else "",
76
+ "tool": grant.tool or "",
77
+ "action": grant.action or "",
78
+ "resource_type": grant.resource_type or "",
79
+ "expires_at": str(grant.expires_at),
80
+ "granted_by": grant.granted_by,
81
+ "created_at": str(grant.created_at),
82
+ "reason": grant.reason or "",
83
+ "usage_count": "0",
84
+ "revoked": "0",
85
+ }
86
+
87
+ def _dict_to_grant(self, d: dict[str, str]) -> Grant:
88
+ """Deserialize a Redis hash dict to a Grant."""
89
+ return Grant(
90
+ id=d["id"],
91
+ agent_id=d["agent_id"] or None,
92
+ effect=Effect(d["effect"]) if d["effect"] else None,
93
+ tool=d["tool"] or None,
94
+ action=d["action"] or None,
95
+ resource_type=d["resource_type"] or None,
96
+ expires_at=float(d["expires_at"]),
97
+ granted_by=d["granted_by"],
98
+ created_at=float(d["created_at"]),
99
+ reason=d["reason"] or None,
100
+ )
101
+
102
+ async def create_grant(self, grant: Grant) -> str:
103
+ key = self._grant_key(grant.id)
104
+ data = self._grant_to_dict(grant)
105
+
106
+ pipe = self._redis.pipeline()
107
+ pipe.hset(key, mapping=data)
108
+ # Set TTL based on expiry time
109
+ ttl = max(1, int(grant.expires_at - time.time()))
110
+ pipe.expire(key, ttl)
111
+ # Track in index set
112
+ pipe.sadd(self._index_key, grant.id)
113
+ await pipe.execute()
114
+
115
+ return grant.id
116
+
117
+ async def find_matching_grant(
118
+ self, agent_ctx: AgentContext, tool_request: ToolRequest
119
+ ) -> Grant | None:
120
+ # Get all active grant IDs from the index
121
+ grant_ids = await self._redis.smembers(self._index_key)
122
+
123
+ for grant_id in grant_ids:
124
+ key = self._grant_key(grant_id)
125
+ data = await self._redis.hgetall(key)
126
+
127
+ if not data:
128
+ # Key expired — remove from index
129
+ await self._redis.srem(self._index_key, grant_id)
130
+ continue
131
+
132
+ if data.get("revoked") == "1":
133
+ continue
134
+
135
+ expires_at = float(data["expires_at"])
136
+ if expires_at <= time.time():
137
+ continue
138
+
139
+ # Match agent_id
140
+ if data["agent_id"] and data["agent_id"] != agent_ctx.agent_id:
141
+ continue
142
+
143
+ # Match effect
144
+ if data["effect"] and data["effect"] != tool_request.effect.value:
145
+ continue
146
+
147
+ # Match tool (exact or prefix with *)
148
+ if data["tool"]:
149
+ if data["tool"].endswith("*"):
150
+ prefix = data["tool"][:-1]
151
+ if not tool_request.tool.startswith(prefix):
152
+ continue
153
+ elif data["tool"] != tool_request.tool:
154
+ continue
155
+
156
+ # Match action
157
+ if data["action"] and data["action"] != tool_request.action:
158
+ continue
159
+
160
+ # Match resource_type
161
+ if (
162
+ data["resource_type"]
163
+ and data["resource_type"] != tool_request.resource_type
164
+ ):
165
+ continue
166
+
167
+ # Match found — increment usage
168
+ await self._redis.hincrby(key, "usage_count", 1)
169
+ return self._dict_to_grant(data)
170
+
171
+ return None
172
+
173
+ async def revoke_grant(self, grant_id: str) -> bool:
174
+ key = self._grant_key(grant_id)
175
+ exists = await self._redis.exists(key)
176
+ if not exists:
177
+ return False
178
+
179
+ await self._redis.hset(key, "revoked", "1")
180
+ return True
181
+
182
+ async def list_active_grants(self, agent_id: str | None = None) -> list[Grant]:
183
+ grant_ids = await self._redis.smembers(self._index_key)
184
+ now = time.time()
185
+ active: list[Grant] = []
186
+
187
+ for grant_id in grant_ids:
188
+ key = self._grant_key(grant_id)
189
+ data = await self._redis.hgetall(key)
190
+
191
+ if not data:
192
+ await self._redis.srem(self._index_key, grant_id)
193
+ continue
194
+
195
+ if data.get("revoked") == "1":
196
+ continue
197
+
198
+ if float(data["expires_at"]) <= now:
199
+ continue
200
+
201
+ if agent_id is not None and data["agent_id"] != agent_id:
202
+ continue
203
+
204
+ active.append(self._dict_to_grant(data))
205
+
206
+ return active
207
+
208
+ async def cleanup_expired(self) -> int:
209
+ grant_ids = await self._redis.smembers(self._index_key)
210
+ now = time.time()
211
+ removed = 0
212
+
213
+ for grant_id in grant_ids:
214
+ key = self._grant_key(grant_id)
215
+ data = await self._redis.hgetall(key)
216
+
217
+ if not data or float(data.get("expires_at", "0")) <= now:
218
+ await self._redis.delete(key)
219
+ await self._redis.srem(self._index_key, grant_id)
220
+ removed += 1
221
+
222
+ return removed
223
+
224
+ async def get_usage_count(self, grant_id: str) -> int:
225
+ key = self._grant_key(grant_id)
226
+ count = await self._redis.hget(key, "usage_count")
227
+ return int(count) if count else 0
228
+
229
+ async def close(self):
230
+ """Close the Redis connection."""
231
+ await self._redis.close()
232
+
233
+
234
+ class RedisApprovalStore:
235
+ """Redis-backed ApprovalStore implementation.
236
+
237
+ Uses Redis hashes for request storage and pub/sub for real-time
238
+ notifications when a decision is made. This avoids polling.
239
+
240
+ Args:
241
+ redis_url: Redis connection URL.
242
+ redis_client: Pre-configured async Redis client.
243
+ key_prefix: Prefix for all Redis keys.
244
+ channel_prefix: Prefix for pub/sub channels.
245
+ """
246
+
247
+ def __init__(
248
+ self,
249
+ redis_url: str | None = None,
250
+ *,
251
+ redis_client: Any | None = None,
252
+ key_prefix: str = "tollgate:approval:",
253
+ channel_prefix: str = "tollgate:approval_notify:",
254
+ ):
255
+ if redis_client is not None:
256
+ self._redis = redis_client
257
+ elif redis_url is not None:
258
+ self._redis = aioredis.from_url(redis_url, decode_responses=True)
259
+ else:
260
+ raise ValueError("Either redis_url or redis_client must be provided")
261
+
262
+ self._prefix = key_prefix
263
+ self._channel_prefix = channel_prefix
264
+
265
+ def _request_key(self, approval_id: str) -> str:
266
+ return f"{self._prefix}{approval_id}"
267
+
268
+ def _channel_key(self, approval_id: str) -> str:
269
+ return f"{self._channel_prefix}{approval_id}"
270
+
271
+ async def create_request(
272
+ self,
273
+ agent_ctx: AgentContext,
274
+ intent: Any,
275
+ tool_request: ToolRequest,
276
+ request_hash: str,
277
+ reason: str,
278
+ expiry: float,
279
+ ) -> str:
280
+ approval_id = str(uuid.uuid4())
281
+ key = self._request_key(approval_id)
282
+
283
+ data = {
284
+ "id": approval_id,
285
+ "agent_json": json.dumps(agent_ctx.to_dict()),
286
+ "intent_json": json.dumps(intent.to_dict()),
287
+ "tool_request_json": json.dumps(tool_request.to_dict()),
288
+ "request_hash": request_hash,
289
+ "reason": reason,
290
+ "expiry": str(expiry),
291
+ "outcome": ApprovalOutcome.DEFERRED.value,
292
+ "decided_by": "",
293
+ "decided_at": "",
294
+ }
295
+
296
+ pipe = self._redis.pipeline()
297
+ pipe.hset(key, mapping=data)
298
+ # Auto-expire after the approval window
299
+ ttl = max(1, int(expiry - time.time()) + 60) # +60s buffer
300
+ pipe.expire(key, ttl)
301
+ await pipe.execute()
302
+
303
+ return approval_id
304
+
305
+ async def set_decision(
306
+ self,
307
+ approval_id: str,
308
+ outcome: ApprovalOutcome,
309
+ decided_by: str,
310
+ decided_at: float,
311
+ request_hash: str,
312
+ ) -> None:
313
+ key = self._request_key(approval_id)
314
+
315
+ # Verify request hash (replay protection)
316
+ stored_hash = await self._redis.hget(key, "request_hash")
317
+ if stored_hash is None:
318
+ return # Request not found
319
+
320
+ if stored_hash != request_hash:
321
+ raise ValueError(
322
+ "Request hash mismatch. Approval bound to a different request."
323
+ )
324
+
325
+ pipe = self._redis.pipeline()
326
+ pipe.hset(key, mapping={
327
+ "outcome": outcome.value,
328
+ "decided_by": decided_by,
329
+ "decided_at": str(decided_at),
330
+ })
331
+ # Publish notification for waiters
332
+ pipe.publish(self._channel_key(approval_id), outcome.value)
333
+ await pipe.execute()
334
+
335
+ async def get_request(self, approval_id: str) -> dict[str, Any] | None:
336
+ key = self._request_key(approval_id)
337
+ data = await self._redis.hgetall(key)
338
+
339
+ if not data:
340
+ return None
341
+
342
+ return {
343
+ "id": data["id"],
344
+ "agent": json.loads(data["agent_json"]),
345
+ "intent": json.loads(data["intent_json"]),
346
+ "tool_request": json.loads(data["tool_request_json"]),
347
+ "request_hash": data["request_hash"],
348
+ "reason": data["reason"],
349
+ "expiry": float(data["expiry"]),
350
+ "outcome": ApprovalOutcome(data["outcome"]),
351
+ "decided_by": data.get("decided_by") or None,
352
+ "decided_at": float(data["decided_at"]) if data.get("decided_at") else None,
353
+ }
354
+
355
+ async def wait_for_decision(
356
+ self, approval_id: str, timeout: float
357
+ ) -> ApprovalOutcome:
358
+ """Wait for a decision using Redis pub/sub (non-polling)."""
359
+ # First check if a decision already exists
360
+ req = await self.get_request(approval_id)
361
+ if req is None:
362
+ return ApprovalOutcome.TIMEOUT
363
+
364
+ if req["expiry"] < time.time():
365
+ return ApprovalOutcome.TIMEOUT
366
+
367
+ if req["outcome"] != ApprovalOutcome.DEFERRED:
368
+ return req["outcome"]
369
+
370
+ # Subscribe and wait for a notification
371
+ channel = self._channel_key(approval_id)
372
+ pubsub = self._redis.pubsub()
373
+
374
+ try:
375
+ await pubsub.subscribe(channel)
376
+
377
+ deadline = time.time() + timeout
378
+ while time.time() < deadline:
379
+ remaining = deadline - time.time()
380
+ if remaining <= 0:
381
+ break
382
+
383
+ try:
384
+ message = await asyncio.wait_for(
385
+ pubsub.get_message(
386
+ ignore_subscribe_messages=True, timeout=1.0
387
+ ),
388
+ timeout=min(remaining, 2.0),
389
+ )
390
+ except asyncio.TimeoutError:
391
+ # Check if decision was made (fallback polling)
392
+ req = await self.get_request(approval_id)
393
+ if req and req["outcome"] != ApprovalOutcome.DEFERRED:
394
+ return req["outcome"]
395
+ continue
396
+
397
+ if message and message["type"] == "message":
398
+ try:
399
+ return ApprovalOutcome(message["data"])
400
+ except ValueError:
401
+ pass
402
+
403
+ return ApprovalOutcome.TIMEOUT
404
+
405
+ finally:
406
+ await pubsub.unsubscribe(channel)
407
+ await pubsub.close()
408
+
409
+ async def close(self):
410
+ """Close the Redis connection."""
411
+ await self._redis.close()