chuk-tool-processor 0.6.13__py3-none-any.whl → 0.9.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chuk-tool-processor might be problematic. Click here for more details.

Files changed (35) hide show
  1. chuk_tool_processor/core/__init__.py +31 -0
  2. chuk_tool_processor/core/exceptions.py +218 -12
  3. chuk_tool_processor/core/processor.py +38 -7
  4. chuk_tool_processor/execution/strategies/__init__.py +6 -0
  5. chuk_tool_processor/execution/strategies/subprocess_strategy.py +2 -1
  6. chuk_tool_processor/execution/wrappers/__init__.py +42 -0
  7. chuk_tool_processor/execution/wrappers/caching.py +48 -13
  8. chuk_tool_processor/execution/wrappers/circuit_breaker.py +370 -0
  9. chuk_tool_processor/execution/wrappers/rate_limiting.py +31 -1
  10. chuk_tool_processor/execution/wrappers/retry.py +93 -53
  11. chuk_tool_processor/logging/metrics.py +2 -2
  12. chuk_tool_processor/mcp/mcp_tool.py +5 -5
  13. chuk_tool_processor/mcp/setup_mcp_http_streamable.py +44 -2
  14. chuk_tool_processor/mcp/setup_mcp_sse.py +44 -2
  15. chuk_tool_processor/mcp/setup_mcp_stdio.py +2 -0
  16. chuk_tool_processor/mcp/stream_manager.py +130 -75
  17. chuk_tool_processor/mcp/transport/__init__.py +10 -0
  18. chuk_tool_processor/mcp/transport/http_streamable_transport.py +193 -108
  19. chuk_tool_processor/mcp/transport/models.py +100 -0
  20. chuk_tool_processor/mcp/transport/sse_transport.py +155 -59
  21. chuk_tool_processor/mcp/transport/stdio_transport.py +58 -10
  22. chuk_tool_processor/models/__init__.py +20 -0
  23. chuk_tool_processor/models/tool_call.py +34 -1
  24. chuk_tool_processor/models/tool_spec.py +350 -0
  25. chuk_tool_processor/models/validated_tool.py +22 -2
  26. chuk_tool_processor/observability/__init__.py +30 -0
  27. chuk_tool_processor/observability/metrics.py +312 -0
  28. chuk_tool_processor/observability/setup.py +105 -0
  29. chuk_tool_processor/observability/tracing.py +345 -0
  30. chuk_tool_processor/plugins/discovery.py +1 -1
  31. chuk_tool_processor-0.9.7.dist-info/METADATA +1813 -0
  32. {chuk_tool_processor-0.6.13.dist-info → chuk_tool_processor-0.9.7.dist-info}/RECORD +34 -27
  33. chuk_tool_processor-0.6.13.dist-info/METADATA +0 -698
  34. {chuk_tool_processor-0.6.13.dist-info → chuk_tool_processor-0.9.7.dist-info}/WHEEL +0 -0
  35. {chuk_tool_processor-0.6.13.dist-info → chuk_tool_processor-0.9.7.dist-info}/top_level.txt +0 -0
@@ -34,15 +34,17 @@ async def setup_mcp_http_streamable(
34
34
  server_names: dict[int, str] | None = None,
35
35
  connection_timeout: float = 30.0,
36
36
  default_timeout: float = 30.0,
37
+ initialization_timeout: float = 60.0,
37
38
  max_concurrency: int | None = None,
38
39
  enable_caching: bool = True,
39
40
  cache_ttl: int = 300,
40
41
  enable_rate_limiting: bool = False,
41
42
  global_rate_limit: int | None = None,
42
43
  tool_rate_limits: dict[str, tuple] | None = None,
43
- enable_retries: bool = True,
44
- max_retries: int = 3,
44
+ enable_retries: bool = True, # CHANGED: Enabled with OAuth errors excluded
45
+ max_retries: int = 2, # Retry non-OAuth errors (OAuth handled at transport level)
45
46
  namespace: str = "http",
47
+ oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
46
48
  ) -> tuple[ToolProcessor, StreamManager]:
47
49
  """
48
50
  Initialize HTTP Streamable transport MCP + a :class:`ToolProcessor`.
@@ -58,6 +60,7 @@ async def setup_mcp_http_streamable(
58
60
  server_names: Optional mapping of server indices to names
59
61
  connection_timeout: Timeout for initial HTTP connection setup
60
62
  default_timeout: Default timeout for tool execution
63
+ initialization_timeout: Timeout for complete initialization (default 60s, increase to 120s+ for slow servers like Notion)
61
64
  max_concurrency: Maximum concurrent operations
62
65
  enable_caching: Whether to enable response caching
63
66
  cache_ttl: Cache time-to-live in seconds
@@ -67,6 +70,7 @@ async def setup_mcp_http_streamable(
67
70
  enable_retries: Whether to enable automatic retries
68
71
  max_retries: Maximum retry attempts
69
72
  namespace: Namespace for registered tools
73
+ oauth_refresh_callback: Optional async callback to refresh OAuth tokens (NEW)
70
74
 
71
75
  Returns:
72
76
  Tuple of (ToolProcessor, StreamManager)
@@ -90,12 +94,49 @@ async def setup_mcp_http_streamable(
90
94
  server_names=server_names,
91
95
  connection_timeout=connection_timeout,
92
96
  default_timeout=default_timeout,
97
+ initialization_timeout=initialization_timeout,
98
+ oauth_refresh_callback=oauth_refresh_callback, # NEW: Pass OAuth callback
93
99
  )
94
100
 
95
101
  # 2️⃣ pull the remote tool list and register each one locally
96
102
  registered = await register_mcp_tools(stream_manager, namespace=namespace)
97
103
 
98
104
  # 3️⃣ build a processor instance configured to your taste
105
+ # IMPORTANT: Retries are enabled but OAuth errors are excluded
106
+ # OAuth refresh happens at transport level with automatic retry
107
+
108
+ # Import RetryConfig to configure OAuth error exclusion
109
+ from chuk_tool_processor.execution.wrappers.retry import RetryConfig
110
+
111
+ # Define OAuth error patterns that should NOT be retried at this level
112
+ # These will be handled by the transport layer's OAuth refresh mechanism
113
+ # Based on RFC 6750 (Bearer Token Usage) and MCP OAuth spec
114
+ oauth_error_patterns = [
115
+ # RFC 6750 Section 3.1 - Standard Bearer token errors
116
+ "invalid_token", # Token expired, revoked, malformed, or invalid
117
+ "insufficient_scope", # Request requires higher privileges (403 Forbidden)
118
+ # OAuth 2.1 token refresh errors
119
+ "invalid_grant", # Refresh token errors
120
+ # MCP spec - OAuth validation failures (401 Unauthorized)
121
+ "oauth validation",
122
+ "unauthorized",
123
+ # Common OAuth error descriptions
124
+ "expired token",
125
+ "token expired",
126
+ "authentication failed",
127
+ "invalid access token",
128
+ ]
129
+
130
+ # Create retry config that skips OAuth errors
131
+ retry_config = (
132
+ RetryConfig(
133
+ max_retries=max_retries,
134
+ skip_retry_on_error_substrings=oauth_error_patterns,
135
+ )
136
+ if enable_retries
137
+ else None
138
+ )
139
+
99
140
  processor = ToolProcessor(
100
141
  default_timeout=default_timeout,
101
142
  max_concurrency=max_concurrency,
@@ -106,6 +147,7 @@ async def setup_mcp_http_streamable(
106
147
  tool_rate_limits=tool_rate_limits,
107
148
  enable_retries=enable_retries,
108
149
  max_retries=max_retries,
150
+ retry_config=retry_config, # NEW: Pass OAuth-aware retry config
109
151
  )
110
152
 
111
153
  logger.debug(
@@ -30,15 +30,17 @@ async def setup_mcp_sse( # noqa: C901 - long but just a config facade
30
30
  server_names: dict[int, str] | None = None,
31
31
  connection_timeout: float = 30.0, # 🔧 INCREASED DEFAULT: was 10.0
32
32
  default_timeout: float = 30.0, # 🔧 INCREASED DEFAULT: was 10.0
33
+ initialization_timeout: float = 60.0,
33
34
  max_concurrency: int | None = None,
34
35
  enable_caching: bool = True,
35
36
  cache_ttl: int = 300,
36
37
  enable_rate_limiting: bool = False,
37
38
  global_rate_limit: int | None = None,
38
39
  tool_rate_limits: dict[str, tuple] | None = None,
39
- enable_retries: bool = True,
40
- max_retries: int = 3,
40
+ enable_retries: bool = True, # CHANGED: Enabled with OAuth errors excluded
41
+ max_retries: int = 2, # Retry non-OAuth errors (OAuth handled at transport level)
41
42
  namespace: str = "sse",
43
+ oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
42
44
  ) -> tuple[ToolProcessor, StreamManager]:
43
45
  """
