kailash 0.8.6__py3-none-any.whl → 0.8.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -54,6 +54,8 @@ Enhanced Production Usage:
54
54
 
55
55
  import asyncio
56
56
  import functools
57
+ import gzip
58
+ import json
57
59
  import logging
58
60
  import time
59
61
  import uuid
@@ -74,6 +76,7 @@ from .errors import (
74
76
  RetryableOperation,
75
77
  ToolError,
76
78
  )
79
+ from .protocol import get_protocol_manager
77
80
  from .utils import CacheManager, ConfigManager, MetricsCollector, format_response
78
81
 
79
82
  logger = logging.getLogger(__name__)
@@ -369,6 +372,13 @@ class MCPServer:
369
372
  transport_timeout: float = 30.0,
370
373
  max_request_size: int = 10_000_000, # 10MB
371
374
  enable_streaming: bool = False,
375
+ # Resource subscription configuration
376
+ enable_subscriptions: bool = True,
377
+ event_store=None,
378
+ # WebSocket compression configuration
379
+ enable_websocket_compression: bool = False,
380
+ compression_threshold: int = 1024, # Only compress messages larger than 1KB
381
+ compression_level: int = 6, # 1 (fastest) to 9 (best compression)
372
382
  ):
373
383
  """
374
384
  Initialize enhanced MCP server.
@@ -396,6 +406,11 @@ class MCPServer:
396
406
  transport_timeout: Transport timeout in seconds
397
407
  max_request_size: Maximum request size in bytes
398
408
  enable_streaming: Enable streaming support
409
+ enable_subscriptions: Enable resource subscriptions (default: True)
410
+ event_store: Optional event store for subscription logging
411
+ enable_websocket_compression: Enable gzip compression for WebSocket messages (default: False)
412
+ compression_threshold: Only compress messages larger than this size in bytes (default: 1024)
413
+ compression_level: Compression level from 1 (fastest) to 9 (best compression) (default: 6)
399
414
  """
400
415
  self.name = name
401
416
 
@@ -404,6 +419,11 @@ class MCPServer:
404
419
  self.websocket_host = websocket_host
405
420
  self.websocket_port = websocket_port
406
421
 
422
+ # WebSocket compression configuration
423
+ self.enable_websocket_compression = enable_websocket_compression
424
+ self.compression_threshold = compression_threshold
425
+ self.compression_level = compression_level
426
+
407
427
  # Enhanced features
408
428
  self.auth_provider = auth_provider
409
429
  self.enable_http_transport = enable_http_transport
@@ -515,6 +535,27 @@ class MCPServer:
515
535
  self._resource_registry: Dict[str, Dict[str, Any]] = {}
516
536
  self._prompt_registry: Dict[str, Dict[str, Any]] = {}
517
537
 
538
+ # Client management for new handlers
539
+ self.client_info: Dict[str, Dict[str, Any]] = {}
540
+ self._pending_sampling_requests: Dict[str, Dict[str, Any]] = {}
541
+
542
+ # Resource subscription support
543
+ self.enable_subscriptions = enable_subscriptions
544
+ self.event_store = event_store
545
+ self.subscription_manager = None
546
+ if self.enable_subscriptions:
547
+ from .subscriptions import ResourceSubscriptionManager
548
+
549
+ self.subscription_manager = ResourceSubscriptionManager(
550
+ auth_manager=(
551
+ self.auth_manager if hasattr(self, "auth_manager") else None
552
+ ),
553
+ event_store=event_store,
554
+ rate_limiter=(
555
+ self.rate_limiter if hasattr(self, "rate_limiter") else None
556
+ ),
557
+ )
558
+
518
559
  # Transport instance (for WebSocket and other transports)
519
560
  self._transport = None
520
561
 
@@ -1643,6 +1684,13 @@ class MCPServer:
1643
1684
  f"WebSocket server started on {self.websocket_host}:{self.websocket_port}"
1644
1685
  )
1645
1686
 
1687
+ # Set up subscription notification callback
1688
+ if self.subscription_manager:
1689
+ await self.subscription_manager.initialize()
1690
+ self.subscription_manager.set_notification_callback(
1691
+ self._send_websocket_notification
1692
+ )
1693
+
1646
1694
  # Keep server running
1647
1695
  try:
1648
1696
  await asyncio.Future() # Run forever
@@ -1658,18 +1706,21 @@ class MCPServer:
1658
1706
  async def _handle_websocket_message(
1659
1707
  self, request: Dict[str, Any], client_id: str
1660
1708
  ) -> Dict[str, Any]:
1661
- """Handle incoming WebSocket message."""
1709
+ """Handle incoming WebSocket message with decompression support."""
1662
1710
  try:
1663
- method = request.get("method", "")
1664
- params = request.get("params", {})
1665
- request_id = request.get("id")
1711
+ # Decompress message if needed
1712
+ decompressed_request = self._decompress_message(request)
1713
+
1714
+ method = decompressed_request.get("method", "")
1715
+ params = decompressed_request.get("params", {})
1716
+ request_id = decompressed_request.get("id")
1666
1717
 
1667
1718
  # Log request
1668
1719
  logger.debug(f"WebSocket request from {client_id}: {method}")
1669
1720
 
1670
1721
  # Route to appropriate handler
1671
1722
  if method == "initialize":
1672
- return await self._handle_initialize(params, request_id)
1723
+ return await self._handle_initialize(params, request_id, client_id)
1673
1724
  elif method == "tools/list":
1674
1725
  return await self._handle_list_tools(params, request_id)
1675
1726
  elif method == "tools/call":
@@ -1677,11 +1728,35 @@ class MCPServer:
1677
1728
  elif method == "resources/list":
1678
1729
  return await self._handle_list_resources(params, request_id)
1679
1730
  elif method == "resources/read":
1680
- return await self._handle_read_resource(params, request_id)
1731
+ return await self._handle_read_resource(params, request_id, client_id)
1732
+ elif method == "resources/subscribe":
1733
+ return await self._handle_subscribe(params, request_id, client_id)
1734
+ elif method == "resources/unsubscribe":
1735
+ return await self._handle_unsubscribe(params, request_id, client_id)
1736
+ elif method == "resources/batch_subscribe":
1737
+ return await self._handle_batch_subscribe(params, request_id, client_id)
1738
+ elif method == "resources/batch_unsubscribe":
1739
+ return await self._handle_batch_unsubscribe(
1740
+ params, request_id, client_id
1741
+ )
1681
1742
  elif method == "prompts/list":
1682
1743
  return await self._handle_list_prompts(params, request_id)
1683
1744
  elif method == "prompts/get":
1684
1745
  return await self._handle_get_prompt(params, request_id)
1746
+ elif method == "logging/setLevel":
1747
+ return await self._handle_logging_set_level(params, request_id)
1748
+ elif method == "roots/list":
1749
+ # Add client_id to params for roots/list handler
1750
+ params_with_client = {**params, "client_id": client_id}
1751
+ return await self._handle_roots_list(params_with_client, request_id)
1752
+ elif method == "completion/complete":
1753
+ return await self._handle_completion_complete(params, request_id)
1754
+ elif method == "sampling/createMessage":
1755
+ # Add client_id to params for sampling handler
1756
+ params_with_client = {**params, "client_id": client_id}
1757
+ return await self._handle_sampling_create_message(
1758
+ params_with_client, request_id
1759
+ )
1685
1760
  else:
1686
1761
  return {
1687
1762
  "jsonrpc": "2.0",
@@ -1698,17 +1773,42 @@ class MCPServer:
1698
1773
  }
1699
1774
 
1700
1775
  async def _handle_initialize(
1701
- self, params: Dict[str, Any], request_id: Any
1776
+ self, params: Dict[str, Any], request_id: Any, client_id: str = None
1702
1777
  ) -> Dict[str, Any]:
1703
1778
  """Handle initialize request."""
