mcp-use 1.3.10__py3-none-any.whl → 1.3.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mcp-use might be problematic. Click here for more details.
- mcp_use/adapters/langchain_adapter.py +9 -52
- mcp_use/agents/mcpagent.py +88 -37
- mcp_use/agents/prompts/templates.py +1 -10
- mcp_use/agents/remote.py +154 -128
- mcp_use/auth/__init__.py +6 -0
- mcp_use/auth/bearer.py +17 -0
- mcp_use/auth/oauth.py +625 -0
- mcp_use/auth/oauth_callback.py +214 -0
- mcp_use/client.py +25 -1
- mcp_use/config.py +7 -2
- mcp_use/connectors/base.py +25 -12
- mcp_use/connectors/http.py +135 -27
- mcp_use/connectors/sandbox.py +12 -3
- mcp_use/connectors/stdio.py +11 -3
- mcp_use/connectors/websocket.py +15 -6
- mcp_use/exceptions.py +31 -0
- mcp_use/middleware/__init__.py +50 -0
- mcp_use/middleware/logging.py +31 -0
- mcp_use/middleware/metrics.py +314 -0
- mcp_use/middleware/middleware.py +262 -0
- mcp_use/task_managers/base.py +13 -23
- mcp_use/task_managers/sse.py +5 -0
- mcp_use/task_managers/streamable_http.py +5 -0
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.12.dist-info}/METADATA +21 -25
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.12.dist-info}/RECORD +28 -19
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.12.dist-info}/WHEEL +0 -0
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.12.dist-info}/entry_points.txt +0 -0
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Metrics middleware for MCP contexts.
|
|
3
|
+
|
|
4
|
+
Classes for collecting comprehensive metrics about MCP context patterns,
|
|
5
|
+
performance, and errors with simple instantiation.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import time
|
|
10
|
+
from collections import Counter, defaultdict
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from .middleware import Middleware, MiddlewareContext, NextFunctionT
|
|
14
|
+
|
|
15
|
+
# Constants for performance thresholds
|
|
16
|
+
SLOW_THRESHOLD_MS = 1000 # 1 second
|
|
17
|
+
FAST_THRESHOLD_MS = 50 # 50ms
|
|
18
|
+
MAX_SLOW_CONTEXTS = 100
|
|
19
|
+
MAX_FAST_CONTEXTS = 100
|
|
20
|
+
MAX_ERROR_TIMESTAMPS = 1000
|
|
21
|
+
MAX_RECENT_ERRORS = 50
|
|
22
|
+
SECONDS_PER_HOUR = 3600
|
|
23
|
+
SECONDS_PER_MINUTE = 60
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MetricsMiddleware(Middleware):
|
|
27
|
+
"""Collects basic metrics about MCP contexts including counts, durations, and errors."""
|
|
28
|
+
|
|
29
|
+
def __init__(self):
|
|
30
|
+
self.metrics = {
|
|
31
|
+
"total_contexts": 0,
|
|
32
|
+
"total_errors": 0,
|
|
33
|
+
"method_counts": defaultdict(int),
|
|
34
|
+
"method_durations": defaultdict(list),
|
|
35
|
+
"active_contexts": 0,
|
|
36
|
+
"start_time": time.time(),
|
|
37
|
+
}
|
|
38
|
+
self.lock = asyncio.Lock()
|
|
39
|
+
|
|
40
|
+
async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
|
|
41
|
+
async with self.lock:
|
|
42
|
+
self.metrics["total_contexts"] += 1
|
|
43
|
+
self.metrics["active_contexts"] += 1
|
|
44
|
+
self.metrics["method_counts"][context.method] = self.metrics["method_counts"].get(context.method, 0) + 1
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
result = await call_next(context)
|
|
48
|
+
duration = time.time() - context.timestamp
|
|
49
|
+
|
|
50
|
+
async with self.lock:
|
|
51
|
+
self.metrics["active_contexts"] -= 1
|
|
52
|
+
self._record_duration(context.method, duration)
|
|
53
|
+
|
|
54
|
+
return result
|
|
55
|
+
except Exception:
|
|
56
|
+
duration = time.time() - context.timestamp
|
|
57
|
+
async with self.lock:
|
|
58
|
+
self.metrics["total_errors"] += 1
|
|
59
|
+
self.metrics["active_contexts"] -= 1
|
|
60
|
+
self._record_duration(context.method, duration)
|
|
61
|
+
raise
|
|
62
|
+
|
|
63
|
+
def _record_duration(self, method: str, duration: float) -> None:
|
|
64
|
+
"""Record duration for a method in a thread-safe manner."""
|
|
65
|
+
self.metrics["method_durations"][method].append(duration)
|
|
66
|
+
|
|
67
|
+
def get_metrics(self) -> dict[str, Any]:
|
|
68
|
+
"""Get current metrics snapshot."""
|
|
69
|
+
uptime = time.time() - self.metrics["start_time"]
|
|
70
|
+
|
|
71
|
+
return {
|
|
72
|
+
**self.metrics,
|
|
73
|
+
"uptime_seconds": uptime,
|
|
74
|
+
"contexts_per_second": self.metrics["total_contexts"] / uptime if uptime > 0 else 0,
|
|
75
|
+
"error_rate": self.metrics["total_errors"] / self.metrics["total_contexts"]
|
|
76
|
+
if self.metrics["total_contexts"] > 0
|
|
77
|
+
else 0,
|
|
78
|
+
"method_avg_duration": {
|
|
79
|
+
method: sum(durations) / len(durations) if durations else 0
|
|
80
|
+
for method, durations in self.metrics["method_durations"].items()
|
|
81
|
+
},
|
|
82
|
+
"method_min_duration": {
|
|
83
|
+
method: min(durations) if durations else 0
|
|
84
|
+
for method, durations in self.metrics["method_durations"].items()
|
|
85
|
+
},
|
|
86
|
+
"method_max_duration": {
|
|
87
|
+
method: max(durations) if durations else 0
|
|
88
|
+
for method, durations in self.metrics["method_durations"].items()
|
|
89
|
+
},
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class PerformanceMetricsMiddleware(Middleware):
|
|
94
|
+
"""Advanced performance metrics including percentiles, throughput, and performance trends."""
|
|
95
|
+
|
|
96
|
+
def __init__(self):
|
|
97
|
+
self.performance_data = {
|
|
98
|
+
"context_times": defaultdict(list),
|
|
99
|
+
"hourly_counts": defaultdict(int),
|
|
100
|
+
"connector_performance": defaultdict(list),
|
|
101
|
+
"slow_contexts": [], # contexts over threshold
|
|
102
|
+
"fast_contexts": [], # Fastest contexts
|
|
103
|
+
"slow_threshold_ms": SLOW_THRESHOLD_MS,
|
|
104
|
+
"fast_threshold_ms": FAST_THRESHOLD_MS,
|
|
105
|
+
}
|
|
106
|
+
self.lock = asyncio.Lock()
|
|
107
|
+
|
|
108
|
+
async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
|
|
109
|
+
start_time = time.time()
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
result = await call_next(context)
|
|
113
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
114
|
+
|
|
115
|
+
async with self.lock:
|
|
116
|
+
# Track performance by method and connector
|
|
117
|
+
self.performance_data["context_times"][context.method].append(duration_ms)
|
|
118
|
+
self.performance_data["connector_performance"][context.connection_id].append(duration_ms)
|
|
119
|
+
|
|
120
|
+
# Track hourly patterns
|
|
121
|
+
hour = int(time.time() // SECONDS_PER_HOUR)
|
|
122
|
+
self.performance_data["hourly_counts"][hour] += 1
|
|
123
|
+
|
|
124
|
+
# Identify slow/fast contexts
|
|
125
|
+
if duration_ms > self.performance_data["slow_threshold_ms"]:
|
|
126
|
+
self.performance_data["slow_contexts"].append(
|
|
127
|
+
{
|
|
128
|
+
"method": context.method,
|
|
129
|
+
"connector": context.connection_id,
|
|
130
|
+
"duration_ms": duration_ms,
|
|
131
|
+
"timestamp": context.timestamp,
|
|
132
|
+
}
|
|
133
|
+
)
|
|
134
|
+
# Keep only last MAX_SLOW_CONTEXTS slow contexts
|
|
135
|
+
if len(self.performance_data["slow_contexts"]) > MAX_SLOW_CONTEXTS:
|
|
136
|
+
self.performance_data["slow_contexts"] = self.performance_data["slow_contexts"][
|
|
137
|
+
-MAX_SLOW_CONTEXTS:
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
if duration_ms < self.performance_data["fast_threshold_ms"]:
|
|
141
|
+
self.performance_data["fast_contexts"].append(
|
|
142
|
+
{
|
|
143
|
+
"method": context.method,
|
|
144
|
+
"connector": context.connection_id,
|
|
145
|
+
"duration_ms": duration_ms,
|
|
146
|
+
"timestamp": context.timestamp,
|
|
147
|
+
}
|
|
148
|
+
)
|
|
149
|
+
# Keep only last MAX_FAST_CONTEXTS fast contexts
|
|
150
|
+
if len(self.performance_data["fast_contexts"]) > MAX_FAST_CONTEXTS:
|
|
151
|
+
self.performance_data["fast_contexts"] = self.performance_data["fast_contexts"][
|
|
152
|
+
-MAX_FAST_CONTEXTS:
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
return result
|
|
156
|
+
|
|
157
|
+
except Exception:
|
|
158
|
+
# Still track duration even for failed contexts
|
|
159
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
160
|
+
async with self.lock:
|
|
161
|
+
self.performance_data["context_times"][context.method].append(duration_ms)
|
|
162
|
+
self.performance_data["connector_performance"][context.connection_id].append(duration_ms)
|
|
163
|
+
raise
|
|
164
|
+
|
|
165
|
+
def get_performance_metrics(self) -> dict[str, Any]:
|
|
166
|
+
"""Get detailed performance statistics."""
|
|
167
|
+
|
|
168
|
+
def calculate_percentiles(values):
|
|
169
|
+
if not values:
|
|
170
|
+
return {"p50": 0, "p90": 0, "p95": 0, "p99": 0}
|
|
171
|
+
sorted_values = sorted(values)
|
|
172
|
+
n = len(sorted_values)
|
|
173
|
+
return {
|
|
174
|
+
"p50": sorted_values[int(n * 0.5)],
|
|
175
|
+
"p90": sorted_values[int(n * 0.9)],
|
|
176
|
+
"p95": sorted_values[int(n * 0.95)],
|
|
177
|
+
"p99": sorted_values[int(n * 0.99)],
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
method_stats = {}
|
|
181
|
+
for method, times in self.performance_data["context_times"].items():
|
|
182
|
+
if times:
|
|
183
|
+
method_stats[method] = {
|
|
184
|
+
"count": len(times),
|
|
185
|
+
"avg_ms": sum(times) / len(times),
|
|
186
|
+
"min_ms": min(times),
|
|
187
|
+
"max_ms": max(times),
|
|
188
|
+
**calculate_percentiles(times),
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
connector_stats = {}
|
|
192
|
+
for connector, times in self.performance_data["connector_performance"].items():
|
|
193
|
+
if times:
|
|
194
|
+
connector_stats[connector] = {
|
|
195
|
+
"count": len(times),
|
|
196
|
+
"avg_ms": sum(times) / len(times),
|
|
197
|
+
"min_ms": min(times),
|
|
198
|
+
"max_ms": max(times),
|
|
199
|
+
**calculate_percentiles(times),
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
return {
|
|
203
|
+
"method_performance": method_stats,
|
|
204
|
+
"connector_performance": connector_stats,
|
|
205
|
+
"slow_contexts": self.performance_data["slow_contexts"][-10:], # Last 10 slow contexts
|
|
206
|
+
"fast_contexts": self.performance_data["fast_contexts"][-10:], # Last 10 fast contexts
|
|
207
|
+
"hourly_distribution": dict(self.performance_data["hourly_counts"]),
|
|
208
|
+
"thresholds": {
|
|
209
|
+
"slow_ms": self.performance_data["slow_threshold_ms"],
|
|
210
|
+
"fast_ms": self.performance_data["fast_threshold_ms"],
|
|
211
|
+
},
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class ErrorTrackingMiddleware(Middleware):
|
|
216
|
+
"""Error tracking and analysis middleware for detailed error analytics."""
|
|
217
|
+
|
|
218
|
+
def __init__(self):
|
|
219
|
+
self.error_data = {
|
|
220
|
+
"error_counts": Counter(),
|
|
221
|
+
"error_by_method": defaultdict(Counter),
|
|
222
|
+
"error_by_connector": defaultdict(Counter),
|
|
223
|
+
"recent_errors": [],
|
|
224
|
+
"error_timestamps": [],
|
|
225
|
+
}
|
|
226
|
+
self.lock = asyncio.Lock()
|
|
227
|
+
|
|
228
|
+
async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
|
|
229
|
+
try:
|
|
230
|
+
return await call_next(context)
|
|
231
|
+
|
|
232
|
+
except Exception as e:
|
|
233
|
+
async with self.lock:
|
|
234
|
+
error_type = type(e).__name__
|
|
235
|
+
error_msg = str(e)
|
|
236
|
+
|
|
237
|
+
# Track error patterns
|
|
238
|
+
self.error_data["error_counts"][error_type] += 1
|
|
239
|
+
self.error_data["error_by_method"][context.method][error_type] += 1
|
|
240
|
+
self.error_data["error_by_connector"][context.connection_id][error_type] += 1
|
|
241
|
+
self.error_data["error_timestamps"].append(time.time())
|
|
242
|
+
|
|
243
|
+
# Keep recent errors for analysis
|
|
244
|
+
error_info = {
|
|
245
|
+
"timestamp": context.timestamp,
|
|
246
|
+
"method": context.method,
|
|
247
|
+
"connector": context.connection_id,
|
|
248
|
+
"error_type": error_type,
|
|
249
|
+
"error_message": error_msg,
|
|
250
|
+
"context_id": context.id,
|
|
251
|
+
}
|
|
252
|
+
self.error_data["recent_errors"].append(error_info)
|
|
253
|
+
|
|
254
|
+
# Keep only last MAX_RECENT_ERRORS errors
|
|
255
|
+
if len(self.error_data["recent_errors"]) > MAX_RECENT_ERRORS:
|
|
256
|
+
self.error_data["recent_errors"] = self.error_data["recent_errors"][-MAX_RECENT_ERRORS:]
|
|
257
|
+
|
|
258
|
+
# Keep only last MAX_ERROR_TIMESTAMPS timestamps
|
|
259
|
+
if len(self.error_data["error_timestamps"]) > MAX_ERROR_TIMESTAMPS:
|
|
260
|
+
self.error_data["error_timestamps"] = self.error_data["error_timestamps"][-MAX_ERROR_TIMESTAMPS:]
|
|
261
|
+
|
|
262
|
+
raise # Re-raise the error
|
|
263
|
+
|
|
264
|
+
def get_error_analytics(self) -> dict[str, Any]:
|
|
265
|
+
"""Get detailed error analytics."""
|
|
266
|
+
|
|
267
|
+
# Calculate error rate over time windows
|
|
268
|
+
now = time.time()
|
|
269
|
+
recent_errors = [t for t in self.error_data["error_timestamps"] if now - t < SECONDS_PER_HOUR] # Last hour
|
|
270
|
+
very_recent_errors = [
|
|
271
|
+
t for t in self.error_data["error_timestamps"] if now - t < 5 * SECONDS_PER_MINUTE
|
|
272
|
+
] # Last 5 min
|
|
273
|
+
|
|
274
|
+
return {
|
|
275
|
+
"total_errors": sum(self.error_data["error_counts"].values()),
|
|
276
|
+
"error_types": dict(self.error_data["error_counts"]),
|
|
277
|
+
"errors_by_method": {method: dict(errors) for method, errors in self.error_data["error_by_method"].items()},
|
|
278
|
+
"errors_by_connector": {
|
|
279
|
+
connector: dict(errors) for connector, errors in self.error_data["error_by_connector"].items()
|
|
280
|
+
},
|
|
281
|
+
"recent_errors": self.error_data["recent_errors"][-10:], # Last 10 errors
|
|
282
|
+
"error_rate_last_hour": len(recent_errors),
|
|
283
|
+
"error_rate_last_5min": len(very_recent_errors),
|
|
284
|
+
"most_common_errors": self.error_data["error_counts"].most_common(5),
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class CombinedAnalyticsMiddleware(Middleware):
|
|
289
|
+
"""Comprehensive middleware combining metrics, performance, and error tracking."""
|
|
290
|
+
|
|
291
|
+
def __init__(self):
|
|
292
|
+
self.metrics_mw = MetricsMiddleware()
|
|
293
|
+
self.perf_mw = PerformanceMetricsMiddleware()
|
|
294
|
+
self.error_mw = ErrorTrackingMiddleware()
|
|
295
|
+
|
|
296
|
+
async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
|
|
297
|
+
# Chain the middleware in the desired order: Metrics -> Errors -> Performance
|
|
298
|
+
async def chain(ctx):
|
|
299
|
+
# The final call in the chain is the original `call_next`
|
|
300
|
+
return await self.perf_mw.on_request(ctx, call_next)
|
|
301
|
+
|
|
302
|
+
async def error_chain(ctx):
|
|
303
|
+
return await self.error_mw.on_request(ctx, chain)
|
|
304
|
+
|
|
305
|
+
return await self.metrics_mw.on_request(context, error_chain)
|
|
306
|
+
|
|
307
|
+
def get_combined_analytics(self) -> dict[str, Any]:
|
|
308
|
+
"""Get all analytics data in one comprehensive report."""
|
|
309
|
+
return {
|
|
310
|
+
"metrics": self.metrics_mw.get_metrics(),
|
|
311
|
+
"performance": self.perf_mw.get_performance_metrics(),
|
|
312
|
+
"errors": self.error_mw.get_error_analytics(),
|
|
313
|
+
"generated_at": time.time(),
|
|
314
|
+
}
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core middleware system for MCP requests.
|
|
3
|
+
|
|
4
|
+
This module provides a robust and extensible middleware architecture:
|
|
5
|
+
- A typed MiddlewareContext to carry request data.
|
|
6
|
+
- A Middleware base class with a dispatcher that routes to strongly-typed hooks.
|
|
7
|
+
- A MiddlewareManager to build and execute the processing chain.
|
|
8
|
+
- A CallbackClientSession that acts as an adapter, creating the initial context
|
|
9
|
+
without requiring changes to upstream callers like HttpConnector.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import time
|
|
13
|
+
import uuid
|
|
14
|
+
from collections.abc import Callable
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from functools import partial
|
|
17
|
+
from typing import Any, Generic, Protocol, TypeVar
|
|
18
|
+
|
|
19
|
+
from mcp import ClientSession
|
|
20
|
+
from mcp.types import (
|
|
21
|
+
CallToolRequestParams,
|
|
22
|
+
CallToolResult,
|
|
23
|
+
GetPromptRequestParams,
|
|
24
|
+
GetPromptResult,
|
|
25
|
+
InitializeRequestParams,
|
|
26
|
+
InitializeResult,
|
|
27
|
+
JSONRPCResponse,
|
|
28
|
+
ListPromptsRequest,
|
|
29
|
+
ListPromptsResult,
|
|
30
|
+
ListResourcesRequest,
|
|
31
|
+
ListResourcesResult,
|
|
32
|
+
ListToolsRequest,
|
|
33
|
+
ListToolsResult,
|
|
34
|
+
ReadResourceRequestParams,
|
|
35
|
+
ReadResourceResult,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# Generig TypeVars for context and results
|
|
39
|
+
T = TypeVar("T")
|
|
40
|
+
R = TypeVar("R", covariant=True)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class MiddlewareContext(Generic[T]):
|
|
45
|
+
"""Unified, typed context for all middleware operations."""
|
|
46
|
+
|
|
47
|
+
id: str
|
|
48
|
+
method: str # The JSON-RPC method name, e.g., "tools/call"
|
|
49
|
+
params: T # The typed parameters for the method
|
|
50
|
+
connection_id: str
|
|
51
|
+
timestamp: float
|
|
52
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass
|
|
56
|
+
class MCPResponseContext:
|
|
57
|
+
"""Extended context for MCP responses with middleware metadata."""
|
|
58
|
+
|
|
59
|
+
request_id: str
|
|
60
|
+
result: Any
|
|
61
|
+
error: Exception | None
|
|
62
|
+
duration: float
|
|
63
|
+
metadata: dict[str, Any]
|
|
64
|
+
jsonrpc_response: JSONRPCResponse | None = None
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def create(cls, request_id: str, result: Any = None, error: Exception = None) -> "MCPResponseContext":
|
|
68
|
+
return cls(request_id=request_id, result=result, error=error, duration=0.0, metadata={})
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# Protocol definition for middleware
|
|
72
|
+
class NextFunctionT(Protocol[T, R]):
|
|
73
|
+
"""Protocol for the `call_next` function passed to middleware."""
|
|
74
|
+
|
|
75
|
+
async def __call__(self, context: MiddlewareContext[T]) -> R: ...
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class Middleware:
|
|
79
|
+
"""Base class for middlewares with hooks."""
|
|
80
|
+
|
|
81
|
+
async def __call__(self, context: MiddlewareContext[T], call_next: NextFunctionT[T, Any]) -> Any:
|
|
82
|
+
"""Main entry point that orchestrates the chain"""
|
|
83
|
+
handler_chain = await self._dispatch_handler(context, call_next)
|
|
84
|
+
return await handler_chain(context)
|
|
85
|
+
|
|
86
|
+
async def _dispatch_handler(
|
|
87
|
+
self, context: MiddlewareContext[Any], call_next: NextFunctionT[Any, Any]
|
|
88
|
+
) -> NextFunctionT[Any, Any]:
|
|
89
|
+
"""Build a chain of handlers"""
|
|
90
|
+
handler = call_next
|
|
91
|
+
|
|
92
|
+
method_map = {
|
|
93
|
+
"initialize": self.on_initialize,
|
|
94
|
+
"tools/call": self.on_call_tool,
|
|
95
|
+
"tools/list": self.on_list_tools,
|
|
96
|
+
"resources/list": self.on_list_resources,
|
|
97
|
+
"resources/read": self.on_read_resource,
|
|
98
|
+
"prompts/list": self.on_list_prompts,
|
|
99
|
+
"prompts/get": self.on_get_prompt,
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
if hook := method_map.get(context.method):
|
|
103
|
+
handler = partial(hook, call_next=handler)
|
|
104
|
+
|
|
105
|
+
# We can assume that all intercepted calls are requests
|
|
106
|
+
handler = partial(self.on_request, call_next=handler)
|
|
107
|
+
|
|
108
|
+
return handler
|
|
109
|
+
|
|
110
|
+
# Default implementations for all hooks
|
|
111
|
+
async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
|
|
112
|
+
return await call_next(context)
|
|
113
|
+
|
|
114
|
+
async def on_initialize(
|
|
115
|
+
self, context: MiddlewareContext[InitializeRequestParams], call_next: NextFunctionT
|
|
116
|
+
) -> InitializeResult:
|
|
117
|
+
return await call_next(context)
|
|
118
|
+
|
|
119
|
+
async def on_call_tool(
|
|
120
|
+
self, context: MiddlewareContext[CallToolRequestParams], call_next: NextFunctionT
|
|
121
|
+
) -> CallToolResult:
|
|
122
|
+
return await call_next(context)
|
|
123
|
+
|
|
124
|
+
async def on_read_resource(
|
|
125
|
+
self, context: MiddlewareContext[ReadResourceRequestParams], call_next: NextFunctionT
|
|
126
|
+
) -> ReadResourceResult:
|
|
127
|
+
return await call_next(context)
|
|
128
|
+
|
|
129
|
+
async def on_get_prompt(
|
|
130
|
+
self, context: MiddlewareContext[GetPromptRequestParams], call_next: NextFunctionT
|
|
131
|
+
) -> GetPromptResult:
|
|
132
|
+
return await call_next(context)
|
|
133
|
+
|
|
134
|
+
async def on_list_tools(
|
|
135
|
+
self, context: MiddlewareContext[ListToolsRequest], call_next: NextFunctionT
|
|
136
|
+
) -> ListToolsResult:
|
|
137
|
+
return await call_next(context)
|
|
138
|
+
|
|
139
|
+
async def on_list_resources(
|
|
140
|
+
self, context: MiddlewareContext[ListResourcesRequest], call_next: NextFunctionT
|
|
141
|
+
) -> ListResourcesResult:
|
|
142
|
+
return await call_next(context)
|
|
143
|
+
|
|
144
|
+
async def on_list_prompts(
|
|
145
|
+
self, context: MiddlewareContext[ListPromptsRequest], call_next: NextFunctionT
|
|
146
|
+
) -> ListPromptsResult:
|
|
147
|
+
return await call_next(context)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class MiddlewareManager:
|
|
151
|
+
"""Manages middleware callbacks for MCP requests."""
|
|
152
|
+
|
|
153
|
+
def __init__(self):
|
|
154
|
+
self.middlewares: list[Middleware] = []
|
|
155
|
+
|
|
156
|
+
def add_middleware(self, callback: Middleware) -> None:
|
|
157
|
+
"""Add a middleware callback."""
|
|
158
|
+
self.middlewares.append(callback)
|
|
159
|
+
|
|
160
|
+
async def process_request(self, context: MiddlewareContext, original_call: Callable) -> MCPResponseContext:
|
|
161
|
+
"""
|
|
162
|
+
Runs the full middleware chain, captures timing and errors,
|
|
163
|
+
and returns a structured MCPResponseContext.
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
try:
|
|
167
|
+
# Chain middleware callbacks
|
|
168
|
+
async def execute_call(_: MiddlewareContext) -> Any:
|
|
169
|
+
return await original_call()
|
|
170
|
+
|
|
171
|
+
call_chain = execute_call
|
|
172
|
+
for middleware in reversed(self.middlewares):
|
|
173
|
+
call_chain = partial(middleware, call_next=call_chain)
|
|
174
|
+
|
|
175
|
+
# Execute the chain
|
|
176
|
+
start_time = time.time()
|
|
177
|
+
|
|
178
|
+
# The result of the chain is the reaw result (e.g., CallToolResult)
|
|
179
|
+
raw_result = await call_chain(context)
|
|
180
|
+
|
|
181
|
+
# Success, now wrap the result in response context
|
|
182
|
+
duration = time.time() - start_time
|
|
183
|
+
|
|
184
|
+
response = MCPResponseContext.create(request_id=context.id, result=raw_result)
|
|
185
|
+
response.duration = duration
|
|
186
|
+
return response
|
|
187
|
+
|
|
188
|
+
except Exception as error:
|
|
189
|
+
duration = time.time() - context.timestamp
|
|
190
|
+
response = MCPResponseContext.create(request_id=context.id, error=error)
|
|
191
|
+
response.duration = duration
|
|
192
|
+
return response
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class CallbackClientSession:
|
|
196
|
+
"""ClientSession wrapper that uses callback-based middleware."""
|
|
197
|
+
|
|
198
|
+
def __init__(self, client_session: ClientSession, connector_id: str, middleware_manager: MiddlewareManager):
|
|
199
|
+
self._client_session = client_session
|
|
200
|
+
self.connector_id = connector_id
|
|
201
|
+
self.middleware_manager = middleware_manager
|
|
202
|
+
|
|
203
|
+
async def _intercept_call(self, method_name: str, params: Any, original_call: Callable) -> Any:
|
|
204
|
+
"""
|
|
205
|
+
Creates the context, runs it through the manager, and unwraps the final response.
|
|
206
|
+
"""
|
|
207
|
+
context = MiddlewareContext(
|
|
208
|
+
id=str(uuid.uuid4()),
|
|
209
|
+
method=method_name,
|
|
210
|
+
params=params,
|
|
211
|
+
connection_id=self.connector_id,
|
|
212
|
+
timestamp=time.time(),
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# This now returns a rich MCPResponseContext
|
|
216
|
+
response_context = await self.middleware_manager.process_request(context, original_call)
|
|
217
|
+
|
|
218
|
+
# If there is an error, return it
|
|
219
|
+
if response_context.error:
|
|
220
|
+
raise response_context.error
|
|
221
|
+
|
|
222
|
+
return response_context.result
|
|
223
|
+
|
|
224
|
+
# Wrap all MCP methods with specific params
|
|
225
|
+
async def initialize(self, *args, **kwargs) -> InitializeResult:
|
|
226
|
+
return await self._intercept_call("initialize", None, lambda: self._client_session.initialize(*args, **kwargs))
|
|
227
|
+
|
|
228
|
+
# List requests usually don't have parameters
|
|
229
|
+
async def list_tools(self, *args, **kwargs) -> ListToolsResult:
|
|
230
|
+
return await self._intercept_call("tools/list", None, lambda: self._client_session.list_tools(*args, **kwargs))
|
|
231
|
+
|
|
232
|
+
async def call_tool(self, name: str, arguments: dict[str, Any], *args, **kwargs) -> CallToolResult:
|
|
233
|
+
params = CallToolRequestParams(name=name, arguments=arguments)
|
|
234
|
+
return await self._intercept_call(
|
|
235
|
+
"tools/call", params, lambda: self._client_session.call_tool(name, arguments, *args, **kwargs)
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
async def list_resources(self, *args, **kwargs) -> ListResourcesResult:
|
|
239
|
+
return await self._intercept_call(
|
|
240
|
+
"resources/list", None, lambda: self._client_session.list_resources(*args, **kwargs)
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
async def read_resource(self, uri: str, *args, **kwargs) -> ReadResourceResult:
|
|
244
|
+
params = ReadResourceRequestParams(uri=uri)
|
|
245
|
+
return await self._intercept_call(
|
|
246
|
+
"resources/read", params, lambda: self._client_session.read_resource(uri, *args, **kwargs)
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
async def list_prompts(self, *args, **kwargs) -> ListPromptsResult:
|
|
250
|
+
return await self._intercept_call(
|
|
251
|
+
"prompts/list", None, lambda: self._client_session.list_prompts(*args, **kwargs)
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
async def get_prompt(self, name: str, *args, **kwargs) -> GetPromptResult:
|
|
255
|
+
params = GetPromptRequestParams(name=name, **kwargs)
|
|
256
|
+
return await self._intercept_call(
|
|
257
|
+
"prompts/get", params, lambda: self._client_session.get_prompt(name, *args, **kwargs)
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def __getattr__(self, name: str) -> Any:
|
|
261
|
+
"""Delegate other attributes to the wrapped session."""
|
|
262
|
+
return getattr(self._client_session, name)
|
mcp_use/task_managers/base.py
CHANGED
|
@@ -22,13 +22,14 @@ class ConnectionManager(Generic[T], ABC):
|
|
|
22
22
|
used with MCP connectors.
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
|
-
def __init__(self):
|
|
25
|
+
def __init__(self) -> None:
|
|
26
26
|
"""Initialize a new connection manager."""
|
|
27
27
|
self._ready_event = asyncio.Event()
|
|
28
28
|
self._done_event = asyncio.Event()
|
|
29
|
+
self._stop_event = asyncio.Event()
|
|
29
30
|
self._exception: Exception | None = None
|
|
30
31
|
self._connection: T | None = None
|
|
31
|
-
self._task: asyncio.Task | None = None
|
|
32
|
+
self._task: asyncio.Task[None] | None = None
|
|
32
33
|
|
|
33
34
|
@abstractmethod
|
|
34
35
|
async def _establish_connection(self) -> T:
|
|
@@ -86,20 +87,15 @@ class ConnectionManager(Generic[T], ABC):
|
|
|
86
87
|
|
|
87
88
|
async def stop(self) -> None:
|
|
88
89
|
"""Stop the connection manager and close the connection."""
|
|
90
|
+
# Signal stop to the connection task instead of cancelling it, avoids
|
|
91
|
+
# propagating CancelledError to unrelated tasks.
|
|
89
92
|
if self._task and not self._task.done():
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
await self._task
|
|
97
|
-
except asyncio.CancelledError:
|
|
98
|
-
logger.debug(f"{self.__class__.__name__} task cancelled successfully")
|
|
99
|
-
except Exception as e:
|
|
100
|
-
logger.warning(f"Error stopping {self.__class__.__name__} task: {e}")
|
|
101
|
-
|
|
102
|
-
# Wait for the connection to be done
|
|
93
|
+
logger.debug(f"Signaling stop to {self.__class__.__name__} task")
|
|
94
|
+
self._stop_event.set()
|
|
95
|
+
# Wait for it to finish gracefully
|
|
96
|
+
await self._task
|
|
97
|
+
|
|
98
|
+
# Ensure cleanup completed
|
|
103
99
|
await self._done_event.wait()
|
|
104
100
|
logger.debug(f"{self.__class__.__name__} task completed")
|
|
105
101
|
|
|
@@ -125,14 +121,8 @@ class ConnectionManager(Generic[T], ABC):
|
|
|
125
121
|
# Signal that the connection is ready
|
|
126
122
|
self._ready_event.set()
|
|
127
123
|
|
|
128
|
-
# Wait
|
|
129
|
-
|
|
130
|
-
# This keeps the connection open until cancelled
|
|
131
|
-
await asyncio.Event().wait()
|
|
132
|
-
except asyncio.CancelledError:
|
|
133
|
-
# Expected when stopping
|
|
134
|
-
logger.debug(f"{self.__class__.__name__} task received cancellation")
|
|
135
|
-
pass
|
|
124
|
+
# Wait until stop is requested
|
|
125
|
+
await self._stop_event.wait()
|
|
136
126
|
|
|
137
127
|
except Exception as e:
|
|
138
128
|
# Store the exception
|
mcp_use/task_managers/sse.py
CHANGED
|
@@ -7,6 +7,7 @@ that ensures proper task isolation and resource cleanup.
|
|
|
7
7
|
|
|
8
8
|
from typing import Any
|
|
9
9
|
|
|
10
|
+
import httpx
|
|
10
11
|
from mcp.client.sse import sse_client
|
|
11
12
|
|
|
12
13
|
from ..logging import logger
|
|
@@ -27,6 +28,7 @@ class SseConnectionManager(ConnectionManager[tuple[Any, Any]]):
|
|
|
27
28
|
headers: dict[str, str] | None = None,
|
|
28
29
|
timeout: float = 5,
|
|
29
30
|
sse_read_timeout: float = 60 * 5,
|
|
31
|
+
auth: httpx.Auth | None = None,
|
|
30
32
|
):
|
|
31
33
|
"""Initialize a new SSE connection manager.
|
|
32
34
|
|
|
@@ -35,12 +37,14 @@ class SseConnectionManager(ConnectionManager[tuple[Any, Any]]):
|
|
|
35
37
|
headers: Optional HTTP headers
|
|
36
38
|
timeout: Timeout for HTTP operations in seconds
|
|
37
39
|
sse_read_timeout: Timeout for SSE read operations in seconds
|
|
40
|
+
auth: Optional httpx.Auth instance for authentication
|
|
38
41
|
"""
|
|
39
42
|
super().__init__()
|
|
40
43
|
self.url = url
|
|
41
44
|
self.headers = headers or {}
|
|
42
45
|
self.timeout = timeout
|
|
43
46
|
self.sse_read_timeout = sse_read_timeout
|
|
47
|
+
self.auth = auth
|
|
44
48
|
self._sse_ctx = None
|
|
45
49
|
|
|
46
50
|
async def _establish_connection(self) -> tuple[Any, Any]:
|
|
@@ -58,6 +62,7 @@ class SseConnectionManager(ConnectionManager[tuple[Any, Any]]):
|
|
|
58
62
|
headers=self.headers,
|
|
59
63
|
timeout=self.timeout,
|
|
60
64
|
sse_read_timeout=self.sse_read_timeout,
|
|
65
|
+
auth=self.auth,
|
|
61
66
|
)
|
|
62
67
|
|
|
63
68
|
# Enter the context manager
|