chuk-tool-processor 0.6.13__py3-none-any.whl → 0.9.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chuk-tool-processor might be problematic. Click here for more details.
- chuk_tool_processor/core/__init__.py +31 -0
- chuk_tool_processor/core/exceptions.py +218 -12
- chuk_tool_processor/core/processor.py +38 -7
- chuk_tool_processor/execution/strategies/__init__.py +6 -0
- chuk_tool_processor/execution/strategies/subprocess_strategy.py +2 -1
- chuk_tool_processor/execution/wrappers/__init__.py +42 -0
- chuk_tool_processor/execution/wrappers/caching.py +48 -13
- chuk_tool_processor/execution/wrappers/circuit_breaker.py +370 -0
- chuk_tool_processor/execution/wrappers/rate_limiting.py +31 -1
- chuk_tool_processor/execution/wrappers/retry.py +93 -53
- chuk_tool_processor/logging/metrics.py +2 -2
- chuk_tool_processor/mcp/mcp_tool.py +5 -5
- chuk_tool_processor/mcp/setup_mcp_http_streamable.py +44 -2
- chuk_tool_processor/mcp/setup_mcp_sse.py +44 -2
- chuk_tool_processor/mcp/setup_mcp_stdio.py +2 -0
- chuk_tool_processor/mcp/stream_manager.py +130 -75
- chuk_tool_processor/mcp/transport/__init__.py +10 -0
- chuk_tool_processor/mcp/transport/http_streamable_transport.py +193 -108
- chuk_tool_processor/mcp/transport/models.py +100 -0
- chuk_tool_processor/mcp/transport/sse_transport.py +155 -59
- chuk_tool_processor/mcp/transport/stdio_transport.py +58 -10
- chuk_tool_processor/models/__init__.py +20 -0
- chuk_tool_processor/models/tool_call.py +34 -1
- chuk_tool_processor/models/tool_spec.py +350 -0
- chuk_tool_processor/models/validated_tool.py +22 -2
- chuk_tool_processor/observability/__init__.py +30 -0
- chuk_tool_processor/observability/metrics.py +312 -0
- chuk_tool_processor/observability/setup.py +105 -0
- chuk_tool_processor/observability/tracing.py +345 -0
- chuk_tool_processor/plugins/discovery.py +1 -1
- chuk_tool_processor-0.9.7.dist-info/METADATA +1813 -0
- {chuk_tool_processor-0.6.13.dist-info → chuk_tool_processor-0.9.7.dist-info}/RECORD +34 -27
- chuk_tool_processor-0.6.13.dist-info/METADATA +0 -698
- {chuk_tool_processor-0.6.13.dist-info → chuk_tool_processor-0.9.7.dist-info}/WHEEL +0 -0
- {chuk_tool_processor-0.6.13.dist-info → chuk_tool_processor-0.9.7.dist-info}/top_level.txt +0 -0
|
@@ -16,12 +16,15 @@ from chuk_mcp.protocol.messages import ( # type: ignore[import-untyped]
|
|
|
16
16
|
send_tools_call,
|
|
17
17
|
send_tools_list,
|
|
18
18
|
)
|
|
19
|
+
from chuk_mcp.transports.http.parameters import StreamableHTTPParameters # type: ignore[import-untyped]
|
|
19
20
|
|
|
20
21
|
# Import chuk-mcp HTTP transport components
|
|
21
|
-
from chuk_mcp.transports.http import
|
|
22
|
-
|
|
22
|
+
from chuk_mcp.transports.http.transport import (
|
|
23
|
+
StreamableHTTPTransport as ChukHTTPTransport, # type: ignore[import-untyped]
|
|
24
|
+
)
|
|
23
25
|
|
|
24
26
|
from .base_transport import MCPBaseTransport
|
|
27
|
+
from .models import TimeoutConfig, TransportMetrics
|
|
25
28
|
|
|
26
29
|
logger = logging.getLogger(__name__)
|
|
27
30
|
|
|
@@ -38,11 +41,13 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
38
41
|
self,
|
|
39
42
|
url: str,
|
|
40
43
|
api_key: str | None = None,
|
|
41
|
-
headers: dict[str, str] | None = None,
|
|
44
|
+
headers: dict[str, str] | None = None,
|
|
42
45
|
connection_timeout: float = 30.0,
|
|
43
46
|
default_timeout: float = 30.0,
|
|
44
47
|
session_id: str | None = None,
|
|
45
48
|
enable_metrics: bool = True,
|
|
49
|
+
oauth_refresh_callback: Any | None = None,
|
|
50
|
+
timeout_config: TimeoutConfig | None = None,
|
|
46
51
|
):
|
|
47
52
|
"""
|
|
48
53
|
Initialize HTTP Streamable transport with enhanced configuration.
|
|
@@ -50,11 +55,13 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
50
55
|
Args:
|
|
51
56
|
url: HTTP server URL (should end with /mcp)
|
|
52
57
|
api_key: Optional API key for authentication
|
|
53
|
-
headers: Optional custom headers
|
|
54
|
-
connection_timeout: Timeout for initial connection
|
|
55
|
-
default_timeout: Default timeout for operations
|
|
58
|
+
headers: Optional custom headers
|
|
59
|
+
connection_timeout: Timeout for initial connection (overrides timeout_config.connect)
|
|
60
|
+
default_timeout: Default timeout for operations (overrides timeout_config.operation)
|
|
56
61
|
session_id: Optional session ID for stateful connections
|
|
57
62
|
enable_metrics: Whether to track performance metrics
|
|
63
|
+
oauth_refresh_callback: Optional async callback to refresh OAuth tokens
|
|
64
|
+
timeout_config: Optional timeout configuration model with connect/operation/quick/shutdown
|
|
58
65
|
"""
|
|
59
66
|
# Ensure URL points to the /mcp endpoint
|
|
60
67
|
if not url.endswith("/mcp"):
|
|
@@ -63,11 +70,18 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
63
70
|
self.url = url
|
|
64
71
|
|
|
65
72
|
self.api_key = api_key
|
|
66
|
-
self.configured_headers = headers or {}
|
|
67
|
-
self.connection_timeout = connection_timeout
|
|
68
|
-
self.default_timeout = default_timeout
|
|
73
|
+
self.configured_headers = headers or {}
|
|
69
74
|
self.session_id = session_id
|
|
70
75
|
self.enable_metrics = enable_metrics
|
|
76
|
+
self.oauth_refresh_callback = oauth_refresh_callback
|
|
77
|
+
|
|
78
|
+
# Use timeout config or create from individual parameters
|
|
79
|
+
if timeout_config is None:
|
|
80
|
+
timeout_config = TimeoutConfig(connect=connection_timeout, operation=default_timeout)
|
|
81
|
+
|
|
82
|
+
self.timeout_config = timeout_config
|
|
83
|
+
self.connection_timeout = timeout_config.connect
|
|
84
|
+
self.default_timeout = timeout_config.operation
|
|
71
85
|
|
|
72
86
|
logger.debug("HTTP Streamable transport initialized with URL: %s", self.url)
|
|
73
87
|
if self.api_key:
|
|
@@ -78,7 +92,7 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
78
92
|
logger.debug("Session ID configured: %s", self.session_id)
|
|
79
93
|
|
|
80
94
|
# State tracking (enhanced like SSE)
|
|
81
|
-
self.
|
|
95
|
+
self._http_transport = None
|
|
82
96
|
self._read_stream = None
|
|
83
97
|
self._write_stream = None
|
|
84
98
|
self._initialized = False
|
|
@@ -88,20 +102,8 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
88
102
|
self._consecutive_failures = 0
|
|
89
103
|
self._max_consecutive_failures = 3
|
|
90
104
|
|
|
91
|
-
# Performance metrics (enhanced like SSE)
|
|
92
|
-
self._metrics =
|
|
93
|
-
"total_calls": 0,
|
|
94
|
-
"successful_calls": 0,
|
|
95
|
-
"failed_calls": 0,
|
|
96
|
-
"total_time": 0.0,
|
|
97
|
-
"avg_response_time": 0.0,
|
|
98
|
-
"last_ping_time": None,
|
|
99
|
-
"initialization_time": None,
|
|
100
|
-
"connection_resets": 0,
|
|
101
|
-
"stream_errors": 0,
|
|
102
|
-
"connection_errors": 0, # NEW
|
|
103
|
-
"recovery_attempts": 0, # NEW
|
|
104
|
-
}
|
|
105
|
+
# Performance metrics (enhanced like SSE) - use Pydantic model
|
|
106
|
+
self._metrics = TransportMetrics() if enable_metrics else None
|
|
105
107
|
|
|
106
108
|
def _get_headers(self) -> dict[str, str]:
|
|
107
109
|
"""Get headers with authentication and custom headers (like SSE)."""
|
|
@@ -115,8 +117,9 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
115
117
|
if self.configured_headers:
|
|
116
118
|
headers.update(self.configured_headers)
|
|
117
119
|
|
|
118
|
-
# Add API key as Bearer token if provided
|
|
119
|
-
|
|
120
|
+
# Add API key as Bearer token if provided and no Authorization header exists
|
|
121
|
+
# This prevents clobbering OAuth tokens from configured_headers
|
|
122
|
+
if self.api_key and "Authorization" not in headers:
|
|
120
123
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
121
124
|
|
|
122
125
|
# Add session ID if provided
|
|
@@ -130,7 +133,7 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
130
133
|
try:
|
|
131
134
|
import httpx
|
|
132
135
|
|
|
133
|
-
async with httpx.AsyncClient(timeout=
|
|
136
|
+
async with httpx.AsyncClient(timeout=self.timeout_config.quick) as client:
|
|
134
137
|
# Test basic connectivity to base URL
|
|
135
138
|
base_url = self.url.replace("/mcp", "")
|
|
136
139
|
response = await client.get(f"{base_url}/health", headers=self._get_headers())
|
|
@@ -159,33 +162,38 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
159
162
|
headers = self._get_headers()
|
|
160
163
|
logger.debug("Using headers: %s", list(headers.keys()))
|
|
161
164
|
|
|
162
|
-
# Create StreamableHTTPParameters with
|
|
165
|
+
# Create StreamableHTTPParameters with minimal configuration
|
|
166
|
+
# NOTE: Keep params minimal - extra params can break message routing
|
|
163
167
|
http_params = StreamableHTTPParameters(
|
|
164
168
|
url=self.url,
|
|
165
169
|
timeout=self.default_timeout,
|
|
166
170
|
headers=headers,
|
|
167
|
-
bearer_token=None, # Don't duplicate auth - it's in headers
|
|
168
|
-
session_id=self.session_id,
|
|
169
171
|
enable_streaming=True,
|
|
170
|
-
max_concurrent_requests=5,
|
|
171
|
-
max_retries=2,
|
|
172
|
-
retry_delay=1.0,
|
|
173
|
-
user_agent="chuk-tool-processor/1.0.0",
|
|
174
172
|
)
|
|
175
173
|
|
|
176
|
-
# Create and
|
|
177
|
-
self.
|
|
174
|
+
# Create and store transport (will be managed via async with in parent scope)
|
|
175
|
+
self._http_transport = ChukHTTPTransport(http_params)
|
|
178
176
|
|
|
177
|
+
# IMPORTANT: Must use async with for proper stream setup
|
|
179
178
|
logger.debug("Establishing HTTP connection...")
|
|
180
|
-
self.
|
|
181
|
-
self.
|
|
179
|
+
self._http_context_entered = await asyncio.wait_for(
|
|
180
|
+
self._http_transport.__aenter__(), timeout=self.connection_timeout
|
|
182
181
|
)
|
|
183
182
|
|
|
183
|
+
# Get streams after context entered
|
|
184
|
+
self._read_stream, self._write_stream = await self._http_transport.get_streams()
|
|
185
|
+
|
|
186
|
+
# Give the transport's message handler task time to start
|
|
187
|
+
await asyncio.sleep(0.1)
|
|
188
|
+
|
|
184
189
|
# Enhanced MCP initialize sequence
|
|
185
190
|
logger.debug("Sending MCP initialize request...")
|
|
186
191
|
init_start = time.time()
|
|
187
192
|
|
|
188
|
-
await asyncio.wait_for(
|
|
193
|
+
await asyncio.wait_for(
|
|
194
|
+
send_initialize(self._read_stream, self._write_stream, timeout=self.default_timeout),
|
|
195
|
+
timeout=self.default_timeout,
|
|
196
|
+
)
|
|
189
197
|
|
|
190
198
|
init_time = time.time() - init_start
|
|
191
199
|
logger.debug("MCP initialize completed in %.3fs", init_time)
|
|
@@ -193,9 +201,11 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
193
201
|
# Verify connection with ping (enhanced like SSE)
|
|
194
202
|
logger.debug("Verifying connection with ping...")
|
|
195
203
|
ping_start = time.time()
|
|
204
|
+
# Use connect timeout for initial ping - some servers (like Notion) are slow
|
|
205
|
+
ping_timeout = self.timeout_config.connect
|
|
196
206
|
ping_success = await asyncio.wait_for(
|
|
197
|
-
send_ping(self._read_stream, self._write_stream),
|
|
198
|
-
timeout=
|
|
207
|
+
send_ping(self._read_stream, self._write_stream, timeout=ping_timeout),
|
|
208
|
+
timeout=ping_timeout,
|
|
199
209
|
)
|
|
200
210
|
ping_time = time.time() - ping_start
|
|
201
211
|
|
|
@@ -205,9 +215,9 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
205
215
|
self._consecutive_failures = 0
|
|
206
216
|
|
|
207
217
|
total_init_time = time.time() - start_time
|
|
208
|
-
if self.enable_metrics:
|
|
209
|
-
self._metrics
|
|
210
|
-
self._metrics
|
|
218
|
+
if self.enable_metrics and self._metrics:
|
|
219
|
+
self._metrics.initialization_time = total_init_time
|
|
220
|
+
self._metrics.last_ping_time = ping_time
|
|
211
221
|
|
|
212
222
|
logger.debug(
|
|
213
223
|
"HTTP Streamable transport initialized successfully in %.3fs (ping: %.3fs)",
|
|
@@ -216,31 +226,31 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
216
226
|
)
|
|
217
227
|
return True
|
|
218
228
|
else:
|
|
219
|
-
logger.
|
|
229
|
+
logger.debug("HTTP connection established but ping failed")
|
|
220
230
|
# Still consider it initialized since connection was established
|
|
221
231
|
self._initialized = True
|
|
222
232
|
self._consecutive_failures = 1 # Mark one failure
|
|
223
|
-
if self.enable_metrics:
|
|
224
|
-
self._metrics
|
|
233
|
+
if self.enable_metrics and self._metrics:
|
|
234
|
+
self._metrics.initialization_time = time.time() - start_time
|
|
225
235
|
return True
|
|
226
236
|
|
|
227
237
|
except TimeoutError:
|
|
228
238
|
logger.error("HTTP Streamable initialization timed out after %ss", self.connection_timeout)
|
|
229
239
|
await self._cleanup()
|
|
230
|
-
if self.enable_metrics:
|
|
231
|
-
self._metrics
|
|
232
|
-
|
|
240
|
+
if self.enable_metrics and self._metrics:
|
|
241
|
+
self._metrics.connection_errors += 1
|
|
242
|
+
raise # Re-raise for OAuth error detection in mcp-cli
|
|
233
243
|
except Exception as e:
|
|
234
244
|
logger.error("Error initializing HTTP Streamable transport: %s", e, exc_info=True)
|
|
235
245
|
await self._cleanup()
|
|
236
|
-
if self.enable_metrics:
|
|
237
|
-
self._metrics
|
|
238
|
-
|
|
246
|
+
if self.enable_metrics and self._metrics:
|
|
247
|
+
self._metrics.connection_errors += 1
|
|
248
|
+
raise # Re-raise for OAuth error detection in mcp-cli
|
|
239
249
|
|
|
240
250
|
async def _attempt_recovery(self) -> bool:
|
|
241
251
|
"""Attempt to recover from connection issues (NEW - like SSE resilience)."""
|
|
242
|
-
if self.enable_metrics:
|
|
243
|
-
self._metrics
|
|
252
|
+
if self.enable_metrics and self._metrics:
|
|
253
|
+
self._metrics.recovery_attempts += 1
|
|
244
254
|
|
|
245
255
|
logger.debug("Attempting HTTP connection recovery...")
|
|
246
256
|
|
|
@@ -260,21 +270,21 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
260
270
|
return
|
|
261
271
|
|
|
262
272
|
# Enhanced metrics logging (like SSE)
|
|
263
|
-
if self.enable_metrics and self._metrics
|
|
264
|
-
success_rate = self._metrics
|
|
273
|
+
if self.enable_metrics and self._metrics and self._metrics.total_calls > 0:
|
|
274
|
+
success_rate = self._metrics.successful_calls / self._metrics.total_calls * 100
|
|
265
275
|
logger.debug(
|
|
266
276
|
"HTTP Streamable transport closing - Calls: %d, Success: %.1f%%, "
|
|
267
277
|
"Avg time: %.3fs, Recoveries: %d, Errors: %d",
|
|
268
|
-
self._metrics
|
|
278
|
+
self._metrics.total_calls,
|
|
269
279
|
success_rate,
|
|
270
|
-
self._metrics
|
|
271
|
-
self._metrics
|
|
272
|
-
self._metrics
|
|
280
|
+
self._metrics.avg_response_time,
|
|
281
|
+
self._metrics.recovery_attempts,
|
|
282
|
+
self._metrics.connection_errors,
|
|
273
283
|
)
|
|
274
284
|
|
|
275
285
|
try:
|
|
276
|
-
if self.
|
|
277
|
-
await self.
|
|
286
|
+
if self._http_transport is not None:
|
|
287
|
+
await self._http_transport.__aexit__(None, None, None)
|
|
278
288
|
logger.debug("HTTP Streamable context closed")
|
|
279
289
|
|
|
280
290
|
except Exception as e:
|
|
@@ -284,7 +294,7 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
284
294
|
|
|
285
295
|
async def _cleanup(self) -> None:
|
|
286
296
|
"""Enhanced cleanup with state reset."""
|
|
287
|
-
self.
|
|
297
|
+
self._http_transport = None
|
|
288
298
|
self._read_stream = None
|
|
289
299
|
self._write_stream = None
|
|
290
300
|
self._initialized = False
|
|
@@ -292,13 +302,14 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
292
302
|
async def send_ping(self) -> bool:
|
|
293
303
|
"""Enhanced ping with health monitoring (like SSE)."""
|
|
294
304
|
if not self._initialized or not self._read_stream:
|
|
295
|
-
logger.
|
|
305
|
+
logger.debug("Cannot send ping: transport not initialized")
|
|
296
306
|
return False
|
|
297
307
|
|
|
298
308
|
start_time = time.time()
|
|
299
309
|
try:
|
|
300
310
|
result = await asyncio.wait_for(
|
|
301
|
-
send_ping(self._read_stream, self._write_stream
|
|
311
|
+
send_ping(self._read_stream, self._write_stream, timeout=self.default_timeout),
|
|
312
|
+
timeout=self.default_timeout,
|
|
302
313
|
)
|
|
303
314
|
|
|
304
315
|
success = bool(result)
|
|
@@ -309,9 +320,9 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
309
320
|
else:
|
|
310
321
|
self._consecutive_failures += 1
|
|
311
322
|
|
|
312
|
-
if self.enable_metrics:
|
|
323
|
+
if self.enable_metrics and self._metrics:
|
|
313
324
|
ping_time = time.time() - start_time
|
|
314
|
-
self._metrics
|
|
325
|
+
self._metrics.last_ping_time = ping_time
|
|
315
326
|
logger.debug("HTTP Streamable ping completed in %.3fs: %s", ping_time, success)
|
|
316
327
|
|
|
317
328
|
return success
|
|
@@ -322,8 +333,8 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
322
333
|
except Exception as e:
|
|
323
334
|
logger.error("HTTP Streamable ping failed: %s", e)
|
|
324
335
|
self._consecutive_failures += 1
|
|
325
|
-
if self.enable_metrics:
|
|
326
|
-
self._metrics
|
|
336
|
+
if self.enable_metrics and self._metrics:
|
|
337
|
+
self._metrics.stream_errors += 1
|
|
327
338
|
return False
|
|
328
339
|
|
|
329
340
|
def is_connected(self) -> bool:
|
|
@@ -341,18 +352,29 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
341
352
|
async def get_tools(self) -> list[dict[str, Any]]:
|
|
342
353
|
"""Enhanced tools retrieval with error handling."""
|
|
343
354
|
if not self._initialized:
|
|
344
|
-
logger.
|
|
355
|
+
logger.debug("Cannot get tools: transport not initialized")
|
|
345
356
|
return []
|
|
346
357
|
|
|
347
358
|
start_time = time.time()
|
|
348
359
|
try:
|
|
349
360
|
tools_response = await asyncio.wait_for(
|
|
350
|
-
send_tools_list(self._read_stream, self._write_stream
|
|
361
|
+
send_tools_list(self._read_stream, self._write_stream, timeout=self.default_timeout),
|
|
362
|
+
timeout=self.default_timeout,
|
|
351
363
|
)
|
|
352
364
|
|
|
353
|
-
# Normalize response
|
|
354
|
-
if
|
|
365
|
+
# Normalize response - handle multiple formats including Pydantic models
|
|
366
|
+
# 1. Check if it's a Pydantic model with tools attribute (e.g., ListToolsResult from chuk_mcp)
|
|
367
|
+
if hasattr(tools_response, "tools"):
|
|
368
|
+
tools = tools_response.tools
|
|
369
|
+
# Convert Pydantic Tool models to dicts if needed
|
|
370
|
+
if tools and len(tools) > 0 and hasattr(tools[0], "model_dump"):
|
|
371
|
+
tools = [t.model_dump() for t in tools]
|
|
372
|
+
elif tools and len(tools) > 0 and hasattr(tools[0], "dict"):
|
|
373
|
+
tools = [t.dict() for t in tools]
|
|
374
|
+
# 2. Check if it's a dict with "tools" key
|
|
375
|
+
elif isinstance(tools_response, dict):
|
|
355
376
|
tools = tools_response.get("tools", [])
|
|
377
|
+
# 3. Check if it's already a list
|
|
356
378
|
elif isinstance(tools_response, list):
|
|
357
379
|
tools = tools_response
|
|
358
380
|
else:
|
|
@@ -375,8 +397,8 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
375
397
|
except Exception as e:
|
|
376
398
|
logger.error("Error getting tools: %s", e)
|
|
377
399
|
self._consecutive_failures += 1
|
|
378
|
-
if self.enable_metrics:
|
|
379
|
-
self._metrics
|
|
400
|
+
if self.enable_metrics and self._metrics:
|
|
401
|
+
self._metrics.stream_errors += 1
|
|
380
402
|
return []
|
|
381
403
|
|
|
382
404
|
async def call_tool(
|
|
@@ -389,8 +411,8 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
389
411
|
tool_timeout = timeout or self.default_timeout
|
|
390
412
|
start_time = time.time()
|
|
391
413
|
|
|
392
|
-
if self.enable_metrics:
|
|
393
|
-
self._metrics
|
|
414
|
+
if self.enable_metrics and self._metrics:
|
|
415
|
+
self._metrics.total_calls += 1
|
|
394
416
|
|
|
395
417
|
try:
|
|
396
418
|
logger.debug("Calling tool '%s' with timeout %ss", tool_name, tool_timeout)
|
|
@@ -410,9 +432,44 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
410
432
|
response_time = time.time() - start_time
|
|
411
433
|
result = self._normalize_mcp_response(raw_response)
|
|
412
434
|
|
|
435
|
+
# NEW: Check for OAuth errors and attempt refresh if callback is available
|
|
436
|
+
if result.get("isError", False) and self._is_oauth_error(result.get("error", "")):
|
|
437
|
+
logger.warning("OAuth error detected: %s", result.get("error"))
|
|
438
|
+
|
|
439
|
+
if self.oauth_refresh_callback:
|
|
440
|
+
logger.debug("Attempting OAuth token refresh...")
|
|
441
|
+
try:
|
|
442
|
+
# Call the refresh callback
|
|
443
|
+
new_headers = await self.oauth_refresh_callback()
|
|
444
|
+
|
|
445
|
+
if new_headers and "Authorization" in new_headers:
|
|
446
|
+
# Update configured headers with new token
|
|
447
|
+
self.configured_headers.update(new_headers)
|
|
448
|
+
logger.debug("OAuth token refreshed, reconnecting...")
|
|
449
|
+
|
|
450
|
+
# Reconnect with new token
|
|
451
|
+
if await self._attempt_recovery():
|
|
452
|
+
logger.debug("Retrying tool call after token refresh...")
|
|
453
|
+
# Retry the tool call once with new token
|
|
454
|
+
raw_response = await asyncio.wait_for(
|
|
455
|
+
send_tools_call(self._read_stream, self._write_stream, tool_name, arguments),
|
|
456
|
+
timeout=tool_timeout,
|
|
457
|
+
)
|
|
458
|
+
result = self._normalize_mcp_response(raw_response)
|
|
459
|
+
logger.debug("Tool call retry completed")
|
|
460
|
+
else:
|
|
461
|
+
logger.error("Failed to reconnect after token refresh")
|
|
462
|
+
else:
|
|
463
|
+
logger.warning("Token refresh did not return valid Authorization header")
|
|
464
|
+
except Exception as refresh_error:
|
|
465
|
+
logger.error("OAuth token refresh failed: %s", refresh_error)
|
|
466
|
+
else:
|
|
467
|
+
logger.warning("OAuth error detected but no refresh callback configured")
|
|
468
|
+
|
|
413
469
|
# Reset failure count on success
|
|
414
|
-
|
|
415
|
-
|
|
470
|
+
if not result.get("isError", False):
|
|
471
|
+
self._consecutive_failures = 0
|
|
472
|
+
self._last_successful_ping = time.time() # Update health timestamp
|
|
416
473
|
|
|
417
474
|
if self.enable_metrics:
|
|
418
475
|
self._update_metrics(response_time, not result.get("isError", False))
|
|
@@ -438,17 +495,17 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
438
495
|
except Exception as e:
|
|
439
496
|
response_time = time.time() - start_time
|
|
440
497
|
self._consecutive_failures += 1
|
|
441
|
-
if self.enable_metrics:
|
|
498
|
+
if self.enable_metrics and self._metrics:
|
|
442
499
|
self._update_metrics(response_time, False)
|
|
443
|
-
self._metrics
|
|
500
|
+
self._metrics.stream_errors += 1
|
|
444
501
|
|
|
445
502
|
# Enhanced connection error detection
|
|
446
503
|
error_str = str(e).lower()
|
|
447
504
|
if any(indicator in error_str for indicator in ["connection", "disconnected", "broken pipe", "eof"]):
|
|
448
505
|
logger.warning("Connection error detected: %s", e)
|
|
449
506
|
self._initialized = False
|
|
450
|
-
if self.enable_metrics:
|
|
451
|
-
self._metrics
|
|
507
|
+
if self.enable_metrics and self._metrics:
|
|
508
|
+
self._metrics.connection_errors += 1
|
|
452
509
|
|
|
453
510
|
error_msg = f"Tool execution failed: {str(e)}"
|
|
454
511
|
logger.error("Tool '%s' error: %s", tool_name, error_msg)
|
|
@@ -456,14 +513,41 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
456
513
|
|
|
457
514
|
def _update_metrics(self, response_time: float, success: bool) -> None:
|
|
458
515
|
"""Enhanced metrics tracking (like SSE)."""
|
|
459
|
-
if
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
516
|
+
if not self._metrics:
|
|
517
|
+
return
|
|
518
|
+
|
|
519
|
+
self._metrics.update_call_metrics(response_time, success)
|
|
520
|
+
|
|
521
|
+
def _is_oauth_error(self, error_msg: str) -> bool:
|
|
522
|
+
"""
|
|
523
|
+
Detect if error is OAuth-related per RFC 6750 and MCP OAuth spec.
|
|
463
524
|
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
525
|
+
Checks for:
|
|
526
|
+
- RFC 6750 Section 3.1 Bearer token errors (invalid_token, insufficient_scope)
|
|
527
|
+
- OAuth 2.1 token refresh errors (invalid_grant)
|
|
528
|
+
- MCP spec OAuth validation failures (401/403 responses)
|
|
529
|
+
"""
|
|
530
|
+
if not error_msg:
|
|
531
|
+
return False
|
|
532
|
+
|
|
533
|
+
error_lower = error_msg.lower()
|
|
534
|
+
oauth_indicators = [
|
|
535
|
+
# RFC 6750 Section 3.1 - Standard Bearer token errors
|
|
536
|
+
"invalid_token", # Token expired, revoked, malformed, or invalid
|
|
537
|
+
"insufficient_scope", # Request requires higher privileges (403 Forbidden)
|
|
538
|
+
# OAuth 2.1 token refresh errors
|
|
539
|
+
"invalid_grant", # Refresh token errors
|
|
540
|
+
# MCP spec - OAuth validation failures (401 Unauthorized)
|
|
541
|
+
"oauth validation",
|
|
542
|
+
"unauthorized",
|
|
543
|
+
# Common OAuth error descriptions
|
|
544
|
+
"expired token",
|
|
545
|
+
"token expired",
|
|
546
|
+
"authentication failed",
|
|
547
|
+
"invalid access token",
|
|
548
|
+
]
|
|
549
|
+
|
|
550
|
+
return any(indicator in error_lower for indicator in oauth_indicators)
|
|
467
551
|
|
|
468
552
|
async def list_resources(self) -> dict[str, Any]:
|
|
469
553
|
"""Enhanced resource listing with error handling."""
|
|
@@ -544,7 +628,10 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
544
628
|
|
|
545
629
|
def get_metrics(self) -> dict[str, Any]:
|
|
546
630
|
"""Enhanced metrics with health information."""
|
|
547
|
-
|
|
631
|
+
if not self._metrics:
|
|
632
|
+
return {}
|
|
633
|
+
|
|
634
|
+
metrics = self._metrics.to_dict()
|
|
548
635
|
metrics.update(
|
|
549
636
|
{
|
|
550
637
|
"is_connected": self.is_connected(),
|
|
@@ -557,22 +644,20 @@ class HTTPStreamableTransport(MCPBaseTransport):
|
|
|
557
644
|
|
|
558
645
|
def reset_metrics(self) -> None:
|
|
559
646
|
"""Enhanced metrics reset preserving health state."""
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
"recovery_attempts": 0,
|
|
575
|
-
}
|
|
647
|
+
if not self._metrics:
|
|
648
|
+
return
|
|
649
|
+
|
|
650
|
+
# Preserve important historical values
|
|
651
|
+
preserved_init_time = self._metrics.initialization_time
|
|
652
|
+
preserved_last_ping = self._metrics.last_ping_time
|
|
653
|
+
preserved_resets = self._metrics.connection_resets
|
|
654
|
+
|
|
655
|
+
# Create new metrics instance with preserved values
|
|
656
|
+
self._metrics = TransportMetrics(
|
|
657
|
+
initialization_time=preserved_init_time,
|
|
658
|
+
last_ping_time=preserved_last_ping,
|
|
659
|
+
connection_resets=preserved_resets,
|
|
660
|
+
)
|
|
576
661
|
|
|
577
662
|
def get_streams(self) -> list[tuple]:
|
|
578
663
|
"""Enhanced streams access with connection check."""
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# chuk_tool_processor/mcp/transport/models.py
|
|
2
|
+
"""
|
|
3
|
+
Pydantic models for MCP transport configuration and metrics.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, Field
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TimeoutConfig(BaseModel):
|
|
14
|
+
"""
|
|
15
|
+
Unified timeout configuration for all MCP operations.
|
|
16
|
+
|
|
17
|
+
Just 4 simple, logical timeout categories:
|
|
18
|
+
- connect: Connection establishment, initialization, session discovery (30s default)
|
|
19
|
+
- operation: Normal operations like tool calls, listing resources (30s default)
|
|
20
|
+
- quick: Fast health checks and pings (5s default)
|
|
21
|
+
- shutdown: Cleanup and shutdown operations (2s default)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
connect: float = Field(
|
|
25
|
+
default=30.0, description="Timeout for connection establishment, initialization, and session discovery"
|
|
26
|
+
)
|
|
27
|
+
operation: float = Field(
|
|
28
|
+
default=30.0, description="Timeout for normal operations (tool calls, listing tools/resources/prompts)"
|
|
29
|
+
)
|
|
30
|
+
quick: float = Field(default=5.0, description="Timeout for quick health checks and pings")
|
|
31
|
+
shutdown: float = Field(default=2.0, description="Timeout for shutdown and cleanup operations")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TransportMetrics(BaseModel):
|
|
35
|
+
"""Performance and connection metrics for transports."""
|
|
36
|
+
|
|
37
|
+
model_config = {"validate_assignment": True}
|
|
38
|
+
|
|
39
|
+
total_calls: int = Field(default=0, description="Total number of calls made")
|
|
40
|
+
successful_calls: int = Field(default=0, description="Number of successful calls")
|
|
41
|
+
failed_calls: int = Field(default=0, description="Number of failed calls")
|
|
42
|
+
total_time: float = Field(default=0.0, description="Total time spent on calls")
|
|
43
|
+
avg_response_time: float = Field(default=0.0, description="Average response time")
|
|
44
|
+
last_ping_time: float | None = Field(default=None, description="Time taken for last ping")
|
|
45
|
+
initialization_time: float | None = Field(default=None, description="Time taken for initialization")
|
|
46
|
+
connection_resets: int = Field(default=0, description="Number of connection resets")
|
|
47
|
+
stream_errors: int = Field(default=0, description="Number of stream errors")
|
|
48
|
+
connection_errors: int = Field(default=0, description="Number of connection errors")
|
|
49
|
+
recovery_attempts: int = Field(default=0, description="Number of recovery attempts")
|
|
50
|
+
session_discoveries: int = Field(default=0, description="Number of session discoveries (SSE)")
|
|
51
|
+
|
|
52
|
+
def to_dict(self) -> dict[str, Any]:
|
|
53
|
+
"""Convert to dictionary format."""
|
|
54
|
+
return self.model_dump()
|
|
55
|
+
|
|
56
|
+
def update_call_metrics(self, response_time: float, success: bool) -> None:
|
|
57
|
+
"""Update metrics after a call."""
|
|
58
|
+
if success:
|
|
59
|
+
self.successful_calls += 1
|
|
60
|
+
else:
|
|
61
|
+
self.failed_calls += 1
|
|
62
|
+
|
|
63
|
+
self.total_time += response_time
|
|
64
|
+
if self.total_calls > 0:
|
|
65
|
+
self.avg_response_time = self.total_time / self.total_calls
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ServerInfo(BaseModel):
|
|
69
|
+
"""Information about a server in StreamManager."""
|
|
70
|
+
|
|
71
|
+
id: int = Field(description="Server ID")
|
|
72
|
+
name: str = Field(description="Server name")
|
|
73
|
+
tools: int = Field(description="Number of tools available")
|
|
74
|
+
status: str = Field(description="Server status (Up/Down)")
|
|
75
|
+
|
|
76
|
+
def to_dict(self) -> dict[str, Any]:
|
|
77
|
+
"""Convert to dictionary format."""
|
|
78
|
+
return self.model_dump()
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class HeadersConfig(BaseModel):
|
|
82
|
+
"""Configuration for HTTP headers."""
|
|
83
|
+
|
|
84
|
+
headers: dict[str, str] = Field(default_factory=dict, description="Custom HTTP headers")
|
|
85
|
+
|
|
86
|
+
def get_headers(self) -> dict[str, str]:
|
|
87
|
+
"""Get headers as dict."""
|
|
88
|
+
return self.headers.copy()
|
|
89
|
+
|
|
90
|
+
def update_headers(self, new_headers: dict[str, str]) -> None:
|
|
91
|
+
"""Update headers with new values."""
|
|
92
|
+
self.headers.update(new_headers)
|
|
93
|
+
|
|
94
|
+
def has_authorization(self) -> bool:
|
|
95
|
+
"""Check if Authorization header is present."""
|
|
96
|
+
return "Authorization" in self.headers
|
|
97
|
+
|
|
98
|
+
def to_dict(self) -> dict[str, Any]:
|
|
99
|
+
"""Convert to dictionary format."""
|
|
100
|
+
return self.model_dump()
|