chuk-tool-processor 0.6.11__py3-none-any.whl → 0.6.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chuk-tool-processor might be problematic. Click here for more details.

Files changed (56) hide show
  1. chuk_tool_processor/core/__init__.py +1 -1
  2. chuk_tool_processor/core/exceptions.py +10 -4
  3. chuk_tool_processor/core/processor.py +97 -97
  4. chuk_tool_processor/execution/strategies/inprocess_strategy.py +142 -150
  5. chuk_tool_processor/execution/strategies/subprocess_strategy.py +200 -205
  6. chuk_tool_processor/execution/tool_executor.py +82 -84
  7. chuk_tool_processor/execution/wrappers/caching.py +102 -103
  8. chuk_tool_processor/execution/wrappers/rate_limiting.py +45 -42
  9. chuk_tool_processor/execution/wrappers/retry.py +23 -25
  10. chuk_tool_processor/logging/__init__.py +23 -17
  11. chuk_tool_processor/logging/context.py +40 -45
  12. chuk_tool_processor/logging/formatter.py +22 -21
  13. chuk_tool_processor/logging/helpers.py +24 -38
  14. chuk_tool_processor/logging/metrics.py +11 -13
  15. chuk_tool_processor/mcp/__init__.py +8 -12
  16. chuk_tool_processor/mcp/mcp_tool.py +153 -109
  17. chuk_tool_processor/mcp/register_mcp_tools.py +17 -17
  18. chuk_tool_processor/mcp/setup_mcp_http_streamable.py +11 -13
  19. chuk_tool_processor/mcp/setup_mcp_sse.py +11 -13
  20. chuk_tool_processor/mcp/setup_mcp_stdio.py +7 -9
  21. chuk_tool_processor/mcp/stream_manager.py +168 -204
  22. chuk_tool_processor/mcp/transport/__init__.py +4 -4
  23. chuk_tool_processor/mcp/transport/base_transport.py +43 -58
  24. chuk_tool_processor/mcp/transport/http_streamable_transport.py +145 -163
  25. chuk_tool_processor/mcp/transport/sse_transport.py +266 -252
  26. chuk_tool_processor/mcp/transport/stdio_transport.py +171 -189
  27. chuk_tool_processor/models/__init__.py +1 -1
  28. chuk_tool_processor/models/execution_strategy.py +16 -21
  29. chuk_tool_processor/models/streaming_tool.py +28 -25
  30. chuk_tool_processor/models/tool_call.py +19 -34
  31. chuk_tool_processor/models/tool_export_mixin.py +22 -8
  32. chuk_tool_processor/models/tool_result.py +40 -77
  33. chuk_tool_processor/models/validated_tool.py +14 -16
  34. chuk_tool_processor/plugins/__init__.py +1 -1
  35. chuk_tool_processor/plugins/discovery.py +10 -10
  36. chuk_tool_processor/plugins/parsers/__init__.py +1 -1
  37. chuk_tool_processor/plugins/parsers/base.py +1 -2
  38. chuk_tool_processor/plugins/parsers/function_call_tool.py +13 -8
  39. chuk_tool_processor/plugins/parsers/json_tool.py +4 -3
  40. chuk_tool_processor/plugins/parsers/openai_tool.py +12 -7
  41. chuk_tool_processor/plugins/parsers/xml_tool.py +4 -4
  42. chuk_tool_processor/registry/__init__.py +12 -12
  43. chuk_tool_processor/registry/auto_register.py +22 -30
  44. chuk_tool_processor/registry/decorators.py +127 -129
  45. chuk_tool_processor/registry/interface.py +26 -23
  46. chuk_tool_processor/registry/metadata.py +27 -22
  47. chuk_tool_processor/registry/provider.py +17 -18
  48. chuk_tool_processor/registry/providers/__init__.py +16 -19
  49. chuk_tool_processor/registry/providers/memory.py +18 -25
  50. chuk_tool_processor/registry/tool_export.py +42 -51
  51. chuk_tool_processor/utils/validation.py +15 -16
  52. {chuk_tool_processor-0.6.11.dist-info → chuk_tool_processor-0.6.13.dist-info}/METADATA +1 -1
  53. chuk_tool_processor-0.6.13.dist-info/RECORD +60 -0
  54. chuk_tool_processor-0.6.11.dist-info/RECORD +0 -60
  55. {chuk_tool_processor-0.6.11.dist-info → chuk_tool_processor-0.6.13.dist-info}/WHEEL +0 -0
  56. {chuk_tool_processor-0.6.11.dist-info → chuk_tool_processor-0.6.13.dist-info}/top_level.txt +0 -0
