kailash 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +3 -3
- kailash/api/custom_nodes_secure.py +3 -3
- kailash/api/gateway.py +1 -1
- kailash/api/studio.py +1 -1
- kailash/api/workflow_api.py +2 -2
- kailash/core/resilience/bulkhead.py +475 -0
- kailash/core/resilience/circuit_breaker.py +92 -10
- kailash/core/resilience/health_monitor.py +578 -0
- kailash/edge/discovery.py +86 -0
- kailash/mcp_server/__init__.py +309 -33
- kailash/mcp_server/advanced_features.py +1022 -0
- kailash/mcp_server/ai_registry_server.py +27 -2
- kailash/mcp_server/auth.py +789 -0
- kailash/mcp_server/client.py +645 -378
- kailash/mcp_server/discovery.py +1593 -0
- kailash/mcp_server/errors.py +673 -0
- kailash/mcp_server/oauth.py +1727 -0
- kailash/mcp_server/protocol.py +1126 -0
- kailash/mcp_server/registry_integration.py +587 -0
- kailash/mcp_server/server.py +1228 -96
- kailash/mcp_server/transports.py +1169 -0
- kailash/mcp_server/utils/__init__.py +6 -1
- kailash/mcp_server/utils/cache.py +250 -7
- kailash/middleware/auth/auth_manager.py +3 -3
- kailash/middleware/communication/api_gateway.py +1 -1
- kailash/middleware/communication/realtime.py +1 -1
- kailash/middleware/mcp/enhanced_server.py +1 -1
- kailash/nodes/__init__.py +2 -0
- kailash/nodes/admin/audit_log.py +6 -6
- kailash/nodes/admin/permission_check.py +8 -8
- kailash/nodes/admin/role_management.py +32 -28
- kailash/nodes/admin/schema.sql +6 -1
- kailash/nodes/admin/schema_manager.py +13 -13
- kailash/nodes/admin/security_event.py +15 -15
- kailash/nodes/admin/tenant_isolation.py +3 -3
- kailash/nodes/admin/transaction_utils.py +3 -3
- kailash/nodes/admin/user_management.py +21 -21
- kailash/nodes/ai/a2a.py +11 -11
- kailash/nodes/ai/ai_providers.py +9 -12
- kailash/nodes/ai/embedding_generator.py +13 -14
- kailash/nodes/ai/intelligent_agent_orchestrator.py +19 -19
- kailash/nodes/ai/iterative_llm_agent.py +2 -2
- kailash/nodes/ai/llm_agent.py +210 -33
- kailash/nodes/ai/self_organizing.py +2 -2
- kailash/nodes/alerts/discord.py +4 -4
- kailash/nodes/api/graphql.py +6 -6
- kailash/nodes/api/http.py +10 -10
- kailash/nodes/api/rate_limiting.py +4 -4
- kailash/nodes/api/rest.py +15 -15
- kailash/nodes/auth/mfa.py +3 -3
- kailash/nodes/auth/risk_assessment.py +2 -2
- kailash/nodes/auth/session_management.py +5 -5
- kailash/nodes/auth/sso.py +143 -0
- kailash/nodes/base.py +8 -2
- kailash/nodes/base_async.py +16 -2
- kailash/nodes/base_with_acl.py +2 -2
- kailash/nodes/cache/__init__.py +9 -0
- kailash/nodes/cache/cache.py +1172 -0
- kailash/nodes/cache/cache_invalidation.py +874 -0
- kailash/nodes/cache/redis_pool_manager.py +595 -0
- kailash/nodes/code/async_python.py +2 -1
- kailash/nodes/code/python.py +194 -30
- kailash/nodes/compliance/data_retention.py +6 -6
- kailash/nodes/compliance/gdpr.py +5 -5
- kailash/nodes/data/__init__.py +10 -0
- kailash/nodes/data/async_sql.py +1956 -129
- kailash/nodes/data/optimistic_locking.py +906 -0
- kailash/nodes/data/readers.py +8 -8
- kailash/nodes/data/redis.py +378 -0
- kailash/nodes/data/sql.py +314 -3
- kailash/nodes/data/streaming.py +21 -0
- kailash/nodes/enterprise/__init__.py +8 -0
- kailash/nodes/enterprise/audit_logger.py +285 -0
- kailash/nodes/enterprise/batch_processor.py +22 -3
- kailash/nodes/enterprise/data_lineage.py +1 -1
- kailash/nodes/enterprise/mcp_executor.py +205 -0
- kailash/nodes/enterprise/service_discovery.py +150 -0
- kailash/nodes/enterprise/tenant_assignment.py +108 -0
- kailash/nodes/logic/async_operations.py +2 -2
- kailash/nodes/logic/convergence.py +1 -1
- kailash/nodes/logic/operations.py +1 -1
- kailash/nodes/monitoring/__init__.py +11 -1
- kailash/nodes/monitoring/health_check.py +456 -0
- kailash/nodes/monitoring/log_processor.py +817 -0
- kailash/nodes/monitoring/metrics_collector.py +627 -0
- kailash/nodes/monitoring/performance_benchmark.py +137 -11
- kailash/nodes/rag/advanced.py +7 -7
- kailash/nodes/rag/agentic.py +49 -2
- kailash/nodes/rag/conversational.py +3 -3
- kailash/nodes/rag/evaluation.py +3 -3
- kailash/nodes/rag/federated.py +3 -3
- kailash/nodes/rag/graph.py +3 -3
- kailash/nodes/rag/multimodal.py +3 -3
- kailash/nodes/rag/optimized.py +5 -5
- kailash/nodes/rag/privacy.py +3 -3
- kailash/nodes/rag/query_processing.py +6 -6
- kailash/nodes/rag/realtime.py +1 -1
- kailash/nodes/rag/registry.py +1 -1
- kailash/nodes/rag/router.py +1 -1
- kailash/nodes/rag/similarity.py +7 -7
- kailash/nodes/rag/strategies.py +4 -4
- kailash/nodes/security/abac_evaluator.py +6 -6
- kailash/nodes/security/behavior_analysis.py +5 -5
- kailash/nodes/security/credential_manager.py +1 -1
- kailash/nodes/security/rotating_credentials.py +11 -11
- kailash/nodes/security/threat_detection.py +8 -8
- kailash/nodes/testing/credential_testing.py +2 -2
- kailash/nodes/transform/processors.py +5 -5
- kailash/runtime/local.py +163 -9
- kailash/runtime/parameter_injection.py +425 -0
- kailash/runtime/parameter_injector.py +657 -0
- kailash/runtime/testing.py +2 -2
- kailash/testing/fixtures.py +2 -2
- kailash/workflow/builder.py +99 -14
- kailash/workflow/builder_improvements.py +207 -0
- kailash/workflow/input_handling.py +170 -0
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/METADATA +22 -9
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/RECORD +122 -95
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/WHEEL +0 -0
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/entry_points.txt +0 -0
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1126 @@
|
|
1
|
+
"""
|
2
|
+
Complete MCP Protocol Implementation.
|
3
|
+
|
4
|
+
This module implements the full Model Context Protocol (MCP) specification,
|
5
|
+
including all message types, progress reporting, cancellation, completion,
|
6
|
+
sampling, and other advanced protocol features that build on the official
|
7
|
+
MCP Python SDK.
|
8
|
+
|
9
|
+
Features:
|
10
|
+
- Complete protocol message type definitions
|
11
|
+
- Progress reporting with token-based tracking
|
12
|
+
- Request cancellation and cleanup
|
13
|
+
- Completion system for prompts and resources
|
14
|
+
- Sampling system for LLM interactions
|
15
|
+
- Roots system for file system access
|
16
|
+
- Meta field support for protocol metadata
|
17
|
+
- Proper error handling with standard codes
|
18
|
+
|
19
|
+
The implementation follows the official MCP specification while providing
|
20
|
+
enhanced functionality for production use cases.
|
21
|
+
|
22
|
+
Examples:
|
23
|
+
Progress reporting:
|
24
|
+
|
25
|
+
>>> from kailash.mcp_server.protocol import ProgressManager
|
26
|
+
>>> progress = ProgressManager()
|
27
|
+
>>>
|
28
|
+
>>> # Start progress tracking
|
29
|
+
>>> token = progress.start_progress("long_operation", total=100)
|
30
|
+
>>> for i in range(100):
|
31
|
+
... await progress.update_progress(token, progress=i, status=f"Step {i}")
|
32
|
+
>>> await progress.complete_progress(token)
|
33
|
+
|
34
|
+
Request cancellation:
|
35
|
+
|
36
|
+
>>> from kailash.mcp_server.protocol import CancellationManager
|
37
|
+
>>> cancellation = CancellationManager()
|
38
|
+
>>>
|
39
|
+
>>> # Check if request should be cancelled
|
40
|
+
>>> if await cancellation.is_cancelled(request_id):
|
41
|
+
... raise CancelledError("Operation was cancelled")
|
42
|
+
|
43
|
+
Completion system:
|
44
|
+
|
45
|
+
>>> from kailash.mcp_server.protocol import CompletionManager
|
46
|
+
>>> completion = CompletionManager()
|
47
|
+
>>>
|
48
|
+
>>> # Get completions for a prompt argument
|
49
|
+
>>> completions = await completion.get_completions(
|
50
|
+
... "prompts/analyze", "data_source", "fil"
|
51
|
+
... )
|
52
|
+
"""
|
53
|
+
|
54
|
+
import asyncio
|
55
|
+
import json
|
56
|
+
import logging
|
57
|
+
import time
|
58
|
+
import uuid
|
59
|
+
from abc import ABC, abstractmethod
|
60
|
+
from dataclasses import asdict, dataclass, field
|
61
|
+
from enum import Enum
|
62
|
+
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
|
63
|
+
|
64
|
+
from .errors import MCPError, MCPErrorCode
|
65
|
+
|
66
|
+
logger = logging.getLogger(__name__)
|
67
|
+
|
68
|
+
|
69
|
+
class MessageType(Enum):
|
70
|
+
"""MCP message types following the official specification."""
|
71
|
+
|
72
|
+
# Core protocol
|
73
|
+
INITIALIZE = "initialize"
|
74
|
+
INITIALIZED = "initialized"
|
75
|
+
|
76
|
+
# Tool operations
|
77
|
+
TOOLS_LIST = "tools/list"
|
78
|
+
TOOLS_CALL = "tools/call"
|
79
|
+
|
80
|
+
# Resource operations
|
81
|
+
RESOURCES_LIST = "resources/list"
|
82
|
+
RESOURCES_READ = "resources/read"
|
83
|
+
RESOURCES_SUBSCRIBE = "resources/subscribe"
|
84
|
+
RESOURCES_UNSUBSCRIBE = "resources/unsubscribe"
|
85
|
+
RESOURCES_UPDATED = "notifications/resources/updated"
|
86
|
+
|
87
|
+
# Prompt operations
|
88
|
+
PROMPTS_LIST = "prompts/list"
|
89
|
+
PROMPTS_GET = "prompts/get"
|
90
|
+
|
91
|
+
# Progress operations
|
92
|
+
PROGRESS = "notifications/progress"
|
93
|
+
|
94
|
+
# Cancellation
|
95
|
+
CANCELLED = "notifications/cancelled"
|
96
|
+
|
97
|
+
# Completion
|
98
|
+
COMPLETION_COMPLETE = "completion/complete"
|
99
|
+
|
100
|
+
# Sampling (Server to Client)
|
101
|
+
SAMPLING_CREATE_MESSAGE = "sampling/createMessage"
|
102
|
+
|
103
|
+
# Roots (File system)
|
104
|
+
ROOTS_LIST = "roots/list"
|
105
|
+
|
106
|
+
# Logging
|
107
|
+
LOGGING_SET_LEVEL = "logging/setLevel"
|
108
|
+
|
109
|
+
# Custom extensions
|
110
|
+
PING = "ping"
|
111
|
+
PONG = "pong"
|
112
|
+
REQUEST = "request" # Generic request type
|
113
|
+
NOTIFICATION = "notification" # Generic notification type
|
114
|
+
|
115
|
+
|
116
|
+
@dataclass
|
117
|
+
class ProgressToken:
|
118
|
+
"""Type-safe progress token with tracking information."""
|
119
|
+
|
120
|
+
value: str
|
121
|
+
operation_name: str
|
122
|
+
total: Optional[float] = None
|
123
|
+
progress: float = 0
|
124
|
+
status: Optional[str] = None
|
125
|
+
|
126
|
+
def __hash__(self):
|
127
|
+
"""Make hashable for use in dictionaries."""
|
128
|
+
return hash(self.value)
|
129
|
+
|
130
|
+
def __eq__(self, other):
|
131
|
+
"""Compare tokens by value."""
|
132
|
+
if isinstance(other, ProgressToken):
|
133
|
+
return self.value == other.value
|
134
|
+
return False
|
135
|
+
|
136
|
+
|
137
|
+
@dataclass
|
138
|
+
class MetaData:
|
139
|
+
"""Meta fields for protocol messages."""
|
140
|
+
|
141
|
+
progress_token: Optional[ProgressToken] = None
|
142
|
+
request_id: Optional[str] = None
|
143
|
+
timestamp: Optional[float] = None
|
144
|
+
operation_id: Optional[str] = None
|
145
|
+
user_id: Optional[str] = None
|
146
|
+
additional_data: Optional[Dict[str, Any]] = None
|
147
|
+
|
148
|
+
def __post_init__(self):
|
149
|
+
"""Initialize timestamp if not provided."""
|
150
|
+
if self.timestamp is None:
|
151
|
+
self.timestamp = time.time()
|
152
|
+
if self.additional_data is None:
|
153
|
+
self.additional_data = {}
|
154
|
+
|
155
|
+
def to_dict(self) -> Dict[str, Any]:
|
156
|
+
"""Convert to dictionary for JSON serialization."""
|
157
|
+
result = {}
|
158
|
+
if self.progress_token:
|
159
|
+
result["progressToken"] = self.progress_token
|
160
|
+
if self.request_id:
|
161
|
+
result["requestId"] = self.request_id
|
162
|
+
if self.timestamp:
|
163
|
+
result["timestamp"] = self.timestamp
|
164
|
+
if self.operation_id:
|
165
|
+
result["operation_id"] = self.operation_id
|
166
|
+
if self.user_id:
|
167
|
+
result["user_id"] = self.user_id
|
168
|
+
if self.additional_data:
|
169
|
+
result.update(self.additional_data)
|
170
|
+
return result
|
171
|
+
|
172
|
+
|
173
|
+
@dataclass
|
174
|
+
class ProgressNotification:
|
175
|
+
"""Progress notification message."""
|
176
|
+
|
177
|
+
method: str = "notifications/progress"
|
178
|
+
params: Dict[str, Any] = field(default_factory=dict)
|
179
|
+
|
180
|
+
def __post_init__(self):
|
181
|
+
"""Ensure proper params structure."""
|
182
|
+
if "progressToken" not in self.params:
|
183
|
+
raise ValueError("Progress notification requires progressToken")
|
184
|
+
|
185
|
+
@classmethod
|
186
|
+
def create(
|
187
|
+
cls,
|
188
|
+
progress_token: ProgressToken,
|
189
|
+
progress: Optional[float] = None,
|
190
|
+
total: Optional[float] = None,
|
191
|
+
status: Optional[str] = None,
|
192
|
+
) -> "ProgressNotification":
|
193
|
+
"""Create progress notification."""
|
194
|
+
params = {"progressToken": progress_token}
|
195
|
+
|
196
|
+
if progress is not None:
|
197
|
+
params["progress"] = progress
|
198
|
+
if total is not None:
|
199
|
+
params["total"] = total
|
200
|
+
if status is not None:
|
201
|
+
params["status"] = status
|
202
|
+
|
203
|
+
return cls(params=params)
|
204
|
+
|
205
|
+
|
206
|
+
@dataclass
|
207
|
+
class CancelledNotification:
|
208
|
+
"""Cancellation notification message."""
|
209
|
+
|
210
|
+
method: str = "notifications/cancelled"
|
211
|
+
params: Dict[str, Any] = field(default_factory=dict)
|
212
|
+
|
213
|
+
def __post_init__(self):
|
214
|
+
"""Ensure proper params structure."""
|
215
|
+
if "requestId" not in self.params:
|
216
|
+
raise ValueError("Cancellation notification requires requestId")
|
217
|
+
|
218
|
+
@classmethod
|
219
|
+
def create(
|
220
|
+
cls, request_id: str, reason: Optional[str] = None
|
221
|
+
) -> "CancelledNotification":
|
222
|
+
"""Create cancellation notification."""
|
223
|
+
params = {"requestId": request_id}
|
224
|
+
if reason:
|
225
|
+
params["reason"] = reason
|
226
|
+
return cls(params=params)
|
227
|
+
|
228
|
+
|
229
|
+
@dataclass
|
230
|
+
class CompletionRequest:
|
231
|
+
"""Completion request for prompts and resources."""
|
232
|
+
|
233
|
+
method: str = "completion/complete"
|
234
|
+
params: Dict[str, Any] = field(default_factory=dict)
|
235
|
+
|
236
|
+
@classmethod
|
237
|
+
def create(
|
238
|
+
cls, ref: Dict[str, Any], argument: Optional[Dict[str, Any]] = None
|
239
|
+
) -> "CompletionRequest":
|
240
|
+
"""Create completion request."""
|
241
|
+
params = {"ref": ref}
|
242
|
+
if argument:
|
243
|
+
params["argument"] = argument
|
244
|
+
return cls(params=params)
|
245
|
+
|
246
|
+
|
247
|
+
@dataclass
|
248
|
+
class CompletionResult:
|
249
|
+
"""Completion result with completion values."""
|
250
|
+
|
251
|
+
completion: Dict[str, Any]
|
252
|
+
|
253
|
+
@classmethod
|
254
|
+
def create(
|
255
|
+
cls, values: List[str], total: Optional[int] = None
|
256
|
+
) -> "CompletionResult":
|
257
|
+
"""Create completion result."""
|
258
|
+
completion = {"values": values}
|
259
|
+
if total is not None:
|
260
|
+
completion["total"] = total
|
261
|
+
return cls(completion=completion)
|
262
|
+
|
263
|
+
|
264
|
+
@dataclass
|
265
|
+
class SamplingRequest:
|
266
|
+
"""Sampling request from server to client."""
|
267
|
+
|
268
|
+
method: str = "sampling/createMessage"
|
269
|
+
params: Dict[str, Any] = field(default_factory=dict)
|
270
|
+
|
271
|
+
@classmethod
|
272
|
+
def create(
|
273
|
+
cls,
|
274
|
+
messages: List[Dict[str, Any]],
|
275
|
+
model_preferences: Optional[Dict[str, Any]] = None,
|
276
|
+
system_prompt: Optional[str] = None,
|
277
|
+
include_context: Optional[str] = None,
|
278
|
+
temperature: Optional[float] = None,
|
279
|
+
max_tokens: Optional[int] = None,
|
280
|
+
stop_sequences: Optional[List[str]] = None,
|
281
|
+
metadata: Optional[Dict[str, Any]] = None,
|
282
|
+
) -> "SamplingRequest":
|
283
|
+
"""Create sampling request."""
|
284
|
+
params = {"messages": messages}
|
285
|
+
|
286
|
+
if model_preferences:
|
287
|
+
params["modelPreferences"] = model_preferences
|
288
|
+
if system_prompt:
|
289
|
+
params["systemPrompt"] = system_prompt
|
290
|
+
if include_context:
|
291
|
+
params["includeContext"] = include_context
|
292
|
+
if temperature is not None:
|
293
|
+
params["temperature"] = temperature
|
294
|
+
if max_tokens is not None:
|
295
|
+
params["maxTokens"] = max_tokens
|
296
|
+
if stop_sequences:
|
297
|
+
params["stopSequences"] = stop_sequences
|
298
|
+
if metadata:
|
299
|
+
params["metadata"] = metadata
|
300
|
+
|
301
|
+
return cls(params=params)
|
302
|
+
|
303
|
+
|
304
|
+
@dataclass
|
305
|
+
class ResourceTemplate:
|
306
|
+
"""Resource template with URI templates."""
|
307
|
+
|
308
|
+
uri_template: str
|
309
|
+
name: Optional[str] = None
|
310
|
+
description: Optional[str] = None
|
311
|
+
mime_type: Optional[str] = None
|
312
|
+
|
313
|
+
def to_dict(self) -> Dict[str, Any]:
|
314
|
+
"""Convert to dictionary."""
|
315
|
+
result = {"uriTemplate": self.uri_template}
|
316
|
+
if self.name:
|
317
|
+
result["name"] = self.name
|
318
|
+
if self.description:
|
319
|
+
result["description"] = self.description
|
320
|
+
if self.mime_type:
|
321
|
+
result["mimeType"] = self.mime_type
|
322
|
+
return result
|
323
|
+
|
324
|
+
|
325
|
+
@dataclass
|
326
|
+
class ToolResult:
|
327
|
+
"""Enhanced tool result with structured content."""
|
328
|
+
|
329
|
+
content: List[Dict[str, Any]]
|
330
|
+
is_error: bool = False
|
331
|
+
|
332
|
+
@classmethod
|
333
|
+
def text(cls, text: str, is_error: bool = False) -> "ToolResult":
|
334
|
+
"""Create text result."""
|
335
|
+
return cls(content=[{"type": "text", "text": text}], is_error=is_error)
|
336
|
+
|
337
|
+
@classmethod
|
338
|
+
def image(cls, data: str, mime_type: str) -> "ToolResult":
|
339
|
+
"""Create image result."""
|
340
|
+
return cls(content=[{"type": "image", "data": data, "mimeType": mime_type}])
|
341
|
+
|
342
|
+
@classmethod
|
343
|
+
def resource(
|
344
|
+
cls, uri: str, text: Optional[str] = None, mime_type: Optional[str] = None
|
345
|
+
) -> "ToolResult":
|
346
|
+
"""Create resource result."""
|
347
|
+
content = {"type": "resource", "resource": {"uri": uri}}
|
348
|
+
if text:
|
349
|
+
content["resource"]["text"] = text
|
350
|
+
if mime_type:
|
351
|
+
content["resource"]["mimeType"] = mime_type
|
352
|
+
return cls(content=[content])
|
353
|
+
|
354
|
+
def to_dict(self) -> Dict[str, Any]:
|
355
|
+
"""Convert to dictionary."""
|
356
|
+
result = {"content": self.content}
|
357
|
+
if self.is_error:
|
358
|
+
result["isError"] = self.is_error
|
359
|
+
return result
|
360
|
+
|
361
|
+
|
362
|
+
class ProgressManager:
|
363
|
+
"""Manages progress reporting for long-running operations."""
|
364
|
+
|
365
|
+
def __init__(self):
|
366
|
+
"""Initialize progress manager."""
|
367
|
+
self._active_progress: Dict[ProgressToken, Dict[str, Any]] = {}
|
368
|
+
self._progress_callbacks: Dict[ProgressToken, List[Callable]] = {}
|
369
|
+
|
370
|
+
def start_progress(
|
371
|
+
self,
|
372
|
+
operation_name: str,
|
373
|
+
total: Optional[float] = None,
|
374
|
+
progress_token: Optional[ProgressToken] = None,
|
375
|
+
) -> ProgressToken:
|
376
|
+
"""Start progress tracking for an operation.
|
377
|
+
|
378
|
+
Args:
|
379
|
+
operation_name: Name of the operation
|
380
|
+
total: Total progress units (if known)
|
381
|
+
progress_token: Custom progress token (generates if None)
|
382
|
+
|
383
|
+
Returns:
|
384
|
+
Progress token for tracking
|
385
|
+
"""
|
386
|
+
if progress_token is None:
|
387
|
+
token_value = f"progress_{uuid.uuid4().hex[:8]}"
|
388
|
+
progress_token = ProgressToken(
|
389
|
+
value=token_value,
|
390
|
+
operation_name=operation_name,
|
391
|
+
total=total,
|
392
|
+
progress=0,
|
393
|
+
status="started",
|
394
|
+
)
|
395
|
+
|
396
|
+
self._active_progress[progress_token] = {
|
397
|
+
"operation": operation_name,
|
398
|
+
"started_at": time.time(),
|
399
|
+
"total": total,
|
400
|
+
"current": 0,
|
401
|
+
"status": "started",
|
402
|
+
}
|
403
|
+
self._progress_callbacks[progress_token] = []
|
404
|
+
|
405
|
+
logger.debug(
|
406
|
+
f"Started progress tracking: {operation_name} ({progress_token.value})"
|
407
|
+
)
|
408
|
+
return progress_token
|
409
|
+
|
410
|
+
async def update_progress(
|
411
|
+
self,
|
412
|
+
progress_token: ProgressToken,
|
413
|
+
progress: Optional[float] = None,
|
414
|
+
status: Optional[str] = None,
|
415
|
+
increment: Optional[float] = None,
|
416
|
+
) -> None:
|
417
|
+
"""Update progress for an operation.
|
418
|
+
|
419
|
+
Args:
|
420
|
+
progress_token: Progress token
|
421
|
+
progress: Current progress value
|
422
|
+
status: Status message
|
423
|
+
increment: Amount to increment current progress
|
424
|
+
"""
|
425
|
+
if progress_token not in self._active_progress:
|
426
|
+
logger.warning(f"Progress token not found: {progress_token}")
|
427
|
+
return
|
428
|
+
|
429
|
+
progress_info = self._active_progress[progress_token]
|
430
|
+
|
431
|
+
# Update progress value
|
432
|
+
if progress is not None:
|
433
|
+
progress_info["current"] = progress
|
434
|
+
progress_token.progress = progress
|
435
|
+
elif increment is not None:
|
436
|
+
new_progress = progress_info.get("current", 0) + increment
|
437
|
+
progress_info["current"] = new_progress
|
438
|
+
progress_token.progress = new_progress
|
439
|
+
|
440
|
+
# Update status
|
441
|
+
if status is not None:
|
442
|
+
progress_info["status"] = status
|
443
|
+
progress_token.status = status
|
444
|
+
|
445
|
+
progress_info["updated_at"] = time.time()
|
446
|
+
|
447
|
+
# Create notification
|
448
|
+
notification = ProgressNotification.create(
|
449
|
+
progress_token=progress_token.value,
|
450
|
+
progress=progress_info["current"],
|
451
|
+
total=progress_info.get("total"),
|
452
|
+
status=progress_info["status"],
|
453
|
+
)
|
454
|
+
|
455
|
+
# Call callbacks
|
456
|
+
for callback in self._progress_callbacks.get(progress_token, []):
|
457
|
+
try:
|
458
|
+
if asyncio.iscoroutinefunction(callback):
|
459
|
+
await callback(notification)
|
460
|
+
else:
|
461
|
+
callback(notification)
|
462
|
+
except Exception as e:
|
463
|
+
logger.error(f"Progress callback error: {e}")
|
464
|
+
|
465
|
+
async def complete_progress(
|
466
|
+
self, progress_token: ProgressToken, status: str = "completed"
|
467
|
+
) -> None:
|
468
|
+
"""Complete progress tracking.
|
469
|
+
|
470
|
+
Args:
|
471
|
+
progress_token: Progress token
|
472
|
+
status: Final status message
|
473
|
+
"""
|
474
|
+
if progress_token not in self._active_progress:
|
475
|
+
return
|
476
|
+
|
477
|
+
progress_info = self._active_progress[progress_token]
|
478
|
+
progress_info["status"] = status
|
479
|
+
progress_info["completed_at"] = time.time()
|
480
|
+
|
481
|
+
# Update token status
|
482
|
+
progress_token.status = status
|
483
|
+
|
484
|
+
# Send final progress update
|
485
|
+
await self.update_progress(progress_token, status=status)
|
486
|
+
|
487
|
+
# Clean up
|
488
|
+
del self._active_progress[progress_token]
|
489
|
+
del self._progress_callbacks[progress_token]
|
490
|
+
|
491
|
+
logger.debug(f"Completed progress tracking: {progress_token.value}")
|
492
|
+
|
493
|
+
def add_progress_callback(
|
494
|
+
self, progress_token: ProgressToken, callback: Callable
|
495
|
+
) -> None:
|
496
|
+
"""Add callback for progress updates.
|
497
|
+
|
498
|
+
Args:
|
499
|
+
progress_token: Progress token
|
500
|
+
callback: Callback function
|
501
|
+
"""
|
502
|
+
if progress_token in self._progress_callbacks:
|
503
|
+
self._progress_callbacks[progress_token].append(callback)
|
504
|
+
|
505
|
+
def get_progress_info(
|
506
|
+
self, progress_token: ProgressToken
|
507
|
+
) -> Optional[Dict[str, Any]]:
|
508
|
+
"""Get current progress information.
|
509
|
+
|
510
|
+
Args:
|
511
|
+
progress_token: Progress token
|
512
|
+
|
513
|
+
Returns:
|
514
|
+
Progress information or None
|
515
|
+
"""
|
516
|
+
return self._active_progress.get(progress_token)
|
517
|
+
|
518
|
+
def list_active_progress(self) -> List[ProgressToken]:
|
519
|
+
"""List all active progress tokens."""
|
520
|
+
return list(self._active_progress.keys())
|
521
|
+
|
522
|
+
def get_active_progress(self) -> List[ProgressToken]:
|
523
|
+
"""Get all active progress tokens (alias for list_active_progress)."""
|
524
|
+
return self.list_active_progress()
|
525
|
+
|
526
|
+
|
527
|
+
class CancellationManager:
|
528
|
+
"""Manages request cancellation and cleanup."""
|
529
|
+
|
530
|
+
def __init__(self):
|
531
|
+
"""Initialize cancellation manager."""
|
532
|
+
self._cancelled_requests: set[str] = set()
|
533
|
+
self._cancellation_callbacks: Dict[str, List[Callable]] = {}
|
534
|
+
self._request_cleanup: Dict[str, List[Callable]] = {}
|
535
|
+
|
536
|
+
async def cancel_request(
|
537
|
+
self, request_id: str, reason: Optional[str] = None
|
538
|
+
) -> None:
|
539
|
+
"""Cancel a request.
|
540
|
+
|
541
|
+
Args:
|
542
|
+
request_id: Request ID to cancel
|
543
|
+
reason: Cancellation reason
|
544
|
+
"""
|
545
|
+
if request_id in self._cancelled_requests:
|
546
|
+
return # Already cancelled
|
547
|
+
|
548
|
+
self._cancelled_requests.add(request_id)
|
549
|
+
|
550
|
+
# Store cancellation reason
|
551
|
+
if not hasattr(self, "_cancellation_reasons"):
|
552
|
+
self._cancellation_reasons = {}
|
553
|
+
self._cancellation_reasons[request_id] = reason
|
554
|
+
|
555
|
+
# Create cancellation notification
|
556
|
+
notification = CancelledNotification.create(request_id, reason)
|
557
|
+
|
558
|
+
# Call cancellation callbacks
|
559
|
+
for callback in self._cancellation_callbacks.get(request_id, []):
|
560
|
+
try:
|
561
|
+
if asyncio.iscoroutinefunction(callback):
|
562
|
+
await callback(notification)
|
563
|
+
else:
|
564
|
+
callback(notification)
|
565
|
+
except Exception as e:
|
566
|
+
logger.error(f"Cancellation callback error: {e}")
|
567
|
+
|
568
|
+
# Run cleanup functions
|
569
|
+
for cleanup in self._request_cleanup.get(request_id, []):
|
570
|
+
try:
|
571
|
+
if asyncio.iscoroutinefunction(cleanup):
|
572
|
+
await cleanup()
|
573
|
+
else:
|
574
|
+
cleanup()
|
575
|
+
except Exception as e:
|
576
|
+
logger.error(f"Cleanup error for {request_id}: {e}")
|
577
|
+
|
578
|
+
# Clean up tracking
|
579
|
+
self._cancellation_callbacks.pop(request_id, None)
|
580
|
+
self._request_cleanup.pop(request_id, None)
|
581
|
+
|
582
|
+
logger.info(f"Cancelled request: {request_id}")
|
583
|
+
|
584
|
+
def is_cancelled(self, request_id: str) -> bool:
|
585
|
+
"""Check if a request is cancelled.
|
586
|
+
|
587
|
+
Args:
|
588
|
+
request_id: Request ID to check
|
589
|
+
|
590
|
+
Returns:
|
591
|
+
True if cancelled
|
592
|
+
"""
|
593
|
+
return request_id in self._cancelled_requests
|
594
|
+
|
595
|
+
def add_cancellation_callback(self, request_id: str, callback: Callable) -> None:
|
596
|
+
"""Add callback for request cancellation.
|
597
|
+
|
598
|
+
Args:
|
599
|
+
request_id: Request ID
|
600
|
+
callback: Callback function
|
601
|
+
"""
|
602
|
+
if request_id not in self._cancellation_callbacks:
|
603
|
+
self._cancellation_callbacks[request_id] = []
|
604
|
+
self._cancellation_callbacks[request_id].append(callback)
|
605
|
+
|
606
|
+
def add_cleanup_function(self, request_id: str, cleanup: Callable) -> None:
|
607
|
+
"""Add cleanup function for request.
|
608
|
+
|
609
|
+
Args:
|
610
|
+
request_id: Request ID
|
611
|
+
cleanup: Cleanup function
|
612
|
+
"""
|
613
|
+
if request_id not in self._request_cleanup:
|
614
|
+
self._request_cleanup[request_id] = []
|
615
|
+
self._request_cleanup[request_id].append(cleanup)
|
616
|
+
|
617
|
+
def clear_cancelled_request(self, request_id: str) -> None:
|
618
|
+
"""Clear cancelled request from tracking.
|
619
|
+
|
620
|
+
Args:
|
621
|
+
request_id: Request ID to clear
|
622
|
+
"""
|
623
|
+
self._cancelled_requests.discard(request_id)
|
624
|
+
if hasattr(self, "_cancellation_reasons"):
|
625
|
+
self._cancellation_reasons.pop(request_id, None)
|
626
|
+
|
627
|
+
def get_cancellation_reason(self, request_id: str) -> Optional[str]:
|
628
|
+
"""Get cancellation reason for a request.
|
629
|
+
|
630
|
+
Args:
|
631
|
+
request_id: Request ID to check
|
632
|
+
|
633
|
+
Returns:
|
634
|
+
Cancellation reason if cancelled, None otherwise
|
635
|
+
"""
|
636
|
+
if not hasattr(self, "_cancellation_reasons"):
|
637
|
+
self._cancellation_reasons = {}
|
638
|
+
return self._cancellation_reasons.get(request_id)
|
639
|
+
|
640
|
+
|
641
|
+
class CompletionManager:
|
642
|
+
"""Manages auto-completion for prompts and resources."""
|
643
|
+
|
644
|
+
def __init__(self):
|
645
|
+
"""Initialize completion manager."""
|
646
|
+
self._completion_providers: Dict[str, Callable] = {}
|
647
|
+
self._available_tools = []
|
648
|
+
self._available_resources = []
|
649
|
+
|
650
|
+
def register_completion_provider(self, ref_type: str, provider: Callable) -> None:
|
651
|
+
"""Register completion provider for a reference type.
|
652
|
+
|
653
|
+
Args:
|
654
|
+
ref_type: Reference type (e.g., "prompts", "resources")
|
655
|
+
provider: Completion provider function
|
656
|
+
"""
|
657
|
+
self._completion_providers[ref_type] = provider
|
658
|
+
|
659
|
+
async def get_completions(
|
660
|
+
self,
|
661
|
+
completion_type: str = None,
|
662
|
+
ref_type: str = None,
|
663
|
+
ref_name: Optional[str] = None,
|
664
|
+
partial: Optional[str] = None,
|
665
|
+
prefix: Optional[str] = None,
|
666
|
+
) -> List[Any]:
|
667
|
+
"""Get completions for a reference.
|
668
|
+
|
669
|
+
Args:
|
670
|
+
completion_type: Type of completion ("tools", "resources", etc)
|
671
|
+
ref_type: Reference type (e.g., "tools", "resources", "prompts")
|
672
|
+
ref_name: Reference name (optional)
|
673
|
+
partial: Partial input to complete (optional)
|
674
|
+
prefix: Prefix to filter completions (optional)
|
675
|
+
|
676
|
+
Returns:
|
677
|
+
List of completion items
|
678
|
+
"""
|
679
|
+
# Handle different argument patterns
|
680
|
+
type_to_use = completion_type or ref_type
|
681
|
+
filter_text = prefix or partial
|
682
|
+
|
683
|
+
if type_to_use == "tools":
|
684
|
+
tools = self._get_available_tools()
|
685
|
+
if filter_text:
|
686
|
+
return [t for t in tools if t.get("name", "").startswith(filter_text)]
|
687
|
+
return tools
|
688
|
+
elif type_to_use == "resources":
|
689
|
+
resources = self._get_available_resources()
|
690
|
+
if filter_text:
|
691
|
+
return [
|
692
|
+
r for r in resources if r.get("uri", "").startswith(filter_text)
|
693
|
+
]
|
694
|
+
return resources
|
695
|
+
|
696
|
+
# Use registered provider if available
|
697
|
+
provider = self._completion_providers.get(type_to_use)
|
698
|
+
if not provider:
|
699
|
+
return []
|
700
|
+
|
701
|
+
try:
|
702
|
+
if asyncio.iscoroutinefunction(provider):
|
703
|
+
completions = await provider(ref_name, filter_text)
|
704
|
+
else:
|
705
|
+
completions = provider(ref_name, filter_text)
|
706
|
+
|
707
|
+
if isinstance(completions, list):
|
708
|
+
return completions
|
709
|
+
else:
|
710
|
+
return []
|
711
|
+
|
712
|
+
except Exception as e:
|
713
|
+
logger.error(f"Completion provider error: {e}")
|
714
|
+
return []
|
715
|
+
|
716
|
+
def _get_available_tools(self) -> List[Dict[str, Any]]:
|
717
|
+
"""Get available tools for completion."""
|
718
|
+
return self._available_tools
|
719
|
+
|
720
|
+
def _get_available_resources(self) -> List[Dict[str, Any]]:
|
721
|
+
"""Get available resources for completion."""
|
722
|
+
return self._available_resources
|
723
|
+
|
724
|
+
|
725
|
+
class SamplingManager:
|
726
|
+
"""Manages LLM sampling requests from server to client."""
|
727
|
+
|
728
|
+
def __init__(self):
|
729
|
+
"""Initialize sampling manager."""
|
730
|
+
self._sampling_callbacks: List[Callable] = []
|
731
|
+
self._samples: List[Dict[str, Any]] = []
|
732
|
+
|
733
|
+
def add_sampling_callback(self, callback: Callable) -> None:
|
734
|
+
"""Add callback for sampling requests.
|
735
|
+
|
736
|
+
Args:
|
737
|
+
callback: Sampling callback function
|
738
|
+
"""
|
739
|
+
self._sampling_callbacks.append(callback)
|
740
|
+
|
741
|
+
async def request_sampling(
|
742
|
+
self, messages: List[Dict[str, Any]], **kwargs
|
743
|
+
) -> Dict[str, Any]:
|
744
|
+
"""Request LLM sampling from client.
|
745
|
+
|
746
|
+
Args:
|
747
|
+
messages: Messages for sampling
|
748
|
+
**kwargs: Additional sampling parameters
|
749
|
+
|
750
|
+
Returns:
|
751
|
+
Sampling result
|
752
|
+
"""
|
753
|
+
request = SamplingRequest.create(messages, **kwargs)
|
754
|
+
|
755
|
+
# Try each callback until one handles the request
|
756
|
+
for callback in self._sampling_callbacks:
|
757
|
+
try:
|
758
|
+
if asyncio.iscoroutinefunction(callback):
|
759
|
+
result = await callback(request)
|
760
|
+
else:
|
761
|
+
result = callback(request)
|
762
|
+
|
763
|
+
if result is not None:
|
764
|
+
return result
|
765
|
+
|
766
|
+
except Exception as e:
|
767
|
+
logger.error(f"Sampling callback error: {e}")
|
768
|
+
|
769
|
+
raise MCPError(
|
770
|
+
"No sampling provider available", error_code=MCPErrorCode.METHOD_NOT_FOUND
|
771
|
+
)
|
772
|
+
|
773
|
+
async def create_message_sample(
|
774
|
+
self, messages: List[Dict[str, Any]], **kwargs
|
775
|
+
) -> Dict[str, Any]:
|
776
|
+
"""Create a message sample.
|
777
|
+
|
778
|
+
Args:
|
779
|
+
messages: Messages for sampling
|
780
|
+
**kwargs: Additional sampling parameters including model_preferences, metadata
|
781
|
+
|
782
|
+
Returns:
|
783
|
+
Sampling result with sample_id and timestamp
|
784
|
+
"""
|
785
|
+
# Create sample with required fields
|
786
|
+
sample = {
|
787
|
+
"messages": messages,
|
788
|
+
"sample_id": f"sample_{uuid.uuid4().hex[:8]}",
|
789
|
+
"timestamp": time.time(),
|
790
|
+
}
|
791
|
+
|
792
|
+
# Add optional fields
|
793
|
+
if "model_preferences" in kwargs:
|
794
|
+
sample["model_preferences"] = kwargs["model_preferences"]
|
795
|
+
if "metadata" in kwargs:
|
796
|
+
sample["metadata"] = kwargs["metadata"]
|
797
|
+
|
798
|
+
# Store in history
|
799
|
+
self._samples.append(sample)
|
800
|
+
|
801
|
+
return sample
|
802
|
+
|
803
|
+
def get_sample_history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
804
|
+
"""Get sampling history.
|
805
|
+
|
806
|
+
Args:
|
807
|
+
limit: Maximum number of samples to return
|
808
|
+
|
809
|
+
Returns:
|
810
|
+
List of sample history entries
|
811
|
+
"""
|
812
|
+
if limit is None:
|
813
|
+
return self._samples.copy()
|
814
|
+
return self._samples[-limit:] if limit > 0 else []
|
815
|
+
|
816
|
+
def clear_sample_history(self) -> None:
|
817
|
+
"""Clear sampling history."""
|
818
|
+
self._samples.clear()
|
819
|
+
|
820
|
+
|
821
|
+
class RootsManager:
|
822
|
+
"""Manages file system roots access."""
|
823
|
+
|
824
|
+
def __init__(self):
|
825
|
+
"""Initialize roots manager."""
|
826
|
+
self._roots: List[Dict[str, Any]] = []
|
827
|
+
self._access_validators: List[Callable] = []
|
828
|
+
|
829
|
+
def add_root(
|
830
|
+
self, uri: str, name: Optional[str] = None, description: Optional[str] = None
|
831
|
+
) -> None:
|
832
|
+
"""Add a file system root.
|
833
|
+
|
834
|
+
Args:
|
835
|
+
uri: Root URI
|
836
|
+
name: Optional name for the root
|
837
|
+
description: Optional description for the root
|
838
|
+
"""
|
839
|
+
root = {"uri": uri}
|
840
|
+
if name:
|
841
|
+
root["name"] = name
|
842
|
+
if description:
|
843
|
+
root["description"] = description
|
844
|
+
|
845
|
+
self._roots.append(root)
|
846
|
+
logger.info(f"Added root: {uri}")
|
847
|
+
|
848
|
+
def remove_root(self, uri: str) -> bool:
|
849
|
+
"""Remove a file system root.
|
850
|
+
|
851
|
+
Args:
|
852
|
+
uri: Root URI to remove
|
853
|
+
|
854
|
+
Returns:
|
855
|
+
True if removed
|
856
|
+
"""
|
857
|
+
for i, root in enumerate(self._roots):
|
858
|
+
if root["uri"] == uri:
|
859
|
+
del self._roots[i]
|
860
|
+
logger.info(f"Removed root: {uri}")
|
861
|
+
return True
|
862
|
+
return False
|
863
|
+
|
864
|
+
def list_roots(self) -> List[Dict[str, Any]]:
|
865
|
+
"""List all file system roots.
|
866
|
+
|
867
|
+
Returns:
|
868
|
+
List of root objects
|
869
|
+
"""
|
870
|
+
return self._roots.copy()
|
871
|
+
|
872
|
+
def find_root_for_uri(self, uri: str) -> Optional[Dict[str, Any]]:
|
873
|
+
"""Find the root that contains the given URI.
|
874
|
+
|
875
|
+
Args:
|
876
|
+
uri: URI to find root for
|
877
|
+
|
878
|
+
Returns:
|
879
|
+
Root object if found, None otherwise
|
880
|
+
"""
|
881
|
+
for root in self._roots:
|
882
|
+
if uri.startswith(root["uri"]):
|
883
|
+
return root
|
884
|
+
return None
|
885
|
+
|
886
|
+
def add_access_validator(self, validator: Callable) -> None:
|
887
|
+
"""Add access validator for roots.
|
888
|
+
|
889
|
+
Args:
|
890
|
+
validator: Validator function
|
891
|
+
"""
|
892
|
+
self._access_validators.append(validator)
|
893
|
+
|
894
|
+
async def validate_access(self, uri: str, operation: str = "read") -> bool:
|
895
|
+
"""Validate access to a URI.
|
896
|
+
|
897
|
+
Args:
|
898
|
+
uri: URI to validate
|
899
|
+
operation: Operation type
|
900
|
+
|
901
|
+
Returns:
|
902
|
+
True if access is allowed
|
903
|
+
"""
|
904
|
+
# Check if URI is under any root
|
905
|
+
is_under_root = False
|
906
|
+
for root in self._roots:
|
907
|
+
root_uri = root["uri"]
|
908
|
+
if uri.startswith(root_uri):
|
909
|
+
is_under_root = True
|
910
|
+
break
|
911
|
+
|
912
|
+
if not is_under_root:
|
913
|
+
return False
|
914
|
+
|
915
|
+
# Run access validators
|
916
|
+
for validator in self._access_validators:
|
917
|
+
try:
|
918
|
+
if asyncio.iscoroutinefunction(validator):
|
919
|
+
allowed = await validator(uri, operation)
|
920
|
+
else:
|
921
|
+
allowed = validator(uri, operation)
|
922
|
+
|
923
|
+
if not allowed:
|
924
|
+
return False
|
925
|
+
|
926
|
+
except Exception as e:
|
927
|
+
logger.error(f"Access validator error: {e}")
|
928
|
+
return False
|
929
|
+
|
930
|
+
return True
|
931
|
+
|
932
|
+
|
933
|
+
class ProtocolManager:
|
934
|
+
"""Central manager for all MCP protocol features."""
|
935
|
+
|
936
|
+
def __init__(self):
|
937
|
+
"""Initialize protocol manager."""
|
938
|
+
self.progress = ProgressManager()
|
939
|
+
self.cancellation = CancellationManager()
|
940
|
+
self.completion = CompletionManager()
|
941
|
+
self.sampling = SamplingManager()
|
942
|
+
self.roots = RootsManager()
|
943
|
+
|
944
|
+
# Protocol state
|
945
|
+
self._initialized = False
|
946
|
+
self._client_capabilities: Dict[str, Any] = {}
|
947
|
+
self._server_capabilities: Dict[str, Any] = {}
|
948
|
+
self._handlers: Dict[str, Callable] = {}
|
949
|
+
|
950
|
+
def set_initialized(self, client_capabilities: Dict[str, Any]) -> None:
|
951
|
+
"""Set protocol as initialized with client capabilities.
|
952
|
+
|
953
|
+
Args:
|
954
|
+
client_capabilities: Client capability advertisement
|
955
|
+
"""
|
956
|
+
self._initialized = True
|
957
|
+
self._client_capabilities = client_capabilities
|
958
|
+
logger.info("MCP protocol initialized")
|
959
|
+
|
960
|
+
def is_initialized(self) -> bool:
|
961
|
+
"""Check if protocol is initialized."""
|
962
|
+
return self._initialized
|
963
|
+
|
964
|
+
def get_client_capabilities(self) -> Dict[str, Any]:
|
965
|
+
"""Get client capabilities."""
|
966
|
+
return self._client_capabilities.copy()
|
967
|
+
|
968
|
+
def set_server_capabilities(self, capabilities: Dict[str, Any]) -> None:
|
969
|
+
"""Set server capabilities.
|
970
|
+
|
971
|
+
Args:
|
972
|
+
capabilities: Server capabilities
|
973
|
+
"""
|
974
|
+
self._server_capabilities = capabilities
|
975
|
+
|
976
|
+
def get_server_capabilities(self) -> Dict[str, Any]:
|
977
|
+
"""Get server capabilities."""
|
978
|
+
return self._server_capabilities.copy()
|
979
|
+
|
980
|
+
def supports_progress(self) -> bool:
|
981
|
+
"""Check if client supports progress reporting."""
|
982
|
+
return self._client_capabilities.get("experimental", {}).get(
|
983
|
+
"progressNotifications", False
|
984
|
+
)
|
985
|
+
|
986
|
+
def supports_cancellation(self) -> bool:
|
987
|
+
"""Check if client supports cancellation."""
|
988
|
+
return True # Basic support assumed
|
989
|
+
|
990
|
+
def supports_completion(self) -> bool:
|
991
|
+
"""Check if client supports completion."""
|
992
|
+
return self._client_capabilities.get("experimental", {}).get(
|
993
|
+
"completion", False
|
994
|
+
)
|
995
|
+
|
996
|
+
def supports_sampling(self) -> bool:
|
997
|
+
"""Check if client supports sampling."""
|
998
|
+
return self._client_capabilities.get("experimental", {}).get("sampling", False)
|
999
|
+
|
1000
|
+
def supports_roots(self) -> bool:
|
1001
|
+
"""Check if client supports roots."""
|
1002
|
+
return self._client_capabilities.get("roots", {}).get("listChanged", False)
|
1003
|
+
|
1004
|
+
def _get_handler(self, method: str) -> Optional[Callable]:
|
1005
|
+
"""Get handler for a method.
|
1006
|
+
|
1007
|
+
Args:
|
1008
|
+
method: Method name
|
1009
|
+
|
1010
|
+
Returns:
|
1011
|
+
Handler function or None
|
1012
|
+
"""
|
1013
|
+
return self._handlers.get(method)
|
1014
|
+
|
1015
|
+
def register_handler(self, method: str, handler: Callable) -> None:
|
1016
|
+
"""Register a handler for a method.
|
1017
|
+
|
1018
|
+
Args:
|
1019
|
+
method: Method name
|
1020
|
+
handler: Handler function
|
1021
|
+
"""
|
1022
|
+
self._handlers[method] = handler
|
1023
|
+
|
1024
|
+
async def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
1025
|
+
"""Handle an incoming request.
|
1026
|
+
|
1027
|
+
Args:
|
1028
|
+
request: Request message
|
1029
|
+
|
1030
|
+
Returns:
|
1031
|
+
Response message
|
1032
|
+
"""
|
1033
|
+
method = request.get("method")
|
1034
|
+
params = request.get("params", {})
|
1035
|
+
request_id = request.get("id")
|
1036
|
+
|
1037
|
+
handler = self._get_handler(method)
|
1038
|
+
if not handler:
|
1039
|
+
raise MCPError(
|
1040
|
+
f"Method not found: {method}", error_code=MCPErrorCode.METHOD_NOT_FOUND
|
1041
|
+
)
|
1042
|
+
|
1043
|
+
try:
|
1044
|
+
# Call handler
|
1045
|
+
if asyncio.iscoroutinefunction(handler):
|
1046
|
+
result = await handler(request)
|
1047
|
+
else:
|
1048
|
+
result = handler(request)
|
1049
|
+
|
1050
|
+
# Build response
|
1051
|
+
response = {"jsonrpc": "2.0", "result": result, "id": request_id}
|
1052
|
+
return response
|
1053
|
+
|
1054
|
+
except MCPError:
|
1055
|
+
raise
|
1056
|
+
except Exception as e:
|
1057
|
+
logger.error(f"Handler error for {method}: {e}")
|
1058
|
+
raise MCPError(str(e), error_code=MCPErrorCode.INTERNAL_ERROR)
|
1059
|
+
|
1060
|
+
def validate_message_type(self, message: Dict[str, Any]) -> MessageType:
|
1061
|
+
"""Validate and determine message type.
|
1062
|
+
|
1063
|
+
Args:
|
1064
|
+
message: Message to validate
|
1065
|
+
|
1066
|
+
Returns:
|
1067
|
+
Message type
|
1068
|
+
|
1069
|
+
Raises:
|
1070
|
+
MCPError: If message is invalid
|
1071
|
+
"""
|
1072
|
+
if "jsonrpc" not in message or message["jsonrpc"] != "2.0":
|
1073
|
+
raise MCPError(
|
1074
|
+
"Invalid JSON-RPC version", error_code=MCPErrorCode.INVALID_REQUEST
|
1075
|
+
)
|
1076
|
+
|
1077
|
+
if "method" not in message:
|
1078
|
+
raise MCPError(
|
1079
|
+
"Missing method field", error_code=MCPErrorCode.INVALID_REQUEST
|
1080
|
+
)
|
1081
|
+
|
1082
|
+
# Check if it's a request or notification
|
1083
|
+
if "id" in message:
|
1084
|
+
return MessageType.REQUEST
|
1085
|
+
else:
|
1086
|
+
return MessageType.NOTIFICATION
|
1087
|
+
|
1088
|
+
|
1089
|
+
# Global protocol manager instance
|
1090
|
+
_protocol_manager: Optional[ProtocolManager] = None
|
1091
|
+
|
1092
|
+
|
1093
|
+
def get_protocol_manager() -> ProtocolManager:
|
1094
|
+
"""Get global protocol manager instance."""
|
1095
|
+
global _protocol_manager
|
1096
|
+
if _protocol_manager is None:
|
1097
|
+
_protocol_manager = ProtocolManager()
|
1098
|
+
return _protocol_manager
|
1099
|
+
|
1100
|
+
|
1101
|
+
# Convenience functions
|
1102
|
+
def start_progress(operation_name: str, total: Optional[float] = None) -> ProgressToken:
|
1103
|
+
"""Start progress tracking."""
|
1104
|
+
return get_protocol_manager().progress.start_progress(operation_name, total)
|
1105
|
+
|
1106
|
+
|
1107
|
+
async def update_progress(
|
1108
|
+
token: ProgressToken, progress: Optional[float] = None, status: Optional[str] = None
|
1109
|
+
) -> None:
|
1110
|
+
"""Update progress."""
|
1111
|
+
await get_protocol_manager().progress.update_progress(token, progress, status)
|
1112
|
+
|
1113
|
+
|
1114
|
+
async def complete_progress(token: ProgressToken, status: str = "completed") -> None:
|
1115
|
+
"""Complete progress."""
|
1116
|
+
await get_protocol_manager().progress.complete_progress(token, status)
|
1117
|
+
|
1118
|
+
|
1119
|
+
def is_cancelled(request_id: str) -> bool:
|
1120
|
+
"""Check if request is cancelled."""
|
1121
|
+
return get_protocol_manager().cancellation.is_cancelled(request_id)
|
1122
|
+
|
1123
|
+
|
1124
|
+
async def cancel_request(request_id: str, reason: Optional[str] = None) -> None:
|
1125
|
+
"""Cancel a request."""
|
1126
|
+
await get_protocol_manager().cancellation.cancel_request(request_id, reason)
|