agent-mcp-gateway 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
src/proxy.py ADDED
@@ -0,0 +1,649 @@
1
+ """Proxy infrastructure for managing downstream MCP server connections."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from typing import Any
6
+
7
+ from fastmcp.client import Client
8
+ from fastmcp.client.transports import StreamableHttpTransport
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class ProxyManager:
15
+ """Manages connections to downstream MCP servers.
16
+
17
+ This class initializes and maintains Client instances for each
18
+ configured downstream MCP server. It supports both stdio (npx/uvx)
19
+ and HTTP transports, implements lazy connection strategy, and provides
20
+ graceful error handling for unreachable servers.
21
+ """
22
+
23
+ def __init__(self):
24
+ """Initialize ProxyManager with empty client registry."""
25
+ self._clients: dict[str, Client] = {}
26
+ self._connection_status: dict[str, bool] = {}
27
+ self._connection_errors: dict[str, str] = {}
28
+ self._current_config: dict = {} # Store current config for reload comparison
29
+
30
+ def initialize_connections(self, mcp_config: dict) -> dict[str, Client]:
31
+ """Initialize Client instances from MCP configuration.
32
+
33
+ Creates disconnected Client instances for lazy connection strategy.
34
+ Connections are established on first use via async context manager.
35
+
36
+ Args:
37
+ mcp_config: MCP servers configuration dictionary with structure:
38
+ {
39
+ "mcpServers": {
40
+ "server-name": {
41
+ "command": "npx",
42
+ "args": [...],
43
+ "env": {...}
44
+ }
45
+ }
46
+ }
47
+
48
+ Returns:
49
+ Dictionary mapping server names to Client instances
50
+
51
+ Raises:
52
+ ValueError: If mcp_config is invalid or malformed
53
+ """
54
+ if not isinstance(mcp_config, dict):
55
+ raise ValueError(
56
+ f"MCP configuration must be a dict, got {type(mcp_config).__name__}"
57
+ )
58
+
59
+ mcp_servers = mcp_config.get("mcpServers", {})
60
+ if not isinstance(mcp_servers, dict):
61
+ raise ValueError(
62
+ f'"mcpServers" must be a dict, got {type(mcp_servers).__name__}'
63
+ )
64
+
65
+ # Clear existing clients
66
+ self._clients.clear()
67
+ self._connection_status.clear()
68
+ self._connection_errors.clear()
69
+
70
+ # Create ProxyClient for each server
71
+ for server_name, server_config in mcp_servers.items():
72
+ try:
73
+ client = self._create_client(server_name, server_config)
74
+ self._clients[server_name] = client
75
+ self._connection_status[server_name] = False # Not yet connected
76
+ self._connection_errors[server_name] = ""
77
+
78
+ logger.info(f"Initialized ProxyClient for server: {server_name}")
79
+ except Exception as e:
80
+ logger.error(f"Failed to initialize client for {server_name}: {e}")
81
+ self._connection_errors[server_name] = str(e)
82
+
83
+ # Store current config for reload comparison
84
+ self._current_config = mcp_config
85
+
86
+ return self._clients
87
+
88
+ def _create_client(self, server_name: str, server_config: dict) -> Client:
89
+ """Create Client instance from server configuration.
90
+
91
+ Args:
92
+ server_name: Name of the server
93
+ server_config: Server configuration dictionary
94
+
95
+ Returns:
96
+ Client instance (disconnected)
97
+
98
+ Raises:
99
+ ValueError: If server configuration is invalid
100
+ """
101
+ # Determine transport type
102
+ has_command = "command" in server_config
103
+ has_url = "url" in server_config
104
+
105
+ if not has_command and not has_url:
106
+ raise ValueError(
107
+ f'Server "{server_name}" must specify either "command" (stdio) '
108
+ f'or "url" (HTTP) transport'
109
+ )
110
+
111
+ if has_command and has_url:
112
+ raise ValueError(
113
+ f'Server "{server_name}" cannot have both "command" and "url"'
114
+ )
115
+
116
+ # Create stdio client
117
+ if has_command:
118
+ command = server_config["command"]
119
+ args = server_config.get("args", [])
120
+ env = server_config.get("env", {})
121
+
122
+ if not isinstance(command, str):
123
+ raise ValueError(
124
+ f'Server "{server_name}": "command" must be a string'
125
+ )
126
+ if not isinstance(args, list):
127
+ raise ValueError(
128
+ f'Server "{server_name}": "args" must be a list'
129
+ )
130
+ if not isinstance(env, dict):
131
+ raise ValueError(
132
+ f'Server "{server_name}": "env" must be a dict'
133
+ )
134
+
135
+ logger.debug(
136
+ f"Creating stdio Client for {server_name}: "
137
+ f"command={command}, args={args}"
138
+ )
139
+
140
+ # FastMCP Client expects MCPConfig with mcpServers key
141
+ # We need to wrap the single server config
142
+ client_config = {
143
+ "mcpServers": {
144
+ server_name: server_config
145
+ }
146
+ }
147
+ return Client(transport=client_config)
148
+
149
+ # Create HTTP client
150
+ if has_url:
151
+ url = server_config["url"]
152
+ headers = server_config.get("headers", {})
153
+
154
+ if not isinstance(url, str):
155
+ raise ValueError(
156
+ f'Server "{server_name}": "url" must be a string'
157
+ )
158
+ if not isinstance(headers, dict):
159
+ raise ValueError(
160
+ f'Server "{server_name}": "headers" must be a dict'
161
+ )
162
+
163
+ # Check if Authorization header is provided (PAT or other auth)
164
+ has_auth_header = headers and any(
165
+ k.lower() == "authorization" for k in headers.keys()
166
+ )
167
+
168
+ if has_auth_header:
169
+ # User provided explicit auth - respect it, don't enable OAuth
170
+ # Use StreamableHttpTransport to pass custom headers
171
+ logger.info(
172
+ f"Creating HTTP Client with custom authentication for {server_name}: "
173
+ f"url={url}"
174
+ )
175
+ transport = StreamableHttpTransport(url, headers=headers)
176
+ return Client(transport)
177
+ else:
178
+ # No auth provided - enable OAuth auto-detection (for Notion, etc.)
179
+ logger.info(
180
+ f"Creating HTTP Client with OAuth support for {server_name}: "
181
+ f"url={url}"
182
+ )
183
+ return Client(url, auth="oauth")
184
+
185
+ # Should never reach here due to earlier validation
186
+ raise ValueError(f'Server "{server_name}" has invalid configuration')
187
+
188
+ def get_client(self, server_name: str) -> Client:
189
+ """Get Client instance for a server.
190
+
191
+ Returns the disconnected Client for the specified server.
192
+ The caller should use 'async with client:' to establish a connection
193
+ and perform operations.
194
+
195
+ Args:
196
+ server_name: Name of the server
197
+
198
+ Returns:
199
+ Client instance for the server
200
+
201
+ Raises:
202
+ KeyError: If server_name is not found in initialized clients
203
+ RuntimeError: If server had initialization errors
204
+ """
205
+ if server_name not in self._clients:
206
+ # Check if it had initialization error
207
+ if server_name in self._connection_errors and self._connection_errors[server_name]:
208
+ raise RuntimeError(
209
+ f'Server "{server_name}" is unavailable: '
210
+ f'{self._connection_errors[server_name]}'
211
+ )
212
+ raise KeyError(
213
+ f'Server "{server_name}" not found in configured servers'
214
+ )
215
+
216
+ return self._clients[server_name]
217
+
218
+ async def test_connection(
219
+ self,
220
+ server_name: str,
221
+ timeout_ms: int = 5000,
222
+ max_retries: int = 3
223
+ ) -> bool:
224
+ """Test connection to a server with retry logic.
225
+
226
+ Attempts to connect to the server and retrieve its tool list.
227
+ Uses exponential backoff for retries.
228
+
229
+ Args:
230
+ server_name: Name of the server to test
231
+ timeout_ms: Connection timeout in milliseconds (default: 5000)
232
+ max_retries: Maximum number of retry attempts (default: 3)
233
+
234
+ Returns:
235
+ True if connection successful, False otherwise
236
+
237
+ Raises:
238
+ KeyError: If server_name is not found in initialized clients
239
+ """
240
+ client = self.get_client(server_name)
241
+
242
+ timeout_sec = timeout_ms / 1000.0
243
+ base_delay = 0.5 # Start with 500ms delay
244
+
245
+ for attempt in range(max_retries):
246
+ try:
247
+ logger.debug(
248
+ f"Testing connection to {server_name} "
249
+ f"(attempt {attempt + 1}/{max_retries})"
250
+ )
251
+
252
+ async with asyncio.timeout(timeout_sec):
253
+ async with client:
254
+ # Try to list tools as connection test
255
+ await client.list_tools()
256
+
257
+ self._connection_status[server_name] = True
258
+ self._connection_errors[server_name] = ""
259
+ logger.info(f"Successfully connected to server: {server_name}")
260
+ return True
261
+
262
+ except asyncio.TimeoutError:
263
+ delay = base_delay * (2 ** attempt)
264
+ logger.warning(
265
+ f"Connection timeout for {server_name} on attempt {attempt + 1}. "
266
+ f"Retrying in {delay}s..."
267
+ )
268
+ if attempt < max_retries - 1:
269
+ await asyncio.sleep(delay)
270
+
271
+ except Exception as e:
272
+ delay = base_delay * (2 ** attempt)
273
+ logger.warning(
274
+ f"Connection error for {server_name} on attempt {attempt + 1}: {e}. "
275
+ f"Retrying in {delay}s..."
276
+ )
277
+ if attempt < max_retries - 1:
278
+ await asyncio.sleep(delay)
279
+
280
+ # All retries failed
281
+ error_msg = f"Failed to connect after {max_retries} attempts"
282
+ self._connection_status[server_name] = False
283
+ self._connection_errors[server_name] = error_msg
284
+ logger.error(f"Failed to connect to server {server_name}: {error_msg}")
285
+ return False
286
+
287
+ async def call_tool(
288
+ self,
289
+ server_name: str,
290
+ tool_name: str,
291
+ arguments: dict[str, Any],
292
+ timeout_ms: int | None = None
293
+ ) -> Any:
294
+ """Call a tool on a downstream server.
295
+
296
+ Establishes a fresh session for each call (automatic session isolation).
297
+
298
+ Args:
299
+ server_name: Name of the server
300
+ tool_name: Name of the tool to call
301
+ arguments: Tool arguments
302
+ timeout_ms: Optional timeout in milliseconds
303
+
304
+ Returns:
305
+ Tool execution result
306
+
307
+ Raises:
308
+ KeyError: If server_name is not found
309
+ RuntimeError: If server is unavailable or tool call fails
310
+ asyncio.TimeoutError: If operation times out
311
+ """
312
+ client = self.get_client(server_name)
313
+
314
+ try:
315
+ if timeout_ms:
316
+ timeout_sec = timeout_ms / 1000.0
317
+ async with asyncio.timeout(timeout_sec):
318
+ async with client:
319
+ result = await client.call_tool(tool_name, arguments)
320
+ return result
321
+ else:
322
+ async with client:
323
+ result = await client.call_tool(tool_name, arguments)
324
+ return result
325
+
326
+ except asyncio.TimeoutError:
327
+ logger.error(
328
+ f"Timeout calling tool {tool_name} on server {server_name} "
329
+ f"(timeout: {timeout_ms}ms)"
330
+ )
331
+ raise
332
+
333
+ except Exception as e:
334
+ logger.error(
335
+ f"Error calling tool {tool_name} on server {server_name}: {e}"
336
+ )
337
+ raise RuntimeError(
338
+ f"Failed to call tool {tool_name} on server {server_name}: {e}"
339
+ )
340
+
341
+ async def list_tools(self, server_name: str) -> list[Any]:
342
+ """List all tools available from a server.
343
+
344
+ Args:
345
+ server_name: Name of the server
346
+
347
+ Returns:
348
+ List of tool definitions
349
+
350
+ Raises:
351
+ KeyError: If server_name is not found
352
+ RuntimeError: If server is unavailable
353
+ """
354
+ client = self.get_client(server_name)
355
+
356
+ try:
357
+ async with client:
358
+ tools = await client.list_tools()
359
+ return tools
360
+ except Exception as e:
361
+ logger.error(f"Error listing tools from server {server_name}: {e}")
362
+ raise RuntimeError(
363
+ f"Failed to list tools from server {server_name}: {e}"
364
+ )
365
+
366
+ def get_server_status(self, server_name: str) -> dict[str, Any]:
367
+ """Get connection status for a server.
368
+
369
+ Args:
370
+ server_name: Name of the server
371
+
372
+ Returns:
373
+ Dictionary with status information:
374
+ {
375
+ "connected": bool,
376
+ "error": str,
377
+ "initialized": bool
378
+ }
379
+ """
380
+ return {
381
+ "connected": self._connection_status.get(server_name, False),
382
+ "error": self._connection_errors.get(server_name, ""),
383
+ "initialized": server_name in self._clients
384
+ }
385
+
386
+ def get_all_servers(self) -> list[str]:
387
+ """Get list of all initialized server names.
388
+
389
+ Returns:
390
+ List of server names
391
+ """
392
+ return list(self._clients.keys())
393
+
394
+ def _config_changed(self, server_name: str, new_mcp_config: dict) -> bool:
395
+ """Check if a server's configuration has changed.
396
+
397
+ Compares the server configuration in the current config with the new config
398
+ to determine if the server needs to be reloaded.
399
+
400
+ Args:
401
+ server_name: Name of the server to check
402
+ new_mcp_config: New MCP configuration to compare against
403
+
404
+ Returns:
405
+ True if configuration changed, False otherwise
406
+ """
407
+ old_servers = self._current_config.get("mcpServers", {})
408
+ new_servers = new_mcp_config.get("mcpServers", {})
409
+
410
+ old_config = old_servers.get(server_name, {})
411
+ new_config = new_servers.get(server_name, {})
412
+
413
+ # Compare the configurations (deep comparison)
414
+ return old_config != new_config
415
+
416
+ async def reload(self, new_mcp_config: dict) -> tuple[bool, str | None]:
417
+ """Reload proxy client connections with new MCP server configuration.
418
+
419
+ This method performs an atomic configuration update by:
420
+ 1. Validating the new configuration
421
+ 2. Determining which servers need to be added, removed, or updated
422
+ 3. Closing connections for removed/updated servers
423
+ 4. Creating new clients for added/updated servers
424
+ 5. Preserving unchanged servers (no disruption)
425
+
426
+ The reload follows the lazy connection strategy - new clients are created
427
+ in disconnected state and will connect on first use.
428
+
429
+ Args:
430
+ new_mcp_config: New MCP server configuration dictionary with structure:
431
+ {
432
+ "mcpServers": {
433
+ "server-name": {
434
+ "command": "npx",
435
+ "args": [...],
436
+ "env": {...}
437
+ }
438
+ }
439
+ }
440
+
441
+ Returns:
442
+ Tuple of (success: bool, error_message: str | None)
443
+ - (True, None) if reload successful
444
+ - (False, error_message) if validation failed or reload encountered errors
445
+
446
+ Thread Safety:
447
+ This method is NOT thread-safe. The caller must ensure that reload
448
+ operations are serialized and that no concurrent operations are
449
+ accessing the ProxyManager during reload.
450
+
451
+ Examples:
452
+ >>> manager = ProxyManager()
453
+ >>> manager.initialize_connections(initial_config)
454
+ >>> success, error = await manager.reload(new_config)
455
+ >>> if success:
456
+ ... logger.info("Reload successful")
457
+ >>> else:
458
+ ... logger.error(f"Reload failed: {error}")
459
+ """
460
+ logger.info("ProxyManager reload initiated")
461
+
462
+ # Validate new configuration structure
463
+ try:
464
+ if not isinstance(new_mcp_config, dict):
465
+ error_msg = f"MCP configuration must be a dict, got {type(new_mcp_config).__name__}"
466
+ logger.error(f"Reload validation failed: {error_msg}")
467
+ return False, error_msg
468
+
469
+ new_mcp_servers = new_mcp_config.get("mcpServers", {})
470
+ if not isinstance(new_mcp_servers, dict):
471
+ error_msg = f'"mcpServers" must be a dict, got {type(new_mcp_servers).__name__}'
472
+ logger.error(f"Reload validation failed: {error_msg}")
473
+ return False, error_msg
474
+
475
+ # Validate each server config before proceeding
476
+ for server_name, server_config in new_mcp_servers.items():
477
+ try:
478
+ # Validate by attempting to parse the config (without creating client)
479
+ has_command = "command" in server_config
480
+ has_url = "url" in server_config
481
+
482
+ if not has_command and not has_url:
483
+ raise ValueError(
484
+ f'Server "{server_name}" must specify either "command" (stdio) '
485
+ f'or "url" (HTTP) transport'
486
+ )
487
+
488
+ if has_command and has_url:
489
+ raise ValueError(
490
+ f'Server "{server_name}" cannot have both "command" and "url"'
491
+ )
492
+
493
+ # Validate stdio config
494
+ if has_command:
495
+ command = server_config["command"]
496
+ args = server_config.get("args", [])
497
+ env = server_config.get("env", {})
498
+
499
+ if not isinstance(command, str):
500
+ raise ValueError(f'Server "{server_name}": "command" must be a string')
501
+ if not isinstance(args, list):
502
+ raise ValueError(f'Server "{server_name}": "args" must be a list')
503
+ if not isinstance(env, dict):
504
+ raise ValueError(f'Server "{server_name}": "env" must be a dict')
505
+
506
+ # Validate HTTP config
507
+ if has_url:
508
+ url = server_config["url"]
509
+ headers = server_config.get("headers", {})
510
+
511
+ if not isinstance(url, str):
512
+ raise ValueError(f'Server "{server_name}": "url" must be a string')
513
+ if not isinstance(headers, dict):
514
+ raise ValueError(f'Server "{server_name}": "headers" must be a dict')
515
+
516
+ except Exception as e:
517
+ error_msg = f"Invalid configuration for server '{server_name}': {e}"
518
+ logger.error(f"Reload validation failed: {error_msg}")
519
+ return False, error_msg
520
+
521
+ except Exception as e:
522
+ error_msg = f"Configuration validation error: {e}"
523
+ logger.error(f"Reload validation failed: {error_msg}")
524
+ return False, error_msg
525
+
526
+ logger.info("Configuration validation passed")
527
+
528
+ # Determine server changes
529
+ old_servers = set(self._clients.keys())
530
+ new_servers = set(new_mcp_config.get("mcpServers", {}).keys())
531
+
532
+ servers_to_add = new_servers - old_servers
533
+ servers_to_remove = old_servers - new_servers
534
+ servers_to_check = old_servers & new_servers
535
+
536
+ servers_to_update = [
537
+ s for s in servers_to_check
538
+ if self._config_changed(s, new_mcp_config)
539
+ ]
540
+ servers_unchanged = [
541
+ s for s in servers_to_check
542
+ if not self._config_changed(s, new_mcp_config)
543
+ ]
544
+
545
+ logger.info(
546
+ f"Server changes: "
547
+ f"+{len(servers_to_add)} (add), "
548
+ f"-{len(servers_to_remove)} (remove), "
549
+ f"~{len(servers_to_update)} (update), "
550
+ f"={len(servers_unchanged)} (unchanged)"
551
+ )
552
+
553
+ if servers_to_add:
554
+ logger.info(f"Servers to add: {sorted(servers_to_add)}")
555
+ if servers_to_remove:
556
+ logger.info(f"Servers to remove: {sorted(servers_to_remove)}")
557
+ if servers_to_update:
558
+ logger.info(f"Servers to update: {sorted(servers_to_update)}")
559
+ if servers_unchanged:
560
+ logger.debug(f"Servers unchanged: {sorted(servers_unchanged)}")
561
+
562
+ # Close connections for removed and updated servers
563
+ servers_to_close = list(servers_to_remove) + servers_to_update
564
+
565
+ if servers_to_close:
566
+ logger.info(f"Closing connections for {len(servers_to_close)} servers")
567
+
568
+ for server_name in servers_to_close:
569
+ try:
570
+ client = self._clients.get(server_name)
571
+ if client is not None:
572
+ # Since we use lazy connection strategy with context managers,
573
+ # clients don't maintain persistent connections. However, we
574
+ # should still clean up any resources.
575
+ logger.debug(f"Cleaning up client for server: {server_name}")
576
+ # Note: Client instances don't have an explicit close() method
577
+ # as they use async context managers. Resources are released
578
+ # when we remove the reference.
579
+
580
+ # Remove from all tracking dictionaries
581
+ self._clients.pop(server_name, None)
582
+ self._connection_status.pop(server_name, None)
583
+ self._connection_errors.pop(server_name, None)
584
+
585
+ logger.debug(f"Removed client for server: {server_name}")
586
+
587
+ except Exception as e:
588
+ logger.warning(
589
+ f"Error during cleanup for server '{server_name}': {e}",
590
+ exc_info=True
591
+ )
592
+ # Continue with reload despite cleanup errors
593
+
594
+ # Create clients for new and updated servers
595
+ servers_to_create = list(servers_to_add) + servers_to_update
596
+ new_mcp_servers = new_mcp_config.get("mcpServers", {})
597
+
598
+ if servers_to_create:
599
+ logger.info(f"Creating clients for {len(servers_to_create)} servers")
600
+
601
+ for server_name in servers_to_create:
602
+ try:
603
+ server_config = new_mcp_servers[server_name]
604
+ client = self._create_client(server_name, server_config)
605
+
606
+ self._clients[server_name] = client
607
+ self._connection_status[server_name] = False # Not yet connected
608
+ self._connection_errors[server_name] = ""
609
+
610
+ logger.info(f"Created client for server: {server_name}")
611
+
612
+ except Exception as e:
613
+ # Log error but don't fail the entire reload
614
+ # This follows the lazy connection strategy - servers with
615
+ # initialization errors are recorded and will error on use
616
+ logger.error(
617
+ f"Failed to create client for server '{server_name}': {e}",
618
+ exc_info=True
619
+ )
620
+ self._connection_errors[server_name] = str(e)
621
+
622
+ # Update stored config
623
+ self._current_config = new_mcp_config
624
+
625
+ logger.info(
626
+ f"ProxyManager reload completed successfully. "
627
+ f"Active servers: {len(self._clients)}"
628
+ )
629
+
630
+ return True, None
631
+
632
+ async def close_all_connections(self):
633
+ """Close all Client connections.
634
+
635
+ This is a cleanup method for graceful shutdown.
636
+ Note: Since we use disconnected clients with context managers,
637
+ there are no persistent connections to close. This method is
638
+ provided for API completeness and future extensibility.
639
+ """
640
+ logger.info("Closing all proxy connections")
641
+
642
+ # With disconnected clients, each 'async with' creates and closes
643
+ # its own session, so no explicit cleanup needed
644
+
645
+ # Clear internal state
646
+ self._connection_status.clear()
647
+ self._connection_errors.clear()
648
+
649
+ logger.info(f"Closed connections for {len(self._clients)} servers")