chuk-tool-processor 0.6.13__py3-none-any.whl → 0.9.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chuk-tool-processor might be problematic. Click here for more details.

Files changed (35) hide show
  1. chuk_tool_processor/core/__init__.py +31 -0
  2. chuk_tool_processor/core/exceptions.py +218 -12
  3. chuk_tool_processor/core/processor.py +38 -7
  4. chuk_tool_processor/execution/strategies/__init__.py +6 -0
  5. chuk_tool_processor/execution/strategies/subprocess_strategy.py +2 -1
  6. chuk_tool_processor/execution/wrappers/__init__.py +42 -0
  7. chuk_tool_processor/execution/wrappers/caching.py +48 -13
  8. chuk_tool_processor/execution/wrappers/circuit_breaker.py +370 -0
  9. chuk_tool_processor/execution/wrappers/rate_limiting.py +31 -1
  10. chuk_tool_processor/execution/wrappers/retry.py +93 -53
  11. chuk_tool_processor/logging/metrics.py +2 -2
  12. chuk_tool_processor/mcp/mcp_tool.py +5 -5
  13. chuk_tool_processor/mcp/setup_mcp_http_streamable.py +44 -2
  14. chuk_tool_processor/mcp/setup_mcp_sse.py +44 -2
  15. chuk_tool_processor/mcp/setup_mcp_stdio.py +2 -0
  16. chuk_tool_processor/mcp/stream_manager.py +130 -75
  17. chuk_tool_processor/mcp/transport/__init__.py +10 -0
  18. chuk_tool_processor/mcp/transport/http_streamable_transport.py +193 -108
  19. chuk_tool_processor/mcp/transport/models.py +100 -0
  20. chuk_tool_processor/mcp/transport/sse_transport.py +155 -59
  21. chuk_tool_processor/mcp/transport/stdio_transport.py +58 -10
  22. chuk_tool_processor/models/__init__.py +20 -0
  23. chuk_tool_processor/models/tool_call.py +34 -1
  24. chuk_tool_processor/models/tool_spec.py +350 -0
  25. chuk_tool_processor/models/validated_tool.py +22 -2
  26. chuk_tool_processor/observability/__init__.py +30 -0
  27. chuk_tool_processor/observability/metrics.py +312 -0
  28. chuk_tool_processor/observability/setup.py +105 -0
  29. chuk_tool_processor/observability/tracing.py +345 -0
  30. chuk_tool_processor/plugins/discovery.py +1 -1
  31. chuk_tool_processor-0.9.7.dist-info/METADATA +1813 -0
  32. {chuk_tool_processor-0.6.13.dist-info → chuk_tool_processor-0.9.7.dist-info}/RECORD +34 -27
  33. chuk_tool_processor-0.6.13.dist-info/METADATA +0 -698
  34. {chuk_tool_processor-0.6.13.dist-info → chuk_tool_processor-0.9.7.dist-info}/WHEEL +0 -0
  35. {chuk_tool_processor-0.6.13.dist-info → chuk_tool_processor-0.9.7.dist-info}/top_level.txt +0 -0
@@ -16,12 +16,15 @@ from chuk_mcp.protocol.messages import ( # type: ignore[import-untyped]
16
16
  send_tools_call,
17
17
  send_tools_list,
18
18
  )
19
+ from chuk_mcp.transports.http.parameters import StreamableHTTPParameters # type: ignore[import-untyped]
19
20
 
20
21
  # Import chuk-mcp HTTP transport components
21
- from chuk_mcp.transports.http import http_client # type: ignore[import-untyped]
22
- from chuk_mcp.transports.http.parameters import StreamableHTTPParameters # type: ignore[import-untyped]
22
+ from chuk_mcp.transports.http.transport import (
23
+ StreamableHTTPTransport as ChukHTTPTransport, # type: ignore[import-untyped]
24
+ )
23
25
 
24
26
  from .base_transport import MCPBaseTransport
27
+ from .models import TimeoutConfig, TransportMetrics
25
28
 
26
29
  logger = logging.getLogger(__name__)
27
30
 
@@ -38,11 +41,13 @@ class HTTPStreamableTransport(MCPBaseTransport):
38
41
  self,
39
42
  url: str,
40
43
  api_key: str | None = None,
41
- headers: dict[str, str] | None = None, # NEW: Headers support
44
+ headers: dict[str, str] | None = None,
42
45
  connection_timeout: float = 30.0,
43
46
  default_timeout: float = 30.0,
44
47
  session_id: str | None = None,
45
48
  enable_metrics: bool = True,
49
+ oauth_refresh_callback: Any | None = None,
50
+ timeout_config: TimeoutConfig | None = None,
46
51
  ):
47
52
  """
48
53
  Initialize HTTP Streamable transport with enhanced configuration.
@@ -50,11 +55,13 @@ class HTTPStreamableTransport(MCPBaseTransport):
50
55
  Args:
51
56
  url: HTTP server URL (should end with /mcp)
52
57
  api_key: Optional API key for authentication
53
- headers: Optional custom headers (NEW)
54
- connection_timeout: Timeout for initial connection
55
- default_timeout: Default timeout for operations
58
+ headers: Optional custom headers
59
+ connection_timeout: Timeout for initial connection (overrides timeout_config.connect)
60
+ default_timeout: Default timeout for operations (overrides timeout_config.operation)
56
61
  session_id: Optional session ID for stateful connections
57
62
  enable_metrics: Whether to track performance metrics
63
+ oauth_refresh_callback: Optional async callback to refresh OAuth tokens
64
+ timeout_config: Optional timeout configuration model with connect/operation/quick/shutdown
58
65
  """
59
66
  # Ensure URL points to the /mcp endpoint
60
67
  if not url.endswith("/mcp"):
@@ -63,11 +70,18 @@ class HTTPStreamableTransport(MCPBaseTransport):
63
70
  self.url = url
64
71
 
65
72
  self.api_key = api_key
66
- self.configured_headers = headers or {} # NEW: Store configured headers
67
- self.connection_timeout = connection_timeout
68
- self.default_timeout = default_timeout
73
+ self.configured_headers = headers or {}
69
74
  self.session_id = session_id
70
75
  self.enable_metrics = enable_metrics
76
+ self.oauth_refresh_callback = oauth_refresh_callback
77
+
78
+ # Use timeout config or create from individual parameters
79
+ if timeout_config is None:
80
+ timeout_config = TimeoutConfig(connect=connection_timeout, operation=default_timeout)
81
+
82
+ self.timeout_config = timeout_config
83
+ self.connection_timeout = timeout_config.connect
84
+ self.default_timeout = timeout_config.operation
71
85
 
72
86
  logger.debug("HTTP Streamable transport initialized with URL: %s", self.url)
73
87
  if self.api_key:
@@ -78,7 +92,7 @@ class HTTPStreamableTransport(MCPBaseTransport):
78
92
  logger.debug("Session ID configured: %s", self.session_id)
79
93
 
80
94
  # State tracking (enhanced like SSE)
81
- self._http_context = None
95
+ self._http_transport = None
82
96
  self._read_stream = None
83
97
  self._write_stream = None
84
98
  self._initialized = False
@@ -88,20 +102,8 @@ class HTTPStreamableTransport(MCPBaseTransport):
88
102
  self._consecutive_failures = 0
89
103
  self._max_consecutive_failures = 3
90
104
 
91
- # Performance metrics (enhanced like SSE)
92
- self._metrics = {
93
- "total_calls": 0,
94
- "successful_calls": 0,
95
- "failed_calls": 0,
96
- "total_time": 0.0,
97
- "avg_response_time": 0.0,
98
- "last_ping_time": None,
99
- "initialization_time": None,
100
- "connection_resets": 0,
101
- "stream_errors": 0,
102
- "connection_errors": 0, # NEW
103
- "recovery_attempts": 0, # NEW
104
- }
105
+ # Performance metrics (enhanced like SSE) - use Pydantic model
106
+ self._metrics = TransportMetrics() if enable_metrics else None
105
107
 
106
108
  def _get_headers(self) -> dict[str, str]:
107
109
  """Get headers with authentication and custom headers (like SSE)."""
@@ -115,8 +117,9 @@ class HTTPStreamableTransport(MCPBaseTransport):
115
117
  if self.configured_headers:
116
118
  headers.update(self.configured_headers)
