kailash 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. kailash/__init__.py +3 -3
  2. kailash/api/custom_nodes_secure.py +3 -3
  3. kailash/api/gateway.py +1 -1
  4. kailash/api/studio.py +1 -1
  5. kailash/api/workflow_api.py +2 -2
  6. kailash/core/resilience/bulkhead.py +475 -0
  7. kailash/core/resilience/circuit_breaker.py +92 -10
  8. kailash/core/resilience/health_monitor.py +578 -0
  9. kailash/edge/discovery.py +86 -0
  10. kailash/mcp_server/__init__.py +309 -33
  11. kailash/mcp_server/advanced_features.py +1022 -0
  12. kailash/mcp_server/ai_registry_server.py +27 -2
  13. kailash/mcp_server/auth.py +789 -0
  14. kailash/mcp_server/client.py +645 -378
  15. kailash/mcp_server/discovery.py +1593 -0
  16. kailash/mcp_server/errors.py +673 -0
  17. kailash/mcp_server/oauth.py +1727 -0
  18. kailash/mcp_server/protocol.py +1126 -0
  19. kailash/mcp_server/registry_integration.py +587 -0
  20. kailash/mcp_server/server.py +1228 -96
  21. kailash/mcp_server/transports.py +1169 -0
  22. kailash/mcp_server/utils/__init__.py +6 -1
  23. kailash/mcp_server/utils/cache.py +250 -7
  24. kailash/middleware/auth/auth_manager.py +3 -3
  25. kailash/middleware/communication/api_gateway.py +1 -1
  26. kailash/middleware/communication/realtime.py +1 -1
  27. kailash/middleware/mcp/enhanced_server.py +1 -1
  28. kailash/nodes/__init__.py +2 -0
  29. kailash/nodes/admin/audit_log.py +6 -6
  30. kailash/nodes/admin/permission_check.py +8 -8
  31. kailash/nodes/admin/role_management.py +32 -28
  32. kailash/nodes/admin/schema.sql +6 -1
  33. kailash/nodes/admin/schema_manager.py +13 -13
  34. kailash/nodes/admin/security_event.py +15 -15
  35. kailash/nodes/admin/tenant_isolation.py +3 -3
  36. kailash/nodes/admin/transaction_utils.py +3 -3
  37. kailash/nodes/admin/user_management.py +21 -21
  38. kailash/nodes/ai/a2a.py +11 -11
  39. kailash/nodes/ai/ai_providers.py +9 -12
  40. kailash/nodes/ai/embedding_generator.py +13 -14
  41. kailash/nodes/ai/intelligent_agent_orchestrator.py +19 -19
  42. kailash/nodes/ai/iterative_llm_agent.py +2 -2
  43. kailash/nodes/ai/llm_agent.py +210 -33
  44. kailash/nodes/ai/self_organizing.py +2 -2
  45. kailash/nodes/alerts/discord.py +4 -4
  46. kailash/nodes/api/graphql.py +6 -6
  47. kailash/nodes/api/http.py +10 -10
  48. kailash/nodes/api/rate_limiting.py +4 -4
  49. kailash/nodes/api/rest.py +15 -15
  50. kailash/nodes/auth/mfa.py +3 -3
  51. kailash/nodes/auth/risk_assessment.py +2 -2
  52. kailash/nodes/auth/session_management.py +5 -5
  53. kailash/nodes/auth/sso.py +143 -0
  54. kailash/nodes/base.py +8 -2
  55. kailash/nodes/base_async.py +16 -2
  56. kailash/nodes/base_with_acl.py +2 -2
  57. kailash/nodes/cache/__init__.py +9 -0
  58. kailash/nodes/cache/cache.py +1172 -0
  59. kailash/nodes/cache/cache_invalidation.py +874 -0
  60. kailash/nodes/cache/redis_pool_manager.py +595 -0
  61. kailash/nodes/code/async_python.py +2 -1
  62. kailash/nodes/code/python.py +194 -30
  63. kailash/nodes/compliance/data_retention.py +6 -6
  64. kailash/nodes/compliance/gdpr.py +5 -5
  65. kailash/nodes/data/__init__.py +10 -0
  66. kailash/nodes/data/async_sql.py +1956 -129
  67. kailash/nodes/data/optimistic_locking.py +906 -0
  68. kailash/nodes/data/readers.py +8 -8
  69. kailash/nodes/data/redis.py +378 -0
  70. kailash/nodes/data/sql.py +314 -3
  71. kailash/nodes/data/streaming.py +21 -0
  72. kailash/nodes/enterprise/__init__.py +8 -0
  73. kailash/nodes/enterprise/audit_logger.py +285 -0
  74. kailash/nodes/enterprise/batch_processor.py +22 -3
  75. kailash/nodes/enterprise/data_lineage.py +1 -1
  76. kailash/nodes/enterprise/mcp_executor.py +205 -0
  77. kailash/nodes/enterprise/service_discovery.py +150 -0
  78. kailash/nodes/enterprise/tenant_assignment.py +108 -0
  79. kailash/nodes/logic/async_operations.py +2 -2
  80. kailash/nodes/logic/convergence.py +1 -1
  81. kailash/nodes/logic/operations.py +1 -1
  82. kailash/nodes/monitoring/__init__.py +11 -1
  83. kailash/nodes/monitoring/health_check.py +456 -0
  84. kailash/nodes/monitoring/log_processor.py +817 -0
  85. kailash/nodes/monitoring/metrics_collector.py +627 -0
  86. kailash/nodes/monitoring/performance_benchmark.py +137 -11
  87. kailash/nodes/rag/advanced.py +7 -7
  88. kailash/nodes/rag/agentic.py +49 -2
  89. kailash/nodes/rag/conversational.py +3 -3
  90. kailash/nodes/rag/evaluation.py +3 -3
  91. kailash/nodes/rag/federated.py +3 -3
  92. kailash/nodes/rag/graph.py +3 -3
  93. kailash/nodes/rag/multimodal.py +3 -3
  94. kailash/nodes/rag/optimized.py +5 -5
  95. kailash/nodes/rag/privacy.py +3 -3
  96. kailash/nodes/rag/query_processing.py +6 -6
  97. kailash/nodes/rag/realtime.py +1 -1
  98. kailash/nodes/rag/registry.py +1 -1
  99. kailash/nodes/rag/router.py +1 -1
  100. kailash/nodes/rag/similarity.py +7 -7
  101. kailash/nodes/rag/strategies.py +4 -4
  102. kailash/nodes/security/abac_evaluator.py +6 -6
  103. kailash/nodes/security/behavior_analysis.py +5 -5
  104. kailash/nodes/security/credential_manager.py +1 -1
  105. kailash/nodes/security/rotating_credentials.py +11 -11
  106. kailash/nodes/security/threat_detection.py +8 -8
  107. kailash/nodes/testing/credential_testing.py +2 -2
  108. kailash/nodes/transform/processors.py +5 -5
  109. kailash/runtime/local.py +163 -9
  110. kailash/runtime/parameter_injection.py +425 -0
  111. kailash/runtime/parameter_injector.py +657 -0
  112. kailash/runtime/testing.py +2 -2
  113. kailash/testing/fixtures.py +2 -2
  114. kailash/workflow/builder.py +99 -14
  115. kailash/workflow/builder_improvements.py +207 -0
  116. kailash/workflow/input_handling.py +170 -0
  117. {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/METADATA +22 -9
  118. {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/RECORD +122 -95
  119. {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/WHEEL +0 -0
  120. {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/entry_points.txt +0 -0
  121. {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/licenses/LICENSE +0 -0
  122. {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1126 @@
1
+ """
2
+ Complete MCP Protocol Implementation.
3
+
4
+ This module implements the full Model Context Protocol (MCP) specification,
5
+ including all message types, progress reporting, cancellation, completion,
6
+ sampling, and other advanced protocol features that build on the official
7
+ MCP Python SDK.
8
+
9
+ Features:
10
+ - Complete protocol message type definitions
11
+ - Progress reporting with token-based tracking
12
+ - Request cancellation and cleanup
13
+ - Completion system for prompts and resources
14
+ - Sampling system for LLM interactions
15
+ - Roots system for file system access
16
+ - Meta field support for protocol metadata
17
+ - Proper error handling with standard codes
18
+
19
+ The implementation follows the official MCP specification while providing
20
+ enhanced functionality for production use cases.
21
+
22
+ Examples:
23
+ Progress reporting:
24
+
25
+ >>> from kailash.mcp_server.protocol import ProgressManager
26
+ >>> progress = ProgressManager()
27
+ >>>
28
+ >>> # Start progress tracking
29
+ >>> token = progress.start_progress("long_operation", total=100)
30
+ >>> for i in range(100):
31
+ ... await progress.update_progress(token, progress=i, status=f"Step {i}")
32
+ >>> await progress.complete_progress(token)
33
+
34
+ Request cancellation:
35
+
36
+ >>> from kailash.mcp_server.protocol import CancellationManager
37
+ >>> cancellation = CancellationManager()
38
+ >>>
39
+ >>> # Check if request should be cancelled
40
+ >>> if await cancellation.is_cancelled(request_id):
41
+ ... raise CancelledError("Operation was cancelled")
42
+
43
+ Completion system:
44
+
45
+ >>> from kailash.mcp_server.protocol import CompletionManager
46
+ >>> completion = CompletionManager()
47
+ >>>
48
+ >>> # Get completions for a prompt argument
49
+ >>> completions = await completion.get_completions(
50
+ ... "prompts/analyze", "data_source", "fil"
51
+ ... )
52
+ """
53
+
54
+ import asyncio
55
+ import json
56
+ import logging
57
+ import time
58
+ import uuid
59
+ from abc import ABC, abstractmethod
60
+ from dataclasses import asdict, dataclass, field
61
+ from enum import Enum
62
+ from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
63
+
64
+ from .errors import MCPError, MCPErrorCode
65
+
66
+ logger = logging.getLogger(__name__)
67
+
68
+
69
+ class MessageType(Enum):
70
+ """MCP message types following the official specification."""
71
+
72
+ # Core protocol
73
+ INITIALIZE = "initialize"
74
+ INITIALIZED = "initialized"
75
+
76
+ # Tool operations
77
+ TOOLS_LIST = "tools/list"
78
+ TOOLS_CALL = "tools/call"
79
+
80
+ # Resource operations
81
+ RESOURCES_LIST = "resources/list"
82
+ RESOURCES_READ = "resources/read"
83
+ RESOURCES_SUBSCRIBE = "resources/subscribe"
84
+ RESOURCES_UNSUBSCRIBE = "resources/unsubscribe"
85
+ RESOURCES_UPDATED = "notifications/resources/updated"
86
+
87
+ # Prompt operations
88
+ PROMPTS_LIST = "prompts/list"
89
+ PROMPTS_GET = "prompts/get"
90
+
91
+ # Progress operations
92
+ PROGRESS = "notifications/progress"
93
+
94
+ # Cancellation
95
+ CANCELLED = "notifications/cancelled"
96
+
97
+ # Completion
98
+ COMPLETION_COMPLETE = "completion/complete"
99
+
100
+ # Sampling (Server to Client)
101
+ SAMPLING_CREATE_MESSAGE = "sampling/createMessage"
102
+
103
+ # Roots (File system)
104
+ ROOTS_LIST = "roots/list"
105
+
106
+ # Logging
107
+ LOGGING_SET_LEVEL = "logging/setLevel"
108
+
109
+ # Custom extensions
110
+ PING = "ping"
111
+ PONG = "pong"
112
+ REQUEST = "request" # Generic request type
113
+ NOTIFICATION = "notification" # Generic notification type
114
+
115
+
116
+ @dataclass
117
+ class ProgressToken:
118
+ """Type-safe progress token with tracking information."""
119
+
120
+ value: str
121
+ operation_name: str
122
+ total: Optional[float] = None
123
+ progress: float = 0
124
+ status: Optional[str] = None
125
+
126
+ def __hash__(self):
127
+ """Make hashable for use in dictionaries."""
128
+ return hash(self.value)
129
+
130
+ def __eq__(self, other):
131
+ """Compare tokens by value."""
132
+ if isinstance(other, ProgressToken):
133
+ return self.value == other.value
134
+ return False
135
+
136
+
137
+ @dataclass
138
+ class MetaData:
139
+ """Meta fields for protocol messages."""
140
+
141
+ progress_token: Optional[ProgressToken] = None
142
+ request_id: Optional[str] = None
143
+ timestamp: Optional[float] = None
144
+ operation_id: Optional[str] = None
145
+ user_id: Optional[str] = None
146
+ additional_data: Optional[Dict[str, Any]] = None
147
+
148
+ def __post_init__(self):
149
+ """Initialize timestamp if not provided."""
150
+ if self.timestamp is None:
151
+ self.timestamp = time.time()
152
+ if self.additional_data is None:
153
+ self.additional_data = {}
154
+
155
+ def to_dict(self) -> Dict[str, Any]:
156
+ """Convert to dictionary for JSON serialization."""
157
+ result = {}
158
+ if self.progress_token:
159
+ result["progressToken"] = self.progress_token
160
+ if self.request_id:
161
+ result["requestId"] = self.request_id
162
+ if self.timestamp:
163
+ result["timestamp"] = self.timestamp
164
+ if self.operation_id:
165
+ result["operation_id"] = self.operation_id
166
+ if self.user_id:
167
+ result["user_id"] = self.user_id
168
+ if self.additional_data:
169
+ result.update(self.additional_data)
170
+ return result
171
+
172
+
173
+ @dataclass
174
+ class ProgressNotification:
175
+ """Progress notification message."""
176
+
177
+ method: str = "notifications/progress"
178
+ params: Dict[str, Any] = field(default_factory=dict)
179
+
180
+ def __post_init__(self):
181
+ """Ensure proper params structure."""
182
+ if "progressToken" not in self.params:
183
+ raise ValueError("Progress notification requires progressToken")
184
+
185
+ @classmethod
186
+ def create(
187
+ cls,
188
+ progress_token: ProgressToken,
189
+ progress: Optional[float] = None,
190
+ total: Optional[float] = None,
191
+ status: Optional[str] = None,
192
+ ) -> "ProgressNotification":
193
+ """Create progress notification."""
194
+ params = {"progressToken": progress_token}
195
+
196
+ if progress is not None:
197
+ params["progress"] = progress
198
+ if total is not None:
199
+ params["total"] = total
200
+ if status is not None:
201
+ params["status"] = status
202
+
203
+ return cls(params=params)
204
+
205
+
206
+ @dataclass
207
+ class CancelledNotification:
208
+ """Cancellation notification message."""
209
+
210
+ method: str = "notifications/cancelled"
211
+ params: Dict[str, Any] = field(default_factory=dict)
212
+
213
+ def __post_init__(self):
214
+ """Ensure proper params structure."""
215
+ if "requestId" not in self.params:
216
+ raise ValueError("Cancellation notification requires requestId")
217
+
218
+ @classmethod
219
+ def create(
220
+ cls, request_id: str, reason: Optional[str] = None
221
+ ) -> "CancelledNotification":
222
+ """Create cancellation notification."""
223
+ params = {"requestId": request_id}
224
+ if reason:
225
+ params["reason"] = reason
226
+ return cls(params=params)
227
+
228
+
229
+ @dataclass
230
+ class CompletionRequest:
231
+ """Completion request for prompts and resources."""
232
+
233
+ method: str = "completion/complete"
234
+ params: Dict[str, Any] = field(default_factory=dict)
235
+
236
+ @classmethod
237
+ def create(
238
+ cls, ref: Dict[str, Any], argument: Optional[Dict[str, Any]] = None
239
+ ) -> "CompletionRequest":
240
+ """Create completion request."""
241
+ params = {"ref": ref}
242
+ if argument:
243
+ params["argument"] = argument
244
+ return cls(params=params)
245
+
246
+
247
+ @dataclass
248
+ class CompletionResult:
249
+ """Completion result with completion values."""
250
+
251
+ completion: Dict[str, Any]
252
+
253
+ @classmethod
254
+ def create(
255
+ cls, values: List[str], total: Optional[int] = None
256
+ ) -> "CompletionResult":
257
+ """Create completion result."""
258
+ completion = {"values": values}
259
+ if total is not None:
260
+ completion["total"] = total
261
+ return cls(completion=completion)
262
+
263
+
264
+ @dataclass
265
+ class SamplingRequest:
266
+ """Sampling request from server to client."""
267
+
268
+ method: str = "sampling/createMessage"
269
+ params: Dict[str, Any] = field(default_factory=dict)
270
+
271
+ @classmethod
272
+ def create(
273
+ cls,
274
+ messages: List[Dict[str, Any]],
275
+ model_preferences: Optional[Dict[str, Any]] = None,
276
+ system_prompt: Optional[str] = None,
277
+ include_context: Optional[str] = None,
278
+ temperature: Optional[float] = None,
279
+ max_tokens: Optional[int] = None,
280
+ stop_sequences: Optional[List[str]] = None,
281
+ metadata: Optional[Dict[str, Any]] = None,
282
+ ) -> "SamplingRequest":
283
+ """Create sampling request."""
284
+ params = {"messages": messages}
285
+
286
+ if model_preferences:
287
+ params["modelPreferences"] = model_preferences
288
+ if system_prompt:
289
+ params["systemPrompt"] = system_prompt
290
+ if include_context:
291
+ params["includeContext"] = include_context
292
+ if temperature is not None:
293
+ params["temperature"] = temperature
294
+ if max_tokens is not None:
295
+ params["maxTokens"] = max_tokens
296
+ if stop_sequences:
297
+ params["stopSequences"] = stop_sequences
298
+ if metadata:
299
+ params["metadata"] = metadata
300
+
301
+ return cls(params=params)
302
+
303
+
304
+ @dataclass
305
+ class ResourceTemplate:
306
+ """Resource template with URI templates."""
307
+
308
+ uri_template: str
309
+ name: Optional[str] = None
310
+ description: Optional[str] = None
311
+ mime_type: Optional[str] = None
312
+
313
+ def to_dict(self) -> Dict[str, Any]:
314
+ """Convert to dictionary."""
315
+ result = {"uriTemplate": self.uri_template}
316
+ if self.name:
317
+ result["name"] = self.name
318
+ if self.description:
319
+ result["description"] = self.description
320
+ if self.mime_type:
321
+ result["mimeType"] = self.mime_type
322
+ return result
323
+
324
+
325
+ @dataclass
326
+ class ToolResult:
327
+ """Enhanced tool result with structured content."""
328
+
329
+ content: List[Dict[str, Any]]
330
+ is_error: bool = False
331
+
332
+ @classmethod
333
+ def text(cls, text: str, is_error: bool = False) -> "ToolResult":
334
+ """Create text result."""
335
+ return cls(content=[{"type": "text", "text": text}], is_error=is_error)
336
+
337
+ @classmethod
338
+ def image(cls, data: str, mime_type: str) -> "ToolResult":
339
+ """Create image result."""
340
+ return cls(content=[{"type": "image", "data": data, "mimeType": mime_type}])
341
+
342
+ @classmethod
343
+ def resource(
344
+ cls, uri: str, text: Optional[str] = None, mime_type: Optional[str] = None
345
+ ) -> "ToolResult":
346
+ """Create resource result."""
347
+ content = {"type": "resource", "resource": {"uri": uri}}
348
+ if text:
349
+ content["resource"]["text"] = text
350
+ if mime_type:
351
+ content["resource"]["mimeType"] = mime_type
352
+ return cls(content=[content])
353
+
354
+ def to_dict(self) -> Dict[str, Any]:
355
+ """Convert to dictionary."""
356
+ result = {"content": self.content}
357
+ if self.is_error:
358
+ result["isError"] = self.is_error
359
+ return result
360
+
361
+
362
+ class ProgressManager:
363
+ """Manages progress reporting for long-running operations."""
364
+
365
+ def __init__(self):
366
+ """Initialize progress manager."""
367
+ self._active_progress: Dict[ProgressToken, Dict[str, Any]] = {}
368
+ self._progress_callbacks: Dict[ProgressToken, List[Callable]] = {}
369
+
370
+ def start_progress(
371
+ self,
372
+ operation_name: str,
373
+ total: Optional[float] = None,
374
+ progress_token: Optional[ProgressToken] = None,
375
+ ) -> ProgressToken:
376
+ """Start progress tracking for an operation.
377
+
378
+ Args:
379
+ operation_name: Name of the operation
380
+ total: Total progress units (if known)
381
+ progress_token: Custom progress token (generates if None)
382
+
383
+ Returns:
384
+ Progress token for tracking
385
+ """
386
+ if progress_token is None:
387
+ token_value = f"progress_{uuid.uuid4().hex[:8]}"
388
+ progress_token = ProgressToken(
389
+ value=token_value,
390
+ operation_name=operation_name,
391
+ total=total,
392
+ progress=0,
393
+ status="started",
394
+ )
395
+
396
+ self._active_progress[progress_token] = {
397
+ "operation": operation_name,
398
+ "started_at": time.time(),
399
+ "total": total,
400
+ "current": 0,
401
+ "status": "started",
402
+ }
403
+ self._progress_callbacks[progress_token] = []
404
+
405
+ logger.debug(
406
+ f"Started progress tracking: {operation_name} ({progress_token.value})"
407
+ )
408
+ return progress_token
409
+
410
+ async def update_progress(
411
+ self,
412
+ progress_token: ProgressToken,
413
+ progress: Optional[float] = None,
414
+ status: Optional[str] = None,
415
+ increment: Optional[float] = None,
416
+ ) -> None:
417
+ """Update progress for an operation.
418
+
419
+ Args:
420
+ progress_token: Progress token
421
+ progress: Current progress value
422
+ status: Status message
423
+ increment: Amount to increment current progress
424
+ """
425
+ if progress_token not in self._active_progress:
426
+ logger.warning(f"Progress token not found: {progress_token}")
427
+ return
428
+
429
+ progress_info = self._active_progress[progress_token]
430
+
431
+ # Update progress value
432
+ if progress is not None:
433
+ progress_info["current"] = progress
434
+ progress_token.progress = progress
435
+ elif increment is not None:
436
+ new_progress = progress_info.get("current", 0) + increment
437
+ progress_info["current"] = new_progress
438
+ progress_token.progress = new_progress
439
+
440
+ # Update status
441
+ if status is not None:
442
+ progress_info["status"] = status
443
+ progress_token.status = status
444
+
445
+ progress_info["updated_at"] = time.time()
446
+
447
+ # Create notification
448
+ notification = ProgressNotification.create(
449
+ progress_token=progress_token.value,
450
+ progress=progress_info["current"],
451
+ total=progress_info.get("total"),
452
+ status=progress_info["status"],
453
+ )
454
+
455
+ # Call callbacks
456
+ for callback in self._progress_callbacks.get(progress_token, []):
457
+ try:
458
+ if asyncio.iscoroutinefunction(callback):
459
+ await callback(notification)
460
+ else:
461
+ callback(notification)
462
+ except Exception as e:
463
+ logger.error(f"Progress callback error: {e}")
464
+
465
+ async def complete_progress(
466
+ self, progress_token: ProgressToken, status: str = "completed"
467
+ ) -> None:
468
+ """Complete progress tracking.
469
+
470
+ Args:
471
+ progress_token: Progress token
472
+ status: Final status message
473
+ """
474
+ if progress_token not in self._active_progress:
475
+ return
476
+
477
+ progress_info = self._active_progress[progress_token]
478
+ progress_info["status"] = status
479
+ progress_info["completed_at"] = time.time()
480
+
481
+ # Update token status
482
+ progress_token.status = status
483
+
484
+ # Send final progress update
485
+ await self.update_progress(progress_token, status=status)
486
+
487
+ # Clean up
488
+ del self._active_progress[progress_token]
489
+ del self._progress_callbacks[progress_token]
490
+
491
+ logger.debug(f"Completed progress tracking: {progress_token.value}")
492
+
493
+ def add_progress_callback(
494
+ self, progress_token: ProgressToken, callback: Callable
495
+ ) -> None:
496
+ """Add callback for progress updates.
497
+
498
+ Args:
499
+ progress_token: Progress token
500
+ callback: Callback function
501
+ """
502
+ if progress_token in self._progress_callbacks:
503
+ self._progress_callbacks[progress_token].append(callback)
504
+
505
+ def get_progress_info(
506
+ self, progress_token: ProgressToken
507
+ ) -> Optional[Dict[str, Any]]:
508
+ """Get current progress information.
509
+
510
+ Args:
511
+ progress_token: Progress token
512
+
513
+ Returns:
514
+ Progress information or None
515
+ """
516
+ return self._active_progress.get(progress_token)
517
+
518
+ def list_active_progress(self) -> List[ProgressToken]:
519
+ """List all active progress tokens."""
520
+ return list(self._active_progress.keys())
521
+
522
+ def get_active_progress(self) -> List[ProgressToken]:
523
+ """Get all active progress tokens (alias for list_active_progress)."""
524
+ return self.list_active_progress()
525
+
526
+
527
+ class CancellationManager:
528
+ """Manages request cancellation and cleanup."""
529
+
530
+ def __init__(self):
531
+ """Initialize cancellation manager."""
532
+ self._cancelled_requests: set[str] = set()
533
+ self._cancellation_callbacks: Dict[str, List[Callable]] = {}
534
+ self._request_cleanup: Dict[str, List[Callable]] = {}
535
+
536
+ async def cancel_request(
537
+ self, request_id: str, reason: Optional[str] = None
538
+ ) -> None:
539
+ """Cancel a request.
540
+
541
+ Args:
542
+ request_id: Request ID to cancel
543
+ reason: Cancellation reason
544
+ """
545
+ if request_id in self._cancelled_requests:
546
+ return # Already cancelled
547
+
548
+ self._cancelled_requests.add(request_id)
549
+
550
+ # Store cancellation reason
551
+ if not hasattr(self, "_cancellation_reasons"):
552
+ self._cancellation_reasons = {}
553
+ self._cancellation_reasons[request_id] = reason
554
+
555
+ # Create cancellation notification
556
+ notification = CancelledNotification.create(request_id, reason)
557
+
558
+ # Call cancellation callbacks
559
+ for callback in self._cancellation_callbacks.get(request_id, []):
560
+ try:
561
+ if asyncio.iscoroutinefunction(callback):
562
+ await callback(notification)
563
+ else:
564
+ callback(notification)
565
+ except Exception as e:
566
+ logger.error(f"Cancellation callback error: {e}")
567
+
568
+ # Run cleanup functions
569
+ for cleanup in self._request_cleanup.get(request_id, []):
570
+ try:
571
+ if asyncio.iscoroutinefunction(cleanup):
572
+ await cleanup()
573
+ else:
574
+ cleanup()
575
+ except Exception as e:
576
+ logger.error(f"Cleanup error for {request_id}: {e}")
577
+
578
+ # Clean up tracking
579
+ self._cancellation_callbacks.pop(request_id, None)
580
+ self._request_cleanup.pop(request_id, None)
581
+
582
+ logger.info(f"Cancelled request: {request_id}")
583
+
584
+ def is_cancelled(self, request_id: str) -> bool:
585
+ """Check if a request is cancelled.
586
+
587
+ Args:
588
+ request_id: Request ID to check
589
+
590
+ Returns:
591
+ True if cancelled
592
+ """
593
+ return request_id in self._cancelled_requests
594
+
595
+ def add_cancellation_callback(self, request_id: str, callback: Callable) -> None:
596
+ """Add callback for request cancellation.
597
+
598
+ Args:
599
+ request_id: Request ID
600
+ callback: Callback function
601
+ """
602
+ if request_id not in self._cancellation_callbacks:
603
+ self._cancellation_callbacks[request_id] = []
604
+ self._cancellation_callbacks[request_id].append(callback)
605
+
606
+ def add_cleanup_function(self, request_id: str, cleanup: Callable) -> None:
607
+ """Add cleanup function for request.
608
+
609
+ Args:
610
+ request_id: Request ID
611
+ cleanup: Cleanup function
612
+ """
613
+ if request_id not in self._request_cleanup:
614
+ self._request_cleanup[request_id] = []
615
+ self._request_cleanup[request_id].append(cleanup)
616
+
617
+ def clear_cancelled_request(self, request_id: str) -> None:
618
+ """Clear cancelled request from tracking.
619
+
620
+ Args:
621
+ request_id: Request ID to clear
622
+ """
623
+ self._cancelled_requests.discard(request_id)
624
+ if hasattr(self, "_cancellation_reasons"):
625
+ self._cancellation_reasons.pop(request_id, None)
626
+
627
+ def get_cancellation_reason(self, request_id: str) -> Optional[str]:
628
+ """Get cancellation reason for a request.
629
+
630
+ Args:
631
+ request_id: Request ID to check
632
+
633
+ Returns:
634
+ Cancellation reason if cancelled, None otherwise
635
+ """
636
+ if not hasattr(self, "_cancellation_reasons"):
637
+ self._cancellation_reasons = {}
638
+ return self._cancellation_reasons.get(request_id)
639
+
640
+
641
+ class CompletionManager:
642
+ """Manages auto-completion for prompts and resources."""
643
+
644
+ def __init__(self):
645
+ """Initialize completion manager."""
646
+ self._completion_providers: Dict[str, Callable] = {}
647
+ self._available_tools = []
648
+ self._available_resources = []
649
+
650
+ def register_completion_provider(self, ref_type: str, provider: Callable) -> None:
651
+ """Register completion provider for a reference type.
652
+
653
+ Args:
654
+ ref_type: Reference type (e.g., "prompts", "resources")
655
+ provider: Completion provider function
656
+ """
657
+ self._completion_providers[ref_type] = provider
658
+
659
+ async def get_completions(
660
+ self,
661
+ completion_type: str = None,
662
+ ref_type: str = None,
663
+ ref_name: Optional[str] = None,
664
+ partial: Optional[str] = None,
665
+ prefix: Optional[str] = None,
666
+ ) -> List[Any]:
667
+ """Get completions for a reference.
668
+
669
+ Args:
670
+ completion_type: Type of completion ("tools", "resources", etc)
671
+ ref_type: Reference type (e.g., "tools", "resources", "prompts")
672
+ ref_name: Reference name (optional)
673
+ partial: Partial input to complete (optional)
674
+ prefix: Prefix to filter completions (optional)
675
+
676
+ Returns:
677
+ List of completion items
678
+ """
679
+ # Handle different argument patterns
680
+ type_to_use = completion_type or ref_type
681
+ filter_text = prefix or partial
682
+
683
+ if type_to_use == "tools":
684
+ tools = self._get_available_tools()
685
+ if filter_text:
686
+ return [t for t in tools if t.get("name", "").startswith(filter_text)]
687
+ return tools
688
+ elif type_to_use == "resources":
689
+ resources = self._get_available_resources()
690
+ if filter_text:
691
+ return [
692
+ r for r in resources if r.get("uri", "").startswith(filter_text)
693
+ ]
694
+ return resources
695
+
696
+ # Use registered provider if available
697
+ provider = self._completion_providers.get(type_to_use)
698
+ if not provider:
699
+ return []
700
+
701
+ try:
702
+ if asyncio.iscoroutinefunction(provider):
703
+ completions = await provider(ref_name, filter_text)
704
+ else:
705
+ completions = provider(ref_name, filter_text)
706
+
707
+ if isinstance(completions, list):
708
+ return completions
709
+ else:
710
+ return []
711
+
712
+ except Exception as e:
713
+ logger.error(f"Completion provider error: {e}")
714
+ return []
715
+
716
+ def _get_available_tools(self) -> List[Dict[str, Any]]:
717
+ """Get available tools for completion."""
718
+ return self._available_tools
719
+
720
+ def _get_available_resources(self) -> List[Dict[str, Any]]:
721
+ """Get available resources for completion."""
722
+ return self._available_resources
723
+
724
+
725
+ class SamplingManager:
726
+ """Manages LLM sampling requests from server to client."""
727
+
728
+ def __init__(self):
729
+ """Initialize sampling manager."""
730
+ self._sampling_callbacks: List[Callable] = []
731
+ self._samples: List[Dict[str, Any]] = []
732
+
733
+ def add_sampling_callback(self, callback: Callable) -> None:
734
+ """Add callback for sampling requests.
735
+
736
+ Args:
737
+ callback: Sampling callback function
738
+ """
739
+ self._sampling_callbacks.append(callback)
740
+
741
+ async def request_sampling(
742
+ self, messages: List[Dict[str, Any]], **kwargs
743
+ ) -> Dict[str, Any]:
744
+ """Request LLM sampling from client.
745
+
746
+ Args:
747
+ messages: Messages for sampling
748
+ **kwargs: Additional sampling parameters
749
+
750
+ Returns:
751
+ Sampling result
752
+ """
753
+ request = SamplingRequest.create(messages, **kwargs)
754
+
755
+ # Try each callback until one handles the request
756
+ for callback in self._sampling_callbacks:
757
+ try:
758
+ if asyncio.iscoroutinefunction(callback):
759
+ result = await callback(request)
760
+ else:
761
+ result = callback(request)
762
+
763
+ if result is not None:
764
+ return result
765
+
766
+ except Exception as e:
767
+ logger.error(f"Sampling callback error: {e}")
768
+
769
+ raise MCPError(
770
+ "No sampling provider available", error_code=MCPErrorCode.METHOD_NOT_FOUND
771
+ )
772
+
773
+ async def create_message_sample(
774
+ self, messages: List[Dict[str, Any]], **kwargs
775
+ ) -> Dict[str, Any]:
776
+ """Create a message sample.
777
+
778
+ Args:
779
+ messages: Messages for sampling
780
+ **kwargs: Additional sampling parameters including model_preferences, metadata
781
+
782
+ Returns:
783
+ Sampling result with sample_id and timestamp
784
+ """
785
+ # Create sample with required fields
786
+ sample = {
787
+ "messages": messages,
788
+ "sample_id": f"sample_{uuid.uuid4().hex[:8]}",
789
+ "timestamp": time.time(),
790
+ }
791
+
792
+ # Add optional fields
793
+ if "model_preferences" in kwargs:
794
+ sample["model_preferences"] = kwargs["model_preferences"]
795
+ if "metadata" in kwargs:
796
+ sample["metadata"] = kwargs["metadata"]
797
+
798
+ # Store in history
799
+ self._samples.append(sample)
800
+
801
+ return sample
802
+
803
+ def get_sample_history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
804
+ """Get sampling history.
805
+
806
+ Args:
807
+ limit: Maximum number of samples to return
808
+
809
+ Returns:
810
+ List of sample history entries
811
+ """
812
+ if limit is None:
813
+ return self._samples.copy()
814
+ return self._samples[-limit:] if limit > 0 else []
815
+
816
+ def clear_sample_history(self) -> None:
817
+ """Clear sampling history."""
818
+ self._samples.clear()
819
+
820
+
821
+ class RootsManager:
822
+ """Manages file system roots access."""
823
+
824
+ def __init__(self):
825
+ """Initialize roots manager."""
826
+ self._roots: List[Dict[str, Any]] = []
827
+ self._access_validators: List[Callable] = []
828
+
829
+ def add_root(
830
+ self, uri: str, name: Optional[str] = None, description: Optional[str] = None
831
+ ) -> None:
832
+ """Add a file system root.
833
+
834
+ Args:
835
+ uri: Root URI
836
+ name: Optional name for the root
837
+ description: Optional description for the root
838
+ """
839
+ root = {"uri": uri}
840
+ if name:
841
+ root["name"] = name
842
+ if description:
843
+ root["description"] = description
844
+
845
+ self._roots.append(root)
846
+ logger.info(f"Added root: {uri}")
847
+
848
+ def remove_root(self, uri: str) -> bool:
849
+ """Remove a file system root.
850
+
851
+ Args:
852
+ uri: Root URI to remove
853
+
854
+ Returns:
855
+ True if removed
856
+ """
857
+ for i, root in enumerate(self._roots):
858
+ if root["uri"] == uri:
859
+ del self._roots[i]
860
+ logger.info(f"Removed root: {uri}")
861
+ return True
862
+ return False
863
+
864
+ def list_roots(self) -> List[Dict[str, Any]]:
865
+ """List all file system roots.
866
+
867
+ Returns:
868
+ List of root objects
869
+ """
870
+ return self._roots.copy()
871
+
872
+ def find_root_for_uri(self, uri: str) -> Optional[Dict[str, Any]]:
873
+ """Find the root that contains the given URI.
874
+
875
+ Args:
876
+ uri: URI to find root for
877
+
878
+ Returns:
879
+ Root object if found, None otherwise
880
+ """
881
+ for root in self._roots:
882
+ if uri.startswith(root["uri"]):
883
+ return root
884
+ return None
885
+
886
+ def add_access_validator(self, validator: Callable) -> None:
887
+ """Add access validator for roots.
888
+
889
+ Args:
890
+ validator: Validator function
891
+ """
892
+ self._access_validators.append(validator)
893
+
894
+ async def validate_access(self, uri: str, operation: str = "read") -> bool:
895
+ """Validate access to a URI.
896
+
897
+ Args:
898
+ uri: URI to validate
899
+ operation: Operation type
900
+
901
+ Returns:
902
+ True if access is allowed
903
+ """
904
+ # Check if URI is under any root
905
+ is_under_root = False
906
+ for root in self._roots:
907
+ root_uri = root["uri"]
908
+ if uri.startswith(root_uri):
909
+ is_under_root = True
910
+ break
911
+
912
+ if not is_under_root:
913
+ return False
914
+
915
+ # Run access validators
916
+ for validator in self._access_validators:
917
+ try:
918
+ if asyncio.iscoroutinefunction(validator):
919
+ allowed = await validator(uri, operation)
920
+ else:
921
+ allowed = validator(uri, operation)
922
+
923
+ if not allowed:
924
+ return False
925
+
926
+ except Exception as e:
927
+ logger.error(f"Access validator error: {e}")
928
+ return False
929
+
930
+ return True
931
+
932
+
933
+ class ProtocolManager:
934
+ """Central manager for all MCP protocol features."""
935
+
936
+ def __init__(self):
937
+ """Initialize protocol manager."""
938
+ self.progress = ProgressManager()
939
+ self.cancellation = CancellationManager()
940
+ self.completion = CompletionManager()
941
+ self.sampling = SamplingManager()
942
+ self.roots = RootsManager()
943
+
944
+ # Protocol state
945
+ self._initialized = False
946
+ self._client_capabilities: Dict[str, Any] = {}
947
+ self._server_capabilities: Dict[str, Any] = {}
948
+ self._handlers: Dict[str, Callable] = {}
949
+
950
+ def set_initialized(self, client_capabilities: Dict[str, Any]) -> None:
951
+ """Set protocol as initialized with client capabilities.
952
+
953
+ Args:
954
+ client_capabilities: Client capability advertisement
955
+ """
956
+ self._initialized = True
957
+ self._client_capabilities = client_capabilities
958
+ logger.info("MCP protocol initialized")
959
+
960
+ def is_initialized(self) -> bool:
961
+ """Check if protocol is initialized."""
962
+ return self._initialized
963
+
964
+ def get_client_capabilities(self) -> Dict[str, Any]:
965
+ """Get client capabilities."""
966
+ return self._client_capabilities.copy()
967
+
968
+ def set_server_capabilities(self, capabilities: Dict[str, Any]) -> None:
969
+ """Set server capabilities.
970
+
971
+ Args:
972
+ capabilities: Server capabilities
973
+ """
974
+ self._server_capabilities = capabilities
975
+
976
+ def get_server_capabilities(self) -> Dict[str, Any]:
977
+ """Get server capabilities."""
978
+ return self._server_capabilities.copy()
979
+
980
+ def supports_progress(self) -> bool:
981
+ """Check if client supports progress reporting."""
982
+ return self._client_capabilities.get("experimental", {}).get(
983
+ "progressNotifications", False
984
+ )
985
+
986
+ def supports_cancellation(self) -> bool:
987
+ """Check if client supports cancellation."""
988
+ return True # Basic support assumed
989
+
990
+ def supports_completion(self) -> bool:
991
+ """Check if client supports completion."""
992
+ return self._client_capabilities.get("experimental", {}).get(
993
+ "completion", False
994
+ )
995
+
996
+ def supports_sampling(self) -> bool:
997
+ """Check if client supports sampling."""
998
+ return self._client_capabilities.get("experimental", {}).get("sampling", False)
999
+
1000
+ def supports_roots(self) -> bool:
1001
+ """Check if client supports roots."""
1002
+ return self._client_capabilities.get("roots", {}).get("listChanged", False)
1003
+
1004
+ def _get_handler(self, method: str) -> Optional[Callable]:
1005
+ """Get handler for a method.
1006
+
1007
+ Args:
1008
+ method: Method name
1009
+
1010
+ Returns:
1011
+ Handler function or None
1012
+ """
1013
+ return self._handlers.get(method)
1014
+
1015
+ def register_handler(self, method: str, handler: Callable) -> None:
1016
+ """Register a handler for a method.
1017
+
1018
+ Args:
1019
+ method: Method name
1020
+ handler: Handler function
1021
+ """
1022
+ self._handlers[method] = handler
1023
+
1024
+ async def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
1025
+ """Handle an incoming request.
1026
+
1027
+ Args:
1028
+ request: Request message
1029
+
1030
+ Returns:
1031
+ Response message
1032
+ """
1033
+ method = request.get("method")
1034
+ params = request.get("params", {})
1035
+ request_id = request.get("id")
1036
+
1037
+ handler = self._get_handler(method)
1038
+ if not handler:
1039
+ raise MCPError(
1040
+ f"Method not found: {method}", error_code=MCPErrorCode.METHOD_NOT_FOUND
1041
+ )
1042
+
1043
+ try:
1044
+ # Call handler
1045
+ if asyncio.iscoroutinefunction(handler):
1046
+ result = await handler(request)
1047
+ else:
1048
+ result = handler(request)
1049
+
1050
+ # Build response
1051
+ response = {"jsonrpc": "2.0", "result": result, "id": request_id}
1052
+ return response
1053
+
1054
+ except MCPError:
1055
+ raise
1056
+ except Exception as e:
1057
+ logger.error(f"Handler error for {method}: {e}")
1058
+ raise MCPError(str(e), error_code=MCPErrorCode.INTERNAL_ERROR)
1059
+
1060
+ def validate_message_type(self, message: Dict[str, Any]) -> MessageType:
1061
+ """Validate and determine message type.
1062
+
1063
+ Args:
1064
+ message: Message to validate
1065
+
1066
+ Returns:
1067
+ Message type
1068
+
1069
+ Raises:
1070
+ MCPError: If message is invalid
1071
+ """
1072
+ if "jsonrpc" not in message or message["jsonrpc"] != "2.0":
1073
+ raise MCPError(
1074
+ "Invalid JSON-RPC version", error_code=MCPErrorCode.INVALID_REQUEST
1075
+ )
1076
+
1077
+ if "method" not in message:
1078
+ raise MCPError(
1079
+ "Missing method field", error_code=MCPErrorCode.INVALID_REQUEST
1080
+ )
1081
+
1082
+ # Check if it's a request or notification
1083
+ if "id" in message:
1084
+ return MessageType.REQUEST
1085
+ else:
1086
+ return MessageType.NOTIFICATION
1087
+
1088
+
1089
+ # Global protocol manager instance
1090
+ _protocol_manager: Optional[ProtocolManager] = None
1091
+
1092
+
1093
+ def get_protocol_manager() -> ProtocolManager:
1094
+ """Get global protocol manager instance."""
1095
+ global _protocol_manager
1096
+ if _protocol_manager is None:
1097
+ _protocol_manager = ProtocolManager()
1098
+ return _protocol_manager
1099
+
1100
+
1101
+ # Convenience functions
1102
+ def start_progress(operation_name: str, total: Optional[float] = None) -> ProgressToken:
1103
+ """Start progress tracking."""
1104
+ return get_protocol_manager().progress.start_progress(operation_name, total)
1105
+
1106
+
1107
+ async def update_progress(
1108
+ token: ProgressToken, progress: Optional[float] = None, status: Optional[str] = None
1109
+ ) -> None:
1110
+ """Update progress."""
1111
+ await get_protocol_manager().progress.update_progress(token, progress, status)
1112
+
1113
+
1114
+ async def complete_progress(token: ProgressToken, status: str = "completed") -> None:
1115
+ """Complete progress."""
1116
+ await get_protocol_manager().progress.complete_progress(token, status)
1117
+
1118
+
1119
+ def is_cancelled(request_id: str) -> bool:
1120
+ """Check if request is cancelled."""
1121
+ return get_protocol_manager().cancellation.is_cancelled(request_id)
1122
+
1123
+
1124
+ async def cancel_request(request_id: str, reason: Optional[str] = None) -> None:
1125
+ """Cancel a request."""
1126
+ await get_protocol_manager().cancellation.cancel_request(request_id, reason)