chuk-tool-processor 0.6.4__py3-none-any.whl → 0.9.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chuk-tool-processor might be problematic. Click here for more details.

Files changed (66) hide show
  1. chuk_tool_processor/core/__init__.py +32 -1
  2. chuk_tool_processor/core/exceptions.py +225 -13
  3. chuk_tool_processor/core/processor.py +135 -104
  4. chuk_tool_processor/execution/strategies/__init__.py +6 -0
  5. chuk_tool_processor/execution/strategies/inprocess_strategy.py +142 -150
  6. chuk_tool_processor/execution/strategies/subprocess_strategy.py +202 -206
  7. chuk_tool_processor/execution/tool_executor.py +82 -84
  8. chuk_tool_processor/execution/wrappers/__init__.py +42 -0
  9. chuk_tool_processor/execution/wrappers/caching.py +150 -116
  10. chuk_tool_processor/execution/wrappers/circuit_breaker.py +370 -0
  11. chuk_tool_processor/execution/wrappers/rate_limiting.py +76 -43
  12. chuk_tool_processor/execution/wrappers/retry.py +116 -78
  13. chuk_tool_processor/logging/__init__.py +23 -17
  14. chuk_tool_processor/logging/context.py +40 -45
  15. chuk_tool_processor/logging/formatter.py +22 -21
  16. chuk_tool_processor/logging/helpers.py +28 -42
  17. chuk_tool_processor/logging/metrics.py +13 -15
  18. chuk_tool_processor/mcp/__init__.py +8 -12
  19. chuk_tool_processor/mcp/mcp_tool.py +158 -114
  20. chuk_tool_processor/mcp/register_mcp_tools.py +22 -22
  21. chuk_tool_processor/mcp/setup_mcp_http_streamable.py +57 -17
  22. chuk_tool_processor/mcp/setup_mcp_sse.py +57 -17
  23. chuk_tool_processor/mcp/setup_mcp_stdio.py +11 -11
  24. chuk_tool_processor/mcp/stream_manager.py +333 -276
  25. chuk_tool_processor/mcp/transport/__init__.py +22 -29
  26. chuk_tool_processor/mcp/transport/base_transport.py +180 -44
  27. chuk_tool_processor/mcp/transport/http_streamable_transport.py +505 -325
  28. chuk_tool_processor/mcp/transport/models.py +100 -0
  29. chuk_tool_processor/mcp/transport/sse_transport.py +607 -276
  30. chuk_tool_processor/mcp/transport/stdio_transport.py +597 -116
  31. chuk_tool_processor/models/__init__.py +21 -1
  32. chuk_tool_processor/models/execution_strategy.py +16 -21
  33. chuk_tool_processor/models/streaming_tool.py +28 -25
  34. chuk_tool_processor/models/tool_call.py +49 -31
  35. chuk_tool_processor/models/tool_export_mixin.py +22 -8
  36. chuk_tool_processor/models/tool_result.py +40 -77
  37. chuk_tool_processor/models/tool_spec.py +350 -0
  38. chuk_tool_processor/models/validated_tool.py +36 -18
  39. chuk_tool_processor/observability/__init__.py +30 -0
  40. chuk_tool_processor/observability/metrics.py +312 -0
  41. chuk_tool_processor/observability/setup.py +105 -0
  42. chuk_tool_processor/observability/tracing.py +345 -0
  43. chuk_tool_processor/plugins/__init__.py +1 -1
  44. chuk_tool_processor/plugins/discovery.py +11 -11
  45. chuk_tool_processor/plugins/parsers/__init__.py +1 -1
  46. chuk_tool_processor/plugins/parsers/base.py +1 -2
  47. chuk_tool_processor/plugins/parsers/function_call_tool.py +13 -8
  48. chuk_tool_processor/plugins/parsers/json_tool.py +4 -3
  49. chuk_tool_processor/plugins/parsers/openai_tool.py +12 -7
  50. chuk_tool_processor/plugins/parsers/xml_tool.py +4 -4
  51. chuk_tool_processor/registry/__init__.py +12 -12
  52. chuk_tool_processor/registry/auto_register.py +22 -30
  53. chuk_tool_processor/registry/decorators.py +127 -129
  54. chuk_tool_processor/registry/interface.py +26 -23
  55. chuk_tool_processor/registry/metadata.py +27 -22
  56. chuk_tool_processor/registry/provider.py +17 -18
  57. chuk_tool_processor/registry/providers/__init__.py +16 -19
  58. chuk_tool_processor/registry/providers/memory.py +18 -25
  59. chuk_tool_processor/registry/tool_export.py +42 -51
  60. chuk_tool_processor/utils/validation.py +15 -16
  61. chuk_tool_processor-0.9.7.dist-info/METADATA +1813 -0
  62. chuk_tool_processor-0.9.7.dist-info/RECORD +67 -0
  63. chuk_tool_processor-0.6.4.dist-info/METADATA +0 -697
  64. chuk_tool_processor-0.6.4.dist-info/RECORD +0 -60
  65. {chuk_tool_processor-0.6.4.dist-info → chuk_tool_processor-0.9.7.dist-info}/WHEEL +0 -0
  66. {chuk_tool_processor-0.6.4.dist-info → chuk_tool_processor-0.9.7.dist-info}/top_level.txt +0 -0
@@ -1,24 +1,28 @@
1
1
  # chuk_tool_processor/mcp/stream_manager.py
2
2
  """
3
- StreamManager for CHUK Tool Processor - Enhanced with robust shutdown handling
3
+ StreamManager for CHUK Tool Processor - Enhanced with robust shutdown handling and headers support
4
4
  """
5
+
5
6
  from __future__ import annotations
6
7
 
7
8
  import asyncio
8
- from typing import Any, Dict, List, Optional, Tuple
9
+ import contextlib
9
10
  from contextlib import asynccontextmanager
11
+ from typing import Any
10
12
 
11
13
  # --------------------------------------------------------------------------- #
12
14
  # CHUK imports #
13
15
  # --------------------------------------------------------------------------- #
14
- from chuk_mcp.config import load_config
16
+ from chuk_mcp.config import load_config # type: ignore[import-untyped]
17
+
18
+ from chuk_tool_processor.logging import get_logger
15
19
  from chuk_tool_processor.mcp.transport import (
20
+ HTTPStreamableTransport,
16
21
  MCPBaseTransport,
17
- StdioTransport,
18
22
  SSETransport,
19
- HTTPStreamableTransport,
23
+ StdioTransport,
24
+ TimeoutConfig,
20
25
  )
21
- from chuk_tool_processor.logging import get_logger
22
26
 
23
27
  logger = get_logger("chuk_tool_processor.mcp.stream_manager")
24
28
 
@@ -26,24 +30,24 @@ logger = get_logger("chuk_tool_processor.mcp.stream_manager")
26
30
  class StreamManager:
27
31
  """
28
32
  Manager for MCP server streams with support for multiple transport types.
29
-
30
- Enhanced with robust shutdown handling to prevent event loop closure issues.
31
-
33
+
34
+ Enhanced with robust shutdown handling and proper headers support.
35
+
32
36
  Updated to support the latest transports:
33
37
  - STDIO (process-based)
34
- - SSE (Server-Sent Events)
35
- - HTTP Streamable (modern replacement for SSE, spec 2025-03-26)
38
+ - SSE (Server-Sent Events) with headers support
39
+ - HTTP Streamable (modern replacement for SSE, spec 2025-03-26) with graceful headers handling
36
40
  """
37
41
 
38
- def __init__(self) -> None:
39
- self.transports: Dict[str, MCPBaseTransport] = {}
40
- self.server_info: List[Dict[str, Any]] = []
41
- self.tool_to_server_map: Dict[str, str] = {}
42
- self.server_names: Dict[int, str] = {}
43
- self.all_tools: List[Dict[str, Any]] = []
42
+ def __init__(self, timeout_config: TimeoutConfig | None = None) -> None:
43
+ self.transports: dict[str, MCPBaseTransport] = {}
44
+ self.server_info: list[dict[str, Any]] = []
45
+ self.tool_to_server_map: dict[str, str] = {}
46
+ self.server_names: dict[int, str] = {}
47
+ self.all_tools: list[dict[str, Any]] = []
44
48
  self._lock = asyncio.Lock()
45
49
  self._closed = False # Track if we've been closed
46
- self._shutdown_timeout = 2.0 # Maximum time to spend on shutdown
50
+ self.timeout_config = timeout_config or TimeoutConfig()
47
51
 
48
52
  # ------------------------------------------------------------------ #
49
53
  # factory helpers with enhanced error handling #
@@ -52,81 +56,67 @@ class StreamManager:
52
56
  async def create(
53
57
  cls,
54
58
  config_file: str,
55
- servers: List[str],
56
- server_names: Optional[Dict[int, str]] = None,
59
+ servers: list[str],
60
+ server_names: dict[int, str] | None = None,
57
61
  transport_type: str = "stdio",
58
62
  default_timeout: float = 30.0,
59
63
  initialization_timeout: float = 60.0, # NEW: Timeout for entire initialization
60
- ) -> "StreamManager":
64
+ ) -> StreamManager:
61
65
  """Create StreamManager with timeout protection."""
62
- try:
63
- inst = cls()
64
- await asyncio.wait_for(
65
- inst.initialize(
66
- config_file,
67
- servers,
68
- server_names,
69
- transport_type,
70
- default_timeout=default_timeout
71
- ),
72
- timeout=initialization_timeout
73
- )
74
- return inst
75
- except asyncio.TimeoutError:
76
- logger.error(f"StreamManager initialization timed out after {initialization_timeout}s")
77
- raise RuntimeError(f"StreamManager initialization timed out after {initialization_timeout}s")
66
+ inst = cls()
67
+ await inst.initialize(
68
+ config_file,
69
+ servers,
70
+ server_names,
71
+ transport_type,
72
+ default_timeout=default_timeout,
73
+ initialization_timeout=initialization_timeout,
74
+ )
75
+ return inst
78
76
 
79
77
  @classmethod
80
78
  async def create_with_sse(
81
79
  cls,
82
- servers: List[Dict[str, str]],
83
- server_names: Optional[Dict[int, str]] = None,
80
+ servers: list[dict[str, str]],
81
+ server_names: dict[int, str] | None = None,
84
82
  connection_timeout: float = 10.0,
85
83
  default_timeout: float = 30.0,
86
84
  initialization_timeout: float = 60.0, # NEW
87
- ) -> "StreamManager":
85
+ oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
86
+ ) -> StreamManager:
88
87
  """Create StreamManager with SSE transport and timeout protection."""
89
- try:
90
- inst = cls()
91
- await asyncio.wait_for(
92
- inst.initialize_with_sse(
93
- servers,
94
- server_names,
95
- connection_timeout=connection_timeout,
96
- default_timeout=default_timeout
97
- ),
98
- timeout=initialization_timeout
99
- )
100
- return inst
101
- except asyncio.TimeoutError:
102
- logger.error(f"SSE StreamManager initialization timed out after {initialization_timeout}s")
103
- raise RuntimeError(f"SSE StreamManager initialization timed out after {initialization_timeout}s")
88
+ inst = cls()
89
+ await inst.initialize_with_sse(
90
+ servers,
91
+ server_names,
92
+ connection_timeout=connection_timeout,
93
+ default_timeout=default_timeout,
94
+ initialization_timeout=initialization_timeout,
95
+ oauth_refresh_callback=oauth_refresh_callback, # NEW: Pass OAuth callback
96
+ )
97
+ return inst
104
98
 
105
99
  @classmethod
106
100
  async def create_with_http_streamable(
107
101
  cls,
108
- servers: List[Dict[str, str]],
109
- server_names: Optional[Dict[int, str]] = None,
102
+ servers: list[dict[str, str]],
103
+ server_names: dict[int, str] | None = None,
110
104
  connection_timeout: float = 30.0,
111
105
  default_timeout: float = 30.0,
112
106
  initialization_timeout: float = 60.0, # NEW
113
- ) -> "StreamManager":
107
+ oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
108
+ ) -> StreamManager:
114
109
  """Create StreamManager with HTTP Streamable transport and timeout protection."""
115
- try:
116
- inst = cls()
117
- await asyncio.wait_for(
118
- inst.initialize_with_http_streamable(
119
- servers,
120
- server_names,
121
- connection_timeout=connection_timeout,
122
- default_timeout=default_timeout
123
- ),
124
- timeout=initialization_timeout
125
- )
126
- return inst
127
- except asyncio.TimeoutError:
128
- logger.error(f"HTTP Streamable StreamManager initialization timed out after {initialization_timeout}s")
129
- raise RuntimeError(f"HTTP Streamable StreamManager initialization timed out after {initialization_timeout}s")
110
+ inst = cls()
111
+ await inst.initialize_with_http_streamable(
112
+ servers,
113
+ server_names,
114
+ connection_timeout=connection_timeout,
115
+ default_timeout=default_timeout,
116
+ initialization_timeout=initialization_timeout,
117
+ oauth_refresh_callback=oauth_refresh_callback, # NEW: Pass OAuth callback
118
+ )
119
+ return inst
130
120
 
131
121
  # ------------------------------------------------------------------ #
132
122
  # NEW: Context manager support for automatic cleanup #
@@ -144,8 +134,8 @@ class StreamManager:
144
134
  async def create_managed(
145
135
  cls,
146
136
  config_file: str,
147
- servers: List[str],
148
- server_names: Optional[Dict[int, str]] = None,
137
+ servers: list[str],
138
+ server_names: dict[int, str] | None = None,
149
139
  transport_type: str = "stdio",
150
140
  default_timeout: float = 30.0,
151
141
  ):
@@ -170,73 +160,113 @@ class StreamManager:
170
160
  async def initialize(
171
161
  self,
172
162
  config_file: str,
173
- servers: List[str],
174
- server_names: Optional[Dict[int, str]] = None,
163
+ servers: list[str],
164
+ server_names: dict[int, str] | None = None,
175
165
  transport_type: str = "stdio",
176
166
  default_timeout: float = 30.0,
167
+ initialization_timeout: float = 60.0,
177
168
  ) -> None:
169
+ """Initialize with graceful headers handling for all transport types."""
178
170
  if self._closed:
179
171
  raise RuntimeError("Cannot initialize a closed StreamManager")
180
-
172
+
181
173
  async with self._lock:
182
174
  self.server_names = server_names or {}
183
175
 
184
176
  for idx, server_name in enumerate(servers):
185
177
  try:
186
178
  if transport_type == "stdio":
187
- params = await load_config(config_file, server_name)
188
- transport: MCPBaseTransport = StdioTransport(params)
179
+ params, server_timeout = await load_config(config_file, server_name)
180
+ # Use per-server timeout if specified, otherwise use global default
181
+ effective_timeout = server_timeout if server_timeout is not None else default_timeout
182
+ logger.info(
183
+ f"Server '{server_name}' using timeout: {effective_timeout}s (per-server: {server_timeout}, default: {default_timeout})"
184
+ )
185
+ # Use initialization_timeout for connection_timeout since subprocess
186
+ # launch can take time (e.g., uvx downloading packages)
187
+ transport: MCPBaseTransport = StdioTransport(
188
+ params, connection_timeout=initialization_timeout, default_timeout=effective_timeout
189
+ )
189
190
  elif transport_type == "sse":
190
- logger.warning("Using SSE transport in initialize() - consider using initialize_with_sse() instead")
191
- params = await load_config(config_file, server_name)
192
-
193
- if isinstance(params, dict) and 'url' in params:
194
- sse_url = params['url']
195
- api_key = params.get('api_key')
191
+ logger.debug(
192
+ "Using SSE transport in initialize() - consider using initialize_with_sse() instead"
193
+ )
194
+ params, server_timeout = await load_config(config_file, server_name)
195
+ # Use per-server timeout if specified, otherwise use global default
196
+ effective_timeout = server_timeout if server_timeout is not None else default_timeout
197
+
198
+ if isinstance(params, dict) and "url" in params:
199
+ sse_url = params["url"]
200
+ api_key = params.get("api_key")
201
+ headers = params.get("headers", {})
196
202
  else:
197
203
  sse_url = "http://localhost:8000"
198
204
  api_key = None
199
- logger.warning(f"No URL configured for SSE transport, using default: {sse_url}")
200
-
201
- transport = SSETransport(
202
- sse_url,
203
- api_key,
204
- default_timeout=default_timeout
205
- )
205
+ headers = {}
206
+ logger.debug("No URL configured for SSE transport, using default: %s", sse_url)
207
+
208
+ # Build SSE transport with optional headers
209
+ transport_params = {"url": sse_url, "api_key": api_key, "default_timeout": effective_timeout}
210
+ if headers:
211
+ transport_params["headers"] = headers
212
+
213
+ transport = SSETransport(**transport_params)
214
+
206
215
  elif transport_type == "http_streamable":
207
- logger.warning("Using HTTP Streamable transport in initialize() - consider using initialize_with_http_streamable() instead")
208
- params = await load_config(config_file, server_name)
209
-
210
- if isinstance(params, dict) and 'url' in params:
211
- http_url = params['url']
212
- api_key = params.get('api_key')
213
- session_id = params.get('session_id')
216
+ logger.debug(
217
+ "Using HTTP Streamable transport in initialize() - consider using initialize_with_http_streamable() instead"
218
+ )
219
+ params, server_timeout = await load_config(config_file, server_name)
220
+ # Use per-server timeout if specified, otherwise use global default
221
+ effective_timeout = server_timeout if server_timeout is not None else default_timeout
222
+
223
+ if isinstance(params, dict) and "url" in params:
224
+ http_url = params["url"]
225
+ api_key = params.get("api_key")
226
+ headers = params.get("headers", {})
227
+ session_id = params.get("session_id")
214
228
  else:
215
229
  http_url = "http://localhost:8000"
216
230
  api_key = None
231
+ headers = {}
217
232
  session_id = None
218
- logger.warning(f"No URL configured for HTTP Streamable transport, using default: {http_url}")
219
-
220
- transport = HTTPStreamableTransport(
221
- http_url,
222
- api_key,
223
- default_timeout=default_timeout,
224
- session_id=session_id
225
- )
233
+ logger.debug("No URL configured for HTTP Streamable transport, using default: %s", http_url)
234
+
235
+ # Build HTTP transport (headers not supported yet)
236
+ transport_params = {
237
+ "url": http_url,
238
+ "api_key": api_key,
239
+ "default_timeout": effective_timeout,
240
+ "session_id": session_id,
241
+ }
242
+ # Note: headers not added until HTTPStreamableTransport supports them
243
+ if headers:
244
+ logger.debug("Headers provided but not supported in HTTPStreamableTransport yet")
245
+
246
+ transport = HTTPStreamableTransport(**transport_params)
247
+
226
248
  else:
227
249
  logger.error("Unsupported transport type: %s", transport_type)
228
250
  continue
229
251
 
230
252
  # Initialize with timeout protection
231
- if not await asyncio.wait_for(transport.initialize(), timeout=default_timeout):
232
- logger.error("Failed to init %s", server_name)
253
+ try:
254
+ if not await asyncio.wait_for(transport.initialize(), timeout=initialization_timeout):
255
+ logger.warning("Failed to init %s", server_name)
256
+ continue
257
+ except TimeoutError:
258
+ logger.error("Timeout initialising %s (timeout=%ss)", server_name, initialization_timeout)
233
259
  continue
234
260
 
235
261
  self.transports[server_name] = transport
236
262
 
237
- # Ping and get tools with timeout protection
238
- status = "Up" if await asyncio.wait_for(transport.send_ping(), timeout=5.0) else "Down"
239
- tools = await asyncio.wait_for(transport.get_tools(), timeout=10.0)
263
+ # Ping and get tools with timeout protection (use longer timeouts for slow servers)
264
+ status = (
265
+ "Up"
266
+ if await asyncio.wait_for(transport.send_ping(), timeout=self.timeout_config.operation)
267
+ else "Down"
268
+ )
269
+ tools = await asyncio.wait_for(transport.get_tools(), timeout=self.timeout_config.operation)
240
270
 
241
271
  for t in tools:
242
272
  name = t.get("name")
@@ -252,13 +282,13 @@ class StreamManager:
252
282
  "status": status,
253
283
  }
254
284
  )
255
- logger.info("Initialised %s - %d tool(s)", server_name, len(tools))
256
- except asyncio.TimeoutError:
285
+ logger.debug("Initialised %s - %d tool(s)", server_name, len(tools))
286
+ except TimeoutError:
257
287
  logger.error("Timeout initialising %s", server_name)
258
288
  except Exception as exc:
259
289
  logger.error("Error initialising %s: %s", server_name, exc)
260
290
 
261
- logger.info(
291
+ logger.debug(
262
292
  "StreamManager ready - %d server(s), %d tool(s)",
263
293
  len(self.transports),
264
294
  len(self.all_tools),
@@ -266,14 +296,17 @@ class StreamManager:
266
296
 
267
297
  async def initialize_with_sse(
268
298
  self,
269
- servers: List[Dict[str, str]],
270
- server_names: Optional[Dict[int, str]] = None,
299
+ servers: list[dict[str, str]],
300
+ server_names: dict[int, str] | None = None,
271
301
  connection_timeout: float = 10.0,
272
302
  default_timeout: float = 30.0,
303
+ initialization_timeout: float = 60.0,
304
+ oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
273
305
  ) -> None:
306
+ """Initialize with SSE transport with optional headers support."""
274
307
  if self._closed:
275
308
  raise RuntimeError("Cannot initialize a closed StreamManager")
276
-
309
+
277
310
  async with self._lock:
278
311
  self.server_names = server_names or {}
279
312
 
@@ -283,20 +316,43 @@ class StreamManager:
283
316
  logger.error("Bad server config: %s", cfg)
284
317
  continue
285
318
  try:
286
- transport = SSETransport(
287
- url,
288
- cfg.get("api_key"),
289
- connection_timeout=connection_timeout,
290
- default_timeout=default_timeout
291
- )
292
-
293
- if not await asyncio.wait_for(transport.initialize(), timeout=connection_timeout):
294
- logger.error("Failed to init SSE %s", name)
319
+ # Build SSE transport parameters with optional headers
320
+ transport_params = {
321
+ "url": url,
322
+ "api_key": cfg.get("api_key"),
323
+ "connection_timeout": connection_timeout,
324
+ "default_timeout": default_timeout,
325
+ }
326
+
327
+ # Add headers if provided
328
+ headers = cfg.get("headers", {})
329
+ if headers:
330
+ logger.debug("SSE %s: Using configured headers: %s", name, list(headers.keys()))
331
+ transport_params["headers"] = headers
332
+
333
+ # Add OAuth refresh callback if provided (NEW)
334
+ if oauth_refresh_callback:
335
+ transport_params["oauth_refresh_callback"] = oauth_refresh_callback
336
+ logger.debug("SSE %s: OAuth refresh callback configured", name)
337
+
338
+ transport = SSETransport(**transport_params)
339
+
340
+ try:
341
+ if not await asyncio.wait_for(transport.initialize(), timeout=initialization_timeout):
342
+ logger.warning("Failed to init SSE %s", name)
343
+ continue
344
+ except TimeoutError:
345
+ logger.error("Timeout initialising SSE %s (timeout=%ss)", name, initialization_timeout)
295
346
  continue
296
347
 
297
348
  self.transports[name] = transport
298
- status = "Up" if await asyncio.wait_for(transport.send_ping(), timeout=5.0) else "Down"
299
- tools = await asyncio.wait_for(transport.get_tools(), timeout=10.0)
349
+ # Use longer timeouts for slow servers (ping can take time after initialization)
350
+ status = (
351
+ "Up"
352
+ if await asyncio.wait_for(transport.send_ping(), timeout=self.timeout_config.operation)
353
+ else "Down"
354
+ )
355
+ tools = await asyncio.wait_for(transport.get_tools(), timeout=self.timeout_config.operation)
300
356
 
301
357
  for t in tools:
302
358
  tname = t.get("name")
@@ -304,16 +360,14 @@ class StreamManager:
304
360
  self.tool_to_server_map[tname] = name
305
361
  self.all_tools.extend(tools)
306
362
 
307
- self.server_info.append(
308
- {"id": idx, "name": name, "tools": len(tools), "status": status}
309
- )
310
- logger.info("Initialised SSE %s - %d tool(s)", name, len(tools))
311
- except asyncio.TimeoutError:
363
+ self.server_info.append({"id": idx, "name": name, "tools": len(tools), "status": status})
364
+ logger.debug("Initialised SSE %s - %d tool(s)", name, len(tools))
365
+ except TimeoutError:
312
366
  logger.error("Timeout initialising SSE %s", name)
313
367
  except Exception as exc:
314
368
  logger.error("Error initialising SSE %s: %s", name, exc)
315
369
 
316
- logger.info(
370
+ logger.debug(
317
371
  "StreamManager ready - %d SSE server(s), %d tool(s)",
318
372
  len(self.transports),
319
373
  len(self.all_tools),
@@ -321,15 +375,19 @@ class StreamManager:
321
375
 
322
376
  async def initialize_with_http_streamable(
323
377
  self,
324
- servers: List[Dict[str, str]],
325
- server_names: Optional[Dict[int, str]] = None,
378
+ servers: list[dict[str, str]],
379
+ server_names: dict[int, str] | None = None,
326
380
  connection_timeout: float = 30.0,
327
381
  default_timeout: float = 30.0,
382
+ initialization_timeout: float = 60.0,
383
+ oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
328
384
  ) -> None:
329
- """Initialize with HTTP Streamable transport (modern MCP spec 2025-03-26)."""
385
+ """Initialize with HTTP Streamable transport with graceful headers handling."""
330
386
  if self._closed:
331
387
  raise RuntimeError("Cannot initialize a closed StreamManager")
332
-
388
+
389
+ logger.debug(f"initialize_with_http_streamable: initialization_timeout={initialization_timeout}")
390
+
333
391
  async with self._lock:
334
392
  self.server_names = server_names or {}
335
393
 
@@ -339,21 +397,48 @@ class StreamManager:
339
397
  logger.error("Bad server config: %s", cfg)
340
398
  continue
341
399
  try:
342
- transport = HTTPStreamableTransport(
343
- url,
344
- cfg.get("api_key"),
345
- connection_timeout=connection_timeout,
346
- default_timeout=default_timeout,
347
- session_id=cfg.get("session_id")
348
- )
349
-
350
- if not await asyncio.wait_for(transport.initialize(), timeout=connection_timeout):
351
- logger.error("Failed to init HTTP Streamable %s", name)
400
+ # Build HTTP Streamable transport parameters
401
+ transport_params = {
402
+ "url": url,
403
+ "api_key": cfg.get("api_key"),
404
+ "connection_timeout": connection_timeout,
405
+ "default_timeout": default_timeout,
406
+ "session_id": cfg.get("session_id"),
407
+ }
408
+
409
+ # Handle headers if provided
410
+ headers = cfg.get("headers", {})
411
+ if headers:
412
+ transport_params["headers"] = headers
413
+ logger.debug("HTTP Streamable %s: Custom headers configured: %s", name, list(headers.keys()))
414
+
415
+ # Add OAuth refresh callback if provided (NEW)
416
+ if oauth_refresh_callback:
417
+ transport_params["oauth_refresh_callback"] = oauth_refresh_callback
418
+ logger.debug("HTTP Streamable %s: OAuth refresh callback configured", name)
419
+
420
+ transport = HTTPStreamableTransport(**transport_params)
421
+
422
+ logger.debug(f"Calling transport.initialize() for {name} with timeout={initialization_timeout}s")
423
+ try:
424
+ if not await asyncio.wait_for(transport.initialize(), timeout=initialization_timeout):
425
+ logger.warning("Failed to init HTTP Streamable %s", name)
426
+ continue
427
+ except TimeoutError:
428
+ logger.error(
429
+ "Timeout initialising HTTP Streamable %s (timeout=%ss)", name, initialization_timeout
430
+ )
352
431
  continue
432
+ logger.debug(f"Successfully initialized {name}")
353
433
 
354
434
  self.transports[name] = transport
355
- status = "Up" if await asyncio.wait_for(transport.send_ping(), timeout=5.0) else "Down"
356
- tools = await asyncio.wait_for(transport.get_tools(), timeout=10.0)
435
+ # Use longer timeouts for slow servers (ping can take time after initialization)
436
+ status = (
437
+ "Up"
438
+ if await asyncio.wait_for(transport.send_ping(), timeout=self.timeout_config.operation)
439
+ else "Down"
440
+ )
441
+ tools = await asyncio.wait_for(transport.get_tools(), timeout=self.timeout_config.operation)
357
442
 
358
443
  for t in tools:
359
444
  tname = t.get("name")
@@ -361,16 +446,14 @@ class StreamManager:
361
446
  self.tool_to_server_map[tname] = name
362
447
  self.all_tools.extend(tools)
363
448
 
364
- self.server_info.append(
365
- {"id": idx, "name": name, "tools": len(tools), "status": status}
366
- )
367
- logger.info("Initialised HTTP Streamable %s - %d tool(s)", name, len(tools))
368
- except asyncio.TimeoutError:
449
+ self.server_info.append({"id": idx, "name": name, "tools": len(tools), "status": status})
450
+ logger.debug("Initialised HTTP Streamable %s - %d tool(s)", name, len(tools))
451
+ except TimeoutError:
369
452
  logger.error("Timeout initialising HTTP Streamable %s", name)
370
453
  except Exception as exc:
371
454
  logger.error("Error initialising HTTP Streamable %s: %s", name, exc)
372
455
 
373
- logger.info(
456
+ logger.debug(
374
457
  "StreamManager ready - %d HTTP Streamable server(s), %d tool(s)",
375
458
  len(self.transports),
376
459
  len(self.all_tools),
@@ -379,66 +462,64 @@ class StreamManager:
379
462
  # ------------------------------------------------------------------ #
380
463
  # queries #
381
464
  # ------------------------------------------------------------------ #
382
- def get_all_tools(self) -> List[Dict[str, Any]]:
465
+ def get_all_tools(self) -> list[dict[str, Any]]:
383
466
  return self.all_tools
384
467
 
385
- def get_server_for_tool(self, tool_name: str) -> Optional[str]:
468
+ def get_server_for_tool(self, tool_name: str) -> str | None:
386
469
  return self.tool_to_server_map.get(tool_name)
387
470
 
388
- def get_server_info(self) -> List[Dict[str, Any]]:
471
+ def get_server_info(self) -> list[dict[str, Any]]:
389
472
  return self.server_info
390
-
391
- async def list_tools(self, server_name: str) -> List[Dict[str, Any]]:
473
+
474
+ async def list_tools(self, server_name: str) -> list[dict[str, Any]]:
392
475
  """List all tools available from a specific server."""
393
476
  if self._closed:
394
477
  logger.warning("Cannot list tools: StreamManager is closed")
395
478
  return []
396
-
479
+
397
480
  if server_name not in self.transports:
398
- logger.error(f"Server '{server_name}' not found in transports")
481
+ logger.error("Server '%s' not found in transports", server_name)
399
482
  return []
400
-
483
+
401
484
  transport = self.transports[server_name]
402
-
485
+
403
486
  try:
404
- tools = await asyncio.wait_for(transport.get_tools(), timeout=10.0)
405
- logger.debug(f"Found {len(tools)} tools for server {server_name}")
487
+ tools = await asyncio.wait_for(transport.get_tools(), timeout=self.timeout_config.operation)
488
+ logger.debug("Found %d tools for server %s", len(tools), server_name)
406
489
  return tools
407
- except asyncio.TimeoutError:
408
- logger.error(f"Timeout listing tools for server {server_name}")
490
+ except TimeoutError:
491
+ logger.error("Timeout listing tools for server %s", server_name)
409
492
  return []
410
493
  except Exception as e:
411
- logger.error(f"Error listing tools for server {server_name}: {e}")
494
+ logger.error("Error listing tools for server %s: %s", server_name, e)
412
495
  return []
413
496
 
414
497
  # ------------------------------------------------------------------ #
415
498
  # EXTRA HELPERS - ping / resources / prompts #
416
499
  # ------------------------------------------------------------------ #
417
- async def ping_servers(self) -> List[Dict[str, Any]]:
500
+ async def ping_servers(self) -> list[dict[str, Any]]:
418
501
  if self._closed:
419
502
  return []
420
-
503
+
421
504
  async def _ping_one(name: str, tr: MCPBaseTransport):
422
505
  try:
423
- ok = await asyncio.wait_for(tr.send_ping(), timeout=5.0)
506
+ ok = await asyncio.wait_for(tr.send_ping(), timeout=self.timeout_config.quick)
424
507
  except Exception:
425
508
  ok = False
426
509
  return {"server": name, "ok": ok}
427
510
 
428
511
  return await asyncio.gather(*(_ping_one(n, t) for n, t in self.transports.items()), return_exceptions=True)
429
512
 
430
- async def list_resources(self) -> List[Dict[str, Any]]:
513
+ async def list_resources(self) -> list[dict[str, Any]]:
431
514
  if self._closed:
432
515
  return []
433
-
434
- out: List[Dict[str, Any]] = []
516
+
517
+ out: list[dict[str, Any]] = []
435
518
 
436
519
  async def _one(name: str, tr: MCPBaseTransport):
437
520
  try:
438
- res = await asyncio.wait_for(tr.list_resources(), timeout=10.0)
439
- resources = (
440
- res.get("resources", []) if isinstance(res, dict) else res
441
- )
521
+ res = await asyncio.wait_for(tr.list_resources(), timeout=self.timeout_config.operation)
522
+ resources = res.get("resources", []) if isinstance(res, dict) else res
442
523
  for item in resources:
443
524
  item = dict(item)
444
525
  item["server"] = name
@@ -449,15 +530,15 @@ class StreamManager:
449
530
  await asyncio.gather(*(_one(n, t) for n, t in self.transports.items()), return_exceptions=True)
450
531
  return out
451
532
 
452
- async def list_prompts(self) -> List[Dict[str, Any]]:
533
+ async def list_prompts(self) -> list[dict[str, Any]]:
453
534
  if self._closed:
454
535
  return []
455
-
456
- out: List[Dict[str, Any]] = []
536
+
537
+ out: list[dict[str, Any]] = []
457
538
 
458
539
  async def _one(name: str, tr: MCPBaseTransport):
459
540
  try:
460
- res = await asyncio.wait_for(tr.list_prompts(), timeout=10.0)
541
+ res = await asyncio.wait_for(tr.list_prompts(), timeout=self.timeout_config.operation)
461
542
  prompts = res.get("prompts", []) if isinstance(res, dict) else res
462
543
  for item in prompts:
463
544
  item = dict(item)
@@ -475,45 +556,40 @@ class StreamManager:
475
556
  async def call_tool(
476
557
  self,
477
558
  tool_name: str,
478
- arguments: Dict[str, Any],
479
- server_name: Optional[str] = None,
480
- timeout: Optional[float] = None,
481
- ) -> Dict[str, Any]:
559
+ arguments: dict[str, Any],
560
+ server_name: str | None = None,
561
+ timeout: float | None = None,
562
+ ) -> dict[str, Any]:
482
563
  """Call a tool on the appropriate server with timeout support."""
483
564
  if self._closed:
484
565
  return {
485
566
  "isError": True,
486
567
  "error": "StreamManager is closed",
487
568
  }
488
-
569
+
489
570
  server_name = server_name or self.get_server_for_tool(tool_name)
490
571
  if not server_name or server_name not in self.transports:
491
572
  return {
492
573
  "isError": True,
493
574
  "error": f"No server found for tool: {tool_name}",
494
575
  }
495
-
576
+
496
577
  transport = self.transports[server_name]
497
-
578
+
498
579
  if timeout is not None:
499
580
  logger.debug("Calling tool '%s' with %ss timeout", tool_name, timeout)
500
581
  try:
501
- if hasattr(transport, 'call_tool'):
582
+ if hasattr(transport, "call_tool"):
502
583
  import inspect
584
+
503
585
  sig = inspect.signature(transport.call_tool)
504
- if 'timeout' in sig.parameters:
586
+ if "timeout" in sig.parameters:
505
587
  return await transport.call_tool(tool_name, arguments, timeout=timeout)
506
588
  else:
507
- return await asyncio.wait_for(
508
- transport.call_tool(tool_name, arguments),
509
- timeout=timeout
510
- )
589
+ return await asyncio.wait_for(transport.call_tool(tool_name, arguments), timeout=timeout)
511
590
  else:
512
- return await asyncio.wait_for(
513
- transport.call_tool(tool_name, arguments),
514
- timeout=timeout
515
- )
516
- except asyncio.TimeoutError:
591
+ return await asyncio.wait_for(transport.call_tool(tool_name, arguments), timeout=timeout)
592
+ except TimeoutError:
517
593
  logger.warning("Tool '%s' timed out after %ss", tool_name, timeout)
518
594
  return {
519
595
  "isError": True,
@@ -521,28 +597,28 @@ class StreamManager:
521
597
  }
522
598
  else:
523
599
  return await transport.call_tool(tool_name, arguments)
524
-
600
+
525
601
  # ------------------------------------------------------------------ #
526
602
  # ENHANCED shutdown with robust error handling #
527
603
  # ------------------------------------------------------------------ #
528
604
  async def close(self) -> None:
529
605
  """
530
606
  Close all transports safely with enhanced error handling.
531
-
607
+
532
608
  ENHANCED: Uses asyncio.shield() to protect critical cleanup and
533
609
  provides multiple fallback strategies for different failure modes.
534
610
  """
535
611
  if self._closed:
536
612
  logger.debug("StreamManager already closed")
537
613
  return
538
-
614
+
539
615
  if not self.transports:
540
616
  logger.debug("No transports to close")
541
617
  self._closed = True
542
618
  return
543
-
544
- logger.debug(f"Closing {len(self.transports)} transports...")
545
-
619
+
620
+ logger.debug("Closing %d transports...", len(self.transports))
621
+
546
622
  try:
547
623
  # Use shield to protect the cleanup operation from cancellation
548
624
  await asyncio.shield(self._do_close_all_transports())
@@ -551,7 +627,7 @@ class StreamManager:
551
627
  logger.debug("Close operation cancelled, performing synchronous cleanup")
552
628
  self._sync_cleanup()
553
629
  except Exception as e:
554
- logger.debug(f"Error during close: {e}")
630
+ logger.debug("Error during close: %s", e)
555
631
  self._sync_cleanup()
556
632
  finally:
557
633
  self._closed = True
@@ -560,99 +636,91 @@ class StreamManager:
560
636
  """Protected cleanup implementation with multiple strategies."""
561
637
  close_results = []
562
638
  transport_items = list(self.transports.items())
563
-
639
+
564
640
  # Strategy 1: Try concurrent close with timeout
565
641
  try:
566
642
  await self._concurrent_close(transport_items, close_results)
567
643
  except Exception as e:
568
- logger.debug(f"Concurrent close failed: {e}, falling back to sequential close")
644
+ logger.debug("Concurrent close failed: %s, falling back to sequential close", e)
569
645
  # Strategy 2: Fall back to sequential close
570
646
  await self._sequential_close(transport_items, close_results)
571
-
647
+
572
648
  # Always clean up state
573
649
  self._cleanup_state()
574
-
650
+
575
651
  # Log summary
576
652
  if close_results:
577
653
  successful_closes = sum(1 for _, success, _ in close_results if success)
578
- logger.debug(f"Transport cleanup: {successful_closes}/{len(close_results)} closed successfully")
654
+ logger.debug("Transport cleanup: %d/%d closed successfully", successful_closes, len(close_results))
579
655
 
580
- async def _concurrent_close(self, transport_items: List[Tuple[str, MCPBaseTransport]], close_results: List) -> None:
656
+ async def _concurrent_close(self, transport_items: list[tuple[str, MCPBaseTransport]], close_results: list) -> None:
581
657
  """Try to close all transports concurrently."""
582
658
  close_tasks = []
583
659
  for name, transport in transport_items:
584
- task = asyncio.create_task(
585
- self._close_single_transport(name, transport),
586
- name=f"close_{name}"
587
- )
660
+ task = asyncio.create_task(self._close_single_transport(name, transport), name=f"close_{name}")
588
661
  close_tasks.append((name, task))
589
-
662
+
590
663
  # Wait for all tasks with a reasonable timeout
591
664
  if close_tasks:
592
665
  try:
593
666
  results = await asyncio.wait_for(
594
- asyncio.gather(
595
- *[task for _, task in close_tasks],
596
- return_exceptions=True
597
- ),
598
- timeout=self._shutdown_timeout
667
+ asyncio.gather(*[task for _, task in close_tasks], return_exceptions=True),
668
+ timeout=self.timeout_config.shutdown,
599
669
  )
600
-
670
+
601
671
  # Process results
602
672
  for i, (name, _) in enumerate(close_tasks):
603
673
  result = results[i] if i < len(results) else None
604
674
  if isinstance(result, Exception):
605
- logger.debug(f"Transport {name} close failed: {result}")
675
+ logger.debug("Transport %s close failed: %s", name, result)
606
676
  close_results.append((name, False, str(result)))
607
677
  else:
608
- logger.debug(f"Transport {name} closed successfully")
678
+ logger.debug("Transport %s closed successfully", name)
609
679
  close_results.append((name, True, None))
610
-
611
- except asyncio.TimeoutError:
680
+
681
+ except TimeoutError:
612
682
  # Cancel any remaining tasks
613
683
  for name, task in close_tasks:
614
684
  if not task.done():
615
685
  task.cancel()
616
686
  close_results.append((name, False, "timeout"))
617
-
687
+
618
688
  # Brief wait for cancellations to complete
619
- try:
689
+ with contextlib.suppress(TimeoutError):
620
690
  await asyncio.wait_for(
621
691
  asyncio.gather(*[task for _, task in close_tasks], return_exceptions=True),
622
- timeout=0.5
692
+ timeout=self.timeout_config.shutdown,
623
693
  )
624
- except asyncio.TimeoutError:
625
- pass # Some tasks may not cancel cleanly
626
694
 
627
- async def _sequential_close(self, transport_items: List[Tuple[str, MCPBaseTransport]], close_results: List) -> None:
695
+ async def _sequential_close(self, transport_items: list[tuple[str, MCPBaseTransport]], close_results: list) -> None:
628
696
  """Close transports one by one as fallback."""
629
697
  for name, transport in transport_items:
630
698
  try:
631
699
  await asyncio.wait_for(
632
700
  self._close_single_transport(name, transport),
633
- timeout=0.5 # Short timeout per transport
701
+ timeout=self.timeout_config.shutdown,
634
702
  )
635
- logger.debug(f"Closed transport: {name}")
703
+ logger.debug("Closed transport: %s", name)
636
704
  close_results.append((name, True, None))
637
- except asyncio.TimeoutError:
638
- logger.debug(f"Transport {name} close timed out (normal during shutdown)")
705
+ except TimeoutError:
706
+ logger.debug("Transport %s close timed out (normal during shutdown)", name)
639
707
  close_results.append((name, False, "timeout"))
640
708
  except asyncio.CancelledError:
641
- logger.debug(f"Transport {name} close cancelled during event loop shutdown")
709
+ logger.debug("Transport %s close cancelled during event loop shutdown", name)
642
710
  close_results.append((name, False, "cancelled"))
643
711
  except Exception as e:
644
- logger.debug(f"Error closing transport {name}: {e}")
712
+ logger.debug("Error closing transport %s: %s", name, e)
645
713
  close_results.append((name, False, str(e)))
646
714
 
647
715
  async def _close_single_transport(self, name: str, transport: MCPBaseTransport) -> None:
648
716
  """Close a single transport with error handling."""
649
717
  try:
650
- if hasattr(transport, 'close') and callable(transport.close):
718
+ if hasattr(transport, "close") and callable(transport.close):
651
719
  await transport.close()
652
720
  else:
653
- logger.debug(f"Transport {name} has no close method")
721
+ logger.debug("Transport %s has no close method", name)
654
722
  except Exception as e:
655
- logger.debug(f"Error closing transport {name}: {e}")
723
+ logger.debug("Error closing transport %s: %s", name, e)
656
724
  raise
657
725
 
658
726
  def _sync_cleanup(self) -> None:
@@ -660,9 +728,9 @@ class StreamManager:
660
728
  try:
661
729
  transport_count = len(self.transports)
662
730
  self._cleanup_state()
663
- logger.debug(f"Synchronous cleanup completed for {transport_count} transports")
731
+ logger.debug("Synchronous cleanup completed for %d transports", transport_count)
664
732
  except Exception as e:
665
- logger.debug(f"Error during synchronous cleanup: {e}")
733
+ logger.debug("Error during synchronous cleanup: %s", e)
666
734
 
667
735
  def _cleanup_state(self) -> None:
668
736
  """Clean up internal state synchronously."""
@@ -673,17 +741,17 @@ class StreamManager:
673
741
  self.all_tools.clear()
674
742
  self.server_names.clear()
675
743
  except Exception as e:
676
- logger.debug(f"Error during state cleanup: {e}")
744
+ logger.debug("Error during state cleanup: %s", e)
677
745
 
678
746
  # ------------------------------------------------------------------ #
679
747
  # backwards-compat: streams helper #
680
748
  # ------------------------------------------------------------------ #
681
- def get_streams(self) -> List[Tuple[Any, Any]]:
749
+ def get_streams(self) -> list[tuple[Any, Any]]:
682
750
  """Return a list of (read_stream, write_stream) tuples for all transports."""
683
751
  if self._closed:
684
752
  return []
685
-
686
- pairs: List[Tuple[Any, Any]] = []
753
+
754
+ pairs: list[tuple[Any, Any]] = []
687
755
 
688
756
  for tr in self.transports.values():
689
757
  if hasattr(tr, "get_streams") and callable(tr.get_streams):
@@ -698,7 +766,7 @@ class StreamManager:
698
766
  return pairs
699
767
 
700
768
  @property
701
- def streams(self) -> List[Tuple[Any, Any]]:
769
+ def streams(self) -> list[tuple[Any, Any]]:
702
770
  """Convenience alias for get_streams()."""
703
771
  return self.get_streams()
704
772
 
@@ -713,34 +781,23 @@ class StreamManager:
713
781
  """Get the number of active transports."""
714
782
  return len(self.transports)
715
783
 
716
- async def health_check(self) -> Dict[str, Any]:
784
+ async def health_check(self) -> dict[str, Any]:
717
785
  """Perform a health check on all transports."""
718
786
  if self._closed:
719
787
  return {"status": "closed", "transports": {}}
720
-
721
- health_info = {
722
- "status": "active",
723
- "transport_count": len(self.transports),
724
- "transports": {}
725
- }
726
-
788
+
789
+ health_info = {"status": "active", "transport_count": len(self.transports), "transports": {}}
790
+
727
791
  for name, transport in self.transports.items():
728
792
  try:
729
- ping_ok = await asyncio.wait_for(transport.send_ping(), timeout=5.0)
793
+ ping_ok = await asyncio.wait_for(transport.send_ping(), timeout=self.timeout_config.quick)
730
794
  health_info["transports"][name] = {
731
795
  "status": "healthy" if ping_ok else "unhealthy",
732
- "ping_success": ping_ok
733
- }
734
- except asyncio.TimeoutError:
735
- health_info["transports"][name] = {
736
- "status": "timeout",
737
- "ping_success": False
796
+ "ping_success": ping_ok,
738
797
  }
798
+ except TimeoutError:
799
+ health_info["transports"][name] = {"status": "timeout", "ping_success": False}
739
800
  except Exception as e:
740
- health_info["transports"][name] = {
741
- "status": "error",
742
- "ping_success": False,
743
- "error": str(e)
744
- }
745
-
746
- return health_info
801
+ health_info["transports"][name] = {"status": "error", "ping_success": False, "error": str(e)}
802
+
803
+ return health_info