agent-mcp-gateway 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agent-mcp-gateway might be problematic. Click here for more details.

src/gateway.py ADDED
@@ -0,0 +1,527 @@
1
+ """Gateway server for Agent MCP Gateway."""
2
+
3
+ import asyncio
4
+ import fnmatch
5
+ from typing import Annotated, Any, Optional
6
+
7
+ from fastmcp import FastMCP
8
+ from fastmcp.exceptions import ToolError
9
+ from pydantic import BaseModel, Field
10
+ from .policy import PolicyEngine
11
+ from .proxy import ProxyManager
12
+
13
+
14
+ # Output schemas for gateway tools
15
+ class ServerInfo(BaseModel):
16
+ """Server information returned by list_servers."""
17
+ name: Annotated[str, Field(description="Server name (use in get_server_tools and execute_tool)")]
18
+ description: Annotated[Optional[str], Field(description="What this server provides (from config or null if not configured)")] = None
19
+ transport: Annotated[Optional[str], Field(description="How server communicates: stdio or http (only if include_metadata=true)")] = None
20
+ command: Annotated[Optional[str], Field(description="Command that runs this server (only if include_metadata=true and transport=stdio)")] = None
21
+ url: Annotated[Optional[str], Field(description="Server endpoint (only if include_metadata=true and transport=http)")] = None
22
+
23
+
24
+ class ToolDefinition(BaseModel):
25
+ """Tool definition from downstream server."""
26
+ name: Annotated[str, Field(description="Tool name (use in execute_tool)")]
27
+ description: Annotated[str, Field(description="What this tool does")]
28
+ inputSchema: Annotated[dict, Field(description="JSON Schema defining required/optional parameters for execute_tool args")]
29
+
30
+
31
+ class GetServerToolsResponse(BaseModel):
32
+ """Response from get_server_tools."""
33
+ tools: Annotated[list[ToolDefinition], Field(description="Tool definitions you can access")]
34
+ server: Annotated[str, Field(description="Queried server name")]
35
+ total_available: Annotated[int, Field(description="Total tools on server (may exceed returned if filtered by permissions/criteria)")]
36
+ returned: Annotated[int, Field(description="Count of tools returned (less than total_available is normal due to filtering)")]
37
+ tokens_used: Annotated[Optional[int], Field(description="Tokens used in schemas (if max_schema_tokens was set)")] = None
38
+ error: Annotated[Optional[str], Field(description="Error message if request failed")] = None
39
+
40
+
41
+ class ToolExecutionResponse(BaseModel):
42
+ """Response from execute_tool."""
43
+ content: Annotated[list[dict], Field(description="Result from the downstream tool (format varies by tool)")]
44
+ isError: Annotated[bool, Field(description="True if the downstream tool returned an error")]
45
+
46
+
47
+ class GatewayStatusResponse(BaseModel):
48
+ """Response from get_gateway_status (debug tool)."""
49
+ reload_status: Annotated[Optional[dict], Field(description="Hot reload history with timestamps and errors")]
50
+ policy_state: Annotated[dict, Field(description="Policy engine configuration (agent count, defaults)")]
51
+ available_servers: Annotated[list[str], Field(description="All configured server names")]
52
+ config_paths: Annotated[dict, Field(description="File paths to gateway configuration")]
53
+ message: Annotated[str, Field(description="Summary status message")]
54
+
55
+
56
+ # Create FastMCP instance
57
+ gateway = FastMCP(name="Agent MCP Gateway")
58
+
59
+ # Module-level storage for configurations (set by main.py)
60
+ _policy_engine: PolicyEngine | None = None
61
+ _mcp_config: dict | None = None
62
+ _proxy_manager: ProxyManager | None = None
63
+ _check_config_changes_fn: Any | None = None # Fallback reload checker
64
+ _get_reload_status_fn: Any | None = None # Reload status getter for diagnostics
65
+ _default_agent_id: str | None = None # Default agent for fallback chain
66
+ _debug_mode: bool = False # Debug mode flag
67
+
68
+
69
+ def initialize_gateway(
70
+ policy_engine: PolicyEngine,
71
+ mcp_config: dict,
72
+ proxy_manager: ProxyManager | None = None,
73
+ check_config_changes_fn: Any = None,
74
+ get_reload_status_fn: Any = None,
75
+ default_agent_id: str | None = None,
76
+ debug_mode: bool = False
77
+ ):
78
+ """Initialize gateway with policy engine, MCP config, and proxy manager.
79
+
80
+ This must be called before the gateway starts accepting requests.
81
+
82
+ Args:
83
+ policy_engine: PolicyEngine instance for access control
84
+ mcp_config: MCP servers configuration dictionary
85
+ proxy_manager: Optional ProxyManager instance (required for get_server_tools)
86
+ check_config_changes_fn: Optional function to check for config changes (fallback mechanism)
87
+ get_reload_status_fn: Optional function to get reload status for diagnostics
88
+ default_agent_id: Optional default agent ID from GATEWAY_DEFAULT_AGENT env var for fallback chain
89
+ debug_mode: Whether debug mode is enabled (exposes get_gateway_status tool)
90
+ """
91
+ global _policy_engine, _mcp_config, _proxy_manager, _check_config_changes_fn, _get_reload_status_fn, _default_agent_id, _debug_mode
92
+ _policy_engine = policy_engine
93
+ _mcp_config = mcp_config
94
+ _proxy_manager = proxy_manager
95
+ _check_config_changes_fn = check_config_changes_fn
96
+ _get_reload_status_fn = get_reload_status_fn
97
+ _default_agent_id = default_agent_id
98
+ _debug_mode = debug_mode
99
+
100
+ # Conditionally register debug tools based on debug mode
101
+ if debug_mode:
102
+ _register_debug_tools()
103
+
104
+
105
+ def get_default_agent_id() -> str | None:
106
+ """Get the default agent ID from gateway configuration.
107
+
108
+ Returns:
109
+ Default agent ID from GATEWAY_DEFAULT_AGENT env var, or None if not set
110
+ """
111
+ return _default_agent_id
112
+
113
+
114
+ def _register_debug_tools():
115
+ """Register debug-only tools when debug mode is enabled.
116
+
117
+ This function is called by initialize_gateway() when debug_mode=True.
118
+ It registers additional diagnostic tools that should only be available
119
+ in debug/development environments.
120
+ """
121
+ # Register get_gateway_status tool
122
+ # Note: The function itself is always defined (for testing), but only
123
+ # registered as a gateway tool when debug mode is enabled
124
+ gateway.tool(get_gateway_status)
125
+
126
+
127
+ @gateway.tool
128
+ async def list_servers(
129
+ agent_id: Annotated[Optional[str], "Your agent identifier (leave empty if not provided to you)"] = None,
130
+ include_metadata: Annotated[bool, "Include technical details (transport, command, url)"] = False
131
+ ) -> list[dict]:
132
+ """Discover downstream MCP servers available through this gateway. Your access is determined by gateway policy rules. Workflow: 1) Call list_servers to discover servers, 2) Call get_server_tools to see available tools, 3) Call execute_tool to use them."""
133
+ # Defensive check (middleware should have resolved agent_id)
134
+ if agent_id is None:
135
+ raise ToolError("Internal error: agent_id not resolved by middleware")
136
+
137
+ # Get configurations from module-level storage
138
+ policy_engine = _policy_engine
139
+ mcp_config = _mcp_config
140
+
141
+ if not policy_engine:
142
+ raise RuntimeError("PolicyEngine not initialized in gateway state")
143
+ if not mcp_config:
144
+ raise RuntimeError("MCP configuration not initialized in gateway state")
145
+
146
+ # Get servers this agent can access
147
+ allowed_servers = policy_engine.get_allowed_servers(agent_id)
148
+ all_servers = mcp_config.get("mcpServers", {})
149
+
150
+ # Build response
151
+ server_list = []
152
+
153
+ # Handle wildcard access
154
+ if allowed_servers == ["*"]:
155
+ # Agent has wildcard access - return all servers
156
+ allowed_servers = list(all_servers.keys())
157
+
158
+ for server_name in allowed_servers:
159
+ if server_name in all_servers:
160
+ server_config = all_servers[server_name]
161
+
162
+ # Determine transport type
163
+ transport = "stdio" if "command" in server_config else "http"
164
+
165
+ # Build ServerInfo object - always include name and description
166
+ server_info_kwargs = {
167
+ "name": server_name,
168
+ "description": server_config.get("description") # Include description always (None if not in config)
169
+ }
170
+
171
+ # Add technical metadata if requested
172
+ if include_metadata:
173
+ server_info_kwargs["transport"] = transport
174
+
175
+ # Add transport-specific metadata
176
+ if transport == "stdio":
177
+ server_info_kwargs["command"] = server_config.get("command")
178
+ elif transport == "http":
179
+ server_info_kwargs["url"] = server_config.get("url")
180
+
181
+ server_list.append(ServerInfo(**server_info_kwargs))
182
+
183
+ return [server.model_dump() for server in server_list]
184
+
185
+
186
+ def _matches_pattern(tool_name: str, pattern: str) -> bool:
187
+ """Check if tool name matches wildcard pattern.
188
+
189
+ Uses glob-style pattern matching:
190
+ - * matches any sequence of characters
191
+ - ? matches any single character
192
+ - [seq] matches any character in seq
193
+ - [!seq] matches any character not in seq
194
+
195
+ Args:
196
+ tool_name: Name of the tool to match
197
+ pattern: Pattern with wildcards (e.g., "get_*", "*_user")
198
+
199
+ Returns:
200
+ True if tool_name matches pattern, False otherwise
201
+
202
+ Example:
203
+ >>> _matches_pattern("get_user", "get_*")
204
+ True
205
+ >>> _matches_pattern("delete_user", "get_*")
206
+ False
207
+ >>> _matches_pattern("list_items", "*_items")
208
+ True
209
+ """
210
+ return fnmatch.fnmatch(tool_name, pattern)
211
+
212
+
213
+ def _estimate_tool_tokens(tool: Any) -> int:
214
+ """Estimate token count for a tool definition.
215
+
216
+ Estimates tokens based on name, description, and input schema JSON length.
217
+ Uses rough approximation: characters / 4 = tokens (typical for English text).
218
+
219
+ Args:
220
+ tool: Tool object with name, description, and inputSchema attributes
221
+
222
+ Returns:
223
+ Estimated token count for the tool definition
224
+
225
+ Example:
226
+ >>> tool = Tool(name="get_user", description="Get user by ID", inputSchema={...})
227
+ >>> _estimate_tool_tokens(tool)
228
+ 42
229
+ """
230
+ # Count name length
231
+ name_len = len(tool.name) if hasattr(tool, 'name') and tool.name else 0
232
+
233
+ # Count description length
234
+ desc_len = len(tool.description) if hasattr(tool, 'description') and tool.description else 0
235
+
236
+ # Count input schema length (convert to string for estimation)
237
+ schema_len = 0
238
+ if hasattr(tool, 'inputSchema') and tool.inputSchema:
239
+ # Convert schema dict to string for rough character count
240
+ import json
241
+ try:
242
+ schema_len = len(json.dumps(tool.inputSchema))
243
+ except Exception:
244
+ # If serialization fails, use a default estimate
245
+ schema_len = 100
246
+
247
+ # Total characters / 4 = rough token estimate
248
+ total_chars = name_len + desc_len + schema_len
249
+ return max(1, total_chars // 4)
250
+
251
+
252
+ @gateway.tool
253
+ async def get_server_tools(
254
+ agent_id: Annotated[Optional[str], "Your agent identifier (leave empty if not provided to you)"] = None,
255
+ server: Annotated[str, "Server name from list_servers"] = "",
256
+ names: Annotated[Optional[str], "Filter: comma-separated tool names"] = None,
257
+ pattern: Annotated[Optional[str], "Filter: wildcard pattern (e.g., 'get_*')"] = None,
258
+ max_schema_tokens: Annotated[Optional[int], "Limit total tokens in returned schemas"] = None
259
+ ) -> dict:
260
+ """Discover tools available on a downstream MCP server accessed through this gateway. Returns only tools you have permission to use (filtered by policy rules). Use the returned tool definitions to call execute_tool."""
261
+ # Defensive check (middleware should have resolved agent_id)
262
+ if agent_id is None:
263
+ raise ToolError("Internal error: agent_id not resolved by middleware")
264
+
265
+ # Check for config changes (fallback mechanism for when file watching doesn't work)
266
+ if _check_config_changes_fn:
267
+ try:
268
+ _check_config_changes_fn()
269
+ except Exception:
270
+ pass # Don't let config check errors break tool execution
271
+
272
+ # Parse comma-separated names string into list
273
+ names_list: Optional[list[str]] = None
274
+ if names is not None and names.strip():
275
+ # Split by comma and trim whitespace from each name
276
+ names_list = [name.strip() for name in names.split(",") if name.strip()]
277
+ # If we ended up with an empty list after filtering, treat as None
278
+ if not names_list:
279
+ names_list = None
280
+
281
+ # Get configurations from module-level storage
282
+ policy_engine = _policy_engine
283
+ proxy_manager = _proxy_manager
284
+
285
+ if not policy_engine:
286
+ return GetServerToolsResponse(
287
+ tools=[],
288
+ server=server,
289
+ total_available=0,
290
+ returned=0,
291
+ tokens_used=None,
292
+ error="PolicyEngine not initialized in gateway state"
293
+ ).model_dump()
294
+
295
+ if not proxy_manager:
296
+ return GetServerToolsResponse(
297
+ tools=[],
298
+ server=server,
299
+ total_available=0,
300
+ returned=0,
301
+ tokens_used=None,
302
+ error="ProxyManager not initialized in gateway state"
303
+ ).model_dump()
304
+
305
+ # Validate agent can access server
306
+ if not policy_engine.can_access_server(agent_id, server):
307
+ return GetServerToolsResponse(
308
+ tools=[],
309
+ server=server,
310
+ total_available=0,
311
+ returned=0,
312
+ tokens_used=None,
313
+ error=f"Access denied: Agent '{agent_id}' cannot access server '{server}'"
314
+ ).model_dump()
315
+
316
+ # Get tools from downstream server
317
+ try:
318
+ all_tools = await proxy_manager.list_tools(server)
319
+ except KeyError:
320
+ return GetServerToolsResponse(
321
+ tools=[],
322
+ server=server,
323
+ total_available=0,
324
+ returned=0,
325
+ tokens_used=None,
326
+ error=f"Server '{server}' not found in configured servers"
327
+ ).model_dump()
328
+ except RuntimeError as e:
329
+ return GetServerToolsResponse(
330
+ tools=[],
331
+ server=server,
332
+ total_available=0,
333
+ returned=0,
334
+ tokens_used=None,
335
+ error=f"Server unavailable: {str(e)}"
336
+ ).model_dump()
337
+ except Exception as e:
338
+ return GetServerToolsResponse(
339
+ tools=[],
340
+ server=server,
341
+ total_available=0,
342
+ returned=0,
343
+ tokens_used=None,
344
+ error=f"Failed to retrieve tools: {str(e)}"
345
+ ).model_dump()
346
+
347
+ total_available = len(all_tools)
348
+
349
+ # Filter tools based on criteria
350
+ filtered_tools = []
351
+ token_count = 0
352
+
353
+ for tool in all_tools:
354
+ tool_name = tool.name if hasattr(tool, 'name') else str(tool)
355
+
356
+ # Filter by explicit names list
357
+ if names_list is not None and tool_name not in names_list:
358
+ continue
359
+
360
+ # Filter by wildcard pattern
361
+ if pattern is not None and not _matches_pattern(tool_name, pattern):
362
+ continue
363
+
364
+ # Filter by policy permissions
365
+ if not policy_engine.can_access_tool(agent_id, server, tool_name):
366
+ continue
367
+
368
+ # Check token budget limit
369
+ if max_schema_tokens is not None:
370
+ tool_tokens = _estimate_tool_tokens(tool)
371
+ if token_count + tool_tokens > max_schema_tokens:
372
+ # Stop adding tools - budget exceeded
373
+ break
374
+ token_count += tool_tokens
375
+
376
+ # Convert tool to ToolDefinition
377
+ tool_definition = ToolDefinition(
378
+ name=tool_name,
379
+ description=tool.description if hasattr(tool, 'description') and tool.description else "",
380
+ inputSchema=tool.inputSchema if hasattr(tool, 'inputSchema') else {}
381
+ )
382
+
383
+ filtered_tools.append(tool_definition)
384
+
385
+ return GetServerToolsResponse(
386
+ tools=filtered_tools,
387
+ server=server,
388
+ total_available=total_available,
389
+ returned=len(filtered_tools),
390
+ tokens_used=token_count if max_schema_tokens is not None else None
391
+ ).model_dump()
392
+
393
+
394
+ @gateway.tool
395
+ async def execute_tool(
396
+ agent_id: Annotated[Optional[str], "Your agent identifier (leave empty if not provided to you)"] = None,
397
+ server: Annotated[str, "Server name from list_servers"] = "",
398
+ tool: Annotated[str, "Tool name from get_server_tools"] = "",
399
+ args: Annotated[dict, "Arguments matching tool's inputSchema"] = {},
400
+ timeout_ms: Annotated[Optional[int], "Execution timeout in milliseconds"] = None
401
+ ) -> dict:
402
+ """Execute a tool on a downstream MCP server accessed through this gateway. Gateway validates permissions then forwards your request to the server. Returns the server's response directly."""
403
+ # Defensive check (middleware should have resolved agent_id)
404
+ if agent_id is None:
405
+ raise ToolError("Internal error: agent_id not resolved by middleware")
406
+
407
+ # Get configurations from module-level storage
408
+ policy_engine = _policy_engine
409
+ proxy_manager = _proxy_manager
410
+
411
+ if not policy_engine:
412
+ raise ToolError("PolicyEngine not initialized in gateway state")
413
+
414
+ if not proxy_manager:
415
+ raise ToolError("ProxyManager not initialized in gateway state")
416
+
417
+ # 1. Validate agent can access server
418
+ if not policy_engine.can_access_server(agent_id, server):
419
+ raise ToolError(f"Agent '{agent_id}' cannot access server '{server}'")
420
+
421
+ # 2. Validate agent can access tool
422
+ if not policy_engine.can_access_tool(agent_id, server, tool):
423
+ raise ToolError(f"Agent '{agent_id}' not authorized to call tool '{tool}' on server '{server}'")
424
+
425
+ # 3. Execute tool on downstream server
426
+ try:
427
+ result = await proxy_manager.call_tool(server, tool, args, timeout_ms)
428
+
429
+ # 4. Return result transparently
430
+ # Handle both ToolResult objects and dict responses
431
+ if hasattr(result, 'content'):
432
+ # ToolResult object
433
+ return ToolExecutionResponse(
434
+ content=result.content,
435
+ isError=getattr(result, "isError", False)
436
+ ).model_dump()
437
+ elif isinstance(result, dict):
438
+ # Already a dict - ensure it has the expected structure
439
+ return ToolExecutionResponse(
440
+ content=result.get("content", [{"type": "text", "text": str(result)}]),
441
+ isError=result.get("isError", False)
442
+ ).model_dump()
443
+ else:
444
+ # Wrap other return types
445
+ return ToolExecutionResponse(
446
+ content=[{"type": "text", "text": str(result)}],
447
+ isError=False
448
+ ).model_dump()
449
+
450
+ except asyncio.TimeoutError:
451
+ raise ToolError(f"Tool execution timed out after {timeout_ms}ms")
452
+ except KeyError as e:
453
+ # Server not found
454
+ raise ToolError(f"Server '{server}' not found in configured servers")
455
+ except RuntimeError as e:
456
+ # Server unavailable or tool execution failed
457
+ error_msg = str(e)
458
+ if "not found" in error_msg.lower() or "unavailable" in error_msg.lower():
459
+ raise ToolError(error_msg)
460
+ else:
461
+ raise ToolError(f"Tool execution failed: {error_msg}")
462
+ except Exception as e:
463
+ # Other errors
464
+ raise ToolError(f"Tool execution failed: {str(e)}")
465
+
466
+
467
+ async def get_gateway_status(
468
+ agent_id: Annotated[Optional[str], "Your agent identifier (leave empty if not provided to you)"] = None
469
+ ) -> dict:
470
+ """Get gateway status, configuration state, and hot reload diagnostics.
471
+
472
+ NOTE: Only available when debug mode is enabled."""
473
+ # Defensive check (middleware should have resolved agent_id)
474
+ if agent_id is None:
475
+ raise ToolError("Internal error: agent_id not resolved by middleware")
476
+
477
+ # Get reload status if available
478
+ reload_status = None
479
+ if _get_reload_status_fn:
480
+ try:
481
+ reload_status = _get_reload_status_fn()
482
+ # Convert datetime objects to ISO strings for JSON serialization
483
+ if reload_status:
484
+ for config_type in ["mcp_config", "gateway_rules"]:
485
+ if config_type in reload_status:
486
+ for key in ["last_attempt", "last_success"]:
487
+ if reload_status[config_type].get(key):
488
+ reload_status[config_type][key] = reload_status[config_type][key].isoformat()
489
+ except Exception:
490
+ reload_status = {"error": "Failed to retrieve reload status"}
491
+
492
+ # Get PolicyEngine state
493
+ policy_state = {}
494
+ if _policy_engine:
495
+ try:
496
+ policy_state = {
497
+ "total_agents": len(_policy_engine.agents),
498
+ "agent_ids": list(_policy_engine.agents.keys()),
499
+ "defaults": _policy_engine.defaults,
500
+ }
501
+ except Exception:
502
+ policy_state = {"error": "Failed to retrieve policy state"}
503
+
504
+ # Get available servers
505
+ available_servers = []
506
+ if _mcp_config and "mcpServers" in _mcp_config:
507
+ available_servers = list(_mcp_config["mcpServers"].keys())
508
+
509
+ # Get config file paths from src/config.py
510
+ config_paths = {}
511
+ try:
512
+ from src.config import get_stored_config_paths
513
+ mcp_path, rules_path = get_stored_config_paths()
514
+ config_paths = {
515
+ "mcp_config": mcp_path,
516
+ "gateway_rules": rules_path,
517
+ }
518
+ except Exception:
519
+ config_paths = {"error": "Failed to retrieve config paths"}
520
+
521
+ return GatewayStatusResponse(
522
+ reload_status=reload_status,
523
+ policy_state=policy_state,
524
+ available_servers=available_servers,
525
+ config_paths=config_paths,
526
+ message="Gateway is operational. Check reload_status for hot reload health."
527
+ ).model_dump()