@@ -2,23 +2,26 @@
2
2
  """
3
3
  StreamManager for CHUK Tool Processor - Enhanced with robust shutdown handling and headers support
4
4
  """
5
+
5
6
  from __future__ import annotations
6
7
 
7
8
  import asyncio
8
- from typing import Any, Dict, List, Optional, Tuple
9
+ import contextlib
9
10
  from contextlib import asynccontextmanager
11
+ from typing import Any
10
12
 
11
13
  # --------------------------------------------------------------------------- #
12
14
  # CHUK imports #
13
15
  # --------------------------------------------------------------------------- #
14
- from chuk_mcp.config import load_config
16
+ from chuk_mcp.config import load_config # type: ignore[import-untyped]
17
+
18
+ from chuk_tool_processor.logging import get_logger
15
19
  from chuk_tool_processor.mcp.transport import (
20
+ HTTPStreamableTransport,
16
21
  MCPBaseTransport,
17
- StdioTransport,
18
22
  SSETransport,
19
- HTTPStreamableTransport,
23
+ StdioTransport,
20
24
  )
21
- from chuk_tool_processor.logging import get_logger
22
25
 
23
26
  logger = get_logger("chuk_tool_processor.mcp.stream_manager")
24
27
 
@@ -26,9 +29,9 @@ logger = get_logger("chuk_tool_processor.mcp.stream_manager")
26
29
  class StreamManager:
27
30
  """
28
31
  Manager for MCP server streams with support for multiple transport types.
29
-
32
+
30
33
  Enhanced with robust shutdown handling and proper headers support.
31
-
34
+
32
35
  Updated to support the latest transports:
33
36
  - STDIO (process-based)
34
37
  - SSE (Server-Sent Events) with headers support
@@ -36,11 +39,11 @@ class StreamManager:
36
39
  """
37
40
 
38
41
  def __init__(self) -> None:
39
- self.transports: Dict[str, MCPBaseTransport] = {}
40
- self.server_info: List[Dict[str, Any]] = []
41
- self.tool_to_server_map: Dict[str, str] = {}
42
- self.server_names: Dict[int, str] = {}
43
- self.all_tools: List[Dict[str, Any]] = []
42
+ self.transports: dict[str, MCPBaseTransport] = {}
43
+ self.server_info: list[dict[str, Any]] = []
44
+ self.tool_to_server_map: dict[str, str] = {}
45
+ self.server_names: dict[int, str] = {}
46
+ self.all_tools: list[dict[str, Any]] = []
44
47
  self._lock = asyncio.Lock()
45
48
  self._closed = False # Track if we've been closed
46
49
  self._shutdown_timeout = 2.0 # Maximum time to spend on shutdown
@@ -52,81 +55,71 @@ class StreamManager:
52
55
  async def create(
53
56
  cls,
54
57
  config_file: str,
55
- servers: List[str],
56
- server_names: Optional[Dict[int, str]] = None,
58
+ servers: list[str],
59
+ server_names: dict[int, str] | None = None,
57
60
  transport_type: str = "stdio",
58
61
  default_timeout: float = 30.0,
59
62
  initialization_timeout: float = 60.0, # NEW: Timeout for entire initialization
60
- ) -> "StreamManager":
63
+ ) -> StreamManager:
61
64
  """Create StreamManager with timeout protection."""
62
65
  try:
63
66
  inst = cls()
64
67
  await asyncio.wait_for(
65
- inst.initialize(
66
- config_file,
67
- servers,
68
- server_names,
69
- transport_type,
70
- default_timeout=default_timeout
71
- ),
72
- timeout=initialization_timeout
68
+ inst.initialize(config_file, servers, server_names, transport_type, default_timeout=default_timeout),
69
+ timeout=initialization_timeout,
73
70
  )
74
71
  return inst
75
- except asyncio.TimeoutError:
72
+ except TimeoutError:
76
73
  logger.error("StreamManager initialization timed out after %ss", initialization_timeout)
77
74
  raise RuntimeError(f"StreamManager initialization timed out after {initialization_timeout}s")
78
75
 
79
76
  @classmethod
80
77
  async def create_with_sse(
81
78
  cls,
82
- servers: List[Dict[str, str]],
83
- server_names: Optional[Dict[int, str]] = None,
79
+ servers: list[dict[str, str]],
80
+ server_names: dict[int, str] | None = None,
84
81
  connection_timeout: float = 10.0,
85
82
  default_timeout: float = 30.0,
86
83
  initialization_timeout: float = 60.0, # NEW
87
- ) -> "StreamManager":
84
+ ) -> StreamManager:
88
85
  """Create StreamManager with SSE transport and timeout protection."""
89
86
  try:
90
87
  inst = cls()
91
88
  await asyncio.wait_for(
92
89
  inst.initialize_with_sse(
93
- servers,
94
- server_names,
95
- connection_timeout=connection_timeout,
96
- default_timeout=default_timeout
90
+ servers, server_names, connection_timeout=connection_timeout, default_timeout=default_timeout
97
91
  ),
98
- timeout=initialization_timeout
92
+ timeout=initialization_timeout,
99
93
  )
100
94
  return inst
101
- except asyncio.TimeoutError:
95
+ except TimeoutError:
102
96
  logger.error("SSE StreamManager initialization timed out after %ss", initialization_timeout)
103
97
  raise RuntimeError(f"SSE StreamManager initialization timed out after {initialization_timeout}s")
104
98
 
105
99
  @classmethod
106
100
  async def create_with_http_streamable(
107
101
  cls,
108
- servers: List[Dict[str, str]],
109
- server_names: Optional[Dict[int, str]] = None,
102
+ servers: list[dict[str, str]],
103
+ server_names: dict[int, str] | None = None,
110
104
  connection_timeout: float = 30.0,
111
105
  default_timeout: float = 30.0,
112
106
  initialization_timeout: float = 60.0, # NEW
113
- ) -> "StreamManager":
107
+ ) -> StreamManager:
114
108
  """Create StreamManager with HTTP Streamable transport and timeout protection."""
115
109
  try:
116
110
  inst = cls()
117
111
  await asyncio.wait_for(
118
112
  inst.initialize_with_http_streamable(
119
- servers,
120
- server_names,
121
- connection_timeout=connection_timeout,
122
- default_timeout=default_timeout
113
+ servers, server_names, connection_timeout=connection_timeout, default_timeout=default_timeout
123
114
  ),
124
- timeout=initialization_timeout
115
+ timeout=initialization_timeout,
125
116
  )
126
117
  return inst
127
- except asyncio.TimeoutError:
118
+ except TimeoutError:
128
119
  logger.error("HTTP Streamable StreamManager initialization timed out after %ss", initialization_timeout)
129
- raise RuntimeError(f"HTTP Streamable StreamManager initialization timed out after {initialization_timeout}s")
120
+ raise RuntimeError(
121
+ f"HTTP Streamable StreamManager initialization timed out after {initialization_timeout}s"
122
+ )
130
123
 
131
124
  # ------------------------------------------------------------------ #
132
125
  # NEW: Context manager support for automatic cleanup #
@@ -144,8 +137,8 @@ class StreamManager:
144
137
  async def create_managed(
145
138
  cls,
146
139
  config_file: str,
147
- servers: List[str],
148
- server_names: Optional[Dict[int, str]] = None,
140
+ servers: list[str],
141
+ server_names: dict[int, str] | None = None,
149
142
  transport_type: str = "stdio",
150
143
  default_timeout: float = 30.0,
151
144
  ):
@@ -170,15 +163,15 @@ class StreamManager:
170
163
  async def initialize(
171
164
  self,
172
165
  config_file: str,
173
- servers: List[str],
174
- server_names: Optional[Dict[int, str]] = None,
166
+ servers: list[str],
167
+ server_names: dict[int, str] | None = None,
175
168
  transport_type: str = "stdio",
176
169
  default_timeout: float = 30.0,
177
170
  ) -> None:
178
171
  """Initialize with graceful headers handling for all transport types."""
179
172
  if self._closed:
180
173
  raise RuntimeError("Cannot initialize a closed StreamManager")
181
-
174
+
182
175
  async with self._lock:
183
176
  self.server_names = server_names or {}
184
177
 
@@ -188,59 +181,61 @@ class StreamManager:
188
181
  params = await load_config(config_file, server_name)
189
182
  transport: MCPBaseTransport = StdioTransport(params)
190
183
  elif transport_type == "sse":
191
- logger.warning("Using SSE transport in initialize() - consider using initialize_with_sse() instead")
184
+ logger.warning(
185
+ "Using SSE transport in initialize() - consider using initialize_with_sse() instead"
186
+ )
192
187
  params = await load_config(config_file, server_name)
193
-
194
- if isinstance(params, dict) and 'url' in params:
195
- sse_url = params['url']
196
- api_key = params.get('api_key')
197
- headers = params.get('headers', {})
188
+
189
+ if isinstance(params, dict) and "url" in params:
190
+ sse_url = params["url"]
191
+ api_key = params.get("api_key")
192
+ headers = params.get("headers", {})
198
193
  else:
199
194
  sse_url = "http://localhost:8000"
200
195
  api_key = None
201
196
  headers = {}
202
197
  logger.warning("No URL configured for SSE transport, using default: %s", sse_url)
203
-
198
+
204
199
  # Build SSE transport with optional headers
205
- transport_params = {
206
- 'url': sse_url,
207
- 'api_key': api_key,
208
- 'default_timeout': default_timeout
209
- }
200
+ transport_params = {"url": sse_url, "api_key": api_key, "default_timeout": default_timeout}
210
201
  if headers:
211
- transport_params['headers'] = headers
212
-
202
+ transport_params["headers"] = headers
203
+
213
204
  transport = SSETransport(**transport_params)
214
-
205
+
215
206
  elif transport_type == "http_streamable":
216
- logger.warning("Using HTTP Streamable transport in initialize() - consider using initialize_with_http_streamable() instead")
207
+ logger.warning(
208
+ "Using HTTP Streamable transport in initialize() - consider using initialize_with_http_streamable() instead"
209
+ )
217
210
  params = await load_config(config_file, server_name)
218
-
219
- if isinstance(params, dict) and 'url' in params:
220
- http_url = params['url']
221
- api_key = params.get('api_key')
222
- headers = params.get('headers', {})
223
- session_id = params.get('session_id')
211
+
212
+ if isinstance(params, dict) and "url" in params:
213
+ http_url = params["url"]
214
+ api_key = params.get("api_key")
215
+ headers = params.get("headers", {})
216
+ session_id = params.get("session_id")
224
217
  else:
225
218
  http_url = "http://localhost:8000"
226
219
  api_key = None
227
220
  headers = {}
228
221
  session_id = None
229
- logger.warning("No URL configured for HTTP Streamable transport, using default: %s", http_url)
230
-
222
+ logger.warning(
223
+ "No URL configured for HTTP Streamable transport, using default: %s", http_url
224
+ )
225
+
231
226
  # Build HTTP transport (headers not supported yet)
232
227
  transport_params = {
233
- 'url': http_url,
234
- 'api_key': api_key,
235
- 'default_timeout': default_timeout,
236
- 'session_id': session_id
228
+ "url": http_url,
229
+ "api_key": api_key,
230
+ "default_timeout": default_timeout,
231
+ "session_id": session_id,
237
232
  }
238
233
  # Note: headers not added until HTTPStreamableTransport supports them
239
234
  if headers:
240
235
  logger.debug("Headers provided but not supported in HTTPStreamableTransport yet")
241
-
236
+
242
237
  transport = HTTPStreamableTransport(**transport_params)
243
-
238
+
244
239
  else:
245
240
  logger.error("Unsupported transport type: %s", transport_type)
246
241
  continue
@@ -271,7 +266,7 @@ class StreamManager:
271
266
  }
272
267
  )
273
268
  logger.debug("Initialised %s - %d tool(s)", server_name, len(tools))
274
- except asyncio.TimeoutError:
269
+ except TimeoutError:
275
270
  logger.error("Timeout initialising %s", server_name)
276
271
  except Exception as exc:
277
272
  logger.error("Error initialising %s: %s", server_name, exc)
@@ -284,15 +279,15 @@ class StreamManager:
284
279
 
285
280
  async def initialize_with_sse(
286
281
  self,
287
- servers: List[Dict[str, str]],
288
- server_names: Optional[Dict[int, str]] = None,
282
+ servers: list[dict[str, str]],
283
+ server_names: dict[int, str] | None = None,
289
284
  connection_timeout: float = 10.0,
290
285
  default_timeout: float = 30.0,
291
286
  ) -> None:
292
287
  """Initialize with SSE transport with optional headers support."""
293
288
  if self._closed:
294
289
  raise RuntimeError("Cannot initialize a closed StreamManager")
295
-
290
+
296
291
  async with self._lock:
297
292
  self.server_names = server_names or {}
298
293
 
@@ -304,20 +299,20 @@ class StreamManager:
304
299
  try:
305
300
  # Build SSE transport parameters with optional headers
306
301
  transport_params = {
307
- 'url': url,
308
- 'api_key': cfg.get("api_key"),
309
- 'connection_timeout': connection_timeout,
310
- 'default_timeout': default_timeout
302
+ "url": url,
303
+ "api_key": cfg.get("api_key"),
304
+ "connection_timeout": connection_timeout,
305
+ "default_timeout": default_timeout,
311
306
  }
312
-
307
+
313
308
  # Add headers if provided
314
309
  headers = cfg.get("headers", {})
315
310
  if headers:
316
311
  logger.debug("SSE %s: Using configured headers: %s", name, list(headers.keys()))
317
- transport_params['headers'] = headers
318
-
312
+ transport_params["headers"] = headers
313
+
319
314
  transport = SSETransport(**transport_params)
320
-
315
+
321
316
  if not await asyncio.wait_for(transport.initialize(), timeout=connection_timeout):
322
317
  logger.error("Failed to init SSE %s", name)
323
318
  continue
@@ -332,11 +327,9 @@ class StreamManager:
332
327
  self.tool_to_server_map[tname] = name
333
328
  self.all_tools.extend(tools)
334
329
 
335
- self.server_info.append(
336
- {"id": idx, "name": name, "tools": len(tools), "status": status}
337
- )
330
+ self.server_info.append({"id": idx, "name": name, "tools": len(tools), "status": status})
338
331
  logger.debug("Initialised SSE %s - %d tool(s)", name, len(tools))
339
- except asyncio.TimeoutError:
332
+ except TimeoutError:
340
333
  logger.error("Timeout initialising SSE %s", name)
341
334
  except Exception as exc:
342
335
  logger.error("Error initialising SSE %s: %s", name, exc)
@@ -349,15 +342,15 @@ class StreamManager:
349
342
 
350
343
  async def initialize_with_http_streamable(
351
344
  self,
352
- servers: List[Dict[str, str]],
353
- server_names: Optional[Dict[int, str]] = None,
345
+ servers: list[dict[str, str]],
346
+ server_names: dict[int, str] | None = None,
354
347
  connection_timeout: float = 30.0,
355
348
  default_timeout: float = 30.0,
356
349
  ) -> None:
357
350
  """Initialize with HTTP Streamable transport with graceful headers handling."""
358
351
  if self._closed:
359
352
  raise RuntimeError("Cannot initialize a closed StreamManager")
360
-
353
+
361
354
  async with self._lock:
362
355
  self.server_names = server_names or {}
363
356
 
@@ -369,22 +362,22 @@ class StreamManager:
369
362
  try:
370
363
  # Build HTTP Streamable transport parameters
371
364
  transport_params = {
372
- 'url': url,
373
- 'api_key': cfg.get("api_key"),
374
- 'connection_timeout': connection_timeout,
375
- 'default_timeout': default_timeout,
376
- 'session_id': cfg.get("session_id")
365
+ "url": url,
366
+ "api_key": cfg.get("api_key"),
367
+ "connection_timeout": connection_timeout,
368
+ "default_timeout": default_timeout,
369
+ "session_id": cfg.get("session_id"),
377
370
  }
378
-
371
+
379
372
  # Handle headers if provided (for future HTTPStreamableTransport support)
380
373
  headers = cfg.get("headers", {})
381
374
  if headers:
382
375
  logger.debug("HTTP Streamable %s: Headers provided but not yet supported in transport", name)
383
376
  # TODO: Add headers support when HTTPStreamableTransport is updated
384
377
  # transport_params['headers'] = headers
385
-
378
+
386
379
  transport = HTTPStreamableTransport(**transport_params)
387
-
380
+
388
381
  if not await asyncio.wait_for(transport.initialize(), timeout=connection_timeout):
389
382
  logger.error("Failed to init HTTP Streamable %s", name)
390
383
  continue
@@ -399,11 +392,9 @@ class StreamManager:
399
392
  self.tool_to_server_map[tname] = name
400
393
  self.all_tools.extend(tools)
401
394
 
402
- self.server_info.append(
403
- {"id": idx, "name": name, "tools": len(tools), "status": status}
404
- )
395
+ self.server_info.append({"id": idx, "name": name, "tools": len(tools), "status": status})
405
396
  logger.debug("Initialised HTTP Streamable %s - %d tool(s)", name, len(tools))
406
- except asyncio.TimeoutError:
397
+ except TimeoutError:
407
398
  logger.error("Timeout initialising HTTP Streamable %s", name)
408
399
  except Exception as exc:
409
400
  logger.error("Error initialising HTTP Streamable %s: %s", name, exc)
@@ -417,32 +408,32 @@ class StreamManager:
417
408
  # ------------------------------------------------------------------ #
418
409
  # queries #
419
410
  # ------------------------------------------------------------------ #
420
- def get_all_tools(self) -> List[Dict[str, Any]]:
411
+ def get_all_tools(self) -> list[dict[str, Any]]:
421
412
  return self.all_tools
422
413
 
423
- def get_server_for_tool(self, tool_name: str) -> Optional[str]:
414
+ def get_server_for_tool(self, tool_name: str) -> str | None:
424
415
  return self.tool_to_server_map.get(tool_name)
425
416
 
426
- def get_server_info(self) -> List[Dict[str, Any]]:
417
+ def get_server_info(self) -> list[dict[str, Any]]:
427
418
  return self.server_info
428
-
429
- async def list_tools(self, server_name: str) -> List[Dict[str, Any]]:
419
+
420
+ async def list_tools(self, server_name: str) -> list[dict[str, Any]]:
430
421
  """List all tools available from a specific server."""
431
422
  if self._closed:
432
423
  logger.warning("Cannot list tools: StreamManager is closed")
433
424
  return []
434
-
425
+
435
426
  if server_name not in self.transports:
436
427
  logger.error("Server '%s' not found in transports", server_name)
437
428
  return []
438
-
429
+
439
430
  transport = self.transports[server_name]
440
-
431
+
441
432
  try:
442
433
  tools = await asyncio.wait_for(transport.get_tools(), timeout=10.0)
443
434
  logger.debug("Found %d tools for server %s", len(tools), server_name)
444
435
  return tools
445
- except asyncio.TimeoutError:
436
+ except TimeoutError:
446
437
  logger.error("Timeout listing tools for server %s", server_name)
447
438
  return []
448
439
  except Exception as e:
@@ -452,10 +443,10 @@ class StreamManager:
452
443
  # ------------------------------------------------------------------ #
453
444
  # EXTRA HELPERS - ping / resources / prompts #
454
445
  # ------------------------------------------------------------------ #
455
- async def ping_servers(self) -> List[Dict[str, Any]]:
446
+ async def ping_servers(self) -> list[dict[str, Any]]:
456
447
  if self._closed:
457
448
  return []
458
-
449
+
459
450
  async def _ping_one(name: str, tr: MCPBaseTransport):
460
451
  try:
461
452
  ok = await asyncio.wait_for(tr.send_ping(), timeout=5.0)
@@ -465,18 +456,16 @@ class StreamManager:
465
456
 
466
457
  return await asyncio.gather(*(_ping_one(n, t) for n, t in self.transports.items()), return_exceptions=True)
467
458
 
468
- async def list_resources(self) -> List[Dict[str, Any]]:
459
+ async def list_resources(self) -> list[dict[str, Any]]:
469
460
  if self._closed:
470
461
  return []
471
-
472
- out: List[Dict[str, Any]] = []
462
+
463
+ out: list[dict[str, Any]] = []
473
464
 
474
465
  async def _one(name: str, tr: MCPBaseTransport):
475
466
  try:
476
467
  res = await asyncio.wait_for(tr.list_resources(), timeout=10.0)
477
- resources = (
478
- res.get("resources", []) if isinstance(res, dict) else res
479
- )
468
+ resources = res.get("resources", []) if isinstance(res, dict) else res
480
469
  for item in resources:
481
470
  item = dict(item)
482
471
  item["server"] = name
@@ -487,11 +476,11 @@ class StreamManager:
487
476
  await asyncio.gather(*(_one(n, t) for n, t in self.transports.items()), return_exceptions=True)
488
477
  return out
489
478
 
490
- async def list_prompts(self) -> List[Dict[str, Any]]:
479
+ async def list_prompts(self) -> list[dict[str, Any]]:
491
480
  if self._closed:
492
481
  return []
493
-
494
- out: List[Dict[str, Any]] = []
482
+
483
+ out: list[dict[str, Any]] = []
495
484
 
496
485
  async def _one(name: str, tr: MCPBaseTransport):
497
486
  try:
@@ -513,45 +502,40 @@ class StreamManager:
513
502
  async def call_tool(
514
503
  self,
515
504
  tool_name: str,
516
- arguments: Dict[str, Any],
517
- server_name: Optional[str] = None,
518
- timeout: Optional[float] = None,
519
- ) -> Dict[str, Any]:
505
+ arguments: dict[str, Any],
506
+ server_name: str | None = None,
507
+ timeout: float | None = None,
508
+ ) -> dict[str, Any]:
520
509
  """Call a tool on the appropriate server with timeout support."""
521
510
  if self._closed:
522
511
  return {
523
512
  "isError": True,
524
513
  "error": "StreamManager is closed",
525
514
  }
526
-
515
+
527
516
  server_name = server_name or self.get_server_for_tool(tool_name)
528
517
  if not server_name or server_name not in self.transports:
529
518
  return {
530
519
  "isError": True,
531
520
  "error": f"No server found for tool: {tool_name}",
532
521
  }
533
-
522
+
534
523
  transport = self.transports[server_name]
535
-
524
+
536
525
  if timeout is not None:
537
526
  logger.debug("Calling tool '%s' with %ss timeout", tool_name, timeout)
538
527
  try:
539
- if hasattr(transport, 'call_tool'):
528
+ if hasattr(transport, "call_tool"):
540
529
  import inspect
530
+
541
531
  sig = inspect.signature(transport.call_tool)
542
- if 'timeout' in sig.parameters:
532
+ if "timeout" in sig.parameters:
543
533
  return await transport.call_tool(tool_name, arguments, timeout=timeout)
544
534
  else:
545
- return await asyncio.wait_for(
546
- transport.call_tool(tool_name, arguments),
547
- timeout=timeout
548
- )
535
+ return await asyncio.wait_for(transport.call_tool(tool_name, arguments), timeout=timeout)
549
536
  else:
550
- return await asyncio.wait_for(
551
- transport.call_tool(tool_name, arguments),
552
- timeout=timeout
553
- )
554
- except asyncio.TimeoutError:
537
+ return await asyncio.wait_for(transport.call_tool(tool_name, arguments), timeout=timeout)
538
+ except TimeoutError:
555
539
  logger.warning("Tool '%s' timed out after %ss", tool_name, timeout)
556
540
  return {
557
541
  "isError": True,
@@ -559,28 +543,28 @@ class StreamManager:
559
543
  }
560
544
  else:
561
545
  return await transport.call_tool(tool_name, arguments)
562
-
546
+
563
547
  # ------------------------------------------------------------------ #
564
548
  # ENHANCED shutdown with robust error handling #
565
549
  # ------------------------------------------------------------------ #
566
550
  async def close(self) -> None:
567
551
  """
568
552
  Close all transports safely with enhanced error handling.
569
-
553
+
570
554
  ENHANCED: Uses asyncio.shield() to protect critical cleanup and
571
555
  provides multiple fallback strategies for different failure modes.
572
556
  """
573
557
  if self._closed:
574
558
  logger.debug("StreamManager already closed")
575
559
  return
576
-
560
+
577
561
  if not self.transports:
578
562
  logger.debug("No transports to close")
579
563
  self._closed = True
580
564
  return
581
-
565
+
582
566
  logger.debug("Closing %d transports...", len(self.transports))
583
-
567
+
584
568
  try:
585
569
  # Use shield to protect the cleanup operation from cancellation
586
570
  await asyncio.shield(self._do_close_all_transports())
@@ -598,7 +582,7 @@ class StreamManager:
598
582
  """Protected cleanup implementation with multiple strategies."""
599
583
  close_results = []
600
584
  transport_items = list(self.transports.items())
601
-
585
+
602
586
  # Strategy 1: Try concurrent close with timeout
603
587
  try:
604
588
  await self._concurrent_close(transport_items, close_results)
@@ -606,36 +590,30 @@ class StreamManager:
606
590
  logger.debug("Concurrent close failed: %s, falling back to sequential close", e)
607
591
  # Strategy 2: Fall back to sequential close
608
592
  await self._sequential_close(transport_items, close_results)
609
-
593
+
610
594
  # Always clean up state
611
595
  self._cleanup_state()
612
-
596
+
613
597
  # Log summary
614
598
  if close_results:
615
599
  successful_closes = sum(1 for _, success, _ in close_results if success)
616
600
  logger.debug("Transport cleanup: %d/%d closed successfully", successful_closes, len(close_results))
617
601
 
618
- async def _concurrent_close(self, transport_items: List[Tuple[str, MCPBaseTransport]], close_results: List) -> None:
602
+ async def _concurrent_close(self, transport_items: list[tuple[str, MCPBaseTransport]], close_results: list) -> None:
619
603
  """Try to close all transports concurrently."""
620
604
  close_tasks = []
621
605
  for name, transport in transport_items:
622
- task = asyncio.create_task(
623
- self._close_single_transport(name, transport),
624
- name=f"close_{name}"
625
- )
606
+ task = asyncio.create_task(self._close_single_transport(name, transport), name=f"close_{name}")
626
607
  close_tasks.append((name, task))
627
-
608
+
628
609
  # Wait for all tasks with a reasonable timeout
629
610
  if close_tasks:
630
611
  try:
631
612
  results = await asyncio.wait_for(
632
- asyncio.gather(
633
- *[task for _, task in close_tasks],
634
- return_exceptions=True
635
- ),
636
- timeout=self._shutdown_timeout
613
+ asyncio.gather(*[task for _, task in close_tasks], return_exceptions=True),
614
+ timeout=self._shutdown_timeout,
637
615
  )
638
-
616
+
639
617
  # Process results
640
618
  for i, (name, _) in enumerate(close_tasks):
641
619
  result = results[i] if i < len(results) else None
@@ -645,34 +623,31 @@ class StreamManager:
645
623
  else:
646
624
  logger.debug("Transport %s closed successfully", name)
647
625
  close_results.append((name, True, None))
648
-
649
- except asyncio.TimeoutError:
626
+
627
+ except TimeoutError:
650
628
  # Cancel any remaining tasks
651
629
  for name, task in close_tasks:
652
630
  if not task.done():
653
631
  task.cancel()
654
632
  close_results.append((name, False, "timeout"))
655
-
633
+
656
634
  # Brief wait for cancellations to complete
657
- try:
635
+ with contextlib.suppress(TimeoutError):
658
636
  await asyncio.wait_for(
659
- asyncio.gather(*[task for _, task in close_tasks], return_exceptions=True),
660
- timeout=0.5
637
+ asyncio.gather(*[task for _, task in close_tasks], return_exceptions=True), timeout=0.5
661
638
  )
662
- except asyncio.TimeoutError:
663
- pass # Some tasks may not cancel cleanly
664
639
 
665
- async def _sequential_close(self, transport_items: List[Tuple[str, MCPBaseTransport]], close_results: List) -> None:
640
+ async def _sequential_close(self, transport_items: list[tuple[str, MCPBaseTransport]], close_results: list) -> None:
666
641
  """Close transports one by one as fallback."""
667
642
  for name, transport in transport_items:
668
643
  try:
669
644
  await asyncio.wait_for(
670
645
  self._close_single_transport(name, transport),
671
- timeout=0.5 # Short timeout per transport
646
+ timeout=0.5, # Short timeout per transport
672
647
  )
673
648
  logger.debug("Closed transport: %s", name)
674
649
  close_results.append((name, True, None))
675
- except asyncio.TimeoutError:
650
+ except TimeoutError:
676
651
  logger.debug("Transport %s close timed out (normal during shutdown)", name)
677
652
  close_results.append((name, False, "timeout"))
678
653
  except asyncio.CancelledError:
@@ -685,7 +660,7 @@ class StreamManager:
685
660
  async def _close_single_transport(self, name: str, transport: MCPBaseTransport) -> None:
686
661
  """Close a single transport with error handling."""
687
662
  try:
688
- if hasattr(transport, 'close') and callable(transport.close):
663
+ if hasattr(transport, "close") and callable(transport.close):
689
664
  await transport.close()
690
665
  else:
691
666
  logger.debug("Transport %s has no close method", name)
@@ -716,12 +691,12 @@ class StreamManager:
716
691
  # ------------------------------------------------------------------ #
717
692
  # backwards-compat: streams helper #
718
693
  # ------------------------------------------------------------------ #
719
- def get_streams(self) -> List[Tuple[Any, Any]]:
694
+ def get_streams(self) -> list[tuple[Any, Any]]:
720
695
  """Return a list of (read_stream, write_stream) tuples for all transports."""
721
696
  if self._closed:
722
697
  return []
723
-
724
- pairs: List[Tuple[Any, Any]] = []
698
+
699
+ pairs: list[tuple[Any, Any]] = []
725
700
 
726
701
  for tr in self.transports.values():
727
702
  if hasattr(tr, "get_streams") and callable(tr.get_streams):
@@ -736,7 +711,7 @@ class StreamManager:
736
711
  return pairs
737
712
 
738
713
  @property
739
- def streams(self) -> List[Tuple[Any, Any]]:
714
+ def streams(self) -> list[tuple[Any, Any]]:
740
715
  """Convenience alias for get_streams()."""
741
716
  return self.get_streams()
742
717
 
@@ -751,34 +726,23 @@ class StreamManager:
751
726
  """Get the number of active transports."""
752
727
  return len(self.transports)
753
728
 
754
- async def health_check(self) -> Dict[str, Any]:
729
+ async def health_check(self) -> dict[str, Any]:
755
730
  """Perform a health check on all transports."""
756
731
  if self._closed:
757
732
  return {"status": "closed", "transports": {}}
758
-
759
- health_info = {
760
- "status": "active",
761
- "transport_count": len(self.transports),
762
- "transports": {}
763
- }
764
-
733
+
734
+ health_info = {"status": "active", "transport_count": len(self.transports), "transports": {}}
735
+
765
736
  for name, transport in self.transports.items():
766
737
  try:
767
738
  ping_ok = await asyncio.wait_for(transport.send_ping(), timeout=5.0)
768
739
  health_info["transports"][name] = {
769
740
  "status": "healthy" if ping_ok else "unhealthy",
770
- "ping_success": ping_ok
771
- }
772
- except asyncio.TimeoutError:
773
- health_info["transports"][name] = {
774
- "status": "timeout",
775
- "ping_success": False
741
+ "ping_success": ping_ok,
776
742
  }
743
+ except TimeoutError:
744
+ health_info["transports"][name] = {"status": "timeout", "ping_success": False}
777
745
  except Exception as e:
778
- health_info["transports"][name] = {
779
- "status": "error",
780
- "ping_success": False,
781
- "error": str(e)
782
- }
783
-
784
- return health_info
746
+ health_info["transports"][name] = {"status": "error", "ping_success": False, "error": str(e)}
747
+
748
+ return health_info