117
119
 
118
- # Add API key as Bearer token if provided
119
- if self.api_key:
120
+ # Add API key as Bearer token if provided and no Authorization header exists
121
+ # This prevents clobbering OAuth tokens from configured_headers
122
+ if self.api_key and "Authorization" not in headers:
120
123
  headers["Authorization"] = f"Bearer {self.api_key}"
121
124
 
122
125
  # Add session ID if provided
@@ -130,7 +133,7 @@ class HTTPStreamableTransport(MCPBaseTransport):
130
133
  try:
131
134
  import httpx
132
135
 
133
- async with httpx.AsyncClient(timeout=5.0) as client:
136
+ async with httpx.AsyncClient(timeout=self.timeout_config.quick) as client:
134
137
  # Test basic connectivity to base URL
135
138
  base_url = self.url.replace("/mcp", "")
136
139
  response = await client.get(f"{base_url}/health", headers=self._get_headers())
@@ -159,33 +162,38 @@ class HTTPStreamableTransport(MCPBaseTransport):
159
162
  headers = self._get_headers()
160
163
  logger.debug("Using headers: %s", list(headers.keys()))
161
164
 
162
- # Create StreamableHTTPParameters with proper configuration
165
+ # Create StreamableHTTPParameters with minimal configuration
166
+ # NOTE: Keep params minimal - extra params can break message routing
163
167
  http_params = StreamableHTTPParameters(
164
168
  url=self.url,
165
169
  timeout=self.default_timeout,
166
170
  headers=headers,
167
- bearer_token=None, # Don't duplicate auth - it's in headers
168
- session_id=self.session_id,
169
171
  enable_streaming=True,
170
- max_concurrent_requests=5,
171
- max_retries=2,
172
- retry_delay=1.0,
173
- user_agent="chuk-tool-processor/1.0.0",
174
172
  )
175
173
 
176
- # Create and enter the HTTP context
177
- self._http_context = http_client(http_params)
174
+ # Create and store transport (will be managed via async with in parent scope)
175
+ self._http_transport = ChukHTTPTransport(http_params)
178
176
 
177
+ # IMPORTANT: Must use async with for proper stream setup
179
178
  logger.debug("Establishing HTTP connection...")
