kailash 0.6.2__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. kailash/__init__.py +3 -3
  2. kailash/api/custom_nodes_secure.py +3 -3
  3. kailash/api/gateway.py +1 -1
  4. kailash/api/studio.py +2 -3
  5. kailash/api/workflow_api.py +3 -4
  6. kailash/core/resilience/bulkhead.py +460 -0
  7. kailash/core/resilience/circuit_breaker.py +92 -10
  8. kailash/edge/discovery.py +86 -0
  9. kailash/mcp_server/__init__.py +334 -0
  10. kailash/mcp_server/advanced_features.py +1022 -0
  11. kailash/{mcp → mcp_server}/ai_registry_server.py +29 -4
  12. kailash/mcp_server/auth.py +789 -0
  13. kailash/mcp_server/client.py +712 -0
  14. kailash/mcp_server/discovery.py +1593 -0
  15. kailash/mcp_server/errors.py +673 -0
  16. kailash/mcp_server/oauth.py +1727 -0
  17. kailash/mcp_server/protocol.py +1126 -0
  18. kailash/mcp_server/registry_integration.py +587 -0
  19. kailash/mcp_server/server.py +1747 -0
  20. kailash/{mcp → mcp_server}/servers/ai_registry.py +2 -2
  21. kailash/mcp_server/transports.py +1169 -0
  22. kailash/mcp_server/utils/cache.py +510 -0
  23. kailash/middleware/auth/auth_manager.py +3 -3
  24. kailash/middleware/communication/api_gateway.py +2 -9
  25. kailash/middleware/communication/realtime.py +1 -1
  26. kailash/middleware/mcp/client_integration.py +1 -1
  27. kailash/middleware/mcp/enhanced_server.py +2 -2
  28. kailash/nodes/__init__.py +2 -0
  29. kailash/nodes/admin/audit_log.py +6 -6
  30. kailash/nodes/admin/permission_check.py +8 -8
  31. kailash/nodes/admin/role_management.py +32 -28
  32. kailash/nodes/admin/schema.sql +6 -1
  33. kailash/nodes/admin/schema_manager.py +13 -13
  34. kailash/nodes/admin/security_event.py +16 -20
  35. kailash/nodes/admin/tenant_isolation.py +3 -3
  36. kailash/nodes/admin/transaction_utils.py +3 -3
  37. kailash/nodes/admin/user_management.py +21 -22
  38. kailash/nodes/ai/a2a.py +11 -11
  39. kailash/nodes/ai/ai_providers.py +9 -12
  40. kailash/nodes/ai/embedding_generator.py +13 -14
  41. kailash/nodes/ai/intelligent_agent_orchestrator.py +19 -19
  42. kailash/nodes/ai/iterative_llm_agent.py +3 -3
  43. kailash/nodes/ai/llm_agent.py +213 -36
  44. kailash/nodes/ai/self_organizing.py +2 -2
  45. kailash/nodes/alerts/discord.py +4 -4
  46. kailash/nodes/api/graphql.py +6 -6
  47. kailash/nodes/api/http.py +12 -17
  48. kailash/nodes/api/rate_limiting.py +4 -4
  49. kailash/nodes/api/rest.py +15 -15
  50. kailash/nodes/auth/mfa.py +3 -4
  51. kailash/nodes/auth/risk_assessment.py +2 -2
  52. kailash/nodes/auth/session_management.py +5 -5
  53. kailash/nodes/auth/sso.py +143 -0
  54. kailash/nodes/base.py +6 -2
  55. kailash/nodes/base_async.py +16 -2
  56. kailash/nodes/base_with_acl.py +2 -2
  57. kailash/nodes/cache/__init__.py +9 -0
  58. kailash/nodes/cache/cache.py +1172 -0
  59. kailash/nodes/cache/cache_invalidation.py +870 -0
  60. kailash/nodes/cache/redis_pool_manager.py +595 -0
  61. kailash/nodes/code/async_python.py +2 -1
  62. kailash/nodes/code/python.py +196 -35
  63. kailash/nodes/compliance/data_retention.py +6 -6
  64. kailash/nodes/compliance/gdpr.py +5 -5
  65. kailash/nodes/data/__init__.py +10 -0
  66. kailash/nodes/data/optimistic_locking.py +906 -0
  67. kailash/nodes/data/readers.py +8 -8
  68. kailash/nodes/data/redis.py +349 -0
  69. kailash/nodes/data/sql.py +314 -3
  70. kailash/nodes/data/streaming.py +21 -0
  71. kailash/nodes/enterprise/__init__.py +8 -0
  72. kailash/nodes/enterprise/audit_logger.py +285 -0
  73. kailash/nodes/enterprise/batch_processor.py +22 -3
  74. kailash/nodes/enterprise/data_lineage.py +1 -1
  75. kailash/nodes/enterprise/mcp_executor.py +205 -0
  76. kailash/nodes/enterprise/service_discovery.py +150 -0
  77. kailash/nodes/enterprise/tenant_assignment.py +108 -0
  78. kailash/nodes/logic/async_operations.py +2 -2
  79. kailash/nodes/logic/convergence.py +1 -1
  80. kailash/nodes/logic/operations.py +1 -1
  81. kailash/nodes/monitoring/__init__.py +11 -1
  82. kailash/nodes/monitoring/health_check.py +456 -0
  83. kailash/nodes/monitoring/log_processor.py +817 -0
  84. kailash/nodes/monitoring/metrics_collector.py +627 -0
  85. kailash/nodes/monitoring/performance_benchmark.py +137 -11
  86. kailash/nodes/rag/advanced.py +7 -7
  87. kailash/nodes/rag/agentic.py +49 -2
  88. kailash/nodes/rag/conversational.py +3 -3
  89. kailash/nodes/rag/evaluation.py +3 -3
  90. kailash/nodes/rag/federated.py +3 -3
  91. kailash/nodes/rag/graph.py +3 -3
  92. kailash/nodes/rag/multimodal.py +3 -3
  93. kailash/nodes/rag/optimized.py +5 -5
  94. kailash/nodes/rag/privacy.py +3 -3
  95. kailash/nodes/rag/query_processing.py +6 -6
  96. kailash/nodes/rag/realtime.py +1 -1
  97. kailash/nodes/rag/registry.py +2 -6
  98. kailash/nodes/rag/router.py +1 -1
  99. kailash/nodes/rag/similarity.py +7 -7
  100. kailash/nodes/rag/strategies.py +4 -4
  101. kailash/nodes/security/abac_evaluator.py +6 -6
  102. kailash/nodes/security/behavior_analysis.py +5 -6
  103. kailash/nodes/security/credential_manager.py +1 -1
  104. kailash/nodes/security/rotating_credentials.py +11 -11
  105. kailash/nodes/security/threat_detection.py +8 -8
  106. kailash/nodes/testing/credential_testing.py +2 -2
  107. kailash/nodes/transform/processors.py +5 -5
  108. kailash/runtime/local.py +162 -14
  109. kailash/runtime/parameter_injection.py +425 -0
  110. kailash/runtime/parameter_injector.py +657 -0
  111. kailash/runtime/testing.py +2 -2
  112. kailash/testing/fixtures.py +2 -2
  113. kailash/workflow/builder.py +99 -18
  114. kailash/workflow/builder_improvements.py +207 -0
  115. kailash/workflow/input_handling.py +170 -0
  116. {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/METADATA +21 -8
  117. {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/RECORD +126 -101
  118. kailash/mcp/__init__.py +0 -53
  119. kailash/mcp/client.py +0 -445
  120. kailash/mcp/server.py +0 -292
  121. kailash/mcp/server_enhanced.py +0 -449
  122. kailash/mcp/utils/cache.py +0 -267
  123. /kailash/{mcp → mcp_server}/client_new.py +0 -0
  124. /kailash/{mcp → mcp_server}/utils/__init__.py +0 -0
  125. /kailash/{mcp → mcp_server}/utils/config.py +0 -0
  126. /kailash/{mcp → mcp_server}/utils/formatters.py +0 -0
  127. /kailash/{mcp → mcp_server}/utils/metrics.py +0 -0
  128. {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/WHEEL +0 -0
  129. {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/entry_points.txt +0 -0
  130. {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/licenses/LICENSE +0 -0
  131. {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1593 @@
1
+ """
2
+ Service Discovery System for MCP Servers and Clients.
3
+
4
+ This module provides automatic discovery of MCP servers, their capabilities,
5
+ and health status. It supports multiple discovery mechanisms including
6
+ file-based registry, network scanning, and external service registries.
7
+
8
+ Features:
9
+ - Automatic server registration and deregistration
10
+ - Health checking and monitoring
11
+ - Capability-based server filtering
12
+ - Load balancing and failover
13
+ - Real-time server status updates
14
+ - Network-based discovery protocols
15
+
16
+ Examples:
17
+ Basic service registry:
18
+
19
+ >>> registry = ServiceRegistry()
20
+ >>> registry.register_server({
21
+ ... "name": "weather-server",
22
+ ... "transport": "http",
23
+ ... "url": "http://localhost:8080",
24
+ ... "capabilities": ["weather.get", "weather.forecast"]
25
+ ... })
26
+ >>> servers = registry.discover_servers(capability="weather.get")
27
+
28
+ Network discovery:
29
+
30
+ >>> discoverer = NetworkDiscovery()
31
+ >>> servers = await discoverer.scan_network("192.168.1.0/24")
32
+
33
+ Service mesh integration:
34
+
35
+ >>> mesh = ServiceMesh(registry)
36
+ >>> client = await mesh.get_client_for_capability("weather.get")
37
+ """
38
+
39
+ import asyncio
40
+ import json
41
+ import logging
42
+ import socket
43
+ import time
44
+ import uuid
45
+ from abc import ABC, abstractmethod
46
+ from dataclasses import asdict, dataclass
47
+ from pathlib import Path
48
+ from typing import Any, Dict, List, Optional, Set, Union
49
+ from urllib.parse import urlparse
50
+
51
+ from .auth import AuthProvider
52
+ from .errors import MCPError, MCPErrorCode, ServiceDiscoveryError
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ @dataclass
58
+ class ServerInfo:
59
+ """Information about a discovered MCP server."""
60
+
61
+ name: str
62
+ transport: str # stdio, sse, http
63
+ capabilities: List[str] = None # List of tool names or capability strings
64
+ metadata: Dict[str, Any] = None
65
+ id: Optional[str] = None
66
+ endpoint: Optional[str] = None # URL or command
67
+ command: Optional[str] = None # For stdio transport
68
+ args: Optional[List[str]] = None # For stdio transport
69
+ url: Optional[str] = None # For HTTP/SSE transport
70
+ health_endpoint: Optional[str] = None # Health check endpoint path
71
+ health_status: str = "unknown" # healthy, unhealthy, unknown
72
+ health: Optional[Dict[str, Any]] = None # Health information dict
73
+ last_seen: float = 0.0
74
+ response_time: Optional[float] = None
75
+ version: str = "1.0.0"
76
+ auth_required: bool = False
77
+
78
+ def __post_init__(self):
79
+ if self.last_seen == 0.0:
80
+ self.last_seen = time.time()
81
+
82
+ # Auto-generate ID if not provided
83
+ if self.id is None:
84
+ self.id = f"{self.name}_{hash(self.name) % 10000}"
85
+
86
+ # Initialize metadata if None
87
+ if self.metadata is None:
88
+ self.metadata = {}
89
+
90
+ # Initialize capabilities if None
91
+ if self.capabilities is None:
92
+ self.capabilities = []
93
+
94
+ # Extract response_time from health if available
95
+ if self.health and isinstance(self.health, dict):
96
+ if "response_time" in self.health and self.response_time is None:
97
+ self.response_time = self.health["response_time"]
98
+ if "status" in self.health and self.health_status == "unknown":
99
+ self.health_status = self.health["status"]
100
+
101
+ # Set endpoint based on transport if not provided
102
+ if self.endpoint is None:
103
+ if self.transport == "stdio" and self.command:
104
+ self.endpoint = self.command
105
+ elif self.transport in ["http", "sse"] and self.url:
106
+ self.endpoint = self.url
107
+ else:
108
+ self.endpoint = "unknown"
109
+
110
+ def to_dict(self) -> Dict[str, Any]:
111
+ """Convert to dictionary format."""
112
+ return asdict(self)
113
+
114
+ @classmethod
115
+ def from_dict(cls, data: Dict[str, Any]) -> "ServerInfo":
116
+ """Create from dictionary format."""
117
+ return cls(**data)
118
+
119
+ def is_healthy(self, max_age: float = 300.0) -> bool:
120
+ """Check if server is considered healthy."""
121
+ # Check health dict first
122
+ if self.health and isinstance(self.health, dict):
123
+ status = self.health.get("status", "unknown")
124
+ if status == "healthy":
125
+ age = time.time() - self.last_seen
126
+ return age < max_age
127
+ else:
128
+ return False
129
+
130
+ # Fall back to health_status field
131
+ age = time.time() - self.last_seen
132
+ return self.health_status == "healthy" and age < max_age
133
+
134
+ def matches_capability(self, capability: str) -> bool:
135
+ """Check if server provides a specific capability."""
136
+ return capability in self.capabilities
137
+
138
+ def has_capability(self, capability: str) -> bool:
139
+ """Check if server provides a specific capability (alias for matches_capability)."""
140
+ return self.matches_capability(capability)
141
+
142
+ def matches_transport(self, transport: str) -> bool:
143
+ """Check if server supports a transport type."""
144
+ return self.transport == transport
145
+
146
+ def matches_filter(self, **filters) -> bool:
147
+ """Check if server matches all provided filter criteria.
148
+
149
+ Args:
150
+ **filters: Filter criteria (capability, transport, metadata, name, etc)
151
+
152
+ Returns:
153
+ True if all filters match
154
+ """
155
+ # Check capability filter
156
+ if "capability" in filters:
157
+ if not self.has_capability(filters["capability"]):
158
+ return False
159
+
160
+ # Check transport filter
161
+ if "transport" in filters:
162
+ if not self.matches_transport(filters["transport"]):
163
+ return False
164
+
165
+ # Check name filter
166
+ if "name" in filters:
167
+ if self.name != filters["name"]:
168
+ return False
169
+
170
+ # Check metadata filter (as a dict)
171
+ if "metadata" in filters:
172
+ filter_metadata = filters["metadata"]
173
+ if not self.metadata:
174
+ return False
175
+ for key, value in filter_metadata.items():
176
+ if key not in self.metadata or self.metadata[key] != value:
177
+ return False
178
+
179
+ # Check other direct attributes
180
+ for key, value in filters.items():
181
+ if key not in ["capability", "transport", "name", "metadata"]:
182
+ # Check if it's a direct attribute
183
+ if hasattr(self, key):
184
+ if getattr(self, key) != value:
185
+ return False
186
+ else:
187
+ # If not an attribute, check in metadata
188
+ if self.metadata and key in self.metadata:
189
+ if self.metadata[key] != value:
190
+ return False
191
+ else:
192
+ return False
193
+
194
+ return True
195
+
196
+ def get_priority_score(self) -> float:
197
+ """Calculate priority score for load balancing."""
198
+ base_score = 1.0
199
+
200
+ # Health bonus
201
+ if self.health_status == "healthy":
202
+ base_score += 0.5
203
+ elif self.health_status == "unhealthy":
204
+ base_score -= 0.5
205
+
206
+ # Response time bonus (lower is better)
207
+ if self.response_time:
208
+ if self.response_time < 0.1: # < 100ms
209
+ base_score += 0.3
210
+ elif self.response_time > 1.0: # > 1s
211
+ base_score -= 0.3
212
+
213
+ # Age penalty
214
+ age = time.time() - self.last_seen
215
+ if age > 60: # Over 1 minute old
216
+ base_score -= min(0.4, age / 300) # Max penalty of 0.4
217
+
218
+ return max(0.1, base_score) # Minimum score of 0.1
219
+
220
+
221
+ class DiscoveryBackend(ABC):
222
+ """Abstract base class for discovery backends."""
223
+
224
+ @abstractmethod
225
+ async def register_server(self, server_info: ServerInfo) -> bool:
226
+ """Register a server with the discovery backend."""
227
+ pass
228
+
229
+ @abstractmethod
230
+ async def deregister_server(self, server_id: str) -> bool:
231
+ """Deregister a server from the discovery backend."""
232
+ pass
233
+
234
+ @abstractmethod
235
+ async def get_servers(self, **filters) -> List[ServerInfo]:
236
+ """Get list of servers matching filters."""
237
+ pass
238
+
239
+ @abstractmethod
240
+ async def update_server_health(
241
+ self, server_id: str, health_status: str, response_time: Optional[float] = None
242
+ ) -> bool:
243
+ """Update server health status."""
244
+ pass
245
+
246
+
247
+ class FileBasedDiscovery(DiscoveryBackend):
248
+ """File-based service discovery using JSON registry."""
249
+
250
+ def __init__(self, registry_path: Union[str, Path] = "mcp_registry.json"):
251
+ """Initialize file-based discovery.
252
+
253
+ Args:
254
+ registry_path: Path to the JSON registry file
255
+ """
256
+ self.registry_path = Path(registry_path)
257
+ self._ensure_registry_file()
258
+
259
+ @property
260
+ def _servers(self) -> Dict[str, ServerInfo]:
261
+ """Get servers as a dict for compatibility with tests."""
262
+ registry = self._read_registry()
263
+ servers = {}
264
+ for server_id, server_data in registry["servers"].items():
265
+ server_info = ServerInfo.from_dict(server_data)
266
+ servers[server_info.name] = server_info
267
+ return servers
268
+
269
+ def _ensure_registry_file(self):
270
+ """Ensure registry file exists."""
271
+ if not self.registry_path.exists():
272
+ self.registry_path.write_text(
273
+ json.dumps(
274
+ {"servers": {}, "last_updated": time.time(), "version": "1.0"},
275
+ indent=2,
276
+ )
277
+ )
278
+
279
+ def _read_registry(self) -> Dict[str, Any]:
280
+ """Read registry from file."""
281
+ try:
282
+ return json.loads(self.registry_path.read_text())
283
+ except (json.JSONDecodeError, FileNotFoundError) as e:
284
+ logger.error(f"Failed to read registry: {e}")
285
+ return {"servers": {}, "last_updated": time.time(), "version": "1.0"}
286
+
287
+ def _write_registry(self, registry: Dict[str, Any]):
288
+ """Write registry to file."""
289
+ registry["last_updated"] = time.time()
290
+ self.registry_path.write_text(json.dumps(registry, indent=2))
291
+
292
+ async def register_server(self, server_info: ServerInfo) -> bool:
293
+ """Register a server in the file registry."""
294
+ try:
295
+ registry = self._read_registry()
296
+ registry["servers"][server_info.id] = server_info.to_dict()
297
+ self._write_registry(registry)
298
+ logger.info(f"Registered server: {server_info.name} ({server_info.id})")
299
+ return True
300
+ except Exception as e:
301
+ logger.error(f"Failed to register server {server_info.id}: {e}")
302
+ return False
303
+
304
+ async def deregister_server(self, server_id: str) -> bool:
305
+ """Deregister a server from the file registry."""
306
+ try:
307
+ registry = self._read_registry()
308
+ if server_id in registry["servers"]:
309
+ del registry["servers"][server_id]
310
+ self._write_registry(registry)
311
+ logger.info(f"Deregistered server: {server_id}")
312
+ return True
313
+ return False
314
+ except Exception as e:
315
+ logger.error(f"Failed to deregister server {server_id}: {e}")
316
+ return False
317
+
318
+ async def unregister_server(self, server_name: str) -> bool:
319
+ """Unregister a server by name (alias for test compatibility)."""
320
+ # Find server by name
321
+ registry = self._read_registry()
322
+ server_id_to_remove = None
323
+
324
+ for server_id, server_data in registry["servers"].items():
325
+ if server_data.get("name") == server_name:
326
+ server_id_to_remove = server_id
327
+ break
328
+
329
+ if server_id_to_remove:
330
+ return await self.deregister_server(server_id_to_remove)
331
+ return False
332
+
333
+ async def get_servers(self, **filters) -> List[ServerInfo]:
334
+ """Get servers matching filters."""
335
+ try:
336
+ registry = self._read_registry()
337
+ servers = []
338
+
339
+ for server_data in registry["servers"].values():
340
+ server_info = ServerInfo.from_dict(server_data)
341
+
342
+ # Apply filters
343
+ if self._matches_filters(server_info, filters):
344
+ servers.append(server_info)
345
+
346
+ return servers
347
+ except Exception as e:
348
+ logger.error(f"Failed to get servers: {e}")
349
+ return []
350
+
351
+ async def discover_servers(self, **filters) -> List[ServerInfo]:
352
+ """Discover servers matching filters (alias for get_servers)."""
353
+ return await self.get_servers(**filters)
354
+
355
+ async def get_server(self, server_name: str) -> Optional[ServerInfo]:
356
+ """Get a specific server by name."""
357
+ servers = await self.get_servers()
358
+ for server in servers:
359
+ if server.name == server_name:
360
+ return server
361
+ return None
362
+
363
+ async def update_server_health(
364
+ self,
365
+ server_identifier: str,
366
+ health_info: Union[str, Dict[str, Any]],
367
+ response_time: Optional[float] = None,
368
+ ) -> bool:
369
+ """Update server health in the registry.
370
+
371
+ Args:
372
+ server_identifier: Server ID or name
373
+ health_info: Health status string or health info dict
374
+ response_time: Optional response time (if health_info is string)
375
+ """
376
+ try:
377
+ registry = self._read_registry()
378
+
379
+ # Find server by ID or name
380
+ server_id = None
381
+ for sid, server_data in registry["servers"].items():
382
+ if (
383
+ sid == server_identifier
384
+ or server_data.get("name") == server_identifier
385
+ ):
386
+ server_id = sid
387
+ break
388
+
389
+ if not server_id:
390
+ return False
391
+
392
+ # Update health info
393
+ if isinstance(health_info, dict):
394
+ # Full health info dict provided
395
+ registry["servers"][server_id]["health"] = health_info
396
+ if "status" in health_info:
397
+ registry["servers"][server_id]["health_status"] = health_info[
398
+ "status"
399
+ ]
400
+ if "response_time" in health_info:
401
+ registry["servers"][server_id]["response_time"] = health_info[
402
+ "response_time"
403
+ ]
404
+ else:
405
+ # Simple string status
406
+ registry["servers"][server_id]["health_status"] = health_info
407
+ if response_time is not None:
408
+ registry["servers"][server_id]["response_time"] = response_time
409
+
410
+ registry["servers"][server_id]["last_seen"] = time.time()
411
+ self._write_registry(registry)
412
+ return True
413
+
414
+ except Exception as e:
415
+ logger.error(f"Failed to update server health {server_identifier}: {e}")
416
+ return False
417
+
418
+ def _matches_filters(
419
+ self, server_info: ServerInfo, filters: Dict[str, Any]
420
+ ) -> bool:
421
+ """Check if server matches the provided filters."""
422
+ for key, value in filters.items():
423
+ if key == "capability":
424
+ if not server_info.matches_capability(value):
425
+ return False
426
+ elif key == "transport":
427
+ if not server_info.matches_transport(value):
428
+ return False
429
+ elif key == "healthy_only":
430
+ if value and not server_info.is_healthy():
431
+ return False
432
+ elif key == "name":
433
+ if server_info.name != value:
434
+ return False
435
+ elif key == "auth_required":
436
+ if server_info.auth_required != value:
437
+ return False
438
+
439
+ return True
440
+
441
+ async def save_registry(self, path: str) -> None:
442
+ """Save the current registry to a different file."""
443
+ # Use async file operations to avoid blocking
444
+ import json
445
+
446
+ import aiofiles
447
+
448
+ try:
449
+ # Read current registry
450
+ registry = self._read_registry()
451
+
452
+ # Write to the specified path asynchronously
453
+ async with aiofiles.open(path, "w") as f:
454
+ await f.write(json.dumps(registry, indent=2))
455
+
456
+ except Exception as e:
457
+ logger.error(f"Failed to save registry to {path}: {e}")
458
+ raise
459
+
460
+ async def load_registry(self, path: str) -> None:
461
+ """Load registry from a different file."""
462
+ import json
463
+ from pathlib import Path
464
+
465
+ import aiofiles
466
+
467
+ try:
468
+ if not Path(path).exists():
469
+ # Create empty registry if file doesn't exist
470
+ logger.warning(
471
+ f"Registry file not found: {path}, creating empty registry"
472
+ )
473
+ self._ensure_registry_file()
474
+ return
475
+
476
+ # Read from the specified path asynchronously
477
+ async with aiofiles.open(path, "r") as f:
478
+ content = await f.read()
479
+ registry = json.loads(content)
480
+
481
+ # Write to our registry path
482
+ self._write_registry(registry)
483
+
484
+ except Exception as e:
485
+ logger.error(f"Failed to load registry from {path}: {e}")
486
+ raise
487
+
488
+
489
+ class NetworkDiscovery:
490
+ """Network-based discovery using UDP broadcast/multicast."""
491
+
492
+ DISCOVERY_PORT = 8765
493
+ MULTICAST_GROUP = "224.0.0.251"
494
+
495
+ def __init__(
496
+ self,
497
+ port: int = DISCOVERY_PORT,
498
+ multicast_group: str = None,
499
+ interface: str = "0.0.0.0",
500
+ ):
501
+ """Initialize network discovery.
502
+
503
+ Args:
504
+ port: UDP port for discovery
505
+ multicast_group: Multicast group address
506
+ interface: Network interface to bind to
507
+ """
508
+ self.port = port
509
+ self.multicast_group = multicast_group or self.MULTICAST_GROUP
510
+ self.interface = interface
511
+ self.running = False
512
+ self._discovered_servers: Dict[str, ServerInfo] = {}
513
+ self._discovery_socket: Optional[socket.socket] = None
514
+ self._transport = None
515
+ self._protocol = None
516
+
517
+ async def start(self):
518
+ """Start network discovery."""
519
+ loop = asyncio.get_event_loop()
520
+
521
+ # Create UDP endpoint
522
+ transport, protocol = await loop.create_datagram_endpoint(
523
+ lambda: self, local_addr=(self.interface, self.port), reuse_port=True
524
+ )
525
+
526
+ self._transport = transport
527
+ self._protocol = protocol
528
+ self.running = True
529
+
530
+ async def stop(self):
531
+ """Stop network discovery."""
532
+ self.running = False
533
+ if self._transport:
534
+ self._transport.close()
535
+ self._transport = None
536
+ self._protocol = None
537
+
538
+ async def start_discovery_listener(self):
539
+ """Start listening for server announcements."""
540
+ await self.start()
541
+
542
+ # Create UDP socket for listening
543
+ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
544
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
545
+ sock.bind(("", self.port))
546
+
547
+ # Join multicast group
548
+ mreq = socket.inet_aton(self.MULTICAST_GROUP) + socket.inet_aton("0.0.0.0")
549
+ sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
550
+
551
+ self._discovery_socket = sock
552
+
553
+ logger.info(f"Started network discovery listener on port {self.port}")
554
+
555
+ while self.running:
556
+ try:
557
+ # Set timeout to check running flag periodically
558
+ sock.settimeout(1.0)
559
+ data, addr = sock.recvfrom(1024)
560
+
561
+ await self._process_announcement(data, addr)
562
+
563
+ except socket.timeout:
564
+ continue
565
+ except Exception as e:
566
+ logger.error(f"Error in discovery listener: {e}")
567
+
568
+ async def _process_announcement(self, data: bytes, addr: tuple):
569
+ """Process server announcement."""
570
+ try:
571
+ announcement = json.loads(data.decode())
572
+
573
+ if announcement.get("type") == "mcp_server_announcement":
574
+ server_info = ServerInfo(
575
+ id=announcement.get("id", str(uuid.uuid4())),
576
+ name=announcement.get("name", "unknown"),
577
+ transport=announcement.get("transport", "http"),
578
+ endpoint=announcement.get("endpoint", f"http://{addr[0]}:8080"),
579
+ capabilities=announcement.get("capabilities", []),
580
+ metadata=announcement.get("metadata", {}),
581
+ health_status="healthy",
582
+ last_seen=time.time(),
583
+ version=announcement.get("version", "1.0.0"),
584
+ auth_required=announcement.get("auth_required", False),
585
+ )
586
+
587
+ self._discovered_servers[server_info.id] = server_info
588
+ logger.info(f"Discovered server: {server_info.name} at {addr[0]}")
589
+
590
+ except (json.JSONDecodeError, KeyError) as e:
591
+ logger.debug(f"Invalid announcement from {addr[0]}: {e}")
592
+
593
+ async def scan_network(
594
+ self, network: str = "192.168.1.0/24", timeout: float = 5.0
595
+ ) -> List[ServerInfo]:
596
+ """Actively scan network for MCP servers.
597
+
598
+ Args:
599
+ network: Network to scan (CIDR notation)
600
+ timeout: Scan timeout in seconds
601
+
602
+ Returns:
603
+ List of discovered servers
604
+ """
605
+ import ipaddress
606
+
607
+ discovered = []
608
+ network_obj = ipaddress.IPv4Network(network, strict=False)
609
+
610
+ logger.info(f"Scanning network {network} for MCP servers...")
611
+
612
+ # Create semaphore to limit concurrent connections
613
+ semaphore = asyncio.Semaphore(50)
614
+
615
+ async def scan_host(ip: str):
616
+ async with semaphore:
617
+ try:
618
+ # Try common MCP ports
619
+ for port in [8080, 8765, 3000, 5000]:
620
+ try:
621
+ reader, writer = await asyncio.wait_for(
622
+ asyncio.open_connection(str(ip), port), timeout=1.0
623
+ )
624
+
625
+ # Send MCP discovery request
626
+ discovery_request = json.dumps(
627
+ {"type": "mcp_discovery_request", "version": "1.0"}
628
+ ).encode()
629
+
630
+ writer.write(discovery_request)
631
+ await writer.drain()
632
+
633
+ # Read response
634
+ response_data = await asyncio.wait_for(
635
+ reader.read(1024), timeout=2.0
636
+ )
637
+
638
+ response = json.loads(response_data.decode())
639
+
640
+ if response.get("type") == "mcp_discovery_response":
641
+ server_info = ServerInfo(
642
+ id=response.get("id", str(uuid.uuid4())),
643
+ name=response.get("name", f"server-{ip}"),
644
+ transport=response.get("transport", "http"),
645
+ endpoint=f"http://{ip}:{port}",
646
+ capabilities=response.get("capabilities", []),
647
+ metadata=response.get("metadata", {}),
648
+ health_status="healthy",
649
+ version=response.get("version", "1.0.0"),
650
+ auth_required=response.get("auth_required", False),
651
+ )
652
+ discovered.append(server_info)
653
+ break
654
+
655
+ writer.close()
656
+ await writer.wait_closed()
657
+
658
+ except (
659
+ asyncio.TimeoutError,
660
+ ConnectionRefusedError,
661
+ json.JSONDecodeError,
662
+ ):
663
+ continue
664
+ except Exception as e:
665
+ logger.debug(f"Error scanning {ip}:{port}: {e}")
666
+ continue
667
+
668
+ except Exception as e:
669
+ logger.debug(f"Error scanning host {ip}: {e}")
670
+
671
+ # Scan all hosts in parallel
672
+ tasks = [scan_host(ip) for ip in network_obj.hosts()]
673
+ await asyncio.gather(*tasks, return_exceptions=True)
674
+
675
+ logger.info(f"Network scan completed. Found {len(discovered)} servers.")
676
+ return discovered
677
+
678
+ def _send_message(self, message: Dict[str, Any], address: tuple = None):
679
+ """Send a message over the network."""
680
+ if not self._transport:
681
+ logger.warning("Transport not initialized")
682
+ return
683
+
684
+ data = json.dumps(message).encode()
685
+
686
+ if address:
687
+ self._transport.sendto(data, address)
688
+ else:
689
+ # Broadcast/multicast
690
+ self._transport.sendto(data, (self.multicast_group, self.port))
691
+
692
+ async def announce_server(self, server_info: ServerInfo):
693
+ """Announce a server on the network."""
694
+ message = {"type": "server_announcement", "server": server_info.to_dict()}
695
+ self._send_message(message)
696
+
697
+ def stop_discovery(self):
698
+ """Stop network discovery."""
699
+ self.running = False
700
+ if self._discovery_socket:
701
+ self._discovery_socket.close()
702
+ self._discovery_socket = None
703
+ logger.info("Stopped network discovery")
704
+
705
+ def get_discovered_servers(self) -> List[ServerInfo]:
706
+ """Get list of currently discovered servers."""
707
+ # Filter out stale servers (older than 5 minutes)
708
+ current_time = time.time()
709
+ active_servers = []
710
+
711
+ for server in self._discovered_servers.values():
712
+ if current_time - server.last_seen < 300: # 5 minutes
713
+ active_servers.append(server)
714
+
715
+ return active_servers
716
+
717
+ async def discover_servers(self, **filters) -> List[ServerInfo]:
718
+ """Discover servers on the network (returns already discovered servers).
719
+
720
+ Args:
721
+ **filters: Filters to apply (capability, transport, etc)
722
+
723
+ Returns:
724
+ List of servers matching filters
725
+ """
726
+ servers = self.get_discovered_servers()
727
+
728
+ # Apply filters
729
+ if filters:
730
+ filtered = []
731
+ for server in servers:
732
+ if server.matches_filter(**filters):
733
+ filtered.append(server)
734
+ return filtered
735
+
736
+ return servers
737
+
738
+ async def _handle_discovery_message(self, message: Dict[str, Any], addr: tuple):
739
+ """Handle incoming discovery message."""
740
+ msg_type = message.get("type")
741
+
742
+ if msg_type == "server_announcement":
743
+ # Handle server announcement
744
+ server_data = message.get("server", {})
745
+ server_info = ServerInfo.from_dict(server_data)
746
+ server_info.last_seen = time.time()
747
+
748
+ # Store by name
749
+ self._discovered_servers[server_info.name] = server_info
750
+ logger.info(f"Discovered server: {server_info.name} from {addr}")
751
+
752
+ elif msg_type == "server_query":
753
+ # Respond to server queries
754
+ pass
755
+
756
+ else:
757
+ logger.debug(f"Unknown message type: {msg_type}")
758
+
759
+ def datagram_received(self, data: bytes, addr: tuple):
760
+ """Handle received datagram (part of asyncio protocol)."""
761
+ try:
762
+ message = json.loads(data.decode())
763
+ # Try to get current event loop
764
+ try:
765
+ loop = asyncio.get_running_loop()
766
+ asyncio.create_task(self._handle_discovery_message(message, addr))
767
+ except RuntimeError:
768
+ # No event loop, run synchronously
769
+ asyncio.run(self._handle_discovery_message(message, addr))
770
+ except json.JSONDecodeError:
771
+ logger.warning(f"Invalid JSON received from {addr}")
772
+ except Exception as e:
773
+ logger.error(f"Error handling datagram from {addr}: {e}")
774
+
775
+
776
+ class ServiceRegistry:
777
+ """Main service registry coordinating multiple discovery backends."""
778
+
779
+ def __init__(self, backends: Optional[List[DiscoveryBackend]] = None):
780
+ """Initialize service registry.
781
+
782
+ Args:
783
+ backends: List of discovery backends to use
784
+ """
785
+ if backends is None:
786
+ backends = [FileBasedDiscovery()]
787
+
788
+ self.backends = backends
789
+ self.health_checker = HealthChecker() # Initialize health checker
790
+ self.load_balancer = LoadBalancer()
791
+ self.service_mesh = ServiceMesh(self)
792
+ self._server_cache: Dict[str, ServerInfo] = {}
793
+ self._cache_expiry = 60.0 # Cache for 1 minute
794
+ self._last_cache_update = 0.0
795
+
796
+ async def register_server(
797
+ self, server_info: Union[ServerInfo, Dict[str, Any]]
798
+ ) -> bool:
799
+ """Register a server with all backends.
800
+
801
+ Args:
802
+ server_info: ServerInfo object or server configuration dictionary
803
+
804
+ Returns:
805
+ True if registration succeeded in at least one backend
806
+ """
807
+ # Convert config to ServerInfo if needed
808
+ if isinstance(server_info, dict):
809
+ server_config = server_info
810
+ server_info = ServerInfo(
811
+ name=server_config["name"],
812
+ transport=server_config["transport"],
813
+ capabilities=server_config.get("capabilities", []),
814
+ id=server_config.get("id"),
815
+ endpoint=server_config.get("endpoint"),
816
+ command=server_config.get("command"),
817
+ args=server_config.get("args"),
818
+ url=server_config.get("url"),
819
+ metadata=server_config.get("metadata", {}),
820
+ auth_required=server_config.get("auth_required", False),
821
+ version=server_config.get("version", "1.0.0"),
822
+ )
823
+
824
+ success_count = 0
825
+ for backend in self.backends:
826
+ try:
827
+ if await backend.register_server(server_info):
828
+ success_count += 1
829
+ except Exception as e:
830
+ logger.error(
831
+ f"Backend {type(backend).__name__} registration failed: {e}"
832
+ )
833
+
834
+ if success_count > 0:
835
+ # Update cache
836
+ self._server_cache[server_info.id] = server_info
837
+ logger.info(
838
+ f"Successfully registered server {server_info.name} with {success_count} backends"
839
+ )
840
+ return True
841
+
842
+ return False
843
+
844
+ async def deregister_server(self, server_id: str) -> bool:
845
+ """Deregister a server from all backends."""
846
+ success_count = 0
847
+ for backend in self.backends:
848
+ try:
849
+ if await backend.deregister_server(server_id):
850
+ success_count += 1
851
+ except Exception as e:
852
+ logger.error(
853
+ f"Backend {type(backend).__name__} deregistration failed: {e}"
854
+ )
855
+
856
+ # Remove from cache
857
+ if server_id in self._server_cache:
858
+ del self._server_cache[server_id]
859
+
860
+ return success_count > 0
861
+
862
+ async def discover_servers(self, **filters) -> List[ServerInfo]:
863
+ """Discover servers matching the given filters.
864
+
865
+ Args:
866
+ **filters: Filter criteria (capability, transport, healthy_only, etc.)
867
+
868
+ Returns:
869
+ List of matching servers, deduplicated and sorted by priority
870
+ """
871
+ # Check cache first
872
+ if (
873
+ time.time() - self._last_cache_update
874
+ ) < self._cache_expiry and not filters.get("force_refresh"):
875
+ cached_servers = list(self._server_cache.values())
876
+ return self._filter_and_sort_servers(cached_servers, filters)
877
+
878
+ # Fetch from all backends
879
+ all_servers: Dict[str, ServerInfo] = {}
880
+
881
+ for backend in self.backends:
882
+ try:
883
+ servers = await backend.get_servers(**filters)
884
+ for server in servers:
885
+ # Use latest info if server exists in multiple backends
886
+ if (
887
+ server.id not in all_servers
888
+ or server.last_seen > all_servers[server.id].last_seen
889
+ ):
890
+ all_servers[server.id] = server
891
+ except Exception as e:
892
+ logger.error(f"Backend {type(backend).__name__} discovery failed: {e}")
893
+
894
+ # Update cache
895
+ self._server_cache = all_servers.copy()
896
+ self._last_cache_update = time.time()
897
+
898
+ servers_list = list(all_servers.values())
899
+ return self._filter_and_sort_servers(servers_list, filters)
900
+
901
+ def _filter_and_sort_servers(
902
+ self, servers: List[ServerInfo], filters: Dict[str, Any]
903
+ ) -> List[ServerInfo]:
904
+ """Filter and sort servers by priority."""
905
+ filtered_servers = []
906
+
907
+ for server in servers:
908
+ # Apply remaining filters not handled by backends
909
+ if filters.get("healthy_only") and not server.is_healthy():
910
+ continue
911
+ if filters.get("capability") and not server.matches_capability(
912
+ filters["capability"]
913
+ ):
914
+ continue
915
+ if filters.get("transport") and not server.matches_transport(
916
+ filters["transport"]
917
+ ):
918
+ continue
919
+
920
+ filtered_servers.append(server)
921
+
922
+ # Sort by priority score (highest first)
923
+ filtered_servers.sort(key=lambda s: s.get_priority_score(), reverse=True)
924
+
925
+ return filtered_servers
926
+
927
+ async def get_best_server(
928
+ self, capability: str, transport: Optional[str] = None
929
+ ) -> Optional[ServerInfo]:
930
+ """Get the best server for a specific capability.
931
+
932
+ Args:
933
+ capability: Required capability
934
+ transport: Preferred transport type
935
+
936
+ Returns:
937
+ Best available server or None
938
+ """
939
+ filters = {"capability": capability, "healthy_only": True}
940
+ if transport:
941
+ filters["transport"] = transport
942
+
943
+ servers = await self.discover_servers(**filters)
944
+ return servers[0] if servers else None
945
+
946
+ def start_health_checking(self, interval: float = 30.0):
947
+ """Start periodic health checking of registered servers.
948
+
949
+ Args:
950
+ interval: Health check interval in seconds
951
+ """
952
+ if not self.health_checker:
953
+ self.health_checker = HealthChecker(self)
954
+
955
+ asyncio.create_task(self.health_checker.start_periodic_checks(interval))
956
+
957
+ def stop_health_checking(self):
958
+ """Stop health checking."""
959
+ if self.health_checker:
960
+ self.health_checker.stop()
961
+
962
+ async def start_health_monitoring(self, interval: float = 30.0):
963
+ """Start health monitoring (async version)."""
964
+ if self.health_checker:
965
+ await self.health_checker.start(self)
966
+
967
+ async def stop_health_monitoring(self):
968
+ """Stop health monitoring (async version)."""
969
+ if self.health_checker:
970
+ self.health_checker.stop()
971
+
972
+ async def get_best_server_for_capability(
973
+ self, capability: str
974
+ ) -> Optional[ServerInfo]:
975
+ """Get best server for capability (async version)."""
976
+ return await self.get_best_server(capability)
977
+
978
+
979
+ class HealthChecker:
980
+ """Health checker for registered MCP servers."""
981
+
982
+ def __init__(self, registry: ServiceRegistry = None, check_interval: float = 30.0):
983
+ """Initialize health checker.
984
+
985
+ Args:
986
+ registry: Service registry to check (optional)
987
+ check_interval: Default check interval in seconds
988
+ """
989
+ self.registry = registry
990
+ self.check_interval = check_interval
991
+ self._running = False
992
+ self._check_task = None
993
+ self.running = False # Keep for backward compatibility
994
+
995
+ async def start(self, registry: ServiceRegistry = None):
996
+ """Start health checking with the registry."""
997
+ if registry:
998
+ self.registry = registry
999
+ if not self.registry:
1000
+ raise ValueError("No registry provided for health checking")
1001
+
1002
+ # Set running state immediately
1003
+ self._running = True
1004
+ self.running = True
1005
+
1006
+ # Start periodic checks
1007
+ self._check_task = asyncio.create_task(
1008
+ self.start_periodic_checks(self.check_interval)
1009
+ )
1010
+
1011
+ async def start_periodic_checks(self, interval: float = 30.0):
1012
+ """Start periodic health checks.
1013
+
1014
+ Args:
1015
+ interval: Check interval in seconds
1016
+ """
1017
+ self._running = True
1018
+ self.running = True # Keep for backward compatibility
1019
+ logger.info(f"Started health checking with {interval}s interval")
1020
+
1021
+ while self._running:
1022
+ try:
1023
+ await self.check_all_servers()
1024
+ await asyncio.sleep(interval)
1025
+ except Exception as e:
1026
+ logger.error(f"Error in health checking: {e}")
1027
+ await asyncio.sleep(min(interval, 10)) # Back off on error
1028
+
1029
+ async def check_all_servers(self):
1030
+ """Check health of all registered servers."""
1031
+ # Get all servers without filters
1032
+ servers = await self.registry.discover_servers(force_refresh=True)
1033
+
1034
+ # Check health in parallel with limited concurrency
1035
+ semaphore = asyncio.Semaphore(10)
1036
+
1037
+ async def check_server(server: ServerInfo):
1038
+ async with semaphore:
1039
+ health_result = await self.check_server_health(server)
1040
+ health_status = health_result["status"]
1041
+ response_time = health_result.get("response_time")
1042
+
1043
+ # Update health in all backends
1044
+ for backend in self.registry.backends:
1045
+ try:
1046
+ await backend.update_server_health(
1047
+ server.id, health_status, response_time
1048
+ )
1049
+ except Exception as e:
1050
+ logger.error(f"Failed to update health for {server.id}: {e}")
1051
+
1052
+ # Run health checks in parallel
1053
+ await asyncio.gather(
1054
+ *[check_server(server) for server in servers], return_exceptions=True
1055
+ )
1056
+
1057
+ async def check_server_health(self, server: ServerInfo) -> Dict[str, Any]:
1058
+ """Check health of a single server.
1059
+
1060
+ Args:
1061
+ server: Server to check
1062
+
1063
+ Returns:
1064
+ Dictionary with status and response_time
1065
+ """
1066
+ start_time = time.time()
1067
+
1068
+ try:
1069
+ if server.transport == "http" or server.transport == "sse":
1070
+ # HTTP/SSE health check
1071
+ import aiohttp
1072
+
1073
+ async with aiohttp.ClientSession(
1074
+ timeout=aiohttp.ClientTimeout(total=10)
1075
+ ) as session:
1076
+ # Try health endpoint first (use configured endpoint or default to /health)
1077
+ health_path = server.health_endpoint or "/health"
1078
+ base_url = server.endpoint or server.url or ""
1079
+ health_url = f"{base_url.rstrip('/')}{health_path}"
1080
+ try:
1081
+ async with session.get(health_url) as response:
1082
+ if response.status == 200:
1083
+ response_time = time.time() - start_time
1084
+ return {
1085
+ "status": "healthy",
1086
+ "response_time": response_time,
1087
+ }
1088
+ except:
1089
+ pass
1090
+
1091
+ # Fallback to main endpoint
1092
+ try:
1093
+ async with session.get(server.endpoint) as response:
1094
+ response_time = time.time() - start_time
1095
+ if response.status < 500:
1096
+ return {
1097
+ "status": "healthy",
1098
+ "response_time": response_time,
1099
+ }
1100
+ else:
1101
+ return {
1102
+ "status": "unhealthy",
1103
+ "response_time": response_time,
1104
+ }
1105
+ except:
1106
+ return {"status": "unhealthy", "response_time": None}
1107
+
1108
+ elif server.transport == "stdio":
1109
+ # For stdio, check if command exists and is executable
1110
+ if server.command:
1111
+ try:
1112
+ # Test if we can run the command
1113
+ import asyncio
1114
+
1115
+ process = await asyncio.create_subprocess_exec(
1116
+ server.command,
1117
+ *(server.args if server.args else []),
1118
+ stdout=asyncio.subprocess.PIPE,
1119
+ stderr=asyncio.subprocess.PIPE,
1120
+ )
1121
+ returncode = await process.wait()
1122
+ response_time = time.time() - start_time
1123
+
1124
+ if returncode == 0:
1125
+ return {"status": "healthy", "response_time": response_time}
1126
+ else:
1127
+ return {
1128
+ "status": "unhealthy",
1129
+ "response_time": response_time,
1130
+ }
1131
+ except Exception as e:
1132
+ return {
1133
+ "status": "unhealthy",
1134
+ "response_time": None,
1135
+ "error": str(e),
1136
+ }
1137
+ else:
1138
+ # No command specified, check if recently seen
1139
+ age = time.time() - server.last_seen
1140
+ if age < 300: # 5 minutes
1141
+ return {"status": "healthy", "response_time": None}
1142
+ else:
1143
+ return {"status": "unknown", "response_time": None}
1144
+
1145
+ else:
1146
+ logger.warning(
1147
+ f"Unknown transport type for health check: {server.transport}"
1148
+ )
1149
+ return {"status": "unknown", "response_time": None}
1150
+
1151
+ except Exception as e:
1152
+ logger.debug(f"Health check failed for {server.name}: {e}")
1153
+ return {"status": "unhealthy", "response_time": None}
1154
+
1155
+ async def stop(self):
1156
+ """Stop health checking."""
1157
+ self._running = False
1158
+ self.running = False
1159
+ if self._check_task:
1160
+ self._check_task.cancel()
1161
+ self._check_task = None
1162
+ logger.info("Stopped health checking")
1163
+
1164
+
1165
+ class ServiceMesh:
1166
+ """Service mesh for intelligent client routing and load balancing."""
1167
+
1168
+ def __init__(self, registry: ServiceRegistry):
1169
+ """Initialize service mesh.
1170
+
1171
+ Args:
1172
+ registry: Service registry to use
1173
+ """
1174
+ self.registry = registry
1175
+ self._client_cache: Dict[str, Any] = {}
1176
+ self._load_balancer = LoadBalancer()
1177
+
1178
+ async def get_client_for_capability(
1179
+ self, capability: str, transport: Optional[str] = None
1180
+ ) -> Optional[Any]:
1181
+ """Get an MCP client for a specific capability.
1182
+
1183
+ Args:
1184
+ capability: Required capability
1185
+ transport: Preferred transport type
1186
+
1187
+ Returns:
1188
+ Configured MCP client or None
1189
+ """
1190
+ # Find best server
1191
+ server = await self.registry.get_best_server(capability, transport)
1192
+ if not server:
1193
+ logger.warning(f"No server found for capability: {capability}")
1194
+ return None
1195
+
1196
+ # Check cache
1197
+ cache_key = f"{server.id}_{capability}"
1198
+ if cache_key in self._client_cache:
1199
+ return self._client_cache[cache_key]
1200
+
1201
+ # Create new client
1202
+ try:
1203
+ client = await self._create_client(server)
1204
+ self._client_cache[cache_key] = client
1205
+ logger.info(f"Created MCP client for {server.name} ({capability})")
1206
+ return client
1207
+ except Exception as e:
1208
+ logger.error(f"Failed to create client for {server.name}: {e}")
1209
+ return None
1210
+
1211
+ async def call_with_failover(
1212
+ self,
1213
+ capability: str,
1214
+ tool_name: str,
1215
+ arguments: Dict[str, Any],
1216
+ max_retries: int = 3,
1217
+ ) -> Dict[str, Any]:
1218
+ """Call a tool with automatic failover to backup servers.
1219
+
1220
+ Args:
1221
+ capability: Required capability
1222
+ tool_name: Tool to call
1223
+ arguments: Tool arguments
1224
+ max_retries: Maximum retry attempts
1225
+
1226
+ Returns:
1227
+ Tool result
1228
+ """
1229
+ servers = await self.registry.discover_servers(
1230
+ capability=capability, healthy_only=True
1231
+ )
1232
+
1233
+ if not servers:
1234
+ raise ServiceDiscoveryError(
1235
+ f"No healthy servers found for capability: {capability}"
1236
+ )
1237
+
1238
+ last_error = None
1239
+
1240
+ for attempt in range(max_retries):
1241
+ # Select server using load balancer
1242
+ server = self._load_balancer.select_server(servers)
1243
+ if not server:
1244
+ break
1245
+
1246
+ try:
1247
+ # Create client for the selected server
1248
+ client = await self._create_client(server)
1249
+
1250
+ # Call the tool
1251
+ result = await client.call_tool(tool_name, arguments)
1252
+
1253
+ # Record successful call
1254
+ self._load_balancer.record_success(server.id)
1255
+ return result
1256
+
1257
+ except Exception as e:
1258
+ last_error = e
1259
+ logger.warning(f"Call to {server.name} failed: {e}")
1260
+
1261
+ # Record failure and remove from current attempt
1262
+ self._load_balancer.record_failure(server.id)
1263
+ servers = [s for s in servers if s.id != server.id]
1264
+
1265
+ # All retries failed
1266
+ raise ServiceDiscoveryError(
1267
+ f"All servers failed for capability {capability}: {last_error}"
1268
+ )
1269
+
1270
+ def _create_server_config(self, server: ServerInfo) -> Dict[str, Any]:
1271
+ """Create server configuration for MCP client.
1272
+
1273
+ Args:
1274
+ server: Server information
1275
+
1276
+ Returns:
1277
+ Server configuration dictionary
1278
+ """
1279
+ if server.transport == "stdio":
1280
+ # Parse command from endpoint
1281
+ if server.endpoint.startswith("python "):
1282
+ command_parts = server.endpoint.split()
1283
+ return {
1284
+ "transport": "stdio",
1285
+ "command": command_parts[0],
1286
+ "args": command_parts[1:],
1287
+ "env": server.metadata.get("env", {}),
1288
+ }
1289
+ else:
1290
+ return {
1291
+ "transport": "stdio",
1292
+ "command": server.endpoint,
1293
+ "args": [],
1294
+ "env": server.metadata.get("env", {}),
1295
+ }
1296
+
1297
+ elif server.transport in ["http", "sse"]:
1298
+ config = {"transport": server.transport, "url": server.endpoint}
1299
+
1300
+ # Add authentication config if present
1301
+ if server.auth_required:
1302
+ auth_config = server.metadata.get("auth_config", {})
1303
+ if auth_config:
1304
+ config["auth"] = auth_config
1305
+
1306
+ return config
1307
+
1308
+ else:
1309
+ # Fallback configuration
1310
+ return {"transport": server.transport, "url": server.endpoint}
1311
+
1312
+ async def _create_client(self, server: ServerInfo) -> Any:
1313
+ """Create a client for the given server.
1314
+
1315
+ Args:
1316
+ server: Server to create client for
1317
+
1318
+ Returns:
1319
+ MCP client instance
1320
+ """
1321
+ try:
1322
+ from .client import MCPClient
1323
+
1324
+ # Create client configuration based on transport
1325
+ if server.transport == "stdio":
1326
+ client = MCPClient(
1327
+ transport="stdio",
1328
+ command=server.command,
1329
+ args=server.args,
1330
+ env=server.metadata.get("env", {}),
1331
+ )
1332
+ elif server.transport in ["http", "sse"]:
1333
+ client = MCPClient(
1334
+ transport=server.transport, url=server.url or server.endpoint
1335
+ )
1336
+ else:
1337
+ raise ValueError(f"Unsupported transport: {server.transport}")
1338
+
1339
+ return client
1340
+ except Exception as e:
1341
+ logger.error(f"Failed to create client for {server.name}: {e}")
1342
+ raise
1343
+
1344
+
1345
+ class LoadBalancer:
1346
+ """Load balancer for distributing requests across servers."""
1347
+
1348
+ def __init__(self):
1349
+ """Initialize load balancer."""
1350
+ self._server_stats: Dict[str, Dict[str, Any]] = {}
1351
+
1352
+ def select_server(self, servers: List[ServerInfo]) -> Optional[ServerInfo]:
1353
+ """Select best server based on load balancing algorithm.
1354
+
1355
+ Args:
1356
+ servers: List of available servers
1357
+
1358
+ Returns:
1359
+ Selected server or None
1360
+ """
1361
+ if not servers:
1362
+ return None
1363
+
1364
+ # Weight servers by priority score and current load
1365
+ weighted_servers = []
1366
+
1367
+ for server in servers:
1368
+ base_weight = server.get_priority_score()
1369
+
1370
+ # Adjust weight based on current load
1371
+ stats = self._server_stats.get(server.id, {})
1372
+ recent_failures = stats.get("recent_failures", 0)
1373
+ recent_calls = stats.get("recent_calls", 0)
1374
+
1375
+ # Penalty for recent failures
1376
+ if recent_failures > 0:
1377
+ base_weight *= 1.0 - min(0.5, recent_failures * 0.1)
1378
+
1379
+ # Penalty for high load
1380
+ if recent_calls > 10:
1381
+ base_weight *= 1.0 - min(0.3, (recent_calls - 10) * 0.02)
1382
+
1383
+ weighted_servers.append((server, max(0.1, base_weight)))
1384
+
1385
+ # Select using weighted random
1386
+ import random
1387
+
1388
+ total_weight = sum(weight for _, weight in weighted_servers)
1389
+
1390
+ if total_weight == 0:
1391
+ return servers[0] # Fallback to first server
1392
+
1393
+ r = random.uniform(0, total_weight)
1394
+ current_weight = 0
1395
+
1396
+ for server, weight in weighted_servers:
1397
+ current_weight += weight
1398
+ if r <= current_weight:
1399
+ return server
1400
+
1401
+ return servers[0] # Fallback
1402
+
1403
+ def record_success(self, server_id: str):
1404
+ """Record successful call to server."""
1405
+ if server_id not in self._server_stats:
1406
+ self._server_stats[server_id] = {
1407
+ "recent_calls": 0,
1408
+ "recent_failures": 0,
1409
+ "last_reset": time.time(),
1410
+ }
1411
+
1412
+ stats = self._server_stats[server_id]
1413
+ stats["recent_calls"] += 1
1414
+
1415
+ # Decay recent failures on success
1416
+ stats["recent_failures"] = max(0, stats["recent_failures"] - 1)
1417
+
1418
+ self._maybe_reset_stats(server_id)
1419
+
1420
+ def record_failure(self, server_id: str):
1421
+ """Record failed call to server."""
1422
+ if server_id not in self._server_stats:
1423
+ self._server_stats[server_id] = {
1424
+ "recent_calls": 0,
1425
+ "recent_failures": 0,
1426
+ "last_reset": time.time(),
1427
+ }
1428
+
1429
+ stats = self._server_stats[server_id]
1430
+ stats["recent_failures"] += 1
1431
+
1432
+ self._maybe_reset_stats(server_id)
1433
+
1434
+ def _maybe_reset_stats(self, server_id: str):
1435
+ """Reset stats if they're getting stale."""
1436
+ stats = self._server_stats[server_id]
1437
+ if time.time() - stats["last_reset"] > 300: # Reset every 5 minutes
1438
+ stats["recent_calls"] = 0
1439
+ stats["recent_failures"] = 0
1440
+ stats["last_reset"] = time.time()
1441
+
1442
+ def _calculate_priority_score(self, server: ServerInfo) -> float:
1443
+ """Calculate priority score for a server.
1444
+
1445
+ Args:
1446
+ server: Server to calculate score for
1447
+
1448
+ Returns:
1449
+ Priority score (0 means unhealthy)
1450
+ """
1451
+ # Unhealthy servers get 0 score
1452
+ if hasattr(server, "health") and server.health:
1453
+ if server.health.get("status") == "unhealthy":
1454
+ return 0
1455
+ elif server.health_status == "unhealthy":
1456
+ return 0
1457
+
1458
+ # Use server's get_priority_score method
1459
+ return server.get_priority_score()
1460
+
1461
+ def select_best_server(self, servers: List[ServerInfo]) -> Optional[ServerInfo]:
1462
+ """Select the best server based on priority scores.
1463
+
1464
+ Args:
1465
+ servers: List of servers to choose from
1466
+
1467
+ Returns:
1468
+ Best server or None
1469
+ """
1470
+ if not servers:
1471
+ return None
1472
+
1473
+ # Score and sort servers
1474
+ scored_servers = [
1475
+ (server, self._calculate_priority_score(server)) for server in servers
1476
+ ]
1477
+ scored_servers.sort(key=lambda x: x[1], reverse=True)
1478
+
1479
+ # Return best server (highest score)
1480
+ return scored_servers[0][0] if scored_servers else None
1481
+
1482
+ def select_servers_round_robin(
1483
+ self, servers: List[ServerInfo], count: int
1484
+ ) -> List[ServerInfo]:
1485
+ """Select servers using round-robin algorithm.
1486
+
1487
+ Args:
1488
+ servers: List of available servers
1489
+ count: Number of servers to select
1490
+
1491
+ Returns:
1492
+ Selected servers
1493
+ """
1494
+ if not servers:
1495
+ return []
1496
+
1497
+ # Track round-robin state
1498
+ if not hasattr(self, "_round_robin_index"):
1499
+ self._round_robin_index = {}
1500
+
1501
+ # Create a key for this server list
1502
+ server_key = tuple(s.name for s in servers)
1503
+
1504
+ if server_key not in self._round_robin_index:
1505
+ self._round_robin_index[server_key] = 0
1506
+
1507
+ selected = []
1508
+ start_index = self._round_robin_index[server_key]
1509
+
1510
+ for i in range(count):
1511
+ index = (start_index + i) % len(servers)
1512
+ selected.append(servers[index])
1513
+
1514
+ # Update index for next call
1515
+ self._round_robin_index[server_key] = (start_index + count) % len(servers)
1516
+
1517
+ return selected
1518
+
1519
+ def select_servers(
1520
+ self, servers: List[ServerInfo], count: int = 1, strategy: str = "priority"
1521
+ ) -> List[ServerInfo]:
1522
+ """Select multiple servers based on strategy.
1523
+
1524
+ Args:
1525
+ servers: List of available servers
1526
+ count: Number of servers to select
1527
+ strategy: Selection strategy ("priority", "round_robin", "random")
1528
+
1529
+ Returns:
1530
+ Selected servers
1531
+ """
1532
+ if not servers:
1533
+ return []
1534
+
1535
+ count = min(count, len(servers))
1536
+
1537
+ if strategy == "round_robin":
1538
+ return self.select_servers_round_robin(servers, count)
1539
+ elif strategy == "random":
1540
+ import random
1541
+
1542
+ return random.sample(servers, count)
1543
+ else: # priority
1544
+ # Sort by priority score
1545
+ scored = [(s, self._calculate_priority_score(s)) for s in servers]
1546
+ scored.sort(key=lambda x: x[1], reverse=True)
1547
+ return [s[0] for s in scored[:count]]
1548
+
1549
+
1550
+ # Convenience functions for easy setup
1551
+ def create_default_registry() -> ServiceRegistry:
1552
+ """Create a default service registry with file and network discovery."""
1553
+ file_backend = FileBasedDiscovery()
1554
+ return ServiceRegistry([file_backend])
1555
+
1556
+
1557
+ async def discover_mcp_servers(
1558
+ capability: Optional[str] = None, transport: Optional[str] = None
1559
+ ) -> List[ServerInfo]:
1560
+ """Discover MCP servers with optional filtering.
1561
+
1562
+ Args:
1563
+ capability: Filter by capability
1564
+ transport: Filter by transport type
1565
+
1566
+ Returns:
1567
+ List of discovered servers
1568
+ """
1569
+ registry = create_default_registry()
1570
+
1571
+ filters = {}
1572
+ if capability:
1573
+ filters["capability"] = capability
1574
+ if transport:
1575
+ filters["transport"] = transport
1576
+
1577
+ return await registry.discover_servers(**filters)
1578
+
1579
+
1580
+ async def get_mcp_client(capability: str, transport: Optional[str] = None):
1581
+ """Get an MCP client for a specific capability.
1582
+
1583
+ Args:
1584
+ capability: Required capability
1585
+ transport: Preferred transport type
1586
+
1587
+ Returns:
1588
+ Configured MCP client
1589
+ """
1590
+ registry = create_default_registry()
1591
+ mesh = ServiceMesh(registry)
1592
+
1593
+ return await mesh.get_client_for_capability(capability, transport)