kailash 0.6.2__py3-none-any.whl → 0.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +3 -3
- kailash/api/custom_nodes_secure.py +3 -3
- kailash/api/gateway.py +1 -1
- kailash/api/studio.py +2 -3
- kailash/api/workflow_api.py +3 -4
- kailash/core/resilience/bulkhead.py +460 -0
- kailash/core/resilience/circuit_breaker.py +92 -10
- kailash/edge/discovery.py +86 -0
- kailash/mcp_server/__init__.py +334 -0
- kailash/mcp_server/advanced_features.py +1022 -0
- kailash/{mcp → mcp_server}/ai_registry_server.py +29 -4
- kailash/mcp_server/auth.py +789 -0
- kailash/mcp_server/client.py +712 -0
- kailash/mcp_server/discovery.py +1593 -0
- kailash/mcp_server/errors.py +673 -0
- kailash/mcp_server/oauth.py +1727 -0
- kailash/mcp_server/protocol.py +1126 -0
- kailash/mcp_server/registry_integration.py +587 -0
- kailash/mcp_server/server.py +1747 -0
- kailash/{mcp → mcp_server}/servers/ai_registry.py +2 -2
- kailash/mcp_server/transports.py +1169 -0
- kailash/mcp_server/utils/cache.py +510 -0
- kailash/middleware/auth/auth_manager.py +3 -3
- kailash/middleware/communication/api_gateway.py +2 -9
- kailash/middleware/communication/realtime.py +1 -1
- kailash/middleware/mcp/client_integration.py +1 -1
- kailash/middleware/mcp/enhanced_server.py +2 -2
- kailash/nodes/__init__.py +2 -0
- kailash/nodes/admin/audit_log.py +6 -6
- kailash/nodes/admin/permission_check.py +8 -8
- kailash/nodes/admin/role_management.py +32 -28
- kailash/nodes/admin/schema.sql +6 -1
- kailash/nodes/admin/schema_manager.py +13 -13
- kailash/nodes/admin/security_event.py +16 -20
- kailash/nodes/admin/tenant_isolation.py +3 -3
- kailash/nodes/admin/transaction_utils.py +3 -3
- kailash/nodes/admin/user_management.py +21 -22
- kailash/nodes/ai/a2a.py +11 -11
- kailash/nodes/ai/ai_providers.py +9 -12
- kailash/nodes/ai/embedding_generator.py +13 -14
- kailash/nodes/ai/intelligent_agent_orchestrator.py +19 -19
- kailash/nodes/ai/iterative_llm_agent.py +3 -3
- kailash/nodes/ai/llm_agent.py +213 -36
- kailash/nodes/ai/self_organizing.py +2 -2
- kailash/nodes/alerts/discord.py +4 -4
- kailash/nodes/api/graphql.py +6 -6
- kailash/nodes/api/http.py +12 -17
- kailash/nodes/api/rate_limiting.py +4 -4
- kailash/nodes/api/rest.py +15 -15
- kailash/nodes/auth/mfa.py +3 -4
- kailash/nodes/auth/risk_assessment.py +2 -2
- kailash/nodes/auth/session_management.py +5 -5
- kailash/nodes/auth/sso.py +143 -0
- kailash/nodes/base.py +6 -2
- kailash/nodes/base_async.py +16 -2
- kailash/nodes/base_with_acl.py +2 -2
- kailash/nodes/cache/__init__.py +9 -0
- kailash/nodes/cache/cache.py +1172 -0
- kailash/nodes/cache/cache_invalidation.py +870 -0
- kailash/nodes/cache/redis_pool_manager.py +595 -0
- kailash/nodes/code/async_python.py +2 -1
- kailash/nodes/code/python.py +196 -35
- kailash/nodes/compliance/data_retention.py +6 -6
- kailash/nodes/compliance/gdpr.py +5 -5
- kailash/nodes/data/__init__.py +10 -0
- kailash/nodes/data/optimistic_locking.py +906 -0
- kailash/nodes/data/readers.py +8 -8
- kailash/nodes/data/redis.py +349 -0
- kailash/nodes/data/sql.py +314 -3
- kailash/nodes/data/streaming.py +21 -0
- kailash/nodes/enterprise/__init__.py +8 -0
- kailash/nodes/enterprise/audit_logger.py +285 -0
- kailash/nodes/enterprise/batch_processor.py +22 -3
- kailash/nodes/enterprise/data_lineage.py +1 -1
- kailash/nodes/enterprise/mcp_executor.py +205 -0
- kailash/nodes/enterprise/service_discovery.py +150 -0
- kailash/nodes/enterprise/tenant_assignment.py +108 -0
- kailash/nodes/logic/async_operations.py +2 -2
- kailash/nodes/logic/convergence.py +1 -1
- kailash/nodes/logic/operations.py +1 -1
- kailash/nodes/monitoring/__init__.py +11 -1
- kailash/nodes/monitoring/health_check.py +456 -0
- kailash/nodes/monitoring/log_processor.py +817 -0
- kailash/nodes/monitoring/metrics_collector.py +627 -0
- kailash/nodes/monitoring/performance_benchmark.py +137 -11
- kailash/nodes/rag/advanced.py +7 -7
- kailash/nodes/rag/agentic.py +49 -2
- kailash/nodes/rag/conversational.py +3 -3
- kailash/nodes/rag/evaluation.py +3 -3
- kailash/nodes/rag/federated.py +3 -3
- kailash/nodes/rag/graph.py +3 -3
- kailash/nodes/rag/multimodal.py +3 -3
- kailash/nodes/rag/optimized.py +5 -5
- kailash/nodes/rag/privacy.py +3 -3
- kailash/nodes/rag/query_processing.py +6 -6
- kailash/nodes/rag/realtime.py +1 -1
- kailash/nodes/rag/registry.py +2 -6
- kailash/nodes/rag/router.py +1 -1
- kailash/nodes/rag/similarity.py +7 -7
- kailash/nodes/rag/strategies.py +4 -4
- kailash/nodes/security/abac_evaluator.py +6 -6
- kailash/nodes/security/behavior_analysis.py +5 -6
- kailash/nodes/security/credential_manager.py +1 -1
- kailash/nodes/security/rotating_credentials.py +11 -11
- kailash/nodes/security/threat_detection.py +8 -8
- kailash/nodes/testing/credential_testing.py +2 -2
- kailash/nodes/transform/processors.py +5 -5
- kailash/runtime/local.py +162 -14
- kailash/runtime/parameter_injection.py +425 -0
- kailash/runtime/parameter_injector.py +657 -0
- kailash/runtime/testing.py +2 -2
- kailash/testing/fixtures.py +2 -2
- kailash/workflow/builder.py +99 -18
- kailash/workflow/builder_improvements.py +207 -0
- kailash/workflow/input_handling.py +170 -0
- {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/METADATA +21 -8
- {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/RECORD +126 -101
- kailash/mcp/__init__.py +0 -53
- kailash/mcp/client.py +0 -445
- kailash/mcp/server.py +0 -292
- kailash/mcp/server_enhanced.py +0 -449
- kailash/mcp/utils/cache.py +0 -267
- /kailash/{mcp → mcp_server}/client_new.py +0 -0
- /kailash/{mcp → mcp_server}/utils/__init__.py +0 -0
- /kailash/{mcp → mcp_server}/utils/config.py +0 -0
- /kailash/{mcp → mcp_server}/utils/formatters.py +0 -0
- /kailash/{mcp → mcp_server}/utils/metrics.py +0 -0
- {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/WHEEL +0 -0
- {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/entry_points.txt +0 -0
- {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1593 @@
|
|
1
|
+
"""
|
2
|
+
Service Discovery System for MCP Servers and Clients.
|
3
|
+
|
4
|
+
This module provides automatic discovery of MCP servers, their capabilities,
|
5
|
+
and health status. It supports multiple discovery mechanisms including
|
6
|
+
file-based registry, network scanning, and external service registries.
|
7
|
+
|
8
|
+
Features:
|
9
|
+
- Automatic server registration and deregistration
|
10
|
+
- Health checking and monitoring
|
11
|
+
- Capability-based server filtering
|
12
|
+
- Load balancing and failover
|
13
|
+
- Real-time server status updates
|
14
|
+
- Network-based discovery protocols
|
15
|
+
|
16
|
+
Examples:
|
17
|
+
Basic service registry:
|
18
|
+
|
19
|
+
>>> registry = ServiceRegistry()
|
20
|
+
>>> registry.register_server({
|
21
|
+
... "name": "weather-server",
|
22
|
+
... "transport": "http",
|
23
|
+
... "url": "http://localhost:8080",
|
24
|
+
... "capabilities": ["weather.get", "weather.forecast"]
|
25
|
+
... })
|
26
|
+
>>> servers = registry.discover_servers(capability="weather.get")
|
27
|
+
|
28
|
+
Network discovery:
|
29
|
+
|
30
|
+
>>> discoverer = NetworkDiscovery()
|
31
|
+
>>> servers = await discoverer.scan_network("192.168.1.0/24")
|
32
|
+
|
33
|
+
Service mesh integration:
|
34
|
+
|
35
|
+
>>> mesh = ServiceMesh(registry)
|
36
|
+
>>> client = await mesh.get_client_for_capability("weather.get")
|
37
|
+
"""
|
38
|
+
|
39
|
+
import asyncio
|
40
|
+
import json
|
41
|
+
import logging
|
42
|
+
import socket
|
43
|
+
import time
|
44
|
+
import uuid
|
45
|
+
from abc import ABC, abstractmethod
|
46
|
+
from dataclasses import asdict, dataclass
|
47
|
+
from pathlib import Path
|
48
|
+
from typing import Any, Dict, List, Optional, Set, Union
|
49
|
+
from urllib.parse import urlparse
|
50
|
+
|
51
|
+
from .auth import AuthProvider
|
52
|
+
from .errors import MCPError, MCPErrorCode, ServiceDiscoveryError
|
53
|
+
|
54
|
+
logger = logging.getLogger(__name__)
|
55
|
+
|
56
|
+
|
57
|
+
@dataclass
|
58
|
+
class ServerInfo:
|
59
|
+
"""Information about a discovered MCP server."""
|
60
|
+
|
61
|
+
name: str
|
62
|
+
transport: str # stdio, sse, http
|
63
|
+
capabilities: List[str] = None # List of tool names or capability strings
|
64
|
+
metadata: Dict[str, Any] = None
|
65
|
+
id: Optional[str] = None
|
66
|
+
endpoint: Optional[str] = None # URL or command
|
67
|
+
command: Optional[str] = None # For stdio transport
|
68
|
+
args: Optional[List[str]] = None # For stdio transport
|
69
|
+
url: Optional[str] = None # For HTTP/SSE transport
|
70
|
+
health_endpoint: Optional[str] = None # Health check endpoint path
|
71
|
+
health_status: str = "unknown" # healthy, unhealthy, unknown
|
72
|
+
health: Optional[Dict[str, Any]] = None # Health information dict
|
73
|
+
last_seen: float = 0.0
|
74
|
+
response_time: Optional[float] = None
|
75
|
+
version: str = "1.0.0"
|
76
|
+
auth_required: bool = False
|
77
|
+
|
78
|
+
def __post_init__(self):
|
79
|
+
if self.last_seen == 0.0:
|
80
|
+
self.last_seen = time.time()
|
81
|
+
|
82
|
+
# Auto-generate ID if not provided
|
83
|
+
if self.id is None:
|
84
|
+
self.id = f"{self.name}_{hash(self.name) % 10000}"
|
85
|
+
|
86
|
+
# Initialize metadata if None
|
87
|
+
if self.metadata is None:
|
88
|
+
self.metadata = {}
|
89
|
+
|
90
|
+
# Initialize capabilities if None
|
91
|
+
if self.capabilities is None:
|
92
|
+
self.capabilities = []
|
93
|
+
|
94
|
+
# Extract response_time from health if available
|
95
|
+
if self.health and isinstance(self.health, dict):
|
96
|
+
if "response_time" in self.health and self.response_time is None:
|
97
|
+
self.response_time = self.health["response_time"]
|
98
|
+
if "status" in self.health and self.health_status == "unknown":
|
99
|
+
self.health_status = self.health["status"]
|
100
|
+
|
101
|
+
# Set endpoint based on transport if not provided
|
102
|
+
if self.endpoint is None:
|
103
|
+
if self.transport == "stdio" and self.command:
|
104
|
+
self.endpoint = self.command
|
105
|
+
elif self.transport in ["http", "sse"] and self.url:
|
106
|
+
self.endpoint = self.url
|
107
|
+
else:
|
108
|
+
self.endpoint = "unknown"
|
109
|
+
|
110
|
+
def to_dict(self) -> Dict[str, Any]:
|
111
|
+
"""Convert to dictionary format."""
|
112
|
+
return asdict(self)
|
113
|
+
|
114
|
+
@classmethod
|
115
|
+
def from_dict(cls, data: Dict[str, Any]) -> "ServerInfo":
|
116
|
+
"""Create from dictionary format."""
|
117
|
+
return cls(**data)
|
118
|
+
|
119
|
+
def is_healthy(self, max_age: float = 300.0) -> bool:
|
120
|
+
"""Check if server is considered healthy."""
|
121
|
+
# Check health dict first
|
122
|
+
if self.health and isinstance(self.health, dict):
|
123
|
+
status = self.health.get("status", "unknown")
|
124
|
+
if status == "healthy":
|
125
|
+
age = time.time() - self.last_seen
|
126
|
+
return age < max_age
|
127
|
+
else:
|
128
|
+
return False
|
129
|
+
|
130
|
+
# Fall back to health_status field
|
131
|
+
age = time.time() - self.last_seen
|
132
|
+
return self.health_status == "healthy" and age < max_age
|
133
|
+
|
134
|
+
def matches_capability(self, capability: str) -> bool:
|
135
|
+
"""Check if server provides a specific capability."""
|
136
|
+
return capability in self.capabilities
|
137
|
+
|
138
|
+
def has_capability(self, capability: str) -> bool:
|
139
|
+
"""Check if server provides a specific capability (alias for matches_capability)."""
|
140
|
+
return self.matches_capability(capability)
|
141
|
+
|
142
|
+
def matches_transport(self, transport: str) -> bool:
|
143
|
+
"""Check if server supports a transport type."""
|
144
|
+
return self.transport == transport
|
145
|
+
|
146
|
+
def matches_filter(self, **filters) -> bool:
|
147
|
+
"""Check if server matches all provided filter criteria.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
**filters: Filter criteria (capability, transport, metadata, name, etc)
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
True if all filters match
|
154
|
+
"""
|
155
|
+
# Check capability filter
|
156
|
+
if "capability" in filters:
|
157
|
+
if not self.has_capability(filters["capability"]):
|
158
|
+
return False
|
159
|
+
|
160
|
+
# Check transport filter
|
161
|
+
if "transport" in filters:
|
162
|
+
if not self.matches_transport(filters["transport"]):
|
163
|
+
return False
|
164
|
+
|
165
|
+
# Check name filter
|
166
|
+
if "name" in filters:
|
167
|
+
if self.name != filters["name"]:
|
168
|
+
return False
|
169
|
+
|
170
|
+
# Check metadata filter (as a dict)
|
171
|
+
if "metadata" in filters:
|
172
|
+
filter_metadata = filters["metadata"]
|
173
|
+
if not self.metadata:
|
174
|
+
return False
|
175
|
+
for key, value in filter_metadata.items():
|
176
|
+
if key not in self.metadata or self.metadata[key] != value:
|
177
|
+
return False
|
178
|
+
|
179
|
+
# Check other direct attributes
|
180
|
+
for key, value in filters.items():
|
181
|
+
if key not in ["capability", "transport", "name", "metadata"]:
|
182
|
+
# Check if it's a direct attribute
|
183
|
+
if hasattr(self, key):
|
184
|
+
if getattr(self, key) != value:
|
185
|
+
return False
|
186
|
+
else:
|
187
|
+
# If not an attribute, check in metadata
|
188
|
+
if self.metadata and key in self.metadata:
|
189
|
+
if self.metadata[key] != value:
|
190
|
+
return False
|
191
|
+
else:
|
192
|
+
return False
|
193
|
+
|
194
|
+
return True
|
195
|
+
|
196
|
+
def get_priority_score(self) -> float:
|
197
|
+
"""Calculate priority score for load balancing."""
|
198
|
+
base_score = 1.0
|
199
|
+
|
200
|
+
# Health bonus
|
201
|
+
if self.health_status == "healthy":
|
202
|
+
base_score += 0.5
|
203
|
+
elif self.health_status == "unhealthy":
|
204
|
+
base_score -= 0.5
|
205
|
+
|
206
|
+
# Response time bonus (lower is better)
|
207
|
+
if self.response_time:
|
208
|
+
if self.response_time < 0.1: # < 100ms
|
209
|
+
base_score += 0.3
|
210
|
+
elif self.response_time > 1.0: # > 1s
|
211
|
+
base_score -= 0.3
|
212
|
+
|
213
|
+
# Age penalty
|
214
|
+
age = time.time() - self.last_seen
|
215
|
+
if age > 60: # Over 1 minute old
|
216
|
+
base_score -= min(0.4, age / 300) # Max penalty of 0.4
|
217
|
+
|
218
|
+
return max(0.1, base_score) # Minimum score of 0.1
|
219
|
+
|
220
|
+
|
221
|
+
class DiscoveryBackend(ABC):
|
222
|
+
"""Abstract base class for discovery backends."""
|
223
|
+
|
224
|
+
@abstractmethod
|
225
|
+
async def register_server(self, server_info: ServerInfo) -> bool:
|
226
|
+
"""Register a server with the discovery backend."""
|
227
|
+
pass
|
228
|
+
|
229
|
+
@abstractmethod
|
230
|
+
async def deregister_server(self, server_id: str) -> bool:
|
231
|
+
"""Deregister a server from the discovery backend."""
|
232
|
+
pass
|
233
|
+
|
234
|
+
@abstractmethod
|
235
|
+
async def get_servers(self, **filters) -> List[ServerInfo]:
|
236
|
+
"""Get list of servers matching filters."""
|
237
|
+
pass
|
238
|
+
|
239
|
+
@abstractmethod
|
240
|
+
async def update_server_health(
|
241
|
+
self, server_id: str, health_status: str, response_time: Optional[float] = None
|
242
|
+
) -> bool:
|
243
|
+
"""Update server health status."""
|
244
|
+
pass
|
245
|
+
|
246
|
+
|
247
|
+
class FileBasedDiscovery(DiscoveryBackend):
|
248
|
+
"""File-based service discovery using JSON registry."""
|
249
|
+
|
250
|
+
def __init__(self, registry_path: Union[str, Path] = "mcp_registry.json"):
|
251
|
+
"""Initialize file-based discovery.
|
252
|
+
|
253
|
+
Args:
|
254
|
+
registry_path: Path to the JSON registry file
|
255
|
+
"""
|
256
|
+
self.registry_path = Path(registry_path)
|
257
|
+
self._ensure_registry_file()
|
258
|
+
|
259
|
+
@property
|
260
|
+
def _servers(self) -> Dict[str, ServerInfo]:
|
261
|
+
"""Get servers as a dict for compatibility with tests."""
|
262
|
+
registry = self._read_registry()
|
263
|
+
servers = {}
|
264
|
+
for server_id, server_data in registry["servers"].items():
|
265
|
+
server_info = ServerInfo.from_dict(server_data)
|
266
|
+
servers[server_info.name] = server_info
|
267
|
+
return servers
|
268
|
+
|
269
|
+
def _ensure_registry_file(self):
|
270
|
+
"""Ensure registry file exists."""
|
271
|
+
if not self.registry_path.exists():
|
272
|
+
self.registry_path.write_text(
|
273
|
+
json.dumps(
|
274
|
+
{"servers": {}, "last_updated": time.time(), "version": "1.0"},
|
275
|
+
indent=2,
|
276
|
+
)
|
277
|
+
)
|
278
|
+
|
279
|
+
def _read_registry(self) -> Dict[str, Any]:
|
280
|
+
"""Read registry from file."""
|
281
|
+
try:
|
282
|
+
return json.loads(self.registry_path.read_text())
|
283
|
+
except (json.JSONDecodeError, FileNotFoundError) as e:
|
284
|
+
logger.error(f"Failed to read registry: {e}")
|
285
|
+
return {"servers": {}, "last_updated": time.time(), "version": "1.0"}
|
286
|
+
|
287
|
+
def _write_registry(self, registry: Dict[str, Any]):
|
288
|
+
"""Write registry to file."""
|
289
|
+
registry["last_updated"] = time.time()
|
290
|
+
self.registry_path.write_text(json.dumps(registry, indent=2))
|
291
|
+
|
292
|
+
async def register_server(self, server_info: ServerInfo) -> bool:
|
293
|
+
"""Register a server in the file registry."""
|
294
|
+
try:
|
295
|
+
registry = self._read_registry()
|
296
|
+
registry["servers"][server_info.id] = server_info.to_dict()
|
297
|
+
self._write_registry(registry)
|
298
|
+
logger.info(f"Registered server: {server_info.name} ({server_info.id})")
|
299
|
+
return True
|
300
|
+
except Exception as e:
|
301
|
+
logger.error(f"Failed to register server {server_info.id}: {e}")
|
302
|
+
return False
|
303
|
+
|
304
|
+
async def deregister_server(self, server_id: str) -> bool:
|
305
|
+
"""Deregister a server from the file registry."""
|
306
|
+
try:
|
307
|
+
registry = self._read_registry()
|
308
|
+
if server_id in registry["servers"]:
|
309
|
+
del registry["servers"][server_id]
|
310
|
+
self._write_registry(registry)
|
311
|
+
logger.info(f"Deregistered server: {server_id}")
|
312
|
+
return True
|
313
|
+
return False
|
314
|
+
except Exception as e:
|
315
|
+
logger.error(f"Failed to deregister server {server_id}: {e}")
|
316
|
+
return False
|
317
|
+
|
318
|
+
async def unregister_server(self, server_name: str) -> bool:
|
319
|
+
"""Unregister a server by name (alias for test compatibility)."""
|
320
|
+
# Find server by name
|
321
|
+
registry = self._read_registry()
|
322
|
+
server_id_to_remove = None
|
323
|
+
|
324
|
+
for server_id, server_data in registry["servers"].items():
|
325
|
+
if server_data.get("name") == server_name:
|
326
|
+
server_id_to_remove = server_id
|
327
|
+
break
|
328
|
+
|
329
|
+
if server_id_to_remove:
|
330
|
+
return await self.deregister_server(server_id_to_remove)
|
331
|
+
return False
|
332
|
+
|
333
|
+
async def get_servers(self, **filters) -> List[ServerInfo]:
|
334
|
+
"""Get servers matching filters."""
|
335
|
+
try:
|
336
|
+
registry = self._read_registry()
|
337
|
+
servers = []
|
338
|
+
|
339
|
+
for server_data in registry["servers"].values():
|
340
|
+
server_info = ServerInfo.from_dict(server_data)
|
341
|
+
|
342
|
+
# Apply filters
|
343
|
+
if self._matches_filters(server_info, filters):
|
344
|
+
servers.append(server_info)
|
345
|
+
|
346
|
+
return servers
|
347
|
+
except Exception as e:
|
348
|
+
logger.error(f"Failed to get servers: {e}")
|
349
|
+
return []
|
350
|
+
|
351
|
+
async def discover_servers(self, **filters) -> List[ServerInfo]:
|
352
|
+
"""Discover servers matching filters (alias for get_servers)."""
|
353
|
+
return await self.get_servers(**filters)
|
354
|
+
|
355
|
+
async def get_server(self, server_name: str) -> Optional[ServerInfo]:
|
356
|
+
"""Get a specific server by name."""
|
357
|
+
servers = await self.get_servers()
|
358
|
+
for server in servers:
|
359
|
+
if server.name == server_name:
|
360
|
+
return server
|
361
|
+
return None
|
362
|
+
|
363
|
+
async def update_server_health(
|
364
|
+
self,
|
365
|
+
server_identifier: str,
|
366
|
+
health_info: Union[str, Dict[str, Any]],
|
367
|
+
response_time: Optional[float] = None,
|
368
|
+
) -> bool:
|
369
|
+
"""Update server health in the registry.
|
370
|
+
|
371
|
+
Args:
|
372
|
+
server_identifier: Server ID or name
|
373
|
+
health_info: Health status string or health info dict
|
374
|
+
response_time: Optional response time (if health_info is string)
|
375
|
+
"""
|
376
|
+
try:
|
377
|
+
registry = self._read_registry()
|
378
|
+
|
379
|
+
# Find server by ID or name
|
380
|
+
server_id = None
|
381
|
+
for sid, server_data in registry["servers"].items():
|
382
|
+
if (
|
383
|
+
sid == server_identifier
|
384
|
+
or server_data.get("name") == server_identifier
|
385
|
+
):
|
386
|
+
server_id = sid
|
387
|
+
break
|
388
|
+
|
389
|
+
if not server_id:
|
390
|
+
return False
|
391
|
+
|
392
|
+
# Update health info
|
393
|
+
if isinstance(health_info, dict):
|
394
|
+
# Full health info dict provided
|
395
|
+
registry["servers"][server_id]["health"] = health_info
|
396
|
+
if "status" in health_info:
|
397
|
+
registry["servers"][server_id]["health_status"] = health_info[
|
398
|
+
"status"
|
399
|
+
]
|
400
|
+
if "response_time" in health_info:
|
401
|
+
registry["servers"][server_id]["response_time"] = health_info[
|
402
|
+
"response_time"
|
403
|
+
]
|
404
|
+
else:
|
405
|
+
# Simple string status
|
406
|
+
registry["servers"][server_id]["health_status"] = health_info
|
407
|
+
if response_time is not None:
|
408
|
+
registry["servers"][server_id]["response_time"] = response_time
|
409
|
+
|
410
|
+
registry["servers"][server_id]["last_seen"] = time.time()
|
411
|
+
self._write_registry(registry)
|
412
|
+
return True
|
413
|
+
|
414
|
+
except Exception as e:
|
415
|
+
logger.error(f"Failed to update server health {server_identifier}: {e}")
|
416
|
+
return False
|
417
|
+
|
418
|
+
def _matches_filters(
|
419
|
+
self, server_info: ServerInfo, filters: Dict[str, Any]
|
420
|
+
) -> bool:
|
421
|
+
"""Check if server matches the provided filters."""
|
422
|
+
for key, value in filters.items():
|
423
|
+
if key == "capability":
|
424
|
+
if not server_info.matches_capability(value):
|
425
|
+
return False
|
426
|
+
elif key == "transport":
|
427
|
+
if not server_info.matches_transport(value):
|
428
|
+
return False
|
429
|
+
elif key == "healthy_only":
|
430
|
+
if value and not server_info.is_healthy():
|
431
|
+
return False
|
432
|
+
elif key == "name":
|
433
|
+
if server_info.name != value:
|
434
|
+
return False
|
435
|
+
elif key == "auth_required":
|
436
|
+
if server_info.auth_required != value:
|
437
|
+
return False
|
438
|
+
|
439
|
+
return True
|
440
|
+
|
441
|
+
async def save_registry(self, path: str) -> None:
|
442
|
+
"""Save the current registry to a different file."""
|
443
|
+
# Use async file operations to avoid blocking
|
444
|
+
import json
|
445
|
+
|
446
|
+
import aiofiles
|
447
|
+
|
448
|
+
try:
|
449
|
+
# Read current registry
|
450
|
+
registry = self._read_registry()
|
451
|
+
|
452
|
+
# Write to the specified path asynchronously
|
453
|
+
async with aiofiles.open(path, "w") as f:
|
454
|
+
await f.write(json.dumps(registry, indent=2))
|
455
|
+
|
456
|
+
except Exception as e:
|
457
|
+
logger.error(f"Failed to save registry to {path}: {e}")
|
458
|
+
raise
|
459
|
+
|
460
|
+
async def load_registry(self, path: str) -> None:
|
461
|
+
"""Load registry from a different file."""
|
462
|
+
import json
|
463
|
+
from pathlib import Path
|
464
|
+
|
465
|
+
import aiofiles
|
466
|
+
|
467
|
+
try:
|
468
|
+
if not Path(path).exists():
|
469
|
+
# Create empty registry if file doesn't exist
|
470
|
+
logger.warning(
|
471
|
+
f"Registry file not found: {path}, creating empty registry"
|
472
|
+
)
|
473
|
+
self._ensure_registry_file()
|
474
|
+
return
|
475
|
+
|
476
|
+
# Read from the specified path asynchronously
|
477
|
+
async with aiofiles.open(path, "r") as f:
|
478
|
+
content = await f.read()
|
479
|
+
registry = json.loads(content)
|
480
|
+
|
481
|
+
# Write to our registry path
|
482
|
+
self._write_registry(registry)
|
483
|
+
|
484
|
+
except Exception as e:
|
485
|
+
logger.error(f"Failed to load registry from {path}: {e}")
|
486
|
+
raise
|
487
|
+
|
488
|
+
|
489
|
+
class NetworkDiscovery:
|
490
|
+
"""Network-based discovery using UDP broadcast/multicast."""
|
491
|
+
|
492
|
+
DISCOVERY_PORT = 8765
|
493
|
+
MULTICAST_GROUP = "224.0.0.251"
|
494
|
+
|
495
|
+
def __init__(
|
496
|
+
self,
|
497
|
+
port: int = DISCOVERY_PORT,
|
498
|
+
multicast_group: str = None,
|
499
|
+
interface: str = "0.0.0.0",
|
500
|
+
):
|
501
|
+
"""Initialize network discovery.
|
502
|
+
|
503
|
+
Args:
|
504
|
+
port: UDP port for discovery
|
505
|
+
multicast_group: Multicast group address
|
506
|
+
interface: Network interface to bind to
|
507
|
+
"""
|
508
|
+
self.port = port
|
509
|
+
self.multicast_group = multicast_group or self.MULTICAST_GROUP
|
510
|
+
self.interface = interface
|
511
|
+
self.running = False
|
512
|
+
self._discovered_servers: Dict[str, ServerInfo] = {}
|
513
|
+
self._discovery_socket: Optional[socket.socket] = None
|
514
|
+
self._transport = None
|
515
|
+
self._protocol = None
|
516
|
+
|
517
|
+
async def start(self):
|
518
|
+
"""Start network discovery."""
|
519
|
+
loop = asyncio.get_event_loop()
|
520
|
+
|
521
|
+
# Create UDP endpoint
|
522
|
+
transport, protocol = await loop.create_datagram_endpoint(
|
523
|
+
lambda: self, local_addr=(self.interface, self.port), reuse_port=True
|
524
|
+
)
|
525
|
+
|
526
|
+
self._transport = transport
|
527
|
+
self._protocol = protocol
|
528
|
+
self.running = True
|
529
|
+
|
530
|
+
async def stop(self):
|
531
|
+
"""Stop network discovery."""
|
532
|
+
self.running = False
|
533
|
+
if self._transport:
|
534
|
+
self._transport.close()
|
535
|
+
self._transport = None
|
536
|
+
self._protocol = None
|
537
|
+
|
538
|
+
async def start_discovery_listener(self):
|
539
|
+
"""Start listening for server announcements."""
|
540
|
+
await self.start()
|
541
|
+
|
542
|
+
# Create UDP socket for listening
|
543
|
+
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
544
|
+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
545
|
+
sock.bind(("", self.port))
|
546
|
+
|
547
|
+
# Join multicast group
|
548
|
+
mreq = socket.inet_aton(self.MULTICAST_GROUP) + socket.inet_aton("0.0.0.0")
|
549
|
+
sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
|
550
|
+
|
551
|
+
self._discovery_socket = sock
|
552
|
+
|
553
|
+
logger.info(f"Started network discovery listener on port {self.port}")
|
554
|
+
|
555
|
+
while self.running:
|
556
|
+
try:
|
557
|
+
# Set timeout to check running flag periodically
|
558
|
+
sock.settimeout(1.0)
|
559
|
+
data, addr = sock.recvfrom(1024)
|
560
|
+
|
561
|
+
await self._process_announcement(data, addr)
|
562
|
+
|
563
|
+
except socket.timeout:
|
564
|
+
continue
|
565
|
+
except Exception as e:
|
566
|
+
logger.error(f"Error in discovery listener: {e}")
|
567
|
+
|
568
|
+
async def _process_announcement(self, data: bytes, addr: tuple):
|
569
|
+
"""Process server announcement."""
|
570
|
+
try:
|
571
|
+
announcement = json.loads(data.decode())
|
572
|
+
|
573
|
+
if announcement.get("type") == "mcp_server_announcement":
|
574
|
+
server_info = ServerInfo(
|
575
|
+
id=announcement.get("id", str(uuid.uuid4())),
|
576
|
+
name=announcement.get("name", "unknown"),
|
577
|
+
transport=announcement.get("transport", "http"),
|
578
|
+
endpoint=announcement.get("endpoint", f"http://{addr[0]}:8080"),
|
579
|
+
capabilities=announcement.get("capabilities", []),
|
580
|
+
metadata=announcement.get("metadata", {}),
|
581
|
+
health_status="healthy",
|
582
|
+
last_seen=time.time(),
|
583
|
+
version=announcement.get("version", "1.0.0"),
|
584
|
+
auth_required=announcement.get("auth_required", False),
|
585
|
+
)
|
586
|
+
|
587
|
+
self._discovered_servers[server_info.id] = server_info
|
588
|
+
logger.info(f"Discovered server: {server_info.name} at {addr[0]}")
|
589
|
+
|
590
|
+
except (json.JSONDecodeError, KeyError) as e:
|
591
|
+
logger.debug(f"Invalid announcement from {addr[0]}: {e}")
|
592
|
+
|
593
|
+
async def scan_network(
|
594
|
+
self, network: str = "192.168.1.0/24", timeout: float = 5.0
|
595
|
+
) -> List[ServerInfo]:
|
596
|
+
"""Actively scan network for MCP servers.
|
597
|
+
|
598
|
+
Args:
|
599
|
+
network: Network to scan (CIDR notation)
|
600
|
+
timeout: Scan timeout in seconds
|
601
|
+
|
602
|
+
Returns:
|
603
|
+
List of discovered servers
|
604
|
+
"""
|
605
|
+
import ipaddress
|
606
|
+
|
607
|
+
discovered = []
|
608
|
+
network_obj = ipaddress.IPv4Network(network, strict=False)
|
609
|
+
|
610
|
+
logger.info(f"Scanning network {network} for MCP servers...")
|
611
|
+
|
612
|
+
# Create semaphore to limit concurrent connections
|
613
|
+
semaphore = asyncio.Semaphore(50)
|
614
|
+
|
615
|
+
async def scan_host(ip: str):
|
616
|
+
async with semaphore:
|
617
|
+
try:
|
618
|
+
# Try common MCP ports
|
619
|
+
for port in [8080, 8765, 3000, 5000]:
|
620
|
+
try:
|
621
|
+
reader, writer = await asyncio.wait_for(
|
622
|
+
asyncio.open_connection(str(ip), port), timeout=1.0
|
623
|
+
)
|
624
|
+
|
625
|
+
# Send MCP discovery request
|
626
|
+
discovery_request = json.dumps(
|
627
|
+
{"type": "mcp_discovery_request", "version": "1.0"}
|
628
|
+
).encode()
|
629
|
+
|
630
|
+
writer.write(discovery_request)
|
631
|
+
await writer.drain()
|
632
|
+
|
633
|
+
# Read response
|
634
|
+
response_data = await asyncio.wait_for(
|
635
|
+
reader.read(1024), timeout=2.0
|
636
|
+
)
|
637
|
+
|
638
|
+
response = json.loads(response_data.decode())
|
639
|
+
|
640
|
+
if response.get("type") == "mcp_discovery_response":
|
641
|
+
server_info = ServerInfo(
|
642
|
+
id=response.get("id", str(uuid.uuid4())),
|
643
|
+
name=response.get("name", f"server-{ip}"),
|
644
|
+
transport=response.get("transport", "http"),
|
645
|
+
endpoint=f"http://{ip}:{port}",
|
646
|
+
capabilities=response.get("capabilities", []),
|
647
|
+
metadata=response.get("metadata", {}),
|
648
|
+
health_status="healthy",
|
649
|
+
version=response.get("version", "1.0.0"),
|
650
|
+
auth_required=response.get("auth_required", False),
|
651
|
+
)
|
652
|
+
discovered.append(server_info)
|
653
|
+
break
|
654
|
+
|
655
|
+
writer.close()
|
656
|
+
await writer.wait_closed()
|
657
|
+
|
658
|
+
except (
|
659
|
+
asyncio.TimeoutError,
|
660
|
+
ConnectionRefusedError,
|
661
|
+
json.JSONDecodeError,
|
662
|
+
):
|
663
|
+
continue
|
664
|
+
except Exception as e:
|
665
|
+
logger.debug(f"Error scanning {ip}:{port}: {e}")
|
666
|
+
continue
|
667
|
+
|
668
|
+
except Exception as e:
|
669
|
+
logger.debug(f"Error scanning host {ip}: {e}")
|
670
|
+
|
671
|
+
# Scan all hosts in parallel
|
672
|
+
tasks = [scan_host(ip) for ip in network_obj.hosts()]
|
673
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
674
|
+
|
675
|
+
logger.info(f"Network scan completed. Found {len(discovered)} servers.")
|
676
|
+
return discovered
|
677
|
+
|
678
|
+
def _send_message(self, message: Dict[str, Any], address: tuple = None):
|
679
|
+
"""Send a message over the network."""
|
680
|
+
if not self._transport:
|
681
|
+
logger.warning("Transport not initialized")
|
682
|
+
return
|
683
|
+
|
684
|
+
data = json.dumps(message).encode()
|
685
|
+
|
686
|
+
if address:
|
687
|
+
self._transport.sendto(data, address)
|
688
|
+
else:
|
689
|
+
# Broadcast/multicast
|
690
|
+
self._transport.sendto(data, (self.multicast_group, self.port))
|
691
|
+
|
692
|
+
async def announce_server(self, server_info: ServerInfo):
|
693
|
+
"""Announce a server on the network."""
|
694
|
+
message = {"type": "server_announcement", "server": server_info.to_dict()}
|
695
|
+
self._send_message(message)
|
696
|
+
|
697
|
+
def stop_discovery(self):
|
698
|
+
"""Stop network discovery."""
|
699
|
+
self.running = False
|
700
|
+
if self._discovery_socket:
|
701
|
+
self._discovery_socket.close()
|
702
|
+
self._discovery_socket = None
|
703
|
+
logger.info("Stopped network discovery")
|
704
|
+
|
705
|
+
def get_discovered_servers(self) -> List[ServerInfo]:
|
706
|
+
"""Get list of currently discovered servers."""
|
707
|
+
# Filter out stale servers (older than 5 minutes)
|
708
|
+
current_time = time.time()
|
709
|
+
active_servers = []
|
710
|
+
|
711
|
+
for server in self._discovered_servers.values():
|
712
|
+
if current_time - server.last_seen < 300: # 5 minutes
|
713
|
+
active_servers.append(server)
|
714
|
+
|
715
|
+
return active_servers
|
716
|
+
|
717
|
+
async def discover_servers(self, **filters) -> List[ServerInfo]:
|
718
|
+
"""Discover servers on the network (returns already discovered servers).
|
719
|
+
|
720
|
+
Args:
|
721
|
+
**filters: Filters to apply (capability, transport, etc)
|
722
|
+
|
723
|
+
Returns:
|
724
|
+
List of servers matching filters
|
725
|
+
"""
|
726
|
+
servers = self.get_discovered_servers()
|
727
|
+
|
728
|
+
# Apply filters
|
729
|
+
if filters:
|
730
|
+
filtered = []
|
731
|
+
for server in servers:
|
732
|
+
if server.matches_filter(**filters):
|
733
|
+
filtered.append(server)
|
734
|
+
return filtered
|
735
|
+
|
736
|
+
return servers
|
737
|
+
|
738
|
+
async def _handle_discovery_message(self, message: Dict[str, Any], addr: tuple):
|
739
|
+
"""Handle incoming discovery message."""
|
740
|
+
msg_type = message.get("type")
|
741
|
+
|
742
|
+
if msg_type == "server_announcement":
|
743
|
+
# Handle server announcement
|
744
|
+
server_data = message.get("server", {})
|
745
|
+
server_info = ServerInfo.from_dict(server_data)
|
746
|
+
server_info.last_seen = time.time()
|
747
|
+
|
748
|
+
# Store by name
|
749
|
+
self._discovered_servers[server_info.name] = server_info
|
750
|
+
logger.info(f"Discovered server: {server_info.name} from {addr}")
|
751
|
+
|
752
|
+
elif msg_type == "server_query":
|
753
|
+
# Respond to server queries
|
754
|
+
pass
|
755
|
+
|
756
|
+
else:
|
757
|
+
logger.debug(f"Unknown message type: {msg_type}")
|
758
|
+
|
759
|
+
def datagram_received(self, data: bytes, addr: tuple):
|
760
|
+
"""Handle received datagram (part of asyncio protocol)."""
|
761
|
+
try:
|
762
|
+
message = json.loads(data.decode())
|
763
|
+
# Try to get current event loop
|
764
|
+
try:
|
765
|
+
loop = asyncio.get_running_loop()
|
766
|
+
asyncio.create_task(self._handle_discovery_message(message, addr))
|
767
|
+
except RuntimeError:
|
768
|
+
# No event loop, run synchronously
|
769
|
+
asyncio.run(self._handle_discovery_message(message, addr))
|
770
|
+
except json.JSONDecodeError:
|
771
|
+
logger.warning(f"Invalid JSON received from {addr}")
|
772
|
+
except Exception as e:
|
773
|
+
logger.error(f"Error handling datagram from {addr}: {e}")
|
774
|
+
|
775
|
+
|
776
|
+
class ServiceRegistry:
|
777
|
+
"""Main service registry coordinating multiple discovery backends."""
|
778
|
+
|
779
|
+
def __init__(self, backends: Optional[List[DiscoveryBackend]] = None):
|
780
|
+
"""Initialize service registry.
|
781
|
+
|
782
|
+
Args:
|
783
|
+
backends: List of discovery backends to use
|
784
|
+
"""
|
785
|
+
if backends is None:
|
786
|
+
backends = [FileBasedDiscovery()]
|
787
|
+
|
788
|
+
self.backends = backends
|
789
|
+
self.health_checker = HealthChecker() # Initialize health checker
|
790
|
+
self.load_balancer = LoadBalancer()
|
791
|
+
self.service_mesh = ServiceMesh(self)
|
792
|
+
self._server_cache: Dict[str, ServerInfo] = {}
|
793
|
+
self._cache_expiry = 60.0 # Cache for 1 minute
|
794
|
+
self._last_cache_update = 0.0
|
795
|
+
|
796
|
+
async def register_server(
|
797
|
+
self, server_info: Union[ServerInfo, Dict[str, Any]]
|
798
|
+
) -> bool:
|
799
|
+
"""Register a server with all backends.
|
800
|
+
|
801
|
+
Args:
|
802
|
+
server_info: ServerInfo object or server configuration dictionary
|
803
|
+
|
804
|
+
Returns:
|
805
|
+
True if registration succeeded in at least one backend
|
806
|
+
"""
|
807
|
+
# Convert config to ServerInfo if needed
|
808
|
+
if isinstance(server_info, dict):
|
809
|
+
server_config = server_info
|
810
|
+
server_info = ServerInfo(
|
811
|
+
name=server_config["name"],
|
812
|
+
transport=server_config["transport"],
|
813
|
+
capabilities=server_config.get("capabilities", []),
|
814
|
+
id=server_config.get("id"),
|
815
|
+
endpoint=server_config.get("endpoint"),
|
816
|
+
command=server_config.get("command"),
|
817
|
+
args=server_config.get("args"),
|
818
|
+
url=server_config.get("url"),
|
819
|
+
metadata=server_config.get("metadata", {}),
|
820
|
+
auth_required=server_config.get("auth_required", False),
|
821
|
+
version=server_config.get("version", "1.0.0"),
|
822
|
+
)
|
823
|
+
|
824
|
+
success_count = 0
|
825
|
+
for backend in self.backends:
|
826
|
+
try:
|
827
|
+
if await backend.register_server(server_info):
|
828
|
+
success_count += 1
|
829
|
+
except Exception as e:
|
830
|
+
logger.error(
|
831
|
+
f"Backend {type(backend).__name__} registration failed: {e}"
|
832
|
+
)
|
833
|
+
|
834
|
+
if success_count > 0:
|
835
|
+
# Update cache
|
836
|
+
self._server_cache[server_info.id] = server_info
|
837
|
+
logger.info(
|
838
|
+
f"Successfully registered server {server_info.name} with {success_count} backends"
|
839
|
+
)
|
840
|
+
return True
|
841
|
+
|
842
|
+
return False
|
843
|
+
|
844
|
+
async def deregister_server(self, server_id: str) -> bool:
|
845
|
+
"""Deregister a server from all backends."""
|
846
|
+
success_count = 0
|
847
|
+
for backend in self.backends:
|
848
|
+
try:
|
849
|
+
if await backend.deregister_server(server_id):
|
850
|
+
success_count += 1
|
851
|
+
except Exception as e:
|
852
|
+
logger.error(
|
853
|
+
f"Backend {type(backend).__name__} deregistration failed: {e}"
|
854
|
+
)
|
855
|
+
|
856
|
+
# Remove from cache
|
857
|
+
if server_id in self._server_cache:
|
858
|
+
del self._server_cache[server_id]
|
859
|
+
|
860
|
+
return success_count > 0
|
861
|
+
|
862
|
+
async def discover_servers(self, **filters) -> List[ServerInfo]:
|
863
|
+
"""Discover servers matching the given filters.
|
864
|
+
|
865
|
+
Args:
|
866
|
+
**filters: Filter criteria (capability, transport, healthy_only, etc.)
|
867
|
+
|
868
|
+
Returns:
|
869
|
+
List of matching servers, deduplicated and sorted by priority
|
870
|
+
"""
|
871
|
+
# Check cache first
|
872
|
+
if (
|
873
|
+
time.time() - self._last_cache_update
|
874
|
+
) < self._cache_expiry and not filters.get("force_refresh"):
|
875
|
+
cached_servers = list(self._server_cache.values())
|
876
|
+
return self._filter_and_sort_servers(cached_servers, filters)
|
877
|
+
|
878
|
+
# Fetch from all backends
|
879
|
+
all_servers: Dict[str, ServerInfo] = {}
|
880
|
+
|
881
|
+
for backend in self.backends:
|
882
|
+
try:
|
883
|
+
servers = await backend.get_servers(**filters)
|
884
|
+
for server in servers:
|
885
|
+
# Use latest info if server exists in multiple backends
|
886
|
+
if (
|
887
|
+
server.id not in all_servers
|
888
|
+
or server.last_seen > all_servers[server.id].last_seen
|
889
|
+
):
|
890
|
+
all_servers[server.id] = server
|
891
|
+
except Exception as e:
|
892
|
+
logger.error(f"Backend {type(backend).__name__} discovery failed: {e}")
|
893
|
+
|
894
|
+
# Update cache
|
895
|
+
self._server_cache = all_servers.copy()
|
896
|
+
self._last_cache_update = time.time()
|
897
|
+
|
898
|
+
servers_list = list(all_servers.values())
|
899
|
+
return self._filter_and_sort_servers(servers_list, filters)
|
900
|
+
|
901
|
+
def _filter_and_sort_servers(
|
902
|
+
self, servers: List[ServerInfo], filters: Dict[str, Any]
|
903
|
+
) -> List[ServerInfo]:
|
904
|
+
"""Filter and sort servers by priority."""
|
905
|
+
filtered_servers = []
|
906
|
+
|
907
|
+
for server in servers:
|
908
|
+
# Apply remaining filters not handled by backends
|
909
|
+
if filters.get("healthy_only") and not server.is_healthy():
|
910
|
+
continue
|
911
|
+
if filters.get("capability") and not server.matches_capability(
|
912
|
+
filters["capability"]
|
913
|
+
):
|
914
|
+
continue
|
915
|
+
if filters.get("transport") and not server.matches_transport(
|
916
|
+
filters["transport"]
|
917
|
+
):
|
918
|
+
continue
|
919
|
+
|
920
|
+
filtered_servers.append(server)
|
921
|
+
|
922
|
+
# Sort by priority score (highest first)
|
923
|
+
filtered_servers.sort(key=lambda s: s.get_priority_score(), reverse=True)
|
924
|
+
|
925
|
+
return filtered_servers
|
926
|
+
|
927
|
+
async def get_best_server(
|
928
|
+
self, capability: str, transport: Optional[str] = None
|
929
|
+
) -> Optional[ServerInfo]:
|
930
|
+
"""Get the best server for a specific capability.
|
931
|
+
|
932
|
+
Args:
|
933
|
+
capability: Required capability
|
934
|
+
transport: Preferred transport type
|
935
|
+
|
936
|
+
Returns:
|
937
|
+
Best available server or None
|
938
|
+
"""
|
939
|
+
filters = {"capability": capability, "healthy_only": True}
|
940
|
+
if transport:
|
941
|
+
filters["transport"] = transport
|
942
|
+
|
943
|
+
servers = await self.discover_servers(**filters)
|
944
|
+
return servers[0] if servers else None
|
945
|
+
|
946
|
+
def start_health_checking(self, interval: float = 30.0):
|
947
|
+
"""Start periodic health checking of registered servers.
|
948
|
+
|
949
|
+
Args:
|
950
|
+
interval: Health check interval in seconds
|
951
|
+
"""
|
952
|
+
if not self.health_checker:
|
953
|
+
self.health_checker = HealthChecker(self)
|
954
|
+
|
955
|
+
asyncio.create_task(self.health_checker.start_periodic_checks(interval))
|
956
|
+
|
957
|
+
def stop_health_checking(self):
|
958
|
+
"""Stop health checking."""
|
959
|
+
if self.health_checker:
|
960
|
+
self.health_checker.stop()
|
961
|
+
|
962
|
+
async def start_health_monitoring(self, interval: float = 30.0):
|
963
|
+
"""Start health monitoring (async version)."""
|
964
|
+
if self.health_checker:
|
965
|
+
await self.health_checker.start(self)
|
966
|
+
|
967
|
+
async def stop_health_monitoring(self):
|
968
|
+
"""Stop health monitoring (async version)."""
|
969
|
+
if self.health_checker:
|
970
|
+
self.health_checker.stop()
|
971
|
+
|
972
|
+
async def get_best_server_for_capability(
|
973
|
+
self, capability: str
|
974
|
+
) -> Optional[ServerInfo]:
|
975
|
+
"""Get best server for capability (async version)."""
|
976
|
+
return await self.get_best_server(capability)
|
977
|
+
|
978
|
+
|
979
|
+
class HealthChecker:
|
980
|
+
"""Health checker for registered MCP servers."""
|
981
|
+
|
982
|
+
def __init__(self, registry: ServiceRegistry = None, check_interval: float = 30.0):
|
983
|
+
"""Initialize health checker.
|
984
|
+
|
985
|
+
Args:
|
986
|
+
registry: Service registry to check (optional)
|
987
|
+
check_interval: Default check interval in seconds
|
988
|
+
"""
|
989
|
+
self.registry = registry
|
990
|
+
self.check_interval = check_interval
|
991
|
+
self._running = False
|
992
|
+
self._check_task = None
|
993
|
+
self.running = False # Keep for backward compatibility
|
994
|
+
|
995
|
+
async def start(self, registry: ServiceRegistry = None):
|
996
|
+
"""Start health checking with the registry."""
|
997
|
+
if registry:
|
998
|
+
self.registry = registry
|
999
|
+
if not self.registry:
|
1000
|
+
raise ValueError("No registry provided for health checking")
|
1001
|
+
|
1002
|
+
# Set running state immediately
|
1003
|
+
self._running = True
|
1004
|
+
self.running = True
|
1005
|
+
|
1006
|
+
# Start periodic checks
|
1007
|
+
self._check_task = asyncio.create_task(
|
1008
|
+
self.start_periodic_checks(self.check_interval)
|
1009
|
+
)
|
1010
|
+
|
1011
|
+
async def start_periodic_checks(self, interval: float = 30.0):
|
1012
|
+
"""Start periodic health checks.
|
1013
|
+
|
1014
|
+
Args:
|
1015
|
+
interval: Check interval in seconds
|
1016
|
+
"""
|
1017
|
+
self._running = True
|
1018
|
+
self.running = True # Keep for backward compatibility
|
1019
|
+
logger.info(f"Started health checking with {interval}s interval")
|
1020
|
+
|
1021
|
+
while self._running:
|
1022
|
+
try:
|
1023
|
+
await self.check_all_servers()
|
1024
|
+
await asyncio.sleep(interval)
|
1025
|
+
except Exception as e:
|
1026
|
+
logger.error(f"Error in health checking: {e}")
|
1027
|
+
await asyncio.sleep(min(interval, 10)) # Back off on error
|
1028
|
+
|
1029
|
+
async def check_all_servers(self):
|
1030
|
+
"""Check health of all registered servers."""
|
1031
|
+
# Get all servers without filters
|
1032
|
+
servers = await self.registry.discover_servers(force_refresh=True)
|
1033
|
+
|
1034
|
+
# Check health in parallel with limited concurrency
|
1035
|
+
semaphore = asyncio.Semaphore(10)
|
1036
|
+
|
1037
|
+
async def check_server(server: ServerInfo):
|
1038
|
+
async with semaphore:
|
1039
|
+
health_result = await self.check_server_health(server)
|
1040
|
+
health_status = health_result["status"]
|
1041
|
+
response_time = health_result.get("response_time")
|
1042
|
+
|
1043
|
+
# Update health in all backends
|
1044
|
+
for backend in self.registry.backends:
|
1045
|
+
try:
|
1046
|
+
await backend.update_server_health(
|
1047
|
+
server.id, health_status, response_time
|
1048
|
+
)
|
1049
|
+
except Exception as e:
|
1050
|
+
logger.error(f"Failed to update health for {server.id}: {e}")
|
1051
|
+
|
1052
|
+
# Run health checks in parallel
|
1053
|
+
await asyncio.gather(
|
1054
|
+
*[check_server(server) for server in servers], return_exceptions=True
|
1055
|
+
)
|
1056
|
+
|
1057
|
+
async def check_server_health(self, server: ServerInfo) -> Dict[str, Any]:
|
1058
|
+
"""Check health of a single server.
|
1059
|
+
|
1060
|
+
Args:
|
1061
|
+
server: Server to check
|
1062
|
+
|
1063
|
+
Returns:
|
1064
|
+
Dictionary with status and response_time
|
1065
|
+
"""
|
1066
|
+
start_time = time.time()
|
1067
|
+
|
1068
|
+
try:
|
1069
|
+
if server.transport == "http" or server.transport == "sse":
|
1070
|
+
# HTTP/SSE health check
|
1071
|
+
import aiohttp
|
1072
|
+
|
1073
|
+
async with aiohttp.ClientSession(
|
1074
|
+
timeout=aiohttp.ClientTimeout(total=10)
|
1075
|
+
) as session:
|
1076
|
+
# Try health endpoint first (use configured endpoint or default to /health)
|
1077
|
+
health_path = server.health_endpoint or "/health"
|
1078
|
+
base_url = server.endpoint or server.url or ""
|
1079
|
+
health_url = f"{base_url.rstrip('/')}{health_path}"
|
1080
|
+
try:
|
1081
|
+
async with session.get(health_url) as response:
|
1082
|
+
if response.status == 200:
|
1083
|
+
response_time = time.time() - start_time
|
1084
|
+
return {
|
1085
|
+
"status": "healthy",
|
1086
|
+
"response_time": response_time,
|
1087
|
+
}
|
1088
|
+
except:
|
1089
|
+
pass
|
1090
|
+
|
1091
|
+
# Fallback to main endpoint
|
1092
|
+
try:
|
1093
|
+
async with session.get(server.endpoint) as response:
|
1094
|
+
response_time = time.time() - start_time
|
1095
|
+
if response.status < 500:
|
1096
|
+
return {
|
1097
|
+
"status": "healthy",
|
1098
|
+
"response_time": response_time,
|
1099
|
+
}
|
1100
|
+
else:
|
1101
|
+
return {
|
1102
|
+
"status": "unhealthy",
|
1103
|
+
"response_time": response_time,
|
1104
|
+
}
|
1105
|
+
except:
|
1106
|
+
return {"status": "unhealthy", "response_time": None}
|
1107
|
+
|
1108
|
+
elif server.transport == "stdio":
|
1109
|
+
# For stdio, check if command exists and is executable
|
1110
|
+
if server.command:
|
1111
|
+
try:
|
1112
|
+
# Test if we can run the command
|
1113
|
+
import asyncio
|
1114
|
+
|
1115
|
+
process = await asyncio.create_subprocess_exec(
|
1116
|
+
server.command,
|
1117
|
+
*(server.args if server.args else []),
|
1118
|
+
stdout=asyncio.subprocess.PIPE,
|
1119
|
+
stderr=asyncio.subprocess.PIPE,
|
1120
|
+
)
|
1121
|
+
returncode = await process.wait()
|
1122
|
+
response_time = time.time() - start_time
|
1123
|
+
|
1124
|
+
if returncode == 0:
|
1125
|
+
return {"status": "healthy", "response_time": response_time}
|
1126
|
+
else:
|
1127
|
+
return {
|
1128
|
+
"status": "unhealthy",
|
1129
|
+
"response_time": response_time,
|
1130
|
+
}
|
1131
|
+
except Exception as e:
|
1132
|
+
return {
|
1133
|
+
"status": "unhealthy",
|
1134
|
+
"response_time": None,
|
1135
|
+
"error": str(e),
|
1136
|
+
}
|
1137
|
+
else:
|
1138
|
+
# No command specified, check if recently seen
|
1139
|
+
age = time.time() - server.last_seen
|
1140
|
+
if age < 300: # 5 minutes
|
1141
|
+
return {"status": "healthy", "response_time": None}
|
1142
|
+
else:
|
1143
|
+
return {"status": "unknown", "response_time": None}
|
1144
|
+
|
1145
|
+
else:
|
1146
|
+
logger.warning(
|
1147
|
+
f"Unknown transport type for health check: {server.transport}"
|
1148
|
+
)
|
1149
|
+
return {"status": "unknown", "response_time": None}
|
1150
|
+
|
1151
|
+
except Exception as e:
|
1152
|
+
logger.debug(f"Health check failed for {server.name}: {e}")
|
1153
|
+
return {"status": "unhealthy", "response_time": None}
|
1154
|
+
|
1155
|
+
async def stop(self):
|
1156
|
+
"""Stop health checking."""
|
1157
|
+
self._running = False
|
1158
|
+
self.running = False
|
1159
|
+
if self._check_task:
|
1160
|
+
self._check_task.cancel()
|
1161
|
+
self._check_task = None
|
1162
|
+
logger.info("Stopped health checking")
|
1163
|
+
|
1164
|
+
|
1165
|
+
class ServiceMesh:
|
1166
|
+
"""Service mesh for intelligent client routing and load balancing."""
|
1167
|
+
|
1168
|
+
def __init__(self, registry: ServiceRegistry):
|
1169
|
+
"""Initialize service mesh.
|
1170
|
+
|
1171
|
+
Args:
|
1172
|
+
registry: Service registry to use
|
1173
|
+
"""
|
1174
|
+
self.registry = registry
|
1175
|
+
self._client_cache: Dict[str, Any] = {}
|
1176
|
+
self._load_balancer = LoadBalancer()
|
1177
|
+
|
1178
|
+
async def get_client_for_capability(
|
1179
|
+
self, capability: str, transport: Optional[str] = None
|
1180
|
+
) -> Optional[Any]:
|
1181
|
+
"""Get an MCP client for a specific capability.
|
1182
|
+
|
1183
|
+
Args:
|
1184
|
+
capability: Required capability
|
1185
|
+
transport: Preferred transport type
|
1186
|
+
|
1187
|
+
Returns:
|
1188
|
+
Configured MCP client or None
|
1189
|
+
"""
|
1190
|
+
# Find best server
|
1191
|
+
server = await self.registry.get_best_server(capability, transport)
|
1192
|
+
if not server:
|
1193
|
+
logger.warning(f"No server found for capability: {capability}")
|
1194
|
+
return None
|
1195
|
+
|
1196
|
+
# Check cache
|
1197
|
+
cache_key = f"{server.id}_{capability}"
|
1198
|
+
if cache_key in self._client_cache:
|
1199
|
+
return self._client_cache[cache_key]
|
1200
|
+
|
1201
|
+
# Create new client
|
1202
|
+
try:
|
1203
|
+
client = await self._create_client(server)
|
1204
|
+
self._client_cache[cache_key] = client
|
1205
|
+
logger.info(f"Created MCP client for {server.name} ({capability})")
|
1206
|
+
return client
|
1207
|
+
except Exception as e:
|
1208
|
+
logger.error(f"Failed to create client for {server.name}: {e}")
|
1209
|
+
return None
|
1210
|
+
|
1211
|
+
async def call_with_failover(
|
1212
|
+
self,
|
1213
|
+
capability: str,
|
1214
|
+
tool_name: str,
|
1215
|
+
arguments: Dict[str, Any],
|
1216
|
+
max_retries: int = 3,
|
1217
|
+
) -> Dict[str, Any]:
|
1218
|
+
"""Call a tool with automatic failover to backup servers.
|
1219
|
+
|
1220
|
+
Args:
|
1221
|
+
capability: Required capability
|
1222
|
+
tool_name: Tool to call
|
1223
|
+
arguments: Tool arguments
|
1224
|
+
max_retries: Maximum retry attempts
|
1225
|
+
|
1226
|
+
Returns:
|
1227
|
+
Tool result
|
1228
|
+
"""
|
1229
|
+
servers = await self.registry.discover_servers(
|
1230
|
+
capability=capability, healthy_only=True
|
1231
|
+
)
|
1232
|
+
|
1233
|
+
if not servers:
|
1234
|
+
raise ServiceDiscoveryError(
|
1235
|
+
f"No healthy servers found for capability: {capability}"
|
1236
|
+
)
|
1237
|
+
|
1238
|
+
last_error = None
|
1239
|
+
|
1240
|
+
for attempt in range(max_retries):
|
1241
|
+
# Select server using load balancer
|
1242
|
+
server = self._load_balancer.select_server(servers)
|
1243
|
+
if not server:
|
1244
|
+
break
|
1245
|
+
|
1246
|
+
try:
|
1247
|
+
# Create client for the selected server
|
1248
|
+
client = await self._create_client(server)
|
1249
|
+
|
1250
|
+
# Call the tool
|
1251
|
+
result = await client.call_tool(tool_name, arguments)
|
1252
|
+
|
1253
|
+
# Record successful call
|
1254
|
+
self._load_balancer.record_success(server.id)
|
1255
|
+
return result
|
1256
|
+
|
1257
|
+
except Exception as e:
|
1258
|
+
last_error = e
|
1259
|
+
logger.warning(f"Call to {server.name} failed: {e}")
|
1260
|
+
|
1261
|
+
# Record failure and remove from current attempt
|
1262
|
+
self._load_balancer.record_failure(server.id)
|
1263
|
+
servers = [s for s in servers if s.id != server.id]
|
1264
|
+
|
1265
|
+
# All retries failed
|
1266
|
+
raise ServiceDiscoveryError(
|
1267
|
+
f"All servers failed for capability {capability}: {last_error}"
|
1268
|
+
)
|
1269
|
+
|
1270
|
+
def _create_server_config(self, server: ServerInfo) -> Dict[str, Any]:
|
1271
|
+
"""Create server configuration for MCP client.
|
1272
|
+
|
1273
|
+
Args:
|
1274
|
+
server: Server information
|
1275
|
+
|
1276
|
+
Returns:
|
1277
|
+
Server configuration dictionary
|
1278
|
+
"""
|
1279
|
+
if server.transport == "stdio":
|
1280
|
+
# Parse command from endpoint
|
1281
|
+
if server.endpoint.startswith("python "):
|
1282
|
+
command_parts = server.endpoint.split()
|
1283
|
+
return {
|
1284
|
+
"transport": "stdio",
|
1285
|
+
"command": command_parts[0],
|
1286
|
+
"args": command_parts[1:],
|
1287
|
+
"env": server.metadata.get("env", {}),
|
1288
|
+
}
|
1289
|
+
else:
|
1290
|
+
return {
|
1291
|
+
"transport": "stdio",
|
1292
|
+
"command": server.endpoint,
|
1293
|
+
"args": [],
|
1294
|
+
"env": server.metadata.get("env", {}),
|
1295
|
+
}
|
1296
|
+
|
1297
|
+
elif server.transport in ["http", "sse"]:
|
1298
|
+
config = {"transport": server.transport, "url": server.endpoint}
|
1299
|
+
|
1300
|
+
# Add authentication config if present
|
1301
|
+
if server.auth_required:
|
1302
|
+
auth_config = server.metadata.get("auth_config", {})
|
1303
|
+
if auth_config:
|
1304
|
+
config["auth"] = auth_config
|
1305
|
+
|
1306
|
+
return config
|
1307
|
+
|
1308
|
+
else:
|
1309
|
+
# Fallback configuration
|
1310
|
+
return {"transport": server.transport, "url": server.endpoint}
|
1311
|
+
|
1312
|
+
async def _create_client(self, server: ServerInfo) -> Any:
|
1313
|
+
"""Create a client for the given server.
|
1314
|
+
|
1315
|
+
Args:
|
1316
|
+
server: Server to create client for
|
1317
|
+
|
1318
|
+
Returns:
|
1319
|
+
MCP client instance
|
1320
|
+
"""
|
1321
|
+
try:
|
1322
|
+
from .client import MCPClient
|
1323
|
+
|
1324
|
+
# Create client configuration based on transport
|
1325
|
+
if server.transport == "stdio":
|
1326
|
+
client = MCPClient(
|
1327
|
+
transport="stdio",
|
1328
|
+
command=server.command,
|
1329
|
+
args=server.args,
|
1330
|
+
env=server.metadata.get("env", {}),
|
1331
|
+
)
|
1332
|
+
elif server.transport in ["http", "sse"]:
|
1333
|
+
client = MCPClient(
|
1334
|
+
transport=server.transport, url=server.url or server.endpoint
|
1335
|
+
)
|
1336
|
+
else:
|
1337
|
+
raise ValueError(f"Unsupported transport: {server.transport}")
|
1338
|
+
|
1339
|
+
return client
|
1340
|
+
except Exception as e:
|
1341
|
+
logger.error(f"Failed to create client for {server.name}: {e}")
|
1342
|
+
raise
|
1343
|
+
|
1344
|
+
|
1345
|
+
class LoadBalancer:
|
1346
|
+
"""Load balancer for distributing requests across servers."""
|
1347
|
+
|
1348
|
+
def __init__(self):
|
1349
|
+
"""Initialize load balancer."""
|
1350
|
+
self._server_stats: Dict[str, Dict[str, Any]] = {}
|
1351
|
+
|
1352
|
+
def select_server(self, servers: List[ServerInfo]) -> Optional[ServerInfo]:
|
1353
|
+
"""Select best server based on load balancing algorithm.
|
1354
|
+
|
1355
|
+
Args:
|
1356
|
+
servers: List of available servers
|
1357
|
+
|
1358
|
+
Returns:
|
1359
|
+
Selected server or None
|
1360
|
+
"""
|
1361
|
+
if not servers:
|
1362
|
+
return None
|
1363
|
+
|
1364
|
+
# Weight servers by priority score and current load
|
1365
|
+
weighted_servers = []
|
1366
|
+
|
1367
|
+
for server in servers:
|
1368
|
+
base_weight = server.get_priority_score()
|
1369
|
+
|
1370
|
+
# Adjust weight based on current load
|
1371
|
+
stats = self._server_stats.get(server.id, {})
|
1372
|
+
recent_failures = stats.get("recent_failures", 0)
|
1373
|
+
recent_calls = stats.get("recent_calls", 0)
|
1374
|
+
|
1375
|
+
# Penalty for recent failures
|
1376
|
+
if recent_failures > 0:
|
1377
|
+
base_weight *= 1.0 - min(0.5, recent_failures * 0.1)
|
1378
|
+
|
1379
|
+
# Penalty for high load
|
1380
|
+
if recent_calls > 10:
|
1381
|
+
base_weight *= 1.0 - min(0.3, (recent_calls - 10) * 0.02)
|
1382
|
+
|
1383
|
+
weighted_servers.append((server, max(0.1, base_weight)))
|
1384
|
+
|
1385
|
+
# Select using weighted random
|
1386
|
+
import random
|
1387
|
+
|
1388
|
+
total_weight = sum(weight for _, weight in weighted_servers)
|
1389
|
+
|
1390
|
+
if total_weight == 0:
|
1391
|
+
return servers[0] # Fallback to first server
|
1392
|
+
|
1393
|
+
r = random.uniform(0, total_weight)
|
1394
|
+
current_weight = 0
|
1395
|
+
|
1396
|
+
for server, weight in weighted_servers:
|
1397
|
+
current_weight += weight
|
1398
|
+
if r <= current_weight:
|
1399
|
+
return server
|
1400
|
+
|
1401
|
+
return servers[0] # Fallback
|
1402
|
+
|
1403
|
+
def record_success(self, server_id: str):
|
1404
|
+
"""Record successful call to server."""
|
1405
|
+
if server_id not in self._server_stats:
|
1406
|
+
self._server_stats[server_id] = {
|
1407
|
+
"recent_calls": 0,
|
1408
|
+
"recent_failures": 0,
|
1409
|
+
"last_reset": time.time(),
|
1410
|
+
}
|
1411
|
+
|
1412
|
+
stats = self._server_stats[server_id]
|
1413
|
+
stats["recent_calls"] += 1
|
1414
|
+
|
1415
|
+
# Decay recent failures on success
|
1416
|
+
stats["recent_failures"] = max(0, stats["recent_failures"] - 1)
|
1417
|
+
|
1418
|
+
self._maybe_reset_stats(server_id)
|
1419
|
+
|
1420
|
+
def record_failure(self, server_id: str):
|
1421
|
+
"""Record failed call to server."""
|
1422
|
+
if server_id not in self._server_stats:
|
1423
|
+
self._server_stats[server_id] = {
|
1424
|
+
"recent_calls": 0,
|
1425
|
+
"recent_failures": 0,
|
1426
|
+
"last_reset": time.time(),
|
1427
|
+
}
|
1428
|
+
|
1429
|
+
stats = self._server_stats[server_id]
|
1430
|
+
stats["recent_failures"] += 1
|
1431
|
+
|
1432
|
+
self._maybe_reset_stats(server_id)
|
1433
|
+
|
1434
|
+
def _maybe_reset_stats(self, server_id: str):
|
1435
|
+
"""Reset stats if they're getting stale."""
|
1436
|
+
stats = self._server_stats[server_id]
|
1437
|
+
if time.time() - stats["last_reset"] > 300: # Reset every 5 minutes
|
1438
|
+
stats["recent_calls"] = 0
|
1439
|
+
stats["recent_failures"] = 0
|
1440
|
+
stats["last_reset"] = time.time()
|
1441
|
+
|
1442
|
+
def _calculate_priority_score(self, server: ServerInfo) -> float:
|
1443
|
+
"""Calculate priority score for a server.
|
1444
|
+
|
1445
|
+
Args:
|
1446
|
+
server: Server to calculate score for
|
1447
|
+
|
1448
|
+
Returns:
|
1449
|
+
Priority score (0 means unhealthy)
|
1450
|
+
"""
|
1451
|
+
# Unhealthy servers get 0 score
|
1452
|
+
if hasattr(server, "health") and server.health:
|
1453
|
+
if server.health.get("status") == "unhealthy":
|
1454
|
+
return 0
|
1455
|
+
elif server.health_status == "unhealthy":
|
1456
|
+
return 0
|
1457
|
+
|
1458
|
+
# Use server's get_priority_score method
|
1459
|
+
return server.get_priority_score()
|
1460
|
+
|
1461
|
+
def select_best_server(self, servers: List[ServerInfo]) -> Optional[ServerInfo]:
|
1462
|
+
"""Select the best server based on priority scores.
|
1463
|
+
|
1464
|
+
Args:
|
1465
|
+
servers: List of servers to choose from
|
1466
|
+
|
1467
|
+
Returns:
|
1468
|
+
Best server or None
|
1469
|
+
"""
|
1470
|
+
if not servers:
|
1471
|
+
return None
|
1472
|
+
|
1473
|
+
# Score and sort servers
|
1474
|
+
scored_servers = [
|
1475
|
+
(server, self._calculate_priority_score(server)) for server in servers
|
1476
|
+
]
|
1477
|
+
scored_servers.sort(key=lambda x: x[1], reverse=True)
|
1478
|
+
|
1479
|
+
# Return best server (highest score)
|
1480
|
+
return scored_servers[0][0] if scored_servers else None
|
1481
|
+
|
1482
|
+
def select_servers_round_robin(
|
1483
|
+
self, servers: List[ServerInfo], count: int
|
1484
|
+
) -> List[ServerInfo]:
|
1485
|
+
"""Select servers using round-robin algorithm.
|
1486
|
+
|
1487
|
+
Args:
|
1488
|
+
servers: List of available servers
|
1489
|
+
count: Number of servers to select
|
1490
|
+
|
1491
|
+
Returns:
|
1492
|
+
Selected servers
|
1493
|
+
"""
|
1494
|
+
if not servers:
|
1495
|
+
return []
|
1496
|
+
|
1497
|
+
# Track round-robin state
|
1498
|
+
if not hasattr(self, "_round_robin_index"):
|
1499
|
+
self._round_robin_index = {}
|
1500
|
+
|
1501
|
+
# Create a key for this server list
|
1502
|
+
server_key = tuple(s.name for s in servers)
|
1503
|
+
|
1504
|
+
if server_key not in self._round_robin_index:
|
1505
|
+
self._round_robin_index[server_key] = 0
|
1506
|
+
|
1507
|
+
selected = []
|
1508
|
+
start_index = self._round_robin_index[server_key]
|
1509
|
+
|
1510
|
+
for i in range(count):
|
1511
|
+
index = (start_index + i) % len(servers)
|
1512
|
+
selected.append(servers[index])
|
1513
|
+
|
1514
|
+
# Update index for next call
|
1515
|
+
self._round_robin_index[server_key] = (start_index + count) % len(servers)
|
1516
|
+
|
1517
|
+
return selected
|
1518
|
+
|
1519
|
+
def select_servers(
|
1520
|
+
self, servers: List[ServerInfo], count: int = 1, strategy: str = "priority"
|
1521
|
+
) -> List[ServerInfo]:
|
1522
|
+
"""Select multiple servers based on strategy.
|
1523
|
+
|
1524
|
+
Args:
|
1525
|
+
servers: List of available servers
|
1526
|
+
count: Number of servers to select
|
1527
|
+
strategy: Selection strategy ("priority", "round_robin", "random")
|
1528
|
+
|
1529
|
+
Returns:
|
1530
|
+
Selected servers
|
1531
|
+
"""
|
1532
|
+
if not servers:
|
1533
|
+
return []
|
1534
|
+
|
1535
|
+
count = min(count, len(servers))
|
1536
|
+
|
1537
|
+
if strategy == "round_robin":
|
1538
|
+
return self.select_servers_round_robin(servers, count)
|
1539
|
+
elif strategy == "random":
|
1540
|
+
import random
|
1541
|
+
|
1542
|
+
return random.sample(servers, count)
|
1543
|
+
else: # priority
|
1544
|
+
# Sort by priority score
|
1545
|
+
scored = [(s, self._calculate_priority_score(s)) for s in servers]
|
1546
|
+
scored.sort(key=lambda x: x[1], reverse=True)
|
1547
|
+
return [s[0] for s in scored[:count]]
|
1548
|
+
|
1549
|
+
|
1550
|
+
# Convenience functions for easy setup
|
1551
|
+
def create_default_registry() -> ServiceRegistry:
|
1552
|
+
"""Create a default service registry with file and network discovery."""
|
1553
|
+
file_backend = FileBasedDiscovery()
|
1554
|
+
return ServiceRegistry([file_backend])
|
1555
|
+
|
1556
|
+
|
1557
|
+
async def discover_mcp_servers(
|
1558
|
+
capability: Optional[str] = None, transport: Optional[str] = None
|
1559
|
+
) -> List[ServerInfo]:
|
1560
|
+
"""Discover MCP servers with optional filtering.
|
1561
|
+
|
1562
|
+
Args:
|
1563
|
+
capability: Filter by capability
|
1564
|
+
transport: Filter by transport type
|
1565
|
+
|
1566
|
+
Returns:
|
1567
|
+
List of discovered servers
|
1568
|
+
"""
|
1569
|
+
registry = create_default_registry()
|
1570
|
+
|
1571
|
+
filters = {}
|
1572
|
+
if capability:
|
1573
|
+
filters["capability"] = capability
|
1574
|
+
if transport:
|
1575
|
+
filters["transport"] = transport
|
1576
|
+
|
1577
|
+
return await registry.discover_servers(**filters)
|
1578
|
+
|
1579
|
+
|
1580
|
+
async def get_mcp_client(capability: str, transport: Optional[str] = None):
|
1581
|
+
"""Get an MCP client for a specific capability.
|
1582
|
+
|
1583
|
+
Args:
|
1584
|
+
capability: Required capability
|
1585
|
+
transport: Preferred transport type
|
1586
|
+
|
1587
|
+
Returns:
|
1588
|
+
Configured MCP client
|
1589
|
+
"""
|
1590
|
+
registry = create_default_registry()
|
1591
|
+
mesh = ServiceMesh(registry)
|
1592
|
+
|
1593
|
+
return await mesh.get_client_for_capability(capability, transport)
|