tollgate 1.0.4__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,360 @@
1
+ """Policy testing framework for Tollgate.
2
+
3
+ Enables declarative scenario-based testing of Tollgate policies to prevent
4
+ regressions in CI. Test scenarios are defined in YAML and run against a
5
+ policy evaluator.
6
+
7
+ Usage:
8
+
9
+ # From Python:
10
+ from tollgate.policy_testing import PolicyTestRunner
11
+
12
+ runner = PolicyTestRunner("policy.yaml", "test_scenarios.yaml")
13
+ results = runner.run()
14
+ assert results.all_passed
15
+
16
+ # From CLI:
17
+ tollgate test-policy policy.yaml --scenarios test_scenarios.yaml
18
+
19
+ Scenario file format:
20
+
21
+ scenarios:
22
+ - name: "Allow read operations"
23
+ description: "Read effects should be allowed for trusted agents"
24
+ agent:
25
+ agent_id: "agent-1"
26
+ version: "1.0"
27
+ owner: "team-a"
28
+ intent:
29
+ action: "fetch_data"
30
+ reason: "Customer request"
31
+ tool_request:
32
+ tool: "api:fetch"
33
+ action: "get"
34
+ resource_type: "url"
35
+ effect: "read"
36
+ params: {}
37
+ manifest_version: "1.0.0"
38
+ expected:
39
+ decision: "ALLOW" # Required: ALLOW, ASK, or DENY
40
+ reason_contains: "Rule" # Optional: substring match on reason
41
+ policy_id: "allow_read" # Optional: exact match on policy_id
42
+ """
43
+
44
+ import sys
45
+ import time
46
+ from dataclasses import dataclass, field
47
+ from pathlib import Path
48
+ from typing import Any
49
+
50
+ import yaml
51
+
52
+ from .policy import YamlPolicyEvaluator
53
+ from .types import AgentContext, DecisionType, Effect, Intent, ToolRequest
54
+
55
+
56
+ @dataclass
57
+ class ScenarioResult:
58
+ """Result of a single test scenario."""
59
+
60
+ name: str
61
+ passed: bool
62
+ expected_decision: str
63
+ actual_decision: str
64
+ expected_reason_contains: str | None = None
65
+ actual_reason: str | None = None
66
+ expected_policy_id: str | None = None
67
+ actual_policy_id: str | None = None
68
+ errors: list[str] = field(default_factory=list)
69
+ duration_ms: float = 0.0
70
+
71
+ def __str__(self) -> str:
72
+ status = "PASS" if self.passed else "FAIL"
73
+ msg = f" [{status}] {self.name}"
74
+ if not self.passed:
75
+ for err in self.errors:
76
+ msg += f"\n {err}"
77
+ return msg
78
+
79
+
80
+ @dataclass
81
+ class PolicyTestRunResult:
82
+ """Aggregate result of a test run."""
83
+
84
+ scenario_results: list[ScenarioResult]
85
+ total: int = 0
86
+ passed: int = 0
87
+ failed: int = 0
88
+ duration_ms: float = 0.0
89
+
90
+ @property
91
+ def all_passed(self) -> bool:
92
+ return self.failed == 0
93
+
94
+ def summary(self) -> str:
95
+ lines = [
96
+ "",
97
+ "=" * 60,
98
+ f" Policy Test Results: {self.passed}/{self.total} passed",
99
+ "=" * 60,
100
+ ]
101
+ for result in self.scenario_results:
102
+ lines.append(str(result))
103
+ lines.append("-" * 60)
104
+ status = "ALL PASSED" if self.all_passed else f"{self.failed} FAILED"
105
+ lines.append(f" {status} ({self.duration_ms:.1f}ms)")
106
+ lines.append("")
107
+ return "\n".join(lines)
108
+
109
+
110
+ class PolicyTestRunner:
111
+ """Run declarative policy test scenarios.
112
+
113
+ Args:
114
+ policy_path: Path to the policy YAML file.
115
+ scenarios_path: Path to the test scenarios YAML file.
116
+ policy_evaluator: Optional pre-configured evaluator (overrides policy_path).
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ policy_path: str | Path | None = None,
122
+ scenarios_path: str | Path | None = None,
123
+ *,
124
+ policy_evaluator: Any | None = None,
125
+ scenarios: list[dict[str, Any]] | None = None,
126
+ ):
127
+ # Load policy
128
+ if policy_evaluator is not None:
129
+ self._evaluator = policy_evaluator
130
+ elif policy_path is not None:
131
+ self._evaluator = YamlPolicyEvaluator(policy_path)
132
+ else:
133
+ raise ValueError("Either policy_path or policy_evaluator must be provided")
134
+
135
+ # Load scenarios
136
+ if scenarios is not None:
137
+ self._scenarios = scenarios
138
+ elif scenarios_path is not None:
139
+ self._scenarios = self._load_scenarios(scenarios_path)
140
+ else:
141
+ raise ValueError("Either scenarios_path or scenarios must be provided")
142
+
143
+ self._validate_scenarios()
144
+
145
+ @staticmethod
146
+ def _load_scenarios(path: str | Path) -> list[dict[str, Any]]:
147
+ path = Path(path)
148
+ if not path.exists():
149
+ raise FileNotFoundError(f"Scenarios file not found: {path}")
150
+
151
+ with path.open("r") as f:
152
+ data = yaml.safe_load(f)
153
+
154
+ if not data or "scenarios" not in data:
155
+ raise ValueError(f"Scenarios file must contain a 'scenarios' key: {path}")
156
+
157
+ return data["scenarios"]
158
+
159
+ def _validate_scenarios(self):
160
+ """Validate scenario structure before running."""
161
+ for i, scenario in enumerate(self._scenarios):
162
+ name = scenario.get("name", f"Scenario {i}")
163
+ if "expected" not in scenario:
164
+ raise ValueError(f"Scenario '{name}' is missing 'expected' key")
165
+ if "decision" not in scenario["expected"]:
166
+ raise ValueError(
167
+ f"Scenario '{name}' expected section must include 'decision'"
168
+ )
169
+ try:
170
+ DecisionType(scenario["expected"]["decision"])
171
+ except ValueError:
172
+ raise ValueError(
173
+ f"Scenario '{name}' has invalid expected decision: "
174
+ f"'{scenario['expected']['decision']}'"
175
+ ) from None
176
+
177
+ def run(self) -> PolicyTestRunResult:
178
+ """Run all test scenarios and return results."""
179
+ start = time.monotonic()
180
+ results: list[ScenarioResult] = []
181
+
182
+ for scenario in self._scenarios:
183
+ result = self._run_scenario(scenario)
184
+ results.append(result)
185
+
186
+ total_ms = (time.monotonic() - start) * 1000
187
+ passed = sum(1 for r in results if r.passed)
188
+ failed = len(results) - passed
189
+
190
+ return PolicyTestRunResult(
191
+ scenario_results=results,
192
+ total=len(results),
193
+ passed=passed,
194
+ failed=failed,
195
+ duration_ms=total_ms,
196
+ )
197
+
198
+ def _run_scenario(self, scenario: dict[str, Any]) -> ScenarioResult:
199
+ """Run a single test scenario."""
200
+ name = scenario.get("name", "Unnamed scenario")
201
+ expected = scenario["expected"]
202
+ expected_decision = expected["decision"]
203
+ expected_reason_contains = expected.get("reason_contains")
204
+ expected_policy_id = expected.get("policy_id")
205
+
206
+ start = time.monotonic()
207
+ errors: list[str] = []
208
+
209
+ try:
210
+ agent_ctx = self._build_agent_context(scenario.get("agent", {}))
211
+ intent = self._build_intent(scenario.get("intent", {}))
212
+ tool_request = self._build_tool_request(scenario.get("tool_request", {}))
213
+
214
+ decision = self._evaluator.evaluate(agent_ctx, intent, tool_request)
215
+
216
+ actual_decision = decision.decision.value
217
+ actual_reason = decision.reason
218
+ actual_policy_id = decision.policy_id
219
+
220
+ # Check decision
221
+ if actual_decision != expected_decision:
222
+ errors.append(
223
+ f"Decision: expected '{expected_decision}', got '{actual_decision}'"
224
+ )
225
+
226
+ # Check reason (substring match)
227
+ if expected_reason_contains and expected_reason_contains not in (
228
+ actual_reason or ""
229
+ ):
230
+ errors.append(
231
+ f"Reason: expected to contain '{expected_reason_contains}', "
232
+ f"got '{actual_reason}'"
233
+ )
234
+
235
+ # Check policy_id
236
+ if expected_policy_id and actual_policy_id != expected_policy_id:
237
+ errors.append(
238
+ f"Policy ID: expected '{expected_policy_id}', "
239
+ f"got '{actual_policy_id}'"
240
+ )
241
+
242
+ except Exception as e:
243
+ actual_decision = "ERROR"
244
+ actual_reason = str(e)
245
+ actual_policy_id = None
246
+ errors.append(f"Exception: {e}")
247
+
248
+ duration_ms = (time.monotonic() - start) * 1000
249
+
250
+ return ScenarioResult(
251
+ name=name,
252
+ passed=len(errors) == 0,
253
+ expected_decision=expected_decision,
254
+ actual_decision=actual_decision,
255
+ expected_reason_contains=expected_reason_contains,
256
+ actual_reason=actual_reason,
257
+ expected_policy_id=expected_policy_id,
258
+ actual_policy_id=actual_policy_id,
259
+ errors=errors,
260
+ duration_ms=duration_ms,
261
+ )
262
+
263
+ @staticmethod
264
+ def _build_agent_context(data: dict[str, Any]) -> AgentContext:
265
+ delegated_by = data.get("delegated_by")
266
+ if delegated_by is not None:
267
+ delegated_by = tuple(delegated_by)
268
+ else:
269
+ delegated_by = ()
270
+ return AgentContext(
271
+ agent_id=data.get("agent_id", "test-agent"),
272
+ version=data.get("version", "1.0"),
273
+ owner=data.get("owner", "test-owner"),
274
+ metadata=data.get("metadata", {}),
275
+ delegated_by=delegated_by,
276
+ )
277
+
278
+ @staticmethod
279
+ def _build_intent(data: dict[str, Any]) -> Intent:
280
+ return Intent(
281
+ action=data.get("action", "test_action"),
282
+ reason=data.get("reason", "test reason"),
283
+ confidence=data.get("confidence"),
284
+ metadata=data.get("metadata", {}),
285
+ )
286
+
287
+ @staticmethod
288
+ def _build_tool_request(data: dict[str, Any]) -> ToolRequest:
289
+ effect_str = data.get("effect", "unknown")
290
+ try:
291
+ effect = Effect(effect_str)
292
+ except ValueError:
293
+ effect = Effect.UNKNOWN
294
+
295
+ return ToolRequest(
296
+ tool=data.get("tool", "unknown"),
297
+ action=data.get("action", "unknown"),
298
+ resource_type=data.get("resource_type", "unknown"),
299
+ effect=effect,
300
+ params=data.get("params", {}),
301
+ metadata=data.get("metadata", {}),
302
+ manifest_version=data.get("manifest_version"),
303
+ )
304
+
305
+
306
+ def cli_main(args: list[str] | None = None) -> int:
307
+ """CLI entry point for ``tollgate test-policy``.
308
+
309
+ Usage:
310
+ tollgate test-policy policy.yaml --scenarios test_scenarios.yaml
311
+ tollgate test-policy policy.yaml -s test_scenarios.yaml --strict
312
+
313
+ Returns exit code 0 on success, 1 on failure.
314
+ """
315
+ import argparse
316
+
317
+ parser = argparse.ArgumentParser(
318
+ prog="tollgate test-policy",
319
+ description="Run declarative policy test scenarios against a Tollgate policy.",
320
+ )
321
+ parser.add_argument(
322
+ "policy_path",
323
+ help="Path to the policy YAML file",
324
+ )
325
+ parser.add_argument(
326
+ "--scenarios", "-s",
327
+ required=True,
328
+ help="Path to the test scenarios YAML file",
329
+ )
330
+ parser.add_argument(
331
+ "--strict",
332
+ action="store_true",
333
+ help="Exit with code 1 on any failure (default behavior)",
334
+ )
335
+ parser.add_argument(
336
+ "--quiet", "-q",
337
+ action="store_true",
338
+ help="Only show failures and summary",
339
+ )
340
+
341
+ parsed = parser.parse_args(args)
342
+
343
+ try:
344
+ runner = PolicyTestRunner(parsed.policy_path, parsed.scenarios)
345
+ except (FileNotFoundError, ValueError) as e:
346
+ print(f"Error: {e}", file=sys.stderr)
347
+ return 2
348
+
349
+ results = runner.run()
350
+
351
+ if parsed.quiet:
352
+ # Only show failures
353
+ for r in results.scenario_results:
354
+ if not r.passed:
355
+ print(str(r))
356
+ print(f"\n{results.passed}/{results.total} passed, {results.failed} failed")
357
+ else:
358
+ print(results.summary())
359
+
360
+ return 0 if results.all_passed else 1
@@ -0,0 +1,162 @@
1
+ """Rate limiting for AI agent tool calls.
2
+
3
+ Provides a sliding-window rate limiter that tracks per-agent, per-tool
4
+ call frequency and blocks calls that exceed configured thresholds.
5
+ """
6
+
7
+ import asyncio
8
+ import time
9
+ from typing import Any, Protocol, runtime_checkable
10
+
11
+ from .types import AgentContext, Effect, ToolRequest
12
+
13
+
14
+ @runtime_checkable
15
+ class RateLimiter(Protocol):
16
+ """Protocol for rate limiting backends.
17
+
18
+ Implement this protocol to use a custom backend (Redis, etc.).
19
+ The InMemoryRateLimiter serves as the reference implementation.
20
+ """
21
+
22
+ async def check_rate_limit(
23
+ self, agent_ctx: AgentContext, tool_request: ToolRequest
24
+ ) -> tuple[bool, str | None, float | None]:
25
+ """Check whether a tool call should be rate-limited.
26
+
27
+ Returns:
28
+ (allowed, reason, retry_after)
29
+ - allowed: True if the call is within limits
30
+ - reason: Human-readable reason if blocked (None if allowed)
31
+ - retry_after: Seconds until the window resets (None if allowed)
32
+ """
33
+ ...
34
+
35
+
36
+ class RateLimitRule:
37
+ """A single rate limit rule parsed from config."""
38
+
39
+ def __init__(
40
+ self,
41
+ *,
42
+ agent_id: str = "*",
43
+ tool: str = "*",
44
+ effect: str | None = None,
45
+ max_calls: int,
46
+ window_seconds: int,
47
+ ):
48
+ self.agent_id = agent_id
49
+ self.tool = tool
50
+ self.effect = effect
51
+ self.max_calls = max_calls
52
+ self.window_seconds = window_seconds
53
+
54
+ def matches(
55
+ self, agent_ctx: AgentContext, tool_request: ToolRequest
56
+ ) -> bool:
57
+ """Check if this rule applies to the given request."""
58
+ # Agent match
59
+ if self.agent_id != "*" and self.agent_id != agent_ctx.agent_id:
60
+ return False
61
+
62
+ # Tool match (supports prefix wildcard like "mcp:*")
63
+ if self.tool != "*":
64
+ if self.tool.endswith("*"):
65
+ if not tool_request.tool.startswith(self.tool[:-1]):
66
+ return False
67
+ elif self.tool != tool_request.tool:
68
+ return False
69
+
70
+ # Effect match
71
+ if self.effect is not None:
72
+ try:
73
+ if Effect(self.effect) != tool_request.effect:
74
+ return False
75
+ except ValueError:
76
+ return False
77
+
78
+ return True
79
+
80
+ def bucket_key(self, agent_ctx: AgentContext) -> str:
81
+ """Generate a unique bucket key for this rule + agent."""
82
+ return f"{self.agent_id}|{self.tool}|{self.effect or '*'}|{agent_ctx.agent_id}"
83
+
84
+
85
+ class InMemoryRateLimiter:
86
+ """Sliding-window rate limiter with in-memory storage.
87
+
88
+ Config is a list of rule dicts, typically from policy.yaml:
89
+
90
+ rate_limits:
91
+ - agent_id: "*"
92
+ tool: "*"
93
+ max_calls: 100
94
+ window_seconds: 60
95
+ - agent_id: "*"
96
+ effect: "write"
97
+ max_calls: 10
98
+ window_seconds: 60
99
+ """
100
+
101
+ def __init__(self, rules: list[dict[str, Any]] | None = None):
102
+ self._rules: list[RateLimitRule] = []
103
+ self._buckets: dict[str, list[float]] = {}
104
+ self._lock = asyncio.Lock()
105
+
106
+ if rules:
107
+ for r in rules:
108
+ self._rules.append(
109
+ RateLimitRule(
110
+ agent_id=r.get("agent_id", "*"),
111
+ tool=r.get("tool", "*"),
112
+ effect=r.get("effect"),
113
+ max_calls=r["max_calls"],
114
+ window_seconds=r["window_seconds"],
115
+ )
116
+ )
117
+
118
+ async def check_rate_limit(
119
+ self, agent_ctx: AgentContext, tool_request: ToolRequest
120
+ ) -> tuple[bool, str | None, float | None]:
121
+ """Check all matching rules. First violation wins."""
122
+ now = time.time()
123
+
124
+ async with self._lock:
125
+ for rule in self._rules:
126
+ if not rule.matches(agent_ctx, tool_request):
127
+ continue
128
+
129
+ key = rule.bucket_key(agent_ctx)
130
+ window_start = now - rule.window_seconds
131
+
132
+ # Get or create bucket, prune expired entries
133
+ bucket = self._buckets.get(key, [])
134
+ bucket = [t for t in bucket if t > window_start]
135
+ self._buckets[key] = bucket
136
+
137
+ if len(bucket) >= rule.max_calls:
138
+ # Rate limit exceeded
139
+ oldest_in_window = bucket[0] if bucket else now
140
+ retry_after = oldest_in_window + rule.window_seconds - now
141
+ reason = (
142
+ f"Rate limit exceeded: {len(bucket)}/{rule.max_calls} "
143
+ f"calls in {rule.window_seconds}s window "
144
+ f"(agent={agent_ctx.agent_id}, "
145
+ f"tool={rule.tool}, effect={rule.effect or '*'})"
146
+ )
147
+ return False, reason, max(0.0, retry_after)
148
+
149
+ # Record this call
150
+ bucket.append(now)
151
+
152
+ return True, None, None
153
+
154
+ async def reset(self, agent_id: str | None = None) -> None:
155
+ """Clear rate limit state. If agent_id is given, clear only that agent."""
156
+ async with self._lock:
157
+ if agent_id is None:
158
+ self._buckets.clear()
159
+ else:
160
+ to_remove = [k for k in self._buckets if k.endswith(f"|{agent_id}")]
161
+ for k in to_remove:
162
+ del self._buckets[k]