chuk-tool-processor 0.5.1__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chuk-tool-processor might be problematic. Click here for more details.

@@ -1,7 +1,6 @@
1
- #!/usr/bin/env python
2
1
  # chuk_tool_processor/mcp/stream_manager.py
3
2
  """
4
- StreamManager for CHUK Tool Processor.
3
+ StreamManager for CHUK Tool Processor - Updated with HTTP Streamable support
5
4
  """
6
5
  from __future__ import annotations
7
6
 
@@ -16,6 +15,7 @@ from chuk_tool_processor.mcp.transport import (
16
15
  MCPBaseTransport,
17
16
  StdioTransport,
18
17
  SSETransport,
18
+ HTTPStreamableTransport,
19
19
  )
20
20
  from chuk_tool_processor.logging import get_logger
21
21
 
@@ -25,11 +25,13 @@ logger = get_logger("chuk_tool_processor.mcp.stream_manager")
25
25
  class StreamManager:
26
26
  """
27
27
  Manager for MCP server streams with support for multiple transport types.
28
+
29
+ Updated to support the latest transports:
30
+ - STDIO (process-based)
31
+ - SSE (Server-Sent Events)
32
+ - HTTP Streamable (modern replacement for SSE, spec 2025-03-26)
28
33
  """
29
34
 
30
- # ------------------------------------------------------------------ #
31
- # construction #
32
- # ------------------------------------------------------------------ #
33
35
  def __init__(self) -> None:
34
36
  self.transports: Dict[str, MCPBaseTransport] = {}
35
37
  self.server_info: List[Dict[str, Any]] = []
@@ -37,7 +39,6 @@ class StreamManager:
37
39
  self.server_names: Dict[int, str] = {}
38
40
  self.all_tools: List[Dict[str, Any]] = []
39
41
  self._lock = asyncio.Lock()
40
- self._close_tasks: List[asyncio.Task] = [] # Track cleanup tasks
41
42
 
42
43
  # ------------------------------------------------------------------ #
43
44
  # factory helpers #
@@ -49,7 +50,7 @@ class StreamManager:
49
50
  servers: List[str],
50
51
  server_names: Optional[Dict[int, str]] = None,
51
52
  transport_type: str = "stdio",
52
- default_timeout: float = 30.0, # ADD: For consistency
53
+ default_timeout: float = 30.0,
53
54
  ) -> "StreamManager":
54
55
  inst = cls()
55
56
  await inst.initialize(
@@ -57,7 +58,7 @@ class StreamManager:
57
58
  servers,
58
59
  server_names,
59
60
  transport_type,
60
- default_timeout=default_timeout # PASS THROUGH
61
+ default_timeout=default_timeout
61
62
  )
62
63
  return inst
63
64
 
@@ -66,20 +67,38 @@ class StreamManager:
66
67
  cls,
67
68
  servers: List[Dict[str, str]],
68
69
  server_names: Optional[Dict[int, str]] = None,
69
- connection_timeout: float = 10.0, # ADD: For SSE connection setup
70
- default_timeout: float = 30.0, # ADD: For tool execution
70
+ connection_timeout: float = 10.0,
71
+ default_timeout: float = 30.0,
71
72
  ) -> "StreamManager":
72
73
  inst = cls()
73
74
  await inst.initialize_with_sse(
74
75
  servers,
75
76
  server_names,
76
- connection_timeout=connection_timeout, # PASS THROUGH
77
- default_timeout=default_timeout # PASS THROUGH
77
+ connection_timeout=connection_timeout,
78
+ default_timeout=default_timeout
79
+ )
80
+ return inst
81
+
82
+ @classmethod
83
+ async def create_with_http_streamable(
84
+ cls,
85
+ servers: List[Dict[str, str]],
86
+ server_names: Optional[Dict[int, str]] = None,
87
+ connection_timeout: float = 30.0,
88
+ default_timeout: float = 30.0,
89
+ ) -> "StreamManager":
90
+ """Create StreamManager with HTTP Streamable transport."""
91
+ inst = cls()
92
+ await inst.initialize_with_http_streamable(
93
+ servers,
94
+ server_names,
95
+ connection_timeout=connection_timeout,
96
+ default_timeout=default_timeout
78
97
  )
79
98
  return inst
80
99
 
81
100
  # ------------------------------------------------------------------ #
82
- # initialisation - stdio / sse #
101
+ # initialisation - stdio / sse / http_streamable #
83
102
  # ------------------------------------------------------------------ #
