kailash 0.8.6__py3-none-any.whl → 0.8.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +4 -4
- kailash/mcp_server/protocol.py +26 -0
- kailash/mcp_server/server.py +746 -16
- kailash/mcp_server/subscriptions.py +1560 -0
- {kailash-0.8.6.dist-info → kailash-0.8.7.dist-info}/METADATA +2 -1
- {kailash-0.8.6.dist-info → kailash-0.8.7.dist-info}/RECORD +10 -9
- {kailash-0.8.6.dist-info → kailash-0.8.7.dist-info}/WHEEL +0 -0
- {kailash-0.8.6.dist-info → kailash-0.8.7.dist-info}/entry_points.txt +0 -0
- {kailash-0.8.6.dist-info → kailash-0.8.7.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.8.6.dist-info → kailash-0.8.7.dist-info}/top_level.txt +0 -0
kailash/mcp_server/server.py
CHANGED
@@ -54,6 +54,8 @@ Enhanced Production Usage:
|
|
54
54
|
|
55
55
|
import asyncio
|
56
56
|
import functools
|
57
|
+
import gzip
|
58
|
+
import json
|
57
59
|
import logging
|
58
60
|
import time
|
59
61
|
import uuid
|
@@ -74,6 +76,7 @@ from .errors import (
|
|
74
76
|
RetryableOperation,
|
75
77
|
ToolError,
|
76
78
|
)
|
79
|
+
from .protocol import get_protocol_manager
|
77
80
|
from .utils import CacheManager, ConfigManager, MetricsCollector, format_response
|
78
81
|
|
79
82
|
logger = logging.getLogger(__name__)
|
@@ -369,6 +372,13 @@ class MCPServer:
|
|
369
372
|
transport_timeout: float = 30.0,
|
370
373
|
max_request_size: int = 10_000_000, # 10MB
|
371
374
|
enable_streaming: bool = False,
|
375
|
+
# Resource subscription configuration
|
376
|
+
enable_subscriptions: bool = True,
|
377
|
+
event_store=None,
|
378
|
+
# WebSocket compression configuration
|
379
|
+
enable_websocket_compression: bool = False,
|
380
|
+
compression_threshold: int = 1024, # Only compress messages larger than 1KB
|
381
|
+
compression_level: int = 6, # 1 (fastest) to 9 (best compression)
|
372
382
|
):
|
373
383
|
"""
|
374
384
|
Initialize enhanced MCP server.
|
@@ -396,6 +406,11 @@ class MCPServer:
|
|
396
406
|
transport_timeout: Transport timeout in seconds
|
397
407
|
max_request_size: Maximum request size in bytes
|
398
408
|
enable_streaming: Enable streaming support
|
409
|
+
enable_subscriptions: Enable resource subscriptions (default: True)
|
410
|
+
event_store: Optional event store for subscription logging
|
411
|
+
enable_websocket_compression: Enable gzip compression for WebSocket messages (default: False)
|
412
|
+
compression_threshold: Only compress messages larger than this size in bytes (default: 1024)
|
413
|
+
compression_level: Compression level from 1 (fastest) to 9 (best compression) (default: 6)
|
399
414
|
"""
|
400
415
|
self.name = name
|
401
416
|
|
@@ -404,6 +419,11 @@ class MCPServer:
|
|
404
419
|
self.websocket_host = websocket_host
|
405
420
|
self.websocket_port = websocket_port
|
406
421
|
|
422
|
+
# WebSocket compression configuration
|
423
|
+
self.enable_websocket_compression = enable_websocket_compression
|
424
|
+
self.compression_threshold = compression_threshold
|
425
|
+
self.compression_level = compression_level
|
426
|
+
|
407
427
|
# Enhanced features
|
408
428
|
self.auth_provider = auth_provider
|
409
429
|
self.enable_http_transport = enable_http_transport
|
@@ -515,6 +535,27 @@ class MCPServer:
|
|
515
535
|
self._resource_registry: Dict[str, Dict[str, Any]] = {}
|
516
536
|
self._prompt_registry: Dict[str, Dict[str, Any]] = {}
|
517
537
|
|
538
|
+
# Client management for new handlers
|
539
|
+
self.client_info: Dict[str, Dict[str, Any]] = {}
|
540
|
+
self._pending_sampling_requests: Dict[str, Dict[str, Any]] = {}
|
541
|
+
|
542
|
+
# Resource subscription support
|
543
|
+
self.enable_subscriptions = enable_subscriptions
|
544
|
+
self.event_store = event_store
|
545
|
+
self.subscription_manager = None
|
546
|
+
if self.enable_subscriptions:
|
547
|
+
from .subscriptions import ResourceSubscriptionManager
|
548
|
+
|
549
|
+
self.subscription_manager = ResourceSubscriptionManager(
|
550
|
+
auth_manager=(
|
551
|
+
self.auth_manager if hasattr(self, "auth_manager") else None
|
552
|
+
),
|
553
|
+
event_store=event_store,
|
554
|
+
rate_limiter=(
|
555
|
+
self.rate_limiter if hasattr(self, "rate_limiter") else None
|
556
|
+
),
|
557
|
+
)
|
558
|
+
|
518
559
|
# Transport instance (for WebSocket and other transports)
|
519
560
|
self._transport = None
|
520
561
|
|
@@ -1643,6 +1684,13 @@ class MCPServer:
|
|
1643
1684
|
f"WebSocket server started on {self.websocket_host}:{self.websocket_port}"
|
1644
1685
|
)
|
1645
1686
|
|
1687
|
+
# Set up subscription notification callback
|
1688
|
+
if self.subscription_manager:
|
1689
|
+
await self.subscription_manager.initialize()
|
1690
|
+
self.subscription_manager.set_notification_callback(
|
1691
|
+
self._send_websocket_notification
|
1692
|
+
)
|
1693
|
+
|
1646
1694
|
# Keep server running
|
1647
1695
|
try:
|
1648
1696
|
await asyncio.Future() # Run forever
|
@@ -1658,18 +1706,21 @@ class MCPServer:
|
|
1658
1706
|
async def _handle_websocket_message(
|
1659
1707
|
self, request: Dict[str, Any], client_id: str
|
1660
1708
|
) -> Dict[str, Any]:
|
1661
|
-
"""Handle incoming WebSocket message."""
|
1709
|
+
"""Handle incoming WebSocket message with decompression support."""
|
1662
1710
|
try:
|
1663
|
-
|
1664
|
-
|
1665
|
-
|
1711
|
+
# Decompress message if needed
|
1712
|
+
decompressed_request = self._decompress_message(request)
|
1713
|
+
|
1714
|
+
method = decompressed_request.get("method", "")
|
1715
|
+
params = decompressed_request.get("params", {})
|
1716
|
+
request_id = decompressed_request.get("id")
|
1666
1717
|
|
1667
1718
|
# Log request
|
1668
1719
|
logger.debug(f"WebSocket request from {client_id}: {method}")
|
1669
1720
|
|
1670
1721
|
# Route to appropriate handler
|
1671
1722
|
if method == "initialize":
|
1672
|
-
return await self._handle_initialize(params, request_id)
|
1723
|
+
return await self._handle_initialize(params, request_id, client_id)
|
1673
1724
|
elif method == "tools/list":
|
1674
1725
|
return await self._handle_list_tools(params, request_id)
|
1675
1726
|
elif method == "tools/call":
|
@@ -1677,11 +1728,35 @@ class MCPServer:
|
|
1677
1728
|
elif method == "resources/list":
|
1678
1729
|
return await self._handle_list_resources(params, request_id)
|
1679
1730
|
elif method == "resources/read":
|
1680
|
-
return await self._handle_read_resource(params, request_id)
|
1731
|
+
return await self._handle_read_resource(params, request_id, client_id)
|
1732
|
+
elif method == "resources/subscribe":
|
1733
|
+
return await self._handle_subscribe(params, request_id, client_id)
|
1734
|
+
elif method == "resources/unsubscribe":
|
1735
|
+
return await self._handle_unsubscribe(params, request_id, client_id)
|
1736
|
+
elif method == "resources/batch_subscribe":
|
1737
|
+
return await self._handle_batch_subscribe(params, request_id, client_id)
|
1738
|
+
elif method == "resources/batch_unsubscribe":
|
1739
|
+
return await self._handle_batch_unsubscribe(
|
1740
|
+
params, request_id, client_id
|
1741
|
+
)
|
1681
1742
|
elif method == "prompts/list":
|
1682
1743
|
return await self._handle_list_prompts(params, request_id)
|
1683
1744
|
elif method == "prompts/get":
|
1684
1745
|
return await self._handle_get_prompt(params, request_id)
|
1746
|
+
elif method == "logging/setLevel":
|
1747
|
+
return await self._handle_logging_set_level(params, request_id)
|
1748
|
+
elif method == "roots/list":
|
1749
|
+
# Add client_id to params for roots/list handler
|
1750
|
+
params_with_client = {**params, "client_id": client_id}
|
1751
|
+
return await self._handle_roots_list(params_with_client, request_id)
|
1752
|
+
elif method == "completion/complete":
|
1753
|
+
return await self._handle_completion_complete(params, request_id)
|
1754
|
+
elif method == "sampling/createMessage":
|
1755
|
+
# Add client_id to params for sampling handler
|
1756
|
+
params_with_client = {**params, "client_id": client_id}
|
1757
|
+
return await self._handle_sampling_create_message(
|
1758
|
+
params_with_client, request_id
|
1759
|
+
)
|
1685
1760
|
else:
|
1686
1761
|
return {
|
1687
1762
|
"jsonrpc": "2.0",
|
@@ -1698,17 +1773,42 @@ class MCPServer:
|
|
1698
1773
|
}
|
1699
1774
|
|
1700
1775
|
async def _handle_initialize(
|
1701
|
-
self, params: Dict[str, Any], request_id: Any
|
1776
|
+
self, params: Dict[str, Any], request_id: Any, client_id: str = None
|
1702
1777
|
) -> Dict[str, Any]:
|
1703
1778
|
"""Handle initialize request."""
|
1779
|
+
# Store client information for capability checks
|
1780
|
+
if client_id:
|
1781
|
+
self.client_info[client_id] = {
|
1782
|
+
"capabilities": params.get("capabilities", {}),
|
1783
|
+
"name": params.get("clientInfo", {}).get("name", "unknown"),
|
1784
|
+
"version": params.get("clientInfo", {}).get("version", "unknown"),
|
1785
|
+
"initialized_at": time.time(),
|
1786
|
+
}
|
1787
|
+
|
1704
1788
|
return {
|
1705
1789
|
"jsonrpc": "2.0",
|
1706
1790
|
"result": {
|
1707
1791
|
"protocolVersion": "2024-11-05",
|
1708
1792
|
"capabilities": {
|
1709
1793
|
"tools": {"listSupported": True, "callSupported": True},
|
1710
|
-
"resources": {
|
1794
|
+
"resources": {
|
1795
|
+
"listSupported": True,
|
1796
|
+
"readSupported": True,
|
1797
|
+
"subscribe": self.enable_subscriptions,
|
1798
|
+
"listChanged": self.enable_subscriptions,
|
1799
|
+
"batch_subscribe": self.enable_subscriptions,
|
1800
|
+
"batch_unsubscribe": self.enable_subscriptions,
|
1801
|
+
},
|
1711
1802
|
"prompts": {"listSupported": True, "getSupported": True},
|
1803
|
+
"logging": {"setLevel": True},
|
1804
|
+
"roots": {"list": True},
|
1805
|
+
"experimental": {
|
1806
|
+
"progressNotifications": True,
|
1807
|
+
"cancellation": True,
|
1808
|
+
"completion": True,
|
1809
|
+
"sampling": True,
|
1810
|
+
"websocketCompression": self.enable_websocket_compression,
|
1811
|
+
},
|
1712
1812
|
},
|
1713
1813
|
"serverInfo": {
|
1714
1814
|
"name": self.name,
|
@@ -1764,10 +1864,14 @@ class MCPServer:
|
|
1764
1864
|
async def _handle_list_resources(
|
1765
1865
|
self, params: Dict[str, Any], request_id: Any
|
1766
1866
|
) -> Dict[str, Any]:
|
1767
|
-
"""Handle resources/list request."""
|
1768
|
-
|
1867
|
+
"""Handle resources/list request with cursor-based pagination."""
|
1868
|
+
cursor = params.get("cursor")
|
1869
|
+
limit = params.get("limit")
|
1870
|
+
|
1871
|
+
# Get all resources
|
1872
|
+
all_resources = []
|
1769
1873
|
for uri, info in self._resource_registry.items():
|
1770
|
-
|
1874
|
+
all_resources.append(
|
1771
1875
|
{
|
1772
1876
|
"uri": uri,
|
1773
1877
|
"name": info.get("name", uri),
|
@@ -1776,15 +1880,66 @@ class MCPServer:
|
|
1776
1880
|
}
|
1777
1881
|
)
|
1778
1882
|
|
1883
|
+
# Handle pagination if subscription manager is available
|
1884
|
+
if self.subscription_manager:
|
1885
|
+
cursor_manager = self.subscription_manager.cursor_manager
|
1886
|
+
|
1887
|
+
# Determine starting position
|
1888
|
+
start_pos = 0
|
1889
|
+
if cursor:
|
1890
|
+
if cursor_manager.is_valid(cursor):
|
1891
|
+
start_pos = cursor_manager.get_cursor_position(cursor) or 0
|
1892
|
+
else:
|
1893
|
+
return {
|
1894
|
+
"jsonrpc": "2.0",
|
1895
|
+
"error": {
|
1896
|
+
"code": -32602,
|
1897
|
+
"message": "Invalid or expired cursor",
|
1898
|
+
},
|
1899
|
+
"id": request_id,
|
1900
|
+
}
|
1901
|
+
|
1902
|
+
# Apply pagination
|
1903
|
+
if limit:
|
1904
|
+
end_pos = start_pos + limit
|
1905
|
+
resources = all_resources[start_pos:end_pos]
|
1906
|
+
|
1907
|
+
# Generate next cursor if there are more resources
|
1908
|
+
next_cursor = None
|
1909
|
+
if end_pos < len(all_resources):
|
1910
|
+
next_cursor = cursor_manager.create_cursor_for_position(
|
1911
|
+
all_resources, end_pos
|
1912
|
+
)
|
1913
|
+
|
1914
|
+
result = {"resources": resources}
|
1915
|
+
if next_cursor:
|
1916
|
+
result["nextCursor"] = next_cursor
|
1917
|
+
|
1918
|
+
return {"jsonrpc": "2.0", "result": result, "id": request_id}
|
1919
|
+
else:
|
1920
|
+
resources = all_resources[start_pos:]
|
1921
|
+
else:
|
1922
|
+
# No pagination support
|
1923
|
+
resources = all_resources
|
1924
|
+
|
1779
1925
|
return {"jsonrpc": "2.0", "result": {"resources": resources}, "id": request_id}
|
1780
1926
|
|
1781
1927
|
async def _handle_read_resource(
|
1782
|
-
self, params: Dict[str, Any], request_id: Any
|
1928
|
+
self, params: Dict[str, Any], request_id: Any, client_id: str = None
|
1783
1929
|
) -> Dict[str, Any]:
|
1784
|
-
"""Handle resources/read request."""
|
1930
|
+
"""Handle resources/read request with change detection."""
|
1785
1931
|
uri = params.get("uri")
|
1786
1932
|
|
1787
|
-
|
1933
|
+
# First try exact match
|
1934
|
+
resource_info = None
|
1935
|
+
resource_params = {}
|
1936
|
+
if uri in self._resource_registry:
|
1937
|
+
resource_info = self._resource_registry[uri]
|
1938
|
+
else:
|
1939
|
+
# Try template matching
|
1940
|
+
resource_info, resource_params = self._match_resource_template(uri)
|
1941
|
+
|
1942
|
+
if resource_info is None:
|
1788
1943
|
return {
|
1789
1944
|
"jsonrpc": "2.0",
|
1790
1945
|
"error": {"code": -32602, "message": f"Resource not found: {uri}"},
|
@@ -1792,16 +1947,38 @@ class MCPServer:
|
|
1792
1947
|
}
|
1793
1948
|
|
1794
1949
|
try:
|
1795
|
-
resource_info = self._resource_registry[uri]
|
1796
1950
|
handler = resource_info.get("handler")
|
1951
|
+
original_handler = resource_info.get("original_handler")
|
1797
1952
|
|
1798
1953
|
if handler:
|
1799
|
-
|
1954
|
+
# Use original handler with parameters if available
|
1955
|
+
if original_handler and resource_params:
|
1956
|
+
content = original_handler(**resource_params)
|
1957
|
+
else:
|
1958
|
+
content = handler()
|
1800
1959
|
if asyncio.iscoroutine(content):
|
1801
1960
|
content = await content
|
1802
1961
|
else:
|
1803
1962
|
content = ""
|
1804
1963
|
|
1964
|
+
# Process change detection if subscription manager is available
|
1965
|
+
if self.subscription_manager:
|
1966
|
+
resource_data = {
|
1967
|
+
"uri": uri,
|
1968
|
+
"text": str(content),
|
1969
|
+
"mimeType": resource_info.get("mime_type", "text/plain"),
|
1970
|
+
}
|
1971
|
+
|
1972
|
+
# Check for changes and notify subscribers
|
1973
|
+
change = (
|
1974
|
+
await self.subscription_manager.resource_monitor.check_for_changes(
|
1975
|
+
uri, resource_data
|
1976
|
+
)
|
1977
|
+
)
|
1978
|
+
|
1979
|
+
if change:
|
1980
|
+
await self.subscription_manager.process_resource_change(change)
|
1981
|
+
|
1805
1982
|
return {
|
1806
1983
|
"jsonrpc": "2.0",
|
1807
1984
|
"result": {"contents": [{"uri": uri, "text": str(content)}]},
|
@@ -1814,6 +1991,24 @@ class MCPServer:
|
|
1814
1991
|
"id": request_id,
|
1815
1992
|
}
|
1816
1993
|
|
1994
|
+
def _match_resource_template(self, uri: str) -> tuple:
|
1995
|
+
"""Match URI against resource templates and extract parameters."""
|
1996
|
+
import re
|
1997
|
+
|
1998
|
+
for template_uri, resource_info in self._resource_registry.items():
|
1999
|
+
# Convert template to regex pattern
|
2000
|
+
# Replace {param} with named capture groups
|
2001
|
+
pattern = re.sub(r"\{([^}]+)\}", r"(?P<\1>[^/]+)", template_uri)
|
2002
|
+
pattern = f"^{pattern}$"
|
2003
|
+
|
2004
|
+
match = re.match(pattern, uri)
|
2005
|
+
if match:
|
2006
|
+
# Extract parameters from the match
|
2007
|
+
params = match.groupdict()
|
2008
|
+
return resource_info, params
|
2009
|
+
|
2010
|
+
return None, {}
|
2011
|
+
|
1817
2012
|
async def _handle_list_prompts(
|
1818
2013
|
self, params: Dict[str, Any], request_id: Any
|
1819
2014
|
) -> Dict[str, Any]:
|
@@ -1870,6 +2065,541 @@ class MCPServer:
|
|
1870
2065
|
"id": request_id,
|
1871
2066
|
}
|
1872
2067
|
|
2068
|
+
async def _handle_subscribe(
|
2069
|
+
self, params: Dict[str, Any], request_id: Any, client_id: str
|
2070
|
+
) -> Dict[str, Any]:
|
2071
|
+
"""Handle resources/subscribe request with GraphQL-style field selection."""
|
2072
|
+
if not self.subscription_manager:
|
2073
|
+
return {
|
2074
|
+
"jsonrpc": "2.0",
|
2075
|
+
"error": {"code": -32601, "message": "Subscriptions not enabled"},
|
2076
|
+
"id": request_id,
|
2077
|
+
}
|
2078
|
+
|
2079
|
+
uri_pattern = params.get("uri")
|
2080
|
+
cursor = params.get("cursor")
|
2081
|
+
# Extract field selection parameters for GraphQL-style filtering
|
2082
|
+
fields = params.get("fields") # e.g., ["uri", "content.text", "metadata.size"]
|
2083
|
+
fragments = params.get("fragments") # e.g., {"basicInfo": ["uri", "name"]}
|
2084
|
+
|
2085
|
+
if not uri_pattern:
|
2086
|
+
return {
|
2087
|
+
"jsonrpc": "2.0",
|
2088
|
+
"error": {"code": -32602, "message": "Missing required parameter: uri"},
|
2089
|
+
"id": request_id,
|
2090
|
+
}
|
2091
|
+
|
2092
|
+
try:
|
2093
|
+
# Create subscription with auth context and field selection
|
2094
|
+
user_context = {"user_id": client_id, "connection_id": client_id}
|
2095
|
+
subscription_id = await self.subscription_manager.create_subscription(
|
2096
|
+
connection_id=client_id,
|
2097
|
+
uri_pattern=uri_pattern,
|
2098
|
+
cursor=cursor,
|
2099
|
+
user_context=user_context,
|
2100
|
+
fields=fields,
|
2101
|
+
fragments=fragments,
|
2102
|
+
)
|
2103
|
+
|
2104
|
+
return {
|
2105
|
+
"jsonrpc": "2.0",
|
2106
|
+
"result": {"subscriptionId": subscription_id},
|
2107
|
+
"id": request_id,
|
2108
|
+
}
|
2109
|
+
except Exception as e:
|
2110
|
+
error_code = -32603
|
2111
|
+
if "permission" in str(e).lower() or "not authorized" in str(e).lower():
|
2112
|
+
error_code = -32601
|
2113
|
+
elif "rate limit" in str(e).lower():
|
2114
|
+
error_code = -32601
|
2115
|
+
|
2116
|
+
return {
|
2117
|
+
"jsonrpc": "2.0",
|
2118
|
+
"error": {"code": error_code, "message": str(e)},
|
2119
|
+
"id": request_id,
|
2120
|
+
}
|
2121
|
+
|
2122
|
+
async def _handle_unsubscribe(
|
2123
|
+
self, params: Dict[str, Any], request_id: Any, client_id: str
|
2124
|
+
) -> Dict[str, Any]:
|
2125
|
+
"""Handle resources/unsubscribe request."""
|
2126
|
+
if not self.subscription_manager:
|
2127
|
+
return {
|
2128
|
+
"jsonrpc": "2.0",
|
2129
|
+
"error": {"code": -32601, "message": "Subscriptions not enabled"},
|
2130
|
+
"id": request_id,
|
2131
|
+
}
|
2132
|
+
|
2133
|
+
subscription_id = params.get("subscriptionId")
|
2134
|
+
|
2135
|
+
if not subscription_id:
|
2136
|
+
return {
|
2137
|
+
"jsonrpc": "2.0",
|
2138
|
+
"error": {
|
2139
|
+
"code": -32602,
|
2140
|
+
"message": "Missing required parameter: subscriptionId",
|
2141
|
+
},
|
2142
|
+
"id": request_id,
|
2143
|
+
}
|
2144
|
+
|
2145
|
+
try:
|
2146
|
+
success = await self.subscription_manager.remove_subscription(
|
2147
|
+
subscription_id, client_id
|
2148
|
+
)
|
2149
|
+
|
2150
|
+
return {
|
2151
|
+
"jsonrpc": "2.0",
|
2152
|
+
"result": {"success": success},
|
2153
|
+
"id": request_id,
|
2154
|
+
}
|
2155
|
+
except Exception as e:
|
2156
|
+
return {
|
2157
|
+
"jsonrpc": "2.0",
|
2158
|
+
"error": {"code": -32603, "message": str(e)},
|
2159
|
+
"id": request_id,
|
2160
|
+
}
|
2161
|
+
|
2162
|
+
async def _handle_batch_subscribe(
|
2163
|
+
self, params: Dict[str, Any], request_id: Any, client_id: str
|
2164
|
+
) -> Dict[str, Any]:
|
2165
|
+
"""Handle resources/batch_subscribe request."""
|
2166
|
+
if not self.subscription_manager:
|
2167
|
+
return {
|
2168
|
+
"jsonrpc": "2.0",
|
2169
|
+
"error": {"code": -32601, "message": "Subscriptions not enabled"},
|
2170
|
+
"id": request_id,
|
2171
|
+
}
|
2172
|
+
|
2173
|
+
subscriptions = params.get("subscriptions")
|
2174
|
+
if not subscriptions or not isinstance(subscriptions, list):
|
2175
|
+
return {
|
2176
|
+
"jsonrpc": "2.0",
|
2177
|
+
"error": {
|
2178
|
+
"code": -32602,
|
2179
|
+
"message": "Missing or invalid parameter: subscriptions",
|
2180
|
+
},
|
2181
|
+
"id": request_id,
|
2182
|
+
}
|
2183
|
+
|
2184
|
+
try:
|
2185
|
+
# Create batch subscriptions with auth context
|
2186
|
+
user_context = {"user_id": client_id, "connection_id": client_id}
|
2187
|
+
results = await self.subscription_manager.create_batch_subscriptions(
|
2188
|
+
subscriptions=subscriptions,
|
2189
|
+
connection_id=client_id,
|
2190
|
+
user_context=user_context,
|
2191
|
+
)
|
2192
|
+
|
2193
|
+
return {
|
2194
|
+
"jsonrpc": "2.0",
|
2195
|
+
"result": results,
|
2196
|
+
"id": request_id,
|
2197
|
+
}
|
2198
|
+
except Exception as e:
|
2199
|
+
return {
|
2200
|
+
"jsonrpc": "2.0",
|
2201
|
+
"error": {"code": -32603, "message": str(e)},
|
2202
|
+
"id": request_id,
|
2203
|
+
}
|
2204
|
+
|
2205
|
+
async def _handle_batch_unsubscribe(
|
2206
|
+
self, params: Dict[str, Any], request_id: Any, client_id: str
|
2207
|
+
) -> Dict[str, Any]:
|
2208
|
+
"""Handle resources/batch_unsubscribe request."""
|
2209
|
+
if not self.subscription_manager:
|
2210
|
+
return {
|
2211
|
+
"jsonrpc": "2.0",
|
2212
|
+
"error": {"code": -32601, "message": "Subscriptions not enabled"},
|
2213
|
+
"id": request_id,
|
2214
|
+
}
|
2215
|
+
|
2216
|
+
subscription_ids = params.get("subscriptionIds")
|
2217
|
+
if not subscription_ids or not isinstance(subscription_ids, list):
|
2218
|
+
return {
|
2219
|
+
"jsonrpc": "2.0",
|
2220
|
+
"error": {
|
2221
|
+
"code": -32602,
|
2222
|
+
"message": "Missing or invalid parameter: subscriptionIds",
|
2223
|
+
},
|
2224
|
+
"id": request_id,
|
2225
|
+
}
|
2226
|
+
|
2227
|
+
try:
|
2228
|
+
# Remove batch subscriptions
|
2229
|
+
results = await self.subscription_manager.remove_batch_subscriptions(
|
2230
|
+
subscription_ids=subscription_ids, connection_id=client_id
|
2231
|
+
)
|
2232
|
+
|
2233
|
+
return {
|
2234
|
+
"jsonrpc": "2.0",
|
2235
|
+
"result": results,
|
2236
|
+
"id": request_id,
|
2237
|
+
}
|
2238
|
+
except Exception as e:
|
2239
|
+
return {
|
2240
|
+
"jsonrpc": "2.0",
|
2241
|
+
"error": {"code": -32603, "message": str(e)},
|
2242
|
+
"id": request_id,
|
2243
|
+
}
|
2244
|
+
|
2245
|
+
async def _handle_connection_close(self, client_id: str):
|
2246
|
+
"""Handle WebSocket connection close."""
|
2247
|
+
if self.subscription_manager:
|
2248
|
+
removed_count = await self.subscription_manager.cleanup_connection(
|
2249
|
+
client_id
|
2250
|
+
)
|
2251
|
+
if removed_count > 0:
|
2252
|
+
logger.info(
|
2253
|
+
f"Cleaned up {removed_count} subscriptions for client {client_id}"
|
2254
|
+
)
|
2255
|
+
|
2256
|
+
def _compress_message(
|
2257
|
+
self, message: Dict[str, Any]
|
2258
|
+
) -> Union[Dict[str, Any], bytes]:
|
2259
|
+
"""Compress message if compression is enabled and message exceeds threshold.
|
2260
|
+
|
2261
|
+
Args:
|
2262
|
+
message: The message to potentially compress
|
2263
|
+
|
2264
|
+
Returns:
|
2265
|
+
Either the original dict or compressed bytes with metadata
|
2266
|
+
"""
|
2267
|
+
if not self.enable_websocket_compression:
|
2268
|
+
return message
|
2269
|
+
|
2270
|
+
# Serialize message to determine size
|
2271
|
+
message_json = json.dumps(message, separators=(",", ":")).encode("utf-8")
|
2272
|
+
|
2273
|
+
# Only compress if message exceeds threshold
|
2274
|
+
if len(message_json) < self.compression_threshold:
|
2275
|
+
return message
|
2276
|
+
|
2277
|
+
try:
|
2278
|
+
# Compress the message
|
2279
|
+
compressed_data = gzip.compress(
|
2280
|
+
message_json, compresslevel=self.compression_level
|
2281
|
+
)
|
2282
|
+
|
2283
|
+
# Calculate compression ratio
|
2284
|
+
compression_ratio = len(compressed_data) / len(message_json)
|
2285
|
+
|
2286
|
+
# Only use compression if it actually reduces size significantly
|
2287
|
+
if compression_ratio > 0.9: # Less than 10% improvement
|
2288
|
+
return message
|
2289
|
+
|
2290
|
+
# Return compressed message with metadata
|
2291
|
+
return {
|
2292
|
+
"__compressed": True,
|
2293
|
+
"__original_size": len(message_json),
|
2294
|
+
"__compressed_size": len(compressed_data),
|
2295
|
+
"__compression_ratio": compression_ratio,
|
2296
|
+
"data": compressed_data.hex(), # Hex encode for JSON transport
|
2297
|
+
}
|
2298
|
+
|
2299
|
+
except Exception as e:
|
2300
|
+
logger.warning(f"Failed to compress message: {e}")
|
2301
|
+
return message
|
2302
|
+
|
2303
|
+
def _decompress_message(self, compressed_message: Dict[str, Any]) -> Dict[str, Any]:
|
2304
|
+
"""Decompress a compressed message.
|
2305
|
+
|
2306
|
+
Args:
|
2307
|
+
compressed_message: The compressed message with metadata
|
2308
|
+
|
2309
|
+
Returns:
|
2310
|
+
The original decompressed message
|
2311
|
+
"""
|
2312
|
+
if not compressed_message.get("__compressed"):
|
2313
|
+
return compressed_message
|
2314
|
+
|
2315
|
+
try:
|
2316
|
+
# Decode hex data and decompress
|
2317
|
+
compressed_data = bytes.fromhex(compressed_message["data"])
|
2318
|
+
decompressed_json = gzip.decompress(compressed_data)
|
2319
|
+
|
2320
|
+
# Parse back to dict
|
2321
|
+
return json.loads(decompressed_json.decode("utf-8"))
|
2322
|
+
|
2323
|
+
except Exception as e:
|
2324
|
+
logger.error(f"Failed to decompress message: {e}")
|
2325
|
+
# Return a sensible error message
|
2326
|
+
return {
|
2327
|
+
"jsonrpc": "2.0",
|
2328
|
+
"error": {
|
2329
|
+
"code": -32603,
|
2330
|
+
"message": f"Failed to decompress message: {e}",
|
2331
|
+
},
|
2332
|
+
}
|
2333
|
+
|
2334
|
+
async def _send_websocket_notification(
|
2335
|
+
self, client_id: str, notification: Dict[str, Any]
|
2336
|
+
):
|
2337
|
+
"""Send notification to WebSocket client with optional compression."""
|
2338
|
+
if self._transport and hasattr(self._transport, "send_message"):
|
2339
|
+
try:
|
2340
|
+
# Apply compression if enabled
|
2341
|
+
message_to_send = self._compress_message(notification)
|
2342
|
+
|
2343
|
+
# Log compression stats if compression was applied
|
2344
|
+
if isinstance(message_to_send, dict) and message_to_send.get(
|
2345
|
+
"__compressed"
|
2346
|
+
):
|
2347
|
+
ratio = message_to_send["__compression_ratio"]
|
2348
|
+
logger.debug(
|
2349
|
+
f"Compressed notification for client {client_id}: "
|
2350
|
+
f"{message_to_send['__original_size']} -> "
|
2351
|
+
f"{message_to_send['__compressed_size']} bytes "
|
2352
|
+
f"({ratio:.2%} ratio)"
|
2353
|
+
)
|
2354
|
+
|
2355
|
+
await self._transport.send_message(message_to_send, client_id=client_id)
|
2356
|
+
logger.debug(
|
2357
|
+
f"Sent notification to client {client_id}: {notification['method']}"
|
2358
|
+
)
|
2359
|
+
except Exception as e:
|
2360
|
+
logger.error(f"Failed to send notification to client {client_id}: {e}")
|
2361
|
+
|
2362
|
+
async def _handle_logging_set_level(
|
2363
|
+
self, params: Dict[str, Any], request_id: Any
|
2364
|
+
) -> Dict[str, Any]:
|
2365
|
+
"""Handle logging/setLevel request to dynamically adjust log levels."""
|
2366
|
+
level = params.get("level", "INFO").upper()
|
2367
|
+
|
2368
|
+
# Validate log level
|
2369
|
+
valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
2370
|
+
if level not in valid_levels:
|
2371
|
+
return {
|
2372
|
+
"jsonrpc": "2.0",
|
2373
|
+
"error": {
|
2374
|
+
"code": -32602,
|
2375
|
+
"message": f"Invalid log level: {level}. Must be one of {valid_levels}",
|
2376
|
+
},
|
2377
|
+
"id": request_id,
|
2378
|
+
}
|
2379
|
+
|
2380
|
+
# Set the log level
|
2381
|
+
logging.getLogger().setLevel(getattr(logging, level))
|
2382
|
+
logger.info(f"Log level changed to {level}")
|
2383
|
+
|
2384
|
+
# Track in event store if available
|
2385
|
+
if self.event_store:
|
2386
|
+
from kailash.middleware.gateway.event_store import EventType
|
2387
|
+
|
2388
|
+
await self.event_store.append(
|
2389
|
+
event_type=EventType.REQUEST_COMPLETED,
|
2390
|
+
request_id=str(request_id),
|
2391
|
+
data={
|
2392
|
+
"type": "log_level_changed",
|
2393
|
+
"level": level,
|
2394
|
+
"timestamp": time.time(),
|
2395
|
+
"changed_by": params.get("client_id", "unknown"),
|
2396
|
+
},
|
2397
|
+
)
|
2398
|
+
|
2399
|
+
return {
|
2400
|
+
"jsonrpc": "2.0",
|
2401
|
+
"result": {"level": level, "levels": valid_levels},
|
2402
|
+
"id": request_id,
|
2403
|
+
}
|
2404
|
+
|
2405
|
+
async def _handle_roots_list(
|
2406
|
+
self, params: Dict[str, Any], request_id: Any
|
2407
|
+
) -> Dict[str, Any]:
|
2408
|
+
"""Handle roots/list request to get file system access roots."""
|
2409
|
+
protocol_mgr = get_protocol_manager()
|
2410
|
+
|
2411
|
+
# Check if client supports roots
|
2412
|
+
client_info = self.client_info.get(params.get("client_id", ""))
|
2413
|
+
if (
|
2414
|
+
not client_info.get("capabilities", {})
|
2415
|
+
.get("roots", {})
|
2416
|
+
.get("listChanged", False)
|
2417
|
+
):
|
2418
|
+
return {
|
2419
|
+
"jsonrpc": "2.0",
|
2420
|
+
"error": {
|
2421
|
+
"code": -32601,
|
2422
|
+
"message": "Client does not support roots capability",
|
2423
|
+
},
|
2424
|
+
"id": request_id,
|
2425
|
+
}
|
2426
|
+
|
2427
|
+
roots = protocol_mgr.roots.list_roots()
|
2428
|
+
|
2429
|
+
# Apply access control if auth manager is available
|
2430
|
+
if self.auth_manager and params.get("client_id"):
|
2431
|
+
filtered_roots = []
|
2432
|
+
for root in roots:
|
2433
|
+
if await protocol_mgr.roots.validate_access(
|
2434
|
+
root["uri"],
|
2435
|
+
operation="list",
|
2436
|
+
user_context=self.client_info.get(params["client_id"], {}),
|
2437
|
+
):
|
2438
|
+
filtered_roots.append(root)
|
2439
|
+
roots = filtered_roots
|
2440
|
+
|
2441
|
+
return {"jsonrpc": "2.0", "result": {"roots": roots}, "id": request_id}
|
2442
|
+
|
2443
|
+
async def _handle_completion_complete(
|
2444
|
+
self, params: Dict[str, Any], request_id: Any
|
2445
|
+
) -> Dict[str, Any]:
|
2446
|
+
"""Handle completion/complete request for auto-completion."""
|
2447
|
+
ref = params.get("ref", {})
|
2448
|
+
argument = params.get("argument", {})
|
2449
|
+
|
2450
|
+
# Extract completion parameters
|
2451
|
+
ref_type = ref.get("type") # "resource", "prompt", "tool"
|
2452
|
+
ref_name = ref.get("name") # Optional specific name
|
2453
|
+
partial_value = argument.get("value", "")
|
2454
|
+
|
2455
|
+
try:
|
2456
|
+
values = []
|
2457
|
+
|
2458
|
+
if ref_type == "resource":
|
2459
|
+
# Search through registered resources
|
2460
|
+
for uri, resource_info in self._resource_registry.items():
|
2461
|
+
if partial_value in uri: # Simple prefix/substring matching
|
2462
|
+
values.append(
|
2463
|
+
{
|
2464
|
+
"uri": uri,
|
2465
|
+
"name": resource_info.get("name", uri),
|
2466
|
+
"description": resource_info.get("description", ""),
|
2467
|
+
}
|
2468
|
+
)
|
2469
|
+
|
2470
|
+
elif ref_type == "prompt":
|
2471
|
+
# Search through registered prompts
|
2472
|
+
for name, prompt_info in self._prompt_registry.items():
|
2473
|
+
if partial_value in name: # Simple prefix/substring matching
|
2474
|
+
values.append(
|
2475
|
+
{
|
2476
|
+
"name": name,
|
2477
|
+
"description": prompt_info.get("description", ""),
|
2478
|
+
"arguments": prompt_info.get("arguments", []),
|
2479
|
+
}
|
2480
|
+
)
|
2481
|
+
|
2482
|
+
elif ref_type == "tool":
|
2483
|
+
# Search through registered tools
|
2484
|
+
for name, tool_info in self._tool_registry.items():
|
2485
|
+
if partial_value in name:
|
2486
|
+
values.append(
|
2487
|
+
{
|
2488
|
+
"name": name,
|
2489
|
+
"description": tool_info.get("description", ""),
|
2490
|
+
"inputSchema": tool_info.get("inputSchema", {}),
|
2491
|
+
}
|
2492
|
+
)
|
2493
|
+
|
2494
|
+
# Limit to 100 items and add hasMore flag if needed
|
2495
|
+
total_matches = len(values)
|
2496
|
+
has_more = total_matches > 100
|
2497
|
+
if has_more:
|
2498
|
+
values = values[:100]
|
2499
|
+
|
2500
|
+
result = {
|
2501
|
+
"completion": {
|
2502
|
+
"values": values,
|
2503
|
+
"total": total_matches,
|
2504
|
+
"hasMore": has_more,
|
2505
|
+
}
|
2506
|
+
}
|
2507
|
+
|
2508
|
+
return {"jsonrpc": "2.0", "result": result, "id": request_id}
|
2509
|
+
|
2510
|
+
except Exception as e:
|
2511
|
+
logger.error(f"Completion error: {e}")
|
2512
|
+
return {
|
2513
|
+
"jsonrpc": "2.0",
|
2514
|
+
"error": {"code": -32603, "message": f"Completion failed: {str(e)}"},
|
2515
|
+
"id": request_id,
|
2516
|
+
}
|
2517
|
+
|
2518
|
+
async def _handle_sampling_create_message(
|
2519
|
+
self, params: Dict[str, Any], request_id: Any
|
2520
|
+
) -> Dict[str, Any]:
|
2521
|
+
"""Handle sampling/createMessage - this is typically server-to-client."""
|
2522
|
+
# This is usually initiated by the server to request LLM sampling from the client
|
2523
|
+
# For server-side handling, we can validate and forward to connected clients
|
2524
|
+
|
2525
|
+
protocol_mgr = get_protocol_manager()
|
2526
|
+
|
2527
|
+
# Check if any client supports sampling
|
2528
|
+
sampling_clients = [
|
2529
|
+
client_id
|
2530
|
+
for client_id, info in self.client_info.items()
|
2531
|
+
if info.get("capabilities", {})
|
2532
|
+
.get("experimental", {})
|
2533
|
+
.get("sampling", False)
|
2534
|
+
]
|
2535
|
+
|
2536
|
+
if not sampling_clients:
|
2537
|
+
return {
|
2538
|
+
"jsonrpc": "2.0",
|
2539
|
+
"error": {
|
2540
|
+
"code": -32601,
|
2541
|
+
"message": "No connected clients support sampling",
|
2542
|
+
},
|
2543
|
+
"id": request_id,
|
2544
|
+
}
|
2545
|
+
|
2546
|
+
# Create sampling request
|
2547
|
+
messages = params.get("messages", [])
|
2548
|
+
sampling_params = {
|
2549
|
+
"messages": messages,
|
2550
|
+
"model_preferences": params.get("modelPreferences"),
|
2551
|
+
"system_prompt": params.get("systemPrompt"),
|
2552
|
+
"temperature": params.get("temperature"),
|
2553
|
+
"max_tokens": params.get("maxTokens"),
|
2554
|
+
"metadata": params.get("metadata"),
|
2555
|
+
}
|
2556
|
+
|
2557
|
+
# Send to first available sampling client (or implement selection logic)
|
2558
|
+
target_client = sampling_clients[0]
|
2559
|
+
|
2560
|
+
# Create server-to-client request
|
2561
|
+
sampling_request = {
|
2562
|
+
"jsonrpc": "2.0",
|
2563
|
+
"method": "sampling/createMessage",
|
2564
|
+
"params": sampling_params,
|
2565
|
+
"id": f"sampling_{uuid.uuid4().hex[:8]}",
|
2566
|
+
}
|
2567
|
+
|
2568
|
+
# Send via WebSocket to client
|
2569
|
+
if self._transport and hasattr(self._transport, "send_message"):
|
2570
|
+
await self._transport.send_message(
|
2571
|
+
sampling_request, client_id=target_client
|
2572
|
+
)
|
2573
|
+
|
2574
|
+
# Store pending sampling request
|
2575
|
+
if not hasattr(self, "_pending_sampling_requests"):
|
2576
|
+
self._pending_sampling_requests = {}
|
2577
|
+
|
2578
|
+
self._pending_sampling_requests[sampling_request["id"]] = {
|
2579
|
+
"original_request_id": request_id,
|
2580
|
+
"client_id": params.get("client_id"),
|
2581
|
+
"timestamp": time.time(),
|
2582
|
+
}
|
2583
|
+
|
2584
|
+
return {
|
2585
|
+
"jsonrpc": "2.0",
|
2586
|
+
"result": {
|
2587
|
+
"status": "sampling_requested",
|
2588
|
+
"sampling_id": sampling_request["id"],
|
2589
|
+
"target_client": target_client,
|
2590
|
+
},
|
2591
|
+
"id": request_id,
|
2592
|
+
}
|
2593
|
+
else:
|
2594
|
+
return {
|
2595
|
+
"jsonrpc": "2.0",
|
2596
|
+
"error": {
|
2597
|
+
"code": -32603,
|
2598
|
+
"message": "Transport does not support sampling",
|
2599
|
+
},
|
2600
|
+
"id": request_id,
|
2601
|
+
}
|
2602
|
+
|
1873
2603
|
async def run_stdio(self):
|
1874
2604
|
"""Run the server using stdio transport for testing."""
|
1875
2605
|
if self._mcp is None:
|