44
46
  Initialise SSE-transport MCP + a :class:`ToolProcessor`.
@@ -50,6 +52,7 @@ async def setup_mcp_sse( # noqa: C901 - long but just a config facade
50
52
  server_names: Optional mapping of server indices to names
51
53
  connection_timeout: Timeout for initial SSE connection setup
52
54
  default_timeout: Default timeout for tool execution
55
+ initialization_timeout: Timeout for complete initialization (default 60s, increase for slow servers)
53
56
  max_concurrency: Maximum concurrent operations
54
57
  enable_caching: Whether to enable response caching
55
58
  cache_ttl: Cache time-to-live in seconds
@@ -59,6 +62,7 @@ async def setup_mcp_sse( # noqa: C901 - long but just a config facade
59
62
  enable_retries: Whether to enable automatic retries
60
63
  max_retries: Maximum retry attempts
61
64
  namespace: Namespace for registered tools
65
+ oauth_refresh_callback: Optional async callback to refresh OAuth tokens (NEW)
62
66
 
63
67
  Returns:
64
68
  Tuple of (ToolProcessor, StreamManager)
@@ -69,12 +73,49 @@ async def setup_mcp_sse( # noqa: C901 - long but just a config facade
69
73
  server_names=server_names,
70
74
  connection_timeout=connection_timeout, # 🔧 ADD THIS LINE
71
75
  default_timeout=default_timeout, # 🔧 ADD THIS LINE
76
+ initialization_timeout=initialization_timeout,
77
+ oauth_refresh_callback=oauth_refresh_callback, # NEW: Pass OAuth callback
72
78
  )
73
79
 
74
80
  # 2️⃣ pull the remote tool list and register each one locally
75
81
  registered = await register_mcp_tools(stream_manager, namespace=namespace)
76
82
 
77
83
  # 3️⃣ build a processor instance configured to your taste
84
+ # IMPORTANT: Retries are enabled but OAuth errors are excluded
85
+ # OAuth refresh happens at transport level with automatic retry
86
+
87
+ # Import RetryConfig to configure OAuth error exclusion
88
+ from chuk_tool_processor.execution.wrappers.retry import RetryConfig
89
+
90
+ # Define OAuth error patterns that should NOT be retried at this level
91
+ # These will be handled by the transport layer's OAuth refresh mechanism
92
+ # Based on RFC 6750 (Bearer Token Usage) and MCP OAuth spec
93
+ oauth_error_patterns = [
94
+ # RFC 6750 Section 3.1 - Standard Bearer token errors
95
+ "invalid_token", # Token expired, revoked, malformed, or invalid
96
+ "insufficient_scope", # Request requires higher privileges (403 Forbidden)
97
+ # OAuth 2.1 token refresh errors
98
+ "invalid_grant", # Refresh token errors
99
+ # MCP spec - OAuth validation failures (401 Unauthorized)
100
+ "oauth validation",
101
+ "unauthorized",
102
+ # Common OAuth error descriptions
103
+ "expired token",
104
+ "token expired",
105
+ "authentication failed",
106
+ "invalid access token",
107
+ ]
108
+
109
+ # Create retry config that skips OAuth errors
110
+ retry_config = (
111
+ RetryConfig(
112
+ max_retries=max_retries,
113
+ skip_retry_on_error_substrings=oauth_error_patterns,
114
+ )
115
+ if enable_retries
116
+ else None
117
+ )
118
+
78
119
  processor = ToolProcessor(
79
120
  default_timeout=default_timeout,
80
121
  max_concurrency=max_concurrency,
@@ -85,6 +126,7 @@ async def setup_mcp_sse( # noqa: C901 - long but just a config facade
85
126
  tool_rate_limits=tool_rate_limits,
86
127
  enable_retries=enable_retries,
87
128
  max_retries=max_retries,
129
+ retry_config=retry_config, # NEW: Pass OAuth-aware retry config
88
130
  )
89
131
 
90
132
  logger.debug(
@@ -30,6 +30,7 @@ async def setup_mcp_stdio( # noqa: C901 - long but just a config facade
30
30
  servers: list[str],
31
31
  server_names: dict[int, str] | None = None,
32
32
  default_timeout: float = 10.0,
33
+ initialization_timeout: float = 60.0,
33
34
  max_concurrency: int | None = None,
34
35
  enable_caching: bool = True,
35
36
  cache_ttl: int = 300,
@@ -53,6 +54,7 @@ async def setup_mcp_stdio( # noqa: C901 - long but just a config facade
53
54
  server_names=server_names,
54
55
  transport_type="stdio",
55
56
  default_timeout=default_timeout, # 🔧 ADD THIS LINE
57
+ initialization_timeout=initialization_timeout,
56
58
  )
57
59
 
58
60
  # 2️⃣ pull the remote tool list and register each one locally
@@ -21,6 +21,7 @@ from chuk_tool_processor.mcp.transport import (
21
21
  MCPBaseTransport,
22
22
  SSETransport,
23
23
  StdioTransport,
24
+ TimeoutConfig,
24
25
  )
25
26
 
26
27
  logger = get_logger("chuk_tool_processor.mcp.stream_manager")
@@ -38,7 +39,7 @@ class StreamManager:
38
39
  - HTTP Streamable (modern replacement for SSE, spec 2025-03-26) with graceful headers handling
39
40
  """
40
41
 
41
- def __init__(self) -> None:
42
+ def __init__(self, timeout_config: TimeoutConfig | None = None) -> None:
42
43
  self.transports: dict[str, MCPBaseTransport] = {}
43
44
  self.server_info: list[dict[str, Any]] = []
44
45
  self.tool_to_server_map: dict[str, str] = {}
@@ -46,7 +47,7 @@ class StreamManager:
46
47
  self.all_tools: list[dict[str, Any]] = []
47
48
  self._lock = asyncio.Lock()
48
49
  self._closed = False # Track if we've been closed
49
- self._shutdown_timeout = 2.0 # Maximum time to spend on shutdown
50
+ self.timeout_config = timeout_config or TimeoutConfig()
50
51
 
51
52
  # ------------------------------------------------------------------ #
52
53
  # factory helpers with enhanced error handling #
@@ -62,16 +63,16 @@ class StreamManager:
62
63
  initialization_timeout: float = 60.0, # NEW: Timeout for entire initialization
63
64
  ) -> StreamManager:
64
65
  """Create StreamManager with timeout protection."""
65
- try:
66
- inst = cls()
67
- await asyncio.wait_for(
68
- inst.initialize(config_file, servers, server_names, transport_type, default_timeout=default_timeout),
69
- timeout=initialization_timeout,
70
- )
71
- return inst
72
- except TimeoutError:
73
- logger.error("StreamManager initialization timed out after %ss", initialization_timeout)
74
- raise RuntimeError(f"StreamManager initialization timed out after {initialization_timeout}s")
66
+ inst = cls()
67
+ await inst.initialize(
68
+ config_file,
69
+ servers,
70
+ server_names,
71
+ transport_type,
72
+ default_timeout=default_timeout,
73
+ initialization_timeout=initialization_timeout,
74
+ )
75
+ return inst
75
76
 
76
77
  @classmethod
77
78
  async def create_with_sse(
@@ -81,20 +82,19 @@ class StreamManager:
81
82
  connection_timeout: float = 10.0,
82
83
  default_timeout: float = 30.0,
83
84
  initialization_timeout: float = 60.0, # NEW
85
+ oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
84
86
  ) -> StreamManager:
85
87
  """Create StreamManager with SSE transport and timeout protection."""
86
- try:
87
- inst = cls()
88
- await asyncio.wait_for(
89
- inst.initialize_with_sse(
90
- servers, server_names, connection_timeout=connection_timeout, default_timeout=default_timeout
91
- ),
92
- timeout=initialization_timeout,
93
- )
94
- return inst
95
- except TimeoutError:
96
- logger.error("SSE StreamManager initialization timed out after %ss", initialization_timeout)
97
- raise RuntimeError(f"SSE StreamManager initialization timed out after {initialization_timeout}s")
88
+ inst = cls()
89
+ await inst.initialize_with_sse(
90
+ servers,
91
+ server_names,
92
+ connection_timeout=connection_timeout,
93
+ default_timeout=default_timeout,
94
+ initialization_timeout=initialization_timeout,
95
+ oauth_refresh_callback=oauth_refresh_callback, # NEW: Pass OAuth callback
96
+ )
97
+ return inst
98
98
 
99
99
  @classmethod
100
100
  async def create_with_http_streamable(
@@ -104,22 +104,19 @@ class StreamManager:
104
104
  connection_timeout: float = 30.0,
105
105
  default_timeout: float = 30.0,
106
106
  initialization_timeout: float = 60.0, # NEW
107
+ oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
107
108
  ) -> StreamManager:
108
109
  """Create StreamManager with HTTP Streamable transport and timeout protection."""
109
- try:
110
- inst = cls()
111
- await asyncio.wait_for(
112
- inst.initialize_with_http_streamable(
113
- servers, server_names, connection_timeout=connection_timeout, default_timeout=default_timeout
114
- ),
115
- timeout=initialization_timeout,
116
- )
117
- return inst
118
- except TimeoutError:
119
- logger.error("HTTP Streamable StreamManager initialization timed out after %ss", initialization_timeout)
120
- raise RuntimeError(
121
- f"HTTP Streamable StreamManager initialization timed out after {initialization_timeout}s"
122
- )
110
+ inst = cls()
111
+ await inst.initialize_with_http_streamable(
112
+ servers,
113
+ server_names,
114
+ connection_timeout=connection_timeout,
115
+ default_timeout=default_timeout,
116
+ initialization_timeout=initialization_timeout,
117
+ oauth_refresh_callback=oauth_refresh_callback, # NEW: Pass OAuth callback
118
+ )
119
+ return inst
123
120
 
124
121
  # ------------------------------------------------------------------ #
125
122
  # NEW: Context manager support for automatic cleanup #
@@ -167,6 +164,7 @@ class StreamManager:
167
164
  server_names: dict[int, str] | None = None,
168
165
  transport_type: str = "stdio",
169
166
  default_timeout: float = 30.0,
167
+ initialization_timeout: float = 60.0,
170
168
  ) -> None:
171
169
  """Initialize with graceful headers handling for all transport types."""
172
170
  if self._closed:
@@ -178,13 +176,24 @@ class StreamManager:
178
176
  for idx, server_name in enumerate(servers):
179
177
  try:
180
178
  if transport_type == "stdio":
181
- params = await load_config(config_file, server_name)
182
- transport: MCPBaseTransport = StdioTransport(params)
179
+ params, server_timeout = await load_config(config_file, server_name)
180
+ # Use per-server timeout if specified, otherwise use global default
181
+ effective_timeout = server_timeout if server_timeout is not None else default_timeout
182
+ logger.info(
183
+ f"Server '{server_name}' using timeout: {effective_timeout}s (per-server: {server_timeout}, default: {default_timeout})"
184
+ )
185
+ # Use initialization_timeout for connection_timeout since subprocess
186
+ # launch can take time (e.g., uvx downloading packages)
187
+ transport: MCPBaseTransport = StdioTransport(
188
+ params, connection_timeout=initialization_timeout, default_timeout=effective_timeout
189
+ )
183
190
  elif transport_type == "sse":
184
- logger.warning(
191
+ logger.debug(
185
192
  "Using SSE transport in initialize() - consider using initialize_with_sse() instead"
186
193
  )
187
- params = await load_config(config_file, server_name)
194
+ params, server_timeout = await load_config(config_file, server_name)
195
+ # Use per-server timeout if specified, otherwise use global default
196
+ effective_timeout = server_timeout if server_timeout is not None else default_timeout
188
197
 
189
198
  if isinstance(params, dict) and "url" in params:
190
199
  sse_url = params["url"]
@@ -194,20 +203,22 @@ class StreamManager:
194
203
  sse_url = "http://localhost:8000"
195
204
  api_key = None
196
205
  headers = {}
197
- logger.warning("No URL configured for SSE transport, using default: %s", sse_url)
206
+ logger.debug("No URL configured for SSE transport, using default: %s", sse_url)
198
207
 
199
208
  # Build SSE transport with optional headers
200
- transport_params = {"url": sse_url, "api_key": api_key, "default_timeout": default_timeout}
209
+ transport_params = {"url": sse_url, "api_key": api_key, "default_timeout": effective_timeout}
201
210
  if headers:
202
211
  transport_params["headers"] = headers
203
212
 
204
213
  transport = SSETransport(**transport_params)
205
214
 
206
215
  elif transport_type == "http_streamable":
207
- logger.warning(
216
+ logger.debug(
208
217
  "Using HTTP Streamable transport in initialize() - consider using initialize_with_http_streamable() instead"
209
218
  )
210
- params = await load_config(config_file, server_name)
219
+ params, server_timeout = await load_config(config_file, server_name)
220
+ # Use per-server timeout if specified, otherwise use global default
221
+ effective_timeout = server_timeout if server_timeout is not None else default_timeout
211
222
 
212
223
  if isinstance(params, dict) and "url" in params:
213
224
  http_url = params["url"]
@@ -219,15 +230,13 @@ class StreamManager:
219
230
  api_key = None
220
231
  headers = {}
221
232
  session_id = None
222
- logger.warning(
223
- "No URL configured for HTTP Streamable transport, using default: %s", http_url
224
- )
233
+ logger.debug("No URL configured for HTTP Streamable transport, using default: %s", http_url)
225
234
 
226
235
  # Build HTTP transport (headers not supported yet)
227
236
  transport_params = {
228
237
  "url": http_url,
229
238
  "api_key": api_key,
230
- "default_timeout": default_timeout,
239
+ "default_timeout": effective_timeout,
231
240
  "session_id": session_id,
232
241
  }
233
242
  # Note: headers not added until HTTPStreamableTransport supports them
@@ -241,15 +250,23 @@ class StreamManager:
241
250
  continue
242
251
 
243
252
  # Initialize with timeout protection
244
- if not await asyncio.wait_for(transport.initialize(), timeout=default_timeout):
245
- logger.error("Failed to init %s", server_name)
253
+ try:
254
+ if not await asyncio.wait_for(transport.initialize(), timeout=initialization_timeout):
255
+ logger.warning("Failed to init %s", server_name)
256
+ continue
257
+ except TimeoutError:
258
+ logger.error("Timeout initialising %s (timeout=%ss)", server_name, initialization_timeout)
246
259
  continue
247
260
 
248
261
  self.transports[server_name] = transport
249
262
 
250
- # Ping and get tools with timeout protection
251
- status = "Up" if await asyncio.wait_for(transport.send_ping(), timeout=5.0) else "Down"
252
- tools = await asyncio.wait_for(transport.get_tools(), timeout=10.0)
263
+ # Ping and get tools with timeout protection (use longer timeouts for slow servers)
264
+ status = (
265
+ "Up"
266
+ if await asyncio.wait_for(transport.send_ping(), timeout=self.timeout_config.operation)
267
+ else "Down"
268
+ )
269
+ tools = await asyncio.wait_for(transport.get_tools(), timeout=self.timeout_config.operation)
253
270
 
254
271
  for t in tools:
255
272
  name = t.get("name")
@@ -283,6 +300,8 @@ class StreamManager:
283
300
  server_names: dict[int, str] | None = None,
284
301
  connection_timeout: float = 10.0,
285
302
  default_timeout: float = 30.0,
303
+ initialization_timeout: float = 60.0,
304
+ oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
286
305
  ) -> None:
287
306
  """Initialize with SSE transport with optional headers support."""
288
307
  if self._closed:
@@ -311,15 +330,29 @@ class StreamManager:
311
330
  logger.debug("SSE %s: Using configured headers: %s", name, list(headers.keys()))
312
331
  transport_params["headers"] = headers
313
332
 
333
+ # Add OAuth refresh callback if provided (NEW)
334
+ if oauth_refresh_callback:
335
+ transport_params["oauth_refresh_callback"] = oauth_refresh_callback
336
+ logger.debug("SSE %s: OAuth refresh callback configured", name)
337
+
314
338
  transport = SSETransport(**transport_params)
315
339
 
316
- if not await asyncio.wait_for(transport.initialize(), timeout=connection_timeout):
317
- logger.error("Failed to init SSE %s", name)
340
+ try:
341
+ if not await asyncio.wait_for(transport.initialize(), timeout=initialization_timeout):
342
+ logger.warning("Failed to init SSE %s", name)
343
+ continue
344
+ except TimeoutError:
345
+ logger.error("Timeout initialising SSE %s (timeout=%ss)", name, initialization_timeout)
318
346
  continue
319
347
 
320
348
  self.transports[name] = transport
321
- status = "Up" if await asyncio.wait_for(transport.send_ping(), timeout=5.0) else "Down"
322
- tools = await asyncio.wait_for(transport.get_tools(), timeout=10.0)
349
+ # Use longer timeouts for slow servers (ping can take time after initialization)
350
+ status = (
351
+ "Up"
352
+ if await asyncio.wait_for(transport.send_ping(), timeout=self.timeout_config.operation)
353
+ else "Down"
354
+ )
355
+ tools = await asyncio.wait_for(transport.get_tools(), timeout=self.timeout_config.operation)
323
356
 
324
357
  for t in tools:
325
358
  tname = t.get("name")
@@ -346,11 +379,15 @@ class StreamManager:
346
379
  server_names: dict[int, str] | None = None,
347
380
  connection_timeout: float = 30.0,
348
381
  default_timeout: float = 30.0,
382
+ initialization_timeout: float = 60.0,
383
+ oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
349
384
  ) -> None:
350
385
  """Initialize with HTTP Streamable transport with graceful headers handling."""
351
386
  if self._closed:
352
387
  raise RuntimeError("Cannot initialize a closed StreamManager")
353
388
 
389
+ logger.debug(f"initialize_with_http_streamable: initialization_timeout={initialization_timeout}")
390
+
354
391
  async with self._lock:
355
392
  self.server_names = server_names or {}
356
393
 
@@ -369,22 +406,39 @@ class StreamManager:
369
406
  "session_id": cfg.get("session_id"),
370
407
  }
371
408
 
372
- # Handle headers if provided (for future HTTPStreamableTransport support)
409
+ # Handle headers if provided
373
410
  headers = cfg.get("headers", {})
374
411
  if headers:
375
- logger.debug("HTTP Streamable %s: Headers provided but not yet supported in transport", name)
376
- # TODO: Add headers support when HTTPStreamableTransport is updated
377
- # transport_params['headers'] = headers
412
+ transport_params["headers"] = headers
413
+ logger.debug("HTTP Streamable %s: Custom headers configured: %s", name, list(headers.keys()))
414
+
415
+ # Add OAuth refresh callback if provided (NEW)
416
+ if oauth_refresh_callback:
417
+ transport_params["oauth_refresh_callback"] = oauth_refresh_callback
418
+ logger.debug("HTTP Streamable %s: OAuth refresh callback configured", name)
378
419
 
379
420
  transport = HTTPStreamableTransport(**transport_params)
380
421
 
381
- if not await asyncio.wait_for(transport.initialize(), timeout=connection_timeout):
382
- logger.error("Failed to init HTTP Streamable %s", name)
422
+ logger.debug(f"Calling transport.initialize() for {name} with timeout={initialization_timeout}s")
423
+ try:
424
+ if not await asyncio.wait_for(transport.initialize(), timeout=initialization_timeout):
425
+ logger.warning("Failed to init HTTP Streamable %s", name)
426
+ continue
427
+ except TimeoutError:
428
+ logger.error(
429
+ "Timeout initialising HTTP Streamable %s (timeout=%ss)", name, initialization_timeout
430
+ )
383
431
  continue
432
+ logger.debug(f"Successfully initialized {name}")
384
433
 
385
434
  self.transports[name] = transport
386
- status = "Up" if await asyncio.wait_for(transport.send_ping(), timeout=5.0) else "Down"
387
- tools = await asyncio.wait_for(transport.get_tools(), timeout=10.0)
435
+ # Use longer timeouts for slow servers (ping can take time after initialization)
436
+ status = (
437
+ "Up"
438
+ if await asyncio.wait_for(transport.send_ping(), timeout=self.timeout_config.operation)
439
+ else "Down"
440
+ )
441
+ tools = await asyncio.wait_for(transport.get_tools(), timeout=self.timeout_config.operation)
388
442
 
389
443
  for t in tools:
390
444
  tname = t.get("name")
@@ -430,7 +484,7 @@ class StreamManager:
430
484
  transport = self.transports[server_name]
431
485
 
432
486
  try:
433
- tools = await asyncio.wait_for(transport.get_tools(), timeout=10.0)
487
+ tools = await asyncio.wait_for(transport.get_tools(), timeout=self.timeout_config.operation)
434
488
  logger.debug("Found %d tools for server %s", len(tools), server_name)
435
489
  return tools
436
490
  except TimeoutError:
@@ -449,7 +503,7 @@ class StreamManager:
449
503
 
450
504
  async def _ping_one(name: str, tr: MCPBaseTransport):
451
505
  try:
452
- ok = await asyncio.wait_for(tr.send_ping(), timeout=5.0)
506
+ ok = await asyncio.wait_for(tr.send_ping(), timeout=self.timeout_config.quick)
453
507
  except Exception:
454
508
  ok = False
455
509
  return {"server": name, "ok": ok}
@@ -464,7 +518,7 @@ class StreamManager:
464
518
 
465
519
  async def _one(name: str, tr: MCPBaseTransport):
466
520
  try:
467
- res = await asyncio.wait_for(tr.list_resources(), timeout=10.0)
521
+ res = await asyncio.wait_for(tr.list_resources(), timeout=self.timeout_config.operation)
468
522
  resources = res.get("resources", []) if isinstance(res, dict) else res
469
523
  for item in resources:
470
524
  item = dict(item)
@@ -484,7 +538,7 @@ class StreamManager:
484
538
 
485
539
  async def _one(name: str, tr: MCPBaseTransport):
486
540
  try:
487
- res = await asyncio.wait_for(tr.list_prompts(), timeout=10.0)
541
+ res = await asyncio.wait_for(tr.list_prompts(), timeout=self.timeout_config.operation)
488
542
  prompts = res.get("prompts", []) if isinstance(res, dict) else res
489
543
  for item in prompts:
490
544
  item = dict(item)
@@ -611,7 +665,7 @@ class StreamManager:
611
665
  try:
612
666
  results = await asyncio.wait_for(
613
667
  asyncio.gather(*[task for _, task in close_tasks], return_exceptions=True),
614
- timeout=self._shutdown_timeout,
668
+ timeout=self.timeout_config.shutdown,
615
669
  )
616
670
 
617
671
  # Process results
@@ -634,7 +688,8 @@ class StreamManager:
634
688
  # Brief wait for cancellations to complete
635
689
  with contextlib.suppress(TimeoutError):
636
690
  await asyncio.wait_for(
637
- asyncio.gather(*[task for _, task in close_tasks], return_exceptions=True), timeout=0.5
691
+ asyncio.gather(*[task for _, task in close_tasks], return_exceptions=True),
692
+ timeout=self.timeout_config.shutdown,
638
693
  )
639
694
 
640
695
  async def _sequential_close(self, transport_items: list[tuple[str, MCPBaseTransport]], close_results: list) -> None:
@@ -643,7 +698,7 @@ class StreamManager:
643
698
  try:
644
699
  await asyncio.wait_for(
645
700
  self._close_single_transport(name, transport),
646
- timeout=0.5, # Short timeout per transport
701
+ timeout=self.timeout_config.shutdown,
647
702
  )
648
703
  logger.debug("Closed transport: %s", name)
649
704
  close_results.append((name, True, None))
@@ -735,7 +790,7 @@ class StreamManager:
735
790
 
736
791
  for name, transport in self.transports.items():
737
792
  try:
738
- ping_ok = await asyncio.wait_for(transport.send_ping(), timeout=5.0)
793
+ ping_ok = await asyncio.wait_for(transport.send_ping(), timeout=self.timeout_config.quick)
739
794
  health_info["transports"][name] = {
740
795
  "status": "healthy" if ping_ok else "unhealthy",
741
796
  "ping_success": ping_ok,
@@ -11,6 +11,12 @@ All transports now follow the same interface and provide consistent behavior:
11
11
 
12
12
  from .base_transport import MCPBaseTransport
13
13
  from .http_streamable_transport import HTTPStreamableTransport
14
+ from .models import (
15
+ HeadersConfig,
16
+ ServerInfo,
17
+ TimeoutConfig,
18
+ TransportMetrics,
19
+ )
14
20
  from .sse_transport import SSETransport
15
21
  from .stdio_transport import StdioTransport
16
22
 
@@ -19,4 +25,8 @@ __all__ = [
19
25
  "StdioTransport",
20
26
  "SSETransport",
21
27
  "HTTPStreamableTransport",
28
+ "TimeoutConfig",
29
+ "TransportMetrics",
30
+ "ServerInfo",
31
+ "HeadersConfig",
22
32
  ]