kailash 0.8.5__py3-none-any.whl → 0.8.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. kailash/__init__.py +5 -5
  2. kailash/channels/__init__.py +2 -1
  3. kailash/channels/mcp_channel.py +23 -4
  4. kailash/cli/validate_imports.py +202 -0
  5. kailash/core/resilience/bulkhead.py +15 -5
  6. kailash/core/resilience/circuit_breaker.py +4 -1
  7. kailash/core/resilience/health_monitor.py +312 -84
  8. kailash/edge/migration/edge_migration_service.py +384 -0
  9. kailash/mcp_server/protocol.py +26 -0
  10. kailash/mcp_server/server.py +1081 -8
  11. kailash/mcp_server/subscriptions.py +1560 -0
  12. kailash/mcp_server/transports.py +305 -0
  13. kailash/middleware/gateway/event_store.py +1 -0
  14. kailash/nodes/base.py +77 -1
  15. kailash/nodes/code/python.py +44 -3
  16. kailash/nodes/data/async_sql.py +42 -20
  17. kailash/nodes/edge/edge_migration_node.py +16 -12
  18. kailash/nodes/governance.py +410 -0
  19. kailash/nodes/rag/registry.py +1 -1
  20. kailash/nodes/transaction/distributed_transaction_manager.py +48 -1
  21. kailash/nodes/transaction/saga_state_storage.py +2 -1
  22. kailash/nodes/validation.py +8 -8
  23. kailash/runtime/local.py +30 -0
  24. kailash/runtime/validation/__init__.py +7 -15
  25. kailash/runtime/validation/import_validator.py +446 -0
  26. kailash/runtime/validation/suggestion_engine.py +5 -5
  27. kailash/utils/data_paths.py +74 -0
  28. kailash/workflow/builder.py +183 -4
  29. kailash/workflow/mermaid_visualizer.py +3 -1
  30. kailash/workflow/templates.py +6 -6
  31. kailash/workflow/validation.py +134 -3
  32. {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/METADATA +20 -17
  33. {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/RECORD +37 -31
  34. {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/WHEEL +0 -0
  35. {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/entry_points.txt +0 -0
  36. {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/licenses/LICENSE +0 -0
  37. {kailash-0.8.5.dist-info → kailash-0.8.7.dist-info}/top_level.txt +0 -0
@@ -54,6 +54,8 @@ Enhanced Production Usage:
54
54
 
55
55
  import asyncio
56
56
  import functools
57
+ import gzip
58
+ import json
57
59
  import logging
58
60
  import time
59
61
  import uuid
@@ -74,6 +76,7 @@ from .errors import (
74
76
  RetryableOperation,
75
77
  ToolError,
76
78
  )
79
+ from .protocol import get_protocol_manager
77
80
  from .utils import CacheManager, ConfigManager, MetricsCollector, format_response
78
81
 
79
82
  logger = logging.getLogger(__name__)
@@ -345,6 +348,11 @@ class MCPServer:
345
348
  self,
346
349
  name: str,
347
350
  config_file: Optional[Union[str, Path]] = None,
351
+ # Transport configuration
352
+ transport: str = "stdio", # "stdio", "websocket", "http", "sse"
353
+ websocket_host: str = "0.0.0.0",
354
+ websocket_port: int = 3001,
355
+ # Caching configuration
348
356
  enable_cache: bool = True,
349
357
  cache_ttl: int = 300,
350
358
  cache_backend: str = "memory", # "memory" or "redis"
@@ -364,6 +372,13 @@ class MCPServer:
364
372
  transport_timeout: float = 30.0,
365
373
  max_request_size: int = 10_000_000, # 10MB
366
374
  enable_streaming: bool = False,
375
+ # Resource subscription configuration
376
+ enable_subscriptions: bool = True,
377
+ event_store=None,
378
+ # WebSocket compression configuration
379
+ enable_websocket_compression: bool = False,
380
+ compression_threshold: int = 1024, # Only compress messages larger than 1KB
381
+ compression_level: int = 6, # 1 (fastest) to 9 (best compression)
367
382
  ):
368
383
  """
369
384
  Initialize enhanced MCP server.
@@ -371,6 +386,9 @@ class MCPServer:
371
386
  Args:
372
387
  name: Server name
373
388
  config_file: Optional configuration file path
389
+ transport: Transport to use ("stdio", "websocket", "http", "sse")
390
+ websocket_host: Host for WebSocket server (default: "0.0.0.0")
391
+ websocket_port: Port for WebSocket server (default: 3001)
374
392
  enable_cache: Whether to enable caching (default: True)
375
393
  cache_ttl: Default cache TTL in seconds (default: 300)
376
394
  cache_backend: Cache backend ("memory" or "redis")
@@ -388,9 +406,24 @@ class MCPServer:
388
406
  transport_timeout: Transport timeout in seconds
389
407
  max_request_size: Maximum request size in bytes
390
408
  enable_streaming: Enable streaming support
409
+ enable_subscriptions: Enable resource subscriptions (default: True)
410
+ event_store: Optional event store for subscription logging
411
+ enable_websocket_compression: Enable gzip compression for WebSocket messages (default: False)
412
+ compression_threshold: Only compress messages larger than this size in bytes (default: 1024)
413
+ compression_level: Compression level from 1 (fastest) to 9 (best compression) (default: 6)
391
414
  """
392
415
  self.name = name
393
416
 
417
+ # Transport configuration
418
+ self.transport = transport
419
+ self.websocket_host = websocket_host
420
+ self.websocket_port = websocket_port
421
+
422
+ # WebSocket compression configuration
423
+ self.enable_websocket_compression = enable_websocket_compression
424
+ self.compression_threshold = compression_threshold
425
+ self.compression_level = compression_level
426
+
394
427
  # Enhanced features
395
428
  self.auth_provider = auth_provider
396
429
  self.enable_http_transport = enable_http_transport
@@ -410,7 +443,9 @@ class MCPServer:
410
443
  "server": {
411
444
  "name": name,
412
445
  "version": "1.0.0",
413
- "transport": "stdio",
446
+ "transport": transport,
447
+ "websocket_host": websocket_host,
448
+ "websocket_port": websocket_port,
414
449
  "enable_http": enable_http_transport,
415
450
  "enable_sse": enable_sse_transport,
416
451
  "timeout": transport_timeout,
@@ -500,6 +535,30 @@ class MCPServer:
500
535
  self._resource_registry: Dict[str, Dict[str, Any]] = {}
501
536
  self._prompt_registry: Dict[str, Dict[str, Any]] = {}
502
537
 
538
+ # Client management for new handlers
539
+ self.client_info: Dict[str, Dict[str, Any]] = {}
540
+ self._pending_sampling_requests: Dict[str, Dict[str, Any]] = {}
541
+
542
+ # Resource subscription support
543
+ self.enable_subscriptions = enable_subscriptions
544
+ self.event_store = event_store
545
+ self.subscription_manager = None
546
+ if self.enable_subscriptions:
547
+ from .subscriptions import ResourceSubscriptionManager
548
+
549
+ self.subscription_manager = ResourceSubscriptionManager(
550
+ auth_manager=(
551
+ self.auth_manager if hasattr(self, "auth_manager") else None
552
+ ),
553
+ event_store=event_store,
554
+ rate_limiter=(
555
+ self.rate_limiter if hasattr(self, "rate_limiter") else None
556
+ ),
557
+ )
558
+
559
+ # Transport instance (for WebSocket and other transports)
560
+ self._transport = None
561
+
503
562
  def _init_mcp(self):
504
563
  """Initialize FastMCP server."""
505
564
  if self._mcp is not None:
@@ -1207,10 +1266,24 @@ class MCPServer:
1207
1266
  self._init_mcp()
1208
1267
 
1209
1268
  # Wrap with metrics if enabled
1269
+ wrapped_func = func
1210
1270
  if self.metrics.enabled:
1211
- func = self.metrics.track_tool(f"resource:{uri}")(func)
1271
+ wrapped_func = self.metrics.track_tool(f"resource:{uri}")(func)
1212
1272
 
1213
- return self._mcp.resource(uri)(func)
1273
+ # Register with FastMCP
1274
+ mcp_resource = self._mcp.resource(uri)(wrapped_func)
1275
+
1276
+ # Track in registry
1277
+ self._resource_registry[uri] = {
1278
+ "handler": mcp_resource,
1279
+ "original_handler": func,
1280
+ "name": uri,
1281
+ "description": func.__doc__ or f"Resource: {uri}",
1282
+ "mime_type": "text/plain",
1283
+ "created_at": time.time(),
1284
+ }
1285
+
1286
+ return mcp_resource
1214
1287
 
1215
1288
  return decorator
1216
1289
 
@@ -1230,10 +1303,23 @@ class MCPServer:
1230
1303
  self._init_mcp()
1231
1304
 
1232
1305
  # Wrap with metrics if enabled
1306
+ wrapped_func = func
1233
1307
  if self.metrics.enabled:
1234
- func = self.metrics.track_tool(f"prompt:{name}")(func)
1308
+ wrapped_func = self.metrics.track_tool(f"prompt:{name}")(func)
1235
1309
 
1236
- return self._mcp.prompt(name)(func)
1310
+ # Register with FastMCP
1311
+ mcp_prompt = self._mcp.prompt(name)(wrapped_func)
1312
+
1313
+ # Track in registry
1314
+ self._prompt_registry[name] = {
1315
+ "handler": mcp_prompt,
1316
+ "original_handler": func,
1317
+ "description": func.__doc__ or f"Prompt: {name}",
1318
+ "arguments": [], # Could be extracted from function signature
1319
+ "created_at": time.time(),
1320
+ }
1321
+
1322
+ return mcp_prompt
1237
1323
 
1238
1324
  return decorator
1239
1325
 
@@ -1454,6 +1540,47 @@ class MCPServer:
1454
1540
  return True
1455
1541
  return False
1456
1542
 
1543
+ def _execute_tool(self, tool_name: str, arguments: dict) -> Any:
1544
+ """Execute a tool directly (for testing purposes)."""
1545
+ if tool_name not in self._tool_registry:
1546
+ raise ValueError(f"Tool '{tool_name}' not found in registry")
1547
+
1548
+ tool_info = self._tool_registry[tool_name]
1549
+ if tool_info.get("disabled", False):
1550
+ raise ValueError(f"Tool '{tool_name}' is currently disabled")
1551
+
1552
+ # Get the tool handler (the enhanced function)
1553
+ if "handler" in tool_info:
1554
+ handler = tool_info["handler"]
1555
+ elif "function" in tool_info:
1556
+ handler = tool_info["function"]
1557
+ else:
1558
+ raise ValueError(f"Tool '{tool_name}' has no valid handler")
1559
+
1560
+ # Update statistics
1561
+ tool_info["call_count"] = tool_info.get("call_count", 0) + 1
1562
+ tool_info["last_called"] = time.time()
1563
+
1564
+ try:
1565
+ # Execute the tool
1566
+ if asyncio.iscoroutinefunction(handler):
1567
+ # For async functions, we need to run in event loop
1568
+ try:
1569
+ loop = asyncio.get_event_loop()
1570
+ if loop.is_running():
1571
+ # Already in async context - create task
1572
+ return asyncio.create_task(handler(**arguments))
1573
+ else:
1574
+ return loop.run_until_complete(handler(**arguments))
1575
+ except RuntimeError:
1576
+ # No event loop - create new one
1577
+ return asyncio.run(handler(**arguments))
1578
+ else:
1579
+ return handler(**arguments)
1580
+ except Exception as e:
1581
+ tool_info["error_count"] = tool_info.get("error_count", 0) + 1
1582
+ raise
1583
+
1457
1584
  def run(self):
1458
1585
  """Run the enhanced MCP server with all features."""
1459
1586
  if self._mcp is None:
@@ -1464,6 +1591,7 @@ class MCPServer:
1464
1591
 
1465
1592
  # Log enhanced server startup
1466
1593
  logger.info(f"Starting enhanced MCP server: {self.name}")
1594
+ logger.info(f"Transport: {self.transport}")
1467
1595
  logger.info("Features enabled:")
1468
1596
  logger.info(f" - Cache: {self.cache.enabled if self.cache else False}")
1469
1597
  logger.info(f" - Metrics: {self.metrics.enabled if self.metrics else False}")
@@ -1490,9 +1618,16 @@ class MCPServer:
1490
1618
  if health["status"] != "healthy":
1491
1619
  logger.warning(f"Server health check shows issues: {health['issues']}")
1492
1620
 
1493
- # Run the FastMCP server
1494
- logger.info("Starting FastMCP server...")
1495
- self._mcp.run()
1621
+ # Run server based on transport type
1622
+ if self.transport == "websocket":
1623
+ logger.info(
1624
+ f"Starting WebSocket server on {self.websocket_host}:{self.websocket_port}..."
1625
+ )
1626
+ asyncio.run(self._run_websocket())
1627
+ else:
1628
+ # Default to FastMCP (STDIO) server
1629
+ logger.info("Starting FastMCP server in STDIO mode...")
1630
+ self._mcp.run()
1496
1631
 
1497
1632
  except KeyboardInterrupt:
1498
1633
  logger.info("Server stopped by user")
@@ -1527,6 +1662,944 @@ class MCPServer:
1527
1662
  self._running = False
1528
1663
  logger.info(f"Enhanced MCP server '{self.name}' stopped")
1529
1664
 
1665
+ async def _run_websocket(self):
1666
+ """Run the server using WebSocket transport."""
1667
+ from .transports import WebSocketServerTransport
1668
+
1669
+ try:
1670
+ # Create WebSocket transport
1671
+ self._transport = WebSocketServerTransport(
1672
+ host=self.websocket_host,
1673
+ port=self.websocket_port,
1674
+ message_handler=self._handle_websocket_message,
1675
+ auth_provider=self.auth_provider,
1676
+ timeout=self.transport_timeout,
1677
+ max_message_size=self.max_request_size,
1678
+ enable_metrics=self.metrics.enabled if self.metrics else False,
1679
+ )
1680
+
1681
+ # Start WebSocket server
1682
+ await self._transport.connect()
1683
+ logger.info(
1684
+ f"WebSocket server started on {self.websocket_host}:{self.websocket_port}"
1685
+ )
1686
+
1687
+ # Set up subscription notification callback
1688
+ if self.subscription_manager:
1689
+ await self.subscription_manager.initialize()
1690
+ self.subscription_manager.set_notification_callback(
1691
+ self._send_websocket_notification
1692
+ )
1693
+
1694
+ # Keep server running
1695
+ try:
1696
+ await asyncio.Future() # Run forever
1697
+ except asyncio.CancelledError:
1698
+ logger.info("WebSocket server cancelled")
1699
+
1700
+ finally:
1701
+ # Clean up
1702
+ if self._transport:
1703
+ await self._transport.disconnect()
1704
+ self._transport = None
1705
+
1706
+ async def _handle_websocket_message(
1707
+ self, request: Dict[str, Any], client_id: str
1708
+ ) -> Dict[str, Any]:
1709
+ """Handle incoming WebSocket message with decompression support."""
1710
+ try:
1711
+ # Decompress message if needed
1712
+ decompressed_request = self._decompress_message(request)
1713
+
1714
+ method = decompressed_request.get("method", "")
1715
+ params = decompressed_request.get("params", {})
1716
+ request_id = decompressed_request.get("id")
1717
+
1718
+ # Log request
1719
+ logger.debug(f"WebSocket request from {client_id}: {method}")
1720
+
1721
+ # Route to appropriate handler
1722
+ if method == "initialize":
1723
+ return await self._handle_initialize(params, request_id, client_id)
1724
+ elif method == "tools/list":
1725
+ return await self._handle_list_tools(params, request_id)
1726
+ elif method == "tools/call":
1727
+ return await self._handle_call_tool(params, request_id)
1728
+ elif method == "resources/list":
1729
+ return await self._handle_list_resources(params, request_id)
1730
+ elif method == "resources/read":
1731
+ return await self._handle_read_resource(params, request_id, client_id)
1732
+ elif method == "resources/subscribe":
1733
+ return await self._handle_subscribe(params, request_id, client_id)
1734
+ elif method == "resources/unsubscribe":
1735
+ return await self._handle_unsubscribe(params, request_id, client_id)
1736
+ elif method == "resources/batch_subscribe":
1737
+ return await self._handle_batch_subscribe(params, request_id, client_id)
1738
+ elif method == "resources/batch_unsubscribe":
1739
+ return await self._handle_batch_unsubscribe(
1740
+ params, request_id, client_id
1741
+ )
1742
+ elif method == "prompts/list":
1743
+ return await self._handle_list_prompts(params, request_id)
1744
+ elif method == "prompts/get":
1745
+ return await self._handle_get_prompt(params, request_id)
1746
+ elif method == "logging/setLevel":
1747
+ return await self._handle_logging_set_level(params, request_id)
1748
+ elif method == "roots/list":
1749
+ # Add client_id to params for roots/list handler
1750
+ params_with_client = {**params, "client_id": client_id}
1751
+ return await self._handle_roots_list(params_with_client, request_id)
1752
+ elif method == "completion/complete":
1753
+ return await self._handle_completion_complete(params, request_id)
1754
+ elif method == "sampling/createMessage":
1755
+ # Add client_id to params for sampling handler
1756
+ params_with_client = {**params, "client_id": client_id}
1757
+ return await self._handle_sampling_create_message(
1758
+ params_with_client, request_id
1759
+ )
1760
+ else:
1761
+ return {
1762
+ "jsonrpc": "2.0",
1763
+ "error": {"code": -32601, "message": f"Method not found: {method}"},
1764
+ "id": request_id,
1765
+ }
1766
+
1767
+ except Exception as e:
1768
+ logger.error(f"Error handling WebSocket message: {e}")
1769
+ return {
1770
+ "jsonrpc": "2.0",
1771
+ "error": {"code": -32603, "message": f"Internal error: {str(e)}"},
1772
+ "id": request.get("id"),
1773
+ }
1774
+
1775
+ async def _handle_initialize(
1776
+ self, params: Dict[str, Any], request_id: Any, client_id: str = None
1777
+ ) -> Dict[str, Any]:
1778
+ """Handle initialize request."""
1779
+ # Store client information for capability checks
1780
+ if client_id:
1781
+ self.client_info[client_id] = {
1782
+ "capabilities": params.get("capabilities", {}),
1783
+ "name": params.get("clientInfo", {}).get("name", "unknown"),
1784
+ "version": params.get("clientInfo", {}).get("version", "unknown"),
1785
+ "initialized_at": time.time(),
1786
+ }
1787
+
1788
+ return {
1789
+ "jsonrpc": "2.0",
1790
+ "result": {
1791
+ "protocolVersion": "2024-11-05",
1792
+ "capabilities": {
1793
+ "tools": {"listSupported": True, "callSupported": True},
1794
+ "resources": {
1795
+ "listSupported": True,
1796
+ "readSupported": True,
1797
+ "subscribe": self.enable_subscriptions,
1798
+ "listChanged": self.enable_subscriptions,
1799
+ "batch_subscribe": self.enable_subscriptions,
1800
+ "batch_unsubscribe": self.enable_subscriptions,
1801
+ },
1802
+ "prompts": {"listSupported": True, "getSupported": True},
1803
+ "logging": {"setLevel": True},
1804
+ "roots": {"list": True},
1805
+ "experimental": {
1806
+ "progressNotifications": True,
1807
+ "cancellation": True,
1808
+ "completion": True,
1809
+ "sampling": True,
1810
+ "websocketCompression": self.enable_websocket_compression,
1811
+ },
1812
+ },
1813
+ "serverInfo": {
1814
+ "name": self.name,
1815
+ "version": self.config.get("server.version", "1.0.0"),
1816
+ },
1817
+ },
1818
+ "id": request_id,
1819
+ }
1820
+
1821
+ async def _handle_list_tools(
1822
+ self, params: Dict[str, Any], request_id: Any
1823
+ ) -> Dict[str, Any]:
1824
+ """Handle tools/list request."""
1825
+ tools = []
1826
+ for name, info in self._tool_registry.items():
1827
+ if not info.get("disabled", False):
1828
+ tools.append(
1829
+ {
1830
+ "name": name,
1831
+ "description": info.get("description", ""),
1832
+ "inputSchema": info.get("input_schema", {}),
1833
+ }
1834
+ )
1835
+
1836
+ return {"jsonrpc": "2.0", "result": {"tools": tools}, "id": request_id}
1837
+
1838
+ async def _handle_call_tool(
1839
+ self, params: Dict[str, Any], request_id: Any
1840
+ ) -> Dict[str, Any]:
1841
+ """Handle tools/call request."""
1842
+ tool_name = params.get("name")
1843
+ arguments = params.get("arguments", {})
1844
+
1845
+ try:
1846
+ result = self._execute_tool(tool_name, arguments)
1847
+
1848
+ # Handle async results
1849
+ if asyncio.iscoroutine(result) or asyncio.isfuture(result):
1850
+ result = await result
1851
+
1852
+ return {
1853
+ "jsonrpc": "2.0",
1854
+ "result": {"content": [{"type": "text", "text": str(result)}]},
1855
+ "id": request_id,
1856
+ }
1857
+ except Exception as e:
1858
+ return {
1859
+ "jsonrpc": "2.0",
1860
+ "error": {"code": -32603, "message": f"Tool execution error: {str(e)}"},
1861
+ "id": request_id,
1862
+ }
1863
+
1864
+ async def _handle_list_resources(
1865
+ self, params: Dict[str, Any], request_id: Any
1866
+ ) -> Dict[str, Any]:
1867
+ """Handle resources/list request with cursor-based pagination."""
1868
+ cursor = params.get("cursor")
1869
+ limit = params.get("limit")
1870
+
1871
+ # Get all resources
1872
+ all_resources = []
1873
+ for uri, info in self._resource_registry.items():
1874
+ all_resources.append(
1875
+ {
1876
+ "uri": uri,
1877
+ "name": info.get("name", uri),
1878
+ "description": info.get("description", ""),
1879
+ "mimeType": info.get("mime_type", "text/plain"),
1880
+ }
1881
+ )
1882
+
1883
+ # Handle pagination if subscription manager is available
1884
+ if self.subscription_manager:
1885
+ cursor_manager = self.subscription_manager.cursor_manager
1886
+
1887
+ # Determine starting position
1888
+ start_pos = 0
1889
+ if cursor:
1890
+ if cursor_manager.is_valid(cursor):
1891
+ start_pos = cursor_manager.get_cursor_position(cursor) or 0
1892
+ else:
1893
+ return {
1894
+ "jsonrpc": "2.0",
1895
+ "error": {
1896
+ "code": -32602,
1897
+ "message": "Invalid or expired cursor",
1898
+ },
1899
+ "id": request_id,
1900
+ }
1901
+
1902
+ # Apply pagination
1903
+ if limit:
1904
+ end_pos = start_pos + limit
1905
+ resources = all_resources[start_pos:end_pos]
1906
+
1907
+ # Generate next cursor if there are more resources
1908
+ next_cursor = None
1909
+ if end_pos < len(all_resources):
1910
+ next_cursor = cursor_manager.create_cursor_for_position(
1911
+ all_resources, end_pos
1912
+ )
1913
+
1914
+ result = {"resources": resources}
1915
+ if next_cursor:
1916
+ result["nextCursor"] = next_cursor
1917
+
1918
+ return {"jsonrpc": "2.0", "result": result, "id": request_id}
1919
+ else:
1920
+ resources = all_resources[start_pos:]
1921
+ else:
1922
+ # No pagination support
1923
+ resources = all_resources
1924
+
1925
+ return {"jsonrpc": "2.0", "result": {"resources": resources}, "id": request_id}
1926
+
1927
+ async def _handle_read_resource(
1928
+ self, params: Dict[str, Any], request_id: Any, client_id: str = None
1929
+ ) -> Dict[str, Any]:
1930
+ """Handle resources/read request with change detection."""
1931
+ uri = params.get("uri")
1932
+
1933
+ # First try exact match
1934
+ resource_info = None
1935
+ resource_params = {}
1936
+ if uri in self._resource_registry:
1937
+ resource_info = self._resource_registry[uri]
1938
+ else:
1939
+ # Try template matching
1940
+ resource_info, resource_params = self._match_resource_template(uri)
1941
+
1942
+ if resource_info is None:
1943
+ return {
1944
+ "jsonrpc": "2.0",
1945
+ "error": {"code": -32602, "message": f"Resource not found: {uri}"},
1946
+ "id": request_id,
1947
+ }
1948
+
1949
+ try:
1950
+ handler = resource_info.get("handler")
1951
+ original_handler = resource_info.get("original_handler")
1952
+
1953
+ if handler:
1954
+ # Use original handler with parameters if available
1955
+ if original_handler and resource_params:
1956
+ content = original_handler(**resource_params)
1957
+ else:
1958
+ content = handler()
1959
+ if asyncio.iscoroutine(content):
1960
+ content = await content
1961
+ else:
1962
+ content = ""
1963
+
1964
+ # Process change detection if subscription manager is available
1965
+ if self.subscription_manager:
1966
+ resource_data = {
1967
+ "uri": uri,
1968
+ "text": str(content),
1969
+ "mimeType": resource_info.get("mime_type", "text/plain"),
1970
+ }
1971
+
1972
+ # Check for changes and notify subscribers
1973
+ change = (
1974
+ await self.subscription_manager.resource_monitor.check_for_changes(
1975
+ uri, resource_data
1976
+ )
1977
+ )
1978
+
1979
+ if change:
1980
+ await self.subscription_manager.process_resource_change(change)
1981
+
1982
+ return {
1983
+ "jsonrpc": "2.0",
1984
+ "result": {"contents": [{"uri": uri, "text": str(content)}]},
1985
+ "id": request_id,
1986
+ }
1987
+ except Exception as e:
1988
+ return {
1989
+ "jsonrpc": "2.0",
1990
+ "error": {"code": -32603, "message": f"Resource read error: {str(e)}"},
1991
+ "id": request_id,
1992
+ }
1993
+
1994
+ def _match_resource_template(self, uri: str) -> tuple:
1995
+ """Match URI against resource templates and extract parameters."""
1996
+ import re
1997
+
1998
+ for template_uri, resource_info in self._resource_registry.items():
1999
+ # Convert template to regex pattern
2000
+ # Replace {param} with named capture groups
2001
+ pattern = re.sub(r"\{([^}]+)\}", r"(?P<\1>[^/]+)", template_uri)
2002
+ pattern = f"^{pattern}$"
2003
+
2004
+ match = re.match(pattern, uri)
2005
+ if match:
2006
+ # Extract parameters from the match
2007
+ params = match.groupdict()
2008
+ return resource_info, params
2009
+
2010
+ return None, {}
2011
+
2012
+ async def _handle_list_prompts(
2013
+ self, params: Dict[str, Any], request_id: Any
2014
+ ) -> Dict[str, Any]:
2015
+ """Handle prompts/list request."""
2016
+ prompts = []
2017
+ for name, info in self._prompt_registry.items():
2018
+ prompts.append(
2019
+ {
2020
+ "name": name,
2021
+ "description": info.get("description", ""),
2022
+ "arguments": info.get("arguments", []),
2023
+ }
2024
+ )
2025
+
2026
+ return {"jsonrpc": "2.0", "result": {"prompts": prompts}, "id": request_id}
2027
+
2028
+ async def _handle_get_prompt(
2029
+ self, params: Dict[str, Any], request_id: Any
2030
+ ) -> Dict[str, Any]:
2031
+ """Handle prompts/get request."""
2032
+ name = params.get("name")
2033
+ arguments = params.get("arguments", {})
2034
+
2035
+ if name not in self._prompt_registry:
2036
+ return {
2037
+ "jsonrpc": "2.0",
2038
+ "error": {"code": -32602, "message": f"Prompt not found: {name}"},
2039
+ "id": request_id,
2040
+ }
2041
+
2042
+ try:
2043
+ prompt_info = self._prompt_registry[name]
2044
+ handler = prompt_info.get("handler")
2045
+
2046
+ if handler:
2047
+ messages = handler(**arguments)
2048
+ if asyncio.iscoroutine(messages):
2049
+ messages = await messages
2050
+ else:
2051
+ messages = []
2052
+
2053
+ return {
2054
+ "jsonrpc": "2.0",
2055
+ "result": {"messages": messages},
2056
+ "id": request_id,
2057
+ }
2058
+ except Exception as e:
2059
+ return {
2060
+ "jsonrpc": "2.0",
2061
+ "error": {
2062
+ "code": -32603,
2063
+ "message": f"Prompt generation error: {str(e)}",
2064
+ },
2065
+ "id": request_id,
2066
+ }
2067
+
2068
+ async def _handle_subscribe(
2069
+ self, params: Dict[str, Any], request_id: Any, client_id: str
2070
+ ) -> Dict[str, Any]:
2071
+ """Handle resources/subscribe request with GraphQL-style field selection."""
2072
+ if not self.subscription_manager:
2073
+ return {
2074
+ "jsonrpc": "2.0",
2075
+ "error": {"code": -32601, "message": "Subscriptions not enabled"},
2076
+ "id": request_id,
2077
+ }
2078
+
2079
+ uri_pattern = params.get("uri")
2080
+ cursor = params.get("cursor")
2081
+ # Extract field selection parameters for GraphQL-style filtering
2082
+ fields = params.get("fields") # e.g., ["uri", "content.text", "metadata.size"]
2083
+ fragments = params.get("fragments") # e.g., {"basicInfo": ["uri", "name"]}
2084
+
2085
+ if not uri_pattern:
2086
+ return {
2087
+ "jsonrpc": "2.0",
2088
+ "error": {"code": -32602, "message": "Missing required parameter: uri"},
2089
+ "id": request_id,
2090
+ }
2091
+
2092
+ try:
2093
+ # Create subscription with auth context and field selection
2094
+ user_context = {"user_id": client_id, "connection_id": client_id}
2095
+ subscription_id = await self.subscription_manager.create_subscription(
2096
+ connection_id=client_id,
2097
+ uri_pattern=uri_pattern,
2098
+ cursor=cursor,
2099
+ user_context=user_context,
2100
+ fields=fields,
2101
+ fragments=fragments,
2102
+ )
2103
+
2104
+ return {
2105
+ "jsonrpc": "2.0",
2106
+ "result": {"subscriptionId": subscription_id},
2107
+ "id": request_id,
2108
+ }
2109
+ except Exception as e:
2110
+ error_code = -32603
2111
+ if "permission" in str(e).lower() or "not authorized" in str(e).lower():
2112
+ error_code = -32601
2113
+ elif "rate limit" in str(e).lower():
2114
+ error_code = -32601
2115
+
2116
+ return {
2117
+ "jsonrpc": "2.0",
2118
+ "error": {"code": error_code, "message": str(e)},
2119
+ "id": request_id,
2120
+ }
2121
+
2122
+ async def _handle_unsubscribe(
2123
+ self, params: Dict[str, Any], request_id: Any, client_id: str
2124
+ ) -> Dict[str, Any]:
2125
+ """Handle resources/unsubscribe request."""
2126
+ if not self.subscription_manager:
2127
+ return {
2128
+ "jsonrpc": "2.0",
2129
+ "error": {"code": -32601, "message": "Subscriptions not enabled"},
2130
+ "id": request_id,
2131
+ }
2132
+
2133
+ subscription_id = params.get("subscriptionId")
2134
+
2135
+ if not subscription_id:
2136
+ return {
2137
+ "jsonrpc": "2.0",
2138
+ "error": {
2139
+ "code": -32602,
2140
+ "message": "Missing required parameter: subscriptionId",
2141
+ },
2142
+ "id": request_id,
2143
+ }
2144
+
2145
+ try:
2146
+ success = await self.subscription_manager.remove_subscription(
2147
+ subscription_id, client_id
2148
+ )
2149
+
2150
+ return {
2151
+ "jsonrpc": "2.0",
2152
+ "result": {"success": success},
2153
+ "id": request_id,
2154
+ }
2155
+ except Exception as e:
2156
+ return {
2157
+ "jsonrpc": "2.0",
2158
+ "error": {"code": -32603, "message": str(e)},
2159
+ "id": request_id,
2160
+ }
2161
+
2162
+ async def _handle_batch_subscribe(
2163
+ self, params: Dict[str, Any], request_id: Any, client_id: str
2164
+ ) -> Dict[str, Any]:
2165
+ """Handle resources/batch_subscribe request."""
2166
+ if not self.subscription_manager:
2167
+ return {
2168
+ "jsonrpc": "2.0",
2169
+ "error": {"code": -32601, "message": "Subscriptions not enabled"},
2170
+ "id": request_id,
2171
+ }
2172
+
2173
+ subscriptions = params.get("subscriptions")
2174
+ if not subscriptions or not isinstance(subscriptions, list):
2175
+ return {
2176
+ "jsonrpc": "2.0",
2177
+ "error": {
2178
+ "code": -32602,
2179
+ "message": "Missing or invalid parameter: subscriptions",
2180
+ },
2181
+ "id": request_id,
2182
+ }
2183
+
2184
+ try:
2185
+ # Create batch subscriptions with auth context
2186
+ user_context = {"user_id": client_id, "connection_id": client_id}
2187
+ results = await self.subscription_manager.create_batch_subscriptions(
2188
+ subscriptions=subscriptions,
2189
+ connection_id=client_id,
2190
+ user_context=user_context,
2191
+ )
2192
+
2193
+ return {
2194
+ "jsonrpc": "2.0",
2195
+ "result": results,
2196
+ "id": request_id,
2197
+ }
2198
+ except Exception as e:
2199
+ return {
2200
+ "jsonrpc": "2.0",
2201
+ "error": {"code": -32603, "message": str(e)},
2202
+ "id": request_id,
2203
+ }
2204
+
2205
+ async def _handle_batch_unsubscribe(
2206
+ self, params: Dict[str, Any], request_id: Any, client_id: str
2207
+ ) -> Dict[str, Any]:
2208
+ """Handle resources/batch_unsubscribe request."""
2209
+ if not self.subscription_manager:
2210
+ return {
2211
+ "jsonrpc": "2.0",
2212
+ "error": {"code": -32601, "message": "Subscriptions not enabled"},
2213
+ "id": request_id,
2214
+ }
2215
+
2216
+ subscription_ids = params.get("subscriptionIds")
2217
+ if not subscription_ids or not isinstance(subscription_ids, list):
2218
+ return {
2219
+ "jsonrpc": "2.0",
2220
+ "error": {
2221
+ "code": -32602,
2222
+ "message": "Missing or invalid parameter: subscriptionIds",
2223
+ },
2224
+ "id": request_id,
2225
+ }
2226
+
2227
+ try:
2228
+ # Remove batch subscriptions
2229
+ results = await self.subscription_manager.remove_batch_subscriptions(
2230
+ subscription_ids=subscription_ids, connection_id=client_id
2231
+ )
2232
+
2233
+ return {
2234
+ "jsonrpc": "2.0",
2235
+ "result": results,
2236
+ "id": request_id,
2237
+ }
2238
+ except Exception as e:
2239
+ return {
2240
+ "jsonrpc": "2.0",
2241
+ "error": {"code": -32603, "message": str(e)},
2242
+ "id": request_id,
2243
+ }
2244
+
2245
+ async def _handle_connection_close(self, client_id: str):
2246
+ """Handle WebSocket connection close."""
2247
+ if self.subscription_manager:
2248
+ removed_count = await self.subscription_manager.cleanup_connection(
2249
+ client_id
2250
+ )
2251
+ if removed_count > 0:
2252
+ logger.info(
2253
+ f"Cleaned up {removed_count} subscriptions for client {client_id}"
2254
+ )
2255
+
2256
+ def _compress_message(
2257
+ self, message: Dict[str, Any]
2258
+ ) -> Union[Dict[str, Any], bytes]:
2259
+ """Compress message if compression is enabled and message exceeds threshold.
2260
+
2261
+ Args:
2262
+ message: The message to potentially compress
2263
+
2264
+ Returns:
2265
+ Either the original dict or compressed bytes with metadata
2266
+ """
2267
+ if not self.enable_websocket_compression:
2268
+ return message
2269
+
2270
+ # Serialize message to determine size
2271
+ message_json = json.dumps(message, separators=(",", ":")).encode("utf-8")
2272
+
2273
+ # Only compress if message exceeds threshold
2274
+ if len(message_json) < self.compression_threshold:
2275
+ return message
2276
+
2277
+ try:
2278
+ # Compress the message
2279
+ compressed_data = gzip.compress(
2280
+ message_json, compresslevel=self.compression_level
2281
+ )
2282
+
2283
+ # Calculate compression ratio
2284
+ compression_ratio = len(compressed_data) / len(message_json)
2285
+
2286
+ # Only use compression if it actually reduces size significantly
2287
+ if compression_ratio > 0.9: # Less than 10% improvement
2288
+ return message
2289
+
2290
+ # Return compressed message with metadata
2291
+ return {
2292
+ "__compressed": True,
2293
+ "__original_size": len(message_json),
2294
+ "__compressed_size": len(compressed_data),
2295
+ "__compression_ratio": compression_ratio,
2296
+ "data": compressed_data.hex(), # Hex encode for JSON transport
2297
+ }
2298
+
2299
+ except Exception as e:
2300
+ logger.warning(f"Failed to compress message: {e}")
2301
+ return message
2302
+
2303
+ def _decompress_message(self, compressed_message: Dict[str, Any]) -> Dict[str, Any]:
2304
+ """Decompress a compressed message.
2305
+
2306
+ Args:
2307
+ compressed_message: The compressed message with metadata
2308
+
2309
+ Returns:
2310
+ The original decompressed message
2311
+ """
2312
+ if not compressed_message.get("__compressed"):
2313
+ return compressed_message
2314
+
2315
+ try:
2316
+ # Decode hex data and decompress
2317
+ compressed_data = bytes.fromhex(compressed_message["data"])
2318
+ decompressed_json = gzip.decompress(compressed_data)
2319
+
2320
+ # Parse back to dict
2321
+ return json.loads(decompressed_json.decode("utf-8"))
2322
+
2323
+ except Exception as e:
2324
+ logger.error(f"Failed to decompress message: {e}")
2325
+ # Return a sensible error message
2326
+ return {
2327
+ "jsonrpc": "2.0",
2328
+ "error": {
2329
+ "code": -32603,
2330
+ "message": f"Failed to decompress message: {e}",
2331
+ },
2332
+ }
2333
+
2334
+ async def _send_websocket_notification(
2335
+ self, client_id: str, notification: Dict[str, Any]
2336
+ ):
2337
+ """Send notification to WebSocket client with optional compression."""
2338
+ if self._transport and hasattr(self._transport, "send_message"):
2339
+ try:
2340
+ # Apply compression if enabled
2341
+ message_to_send = self._compress_message(notification)
2342
+
2343
+ # Log compression stats if compression was applied
2344
+ if isinstance(message_to_send, dict) and message_to_send.get(
2345
+ "__compressed"
2346
+ ):
2347
+ ratio = message_to_send["__compression_ratio"]
2348
+ logger.debug(
2349
+ f"Compressed notification for client {client_id}: "
2350
+ f"{message_to_send['__original_size']} -> "
2351
+ f"{message_to_send['__compressed_size']} bytes "
2352
+ f"({ratio:.2%} ratio)"
2353
+ )
2354
+
2355
+ await self._transport.send_message(message_to_send, client_id=client_id)
2356
+ logger.debug(
2357
+ f"Sent notification to client {client_id}: {notification['method']}"
2358
+ )
2359
+ except Exception as e:
2360
+ logger.error(f"Failed to send notification to client {client_id}: {e}")
2361
+
2362
+ async def _handle_logging_set_level(
2363
+ self, params: Dict[str, Any], request_id: Any
2364
+ ) -> Dict[str, Any]:
2365
+ """Handle logging/setLevel request to dynamically adjust log levels."""
2366
+ level = params.get("level", "INFO").upper()
2367
+
2368
+ # Validate log level
2369
+ valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
2370
+ if level not in valid_levels:
2371
+ return {
2372
+ "jsonrpc": "2.0",
2373
+ "error": {
2374
+ "code": -32602,
2375
+ "message": f"Invalid log level: {level}. Must be one of {valid_levels}",
2376
+ },
2377
+ "id": request_id,
2378
+ }
2379
+
2380
+ # Set the log level
2381
+ logging.getLogger().setLevel(getattr(logging, level))
2382
+ logger.info(f"Log level changed to {level}")
2383
+
2384
+ # Track in event store if available
2385
+ if self.event_store:
2386
+ from kailash.middleware.gateway.event_store import EventType
2387
+
2388
+ await self.event_store.append(
2389
+ event_type=EventType.REQUEST_COMPLETED,
2390
+ request_id=str(request_id),
2391
+ data={
2392
+ "type": "log_level_changed",
2393
+ "level": level,
2394
+ "timestamp": time.time(),
2395
+ "changed_by": params.get("client_id", "unknown"),
2396
+ },
2397
+ )
2398
+
2399
+ return {
2400
+ "jsonrpc": "2.0",
2401
+ "result": {"level": level, "levels": valid_levels},
2402
+ "id": request_id,
2403
+ }
2404
+
2405
+ async def _handle_roots_list(
2406
+ self, params: Dict[str, Any], request_id: Any
2407
+ ) -> Dict[str, Any]:
2408
+ """Handle roots/list request to get file system access roots."""
2409
+ protocol_mgr = get_protocol_manager()
2410
+
2411
+ # Check if client supports roots
2412
+ client_info = self.client_info.get(params.get("client_id", ""))
2413
+ if (
2414
+ not client_info.get("capabilities", {})
2415
+ .get("roots", {})
2416
+ .get("listChanged", False)
2417
+ ):
2418
+ return {
2419
+ "jsonrpc": "2.0",
2420
+ "error": {
2421
+ "code": -32601,
2422
+ "message": "Client does not support roots capability",
2423
+ },
2424
+ "id": request_id,
2425
+ }
2426
+
2427
+ roots = protocol_mgr.roots.list_roots()
2428
+
2429
+ # Apply access control if auth manager is available
2430
+ if self.auth_manager and params.get("client_id"):
2431
+ filtered_roots = []
2432
+ for root in roots:
2433
+ if await protocol_mgr.roots.validate_access(
2434
+ root["uri"],
2435
+ operation="list",
2436
+ user_context=self.client_info.get(params["client_id"], {}),
2437
+ ):
2438
+ filtered_roots.append(root)
2439
+ roots = filtered_roots
2440
+
2441
+ return {"jsonrpc": "2.0", "result": {"roots": roots}, "id": request_id}
2442
+
2443
+ async def _handle_completion_complete(
2444
+ self, params: Dict[str, Any], request_id: Any
2445
+ ) -> Dict[str, Any]:
2446
+ """Handle completion/complete request for auto-completion."""
2447
+ ref = params.get("ref", {})
2448
+ argument = params.get("argument", {})
2449
+
2450
+ # Extract completion parameters
2451
+ ref_type = ref.get("type") # "resource", "prompt", "tool"
2452
+ ref_name = ref.get("name") # Optional specific name
2453
+ partial_value = argument.get("value", "")
2454
+
2455
+ try:
2456
+ values = []
2457
+
2458
+ if ref_type == "resource":
2459
+ # Search through registered resources
2460
+ for uri, resource_info in self._resource_registry.items():
2461
+ if partial_value in uri: # Simple prefix/substring matching
2462
+ values.append(
2463
+ {
2464
+ "uri": uri,
2465
+ "name": resource_info.get("name", uri),
2466
+ "description": resource_info.get("description", ""),
2467
+ }
2468
+ )
2469
+
2470
+ elif ref_type == "prompt":
2471
+ # Search through registered prompts
2472
+ for name, prompt_info in self._prompt_registry.items():
2473
+ if partial_value in name: # Simple prefix/substring matching
2474
+ values.append(
2475
+ {
2476
+ "name": name,
2477
+ "description": prompt_info.get("description", ""),
2478
+ "arguments": prompt_info.get("arguments", []),
2479
+ }
2480
+ )
2481
+
2482
+ elif ref_type == "tool":
2483
+ # Search through registered tools
2484
+ for name, tool_info in self._tool_registry.items():
2485
+ if partial_value in name:
2486
+ values.append(
2487
+ {
2488
+ "name": name,
2489
+ "description": tool_info.get("description", ""),
2490
+ "inputSchema": tool_info.get("inputSchema", {}),
2491
+ }
2492
+ )
2493
+
2494
+ # Limit to 100 items and add hasMore flag if needed
2495
+ total_matches = len(values)
2496
+ has_more = total_matches > 100
2497
+ if has_more:
2498
+ values = values[:100]
2499
+
2500
+ result = {
2501
+ "completion": {
2502
+ "values": values,
2503
+ "total": total_matches,
2504
+ "hasMore": has_more,
2505
+ }
2506
+ }
2507
+
2508
+ return {"jsonrpc": "2.0", "result": result, "id": request_id}
2509
+
2510
+ except Exception as e:
2511
+ logger.error(f"Completion error: {e}")
2512
+ return {
2513
+ "jsonrpc": "2.0",
2514
+ "error": {"code": -32603, "message": f"Completion failed: {str(e)}"},
2515
+ "id": request_id,
2516
+ }
2517
+
2518
+ async def _handle_sampling_create_message(
2519
+ self, params: Dict[str, Any], request_id: Any
2520
+ ) -> Dict[str, Any]:
2521
+ """Handle sampling/createMessage - this is typically server-to-client."""
2522
+ # This is usually initiated by the server to request LLM sampling from the client
2523
+ # For server-side handling, we can validate and forward to connected clients
2524
+
2525
+ protocol_mgr = get_protocol_manager()
2526
+
2527
+ # Check if any client supports sampling
2528
+ sampling_clients = [
2529
+ client_id
2530
+ for client_id, info in self.client_info.items()
2531
+ if info.get("capabilities", {})
2532
+ .get("experimental", {})
2533
+ .get("sampling", False)
2534
+ ]
2535
+
2536
+ if not sampling_clients:
2537
+ return {
2538
+ "jsonrpc": "2.0",
2539
+ "error": {
2540
+ "code": -32601,
2541
+ "message": "No connected clients support sampling",
2542
+ },
2543
+ "id": request_id,
2544
+ }
2545
+
2546
+ # Create sampling request
2547
+ messages = params.get("messages", [])
2548
+ sampling_params = {
2549
+ "messages": messages,
2550
+ "model_preferences": params.get("modelPreferences"),
2551
+ "system_prompt": params.get("systemPrompt"),
2552
+ "temperature": params.get("temperature"),
2553
+ "max_tokens": params.get("maxTokens"),
2554
+ "metadata": params.get("metadata"),
2555
+ }
2556
+
2557
+ # Send to first available sampling client (or implement selection logic)
2558
+ target_client = sampling_clients[0]
2559
+
2560
+ # Create server-to-client request
2561
+ sampling_request = {
2562
+ "jsonrpc": "2.0",
2563
+ "method": "sampling/createMessage",
2564
+ "params": sampling_params,
2565
+ "id": f"sampling_{uuid.uuid4().hex[:8]}",
2566
+ }
2567
+
2568
+ # Send via WebSocket to client
2569
+ if self._transport and hasattr(self._transport, "send_message"):
2570
+ await self._transport.send_message(
2571
+ sampling_request, client_id=target_client
2572
+ )
2573
+
2574
+ # Store pending sampling request
2575
+ if not hasattr(self, "_pending_sampling_requests"):
2576
+ self._pending_sampling_requests = {}
2577
+
2578
+ self._pending_sampling_requests[sampling_request["id"]] = {
2579
+ "original_request_id": request_id,
2580
+ "client_id": params.get("client_id"),
2581
+ "timestamp": time.time(),
2582
+ }
2583
+
2584
+ return {
2585
+ "jsonrpc": "2.0",
2586
+ "result": {
2587
+ "status": "sampling_requested",
2588
+ "sampling_id": sampling_request["id"],
2589
+ "target_client": target_client,
2590
+ },
2591
+ "id": request_id,
2592
+ }
2593
+ else:
2594
+ return {
2595
+ "jsonrpc": "2.0",
2596
+ "error": {
2597
+ "code": -32603,
2598
+ "message": "Transport does not support sampling",
2599
+ },
2600
+ "id": request_id,
2601
+ }
2602
+
1530
2603
  async def run_stdio(self):
1531
2604
  """Run the server using stdio transport for testing."""
1532
2605
  if self._mcp is None: