tollgate 1.0.5__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tollgate/__init__.py +34 -2
- tollgate/anomaly_detector.py +396 -0
- tollgate/audit.py +90 -1
- tollgate/backends/__init__.py +37 -0
- tollgate/backends/redis_store.py +411 -0
- tollgate/backends/sqlite_store.py +458 -0
- tollgate/circuit_breaker.py +206 -0
- tollgate/context_monitor.py +292 -0
- tollgate/exceptions.py +20 -0
- tollgate/manifest_signing.py +90 -0
- tollgate/network_guard.py +114 -0
- tollgate/policy.py +37 -0
- tollgate/policy_testing.py +360 -0
- tollgate/rate_limiter.py +162 -0
- tollgate/registry.py +225 -2
- tollgate/tower.py +182 -11
- tollgate/types.py +21 -1
- tollgate/verification.py +81 -0
- tollgate-1.4.0.dist-info/METADATA +393 -0
- tollgate-1.4.0.dist-info/RECORD +33 -0
- tollgate-1.4.0.dist-info/entry_points.txt +2 -0
- tollgate-1.0.5.dist-info/METADATA +0 -144
- tollgate-1.0.5.dist-info/RECORD +0 -21
- {tollgate-1.0.5.dist-info → tollgate-1.4.0.dist-info}/WHEEL +0 -0
- {tollgate-1.0.5.dist-info → tollgate-1.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
"""Policy testing framework for Tollgate.
|
|
2
|
+
|
|
3
|
+
Enables declarative scenario-based testing of Tollgate policies to prevent
|
|
4
|
+
regressions in CI. Test scenarios are defined in YAML and run against a
|
|
5
|
+
policy evaluator.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
|
|
9
|
+
# From Python:
|
|
10
|
+
from tollgate.policy_testing import PolicyTestRunner
|
|
11
|
+
|
|
12
|
+
runner = PolicyTestRunner("policy.yaml", "test_scenarios.yaml")
|
|
13
|
+
results = runner.run()
|
|
14
|
+
assert results.all_passed
|
|
15
|
+
|
|
16
|
+
# From CLI:
|
|
17
|
+
tollgate test-policy policy.yaml --scenarios test_scenarios.yaml
|
|
18
|
+
|
|
19
|
+
Scenario file format:
|
|
20
|
+
|
|
21
|
+
scenarios:
|
|
22
|
+
- name: "Allow read operations"
|
|
23
|
+
description: "Read effects should be allowed for trusted agents"
|
|
24
|
+
agent:
|
|
25
|
+
agent_id: "agent-1"
|
|
26
|
+
version: "1.0"
|
|
27
|
+
owner: "team-a"
|
|
28
|
+
intent:
|
|
29
|
+
action: "fetch_data"
|
|
30
|
+
reason: "Customer request"
|
|
31
|
+
tool_request:
|
|
32
|
+
tool: "api:fetch"
|
|
33
|
+
action: "get"
|
|
34
|
+
resource_type: "url"
|
|
35
|
+
effect: "read"
|
|
36
|
+
params: {}
|
|
37
|
+
manifest_version: "1.0.0"
|
|
38
|
+
expected:
|
|
39
|
+
decision: "ALLOW" # Required: ALLOW, ASK, or DENY
|
|
40
|
+
reason_contains: "Rule" # Optional: substring match on reason
|
|
41
|
+
policy_id: "allow_read" # Optional: exact match on policy_id
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
import sys
|
|
45
|
+
import time
|
|
46
|
+
from dataclasses import dataclass, field
|
|
47
|
+
from pathlib import Path
|
|
48
|
+
from typing import Any
|
|
49
|
+
|
|
50
|
+
import yaml
|
|
51
|
+
|
|
52
|
+
from .policy import YamlPolicyEvaluator
|
|
53
|
+
from .types import AgentContext, DecisionType, Effect, Intent, ToolRequest
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class ScenarioResult:
|
|
58
|
+
"""Result of a single test scenario."""
|
|
59
|
+
|
|
60
|
+
name: str
|
|
61
|
+
passed: bool
|
|
62
|
+
expected_decision: str
|
|
63
|
+
actual_decision: str
|
|
64
|
+
expected_reason_contains: str | None = None
|
|
65
|
+
actual_reason: str | None = None
|
|
66
|
+
expected_policy_id: str | None = None
|
|
67
|
+
actual_policy_id: str | None = None
|
|
68
|
+
errors: list[str] = field(default_factory=list)
|
|
69
|
+
duration_ms: float = 0.0
|
|
70
|
+
|
|
71
|
+
def __str__(self) -> str:
|
|
72
|
+
status = "PASS" if self.passed else "FAIL"
|
|
73
|
+
msg = f" [{status}] {self.name}"
|
|
74
|
+
if not self.passed:
|
|
75
|
+
for err in self.errors:
|
|
76
|
+
msg += f"\n {err}"
|
|
77
|
+
return msg
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass
|
|
81
|
+
class PolicyTestRunResult:
|
|
82
|
+
"""Aggregate result of a test run."""
|
|
83
|
+
|
|
84
|
+
scenario_results: list[ScenarioResult]
|
|
85
|
+
total: int = 0
|
|
86
|
+
passed: int = 0
|
|
87
|
+
failed: int = 0
|
|
88
|
+
duration_ms: float = 0.0
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def all_passed(self) -> bool:
|
|
92
|
+
return self.failed == 0
|
|
93
|
+
|
|
94
|
+
def summary(self) -> str:
|
|
95
|
+
lines = [
|
|
96
|
+
"",
|
|
97
|
+
"=" * 60,
|
|
98
|
+
f" Policy Test Results: {self.passed}/{self.total} passed",
|
|
99
|
+
"=" * 60,
|
|
100
|
+
]
|
|
101
|
+
for result in self.scenario_results:
|
|
102
|
+
lines.append(str(result))
|
|
103
|
+
lines.append("-" * 60)
|
|
104
|
+
status = "ALL PASSED" if self.all_passed else f"{self.failed} FAILED"
|
|
105
|
+
lines.append(f" {status} ({self.duration_ms:.1f}ms)")
|
|
106
|
+
lines.append("")
|
|
107
|
+
return "\n".join(lines)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class PolicyTestRunner:
|
|
111
|
+
"""Run declarative policy test scenarios.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
policy_path: Path to the policy YAML file.
|
|
115
|
+
scenarios_path: Path to the test scenarios YAML file.
|
|
116
|
+
policy_evaluator: Optional pre-configured evaluator (overrides policy_path).
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def __init__(
|
|
120
|
+
self,
|
|
121
|
+
policy_path: str | Path | None = None,
|
|
122
|
+
scenarios_path: str | Path | None = None,
|
|
123
|
+
*,
|
|
124
|
+
policy_evaluator: Any | None = None,
|
|
125
|
+
scenarios: list[dict[str, Any]] | None = None,
|
|
126
|
+
):
|
|
127
|
+
# Load policy
|
|
128
|
+
if policy_evaluator is not None:
|
|
129
|
+
self._evaluator = policy_evaluator
|
|
130
|
+
elif policy_path is not None:
|
|
131
|
+
self._evaluator = YamlPolicyEvaluator(policy_path)
|
|
132
|
+
else:
|
|
133
|
+
raise ValueError("Either policy_path or policy_evaluator must be provided")
|
|
134
|
+
|
|
135
|
+
# Load scenarios
|
|
136
|
+
if scenarios is not None:
|
|
137
|
+
self._scenarios = scenarios
|
|
138
|
+
elif scenarios_path is not None:
|
|
139
|
+
self._scenarios = self._load_scenarios(scenarios_path)
|
|
140
|
+
else:
|
|
141
|
+
raise ValueError("Either scenarios_path or scenarios must be provided")
|
|
142
|
+
|
|
143
|
+
self._validate_scenarios()
|
|
144
|
+
|
|
145
|
+
@staticmethod
|
|
146
|
+
def _load_scenarios(path: str | Path) -> list[dict[str, Any]]:
|
|
147
|
+
path = Path(path)
|
|
148
|
+
if not path.exists():
|
|
149
|
+
raise FileNotFoundError(f"Scenarios file not found: {path}")
|
|
150
|
+
|
|
151
|
+
with path.open("r") as f:
|
|
152
|
+
data = yaml.safe_load(f)
|
|
153
|
+
|
|
154
|
+
if not data or "scenarios" not in data:
|
|
155
|
+
raise ValueError(f"Scenarios file must contain a 'scenarios' key: {path}")
|
|
156
|
+
|
|
157
|
+
return data["scenarios"]
|
|
158
|
+
|
|
159
|
+
def _validate_scenarios(self):
|
|
160
|
+
"""Validate scenario structure before running."""
|
|
161
|
+
for i, scenario in enumerate(self._scenarios):
|
|
162
|
+
name = scenario.get("name", f"Scenario {i}")
|
|
163
|
+
if "expected" not in scenario:
|
|
164
|
+
raise ValueError(f"Scenario '{name}' is missing 'expected' key")
|
|
165
|
+
if "decision" not in scenario["expected"]:
|
|
166
|
+
raise ValueError(
|
|
167
|
+
f"Scenario '{name}' expected section must include 'decision'"
|
|
168
|
+
)
|
|
169
|
+
try:
|
|
170
|
+
DecisionType(scenario["expected"]["decision"])
|
|
171
|
+
except ValueError:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"Scenario '{name}' has invalid expected decision: "
|
|
174
|
+
f"'{scenario['expected']['decision']}'"
|
|
175
|
+
) from None
|
|
176
|
+
|
|
177
|
+
def run(self) -> PolicyTestRunResult:
|
|
178
|
+
"""Run all test scenarios and return results."""
|
|
179
|
+
start = time.monotonic()
|
|
180
|
+
results: list[ScenarioResult] = []
|
|
181
|
+
|
|
182
|
+
for scenario in self._scenarios:
|
|
183
|
+
result = self._run_scenario(scenario)
|
|
184
|
+
results.append(result)
|
|
185
|
+
|
|
186
|
+
total_ms = (time.monotonic() - start) * 1000
|
|
187
|
+
passed = sum(1 for r in results if r.passed)
|
|
188
|
+
failed = len(results) - passed
|
|
189
|
+
|
|
190
|
+
return PolicyTestRunResult(
|
|
191
|
+
scenario_results=results,
|
|
192
|
+
total=len(results),
|
|
193
|
+
passed=passed,
|
|
194
|
+
failed=failed,
|
|
195
|
+
duration_ms=total_ms,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def _run_scenario(self, scenario: dict[str, Any]) -> ScenarioResult:
|
|
199
|
+
"""Run a single test scenario."""
|
|
200
|
+
name = scenario.get("name", "Unnamed scenario")
|
|
201
|
+
expected = scenario["expected"]
|
|
202
|
+
expected_decision = expected["decision"]
|
|
203
|
+
expected_reason_contains = expected.get("reason_contains")
|
|
204
|
+
expected_policy_id = expected.get("policy_id")
|
|
205
|
+
|
|
206
|
+
start = time.monotonic()
|
|
207
|
+
errors: list[str] = []
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
agent_ctx = self._build_agent_context(scenario.get("agent", {}))
|
|
211
|
+
intent = self._build_intent(scenario.get("intent", {}))
|
|
212
|
+
tool_request = self._build_tool_request(scenario.get("tool_request", {}))
|
|
213
|
+
|
|
214
|
+
decision = self._evaluator.evaluate(agent_ctx, intent, tool_request)
|
|
215
|
+
|
|
216
|
+
actual_decision = decision.decision.value
|
|
217
|
+
actual_reason = decision.reason
|
|
218
|
+
actual_policy_id = decision.policy_id
|
|
219
|
+
|
|
220
|
+
# Check decision
|
|
221
|
+
if actual_decision != expected_decision:
|
|
222
|
+
errors.append(
|
|
223
|
+
f"Decision: expected '{expected_decision}', got '{actual_decision}'"
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Check reason (substring match)
|
|
227
|
+
if expected_reason_contains and expected_reason_contains not in (
|
|
228
|
+
actual_reason or ""
|
|
229
|
+
):
|
|
230
|
+
errors.append(
|
|
231
|
+
f"Reason: expected to contain '{expected_reason_contains}', "
|
|
232
|
+
f"got '{actual_reason}'"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Check policy_id
|
|
236
|
+
if expected_policy_id and actual_policy_id != expected_policy_id:
|
|
237
|
+
errors.append(
|
|
238
|
+
f"Policy ID: expected '{expected_policy_id}', "
|
|
239
|
+
f"got '{actual_policy_id}'"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
except Exception as e:
|
|
243
|
+
actual_decision = "ERROR"
|
|
244
|
+
actual_reason = str(e)
|
|
245
|
+
actual_policy_id = None
|
|
246
|
+
errors.append(f"Exception: {e}")
|
|
247
|
+
|
|
248
|
+
duration_ms = (time.monotonic() - start) * 1000
|
|
249
|
+
|
|
250
|
+
return ScenarioResult(
|
|
251
|
+
name=name,
|
|
252
|
+
passed=len(errors) == 0,
|
|
253
|
+
expected_decision=expected_decision,
|
|
254
|
+
actual_decision=actual_decision,
|
|
255
|
+
expected_reason_contains=expected_reason_contains,
|
|
256
|
+
actual_reason=actual_reason,
|
|
257
|
+
expected_policy_id=expected_policy_id,
|
|
258
|
+
actual_policy_id=actual_policy_id,
|
|
259
|
+
errors=errors,
|
|
260
|
+
duration_ms=duration_ms,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
@staticmethod
|
|
264
|
+
def _build_agent_context(data: dict[str, Any]) -> AgentContext:
|
|
265
|
+
delegated_by = data.get("delegated_by")
|
|
266
|
+
if delegated_by is not None:
|
|
267
|
+
delegated_by = tuple(delegated_by)
|
|
268
|
+
else:
|
|
269
|
+
delegated_by = ()
|
|
270
|
+
return AgentContext(
|
|
271
|
+
agent_id=data.get("agent_id", "test-agent"),
|
|
272
|
+
version=data.get("version", "1.0"),
|
|
273
|
+
owner=data.get("owner", "test-owner"),
|
|
274
|
+
metadata=data.get("metadata", {}),
|
|
275
|
+
delegated_by=delegated_by,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
@staticmethod
|
|
279
|
+
def _build_intent(data: dict[str, Any]) -> Intent:
|
|
280
|
+
return Intent(
|
|
281
|
+
action=data.get("action", "test_action"),
|
|
282
|
+
reason=data.get("reason", "test reason"),
|
|
283
|
+
confidence=data.get("confidence"),
|
|
284
|
+
metadata=data.get("metadata", {}),
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
@staticmethod
|
|
288
|
+
def _build_tool_request(data: dict[str, Any]) -> ToolRequest:
|
|
289
|
+
effect_str = data.get("effect", "unknown")
|
|
290
|
+
try:
|
|
291
|
+
effect = Effect(effect_str)
|
|
292
|
+
except ValueError:
|
|
293
|
+
effect = Effect.UNKNOWN
|
|
294
|
+
|
|
295
|
+
return ToolRequest(
|
|
296
|
+
tool=data.get("tool", "unknown"),
|
|
297
|
+
action=data.get("action", "unknown"),
|
|
298
|
+
resource_type=data.get("resource_type", "unknown"),
|
|
299
|
+
effect=effect,
|
|
300
|
+
params=data.get("params", {}),
|
|
301
|
+
metadata=data.get("metadata", {}),
|
|
302
|
+
manifest_version=data.get("manifest_version"),
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def cli_main(args: list[str] | None = None) -> int:
|
|
307
|
+
"""CLI entry point for ``tollgate test-policy``.
|
|
308
|
+
|
|
309
|
+
Usage:
|
|
310
|
+
tollgate test-policy policy.yaml --scenarios test_scenarios.yaml
|
|
311
|
+
tollgate test-policy policy.yaml -s test_scenarios.yaml --strict
|
|
312
|
+
|
|
313
|
+
Returns exit code 0 on success, 1 on failure.
|
|
314
|
+
"""
|
|
315
|
+
import argparse
|
|
316
|
+
|
|
317
|
+
parser = argparse.ArgumentParser(
|
|
318
|
+
prog="tollgate test-policy",
|
|
319
|
+
description="Run declarative policy test scenarios against a Tollgate policy.",
|
|
320
|
+
)
|
|
321
|
+
parser.add_argument(
|
|
322
|
+
"policy_path",
|
|
323
|
+
help="Path to the policy YAML file",
|
|
324
|
+
)
|
|
325
|
+
parser.add_argument(
|
|
326
|
+
"--scenarios", "-s",
|
|
327
|
+
required=True,
|
|
328
|
+
help="Path to the test scenarios YAML file",
|
|
329
|
+
)
|
|
330
|
+
parser.add_argument(
|
|
331
|
+
"--strict",
|
|
332
|
+
action="store_true",
|
|
333
|
+
help="Exit with code 1 on any failure (default behavior)",
|
|
334
|
+
)
|
|
335
|
+
parser.add_argument(
|
|
336
|
+
"--quiet", "-q",
|
|
337
|
+
action="store_true",
|
|
338
|
+
help="Only show failures and summary",
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
parsed = parser.parse_args(args)
|
|
342
|
+
|
|
343
|
+
try:
|
|
344
|
+
runner = PolicyTestRunner(parsed.policy_path, parsed.scenarios)
|
|
345
|
+
except (FileNotFoundError, ValueError) as e:
|
|
346
|
+
print(f"Error: {e}", file=sys.stderr)
|
|
347
|
+
return 2
|
|
348
|
+
|
|
349
|
+
results = runner.run()
|
|
350
|
+
|
|
351
|
+
if parsed.quiet:
|
|
352
|
+
# Only show failures
|
|
353
|
+
for r in results.scenario_results:
|
|
354
|
+
if not r.passed:
|
|
355
|
+
print(str(r))
|
|
356
|
+
print(f"\n{results.passed}/{results.total} passed, {results.failed} failed")
|
|
357
|
+
else:
|
|
358
|
+
print(results.summary())
|
|
359
|
+
|
|
360
|
+
return 0 if results.all_passed else 1
|
tollgate/rate_limiter.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""Rate limiting for AI agent tool calls.
|
|
2
|
+
|
|
3
|
+
Provides a sliding-window rate limiter that tracks per-agent, per-tool
|
|
4
|
+
call frequency and blocks calls that exceed configured thresholds.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import time
|
|
9
|
+
from typing import Any, Protocol, runtime_checkable
|
|
10
|
+
|
|
11
|
+
from .types import AgentContext, Effect, ToolRequest
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@runtime_checkable
|
|
15
|
+
class RateLimiter(Protocol):
|
|
16
|
+
"""Protocol for rate limiting backends.
|
|
17
|
+
|
|
18
|
+
Implement this protocol to use a custom backend (Redis, etc.).
|
|
19
|
+
The InMemoryRateLimiter serves as the reference implementation.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
async def check_rate_limit(
|
|
23
|
+
self, agent_ctx: AgentContext, tool_request: ToolRequest
|
|
24
|
+
) -> tuple[bool, str | None, float | None]:
|
|
25
|
+
"""Check whether a tool call should be rate-limited.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
(allowed, reason, retry_after)
|
|
29
|
+
- allowed: True if the call is within limits
|
|
30
|
+
- reason: Human-readable reason if blocked (None if allowed)
|
|
31
|
+
- retry_after: Seconds until the window resets (None if allowed)
|
|
32
|
+
"""
|
|
33
|
+
...
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class RateLimitRule:
|
|
37
|
+
"""A single rate limit rule parsed from config."""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
*,
|
|
42
|
+
agent_id: str = "*",
|
|
43
|
+
tool: str = "*",
|
|
44
|
+
effect: str | None = None,
|
|
45
|
+
max_calls: int,
|
|
46
|
+
window_seconds: int,
|
|
47
|
+
):
|
|
48
|
+
self.agent_id = agent_id
|
|
49
|
+
self.tool = tool
|
|
50
|
+
self.effect = effect
|
|
51
|
+
self.max_calls = max_calls
|
|
52
|
+
self.window_seconds = window_seconds
|
|
53
|
+
|
|
54
|
+
def matches(
|
|
55
|
+
self, agent_ctx: AgentContext, tool_request: ToolRequest
|
|
56
|
+
) -> bool:
|
|
57
|
+
"""Check if this rule applies to the given request."""
|
|
58
|
+
# Agent match
|
|
59
|
+
if self.agent_id != "*" and self.agent_id != agent_ctx.agent_id:
|
|
60
|
+
return False
|
|
61
|
+
|
|
62
|
+
# Tool match (supports prefix wildcard like "mcp:*")
|
|
63
|
+
if self.tool != "*":
|
|
64
|
+
if self.tool.endswith("*"):
|
|
65
|
+
if not tool_request.tool.startswith(self.tool[:-1]):
|
|
66
|
+
return False
|
|
67
|
+
elif self.tool != tool_request.tool:
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
# Effect match
|
|
71
|
+
if self.effect is not None:
|
|
72
|
+
try:
|
|
73
|
+
if Effect(self.effect) != tool_request.effect:
|
|
74
|
+
return False
|
|
75
|
+
except ValueError:
|
|
76
|
+
return False
|
|
77
|
+
|
|
78
|
+
return True
|
|
79
|
+
|
|
80
|
+
def bucket_key(self, agent_ctx: AgentContext) -> str:
|
|
81
|
+
"""Generate a unique bucket key for this rule + agent."""
|
|
82
|
+
return f"{self.agent_id}|{self.tool}|{self.effect or '*'}|{agent_ctx.agent_id}"
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class InMemoryRateLimiter:
|
|
86
|
+
"""Sliding-window rate limiter with in-memory storage.
|
|
87
|
+
|
|
88
|
+
Config is a list of rule dicts, typically from policy.yaml:
|
|
89
|
+
|
|
90
|
+
rate_limits:
|
|
91
|
+
- agent_id: "*"
|
|
92
|
+
tool: "*"
|
|
93
|
+
max_calls: 100
|
|
94
|
+
window_seconds: 60
|
|
95
|
+
- agent_id: "*"
|
|
96
|
+
effect: "write"
|
|
97
|
+
max_calls: 10
|
|
98
|
+
window_seconds: 60
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(self, rules: list[dict[str, Any]] | None = None):
|
|
102
|
+
self._rules: list[RateLimitRule] = []
|
|
103
|
+
self._buckets: dict[str, list[float]] = {}
|
|
104
|
+
self._lock = asyncio.Lock()
|
|
105
|
+
|
|
106
|
+
if rules:
|
|
107
|
+
for r in rules:
|
|
108
|
+
self._rules.append(
|
|
109
|
+
RateLimitRule(
|
|
110
|
+
agent_id=r.get("agent_id", "*"),
|
|
111
|
+
tool=r.get("tool", "*"),
|
|
112
|
+
effect=r.get("effect"),
|
|
113
|
+
max_calls=r["max_calls"],
|
|
114
|
+
window_seconds=r["window_seconds"],
|
|
115
|
+
)
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
async def check_rate_limit(
|
|
119
|
+
self, agent_ctx: AgentContext, tool_request: ToolRequest
|
|
120
|
+
) -> tuple[bool, str | None, float | None]:
|
|
121
|
+
"""Check all matching rules. First violation wins."""
|
|
122
|
+
now = time.time()
|
|
123
|
+
|
|
124
|
+
async with self._lock:
|
|
125
|
+
for rule in self._rules:
|
|
126
|
+
if not rule.matches(agent_ctx, tool_request):
|
|
127
|
+
continue
|
|
128
|
+
|
|
129
|
+
key = rule.bucket_key(agent_ctx)
|
|
130
|
+
window_start = now - rule.window_seconds
|
|
131
|
+
|
|
132
|
+
# Get or create bucket, prune expired entries
|
|
133
|
+
bucket = self._buckets.get(key, [])
|
|
134
|
+
bucket = [t for t in bucket if t > window_start]
|
|
135
|
+
self._buckets[key] = bucket
|
|
136
|
+
|
|
137
|
+
if len(bucket) >= rule.max_calls:
|
|
138
|
+
# Rate limit exceeded
|
|
139
|
+
oldest_in_window = bucket[0] if bucket else now
|
|
140
|
+
retry_after = oldest_in_window + rule.window_seconds - now
|
|
141
|
+
reason = (
|
|
142
|
+
f"Rate limit exceeded: {len(bucket)}/{rule.max_calls} "
|
|
143
|
+
f"calls in {rule.window_seconds}s window "
|
|
144
|
+
f"(agent={agent_ctx.agent_id}, "
|
|
145
|
+
f"tool={rule.tool}, effect={rule.effect or '*'})"
|
|
146
|
+
)
|
|
147
|
+
return False, reason, max(0.0, retry_after)
|
|
148
|
+
|
|
149
|
+
# Record this call
|
|
150
|
+
bucket.append(now)
|
|
151
|
+
|
|
152
|
+
return True, None, None
|
|
153
|
+
|
|
154
|
+
async def reset(self, agent_id: str | None = None) -> None:
|
|
155
|
+
"""Clear rate limit state. If agent_id is given, clear only that agent."""
|
|
156
|
+
async with self._lock:
|
|
157
|
+
if agent_id is None:
|
|
158
|
+
self._buckets.clear()
|
|
159
|
+
else:
|
|
160
|
+
to_remove = [k for k in self._buckets if k.endswith(f"|{agent_id}")]
|
|
161
|
+
for k in to_remove:
|
|
162
|
+
del self._buckets[k]
|