chuk-tool-processor 0.6.11__py3-none-any.whl → 0.6.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chuk-tool-processor might be problematic. Click here for more details.
- chuk_tool_processor/core/__init__.py +1 -1
- chuk_tool_processor/core/exceptions.py +10 -4
- chuk_tool_processor/core/processor.py +97 -97
- chuk_tool_processor/execution/strategies/inprocess_strategy.py +142 -150
- chuk_tool_processor/execution/strategies/subprocess_strategy.py +200 -205
- chuk_tool_processor/execution/tool_executor.py +82 -84
- chuk_tool_processor/execution/wrappers/caching.py +102 -103
- chuk_tool_processor/execution/wrappers/rate_limiting.py +45 -42
- chuk_tool_processor/execution/wrappers/retry.py +23 -25
- chuk_tool_processor/logging/__init__.py +23 -17
- chuk_tool_processor/logging/context.py +40 -45
- chuk_tool_processor/logging/formatter.py +22 -21
- chuk_tool_processor/logging/helpers.py +24 -38
- chuk_tool_processor/logging/metrics.py +11 -13
- chuk_tool_processor/mcp/__init__.py +8 -12
- chuk_tool_processor/mcp/mcp_tool.py +153 -109
- chuk_tool_processor/mcp/register_mcp_tools.py +17 -17
- chuk_tool_processor/mcp/setup_mcp_http_streamable.py +11 -13
- chuk_tool_processor/mcp/setup_mcp_sse.py +11 -13
- chuk_tool_processor/mcp/setup_mcp_stdio.py +7 -9
- chuk_tool_processor/mcp/stream_manager.py +168 -204
- chuk_tool_processor/mcp/transport/__init__.py +4 -4
- chuk_tool_processor/mcp/transport/base_transport.py +43 -58
- chuk_tool_processor/mcp/transport/http_streamable_transport.py +145 -163
- chuk_tool_processor/mcp/transport/sse_transport.py +266 -252
- chuk_tool_processor/mcp/transport/stdio_transport.py +171 -189
- chuk_tool_processor/models/__init__.py +1 -1
- chuk_tool_processor/models/execution_strategy.py +16 -21
- chuk_tool_processor/models/streaming_tool.py +28 -25
- chuk_tool_processor/models/tool_call.py +19 -34
- chuk_tool_processor/models/tool_export_mixin.py +22 -8
- chuk_tool_processor/models/tool_result.py +40 -77
- chuk_tool_processor/models/validated_tool.py +14 -16
- chuk_tool_processor/plugins/__init__.py +1 -1
- chuk_tool_processor/plugins/discovery.py +10 -10
- chuk_tool_processor/plugins/parsers/__init__.py +1 -1
- chuk_tool_processor/plugins/parsers/base.py +1 -2
- chuk_tool_processor/plugins/parsers/function_call_tool.py +13 -8
- chuk_tool_processor/plugins/parsers/json_tool.py +4 -3
- chuk_tool_processor/plugins/parsers/openai_tool.py +12 -7
- chuk_tool_processor/plugins/parsers/xml_tool.py +4 -4
- chuk_tool_processor/registry/__init__.py +12 -12
- chuk_tool_processor/registry/auto_register.py +22 -30
- chuk_tool_processor/registry/decorators.py +127 -129
- chuk_tool_processor/registry/interface.py +26 -23
- chuk_tool_processor/registry/metadata.py +27 -22
- chuk_tool_processor/registry/provider.py +17 -18
- chuk_tool_processor/registry/providers/__init__.py +16 -19
- chuk_tool_processor/registry/providers/memory.py +18 -25
- chuk_tool_processor/registry/tool_export.py +42 -51
- chuk_tool_processor/utils/validation.py +15 -16
- {chuk_tool_processor-0.6.11.dist-info → chuk_tool_processor-0.6.13.dist-info}/METADATA +1 -1
- chuk_tool_processor-0.6.13.dist-info/RECORD +60 -0
- chuk_tool_processor-0.6.11.dist-info/RECORD +0 -60
- {chuk_tool_processor-0.6.11.dist-info → chuk_tool_processor-0.6.13.dist-info}/WHEEL +0 -0
- {chuk_tool_processor-0.6.11.dist-info → chuk_tool_processor-0.6.13.dist-info}/top_level.txt +0 -0
|
@@ -2,23 +2,26 @@
|
|
|
2
2
|
"""
|
|
3
3
|
StreamManager for CHUK Tool Processor - Enhanced with robust shutdown handling and headers support
|
|
4
4
|
"""
|
|
5
|
+
|
|
5
6
|
from __future__ import annotations
|
|
6
7
|
|
|
7
8
|
import asyncio
|
|
8
|
-
|
|
9
|
+
import contextlib
|
|
9
10
|
from contextlib import asynccontextmanager
|
|
11
|
+
from typing import Any
|
|
10
12
|
|
|
11
13
|
# --------------------------------------------------------------------------- #
|
|
12
14
|
# CHUK imports #
|
|
13
15
|
# --------------------------------------------------------------------------- #
|
|
14
|
-
from chuk_mcp.config import load_config
|
|
16
|
+
from chuk_mcp.config import load_config # type: ignore[import-untyped]
|
|
17
|
+
|
|
18
|
+
from chuk_tool_processor.logging import get_logger
|
|
15
19
|
from chuk_tool_processor.mcp.transport import (
|
|
20
|
+
HTTPStreamableTransport,
|
|
16
21
|
MCPBaseTransport,
|
|
17
|
-
StdioTransport,
|
|
18
22
|
SSETransport,
|
|
19
|
-
|
|
23
|
+
StdioTransport,
|
|
20
24
|
)
|
|
21
|
-
from chuk_tool_processor.logging import get_logger
|
|
22
25
|
|
|
23
26
|
logger = get_logger("chuk_tool_processor.mcp.stream_manager")
|
|
24
27
|
|
|
@@ -26,9 +29,9 @@ logger = get_logger("chuk_tool_processor.mcp.stream_manager")
|
|
|
26
29
|
class StreamManager:
|
|
27
30
|
"""
|
|
28
31
|
Manager for MCP server streams with support for multiple transport types.
|
|
29
|
-
|
|
32
|
+
|
|
30
33
|
Enhanced with robust shutdown handling and proper headers support.
|
|
31
|
-
|
|
34
|
+
|
|
32
35
|
Updated to support the latest transports:
|
|
33
36
|
- STDIO (process-based)
|
|
34
37
|
- SSE (Server-Sent Events) with headers support
|
|
@@ -36,11 +39,11 @@ class StreamManager:
|
|
|
36
39
|
"""
|
|
37
40
|
|
|
38
41
|
def __init__(self) -> None:
|
|
39
|
-
self.transports:
|
|
40
|
-
self.server_info:
|
|
41
|
-
self.tool_to_server_map:
|
|
42
|
-
self.server_names:
|
|
43
|
-
self.all_tools:
|
|
42
|
+
self.transports: dict[str, MCPBaseTransport] = {}
|
|
43
|
+
self.server_info: list[dict[str, Any]] = []
|
|
44
|
+
self.tool_to_server_map: dict[str, str] = {}
|
|
45
|
+
self.server_names: dict[int, str] = {}
|
|
46
|
+
self.all_tools: list[dict[str, Any]] = []
|
|
44
47
|
self._lock = asyncio.Lock()
|
|
45
48
|
self._closed = False # Track if we've been closed
|
|
46
49
|
self._shutdown_timeout = 2.0 # Maximum time to spend on shutdown
|
|
@@ -52,81 +55,71 @@ class StreamManager:
|
|
|
52
55
|
async def create(
|
|
53
56
|
cls,
|
|
54
57
|
config_file: str,
|
|
55
|
-
servers:
|
|
56
|
-
server_names:
|
|
58
|
+
servers: list[str],
|
|
59
|
+
server_names: dict[int, str] | None = None,
|
|
57
60
|
transport_type: str = "stdio",
|
|
58
61
|
default_timeout: float = 30.0,
|
|
59
62
|
initialization_timeout: float = 60.0, # NEW: Timeout for entire initialization
|
|
60
|
-
) ->
|
|
63
|
+
) -> StreamManager:
|
|
61
64
|
"""Create StreamManager with timeout protection."""
|
|
62
65
|
try:
|
|
63
66
|
inst = cls()
|
|
64
67
|
await asyncio.wait_for(
|
|
65
|
-
inst.initialize(
|
|
66
|
-
|
|
67
|
-
servers,
|
|
68
|
-
server_names,
|
|
69
|
-
transport_type,
|
|
70
|
-
default_timeout=default_timeout
|
|
71
|
-
),
|
|
72
|
-
timeout=initialization_timeout
|
|
68
|
+
inst.initialize(config_file, servers, server_names, transport_type, default_timeout=default_timeout),
|
|
69
|
+
timeout=initialization_timeout,
|
|
73
70
|
)
|
|
74
71
|
return inst
|
|
75
|
-
except
|
|
72
|
+
except TimeoutError:
|
|
76
73
|
logger.error("StreamManager initialization timed out after %ss", initialization_timeout)
|
|
77
74
|
raise RuntimeError(f"StreamManager initialization timed out after {initialization_timeout}s")
|
|
78
75
|
|
|
79
76
|
@classmethod
|
|
80
77
|
async def create_with_sse(
|
|
81
78
|
cls,
|
|
82
|
-
servers:
|
|
83
|
-
server_names:
|
|
79
|
+
servers: list[dict[str, str]],
|
|
80
|
+
server_names: dict[int, str] | None = None,
|
|
84
81
|
connection_timeout: float = 10.0,
|
|
85
82
|
default_timeout: float = 30.0,
|
|
86
83
|
initialization_timeout: float = 60.0, # NEW
|
|
87
|
-
) ->
|
|
84
|
+
) -> StreamManager:
|
|
88
85
|
"""Create StreamManager with SSE transport and timeout protection."""
|
|
89
86
|
try:
|
|
90
87
|
inst = cls()
|
|
91
88
|
await asyncio.wait_for(
|
|
92
89
|
inst.initialize_with_sse(
|
|
93
|
-
servers,
|
|
94
|
-
server_names,
|
|
95
|
-
connection_timeout=connection_timeout,
|
|
96
|
-
default_timeout=default_timeout
|
|
90
|
+
servers, server_names, connection_timeout=connection_timeout, default_timeout=default_timeout
|
|
97
91
|
),
|
|
98
|
-
timeout=initialization_timeout
|
|
92
|
+
timeout=initialization_timeout,
|
|
99
93
|
)
|
|
100
94
|
return inst
|
|
101
|
-
except
|
|
95
|
+
except TimeoutError:
|
|
102
96
|
logger.error("SSE StreamManager initialization timed out after %ss", initialization_timeout)
|
|
103
97
|
raise RuntimeError(f"SSE StreamManager initialization timed out after {initialization_timeout}s")
|
|
104
98
|
|
|
105
99
|
@classmethod
|
|
106
100
|
async def create_with_http_streamable(
|
|
107
101
|
cls,
|
|
108
|
-
servers:
|
|
109
|
-
server_names:
|
|
102
|
+
servers: list[dict[str, str]],
|
|
103
|
+
server_names: dict[int, str] | None = None,
|
|
110
104
|
connection_timeout: float = 30.0,
|
|
111
105
|
default_timeout: float = 30.0,
|
|
112
106
|
initialization_timeout: float = 60.0, # NEW
|
|
113
|
-
) ->
|
|
107
|
+
) -> StreamManager:
|
|
114
108
|
"""Create StreamManager with HTTP Streamable transport and timeout protection."""
|
|
115
109
|
try:
|
|
116
110
|
inst = cls()
|
|
117
111
|
await asyncio.wait_for(
|
|
118
112
|
inst.initialize_with_http_streamable(
|
|
119
|
-
servers,
|
|
120
|
-
server_names,
|
|
121
|
-
connection_timeout=connection_timeout,
|
|
122
|
-
default_timeout=default_timeout
|
|
113
|
+
servers, server_names, connection_timeout=connection_timeout, default_timeout=default_timeout
|
|
123
114
|
),
|
|
124
|
-
timeout=initialization_timeout
|
|
115
|
+
timeout=initialization_timeout,
|
|
125
116
|
)
|
|
126
117
|
return inst
|
|
127
|
-
except
|
|
118
|
+
except TimeoutError:
|
|
128
119
|
logger.error("HTTP Streamable StreamManager initialization timed out after %ss", initialization_timeout)
|
|
129
|
-
raise RuntimeError(
|
|
120
|
+
raise RuntimeError(
|
|
121
|
+
f"HTTP Streamable StreamManager initialization timed out after {initialization_timeout}s"
|
|
122
|
+
)
|
|
130
123
|
|
|
131
124
|
# ------------------------------------------------------------------ #
|
|
132
125
|
# NEW: Context manager support for automatic cleanup #
|
|
@@ -144,8 +137,8 @@ class StreamManager:
|
|
|
144
137
|
async def create_managed(
|
|
145
138
|
cls,
|
|
146
139
|
config_file: str,
|
|
147
|
-
servers:
|
|
148
|
-
server_names:
|
|
140
|
+
servers: list[str],
|
|
141
|
+
server_names: dict[int, str] | None = None,
|
|
149
142
|
transport_type: str = "stdio",
|
|
150
143
|
default_timeout: float = 30.0,
|
|
151
144
|
):
|
|
@@ -170,15 +163,15 @@ class StreamManager:
|
|
|
170
163
|
async def initialize(
|
|
171
164
|
self,
|
|
172
165
|
config_file: str,
|
|
173
|
-
servers:
|
|
174
|
-
server_names:
|
|
166
|
+
servers: list[str],
|
|
167
|
+
server_names: dict[int, str] | None = None,
|
|
175
168
|
transport_type: str = "stdio",
|
|
176
169
|
default_timeout: float = 30.0,
|
|
177
170
|
) -> None:
|
|
178
171
|
"""Initialize with graceful headers handling for all transport types."""
|
|
179
172
|
if self._closed:
|
|
180
173
|
raise RuntimeError("Cannot initialize a closed StreamManager")
|
|
181
|
-
|
|
174
|
+
|
|
182
175
|
async with self._lock:
|
|
183
176
|
self.server_names = server_names or {}
|
|
184
177
|
|
|
@@ -188,59 +181,61 @@ class StreamManager:
|
|
|
188
181
|
params = await load_config(config_file, server_name)
|
|
189
182
|
transport: MCPBaseTransport = StdioTransport(params)
|
|
190
183
|
elif transport_type == "sse":
|
|
191
|
-
logger.warning(
|
|
184
|
+
logger.warning(
|
|
185
|
+
"Using SSE transport in initialize() - consider using initialize_with_sse() instead"
|
|
186
|
+
)
|
|
192
187
|
params = await load_config(config_file, server_name)
|
|
193
|
-
|
|
194
|
-
if isinstance(params, dict) and
|
|
195
|
-
sse_url = params[
|
|
196
|
-
api_key = params.get(
|
|
197
|
-
headers = params.get(
|
|
188
|
+
|
|
189
|
+
if isinstance(params, dict) and "url" in params:
|
|
190
|
+
sse_url = params["url"]
|
|
191
|
+
api_key = params.get("api_key")
|
|
192
|
+
headers = params.get("headers", {})
|
|
198
193
|
else:
|
|
199
194
|
sse_url = "http://localhost:8000"
|
|
200
195
|
api_key = None
|
|
201
196
|
headers = {}
|
|
202
197
|
logger.warning("No URL configured for SSE transport, using default: %s", sse_url)
|
|
203
|
-
|
|
198
|
+
|
|
204
199
|
# Build SSE transport with optional headers
|
|
205
|
-
transport_params = {
|
|
206
|
-
'url': sse_url,
|
|
207
|
-
'api_key': api_key,
|
|
208
|
-
'default_timeout': default_timeout
|
|
209
|
-
}
|
|
200
|
+
transport_params = {"url": sse_url, "api_key": api_key, "default_timeout": default_timeout}
|
|
210
201
|
if headers:
|
|
211
|
-
transport_params[
|
|
212
|
-
|
|
202
|
+
transport_params["headers"] = headers
|
|
203
|
+
|
|
213
204
|
transport = SSETransport(**transport_params)
|
|
214
|
-
|
|
205
|
+
|
|
215
206
|
elif transport_type == "http_streamable":
|
|
216
|
-
logger.warning(
|
|
207
|
+
logger.warning(
|
|
208
|
+
"Using HTTP Streamable transport in initialize() - consider using initialize_with_http_streamable() instead"
|
|
209
|
+
)
|
|
217
210
|
params = await load_config(config_file, server_name)
|
|
218
|
-
|
|
219
|
-
if isinstance(params, dict) and
|
|
220
|
-
http_url = params[
|
|
221
|
-
api_key = params.get(
|
|
222
|
-
headers = params.get(
|
|
223
|
-
session_id = params.get(
|
|
211
|
+
|
|
212
|
+
if isinstance(params, dict) and "url" in params:
|
|
213
|
+
http_url = params["url"]
|
|
214
|
+
api_key = params.get("api_key")
|
|
215
|
+
headers = params.get("headers", {})
|
|
216
|
+
session_id = params.get("session_id")
|
|
224
217
|
else:
|
|
225
218
|
http_url = "http://localhost:8000"
|
|
226
219
|
api_key = None
|
|
227
220
|
headers = {}
|
|
228
221
|
session_id = None
|
|
229
|
-
logger.warning(
|
|
230
|
-
|
|
222
|
+
logger.warning(
|
|
223
|
+
"No URL configured for HTTP Streamable transport, using default: %s", http_url
|
|
224
|
+
)
|
|
225
|
+
|
|
231
226
|
# Build HTTP transport (headers not supported yet)
|
|
232
227
|
transport_params = {
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
228
|
+
"url": http_url,
|
|
229
|
+
"api_key": api_key,
|
|
230
|
+
"default_timeout": default_timeout,
|
|
231
|
+
"session_id": session_id,
|
|
237
232
|
}
|
|
238
233
|
# Note: headers not added until HTTPStreamableTransport supports them
|
|
239
234
|
if headers:
|
|
240
235
|
logger.debug("Headers provided but not supported in HTTPStreamableTransport yet")
|
|
241
|
-
|
|
236
|
+
|
|
242
237
|
transport = HTTPStreamableTransport(**transport_params)
|
|
243
|
-
|
|
238
|
+
|
|
244
239
|
else:
|
|
245
240
|
logger.error("Unsupported transport type: %s", transport_type)
|
|
246
241
|
continue
|
|
@@ -271,7 +266,7 @@ class StreamManager:
|
|
|
271
266
|
}
|
|
272
267
|
)
|
|
273
268
|
logger.debug("Initialised %s - %d tool(s)", server_name, len(tools))
|
|
274
|
-
except
|
|
269
|
+
except TimeoutError:
|
|
275
270
|
logger.error("Timeout initialising %s", server_name)
|
|
276
271
|
except Exception as exc:
|
|
277
272
|
logger.error("Error initialising %s: %s", server_name, exc)
|
|
@@ -284,15 +279,15 @@ class StreamManager:
|
|
|
284
279
|
|
|
285
280
|
async def initialize_with_sse(
|
|
286
281
|
self,
|
|
287
|
-
servers:
|
|
288
|
-
server_names:
|
|
282
|
+
servers: list[dict[str, str]],
|
|
283
|
+
server_names: dict[int, str] | None = None,
|
|
289
284
|
connection_timeout: float = 10.0,
|
|
290
285
|
default_timeout: float = 30.0,
|
|
291
286
|
) -> None:
|
|
292
287
|
"""Initialize with SSE transport with optional headers support."""
|
|
293
288
|
if self._closed:
|
|
294
289
|
raise RuntimeError("Cannot initialize a closed StreamManager")
|
|
295
|
-
|
|
290
|
+
|
|
296
291
|
async with self._lock:
|
|
297
292
|
self.server_names = server_names or {}
|
|
298
293
|
|
|
@@ -304,20 +299,20 @@ class StreamManager:
|
|
|
304
299
|
try:
|
|
305
300
|
# Build SSE transport parameters with optional headers
|
|
306
301
|
transport_params = {
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
302
|
+
"url": url,
|
|
303
|
+
"api_key": cfg.get("api_key"),
|
|
304
|
+
"connection_timeout": connection_timeout,
|
|
305
|
+
"default_timeout": default_timeout,
|
|
311
306
|
}
|
|
312
|
-
|
|
307
|
+
|
|
313
308
|
# Add headers if provided
|
|
314
309
|
headers = cfg.get("headers", {})
|
|
315
310
|
if headers:
|
|
316
311
|
logger.debug("SSE %s: Using configured headers: %s", name, list(headers.keys()))
|
|
317
|
-
transport_params[
|
|
318
|
-
|
|
312
|
+
transport_params["headers"] = headers
|
|
313
|
+
|
|
319
314
|
transport = SSETransport(**transport_params)
|
|
320
|
-
|
|
315
|
+
|
|
321
316
|
if not await asyncio.wait_for(transport.initialize(), timeout=connection_timeout):
|
|
322
317
|
logger.error("Failed to init SSE %s", name)
|
|
323
318
|
continue
|
|
@@ -332,11 +327,9 @@ class StreamManager:
|
|
|
332
327
|
self.tool_to_server_map[tname] = name
|
|
333
328
|
self.all_tools.extend(tools)
|
|
334
329
|
|
|
335
|
-
self.server_info.append(
|
|
336
|
-
{"id": idx, "name": name, "tools": len(tools), "status": status}
|
|
337
|
-
)
|
|
330
|
+
self.server_info.append({"id": idx, "name": name, "tools": len(tools), "status": status})
|
|
338
331
|
logger.debug("Initialised SSE %s - %d tool(s)", name, len(tools))
|
|
339
|
-
except
|
|
332
|
+
except TimeoutError:
|
|
340
333
|
logger.error("Timeout initialising SSE %s", name)
|
|
341
334
|
except Exception as exc:
|
|
342
335
|
logger.error("Error initialising SSE %s: %s", name, exc)
|
|
@@ -349,15 +342,15 @@ class StreamManager:
|
|
|
349
342
|
|
|
350
343
|
async def initialize_with_http_streamable(
|
|
351
344
|
self,
|
|
352
|
-
servers:
|
|
353
|
-
server_names:
|
|
345
|
+
servers: list[dict[str, str]],
|
|
346
|
+
server_names: dict[int, str] | None = None,
|
|
354
347
|
connection_timeout: float = 30.0,
|
|
355
348
|
default_timeout: float = 30.0,
|
|
356
349
|
) -> None:
|
|
357
350
|
"""Initialize with HTTP Streamable transport with graceful headers handling."""
|
|
358
351
|
if self._closed:
|
|
359
352
|
raise RuntimeError("Cannot initialize a closed StreamManager")
|
|
360
|
-
|
|
353
|
+
|
|
361
354
|
async with self._lock:
|
|
362
355
|
self.server_names = server_names or {}
|
|
363
356
|
|
|
@@ -369,22 +362,22 @@ class StreamManager:
|
|
|
369
362
|
try:
|
|
370
363
|
# Build HTTP Streamable transport parameters
|
|
371
364
|
transport_params = {
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
365
|
+
"url": url,
|
|
366
|
+
"api_key": cfg.get("api_key"),
|
|
367
|
+
"connection_timeout": connection_timeout,
|
|
368
|
+
"default_timeout": default_timeout,
|
|
369
|
+
"session_id": cfg.get("session_id"),
|
|
377
370
|
}
|
|
378
|
-
|
|
371
|
+
|
|
379
372
|
# Handle headers if provided (for future HTTPStreamableTransport support)
|
|
380
373
|
headers = cfg.get("headers", {})
|
|
381
374
|
if headers:
|
|
382
375
|
logger.debug("HTTP Streamable %s: Headers provided but not yet supported in transport", name)
|
|
383
376
|
# TODO: Add headers support when HTTPStreamableTransport is updated
|
|
384
377
|
# transport_params['headers'] = headers
|
|
385
|
-
|
|
378
|
+
|
|
386
379
|
transport = HTTPStreamableTransport(**transport_params)
|
|
387
|
-
|
|
380
|
+
|
|
388
381
|
if not await asyncio.wait_for(transport.initialize(), timeout=connection_timeout):
|
|
389
382
|
logger.error("Failed to init HTTP Streamable %s", name)
|
|
390
383
|
continue
|
|
@@ -399,11 +392,9 @@ class StreamManager:
|
|
|
399
392
|
self.tool_to_server_map[tname] = name
|
|
400
393
|
self.all_tools.extend(tools)
|
|
401
394
|
|
|
402
|
-
self.server_info.append(
|
|
403
|
-
{"id": idx, "name": name, "tools": len(tools), "status": status}
|
|
404
|
-
)
|
|
395
|
+
self.server_info.append({"id": idx, "name": name, "tools": len(tools), "status": status})
|
|
405
396
|
logger.debug("Initialised HTTP Streamable %s - %d tool(s)", name, len(tools))
|
|
406
|
-
except
|
|
397
|
+
except TimeoutError:
|
|
407
398
|
logger.error("Timeout initialising HTTP Streamable %s", name)
|
|
408
399
|
except Exception as exc:
|
|
409
400
|
logger.error("Error initialising HTTP Streamable %s: %s", name, exc)
|
|
@@ -417,32 +408,32 @@ class StreamManager:
|
|
|
417
408
|
# ------------------------------------------------------------------ #
|
|
418
409
|
# queries #
|
|
419
410
|
# ------------------------------------------------------------------ #
|
|
420
|
-
def get_all_tools(self) ->
|
|
411
|
+
def get_all_tools(self) -> list[dict[str, Any]]:
|
|
421
412
|
return self.all_tools
|
|
422
413
|
|
|
423
|
-
def get_server_for_tool(self, tool_name: str) ->
|
|
414
|
+
def get_server_for_tool(self, tool_name: str) -> str | None:
|
|
424
415
|
return self.tool_to_server_map.get(tool_name)
|
|
425
416
|
|
|
426
|
-
def get_server_info(self) ->
|
|
417
|
+
def get_server_info(self) -> list[dict[str, Any]]:
|
|
427
418
|
return self.server_info
|
|
428
|
-
|
|
429
|
-
async def list_tools(self, server_name: str) ->
|
|
419
|
+
|
|
420
|
+
async def list_tools(self, server_name: str) -> list[dict[str, Any]]:
|
|
430
421
|
"""List all tools available from a specific server."""
|
|
431
422
|
if self._closed:
|
|
432
423
|
logger.warning("Cannot list tools: StreamManager is closed")
|
|
433
424
|
return []
|
|
434
|
-
|
|
425
|
+
|
|
435
426
|
if server_name not in self.transports:
|
|
436
427
|
logger.error("Server '%s' not found in transports", server_name)
|
|
437
428
|
return []
|
|
438
|
-
|
|
429
|
+
|
|
439
430
|
transport = self.transports[server_name]
|
|
440
|
-
|
|
431
|
+
|
|
441
432
|
try:
|
|
442
433
|
tools = await asyncio.wait_for(transport.get_tools(), timeout=10.0)
|
|
443
434
|
logger.debug("Found %d tools for server %s", len(tools), server_name)
|
|
444
435
|
return tools
|
|
445
|
-
except
|
|
436
|
+
except TimeoutError:
|
|
446
437
|
logger.error("Timeout listing tools for server %s", server_name)
|
|
447
438
|
return []
|
|
448
439
|
except Exception as e:
|
|
@@ -452,10 +443,10 @@ class StreamManager:
|
|
|
452
443
|
# ------------------------------------------------------------------ #
|
|
453
444
|
# EXTRA HELPERS - ping / resources / prompts #
|
|
454
445
|
# ------------------------------------------------------------------ #
|
|
455
|
-
async def ping_servers(self) ->
|
|
446
|
+
async def ping_servers(self) -> list[dict[str, Any]]:
|
|
456
447
|
if self._closed:
|
|
457
448
|
return []
|
|
458
|
-
|
|
449
|
+
|
|
459
450
|
async def _ping_one(name: str, tr: MCPBaseTransport):
|
|
460
451
|
try:
|
|
461
452
|
ok = await asyncio.wait_for(tr.send_ping(), timeout=5.0)
|
|
@@ -465,18 +456,16 @@ class StreamManager:
|
|
|
465
456
|
|
|
466
457
|
return await asyncio.gather(*(_ping_one(n, t) for n, t in self.transports.items()), return_exceptions=True)
|
|
467
458
|
|
|
468
|
-
async def list_resources(self) ->
|
|
459
|
+
async def list_resources(self) -> list[dict[str, Any]]:
|
|
469
460
|
if self._closed:
|
|
470
461
|
return []
|
|
471
|
-
|
|
472
|
-
out:
|
|
462
|
+
|
|
463
|
+
out: list[dict[str, Any]] = []
|
|
473
464
|
|
|
474
465
|
async def _one(name: str, tr: MCPBaseTransport):
|
|
475
466
|
try:
|
|
476
467
|
res = await asyncio.wait_for(tr.list_resources(), timeout=10.0)
|
|
477
|
-
resources = (
|
|
478
|
-
res.get("resources", []) if isinstance(res, dict) else res
|
|
479
|
-
)
|
|
468
|
+
resources = res.get("resources", []) if isinstance(res, dict) else res
|
|
480
469
|
for item in resources:
|
|
481
470
|
item = dict(item)
|
|
482
471
|
item["server"] = name
|
|
@@ -487,11 +476,11 @@ class StreamManager:
|
|
|
487
476
|
await asyncio.gather(*(_one(n, t) for n, t in self.transports.items()), return_exceptions=True)
|
|
488
477
|
return out
|
|
489
478
|
|
|
490
|
-
async def list_prompts(self) ->
|
|
479
|
+
async def list_prompts(self) -> list[dict[str, Any]]:
|
|
491
480
|
if self._closed:
|
|
492
481
|
return []
|
|
493
|
-
|
|
494
|
-
out:
|
|
482
|
+
|
|
483
|
+
out: list[dict[str, Any]] = []
|
|
495
484
|
|
|
496
485
|
async def _one(name: str, tr: MCPBaseTransport):
|
|
497
486
|
try:
|
|
@@ -513,45 +502,40 @@ class StreamManager:
|
|
|
513
502
|
async def call_tool(
|
|
514
503
|
self,
|
|
515
504
|
tool_name: str,
|
|
516
|
-
arguments:
|
|
517
|
-
server_name:
|
|
518
|
-
timeout:
|
|
519
|
-
) ->
|
|
505
|
+
arguments: dict[str, Any],
|
|
506
|
+
server_name: str | None = None,
|
|
507
|
+
timeout: float | None = None,
|
|
508
|
+
) -> dict[str, Any]:
|
|
520
509
|
"""Call a tool on the appropriate server with timeout support."""
|
|
521
510
|
if self._closed:
|
|
522
511
|
return {
|
|
523
512
|
"isError": True,
|
|
524
513
|
"error": "StreamManager is closed",
|
|
525
514
|
}
|
|
526
|
-
|
|
515
|
+
|
|
527
516
|
server_name = server_name or self.get_server_for_tool(tool_name)
|
|
528
517
|
if not server_name or server_name not in self.transports:
|
|
529
518
|
return {
|
|
530
519
|
"isError": True,
|
|
531
520
|
"error": f"No server found for tool: {tool_name}",
|
|
532
521
|
}
|
|
533
|
-
|
|
522
|
+
|
|
534
523
|
transport = self.transports[server_name]
|
|
535
|
-
|
|
524
|
+
|
|
536
525
|
if timeout is not None:
|
|
537
526
|
logger.debug("Calling tool '%s' with %ss timeout", tool_name, timeout)
|
|
538
527
|
try:
|
|
539
|
-
if hasattr(transport,
|
|
528
|
+
if hasattr(transport, "call_tool"):
|
|
540
529
|
import inspect
|
|
530
|
+
|
|
541
531
|
sig = inspect.signature(transport.call_tool)
|
|
542
|
-
if
|
|
532
|
+
if "timeout" in sig.parameters:
|
|
543
533
|
return await transport.call_tool(tool_name, arguments, timeout=timeout)
|
|
544
534
|
else:
|
|
545
|
-
return await asyncio.wait_for(
|
|
546
|
-
transport.call_tool(tool_name, arguments),
|
|
547
|
-
timeout=timeout
|
|
548
|
-
)
|
|
535
|
+
return await asyncio.wait_for(transport.call_tool(tool_name, arguments), timeout=timeout)
|
|
549
536
|
else:
|
|
550
|
-
return await asyncio.wait_for(
|
|
551
|
-
|
|
552
|
-
timeout=timeout
|
|
553
|
-
)
|
|
554
|
-
except asyncio.TimeoutError:
|
|
537
|
+
return await asyncio.wait_for(transport.call_tool(tool_name, arguments), timeout=timeout)
|
|
538
|
+
except TimeoutError:
|
|
555
539
|
logger.warning("Tool '%s' timed out after %ss", tool_name, timeout)
|
|
556
540
|
return {
|
|
557
541
|
"isError": True,
|
|
@@ -559,28 +543,28 @@ class StreamManager:
|
|
|
559
543
|
}
|
|
560
544
|
else:
|
|
561
545
|
return await transport.call_tool(tool_name, arguments)
|
|
562
|
-
|
|
546
|
+
|
|
563
547
|
# ------------------------------------------------------------------ #
|
|
564
548
|
# ENHANCED shutdown with robust error handling #
|
|
565
549
|
# ------------------------------------------------------------------ #
|
|
566
550
|
async def close(self) -> None:
|
|
567
551
|
"""
|
|
568
552
|
Close all transports safely with enhanced error handling.
|
|
569
|
-
|
|
553
|
+
|
|
570
554
|
ENHANCED: Uses asyncio.shield() to protect critical cleanup and
|
|
571
555
|
provides multiple fallback strategies for different failure modes.
|
|
572
556
|
"""
|
|
573
557
|
if self._closed:
|
|
574
558
|
logger.debug("StreamManager already closed")
|
|
575
559
|
return
|
|
576
|
-
|
|
560
|
+
|
|
577
561
|
if not self.transports:
|
|
578
562
|
logger.debug("No transports to close")
|
|
579
563
|
self._closed = True
|
|
580
564
|
return
|
|
581
|
-
|
|
565
|
+
|
|
582
566
|
logger.debug("Closing %d transports...", len(self.transports))
|
|
583
|
-
|
|
567
|
+
|
|
584
568
|
try:
|
|
585
569
|
# Use shield to protect the cleanup operation from cancellation
|
|
586
570
|
await asyncio.shield(self._do_close_all_transports())
|
|
@@ -598,7 +582,7 @@ class StreamManager:
|
|
|
598
582
|
"""Protected cleanup implementation with multiple strategies."""
|
|
599
583
|
close_results = []
|
|
600
584
|
transport_items = list(self.transports.items())
|
|
601
|
-
|
|
585
|
+
|
|
602
586
|
# Strategy 1: Try concurrent close with timeout
|
|
603
587
|
try:
|
|
604
588
|
await self._concurrent_close(transport_items, close_results)
|
|
@@ -606,36 +590,30 @@ class StreamManager:
|
|
|
606
590
|
logger.debug("Concurrent close failed: %s, falling back to sequential close", e)
|
|
607
591
|
# Strategy 2: Fall back to sequential close
|
|
608
592
|
await self._sequential_close(transport_items, close_results)
|
|
609
|
-
|
|
593
|
+
|
|
610
594
|
# Always clean up state
|
|
611
595
|
self._cleanup_state()
|
|
612
|
-
|
|
596
|
+
|
|
613
597
|
# Log summary
|
|
614
598
|
if close_results:
|
|
615
599
|
successful_closes = sum(1 for _, success, _ in close_results if success)
|
|
616
600
|
logger.debug("Transport cleanup: %d/%d closed successfully", successful_closes, len(close_results))
|
|
617
601
|
|
|
618
|
-
async def _concurrent_close(self, transport_items:
|
|
602
|
+
async def _concurrent_close(self, transport_items: list[tuple[str, MCPBaseTransport]], close_results: list) -> None:
|
|
619
603
|
"""Try to close all transports concurrently."""
|
|
620
604
|
close_tasks = []
|
|
621
605
|
for name, transport in transport_items:
|
|
622
|
-
task = asyncio.create_task(
|
|
623
|
-
self._close_single_transport(name, transport),
|
|
624
|
-
name=f"close_{name}"
|
|
625
|
-
)
|
|
606
|
+
task = asyncio.create_task(self._close_single_transport(name, transport), name=f"close_{name}")
|
|
626
607
|
close_tasks.append((name, task))
|
|
627
|
-
|
|
608
|
+
|
|
628
609
|
# Wait for all tasks with a reasonable timeout
|
|
629
610
|
if close_tasks:
|
|
630
611
|
try:
|
|
631
612
|
results = await asyncio.wait_for(
|
|
632
|
-
asyncio.gather(
|
|
633
|
-
|
|
634
|
-
return_exceptions=True
|
|
635
|
-
),
|
|
636
|
-
timeout=self._shutdown_timeout
|
|
613
|
+
asyncio.gather(*[task for _, task in close_tasks], return_exceptions=True),
|
|
614
|
+
timeout=self._shutdown_timeout,
|
|
637
615
|
)
|
|
638
|
-
|
|
616
|
+
|
|
639
617
|
# Process results
|
|
640
618
|
for i, (name, _) in enumerate(close_tasks):
|
|
641
619
|
result = results[i] if i < len(results) else None
|
|
@@ -645,34 +623,31 @@ class StreamManager:
|
|
|
645
623
|
else:
|
|
646
624
|
logger.debug("Transport %s closed successfully", name)
|
|
647
625
|
close_results.append((name, True, None))
|
|
648
|
-
|
|
649
|
-
except
|
|
626
|
+
|
|
627
|
+
except TimeoutError:
|
|
650
628
|
# Cancel any remaining tasks
|
|
651
629
|
for name, task in close_tasks:
|
|
652
630
|
if not task.done():
|
|
653
631
|
task.cancel()
|
|
654
632
|
close_results.append((name, False, "timeout"))
|
|
655
|
-
|
|
633
|
+
|
|
656
634
|
# Brief wait for cancellations to complete
|
|
657
|
-
|
|
635
|
+
with contextlib.suppress(TimeoutError):
|
|
658
636
|
await asyncio.wait_for(
|
|
659
|
-
asyncio.gather(*[task for _, task in close_tasks], return_exceptions=True),
|
|
660
|
-
timeout=0.5
|
|
637
|
+
asyncio.gather(*[task for _, task in close_tasks], return_exceptions=True), timeout=0.5
|
|
661
638
|
)
|
|
662
|
-
except asyncio.TimeoutError:
|
|
663
|
-
pass # Some tasks may not cancel cleanly
|
|
664
639
|
|
|
665
|
-
async def _sequential_close(self, transport_items:
|
|
640
|
+
async def _sequential_close(self, transport_items: list[tuple[str, MCPBaseTransport]], close_results: list) -> None:
|
|
666
641
|
"""Close transports one by one as fallback."""
|
|
667
642
|
for name, transport in transport_items:
|
|
668
643
|
try:
|
|
669
644
|
await asyncio.wait_for(
|
|
670
645
|
self._close_single_transport(name, transport),
|
|
671
|
-
timeout=0.5 # Short timeout per transport
|
|
646
|
+
timeout=0.5, # Short timeout per transport
|
|
672
647
|
)
|
|
673
648
|
logger.debug("Closed transport: %s", name)
|
|
674
649
|
close_results.append((name, True, None))
|
|
675
|
-
except
|
|
650
|
+
except TimeoutError:
|
|
676
651
|
logger.debug("Transport %s close timed out (normal during shutdown)", name)
|
|
677
652
|
close_results.append((name, False, "timeout"))
|
|
678
653
|
except asyncio.CancelledError:
|
|
@@ -685,7 +660,7 @@ class StreamManager:
|
|
|
685
660
|
async def _close_single_transport(self, name: str, transport: MCPBaseTransport) -> None:
|
|
686
661
|
"""Close a single transport with error handling."""
|
|
687
662
|
try:
|
|
688
|
-
if hasattr(transport,
|
|
663
|
+
if hasattr(transport, "close") and callable(transport.close):
|
|
689
664
|
await transport.close()
|
|
690
665
|
else:
|
|
691
666
|
logger.debug("Transport %s has no close method", name)
|
|
@@ -716,12 +691,12 @@ class StreamManager:
|
|
|
716
691
|
# ------------------------------------------------------------------ #
|
|
717
692
|
# backwards-compat: streams helper #
|
|
718
693
|
# ------------------------------------------------------------------ #
|
|
719
|
-
def get_streams(self) ->
|
|
694
|
+
def get_streams(self) -> list[tuple[Any, Any]]:
|
|
720
695
|
"""Return a list of (read_stream, write_stream) tuples for all transports."""
|
|
721
696
|
if self._closed:
|
|
722
697
|
return []
|
|
723
|
-
|
|
724
|
-
pairs:
|
|
698
|
+
|
|
699
|
+
pairs: list[tuple[Any, Any]] = []
|
|
725
700
|
|
|
726
701
|
for tr in self.transports.values():
|
|
727
702
|
if hasattr(tr, "get_streams") and callable(tr.get_streams):
|
|
@@ -736,7 +711,7 @@ class StreamManager:
|
|
|
736
711
|
return pairs
|
|
737
712
|
|
|
738
713
|
@property
|
|
739
|
-
def streams(self) ->
|
|
714
|
+
def streams(self) -> list[tuple[Any, Any]]:
|
|
740
715
|
"""Convenience alias for get_streams()."""
|
|
741
716
|
return self.get_streams()
|
|
742
717
|
|
|
@@ -751,34 +726,23 @@ class StreamManager:
|
|
|
751
726
|
"""Get the number of active transports."""
|
|
752
727
|
return len(self.transports)
|
|
753
728
|
|
|
754
|
-
async def health_check(self) ->
|
|
729
|
+
async def health_check(self) -> dict[str, Any]:
|
|
755
730
|
"""Perform a health check on all transports."""
|
|
756
731
|
if self._closed:
|
|
757
732
|
return {"status": "closed", "transports": {}}
|
|
758
|
-
|
|
759
|
-
health_info = {
|
|
760
|
-
|
|
761
|
-
"transport_count": len(self.transports),
|
|
762
|
-
"transports": {}
|
|
763
|
-
}
|
|
764
|
-
|
|
733
|
+
|
|
734
|
+
health_info = {"status": "active", "transport_count": len(self.transports), "transports": {}}
|
|
735
|
+
|
|
765
736
|
for name, transport in self.transports.items():
|
|
766
737
|
try:
|
|
767
738
|
ping_ok = await asyncio.wait_for(transport.send_ping(), timeout=5.0)
|
|
768
739
|
health_info["transports"][name] = {
|
|
769
740
|
"status": "healthy" if ping_ok else "unhealthy",
|
|
770
|
-
"ping_success": ping_ok
|
|
771
|
-
}
|
|
772
|
-
except asyncio.TimeoutError:
|
|
773
|
-
health_info["transports"][name] = {
|
|
774
|
-
"status": "timeout",
|
|
775
|
-
"ping_success": False
|
|
741
|
+
"ping_success": ping_ok,
|
|
776
742
|
}
|
|
743
|
+
except TimeoutError:
|
|
744
|
+
health_info["transports"][name] = {"status": "timeout", "ping_success": False}
|
|
777
745
|
except Exception as e:
|
|
778
|
-
health_info["transports"][name] = {
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
"error": str(e)
|
|
782
|
-
}
|
|
783
|
-
|
|
784
|
-
return health_info
|
|
746
|
+
health_info["transports"][name] = {"status": "error", "ping_success": False, "error": str(e)}
|
|
747
|
+
|
|
748
|
+
return health_info
|