chuk-tool-processor 0.6.4__py3-none-any.whl → 0.9.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chuk-tool-processor might be problematic. Click here for more details.
- chuk_tool_processor/core/__init__.py +32 -1
- chuk_tool_processor/core/exceptions.py +225 -13
- chuk_tool_processor/core/processor.py +135 -104
- chuk_tool_processor/execution/strategies/__init__.py +6 -0
- chuk_tool_processor/execution/strategies/inprocess_strategy.py +142 -150
- chuk_tool_processor/execution/strategies/subprocess_strategy.py +202 -206
- chuk_tool_processor/execution/tool_executor.py +82 -84
- chuk_tool_processor/execution/wrappers/__init__.py +42 -0
- chuk_tool_processor/execution/wrappers/caching.py +150 -116
- chuk_tool_processor/execution/wrappers/circuit_breaker.py +370 -0
- chuk_tool_processor/execution/wrappers/rate_limiting.py +76 -43
- chuk_tool_processor/execution/wrappers/retry.py +116 -78
- chuk_tool_processor/logging/__init__.py +23 -17
- chuk_tool_processor/logging/context.py +40 -45
- chuk_tool_processor/logging/formatter.py +22 -21
- chuk_tool_processor/logging/helpers.py +28 -42
- chuk_tool_processor/logging/metrics.py +13 -15
- chuk_tool_processor/mcp/__init__.py +8 -12
- chuk_tool_processor/mcp/mcp_tool.py +158 -114
- chuk_tool_processor/mcp/register_mcp_tools.py +22 -22
- chuk_tool_processor/mcp/setup_mcp_http_streamable.py +57 -17
- chuk_tool_processor/mcp/setup_mcp_sse.py +57 -17
- chuk_tool_processor/mcp/setup_mcp_stdio.py +11 -11
- chuk_tool_processor/mcp/stream_manager.py +333 -276
- chuk_tool_processor/mcp/transport/__init__.py +22 -29
- chuk_tool_processor/mcp/transport/base_transport.py +180 -44
- chuk_tool_processor/mcp/transport/http_streamable_transport.py +505 -325
- chuk_tool_processor/mcp/transport/models.py +100 -0
- chuk_tool_processor/mcp/transport/sse_transport.py +607 -276
- chuk_tool_processor/mcp/transport/stdio_transport.py +597 -116
- chuk_tool_processor/models/__init__.py +21 -1
- chuk_tool_processor/models/execution_strategy.py +16 -21
- chuk_tool_processor/models/streaming_tool.py +28 -25
- chuk_tool_processor/models/tool_call.py +49 -31
- chuk_tool_processor/models/tool_export_mixin.py +22 -8
- chuk_tool_processor/models/tool_result.py +40 -77
- chuk_tool_processor/models/tool_spec.py +350 -0
- chuk_tool_processor/models/validated_tool.py +36 -18
- chuk_tool_processor/observability/__init__.py +30 -0
- chuk_tool_processor/observability/metrics.py +312 -0
- chuk_tool_processor/observability/setup.py +105 -0
- chuk_tool_processor/observability/tracing.py +345 -0
- chuk_tool_processor/plugins/__init__.py +1 -1
- chuk_tool_processor/plugins/discovery.py +11 -11
- chuk_tool_processor/plugins/parsers/__init__.py +1 -1
- chuk_tool_processor/plugins/parsers/base.py +1 -2
- chuk_tool_processor/plugins/parsers/function_call_tool.py +13 -8
- chuk_tool_processor/plugins/parsers/json_tool.py +4 -3
- chuk_tool_processor/plugins/parsers/openai_tool.py +12 -7
- chuk_tool_processor/plugins/parsers/xml_tool.py +4 -4
- chuk_tool_processor/registry/__init__.py +12 -12
- chuk_tool_processor/registry/auto_register.py +22 -30
- chuk_tool_processor/registry/decorators.py +127 -129
- chuk_tool_processor/registry/interface.py +26 -23
- chuk_tool_processor/registry/metadata.py +27 -22
- chuk_tool_processor/registry/provider.py +17 -18
- chuk_tool_processor/registry/providers/__init__.py +16 -19
- chuk_tool_processor/registry/providers/memory.py +18 -25
- chuk_tool_processor/registry/tool_export.py +42 -51
- chuk_tool_processor/utils/validation.py +15 -16
- chuk_tool_processor-0.9.7.dist-info/METADATA +1813 -0
- chuk_tool_processor-0.9.7.dist-info/RECORD +67 -0
- chuk_tool_processor-0.6.4.dist-info/METADATA +0 -697
- chuk_tool_processor-0.6.4.dist-info/RECORD +0 -60
- {chuk_tool_processor-0.6.4.dist-info → chuk_tool_processor-0.9.7.dist-info}/WHEEL +0 -0
- {chuk_tool_processor-0.6.4.dist-info → chuk_tool_processor-0.9.7.dist-info}/top_level.txt +0 -0
|
@@ -1,24 +1,28 @@
|
|
|
1
1
|
# chuk_tool_processor/mcp/stream_manager.py
|
|
2
2
|
"""
|
|
3
|
-
StreamManager for CHUK Tool Processor - Enhanced with robust shutdown handling
|
|
3
|
+
StreamManager for CHUK Tool Processor - Enhanced with robust shutdown handling and headers support
|
|
4
4
|
"""
|
|
5
|
+
|
|
5
6
|
from __future__ import annotations
|
|
6
7
|
|
|
7
8
|
import asyncio
|
|
8
|
-
|
|
9
|
+
import contextlib
|
|
9
10
|
from contextlib import asynccontextmanager
|
|
11
|
+
from typing import Any
|
|
10
12
|
|
|
11
13
|
# --------------------------------------------------------------------------- #
|
|
12
14
|
# CHUK imports #
|
|
13
15
|
# --------------------------------------------------------------------------- #
|
|
14
|
-
from chuk_mcp.config import load_config
|
|
16
|
+
from chuk_mcp.config import load_config # type: ignore[import-untyped]
|
|
17
|
+
|
|
18
|
+
from chuk_tool_processor.logging import get_logger
|
|
15
19
|
from chuk_tool_processor.mcp.transport import (
|
|
20
|
+
HTTPStreamableTransport,
|
|
16
21
|
MCPBaseTransport,
|
|
17
|
-
StdioTransport,
|
|
18
22
|
SSETransport,
|
|
19
|
-
|
|
23
|
+
StdioTransport,
|
|
24
|
+
TimeoutConfig,
|
|
20
25
|
)
|
|
21
|
-
from chuk_tool_processor.logging import get_logger
|
|
22
26
|
|
|
23
27
|
logger = get_logger("chuk_tool_processor.mcp.stream_manager")
|
|
24
28
|
|
|
@@ -26,24 +30,24 @@ logger = get_logger("chuk_tool_processor.mcp.stream_manager")
|
|
|
26
30
|
class StreamManager:
|
|
27
31
|
"""
|
|
28
32
|
Manager for MCP server streams with support for multiple transport types.
|
|
29
|
-
|
|
30
|
-
Enhanced with robust shutdown handling
|
|
31
|
-
|
|
33
|
+
|
|
34
|
+
Enhanced with robust shutdown handling and proper headers support.
|
|
35
|
+
|
|
32
36
|
Updated to support the latest transports:
|
|
33
37
|
- STDIO (process-based)
|
|
34
|
-
- SSE (Server-Sent Events)
|
|
35
|
-
- HTTP Streamable (modern replacement for SSE, spec 2025-03-26)
|
|
38
|
+
- SSE (Server-Sent Events) with headers support
|
|
39
|
+
- HTTP Streamable (modern replacement for SSE, spec 2025-03-26) with graceful headers handling
|
|
36
40
|
"""
|
|
37
41
|
|
|
38
|
-
def __init__(self) -> None:
|
|
39
|
-
self.transports:
|
|
40
|
-
self.server_info:
|
|
41
|
-
self.tool_to_server_map:
|
|
42
|
-
self.server_names:
|
|
43
|
-
self.all_tools:
|
|
42
|
+
def __init__(self, timeout_config: TimeoutConfig | None = None) -> None:
|
|
43
|
+
self.transports: dict[str, MCPBaseTransport] = {}
|
|
44
|
+
self.server_info: list[dict[str, Any]] = []
|
|
45
|
+
self.tool_to_server_map: dict[str, str] = {}
|
|
46
|
+
self.server_names: dict[int, str] = {}
|
|
47
|
+
self.all_tools: list[dict[str, Any]] = []
|
|
44
48
|
self._lock = asyncio.Lock()
|
|
45
49
|
self._closed = False # Track if we've been closed
|
|
46
|
-
self.
|
|
50
|
+
self.timeout_config = timeout_config or TimeoutConfig()
|
|
47
51
|
|
|
48
52
|
# ------------------------------------------------------------------ #
|
|
49
53
|
# factory helpers with enhanced error handling #
|
|
@@ -52,81 +56,67 @@ class StreamManager:
|
|
|
52
56
|
async def create(
|
|
53
57
|
cls,
|
|
54
58
|
config_file: str,
|
|
55
|
-
servers:
|
|
56
|
-
server_names:
|
|
59
|
+
servers: list[str],
|
|
60
|
+
server_names: dict[int, str] | None = None,
|
|
57
61
|
transport_type: str = "stdio",
|
|
58
62
|
default_timeout: float = 30.0,
|
|
59
63
|
initialization_timeout: float = 60.0, # NEW: Timeout for entire initialization
|
|
60
|
-
) ->
|
|
64
|
+
) -> StreamManager:
|
|
61
65
|
"""Create StreamManager with timeout protection."""
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
timeout=initialization_timeout
|
|
73
|
-
)
|
|
74
|
-
return inst
|
|
75
|
-
except asyncio.TimeoutError:
|
|
76
|
-
logger.error(f"StreamManager initialization timed out after {initialization_timeout}s")
|
|
77
|
-
raise RuntimeError(f"StreamManager initialization timed out after {initialization_timeout}s")
|
|
66
|
+
inst = cls()
|
|
67
|
+
await inst.initialize(
|
|
68
|
+
config_file,
|
|
69
|
+
servers,
|
|
70
|
+
server_names,
|
|
71
|
+
transport_type,
|
|
72
|
+
default_timeout=default_timeout,
|
|
73
|
+
initialization_timeout=initialization_timeout,
|
|
74
|
+
)
|
|
75
|
+
return inst
|
|
78
76
|
|
|
79
77
|
@classmethod
|
|
80
78
|
async def create_with_sse(
|
|
81
79
|
cls,
|
|
82
|
-
servers:
|
|
83
|
-
server_names:
|
|
80
|
+
servers: list[dict[str, str]],
|
|
81
|
+
server_names: dict[int, str] | None = None,
|
|
84
82
|
connection_timeout: float = 10.0,
|
|
85
83
|
default_timeout: float = 30.0,
|
|
86
84
|
initialization_timeout: float = 60.0, # NEW
|
|
87
|
-
|
|
85
|
+
oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
|
|
86
|
+
) -> StreamManager:
|
|
88
87
|
"""Create StreamManager with SSE transport and timeout protection."""
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
)
|
|
100
|
-
return inst
|
|
101
|
-
except asyncio.TimeoutError:
|
|
102
|
-
logger.error(f"SSE StreamManager initialization timed out after {initialization_timeout}s")
|
|
103
|
-
raise RuntimeError(f"SSE StreamManager initialization timed out after {initialization_timeout}s")
|
|
88
|
+
inst = cls()
|
|
89
|
+
await inst.initialize_with_sse(
|
|
90
|
+
servers,
|
|
91
|
+
server_names,
|
|
92
|
+
connection_timeout=connection_timeout,
|
|
93
|
+
default_timeout=default_timeout,
|
|
94
|
+
initialization_timeout=initialization_timeout,
|
|
95
|
+
oauth_refresh_callback=oauth_refresh_callback, # NEW: Pass OAuth callback
|
|
96
|
+
)
|
|
97
|
+
return inst
|
|
104
98
|
|
|
105
99
|
@classmethod
|
|
106
100
|
async def create_with_http_streamable(
|
|
107
101
|
cls,
|
|
108
|
-
servers:
|
|
109
|
-
server_names:
|
|
102
|
+
servers: list[dict[str, str]],
|
|
103
|
+
server_names: dict[int, str] | None = None,
|
|
110
104
|
connection_timeout: float = 30.0,
|
|
111
105
|
default_timeout: float = 30.0,
|
|
112
106
|
initialization_timeout: float = 60.0, # NEW
|
|
113
|
-
|
|
107
|
+
oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
|
|
108
|
+
) -> StreamManager:
|
|
114
109
|
"""Create StreamManager with HTTP Streamable transport and timeout protection."""
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
)
|
|
126
|
-
return inst
|
|
127
|
-
except asyncio.TimeoutError:
|
|
128
|
-
logger.error(f"HTTP Streamable StreamManager initialization timed out after {initialization_timeout}s")
|
|
129
|
-
raise RuntimeError(f"HTTP Streamable StreamManager initialization timed out after {initialization_timeout}s")
|
|
110
|
+
inst = cls()
|
|
111
|
+
await inst.initialize_with_http_streamable(
|
|
112
|
+
servers,
|
|
113
|
+
server_names,
|
|
114
|
+
connection_timeout=connection_timeout,
|
|
115
|
+
default_timeout=default_timeout,
|
|
116
|
+
initialization_timeout=initialization_timeout,
|
|
117
|
+
oauth_refresh_callback=oauth_refresh_callback, # NEW: Pass OAuth callback
|
|
118
|
+
)
|
|
119
|
+
return inst
|
|
130
120
|
|
|
131
121
|
# ------------------------------------------------------------------ #
|
|
132
122
|
# NEW: Context manager support for automatic cleanup #
|
|
@@ -144,8 +134,8 @@ class StreamManager:
|
|
|
144
134
|
async def create_managed(
|
|
145
135
|
cls,
|
|
146
136
|
config_file: str,
|
|
147
|
-
servers:
|
|
148
|
-
server_names:
|
|
137
|
+
servers: list[str],
|
|
138
|
+
server_names: dict[int, str] | None = None,
|
|
149
139
|
transport_type: str = "stdio",
|
|
150
140
|
default_timeout: float = 30.0,
|
|
151
141
|
):
|
|
@@ -170,73 +160,113 @@ class StreamManager:
|
|
|
170
160
|
async def initialize(
|
|
171
161
|
self,
|
|
172
162
|
config_file: str,
|
|
173
|
-
servers:
|
|
174
|
-
server_names:
|
|
163
|
+
servers: list[str],
|
|
164
|
+
server_names: dict[int, str] | None = None,
|
|
175
165
|
transport_type: str = "stdio",
|
|
176
166
|
default_timeout: float = 30.0,
|
|
167
|
+
initialization_timeout: float = 60.0,
|
|
177
168
|
) -> None:
|
|
169
|
+
"""Initialize with graceful headers handling for all transport types."""
|
|
178
170
|
if self._closed:
|
|
179
171
|
raise RuntimeError("Cannot initialize a closed StreamManager")
|
|
180
|
-
|
|
172
|
+
|
|
181
173
|
async with self._lock:
|
|
182
174
|
self.server_names = server_names or {}
|
|
183
175
|
|
|
184
176
|
for idx, server_name in enumerate(servers):
|
|
185
177
|
try:
|
|
186
178
|
if transport_type == "stdio":
|
|
187
|
-
params = await load_config(config_file, server_name)
|
|
188
|
-
|
|
179
|
+
params, server_timeout = await load_config(config_file, server_name)
|
|
180
|
+
# Use per-server timeout if specified, otherwise use global default
|
|
181
|
+
effective_timeout = server_timeout if server_timeout is not None else default_timeout
|
|
182
|
+
logger.info(
|
|
183
|
+
f"Server '{server_name}' using timeout: {effective_timeout}s (per-server: {server_timeout}, default: {default_timeout})"
|
|
184
|
+
)
|
|
185
|
+
# Use initialization_timeout for connection_timeout since subprocess
|
|
186
|
+
# launch can take time (e.g., uvx downloading packages)
|
|
187
|
+
transport: MCPBaseTransport = StdioTransport(
|
|
188
|
+
params, connection_timeout=initialization_timeout, default_timeout=effective_timeout
|
|
189
|
+
)
|
|
189
190
|
elif transport_type == "sse":
|
|
190
|
-
logger.
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
191
|
+
logger.debug(
|
|
192
|
+
"Using SSE transport in initialize() - consider using initialize_with_sse() instead"
|
|
193
|
+
)
|
|
194
|
+
params, server_timeout = await load_config(config_file, server_name)
|
|
195
|
+
# Use per-server timeout if specified, otherwise use global default
|
|
196
|
+
effective_timeout = server_timeout if server_timeout is not None else default_timeout
|
|
197
|
+
|
|
198
|
+
if isinstance(params, dict) and "url" in params:
|
|
199
|
+
sse_url = params["url"]
|
|
200
|
+
api_key = params.get("api_key")
|
|
201
|
+
headers = params.get("headers", {})
|
|
196
202
|
else:
|
|
197
203
|
sse_url = "http://localhost:8000"
|
|
198
204
|
api_key = None
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
205
|
+
headers = {}
|
|
206
|
+
logger.debug("No URL configured for SSE transport, using default: %s", sse_url)
|
|
207
|
+
|
|
208
|
+
# Build SSE transport with optional headers
|
|
209
|
+
transport_params = {"url": sse_url, "api_key": api_key, "default_timeout": effective_timeout}
|
|
210
|
+
if headers:
|
|
211
|
+
transport_params["headers"] = headers
|
|
212
|
+
|
|
213
|
+
transport = SSETransport(**transport_params)
|
|
214
|
+
|
|
206
215
|
elif transport_type == "http_streamable":
|
|
207
|
-
logger.
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
216
|
+
logger.debug(
|
|
217
|
+
"Using HTTP Streamable transport in initialize() - consider using initialize_with_http_streamable() instead"
|
|
218
|
+
)
|
|
219
|
+
params, server_timeout = await load_config(config_file, server_name)
|
|
220
|
+
# Use per-server timeout if specified, otherwise use global default
|
|
221
|
+
effective_timeout = server_timeout if server_timeout is not None else default_timeout
|
|
222
|
+
|
|
223
|
+
if isinstance(params, dict) and "url" in params:
|
|
224
|
+
http_url = params["url"]
|
|
225
|
+
api_key = params.get("api_key")
|
|
226
|
+
headers = params.get("headers", {})
|
|
227
|
+
session_id = params.get("session_id")
|
|
214
228
|
else:
|
|
215
229
|
http_url = "http://localhost:8000"
|
|
216
230
|
api_key = None
|
|
231
|
+
headers = {}
|
|
217
232
|
session_id = None
|
|
218
|
-
logger.
|
|
219
|
-
|
|
220
|
-
transport
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
233
|
+
logger.debug("No URL configured for HTTP Streamable transport, using default: %s", http_url)
|
|
234
|
+
|
|
235
|
+
# Build HTTP transport (headers not supported yet)
|
|
236
|
+
transport_params = {
|
|
237
|
+
"url": http_url,
|
|
238
|
+
"api_key": api_key,
|
|
239
|
+
"default_timeout": effective_timeout,
|
|
240
|
+
"session_id": session_id,
|
|
241
|
+
}
|
|
242
|
+
# Note: headers not added until HTTPStreamableTransport supports them
|
|
243
|
+
if headers:
|
|
244
|
+
logger.debug("Headers provided but not supported in HTTPStreamableTransport yet")
|
|
245
|
+
|
|
246
|
+
transport = HTTPStreamableTransport(**transport_params)
|
|
247
|
+
|
|
226
248
|
else:
|
|
227
249
|
logger.error("Unsupported transport type: %s", transport_type)
|
|
228
250
|
continue
|
|
229
251
|
|
|
230
252
|
# Initialize with timeout protection
|
|
231
|
-
|
|
232
|
-
|
|
253
|
+
try:
|
|
254
|
+
if not await asyncio.wait_for(transport.initialize(), timeout=initialization_timeout):
|
|
255
|
+
logger.warning("Failed to init %s", server_name)
|
|
256
|
+
continue
|
|
257
|
+
except TimeoutError:
|
|
258
|
+
logger.error("Timeout initialising %s (timeout=%ss)", server_name, initialization_timeout)
|
|
233
259
|
continue
|
|
234
260
|
|
|
235
261
|
self.transports[server_name] = transport
|
|
236
262
|
|
|
237
|
-
# Ping and get tools with timeout protection
|
|
238
|
-
status =
|
|
239
|
-
|
|
263
|
+
# Ping and get tools with timeout protection (use longer timeouts for slow servers)
|
|
264
|
+
status = (
|
|
265
|
+
"Up"
|
|
266
|
+
if await asyncio.wait_for(transport.send_ping(), timeout=self.timeout_config.operation)
|
|
267
|
+
else "Down"
|
|
268
|
+
)
|
|
269
|
+
tools = await asyncio.wait_for(transport.get_tools(), timeout=self.timeout_config.operation)
|
|
240
270
|
|
|
241
271
|
for t in tools:
|
|
242
272
|
name = t.get("name")
|
|
@@ -252,13 +282,13 @@ class StreamManager:
|
|
|
252
282
|
"status": status,
|
|
253
283
|
}
|
|
254
284
|
)
|
|
255
|
-
logger.
|
|
256
|
-
except
|
|
285
|
+
logger.debug("Initialised %s - %d tool(s)", server_name, len(tools))
|
|
286
|
+
except TimeoutError:
|
|
257
287
|
logger.error("Timeout initialising %s", server_name)
|
|
258
288
|
except Exception as exc:
|
|
259
289
|
logger.error("Error initialising %s: %s", server_name, exc)
|
|
260
290
|
|
|
261
|
-
logger.
|
|
291
|
+
logger.debug(
|
|
262
292
|
"StreamManager ready - %d server(s), %d tool(s)",
|
|
263
293
|
len(self.transports),
|
|
264
294
|
len(self.all_tools),
|
|
@@ -266,14 +296,17 @@ class StreamManager:
|
|
|
266
296
|
|
|
267
297
|
async def initialize_with_sse(
|
|
268
298
|
self,
|
|
269
|
-
servers:
|
|
270
|
-
server_names:
|
|
299
|
+
servers: list[dict[str, str]],
|
|
300
|
+
server_names: dict[int, str] | None = None,
|
|
271
301
|
connection_timeout: float = 10.0,
|
|
272
302
|
default_timeout: float = 30.0,
|
|
303
|
+
initialization_timeout: float = 60.0,
|
|
304
|
+
oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
|
|
273
305
|
) -> None:
|
|
306
|
+
"""Initialize with SSE transport with optional headers support."""
|
|
274
307
|
if self._closed:
|
|
275
308
|
raise RuntimeError("Cannot initialize a closed StreamManager")
|
|
276
|
-
|
|
309
|
+
|
|
277
310
|
async with self._lock:
|
|
278
311
|
self.server_names = server_names or {}
|
|
279
312
|
|
|
@@ -283,20 +316,43 @@ class StreamManager:
|
|
|
283
316
|
logger.error("Bad server config: %s", cfg)
|
|
284
317
|
continue
|
|
285
318
|
try:
|
|
286
|
-
transport
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
319
|
+
# Build SSE transport parameters with optional headers
|
|
320
|
+
transport_params = {
|
|
321
|
+
"url": url,
|
|
322
|
+
"api_key": cfg.get("api_key"),
|
|
323
|
+
"connection_timeout": connection_timeout,
|
|
324
|
+
"default_timeout": default_timeout,
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
# Add headers if provided
|
|
328
|
+
headers = cfg.get("headers", {})
|
|
329
|
+
if headers:
|
|
330
|
+
logger.debug("SSE %s: Using configured headers: %s", name, list(headers.keys()))
|
|
331
|
+
transport_params["headers"] = headers
|
|
332
|
+
|
|
333
|
+
# Add OAuth refresh callback if provided (NEW)
|
|
334
|
+
if oauth_refresh_callback:
|
|
335
|
+
transport_params["oauth_refresh_callback"] = oauth_refresh_callback
|
|
336
|
+
logger.debug("SSE %s: OAuth refresh callback configured", name)
|
|
337
|
+
|
|
338
|
+
transport = SSETransport(**transport_params)
|
|
339
|
+
|
|
340
|
+
try:
|
|
341
|
+
if not await asyncio.wait_for(transport.initialize(), timeout=initialization_timeout):
|
|
342
|
+
logger.warning("Failed to init SSE %s", name)
|
|
343
|
+
continue
|
|
344
|
+
except TimeoutError:
|
|
345
|
+
logger.error("Timeout initialising SSE %s (timeout=%ss)", name, initialization_timeout)
|
|
295
346
|
continue
|
|
296
347
|
|
|
297
348
|
self.transports[name] = transport
|
|
298
|
-
|
|
299
|
-
|
|
349
|
+
# Use longer timeouts for slow servers (ping can take time after initialization)
|
|
350
|
+
status = (
|
|
351
|
+
"Up"
|
|
352
|
+
if await asyncio.wait_for(transport.send_ping(), timeout=self.timeout_config.operation)
|
|
353
|
+
else "Down"
|
|
354
|
+
)
|
|
355
|
+
tools = await asyncio.wait_for(transport.get_tools(), timeout=self.timeout_config.operation)
|
|
300
356
|
|
|
301
357
|
for t in tools:
|
|
302
358
|
tname = t.get("name")
|
|
@@ -304,16 +360,14 @@ class StreamManager:
|
|
|
304
360
|
self.tool_to_server_map[tname] = name
|
|
305
361
|
self.all_tools.extend(tools)
|
|
306
362
|
|
|
307
|
-
self.server_info.append(
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
logger.info("Initialised SSE %s - %d tool(s)", name, len(tools))
|
|
311
|
-
except asyncio.TimeoutError:
|
|
363
|
+
self.server_info.append({"id": idx, "name": name, "tools": len(tools), "status": status})
|
|
364
|
+
logger.debug("Initialised SSE %s - %d tool(s)", name, len(tools))
|
|
365
|
+
except TimeoutError:
|
|
312
366
|
logger.error("Timeout initialising SSE %s", name)
|
|
313
367
|
except Exception as exc:
|
|
314
368
|
logger.error("Error initialising SSE %s: %s", name, exc)
|
|
315
369
|
|
|
316
|
-
logger.
|
|
370
|
+
logger.debug(
|
|
317
371
|
"StreamManager ready - %d SSE server(s), %d tool(s)",
|
|
318
372
|
len(self.transports),
|
|
319
373
|
len(self.all_tools),
|
|
@@ -321,15 +375,19 @@ class StreamManager:
|
|
|
321
375
|
|
|
322
376
|
async def initialize_with_http_streamable(
|
|
323
377
|
self,
|
|
324
|
-
servers:
|
|
325
|
-
server_names:
|
|
378
|
+
servers: list[dict[str, str]],
|
|
379
|
+
server_names: dict[int, str] | None = None,
|
|
326
380
|
connection_timeout: float = 30.0,
|
|
327
381
|
default_timeout: float = 30.0,
|
|
382
|
+
initialization_timeout: float = 60.0,
|
|
383
|
+
oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
|
|
328
384
|
) -> None:
|
|
329
|
-
"""Initialize with HTTP Streamable transport
|
|
385
|
+
"""Initialize with HTTP Streamable transport with graceful headers handling."""
|
|
330
386
|
if self._closed:
|
|
331
387
|
raise RuntimeError("Cannot initialize a closed StreamManager")
|
|
332
|
-
|
|
388
|
+
|
|
389
|
+
logger.debug(f"initialize_with_http_streamable: initialization_timeout={initialization_timeout}")
|
|
390
|
+
|
|
333
391
|
async with self._lock:
|
|
334
392
|
self.server_names = server_names or {}
|
|
335
393
|
|
|
@@ -339,21 +397,48 @@ class StreamManager:
|
|
|
339
397
|
logger.error("Bad server config: %s", cfg)
|
|
340
398
|
continue
|
|
341
399
|
try:
|
|
342
|
-
transport
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
400
|
+
# Build HTTP Streamable transport parameters
|
|
401
|
+
transport_params = {
|
|
402
|
+
"url": url,
|
|
403
|
+
"api_key": cfg.get("api_key"),
|
|
404
|
+
"connection_timeout": connection_timeout,
|
|
405
|
+
"default_timeout": default_timeout,
|
|
406
|
+
"session_id": cfg.get("session_id"),
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
# Handle headers if provided
|
|
410
|
+
headers = cfg.get("headers", {})
|
|
411
|
+
if headers:
|
|
412
|
+
transport_params["headers"] = headers
|
|
413
|
+
logger.debug("HTTP Streamable %s: Custom headers configured: %s", name, list(headers.keys()))
|
|
414
|
+
|
|
415
|
+
# Add OAuth refresh callback if provided (NEW)
|
|
416
|
+
if oauth_refresh_callback:
|
|
417
|
+
transport_params["oauth_refresh_callback"] = oauth_refresh_callback
|
|
418
|
+
logger.debug("HTTP Streamable %s: OAuth refresh callback configured", name)
|
|
419
|
+
|
|
420
|
+
transport = HTTPStreamableTransport(**transport_params)
|
|
421
|
+
|
|
422
|
+
logger.debug(f"Calling transport.initialize() for {name} with timeout={initialization_timeout}s")
|
|
423
|
+
try:
|
|
424
|
+
if not await asyncio.wait_for(transport.initialize(), timeout=initialization_timeout):
|
|
425
|
+
logger.warning("Failed to init HTTP Streamable %s", name)
|
|
426
|
+
continue
|
|
427
|
+
except TimeoutError:
|
|
428
|
+
logger.error(
|
|
429
|
+
"Timeout initialising HTTP Streamable %s (timeout=%ss)", name, initialization_timeout
|
|
430
|
+
)
|
|
352
431
|
continue
|
|
432
|
+
logger.debug(f"Successfully initialized {name}")
|
|
353
433
|
|
|
354
434
|
self.transports[name] = transport
|
|
355
|
-
|
|
356
|
-
|
|
435
|
+
# Use longer timeouts for slow servers (ping can take time after initialization)
|
|
436
|
+
status = (
|
|
437
|
+
"Up"
|
|
438
|
+
if await asyncio.wait_for(transport.send_ping(), timeout=self.timeout_config.operation)
|
|
439
|
+
else "Down"
|
|
440
|
+
)
|
|
441
|
+
tools = await asyncio.wait_for(transport.get_tools(), timeout=self.timeout_config.operation)
|
|
357
442
|
|
|
358
443
|
for t in tools:
|
|
359
444
|
tname = t.get("name")
|
|
@@ -361,16 +446,14 @@ class StreamManager:
|
|
|
361
446
|
self.tool_to_server_map[tname] = name
|
|
362
447
|
self.all_tools.extend(tools)
|
|
363
448
|
|
|
364
|
-
self.server_info.append(
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
logger.info("Initialised HTTP Streamable %s - %d tool(s)", name, len(tools))
|
|
368
|
-
except asyncio.TimeoutError:
|
|
449
|
+
self.server_info.append({"id": idx, "name": name, "tools": len(tools), "status": status})
|
|
450
|
+
logger.debug("Initialised HTTP Streamable %s - %d tool(s)", name, len(tools))
|
|
451
|
+
except TimeoutError:
|
|
369
452
|
logger.error("Timeout initialising HTTP Streamable %s", name)
|
|
370
453
|
except Exception as exc:
|
|
371
454
|
logger.error("Error initialising HTTP Streamable %s: %s", name, exc)
|
|
372
455
|
|
|
373
|
-
logger.
|
|
456
|
+
logger.debug(
|
|
374
457
|
"StreamManager ready - %d HTTP Streamable server(s), %d tool(s)",
|
|
375
458
|
len(self.transports),
|
|
376
459
|
len(self.all_tools),
|
|
@@ -379,66 +462,64 @@ class StreamManager:
|
|
|
379
462
|
# ------------------------------------------------------------------ #
|
|
380
463
|
# queries #
|
|
381
464
|
# ------------------------------------------------------------------ #
|
|
382
|
-
def get_all_tools(self) ->
|
|
465
|
+
def get_all_tools(self) -> list[dict[str, Any]]:
|
|
383
466
|
return self.all_tools
|
|
384
467
|
|
|
385
|
-
def get_server_for_tool(self, tool_name: str) ->
|
|
468
|
+
def get_server_for_tool(self, tool_name: str) -> str | None:
|
|
386
469
|
return self.tool_to_server_map.get(tool_name)
|
|
387
470
|
|
|
388
|
-
def get_server_info(self) ->
|
|
471
|
+
def get_server_info(self) -> list[dict[str, Any]]:
|
|
389
472
|
return self.server_info
|
|
390
|
-
|
|
391
|
-
async def list_tools(self, server_name: str) ->
|
|
473
|
+
|
|
474
|
+
async def list_tools(self, server_name: str) -> list[dict[str, Any]]:
|
|
392
475
|
"""List all tools available from a specific server."""
|
|
393
476
|
if self._closed:
|
|
394
477
|
logger.warning("Cannot list tools: StreamManager is closed")
|
|
395
478
|
return []
|
|
396
|
-
|
|
479
|
+
|
|
397
480
|
if server_name not in self.transports:
|
|
398
|
-
logger.error(
|
|
481
|
+
logger.error("Server '%s' not found in transports", server_name)
|
|
399
482
|
return []
|
|
400
|
-
|
|
483
|
+
|
|
401
484
|
transport = self.transports[server_name]
|
|
402
|
-
|
|
485
|
+
|
|
403
486
|
try:
|
|
404
|
-
tools = await asyncio.wait_for(transport.get_tools(), timeout=
|
|
405
|
-
logger.debug(
|
|
487
|
+
tools = await asyncio.wait_for(transport.get_tools(), timeout=self.timeout_config.operation)
|
|
488
|
+
logger.debug("Found %d tools for server %s", len(tools), server_name)
|
|
406
489
|
return tools
|
|
407
|
-
except
|
|
408
|
-
logger.error(
|
|
490
|
+
except TimeoutError:
|
|
491
|
+
logger.error("Timeout listing tools for server %s", server_name)
|
|
409
492
|
return []
|
|
410
493
|
except Exception as e:
|
|
411
|
-
logger.error(
|
|
494
|
+
logger.error("Error listing tools for server %s: %s", server_name, e)
|
|
412
495
|
return []
|
|
413
496
|
|
|
414
497
|
# ------------------------------------------------------------------ #
|
|
415
498
|
# EXTRA HELPERS - ping / resources / prompts #
|
|
416
499
|
# ------------------------------------------------------------------ #
|
|
417
|
-
async def ping_servers(self) ->
|
|
500
|
+
async def ping_servers(self) -> list[dict[str, Any]]:
|
|
418
501
|
if self._closed:
|
|
419
502
|
return []
|
|
420
|
-
|
|
503
|
+
|
|
421
504
|
async def _ping_one(name: str, tr: MCPBaseTransport):
|
|
422
505
|
try:
|
|
423
|
-
ok = await asyncio.wait_for(tr.send_ping(), timeout=
|
|
506
|
+
ok = await asyncio.wait_for(tr.send_ping(), timeout=self.timeout_config.quick)
|
|
424
507
|
except Exception:
|
|
425
508
|
ok = False
|
|
426
509
|
return {"server": name, "ok": ok}
|
|
427
510
|
|
|
428
511
|
return await asyncio.gather(*(_ping_one(n, t) for n, t in self.transports.items()), return_exceptions=True)
|
|
429
512
|
|
|
430
|
-
async def list_resources(self) ->
|
|
513
|
+
async def list_resources(self) -> list[dict[str, Any]]:
|
|
431
514
|
if self._closed:
|
|
432
515
|
return []
|
|
433
|
-
|
|
434
|
-
out:
|
|
516
|
+
|
|
517
|
+
out: list[dict[str, Any]] = []
|
|
435
518
|
|
|
436
519
|
async def _one(name: str, tr: MCPBaseTransport):
|
|
437
520
|
try:
|
|
438
|
-
res = await asyncio.wait_for(tr.list_resources(), timeout=
|
|
439
|
-
resources = (
|
|
440
|
-
res.get("resources", []) if isinstance(res, dict) else res
|
|
441
|
-
)
|
|
521
|
+
res = await asyncio.wait_for(tr.list_resources(), timeout=self.timeout_config.operation)
|
|
522
|
+
resources = res.get("resources", []) if isinstance(res, dict) else res
|
|
442
523
|
for item in resources:
|
|
443
524
|
item = dict(item)
|
|
444
525
|
item["server"] = name
|
|
@@ -449,15 +530,15 @@ class StreamManager:
|
|
|
449
530
|
await asyncio.gather(*(_one(n, t) for n, t in self.transports.items()), return_exceptions=True)
|
|
450
531
|
return out
|
|
451
532
|
|
|
452
|
-
async def list_prompts(self) ->
|
|
533
|
+
async def list_prompts(self) -> list[dict[str, Any]]:
|
|
453
534
|
if self._closed:
|
|
454
535
|
return []
|
|
455
|
-
|
|
456
|
-
out:
|
|
536
|
+
|
|
537
|
+
out: list[dict[str, Any]] = []
|
|
457
538
|
|
|
458
539
|
async def _one(name: str, tr: MCPBaseTransport):
|
|
459
540
|
try:
|
|
460
|
-
res = await asyncio.wait_for(tr.list_prompts(), timeout=
|
|
541
|
+
res = await asyncio.wait_for(tr.list_prompts(), timeout=self.timeout_config.operation)
|
|
461
542
|
prompts = res.get("prompts", []) if isinstance(res, dict) else res
|
|
462
543
|
for item in prompts:
|
|
463
544
|
item = dict(item)
|
|
@@ -475,45 +556,40 @@ class StreamManager:
|
|
|
475
556
|
async def call_tool(
|
|
476
557
|
self,
|
|
477
558
|
tool_name: str,
|
|
478
|
-
arguments:
|
|
479
|
-
server_name:
|
|
480
|
-
timeout:
|
|
481
|
-
) ->
|
|
559
|
+
arguments: dict[str, Any],
|
|
560
|
+
server_name: str | None = None,
|
|
561
|
+
timeout: float | None = None,
|
|
562
|
+
) -> dict[str, Any]:
|
|
482
563
|
"""Call a tool on the appropriate server with timeout support."""
|
|
483
564
|
if self._closed:
|
|
484
565
|
return {
|
|
485
566
|
"isError": True,
|
|
486
567
|
"error": "StreamManager is closed",
|
|
487
568
|
}
|
|
488
|
-
|
|
569
|
+
|
|
489
570
|
server_name = server_name or self.get_server_for_tool(tool_name)
|
|
490
571
|
if not server_name or server_name not in self.transports:
|
|
491
572
|
return {
|
|
492
573
|
"isError": True,
|
|
493
574
|
"error": f"No server found for tool: {tool_name}",
|
|
494
575
|
}
|
|
495
|
-
|
|
576
|
+
|
|
496
577
|
transport = self.transports[server_name]
|
|
497
|
-
|
|
578
|
+
|
|
498
579
|
if timeout is not None:
|
|
499
580
|
logger.debug("Calling tool '%s' with %ss timeout", tool_name, timeout)
|
|
500
581
|
try:
|
|
501
|
-
if hasattr(transport,
|
|
582
|
+
if hasattr(transport, "call_tool"):
|
|
502
583
|
import inspect
|
|
584
|
+
|
|
503
585
|
sig = inspect.signature(transport.call_tool)
|
|
504
|
-
if
|
|
586
|
+
if "timeout" in sig.parameters:
|
|
505
587
|
return await transport.call_tool(tool_name, arguments, timeout=timeout)
|
|
506
588
|
else:
|
|
507
|
-
return await asyncio.wait_for(
|
|
508
|
-
transport.call_tool(tool_name, arguments),
|
|
509
|
-
timeout=timeout
|
|
510
|
-
)
|
|
589
|
+
return await asyncio.wait_for(transport.call_tool(tool_name, arguments), timeout=timeout)
|
|
511
590
|
else:
|
|
512
|
-
return await asyncio.wait_for(
|
|
513
|
-
|
|
514
|
-
timeout=timeout
|
|
515
|
-
)
|
|
516
|
-
except asyncio.TimeoutError:
|
|
591
|
+
return await asyncio.wait_for(transport.call_tool(tool_name, arguments), timeout=timeout)
|
|
592
|
+
except TimeoutError:
|
|
517
593
|
logger.warning("Tool '%s' timed out after %ss", tool_name, timeout)
|
|
518
594
|
return {
|
|
519
595
|
"isError": True,
|
|
@@ -521,28 +597,28 @@ class StreamManager:
|
|
|
521
597
|
}
|
|
522
598
|
else:
|
|
523
599
|
return await transport.call_tool(tool_name, arguments)
|
|
524
|
-
|
|
600
|
+
|
|
525
601
|
# ------------------------------------------------------------------ #
|
|
526
602
|
# ENHANCED shutdown with robust error handling #
|
|
527
603
|
# ------------------------------------------------------------------ #
|
|
528
604
|
async def close(self) -> None:
|
|
529
605
|
"""
|
|
530
606
|
Close all transports safely with enhanced error handling.
|
|
531
|
-
|
|
607
|
+
|
|
532
608
|
ENHANCED: Uses asyncio.shield() to protect critical cleanup and
|
|
533
609
|
provides multiple fallback strategies for different failure modes.
|
|
534
610
|
"""
|
|
535
611
|
if self._closed:
|
|
536
612
|
logger.debug("StreamManager already closed")
|
|
537
613
|
return
|
|
538
|
-
|
|
614
|
+
|
|
539
615
|
if not self.transports:
|
|
540
616
|
logger.debug("No transports to close")
|
|
541
617
|
self._closed = True
|
|
542
618
|
return
|
|
543
|
-
|
|
544
|
-
logger.debug(
|
|
545
|
-
|
|
619
|
+
|
|
620
|
+
logger.debug("Closing %d transports...", len(self.transports))
|
|
621
|
+
|
|
546
622
|
try:
|
|
547
623
|
# Use shield to protect the cleanup operation from cancellation
|
|
548
624
|
await asyncio.shield(self._do_close_all_transports())
|
|
@@ -551,7 +627,7 @@ class StreamManager:
|
|
|
551
627
|
logger.debug("Close operation cancelled, performing synchronous cleanup")
|
|
552
628
|
self._sync_cleanup()
|
|
553
629
|
except Exception as e:
|
|
554
|
-
logger.debug(
|
|
630
|
+
logger.debug("Error during close: %s", e)
|
|
555
631
|
self._sync_cleanup()
|
|
556
632
|
finally:
|
|
557
633
|
self._closed = True
|
|
@@ -560,99 +636,91 @@ class StreamManager:
|
|
|
560
636
|
"""Protected cleanup implementation with multiple strategies."""
|
|
561
637
|
close_results = []
|
|
562
638
|
transport_items = list(self.transports.items())
|
|
563
|
-
|
|
639
|
+
|
|
564
640
|
# Strategy 1: Try concurrent close with timeout
|
|
565
641
|
try:
|
|
566
642
|
await self._concurrent_close(transport_items, close_results)
|
|
567
643
|
except Exception as e:
|
|
568
|
-
logger.debug(
|
|
644
|
+
logger.debug("Concurrent close failed: %s, falling back to sequential close", e)
|
|
569
645
|
# Strategy 2: Fall back to sequential close
|
|
570
646
|
await self._sequential_close(transport_items, close_results)
|
|
571
|
-
|
|
647
|
+
|
|
572
648
|
# Always clean up state
|
|
573
649
|
self._cleanup_state()
|
|
574
|
-
|
|
650
|
+
|
|
575
651
|
# Log summary
|
|
576
652
|
if close_results:
|
|
577
653
|
successful_closes = sum(1 for _, success, _ in close_results if success)
|
|
578
|
-
logger.debug(
|
|
654
|
+
logger.debug("Transport cleanup: %d/%d closed successfully", successful_closes, len(close_results))
|
|
579
655
|
|
|
580
|
-
async def _concurrent_close(self, transport_items:
|
|
656
|
+
async def _concurrent_close(self, transport_items: list[tuple[str, MCPBaseTransport]], close_results: list) -> None:
|
|
581
657
|
"""Try to close all transports concurrently."""
|
|
582
658
|
close_tasks = []
|
|
583
659
|
for name, transport in transport_items:
|
|
584
|
-
task = asyncio.create_task(
|
|
585
|
-
self._close_single_transport(name, transport),
|
|
586
|
-
name=f"close_{name}"
|
|
587
|
-
)
|
|
660
|
+
task = asyncio.create_task(self._close_single_transport(name, transport), name=f"close_{name}")
|
|
588
661
|
close_tasks.append((name, task))
|
|
589
|
-
|
|
662
|
+
|
|
590
663
|
# Wait for all tasks with a reasonable timeout
|
|
591
664
|
if close_tasks:
|
|
592
665
|
try:
|
|
593
666
|
results = await asyncio.wait_for(
|
|
594
|
-
asyncio.gather(
|
|
595
|
-
|
|
596
|
-
return_exceptions=True
|
|
597
|
-
),
|
|
598
|
-
timeout=self._shutdown_timeout
|
|
667
|
+
asyncio.gather(*[task for _, task in close_tasks], return_exceptions=True),
|
|
668
|
+
timeout=self.timeout_config.shutdown,
|
|
599
669
|
)
|
|
600
|
-
|
|
670
|
+
|
|
601
671
|
# Process results
|
|
602
672
|
for i, (name, _) in enumerate(close_tasks):
|
|
603
673
|
result = results[i] if i < len(results) else None
|
|
604
674
|
if isinstance(result, Exception):
|
|
605
|
-
logger.debug(
|
|
675
|
+
logger.debug("Transport %s close failed: %s", name, result)
|
|
606
676
|
close_results.append((name, False, str(result)))
|
|
607
677
|
else:
|
|
608
|
-
logger.debug(
|
|
678
|
+
logger.debug("Transport %s closed successfully", name)
|
|
609
679
|
close_results.append((name, True, None))
|
|
610
|
-
|
|
611
|
-
except
|
|
680
|
+
|
|
681
|
+
except TimeoutError:
|
|
612
682
|
# Cancel any remaining tasks
|
|
613
683
|
for name, task in close_tasks:
|
|
614
684
|
if not task.done():
|
|
615
685
|
task.cancel()
|
|
616
686
|
close_results.append((name, False, "timeout"))
|
|
617
|
-
|
|
687
|
+
|
|
618
688
|
# Brief wait for cancellations to complete
|
|
619
|
-
|
|
689
|
+
with contextlib.suppress(TimeoutError):
|
|
620
690
|
await asyncio.wait_for(
|
|
621
691
|
asyncio.gather(*[task for _, task in close_tasks], return_exceptions=True),
|
|
622
|
-
timeout=
|
|
692
|
+
timeout=self.timeout_config.shutdown,
|
|
623
693
|
)
|
|
624
|
-
except asyncio.TimeoutError:
|
|
625
|
-
pass # Some tasks may not cancel cleanly
|
|
626
694
|
|
|
627
|
-
async def _sequential_close(self, transport_items:
|
|
695
|
+
async def _sequential_close(self, transport_items: list[tuple[str, MCPBaseTransport]], close_results: list) -> None:
|
|
628
696
|
"""Close transports one by one as fallback."""
|
|
629
697
|
for name, transport in transport_items:
|
|
630
698
|
try:
|
|
631
699
|
await asyncio.wait_for(
|
|
632
700
|
self._close_single_transport(name, transport),
|
|
633
|
-
timeout=
|
|
701
|
+
timeout=self.timeout_config.shutdown,
|
|
634
702
|
)
|
|
635
|
-
logger.debug(
|
|
703
|
+
logger.debug("Closed transport: %s", name)
|
|
636
704
|
close_results.append((name, True, None))
|
|
637
|
-
except
|
|
638
|
-
logger.debug(
|
|
705
|
+
except TimeoutError:
|
|
706
|
+
logger.debug("Transport %s close timed out (normal during shutdown)", name)
|
|
639
707
|
close_results.append((name, False, "timeout"))
|
|
640
708
|
except asyncio.CancelledError:
|
|
641
|
-
logger.debug(
|
|
709
|
+
logger.debug("Transport %s close cancelled during event loop shutdown", name)
|
|
642
710
|
close_results.append((name, False, "cancelled"))
|
|
643
711
|
except Exception as e:
|
|
644
|
-
logger.debug(
|
|
712
|
+
logger.debug("Error closing transport %s: %s", name, e)
|
|
645
713
|
close_results.append((name, False, str(e)))
|
|
646
714
|
|
|
647
715
|
async def _close_single_transport(self, name: str, transport: MCPBaseTransport) -> None:
|
|
648
716
|
"""Close a single transport with error handling."""
|
|
649
717
|
try:
|
|
650
|
-
if hasattr(transport,
|
|
718
|
+
if hasattr(transport, "close") and callable(transport.close):
|
|
651
719
|
await transport.close()
|
|
652
720
|
else:
|
|
653
|
-
logger.debug(
|
|
721
|
+
logger.debug("Transport %s has no close method", name)
|
|
654
722
|
except Exception as e:
|
|
655
|
-
logger.debug(
|
|
723
|
+
logger.debug("Error closing transport %s: %s", name, e)
|
|
656
724
|
raise
|
|
657
725
|
|
|
658
726
|
def _sync_cleanup(self) -> None:
|
|
@@ -660,9 +728,9 @@ class StreamManager:
|
|
|
660
728
|
try:
|
|
661
729
|
transport_count = len(self.transports)
|
|
662
730
|
self._cleanup_state()
|
|
663
|
-
logger.debug(
|
|
731
|
+
logger.debug("Synchronous cleanup completed for %d transports", transport_count)
|
|
664
732
|
except Exception as e:
|
|
665
|
-
logger.debug(
|
|
733
|
+
logger.debug("Error during synchronous cleanup: %s", e)
|
|
666
734
|
|
|
667
735
|
def _cleanup_state(self) -> None:
|
|
668
736
|
"""Clean up internal state synchronously."""
|
|
@@ -673,17 +741,17 @@ class StreamManager:
|
|
|
673
741
|
self.all_tools.clear()
|
|
674
742
|
self.server_names.clear()
|
|
675
743
|
except Exception as e:
|
|
676
|
-
logger.debug(
|
|
744
|
+
logger.debug("Error during state cleanup: %s", e)
|
|
677
745
|
|
|
678
746
|
# ------------------------------------------------------------------ #
|
|
679
747
|
# backwards-compat: streams helper #
|
|
680
748
|
# ------------------------------------------------------------------ #
|
|
681
|
-
def get_streams(self) ->
|
|
749
|
+
def get_streams(self) -> list[tuple[Any, Any]]:
|
|
682
750
|
"""Return a list of (read_stream, write_stream) tuples for all transports."""
|
|
683
751
|
if self._closed:
|
|
684
752
|
return []
|
|
685
|
-
|
|
686
|
-
pairs:
|
|
753
|
+
|
|
754
|
+
pairs: list[tuple[Any, Any]] = []
|
|
687
755
|
|
|
688
756
|
for tr in self.transports.values():
|
|
689
757
|
if hasattr(tr, "get_streams") and callable(tr.get_streams):
|
|
@@ -698,7 +766,7 @@ class StreamManager:
|
|
|
698
766
|
return pairs
|
|
699
767
|
|
|
700
768
|
@property
|
|
701
|
-
def streams(self) ->
|
|
769
|
+
def streams(self) -> list[tuple[Any, Any]]:
|
|
702
770
|
"""Convenience alias for get_streams()."""
|
|
703
771
|
return self.get_streams()
|
|
704
772
|
|
|
@@ -713,34 +781,23 @@ class StreamManager:
|
|
|
713
781
|
"""Get the number of active transports."""
|
|
714
782
|
return len(self.transports)
|
|
715
783
|
|
|
716
|
-
async def health_check(self) ->
|
|
784
|
+
async def health_check(self) -> dict[str, Any]:
|
|
717
785
|
"""Perform a health check on all transports."""
|
|
718
786
|
if self._closed:
|
|
719
787
|
return {"status": "closed", "transports": {}}
|
|
720
|
-
|
|
721
|
-
health_info = {
|
|
722
|
-
|
|
723
|
-
"transport_count": len(self.transports),
|
|
724
|
-
"transports": {}
|
|
725
|
-
}
|
|
726
|
-
|
|
788
|
+
|
|
789
|
+
health_info = {"status": "active", "transport_count": len(self.transports), "transports": {}}
|
|
790
|
+
|
|
727
791
|
for name, transport in self.transports.items():
|
|
728
792
|
try:
|
|
729
|
-
ping_ok = await asyncio.wait_for(transport.send_ping(), timeout=
|
|
793
|
+
ping_ok = await asyncio.wait_for(transport.send_ping(), timeout=self.timeout_config.quick)
|
|
730
794
|
health_info["transports"][name] = {
|
|
731
795
|
"status": "healthy" if ping_ok else "unhealthy",
|
|
732
|
-
"ping_success": ping_ok
|
|
733
|
-
}
|
|
734
|
-
except asyncio.TimeoutError:
|
|
735
|
-
health_info["transports"][name] = {
|
|
736
|
-
"status": "timeout",
|
|
737
|
-
"ping_success": False
|
|
796
|
+
"ping_success": ping_ok,
|
|
738
797
|
}
|
|
798
|
+
except TimeoutError:
|
|
799
|
+
health_info["transports"][name] = {"status": "timeout", "ping_success": False}
|
|
739
800
|
except Exception as e:
|
|
740
|
-
health_info["transports"][name] = {
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
"error": str(e)
|
|
744
|
-
}
|
|
745
|
-
|
|
746
|
-
return health_info
|
|
801
|
+
health_info["transports"][name] = {"status": "error", "ping_success": False, "error": str(e)}
|
|
802
|
+
|
|
803
|
+
return health_info
|