kailash 0.6.3__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. kailash/__init__.py +3 -3
  2. kailash/api/custom_nodes_secure.py +3 -3
  3. kailash/api/gateway.py +1 -1
  4. kailash/api/studio.py +2 -3
  5. kailash/api/workflow_api.py +3 -4
  6. kailash/core/resilience/bulkhead.py +460 -0
  7. kailash/core/resilience/circuit_breaker.py +92 -10
  8. kailash/edge/discovery.py +86 -0
  9. kailash/mcp_server/__init__.py +309 -33
  10. kailash/mcp_server/advanced_features.py +1022 -0
  11. kailash/mcp_server/ai_registry_server.py +27 -2
  12. kailash/mcp_server/auth.py +789 -0
  13. kailash/mcp_server/client.py +645 -378
  14. kailash/mcp_server/discovery.py +1593 -0
  15. kailash/mcp_server/errors.py +673 -0
  16. kailash/mcp_server/oauth.py +1727 -0
  17. kailash/mcp_server/protocol.py +1126 -0
  18. kailash/mcp_server/registry_integration.py +587 -0
  19. kailash/mcp_server/server.py +1213 -98
  20. kailash/mcp_server/transports.py +1169 -0
  21. kailash/mcp_server/utils/__init__.py +6 -1
  22. kailash/mcp_server/utils/cache.py +250 -7
  23. kailash/middleware/auth/auth_manager.py +3 -3
  24. kailash/middleware/communication/api_gateway.py +2 -9
  25. kailash/middleware/communication/realtime.py +1 -1
  26. kailash/middleware/mcp/enhanced_server.py +1 -1
  27. kailash/nodes/__init__.py +2 -0
  28. kailash/nodes/admin/audit_log.py +6 -6
  29. kailash/nodes/admin/permission_check.py +8 -8
  30. kailash/nodes/admin/role_management.py +32 -28
  31. kailash/nodes/admin/schema.sql +6 -1
  32. kailash/nodes/admin/schema_manager.py +13 -13
  33. kailash/nodes/admin/security_event.py +16 -20
  34. kailash/nodes/admin/tenant_isolation.py +3 -3
  35. kailash/nodes/admin/transaction_utils.py +3 -3
  36. kailash/nodes/admin/user_management.py +21 -22
  37. kailash/nodes/ai/a2a.py +11 -11
  38. kailash/nodes/ai/ai_providers.py +9 -12
  39. kailash/nodes/ai/embedding_generator.py +13 -14
  40. kailash/nodes/ai/intelligent_agent_orchestrator.py +19 -19
  41. kailash/nodes/ai/iterative_llm_agent.py +2 -2
  42. kailash/nodes/ai/llm_agent.py +210 -33
  43. kailash/nodes/ai/self_organizing.py +2 -2
  44. kailash/nodes/alerts/discord.py +4 -4
  45. kailash/nodes/api/graphql.py +6 -6
  46. kailash/nodes/api/http.py +12 -17
  47. kailash/nodes/api/rate_limiting.py +4 -4
  48. kailash/nodes/api/rest.py +15 -15
  49. kailash/nodes/auth/mfa.py +3 -4
  50. kailash/nodes/auth/risk_assessment.py +2 -2
  51. kailash/nodes/auth/session_management.py +5 -5
  52. kailash/nodes/auth/sso.py +143 -0
  53. kailash/nodes/base.py +6 -2
  54. kailash/nodes/base_async.py +16 -2
  55. kailash/nodes/base_with_acl.py +2 -2
  56. kailash/nodes/cache/__init__.py +9 -0
  57. kailash/nodes/cache/cache.py +1172 -0
  58. kailash/nodes/cache/cache_invalidation.py +870 -0
  59. kailash/nodes/cache/redis_pool_manager.py +595 -0
  60. kailash/nodes/code/async_python.py +2 -1
  61. kailash/nodes/code/python.py +196 -35
  62. kailash/nodes/compliance/data_retention.py +6 -6
  63. kailash/nodes/compliance/gdpr.py +5 -5
  64. kailash/nodes/data/__init__.py +10 -0
  65. kailash/nodes/data/optimistic_locking.py +906 -0
  66. kailash/nodes/data/readers.py +8 -8
  67. kailash/nodes/data/redis.py +349 -0
  68. kailash/nodes/data/sql.py +314 -3
  69. kailash/nodes/data/streaming.py +21 -0
  70. kailash/nodes/enterprise/__init__.py +8 -0
  71. kailash/nodes/enterprise/audit_logger.py +285 -0
  72. kailash/nodes/enterprise/batch_processor.py +22 -3
  73. kailash/nodes/enterprise/data_lineage.py +1 -1
  74. kailash/nodes/enterprise/mcp_executor.py +205 -0
  75. kailash/nodes/enterprise/service_discovery.py +150 -0
  76. kailash/nodes/enterprise/tenant_assignment.py +108 -0
  77. kailash/nodes/logic/async_operations.py +2 -2
  78. kailash/nodes/logic/convergence.py +1 -1
  79. kailash/nodes/logic/operations.py +1 -1
  80. kailash/nodes/monitoring/__init__.py +11 -1
  81. kailash/nodes/monitoring/health_check.py +456 -0
  82. kailash/nodes/monitoring/log_processor.py +817 -0
  83. kailash/nodes/monitoring/metrics_collector.py +627 -0
  84. kailash/nodes/monitoring/performance_benchmark.py +137 -11
  85. kailash/nodes/rag/advanced.py +7 -7
  86. kailash/nodes/rag/agentic.py +49 -2
  87. kailash/nodes/rag/conversational.py +3 -3
  88. kailash/nodes/rag/evaluation.py +3 -3
  89. kailash/nodes/rag/federated.py +3 -3
  90. kailash/nodes/rag/graph.py +3 -3
  91. kailash/nodes/rag/multimodal.py +3 -3
  92. kailash/nodes/rag/optimized.py +5 -5
  93. kailash/nodes/rag/privacy.py +3 -3
  94. kailash/nodes/rag/query_processing.py +6 -6
  95. kailash/nodes/rag/realtime.py +1 -1
  96. kailash/nodes/rag/registry.py +2 -6
  97. kailash/nodes/rag/router.py +1 -1
  98. kailash/nodes/rag/similarity.py +7 -7
  99. kailash/nodes/rag/strategies.py +4 -4
  100. kailash/nodes/security/abac_evaluator.py +6 -6
  101. kailash/nodes/security/behavior_analysis.py +5 -6
  102. kailash/nodes/security/credential_manager.py +1 -1
  103. kailash/nodes/security/rotating_credentials.py +11 -11
  104. kailash/nodes/security/threat_detection.py +8 -8
  105. kailash/nodes/testing/credential_testing.py +2 -2
  106. kailash/nodes/transform/processors.py +5 -5
  107. kailash/runtime/local.py +162 -14
  108. kailash/runtime/parameter_injection.py +425 -0
  109. kailash/runtime/parameter_injector.py +657 -0
  110. kailash/runtime/testing.py +2 -2
  111. kailash/testing/fixtures.py +2 -2
  112. kailash/workflow/builder.py +99 -18
  113. kailash/workflow/builder_improvements.py +207 -0
  114. kailash/workflow/input_handling.py +170 -0
  115. {kailash-0.6.3.dist-info → kailash-0.6.4.dist-info}/METADATA +22 -9
  116. {kailash-0.6.3.dist-info → kailash-0.6.4.dist-info}/RECORD +120 -94
  117. {kailash-0.6.3.dist-info → kailash-0.6.4.dist-info}/WHEEL +0 -0
  118. {kailash-0.6.3.dist-info → kailash-0.6.4.dist-info}/entry_points.txt +0 -0
  119. {kailash-0.6.3.dist-info → kailash-0.6.4.dist-info}/licenses/LICENSE +0 -0
  120. {kailash-0.6.3.dist-info → kailash-0.6.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1169 @@
1
+ """
2
+ Complete MCP Transport Implementations.
3
+
4
+ This module provides comprehensive transport layer implementations for MCP,
5
+ including enhanced STDIO, SSE, StreamableHTTP, and WebSocket transports.
6
+ All implementations build on the official MCP Python SDK while adding
7
+ production-ready features like security, connection management, and monitoring.
8
+
9
+ Features:
10
+ - Enhanced STDIO transport with proper process management
11
+ - Complete SSE transport with endpoint negotiation
12
+ - StreamableHTTP transport with session management
13
+ - WebSocket transport for real-time communication
14
+ - Transport security and validation
15
+ - Connection pooling and management
16
+ - Health checking and monitoring
17
+
18
+ Examples:
19
+ Enhanced STDIO transport:
20
+
21
+ >>> from kailash.mcp_server.transports import EnhancedStdioTransport
22
+ >>> transport = EnhancedStdioTransport(
23
+ ... command="python",
24
+ ... args=["-m", "my_mcp_server"],
25
+ ... environment_filter=["PATH", "PYTHONPATH"]
26
+ ... )
27
+ >>> async with transport:
28
+ ... session = await transport.create_session()
29
+
30
+ SSE transport with security:
31
+
32
+ >>> from kailash.mcp_server.transports import SSETransport
33
+ >>> transport = SSETransport(
34
+ ... base_url="https://api.example.com/mcp",
35
+ ... auth_header="Bearer token123",
36
+ ... validate_origin=True
37
+ ... )
38
+
39
+ StreamableHTTP with session management:
40
+
41
+ >>> from kailash.mcp_server.transports import StreamableHTTPTransport
42
+ >>> transport = StreamableHTTPTransport(
43
+ ... base_url="https://api.example.com/mcp",
44
+ ... session_management=True,
45
+ ... streaming_threshold=1024
46
+ ... )
47
+ """
48
+
49
+ import asyncio
50
+ import json
51
+ import logging
52
+ import os
53
+ import platform
54
+ import signal
55
+ import socket
56
+ import subprocess
57
+ import time
58
+ import uuid
59
+ import weakref
60
+ from abc import ABC, abstractmethod
61
+ from contextlib import AsyncExitStack
62
+ from pathlib import Path
63
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
64
+ from urllib.parse import urljoin, urlparse
65
+
66
+ import aiohttp
67
+ import websockets
68
+
69
+ from .auth import AuthProvider
70
+ from .errors import MCPError, MCPErrorCode, TransportError
71
+ from .protocol import MetaData, ProtocolManager
72
+
73
+ logger = logging.getLogger(__name__)
74
+
75
+
76
+ class TransportSecurity:
77
+ """Security utilities for MCP transports."""
78
+
79
+ ALLOWED_SCHEMES = {"http", "https", "ws", "wss"}
80
+ BLOCKED_HOSTS = {
81
+ "169.254.169.254",
82
+ "localhost",
83
+ "127.0.0.1",
84
+ } # Basic DNS rebinding protection
85
+
86
+ @classmethod
87
+ def validate_url(cls, url: str, allow_localhost: bool = False) -> bool:
88
+ """Validate URL for security.
89
+
90
+ Args:
91
+ url: URL to validate
92
+ allow_localhost: Allow localhost connections
93
+
94
+ Returns:
95
+ True if URL is safe
96
+ """
97
+ try:
98
+ parsed = urlparse(url)
99
+
100
+ # Check scheme
101
+ if parsed.scheme not in cls.ALLOWED_SCHEMES:
102
+ logger.warning(f"Blocked unsafe scheme: {parsed.scheme}")
103
+ return False
104
+
105
+ # Check host
106
+ if not allow_localhost and parsed.hostname in cls.BLOCKED_HOSTS:
107
+ logger.warning(f"Blocked potentially unsafe host: {parsed.hostname}")
108
+ return False
109
+
110
+ # Check for IP address patterns that could be exploited
111
+ if parsed.hostname and parsed.hostname.startswith("0."):
112
+ logger.warning(f"Blocked potentially unsafe IP: {parsed.hostname}")
113
+ return False
114
+
115
+ return True
116
+
117
+ except Exception as e:
118
+ logger.error(f"URL validation error: {e}")
119
+ return False
120
+
121
+ @classmethod
122
+ def validate_origin(cls, origin: str, expected_origins: List[str]) -> bool:
123
+ """Validate request origin.
124
+
125
+ Args:
126
+ origin: Request origin
127
+ expected_origins: List of allowed origins
128
+
129
+ Returns:
130
+ True if origin is allowed
131
+ """
132
+ if not origin:
133
+ return False
134
+
135
+ # Exact match
136
+ if origin in expected_origins:
137
+ return True
138
+
139
+ # Wildcard patterns
140
+ for expected in expected_origins:
141
+ if "*" in expected:
142
+ # Convert wildcard pattern to regex
143
+ import re
144
+
145
+ pattern = expected.replace(".", r"\.").replace("*", ".*")
146
+ if re.match(f"^{pattern}$", origin):
147
+ return True
148
+
149
+ return False
150
+
151
+
152
+ class BaseTransport(ABC):
153
+ """Base class for MCP transports."""
154
+
155
+ def __init__(
156
+ self,
157
+ name: str,
158
+ auth_provider: Optional[AuthProvider] = None,
159
+ timeout: float = 30.0,
160
+ max_retries: int = 3,
161
+ enable_metrics: bool = True,
162
+ ):
163
+ """Initialize base transport.
164
+
165
+ Args:
166
+ name: Transport name
167
+ auth_provider: Authentication provider
168
+ timeout: Connection timeout
169
+ max_retries: Maximum retry attempts
170
+ enable_metrics: Enable metrics collection
171
+ """
172
+ self.name = name
173
+ self.auth_provider = auth_provider
174
+ self.timeout = timeout
175
+ self.max_retries = max_retries
176
+ self.enable_metrics = enable_metrics
177
+
178
+ # State
179
+ self._connected = False
180
+ self._sessions: Set[weakref.ref] = set()
181
+ self._metrics: Dict[str, Any] = {}
182
+
183
+ # Initialize metrics
184
+ if enable_metrics:
185
+ self._metrics = {
186
+ "connections_total": 0,
187
+ "connections_failed": 0,
188
+ "messages_sent": 0,
189
+ "messages_received": 0,
190
+ "bytes_sent": 0,
191
+ "bytes_received": 0,
192
+ "errors_total": 0,
193
+ "start_time": time.time(),
194
+ }
195
+
196
+ @abstractmethod
197
+ async def connect(self) -> None:
198
+ """Connect the transport."""
199
+ pass
200
+
201
+ @abstractmethod
202
+ async def disconnect(self) -> None:
203
+ """Disconnect the transport."""
204
+ pass
205
+
206
+ @abstractmethod
207
+ async def send_message(self, message: Dict[str, Any]) -> None:
208
+ """Send a message."""
209
+ pass
210
+
211
+ @abstractmethod
212
+ async def receive_message(self) -> Dict[str, Any]:
213
+ """Receive a message."""
214
+ pass
215
+
216
+ async def __aenter__(self):
217
+ """Async context manager entry."""
218
+ await self.connect()
219
+ return self
220
+
221
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
222
+ """Async context manager exit."""
223
+ await self.disconnect()
224
+
225
+ def is_connected(self) -> bool:
226
+ """Check if transport is connected."""
227
+ return self._connected
228
+
229
+ def get_metrics(self) -> Dict[str, Any]:
230
+ """Get transport metrics."""
231
+ if not self.enable_metrics:
232
+ return {}
233
+
234
+ metrics = self._metrics.copy()
235
+ metrics["uptime"] = time.time() - metrics["start_time"]
236
+ metrics["active_sessions"] = len(self._sessions)
237
+ return metrics
238
+
239
+ def _update_metrics(self, metric: str, value: Union[int, float] = 1):
240
+ """Update metrics."""
241
+ if self.enable_metrics and metric in self._metrics:
242
+ self._metrics[metric] += value
243
+
244
+
245
+ class EnhancedStdioTransport(BaseTransport):
246
+ """Enhanced STDIO transport with proper process management."""
247
+
248
+ def __init__(
249
+ self,
250
+ command: str,
251
+ args: Optional[List[str]] = None,
252
+ env: Optional[Dict[str, str]] = None,
253
+ working_directory: Optional[str] = None,
254
+ environment_filter: Optional[List[str]] = None,
255
+ **kwargs,
256
+ ):
257
+ """Initialize enhanced STDIO transport.
258
+
259
+ Args:
260
+ command: Command to execute
261
+ args: Command arguments
262
+ env: Environment variables
263
+ working_directory: Working directory
264
+ environment_filter: Allowed environment variables
265
+ **kwargs: Base transport arguments
266
+ """
267
+ super().__init__("stdio", **kwargs)
268
+
269
+ self.command = command
270
+ self.args = args or []
271
+ self.env = env or {}
272
+ self.working_directory = working_directory
273
+ self.environment_filter = environment_filter
274
+
275
+ # Process management
276
+ self.process: Optional[asyncio.subprocess.Process] = None
277
+ self._read_task: Optional[asyncio.Task] = None
278
+ self._write_queue: asyncio.Queue = asyncio.Queue()
279
+ self._message_buffer: List[str] = []
280
+
281
+ async def connect(self) -> None:
282
+ """Start the subprocess and connect."""
283
+ if self._connected:
284
+ return
285
+
286
+ try:
287
+ # Prepare environment
288
+ process_env = self._prepare_environment()
289
+
290
+ # Start process
291
+ self.process = await asyncio.create_subprocess_exec(
292
+ self.command,
293
+ *self.args,
294
+ stdin=asyncio.subprocess.PIPE,
295
+ stdout=asyncio.subprocess.PIPE,
296
+ stderr=asyncio.subprocess.PIPE,
297
+ env=process_env,
298
+ cwd=self.working_directory,
299
+ )
300
+
301
+ # Start I/O tasks
302
+ self._read_task = asyncio.create_task(self._read_loop())
303
+
304
+ self._connected = True
305
+ self._update_metrics("connections_total")
306
+
307
+ logger.info(f"STDIO transport connected: {self.command}")
308
+
309
+ except Exception as e:
310
+ self._update_metrics("connections_failed")
311
+ raise TransportError(
312
+ f"Failed to start process: {e}", transport_type="stdio"
313
+ )
314
+
315
+ async def disconnect(self) -> None:
316
+ """Terminate the subprocess."""
317
+ if not self._connected:
318
+ return
319
+
320
+ self._connected = False
321
+
322
+ # Cancel read task
323
+ if self._read_task and not self._read_task.done():
324
+ self._read_task.cancel()
325
+ try:
326
+ await self._read_task
327
+ except asyncio.CancelledError:
328
+ pass
329
+
330
+ # Terminate process
331
+ if self.process:
332
+ try:
333
+ # Try graceful termination first
334
+ self.process.terminate()
335
+
336
+ # Wait with timeout
337
+ try:
338
+ await asyncio.wait_for(self.process.wait(), timeout=5.0)
339
+ except asyncio.TimeoutError:
340
+ # Force kill if needed
341
+ if platform.system() != "Windows":
342
+ self.process.kill()
343
+ await self.process.wait()
344
+
345
+ except Exception as e:
346
+ logger.error(f"Error terminating process: {e}")
347
+
348
+ finally:
349
+ self.process = None
350
+
351
+ logger.info("STDIO transport disconnected")
352
+
353
+ async def send_message(self, message: Dict[str, Any]) -> None:
354
+ """Send message to subprocess."""
355
+ if not self._connected or not self.process:
356
+ raise TransportError("Transport not connected", transport_type="stdio")
357
+
358
+ try:
359
+ # Serialize message
360
+ message_data = json.dumps(message) + "\n"
361
+ message_bytes = message_data.encode("utf-8")
362
+
363
+ # Send to subprocess
364
+ self.process.stdin.write(message_bytes)
365
+ await self.process.stdin.drain()
366
+
367
+ self._update_metrics("messages_sent")
368
+ self._update_metrics("bytes_sent", len(message_bytes))
369
+
370
+ except Exception as e:
371
+ self._update_metrics("errors_total")
372
+ raise TransportError(f"Failed to send message: {e}", transport_type="stdio")
373
+
374
+ async def receive_message(self) -> Dict[str, Any]:
375
+ """Receive message from subprocess."""
376
+ if not self._connected:
377
+ raise TransportError("Transport not connected", transport_type="stdio")
378
+
379
+ try:
380
+ # Wait for message in buffer
381
+ while not self._message_buffer:
382
+ await asyncio.sleep(0.01) # Small delay to prevent busy waiting
383
+
384
+ if not self._connected:
385
+ raise TransportError(
386
+ "Transport disconnected", transport_type="stdio"
387
+ )
388
+
389
+ # Get message from buffer
390
+ message_data = self._message_buffer.pop(0)
391
+ message = json.loads(message_data)
392
+
393
+ self._update_metrics("messages_received")
394
+ self._update_metrics("bytes_received", len(message_data))
395
+
396
+ return message
397
+
398
+ except json.JSONDecodeError as e:
399
+ self._update_metrics("errors_total")
400
+ raise TransportError(f"Invalid JSON received: {e}", transport_type="stdio")
401
+ except Exception as e:
402
+ self._update_metrics("errors_total")
403
+ raise TransportError(
404
+ f"Failed to receive message: {e}", transport_type="stdio"
405
+ )
406
+
407
+ async def _read_loop(self):
408
+ """Background task to read from subprocess."""
409
+ if not self.process:
410
+ return
411
+
412
+ try:
413
+ while self._connected and self.process:
414
+ # Read line from stdout
415
+ line = await self.process.stdout.readline()
416
+
417
+ if not line:
418
+ # Process ended
419
+ break
420
+
421
+ # Decode and strip
422
+ line_str = line.decode("utf-8").strip()
423
+
424
+ if line_str:
425
+ self._message_buffer.append(line_str)
426
+
427
+ except Exception as e:
428
+ logger.error(f"STDIO read loop error: {e}")
429
+ finally:
430
+ if self._connected:
431
+ await self.disconnect()
432
+
433
+ def _prepare_environment(self) -> Dict[str, str]:
434
+ """Prepare process environment variables."""
435
+ # Start with filtered parent environment
436
+ if self.environment_filter:
437
+ process_env = {
438
+ key: value
439
+ for key, value in os.environ.items()
440
+ if key in self.environment_filter
441
+ }
442
+ else:
443
+ process_env = os.environ.copy()
444
+
445
+ # Add custom environment variables
446
+ process_env.update(self.env)
447
+
448
+ return process_env
449
+
450
+ async def get_process_info(self) -> Dict[str, Any]:
451
+ """Get information about the subprocess."""
452
+ if not self.process:
453
+ return {}
454
+
455
+ return {
456
+ "pid": self.process.pid,
457
+ "returncode": self.process.returncode,
458
+ "command": [self.command] + self.args,
459
+ "working_directory": self.working_directory,
460
+ }
461
+
462
+
463
+ class SSETransport(BaseTransport):
464
+ """Server-Sent Events transport with endpoint negotiation."""
465
+
466
+ def __init__(
467
+ self,
468
+ base_url: str,
469
+ auth_header: Optional[str] = None,
470
+ validate_origin: bool = True,
471
+ allowed_origins: Optional[List[str]] = None,
472
+ endpoint_path: str = "/sse",
473
+ message_path: str = "/message",
474
+ allow_localhost: bool = False,
475
+ skip_security_validation: bool = False,
476
+ **kwargs,
477
+ ):
478
+ """Initialize SSE transport.
479
+
480
+ Args:
481
+ base_url: Base URL for the server
482
+ auth_header: Authorization header
483
+ validate_origin: Enable origin validation
484
+ allowed_origins: List of allowed origins
485
+ endpoint_path: SSE endpoint path
486
+ message_path: Message posting path
487
+ allow_localhost: Allow connections to localhost (for testing)
488
+ skip_security_validation: Skip all security validation (for testing)
489
+ **kwargs: Base transport arguments
490
+ """
491
+ super().__init__("sse", **kwargs)
492
+
493
+ self.base_url = base_url.rstrip("/")
494
+ self.auth_header = auth_header
495
+ self.validate_origin = validate_origin
496
+ self.allowed_origins = allowed_origins or [base_url]
497
+ self.endpoint_path = endpoint_path
498
+ self.message_path = message_path
499
+ self.allow_localhost = allow_localhost
500
+ self.skip_security_validation = skip_security_validation
501
+
502
+ # Connection state
503
+ self.session: Optional[aiohttp.ClientSession] = None
504
+ self.sse_response: Optional[aiohttp.ClientResponse] = None
505
+ self._read_task: Optional[asyncio.Task] = None
506
+ self._message_queue: asyncio.Queue = asyncio.Queue()
507
+
508
+ async def connect(self) -> None:
509
+ """Connect to SSE endpoint."""
510
+ if self._connected:
511
+ return
512
+
513
+ # Validate URL (with configurable security)
514
+ if not self.skip_security_validation:
515
+ if not TransportSecurity.validate_url(
516
+ self.base_url, allow_localhost=self.allow_localhost
517
+ ):
518
+ raise TransportError("Invalid or unsafe URL", transport_type="sse")
519
+
520
+ try:
521
+ # Create session
522
+ headers = {}
523
+ if self.auth_header:
524
+ headers["Authorization"] = self.auth_header
525
+
526
+ # Add CORS headers if origin validation is enabled
527
+ if self.validate_origin:
528
+ headers["Origin"] = self.base_url
529
+
530
+ timeout = aiohttp.ClientTimeout(total=self.timeout)
531
+ self.session = aiohttp.ClientSession(headers=headers, timeout=timeout)
532
+
533
+ # Connect to SSE endpoint
534
+ sse_url = urljoin(self.base_url, self.endpoint_path)
535
+ self.sse_response = await self.session.get(
536
+ sse_url,
537
+ headers={"Accept": "text/event-stream", "Cache-Control": "no-cache"},
538
+ )
539
+
540
+ if self.sse_response.status != 200:
541
+ raise TransportError(
542
+ f"SSE connection failed: {self.sse_response.status}",
543
+ transport_type="sse",
544
+ )
545
+
546
+ # Start reading SSE events
547
+ self._read_task = asyncio.create_task(self._read_sse_events())
548
+
549
+ self._connected = True
550
+ self._update_metrics("connections_total")
551
+
552
+ logger.info(f"SSE transport connected: {sse_url}")
553
+
554
+ except Exception as e:
555
+ self._update_metrics("connections_failed")
556
+ await self._cleanup_connection()
557
+ raise TransportError(f"SSE connection failed: {e}", transport_type="sse")
558
+
559
+ async def disconnect(self) -> None:
560
+ """Disconnect from SSE endpoint."""
561
+ if not self._connected:
562
+ return
563
+
564
+ self._connected = False
565
+ await self._cleanup_connection()
566
+
567
+ logger.info("SSE transport disconnected")
568
+
569
+ async def send_message(self, message: Dict[str, Any]) -> None:
570
+ """Send message via HTTP POST."""
571
+ if not self._connected or not self.session:
572
+ raise TransportError("Transport not connected", transport_type="sse")
573
+
574
+ try:
575
+ message_url = urljoin(self.base_url, self.message_path)
576
+
577
+ async with self.session.post(
578
+ message_url, json=message, headers={"Content-Type": "application/json"}
579
+ ) as response:
580
+ if response.status not in (200, 201, 202):
581
+ raise TransportError(
582
+ f"Message send failed: {response.status}", transport_type="sse"
583
+ )
584
+
585
+ self._update_metrics("messages_sent")
586
+ self._update_metrics("bytes_sent", len(json.dumps(message)))
587
+
588
+ except Exception as e:
589
+ self._update_metrics("errors_total")
590
+ raise TransportError(f"Failed to send message: {e}", transport_type="sse")
591
+
592
+ async def receive_message(self) -> Dict[str, Any]:
593
+ """Receive message from SSE stream."""
594
+ if not self._connected:
595
+ raise TransportError("Transport not connected", transport_type="sse")
596
+
597
+ try:
598
+ # Wait for message from queue
599
+ message = await asyncio.wait_for(
600
+ self._message_queue.get(), timeout=self.timeout
601
+ )
602
+
603
+ self._update_metrics("messages_received")
604
+ self._update_metrics("bytes_received", len(json.dumps(message)))
605
+
606
+ return message
607
+
608
+ except asyncio.TimeoutError:
609
+ raise TransportError("Receive timeout", transport_type="sse")
610
+ except Exception as e:
611
+ self._update_metrics("errors_total")
612
+ raise TransportError(
613
+ f"Failed to receive message: {e}", transport_type="sse"
614
+ )
615
+
616
+ async def _read_sse_events(self):
617
+ """Background task to read SSE events."""
618
+ if not self.sse_response:
619
+ return
620
+
621
+ try:
622
+ async for line in self.sse_response.content:
623
+ if not self._connected:
624
+ break
625
+
626
+ line_str = line.decode("utf-8").strip()
627
+
628
+ # Parse SSE event
629
+ if line_str.startswith("data: "):
630
+ data_str = line_str[6:] # Remove "data: " prefix
631
+
632
+ try:
633
+ message = json.loads(data_str)
634
+ await self._message_queue.put(message)
635
+ except json.JSONDecodeError:
636
+ logger.warning(f"Invalid JSON in SSE event: {data_str}")
637
+
638
+ except Exception as e:
639
+ logger.error(f"SSE read error: {e}")
640
+ finally:
641
+ if self._connected:
642
+ await self.disconnect()
643
+
644
+ async def _cleanup_connection(self):
645
+ """Clean up connection resources."""
646
+ # Cancel read task
647
+ if self._read_task and not self._read_task.done():
648
+ self._read_task.cancel()
649
+ try:
650
+ await self._read_task
651
+ except asyncio.CancelledError:
652
+ pass
653
+
654
+ # Close SSE response
655
+ if self.sse_response:
656
+ self.sse_response.close()
657
+ self.sse_response = None
658
+
659
+ # Close session
660
+ if self.session:
661
+ await self.session.close()
662
+ self.session = None
663
+
664
+
665
+ class StreamableHTTPTransport(BaseTransport):
666
+ """StreamableHTTP transport with session management."""
667
+
668
+ def __init__(
669
+ self,
670
+ base_url: str,
671
+ session_management: bool = True,
672
+ streaming_threshold: int = 1024,
673
+ chunk_size: int = 8192,
674
+ allow_localhost: bool = False,
675
+ skip_security_validation: bool = False,
676
+ **kwargs,
677
+ ):
678
+ """Initialize StreamableHTTP transport.
679
+
680
+ Args:
681
+ base_url: Base URL for the server
682
+ session_management: Enable session management
683
+ streaming_threshold: Size threshold for streaming
684
+ chunk_size: Chunk size for streaming
685
+ allow_localhost: Allow connections to localhost (for testing)
686
+ skip_security_validation: Skip all security validation (for testing)
687
+ **kwargs: Base transport arguments
688
+ """
689
+ super().__init__("streamable_http", **kwargs)
690
+
691
+ self.base_url = base_url.rstrip("/")
692
+ self.session_management = session_management
693
+ self.streaming_threshold = streaming_threshold
694
+ self.chunk_size = chunk_size
695
+ self.allow_localhost = allow_localhost
696
+ self.skip_security_validation = skip_security_validation
697
+
698
+ # Session state
699
+ self.session: Optional[aiohttp.ClientSession] = None
700
+ self.session_id: Optional[str] = None
701
+
702
+ async def connect(self) -> None:
703
+ """Connect and optionally create session."""
704
+ if self._connected:
705
+ return
706
+
707
+ # Validate URL (with configurable security)
708
+ if not self.skip_security_validation:
709
+ if not TransportSecurity.validate_url(
710
+ self.base_url, allow_localhost=self.allow_localhost
711
+ ):
712
+ raise TransportError(
713
+ "Invalid or unsafe URL", transport_type="streamable_http"
714
+ )
715
+
716
+ try:
717
+ # Create HTTP session
718
+ timeout = aiohttp.ClientTimeout(total=self.timeout)
719
+ self.session = aiohttp.ClientSession(timeout=timeout)
720
+
721
+ # Create server session if enabled
722
+ if self.session_management:
723
+ await self._create_server_session()
724
+
725
+ self._connected = True
726
+ self._update_metrics("connections_total")
727
+
728
+ logger.info(f"StreamableHTTP transport connected: {self.base_url}")
729
+
730
+ except Exception as e:
731
+ self._update_metrics("connections_failed")
732
+ await self._cleanup_connection()
733
+ raise TransportError(
734
+ f"HTTP connection failed: {e}", transport_type="streamable_http"
735
+ )
736
+
737
+ async def disconnect(self) -> None:
738
+ """Disconnect and cleanup session."""
739
+ if not self._connected:
740
+ return
741
+
742
+ self._connected = False
743
+
744
+ # Close server session
745
+ if self.session_management and self.session_id:
746
+ await self._close_server_session()
747
+
748
+ await self._cleanup_connection()
749
+
750
+ logger.info("StreamableHTTP transport disconnected")
751
+
752
+ async def send_message(self, message: Dict[str, Any]) -> None:
753
+ """Send message via HTTP POST."""
754
+ if not self._connected or not self.session:
755
+ raise TransportError(
756
+ "Transport not connected", transport_type="streamable_http"
757
+ )
758
+
759
+ try:
760
+ # Prepare URL
761
+ url = urljoin(self.base_url, "/message")
762
+
763
+ # Add session ID if using session management
764
+ headers = {"Content-Type": "application/json"}
765
+ if self.session_id:
766
+ headers["X-Session-ID"] = self.session_id
767
+
768
+ # Determine if streaming is needed
769
+ message_data = json.dumps(message)
770
+ use_streaming = len(message_data) > self.streaming_threshold
771
+
772
+ if use_streaming:
773
+ # Stream large message
774
+ await self._send_streamed_message(url, message_data, headers)
775
+ else:
776
+ # Send normal message
777
+ async with self.session.post(
778
+ url, json=message, headers=headers
779
+ ) as response:
780
+ if response.status not in (200, 201, 202):
781
+ raise TransportError(
782
+ f"Message send failed: {response.status}",
783
+ transport_type="streamable_http",
784
+ )
785
+
786
+ self._update_metrics("messages_sent")
787
+ self._update_metrics("bytes_sent", len(message_data))
788
+
789
+ except Exception as e:
790
+ self._update_metrics("errors_total")
791
+ raise TransportError(
792
+ f"Failed to send message: {e}", transport_type="streamable_http"
793
+ )
794
+
795
+ async def receive_message(self) -> Dict[str, Any]:
796
+ """Receive message via HTTP GET/POST."""
797
+ if not self._connected or not self.session:
798
+ raise TransportError(
799
+ "Transport not connected", transport_type="streamable_http"
800
+ )
801
+
802
+ try:
803
+ # Prepare URL
804
+ url = urljoin(self.base_url, "/receive")
805
+
806
+ # Add session ID if using session management
807
+ headers = {}
808
+ if self.session_id:
809
+ headers["X-Session-ID"] = self.session_id
810
+
811
+ # Receive message
812
+ async with self.session.get(url, headers=headers) as response:
813
+ if response.status == 204:
814
+ # No message available
815
+ await asyncio.sleep(0.1) # Brief delay before retry
816
+ return await self.receive_message()
817
+
818
+ if response.status != 200:
819
+ raise TransportError(
820
+ f"Message receive failed: {response.status}",
821
+ transport_type="streamable_http",
822
+ )
823
+
824
+ # Check if response is streamed
825
+ content_length = response.headers.get("Content-Length")
826
+ if content_length and int(content_length) > self.streaming_threshold:
827
+ message = await self._receive_streamed_message(response)
828
+ else:
829
+ message = await response.json()
830
+
831
+ self._update_metrics("messages_received")
832
+ self._update_metrics("bytes_received", len(json.dumps(message)))
833
+
834
+ return message
835
+
836
+ except Exception as e:
837
+ self._update_metrics("errors_total")
838
+ raise TransportError(
839
+ f"Failed to receive message: {e}", transport_type="streamable_http"
840
+ )
841
+
842
+ async def _create_server_session(self):
843
+ """Create session with server."""
844
+ if not self.session:
845
+ return
846
+
847
+ url = urljoin(self.base_url, "/session")
848
+
849
+ async with self.session.post(url) as response:
850
+ if response.status == 201:
851
+ session_data = await response.json()
852
+ self.session_id = session_data.get("session_id")
853
+ logger.info(f"Created server session: {self.session_id}")
854
+ else:
855
+ logger.warning(f"Failed to create server session: {response.status}")
856
+
857
+ async def _close_server_session(self):
858
+ """Close session with server."""
859
+ if not self.session or not self.session_id:
860
+ return
861
+
862
+ url = urljoin(self.base_url, f"/session/{self.session_id}")
863
+
864
+ try:
865
+ async with self.session.delete(url) as response:
866
+ if response.status in (200, 204):
867
+ logger.info(f"Closed server session: {self.session_id}")
868
+ else:
869
+ logger.warning(f"Failed to close server session: {response.status}")
870
+ except Exception as e:
871
+ logger.error(f"Error closing server session: {e}")
872
+ finally:
873
+ self.session_id = None
874
+
875
+ async def _send_streamed_message(
876
+ self, url: str, message_data: str, headers: Dict[str, str]
877
+ ):
878
+ """Send message using streaming."""
879
+ headers["Transfer-Encoding"] = "chunked"
880
+
881
+ async def message_chunks():
882
+ for i in range(0, len(message_data), self.chunk_size):
883
+ yield message_data[i : i + self.chunk_size].encode("utf-8")
884
+
885
+ async with self.session.post(
886
+ url, data=message_chunks(), headers=headers
887
+ ) as response:
888
+ if response.status not in (200, 201, 202):
889
+ raise TransportError(
890
+ f"Streamed message send failed: {response.status}",
891
+ transport_type="streamable_http",
892
+ )
893
+
894
+ async def _receive_streamed_message(
895
+ self, response: aiohttp.ClientResponse
896
+ ) -> Dict[str, Any]:
897
+ """Receive streamed message."""
898
+ chunks = []
899
+
900
+ async for chunk in response.content.iter_chunked(self.chunk_size):
901
+ chunks.append(chunk.decode("utf-8"))
902
+
903
+ message_data = "".join(chunks)
904
+ return json.loads(message_data)
905
+
906
+ async def _cleanup_connection(self):
907
+ """Clean up connection resources."""
908
+ if self.session:
909
+ await self.session.close()
910
+ self.session = None
911
+
912
+
913
+ class WebSocketTransport(BaseTransport):
914
+ """WebSocket transport for real-time communication."""
915
+
916
+ def __init__(
917
+ self,
918
+ url: str,
919
+ subprotocols: Optional[List[str]] = None,
920
+ ping_interval: float = 20.0,
921
+ ping_timeout: float = 20.0,
922
+ allow_localhost: bool = False,
923
+ skip_security_validation: bool = False,
924
+ **kwargs,
925
+ ):
926
+ """Initialize WebSocket transport.
927
+
928
+ Args:
929
+ url: WebSocket URL
930
+ subprotocols: WebSocket subprotocols
931
+ ping_interval: Ping interval in seconds
932
+ ping_timeout: Ping timeout in seconds
933
+ allow_localhost: Allow connections to localhost (for testing)
934
+ skip_security_validation: Skip all security validation (for testing)
935
+ **kwargs: Base transport arguments
936
+ """
937
+ super().__init__("websocket", **kwargs)
938
+
939
+ self.url = url
940
+ self.subprotocols = subprotocols or ["mcp-v1"]
941
+ self.ping_interval = ping_interval
942
+ self.ping_timeout = ping_timeout
943
+ self.allow_localhost = allow_localhost
944
+ self.skip_security_validation = skip_security_validation
945
+
946
+ # Connection state
947
+ self.websocket: Optional[websockets.WebSocketServerProtocol] = None
948
+ self._read_task: Optional[asyncio.Task] = None
949
+ self._message_queue: asyncio.Queue = asyncio.Queue()
950
+
951
+ async def connect(self) -> None:
952
+ """Connect to WebSocket server."""
953
+ if self._connected:
954
+ return
955
+
956
+ # Validate URL (with configurable security)
957
+ if not self.skip_security_validation:
958
+ if not TransportSecurity.validate_url(
959
+ self.url, allow_localhost=self.allow_localhost
960
+ ):
961
+ raise TransportError(
962
+ "Invalid or unsafe URL", transport_type="websocket"
963
+ )
964
+
965
+ try:
966
+ # Connect to WebSocket
967
+ extra_headers = {}
968
+ if self.auth_provider:
969
+ # Add authentication headers
970
+ auth_headers = await self.auth_provider.get_headers()
971
+ extra_headers.update(auth_headers)
972
+
973
+ self.websocket = await websockets.connect(
974
+ self.url,
975
+ subprotocols=self.subprotocols,
976
+ extra_headers=extra_headers,
977
+ ping_interval=self.ping_interval,
978
+ ping_timeout=self.ping_timeout,
979
+ )
980
+
981
+ # Start reading messages
982
+ self._read_task = asyncio.create_task(self._read_messages())
983
+
984
+ self._connected = True
985
+ self._update_metrics("connections_total")
986
+
987
+ logger.info(f"WebSocket transport connected: {self.url}")
988
+
989
+ except Exception as e:
990
+ self._update_metrics("connections_failed")
991
+ await self._cleanup_connection()
992
+ raise TransportError(
993
+ f"WebSocket connection failed: {e}", transport_type="websocket"
994
+ )
995
+
996
+ async def disconnect(self) -> None:
997
+ """Disconnect from WebSocket server."""
998
+ if not self._connected:
999
+ return
1000
+
1001
+ self._connected = False
1002
+ await self._cleanup_connection()
1003
+
1004
+ logger.info("WebSocket transport disconnected")
1005
+
1006
+ async def send_message(self, message: Dict[str, Any]) -> None:
1007
+ """Send message via WebSocket."""
1008
+ if not self._connected or not self.websocket:
1009
+ raise TransportError("Transport not connected", transport_type="websocket")
1010
+
1011
+ try:
1012
+ message_data = json.dumps(message)
1013
+ await self.websocket.send(message_data)
1014
+
1015
+ self._update_metrics("messages_sent")
1016
+ self._update_metrics("bytes_sent", len(message_data))
1017
+
1018
+ except Exception as e:
1019
+ self._update_metrics("errors_total")
1020
+ raise TransportError(
1021
+ f"Failed to send message: {e}", transport_type="websocket"
1022
+ )
1023
+
1024
+ async def receive_message(self) -> Dict[str, Any]:
1025
+ """Receive message from WebSocket."""
1026
+ if not self._connected:
1027
+ raise TransportError("Transport not connected", transport_type="websocket")
1028
+
1029
+ try:
1030
+ message = await asyncio.wait_for(
1031
+ self._message_queue.get(), timeout=self.timeout
1032
+ )
1033
+
1034
+ self._update_metrics("messages_received")
1035
+ self._update_metrics("bytes_received", len(json.dumps(message)))
1036
+
1037
+ return message
1038
+
1039
+ except asyncio.TimeoutError:
1040
+ raise TransportError("Receive timeout", transport_type="websocket")
1041
+ except Exception as e:
1042
+ self._update_metrics("errors_total")
1043
+ raise TransportError(
1044
+ f"Failed to receive message: {e}", transport_type="websocket"
1045
+ )
1046
+
1047
+ async def _read_messages(self):
1048
+ """Background task to read WebSocket messages."""
1049
+ if not self.websocket:
1050
+ return
1051
+
1052
+ try:
1053
+ async for message_data in self.websocket:
1054
+ if not self._connected:
1055
+ break
1056
+
1057
+ try:
1058
+ message = json.loads(message_data)
1059
+ await self._message_queue.put(message)
1060
+ except json.JSONDecodeError:
1061
+ logger.warning(f"Invalid JSON in WebSocket message: {message_data}")
1062
+
1063
+ except websockets.exceptions.ConnectionClosed:
1064
+ logger.info("WebSocket connection closed")
1065
+ except Exception as e:
1066
+ logger.error(f"WebSocket read error: {e}")
1067
+ finally:
1068
+ if self._connected:
1069
+ await self.disconnect()
1070
+
1071
+ async def _cleanup_connection(self):
1072
+ """Clean up connection resources."""
1073
+ # Cancel read task
1074
+ if self._read_task and not self._read_task.done():
1075
+ self._read_task.cancel()
1076
+ try:
1077
+ await self._read_task
1078
+ except asyncio.CancelledError:
1079
+ pass
1080
+
1081
+ # Close WebSocket
1082
+ if self.websocket:
1083
+ await self.websocket.close()
1084
+ self.websocket = None
1085
+
1086
+
1087
+ class TransportManager:
1088
+ """Manager for MCP transport instances."""
1089
+
1090
+ def __init__(self):
1091
+ """Initialize transport manager."""
1092
+ self._transports: Dict[str, BaseTransport] = {}
1093
+ self._transport_factories: Dict[str, Callable] = {
1094
+ "stdio": EnhancedStdioTransport,
1095
+ "sse": SSETransport,
1096
+ "streamable_http": StreamableHTTPTransport,
1097
+ "websocket": WebSocketTransport,
1098
+ }
1099
+
1100
+ def register_transport_factory(self, transport_type: str, factory: Callable):
1101
+ """Register transport factory.
1102
+
1103
+ Args:
1104
+ transport_type: Transport type name
1105
+ factory: Factory function
1106
+ """
1107
+ self._transport_factories[transport_type] = factory
1108
+
1109
+ def create_transport(self, transport_type: str, **kwargs) -> BaseTransport:
1110
+ """Create transport instance.
1111
+
1112
+ Args:
1113
+ transport_type: Transport type
1114
+ **kwargs: Transport arguments
1115
+
1116
+ Returns:
1117
+ Transport instance
1118
+ """
1119
+ factory = self._transport_factories.get(transport_type)
1120
+ if not factory:
1121
+ raise ValueError(f"Unknown transport type: {transport_type}")
1122
+
1123
+ return factory(**kwargs)
1124
+
1125
+ def register_transport(self, name: str, transport: BaseTransport):
1126
+ """Register transport instance.
1127
+
1128
+ Args:
1129
+ name: Transport name
1130
+ transport: Transport instance
1131
+ """
1132
+ self._transports[name] = transport
1133
+
1134
+ def get_transport(self, name: str) -> Optional[BaseTransport]:
1135
+ """Get registered transport.
1136
+
1137
+ Args:
1138
+ name: Transport name
1139
+
1140
+ Returns:
1141
+ Transport instance or None
1142
+ """
1143
+ return self._transports.get(name)
1144
+
1145
+ def list_transports(self) -> List[str]:
1146
+ """List registered transport names."""
1147
+ return list(self._transports.keys())
1148
+
1149
+ async def disconnect_all(self):
1150
+ """Disconnect all registered transports."""
1151
+ for transport in self._transports.values():
1152
+ try:
1153
+ await transport.disconnect()
1154
+ except Exception as e:
1155
+ logger.error(f"Error disconnecting transport: {e}")
1156
+
1157
+ self._transports.clear()
1158
+
1159
+
1160
+ # Global transport manager
1161
+ _transport_manager: Optional[TransportManager] = None
1162
+
1163
+
1164
+ def get_transport_manager() -> TransportManager:
1165
+ """Get global transport manager."""
1166
+ global _transport_manager
1167
+ if _transport_manager is None:
1168
+ _transport_manager = TransportManager()
1169
+ return _transport_manager