1779
+ # Store client information for capability checks
1780
+ if client_id:
1781
+ self.client_info[client_id] = {
1782
+ "capabilities": params.get("capabilities", {}),
1783
+ "name": params.get("clientInfo", {}).get("name", "unknown"),
1784
+ "version": params.get("clientInfo", {}).get("version", "unknown"),
1785
+ "initialized_at": time.time(),
1786
+ }
1787
+
1704
1788
  return {
1705
1789
  "jsonrpc": "2.0",
1706
1790
  "result": {
1707
1791
  "protocolVersion": "2024-11-05",
1708
1792
  "capabilities": {
1709
1793
  "tools": {"listSupported": True, "callSupported": True},
1710
- "resources": {"listSupported": True, "readSupported": True},
1794
+ "resources": {
1795
+ "listSupported": True,
1796
+ "readSupported": True,
1797
+ "subscribe": self.enable_subscriptions,
1798
+ "listChanged": self.enable_subscriptions,
1799
+ "batch_subscribe": self.enable_subscriptions,
1800
+ "batch_unsubscribe": self.enable_subscriptions,
1801
+ },
1711
1802
  "prompts": {"listSupported": True, "getSupported": True},
1803
+ "logging": {"setLevel": True},
1804
+ "roots": {"list": True},
1805
+ "experimental": {
1806
+ "progressNotifications": True,
1807
+ "cancellation": True,
1808
+ "completion": True,
1809
+ "sampling": True,
1810
+ "websocketCompression": self.enable_websocket_compression,
1811
+ },
1712
1812
  },
1713
1813
  "serverInfo": {
1714
1814
  "name": self.name,
@@ -1764,10 +1864,14 @@ class MCPServer:
1764
1864
  async def _handle_list_resources(
1765
1865
  self, params: Dict[str, Any], request_id: Any
1766
1866
  ) -> Dict[str, Any]:
1767
- """Handle resources/list request."""
1768
- resources = []
1867
+ """Handle resources/list request with cursor-based pagination."""
1868
+ cursor = params.get("cursor")
1869
+ limit = params.get("limit")
1870
+
1871
+ # Get all resources
1872
+ all_resources = []
1769
1873
  for uri, info in self._resource_registry.items():
1770
- resources.append(
1874
+ all_resources.append(
1771
1875
  {
1772
1876
  "uri": uri,
1773
1877
  "name": info.get("name", uri),
@@ -1776,15 +1880,66 @@ class MCPServer:
1776
1880
  }
1777
1881
  )
1778
1882
 
1883
+ # Handle pagination if subscription manager is available
1884
+ if self.subscription_manager:
1885
+ cursor_manager = self.subscription_manager.cursor_manager
1886
+
1887
+ # Determine starting position
1888
+ start_pos = 0
1889
+ if cursor:
1890
+ if cursor_manager.is_valid(cursor):
1891
+ start_pos = cursor_manager.get_cursor_position(cursor) or 0
1892
+ else:
1893
+ return {
1894
+ "jsonrpc": "2.0",
1895
+ "error": {
1896
+ "code": -32602,
1897
+ "message": "Invalid or expired cursor",
1898
+ },
1899
+ "id": request_id,
1900
+ }
1901
+
1902
+ # Apply pagination
1903
+ if limit:
1904
+ end_pos = start_pos + limit
1905
+ resources = all_resources[start_pos:end_pos]
1906
+
1907
+ # Generate next cursor if there are more resources
1908
+ next_cursor = None
1909
+ if end_pos < len(all_resources):
1910
+ next_cursor = cursor_manager.create_cursor_for_position(
1911
+ all_resources, end_pos
1912
+ )
1913
+
1914
+ result = {"resources": resources}
1915
+ if next_cursor:
1916
+ result["nextCursor"] = next_cursor
1917
+
1918
+ return {"jsonrpc": "2.0", "result": result, "id": request_id}
1919
+ else:
1920
+ resources = all_resources[start_pos:]
1921
+ else:
1922
+ # No pagination support
1923
+ resources = all_resources
1924
+
1779
1925
  return {"jsonrpc": "2.0", "result": {"resources": resources}, "id": request_id}
1780
1926
 
1781
1927
  async def _handle_read_resource(
1782
- self, params: Dict[str, Any], request_id: Any
1928
+ self, params: Dict[str, Any], request_id: Any, client_id: str = None
1783
1929
  ) -> Dict[str, Any]:
1784
- """Handle resources/read request."""
1930
+ """Handle resources/read request with change detection."""
1785
1931
  uri = params.get("uri")
1786
1932
 
1787
- if uri not in self._resource_registry:
1933
+ # First try exact match
1934
+ resource_info = None
1935
+ resource_params = {}
1936
+ if uri in self._resource_registry:
1937
+ resource_info = self._resource_registry[uri]
1938
+ else:
1939
+ # Try template matching
1940
+ resource_info, resource_params = self._match_resource_template(uri)
1941
+
1942
+ if resource_info is None:
1788
1943
  return {
1789
1944
  "jsonrpc": "2.0",
1790
1945
  "error": {"code": -32602, "message": f"Resource not found: {uri}"},
@@ -1792,16 +1947,38 @@ class MCPServer:
1792
1947
  }
1793
1948
 
1794
1949
  try:
1795
- resource_info = self._resource_registry[uri]
1796
1950
  handler = resource_info.get("handler")
1951
+ original_handler = resource_info.get("original_handler")
1797
1952
 
1798
1953
  if handler:
1799
- content = handler()
1954
+ # Use original handler with parameters if available
1955
+ if original_handler and resource_params:
1956
+ content = original_handler(**resource_params)
1957
+ else:
1958
+ content = handler()
1800
1959
  if asyncio.iscoroutine(content):
1801
1960
  content = await content
1802
1961
  else:
1803
1962
  content = ""
1804
1963
 
1964
+ # Process change detection if subscription manager is available
1965
+ if self.subscription_manager:
1966
+ resource_data = {
1967
+ "uri": uri,
1968
+ "text": str(content),
1969
+ "mimeType": resource_info.get("mime_type", "text/plain"),
1970
+ }
1971
+
1972
+ # Check for changes and notify subscribers
1973
+ change = (
1974
+ await self.subscription_manager.resource_monitor.check_for_changes(
1975
+ uri, resource_data
1976
+ )
1977
+ )
1978
+
1979
+ if change:
1980
+ await self.subscription_manager.process_resource_change(change)
1981
+
1805
1982
  return {
1806
1983
  "jsonrpc": "2.0",
1807
1984
  "result": {"contents": [{"uri": uri, "text": str(content)}]},
@@ -1814,6 +1991,24 @@ class MCPServer:
1814
1991
  "id": request_id,
1815
1992
  }
1816
1993
 
1994
+ def _match_resource_template(self, uri: str) -> tuple:
1995
+ """Match URI against resource templates and extract parameters."""
1996
+ import re
1997
+
1998
+ for template_uri, resource_info in self._resource_registry.items():
1999
+ # Convert template to regex pattern
2000
+ # Replace {param} with named capture groups
2001
+ pattern = re.sub(r"\{([^}]+)\}", r"(?P<\1>[^/]+)", template_uri)
2002
+ pattern = f"^{pattern}$"
2003
+
2004
+ match = re.match(pattern, uri)
2005
+ if match:
2006
+ # Extract parameters from the match
2007
+ params = match.groupdict()
2008
+ return resource_info, params
2009
+
2010
+ return None, {}
2011
+
1817
2012
  async def _handle_list_prompts(
1818
2013
  self, params: Dict[str, Any], request_id: Any
1819
2014
  ) -> Dict[str, Any]:
@@ -1870,6 +2065,541 @@ class MCPServer:
1870
2065
  "id": request_id,
1871
2066
  }
1872
2067
 
2068
+ async def _handle_subscribe(
2069
+ self, params: Dict[str, Any], request_id: Any, client_id: str
2070
+ ) -> Dict[str, Any]:
2071
+ """Handle resources/subscribe request with GraphQL-style field selection."""
2072
+ if not self.subscription_manager:
2073
+ return {
2074
+ "jsonrpc": "2.0",
2075
+ "error": {"code": -32601, "message": "Subscriptions not enabled"},
2076
+ "id": request_id,
2077
+ }
2078
+
2079
+ uri_pattern = params.get("uri")
2080
+ cursor = params.get("cursor")
2081
+ # Extract field selection parameters for GraphQL-style filtering
2082
+ fields = params.get("fields") # e.g., ["uri", "content.text", "metadata.size"]
2083
+ fragments = params.get("fragments") # e.g., {"basicInfo": ["uri", "name"]}
2084
+
2085
+ if not uri_pattern:
2086
+ return {
2087
+ "jsonrpc": "2.0",
2088
+ "error": {"code": -32602, "message": "Missing required parameter: uri"},
2089
+ "id": request_id,
2090
+ }
2091
+
2092
+ try:
2093
+ # Create subscription with auth context and field selection
2094
+ user_context = {"user_id": client_id, "connection_id": client_id}
2095
+ subscription_id = await self.subscription_manager.create_subscription(
2096
+ connection_id=client_id,
2097
+ uri_pattern=uri_pattern,
2098
+ cursor=cursor,
2099
+ user_context=user_context,
2100
+ fields=fields,
2101
+ fragments=fragments,
2102
+ )
2103
+
2104
+ return {
2105
+ "jsonrpc": "2.0",
2106
+ "result": {"subscriptionId": subscription_id},
2107
+ "id": request_id,
2108
+ }
2109
+ except Exception as e:
2110
+ error_code = -32603
2111
+ if "permission" in str(e).lower() or "not authorized" in str(e).lower():
2112
+ error_code = -32601
2113
+ elif "rate limit" in str(e).lower():
2114
+ error_code = -32601
2115
+
2116
+ return {
2117
+ "jsonrpc": "2.0",
2118
+ "error": {"code": error_code, "message": str(e)},
2119
+ "id": request_id,
2120
+ }
2121
+
2122
+ async def _handle_unsubscribe(
2123
+ self, params: Dict[str, Any], request_id: Any, client_id: str
2124
+ ) -> Dict[str, Any]:
2125
+ """Handle resources/unsubscribe request."""
2126
+ if not self.subscription_manager:
2127
+ return {
2128
+ "jsonrpc": "2.0",
2129
+ "error": {"code": -32601, "message": "Subscriptions not enabled"},
2130
+ "id": request_id,
2131
+ }
2132
+
2133
+ subscription_id = params.get("subscriptionId")
2134
+
2135
+ if not subscription_id:
2136
+ return {
2137
+ "jsonrpc": "2.0",
2138
+ "error": {
2139
+ "code": -32602,
2140
+ "message": "Missing required parameter: subscriptionId",
2141
+ },
2142
+ "id": request_id,
2143
+ }
2144
+
2145
+ try:
2146
+ success = await self.subscription_manager.remove_subscription(
2147
+ subscription_id, client_id
2148
+ )
2149
+
2150
+ return {
2151
+ "jsonrpc": "2.0",
2152
+ "result": {"success": success},
2153
+ "id": request_id,
2154
+ }
2155
+ except Exception as e:
2156
+ return {
2157
+ "jsonrpc": "2.0",
2158
+ "error": {"code": -32603, "message": str(e)},
2159
+ "id": request_id,
2160
+ }
2161
+
2162
+ async def _handle_batch_subscribe(
2163
+ self, params: Dict[str, Any], request_id: Any, client_id: str
2164
+ ) -> Dict[str, Any]:
2165
+ """Handle resources/batch_subscribe request."""
2166
+ if not self.subscription_manager:
2167
+ return {
2168
+ "jsonrpc": "2.0",
2169
+ "error": {"code": -32601, "message": "Subscriptions not enabled"},
2170
+ "id": request_id,
2171
+ }
2172
+
2173
+ subscriptions = params.get("subscriptions")
2174
+ if not subscriptions or not isinstance(subscriptions, list):
2175
+ return {
2176
+ "jsonrpc": "2.0",
2177
+ "error": {
2178
+ "code": -32602,
2179
+ "message": "Missing or invalid parameter: subscriptions",
2180
+ },
2181
+ "id": request_id,
2182
+ }
2183
+
2184
+ try:
2185
+ # Create batch subscriptions with auth context
2186
+ user_context = {"user_id": client_id, "connection_id": client_id}
2187
+ results = await self.subscription_manager.create_batch_subscriptions(
2188
+ subscriptions=subscriptions,
2189
+ connection_id=client_id,
2190
+ user_context=user_context,
2191
+ )
2192
+
2193
+ return {
2194
+ "jsonrpc": "2.0",
2195
+ "result": results,
2196
+ "id": request_id,
2197
+ }
2198
+ except Exception as e:
2199
+ return {
2200
+ "jsonrpc": "2.0",
2201
+ "error": {"code": -32603, "message": str(e)},
2202
+ "id": request_id,
2203
+ }
2204
+
2205
+ async def _handle_batch_unsubscribe(
2206
+ self, params: Dict[str, Any], request_id: Any, client_id: str
2207
+ ) -> Dict[str, Any]:
2208
+ """Handle resources/batch_unsubscribe request."""
2209
+ if not self.subscription_manager:
2210
+ return {
2211
+ "jsonrpc": "2.0",
2212
+ "error": {"code": -32601, "message": "Subscriptions not enabled"},
2213
+ "id": request_id,
2214
+ }
2215
+
2216
+ subscription_ids = params.get("subscriptionIds")
2217
+ if not subscription_ids or not isinstance(subscription_ids, list):
2218
+ return {
2219
+ "jsonrpc": "2.0",
2220
+ "error": {
2221
+ "code": -32602,
2222
+ "message": "Missing or invalid parameter: subscriptionIds",
2223
+ },
2224
+ "id": request_id,
2225
+ }
2226
+
2227
+ try:
2228
+ # Remove batch subscriptions
2229
+ results = await self.subscription_manager.remove_batch_subscriptions(
2230
+ subscription_ids=subscription_ids, connection_id=client_id
2231
+ )
2232
+
2233
+ return {
2234
+ "jsonrpc": "2.0",
2235
+ "result": results,
2236
+ "id": request_id,
2237
+ }
2238
+ except Exception as e:
2239
+ return {
2240
+ "jsonrpc": "2.0",
2241
+ "error": {"code": -32603, "message": str(e)},
2242
+ "id": request_id,
2243
+ }
2244
+
2245
+ async def _handle_connection_close(self, client_id: str):
2246
+ """Handle WebSocket connection close."""
2247
+ if self.subscription_manager:
2248
+ removed_count = await self.subscription_manager.cleanup_connection(
2249
+ client_id
2250
+ )
2251
+ if removed_count > 0:
2252
+ logger.info(
2253
+ f"Cleaned up {removed_count} subscriptions for client {client_id}"
2254
+ )
2255
+
2256
+ def _compress_message(
2257
+ self, message: Dict[str, Any]
2258
+ ) -> Union[Dict[str, Any], bytes]:
2259
+ """Compress message if compression is enabled and message exceeds threshold.
2260
+
2261
+ Args:
2262
+ message: The message to potentially compress
2263
+
2264
+ Returns:
2265
+ Either the original dict or compressed bytes with metadata
2266
+ """
2267
+ if not self.enable_websocket_compression:
2268
+ return message
2269
+
2270
+ # Serialize message to determine size
2271
+ message_json = json.dumps(message, separators=(",", ":")).encode("utf-8")
2272
+
2273
+ # Only compress if message exceeds threshold
2274
+ if len(message_json) < self.compression_threshold:
2275
+ return message
2276
+
2277
+ try:
2278
+ # Compress the message
2279
+ compressed_data = gzip.compress(
2280
+ message_json, compresslevel=self.compression_level
2281
+ )
2282
+
2283
+ # Calculate compression ratio
2284
+ compression_ratio = len(compressed_data) / len(message_json)
2285
+
2286
+ # Only use compression if it actually reduces size significantly
2287
+ if compression_ratio > 0.9: # Less than 10% improvement
2288
+ return message
2289
+
2290
+ # Return compressed message with metadata
2291
+ return {
2292
+ "__compressed": True,
2293
+ "__original_size": len(message_json),
2294
+ "__compressed_size": len(compressed_data),
2295
+ "__compression_ratio": compression_ratio,
2296
+ "data": compressed_data.hex(), # Hex encode for JSON transport
2297
+ }
2298
+
2299
+ except Exception as e:
2300
+ logger.warning(f"Failed to compress message: {e}")
2301
+ return message
2302
+
2303
+ def _decompress_message(self, compressed_message: Dict[str, Any]) -> Dict[str, Any]:
2304
+ """Decompress a compressed message.
2305
+
2306
+ Args:
2307
+ compressed_message: The compressed message with metadata
2308
+
2309
+ Returns:
2310
+ The original decompressed message
2311
+ """
2312
+ if not compressed_message.get("__compressed"):
2313
+ return compressed_message
2314
+
2315
+ try:
2316
+ # Decode hex data and decompress
2317
+ compressed_data = bytes.fromhex(compressed_message["data"])
2318
+ decompressed_json = gzip.decompress(compressed_data)
2319
+
2320
+ # Parse back to dict
2321
+ return json.loads(decompressed_json.decode("utf-8"))
2322
+
2323
+ except Exception as e:
2324
+ logger.error(f"Failed to decompress message: {e}")
2325
+ # Return a sensible error message
2326
+ return {
2327
+ "jsonrpc": "2.0",
2328
+ "error": {
2329
+ "code": -32603,
2330
+ "message": f"Failed to decompress message: {e}",
2331
+ },
2332
+ }
2333
+
2334
+ async def _send_websocket_notification(
2335
+ self, client_id: str, notification: Dict[str, Any]
2336
+ ):
2337
+ """Send notification to WebSocket client with optional compression."""
2338
+ if self._transport and hasattr(self._transport, "send_message"):
2339
+ try:
2340
+ # Apply compression if enabled
2341
+ message_to_send = self._compress_message(notification)
2342
+
2343
+ # Log compression stats if compression was applied
2344
+ if isinstance(message_to_send, dict) and message_to_send.get(
2345
+ "__compressed"
2346
+ ):
2347
+ ratio = message_to_send["__compression_ratio"]
2348
+ logger.debug(
2349
+ f"Compressed notification for client {client_id}: "
2350
+ f"{message_to_send['__original_size']} -> "
2351
+ f"{message_to_send['__compressed_size']} bytes "
2352
+ f"({ratio:.2%} ratio)"
2353
+ )
2354
+
2355
+ await self._transport.send_message(message_to_send, client_id=client_id)
2356
+ logger.debug(
2357
+ f"Sent notification to client {client_id}: {notification['method']}"
2358
+ )
2359
+ except Exception as e:
2360
+ logger.error(f"Failed to send notification to client {client_id}: {e}")
2361
+
2362
+ async def _handle_logging_set_level(
2363
+ self, params: Dict[str, Any], request_id: Any
2364
+ ) -> Dict[str, Any]:
2365
+ """Handle logging/setLevel request to dynamically adjust log levels."""
2366
+ level = params.get("level", "INFO").upper()
2367
+
2368
+ # Validate log level
2369
+ valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
2370
+ if level not in valid_levels:
2371
+ return {
2372
+ "jsonrpc": "2.0",
2373
+ "error": {
2374
+ "code": -32602,
2375
+ "message": f"Invalid log level: {level}. Must be one of {valid_levels}",
2376
+ },
2377
+ "id": request_id,
2378
+ }
2379
+
2380
+ # Set the log level
2381
+ logging.getLogger().setLevel(getattr(logging, level))
2382
+ logger.info(f"Log level changed to {level}")
2383
+
2384
+ # Track in event store if available
2385
+ if self.event_store:
2386
+ from kailash.middleware.gateway.event_store import EventType
2387
+
2388
+ await self.event_store.append(
2389
+ event_type=EventType.REQUEST_COMPLETED,
2390
+ request_id=str(request_id),
2391
+ data={
2392
+ "type": "log_level_changed",
2393
+ "level": level,
2394
+ "timestamp": time.time(),
2395
+ "changed_by": params.get("client_id", "unknown"),
2396
+ },
2397
+ )
2398
+
2399
+ return {
2400
+ "jsonrpc": "2.0",
2401
+ "result": {"level": level, "levels": valid_levels},
2402
+ "id": request_id,
2403
+ }
2404
+
2405
+ async def _handle_roots_list(
2406
+ self, params: Dict[str, Any], request_id: Any
2407
+ ) -> Dict[str, Any]:
2408
+ """Handle roots/list request to get file system access roots."""
2409
+ protocol_mgr = get_protocol_manager()
2410
+
2411
+ # Check if client supports roots
2412
+ client_info = self.client_info.get(params.get("client_id", ""))
2413
+ if (
2414
+ not client_info.get("capabilities", {})
2415
+ .get("roots", {})
2416
+ .get("listChanged", False)
2417
+ ):
2418
+ return {
2419
+ "jsonrpc": "2.0",
2420
+ "error": {
2421
+ "code": -32601,
2422
+ "message": "Client does not support roots capability",
2423
+ },
2424
+ "id": request_id,
2425
+ }
2426
+
2427
+ roots = protocol_mgr.roots.list_roots()
2428
+
2429
+ # Apply access control if auth manager is available
2430
+ if self.auth_manager and params.get("client_id"):
2431
+ filtered_roots = []
2432
+ for root in roots:
2433
+ if await protocol_mgr.roots.validate_access(
2434
+ root["uri"],
2435
+ operation="list",
2436
+ user_context=self.client_info.get(params["client_id"], {}),
2437
+ ):
2438
+ filtered_roots.append(root)
2439
+ roots = filtered_roots
2440
+
2441
+ return {"jsonrpc": "2.0", "result": {"roots": roots}, "id": request_id}
2442
+
2443
+ async def _handle_completion_complete(
2444
+ self, params: Dict[str, Any], request_id: Any
2445
+ ) -> Dict[str, Any]:
2446
+ """Handle completion/complete request for auto-completion."""
2447
+ ref = params.get("ref", {})
2448
+ argument = params.get("argument", {})
2449
+
2450
+ # Extract completion parameters
2451
+ ref_type = ref.get("type") # "resource", "prompt", "tool"
2452
+ ref_name = ref.get("name") # Optional specific name
2453
+ partial_value = argument.get("value", "")
2454
+
2455
+ try:
2456
+ values = []
2457
+
2458
+ if ref_type == "resource":
2459
+ # Search through registered resources
2460
+ for uri, resource_info in self._resource_registry.items():
2461
+ if partial_value in uri: # Simple prefix/substring matching
2462
+ values.append(
2463
+ {
2464
+ "uri": uri,
2465
+ "name": resource_info.get("name", uri),
2466
+ "description": resource_info.get("description", ""),
2467
+ }
2468
+ )
2469
+
2470
+ elif ref_type == "prompt":
2471
+ # Search through registered prompts
2472
+ for name, prompt_info in self._prompt_registry.items():
2473
+ if partial_value in name: # Simple prefix/substring matching
2474
+ values.append(
2475
+ {
2476
+ "name": name,
2477
+ "description": prompt_info.get("description", ""),
2478
+ "arguments": prompt_info.get("arguments", []),
2479
+ }
2480
+ )
2481
+
2482
+ elif ref_type == "tool":
2483
+ # Search through registered tools
2484
+ for name, tool_info in self._tool_registry.items():
2485
+ if partial_value in name:
2486
+ values.append(
2487
+ {
2488
+ "name": name,
2489
+ "description": tool_info.get("description", ""),
2490
+ "inputSchema": tool_info.get("inputSchema", {}),
2491
+ }
2492
+ )
2493
+
2494
+ # Limit to 100 items and add hasMore flag if needed
2495
+ total_matches = len(values)
2496
+ has_more = total_matches > 100
2497
+ if has_more:
2498
+ values = values[:100]
2499
+
2500
+ result = {
2501
+ "completion": {
2502
+ "values": values,
2503
+ "total": total_matches,
2504
+ "hasMore": has_more,
2505
+ }
2506
+ }
2507
+
2508
+ return {"jsonrpc": "2.0", "result": result, "id": request_id}
2509
+
2510
+ except Exception as e:
2511
+ logger.error(f"Completion error: {e}")
2512
+ return {
2513
+ "jsonrpc": "2.0",
2514
+ "error": {"code": -32603, "message": f"Completion failed: {str(e)}"},
2515
+ "id": request_id,
2516
+ }
2517
+
2518
+ async def _handle_sampling_create_message(
2519
+ self, params: Dict[str, Any], request_id: Any
2520
+ ) -> Dict[str, Any]:
2521
+ """Handle sampling/createMessage - this is typically server-to-client."""
2522
+ # This is usually initiated by the server to request LLM sampling from the client
2523
+ # For server-side handling, we can validate and forward to connected clients
2524
+
2525
+ protocol_mgr = get_protocol_manager()
2526
+
2527
+ # Check if any client supports sampling
2528
+ sampling_clients = [
2529
+ client_id
2530
+ for client_id, info in self.client_info.items()
2531
+ if info.get("capabilities", {})
2532
+ .get("experimental", {})
2533
+ .get("sampling", False)
2534
+ ]
2535
+
2536
+ if not sampling_clients:
2537
+ return {
2538
+ "jsonrpc": "2.0",
2539
+ "error": {
2540
+ "code": -32601,
2541
+ "message": "No connected clients support sampling",
2542
+ },
2543
+ "id": request_id,
2544
+ }
2545
+
2546
+ # Create sampling request
2547
+ messages = params.get("messages", [])
2548
+ sampling_params = {
2549
+ "messages": messages,
2550
+ "model_preferences": params.get("modelPreferences"),
2551
+ "system_prompt": params.get("systemPrompt"),
2552
+ "temperature": params.get("temperature"),
2553
+ "max_tokens": params.get("maxTokens"),
2554
+ "metadata": params.get("metadata"),
2555
+ }
2556
+
2557
+ # Send to first available sampling client (or implement selection logic)
2558
+ target_client = sampling_clients[0]
2559
+
2560
+ # Create server-to-client request
2561
+ sampling_request = {
2562
+ "jsonrpc": "2.0",
2563
+ "method": "sampling/createMessage",
2564
+ "params": sampling_params,
2565
+ "id": f"sampling_{uuid.uuid4().hex[:8]}",
2566
+ }
2567
+
2568
+ # Send via WebSocket to client
2569
+ if self._transport and hasattr(self._transport, "send_message"):
2570
+ await self._transport.send_message(
2571
+ sampling_request, client_id=target_client
2572
+ )
2573
+
2574
+ # Store pending sampling request
2575
+ if not hasattr(self, "_pending_sampling_requests"):
2576
+ self._pending_sampling_requests = {}
2577
+
2578
+ self._pending_sampling_requests[sampling_request["id"]] = {
2579
+ "original_request_id": request_id,
2580
+ "client_id": params.get("client_id"),
2581
+ "timestamp": time.time(),
2582
+ }
2583
+
2584
+ return {
2585
+ "jsonrpc": "2.0",
2586
+ "result": {
2587
+ "status": "sampling_requested",
2588
+ "sampling_id": sampling_request["id"],
2589
+ "target_client": target_client,
2590
+ },
2591
+ "id": request_id,
2592
+ }
2593
+ else:
2594
+ return {
2595
+ "jsonrpc": "2.0",
2596
+ "error": {
2597
+ "code": -32603,
2598
+ "message": "Transport does not support sampling",
2599
+ },
2600
+ "id": request_id,
2601
+ }
2602
+
1873
2603
  async def run_stdio(self):
1874
2604
  """Run the server using stdio transport for testing."""
1875
2605
  if self._mcp is None: