mcp-use 1.3.11__py3-none-any.whl → 1.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mcp-use might be problematic. Click here for more details.
- mcp_use/__init__.py +1 -1
- mcp_use/adapters/.deprecated +0 -0
- mcp_use/adapters/__init__.py +18 -7
- mcp_use/adapters/base.py +12 -185
- mcp_use/adapters/langchain_adapter.py +12 -264
- mcp_use/agents/adapters/__init__.py +10 -0
- mcp_use/agents/adapters/base.py +193 -0
- mcp_use/agents/adapters/langchain_adapter.py +228 -0
- mcp_use/agents/base.py +1 -1
- mcp_use/agents/managers/__init__.py +19 -0
- mcp_use/agents/managers/base.py +36 -0
- mcp_use/agents/managers/server_manager.py +131 -0
- mcp_use/agents/managers/tools/__init__.py +15 -0
- mcp_use/agents/managers/tools/base_tool.py +19 -0
- mcp_use/agents/managers/tools/connect_server.py +69 -0
- mcp_use/agents/managers/tools/disconnect_server.py +43 -0
- mcp_use/agents/managers/tools/get_active_server.py +29 -0
- mcp_use/agents/managers/tools/list_servers_tool.py +53 -0
- mcp_use/agents/managers/tools/search_tools.py +328 -0
- mcp_use/agents/mcpagent.py +88 -47
- mcp_use/agents/remote.py +168 -129
- mcp_use/auth/.deprecated +0 -0
- mcp_use/auth/__init__.py +19 -4
- mcp_use/auth/bearer.py +11 -12
- mcp_use/auth/oauth.py +11 -620
- mcp_use/auth/oauth_callback.py +16 -207
- mcp_use/client/__init__.py +1 -0
- mcp_use/client/auth/__init__.py +6 -0
- mcp_use/client/auth/bearer.py +23 -0
- mcp_use/client/auth/oauth.py +629 -0
- mcp_use/client/auth/oauth_callback.py +214 -0
- mcp_use/client/client.py +356 -0
- mcp_use/client/config.py +106 -0
- mcp_use/client/connectors/__init__.py +20 -0
- mcp_use/client/connectors/base.py +470 -0
- mcp_use/client/connectors/http.py +304 -0
- mcp_use/client/connectors/sandbox.py +332 -0
- mcp_use/client/connectors/stdio.py +109 -0
- mcp_use/client/connectors/utils.py +13 -0
- mcp_use/client/connectors/websocket.py +257 -0
- mcp_use/client/exceptions.py +31 -0
- mcp_use/client/middleware/__init__.py +50 -0
- mcp_use/client/middleware/logging.py +31 -0
- mcp_use/client/middleware/metrics.py +314 -0
- mcp_use/client/middleware/middleware.py +266 -0
- mcp_use/client/session.py +162 -0
- mcp_use/client/task_managers/__init__.py +20 -0
- mcp_use/client/task_managers/base.py +145 -0
- mcp_use/client/task_managers/sse.py +84 -0
- mcp_use/client/task_managers/stdio.py +69 -0
- mcp_use/client/task_managers/streamable_http.py +86 -0
- mcp_use/client/task_managers/websocket.py +68 -0
- mcp_use/client.py +12 -320
- mcp_use/config.py +20 -92
- mcp_use/connectors/.deprecated +0 -0
- mcp_use/connectors/__init__.py +46 -20
- mcp_use/connectors/base.py +12 -447
- mcp_use/connectors/http.py +13 -288
- mcp_use/connectors/sandbox.py +13 -297
- mcp_use/connectors/stdio.py +13 -96
- mcp_use/connectors/utils.py +15 -8
- mcp_use/connectors/websocket.py +13 -252
- mcp_use/exceptions.py +33 -18
- mcp_use/managers/.deprecated +0 -0
- mcp_use/managers/__init__.py +56 -17
- mcp_use/managers/base.py +13 -31
- mcp_use/managers/server_manager.py +13 -119
- mcp_use/managers/tools/__init__.py +45 -15
- mcp_use/managers/tools/base_tool.py +5 -16
- mcp_use/managers/tools/connect_server.py +5 -67
- mcp_use/managers/tools/disconnect_server.py +5 -41
- mcp_use/managers/tools/get_active_server.py +5 -26
- mcp_use/managers/tools/list_servers_tool.py +5 -51
- mcp_use/managers/tools/search_tools.py +17 -321
- mcp_use/middleware/.deprecated +0 -0
- mcp_use/middleware/__init__.py +89 -0
- mcp_use/middleware/logging.py +19 -0
- mcp_use/middleware/metrics.py +41 -0
- mcp_use/middleware/middleware.py +55 -0
- mcp_use/session.py +13 -149
- mcp_use/task_managers/.deprecated +0 -0
- mcp_use/task_managers/__init__.py +48 -20
- mcp_use/task_managers/base.py +13 -140
- mcp_use/task_managers/sse.py +13 -79
- mcp_use/task_managers/stdio.py +13 -64
- mcp_use/task_managers/streamable_http.py +15 -81
- mcp_use/task_managers/websocket.py +13 -63
- mcp_use/telemetry/events.py +58 -0
- mcp_use/telemetry/telemetry.py +71 -1
- mcp_use/types/.deprecated +0 -0
- mcp_use/types/sandbox.py +13 -18
- {mcp_use-1.3.11.dist-info → mcp_use-1.3.13.dist-info}/METADATA +66 -40
- mcp_use-1.3.13.dist-info/RECORD +109 -0
- mcp_use-1.3.11.dist-info/RECORD +0 -60
- mcp_use-1.3.11.dist-info/licenses/LICENSE +0 -21
- /mcp_use/{observability → agents/observability}/__init__.py +0 -0
- /mcp_use/{observability → agents/observability}/callbacks_manager.py +0 -0
- /mcp_use/{observability → agents/observability}/laminar.py +0 -0
- /mcp_use/{observability → agents/observability}/langfuse.py +0 -0
- {mcp_use-1.3.11.dist-info → mcp_use-1.3.13.dist-info}/WHEEL +0 -0
- {mcp_use-1.3.11.dist-info → mcp_use-1.3.13.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Metrics middleware for MCP contexts.
|
|
3
|
+
|
|
4
|
+
Classes for collecting comprehensive metrics about MCP context patterns,
|
|
5
|
+
performance, and errors with simple instantiation.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import time
|
|
10
|
+
from collections import Counter, defaultdict
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from mcp_use.client.middleware.middleware import Middleware, MiddlewareContext, NextFunctionT
|
|
14
|
+
|
|
15
|
+
# Constants for performance thresholds
|
|
16
|
+
SLOW_THRESHOLD_MS = 1000 # 1 second
|
|
17
|
+
FAST_THRESHOLD_MS = 50 # 50ms
|
|
18
|
+
MAX_SLOW_CONTEXTS = 100
|
|
19
|
+
MAX_FAST_CONTEXTS = 100
|
|
20
|
+
MAX_ERROR_TIMESTAMPS = 1000
|
|
21
|
+
MAX_RECENT_ERRORS = 50
|
|
22
|
+
SECONDS_PER_HOUR = 3600
|
|
23
|
+
SECONDS_PER_MINUTE = 60
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MetricsMiddleware(Middleware):
|
|
27
|
+
"""Collects basic metrics about MCP contexts including counts, durations, and errors."""
|
|
28
|
+
|
|
29
|
+
def __init__(self):
|
|
30
|
+
self.metrics = {
|
|
31
|
+
"total_contexts": 0,
|
|
32
|
+
"total_errors": 0,
|
|
33
|
+
"method_counts": defaultdict(int),
|
|
34
|
+
"method_durations": defaultdict(list),
|
|
35
|
+
"active_contexts": 0,
|
|
36
|
+
"start_time": time.time(),
|
|
37
|
+
}
|
|
38
|
+
self.lock = asyncio.Lock()
|
|
39
|
+
|
|
40
|
+
async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
|
|
41
|
+
async with self.lock:
|
|
42
|
+
self.metrics["total_contexts"] += 1
|
|
43
|
+
self.metrics["active_contexts"] += 1
|
|
44
|
+
self.metrics["method_counts"][context.method] = self.metrics["method_counts"].get(context.method, 0) + 1
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
result = await call_next(context)
|
|
48
|
+
duration = time.time() - context.timestamp
|
|
49
|
+
|
|
50
|
+
async with self.lock:
|
|
51
|
+
self.metrics["active_contexts"] -= 1
|
|
52
|
+
self._record_duration(context.method, duration)
|
|
53
|
+
|
|
54
|
+
return result
|
|
55
|
+
except Exception:
|
|
56
|
+
duration = time.time() - context.timestamp
|
|
57
|
+
async with self.lock:
|
|
58
|
+
self.metrics["total_errors"] += 1
|
|
59
|
+
self.metrics["active_contexts"] -= 1
|
|
60
|
+
self._record_duration(context.method, duration)
|
|
61
|
+
raise
|
|
62
|
+
|
|
63
|
+
def _record_duration(self, method: str, duration: float) -> None:
|
|
64
|
+
"""Record duration for a method in a thread-safe manner."""
|
|
65
|
+
self.metrics["method_durations"][method].append(duration)
|
|
66
|
+
|
|
67
|
+
def get_metrics(self) -> dict[str, Any]:
|
|
68
|
+
"""Get current metrics snapshot."""
|
|
69
|
+
uptime = time.time() - self.metrics["start_time"]
|
|
70
|
+
|
|
71
|
+
return {
|
|
72
|
+
**self.metrics,
|
|
73
|
+
"uptime_seconds": uptime,
|
|
74
|
+
"contexts_per_second": self.metrics["total_contexts"] / uptime if uptime > 0 else 0,
|
|
75
|
+
"error_rate": self.metrics["total_errors"] / self.metrics["total_contexts"]
|
|
76
|
+
if self.metrics["total_contexts"] > 0
|
|
77
|
+
else 0,
|
|
78
|
+
"method_avg_duration": {
|
|
79
|
+
method: sum(durations) / len(durations) if durations else 0
|
|
80
|
+
for method, durations in self.metrics["method_durations"].items()
|
|
81
|
+
},
|
|
82
|
+
"method_min_duration": {
|
|
83
|
+
method: min(durations) if durations else 0
|
|
84
|
+
for method, durations in self.metrics["method_durations"].items()
|
|
85
|
+
},
|
|
86
|
+
"method_max_duration": {
|
|
87
|
+
method: max(durations) if durations else 0
|
|
88
|
+
for method, durations in self.metrics["method_durations"].items()
|
|
89
|
+
},
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class PerformanceMetricsMiddleware(Middleware):
|
|
94
|
+
"""Advanced performance metrics including percentiles, throughput, and performance trends."""
|
|
95
|
+
|
|
96
|
+
def __init__(self):
|
|
97
|
+
self.performance_data = {
|
|
98
|
+
"context_times": defaultdict(list),
|
|
99
|
+
"hourly_counts": defaultdict(int),
|
|
100
|
+
"connector_performance": defaultdict(list),
|
|
101
|
+
"slow_contexts": [], # contexts over threshold
|
|
102
|
+
"fast_contexts": [], # Fastest contexts
|
|
103
|
+
"slow_threshold_ms": SLOW_THRESHOLD_MS,
|
|
104
|
+
"fast_threshold_ms": FAST_THRESHOLD_MS,
|
|
105
|
+
}
|
|
106
|
+
self.lock = asyncio.Lock()
|
|
107
|
+
|
|
108
|
+
async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
|
|
109
|
+
start_time = time.time()
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
result = await call_next(context)
|
|
113
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
114
|
+
|
|
115
|
+
async with self.lock:
|
|
116
|
+
# Track performance by method and connector
|
|
117
|
+
self.performance_data["context_times"][context.method].append(duration_ms)
|
|
118
|
+
self.performance_data["connector_performance"][context.connection_id].append(duration_ms)
|
|
119
|
+
|
|
120
|
+
# Track hourly patterns
|
|
121
|
+
hour = int(time.time() // SECONDS_PER_HOUR)
|
|
122
|
+
self.performance_data["hourly_counts"][hour] += 1
|
|
123
|
+
|
|
124
|
+
# Identify slow/fast contexts
|
|
125
|
+
if duration_ms > self.performance_data["slow_threshold_ms"]:
|
|
126
|
+
self.performance_data["slow_contexts"].append(
|
|
127
|
+
{
|
|
128
|
+
"method": context.method,
|
|
129
|
+
"connector": context.connection_id,
|
|
130
|
+
"duration_ms": duration_ms,
|
|
131
|
+
"timestamp": context.timestamp,
|
|
132
|
+
}
|
|
133
|
+
)
|
|
134
|
+
# Keep only last MAX_SLOW_CONTEXTS slow contexts
|
|
135
|
+
if len(self.performance_data["slow_contexts"]) > MAX_SLOW_CONTEXTS:
|
|
136
|
+
self.performance_data["slow_contexts"] = self.performance_data["slow_contexts"][
|
|
137
|
+
-MAX_SLOW_CONTEXTS:
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
if duration_ms < self.performance_data["fast_threshold_ms"]:
|
|
141
|
+
self.performance_data["fast_contexts"].append(
|
|
142
|
+
{
|
|
143
|
+
"method": context.method,
|
|
144
|
+
"connector": context.connection_id,
|
|
145
|
+
"duration_ms": duration_ms,
|
|
146
|
+
"timestamp": context.timestamp,
|
|
147
|
+
}
|
|
148
|
+
)
|
|
149
|
+
# Keep only last MAX_FAST_CONTEXTS fast contexts
|
|
150
|
+
if len(self.performance_data["fast_contexts"]) > MAX_FAST_CONTEXTS:
|
|
151
|
+
self.performance_data["fast_contexts"] = self.performance_data["fast_contexts"][
|
|
152
|
+
-MAX_FAST_CONTEXTS:
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
return result
|
|
156
|
+
|
|
157
|
+
except Exception:
|
|
158
|
+
# Still track duration even for failed contexts
|
|
159
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
160
|
+
async with self.lock:
|
|
161
|
+
self.performance_data["context_times"][context.method].append(duration_ms)
|
|
162
|
+
self.performance_data["connector_performance"][context.connection_id].append(duration_ms)
|
|
163
|
+
raise
|
|
164
|
+
|
|
165
|
+
def get_performance_metrics(self) -> dict[str, Any]:
|
|
166
|
+
"""Get detailed performance statistics."""
|
|
167
|
+
|
|
168
|
+
def calculate_percentiles(values):
|
|
169
|
+
if not values:
|
|
170
|
+
return {"p50": 0, "p90": 0, "p95": 0, "p99": 0}
|
|
171
|
+
sorted_values = sorted(values)
|
|
172
|
+
n = len(sorted_values)
|
|
173
|
+
return {
|
|
174
|
+
"p50": sorted_values[int(n * 0.5)],
|
|
175
|
+
"p90": sorted_values[int(n * 0.9)],
|
|
176
|
+
"p95": sorted_values[int(n * 0.95)],
|
|
177
|
+
"p99": sorted_values[int(n * 0.99)],
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
method_stats = {}
|
|
181
|
+
for method, times in self.performance_data["context_times"].items():
|
|
182
|
+
if times:
|
|
183
|
+
method_stats[method] = {
|
|
184
|
+
"count": len(times),
|
|
185
|
+
"avg_ms": sum(times) / len(times),
|
|
186
|
+
"min_ms": min(times),
|
|
187
|
+
"max_ms": max(times),
|
|
188
|
+
**calculate_percentiles(times),
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
connector_stats = {}
|
|
192
|
+
for connector, times in self.performance_data["connector_performance"].items():
|
|
193
|
+
if times:
|
|
194
|
+
connector_stats[connector] = {
|
|
195
|
+
"count": len(times),
|
|
196
|
+
"avg_ms": sum(times) / len(times),
|
|
197
|
+
"min_ms": min(times),
|
|
198
|
+
"max_ms": max(times),
|
|
199
|
+
**calculate_percentiles(times),
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
return {
|
|
203
|
+
"method_performance": method_stats,
|
|
204
|
+
"connector_performance": connector_stats,
|
|
205
|
+
"slow_contexts": self.performance_data["slow_contexts"][-10:], # Last 10 slow contexts
|
|
206
|
+
"fast_contexts": self.performance_data["fast_contexts"][-10:], # Last 10 fast contexts
|
|
207
|
+
"hourly_distribution": dict(self.performance_data["hourly_counts"]),
|
|
208
|
+
"thresholds": {
|
|
209
|
+
"slow_ms": self.performance_data["slow_threshold_ms"],
|
|
210
|
+
"fast_ms": self.performance_data["fast_threshold_ms"],
|
|
211
|
+
},
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class ErrorTrackingMiddleware(Middleware):
|
|
216
|
+
"""Error tracking and analysis middleware for detailed error analytics."""
|
|
217
|
+
|
|
218
|
+
def __init__(self):
|
|
219
|
+
self.error_data = {
|
|
220
|
+
"error_counts": Counter(),
|
|
221
|
+
"error_by_method": defaultdict(Counter),
|
|
222
|
+
"error_by_connector": defaultdict(Counter),
|
|
223
|
+
"recent_errors": [],
|
|
224
|
+
"error_timestamps": [],
|
|
225
|
+
}
|
|
226
|
+
self.lock = asyncio.Lock()
|
|
227
|
+
|
|
228
|
+
async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
|
|
229
|
+
try:
|
|
230
|
+
return await call_next(context)
|
|
231
|
+
|
|
232
|
+
except Exception as e:
|
|
233
|
+
async with self.lock:
|
|
234
|
+
error_type = type(e).__name__
|
|
235
|
+
error_msg = str(e)
|
|
236
|
+
|
|
237
|
+
# Track error patterns
|
|
238
|
+
self.error_data["error_counts"][error_type] += 1
|
|
239
|
+
self.error_data["error_by_method"][context.method][error_type] += 1
|
|
240
|
+
self.error_data["error_by_connector"][context.connection_id][error_type] += 1
|
|
241
|
+
self.error_data["error_timestamps"].append(time.time())
|
|
242
|
+
|
|
243
|
+
# Keep recent errors for analysis
|
|
244
|
+
error_info = {
|
|
245
|
+
"timestamp": context.timestamp,
|
|
246
|
+
"method": context.method,
|
|
247
|
+
"connector": context.connection_id,
|
|
248
|
+
"error_type": error_type,
|
|
249
|
+
"error_message": error_msg,
|
|
250
|
+
"context_id": context.id,
|
|
251
|
+
}
|
|
252
|
+
self.error_data["recent_errors"].append(error_info)
|
|
253
|
+
|
|
254
|
+
# Keep only last MAX_RECENT_ERRORS errors
|
|
255
|
+
if len(self.error_data["recent_errors"]) > MAX_RECENT_ERRORS:
|
|
256
|
+
self.error_data["recent_errors"] = self.error_data["recent_errors"][-MAX_RECENT_ERRORS:]
|
|
257
|
+
|
|
258
|
+
# Keep only last MAX_ERROR_TIMESTAMPS timestamps
|
|
259
|
+
if len(self.error_data["error_timestamps"]) > MAX_ERROR_TIMESTAMPS:
|
|
260
|
+
self.error_data["error_timestamps"] = self.error_data["error_timestamps"][-MAX_ERROR_TIMESTAMPS:]
|
|
261
|
+
|
|
262
|
+
raise # Re-raise the error
|
|
263
|
+
|
|
264
|
+
def get_error_analytics(self) -> dict[str, Any]:
|
|
265
|
+
"""Get detailed error analytics."""
|
|
266
|
+
|
|
267
|
+
# Calculate error rate over time windows
|
|
268
|
+
now = time.time()
|
|
269
|
+
recent_errors = [t for t in self.error_data["error_timestamps"] if now - t < SECONDS_PER_HOUR] # Last hour
|
|
270
|
+
very_recent_errors = [
|
|
271
|
+
t for t in self.error_data["error_timestamps"] if now - t < 5 * SECONDS_PER_MINUTE
|
|
272
|
+
] # Last 5 min
|
|
273
|
+
|
|
274
|
+
return {
|
|
275
|
+
"total_errors": sum(self.error_data["error_counts"].values()),
|
|
276
|
+
"error_types": dict(self.error_data["error_counts"]),
|
|
277
|
+
"errors_by_method": {method: dict(errors) for method, errors in self.error_data["error_by_method"].items()},
|
|
278
|
+
"errors_by_connector": {
|
|
279
|
+
connector: dict(errors) for connector, errors in self.error_data["error_by_connector"].items()
|
|
280
|
+
},
|
|
281
|
+
"recent_errors": self.error_data["recent_errors"][-10:], # Last 10 errors
|
|
282
|
+
"error_rate_last_hour": len(recent_errors),
|
|
283
|
+
"error_rate_last_5min": len(very_recent_errors),
|
|
284
|
+
"most_common_errors": self.error_data["error_counts"].most_common(5),
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class CombinedAnalyticsMiddleware(Middleware):
|
|
289
|
+
"""Comprehensive middleware combining metrics, performance, and error tracking."""
|
|
290
|
+
|
|
291
|
+
def __init__(self):
|
|
292
|
+
self.metrics_mw = MetricsMiddleware()
|
|
293
|
+
self.perf_mw = PerformanceMetricsMiddleware()
|
|
294
|
+
self.error_mw = ErrorTrackingMiddleware()
|
|
295
|
+
|
|
296
|
+
async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
|
|
297
|
+
# Chain the middleware in the desired order: Metrics -> Errors -> Performance
|
|
298
|
+
async def chain(ctx):
|
|
299
|
+
# The final call in the chain is the original `call_next`
|
|
300
|
+
return await self.perf_mw.on_request(ctx, call_next)
|
|
301
|
+
|
|
302
|
+
async def error_chain(ctx):
|
|
303
|
+
return await self.error_mw.on_request(ctx, chain)
|
|
304
|
+
|
|
305
|
+
return await self.metrics_mw.on_request(context, error_chain)
|
|
306
|
+
|
|
307
|
+
def get_combined_analytics(self) -> dict[str, Any]:
|
|
308
|
+
"""Get all analytics data in one comprehensive report."""
|
|
309
|
+
return {
|
|
310
|
+
"metrics": self.metrics_mw.get_metrics(),
|
|
311
|
+
"performance": self.perf_mw.get_performance_metrics(),
|
|
312
|
+
"errors": self.error_mw.get_error_analytics(),
|
|
313
|
+
"generated_at": time.time(),
|
|
314
|
+
}
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core middleware system for MCP requests.
|
|
3
|
+
|
|
4
|
+
This module provides a robust and extensible middleware architecture:
|
|
5
|
+
- A typed MiddlewareContext to carry request data.
|
|
6
|
+
- A Middleware base class with a dispatcher that routes to strongly-typed hooks.
|
|
7
|
+
- A MiddlewareManager to build and execute the processing chain.
|
|
8
|
+
- A CallbackClientSession that acts as an adapter, creating the initial context
|
|
9
|
+
without requiring changes to upstream callers like HttpConnector.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import time
|
|
13
|
+
import uuid
|
|
14
|
+
from collections.abc import Callable
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from functools import partial
|
|
17
|
+
from typing import Any, Generic, Protocol, TypeVar
|
|
18
|
+
|
|
19
|
+
from mcp import ClientSession
|
|
20
|
+
from mcp.types import (
|
|
21
|
+
CallToolRequestParams,
|
|
22
|
+
CallToolResult,
|
|
23
|
+
GetPromptRequestParams,
|
|
24
|
+
GetPromptResult,
|
|
25
|
+
InitializeRequestParams,
|
|
26
|
+
InitializeResult,
|
|
27
|
+
JSONRPCResponse,
|
|
28
|
+
ListPromptsRequest,
|
|
29
|
+
ListPromptsResult,
|
|
30
|
+
ListResourcesRequest,
|
|
31
|
+
ListResourcesResult,
|
|
32
|
+
ListToolsRequest,
|
|
33
|
+
ListToolsResult,
|
|
34
|
+
ReadResourceRequestParams,
|
|
35
|
+
ReadResourceResult,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
from mcp_use.telemetry.telemetry import telemetry
|
|
39
|
+
|
|
40
|
+
# Generig TypeVars for context and results
|
|
41
|
+
T = TypeVar("T")
|
|
42
|
+
R = TypeVar("R", covariant=True)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class MiddlewareContext(Generic[T]):
|
|
47
|
+
"""Unified, typed context for all middleware operations."""
|
|
48
|
+
|
|
49
|
+
id: str
|
|
50
|
+
method: str # The JSON-RPC method name, e.g., "tools/call"
|
|
51
|
+
params: T # The typed parameters for the method
|
|
52
|
+
connection_id: str
|
|
53
|
+
timestamp: float
|
|
54
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class MCPResponseContext:
|
|
59
|
+
"""Extended context for MCP responses with middleware metadata."""
|
|
60
|
+
|
|
61
|
+
request_id: str
|
|
62
|
+
result: Any
|
|
63
|
+
error: Exception | None
|
|
64
|
+
duration: float
|
|
65
|
+
metadata: dict[str, Any]
|
|
66
|
+
jsonrpc_response: JSONRPCResponse | None = None
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def create(cls, request_id: str, result: Any = None, error: Exception = None) -> "MCPResponseContext":
|
|
70
|
+
return cls(request_id=request_id, result=result, error=error, duration=0.0, metadata={})
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# Protocol definition for middleware
|
|
74
|
+
class NextFunctionT(Protocol[T, R]):
|
|
75
|
+
"""Protocol for the `call_next` function passed to middleware."""
|
|
76
|
+
|
|
77
|
+
async def __call__(self, context: MiddlewareContext[T]) -> R: ...
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class Middleware:
|
|
81
|
+
"""Base class for middlewares with hooks."""
|
|
82
|
+
|
|
83
|
+
async def __call__(self, context: MiddlewareContext[T], call_next: NextFunctionT[T, Any]) -> Any:
|
|
84
|
+
"""Main entry point that orchestrates the chain"""
|
|
85
|
+
handler_chain = await self._dispatch_handler(context, call_next)
|
|
86
|
+
return await handler_chain(context)
|
|
87
|
+
|
|
88
|
+
async def _dispatch_handler(
|
|
89
|
+
self, context: MiddlewareContext[Any], call_next: NextFunctionT[Any, Any]
|
|
90
|
+
) -> NextFunctionT[Any, Any]:
|
|
91
|
+
"""Build a chain of handlers"""
|
|
92
|
+
handler = call_next
|
|
93
|
+
|
|
94
|
+
method_map = {
|
|
95
|
+
"initialize": self.on_initialize,
|
|
96
|
+
"tools/call": self.on_call_tool,
|
|
97
|
+
"tools/list": self.on_list_tools,
|
|
98
|
+
"resources/list": self.on_list_resources,
|
|
99
|
+
"resources/read": self.on_read_resource,
|
|
100
|
+
"prompts/list": self.on_list_prompts,
|
|
101
|
+
"prompts/get": self.on_get_prompt,
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
if hook := method_map.get(context.method):
|
|
105
|
+
handler = partial(hook, call_next=handler)
|
|
106
|
+
|
|
107
|
+
# We can assume that all intercepted calls are requests
|
|
108
|
+
handler = partial(self.on_request, call_next=handler)
|
|
109
|
+
|
|
110
|
+
return handler
|
|
111
|
+
|
|
112
|
+
# Default implementations for all hooks
|
|
113
|
+
async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
|
|
114
|
+
return await call_next(context)
|
|
115
|
+
|
|
116
|
+
async def on_initialize(
|
|
117
|
+
self, context: MiddlewareContext[InitializeRequestParams], call_next: NextFunctionT
|
|
118
|
+
) -> InitializeResult:
|
|
119
|
+
return await call_next(context)
|
|
120
|
+
|
|
121
|
+
async def on_call_tool(
|
|
122
|
+
self, context: MiddlewareContext[CallToolRequestParams], call_next: NextFunctionT
|
|
123
|
+
) -> CallToolResult:
|
|
124
|
+
return await call_next(context)
|
|
125
|
+
|
|
126
|
+
async def on_read_resource(
|
|
127
|
+
self, context: MiddlewareContext[ReadResourceRequestParams], call_next: NextFunctionT
|
|
128
|
+
) -> ReadResourceResult:
|
|
129
|
+
return await call_next(context)
|
|
130
|
+
|
|
131
|
+
async def on_get_prompt(
|
|
132
|
+
self, context: MiddlewareContext[GetPromptRequestParams], call_next: NextFunctionT
|
|
133
|
+
) -> GetPromptResult:
|
|
134
|
+
return await call_next(context)
|
|
135
|
+
|
|
136
|
+
async def on_list_tools(
|
|
137
|
+
self, context: MiddlewareContext[ListToolsRequest], call_next: NextFunctionT
|
|
138
|
+
) -> ListToolsResult:
|
|
139
|
+
return await call_next(context)
|
|
140
|
+
|
|
141
|
+
async def on_list_resources(
|
|
142
|
+
self, context: MiddlewareContext[ListResourcesRequest], call_next: NextFunctionT
|
|
143
|
+
) -> ListResourcesResult:
|
|
144
|
+
return await call_next(context)
|
|
145
|
+
|
|
146
|
+
async def on_list_prompts(
|
|
147
|
+
self, context: MiddlewareContext[ListPromptsRequest], call_next: NextFunctionT
|
|
148
|
+
) -> ListPromptsResult:
|
|
149
|
+
return await call_next(context)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class MiddlewareManager:
|
|
153
|
+
"""Manages middleware callbacks for MCP requests."""
|
|
154
|
+
|
|
155
|
+
def __init__(self):
|
|
156
|
+
self.middlewares: list[Middleware] = []
|
|
157
|
+
|
|
158
|
+
@telemetry("middleware_add")
|
|
159
|
+
def add_middleware(self, callback: Middleware) -> None:
|
|
160
|
+
"""Add a middleware callback."""
|
|
161
|
+
self.middlewares.append(callback)
|
|
162
|
+
|
|
163
|
+
@telemetry("middleware_process_request")
|
|
164
|
+
async def process_request(self, context: MiddlewareContext, original_call: Callable) -> MCPResponseContext:
|
|
165
|
+
"""
|
|
166
|
+
Runs the full middleware chain, captures timing and errors,
|
|
167
|
+
and returns a structured MCPResponseContext.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
# Chain middleware callbacks
|
|
172
|
+
async def execute_call(_: MiddlewareContext) -> Any:
|
|
173
|
+
return await original_call()
|
|
174
|
+
|
|
175
|
+
call_chain = execute_call
|
|
176
|
+
for middleware in reversed(self.middlewares):
|
|
177
|
+
call_chain = partial(middleware, call_next=call_chain)
|
|
178
|
+
|
|
179
|
+
# Execute the chain
|
|
180
|
+
start_time = time.time()
|
|
181
|
+
|
|
182
|
+
# The result of the chain is the reaw result (e.g., CallToolResult)
|
|
183
|
+
raw_result = await call_chain(context)
|
|
184
|
+
|
|
185
|
+
# Success, now wrap the result in response context
|
|
186
|
+
duration = time.time() - start_time
|
|
187
|
+
|
|
188
|
+
response = MCPResponseContext.create(request_id=context.id, result=raw_result)
|
|
189
|
+
response.duration = duration
|
|
190
|
+
return response
|
|
191
|
+
|
|
192
|
+
except Exception as error:
|
|
193
|
+
duration = time.time() - context.timestamp
|
|
194
|
+
response = MCPResponseContext.create(request_id=context.id, error=error)
|
|
195
|
+
response.duration = duration
|
|
196
|
+
return response
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class CallbackClientSession:
|
|
200
|
+
"""ClientSession wrapper that uses callback-based middleware."""
|
|
201
|
+
|
|
202
|
+
def __init__(self, client_session: ClientSession, connector_id: str, middleware_manager: MiddlewareManager):
|
|
203
|
+
self._client_session = client_session
|
|
204
|
+
self.connector_id = connector_id
|
|
205
|
+
self.middleware_manager = middleware_manager
|
|
206
|
+
|
|
207
|
+
async def _intercept_call(self, method_name: str, params: Any, original_call: Callable) -> Any:
|
|
208
|
+
"""
|
|
209
|
+
Creates the context, runs it through the manager, and unwraps the final response.
|
|
210
|
+
"""
|
|
211
|
+
context = MiddlewareContext(
|
|
212
|
+
id=str(uuid.uuid4()),
|
|
213
|
+
method=method_name,
|
|
214
|
+
params=params,
|
|
215
|
+
connection_id=self.connector_id,
|
|
216
|
+
timestamp=time.time(),
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# This now returns a rich MCPResponseContext
|
|
220
|
+
response_context = await self.middleware_manager.process_request(context, original_call)
|
|
221
|
+
|
|
222
|
+
# If there is an error, return it
|
|
223
|
+
if response_context.error:
|
|
224
|
+
raise response_context.error
|
|
225
|
+
|
|
226
|
+
return response_context.result
|
|
227
|
+
|
|
228
|
+
# Wrap all MCP methods with specific params
|
|
229
|
+
async def initialize(self, *args, **kwargs) -> InitializeResult:
|
|
230
|
+
return await self._intercept_call("initialize", None, lambda: self._client_session.initialize(*args, **kwargs))
|
|
231
|
+
|
|
232
|
+
# List requests usually don't have parameters
|
|
233
|
+
async def list_tools(self, *args, **kwargs) -> ListToolsResult:
|
|
234
|
+
return await self._intercept_call("tools/list", None, lambda: self._client_session.list_tools(*args, **kwargs))
|
|
235
|
+
|
|
236
|
+
async def call_tool(self, name: str, arguments: dict[str, Any], *args, **kwargs) -> CallToolResult:
|
|
237
|
+
params = CallToolRequestParams(name=name, arguments=arguments)
|
|
238
|
+
return await self._intercept_call(
|
|
239
|
+
"tools/call", params, lambda: self._client_session.call_tool(name, arguments, *args, **kwargs)
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
async def list_resources(self, *args, **kwargs) -> ListResourcesResult:
|
|
243
|
+
return await self._intercept_call(
|
|
244
|
+
"resources/list", None, lambda: self._client_session.list_resources(*args, **kwargs)
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
async def read_resource(self, uri: str, *args, **kwargs) -> ReadResourceResult:
|
|
248
|
+
params = ReadResourceRequestParams(uri=uri)
|
|
249
|
+
return await self._intercept_call(
|
|
250
|
+
"resources/read", params, lambda: self._client_session.read_resource(uri, *args, **kwargs)
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
async def list_prompts(self, *args, **kwargs) -> ListPromptsResult:
|
|
254
|
+
return await self._intercept_call(
|
|
255
|
+
"prompts/list", None, lambda: self._client_session.list_prompts(*args, **kwargs)
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
async def get_prompt(self, name: str, *args, **kwargs) -> GetPromptResult:
|
|
259
|
+
params = GetPromptRequestParams(name=name, **kwargs)
|
|
260
|
+
return await self._intercept_call(
|
|
261
|
+
"prompts/get", params, lambda: self._client_session.get_prompt(name, *args, **kwargs)
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
def __getattr__(self, name: str) -> Any:
|
|
265
|
+
"""Delegate other attributes to the wrapped session."""
|
|
266
|
+
return getattr(self._client_session, name)
|