chuk-tool-processor 0.2__py3-none-any.whl → 0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chuk-tool-processor might be problematic. Click here for more details.
- chuk_tool_processor/execution/tool_executor.py +47 -9
- chuk_tool_processor/mcp/mcp_tool.py +46 -9
- chuk_tool_processor/mcp/setup_mcp_sse.py +21 -1
- chuk_tool_processor/mcp/stream_manager.py +95 -7
- chuk_tool_processor/mcp/transport/sse_transport.py +86 -13
- chuk_tool_processor-0.4.dist-info/METADATA +831 -0
- {chuk_tool_processor-0.2.dist-info → chuk_tool_processor-0.4.dist-info}/RECORD +9 -9
- chuk_tool_processor-0.2.dist-info/METADATA +0 -401
- {chuk_tool_processor-0.2.dist-info → chuk_tool_processor-0.4.dist-info}/WHEEL +0 -0
- {chuk_tool_processor-0.2.dist-info → chuk_tool_processor-0.4.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
#!/usr/bin/env python
|
|
2
2
|
# chuk_tool_processor/execution/tool_executor.py
|
|
3
3
|
"""
|
|
4
|
-
Modified ToolExecutor with true streaming support and
|
|
4
|
+
Modified ToolExecutor with true streaming support and proper timeout handling.
|
|
5
5
|
|
|
6
6
|
This version accesses streaming tools' stream_execute method directly
|
|
7
7
|
to enable true item-by-item streaming behavior, while preventing duplicates.
|
|
8
|
+
|
|
9
|
+
FIXED: Proper timeout precedence - respects strategy's default_timeout when available.
|
|
8
10
|
"""
|
|
9
11
|
import asyncio
|
|
10
12
|
from datetime import datetime, timezone
|
|
@@ -25,12 +27,14 @@ class ToolExecutor:
|
|
|
25
27
|
|
|
26
28
|
This class provides a unified interface for executing tools using different
|
|
27
29
|
execution strategies, with special support for streaming tools.
|
|
30
|
+
|
|
31
|
+
FIXED: Proper timeout handling that respects strategy's default_timeout.
|
|
28
32
|
"""
|
|
29
33
|
|
|
30
34
|
def __init__(
|
|
31
35
|
self,
|
|
32
36
|
registry: Optional[ToolRegistryInterface] = None,
|
|
33
|
-
default_timeout: float =
|
|
37
|
+
default_timeout: Optional[float] = None, # Made optional to allow strategy precedence
|
|
34
38
|
strategy: Optional[ExecutionStrategy] = None,
|
|
35
39
|
strategy_kwargs: Optional[Dict[str, Any]] = None,
|
|
36
40
|
) -> None:
|
|
@@ -39,12 +43,12 @@ class ToolExecutor:
|
|
|
39
43
|
|
|
40
44
|
Args:
|
|
41
45
|
registry: Tool registry to use for tool lookups
|
|
42
|
-
default_timeout: Default timeout for tool execution
|
|
46
|
+
default_timeout: Default timeout for tool execution (optional)
|
|
47
|
+
If None, will use strategy's default_timeout if available
|
|
43
48
|
strategy: Optional execution strategy (default: InProcessStrategy)
|
|
44
49
|
strategy_kwargs: Additional arguments for the strategy constructor
|
|
45
50
|
"""
|
|
46
51
|
self.registry = registry
|
|
47
|
-
self.default_timeout = default_timeout
|
|
48
52
|
|
|
49
53
|
# Create strategy if not provided
|
|
50
54
|
if strategy is None:
|
|
@@ -55,13 +59,31 @@ class ToolExecutor:
|
|
|
55
59
|
raise ValueError("Registry must be provided if strategy is not")
|
|
56
60
|
|
|
57
61
|
strategy_kwargs = strategy_kwargs or {}
|
|
62
|
+
|
|
63
|
+
# If no default_timeout specified, use a reasonable default for the strategy
|
|
64
|
+
strategy_timeout = default_timeout if default_timeout is not None else 30.0
|
|
65
|
+
|
|
58
66
|
strategy = _inprocess_mod.InProcessStrategy(
|
|
59
67
|
registry,
|
|
60
|
-
default_timeout=
|
|
68
|
+
default_timeout=strategy_timeout,
|
|
61
69
|
**strategy_kwargs,
|
|
62
70
|
)
|
|
63
71
|
|
|
64
72
|
self.strategy = strategy
|
|
73
|
+
|
|
74
|
+
# Set default timeout with proper precedence:
|
|
75
|
+
# 1. Explicit default_timeout parameter
|
|
76
|
+
# 2. Strategy's default_timeout (if available and not None)
|
|
77
|
+
# 3. Fallback to 30.0 seconds
|
|
78
|
+
if default_timeout is not None:
|
|
79
|
+
self.default_timeout = default_timeout
|
|
80
|
+
logger.debug(f"Using explicit default_timeout: {self.default_timeout}s")
|
|
81
|
+
elif hasattr(strategy, 'default_timeout') and strategy.default_timeout is not None:
|
|
82
|
+
self.default_timeout = strategy.default_timeout
|
|
83
|
+
logger.debug(f"Using strategy's default_timeout: {self.default_timeout}s")
|
|
84
|
+
else:
|
|
85
|
+
self.default_timeout = 30.0 # Conservative fallback
|
|
86
|
+
logger.debug(f"Using fallback default_timeout: {self.default_timeout}s")
|
|
65
87
|
|
|
66
88
|
@property
|
|
67
89
|
def supports_streaming(self) -> bool:
|
|
@@ -79,7 +101,7 @@ class ToolExecutor:
|
|
|
79
101
|
|
|
80
102
|
Args:
|
|
81
103
|
calls: List of tool calls to execute
|
|
82
|
-
timeout: Optional timeout for execution (overrides
|
|
104
|
+
timeout: Optional timeout for execution (overrides all defaults)
|
|
83
105
|
use_cache: Whether to use cached results (for caching wrappers)
|
|
84
106
|
|
|
85
107
|
Returns:
|
|
@@ -88,10 +110,13 @@ class ToolExecutor:
|
|
|
88
110
|
if not calls:
|
|
89
111
|
return []
|
|
90
112
|
|
|
91
|
-
#
|
|
113
|
+
# Timeout precedence:
|
|
114
|
+
# 1. Explicit timeout parameter (highest priority)
|
|
115
|
+
# 2. Executor's default_timeout (which already considers strategy's timeout)
|
|
92
116
|
effective_timeout = timeout if timeout is not None else self.default_timeout
|
|
93
117
|
|
|
94
|
-
logger.debug(f"Executing {len(calls)} tool calls with timeout {effective_timeout}s"
|
|
118
|
+
logger.debug(f"Executing {len(calls)} tool calls with timeout {effective_timeout}s "
|
|
119
|
+
f"(explicit: {timeout is not None})")
|
|
95
120
|
|
|
96
121
|
# Delegate to the strategy
|
|
97
122
|
return await self.strategy.run(calls, timeout=effective_timeout)
|
|
@@ -118,9 +143,12 @@ class ToolExecutor:
|
|
|
118
143
|
if not calls:
|
|
119
144
|
return
|
|
120
145
|
|
|
121
|
-
# Use the
|
|
146
|
+
# Use the same timeout precedence as execute()
|
|
122
147
|
effective_timeout = timeout if timeout is not None else self.default_timeout
|
|
123
148
|
|
|
149
|
+
logger.debug(f"Stream executing {len(calls)} tool calls with timeout {effective_timeout}s "
|
|
150
|
+
f"(explicit: {timeout is not None})")
|
|
151
|
+
|
|
124
152
|
# There are two possible ways to handle streaming:
|
|
125
153
|
# 1. Use the strategy's stream_run if available
|
|
126
154
|
# 2. Use direct streaming for streaming tools
|
|
@@ -232,6 +260,8 @@ class ToolExecutor:
|
|
|
232
260
|
machine = "direct-stream"
|
|
233
261
|
pid = 0
|
|
234
262
|
|
|
263
|
+
logger.debug(f"Direct streaming {call.tool} with timeout {timeout}s")
|
|
264
|
+
|
|
235
265
|
# Create streaming task with timeout
|
|
236
266
|
async def stream_with_timeout():
|
|
237
267
|
try:
|
|
@@ -265,11 +295,16 @@ class ToolExecutor:
|
|
|
265
295
|
try:
|
|
266
296
|
if timeout:
|
|
267
297
|
await asyncio.wait_for(stream_with_timeout(), timeout)
|
|
298
|
+
logger.debug(f"Direct streaming {call.tool} completed within {timeout}s")
|
|
268
299
|
else:
|
|
269
300
|
await stream_with_timeout()
|
|
301
|
+
logger.debug(f"Direct streaming {call.tool} completed (no timeout)")
|
|
270
302
|
except asyncio.TimeoutError:
|
|
271
303
|
# Handle timeout
|
|
272
304
|
end_time = datetime.now(timezone.utc)
|
|
305
|
+
actual_duration = (end_time - start_time).total_seconds()
|
|
306
|
+
logger.debug(f"Direct streaming {call.tool} timed out after {actual_duration:.3f}s (limit: {timeout}s)")
|
|
307
|
+
|
|
273
308
|
timeout_result = ToolResult(
|
|
274
309
|
tool=call.tool,
|
|
275
310
|
result=None,
|
|
@@ -283,6 +318,8 @@ class ToolExecutor:
|
|
|
283
318
|
except Exception as e:
|
|
284
319
|
# Handle other errors
|
|
285
320
|
end_time = datetime.now(timezone.utc)
|
|
321
|
+
logger.exception(f"Error in direct streaming {call.tool}: {e}")
|
|
322
|
+
|
|
286
323
|
error_result = ToolResult(
|
|
287
324
|
tool=call.tool,
|
|
288
325
|
result=None,
|
|
@@ -300,5 +337,6 @@ class ToolExecutor:
|
|
|
300
337
|
|
|
301
338
|
This should be called during application shutdown to ensure proper cleanup.
|
|
302
339
|
"""
|
|
340
|
+
logger.debug("Shutting down ToolExecutor")
|
|
303
341
|
if hasattr(self.strategy, "shutdown") and callable(self.strategy.shutdown):
|
|
304
342
|
await self.strategy.shutdown()
|
|
@@ -36,9 +36,11 @@ class MCPTool:
|
|
|
36
36
|
servers: Optional[List[str]] = None,
|
|
37
37
|
server_names: Optional[Dict[int, str]] = None,
|
|
38
38
|
namespace: str = "stdio",
|
|
39
|
+
default_timeout: Optional[float] = None, # Add default timeout support
|
|
39
40
|
) -> None:
|
|
40
41
|
self.tool_name = tool_name
|
|
41
42
|
self._sm: Optional[StreamManager] = stream_manager
|
|
43
|
+
self.default_timeout = default_timeout or 30.0 # Default to 30s if not specified
|
|
42
44
|
|
|
43
45
|
# Boot-strap parameters (only needed if _sm is None)
|
|
44
46
|
self._cfg_file = cfg_file
|
|
@@ -78,21 +80,56 @@ class MCPTool:
|
|
|
78
80
|
return self._sm # type: ignore[return-value]
|
|
79
81
|
|
|
80
82
|
# ------------------------------------------------------------------ #
|
|
81
|
-
async def execute(self, **kwargs: Any) -> Any:
|
|
83
|
+
async def execute(self, timeout: Optional[float] = None, **kwargs: Any) -> Any:
|
|
82
84
|
"""
|
|
83
|
-
Forward the call to the remote MCP tool.
|
|
85
|
+
Forward the call to the remote MCP tool with timeout support.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
timeout: Optional timeout for this specific call. If not provided,
|
|
89
|
+
uses the instance's default_timeout.
|
|
90
|
+
**kwargs: Arguments to pass to the MCP tool.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
The result from the MCP tool call.
|
|
84
94
|
|
|
85
95
|
Raises
|
|
86
96
|
------
|
|
87
97
|
RuntimeError
|
|
88
98
|
If the server returns an error payload.
|
|
99
|
+
asyncio.TimeoutError
|
|
100
|
+
If the call times out.
|
|
89
101
|
"""
|
|
90
102
|
sm = await self._ensure_stream_manager()
|
|
91
|
-
|
|
103
|
+
|
|
104
|
+
# Use provided timeout, fall back to instance default, then global default
|
|
105
|
+
effective_timeout = timeout if timeout is not None else self.default_timeout
|
|
106
|
+
|
|
107
|
+
logger.debug("Calling MCP tool '%s' with timeout: %ss", self.tool_name, effective_timeout)
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
# Pass timeout directly to StreamManager instead of wrapping with wait_for
|
|
111
|
+
result = await sm.call_tool(
|
|
112
|
+
tool_name=self.tool_name,
|
|
113
|
+
arguments=kwargs,
|
|
114
|
+
timeout=effective_timeout
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if result.get("isError"):
|
|
118
|
+
err = result.get("error", "Unknown error")
|
|
119
|
+
logger.error("Remote MCP error from '%s': %s", self.tool_name, err)
|
|
120
|
+
raise RuntimeError(err)
|
|
121
|
+
|
|
122
|
+
return result.get("content")
|
|
123
|
+
|
|
124
|
+
except asyncio.TimeoutError:
|
|
125
|
+
logger.warning("MCP tool '%s' timed out after %ss", self.tool_name, effective_timeout)
|
|
126
|
+
raise
|
|
127
|
+
except Exception as e:
|
|
128
|
+
logger.error("Error calling MCP tool '%s': %s", self.tool_name, e)
|
|
129
|
+
raise
|
|
92
130
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
return result.get("content")
|
|
131
|
+
# ------------------------------------------------------------------ #
|
|
132
|
+
# Legacy method name support
|
|
133
|
+
async def _aexecute(self, timeout: Optional[float] = None, **kwargs: Any) -> Any:
|
|
134
|
+
"""Legacy alias for execute() method."""
|
|
135
|
+
return await self.execute(timeout=timeout, **kwargs)
|
|
@@ -14,6 +14,7 @@ Utility that wires up:
|
|
|
14
14
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
|
+
import os
|
|
17
18
|
from typing import Dict, List, Optional, Tuple
|
|
18
19
|
|
|
19
20
|
from chuk_tool_processor.core.processor import ToolProcessor
|
|
@@ -47,7 +48,26 @@ async def setup_mcp_sse( # noqa: C901 – long, but just a config wrapper
|
|
|
47
48
|
and return a ready-to-go :class:`ToolProcessor`.
|
|
48
49
|
|
|
49
50
|
Everything is **async-native** – call with ``await``.
|
|
51
|
+
|
|
52
|
+
NEW: Automatically detects and adds bearer token from MCP_BEARER_TOKEN
|
|
53
|
+
environment variable if not explicitly provided in server config.
|
|
50
54
|
"""
|
|
55
|
+
|
|
56
|
+
# NEW: Auto-detect and add bearer token to servers if available
|
|
57
|
+
bearer_token = os.getenv("MCP_BEARER_TOKEN")
|
|
58
|
+
if bearer_token:
|
|
59
|
+
logger.info("Found MCP_BEARER_TOKEN environment variable, adding to server configs")
|
|
60
|
+
|
|
61
|
+
# Add api_key to servers that don't already have it
|
|
62
|
+
enhanced_servers = []
|
|
63
|
+
for server in servers:
|
|
64
|
+
enhanced_server = dict(server) # Make a copy
|
|
65
|
+
if "api_key" not in enhanced_server and bearer_token:
|
|
66
|
+
enhanced_server["api_key"] = bearer_token
|
|
67
|
+
logger.info("Added bearer token to server: %s", enhanced_server.get("name", "unnamed"))
|
|
68
|
+
enhanced_servers.append(enhanced_server)
|
|
69
|
+
servers = enhanced_servers
|
|
70
|
+
|
|
51
71
|
# 1️⃣ connect to the remote MCP servers
|
|
52
72
|
stream_manager = await StreamManager.create_with_sse(
|
|
53
73
|
servers=servers,
|
|
@@ -76,4 +96,4 @@ async def setup_mcp_sse( # noqa: C901 – long, but just a config wrapper
|
|
|
76
96
|
"" if len(registered) == 1 else "s",
|
|
77
97
|
namespace,
|
|
78
98
|
)
|
|
79
|
-
return processor, stream_manager
|
|
99
|
+
return processor, stream_manager
|
|
@@ -47,9 +47,16 @@ class StreamManager:
|
|
|
47
47
|
servers: List[str],
|
|
48
48
|
server_names: Optional[Dict[int, str]] = None,
|
|
49
49
|
transport_type: str = "stdio",
|
|
50
|
+
default_timeout: float = 30.0, # ADD: For consistency
|
|
50
51
|
) -> "StreamManager":
|
|
51
52
|
inst = cls()
|
|
52
|
-
await inst.initialize(
|
|
53
|
+
await inst.initialize(
|
|
54
|
+
config_file,
|
|
55
|
+
servers,
|
|
56
|
+
server_names,
|
|
57
|
+
transport_type,
|
|
58
|
+
default_timeout=default_timeout # PASS THROUGH
|
|
59
|
+
)
|
|
53
60
|
return inst
|
|
54
61
|
|
|
55
62
|
@classmethod
|
|
@@ -57,9 +64,16 @@ class StreamManager:
|
|
|
57
64
|
cls,
|
|
58
65
|
servers: List[Dict[str, str]],
|
|
59
66
|
server_names: Optional[Dict[int, str]] = None,
|
|
67
|
+
connection_timeout: float = 10.0, # ADD: For SSE connection setup
|
|
68
|
+
default_timeout: float = 30.0, # ADD: For tool execution
|
|
60
69
|
) -> "StreamManager":
|
|
61
70
|
inst = cls()
|
|
62
|
-
await inst.initialize_with_sse(
|
|
71
|
+
await inst.initialize_with_sse(
|
|
72
|
+
servers,
|
|
73
|
+
server_names,
|
|
74
|
+
connection_timeout=connection_timeout, # PASS THROUGH
|
|
75
|
+
default_timeout=default_timeout # PASS THROUGH
|
|
76
|
+
)
|
|
63
77
|
return inst
|
|
64
78
|
|
|
65
79
|
# ------------------------------------------------------------------ #
|
|
@@ -71,6 +85,7 @@ class StreamManager:
|
|
|
71
85
|
servers: List[str],
|
|
72
86
|
server_names: Optional[Dict[int, str]] = None,
|
|
73
87
|
transport_type: str = "stdio",
|
|
88
|
+
default_timeout: float = 30.0, # ADD: For consistency
|
|
74
89
|
) -> None:
|
|
75
90
|
async with self._lock:
|
|
76
91
|
self.server_names = server_names or {}
|
|
@@ -81,7 +96,24 @@ class StreamManager:
|
|
|
81
96
|
params = await load_config(config_file, server_name)
|
|
82
97
|
transport: MCPBaseTransport = StdioTransport(params)
|
|
83
98
|
elif transport_type == "sse":
|
|
84
|
-
transport
|
|
99
|
+
# WARNING: For SSE transport, prefer using create_with_sse() instead
|
|
100
|
+
# This is a fallback for backward compatibility
|
|
101
|
+
logger.warning("Using SSE transport in initialize() - consider using initialize_with_sse() instead")
|
|
102
|
+
|
|
103
|
+
# Try to extract URL from params or use localhost as fallback
|
|
104
|
+
if isinstance(params, dict) and 'url' in params:
|
|
105
|
+
sse_url = params['url']
|
|
106
|
+
api_key = params.get('api_key')
|
|
107
|
+
else:
|
|
108
|
+
sse_url = "http://localhost:8000"
|
|
109
|
+
api_key = None
|
|
110
|
+
logger.warning(f"No URL configured for SSE transport, using default: {sse_url}")
|
|
111
|
+
|
|
112
|
+
transport = SSETransport(
|
|
113
|
+
sse_url,
|
|
114
|
+
api_key,
|
|
115
|
+
default_timeout=default_timeout
|
|
116
|
+
)
|
|
85
117
|
else:
|
|
86
118
|
logger.error("Unsupported transport type: %s", transport_type)
|
|
87
119
|
continue
|
|
@@ -125,6 +157,8 @@ class StreamManager:
|
|
|
125
157
|
self,
|
|
126
158
|
servers: List[Dict[str, str]],
|
|
127
159
|
server_names: Optional[Dict[int, str]] = None,
|
|
160
|
+
connection_timeout: float = 10.0, # ADD: For SSE connection setup
|
|
161
|
+
default_timeout: float = 30.0, # ADD: For tool execution
|
|
128
162
|
) -> None:
|
|
129
163
|
async with self._lock:
|
|
130
164
|
self.server_names = server_names or {}
|
|
@@ -135,7 +169,14 @@ class StreamManager:
|
|
|
135
169
|
logger.error("Bad server config: %s", cfg)
|
|
136
170
|
continue
|
|
137
171
|
try:
|
|
138
|
-
|
|
172
|
+
# FIXED: Pass timeout parameters to SSETransport
|
|
173
|
+
transport = SSETransport(
|
|
174
|
+
url,
|
|
175
|
+
cfg.get("api_key"),
|
|
176
|
+
connection_timeout=connection_timeout, # ADD THIS
|
|
177
|
+
default_timeout=default_timeout # ADD THIS
|
|
178
|
+
)
|
|
179
|
+
|
|
139
180
|
if not await transport.initialize():
|
|
140
181
|
logger.error("Failed to init SSE %s", name)
|
|
141
182
|
continue
|
|
@@ -265,7 +306,20 @@ class StreamManager:
|
|
|
265
306
|
tool_name: str,
|
|
266
307
|
arguments: Dict[str, Any],
|
|
267
308
|
server_name: Optional[str] = None,
|
|
309
|
+
timeout: Optional[float] = None, # Timeout parameter already exists
|
|
268
310
|
) -> Dict[str, Any]:
|
|
311
|
+
"""
|
|
312
|
+
Call a tool on the appropriate server with timeout support.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
tool_name: Name of the tool to call
|
|
316
|
+
arguments: Arguments to pass to the tool
|
|
317
|
+
server_name: Optional server name (auto-detected if not provided)
|
|
318
|
+
timeout: Optional timeout for the call
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
Dictionary containing the tool result or error
|
|
322
|
+
"""
|
|
269
323
|
server_name = server_name or self.get_server_for_tool(tool_name)
|
|
270
324
|
if not server_name or server_name not in self.transports:
|
|
271
325
|
# wording kept exactly for unit-test expectation
|
|
@@ -273,8 +327,42 @@ class StreamManager:
|
|
|
273
327
|
"isError": True,
|
|
274
328
|
"error": f"No server found for tool: {tool_name}",
|
|
275
329
|
}
|
|
276
|
-
|
|
277
|
-
|
|
330
|
+
|
|
331
|
+
transport = self.transports[server_name]
|
|
332
|
+
|
|
333
|
+
# Apply timeout if specified
|
|
334
|
+
if timeout is not None:
|
|
335
|
+
logger.debug("Calling tool '%s' with %ss timeout", tool_name, timeout)
|
|
336
|
+
try:
|
|
337
|
+
# ENHANCED: Pass timeout to transport.call_tool if it supports it
|
|
338
|
+
if hasattr(transport, 'call_tool'):
|
|
339
|
+
import inspect
|
|
340
|
+
sig = inspect.signature(transport.call_tool)
|
|
341
|
+
if 'timeout' in sig.parameters:
|
|
342
|
+
# Transport supports timeout parameter - pass it through
|
|
343
|
+
return await transport.call_tool(tool_name, arguments, timeout=timeout)
|
|
344
|
+
else:
|
|
345
|
+
# Transport doesn't support timeout - use asyncio.wait_for wrapper
|
|
346
|
+
return await asyncio.wait_for(
|
|
347
|
+
transport.call_tool(tool_name, arguments),
|
|
348
|
+
timeout=timeout
|
|
349
|
+
)
|
|
350
|
+
else:
|
|
351
|
+
# Fallback to asyncio.wait_for
|
|
352
|
+
return await asyncio.wait_for(
|
|
353
|
+
transport.call_tool(tool_name, arguments),
|
|
354
|
+
timeout=timeout
|
|
355
|
+
)
|
|
356
|
+
except asyncio.TimeoutError:
|
|
357
|
+
logger.warning("Tool '%s' timed out after %ss", tool_name, timeout)
|
|
358
|
+
return {
|
|
359
|
+
"isError": True,
|
|
360
|
+
"error": f"Tool call timed out after {timeout}s",
|
|
361
|
+
}
|
|
362
|
+
else:
|
|
363
|
+
# No timeout specified, call directly
|
|
364
|
+
return await transport.call_tool(tool_name, arguments)
|
|
365
|
+
|
|
278
366
|
# ------------------------------------------------------------------ #
|
|
279
367
|
# shutdown #
|
|
280
368
|
# ------------------------------------------------------------------ #
|
|
@@ -318,4 +406,4 @@ class StreamManager:
|
|
|
318
406
|
# convenience alias
|
|
319
407
|
@property
|
|
320
408
|
def streams(self) -> List[Tuple[Any, Any]]: # pragma: no cover
|
|
321
|
-
return self.get_streams()
|
|
409
|
+
return self.get_streams()
|
|
@@ -8,12 +8,15 @@ This transport:
|
|
|
8
8
|
3. Sends MCP initialize handshake FIRST
|
|
9
9
|
4. Only then proceeds with tools/list and tool calls
|
|
10
10
|
5. Handles async responses via SSE message events
|
|
11
|
+
|
|
12
|
+
FIXED: All hardcoded timeouts are now configurable parameters.
|
|
11
13
|
"""
|
|
12
14
|
from __future__ import annotations
|
|
13
15
|
|
|
14
16
|
import asyncio
|
|
15
17
|
import contextlib
|
|
16
18
|
import json
|
|
19
|
+
import os
|
|
17
20
|
from typing import Any, Dict, List, Optional
|
|
18
21
|
|
|
19
22
|
import httpx
|
|
@@ -23,7 +26,8 @@ from .base_transport import MCPBaseTransport
|
|
|
23
26
|
# --------------------------------------------------------------------------- #
|
|
24
27
|
# Helpers #
|
|
25
28
|
# --------------------------------------------------------------------------- #
|
|
26
|
-
DEFAULT_TIMEOUT = 30.0 #
|
|
29
|
+
DEFAULT_TIMEOUT = 30.0 # Default timeout for tool calls
|
|
30
|
+
DEFAULT_CONNECTION_TIMEOUT = 10.0 # Default timeout for connection setup
|
|
27
31
|
HEADERS_JSON: Dict[str, str] = {"accept": "application/json"}
|
|
28
32
|
|
|
29
33
|
|
|
@@ -46,9 +50,33 @@ class SSETransport(MCPBaseTransport):
|
|
|
46
50
|
5. Waits for async responses via SSE message events
|
|
47
51
|
"""
|
|
48
52
|
|
|
49
|
-
def __init__(
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
url: str,
|
|
56
|
+
api_key: Optional[str] = None,
|
|
57
|
+
connection_timeout: float = DEFAULT_CONNECTION_TIMEOUT,
|
|
58
|
+
default_timeout: float = DEFAULT_TIMEOUT
|
|
59
|
+
) -> None:
|
|
60
|
+
"""
|
|
61
|
+
Initialize SSE Transport with configurable timeouts.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
url: Base URL for the MCP server
|
|
65
|
+
api_key: Optional API key for authentication
|
|
66
|
+
connection_timeout: Timeout for connection setup (default: 10.0s)
|
|
67
|
+
default_timeout: Default timeout for tool calls (default: 30.0s)
|
|
68
|
+
"""
|
|
50
69
|
self.base_url = url.rstrip("/")
|
|
51
70
|
self.api_key = api_key
|
|
71
|
+
self.connection_timeout = connection_timeout
|
|
72
|
+
self.default_timeout = default_timeout
|
|
73
|
+
|
|
74
|
+
# NEW: Auto-detect bearer token from environment if not provided
|
|
75
|
+
if not self.api_key:
|
|
76
|
+
bearer_token = os.getenv("MCP_BEARER_TOKEN")
|
|
77
|
+
if bearer_token:
|
|
78
|
+
self.api_key = bearer_token
|
|
79
|
+
print(f"🔑 Using bearer token from MCP_BEARER_TOKEN environment variable")
|
|
52
80
|
|
|
53
81
|
# httpx client (None until initialise)
|
|
54
82
|
self._client: httpx.AsyncClient | None = None
|
|
@@ -75,11 +103,16 @@ class SSETransport(MCPBaseTransport):
|
|
|
75
103
|
|
|
76
104
|
headers = {}
|
|
77
105
|
if self.api_key:
|
|
78
|
-
|
|
106
|
+
# NEW: Handle both "Bearer token" and just "token" formats
|
|
107
|
+
if self.api_key.startswith("Bearer "):
|
|
108
|
+
headers["Authorization"] = self.api_key
|
|
109
|
+
else:
|
|
110
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
111
|
+
print(f"🔑 Added Authorization header to httpx client")
|
|
79
112
|
|
|
80
113
|
self._client = httpx.AsyncClient(
|
|
81
114
|
headers=headers,
|
|
82
|
-
timeout=
|
|
115
|
+
timeout=self.default_timeout, # Use configurable timeout
|
|
83
116
|
)
|
|
84
117
|
self.session = self._client
|
|
85
118
|
|
|
@@ -87,8 +120,8 @@ class SSETransport(MCPBaseTransport):
|
|
|
87
120
|
self._sse_task = asyncio.create_task(self._handle_sse_connection())
|
|
88
121
|
|
|
89
122
|
try:
|
|
90
|
-
#
|
|
91
|
-
await asyncio.wait_for(self._connected.wait(), timeout=
|
|
123
|
+
# FIXED: Use configurable connection timeout instead of hardcoded 10.0
|
|
124
|
+
await asyncio.wait_for(self._connected.wait(), timeout=self.connection_timeout)
|
|
92
125
|
|
|
93
126
|
# NEW: Send MCP initialize handshake
|
|
94
127
|
if await self._initialize_mcp_session():
|
|
@@ -223,6 +256,11 @@ class SSETransport(MCPBaseTransport):
|
|
|
223
256
|
|
|
224
257
|
if event_type == "endpoint":
|
|
225
258
|
# Got the endpoint URL for messages - construct full URL
|
|
259
|
+
# NEW: Handle URLs that need trailing slash fix
|
|
260
|
+
if "/messages?" in data and "/messages/?" not in data:
|
|
261
|
+
data = data.replace("/messages?", "/messages/?", 1)
|
|
262
|
+
print(f"🔧 Fixed URL redirect: added trailing slash")
|
|
263
|
+
|
|
226
264
|
self._message_url = f"{self.base_url}{data}"
|
|
227
265
|
|
|
228
266
|
# Extract session_id if present
|
|
@@ -267,7 +305,8 @@ class SSETransport(MCPBaseTransport):
|
|
|
267
305
|
if not self._initialized.is_set():
|
|
268
306
|
print("⏳ Waiting for MCP initialization...")
|
|
269
307
|
try:
|
|
270
|
-
|
|
308
|
+
# FIXED: Use configurable connection timeout instead of hardcoded 10.0
|
|
309
|
+
await asyncio.wait_for(self._initialized.wait(), timeout=self.connection_timeout)
|
|
271
310
|
except asyncio.TimeoutError:
|
|
272
311
|
print("❌ Timeout waiting for MCP initialization")
|
|
273
312
|
return []
|
|
@@ -293,8 +332,23 @@ class SSETransport(MCPBaseTransport):
|
|
|
293
332
|
|
|
294
333
|
return []
|
|
295
334
|
|
|
296
|
-
async def call_tool(
|
|
297
|
-
|
|
335
|
+
async def call_tool(
|
|
336
|
+
self,
|
|
337
|
+
tool_name: str,
|
|
338
|
+
arguments: Dict[str, Any],
|
|
339
|
+
timeout: Optional[float] = None
|
|
340
|
+
) -> Dict[str, Any]:
|
|
341
|
+
"""
|
|
342
|
+
Execute a tool call using the MCP protocol.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
tool_name: Name of the tool to call
|
|
346
|
+
arguments: Arguments to pass to the tool
|
|
347
|
+
timeout: Optional timeout for this specific call
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
Dictionary containing the tool result or error
|
|
351
|
+
"""
|
|
298
352
|
# NEW: Ensure initialization before tool calls
|
|
299
353
|
if not self._initialized.is_set():
|
|
300
354
|
return {"isError": True, "error": "MCP session not initialized"}
|
|
@@ -313,7 +367,9 @@ class SSETransport(MCPBaseTransport):
|
|
|
313
367
|
}
|
|
314
368
|
}
|
|
315
369
|
|
|
316
|
-
|
|
370
|
+
# Use provided timeout or fall back to default
|
|
371
|
+
effective_timeout = timeout if timeout is not None else self.default_timeout
|
|
372
|
+
response = await self._send_message(message, timeout=effective_timeout)
|
|
317
373
|
|
|
318
374
|
# Process MCP response
|
|
319
375
|
if "error" in response:
|
|
@@ -345,8 +401,21 @@ class SSETransport(MCPBaseTransport):
|
|
|
345
401
|
except Exception as e:
|
|
346
402
|
return {"isError": True, "error": str(e)}
|
|
347
403
|
|
|
348
|
-
async def _send_message(
|
|
349
|
-
|
|
404
|
+
async def _send_message(
|
|
405
|
+
self,
|
|
406
|
+
message: Dict[str, Any],
|
|
407
|
+
timeout: Optional[float] = None
|
|
408
|
+
) -> Dict[str, Any]:
|
|
409
|
+
"""
|
|
410
|
+
Send a JSON-RPC message to the server and wait for async response.
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
message: JSON-RPC message to send
|
|
414
|
+
timeout: Optional timeout for this specific message
|
|
415
|
+
|
|
416
|
+
Returns:
|
|
417
|
+
Response message from the server
|
|
418
|
+
"""
|
|
350
419
|
if not self._client or not self._message_url:
|
|
351
420
|
raise RuntimeError("Transport not properly initialized")
|
|
352
421
|
|
|
@@ -354,6 +423,9 @@ class SSETransport(MCPBaseTransport):
|
|
|
354
423
|
if not message_id:
|
|
355
424
|
raise ValueError("Message must have an ID")
|
|
356
425
|
|
|
426
|
+
# Use provided timeout or fall back to default
|
|
427
|
+
effective_timeout = timeout if timeout is not None else self.default_timeout
|
|
428
|
+
|
|
357
429
|
# Create a future for this request
|
|
358
430
|
future = asyncio.Future()
|
|
359
431
|
async with self._message_lock:
|
|
@@ -373,7 +445,8 @@ class SSETransport(MCPBaseTransport):
|
|
|
373
445
|
if response.status_code == 202:
|
|
374
446
|
# Server accepted - wait for async response via SSE
|
|
375
447
|
try:
|
|
376
|
-
|
|
448
|
+
# FIXED: Use effective_timeout instead of hardcoded 30.0
|
|
449
|
+
response_message = await asyncio.wait_for(future, timeout=effective_timeout)
|
|
377
450
|
return response_message
|
|
378
451
|
except asyncio.TimeoutError:
|
|
379
452
|
raise RuntimeError(f"Timeout waiting for response to message {message_id}")
|