chuk-tool-processor 0.3__py3-none-any.whl → 0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chuk-tool-processor might be problematic. Click here for more details.

@@ -1,10 +1,12 @@
1
1
  #!/usr/bin/env python
2
2
  # chuk_tool_processor/execution/tool_executor.py
3
3
  """
4
- Modified ToolExecutor with true streaming support and duplicate prevention.
4
+ Modified ToolExecutor with true streaming support and proper timeout handling.
5
5
 
6
6
  This version accesses streaming tools' stream_execute method directly
7
7
  to enable true item-by-item streaming behavior, while preventing duplicates.
8
+
9
+ FIXED: Proper timeout precedence - respects strategy's default_timeout when available.
8
10
  """
9
11
  import asyncio
10
12
  from datetime import datetime, timezone
@@ -25,12 +27,14 @@ class ToolExecutor:
25
27
 
26
28
  This class provides a unified interface for executing tools using different
27
29
  execution strategies, with special support for streaming tools.
30
+
31
+ FIXED: Proper timeout handling that respects strategy's default_timeout.
28
32
  """
29
33
 
30
34
  def __init__(
31
35
  self,
32
36
  registry: Optional[ToolRegistryInterface] = None,
33
- default_timeout: float = 10.0,
37
+ default_timeout: Optional[float] = None, # Made optional to allow strategy precedence
34
38
  strategy: Optional[ExecutionStrategy] = None,
35
39
  strategy_kwargs: Optional[Dict[str, Any]] = None,
36
40
  ) -> None:
@@ -39,12 +43,12 @@ class ToolExecutor:
39
43
 
40
44
  Args:
41
45
  registry: Tool registry to use for tool lookups
42
- default_timeout: Default timeout for tool execution
46
+ default_timeout: Default timeout for tool execution (optional)
47
+ If None, will use strategy's default_timeout if available
43
48
  strategy: Optional execution strategy (default: InProcessStrategy)
44
49
  strategy_kwargs: Additional arguments for the strategy constructor
45
50
  """
46
51
  self.registry = registry
47
- self.default_timeout = default_timeout
48
52
 
49
53
  # Create strategy if not provided
50
54
  if strategy is None:
@@ -55,13 +59,31 @@ class ToolExecutor:
55
59
  raise ValueError("Registry must be provided if strategy is not")
56
60
 
57
61
  strategy_kwargs = strategy_kwargs or {}
62
+
63
+ # If no default_timeout specified, use a reasonable default for the strategy
64
+ strategy_timeout = default_timeout if default_timeout is not None else 30.0
65
+
58
66
  strategy = _inprocess_mod.InProcessStrategy(
59
67
  registry,
60
- default_timeout=default_timeout,
68
+ default_timeout=strategy_timeout,
61
69
  **strategy_kwargs,
62
70
  )
63
71
 
64
72
  self.strategy = strategy
73
+
74
+ # Set default timeout with proper precedence:
75
+ # 1. Explicit default_timeout parameter
76
+ # 2. Strategy's default_timeout (if available and not None)
77
+ # 3. Fallback to 30.0 seconds
78
+ if default_timeout is not None:
79
+ self.default_timeout = default_timeout
80
+ logger.debug(f"Using explicit default_timeout: {self.default_timeout}s")
81
+ elif hasattr(strategy, 'default_timeout') and strategy.default_timeout is not None:
82
+ self.default_timeout = strategy.default_timeout
83
+ logger.debug(f"Using strategy's default_timeout: {self.default_timeout}s")
84
+ else:
85
+ self.default_timeout = 30.0 # Conservative fallback
86
+ logger.debug(f"Using fallback default_timeout: {self.default_timeout}s")
65
87
 
66
88
  @property
67
89
  def supports_streaming(self) -> bool:
@@ -79,7 +101,7 @@ class ToolExecutor:
79
101
 
80
102
  Args:
81
103
  calls: List of tool calls to execute
82
- timeout: Optional timeout for execution (overrides default_timeout)
104
+ timeout: Optional timeout for execution (overrides all defaults)
83
105
  use_cache: Whether to use cached results (for caching wrappers)
84
106
 
85
107
  Returns:
@@ -88,10 +110,13 @@ class ToolExecutor:
88
110
  if not calls:
89
111
  return []
90
112
 
91
- # Use the provided timeout or fall back to default
113
+ # Timeout precedence:
114
+ # 1. Explicit timeout parameter (highest priority)
115
+ # 2. Executor's default_timeout (which already considers strategy's timeout)
92
116
  effective_timeout = timeout if timeout is not None else self.default_timeout
93
117
 
94
- logger.debug(f"Executing {len(calls)} tool calls with timeout {effective_timeout}s")
118
+ logger.debug(f"Executing {len(calls)} tool calls with timeout {effective_timeout}s "
119
+ f"(explicit: {timeout is not None})")
95
120
 
96
121
  # Delegate to the strategy
97
122
  return await self.strategy.run(calls, timeout=effective_timeout)
@@ -118,9 +143,12 @@ class ToolExecutor:
118
143
  if not calls:
119
144
  return
120
145
 
121
- # Use the provided timeout or fall back to default
146
+ # Use the same timeout precedence as execute()
122
147
  effective_timeout = timeout if timeout is not None else self.default_timeout
123
148
 
149
+ logger.debug(f"Stream executing {len(calls)} tool calls with timeout {effective_timeout}s "
150
+ f"(explicit: {timeout is not None})")
151
+
124
152
  # There are two possible ways to handle streaming:
125
153
  # 1. Use the strategy's stream_run if available
126
154
  # 2. Use direct streaming for streaming tools
@@ -232,6 +260,8 @@ class ToolExecutor:
232
260
  machine = "direct-stream"
233
261
  pid = 0
234
262
 
263
+ logger.debug(f"Direct streaming {call.tool} with timeout {timeout}s")
264
+
235
265
  # Create streaming task with timeout
236
266
  async def stream_with_timeout():
237
267
  try:
@@ -265,11 +295,16 @@ class ToolExecutor:
265
295
  try:
266
296
  if timeout:
267
297
  await asyncio.wait_for(stream_with_timeout(), timeout)
298
+ logger.debug(f"Direct streaming {call.tool} completed within {timeout}s")
268
299
  else:
269
300
  await stream_with_timeout()
301
+ logger.debug(f"Direct streaming {call.tool} completed (no timeout)")
270
302
  except asyncio.TimeoutError:
271
303
  # Handle timeout
272
304
  end_time = datetime.now(timezone.utc)
305
+ actual_duration = (end_time - start_time).total_seconds()
306
+ logger.debug(f"Direct streaming {call.tool} timed out after {actual_duration:.3f}s (limit: {timeout}s)")
307
+
273
308
  timeout_result = ToolResult(
274
309
  tool=call.tool,
275
310
  result=None,
@@ -283,6 +318,8 @@ class ToolExecutor:
283
318
  except Exception as e:
284
319
  # Handle other errors
285
320
  end_time = datetime.now(timezone.utc)
321
+ logger.exception(f"Error in direct streaming {call.tool}: {e}")
322
+
286
323
  error_result = ToolResult(
287
324
  tool=call.tool,
288
325
  result=None,
@@ -300,5 +337,6 @@ class ToolExecutor:
300
337
 
301
338
  This should be called during application shutdown to ensure proper cleanup.
302
339
  """
340
+ logger.debug("Shutting down ToolExecutor")
303
341
  if hasattr(self.strategy, "shutdown") and callable(self.strategy.shutdown):
304
342
  await self.strategy.shutdown()
@@ -47,9 +47,16 @@ class StreamManager:
47
47
  servers: List[str],
48
48
  server_names: Optional[Dict[int, str]] = None,
49
49
  transport_type: str = "stdio",
50
+ default_timeout: float = 30.0, # ADD: For consistency
50
51
  ) -> "StreamManager":
51
52
  inst = cls()
52
- await inst.initialize(config_file, servers, server_names, transport_type)
53
+ await inst.initialize(
54
+ config_file,
55
+ servers,
56
+ server_names,
57
+ transport_type,
58
+ default_timeout=default_timeout # PASS THROUGH
59
+ )
53
60
  return inst
54
61
 
55
62
  @classmethod
@@ -57,9 +64,16 @@ class StreamManager:
57
64
  cls,
58
65
  servers: List[Dict[str, str]],
59
66
  server_names: Optional[Dict[int, str]] = None,
67
+ connection_timeout: float = 10.0, # ADD: For SSE connection setup
68
+ default_timeout: float = 30.0, # ADD: For tool execution
60
69
  ) -> "StreamManager":
61
70
  inst = cls()
62
- await inst.initialize_with_sse(servers, server_names)
71
+ await inst.initialize_with_sse(
72
+ servers,
73
+ server_names,
74
+ connection_timeout=connection_timeout, # PASS THROUGH
75
+ default_timeout=default_timeout # PASS THROUGH
76
+ )
63
77
  return inst
64
78
 
65
79
  # ------------------------------------------------------------------ #
@@ -71,6 +85,7 @@ class StreamManager:
71
85
  servers: List[str],
72
86
  server_names: Optional[Dict[int, str]] = None,
73
87
  transport_type: str = "stdio",
88
+ default_timeout: float = 30.0, # ADD: For consistency
74
89
  ) -> None:
75
90
  async with self._lock:
76
91
  self.server_names = server_names or {}
@@ -81,7 +96,24 @@ class StreamManager:
81
96
  params = await load_config(config_file, server_name)
82
97
  transport: MCPBaseTransport = StdioTransport(params)
83
98
  elif transport_type == "sse":
84
- transport = SSETransport("http://localhost:8000")
99
+ # WARNING: For SSE transport, prefer using create_with_sse() instead
100
+ # This is a fallback for backward compatibility
101
+ logger.warning("Using SSE transport in initialize() - consider using initialize_with_sse() instead")
102
+
103
+ # Try to extract URL from params or use localhost as fallback
104
+ if isinstance(params, dict) and 'url' in params:
105
+ sse_url = params['url']
106
+ api_key = params.get('api_key')
107
+ else:
108
+ sse_url = "http://localhost:8000"
109
+ api_key = None
110
+ logger.warning(f"No URL configured for SSE transport, using default: {sse_url}")
111
+
112
+ transport = SSETransport(
113
+ sse_url,
114
+ api_key,
115
+ default_timeout=default_timeout
116
+ )
85
117
  else:
86
118
  logger.error("Unsupported transport type: %s", transport_type)
87
119
  continue
@@ -125,6 +157,8 @@ class StreamManager:
125
157
  self,
126
158
  servers: List[Dict[str, str]],
127
159
  server_names: Optional[Dict[int, str]] = None,
160
+ connection_timeout: float = 10.0, # ADD: For SSE connection setup
161
+ default_timeout: float = 30.0, # ADD: For tool execution
128
162
  ) -> None:
129
163
  async with self._lock:
130
164
  self.server_names = server_names or {}
@@ -135,7 +169,14 @@ class StreamManager:
135
169
  logger.error("Bad server config: %s", cfg)
136
170
  continue
137
171
  try:
138
- transport = SSETransport(url, cfg.get("api_key"))
172
+ # FIXED: Pass timeout parameters to SSETransport
173
+ transport = SSETransport(
174
+ url,
175
+ cfg.get("api_key"),
176
+ connection_timeout=connection_timeout, # ADD THIS
177
+ default_timeout=default_timeout # ADD THIS
178
+ )
179
+
139
180
  if not await transport.initialize():
140
181
  logger.error("Failed to init SSE %s", name)
141
182
  continue
@@ -265,7 +306,7 @@ class StreamManager:
265
306
  tool_name: str,
266
307
  arguments: Dict[str, Any],
267
308
  server_name: Optional[str] = None,
268
- timeout: Optional[float] = None, # Add timeout parameter
309
+ timeout: Optional[float] = None, # Timeout parameter already exists
269
310
  ) -> Dict[str, Any]:
270
311
  """
271
312
  Call a tool on the appropriate server with timeout support.
@@ -293,10 +334,25 @@ class StreamManager:
293
334
  if timeout is not None:
294
335
  logger.debug("Calling tool '%s' with %ss timeout", tool_name, timeout)
295
336
  try:
296
- return await asyncio.wait_for(
297
- transport.call_tool(tool_name, arguments),
298
- timeout=timeout
299
- )
337
+ # ENHANCED: Pass timeout to transport.call_tool if it supports it
338
+ if hasattr(transport, 'call_tool'):
339
+ import inspect
340
+ sig = inspect.signature(transport.call_tool)
341
+ if 'timeout' in sig.parameters:
342
+ # Transport supports timeout parameter - pass it through
343
+ return await transport.call_tool(tool_name, arguments, timeout=timeout)
344
+ else:
345
+ # Transport doesn't support timeout - use asyncio.wait_for wrapper
346
+ return await asyncio.wait_for(
347
+ transport.call_tool(tool_name, arguments),
348
+ timeout=timeout
349
+ )
350
+ else:
351
+ # Fallback to asyncio.wait_for
352
+ return await asyncio.wait_for(
353
+ transport.call_tool(tool_name, arguments),
354
+ timeout=timeout
355
+ )
300
356
  except asyncio.TimeoutError:
301
357
  logger.warning("Tool '%s' timed out after %ss", tool_name, timeout)
302
358
  return {
@@ -350,4 +406,4 @@ class StreamManager:
350
406
  # convenience alias
351
407
  @property
352
408
  def streams(self) -> List[Tuple[Any, Any]]: # pragma: no cover
353
- return self.get_streams()
409
+ return self.get_streams()
@@ -8,6 +8,8 @@ This transport:
8
8
  3. Sends MCP initialize handshake FIRST
9
9
  4. Only then proceeds with tools/list and tool calls
10
10
  5. Handles async responses via SSE message events
11
+
12
+ FIXED: All hardcoded timeouts are now configurable parameters.
11
13
  """
12
14
  from __future__ import annotations
13
15
 
@@ -24,7 +26,8 @@ from .base_transport import MCPBaseTransport
24
26
  # --------------------------------------------------------------------------- #
25
27
  # Helpers #
26
28
  # --------------------------------------------------------------------------- #
27
- DEFAULT_TIMEOUT = 30.0 # Longer timeout for real servers
29
+ DEFAULT_TIMEOUT = 30.0 # Default timeout for tool calls
30
+ DEFAULT_CONNECTION_TIMEOUT = 10.0 # Default timeout for connection setup
28
31
  HEADERS_JSON: Dict[str, str] = {"accept": "application/json"}
29
32
 
30
33
 
@@ -47,9 +50,26 @@ class SSETransport(MCPBaseTransport):
47
50
  5. Waits for async responses via SSE message events
48
51
  """
49
52
 
50
- def __init__(self, url: str, api_key: Optional[str] = None) -> None:
53
+ def __init__(
54
+ self,
55
+ url: str,
56
+ api_key: Optional[str] = None,
57
+ connection_timeout: float = DEFAULT_CONNECTION_TIMEOUT,
58
+ default_timeout: float = DEFAULT_TIMEOUT
59
+ ) -> None:
60
+ """
61
+ Initialize SSE Transport with configurable timeouts.
62
+
63
+ Args:
64
+ url: Base URL for the MCP server
65
+ api_key: Optional API key for authentication
66
+ connection_timeout: Timeout for connection setup (default: 10.0s)
67
+ default_timeout: Default timeout for tool calls (default: 30.0s)
68
+ """
51
69
  self.base_url = url.rstrip("/")
52
70
  self.api_key = api_key
71
+ self.connection_timeout = connection_timeout
72
+ self.default_timeout = default_timeout
53
73
 
54
74
  # NEW: Auto-detect bearer token from environment if not provided
55
75
  if not self.api_key:
@@ -92,7 +112,7 @@ class SSETransport(MCPBaseTransport):
92
112
 
93
113
  self._client = httpx.AsyncClient(
94
114
  headers=headers,
95
- timeout=DEFAULT_TIMEOUT,
115
+ timeout=self.default_timeout, # Use configurable timeout
96
116
  )
97
117
  self.session = self._client
98
118
 
@@ -100,8 +120,8 @@ class SSETransport(MCPBaseTransport):
100
120
  self._sse_task = asyncio.create_task(self._handle_sse_connection())
101
121
 
102
122
  try:
103
- # Wait for endpoint event (up to 10 seconds)
104
- await asyncio.wait_for(self._connected.wait(), timeout=10.0)
123
+ # FIXED: Use configurable connection timeout instead of hardcoded 10.0
124
+ await asyncio.wait_for(self._connected.wait(), timeout=self.connection_timeout)
105
125
 
106
126
  # NEW: Send MCP initialize handshake
107
127
  if await self._initialize_mcp_session():
@@ -285,7 +305,8 @@ class SSETransport(MCPBaseTransport):
285
305
  if not self._initialized.is_set():
286
306
  print("⏳ Waiting for MCP initialization...")
287
307
  try:
288
- await asyncio.wait_for(self._initialized.wait(), timeout=10.0)
308
+ # FIXED: Use configurable connection timeout instead of hardcoded 10.0
309
+ await asyncio.wait_for(self._initialized.wait(), timeout=self.connection_timeout)
289
310
  except asyncio.TimeoutError:
290
311
  print("❌ Timeout waiting for MCP initialization")
291
312
  return []
@@ -311,8 +332,23 @@ class SSETransport(MCPBaseTransport):
311
332
 
312
333
  return []
313
334
 
314
- async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
315
- """Execute a tool call using the MCP protocol."""
335
+ async def call_tool(
336
+ self,
337
+ tool_name: str,
338
+ arguments: Dict[str, Any],
339
+ timeout: Optional[float] = None
340
+ ) -> Dict[str, Any]:
341
+ """
342
+ Execute a tool call using the MCP protocol.
343
+
344
+ Args:
345
+ tool_name: Name of the tool to call
346
+ arguments: Arguments to pass to the tool
347
+ timeout: Optional timeout for this specific call
348
+
349
+ Returns:
350
+ Dictionary containing the tool result or error
351
+ """
316
352
  # NEW: Ensure initialization before tool calls
317
353
  if not self._initialized.is_set():
318
354
  return {"isError": True, "error": "MCP session not initialized"}
@@ -331,7 +367,9 @@ class SSETransport(MCPBaseTransport):
331
367
  }
332
368
  }
333
369
 
334
- response = await self._send_message(message)
370
+ # Use provided timeout or fall back to default
371
+ effective_timeout = timeout if timeout is not None else self.default_timeout
372
+ response = await self._send_message(message, timeout=effective_timeout)
335
373
 
336
374
  # Process MCP response
337
375
  if "error" in response:
@@ -363,8 +401,21 @@ class SSETransport(MCPBaseTransport):
363
401
  except Exception as e:
364
402
  return {"isError": True, "error": str(e)}
365
403
 
366
- async def _send_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
367
- """Send a JSON-RPC message to the server and wait for async response."""
404
+ async def _send_message(
405
+ self,
406
+ message: Dict[str, Any],
407
+ timeout: Optional[float] = None
408
+ ) -> Dict[str, Any]:
409
+ """
410
+ Send a JSON-RPC message to the server and wait for async response.
411
+
412
+ Args:
413
+ message: JSON-RPC message to send
414
+ timeout: Optional timeout for this specific message
415
+
416
+ Returns:
417
+ Response message from the server
418
+ """
368
419
  if not self._client or not self._message_url:
369
420
  raise RuntimeError("Transport not properly initialized")
370
421
 
@@ -372,6 +423,9 @@ class SSETransport(MCPBaseTransport):
372
423
  if not message_id:
373
424
  raise ValueError("Message must have an ID")
374
425
 
426
+ # Use provided timeout or fall back to default
427
+ effective_timeout = timeout if timeout is not None else self.default_timeout
428
+
375
429
  # Create a future for this request
376
430
  future = asyncio.Future()
377
431
  async with self._message_lock:
@@ -391,7 +445,8 @@ class SSETransport(MCPBaseTransport):
391
445
  if response.status_code == 202:
392
446
  # Server accepted - wait for async response via SSE
393
447
  try:
394
- response_message = await asyncio.wait_for(future, timeout=30.0)
448
+ # FIXED: Use effective_timeout instead of hardcoded 30.0
449
+ response_message = await asyncio.wait_for(future, timeout=effective_timeout)
395
450
  return response_message
396
451
  except asyncio.TimeoutError:
397
452
  raise RuntimeError(f"Timeout waiting for response to message {message_id}")