180
- self._read_stream, self._write_stream = await asyncio.wait_for(
181
- self._http_context.__aenter__(), timeout=self.connection_timeout
179
+ self._http_context_entered = await asyncio.wait_for(
180
+ self._http_transport.__aenter__(), timeout=self.connection_timeout
182
181
  )
183
182
 
183
+ # Get streams after context entered
184
+ self._read_stream, self._write_stream = await self._http_transport.get_streams()
185
+
186
+ # Give the transport's message handler task time to start
187
+ await asyncio.sleep(0.1)
188
+
184
189
  # Enhanced MCP initialize sequence
185
190
  logger.debug("Sending MCP initialize request...")
186
191
  init_start = time.time()
187
192
 
188
- await asyncio.wait_for(send_initialize(self._read_stream, self._write_stream), timeout=self.default_timeout)
193
+ await asyncio.wait_for(
194
+ send_initialize(self._read_stream, self._write_stream, timeout=self.default_timeout),
195
+ timeout=self.default_timeout,
196
+ )
189
197
 
190
198
  init_time = time.time() - init_start
191
199
  logger.debug("MCP initialize completed in %.3fs", init_time)
@@ -193,9 +201,11 @@ class HTTPStreamableTransport(MCPBaseTransport):
193
201
  # Verify connection with ping (enhanced like SSE)
194
202
  logger.debug("Verifying connection with ping...")
195
203
  ping_start = time.time()
204
+ # Use connect timeout for initial ping - some servers (like Notion) are slow
205
+ ping_timeout = self.timeout_config.connect
196
206
  ping_success = await asyncio.wait_for(
197
- send_ping(self._read_stream, self._write_stream),
198
- timeout=10.0, # Longer timeout for initial ping
207
+ send_ping(self._read_stream, self._write_stream, timeout=ping_timeout),
208
+ timeout=ping_timeout,
199
209
  )
200
210
  ping_time = time.time() - ping_start
201
211
 
@@ -205,9 +215,9 @@ class HTTPStreamableTransport(MCPBaseTransport):
205
215
  self._consecutive_failures = 0
206
216
 
207
217
  total_init_time = time.time() - start_time
208
- if self.enable_metrics:
209
- self._metrics["initialization_time"] = total_init_time
210
- self._metrics["last_ping_time"] = ping_time
218
+ if self.enable_metrics and self._metrics:
219
+ self._metrics.initialization_time = total_init_time
220
+ self._metrics.last_ping_time = ping_time
211
221
 
212
222
  logger.debug(
213
223
  "HTTP Streamable transport initialized successfully in %.3fs (ping: %.3fs)",
@@ -216,31 +226,31 @@ class HTTPStreamableTransport(MCPBaseTransport):
216
226
  )
217
227
  return True
218
228
  else:
219
- logger.warning("HTTP connection established but ping failed")
229
+ logger.debug("HTTP connection established but ping failed")
220
230
  # Still consider it initialized since connection was established
221
231
  self._initialized = True
222
232
  self._consecutive_failures = 1 # Mark one failure
223
- if self.enable_metrics:
224
- self._metrics["initialization_time"] = time.time() - start_time
233
+ if self.enable_metrics and self._metrics:
234
+ self._metrics.initialization_time = time.time() - start_time
225
235
  return True
226
236
 
227
237
  except TimeoutError:
228
238
  logger.error("HTTP Streamable initialization timed out after %ss", self.connection_timeout)
229
239
  await self._cleanup()
230
- if self.enable_metrics:
231
- self._metrics["connection_errors"] += 1
232
- return False
240
+ if self.enable_metrics and self._metrics:
241
+ self._metrics.connection_errors += 1
242
+ raise # Re-raise for OAuth error detection in mcp-cli
233
243
  except Exception as e:
234
244
  logger.error("Error initializing HTTP Streamable transport: %s", e, exc_info=True)
235
245
  await self._cleanup()
236
- if self.enable_metrics:
237
- self._metrics["connection_errors"] += 1
238
- return False
246
+ if self.enable_metrics and self._metrics:
247
+ self._metrics.connection_errors += 1
248
+ raise # Re-raise for OAuth error detection in mcp-cli
239
249
 
240
250
  async def _attempt_recovery(self) -> bool:
241
251
  """Attempt to recover from connection issues (NEW - like SSE resilience)."""
242
- if self.enable_metrics:
243
- self._metrics["recovery_attempts"] += 1
252
+ if self.enable_metrics and self._metrics:
253
+ self._metrics.recovery_attempts += 1
244
254
 
245
255
  logger.debug("Attempting HTTP connection recovery...")
246
256
 
@@ -260,21 +270,21 @@ class HTTPStreamableTransport(MCPBaseTransport):
260
270
  return
261
271
 
262
272
  # Enhanced metrics logging (like SSE)
263
- if self.enable_metrics and self._metrics["total_calls"] > 0:
264
- success_rate = self._metrics["successful_calls"] / self._metrics["total_calls"] * 100
273
+ if self.enable_metrics and self._metrics and self._metrics.total_calls > 0:
274
+ success_rate = self._metrics.successful_calls / self._metrics.total_calls * 100
265
275
  logger.debug(
266
276
  "HTTP Streamable transport closing - Calls: %d, Success: %.1f%%, "
267
277
  "Avg time: %.3fs, Recoveries: %d, Errors: %d",
268
- self._metrics["total_calls"],
278
+ self._metrics.total_calls,
269
279
  success_rate,
270
- self._metrics["avg_response_time"],
271
- self._metrics["recovery_attempts"],
272
- self._metrics["connection_errors"],
280
+ self._metrics.avg_response_time,
281
+ self._metrics.recovery_attempts,
282
+ self._metrics.connection_errors,
273
283
  )
274
284
 
275
285
  try:
276
- if self._http_context is not None:
277
- await self._http_context.__aexit__(None, None, None)
286
+ if self._http_transport is not None:
287
+ await self._http_transport.__aexit__(None, None, None)
278
288
  logger.debug("HTTP Streamable context closed")
279
289
 
280
290
  except Exception as e:
@@ -284,7 +294,7 @@ class HTTPStreamableTransport(MCPBaseTransport):
284
294
 
285
295
  async def _cleanup(self) -> None:
286
296
  """Enhanced cleanup with state reset."""
287
- self._http_context = None
297
+ self._http_transport = None
288
298
  self._read_stream = None
289
299
  self._write_stream = None
290
300
  self._initialized = False
@@ -292,13 +302,14 @@ class HTTPStreamableTransport(MCPBaseTransport):
292
302
  async def send_ping(self) -> bool:
293
303
  """Enhanced ping with health monitoring (like SSE)."""
294
304
  if not self._initialized or not self._read_stream:
295
- logger.error("Cannot send ping: transport not initialized")
305
+ logger.debug("Cannot send ping: transport not initialized")
296
306
  return False
297
307
 
298
308
  start_time = time.time()
299
309
  try:
300
310
  result = await asyncio.wait_for(
301
- send_ping(self._read_stream, self._write_stream), timeout=self.default_timeout
311
+ send_ping(self._read_stream, self._write_stream, timeout=self.default_timeout),
312
+ timeout=self.default_timeout,
302
313
  )
303
314
 
304
315
  success = bool(result)
@@ -309,9 +320,9 @@ class HTTPStreamableTransport(MCPBaseTransport):
309
320
  else:
310
321
  self._consecutive_failures += 1
311
322
 
312
- if self.enable_metrics:
323
+ if self.enable_metrics and self._metrics:
313
324
  ping_time = time.time() - start_time
314
- self._metrics["last_ping_time"] = ping_time
325
+ self._metrics.last_ping_time = ping_time
315
326
  logger.debug("HTTP Streamable ping completed in %.3fs: %s", ping_time, success)
316
327
 
317
328
  return success
@@ -322,8 +333,8 @@ class HTTPStreamableTransport(MCPBaseTransport):
322
333
  except Exception as e:
323
334
  logger.error("HTTP Streamable ping failed: %s", e)
324
335
  self._consecutive_failures += 1
325
- if self.enable_metrics:
326
- self._metrics["stream_errors"] += 1
336
+ if self.enable_metrics and self._metrics:
337
+ self._metrics.stream_errors += 1
327
338
  return False
328
339
 
329
340
  def is_connected(self) -> bool:
@@ -341,18 +352,29 @@ class HTTPStreamableTransport(MCPBaseTransport):
341
352
  async def get_tools(self) -> list[dict[str, Any]]:
342
353
  """Enhanced tools retrieval with error handling."""
343
354
  if not self._initialized:
344
- logger.error("Cannot get tools: transport not initialized")
355
+ logger.debug("Cannot get tools: transport not initialized")
345
356
  return []
346
357
 
347
358
  start_time = time.time()
348
359
  try:
349
360
  tools_response = await asyncio.wait_for(
350
- send_tools_list(self._read_stream, self._write_stream), timeout=self.default_timeout
361
+ send_tools_list(self._read_stream, self._write_stream, timeout=self.default_timeout),
362
+ timeout=self.default_timeout,
351
363
  )
352
364
 
353
- # Normalize response
354
- if isinstance(tools_response, dict):
365
+ # Normalize response - handle multiple formats including Pydantic models
366
+ # 1. Check if it's a Pydantic model with tools attribute (e.g., ListToolsResult from chuk_mcp)
367
+ if hasattr(tools_response, "tools"):
368
+ tools = tools_response.tools
369
+ # Convert Pydantic Tool models to dicts if needed
370
+ if tools and len(tools) > 0 and hasattr(tools[0], "model_dump"):
371
+ tools = [t.model_dump() for t in tools]
372
+ elif tools and len(tools) > 0 and hasattr(tools[0], "dict"):
373
+ tools = [t.dict() for t in tools]
374
+ # 2. Check if it's a dict with "tools" key
375
+ elif isinstance(tools_response, dict):
355
376
  tools = tools_response.get("tools", [])
377
+ # 3. Check if it's already a list
356
378
  elif isinstance(tools_response, list):
357
379
  tools = tools_response
358
380
  else:
@@ -375,8 +397,8 @@ class HTTPStreamableTransport(MCPBaseTransport):
375
397
  except Exception as e:
376
398
  logger.error("Error getting tools: %s", e)
377
399
  self._consecutive_failures += 1
378
- if self.enable_metrics:
379
- self._metrics["stream_errors"] += 1
400
+ if self.enable_metrics and self._metrics:
401
+ self._metrics.stream_errors += 1
380
402
  return []
381
403
 
382
404
  async def call_tool(
@@ -389,8 +411,8 @@ class HTTPStreamableTransport(MCPBaseTransport):
389
411
  tool_timeout = timeout or self.default_timeout
390
412
  start_time = time.time()
391
413
 
392
- if self.enable_metrics:
393
- self._metrics["total_calls"] += 1
414
+ if self.enable_metrics and self._metrics:
415
+ self._metrics.total_calls += 1
394
416
 
395
417
  try:
396
418
  logger.debug("Calling tool '%s' with timeout %ss", tool_name, tool_timeout)
@@ -410,9 +432,44 @@ class HTTPStreamableTransport(MCPBaseTransport):
410
432
  response_time = time.time() - start_time
411
433
  result = self._normalize_mcp_response(raw_response)
412
434
 
435
+ # NEW: Check for OAuth errors and attempt refresh if callback is available
436
+ if result.get("isError", False) and self._is_oauth_error(result.get("error", "")):
437
+ logger.warning("OAuth error detected: %s", result.get("error"))
438
+
439
+ if self.oauth_refresh_callback:
440
+ logger.debug("Attempting OAuth token refresh...")
441
+ try:
442
+ # Call the refresh callback
443
+ new_headers = await self.oauth_refresh_callback()
444
+
445
+ if new_headers and "Authorization" in new_headers:
446
+ # Update configured headers with new token
447
+ self.configured_headers.update(new_headers)
448
+ logger.debug("OAuth token refreshed, reconnecting...")
449
+
450
+ # Reconnect with new token
451
+ if await self._attempt_recovery():
452
+ logger.debug("Retrying tool call after token refresh...")
453
+ # Retry the tool call once with new token
454
+ raw_response = await asyncio.wait_for(
455
+ send_tools_call(self._read_stream, self._write_stream, tool_name, arguments),
456
+ timeout=tool_timeout,
457
+ )
458
+ result = self._normalize_mcp_response(raw_response)
459
+ logger.debug("Tool call retry completed")
460
+ else:
461
+ logger.error("Failed to reconnect after token refresh")
462
+ else:
463
+ logger.warning("Token refresh did not return valid Authorization header")
464
+ except Exception as refresh_error:
465
+ logger.error("OAuth token refresh failed: %s", refresh_error)
466
+ else:
467
+ logger.warning("OAuth error detected but no refresh callback configured")
468
+
413
469
  # Reset failure count on success
414
- self._consecutive_failures = 0
415
- self._last_successful_ping = time.time() # Update health timestamp
470
+ if not result.get("isError", False):
471
+ self._consecutive_failures = 0
472
+ self._last_successful_ping = time.time() # Update health timestamp
416
473
 
417
474
  if self.enable_metrics:
418
475
  self._update_metrics(response_time, not result.get("isError", False))
@@ -438,17 +495,17 @@ class HTTPStreamableTransport(MCPBaseTransport):
438
495
  except Exception as e:
439
496
  response_time = time.time() - start_time
440
497
  self._consecutive_failures += 1
441
- if self.enable_metrics:
498
+ if self.enable_metrics and self._metrics:
442
499
  self._update_metrics(response_time, False)
443
- self._metrics["stream_errors"] += 1
500
+ self._metrics.stream_errors += 1
444
501
 
445
502
  # Enhanced connection error detection
446
503
  error_str = str(e).lower()
447
504
  if any(indicator in error_str for indicator in ["connection", "disconnected", "broken pipe", "eof"]):
448
505
  logger.warning("Connection error detected: %s", e)
449
506
  self._initialized = False
450
- if self.enable_metrics:
451
- self._metrics["connection_errors"] += 1
507
+ if self.enable_metrics and self._metrics:
508
+ self._metrics.connection_errors += 1
452
509
 
453
510
  error_msg = f"Tool execution failed: {str(e)}"
454
511
  logger.error("Tool '%s' error: %s", tool_name, error_msg)
@@ -456,14 +513,41 @@ class HTTPStreamableTransport(MCPBaseTransport):
456
513
 
457
514
  def _update_metrics(self, response_time: float, success: bool) -> None:
458
515
  """Enhanced metrics tracking (like SSE)."""
459
- if success:
460
- self._metrics["successful_calls"] += 1
461
- else:
462
- self._metrics["failed_calls"] += 1
516
+ if not self._metrics:
517
+ return
518
+
519
+ self._metrics.update_call_metrics(response_time, success)
520
+
521
+ def _is_oauth_error(self, error_msg: str) -> bool:
522
+ """
523
+ Detect if error is OAuth-related per RFC 6750 and MCP OAuth spec.
463
524
 
464
- self._metrics["total_time"] += response_time
465
- if self._metrics["total_calls"] > 0:
466
- self._metrics["avg_response_time"] = self._metrics["total_time"] / self._metrics["total_calls"]
525
+ Checks for:
526
+ - RFC 6750 Section 3.1 Bearer token errors (invalid_token, insufficient_scope)
527
+ - OAuth 2.1 token refresh errors (invalid_grant)
528
+ - MCP spec OAuth validation failures (401/403 responses)
529
+ """
530
+ if not error_msg:
531
+ return False
532
+
533
+ error_lower = error_msg.lower()
534
+ oauth_indicators = [
535
+ # RFC 6750 Section 3.1 - Standard Bearer token errors
536
+ "invalid_token", # Token expired, revoked, malformed, or invalid
537
+ "insufficient_scope", # Request requires higher privileges (403 Forbidden)
538
+ # OAuth 2.1 token refresh errors
539
+ "invalid_grant", # Refresh token errors
540
+ # MCP spec - OAuth validation failures (401 Unauthorized)
541
+ "oauth validation",
542
+ "unauthorized",
543
+ # Common OAuth error descriptions
544
+ "expired token",
545
+ "token expired",
546
+ "authentication failed",
547
+ "invalid access token",
548
+ ]
549
+
550
+ return any(indicator in error_lower for indicator in oauth_indicators)
467
551
 
468
552
  async def list_resources(self) -> dict[str, Any]:
469
553
  """Enhanced resource listing with error handling."""
@@ -544,7 +628,10 @@ class HTTPStreamableTransport(MCPBaseTransport):
544
628
 
545
629
  def get_metrics(self) -> dict[str, Any]:
546
630
  """Enhanced metrics with health information."""
547
- metrics = self._metrics.copy()
631
+ if not self._metrics:
632
+ return {}
633
+
634
+ metrics = self._metrics.to_dict()
548
635
  metrics.update(
549
636
  {
550
637
  "is_connected": self.is_connected(),
@@ -557,22 +644,20 @@ class HTTPStreamableTransport(MCPBaseTransport):
557
644
 
558
645
  def reset_metrics(self) -> None:
559
646
  """Enhanced metrics reset preserving health state."""
560
- preserved_init_time = self._metrics.get("initialization_time")
561
- preserved_last_ping = self._metrics.get("last_ping_time")
562
-
563
- self._metrics = {
564
- "total_calls": 0,
565
- "successful_calls": 0,
566
- "failed_calls": 0,
567
- "total_time": 0.0,
568
- "avg_response_time": 0.0,
569
- "last_ping_time": preserved_last_ping,
570
- "initialization_time": preserved_init_time,
571
- "connection_resets": self._metrics.get("connection_resets", 0),
572
- "stream_errors": 0,
573
- "connection_errors": 0,
574
- "recovery_attempts": 0,
575
- }
647
+ if not self._metrics:
648
+ return
649
+
650
+ # Preserve important historical values
651
+ preserved_init_time = self._metrics.initialization_time
652
+ preserved_last_ping = self._metrics.last_ping_time
653
+ preserved_resets = self._metrics.connection_resets
654
+
655
+ # Create new metrics instance with preserved values
656
+ self._metrics = TransportMetrics(
657
+ initialization_time=preserved_init_time,
658
+ last_ping_time=preserved_last_ping,
659
+ connection_resets=preserved_resets,
660
+ )
576
661
 
577
662
  def get_streams(self) -> list[tuple]:
578
663
  """Enhanced streams access with connection check."""
@@ -0,0 +1,100 @@
1
+ # chuk_tool_processor/mcp/transport/models.py
2
+ """
3
+ Pydantic models for MCP transport configuration and metrics.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import Any
9
+
10
+ from pydantic import BaseModel, Field
11
+
12
+
13
+ class TimeoutConfig(BaseModel):
14
+ """
15
+ Unified timeout configuration for all MCP operations.
16
+
17
+ Just 4 simple, logical timeout categories:
18
+ - connect: Connection establishment, initialization, session discovery (30s default)
19
+ - operation: Normal operations like tool calls, listing resources (30s default)
20
+ - quick: Fast health checks and pings (5s default)
21
+ - shutdown: Cleanup and shutdown operations (2s default)
22
+ """
23
+
24
+ connect: float = Field(
25
+ default=30.0, description="Timeout for connection establishment, initialization, and session discovery"
26
+ )
27
+ operation: float = Field(
28
+ default=30.0, description="Timeout for normal operations (tool calls, listing tools/resources/prompts)"
29
+ )
30
+ quick: float = Field(default=5.0, description="Timeout for quick health checks and pings")
31
+ shutdown: float = Field(default=2.0, description="Timeout for shutdown and cleanup operations")
32
+
33
+
34
+ class TransportMetrics(BaseModel):
35
+ """Performance and connection metrics for transports."""
36
+
37
+ model_config = {"validate_assignment": True}
38
+
39
+ total_calls: int = Field(default=0, description="Total number of calls made")
40
+ successful_calls: int = Field(default=0, description="Number of successful calls")
41
+ failed_calls: int = Field(default=0, description="Number of failed calls")
42
+ total_time: float = Field(default=0.0, description="Total time spent on calls")
43
+ avg_response_time: float = Field(default=0.0, description="Average response time")
44
+ last_ping_time: float | None = Field(default=None, description="Time taken for last ping")
45
+ initialization_time: float | None = Field(default=None, description="Time taken for initialization")
46
+ connection_resets: int = Field(default=0, description="Number of connection resets")
47
+ stream_errors: int = Field(default=0, description="Number of stream errors")
48
+ connection_errors: int = Field(default=0, description="Number of connection errors")
49
+ recovery_attempts: int = Field(default=0, description="Number of recovery attempts")
50
+ session_discoveries: int = Field(default=0, description="Number of session discoveries (SSE)")
51
+
52
+ def to_dict(self) -> dict[str, Any]:
53
+ """Convert to dictionary format."""
54
+ return self.model_dump()
55
+
56
+ def update_call_metrics(self, response_time: float, success: bool) -> None:
57
+ """Update metrics after a call."""
58
+ if success:
59
+ self.successful_calls += 1
60
+ else:
61
+ self.failed_calls += 1
62
+
63
+ self.total_time += response_time
64
+ if self.total_calls > 0:
65
+ self.avg_response_time = self.total_time / self.total_calls
66
+
67
+
68
+ class ServerInfo(BaseModel):
69
+ """Information about a server in StreamManager."""
70
+
71
+ id: int = Field(description="Server ID")
72
+ name: str = Field(description="Server name")
73
+ tools: int = Field(description="Number of tools available")
74
+ status: str = Field(description="Server status (Up/Down)")
75
+
76
+ def to_dict(self) -> dict[str, Any]:
77
+ """Convert to dictionary format."""
78
+ return self.model_dump()
79
+
80
+
81
+ class HeadersConfig(BaseModel):
82
+ """Configuration for HTTP headers."""
83
+
84
+ headers: dict[str, str] = Field(default_factory=dict, description="Custom HTTP headers")
85
+
86
+ def get_headers(self) -> dict[str, str]:
87
+ """Get headers as dict."""
88
+ return self.headers.copy()
89
+
90
+ def update_headers(self, new_headers: dict[str, str]) -> None:
91
+ """Update headers with new values."""
92
+ self.headers.update(new_headers)
93
+
94
+ def has_authorization(self) -> bool:
95
+ """Check if Authorization header is present."""
96
+ return "Authorization" in self.headers
97
+
98
+ def to_dict(self) -> dict[str, Any]:
99
+ """Convert to dictionary format."""
100
+ return self.model_dump()