84
103
  async def initialize(
85
104
  self,
@@ -87,7 +106,7 @@ class StreamManager:
87
106
  servers: List[str],
88
107
  server_names: Optional[Dict[int, str]] = None,
89
108
  transport_type: str = "stdio",
90
- default_timeout: float = 30.0, # ADD: For consistency
109
+ default_timeout: float = 30.0,
91
110
  ) -> None:
92
111
  async with self._lock:
93
112
  self.server_names = server_names or {}
@@ -98,11 +117,9 @@ class StreamManager:
98
117
  params = await load_config(config_file, server_name)
99
118
  transport: MCPBaseTransport = StdioTransport(params)
100
119
  elif transport_type == "sse":
101
- # WARNING: For SSE transport, prefer using create_with_sse() instead
102
- # This is a fallback for backward compatibility
103
120
  logger.warning("Using SSE transport in initialize() - consider using initialize_with_sse() instead")
121
+ params = await load_config(config_file, server_name)
104
122
 
105
- # Try to extract URL from params or use localhost as fallback
106
123
  if isinstance(params, dict) and 'url' in params:
107
124
  sse_url = params['url']
108
125
  api_key = params.get('api_key')
@@ -116,6 +133,26 @@ class StreamManager:
116
133
  api_key,
117
134
  default_timeout=default_timeout
118
135
  )
136
+ elif transport_type == "http_streamable":
137
+ logger.warning("Using HTTP Streamable transport in initialize() - consider using initialize_with_http_streamable() instead")
138
+ params = await load_config(config_file, server_name)
139
+
140
+ if isinstance(params, dict) and 'url' in params:
141
+ http_url = params['url']
142
+ api_key = params.get('api_key')
143
+ session_id = params.get('session_id')
144
+ else:
145
+ http_url = "http://localhost:8000"
146
+ api_key = None
147
+ session_id = None
148
+ logger.warning(f"No URL configured for HTTP Streamable transport, using default: {http_url}")
149
+
150
+ transport = HTTPStreamableTransport(
151
+ http_url,
152
+ api_key,
153
+ default_timeout=default_timeout,
154
+ session_id=session_id
155
+ )
119
156
  else:
120
157
  logger.error("Unsupported transport type: %s", transport_type)
121
158
  continue
@@ -124,10 +161,8 @@ class StreamManager:
124
161
  logger.error("Failed to init %s", server_name)
125
162
  continue
126
163
 
127
- # store transport
128
164
  self.transports[server_name] = transport
129
165
 
130
- # ping + gather tools
131
166
  status = "Up" if await transport.send_ping() else "Down"
132
167
  tools = await transport.get_tools()
133
168
 
@@ -146,7 +181,7 @@ class StreamManager:
146
181
  }
147
182
  )
148
183
  logger.info("Initialised %s - %d tool(s)", server_name, len(tools))
149
- except Exception as exc: # noqa: BLE001
184
+ except Exception as exc:
150
185
  logger.error("Error initialising %s: %s", server_name, exc)
151
186
 
152
187
  logger.info(
@@ -159,8 +194,8 @@ class StreamManager:
159
194
  self,
160
195
  servers: List[Dict[str, str]],
161
196
  server_names: Optional[Dict[int, str]] = None,
162
- connection_timeout: float = 10.0, # ADD: For SSE connection setup
163
- default_timeout: float = 30.0, # ADD: For tool execution
197
+ connection_timeout: float = 10.0,
198
+ default_timeout: float = 30.0,
164
199
  ) -> None:
165
200
  async with self._lock:
166
201
  self.server_names = server_names or {}
@@ -171,12 +206,11 @@ class StreamManager:
171
206
  logger.error("Bad server config: %s", cfg)
172
207
  continue
173
208
  try:
174
- # FIXED: Pass timeout parameters to SSETransport
175
209
  transport = SSETransport(
176
210
  url,
177
211
  cfg.get("api_key"),
178
- connection_timeout=connection_timeout, # ADD THIS
179
- default_timeout=default_timeout # ADD THIS
212
+ connection_timeout=connection_timeout,
213
+ default_timeout=default_timeout
180
214
  )
181
215
 
182
216
  if not await transport.initialize():
@@ -197,7 +231,7 @@ class StreamManager:
197
231
  {"id": idx, "name": name, "tools": len(tools), "status": status}
198
232
  )
199
233
  logger.info("Initialised SSE %s - %d tool(s)", name, len(tools))
200
- except Exception as exc: # noqa: BLE001
234
+ except Exception as exc:
201
235
  logger.error("Error initialising SSE %s: %s", name, exc)
