agent-mcp-gateway 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agent-mcp-gateway might be problematic. Click here for more details.

src/proxy.py ADDED
@@ -0,0 +1,647 @@
1
+ """Proxy infrastructure for managing downstream MCP server connections."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from typing import Any
6
+
7
+ from fastmcp.client import Client
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class ProxyManager:
14
+ """Manages connections to downstream MCP servers.
15
+
16
+ This class initializes and maintains Client instances for each
17
+ configured downstream MCP server. It supports both stdio (npx/uvx)
18
+ and HTTP transports, implements lazy connection strategy, and provides
19
+ graceful error handling for unreachable servers.
20
+ """
21
+
22
+ def __init__(self):
23
+ """Initialize ProxyManager with empty client registry."""
24
+ self._clients: dict[str, Client] = {}
25
+ self._connection_status: dict[str, bool] = {}
26
+ self._connection_errors: dict[str, str] = {}
27
+ self._current_config: dict = {} # Store current config for reload comparison
28
+
29
+ def initialize_connections(self, mcp_config: dict) -> dict[str, Client]:
30
+ """Initialize Client instances from MCP configuration.
31
+
32
+ Creates disconnected Client instances for lazy connection strategy.
33
+ Connections are established on first use via async context manager.
34
+
35
+ Args:
36
+ mcp_config: MCP servers configuration dictionary with structure:
37
+ {
38
+ "mcpServers": {
39
+ "server-name": {
40
+ "command": "npx",
41
+ "args": [...],
42
+ "env": {...}
43
+ }
44
+ }
45
+ }
46
+
47
+ Returns:
48
+ Dictionary mapping server names to Client instances
49
+
50
+ Raises:
51
+ ValueError: If mcp_config is invalid or malformed
52
+ """
53
+ if not isinstance(mcp_config, dict):
54
+ raise ValueError(
55
+ f"MCP configuration must be a dict, got {type(mcp_config).__name__}"
56
+ )
57
+
58
+ mcp_servers = mcp_config.get("mcpServers", {})
59
+ if not isinstance(mcp_servers, dict):
60
+ raise ValueError(
61
+ f'"mcpServers" must be a dict, got {type(mcp_servers).__name__}'
62
+ )
63
+
64
+ # Clear existing clients
65
+ self._clients.clear()
66
+ self._connection_status.clear()
67
+ self._connection_errors.clear()
68
+
69
+ # Create ProxyClient for each server
70
+ for server_name, server_config in mcp_servers.items():
71
+ try:
72
+ client = self._create_client(server_name, server_config)
73
+ self._clients[server_name] = client
74
+ self._connection_status[server_name] = False # Not yet connected
75
+ self._connection_errors[server_name] = ""
76
+
77
+ logger.info(f"Initialized ProxyClient for server: {server_name}")
78
+ except Exception as e:
79
+ logger.error(f"Failed to initialize client for {server_name}: {e}")
80
+ self._connection_errors[server_name] = str(e)
81
+
82
+ # Store current config for reload comparison
83
+ self._current_config = mcp_config
84
+
85
+ return self._clients
86
+
87
+ def _create_client(self, server_name: str, server_config: dict) -> Client:
88
+ """Create Client instance from server configuration.
89
+
90
+ Args:
91
+ server_name: Name of the server
92
+ server_config: Server configuration dictionary
93
+
94
+ Returns:
95
+ Client instance (disconnected)
96
+
97
+ Raises:
98
+ ValueError: If server configuration is invalid
99
+ """
100
+ # Determine transport type
101
+ has_command = "command" in server_config
102
+ has_url = "url" in server_config
103
+
104
+ if not has_command and not has_url:
105
+ raise ValueError(
106
+ f'Server "{server_name}" must specify either "command" (stdio) '
107
+ f'or "url" (HTTP) transport'
108
+ )
109
+
110
+ if has_command and has_url:
111
+ raise ValueError(
112
+ f'Server "{server_name}" cannot have both "command" and "url"'
113
+ )
114
+
115
+ # Create stdio client
116
+ if has_command:
117
+ command = server_config["command"]
118
+ args = server_config.get("args", [])
119
+ env = server_config.get("env", {})
120
+
121
+ if not isinstance(command, str):
122
+ raise ValueError(
123
+ f'Server "{server_name}": "command" must be a string'
124
+ )
125
+ if not isinstance(args, list):
126
+ raise ValueError(
127
+ f'Server "{server_name}": "args" must be a list'
128
+ )
129
+ if not isinstance(env, dict):
130
+ raise ValueError(
131
+ f'Server "{server_name}": "env" must be a dict'
132
+ )
133
+
134
+ logger.debug(
135
+ f"Creating stdio Client for {server_name}: "
136
+ f"command={command}, args={args}"
137
+ )
138
+
139
+ # FastMCP Client expects MCPConfig with mcpServers key
140
+ # We need to wrap the single server config
141
+ client_config = {
142
+ "mcpServers": {
143
+ server_name: server_config
144
+ }
145
+ }
146
+ return Client(transport=client_config)
147
+
148
+ # Create HTTP client
149
+ if has_url:
150
+ url = server_config["url"]
151
+ headers = server_config.get("headers", {})
152
+
153
+ if not isinstance(url, str):
154
+ raise ValueError(
155
+ f'Server "{server_name}": "url" must be a string'
156
+ )
157
+ if not isinstance(headers, dict):
158
+ raise ValueError(
159
+ f'Server "{server_name}": "headers" must be a dict'
160
+ )
161
+
162
+ logger.info(
163
+ f"Creating HTTP Client with OAuth support for {server_name}: url={url}"
164
+ )
165
+
166
+ # Create HTTP client with OAuth enabled
167
+ # OAuth is auto-detected by the MCP protocol - servers that don't
168
+ # require OAuth will return 200, servers that do will return 401
169
+
170
+ # Note: Custom headers are not supported when using OAuth auto-detection
171
+ # If custom headers are needed, they should be handled at the application level
172
+ # For now, we prioritize OAuth support over custom headers
173
+ if headers:
174
+ logger.warning(
175
+ f"Server {server_name} has custom headers configured. "
176
+ "Custom headers may conflict with OAuth authentication. "
177
+ "If OAuth is required, consider removing custom headers."
178
+ )
179
+
180
+ # Pass URL directly (not as MCPConfig) to enable auth parameter
181
+ return Client(url, auth="oauth")
182
+
183
+ # Should never reach here due to earlier validation
184
+ raise ValueError(f'Server "{server_name}" has invalid configuration')
185
+
186
+ def get_client(self, server_name: str) -> Client:
187
+ """Get Client instance for a server.
188
+
189
+ Returns the disconnected Client for the specified server.
190
+ The caller should use 'async with client:' to establish a connection
191
+ and perform operations.
192
+
193
+ Args:
194
+ server_name: Name of the server
195
+
196
+ Returns:
197
+ Client instance for the server
198
+
199
+ Raises:
200
+ KeyError: If server_name is not found in initialized clients
201
+ RuntimeError: If server had initialization errors
202
+ """
203
+ if server_name not in self._clients:
204
+ # Check if it had initialization error
205
+ if server_name in self._connection_errors and self._connection_errors[server_name]:
206
+ raise RuntimeError(
207
+ f'Server "{server_name}" is unavailable: '
208
+ f'{self._connection_errors[server_name]}'
209
+ )
210
+ raise KeyError(
211
+ f'Server "{server_name}" not found in configured servers'
212
+ )
213
+
214
+ return self._clients[server_name]
215
+
216
+ async def test_connection(
217
+ self,
218
+ server_name: str,
219
+ timeout_ms: int = 5000,
220
+ max_retries: int = 3
221
+ ) -> bool:
222
+ """Test connection to a server with retry logic.
223
+
224
+ Attempts to connect to the server and retrieve its tool list.
225
+ Uses exponential backoff for retries.
226
+
227
+ Args:
228
+ server_name: Name of the server to test
229
+ timeout_ms: Connection timeout in milliseconds (default: 5000)
230
+ max_retries: Maximum number of retry attempts (default: 3)
231
+
232
+ Returns:
233
+ True if connection successful, False otherwise
234
+
235
+ Raises:
236
+ KeyError: If server_name is not found in initialized clients
237
+ """
238
+ client = self.get_client(server_name)
239
+
240
+ timeout_sec = timeout_ms / 1000.0
241
+ base_delay = 0.5 # Start with 500ms delay
242
+
243
+ for attempt in range(max_retries):
244
+ try:
245
+ logger.debug(
246
+ f"Testing connection to {server_name} "
247
+ f"(attempt {attempt + 1}/{max_retries})"
248
+ )
249
+
250
+ async with asyncio.timeout(timeout_sec):
251
+ async with client:
252
+ # Try to list tools as connection test
253
+ await client.list_tools()
254
+
255
+ self._connection_status[server_name] = True
256
+ self._connection_errors[server_name] = ""
257
+ logger.info(f"Successfully connected to server: {server_name}")
258
+ return True
259
+
260
+ except asyncio.TimeoutError:
261
+ delay = base_delay * (2 ** attempt)
262
+ logger.warning(
263
+ f"Connection timeout for {server_name} on attempt {attempt + 1}. "
264
+ f"Retrying in {delay}s..."
265
+ )
266
+ if attempt < max_retries - 1:
267
+ await asyncio.sleep(delay)
268
+
269
+ except Exception as e:
270
+ delay = base_delay * (2 ** attempt)
271
+ logger.warning(
272
+ f"Connection error for {server_name} on attempt {attempt + 1}: {e}. "
273
+ f"Retrying in {delay}s..."
274
+ )
275
+ if attempt < max_retries - 1:
276
+ await asyncio.sleep(delay)
277
+
278
+ # All retries failed
279
+ error_msg = f"Failed to connect after {max_retries} attempts"
280
+ self._connection_status[server_name] = False
281
+ self._connection_errors[server_name] = error_msg
282
+ logger.error(f"Failed to connect to server {server_name}: {error_msg}")
283
+ return False
284
+
285
+ async def call_tool(
286
+ self,
287
+ server_name: str,
288
+ tool_name: str,
289
+ arguments: dict[str, Any],
290
+ timeout_ms: int | None = None
291
+ ) -> Any:
292
+ """Call a tool on a downstream server.
293
+
294
+ Establishes a fresh session for each call (automatic session isolation).
295
+
296
+ Args:
297
+ server_name: Name of the server
298
+ tool_name: Name of the tool to call
299
+ arguments: Tool arguments
300
+ timeout_ms: Optional timeout in milliseconds
301
+
302
+ Returns:
303
+ Tool execution result
304
+
305
+ Raises:
306
+ KeyError: If server_name is not found
307
+ RuntimeError: If server is unavailable or tool call fails
308
+ asyncio.TimeoutError: If operation times out
309
+ """
310
+ client = self.get_client(server_name)
311
+
312
+ try:
313
+ if timeout_ms:
314
+ timeout_sec = timeout_ms / 1000.0
315
+ async with asyncio.timeout(timeout_sec):
316
+ async with client:
317
+ result = await client.call_tool(tool_name, arguments)
318
+ return result
319
+ else:
320
+ async with client:
321
+ result = await client.call_tool(tool_name, arguments)
322
+ return result
323
+
324
+ except asyncio.TimeoutError:
325
+ logger.error(
326
+ f"Timeout calling tool {tool_name} on server {server_name} "
327
+ f"(timeout: {timeout_ms}ms)"
328
+ )
329
+ raise
330
+
331
+ except Exception as e:
332
+ logger.error(
333
+ f"Error calling tool {tool_name} on server {server_name}: {e}"
334
+ )
335
+ raise RuntimeError(
336
+ f"Failed to call tool {tool_name} on server {server_name}: {e}"
337
+ )
338
+
339
+ async def list_tools(self, server_name: str) -> list[Any]:
340
+ """List all tools available from a server.
341
+
342
+ Args:
343
+ server_name: Name of the server
344
+
345
+ Returns:
346
+ List of tool definitions
347
+
348
+ Raises:
349
+ KeyError: If server_name is not found
350
+ RuntimeError: If server is unavailable
351
+ """
352
+ client = self.get_client(server_name)
353
+
354
+ try:
355
+ async with client:
356
+ tools = await client.list_tools()
357
+ return tools
358
+ except Exception as e:
359
+ logger.error(f"Error listing tools from server {server_name}: {e}")
360
+ raise RuntimeError(
361
+ f"Failed to list tools from server {server_name}: {e}"
362
+ )
363
+
364
+ def get_server_status(self, server_name: str) -> dict[str, Any]:
365
+ """Get connection status for a server.
366
+
367
+ Args:
368
+ server_name: Name of the server
369
+
370
+ Returns:
371
+ Dictionary with status information:
372
+ {
373
+ "connected": bool,
374
+ "error": str,
375
+ "initialized": bool
376
+ }
377
+ """
378
+ return {
379
+ "connected": self._connection_status.get(server_name, False),
380
+ "error": self._connection_errors.get(server_name, ""),
381
+ "initialized": server_name in self._clients
382
+ }
383
+
384
+ def get_all_servers(self) -> list[str]:
385
+ """Get list of all initialized server names.
386
+
387
+ Returns:
388
+ List of server names
389
+ """
390
+ return list(self._clients.keys())
391
+
392
+ def _config_changed(self, server_name: str, new_mcp_config: dict) -> bool:
393
+ """Check if a server's configuration has changed.
394
+
395
+ Compares the server configuration in the current config with the new config
396
+ to determine if the server needs to be reloaded.
397
+
398
+ Args:
399
+ server_name: Name of the server to check
400
+ new_mcp_config: New MCP configuration to compare against
401
+
402
+ Returns:
403
+ True if configuration changed, False otherwise
404
+ """
405
+ old_servers = self._current_config.get("mcpServers", {})
406
+ new_servers = new_mcp_config.get("mcpServers", {})
407
+
408
+ old_config = old_servers.get(server_name, {})
409
+ new_config = new_servers.get(server_name, {})
410
+
411
+ # Compare the configurations (deep comparison)
412
+ return old_config != new_config
413
+
414
+ async def reload(self, new_mcp_config: dict) -> tuple[bool, str | None]:
415
+ """Reload proxy client connections with new MCP server configuration.
416
+
417
+ This method performs an atomic configuration update by:
418
+ 1. Validating the new configuration
419
+ 2. Determining which servers need to be added, removed, or updated
420
+ 3. Closing connections for removed/updated servers
421
+ 4. Creating new clients for added/updated servers
422
+ 5. Preserving unchanged servers (no disruption)
423
+
424
+ The reload follows the lazy connection strategy - new clients are created
425
+ in disconnected state and will connect on first use.
426
+
427
+ Args:
428
+ new_mcp_config: New MCP server configuration dictionary with structure:
429
+ {
430
+ "mcpServers": {
431
+ "server-name": {
432
+ "command": "npx",
433
+ "args": [...],
434
+ "env": {...}
435
+ }
436
+ }
437
+ }
438
+
439
+ Returns:
440
+ Tuple of (success: bool, error_message: str | None)
441
+ - (True, None) if reload successful
442
+ - (False, error_message) if validation failed or reload encountered errors
443
+
444
+ Thread Safety:
445
+ This method is NOT thread-safe. The caller must ensure that reload
446
+ operations are serialized and that no concurrent operations are
447
+ accessing the ProxyManager during reload.
448
+
449
+ Examples:
450
+ >>> manager = ProxyManager()
451
+ >>> manager.initialize_connections(initial_config)
452
+ >>> success, error = await manager.reload(new_config)
453
+ >>> if success:
454
+ ... logger.info("Reload successful")
455
+ >>> else:
456
+ ... logger.error(f"Reload failed: {error}")
457
+ """
458
+ logger.info("ProxyManager reload initiated")
459
+
460
+ # Validate new configuration structure
461
+ try:
462
+ if not isinstance(new_mcp_config, dict):
463
+ error_msg = f"MCP configuration must be a dict, got {type(new_mcp_config).__name__}"
464
+ logger.error(f"Reload validation failed: {error_msg}")
465
+ return False, error_msg
466
+
467
+ new_mcp_servers = new_mcp_config.get("mcpServers", {})
468
+ if not isinstance(new_mcp_servers, dict):
469
+ error_msg = f'"mcpServers" must be a dict, got {type(new_mcp_servers).__name__}'
470
+ logger.error(f"Reload validation failed: {error_msg}")
471
+ return False, error_msg
472
+
473
+ # Validate each server config before proceeding
474
+ for server_name, server_config in new_mcp_servers.items():
475
+ try:
476
+ # Validate by attempting to parse the config (without creating client)
477
+ has_command = "command" in server_config
478
+ has_url = "url" in server_config
479
+
480
+ if not has_command and not has_url:
481
+ raise ValueError(
482
+ f'Server "{server_name}" must specify either "command" (stdio) '
483
+ f'or "url" (HTTP) transport'
484
+ )
485
+
486
+ if has_command and has_url:
487
+ raise ValueError(
488
+ f'Server "{server_name}" cannot have both "command" and "url"'
489
+ )
490
+
491
+ # Validate stdio config
492
+ if has_command:
493
+ command = server_config["command"]
494
+ args = server_config.get("args", [])
495
+ env = server_config.get("env", {})
496
+
497
+ if not isinstance(command, str):
498
+ raise ValueError(f'Server "{server_name}": "command" must be a string')
499
+ if not isinstance(args, list):
500
+ raise ValueError(f'Server "{server_name}": "args" must be a list')
501
+ if not isinstance(env, dict):
502
+ raise ValueError(f'Server "{server_name}": "env" must be a dict')
503
+
504
+ # Validate HTTP config
505
+ if has_url:
506
+ url = server_config["url"]
507
+ headers = server_config.get("headers", {})
508
+
509
+ if not isinstance(url, str):
510
+ raise ValueError(f'Server "{server_name}": "url" must be a string')
511
+ if not isinstance(headers, dict):
512
+ raise ValueError(f'Server "{server_name}": "headers" must be a dict')
513
+
514
+ except Exception as e:
515
+ error_msg = f"Invalid configuration for server '{server_name}': {e}"
516
+ logger.error(f"Reload validation failed: {error_msg}")
517
+ return False, error_msg
518
+
519
+ except Exception as e:
520
+ error_msg = f"Configuration validation error: {e}"
521
+ logger.error(f"Reload validation failed: {error_msg}")
522
+ return False, error_msg
523
+
524
+ logger.info("Configuration validation passed")
525
+
526
+ # Determine server changes
527
+ old_servers = set(self._clients.keys())
528
+ new_servers = set(new_mcp_config.get("mcpServers", {}).keys())
529
+
530
+ servers_to_add = new_servers - old_servers
531
+ servers_to_remove = old_servers - new_servers
532
+ servers_to_check = old_servers & new_servers
533
+
534
+ servers_to_update = [
535
+ s for s in servers_to_check
536
+ if self._config_changed(s, new_mcp_config)
537
+ ]
538
+ servers_unchanged = [
539
+ s for s in servers_to_check
540
+ if not self._config_changed(s, new_mcp_config)
541
+ ]
542
+
543
+ logger.info(
544
+ f"Server changes: "
545
+ f"+{len(servers_to_add)} (add), "
546
+ f"-{len(servers_to_remove)} (remove), "
547
+ f"~{len(servers_to_update)} (update), "
548
+ f"={len(servers_unchanged)} (unchanged)"
549
+ )
550
+
551
+ if servers_to_add:
552
+ logger.info(f"Servers to add: {sorted(servers_to_add)}")
553
+ if servers_to_remove:
554
+ logger.info(f"Servers to remove: {sorted(servers_to_remove)}")
555
+ if servers_to_update:
556
+ logger.info(f"Servers to update: {sorted(servers_to_update)}")
557
+ if servers_unchanged:
558
+ logger.debug(f"Servers unchanged: {sorted(servers_unchanged)}")
559
+
560
+ # Close connections for removed and updated servers
561
+ servers_to_close = list(servers_to_remove) + servers_to_update
562
+
563
+ if servers_to_close:
564
+ logger.info(f"Closing connections for {len(servers_to_close)} servers")
565
+
566
+ for server_name in servers_to_close:
567
+ try:
568
+ client = self._clients.get(server_name)
569
+ if client is not None:
570
+ # Since we use lazy connection strategy with context managers,
571
+ # clients don't maintain persistent connections. However, we
572
+ # should still clean up any resources.
573
+ logger.debug(f"Cleaning up client for server: {server_name}")
574
+ # Note: Client instances don't have an explicit close() method
575
+ # as they use async context managers. Resources are released
576
+ # when we remove the reference.
577
+
578
+ # Remove from all tracking dictionaries
579
+ self._clients.pop(server_name, None)
580
+ self._connection_status.pop(server_name, None)
581
+ self._connection_errors.pop(server_name, None)
582
+
583
+ logger.debug(f"Removed client for server: {server_name}")
584
+
585
+ except Exception as e:
586
+ logger.warning(
587
+ f"Error during cleanup for server '{server_name}': {e}",
588
+ exc_info=True
589
+ )
590
+ # Continue with reload despite cleanup errors
591
+
592
+ # Create clients for new and updated servers
593
+ servers_to_create = list(servers_to_add) + servers_to_update
594
+ new_mcp_servers = new_mcp_config.get("mcpServers", {})
595
+
596
+ if servers_to_create:
597
+ logger.info(f"Creating clients for {len(servers_to_create)} servers")
598
+
599
+ for server_name in servers_to_create:
600
+ try:
601
+ server_config = new_mcp_servers[server_name]
602
+ client = self._create_client(server_name, server_config)
603
+
604
+ self._clients[server_name] = client
605
+ self._connection_status[server_name] = False # Not yet connected
606
+ self._connection_errors[server_name] = ""
607
+
608
+ logger.info(f"Created client for server: {server_name}")
609
+
610
+ except Exception as e:
611
+ # Log error but don't fail the entire reload
612
+ # This follows the lazy connection strategy - servers with
613
+ # initialization errors are recorded and will error on use
614
+ logger.error(
615
+ f"Failed to create client for server '{server_name}': {e}",
616
+ exc_info=True
617
+ )
618
+ self._connection_errors[server_name] = str(e)
619
+
620
+ # Update stored config
621
+ self._current_config = new_mcp_config
622
+
623
+ logger.info(
624
+ f"ProxyManager reload completed successfully. "
625
+ f"Active servers: {len(self._clients)}"
626
+ )
627
+
628
+ return True, None
629
+
630
+ async def close_all_connections(self):
631
+ """Close all Client connections.
632
+
633
+ This is a cleanup method for graceful shutdown.
634
+ Note: Since we use disconnected clients with context managers,
635
+ there are no persistent connections to close. This method is
636
+ provided for API completeness and future extensibility.
637
+ """
638
+ logger.info("Closing all proxy connections")
639
+
640
+ # With disconnected clients, each 'async with' creates and closes
641
+ # its own session, so no explicit cleanup needed
642
+
643
+ # Clear internal state
644
+ self._connection_status.clear()
645
+ self._connection_errors.clear()
646
+
647
+ logger.info(f"Closed connections for {len(self._clients)} servers")