agent-mcp-gateway 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agent-mcp-gateway might be problematic. Click here for more details.

src/policy.py ADDED
@@ -0,0 +1,494 @@
1
+ """Policy evaluation engine for Agent MCP Gateway.
2
+
3
+ This module implements the core policy engine that evaluates agent permissions
4
+ against configured rules. It enforces a strict deny-before-allow precedence
5
+ and supports wildcard pattern matching for flexible rule definitions.
6
+
7
+ Precedence Order (CRITICAL - DO NOT CHANGE):
8
+ 1. Explicit deny rules (specific tool names)
9
+ 2. Explicit allow rules (specific tool names)
10
+ 3. Wildcard deny rules (patterns like drop_*)
11
+ 4. Wildcard allow rules (patterns like get_* or *)
12
+ 5. Default policy (from defaults.deny_on_missing_agent)
13
+ """
14
+
15
+ import fnmatch
16
+ import logging
17
+ import threading
18
+ from typing import Literal, Optional
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class PolicyEngine:
25
+ """Evaluates agent permissions against configured rules.
26
+
27
+ This class implements the security policy evaluation logic for the gateway.
28
+ It determines whether agents can access specific servers and tools based on
29
+ configurable allow/deny rules with wildcard pattern support.
30
+ """
31
+
32
+ def __init__(self, rules: dict):
33
+ """Initialize policy engine with rules dictionary.
34
+
35
+ Args:
36
+ rules: Gateway rules configuration with structure:
37
+ {
38
+ "agents": {
39
+ "agent_id": {
40
+ "allow": {"servers": [...], "tools": {...}},
41
+ "deny": {"servers": [...], "tools": {...}}
42
+ }
43
+ },
44
+ "defaults": {"deny_on_missing_agent": bool}
45
+ }
46
+ """
47
+ self.rules = rules
48
+ self.agents = rules.get("agents", {})
49
+ self.defaults = rules.get("defaults", {})
50
+ self._lock = threading.RLock() # Reentrant lock for nested calls
51
+
52
+ def can_access_server(self, agent_id: str, server: str) -> bool:
53
+ """Check if agent can access a server.
54
+
55
+ An agent can access a server if:
56
+ - The agent exists in the rules
57
+ - The server is in the agent's allow.servers list (or "*" is present)
58
+ - The server is not in the agent's deny.servers list
59
+
60
+ Args:
61
+ agent_id: Agent identifier
62
+ server: Server name
63
+
64
+ Returns:
65
+ True if agent can access server, False otherwise
66
+ """
67
+ with self._lock:
68
+ # Check if agent exists in rules
69
+ if agent_id not in self.agents:
70
+ # Unknown agent - check default policy
71
+ return not self.defaults.get("deny_on_missing_agent", True)
72
+
73
+ agent_rules = self.agents[agent_id]
74
+
75
+ # Check deny rules first (deny takes precedence)
76
+ deny_servers = agent_rules.get("deny", {}).get("servers", [])
77
+ if server in deny_servers or "*" in deny_servers:
78
+ return False
79
+
80
+ # Check for wildcard deny patterns
81
+ for pattern in deny_servers:
82
+ if self._matches_pattern(server, pattern):
83
+ return False
84
+
85
+ # Check allow rules
86
+ allow_servers = agent_rules.get("allow", {}).get("servers", [])
87
+
88
+ # Explicit allow or wildcard allow
89
+ if server in allow_servers or "*" in allow_servers:
90
+ return True
91
+
92
+ # Check for wildcard allow patterns
93
+ for pattern in allow_servers:
94
+ if self._matches_pattern(server, pattern):
95
+ return True
96
+
97
+ # Not explicitly allowed
98
+ return False
99
+
100
+ def can_access_tool(self, agent_id: str, server: str, tool: str) -> bool:
101
+ """Check if agent can access a specific tool.
102
+
103
+ Applies deny-before-allow precedence:
104
+ 1. Explicit deny rules (specific tool names)
105
+ 2. Explicit allow rules (specific tool names)
106
+ 3. Wildcard deny rules (patterns like drop_*)
107
+ 4. Wildcard allow rules (patterns like get_* or *)
108
+ 5. Default policy
109
+
110
+ Args:
111
+ agent_id: Agent identifier
112
+ server: Server name
113
+ tool: Tool name
114
+
115
+ Returns:
116
+ True if agent can access tool, False otherwise
117
+ """
118
+ with self._lock:
119
+ # First, agent must have access to the server
120
+ if not self.can_access_server(agent_id, server):
121
+ return False
122
+
123
+ # Check if agent exists in rules
124
+ if agent_id not in self.agents:
125
+ # Unknown agent but has server access - check default policy
126
+ return not self.defaults.get("deny_on_missing_agent", True)
127
+
128
+ agent_rules = self.agents[agent_id]
129
+
130
+ # Get tool rules for this server
131
+ deny_tools = agent_rules.get("deny", {}).get("tools", {}).get(server, [])
132
+ allow_tools = agent_rules.get("allow", {}).get("tools", {}).get(server, [])
133
+
134
+ # Separate explicit rules from wildcard patterns
135
+ explicit_deny = []
136
+ wildcard_deny = []
137
+ explicit_allow = []
138
+ wildcard_allow = []
139
+
140
+ for rule in deny_tools:
141
+ if "*" in rule:
142
+ wildcard_deny.append(rule)
143
+ else:
144
+ explicit_deny.append(rule)
145
+
146
+ for rule in allow_tools:
147
+ if "*" in rule:
148
+ wildcard_allow.append(rule)
149
+ else:
150
+ explicit_allow.append(rule)
151
+
152
+ # Apply precedence order (CRITICAL - DO NOT CHANGE)
153
+
154
+ # 1. Explicit deny rules
155
+ if tool in explicit_deny:
156
+ return False
157
+
158
+ # 2. Explicit allow rules
159
+ if tool in explicit_allow:
160
+ return True
161
+
162
+ # 3. Wildcard deny rules
163
+ for pattern in wildcard_deny:
164
+ if self._matches_pattern(tool, pattern):
165
+ return False
166
+
167
+ # 4. Wildcard allow rules
168
+ for pattern in wildcard_allow:
169
+ if self._matches_pattern(tool, pattern):
170
+ return True
171
+
172
+ # 5. Default policy - if no rules match, deny
173
+ return False
174
+
175
+ def get_allowed_servers(self, agent_id: str) -> list[str]:
176
+ """Get list of servers this agent can access.
177
+
178
+ Note: This returns the configured server list, not all possible servers.
179
+ If wildcard "*" is present, returns ["*"] to indicate all servers allowed.
180
+
181
+ Args:
182
+ agent_id: Agent identifier
183
+
184
+ Returns:
185
+ List of server names the agent can access, or ["*"] for wildcard
186
+ """
187
+ with self._lock:
188
+ # Check if agent exists in rules
189
+ if agent_id not in self.agents:
190
+ # Unknown agent - check default policy
191
+ if self.defaults.get("deny_on_missing_agent", True):
192
+ return []
193
+ else:
194
+ # If not denying unknown agents, return empty list
195
+ # (caller should interpret this as "depends on what servers exist")
196
+ return []
197
+
198
+ agent_rules = self.agents[agent_id]
199
+ allow_servers = agent_rules.get("allow", {}).get("servers", [])
200
+ deny_servers = agent_rules.get("deny", {}).get("servers", [])
201
+
202
+ # If wildcard allow and no wildcard deny, return wildcard
203
+ if "*" in allow_servers and "*" not in deny_servers:
204
+ return ["*"]
205
+
206
+ # Filter out denied servers
207
+ allowed = []
208
+ for server in allow_servers:
209
+ if server == "*":
210
+ continue
211
+
212
+ # Check if this server is denied
213
+ is_denied = False
214
+ if server in deny_servers:
215
+ is_denied = True
216
+ else:
217
+ # Check wildcard deny patterns
218
+ for pattern in deny_servers:
219
+ if self._matches_pattern(server, pattern):
220
+ is_denied = True
221
+ break
222
+
223
+ if not is_denied:
224
+ allowed.append(server)
225
+
226
+ return allowed
227
+
228
+ def get_allowed_tools(self, agent_id: str, server: str) -> list[str] | Literal["*"]:
229
+ """Get list of allowed tools for agent on server.
230
+
231
+ Returns either a list of specific tool names or "*" to indicate
232
+ all tools are allowed (subject to deny rules being checked at access time).
233
+
234
+ Args:
235
+ agent_id: Agent identifier
236
+ server: Server name
237
+
238
+ Returns:
239
+ List of tool names or "*" for wildcard access
240
+ """
241
+ with self._lock:
242
+ # Agent must have server access first
243
+ if not self.can_access_server(agent_id, server):
244
+ return []
245
+
246
+ # Check if agent exists in rules
247
+ if agent_id not in self.agents:
248
+ # Unknown agent but has server access
249
+ if not self.defaults.get("deny_on_missing_agent", True):
250
+ return "*"
251
+ return []
252
+
253
+ agent_rules = self.agents[agent_id]
254
+ allow_tools = agent_rules.get("allow", {}).get("tools", {}).get(server, [])
255
+
256
+ # If wildcard allow, return "*"
257
+ if "*" in allow_tools:
258
+ return "*"
259
+
260
+ # Return list of allowed tools (including patterns)
261
+ return allow_tools
262
+
263
+ def get_policy_decision_reason(self, agent_id: str, server: str, tool: str | None = None) -> str:
264
+ """Get human-readable reason for policy decision.
265
+
266
+ Provides clear explanation of why access was allowed or denied,
267
+ useful for debugging and audit logs.
268
+
269
+ Args:
270
+ agent_id: Agent identifier
271
+ server: Server name
272
+ tool: Optional tool name
273
+
274
+ Returns:
275
+ String explaining why access was allowed/denied
276
+ """
277
+ with self._lock:
278
+ # Check if agent exists
279
+ if agent_id not in self.agents:
280
+ if self.defaults.get("deny_on_missing_agent", True):
281
+ return f"Agent '{agent_id}' not found in rules; default policy denies access"
282
+ else:
283
+ return f"Agent '{agent_id}' not found in rules; default policy allows access"
284
+
285
+ agent_rules = self.agents[agent_id]
286
+
287
+ # Check server access
288
+ deny_servers = agent_rules.get("deny", {}).get("servers", [])
289
+ allow_servers = agent_rules.get("allow", {}).get("servers", [])
290
+
291
+ # Check explicit server deny
292
+ if server in deny_servers:
293
+ return f"Server '{server}' explicitly denied for agent '{agent_id}'"
294
+
295
+ # Check wildcard server deny
296
+ for pattern in deny_servers:
297
+ if self._matches_pattern(server, pattern):
298
+ return f"Server '{server}' denied by pattern '{pattern}' for agent '{agent_id}'"
299
+
300
+ # Check server allow
301
+ server_allowed = False
302
+ server_allow_reason = ""
303
+
304
+ if server in allow_servers:
305
+ server_allowed = True
306
+ server_allow_reason = f"Server '{server}' explicitly allowed"
307
+ elif "*" in allow_servers:
308
+ server_allowed = True
309
+ server_allow_reason = "Server allowed by wildcard '*'"
310
+ else:
311
+ # Check wildcard patterns
312
+ for pattern in allow_servers:
313
+ if self._matches_pattern(server, pattern):
314
+ server_allowed = True
315
+ server_allow_reason = f"Server '{server}' allowed by pattern '{pattern}'"
316
+ break
317
+
318
+ if not server_allowed:
319
+ return f"Server '{server}' not in allowed list for agent '{agent_id}'"
320
+
321
+ # If no tool specified, return server access reason
322
+ if tool is None:
323
+ return server_allow_reason
324
+
325
+ # Check tool access
326
+ deny_tools = agent_rules.get("deny", {}).get("tools", {}).get(server, [])
327
+ allow_tools = agent_rules.get("allow", {}).get("tools", {}).get(server, [])
328
+
329
+ # Check explicit tool deny
330
+ if tool in deny_tools:
331
+ return f"Tool '{tool}' explicitly denied for agent '{agent_id}' on server '{server}'"
332
+
333
+ # Check explicit tool allow
334
+ if tool in allow_tools:
335
+ return f"Tool '{tool}' explicitly allowed for agent '{agent_id}' on server '{server}'"
336
+
337
+ # Check wildcard deny patterns
338
+ for pattern in deny_tools:
339
+ if "*" in pattern and self._matches_pattern(tool, pattern):
340
+ return f"Tool '{tool}' denied by pattern '{pattern}' for agent '{agent_id}' on server '{server}'"
341
+
342
+ # Check wildcard allow patterns
343
+ for pattern in allow_tools:
344
+ if "*" in pattern and self._matches_pattern(tool, pattern):
345
+ return f"Tool '{tool}' allowed by pattern '{pattern}' for agent '{agent_id}' on server '{server}'"
346
+
347
+ # No matching rule
348
+ return f"Tool '{tool}' not in allowed list for agent '{agent_id}' on server '{server}'"
349
+
350
+ def _matches_pattern(self, name: str, pattern: str) -> bool:
351
+ """Check if name matches wildcard pattern.
352
+
353
+ Uses glob-style pattern matching:
354
+ - * matches any sequence of characters
355
+ - ? matches any single character
356
+ - [seq] matches any character in seq
357
+ - [!seq] matches any character not in seq
358
+
359
+ Args:
360
+ name: String to match
361
+ pattern: Pattern with wildcards (*, get_*, etc.)
362
+
363
+ Returns:
364
+ True if name matches pattern
365
+ """
366
+ return fnmatch.fnmatch(name, pattern)
367
+
368
+ def _compute_rule_diff(self, old_rules: dict, new_rules: dict) -> dict[str, list[str]]:
369
+ """Compute differences between old and new rules.
370
+
371
+ Args:
372
+ old_rules: Current rules dictionary
373
+ new_rules: New rules dictionary
374
+
375
+ Returns:
376
+ Dictionary with keys 'added', 'removed', 'modified' containing lists of agent IDs
377
+ """
378
+ old_agents = set(old_rules.get("agents", {}).keys())
379
+ new_agents = set(new_rules.get("agents", {}).keys())
380
+
381
+ added = sorted(new_agents - old_agents)
382
+ removed = sorted(old_agents - new_agents)
383
+
384
+ # Check for modified agents (agents present in both but with different configs)
385
+ modified = []
386
+ for agent_id in sorted(old_agents & new_agents):
387
+ old_config = old_rules.get("agents", {}).get(agent_id)
388
+ new_config = new_rules.get("agents", {}).get(agent_id)
389
+ if old_config != new_config:
390
+ modified.append(agent_id)
391
+
392
+ # Check if defaults changed
393
+ defaults_changed = old_rules.get("defaults") != new_rules.get("defaults")
394
+
395
+ return {
396
+ "added": added,
397
+ "removed": removed,
398
+ "modified": modified,
399
+ "defaults_changed": defaults_changed
400
+ }
401
+
402
+ def reload(self, new_rules: dict) -> tuple[bool, Optional[str]]:
403
+ """Reload policy rules with validation and atomic swap.
404
+
405
+ This method validates new rules before applying them. If validation fails,
406
+ the current rules remain unchanged. If validation succeeds, rules are
407
+ atomically swapped to the new configuration.
408
+
409
+ Thread-safety: This method is thread-safe and uses an internal lock to
410
+ prevent race conditions during reload operations.
411
+
412
+ Args:
413
+ new_rules: New gateway rules configuration with structure:
414
+ {
415
+ "agents": {
416
+ "agent_id": {
417
+ "allow": {"servers": [...], "tools": {...}},
418
+ "deny": {"servers": [...], "tools": {...}}
419
+ }
420
+ },
421
+ "defaults": {"deny_on_missing_agent": bool}
422
+ }
423
+
424
+ Returns:
425
+ Tuple of (success, error_message):
426
+ - (True, None) if reload successful
427
+ - (False, error_message) if validation failed
428
+
429
+ Example:
430
+ >>> engine = PolicyEngine(old_rules)
431
+ >>> success, error = engine.reload(new_rules)
432
+ >>> if success:
433
+ ... print("Rules reloaded successfully")
434
+ ... else:
435
+ ... print(f"Reload failed: {error}")
436
+ """
437
+ with self._lock:
438
+ logger.info("PolicyEngine reload initiated")
439
+
440
+ # Import validation function to avoid circular dependency at module level
441
+ from src.config import validate_gateway_rules
442
+
443
+ # Validate new rules structure
444
+ valid, error_msg = validate_gateway_rules(new_rules)
445
+ if not valid:
446
+ logger.error(f"PolicyEngine reload failed: Validation error: {error_msg}")
447
+ return False, f"Validation error: {error_msg}"
448
+
449
+ logger.info("PolicyEngine reload: Validation passed")
450
+
451
+ # Store old rules for potential rollback
452
+ old_rules = self.rules
453
+
454
+ try:
455
+ # Compute diff for logging
456
+ diff = self._compute_rule_diff(old_rules, new_rules)
457
+
458
+ # Log changes
459
+ if diff["added"]:
460
+ logger.info(f"PolicyEngine reload: Added agents: {', '.join(diff['added'])}")
461
+ if diff["removed"]:
462
+ logger.info(f"PolicyEngine reload: Removed agents: {', '.join(diff['removed'])}")
463
+ if diff["modified"]:
464
+ logger.info(f"PolicyEngine reload: Modified agents: {', '.join(diff['modified'])}")
465
+ if diff["defaults_changed"]:
466
+ logger.info("PolicyEngine reload: Default policy changed")
467
+
468
+ # Summarize changes
469
+ total_changes = len(diff["added"]) + len(diff["removed"]) + len(diff["modified"])
470
+ if total_changes == 0 and not diff["defaults_changed"]:
471
+ logger.info("PolicyEngine reload: No changes detected in rules")
472
+ else:
473
+ logger.info(
474
+ f"PolicyEngine reload: Rules updated - "
475
+ f"{len(diff['added'])} agents added, "
476
+ f"{len(diff['removed'])} removed, "
477
+ f"{len(diff['modified'])} modified"
478
+ )
479
+
480
+ # Atomic swap: Update internal state
481
+ self.rules = new_rules
482
+ self.agents = new_rules.get("agents", {})
483
+ self.defaults = new_rules.get("defaults", {})
484
+
485
+ logger.info("PolicyEngine reload complete")
486
+ return True, None
487
+
488
+ except Exception as e:
489
+ # Rollback on any error during swap
490
+ logger.error(f"PolicyEngine reload failed: Unexpected error during swap: {e}")
491
+ self.rules = old_rules
492
+ self.agents = old_rules.get("agents", {})
493
+ self.defaults = old_rules.get("defaults", {})
494
+ return False, f"Unexpected error during reload: {str(e)}"