202
236
 
203
237
  logger.info(
@@ -206,6 +240,58 @@ class StreamManager:
206
240
  len(self.all_tools),
207
241
  )
208
242
 
243
+ async def initialize_with_http_streamable(
244
+ self,
245
+ servers: List[Dict[str, str]],
246
+ server_names: Optional[Dict[int, str]] = None,
247
+ connection_timeout: float = 30.0,
248
+ default_timeout: float = 30.0,
249
+ ) -> None:
250
+ """Initialize with HTTP Streamable transport (modern MCP spec 2025-03-26)."""
251
+ async with self._lock:
252
+ self.server_names = server_names or {}
253
+
254
+ for idx, cfg in enumerate(servers):
255
+ name, url = cfg.get("name"), cfg.get("url")
256
+ if not (name and url):
257
+ logger.error("Bad server config: %s", cfg)
258
+ continue
259
+ try:
260
+ transport = HTTPStreamableTransport(
261
+ url,
262
+ cfg.get("api_key"),
263
+ connection_timeout=connection_timeout,
264
+ default_timeout=default_timeout,
265
+ session_id=cfg.get("session_id")
266
+ )
267
+
268
+ if not await transport.initialize():
269
+ logger.error("Failed to init HTTP Streamable %s", name)
270
+ continue
271
+
272
+ self.transports[name] = transport
273
+ status = "Up" if await transport.send_ping() else "Down"
274
+ tools = await transport.get_tools()
275
+
276
+ for t in tools:
277
+ tname = t.get("name")
278
+ if tname:
279
+ self.tool_to_server_map[tname] = name
280
+ self.all_tools.extend(tools)
281
+
282
+ self.server_info.append(
283
+ {"id": idx, "name": name, "tools": len(tools), "status": status}
284
+ )
285
+ logger.info("Initialised HTTP Streamable %s - %d tool(s)", name, len(tools))
286
+ except Exception as exc:
287
+ logger.error("Error initialising HTTP Streamable %s: %s", name, exc)
288
+
289
+ logger.info(
290
+ "StreamManager ready - %d HTTP Streamable server(s), %d tool(s)",
291
+ len(self.transports),
292
+ len(self.all_tools),
293
+ )
294
+
209
295
  # ------------------------------------------------------------------ #
210
296
  # queries #
211
297
  # ------------------------------------------------------------------ #
@@ -219,26 +305,14 @@ class StreamManager:
219
305
  return self.server_info
220
306
 
221
307
  async def list_tools(self, server_name: str) -> List[Dict[str, Any]]:
222
- """
223
- List all tools available from a specific server.
224
-
225
- This method is required by ProxyServerManager for proper tool discovery.
226
-
227
- Args:
228
- server_name: Name of the server to query
229
-
230
- Returns:
231
- List of tool definitions from the server
232
- """
308
+ """List all tools available from a specific server."""
233
309
  if server_name not in self.transports:
234
310
  logger.error(f"Server '{server_name}' not found in transports")
235
311
  return []
236
312
 
237
- # Get the transport for this server
238
313
  transport = self.transports[server_name]
239
314
 
240
315
  try:
241
- # Call the get_tools method on the transport
242
316
  tools = await transport.get_tools()
243
317
  logger.debug(f"Found {len(tools)} tools for server {server_name}")
244
318
  return tools
@@ -253,7 +327,7 @@ class StreamManager:
253
327
  async def _ping_one(name: str, tr: MCPBaseTransport):
254
328
  try:
255
329
  ok = await tr.send_ping()
256
- except Exception: # pragma: no cover
330
+ except Exception:
257
331
  ok = False
258
332
  return {"server": name, "ok": ok}
259
333
 
@@ -263,11 +337,8 @@ class StreamManager:
263
337
  out: List[Dict[str, Any]] = []
264
338
 
265
339
  async def _one(name: str, tr: MCPBaseTransport):
266
- if not hasattr(tr, "list_resources"):
267
- return
268
340
  try:
269
- res = await tr.list_resources() # type: ignore[attr-defined]
270
- # accept either {"resources": [...]} **or** a plain list
341
+ res = await tr.list_resources()
271
342
  resources = (
272
343
  res.get("resources", []) if isinstance(res, dict) else res
273
344
  )
@@ -285,10 +356,8 @@ class StreamManager:
285
356
  out: List[Dict[str, Any]] = []
286
357
 
287
358
  async def _one(name: str, tr: MCPBaseTransport):
288
- if not hasattr(tr, "list_prompts"):
289
- return
290
359
  try:
291
- res = await tr.list_prompts() # type: ignore[attr-defined]
360
+ res = await tr.list_prompts()
292
361
  prompts = res.get("prompts", []) if isinstance(res, dict) else res
293
362
  for item in prompts:
294
363
  item = dict(item)
@@ -308,23 +377,11 @@ class StreamManager:
308
377
  tool_name: str,
309
378
  arguments: Dict[str, Any],
310
379
  server_name: Optional[str] = None,
311
- timeout: Optional[float] = None, # Timeout parameter already exists
380
+ timeout: Optional[float] = None,
312
381
  ) -> Dict[str, Any]:
313
- """
314
- Call a tool on the appropriate server with timeout support.
315
-
316
- Args:
317
- tool_name: Name of the tool to call
318
- arguments: Arguments to pass to the tool
319
- server_name: Optional server name (auto-detected if not provided)
320
- timeout: Optional timeout for the call
321
-
322
- Returns:
323
- Dictionary containing the tool result or error
324
- """
382
+ """Call a tool on the appropriate server with timeout support."""
325
383
  server_name = server_name or self.get_server_for_tool(tool_name)
326
384
  if not server_name or server_name not in self.transports:
327
- # wording kept exactly for unit-test expectation
328
385
  return {
329
386
  "isError": True,
330
387
  "error": f"No server found for tool: {tool_name}",
@@ -332,25 +389,20 @@ class StreamManager:
332
389
 
333
390
  transport = self.transports[server_name]
334
391
 
335
- # Apply timeout if specified
336
392
  if timeout is not None:
337
393
  logger.debug("Calling tool '%s' with %ss timeout", tool_name, timeout)
338
394
  try:
339
- # ENHANCED: Pass timeout to transport.call_tool if it supports it
340
395
  if hasattr(transport, 'call_tool'):
341
396
  import inspect
342
397
  sig = inspect.signature(transport.call_tool)
343
398
  if 'timeout' in sig.parameters:
344
- # Transport supports timeout parameter - pass it through
345
399
  return await transport.call_tool(tool_name, arguments, timeout=timeout)
346
400
  else:
347
- # Transport doesn't support timeout - use asyncio.wait_for wrapper
348
401
  return await asyncio.wait_for(
349
402
  transport.call_tool(tool_name, arguments),
350
403
  timeout=timeout
351
404
  )
352
405
  else:
353
- # Fallback to asyncio.wait_for
354
406
  return await asyncio.wait_for(
355
407
  transport.call_tool(tool_name, arguments),
356
408
  timeout=timeout
@@ -362,94 +414,69 @@ class StreamManager:
362
414
  "error": f"Tool call timed out after {timeout}s",
363
415
  }
364
416
  else:
365
- # No timeout specified, call directly
366
417
  return await transport.call_tool(tool_name, arguments)
367
418
 
368
419
  # ------------------------------------------------------------------ #
369
- # shutdown - PROPERLY FIXED VERSION #
420
+ # shutdown - FIXED VERSION to prevent cancel scope errors #
370
421
  # ------------------------------------------------------------------ #
371
422
  async def close(self) -> None:
372
- """
373
- Properly close all transports with graceful handling of cancellation.
374
- """
423
+ """Close all transports safely without cancel scope errors."""
375
424
  if not self.transports:
425
+ logger.debug("No transports to close")
376
426
  return
377
427
 
378
- # Cancel any existing close tasks
379
- for task in self._close_tasks:
380
- if not task.done():
381
- task.cancel()
382
- self._close_tasks.clear()
428
+ logger.debug(f"Closing {len(self.transports)} transports...")
383
429
 
384
- # Create close tasks for all transports
385
- close_tasks = []
386
- for name, transport in list(self.transports.items()):
387
- try:
388
- task = asyncio.create_task(
389
- self._close_transport(name, transport),
390
- name=f"close_{name}"
391
- )
392
- close_tasks.append(task)
393
- self._close_tasks.append(task)
394
- except Exception as e:
395
- logger.debug(f"Error creating close task for {name}: {e}")
430
+ # Strategy: Close transports sequentially with short timeouts
431
+ close_results = []
432
+ transport_items = list(self.transports.items())
396
433
 
397
- # Wait for all close tasks with a timeout
398
- if close_tasks:
434
+ for name, transport in transport_items:
399
435
  try:
400
- # Give transports a reasonable time to close gracefully
401
- await asyncio.wait_for(
402
- asyncio.gather(*close_tasks, return_exceptions=True),
403
- timeout=2.0
404
- )
405
- except asyncio.TimeoutError:
406
- # Cancel any still-running tasks
407
- for task in close_tasks:
408
- if not task.done():
409
- task.cancel()
410
- # Brief wait for cancellation to take effect
411
- await asyncio.gather(*close_tasks, return_exceptions=True)
412
- except asyncio.CancelledError:
413
- # This is expected during event loop shutdown
414
- logger.debug("Close operation cancelled during shutdown")
436
+ try:
437
+ await asyncio.wait_for(transport.close(), timeout=0.2)
438
+ logger.debug(f"Closed transport: {name}")
439
+ close_results.append((name, True, None))
440
+ except asyncio.TimeoutError:
441
+ logger.debug(f"Transport {name} close timed out (normal during shutdown)")
442
+ close_results.append((name, False, "timeout"))
443
+ except asyncio.CancelledError:
444
+ logger.debug(f"Transport {name} close cancelled during event loop shutdown")
445
+ close_results.append((name, False, "cancelled"))
446
+
415
447
  except Exception as e:
416
- logger.debug(f"Unexpected error during close: {e}")
448
+ logger.debug(f"Error closing transport {name}: {e}")
449
+ close_results.append((name, False, str(e)))
417
450
 
418
451
  # Clean up state
419
452
  self._cleanup_state()
420
-
421
- async def _close_transport(self, name: str, transport: MCPBaseTransport) -> None:
422
- """Close a single transport with error handling."""
453
+
454
+ # Log summary
455
+ successful_closes = sum(1 for _, success, _ in close_results if success)
456
+ if close_results:
457
+ logger.debug(f"Transport cleanup: {successful_closes}/{len(close_results)} closed successfully")
458
+
459
+ def _cleanup_state(self) -> None:
460
+ """Clean up internal state synchronously."""
423
461
  try:
424
- await transport.close()
425
- logger.debug(f"Closed transport: {name}")
426
- except asyncio.CancelledError:
427
- # Re-raise cancellation
428
- raise
462
+ self.transports.clear()
463
+ self.server_info.clear()
464
+ self.tool_to_server_map.clear()
465
+ self.all_tools.clear()
466
+ self.server_names.clear()
429
467
  except Exception as e:
430
- logger.debug(f"Error closing transport {name}: {e}")
431
-
432
- def _cleanup_state(self) -> None:
433
- """Clean up internal state (synchronous)."""
434
- self.transports.clear()
435
- self.server_info.clear()
436
- self.tool_to_server_map.clear()
437
- self.all_tools.clear()
438
- self._close_tasks.clear()
468
+ logger.debug(f"Error during state cleanup: {e}")
439
469
 
440
470
  # ------------------------------------------------------------------ #
441
471
  # backwards-compat: streams helper #
442
472
  # ------------------------------------------------------------------ #
443
473
  def get_streams(self) -> List[Tuple[Any, Any]]:
444
- """
445
- Return a list of ``(read_stream, write_stream)`` tuples for **all**
446
- transports. Older CLI commands rely on this helper.
447
- """
474
+ """Return a list of (read_stream, write_stream) tuples for all transports."""
448
475
  pairs: List[Tuple[Any, Any]] = []
449
476
 
450
477
  for tr in self.transports.values():
451
478
  if hasattr(tr, "get_streams") and callable(tr.get_streams):
452
- pairs.extend(tr.get_streams()) # type: ignore[arg-type]
479
+ pairs.extend(tr.get_streams())
453
480
  continue
454
481
 
455
482
  rd = getattr(tr, "read_stream", None)
@@ -459,7 +486,7 @@ class StreamManager:
459
486
 
460
487
  return pairs
461
488
 
462
- # convenience alias
463
489
  @property
464
- def streams(self) -> List[Tuple[Any, Any]]: # pragma: no cover
490
+ def streams(self) -> List[Tuple[Any, Any]]:
491
+ """Convenience alias for get_streams()."""
465
492
  return self.get_streams()
@@ -6,9 +6,11 @@ MCP transport implementations.
6
6
  from .base_transport import MCPBaseTransport
7
7
  from .stdio_transport import StdioTransport
8
8
  from .sse_transport import SSETransport
9
+ from .http_streamable_transport import HTTPStreamableTransport
9
10
 
10
11
  __all__ = [
11
12
  "MCPBaseTransport",
12
- "StdioTransport",
13
- "SSETransport"
13
+ "StdioTransport",
14
+ "SSETransport",
15
+ "HTTPStreamableTransport"
14
16
  ]