kailash 0.8.5__py3-none-any.whl → 0.8.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +5 -5
- kailash/channels/__init__.py +2 -1
- kailash/channels/mcp_channel.py +23 -4
- kailash/cli/validate_imports.py +202 -0
- kailash/core/resilience/bulkhead.py +15 -5
- kailash/core/resilience/circuit_breaker.py +4 -1
- kailash/core/resilience/health_monitor.py +312 -84
- kailash/edge/migration/edge_migration_service.py +384 -0
- kailash/mcp_server/protocol.py +26 -0
- kailash/mcp_server/server.py +1081 -8
- kailash/mcp_server/subscriptions.py +1560 -0
- kailash/mcp_server/transports.py +305 -0
- kailash/middleware/gateway/event_store.py +1 -0
- kailash/nodes/base.py +77 -1
- kailash/nodes/code/python.py +44 -3
- kailash/nodes/data/async_sql.py +42 -20
- kailash/nodes/edge/edge_migration_node.py +16 -12
- kailash/nodes/governance.py +410 -0
- kailash/nodes/rag/registry.py +1 -1
- kailash/nodes/transaction/distributed_transaction_manager.py +48 -1
- kailash/nodes/transaction/saga_state_storage.py +2 -1
- kailash/nodes/validation.py +8 -8
- kailash/runtime/local.py +30 -0
- kailash/runtime/validation/__init__.py +7 -15
- kailash/runtime/validation/import_validator.py +446 -0
- kailash/runtime/validation/suggestion_engine.py +5 -5
- kailash/utils/data_paths.py +74 -0
- kailash/workflow/builder.py +183 -4
- kailash/workflow/mermaid_visualizer.py +3 -1
- kailash/workflow/templates.py +6 -6
- kailash/workflow/validation.py +134 -3
- {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/METADATA +20 -17
- {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/RECORD +37 -31
- {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/WHEEL +0 -0
- {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/entry_points.txt +0 -0
- {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/top_level.txt +0 -0
kailash/mcp_server/server.py
CHANGED
@@ -54,6 +54,8 @@ Enhanced Production Usage:
|
|
54
54
|
|
55
55
|
import asyncio
|
56
56
|
import functools
|
57
|
+
import gzip
|
58
|
+
import json
|
57
59
|
import logging
|
58
60
|
import time
|
59
61
|
import uuid
|
@@ -74,6 +76,7 @@ from .errors import (
|
|
74
76
|
RetryableOperation,
|
75
77
|
ToolError,
|
76
78
|
)
|
79
|
+
from .protocol import get_protocol_manager
|
77
80
|
from .utils import CacheManager, ConfigManager, MetricsCollector, format_response
|
78
81
|
|
79
82
|
logger = logging.getLogger(__name__)
|
@@ -345,6 +348,11 @@ class MCPServer:
|
|
345
348
|
self,
|
346
349
|
name: str,
|
347
350
|
config_file: Optional[Union[str, Path]] = None,
|
351
|
+
# Transport configuration
|
352
|
+
transport: str = "stdio", # "stdio", "websocket", "http", "sse"
|
353
|
+
websocket_host: str = "0.0.0.0",
|
354
|
+
websocket_port: int = 3001,
|
355
|
+
# Caching configuration
|
348
356
|
enable_cache: bool = True,
|
349
357
|
cache_ttl: int = 300,
|
350
358
|
cache_backend: str = "memory", # "memory" or "redis"
|
@@ -364,6 +372,13 @@ class MCPServer:
|
|
364
372
|
transport_timeout: float = 30.0,
|
365
373
|
max_request_size: int = 10_000_000, # 10MB
|
366
374
|
enable_streaming: bool = False,
|
375
|
+
# Resource subscription configuration
|
376
|
+
enable_subscriptions: bool = True,
|
377
|
+
event_store=None,
|
378
|
+
# WebSocket compression configuration
|
379
|
+
enable_websocket_compression: bool = False,
|
380
|
+
compression_threshold: int = 1024, # Only compress messages larger than 1KB
|
381
|
+
compression_level: int = 6, # 1 (fastest) to 9 (best compression)
|
367
382
|
):
|
368
383
|
"""
|
369
384
|
Initialize enhanced MCP server.
|
@@ -371,6 +386,9 @@ class MCPServer:
|
|
371
386
|
Args:
|
372
387
|
name: Server name
|
373
388
|
config_file: Optional configuration file path
|
389
|
+
transport: Transport to use ("stdio", "websocket", "http", "sse")
|
390
|
+
websocket_host: Host for WebSocket server (default: "0.0.0.0")
|
391
|
+
websocket_port: Port for WebSocket server (default: 3001)
|
374
392
|
enable_cache: Whether to enable caching (default: True)
|
375
393
|
cache_ttl: Default cache TTL in seconds (default: 300)
|
376
394
|
cache_backend: Cache backend ("memory" or "redis")
|
@@ -388,9 +406,24 @@ class MCPServer:
|
|
388
406
|
transport_timeout: Transport timeout in seconds
|
389
407
|
max_request_size: Maximum request size in bytes
|
390
408
|
enable_streaming: Enable streaming support
|
409
|
+
enable_subscriptions: Enable resource subscriptions (default: True)
|
410
|
+
event_store: Optional event store for subscription logging
|
411
|
+
enable_websocket_compression: Enable gzip compression for WebSocket messages (default: False)
|
412
|
+
compression_threshold: Only compress messages larger than this size in bytes (default: 1024)
|
413
|
+
compression_level: Compression level from 1 (fastest) to 9 (best compression) (default: 6)
|
391
414
|
"""
|
392
415
|
self.name = name
|
393
416
|
|
417
|
+
# Transport configuration
|
418
|
+
self.transport = transport
|
419
|
+
self.websocket_host = websocket_host
|
420
|
+
self.websocket_port = websocket_port
|
421
|
+
|
422
|
+
# WebSocket compression configuration
|
423
|
+
self.enable_websocket_compression = enable_websocket_compression
|
424
|
+
self.compression_threshold = compression_threshold
|
425
|
+
self.compression_level = compression_level
|
426
|
+
|
394
427
|
# Enhanced features
|
395
428
|
self.auth_provider = auth_provider
|
396
429
|
self.enable_http_transport = enable_http_transport
|
@@ -410,7 +443,9 @@ class MCPServer:
|
|
410
443
|
"server": {
|
411
444
|
"name": name,
|
412
445
|
"version": "1.0.0",
|
413
|
-
"transport":
|
446
|
+
"transport": transport,
|
447
|
+
"websocket_host": websocket_host,
|
448
|
+
"websocket_port": websocket_port,
|
414
449
|
"enable_http": enable_http_transport,
|
415
450
|
"enable_sse": enable_sse_transport,
|
416
451
|
"timeout": transport_timeout,
|
@@ -500,6 +535,30 @@ class MCPServer:
|
|
500
535
|
self._resource_registry: Dict[str, Dict[str, Any]] = {}
|
501
536
|
self._prompt_registry: Dict[str, Dict[str, Any]] = {}
|
502
537
|
|
538
|
+
# Client management for new handlers
|
539
|
+
self.client_info: Dict[str, Dict[str, Any]] = {}
|
540
|
+
self._pending_sampling_requests: Dict[str, Dict[str, Any]] = {}
|
541
|
+
|
542
|
+
# Resource subscription support
|
543
|
+
self.enable_subscriptions = enable_subscriptions
|
544
|
+
self.event_store = event_store
|
545
|
+
self.subscription_manager = None
|
546
|
+
if self.enable_subscriptions:
|
547
|
+
from .subscriptions import ResourceSubscriptionManager
|
548
|
+
|
549
|
+
self.subscription_manager = ResourceSubscriptionManager(
|
550
|
+
auth_manager=(
|
551
|
+
self.auth_manager if hasattr(self, "auth_manager") else None
|
552
|
+
),
|
553
|
+
event_store=event_store,
|
554
|
+
rate_limiter=(
|
555
|
+
self.rate_limiter if hasattr(self, "rate_limiter") else None
|
556
|
+
),
|
557
|
+
)
|
558
|
+
|
559
|
+
# Transport instance (for WebSocket and other transports)
|
560
|
+
self._transport = None
|
561
|
+
|
503
562
|
def _init_mcp(self):
|
504
563
|
"""Initialize FastMCP server."""
|
505
564
|
if self._mcp is not None:
|
@@ -1207,10 +1266,24 @@ class MCPServer:
|
|
1207
1266
|
self._init_mcp()
|
1208
1267
|
|
1209
1268
|
# Wrap with metrics if enabled
|
1269
|
+
wrapped_func = func
|
1210
1270
|
if self.metrics.enabled:
|
1211
|
-
|
1271
|
+
wrapped_func = self.metrics.track_tool(f"resource:{uri}")(func)
|
1212
1272
|
|
1213
|
-
|
1273
|
+
# Register with FastMCP
|
1274
|
+
mcp_resource = self._mcp.resource(uri)(wrapped_func)
|
1275
|
+
|
1276
|
+
# Track in registry
|
1277
|
+
self._resource_registry[uri] = {
|
1278
|
+
"handler": mcp_resource,
|
1279
|
+
"original_handler": func,
|
1280
|
+
"name": uri,
|
1281
|
+
"description": func.__doc__ or f"Resource: {uri}",
|
1282
|
+
"mime_type": "text/plain",
|
1283
|
+
"created_at": time.time(),
|
1284
|
+
}
|
1285
|
+
|
1286
|
+
return mcp_resource
|
1214
1287
|
|
1215
1288
|
return decorator
|
1216
1289
|
|
@@ -1230,10 +1303,23 @@ class MCPServer:
|
|
1230
1303
|
self._init_mcp()
|
1231
1304
|
|
1232
1305
|
# Wrap with metrics if enabled
|
1306
|
+
wrapped_func = func
|
1233
1307
|
if self.metrics.enabled:
|
1234
|
-
|
1308
|
+
wrapped_func = self.metrics.track_tool(f"prompt:{name}")(func)
|
1235
1309
|
|
1236
|
-
|
1310
|
+
# Register with FastMCP
|
1311
|
+
mcp_prompt = self._mcp.prompt(name)(wrapped_func)
|
1312
|
+
|
1313
|
+
# Track in registry
|
1314
|
+
self._prompt_registry[name] = {
|
1315
|
+
"handler": mcp_prompt,
|
1316
|
+
"original_handler": func,
|
1317
|
+
"description": func.__doc__ or f"Prompt: {name}",
|
1318
|
+
"arguments": [], # Could be extracted from function signature
|
1319
|
+
"created_at": time.time(),
|
1320
|
+
}
|
1321
|
+
|
1322
|
+
return mcp_prompt
|
1237
1323
|
|
1238
1324
|
return decorator
|
1239
1325
|
|
@@ -1454,6 +1540,47 @@ class MCPServer:
|
|
1454
1540
|
return True
|
1455
1541
|
return False
|
1456
1542
|
|
1543
|
+
def _execute_tool(self, tool_name: str, arguments: dict) -> Any:
|
1544
|
+
"""Execute a tool directly (for testing purposes)."""
|
1545
|
+
if tool_name not in self._tool_registry:
|
1546
|
+
raise ValueError(f"Tool '{tool_name}' not found in registry")
|
1547
|
+
|
1548
|
+
tool_info = self._tool_registry[tool_name]
|
1549
|
+
if tool_info.get("disabled", False):
|
1550
|
+
raise ValueError(f"Tool '{tool_name}' is currently disabled")
|
1551
|
+
|
1552
|
+
# Get the tool handler (the enhanced function)
|
1553
|
+
if "handler" in tool_info:
|
1554
|
+
handler = tool_info["handler"]
|
1555
|
+
elif "function" in tool_info:
|
1556
|
+
handler = tool_info["function"]
|
1557
|
+
else:
|
1558
|
+
raise ValueError(f"Tool '{tool_name}' has no valid handler")
|
1559
|
+
|
1560
|
+
# Update statistics
|
1561
|
+
tool_info["call_count"] = tool_info.get("call_count", 0) + 1
|
1562
|
+
tool_info["last_called"] = time.time()
|
1563
|
+
|
1564
|
+
try:
|
1565
|
+
# Execute the tool
|
1566
|
+
if asyncio.iscoroutinefunction(handler):
|
1567
|
+
# For async functions, we need to run in event loop
|
1568
|
+
try:
|
1569
|
+
loop = asyncio.get_event_loop()
|
1570
|
+
if loop.is_running():
|
1571
|
+
# Already in async context - create task
|
1572
|
+
return asyncio.create_task(handler(**arguments))
|
1573
|
+
else:
|
1574
|
+
return loop.run_until_complete(handler(**arguments))
|
1575
|
+
except RuntimeError:
|
1576
|
+
# No event loop - create new one
|
1577
|
+
return asyncio.run(handler(**arguments))
|
1578
|
+
else:
|
1579
|
+
return handler(**arguments)
|
1580
|
+
except Exception as e:
|
1581
|
+
tool_info["error_count"] = tool_info.get("error_count", 0) + 1
|
1582
|
+
raise
|
1583
|
+
|
1457
1584
|
def run(self):
|
1458
1585
|
"""Run the enhanced MCP server with all features."""
|
1459
1586
|
if self._mcp is None:
|
@@ -1464,6 +1591,7 @@ class MCPServer:
|
|
1464
1591
|
|
1465
1592
|
# Log enhanced server startup
|
1466
1593
|
logger.info(f"Starting enhanced MCP server: {self.name}")
|
1594
|
+
logger.info(f"Transport: {self.transport}")
|
1467
1595
|
logger.info("Features enabled:")
|
1468
1596
|
logger.info(f" - Cache: {self.cache.enabled if self.cache else False}")
|
1469
1597
|
logger.info(f" - Metrics: {self.metrics.enabled if self.metrics else False}")
|
@@ -1490,9 +1618,16 @@ class MCPServer:
|
|
1490
1618
|
if health["status"] != "healthy":
|
1491
1619
|
logger.warning(f"Server health check shows issues: {health['issues']}")
|
1492
1620
|
|
1493
|
-
# Run
|
1494
|
-
|
1495
|
-
|
1621
|
+
# Run server based on transport type
|
1622
|
+
if self.transport == "websocket":
|
1623
|
+
logger.info(
|
1624
|
+
f"Starting WebSocket server on {self.websocket_host}:{self.websocket_port}..."
|
1625
|
+
)
|
1626
|
+
asyncio.run(self._run_websocket())
|
1627
|
+
else:
|
1628
|
+
# Default to FastMCP (STDIO) server
|
1629
|
+
logger.info("Starting FastMCP server in STDIO mode...")
|
1630
|
+
self._mcp.run()
|
1496
1631
|
|
1497
1632
|
except KeyboardInterrupt:
|
1498
1633
|
logger.info("Server stopped by user")
|
@@ -1527,6 +1662,944 @@ class MCPServer:
|
|
1527
1662
|
self._running = False
|
1528
1663
|
logger.info(f"Enhanced MCP server '{self.name}' stopped")
|
1529
1664
|
|
1665
|
+
async def _run_websocket(self):
|
1666
|
+
"""Run the server using WebSocket transport."""
|
1667
|
+
from .transports import WebSocketServerTransport
|
1668
|
+
|
1669
|
+
try:
|
1670
|
+
# Create WebSocket transport
|
1671
|
+
self._transport = WebSocketServerTransport(
|
1672
|
+
host=self.websocket_host,
|
1673
|
+
port=self.websocket_port,
|
1674
|
+
message_handler=self._handle_websocket_message,
|
1675
|
+
auth_provider=self.auth_provider,
|
1676
|
+
timeout=self.transport_timeout,
|
1677
|
+
max_message_size=self.max_request_size,
|
1678
|
+
enable_metrics=self.metrics.enabled if self.metrics else False,
|
1679
|
+
)
|
1680
|
+
|
1681
|
+
# Start WebSocket server
|
1682
|
+
await self._transport.connect()
|
1683
|
+
logger.info(
|
1684
|
+
f"WebSocket server started on {self.websocket_host}:{self.websocket_port}"
|
1685
|
+
)
|
1686
|
+
|
1687
|
+
# Set up subscription notification callback
|
1688
|
+
if self.subscription_manager:
|
1689
|
+
await self.subscription_manager.initialize()
|
1690
|
+
self.subscription_manager.set_notification_callback(
|
1691
|
+
self._send_websocket_notification
|
1692
|
+
)
|
1693
|
+
|
1694
|
+
# Keep server running
|
1695
|
+
try:
|
1696
|
+
await asyncio.Future() # Run forever
|
1697
|
+
except asyncio.CancelledError:
|
1698
|
+
logger.info("WebSocket server cancelled")
|
1699
|
+
|
1700
|
+
finally:
|
1701
|
+
# Clean up
|
1702
|
+
if self._transport:
|
1703
|
+
await self._transport.disconnect()
|
1704
|
+
self._transport = None
|
1705
|
+
|
1706
|
+
async def _handle_websocket_message(
|
1707
|
+
self, request: Dict[str, Any], client_id: str
|
1708
|
+
) -> Dict[str, Any]:
|
1709
|
+
"""Handle incoming WebSocket message with decompression support."""
|
1710
|
+
try:
|
1711
|
+
# Decompress message if needed
|
1712
|
+
decompressed_request = self._decompress_message(request)
|
1713
|
+
|
1714
|
+
method = decompressed_request.get("method", "")
|
1715
|
+
params = decompressed_request.get("params", {})
|
1716
|
+
request_id = decompressed_request.get("id")
|
1717
|
+
|
1718
|
+
# Log request
|
1719
|
+
logger.debug(f"WebSocket request from {client_id}: {method}")
|
1720
|
+
|
1721
|
+
# Route to appropriate handler
|
1722
|
+
if method == "initialize":
|
1723
|
+
return await self._handle_initialize(params, request_id, client_id)
|
1724
|
+
elif method == "tools/list":
|
1725
|
+
return await self._handle_list_tools(params, request_id)
|
1726
|
+
elif method == "tools/call":
|
1727
|
+
return await self._handle_call_tool(params, request_id)
|
1728
|
+
elif method == "resources/list":
|
1729
|
+
return await self._handle_list_resources(params, request_id)
|
1730
|
+
elif method == "resources/read":
|
1731
|
+
return await self._handle_read_resource(params, request_id, client_id)
|
1732
|
+
elif method == "resources/subscribe":
|
1733
|
+
return await self._handle_subscribe(params, request_id, client_id)
|
1734
|
+
elif method == "resources/unsubscribe":
|
1735
|
+
return await self._handle_unsubscribe(params, request_id, client_id)
|
1736
|
+
elif method == "resources/batch_subscribe":
|
1737
|
+
return await self._handle_batch_subscribe(params, request_id, client_id)
|
1738
|
+
elif method == "resources/batch_unsubscribe":
|
1739
|
+
return await self._handle_batch_unsubscribe(
|
1740
|
+
params, request_id, client_id
|
1741
|
+
)
|
1742
|
+
elif method == "prompts/list":
|
1743
|
+
return await self._handle_list_prompts(params, request_id)
|
1744
|
+
elif method == "prompts/get":
|
1745
|
+
return await self._handle_get_prompt(params, request_id)
|
1746
|
+
elif method == "logging/setLevel":
|
1747
|
+
return await self._handle_logging_set_level(params, request_id)
|
1748
|
+
elif method == "roots/list":
|
1749
|
+
# Add client_id to params for roots/list handler
|
1750
|
+
params_with_client = {**params, "client_id": client_id}
|
1751
|
+
return await self._handle_roots_list(params_with_client, request_id)
|
1752
|
+
elif method == "completion/complete":
|
1753
|
+
return await self._handle_completion_complete(params, request_id)
|
1754
|
+
elif method == "sampling/createMessage":
|
1755
|
+
# Add client_id to params for sampling handler
|
1756
|
+
params_with_client = {**params, "client_id": client_id}
|
1757
|
+
return await self._handle_sampling_create_message(
|
1758
|
+
params_with_client, request_id
|
1759
|
+
)
|
1760
|
+
else:
|
1761
|
+
return {
|
1762
|
+
"jsonrpc": "2.0",
|
1763
|
+
"error": {"code": -32601, "message": f"Method not found: {method}"},
|
1764
|
+
"id": request_id,
|
1765
|
+
}
|
1766
|
+
|
1767
|
+
except Exception as e:
|
1768
|
+
logger.error(f"Error handling WebSocket message: {e}")
|
1769
|
+
return {
|
1770
|
+
"jsonrpc": "2.0",
|
1771
|
+
"error": {"code": -32603, "message": f"Internal error: {str(e)}"},
|
1772
|
+
"id": request.get("id"),
|
1773
|
+
}
|
1774
|
+
|
1775
|
+
async def _handle_initialize(
|
1776
|
+
self, params: Dict[str, Any], request_id: Any, client_id: str = None
|
1777
|
+
) -> Dict[str, Any]:
|
1778
|
+
"""Handle initialize request."""
|
1779
|
+
# Store client information for capability checks
|
1780
|
+
if client_id:
|
1781
|
+
self.client_info[client_id] = {
|
1782
|
+
"capabilities": params.get("capabilities", {}),
|
1783
|
+
"name": params.get("clientInfo", {}).get("name", "unknown"),
|
1784
|
+
"version": params.get("clientInfo", {}).get("version", "unknown"),
|
1785
|
+
"initialized_at": time.time(),
|
1786
|
+
}
|
1787
|
+
|
1788
|
+
return {
|
1789
|
+
"jsonrpc": "2.0",
|
1790
|
+
"result": {
|
1791
|
+
"protocolVersion": "2024-11-05",
|
1792
|
+
"capabilities": {
|
1793
|
+
"tools": {"listSupported": True, "callSupported": True},
|
1794
|
+
"resources": {
|
1795
|
+
"listSupported": True,
|
1796
|
+
"readSupported": True,
|
1797
|
+
"subscribe": self.enable_subscriptions,
|
1798
|
+
"listChanged": self.enable_subscriptions,
|
1799
|
+
"batch_subscribe": self.enable_subscriptions,
|
1800
|
+
"batch_unsubscribe": self.enable_subscriptions,
|
1801
|
+
},
|
1802
|
+
"prompts": {"listSupported": True, "getSupported": True},
|
1803
|
+
"logging": {"setLevel": True},
|
1804
|
+
"roots": {"list": True},
|
1805
|
+
"experimental": {
|
1806
|
+
"progressNotifications": True,
|
1807
|
+
"cancellation": True,
|
1808
|
+
"completion": True,
|
1809
|
+
"sampling": True,
|
1810
|
+
"websocketCompression": self.enable_websocket_compression,
|
1811
|
+
},
|
1812
|
+
},
|
1813
|
+
"serverInfo": {
|
1814
|
+
"name": self.name,
|
1815
|
+
"version": self.config.get("server.version", "1.0.0"),
|
1816
|
+
},
|
1817
|
+
},
|
1818
|
+
"id": request_id,
|
1819
|
+
}
|
1820
|
+
|
1821
|
+
async def _handle_list_tools(
|
1822
|
+
self, params: Dict[str, Any], request_id: Any
|
1823
|
+
) -> Dict[str, Any]:
|
1824
|
+
"""Handle tools/list request."""
|
1825
|
+
tools = []
|
1826
|
+
for name, info in self._tool_registry.items():
|
1827
|
+
if not info.get("disabled", False):
|
1828
|
+
tools.append(
|
1829
|
+
{
|
1830
|
+
"name": name,
|
1831
|
+
"description": info.get("description", ""),
|
1832
|
+
"inputSchema": info.get("input_schema", {}),
|
1833
|
+
}
|
1834
|
+
)
|
1835
|
+
|
1836
|
+
return {"jsonrpc": "2.0", "result": {"tools": tools}, "id": request_id}
|
1837
|
+
|
1838
|
+
async def _handle_call_tool(
|
1839
|
+
self, params: Dict[str, Any], request_id: Any
|
1840
|
+
) -> Dict[str, Any]:
|
1841
|
+
"""Handle tools/call request."""
|
1842
|
+
tool_name = params.get("name")
|
1843
|
+
arguments = params.get("arguments", {})
|
1844
|
+
|
1845
|
+
try:
|
1846
|
+
result = self._execute_tool(tool_name, arguments)
|
1847
|
+
|
1848
|
+
# Handle async results
|
1849
|
+
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
|
1850
|
+
result = await result
|
1851
|
+
|
1852
|
+
return {
|
1853
|
+
"jsonrpc": "2.0",
|
1854
|
+
"result": {"content": [{"type": "text", "text": str(result)}]},
|
1855
|
+
"id": request_id,
|
1856
|
+
}
|
1857
|
+
except Exception as e:
|
1858
|
+
return {
|
1859
|
+
"jsonrpc": "2.0",
|
1860
|
+
"error": {"code": -32603, "message": f"Tool execution error: {str(e)}"},
|
1861
|
+
"id": request_id,
|
1862
|
+
}
|
1863
|
+
|
1864
|
+
async def _handle_list_resources(
|
1865
|
+
self, params: Dict[str, Any], request_id: Any
|
1866
|
+
) -> Dict[str, Any]:
|
1867
|
+
"""Handle resources/list request with cursor-based pagination."""
|
1868
|
+
cursor = params.get("cursor")
|
1869
|
+
limit = params.get("limit")
|
1870
|
+
|
1871
|
+
# Get all resources
|
1872
|
+
all_resources = []
|
1873
|
+
for uri, info in self._resource_registry.items():
|
1874
|
+
all_resources.append(
|
1875
|
+
{
|
1876
|
+
"uri": uri,
|
1877
|
+
"name": info.get("name", uri),
|
1878
|
+
"description": info.get("description", ""),
|
1879
|
+
"mimeType": info.get("mime_type", "text/plain"),
|
1880
|
+
}
|
1881
|
+
)
|
1882
|
+
|
1883
|
+
# Handle pagination if subscription manager is available
|
1884
|
+
if self.subscription_manager:
|
1885
|
+
cursor_manager = self.subscription_manager.cursor_manager
|
1886
|
+
|
1887
|
+
# Determine starting position
|
1888
|
+
start_pos = 0
|
1889
|
+
if cursor:
|
1890
|
+
if cursor_manager.is_valid(cursor):
|
1891
|
+
start_pos = cursor_manager.get_cursor_position(cursor) or 0
|
1892
|
+
else:
|
1893
|
+
return {
|
1894
|
+
"jsonrpc": "2.0",
|
1895
|
+
"error": {
|
1896
|
+
"code": -32602,
|
1897
|
+
"message": "Invalid or expired cursor",
|
1898
|
+
},
|
1899
|
+
"id": request_id,
|
1900
|
+
}
|
1901
|
+
|
1902
|
+
# Apply pagination
|
1903
|
+
if limit:
|
1904
|
+
end_pos = start_pos + limit
|
1905
|
+
resources = all_resources[start_pos:end_pos]
|
1906
|
+
|
1907
|
+
# Generate next cursor if there are more resources
|
1908
|
+
next_cursor = None
|
1909
|
+
if end_pos < len(all_resources):
|
1910
|
+
next_cursor = cursor_manager.create_cursor_for_position(
|
1911
|
+
all_resources, end_pos
|
1912
|
+
)
|
1913
|
+
|
1914
|
+
result = {"resources": resources}
|
1915
|
+
if next_cursor:
|
1916
|
+
result["nextCursor"] = next_cursor
|
1917
|
+
|
1918
|
+
return {"jsonrpc": "2.0", "result": result, "id": request_id}
|
1919
|
+
else:
|
1920
|
+
resources = all_resources[start_pos:]
|
1921
|
+
else:
|
1922
|
+
# No pagination support
|
1923
|
+
resources = all_resources
|
1924
|
+
|
1925
|
+
return {"jsonrpc": "2.0", "result": {"resources": resources}, "id": request_id}
|
1926
|
+
|
1927
|
+
async def _handle_read_resource(
|
1928
|
+
self, params: Dict[str, Any], request_id: Any, client_id: str = None
|
1929
|
+
) -> Dict[str, Any]:
|
1930
|
+
"""Handle resources/read request with change detection."""
|
1931
|
+
uri = params.get("uri")
|
1932
|
+
|
1933
|
+
# First try exact match
|
1934
|
+
resource_info = None
|
1935
|
+
resource_params = {}
|
1936
|
+
if uri in self._resource_registry:
|
1937
|
+
resource_info = self._resource_registry[uri]
|
1938
|
+
else:
|
1939
|
+
# Try template matching
|
1940
|
+
resource_info, resource_params = self._match_resource_template(uri)
|
1941
|
+
|
1942
|
+
if resource_info is None:
|
1943
|
+
return {
|
1944
|
+
"jsonrpc": "2.0",
|
1945
|
+
"error": {"code": -32602, "message": f"Resource not found: {uri}"},
|
1946
|
+
"id": request_id,
|
1947
|
+
}
|
1948
|
+
|
1949
|
+
try:
|
1950
|
+
handler = resource_info.get("handler")
|
1951
|
+
original_handler = resource_info.get("original_handler")
|
1952
|
+
|
1953
|
+
if handler:
|
1954
|
+
# Use original handler with parameters if available
|
1955
|
+
if original_handler and resource_params:
|
1956
|
+
content = original_handler(**resource_params)
|
1957
|
+
else:
|
1958
|
+
content = handler()
|
1959
|
+
if asyncio.iscoroutine(content):
|
1960
|
+
content = await content
|
1961
|
+
else:
|
1962
|
+
content = ""
|
1963
|
+
|
1964
|
+
# Process change detection if subscription manager is available
|
1965
|
+
if self.subscription_manager:
|
1966
|
+
resource_data = {
|
1967
|
+
"uri": uri,
|
1968
|
+
"text": str(content),
|
1969
|
+
"mimeType": resource_info.get("mime_type", "text/plain"),
|
1970
|
+
}
|
1971
|
+
|
1972
|
+
# Check for changes and notify subscribers
|
1973
|
+
change = (
|
1974
|
+
await self.subscription_manager.resource_monitor.check_for_changes(
|
1975
|
+
uri, resource_data
|
1976
|
+
)
|
1977
|
+
)
|
1978
|
+
|
1979
|
+
if change:
|
1980
|
+
await self.subscription_manager.process_resource_change(change)
|
1981
|
+
|
1982
|
+
return {
|
1983
|
+
"jsonrpc": "2.0",
|
1984
|
+
"result": {"contents": [{"uri": uri, "text": str(content)}]},
|
1985
|
+
"id": request_id,
|
1986
|
+
}
|
1987
|
+
except Exception as e:
|
1988
|
+
return {
|
1989
|
+
"jsonrpc": "2.0",
|
1990
|
+
"error": {"code": -32603, "message": f"Resource read error: {str(e)}"},
|
1991
|
+
"id": request_id,
|
1992
|
+
}
|
1993
|
+
|
1994
|
+
def _match_resource_template(self, uri: str) -> tuple:
|
1995
|
+
"""Match URI against resource templates and extract parameters."""
|
1996
|
+
import re
|
1997
|
+
|
1998
|
+
for template_uri, resource_info in self._resource_registry.items():
|
1999
|
+
# Convert template to regex pattern
|
2000
|
+
# Replace {param} with named capture groups
|
2001
|
+
pattern = re.sub(r"\{([^}]+)\}", r"(?P<\1>[^/]+)", template_uri)
|
2002
|
+
pattern = f"^{pattern}$"
|
2003
|
+
|
2004
|
+
match = re.match(pattern, uri)
|
2005
|
+
if match:
|
2006
|
+
# Extract parameters from the match
|
2007
|
+
params = match.groupdict()
|
2008
|
+
return resource_info, params
|
2009
|
+
|
2010
|
+
return None, {}
|
2011
|
+
|
2012
|
+
async def _handle_list_prompts(
|
2013
|
+
self, params: Dict[str, Any], request_id: Any
|
2014
|
+
) -> Dict[str, Any]:
|
2015
|
+
"""Handle prompts/list request."""
|
2016
|
+
prompts = []
|
2017
|
+
for name, info in self._prompt_registry.items():
|
2018
|
+
prompts.append(
|
2019
|
+
{
|
2020
|
+
"name": name,
|
2021
|
+
"description": info.get("description", ""),
|
2022
|
+
"arguments": info.get("arguments", []),
|
2023
|
+
}
|
2024
|
+
)
|
2025
|
+
|
2026
|
+
return {"jsonrpc": "2.0", "result": {"prompts": prompts}, "id": request_id}
|
2027
|
+
|
2028
|
+
async def _handle_get_prompt(
|
2029
|
+
self, params: Dict[str, Any], request_id: Any
|
2030
|
+
) -> Dict[str, Any]:
|
2031
|
+
"""Handle prompts/get request."""
|
2032
|
+
name = params.get("name")
|
2033
|
+
arguments = params.get("arguments", {})
|
2034
|
+
|
2035
|
+
if name not in self._prompt_registry:
|
2036
|
+
return {
|
2037
|
+
"jsonrpc": "2.0",
|
2038
|
+
"error": {"code": -32602, "message": f"Prompt not found: {name}"},
|
2039
|
+
"id": request_id,
|
2040
|
+
}
|
2041
|
+
|
2042
|
+
try:
|
2043
|
+
prompt_info = self._prompt_registry[name]
|
2044
|
+
handler = prompt_info.get("handler")
|
2045
|
+
|
2046
|
+
if handler:
|
2047
|
+
messages = handler(**arguments)
|
2048
|
+
if asyncio.iscoroutine(messages):
|
2049
|
+
messages = await messages
|
2050
|
+
else:
|
2051
|
+
messages = []
|
2052
|
+
|
2053
|
+
return {
|
2054
|
+
"jsonrpc": "2.0",
|
2055
|
+
"result": {"messages": messages},
|
2056
|
+
"id": request_id,
|
2057
|
+
}
|
2058
|
+
except Exception as e:
|
2059
|
+
return {
|
2060
|
+
"jsonrpc": "2.0",
|
2061
|
+
"error": {
|
2062
|
+
"code": -32603,
|
2063
|
+
"message": f"Prompt generation error: {str(e)}",
|
2064
|
+
},
|
2065
|
+
"id": request_id,
|
2066
|
+
}
|
2067
|
+
|
2068
|
+
async def _handle_subscribe(
|
2069
|
+
self, params: Dict[str, Any], request_id: Any, client_id: str
|
2070
|
+
) -> Dict[str, Any]:
|
2071
|
+
"""Handle resources/subscribe request with GraphQL-style field selection."""
|
2072
|
+
if not self.subscription_manager:
|
2073
|
+
return {
|
2074
|
+
"jsonrpc": "2.0",
|
2075
|
+
"error": {"code": -32601, "message": "Subscriptions not enabled"},
|
2076
|
+
"id": request_id,
|
2077
|
+
}
|
2078
|
+
|
2079
|
+
uri_pattern = params.get("uri")
|
2080
|
+
cursor = params.get("cursor")
|
2081
|
+
# Extract field selection parameters for GraphQL-style filtering
|
2082
|
+
fields = params.get("fields") # e.g., ["uri", "content.text", "metadata.size"]
|
2083
|
+
fragments = params.get("fragments") # e.g., {"basicInfo": ["uri", "name"]}
|
2084
|
+
|
2085
|
+
if not uri_pattern:
|
2086
|
+
return {
|
2087
|
+
"jsonrpc": "2.0",
|
2088
|
+
"error": {"code": -32602, "message": "Missing required parameter: uri"},
|
2089
|
+
"id": request_id,
|
2090
|
+
}
|
2091
|
+
|
2092
|
+
try:
|
2093
|
+
# Create subscription with auth context and field selection
|
2094
|
+
user_context = {"user_id": client_id, "connection_id": client_id}
|
2095
|
+
subscription_id = await self.subscription_manager.create_subscription(
|
2096
|
+
connection_id=client_id,
|
2097
|
+
uri_pattern=uri_pattern,
|
2098
|
+
cursor=cursor,
|
2099
|
+
user_context=user_context,
|
2100
|
+
fields=fields,
|
2101
|
+
fragments=fragments,
|
2102
|
+
)
|
2103
|
+
|
2104
|
+
return {
|
2105
|
+
"jsonrpc": "2.0",
|
2106
|
+
"result": {"subscriptionId": subscription_id},
|
2107
|
+
"id": request_id,
|
2108
|
+
}
|
2109
|
+
except Exception as e:
|
2110
|
+
error_code = -32603
|
2111
|
+
if "permission" in str(e).lower() or "not authorized" in str(e).lower():
|
2112
|
+
error_code = -32601
|
2113
|
+
elif "rate limit" in str(e).lower():
|
2114
|
+
error_code = -32601
|
2115
|
+
|
2116
|
+
return {
|
2117
|
+
"jsonrpc": "2.0",
|
2118
|
+
"error": {"code": error_code, "message": str(e)},
|
2119
|
+
"id": request_id,
|
2120
|
+
}
|
2121
|
+
|
2122
|
+
async def _handle_unsubscribe(
|
2123
|
+
self, params: Dict[str, Any], request_id: Any, client_id: str
|
2124
|
+
) -> Dict[str, Any]:
|
2125
|
+
"""Handle resources/unsubscribe request."""
|
2126
|
+
if not self.subscription_manager:
|
2127
|
+
return {
|
2128
|
+
"jsonrpc": "2.0",
|
2129
|
+
"error": {"code": -32601, "message": "Subscriptions not enabled"},
|
2130
|
+
"id": request_id,
|
2131
|
+
}
|
2132
|
+
|
2133
|
+
subscription_id = params.get("subscriptionId")
|
2134
|
+
|
2135
|
+
if not subscription_id:
|
2136
|
+
return {
|
2137
|
+
"jsonrpc": "2.0",
|
2138
|
+
"error": {
|
2139
|
+
"code": -32602,
|
2140
|
+
"message": "Missing required parameter: subscriptionId",
|
2141
|
+
},
|
2142
|
+
"id": request_id,
|
2143
|
+
}
|
2144
|
+
|
2145
|
+
try:
|
2146
|
+
success = await self.subscription_manager.remove_subscription(
|
2147
|
+
subscription_id, client_id
|
2148
|
+
)
|
2149
|
+
|
2150
|
+
return {
|
2151
|
+
"jsonrpc": "2.0",
|
2152
|
+
"result": {"success": success},
|
2153
|
+
"id": request_id,
|
2154
|
+
}
|
2155
|
+
except Exception as e:
|
2156
|
+
return {
|
2157
|
+
"jsonrpc": "2.0",
|
2158
|
+
"error": {"code": -32603, "message": str(e)},
|
2159
|
+
"id": request_id,
|
2160
|
+
}
|
2161
|
+
|
2162
|
+
async def _handle_batch_subscribe(
|
2163
|
+
self, params: Dict[str, Any], request_id: Any, client_id: str
|
2164
|
+
) -> Dict[str, Any]:
|
2165
|
+
"""Handle resources/batch_subscribe request."""
|
2166
|
+
if not self.subscription_manager:
|
2167
|
+
return {
|
2168
|
+
"jsonrpc": "2.0",
|
2169
|
+
"error": {"code": -32601, "message": "Subscriptions not enabled"},
|
2170
|
+
"id": request_id,
|
2171
|
+
}
|
2172
|
+
|
2173
|
+
subscriptions = params.get("subscriptions")
|
2174
|
+
if not subscriptions or not isinstance(subscriptions, list):
|
2175
|
+
return {
|
2176
|
+
"jsonrpc": "2.0",
|
2177
|
+
"error": {
|
2178
|
+
"code": -32602,
|
2179
|
+
"message": "Missing or invalid parameter: subscriptions",
|
2180
|
+
},
|
2181
|
+
"id": request_id,
|
2182
|
+
}
|
2183
|
+
|
2184
|
+
try:
|
2185
|
+
# Create batch subscriptions with auth context
|
2186
|
+
user_context = {"user_id": client_id, "connection_id": client_id}
|
2187
|
+
results = await self.subscription_manager.create_batch_subscriptions(
|
2188
|
+
subscriptions=subscriptions,
|
2189
|
+
connection_id=client_id,
|
2190
|
+
user_context=user_context,
|
2191
|
+
)
|
2192
|
+
|
2193
|
+
return {
|
2194
|
+
"jsonrpc": "2.0",
|
2195
|
+
"result": results,
|
2196
|
+
"id": request_id,
|
2197
|
+
}
|
2198
|
+
except Exception as e:
|
2199
|
+
return {
|
2200
|
+
"jsonrpc": "2.0",
|
2201
|
+
"error": {"code": -32603, "message": str(e)},
|
2202
|
+
"id": request_id,
|
2203
|
+
}
|
2204
|
+
|
2205
|
+
async def _handle_batch_unsubscribe(
|
2206
|
+
self, params: Dict[str, Any], request_id: Any, client_id: str
|
2207
|
+
) -> Dict[str, Any]:
|
2208
|
+
"""Handle resources/batch_unsubscribe request."""
|
2209
|
+
if not self.subscription_manager:
|
2210
|
+
return {
|
2211
|
+
"jsonrpc": "2.0",
|
2212
|
+
"error": {"code": -32601, "message": "Subscriptions not enabled"},
|
2213
|
+
"id": request_id,
|
2214
|
+
}
|
2215
|
+
|
2216
|
+
subscription_ids = params.get("subscriptionIds")
|
2217
|
+
if not subscription_ids or not isinstance(subscription_ids, list):
|
2218
|
+
return {
|
2219
|
+
"jsonrpc": "2.0",
|
2220
|
+
"error": {
|
2221
|
+
"code": -32602,
|
2222
|
+
"message": "Missing or invalid parameter: subscriptionIds",
|
2223
|
+
},
|
2224
|
+
"id": request_id,
|
2225
|
+
}
|
2226
|
+
|
2227
|
+
try:
|
2228
|
+
# Remove batch subscriptions
|
2229
|
+
results = await self.subscription_manager.remove_batch_subscriptions(
|
2230
|
+
subscription_ids=subscription_ids, connection_id=client_id
|
2231
|
+
)
|
2232
|
+
|
2233
|
+
return {
|
2234
|
+
"jsonrpc": "2.0",
|
2235
|
+
"result": results,
|
2236
|
+
"id": request_id,
|
2237
|
+
}
|
2238
|
+
except Exception as e:
|
2239
|
+
return {
|
2240
|
+
"jsonrpc": "2.0",
|
2241
|
+
"error": {"code": -32603, "message": str(e)},
|
2242
|
+
"id": request_id,
|
2243
|
+
}
|
2244
|
+
|
2245
|
+
async def _handle_connection_close(self, client_id: str):
|
2246
|
+
"""Handle WebSocket connection close."""
|
2247
|
+
if self.subscription_manager:
|
2248
|
+
removed_count = await self.subscription_manager.cleanup_connection(
|
2249
|
+
client_id
|
2250
|
+
)
|
2251
|
+
if removed_count > 0:
|
2252
|
+
logger.info(
|
2253
|
+
f"Cleaned up {removed_count} subscriptions for client {client_id}"
|
2254
|
+
)
|
2255
|
+
|
2256
|
+
def _compress_message(
|
2257
|
+
self, message: Dict[str, Any]
|
2258
|
+
) -> Union[Dict[str, Any], bytes]:
|
2259
|
+
"""Compress message if compression is enabled and message exceeds threshold.
|
2260
|
+
|
2261
|
+
Args:
|
2262
|
+
message: The message to potentially compress
|
2263
|
+
|
2264
|
+
Returns:
|
2265
|
+
Either the original dict or compressed bytes with metadata
|
2266
|
+
"""
|
2267
|
+
if not self.enable_websocket_compression:
|
2268
|
+
return message
|
2269
|
+
|
2270
|
+
# Serialize message to determine size
|
2271
|
+
message_json = json.dumps(message, separators=(",", ":")).encode("utf-8")
|
2272
|
+
|
2273
|
+
# Only compress if message exceeds threshold
|
2274
|
+
if len(message_json) < self.compression_threshold:
|
2275
|
+
return message
|
2276
|
+
|
2277
|
+
try:
|
2278
|
+
# Compress the message
|
2279
|
+
compressed_data = gzip.compress(
|
2280
|
+
message_json, compresslevel=self.compression_level
|
2281
|
+
)
|
2282
|
+
|
2283
|
+
# Calculate compression ratio
|
2284
|
+
compression_ratio = len(compressed_data) / len(message_json)
|
2285
|
+
|
2286
|
+
# Only use compression if it actually reduces size significantly
|
2287
|
+
if compression_ratio > 0.9: # Less than 10% improvement
|
2288
|
+
return message
|
2289
|
+
|
2290
|
+
# Return compressed message with metadata
|
2291
|
+
return {
|
2292
|
+
"__compressed": True,
|
2293
|
+
"__original_size": len(message_json),
|
2294
|
+
"__compressed_size": len(compressed_data),
|
2295
|
+
"__compression_ratio": compression_ratio,
|
2296
|
+
"data": compressed_data.hex(), # Hex encode for JSON transport
|
2297
|
+
}
|
2298
|
+
|
2299
|
+
except Exception as e:
|
2300
|
+
logger.warning(f"Failed to compress message: {e}")
|
2301
|
+
return message
|
2302
|
+
|
2303
|
+
def _decompress_message(self, compressed_message: Dict[str, Any]) -> Dict[str, Any]:
|
2304
|
+
"""Decompress a compressed message.
|
2305
|
+
|
2306
|
+
Args:
|
2307
|
+
compressed_message: The compressed message with metadata
|
2308
|
+
|
2309
|
+
Returns:
|
2310
|
+
The original decompressed message
|
2311
|
+
"""
|
2312
|
+
if not compressed_message.get("__compressed"):
|
2313
|
+
return compressed_message
|
2314
|
+
|
2315
|
+
try:
|
2316
|
+
# Decode hex data and decompress
|
2317
|
+
compressed_data = bytes.fromhex(compressed_message["data"])
|
2318
|
+
decompressed_json = gzip.decompress(compressed_data)
|
2319
|
+
|
2320
|
+
# Parse back to dict
|
2321
|
+
return json.loads(decompressed_json.decode("utf-8"))
|
2322
|
+
|
2323
|
+
except Exception as e:
|
2324
|
+
logger.error(f"Failed to decompress message: {e}")
|
2325
|
+
# Return a sensible error message
|
2326
|
+
return {
|
2327
|
+
"jsonrpc": "2.0",
|
2328
|
+
"error": {
|
2329
|
+
"code": -32603,
|
2330
|
+
"message": f"Failed to decompress message: {e}",
|
2331
|
+
},
|
2332
|
+
}
|
2333
|
+
|
2334
|
+
async def _send_websocket_notification(
|
2335
|
+
self, client_id: str, notification: Dict[str, Any]
|
2336
|
+
):
|
2337
|
+
"""Send notification to WebSocket client with optional compression."""
|
2338
|
+
if self._transport and hasattr(self._transport, "send_message"):
|
2339
|
+
try:
|
2340
|
+
# Apply compression if enabled
|
2341
|
+
message_to_send = self._compress_message(notification)
|
2342
|
+
|
2343
|
+
# Log compression stats if compression was applied
|
2344
|
+
if isinstance(message_to_send, dict) and message_to_send.get(
|
2345
|
+
"__compressed"
|
2346
|
+
):
|
2347
|
+
ratio = message_to_send["__compression_ratio"]
|
2348
|
+
logger.debug(
|
2349
|
+
f"Compressed notification for client {client_id}: "
|
2350
|
+
f"{message_to_send['__original_size']} -> "
|
2351
|
+
f"{message_to_send['__compressed_size']} bytes "
|
2352
|
+
f"({ratio:.2%} ratio)"
|
2353
|
+
)
|
2354
|
+
|
2355
|
+
await self._transport.send_message(message_to_send, client_id=client_id)
|
2356
|
+
logger.debug(
|
2357
|
+
f"Sent notification to client {client_id}: {notification['method']}"
|
2358
|
+
)
|
2359
|
+
except Exception as e:
|
2360
|
+
logger.error(f"Failed to send notification to client {client_id}: {e}")
|
2361
|
+
|
2362
|
+
async def _handle_logging_set_level(
|
2363
|
+
self, params: Dict[str, Any], request_id: Any
|
2364
|
+
) -> Dict[str, Any]:
|
2365
|
+
"""Handle logging/setLevel request to dynamically adjust log levels."""
|
2366
|
+
level = params.get("level", "INFO").upper()
|
2367
|
+
|
2368
|
+
# Validate log level
|
2369
|
+
valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
2370
|
+
if level not in valid_levels:
|
2371
|
+
return {
|
2372
|
+
"jsonrpc": "2.0",
|
2373
|
+
"error": {
|
2374
|
+
"code": -32602,
|
2375
|
+
"message": f"Invalid log level: {level}. Must be one of {valid_levels}",
|
2376
|
+
},
|
2377
|
+
"id": request_id,
|
2378
|
+
}
|
2379
|
+
|
2380
|
+
# Set the log level
|
2381
|
+
logging.getLogger().setLevel(getattr(logging, level))
|
2382
|
+
logger.info(f"Log level changed to {level}")
|
2383
|
+
|
2384
|
+
# Track in event store if available
|
2385
|
+
if self.event_store:
|
2386
|
+
from kailash.middleware.gateway.event_store import EventType
|
2387
|
+
|
2388
|
+
await self.event_store.append(
|
2389
|
+
event_type=EventType.REQUEST_COMPLETED,
|
2390
|
+
request_id=str(request_id),
|
2391
|
+
data={
|
2392
|
+
"type": "log_level_changed",
|
2393
|
+
"level": level,
|
2394
|
+
"timestamp": time.time(),
|
2395
|
+
"changed_by": params.get("client_id", "unknown"),
|
2396
|
+
},
|
2397
|
+
)
|
2398
|
+
|
2399
|
+
return {
|
2400
|
+
"jsonrpc": "2.0",
|
2401
|
+
"result": {"level": level, "levels": valid_levels},
|
2402
|
+
"id": request_id,
|
2403
|
+
}
|
2404
|
+
|
2405
|
+
async def _handle_roots_list(
|
2406
|
+
self, params: Dict[str, Any], request_id: Any
|
2407
|
+
) -> Dict[str, Any]:
|
2408
|
+
"""Handle roots/list request to get file system access roots."""
|
2409
|
+
protocol_mgr = get_protocol_manager()
|
2410
|
+
|
2411
|
+
# Check if client supports roots
|
2412
|
+
client_info = self.client_info.get(params.get("client_id", ""))
|
2413
|
+
if (
|
2414
|
+
not client_info.get("capabilities", {})
|
2415
|
+
.get("roots", {})
|
2416
|
+
.get("listChanged", False)
|
2417
|
+
):
|
2418
|
+
return {
|
2419
|
+
"jsonrpc": "2.0",
|
2420
|
+
"error": {
|
2421
|
+
"code": -32601,
|
2422
|
+
"message": "Client does not support roots capability",
|
2423
|
+
},
|
2424
|
+
"id": request_id,
|
2425
|
+
}
|
2426
|
+
|
2427
|
+
roots = protocol_mgr.roots.list_roots()
|
2428
|
+
|
2429
|
+
# Apply access control if auth manager is available
|
2430
|
+
if self.auth_manager and params.get("client_id"):
|
2431
|
+
filtered_roots = []
|
2432
|
+
for root in roots:
|
2433
|
+
if await protocol_mgr.roots.validate_access(
|
2434
|
+
root["uri"],
|
2435
|
+
operation="list",
|
2436
|
+
user_context=self.client_info.get(params["client_id"], {}),
|
2437
|
+
):
|
2438
|
+
filtered_roots.append(root)
|
2439
|
+
roots = filtered_roots
|
2440
|
+
|
2441
|
+
return {"jsonrpc": "2.0", "result": {"roots": roots}, "id": request_id}
|
2442
|
+
|
2443
|
+
async def _handle_completion_complete(
|
2444
|
+
self, params: Dict[str, Any], request_id: Any
|
2445
|
+
) -> Dict[str, Any]:
|
2446
|
+
"""Handle completion/complete request for auto-completion."""
|
2447
|
+
ref = params.get("ref", {})
|
2448
|
+
argument = params.get("argument", {})
|
2449
|
+
|
2450
|
+
# Extract completion parameters
|
2451
|
+
ref_type = ref.get("type") # "resource", "prompt", "tool"
|
2452
|
+
ref_name = ref.get("name") # Optional specific name
|
2453
|
+
partial_value = argument.get("value", "")
|
2454
|
+
|
2455
|
+
try:
|
2456
|
+
values = []
|
2457
|
+
|
2458
|
+
if ref_type == "resource":
|
2459
|
+
# Search through registered resources
|
2460
|
+
for uri, resource_info in self._resource_registry.items():
|
2461
|
+
if partial_value in uri: # Simple prefix/substring matching
|
2462
|
+
values.append(
|
2463
|
+
{
|
2464
|
+
"uri": uri,
|
2465
|
+
"name": resource_info.get("name", uri),
|
2466
|
+
"description": resource_info.get("description", ""),
|
2467
|
+
}
|
2468
|
+
)
|
2469
|
+
|
2470
|
+
elif ref_type == "prompt":
|
2471
|
+
# Search through registered prompts
|
2472
|
+
for name, prompt_info in self._prompt_registry.items():
|
2473
|
+
if partial_value in name: # Simple prefix/substring matching
|
2474
|
+
values.append(
|
2475
|
+
{
|
2476
|
+
"name": name,
|
2477
|
+
"description": prompt_info.get("description", ""),
|
2478
|
+
"arguments": prompt_info.get("arguments", []),
|
2479
|
+
}
|
2480
|
+
)
|
2481
|
+
|
2482
|
+
elif ref_type == "tool":
|
2483
|
+
# Search through registered tools
|
2484
|
+
for name, tool_info in self._tool_registry.items():
|
2485
|
+
if partial_value in name:
|
2486
|
+
values.append(
|
2487
|
+
{
|
2488
|
+
"name": name,
|
2489
|
+
"description": tool_info.get("description", ""),
|
2490
|
+
"inputSchema": tool_info.get("inputSchema", {}),
|
2491
|
+
}
|
2492
|
+
)
|
2493
|
+
|
2494
|
+
# Limit to 100 items and add hasMore flag if needed
|
2495
|
+
total_matches = len(values)
|
2496
|
+
has_more = total_matches > 100
|
2497
|
+
if has_more:
|
2498
|
+
values = values[:100]
|
2499
|
+
|
2500
|
+
result = {
|
2501
|
+
"completion": {
|
2502
|
+
"values": values,
|
2503
|
+
"total": total_matches,
|
2504
|
+
"hasMore": has_more,
|
2505
|
+
}
|
2506
|
+
}
|
2507
|
+
|
2508
|
+
return {"jsonrpc": "2.0", "result": result, "id": request_id}
|
2509
|
+
|
2510
|
+
except Exception as e:
|
2511
|
+
logger.error(f"Completion error: {e}")
|
2512
|
+
return {
|
2513
|
+
"jsonrpc": "2.0",
|
2514
|
+
"error": {"code": -32603, "message": f"Completion failed: {str(e)}"},
|
2515
|
+
"id": request_id,
|
2516
|
+
}
|
2517
|
+
|
2518
|
+
async def _handle_sampling_create_message(
|
2519
|
+
self, params: Dict[str, Any], request_id: Any
|
2520
|
+
) -> Dict[str, Any]:
|
2521
|
+
"""Handle sampling/createMessage - this is typically server-to-client."""
|
2522
|
+
# This is usually initiated by the server to request LLM sampling from the client
|
2523
|
+
# For server-side handling, we can validate and forward to connected clients
|
2524
|
+
|
2525
|
+
protocol_mgr = get_protocol_manager()
|
2526
|
+
|
2527
|
+
# Check if any client supports sampling
|
2528
|
+
sampling_clients = [
|
2529
|
+
client_id
|
2530
|
+
for client_id, info in self.client_info.items()
|
2531
|
+
if info.get("capabilities", {})
|
2532
|
+
.get("experimental", {})
|
2533
|
+
.get("sampling", False)
|
2534
|
+
]
|
2535
|
+
|
2536
|
+
if not sampling_clients:
|
2537
|
+
return {
|
2538
|
+
"jsonrpc": "2.0",
|
2539
|
+
"error": {
|
2540
|
+
"code": -32601,
|
2541
|
+
"message": "No connected clients support sampling",
|
2542
|
+
},
|
2543
|
+
"id": request_id,
|
2544
|
+
}
|
2545
|
+
|
2546
|
+
# Create sampling request
|
2547
|
+
messages = params.get("messages", [])
|
2548
|
+
sampling_params = {
|
2549
|
+
"messages": messages,
|
2550
|
+
"model_preferences": params.get("modelPreferences"),
|
2551
|
+
"system_prompt": params.get("systemPrompt"),
|
2552
|
+
"temperature": params.get("temperature"),
|
2553
|
+
"max_tokens": params.get("maxTokens"),
|
2554
|
+
"metadata": params.get("metadata"),
|
2555
|
+
}
|
2556
|
+
|
2557
|
+
# Send to first available sampling client (or implement selection logic)
|
2558
|
+
target_client = sampling_clients[0]
|
2559
|
+
|
2560
|
+
# Create server-to-client request
|
2561
|
+
sampling_request = {
|
2562
|
+
"jsonrpc": "2.0",
|
2563
|
+
"method": "sampling/createMessage",
|
2564
|
+
"params": sampling_params,
|
2565
|
+
"id": f"sampling_{uuid.uuid4().hex[:8]}",
|
2566
|
+
}
|
2567
|
+
|
2568
|
+
# Send via WebSocket to client
|
2569
|
+
if self._transport and hasattr(self._transport, "send_message"):
|
2570
|
+
await self._transport.send_message(
|
2571
|
+
sampling_request, client_id=target_client
|
2572
|
+
)
|
2573
|
+
|
2574
|
+
# Store pending sampling request
|
2575
|
+
if not hasattr(self, "_pending_sampling_requests"):
|
2576
|
+
self._pending_sampling_requests = {}
|
2577
|
+
|
2578
|
+
self._pending_sampling_requests[sampling_request["id"]] = {
|
2579
|
+
"original_request_id": request_id,
|
2580
|
+
"client_id": params.get("client_id"),
|
2581
|
+
"timestamp": time.time(),
|
2582
|
+
}
|
2583
|
+
|
2584
|
+
return {
|
2585
|
+
"jsonrpc": "2.0",
|
2586
|
+
"result": {
|
2587
|
+
"status": "sampling_requested",
|
2588
|
+
"sampling_id": sampling_request["id"],
|
2589
|
+
"target_client": target_client,
|
2590
|
+
},
|
2591
|
+
"id": request_id,
|
2592
|
+
}
|
2593
|
+
else:
|
2594
|
+
return {
|
2595
|
+
"jsonrpc": "2.0",
|
2596
|
+
"error": {
|
2597
|
+
"code": -32603,
|
2598
|
+
"message": "Transport does not support sampling",
|
2599
|
+
},
|
2600
|
+
"id": request_id,
|
2601
|
+
}
|
2602
|
+
|
1530
2603
|
async def run_stdio(self):
|
1531
2604
|
"""Run the server using stdio transport for testing."""
|
1532
2605
|
if self._mcp is None:
|