mcp-use 1.3.12__py3-none-any.whl → 1.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mcp-use might be problematic. Click here for more details.

Files changed (101) hide show
  1. mcp_use/__init__.py +1 -1
  2. mcp_use/adapters/.deprecated +0 -0
  3. mcp_use/adapters/__init__.py +18 -7
  4. mcp_use/adapters/base.py +12 -185
  5. mcp_use/adapters/langchain_adapter.py +12 -219
  6. mcp_use/agents/adapters/__init__.py +10 -0
  7. mcp_use/agents/adapters/base.py +193 -0
  8. mcp_use/agents/adapters/langchain_adapter.py +228 -0
  9. mcp_use/agents/base.py +1 -1
  10. mcp_use/agents/managers/__init__.py +19 -0
  11. mcp_use/agents/managers/base.py +36 -0
  12. mcp_use/agents/managers/server_manager.py +131 -0
  13. mcp_use/agents/managers/tools/__init__.py +15 -0
  14. mcp_use/agents/managers/tools/base_tool.py +19 -0
  15. mcp_use/agents/managers/tools/connect_server.py +69 -0
  16. mcp_use/agents/managers/tools/disconnect_server.py +43 -0
  17. mcp_use/agents/managers/tools/get_active_server.py +29 -0
  18. mcp_use/agents/managers/tools/list_servers_tool.py +53 -0
  19. mcp_use/agents/managers/tools/search_tools.py +328 -0
  20. mcp_use/agents/mcpagent.py +16 -14
  21. mcp_use/agents/remote.py +14 -1
  22. mcp_use/auth/.deprecated +0 -0
  23. mcp_use/auth/__init__.py +19 -4
  24. mcp_use/auth/bearer.py +11 -12
  25. mcp_use/auth/oauth.py +11 -620
  26. mcp_use/auth/oauth_callback.py +16 -207
  27. mcp_use/client/__init__.py +1 -0
  28. mcp_use/client/auth/__init__.py +6 -0
  29. mcp_use/client/auth/bearer.py +23 -0
  30. mcp_use/client/auth/oauth.py +629 -0
  31. mcp_use/client/auth/oauth_callback.py +214 -0
  32. mcp_use/client/client.py +356 -0
  33. mcp_use/client/config.py +106 -0
  34. mcp_use/client/connectors/__init__.py +20 -0
  35. mcp_use/client/connectors/base.py +470 -0
  36. mcp_use/client/connectors/http.py +304 -0
  37. mcp_use/client/connectors/sandbox.py +332 -0
  38. mcp_use/client/connectors/stdio.py +109 -0
  39. mcp_use/client/connectors/utils.py +13 -0
  40. mcp_use/client/connectors/websocket.py +257 -0
  41. mcp_use/client/exceptions.py +31 -0
  42. mcp_use/client/middleware/__init__.py +50 -0
  43. mcp_use/client/middleware/logging.py +31 -0
  44. mcp_use/client/middleware/metrics.py +314 -0
  45. mcp_use/client/middleware/middleware.py +266 -0
  46. mcp_use/client/session.py +162 -0
  47. mcp_use/client/task_managers/__init__.py +20 -0
  48. mcp_use/client/task_managers/base.py +145 -0
  49. mcp_use/client/task_managers/sse.py +84 -0
  50. mcp_use/client/task_managers/stdio.py +69 -0
  51. mcp_use/client/task_managers/streamable_http.py +86 -0
  52. mcp_use/client/task_managers/websocket.py +68 -0
  53. mcp_use/client.py +12 -344
  54. mcp_use/config.py +20 -97
  55. mcp_use/connectors/.deprecated +0 -0
  56. mcp_use/connectors/__init__.py +46 -20
  57. mcp_use/connectors/base.py +12 -455
  58. mcp_use/connectors/http.py +13 -300
  59. mcp_use/connectors/sandbox.py +13 -306
  60. mcp_use/connectors/stdio.py +13 -104
  61. mcp_use/connectors/utils.py +15 -8
  62. mcp_use/connectors/websocket.py +13 -252
  63. mcp_use/exceptions.py +33 -18
  64. mcp_use/managers/.deprecated +0 -0
  65. mcp_use/managers/__init__.py +56 -17
  66. mcp_use/managers/base.py +13 -31
  67. mcp_use/managers/server_manager.py +13 -119
  68. mcp_use/managers/tools/__init__.py +45 -15
  69. mcp_use/managers/tools/base_tool.py +5 -16
  70. mcp_use/managers/tools/connect_server.py +5 -67
  71. mcp_use/managers/tools/disconnect_server.py +5 -41
  72. mcp_use/managers/tools/get_active_server.py +5 -26
  73. mcp_use/managers/tools/list_servers_tool.py +5 -51
  74. mcp_use/managers/tools/search_tools.py +17 -321
  75. mcp_use/middleware/.deprecated +0 -0
  76. mcp_use/middleware/__init__.py +89 -50
  77. mcp_use/middleware/logging.py +14 -26
  78. mcp_use/middleware/metrics.py +30 -303
  79. mcp_use/middleware/middleware.py +39 -246
  80. mcp_use/session.py +13 -149
  81. mcp_use/task_managers/.deprecated +0 -0
  82. mcp_use/task_managers/__init__.py +48 -20
  83. mcp_use/task_managers/base.py +13 -140
  84. mcp_use/task_managers/sse.py +13 -79
  85. mcp_use/task_managers/stdio.py +13 -64
  86. mcp_use/task_managers/streamable_http.py +15 -81
  87. mcp_use/task_managers/websocket.py +13 -63
  88. mcp_use/telemetry/events.py +58 -0
  89. mcp_use/telemetry/telemetry.py +71 -1
  90. mcp_use/types/.deprecated +0 -0
  91. mcp_use/types/sandbox.py +13 -18
  92. {mcp_use-1.3.12.dist-info → mcp_use-1.3.13.dist-info}/METADATA +59 -34
  93. mcp_use-1.3.13.dist-info/RECORD +109 -0
  94. mcp_use-1.3.12.dist-info/RECORD +0 -64
  95. mcp_use-1.3.12.dist-info/licenses/LICENSE +0 -21
  96. /mcp_use/{observability → agents/observability}/__init__.py +0 -0
  97. /mcp_use/{observability → agents/observability}/callbacks_manager.py +0 -0
  98. /mcp_use/{observability → agents/observability}/laminar.py +0 -0
  99. /mcp_use/{observability → agents/observability}/langfuse.py +0 -0
  100. {mcp_use-1.3.12.dist-info → mcp_use-1.3.13.dist-info}/WHEEL +0 -0
  101. {mcp_use-1.3.12.dist-info → mcp_use-1.3.13.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,314 @@
1
+ """
2
+ Metrics middleware for MCP contexts.
3
+
4
+ Classes for collecting comprehensive metrics about MCP context patterns,
5
+ performance, and errors with simple instantiation.
6
+ """
7
+
8
+ import asyncio
9
+ import time
10
+ from collections import Counter, defaultdict
11
+ from typing import Any
12
+
13
+ from mcp_use.client.middleware.middleware import Middleware, MiddlewareContext, NextFunctionT
14
+
15
+ # Constants for performance thresholds
16
+ SLOW_THRESHOLD_MS = 1000 # 1 second
17
+ FAST_THRESHOLD_MS = 50 # 50ms
18
+ MAX_SLOW_CONTEXTS = 100
19
+ MAX_FAST_CONTEXTS = 100
20
+ MAX_ERROR_TIMESTAMPS = 1000
21
+ MAX_RECENT_ERRORS = 50
22
+ SECONDS_PER_HOUR = 3600
23
+ SECONDS_PER_MINUTE = 60
24
+
25
+
26
+ class MetricsMiddleware(Middleware):
27
+ """Collects basic metrics about MCP contexts including counts, durations, and errors."""
28
+
29
+ def __init__(self):
30
+ self.metrics = {
31
+ "total_contexts": 0,
32
+ "total_errors": 0,
33
+ "method_counts": defaultdict(int),
34
+ "method_durations": defaultdict(list),
35
+ "active_contexts": 0,
36
+ "start_time": time.time(),
37
+ }
38
+ self.lock = asyncio.Lock()
39
+
40
+ async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
41
+ async with self.lock:
42
+ self.metrics["total_contexts"] += 1
43
+ self.metrics["active_contexts"] += 1
44
+ self.metrics["method_counts"][context.method] = self.metrics["method_counts"].get(context.method, 0) + 1
45
+
46
+ try:
47
+ result = await call_next(context)
48
+ duration = time.time() - context.timestamp
49
+
50
+ async with self.lock:
51
+ self.metrics["active_contexts"] -= 1
52
+ self._record_duration(context.method, duration)
53
+
54
+ return result
55
+ except Exception:
56
+ duration = time.time() - context.timestamp
57
+ async with self.lock:
58
+ self.metrics["total_errors"] += 1
59
+ self.metrics["active_contexts"] -= 1
60
+ self._record_duration(context.method, duration)
61
+ raise
62
+
63
+ def _record_duration(self, method: str, duration: float) -> None:
64
+ """Record duration for a method in a thread-safe manner."""
65
+ self.metrics["method_durations"][method].append(duration)
66
+
67
+ def get_metrics(self) -> dict[str, Any]:
68
+ """Get current metrics snapshot."""
69
+ uptime = time.time() - self.metrics["start_time"]
70
+
71
+ return {
72
+ **self.metrics,
73
+ "uptime_seconds": uptime,
74
+ "contexts_per_second": self.metrics["total_contexts"] / uptime if uptime > 0 else 0,
75
+ "error_rate": self.metrics["total_errors"] / self.metrics["total_contexts"]
76
+ if self.metrics["total_contexts"] > 0
77
+ else 0,
78
+ "method_avg_duration": {
79
+ method: sum(durations) / len(durations) if durations else 0
80
+ for method, durations in self.metrics["method_durations"].items()
81
+ },
82
+ "method_min_duration": {
83
+ method: min(durations) if durations else 0
84
+ for method, durations in self.metrics["method_durations"].items()
85
+ },
86
+ "method_max_duration": {
87
+ method: max(durations) if durations else 0
88
+ for method, durations in self.metrics["method_durations"].items()
89
+ },
90
+ }
91
+
92
+
93
+ class PerformanceMetricsMiddleware(Middleware):
94
+ """Advanced performance metrics including percentiles, throughput, and performance trends."""
95
+
96
+ def __init__(self):
97
+ self.performance_data = {
98
+ "context_times": defaultdict(list),
99
+ "hourly_counts": defaultdict(int),
100
+ "connector_performance": defaultdict(list),
101
+ "slow_contexts": [], # contexts over threshold
102
+ "fast_contexts": [], # Fastest contexts
103
+ "slow_threshold_ms": SLOW_THRESHOLD_MS,
104
+ "fast_threshold_ms": FAST_THRESHOLD_MS,
105
+ }
106
+ self.lock = asyncio.Lock()
107
+
108
+ async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
109
+ start_time = time.time()
110
+
111
+ try:
112
+ result = await call_next(context)
113
+ duration_ms = (time.time() - start_time) * 1000
114
+
115
+ async with self.lock:
116
+ # Track performance by method and connector
117
+ self.performance_data["context_times"][context.method].append(duration_ms)
118
+ self.performance_data["connector_performance"][context.connection_id].append(duration_ms)
119
+
120
+ # Track hourly patterns
121
+ hour = int(time.time() // SECONDS_PER_HOUR)
122
+ self.performance_data["hourly_counts"][hour] += 1
123
+
124
+ # Identify slow/fast contexts
125
+ if duration_ms > self.performance_data["slow_threshold_ms"]:
126
+ self.performance_data["slow_contexts"].append(
127
+ {
128
+ "method": context.method,
129
+ "connector": context.connection_id,
130
+ "duration_ms": duration_ms,
131
+ "timestamp": context.timestamp,
132
+ }
133
+ )
134
+ # Keep only last MAX_SLOW_CONTEXTS slow contexts
135
+ if len(self.performance_data["slow_contexts"]) > MAX_SLOW_CONTEXTS:
136
+ self.performance_data["slow_contexts"] = self.performance_data["slow_contexts"][
137
+ -MAX_SLOW_CONTEXTS:
138
+ ]
139
+
140
+ if duration_ms < self.performance_data["fast_threshold_ms"]:
141
+ self.performance_data["fast_contexts"].append(
142
+ {
143
+ "method": context.method,
144
+ "connector": context.connection_id,
145
+ "duration_ms": duration_ms,
146
+ "timestamp": context.timestamp,
147
+ }
148
+ )
149
+ # Keep only last MAX_FAST_CONTEXTS fast contexts
150
+ if len(self.performance_data["fast_contexts"]) > MAX_FAST_CONTEXTS:
151
+ self.performance_data["fast_contexts"] = self.performance_data["fast_contexts"][
152
+ -MAX_FAST_CONTEXTS:
153
+ ]
154
+
155
+ return result
156
+
157
+ except Exception:
158
+ # Still track duration even for failed contexts
159
+ duration_ms = (time.time() - start_time) * 1000
160
+ async with self.lock:
161
+ self.performance_data["context_times"][context.method].append(duration_ms)
162
+ self.performance_data["connector_performance"][context.connection_id].append(duration_ms)
163
+ raise
164
+
165
+ def get_performance_metrics(self) -> dict[str, Any]:
166
+ """Get detailed performance statistics."""
167
+
168
+ def calculate_percentiles(values):
169
+ if not values:
170
+ return {"p50": 0, "p90": 0, "p95": 0, "p99": 0}
171
+ sorted_values = sorted(values)
172
+ n = len(sorted_values)
173
+ return {
174
+ "p50": sorted_values[int(n * 0.5)],
175
+ "p90": sorted_values[int(n * 0.9)],
176
+ "p95": sorted_values[int(n * 0.95)],
177
+ "p99": sorted_values[int(n * 0.99)],
178
+ }
179
+
180
+ method_stats = {}
181
+ for method, times in self.performance_data["context_times"].items():
182
+ if times:
183
+ method_stats[method] = {
184
+ "count": len(times),
185
+ "avg_ms": sum(times) / len(times),
186
+ "min_ms": min(times),
187
+ "max_ms": max(times),
188
+ **calculate_percentiles(times),
189
+ }
190
+
191
+ connector_stats = {}
192
+ for connector, times in self.performance_data["connector_performance"].items():
193
+ if times:
194
+ connector_stats[connector] = {
195
+ "count": len(times),
196
+ "avg_ms": sum(times) / len(times),
197
+ "min_ms": min(times),
198
+ "max_ms": max(times),
199
+ **calculate_percentiles(times),
200
+ }
201
+
202
+ return {
203
+ "method_performance": method_stats,
204
+ "connector_performance": connector_stats,
205
+ "slow_contexts": self.performance_data["slow_contexts"][-10:], # Last 10 slow contexts
206
+ "fast_contexts": self.performance_data["fast_contexts"][-10:], # Last 10 fast contexts
207
+ "hourly_distribution": dict(self.performance_data["hourly_counts"]),
208
+ "thresholds": {
209
+ "slow_ms": self.performance_data["slow_threshold_ms"],
210
+ "fast_ms": self.performance_data["fast_threshold_ms"],
211
+ },
212
+ }
213
+
214
+
215
+ class ErrorTrackingMiddleware(Middleware):
216
+ """Error tracking and analysis middleware for detailed error analytics."""
217
+
218
+ def __init__(self):
219
+ self.error_data = {
220
+ "error_counts": Counter(),
221
+ "error_by_method": defaultdict(Counter),
222
+ "error_by_connector": defaultdict(Counter),
223
+ "recent_errors": [],
224
+ "error_timestamps": [],
225
+ }
226
+ self.lock = asyncio.Lock()
227
+
228
+ async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
229
+ try:
230
+ return await call_next(context)
231
+
232
+ except Exception as e:
233
+ async with self.lock:
234
+ error_type = type(e).__name__
235
+ error_msg = str(e)
236
+
237
+ # Track error patterns
238
+ self.error_data["error_counts"][error_type] += 1
239
+ self.error_data["error_by_method"][context.method][error_type] += 1
240
+ self.error_data["error_by_connector"][context.connection_id][error_type] += 1
241
+ self.error_data["error_timestamps"].append(time.time())
242
+
243
+ # Keep recent errors for analysis
244
+ error_info = {
245
+ "timestamp": context.timestamp,
246
+ "method": context.method,
247
+ "connector": context.connection_id,
248
+ "error_type": error_type,
249
+ "error_message": error_msg,
250
+ "context_id": context.id,
251
+ }
252
+ self.error_data["recent_errors"].append(error_info)
253
+
254
+ # Keep only last MAX_RECENT_ERRORS errors
255
+ if len(self.error_data["recent_errors"]) > MAX_RECENT_ERRORS:
256
+ self.error_data["recent_errors"] = self.error_data["recent_errors"][-MAX_RECENT_ERRORS:]
257
+
258
+ # Keep only last MAX_ERROR_TIMESTAMPS timestamps
259
+ if len(self.error_data["error_timestamps"]) > MAX_ERROR_TIMESTAMPS:
260
+ self.error_data["error_timestamps"] = self.error_data["error_timestamps"][-MAX_ERROR_TIMESTAMPS:]
261
+
262
+ raise # Re-raise the error
263
+
264
+ def get_error_analytics(self) -> dict[str, Any]:
265
+ """Get detailed error analytics."""
266
+
267
+ # Calculate error rate over time windows
268
+ now = time.time()
269
+ recent_errors = [t for t in self.error_data["error_timestamps"] if now - t < SECONDS_PER_HOUR] # Last hour
270
+ very_recent_errors = [
271
+ t for t in self.error_data["error_timestamps"] if now - t < 5 * SECONDS_PER_MINUTE
272
+ ] # Last 5 min
273
+
274
+ return {
275
+ "total_errors": sum(self.error_data["error_counts"].values()),
276
+ "error_types": dict(self.error_data["error_counts"]),
277
+ "errors_by_method": {method: dict(errors) for method, errors in self.error_data["error_by_method"].items()},
278
+ "errors_by_connector": {
279
+ connector: dict(errors) for connector, errors in self.error_data["error_by_connector"].items()
280
+ },
281
+ "recent_errors": self.error_data["recent_errors"][-10:], # Last 10 errors
282
+ "error_rate_last_hour": len(recent_errors),
283
+ "error_rate_last_5min": len(very_recent_errors),
284
+ "most_common_errors": self.error_data["error_counts"].most_common(5),
285
+ }
286
+
287
+
288
+ class CombinedAnalyticsMiddleware(Middleware):
289
+ """Comprehensive middleware combining metrics, performance, and error tracking."""
290
+
291
+ def __init__(self):
292
+ self.metrics_mw = MetricsMiddleware()
293
+ self.perf_mw = PerformanceMetricsMiddleware()
294
+ self.error_mw = ErrorTrackingMiddleware()
295
+
296
+ async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
297
+ # Chain the middleware in the desired order: Metrics -> Errors -> Performance
298
+ async def chain(ctx):
299
+ # The final call in the chain is the original `call_next`
300
+ return await self.perf_mw.on_request(ctx, call_next)
301
+
302
+ async def error_chain(ctx):
303
+ return await self.error_mw.on_request(ctx, chain)
304
+
305
+ return await self.metrics_mw.on_request(context, error_chain)
306
+
307
+ def get_combined_analytics(self) -> dict[str, Any]:
308
+ """Get all analytics data in one comprehensive report."""
309
+ return {
310
+ "metrics": self.metrics_mw.get_metrics(),
311
+ "performance": self.perf_mw.get_performance_metrics(),
312
+ "errors": self.error_mw.get_error_analytics(),
313
+ "generated_at": time.time(),
314
+ }
@@ -0,0 +1,266 @@
1
+ """
2
+ Core middleware system for MCP requests.
3
+
4
+ This module provides a robust and extensible middleware architecture:
5
+ - A typed MiddlewareContext to carry request data.
6
+ - A Middleware base class with a dispatcher that routes to strongly-typed hooks.
7
+ - A MiddlewareManager to build and execute the processing chain.
8
+ - A CallbackClientSession that acts as an adapter, creating the initial context
9
+ without requiring changes to upstream callers like HttpConnector.
10
+ """
11
+
12
+ import time
13
+ import uuid
14
+ from collections.abc import Callable
15
+ from dataclasses import dataclass, field
16
+ from functools import partial
17
+ from typing import Any, Generic, Protocol, TypeVar
18
+
19
+ from mcp import ClientSession
20
+ from mcp.types import (
21
+ CallToolRequestParams,
22
+ CallToolResult,
23
+ GetPromptRequestParams,
24
+ GetPromptResult,
25
+ InitializeRequestParams,
26
+ InitializeResult,
27
+ JSONRPCResponse,
28
+ ListPromptsRequest,
29
+ ListPromptsResult,
30
+ ListResourcesRequest,
31
+ ListResourcesResult,
32
+ ListToolsRequest,
33
+ ListToolsResult,
34
+ ReadResourceRequestParams,
35
+ ReadResourceResult,
36
+ )
37
+
38
+ from mcp_use.telemetry.telemetry import telemetry
39
+
40
+ # Generig TypeVars for context and results
41
+ T = TypeVar("T")
42
+ R = TypeVar("R", covariant=True)
43
+
44
+
45
+ @dataclass
46
+ class MiddlewareContext(Generic[T]):
47
+ """Unified, typed context for all middleware operations."""
48
+
49
+ id: str
50
+ method: str # The JSON-RPC method name, e.g., "tools/call"
51
+ params: T # The typed parameters for the method
52
+ connection_id: str
53
+ timestamp: float
54
+ metadata: dict[str, Any] = field(default_factory=dict)
55
+
56
+
57
+ @dataclass
58
+ class MCPResponseContext:
59
+ """Extended context for MCP responses with middleware metadata."""
60
+
61
+ request_id: str
62
+ result: Any
63
+ error: Exception | None
64
+ duration: float
65
+ metadata: dict[str, Any]
66
+ jsonrpc_response: JSONRPCResponse | None = None
67
+
68
+ @classmethod
69
+ def create(cls, request_id: str, result: Any = None, error: Exception = None) -> "MCPResponseContext":
70
+ return cls(request_id=request_id, result=result, error=error, duration=0.0, metadata={})
71
+
72
+
73
+ # Protocol definition for middleware
74
+ class NextFunctionT(Protocol[T, R]):
75
+ """Protocol for the `call_next` function passed to middleware."""
76
+
77
+ async def __call__(self, context: MiddlewareContext[T]) -> R: ...
78
+
79
+
80
+ class Middleware:
81
+ """Base class for middlewares with hooks."""
82
+
83
+ async def __call__(self, context: MiddlewareContext[T], call_next: NextFunctionT[T, Any]) -> Any:
84
+ """Main entry point that orchestrates the chain"""
85
+ handler_chain = await self._dispatch_handler(context, call_next)
86
+ return await handler_chain(context)
87
+
88
+ async def _dispatch_handler(
89
+ self, context: MiddlewareContext[Any], call_next: NextFunctionT[Any, Any]
90
+ ) -> NextFunctionT[Any, Any]:
91
+ """Build a chain of handlers"""
92
+ handler = call_next
93
+
94
+ method_map = {
95
+ "initialize": self.on_initialize,
96
+ "tools/call": self.on_call_tool,
97
+ "tools/list": self.on_list_tools,
98
+ "resources/list": self.on_list_resources,
99
+ "resources/read": self.on_read_resource,
100
+ "prompts/list": self.on_list_prompts,
101
+ "prompts/get": self.on_get_prompt,
102
+ }
103
+
104
+ if hook := method_map.get(context.method):
105
+ handler = partial(hook, call_next=handler)
106
+
107
+ # We can assume that all intercepted calls are requests
108
+ handler = partial(self.on_request, call_next=handler)
109
+
110
+ return handler
111
+
112
+ # Default implementations for all hooks
113
+ async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
114
+ return await call_next(context)
115
+
116
+ async def on_initialize(
117
+ self, context: MiddlewareContext[InitializeRequestParams], call_next: NextFunctionT
118
+ ) -> InitializeResult:
119
+ return await call_next(context)
120
+
121
+ async def on_call_tool(
122
+ self, context: MiddlewareContext[CallToolRequestParams], call_next: NextFunctionT
123
+ ) -> CallToolResult:
124
+ return await call_next(context)
125
+
126
+ async def on_read_resource(
127
+ self, context: MiddlewareContext[ReadResourceRequestParams], call_next: NextFunctionT
128
+ ) -> ReadResourceResult:
129
+ return await call_next(context)
130
+
131
+ async def on_get_prompt(
132
+ self, context: MiddlewareContext[GetPromptRequestParams], call_next: NextFunctionT
133
+ ) -> GetPromptResult:
134
+ return await call_next(context)
135
+
136
+ async def on_list_tools(
137
+ self, context: MiddlewareContext[ListToolsRequest], call_next: NextFunctionT
138
+ ) -> ListToolsResult:
139
+ return await call_next(context)
140
+
141
+ async def on_list_resources(
142
+ self, context: MiddlewareContext[ListResourcesRequest], call_next: NextFunctionT
143
+ ) -> ListResourcesResult:
144
+ return await call_next(context)
145
+
146
+ async def on_list_prompts(
147
+ self, context: MiddlewareContext[ListPromptsRequest], call_next: NextFunctionT
148
+ ) -> ListPromptsResult:
149
+ return await call_next(context)
150
+
151
+
152
+ class MiddlewareManager:
153
+ """Manages middleware callbacks for MCP requests."""
154
+
155
+ def __init__(self):
156
+ self.middlewares: list[Middleware] = []
157
+
158
+ @telemetry("middleware_add")
159
+ def add_middleware(self, callback: Middleware) -> None:
160
+ """Add a middleware callback."""
161
+ self.middlewares.append(callback)
162
+
163
+ @telemetry("middleware_process_request")
164
+ async def process_request(self, context: MiddlewareContext, original_call: Callable) -> MCPResponseContext:
165
+ """
166
+ Runs the full middleware chain, captures timing and errors,
167
+ and returns a structured MCPResponseContext.
168
+ """
169
+
170
+ try:
171
+ # Chain middleware callbacks
172
+ async def execute_call(_: MiddlewareContext) -> Any:
173
+ return await original_call()
174
+
175
+ call_chain = execute_call
176
+ for middleware in reversed(self.middlewares):
177
+ call_chain = partial(middleware, call_next=call_chain)
178
+
179
+ # Execute the chain
180
+ start_time = time.time()
181
+
182
+ # The result of the chain is the reaw result (e.g., CallToolResult)
183
+ raw_result = await call_chain(context)
184
+
185
+ # Success, now wrap the result in response context
186
+ duration = time.time() - start_time
187
+
188
+ response = MCPResponseContext.create(request_id=context.id, result=raw_result)
189
+ response.duration = duration
190
+ return response
191
+
192
+ except Exception as error:
193
+ duration = time.time() - context.timestamp
194
+ response = MCPResponseContext.create(request_id=context.id, error=error)
195
+ response.duration = duration
196
+ return response
197
+
198
+
199
+ class CallbackClientSession:
200
+ """ClientSession wrapper that uses callback-based middleware."""
201
+
202
+ def __init__(self, client_session: ClientSession, connector_id: str, middleware_manager: MiddlewareManager):
203
+ self._client_session = client_session
204
+ self.connector_id = connector_id
205
+ self.middleware_manager = middleware_manager
206
+
207
+ async def _intercept_call(self, method_name: str, params: Any, original_call: Callable) -> Any:
208
+ """
209
+ Creates the context, runs it through the manager, and unwraps the final response.
210
+ """
211
+ context = MiddlewareContext(
212
+ id=str(uuid.uuid4()),
213
+ method=method_name,
214
+ params=params,
215
+ connection_id=self.connector_id,
216
+ timestamp=time.time(),
217
+ )
218
+
219
+ # This now returns a rich MCPResponseContext
220
+ response_context = await self.middleware_manager.process_request(context, original_call)
221
+
222
+ # If there is an error, return it
223
+ if response_context.error:
224
+ raise response_context.error
225
+
226
+ return response_context.result
227
+
228
+ # Wrap all MCP methods with specific params
229
+ async def initialize(self, *args, **kwargs) -> InitializeResult:
230
+ return await self._intercept_call("initialize", None, lambda: self._client_session.initialize(*args, **kwargs))
231
+
232
+ # List requests usually don't have parameters
233
+ async def list_tools(self, *args, **kwargs) -> ListToolsResult:
234
+ return await self._intercept_call("tools/list", None, lambda: self._client_session.list_tools(*args, **kwargs))
235
+
236
+ async def call_tool(self, name: str, arguments: dict[str, Any], *args, **kwargs) -> CallToolResult:
237
+ params = CallToolRequestParams(name=name, arguments=arguments)
238
+ return await self._intercept_call(
239
+ "tools/call", params, lambda: self._client_session.call_tool(name, arguments, *args, **kwargs)
240
+ )
241
+
242
+ async def list_resources(self, *args, **kwargs) -> ListResourcesResult:
243
+ return await self._intercept_call(
244
+ "resources/list", None, lambda: self._client_session.list_resources(*args, **kwargs)
245
+ )
246
+
247
+ async def read_resource(self, uri: str, *args, **kwargs) -> ReadResourceResult:
248
+ params = ReadResourceRequestParams(uri=uri)
249
+ return await self._intercept_call(
250
+ "resources/read", params, lambda: self._client_session.read_resource(uri, *args, **kwargs)
251
+ )
252
+
253
+ async def list_prompts(self, *args, **kwargs) -> ListPromptsResult:
254
+ return await self._intercept_call(
255
+ "prompts/list", None, lambda: self._client_session.list_prompts(*args, **kwargs)
256
+ )
257
+
258
+ async def get_prompt(self, name: str, *args, **kwargs) -> GetPromptResult:
259
+ params = GetPromptRequestParams(name=name, **kwargs)
260
+ return await self._intercept_call(
261
+ "prompts/get", params, lambda: self._client_session.get_prompt(name, *args, **kwargs)
262
+ )
263
+
264
+ def __getattr__(self, name: str) -> Any:
265
+ """Delegate other attributes to the wrapped session."""
266
+ return getattr(self._client_session, name)