kailash 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +3 -3
- kailash/api/custom_nodes_secure.py +3 -3
- kailash/api/gateway.py +1 -1
- kailash/api/studio.py +1 -1
- kailash/api/workflow_api.py +2 -2
- kailash/core/resilience/bulkhead.py +475 -0
- kailash/core/resilience/circuit_breaker.py +92 -10
- kailash/core/resilience/health_monitor.py +578 -0
- kailash/edge/discovery.py +86 -0
- kailash/mcp_server/__init__.py +309 -33
- kailash/mcp_server/advanced_features.py +1022 -0
- kailash/mcp_server/ai_registry_server.py +27 -2
- kailash/mcp_server/auth.py +789 -0
- kailash/mcp_server/client.py +645 -378
- kailash/mcp_server/discovery.py +1593 -0
- kailash/mcp_server/errors.py +673 -0
- kailash/mcp_server/oauth.py +1727 -0
- kailash/mcp_server/protocol.py +1126 -0
- kailash/mcp_server/registry_integration.py +587 -0
- kailash/mcp_server/server.py +1228 -96
- kailash/mcp_server/transports.py +1169 -0
- kailash/mcp_server/utils/__init__.py +6 -1
- kailash/mcp_server/utils/cache.py +250 -7
- kailash/middleware/auth/auth_manager.py +3 -3
- kailash/middleware/communication/api_gateway.py +1 -1
- kailash/middleware/communication/realtime.py +1 -1
- kailash/middleware/mcp/enhanced_server.py +1 -1
- kailash/nodes/__init__.py +2 -0
- kailash/nodes/admin/audit_log.py +6 -6
- kailash/nodes/admin/permission_check.py +8 -8
- kailash/nodes/admin/role_management.py +32 -28
- kailash/nodes/admin/schema.sql +6 -1
- kailash/nodes/admin/schema_manager.py +13 -13
- kailash/nodes/admin/security_event.py +15 -15
- kailash/nodes/admin/tenant_isolation.py +3 -3
- kailash/nodes/admin/transaction_utils.py +3 -3
- kailash/nodes/admin/user_management.py +21 -21
- kailash/nodes/ai/a2a.py +11 -11
- kailash/nodes/ai/ai_providers.py +9 -12
- kailash/nodes/ai/embedding_generator.py +13 -14
- kailash/nodes/ai/intelligent_agent_orchestrator.py +19 -19
- kailash/nodes/ai/iterative_llm_agent.py +2 -2
- kailash/nodes/ai/llm_agent.py +210 -33
- kailash/nodes/ai/self_organizing.py +2 -2
- kailash/nodes/alerts/discord.py +4 -4
- kailash/nodes/api/graphql.py +6 -6
- kailash/nodes/api/http.py +10 -10
- kailash/nodes/api/rate_limiting.py +4 -4
- kailash/nodes/api/rest.py +15 -15
- kailash/nodes/auth/mfa.py +3 -3
- kailash/nodes/auth/risk_assessment.py +2 -2
- kailash/nodes/auth/session_management.py +5 -5
- kailash/nodes/auth/sso.py +143 -0
- kailash/nodes/base.py +8 -2
- kailash/nodes/base_async.py +16 -2
- kailash/nodes/base_with_acl.py +2 -2
- kailash/nodes/cache/__init__.py +9 -0
- kailash/nodes/cache/cache.py +1172 -0
- kailash/nodes/cache/cache_invalidation.py +874 -0
- kailash/nodes/cache/redis_pool_manager.py +595 -0
- kailash/nodes/code/async_python.py +2 -1
- kailash/nodes/code/python.py +194 -30
- kailash/nodes/compliance/data_retention.py +6 -6
- kailash/nodes/compliance/gdpr.py +5 -5
- kailash/nodes/data/__init__.py +10 -0
- kailash/nodes/data/async_sql.py +1956 -129
- kailash/nodes/data/optimistic_locking.py +906 -0
- kailash/nodes/data/readers.py +8 -8
- kailash/nodes/data/redis.py +378 -0
- kailash/nodes/data/sql.py +314 -3
- kailash/nodes/data/streaming.py +21 -0
- kailash/nodes/enterprise/__init__.py +8 -0
- kailash/nodes/enterprise/audit_logger.py +285 -0
- kailash/nodes/enterprise/batch_processor.py +22 -3
- kailash/nodes/enterprise/data_lineage.py +1 -1
- kailash/nodes/enterprise/mcp_executor.py +205 -0
- kailash/nodes/enterprise/service_discovery.py +150 -0
- kailash/nodes/enterprise/tenant_assignment.py +108 -0
- kailash/nodes/logic/async_operations.py +2 -2
- kailash/nodes/logic/convergence.py +1 -1
- kailash/nodes/logic/operations.py +1 -1
- kailash/nodes/monitoring/__init__.py +11 -1
- kailash/nodes/monitoring/health_check.py +456 -0
- kailash/nodes/monitoring/log_processor.py +817 -0
- kailash/nodes/monitoring/metrics_collector.py +627 -0
- kailash/nodes/monitoring/performance_benchmark.py +137 -11
- kailash/nodes/rag/advanced.py +7 -7
- kailash/nodes/rag/agentic.py +49 -2
- kailash/nodes/rag/conversational.py +3 -3
- kailash/nodes/rag/evaluation.py +3 -3
- kailash/nodes/rag/federated.py +3 -3
- kailash/nodes/rag/graph.py +3 -3
- kailash/nodes/rag/multimodal.py +3 -3
- kailash/nodes/rag/optimized.py +5 -5
- kailash/nodes/rag/privacy.py +3 -3
- kailash/nodes/rag/query_processing.py +6 -6
- kailash/nodes/rag/realtime.py +1 -1
- kailash/nodes/rag/registry.py +1 -1
- kailash/nodes/rag/router.py +1 -1
- kailash/nodes/rag/similarity.py +7 -7
- kailash/nodes/rag/strategies.py +4 -4
- kailash/nodes/security/abac_evaluator.py +6 -6
- kailash/nodes/security/behavior_analysis.py +5 -5
- kailash/nodes/security/credential_manager.py +1 -1
- kailash/nodes/security/rotating_credentials.py +11 -11
- kailash/nodes/security/threat_detection.py +8 -8
- kailash/nodes/testing/credential_testing.py +2 -2
- kailash/nodes/transform/processors.py +5 -5
- kailash/runtime/local.py +163 -9
- kailash/runtime/parameter_injection.py +425 -0
- kailash/runtime/parameter_injector.py +657 -0
- kailash/runtime/testing.py +2 -2
- kailash/testing/fixtures.py +2 -2
- kailash/workflow/builder.py +99 -14
- kailash/workflow/builder_improvements.py +207 -0
- kailash/workflow/input_handling.py +170 -0
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/METADATA +22 -9
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/RECORD +122 -95
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/WHEEL +0 -0
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/entry_points.txt +0 -0
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1169 @@
|
|
1
|
+
"""
|
2
|
+
Complete MCP Transport Implementations.
|
3
|
+
|
4
|
+
This module provides comprehensive transport layer implementations for MCP,
|
5
|
+
including enhanced STDIO, SSE, StreamableHTTP, and WebSocket transports.
|
6
|
+
All implementations build on the official MCP Python SDK while adding
|
7
|
+
production-ready features like security, connection management, and monitoring.
|
8
|
+
|
9
|
+
Features:
|
10
|
+
- Enhanced STDIO transport with proper process management
|
11
|
+
- Complete SSE transport with endpoint negotiation
|
12
|
+
- StreamableHTTP transport with session management
|
13
|
+
- WebSocket transport for real-time communication
|
14
|
+
- Transport security and validation
|
15
|
+
- Connection pooling and management
|
16
|
+
- Health checking and monitoring
|
17
|
+
|
18
|
+
Examples:
|
19
|
+
Enhanced STDIO transport:
|
20
|
+
|
21
|
+
>>> from kailash.mcp_server.transports import EnhancedStdioTransport
|
22
|
+
>>> transport = EnhancedStdioTransport(
|
23
|
+
... command="python",
|
24
|
+
... args=["-m", "my_mcp_server"],
|
25
|
+
... environment_filter=["PATH", "PYTHONPATH"]
|
26
|
+
... )
|
27
|
+
>>> async with transport:
|
28
|
+
... session = await transport.create_session()
|
29
|
+
|
30
|
+
SSE transport with security:
|
31
|
+
|
32
|
+
>>> from kailash.mcp_server.transports import SSETransport
|
33
|
+
>>> transport = SSETransport(
|
34
|
+
... base_url="https://api.example.com/mcp",
|
35
|
+
... auth_header="Bearer token123",
|
36
|
+
... validate_origin=True
|
37
|
+
... )
|
38
|
+
|
39
|
+
StreamableHTTP with session management:
|
40
|
+
|
41
|
+
>>> from kailash.mcp_server.transports import StreamableHTTPTransport
|
42
|
+
>>> transport = StreamableHTTPTransport(
|
43
|
+
... base_url="https://api.example.com/mcp",
|
44
|
+
... session_management=True,
|
45
|
+
... streaming_threshold=1024
|
46
|
+
... )
|
47
|
+
"""
|
48
|
+
|
49
|
+
import asyncio
|
50
|
+
import json
|
51
|
+
import logging
|
52
|
+
import os
|
53
|
+
import platform
|
54
|
+
import signal
|
55
|
+
import socket
|
56
|
+
import subprocess
|
57
|
+
import time
|
58
|
+
import uuid
|
59
|
+
import weakref
|
60
|
+
from abc import ABC, abstractmethod
|
61
|
+
from contextlib import AsyncExitStack
|
62
|
+
from pathlib import Path
|
63
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
64
|
+
from urllib.parse import urljoin, urlparse
|
65
|
+
|
66
|
+
import aiohttp
|
67
|
+
import websockets
|
68
|
+
|
69
|
+
from .auth import AuthProvider
|
70
|
+
from .errors import MCPError, MCPErrorCode, TransportError
|
71
|
+
from .protocol import MetaData, ProtocolManager
|
72
|
+
|
73
|
+
logger = logging.getLogger(__name__)
|
74
|
+
|
75
|
+
|
76
|
+
class TransportSecurity:
|
77
|
+
"""Security utilities for MCP transports."""
|
78
|
+
|
79
|
+
ALLOWED_SCHEMES = {"http", "https", "ws", "wss"}
|
80
|
+
BLOCKED_HOSTS = {
|
81
|
+
"169.254.169.254",
|
82
|
+
"localhost",
|
83
|
+
"127.0.0.1",
|
84
|
+
} # Basic DNS rebinding protection
|
85
|
+
|
86
|
+
@classmethod
|
87
|
+
def validate_url(cls, url: str, allow_localhost: bool = False) -> bool:
|
88
|
+
"""Validate URL for security.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
url: URL to validate
|
92
|
+
allow_localhost: Allow localhost connections
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
True if URL is safe
|
96
|
+
"""
|
97
|
+
try:
|
98
|
+
parsed = urlparse(url)
|
99
|
+
|
100
|
+
# Check scheme
|
101
|
+
if parsed.scheme not in cls.ALLOWED_SCHEMES:
|
102
|
+
logger.warning(f"Blocked unsafe scheme: {parsed.scheme}")
|
103
|
+
return False
|
104
|
+
|
105
|
+
# Check host
|
106
|
+
if not allow_localhost and parsed.hostname in cls.BLOCKED_HOSTS:
|
107
|
+
logger.warning(f"Blocked potentially unsafe host: {parsed.hostname}")
|
108
|
+
return False
|
109
|
+
|
110
|
+
# Check for IP address patterns that could be exploited
|
111
|
+
if parsed.hostname and parsed.hostname.startswith("0."):
|
112
|
+
logger.warning(f"Blocked potentially unsafe IP: {parsed.hostname}")
|
113
|
+
return False
|
114
|
+
|
115
|
+
return True
|
116
|
+
|
117
|
+
except Exception as e:
|
118
|
+
logger.error(f"URL validation error: {e}")
|
119
|
+
return False
|
120
|
+
|
121
|
+
@classmethod
|
122
|
+
def validate_origin(cls, origin: str, expected_origins: List[str]) -> bool:
|
123
|
+
"""Validate request origin.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
origin: Request origin
|
127
|
+
expected_origins: List of allowed origins
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
True if origin is allowed
|
131
|
+
"""
|
132
|
+
if not origin:
|
133
|
+
return False
|
134
|
+
|
135
|
+
# Exact match
|
136
|
+
if origin in expected_origins:
|
137
|
+
return True
|
138
|
+
|
139
|
+
# Wildcard patterns
|
140
|
+
for expected in expected_origins:
|
141
|
+
if "*" in expected:
|
142
|
+
# Convert wildcard pattern to regex
|
143
|
+
import re
|
144
|
+
|
145
|
+
pattern = expected.replace(".", r"\.").replace("*", ".*")
|
146
|
+
if re.match(f"^{pattern}$", origin):
|
147
|
+
return True
|
148
|
+
|
149
|
+
return False
|
150
|
+
|
151
|
+
|
152
|
+
class BaseTransport(ABC):
|
153
|
+
"""Base class for MCP transports."""
|
154
|
+
|
155
|
+
def __init__(
|
156
|
+
self,
|
157
|
+
name: str,
|
158
|
+
auth_provider: Optional[AuthProvider] = None,
|
159
|
+
timeout: float = 30.0,
|
160
|
+
max_retries: int = 3,
|
161
|
+
enable_metrics: bool = True,
|
162
|
+
):
|
163
|
+
"""Initialize base transport.
|
164
|
+
|
165
|
+
Args:
|
166
|
+
name: Transport name
|
167
|
+
auth_provider: Authentication provider
|
168
|
+
timeout: Connection timeout
|
169
|
+
max_retries: Maximum retry attempts
|
170
|
+
enable_metrics: Enable metrics collection
|
171
|
+
"""
|
172
|
+
self.name = name
|
173
|
+
self.auth_provider = auth_provider
|
174
|
+
self.timeout = timeout
|
175
|
+
self.max_retries = max_retries
|
176
|
+
self.enable_metrics = enable_metrics
|
177
|
+
|
178
|
+
# State
|
179
|
+
self._connected = False
|
180
|
+
self._sessions: Set[weakref.ref] = set()
|
181
|
+
self._metrics: Dict[str, Any] = {}
|
182
|
+
|
183
|
+
# Initialize metrics
|
184
|
+
if enable_metrics:
|
185
|
+
self._metrics = {
|
186
|
+
"connections_total": 0,
|
187
|
+
"connections_failed": 0,
|
188
|
+
"messages_sent": 0,
|
189
|
+
"messages_received": 0,
|
190
|
+
"bytes_sent": 0,
|
191
|
+
"bytes_received": 0,
|
192
|
+
"errors_total": 0,
|
193
|
+
"start_time": time.time(),
|
194
|
+
}
|
195
|
+
|
196
|
+
@abstractmethod
|
197
|
+
async def connect(self) -> None:
|
198
|
+
"""Connect the transport."""
|
199
|
+
pass
|
200
|
+
|
201
|
+
@abstractmethod
|
202
|
+
async def disconnect(self) -> None:
|
203
|
+
"""Disconnect the transport."""
|
204
|
+
pass
|
205
|
+
|
206
|
+
@abstractmethod
|
207
|
+
async def send_message(self, message: Dict[str, Any]) -> None:
|
208
|
+
"""Send a message."""
|
209
|
+
pass
|
210
|
+
|
211
|
+
@abstractmethod
|
212
|
+
async def receive_message(self) -> Dict[str, Any]:
|
213
|
+
"""Receive a message."""
|
214
|
+
pass
|
215
|
+
|
216
|
+
async def __aenter__(self):
|
217
|
+
"""Async context manager entry."""
|
218
|
+
await self.connect()
|
219
|
+
return self
|
220
|
+
|
221
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
222
|
+
"""Async context manager exit."""
|
223
|
+
await self.disconnect()
|
224
|
+
|
225
|
+
def is_connected(self) -> bool:
|
226
|
+
"""Check if transport is connected."""
|
227
|
+
return self._connected
|
228
|
+
|
229
|
+
def get_metrics(self) -> Dict[str, Any]:
|
230
|
+
"""Get transport metrics."""
|
231
|
+
if not self.enable_metrics:
|
232
|
+
return {}
|
233
|
+
|
234
|
+
metrics = self._metrics.copy()
|
235
|
+
metrics["uptime"] = time.time() - metrics["start_time"]
|
236
|
+
metrics["active_sessions"] = len(self._sessions)
|
237
|
+
return metrics
|
238
|
+
|
239
|
+
def _update_metrics(self, metric: str, value: Union[int, float] = 1):
|
240
|
+
"""Update metrics."""
|
241
|
+
if self.enable_metrics and metric in self._metrics:
|
242
|
+
self._metrics[metric] += value
|
243
|
+
|
244
|
+
|
245
|
+
class EnhancedStdioTransport(BaseTransport):
|
246
|
+
"""Enhanced STDIO transport with proper process management."""
|
247
|
+
|
248
|
+
def __init__(
|
249
|
+
self,
|
250
|
+
command: str,
|
251
|
+
args: Optional[List[str]] = None,
|
252
|
+
env: Optional[Dict[str, str]] = None,
|
253
|
+
working_directory: Optional[str] = None,
|
254
|
+
environment_filter: Optional[List[str]] = None,
|
255
|
+
**kwargs,
|
256
|
+
):
|
257
|
+
"""Initialize enhanced STDIO transport.
|
258
|
+
|
259
|
+
Args:
|
260
|
+
command: Command to execute
|
261
|
+
args: Command arguments
|
262
|
+
env: Environment variables
|
263
|
+
working_directory: Working directory
|
264
|
+
environment_filter: Allowed environment variables
|
265
|
+
**kwargs: Base transport arguments
|
266
|
+
"""
|
267
|
+
super().__init__("stdio", **kwargs)
|
268
|
+
|
269
|
+
self.command = command
|
270
|
+
self.args = args or []
|
271
|
+
self.env = env or {}
|
272
|
+
self.working_directory = working_directory
|
273
|
+
self.environment_filter = environment_filter
|
274
|
+
|
275
|
+
# Process management
|
276
|
+
self.process: Optional[asyncio.subprocess.Process] = None
|
277
|
+
self._read_task: Optional[asyncio.Task] = None
|
278
|
+
self._write_queue: asyncio.Queue = asyncio.Queue()
|
279
|
+
self._message_buffer: List[str] = []
|
280
|
+
|
281
|
+
async def connect(self) -> None:
|
282
|
+
"""Start the subprocess and connect."""
|
283
|
+
if self._connected:
|
284
|
+
return
|
285
|
+
|
286
|
+
try:
|
287
|
+
# Prepare environment
|
288
|
+
process_env = self._prepare_environment()
|
289
|
+
|
290
|
+
# Start process
|
291
|
+
self.process = await asyncio.create_subprocess_exec(
|
292
|
+
self.command,
|
293
|
+
*self.args,
|
294
|
+
stdin=asyncio.subprocess.PIPE,
|
295
|
+
stdout=asyncio.subprocess.PIPE,
|
296
|
+
stderr=asyncio.subprocess.PIPE,
|
297
|
+
env=process_env,
|
298
|
+
cwd=self.working_directory,
|
299
|
+
)
|
300
|
+
|
301
|
+
# Start I/O tasks
|
302
|
+
self._read_task = asyncio.create_task(self._read_loop())
|
303
|
+
|
304
|
+
self._connected = True
|
305
|
+
self._update_metrics("connections_total")
|
306
|
+
|
307
|
+
logger.info(f"STDIO transport connected: {self.command}")
|
308
|
+
|
309
|
+
except Exception as e:
|
310
|
+
self._update_metrics("connections_failed")
|
311
|
+
raise TransportError(
|
312
|
+
f"Failed to start process: {e}", transport_type="stdio"
|
313
|
+
)
|
314
|
+
|
315
|
+
async def disconnect(self) -> None:
|
316
|
+
"""Terminate the subprocess."""
|
317
|
+
if not self._connected:
|
318
|
+
return
|
319
|
+
|
320
|
+
self._connected = False
|
321
|
+
|
322
|
+
# Cancel read task
|
323
|
+
if self._read_task and not self._read_task.done():
|
324
|
+
self._read_task.cancel()
|
325
|
+
try:
|
326
|
+
await self._read_task
|
327
|
+
except asyncio.CancelledError:
|
328
|
+
pass
|
329
|
+
|
330
|
+
# Terminate process
|
331
|
+
if self.process:
|
332
|
+
try:
|
333
|
+
# Try graceful termination first
|
334
|
+
self.process.terminate()
|
335
|
+
|
336
|
+
# Wait with timeout
|
337
|
+
try:
|
338
|
+
await asyncio.wait_for(self.process.wait(), timeout=5.0)
|
339
|
+
except asyncio.TimeoutError:
|
340
|
+
# Force kill if needed
|
341
|
+
if platform.system() != "Windows":
|
342
|
+
self.process.kill()
|
343
|
+
await self.process.wait()
|
344
|
+
|
345
|
+
except Exception as e:
|
346
|
+
logger.error(f"Error terminating process: {e}")
|
347
|
+
|
348
|
+
finally:
|
349
|
+
self.process = None
|
350
|
+
|
351
|
+
logger.info("STDIO transport disconnected")
|
352
|
+
|
353
|
+
async def send_message(self, message: Dict[str, Any]) -> None:
|
354
|
+
"""Send message to subprocess."""
|
355
|
+
if not self._connected or not self.process:
|
356
|
+
raise TransportError("Transport not connected", transport_type="stdio")
|
357
|
+
|
358
|
+
try:
|
359
|
+
# Serialize message
|
360
|
+
message_data = json.dumps(message) + "\n"
|
361
|
+
message_bytes = message_data.encode("utf-8")
|
362
|
+
|
363
|
+
# Send to subprocess
|
364
|
+
self.process.stdin.write(message_bytes)
|
365
|
+
await self.process.stdin.drain()
|
366
|
+
|
367
|
+
self._update_metrics("messages_sent")
|
368
|
+
self._update_metrics("bytes_sent", len(message_bytes))
|
369
|
+
|
370
|
+
except Exception as e:
|
371
|
+
self._update_metrics("errors_total")
|
372
|
+
raise TransportError(f"Failed to send message: {e}", transport_type="stdio")
|
373
|
+
|
374
|
+
async def receive_message(self) -> Dict[str, Any]:
|
375
|
+
"""Receive message from subprocess."""
|
376
|
+
if not self._connected:
|
377
|
+
raise TransportError("Transport not connected", transport_type="stdio")
|
378
|
+
|
379
|
+
try:
|
380
|
+
# Wait for message in buffer
|
381
|
+
while not self._message_buffer:
|
382
|
+
await asyncio.sleep(0.01) # Small delay to prevent busy waiting
|
383
|
+
|
384
|
+
if not self._connected:
|
385
|
+
raise TransportError(
|
386
|
+
"Transport disconnected", transport_type="stdio"
|
387
|
+
)
|
388
|
+
|
389
|
+
# Get message from buffer
|
390
|
+
message_data = self._message_buffer.pop(0)
|
391
|
+
message = json.loads(message_data)
|
392
|
+
|
393
|
+
self._update_metrics("messages_received")
|
394
|
+
self._update_metrics("bytes_received", len(message_data))
|
395
|
+
|
396
|
+
return message
|
397
|
+
|
398
|
+
except json.JSONDecodeError as e:
|
399
|
+
self._update_metrics("errors_total")
|
400
|
+
raise TransportError(f"Invalid JSON received: {e}", transport_type="stdio")
|
401
|
+
except Exception as e:
|
402
|
+
self._update_metrics("errors_total")
|
403
|
+
raise TransportError(
|
404
|
+
f"Failed to receive message: {e}", transport_type="stdio"
|
405
|
+
)
|
406
|
+
|
407
|
+
async def _read_loop(self):
|
408
|
+
"""Background task to read from subprocess."""
|
409
|
+
if not self.process:
|
410
|
+
return
|
411
|
+
|
412
|
+
try:
|
413
|
+
while self._connected and self.process:
|
414
|
+
# Read line from stdout
|
415
|
+
line = await self.process.stdout.readline()
|
416
|
+
|
417
|
+
if not line:
|
418
|
+
# Process ended
|
419
|
+
break
|
420
|
+
|
421
|
+
# Decode and strip
|
422
|
+
line_str = line.decode("utf-8").strip()
|
423
|
+
|
424
|
+
if line_str:
|
425
|
+
self._message_buffer.append(line_str)
|
426
|
+
|
427
|
+
except Exception as e:
|
428
|
+
logger.error(f"STDIO read loop error: {e}")
|
429
|
+
finally:
|
430
|
+
if self._connected:
|
431
|
+
await self.disconnect()
|
432
|
+
|
433
|
+
def _prepare_environment(self) -> Dict[str, str]:
|
434
|
+
"""Prepare process environment variables."""
|
435
|
+
# Start with filtered parent environment
|
436
|
+
if self.environment_filter:
|
437
|
+
process_env = {
|
438
|
+
key: value
|
439
|
+
for key, value in os.environ.items()
|
440
|
+
if key in self.environment_filter
|
441
|
+
}
|
442
|
+
else:
|
443
|
+
process_env = os.environ.copy()
|
444
|
+
|
445
|
+
# Add custom environment variables
|
446
|
+
process_env.update(self.env)
|
447
|
+
|
448
|
+
return process_env
|
449
|
+
|
450
|
+
async def get_process_info(self) -> Dict[str, Any]:
|
451
|
+
"""Get information about the subprocess."""
|
452
|
+
if not self.process:
|
453
|
+
return {}
|
454
|
+
|
455
|
+
return {
|
456
|
+
"pid": self.process.pid,
|
457
|
+
"returncode": self.process.returncode,
|
458
|
+
"command": [self.command] + self.args,
|
459
|
+
"working_directory": self.working_directory,
|
460
|
+
}
|
461
|
+
|
462
|
+
|
463
|
+
class SSETransport(BaseTransport):
|
464
|
+
"""Server-Sent Events transport with endpoint negotiation."""
|
465
|
+
|
466
|
+
def __init__(
|
467
|
+
self,
|
468
|
+
base_url: str,
|
469
|
+
auth_header: Optional[str] = None,
|
470
|
+
validate_origin: bool = True,
|
471
|
+
allowed_origins: Optional[List[str]] = None,
|
472
|
+
endpoint_path: str = "/sse",
|
473
|
+
message_path: str = "/message",
|
474
|
+
allow_localhost: bool = False,
|
475
|
+
skip_security_validation: bool = False,
|
476
|
+
**kwargs,
|
477
|
+
):
|
478
|
+
"""Initialize SSE transport.
|
479
|
+
|
480
|
+
Args:
|
481
|
+
base_url: Base URL for the server
|
482
|
+
auth_header: Authorization header
|
483
|
+
validate_origin: Enable origin validation
|
484
|
+
allowed_origins: List of allowed origins
|
485
|
+
endpoint_path: SSE endpoint path
|
486
|
+
message_path: Message posting path
|
487
|
+
allow_localhost: Allow connections to localhost (for testing)
|
488
|
+
skip_security_validation: Skip all security validation (for testing)
|
489
|
+
**kwargs: Base transport arguments
|
490
|
+
"""
|
491
|
+
super().__init__("sse", **kwargs)
|
492
|
+
|
493
|
+
self.base_url = base_url.rstrip("/")
|
494
|
+
self.auth_header = auth_header
|
495
|
+
self.validate_origin = validate_origin
|
496
|
+
self.allowed_origins = allowed_origins or [base_url]
|
497
|
+
self.endpoint_path = endpoint_path
|
498
|
+
self.message_path = message_path
|
499
|
+
self.allow_localhost = allow_localhost
|
500
|
+
self.skip_security_validation = skip_security_validation
|
501
|
+
|
502
|
+
# Connection state
|
503
|
+
self.session: Optional[aiohttp.ClientSession] = None
|
504
|
+
self.sse_response: Optional[aiohttp.ClientResponse] = None
|
505
|
+
self._read_task: Optional[asyncio.Task] = None
|
506
|
+
self._message_queue: asyncio.Queue = asyncio.Queue()
|
507
|
+
|
508
|
+
async def connect(self) -> None:
|
509
|
+
"""Connect to SSE endpoint."""
|
510
|
+
if self._connected:
|
511
|
+
return
|
512
|
+
|
513
|
+
# Validate URL (with configurable security)
|
514
|
+
if not self.skip_security_validation:
|
515
|
+
if not TransportSecurity.validate_url(
|
516
|
+
self.base_url, allow_localhost=self.allow_localhost
|
517
|
+
):
|
518
|
+
raise TransportError("Invalid or unsafe URL", transport_type="sse")
|
519
|
+
|
520
|
+
try:
|
521
|
+
# Create session
|
522
|
+
headers = {}
|
523
|
+
if self.auth_header:
|
524
|
+
headers["Authorization"] = self.auth_header
|
525
|
+
|
526
|
+
# Add CORS headers if origin validation is enabled
|
527
|
+
if self.validate_origin:
|
528
|
+
headers["Origin"] = self.base_url
|
529
|
+
|
530
|
+
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
531
|
+
self.session = aiohttp.ClientSession(headers=headers, timeout=timeout)
|
532
|
+
|
533
|
+
# Connect to SSE endpoint
|
534
|
+
sse_url = urljoin(self.base_url, self.endpoint_path)
|
535
|
+
self.sse_response = await self.session.get(
|
536
|
+
sse_url,
|
537
|
+
headers={"Accept": "text/event-stream", "Cache-Control": "no-cache"},
|
538
|
+
)
|
539
|
+
|
540
|
+
if self.sse_response.status != 200:
|
541
|
+
raise TransportError(
|
542
|
+
f"SSE connection failed: {self.sse_response.status}",
|
543
|
+
transport_type="sse",
|
544
|
+
)
|
545
|
+
|
546
|
+
# Start reading SSE events
|
547
|
+
self._read_task = asyncio.create_task(self._read_sse_events())
|
548
|
+
|
549
|
+
self._connected = True
|
550
|
+
self._update_metrics("connections_total")
|
551
|
+
|
552
|
+
logger.info(f"SSE transport connected: {sse_url}")
|
553
|
+
|
554
|
+
except Exception as e:
|
555
|
+
self._update_metrics("connections_failed")
|
556
|
+
await self._cleanup_connection()
|
557
|
+
raise TransportError(f"SSE connection failed: {e}", transport_type="sse")
|
558
|
+
|
559
|
+
async def disconnect(self) -> None:
|
560
|
+
"""Disconnect from SSE endpoint."""
|
561
|
+
if not self._connected:
|
562
|
+
return
|
563
|
+
|
564
|
+
self._connected = False
|
565
|
+
await self._cleanup_connection()
|
566
|
+
|
567
|
+
logger.info("SSE transport disconnected")
|
568
|
+
|
569
|
+
async def send_message(self, message: Dict[str, Any]) -> None:
|
570
|
+
"""Send message via HTTP POST."""
|
571
|
+
if not self._connected or not self.session:
|
572
|
+
raise TransportError("Transport not connected", transport_type="sse")
|
573
|
+
|
574
|
+
try:
|
575
|
+
message_url = urljoin(self.base_url, self.message_path)
|
576
|
+
|
577
|
+
async with self.session.post(
|
578
|
+
message_url, json=message, headers={"Content-Type": "application/json"}
|
579
|
+
) as response:
|
580
|
+
if response.status not in (200, 201, 202):
|
581
|
+
raise TransportError(
|
582
|
+
f"Message send failed: {response.status}", transport_type="sse"
|
583
|
+
)
|
584
|
+
|
585
|
+
self._update_metrics("messages_sent")
|
586
|
+
self._update_metrics("bytes_sent", len(json.dumps(message)))
|
587
|
+
|
588
|
+
except Exception as e:
|
589
|
+
self._update_metrics("errors_total")
|
590
|
+
raise TransportError(f"Failed to send message: {e}", transport_type="sse")
|
591
|
+
|
592
|
+
async def receive_message(self) -> Dict[str, Any]:
|
593
|
+
"""Receive message from SSE stream."""
|
594
|
+
if not self._connected:
|
595
|
+
raise TransportError("Transport not connected", transport_type="sse")
|
596
|
+
|
597
|
+
try:
|
598
|
+
# Wait for message from queue
|
599
|
+
message = await asyncio.wait_for(
|
600
|
+
self._message_queue.get(), timeout=self.timeout
|
601
|
+
)
|
602
|
+
|
603
|
+
self._update_metrics("messages_received")
|
604
|
+
self._update_metrics("bytes_received", len(json.dumps(message)))
|
605
|
+
|
606
|
+
return message
|
607
|
+
|
608
|
+
except asyncio.TimeoutError:
|
609
|
+
raise TransportError("Receive timeout", transport_type="sse")
|
610
|
+
except Exception as e:
|
611
|
+
self._update_metrics("errors_total")
|
612
|
+
raise TransportError(
|
613
|
+
f"Failed to receive message: {e}", transport_type="sse"
|
614
|
+
)
|
615
|
+
|
616
|
+
async def _read_sse_events(self):
|
617
|
+
"""Background task to read SSE events."""
|
618
|
+
if not self.sse_response:
|
619
|
+
return
|
620
|
+
|
621
|
+
try:
|
622
|
+
async for line in self.sse_response.content:
|
623
|
+
if not self._connected:
|
624
|
+
break
|
625
|
+
|
626
|
+
line_str = line.decode("utf-8").strip()
|
627
|
+
|
628
|
+
# Parse SSE event
|
629
|
+
if line_str.startswith("data: "):
|
630
|
+
data_str = line_str[6:] # Remove "data: " prefix
|
631
|
+
|
632
|
+
try:
|
633
|
+
message = json.loads(data_str)
|
634
|
+
await self._message_queue.put(message)
|
635
|
+
except json.JSONDecodeError:
|
636
|
+
logger.warning(f"Invalid JSON in SSE event: {data_str}")
|
637
|
+
|
638
|
+
except Exception as e:
|
639
|
+
logger.error(f"SSE read error: {e}")
|
640
|
+
finally:
|
641
|
+
if self._connected:
|
642
|
+
await self.disconnect()
|
643
|
+
|
644
|
+
async def _cleanup_connection(self):
|
645
|
+
"""Clean up connection resources."""
|
646
|
+
# Cancel read task
|
647
|
+
if self._read_task and not self._read_task.done():
|
648
|
+
self._read_task.cancel()
|
649
|
+
try:
|
650
|
+
await self._read_task
|
651
|
+
except asyncio.CancelledError:
|
652
|
+
pass
|
653
|
+
|
654
|
+
# Close SSE response
|
655
|
+
if self.sse_response:
|
656
|
+
self.sse_response.close()
|
657
|
+
self.sse_response = None
|
658
|
+
|
659
|
+
# Close session
|
660
|
+
if self.session:
|
661
|
+
await self.session.close()
|
662
|
+
self.session = None
|
663
|
+
|
664
|
+
|
665
|
+
class StreamableHTTPTransport(BaseTransport):
|
666
|
+
"""StreamableHTTP transport with session management."""
|
667
|
+
|
668
|
+
def __init__(
|
669
|
+
self,
|
670
|
+
base_url: str,
|
671
|
+
session_management: bool = True,
|
672
|
+
streaming_threshold: int = 1024,
|
673
|
+
chunk_size: int = 8192,
|
674
|
+
allow_localhost: bool = False,
|
675
|
+
skip_security_validation: bool = False,
|
676
|
+
**kwargs,
|
677
|
+
):
|
678
|
+
"""Initialize StreamableHTTP transport.
|
679
|
+
|
680
|
+
Args:
|
681
|
+
base_url: Base URL for the server
|
682
|
+
session_management: Enable session management
|
683
|
+
streaming_threshold: Size threshold for streaming
|
684
|
+
chunk_size: Chunk size for streaming
|
685
|
+
allow_localhost: Allow connections to localhost (for testing)
|
686
|
+
skip_security_validation: Skip all security validation (for testing)
|
687
|
+
**kwargs: Base transport arguments
|
688
|
+
"""
|
689
|
+
super().__init__("streamable_http", **kwargs)
|
690
|
+
|
691
|
+
self.base_url = base_url.rstrip("/")
|
692
|
+
self.session_management = session_management
|
693
|
+
self.streaming_threshold = streaming_threshold
|
694
|
+
self.chunk_size = chunk_size
|
695
|
+
self.allow_localhost = allow_localhost
|
696
|
+
self.skip_security_validation = skip_security_validation
|
697
|
+
|
698
|
+
# Session state
|
699
|
+
self.session: Optional[aiohttp.ClientSession] = None
|
700
|
+
self.session_id: Optional[str] = None
|
701
|
+
|
702
|
+
async def connect(self) -> None:
|
703
|
+
"""Connect and optionally create session."""
|
704
|
+
if self._connected:
|
705
|
+
return
|
706
|
+
|
707
|
+
# Validate URL (with configurable security)
|
708
|
+
if not self.skip_security_validation:
|
709
|
+
if not TransportSecurity.validate_url(
|
710
|
+
self.base_url, allow_localhost=self.allow_localhost
|
711
|
+
):
|
712
|
+
raise TransportError(
|
713
|
+
"Invalid or unsafe URL", transport_type="streamable_http"
|
714
|
+
)
|
715
|
+
|
716
|
+
try:
|
717
|
+
# Create HTTP session
|
718
|
+
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
719
|
+
self.session = aiohttp.ClientSession(timeout=timeout)
|
720
|
+
|
721
|
+
# Create server session if enabled
|
722
|
+
if self.session_management:
|
723
|
+
await self._create_server_session()
|
724
|
+
|
725
|
+
self._connected = True
|
726
|
+
self._update_metrics("connections_total")
|
727
|
+
|
728
|
+
logger.info(f"StreamableHTTP transport connected: {self.base_url}")
|
729
|
+
|
730
|
+
except Exception as e:
|
731
|
+
self._update_metrics("connections_failed")
|
732
|
+
await self._cleanup_connection()
|
733
|
+
raise TransportError(
|
734
|
+
f"HTTP connection failed: {e}", transport_type="streamable_http"
|
735
|
+
)
|
736
|
+
|
737
|
+
async def disconnect(self) -> None:
|
738
|
+
"""Disconnect and cleanup session."""
|
739
|
+
if not self._connected:
|
740
|
+
return
|
741
|
+
|
742
|
+
self._connected = False
|
743
|
+
|
744
|
+
# Close server session
|
745
|
+
if self.session_management and self.session_id:
|
746
|
+
await self._close_server_session()
|
747
|
+
|
748
|
+
await self._cleanup_connection()
|
749
|
+
|
750
|
+
logger.info("StreamableHTTP transport disconnected")
|
751
|
+
|
752
|
+
async def send_message(self, message: Dict[str, Any]) -> None:
|
753
|
+
"""Send message via HTTP POST."""
|
754
|
+
if not self._connected or not self.session:
|
755
|
+
raise TransportError(
|
756
|
+
"Transport not connected", transport_type="streamable_http"
|
757
|
+
)
|
758
|
+
|
759
|
+
try:
|
760
|
+
# Prepare URL
|
761
|
+
url = urljoin(self.base_url, "/message")
|
762
|
+
|
763
|
+
# Add session ID if using session management
|
764
|
+
headers = {"Content-Type": "application/json"}
|
765
|
+
if self.session_id:
|
766
|
+
headers["X-Session-ID"] = self.session_id
|
767
|
+
|
768
|
+
# Determine if streaming is needed
|
769
|
+
message_data = json.dumps(message)
|
770
|
+
use_streaming = len(message_data) > self.streaming_threshold
|
771
|
+
|
772
|
+
if use_streaming:
|
773
|
+
# Stream large message
|
774
|
+
await self._send_streamed_message(url, message_data, headers)
|
775
|
+
else:
|
776
|
+
# Send normal message
|
777
|
+
async with self.session.post(
|
778
|
+
url, json=message, headers=headers
|
779
|
+
) as response:
|
780
|
+
if response.status not in (200, 201, 202):
|
781
|
+
raise TransportError(
|
782
|
+
f"Message send failed: {response.status}",
|
783
|
+
transport_type="streamable_http",
|
784
|
+
)
|
785
|
+
|
786
|
+
self._update_metrics("messages_sent")
|
787
|
+
self._update_metrics("bytes_sent", len(message_data))
|
788
|
+
|
789
|
+
except Exception as e:
|
790
|
+
self._update_metrics("errors_total")
|
791
|
+
raise TransportError(
|
792
|
+
f"Failed to send message: {e}", transport_type="streamable_http"
|
793
|
+
)
|
794
|
+
|
795
|
+
async def receive_message(self) -> Dict[str, Any]:
|
796
|
+
"""Receive message via HTTP GET/POST."""
|
797
|
+
if not self._connected or not self.session:
|
798
|
+
raise TransportError(
|
799
|
+
"Transport not connected", transport_type="streamable_http"
|
800
|
+
)
|
801
|
+
|
802
|
+
try:
|
803
|
+
# Prepare URL
|
804
|
+
url = urljoin(self.base_url, "/receive")
|
805
|
+
|
806
|
+
# Add session ID if using session management
|
807
|
+
headers = {}
|
808
|
+
if self.session_id:
|
809
|
+
headers["X-Session-ID"] = self.session_id
|
810
|
+
|
811
|
+
# Receive message
|
812
|
+
async with self.session.get(url, headers=headers) as response:
|
813
|
+
if response.status == 204:
|
814
|
+
# No message available
|
815
|
+
await asyncio.sleep(0.1) # Brief delay before retry
|
816
|
+
return await self.receive_message()
|
817
|
+
|
818
|
+
if response.status != 200:
|
819
|
+
raise TransportError(
|
820
|
+
f"Message receive failed: {response.status}",
|
821
|
+
transport_type="streamable_http",
|
822
|
+
)
|
823
|
+
|
824
|
+
# Check if response is streamed
|
825
|
+
content_length = response.headers.get("Content-Length")
|
826
|
+
if content_length and int(content_length) > self.streaming_threshold:
|
827
|
+
message = await self._receive_streamed_message(response)
|
828
|
+
else:
|
829
|
+
message = await response.json()
|
830
|
+
|
831
|
+
self._update_metrics("messages_received")
|
832
|
+
self._update_metrics("bytes_received", len(json.dumps(message)))
|
833
|
+
|
834
|
+
return message
|
835
|
+
|
836
|
+
except Exception as e:
|
837
|
+
self._update_metrics("errors_total")
|
838
|
+
raise TransportError(
|
839
|
+
f"Failed to receive message: {e}", transport_type="streamable_http"
|
840
|
+
)
|
841
|
+
|
842
|
+
async def _create_server_session(self):
|
843
|
+
"""Create session with server."""
|
844
|
+
if not self.session:
|
845
|
+
return
|
846
|
+
|
847
|
+
url = urljoin(self.base_url, "/session")
|
848
|
+
|
849
|
+
async with self.session.post(url) as response:
|
850
|
+
if response.status == 201:
|
851
|
+
session_data = await response.json()
|
852
|
+
self.session_id = session_data.get("session_id")
|
853
|
+
logger.info(f"Created server session: {self.session_id}")
|
854
|
+
else:
|
855
|
+
logger.warning(f"Failed to create server session: {response.status}")
|
856
|
+
|
857
|
+
async def _close_server_session(self):
|
858
|
+
"""Close session with server."""
|
859
|
+
if not self.session or not self.session_id:
|
860
|
+
return
|
861
|
+
|
862
|
+
url = urljoin(self.base_url, f"/session/{self.session_id}")
|
863
|
+
|
864
|
+
try:
|
865
|
+
async with self.session.delete(url) as response:
|
866
|
+
if response.status in (200, 204):
|
867
|
+
logger.info(f"Closed server session: {self.session_id}")
|
868
|
+
else:
|
869
|
+
logger.warning(f"Failed to close server session: {response.status}")
|
870
|
+
except Exception as e:
|
871
|
+
logger.error(f"Error closing server session: {e}")
|
872
|
+
finally:
|
873
|
+
self.session_id = None
|
874
|
+
|
875
|
+
async def _send_streamed_message(
|
876
|
+
self, url: str, message_data: str, headers: Dict[str, str]
|
877
|
+
):
|
878
|
+
"""Send message using streaming."""
|
879
|
+
headers["Transfer-Encoding"] = "chunked"
|
880
|
+
|
881
|
+
async def message_chunks():
|
882
|
+
for i in range(0, len(message_data), self.chunk_size):
|
883
|
+
yield message_data[i : i + self.chunk_size].encode("utf-8")
|
884
|
+
|
885
|
+
async with self.session.post(
|
886
|
+
url, data=message_chunks(), headers=headers
|
887
|
+
) as response:
|
888
|
+
if response.status not in (200, 201, 202):
|
889
|
+
raise TransportError(
|
890
|
+
f"Streamed message send failed: {response.status}",
|
891
|
+
transport_type="streamable_http",
|
892
|
+
)
|
893
|
+
|
894
|
+
async def _receive_streamed_message(
|
895
|
+
self, response: aiohttp.ClientResponse
|
896
|
+
) -> Dict[str, Any]:
|
897
|
+
"""Receive streamed message."""
|
898
|
+
chunks = []
|
899
|
+
|
900
|
+
async for chunk in response.content.iter_chunked(self.chunk_size):
|
901
|
+
chunks.append(chunk.decode("utf-8"))
|
902
|
+
|
903
|
+
message_data = "".join(chunks)
|
904
|
+
return json.loads(message_data)
|
905
|
+
|
906
|
+
async def _cleanup_connection(self):
|
907
|
+
"""Clean up connection resources."""
|
908
|
+
if self.session:
|
909
|
+
await self.session.close()
|
910
|
+
self.session = None
|
911
|
+
|
912
|
+
|
913
|
+
class WebSocketTransport(BaseTransport):
|
914
|
+
"""WebSocket transport for real-time communication."""
|
915
|
+
|
916
|
+
def __init__(
|
917
|
+
self,
|
918
|
+
url: str,
|
919
|
+
subprotocols: Optional[List[str]] = None,
|
920
|
+
ping_interval: float = 20.0,
|
921
|
+
ping_timeout: float = 20.0,
|
922
|
+
allow_localhost: bool = False,
|
923
|
+
skip_security_validation: bool = False,
|
924
|
+
**kwargs,
|
925
|
+
):
|
926
|
+
"""Initialize WebSocket transport.
|
927
|
+
|
928
|
+
Args:
|
929
|
+
url: WebSocket URL
|
930
|
+
subprotocols: WebSocket subprotocols
|
931
|
+
ping_interval: Ping interval in seconds
|
932
|
+
ping_timeout: Ping timeout in seconds
|
933
|
+
allow_localhost: Allow connections to localhost (for testing)
|
934
|
+
skip_security_validation: Skip all security validation (for testing)
|
935
|
+
**kwargs: Base transport arguments
|
936
|
+
"""
|
937
|
+
super().__init__("websocket", **kwargs)
|
938
|
+
|
939
|
+
self.url = url
|
940
|
+
self.subprotocols = subprotocols or ["mcp-v1"]
|
941
|
+
self.ping_interval = ping_interval
|
942
|
+
self.ping_timeout = ping_timeout
|
943
|
+
self.allow_localhost = allow_localhost
|
944
|
+
self.skip_security_validation = skip_security_validation
|
945
|
+
|
946
|
+
# Connection state
|
947
|
+
self.websocket: Optional[websockets.WebSocketServerProtocol] = None
|
948
|
+
self._read_task: Optional[asyncio.Task] = None
|
949
|
+
self._message_queue: asyncio.Queue = asyncio.Queue()
|
950
|
+
|
951
|
+
async def connect(self) -> None:
|
952
|
+
"""Connect to WebSocket server."""
|
953
|
+
if self._connected:
|
954
|
+
return
|
955
|
+
|
956
|
+
# Validate URL (with configurable security)
|
957
|
+
if not self.skip_security_validation:
|
958
|
+
if not TransportSecurity.validate_url(
|
959
|
+
self.url, allow_localhost=self.allow_localhost
|
960
|
+
):
|
961
|
+
raise TransportError(
|
962
|
+
"Invalid or unsafe URL", transport_type="websocket"
|
963
|
+
)
|
964
|
+
|
965
|
+
try:
|
966
|
+
# Connect to WebSocket
|
967
|
+
extra_headers = {}
|
968
|
+
if self.auth_provider:
|
969
|
+
# Add authentication headers
|
970
|
+
auth_headers = await self.auth_provider.get_headers()
|
971
|
+
extra_headers.update(auth_headers)
|
972
|
+
|
973
|
+
self.websocket = await websockets.connect(
|
974
|
+
self.url,
|
975
|
+
subprotocols=self.subprotocols,
|
976
|
+
extra_headers=extra_headers,
|
977
|
+
ping_interval=self.ping_interval,
|
978
|
+
ping_timeout=self.ping_timeout,
|
979
|
+
)
|
980
|
+
|
981
|
+
# Start reading messages
|
982
|
+
self._read_task = asyncio.create_task(self._read_messages())
|
983
|
+
|
984
|
+
self._connected = True
|
985
|
+
self._update_metrics("connections_total")
|
986
|
+
|
987
|
+
logger.info(f"WebSocket transport connected: {self.url}")
|
988
|
+
|
989
|
+
except Exception as e:
|
990
|
+
self._update_metrics("connections_failed")
|
991
|
+
await self._cleanup_connection()
|
992
|
+
raise TransportError(
|
993
|
+
f"WebSocket connection failed: {e}", transport_type="websocket"
|
994
|
+
)
|
995
|
+
|
996
|
+
async def disconnect(self) -> None:
|
997
|
+
"""Disconnect from WebSocket server."""
|
998
|
+
if not self._connected:
|
999
|
+
return
|
1000
|
+
|
1001
|
+
self._connected = False
|
1002
|
+
await self._cleanup_connection()
|
1003
|
+
|
1004
|
+
logger.info("WebSocket transport disconnected")
|
1005
|
+
|
1006
|
+
async def send_message(self, message: Dict[str, Any]) -> None:
|
1007
|
+
"""Send message via WebSocket."""
|
1008
|
+
if not self._connected or not self.websocket:
|
1009
|
+
raise TransportError("Transport not connected", transport_type="websocket")
|
1010
|
+
|
1011
|
+
try:
|
1012
|
+
message_data = json.dumps(message)
|
1013
|
+
await self.websocket.send(message_data)
|
1014
|
+
|
1015
|
+
self._update_metrics("messages_sent")
|
1016
|
+
self._update_metrics("bytes_sent", len(message_data))
|
1017
|
+
|
1018
|
+
except Exception as e:
|
1019
|
+
self._update_metrics("errors_total")
|
1020
|
+
raise TransportError(
|
1021
|
+
f"Failed to send message: {e}", transport_type="websocket"
|
1022
|
+
)
|
1023
|
+
|
1024
|
+
async def receive_message(self) -> Dict[str, Any]:
|
1025
|
+
"""Receive message from WebSocket."""
|
1026
|
+
if not self._connected:
|
1027
|
+
raise TransportError("Transport not connected", transport_type="websocket")
|
1028
|
+
|
1029
|
+
try:
|
1030
|
+
message = await asyncio.wait_for(
|
1031
|
+
self._message_queue.get(), timeout=self.timeout
|
1032
|
+
)
|
1033
|
+
|
1034
|
+
self._update_metrics("messages_received")
|
1035
|
+
self._update_metrics("bytes_received", len(json.dumps(message)))
|
1036
|
+
|
1037
|
+
return message
|
1038
|
+
|
1039
|
+
except asyncio.TimeoutError:
|
1040
|
+
raise TransportError("Receive timeout", transport_type="websocket")
|
1041
|
+
except Exception as e:
|
1042
|
+
self._update_metrics("errors_total")
|
1043
|
+
raise TransportError(
|
1044
|
+
f"Failed to receive message: {e}", transport_type="websocket"
|
1045
|
+
)
|
1046
|
+
|
1047
|
+
async def _read_messages(self):
|
1048
|
+
"""Background task to read WebSocket messages."""
|
1049
|
+
if not self.websocket:
|
1050
|
+
return
|
1051
|
+
|
1052
|
+
try:
|
1053
|
+
async for message_data in self.websocket:
|
1054
|
+
if not self._connected:
|
1055
|
+
break
|
1056
|
+
|
1057
|
+
try:
|
1058
|
+
message = json.loads(message_data)
|
1059
|
+
await self._message_queue.put(message)
|
1060
|
+
except json.JSONDecodeError:
|
1061
|
+
logger.warning(f"Invalid JSON in WebSocket message: {message_data}")
|
1062
|
+
|
1063
|
+
except websockets.exceptions.ConnectionClosed:
|
1064
|
+
logger.info("WebSocket connection closed")
|
1065
|
+
except Exception as e:
|
1066
|
+
logger.error(f"WebSocket read error: {e}")
|
1067
|
+
finally:
|
1068
|
+
if self._connected:
|
1069
|
+
await self.disconnect()
|
1070
|
+
|
1071
|
+
async def _cleanup_connection(self):
|
1072
|
+
"""Clean up connection resources."""
|
1073
|
+
# Cancel read task
|
1074
|
+
if self._read_task and not self._read_task.done():
|
1075
|
+
self._read_task.cancel()
|
1076
|
+
try:
|
1077
|
+
await self._read_task
|
1078
|
+
except asyncio.CancelledError:
|
1079
|
+
pass
|
1080
|
+
|
1081
|
+
# Close WebSocket
|
1082
|
+
if self.websocket:
|
1083
|
+
await self.websocket.close()
|
1084
|
+
self.websocket = None
|
1085
|
+
|
1086
|
+
|
1087
|
+
class TransportManager:
|
1088
|
+
"""Manager for MCP transport instances."""
|
1089
|
+
|
1090
|
+
def __init__(self):
|
1091
|
+
"""Initialize transport manager."""
|
1092
|
+
self._transports: Dict[str, BaseTransport] = {}
|
1093
|
+
self._transport_factories: Dict[str, Callable] = {
|
1094
|
+
"stdio": EnhancedStdioTransport,
|
1095
|
+
"sse": SSETransport,
|
1096
|
+
"streamable_http": StreamableHTTPTransport,
|
1097
|
+
"websocket": WebSocketTransport,
|
1098
|
+
}
|
1099
|
+
|
1100
|
+
def register_transport_factory(self, transport_type: str, factory: Callable):
|
1101
|
+
"""Register transport factory.
|
1102
|
+
|
1103
|
+
Args:
|
1104
|
+
transport_type: Transport type name
|
1105
|
+
factory: Factory function
|
1106
|
+
"""
|
1107
|
+
self._transport_factories[transport_type] = factory
|
1108
|
+
|
1109
|
+
def create_transport(self, transport_type: str, **kwargs) -> BaseTransport:
|
1110
|
+
"""Create transport instance.
|
1111
|
+
|
1112
|
+
Args:
|
1113
|
+
transport_type: Transport type
|
1114
|
+
**kwargs: Transport arguments
|
1115
|
+
|
1116
|
+
Returns:
|
1117
|
+
Transport instance
|
1118
|
+
"""
|
1119
|
+
factory = self._transport_factories.get(transport_type)
|
1120
|
+
if not factory:
|
1121
|
+
raise ValueError(f"Unknown transport type: {transport_type}")
|
1122
|
+
|
1123
|
+
return factory(**kwargs)
|
1124
|
+
|
1125
|
+
def register_transport(self, name: str, transport: BaseTransport):
|
1126
|
+
"""Register transport instance.
|
1127
|
+
|
1128
|
+
Args:
|
1129
|
+
name: Transport name
|
1130
|
+
transport: Transport instance
|
1131
|
+
"""
|
1132
|
+
self._transports[name] = transport
|
1133
|
+
|
1134
|
+
def get_transport(self, name: str) -> Optional[BaseTransport]:
|
1135
|
+
"""Get registered transport.
|
1136
|
+
|
1137
|
+
Args:
|
1138
|
+
name: Transport name
|
1139
|
+
|
1140
|
+
Returns:
|
1141
|
+
Transport instance or None
|
1142
|
+
"""
|
1143
|
+
return self._transports.get(name)
|
1144
|
+
|
1145
|
+
def list_transports(self) -> List[str]:
|
1146
|
+
"""List registered transport names."""
|
1147
|
+
return list(self._transports.keys())
|
1148
|
+
|
1149
|
+
async def disconnect_all(self):
|
1150
|
+
"""Disconnect all registered transports."""
|
1151
|
+
for transport in self._transports.values():
|
1152
|
+
try:
|
1153
|
+
await transport.disconnect()
|
1154
|
+
except Exception as e:
|
1155
|
+
logger.error(f"Error disconnecting transport: {e}")
|
1156
|
+
|
1157
|
+
self._transports.clear()
|
1158
|
+
|
1159
|
+
|
1160
|
+
# Global transport manager
|
1161
|
+
_transport_manager: Optional[TransportManager] = None
|
1162
|
+
|
1163
|
+
|
1164
|
+
def get_transport_manager() -> TransportManager:
|
1165
|
+
"""Get global transport manager."""
|
1166
|
+
global _transport_manager
|
1167
|
+
if _transport_manager is None:
|
1168
|
+
_transport_manager = TransportManager()
|
1169
|
+
return _transport_manager
|