fast-agent-mcp 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fast-agent-mcp might be problematic. Click here for more details.
- fast_agent/agents/llm_agent.py +24 -0
- fast_agent/agents/mcp_agent.py +7 -1
- fast_agent/core/direct_factory.py +20 -8
- fast_agent/llm/provider/anthropic/llm_anthropic.py +107 -62
- fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py +4 -3
- fast_agent/llm/provider/google/google_converter.py +8 -41
- fast_agent/llm/provider/openai/llm_openai.py +3 -3
- fast_agent/mcp/mcp_agent_client_session.py +45 -2
- fast_agent/mcp/mcp_aggregator.py +314 -33
- fast_agent/mcp/mcp_connection_manager.py +86 -10
- fast_agent/mcp/stdio_tracking_simple.py +59 -0
- fast_agent/mcp/streamable_http_tracking.py +309 -0
- fast_agent/mcp/transport_tracking.py +600 -0
- fast_agent/resources/examples/data-analysis/analysis.py +7 -3
- fast_agent/ui/console_display.py +22 -1
- fast_agent/ui/elicitation_style.py +7 -7
- fast_agent/ui/enhanced_prompt.py +21 -1
- fast_agent/ui/interactive_prompt.py +5 -0
- fast_agent/ui/mcp_display.py +708 -0
- {fast_agent_mcp-0.3.8.dist-info → fast_agent_mcp-0.3.10.dist-info}/METADATA +5 -5
- {fast_agent_mcp-0.3.8.dist-info → fast_agent_mcp-0.3.10.dist-info}/RECORD +24 -20
- {fast_agent_mcp-0.3.8.dist-info → fast_agent_mcp-0.3.10.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.3.8.dist-info → fast_agent_mcp-0.3.10.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.3.8.dist-info → fast_agent_mcp-0.3.10.dist-info}/licenses/LICENSE +0 -0
|
@@ -21,10 +21,9 @@ from mcp.client.sse import sse_client
|
|
|
21
21
|
from mcp.client.stdio import (
|
|
22
22
|
StdioServerParameters,
|
|
23
23
|
get_default_environment,
|
|
24
|
-
stdio_client,
|
|
25
24
|
)
|
|
26
|
-
from mcp.client.streamable_http import GetSessionIdCallback
|
|
27
|
-
from mcp.types import JSONRPCMessage, ServerCapabilities
|
|
25
|
+
from mcp.client.streamable_http import GetSessionIdCallback
|
|
26
|
+
from mcp.types import Implementation, JSONRPCMessage, ServerCapabilities
|
|
28
27
|
|
|
29
28
|
from fast_agent.config import MCPServerSettings
|
|
30
29
|
from fast_agent.context_dependent import ContextDependent
|
|
@@ -34,6 +33,9 @@ from fast_agent.event_progress import ProgressAction
|
|
|
34
33
|
from fast_agent.mcp.logger_textio import get_stderr_handler
|
|
35
34
|
from fast_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
|
|
36
35
|
from fast_agent.mcp.oauth_client import build_oauth_provider
|
|
36
|
+
from fast_agent.mcp.stdio_tracking_simple import tracking_stdio_client
|
|
37
|
+
from fast_agent.mcp.streamable_http_tracking import tracking_streamablehttp_client
|
|
38
|
+
from fast_agent.mcp.transport_tracking import TransportChannelMetrics
|
|
37
39
|
|
|
38
40
|
if TYPE_CHECKING:
|
|
39
41
|
from fast_agent.context import Context
|
|
@@ -107,6 +109,14 @@ class ServerConnection:
|
|
|
107
109
|
|
|
108
110
|
# Server instructions from initialization
|
|
109
111
|
self.server_instructions: str | None = None
|
|
112
|
+
self.server_capabilities: ServerCapabilities | None = None
|
|
113
|
+
self.server_implementation: Implementation | None = None
|
|
114
|
+
self.client_capabilities: dict | None = None
|
|
115
|
+
self.server_instructions_available: bool = False
|
|
116
|
+
self.server_instructions_enabled: bool = server_config.include_instructions if server_config else True
|
|
117
|
+
self.session_id: str | None = None
|
|
118
|
+
self._get_session_id_cb: GetSessionIdCallback | None = None
|
|
119
|
+
self.transport_metrics: TransportChannelMetrics | None = None
|
|
110
120
|
|
|
111
121
|
def is_healthy(self) -> bool:
|
|
112
122
|
"""Check if the server connection is healthy and ready to use."""
|
|
@@ -138,15 +148,32 @@ class ServerConnection:
|
|
|
138
148
|
result = await self.session.initialize()
|
|
139
149
|
|
|
140
150
|
self.server_capabilities = result.capabilities
|
|
151
|
+
# InitializeResult exposes server info via `serverInfo`; keep fallback for older fields
|
|
152
|
+
implementation = getattr(result, "serverInfo", None)
|
|
153
|
+
if implementation is None:
|
|
154
|
+
implementation = getattr(result, "implementation", None)
|
|
155
|
+
self.server_implementation = implementation
|
|
156
|
+
|
|
157
|
+
raw_instructions = getattr(result, "instructions", None)
|
|
158
|
+
self.server_instructions_available = bool(raw_instructions)
|
|
141
159
|
|
|
142
160
|
# Store instructions if provided by the server and enabled in config
|
|
143
161
|
if self.server_config.include_instructions:
|
|
144
|
-
self.server_instructions =
|
|
162
|
+
self.server_instructions = raw_instructions
|
|
145
163
|
if self.server_instructions:
|
|
146
|
-
logger.debug(
|
|
164
|
+
logger.debug(
|
|
165
|
+
f"{self.server_name}: Received server instructions",
|
|
166
|
+
data={"instructions": self.server_instructions},
|
|
167
|
+
)
|
|
147
168
|
else:
|
|
148
169
|
self.server_instructions = None
|
|
149
|
-
|
|
170
|
+
if self.server_instructions_available:
|
|
171
|
+
logger.debug(
|
|
172
|
+
f"{self.server_name}: Server instructions disabled by configuration",
|
|
173
|
+
data={"instructions": raw_instructions},
|
|
174
|
+
)
|
|
175
|
+
else:
|
|
176
|
+
logger.debug(f"{self.server_name}: No server instructions provided")
|
|
150
177
|
|
|
151
178
|
# If there's an init hook, run it
|
|
152
179
|
|
|
@@ -175,10 +202,15 @@ class ServerConnection:
|
|
|
175
202
|
)
|
|
176
203
|
|
|
177
204
|
session = self._client_session_factory(
|
|
178
|
-
read_stream,
|
|
205
|
+
read_stream,
|
|
206
|
+
send_stream,
|
|
207
|
+
read_timeout,
|
|
208
|
+
server_config=self.server_config,
|
|
209
|
+
transport_metrics=self.transport_metrics,
|
|
179
210
|
)
|
|
180
211
|
|
|
181
212
|
self.session = session
|
|
213
|
+
self.client_capabilities = getattr(session, "client_capabilities", None)
|
|
182
214
|
|
|
183
215
|
return session
|
|
184
216
|
|
|
@@ -192,11 +224,30 @@ async def _server_lifecycle_task(server_conn: ServerConnection) -> None:
|
|
|
192
224
|
try:
|
|
193
225
|
transport_context = server_conn._transport_context_factory()
|
|
194
226
|
|
|
195
|
-
async with transport_context as (read_stream, write_stream,
|
|
227
|
+
async with transport_context as (read_stream, write_stream, get_session_id_cb):
|
|
228
|
+
server_conn._get_session_id_cb = get_session_id_cb
|
|
229
|
+
|
|
230
|
+
if get_session_id_cb is not None:
|
|
231
|
+
try:
|
|
232
|
+
server_conn.session_id = get_session_id_cb()
|
|
233
|
+
except Exception:
|
|
234
|
+
logger.debug(f"{server_name}: Unable to retrieve session id from transport")
|
|
235
|
+
elif server_conn.server_config.transport == "stdio":
|
|
236
|
+
server_conn.session_id = "local"
|
|
237
|
+
|
|
196
238
|
server_conn.create_session(read_stream, write_stream)
|
|
197
239
|
|
|
198
240
|
async with server_conn.session:
|
|
199
241
|
await server_conn.initialize_session()
|
|
242
|
+
|
|
243
|
+
if get_session_id_cb is not None:
|
|
244
|
+
try:
|
|
245
|
+
server_conn.session_id = get_session_id_cb() or server_conn.session_id
|
|
246
|
+
except Exception:
|
|
247
|
+
logger.debug(f"{server_name}: Unable to refresh session id after init")
|
|
248
|
+
elif server_conn.server_config.transport == "stdio":
|
|
249
|
+
server_conn.session_id = "local"
|
|
250
|
+
|
|
200
251
|
await server_conn.wait_for_shutdown_request()
|
|
201
252
|
|
|
202
253
|
except HTTPStatusError as http_exc:
|
|
@@ -353,6 +404,8 @@ class MCPConnectionManager(ContextDependent):
|
|
|
353
404
|
|
|
354
405
|
logger.debug(f"{server_name}: Found server configuration=", data=config.model_dump())
|
|
355
406
|
|
|
407
|
+
transport_metrics = TransportChannelMetrics() if config.transport in ("http", "stdio") else None
|
|
408
|
+
|
|
356
409
|
def transport_context_factory():
|
|
357
410
|
if config.transport == "stdio":
|
|
358
411
|
if not config.command:
|
|
@@ -369,7 +422,11 @@ class MCPConnectionManager(ContextDependent):
|
|
|
369
422
|
error_handler = get_stderr_handler(server_name)
|
|
370
423
|
# Explicitly ensure we're using our custom logger for stderr
|
|
371
424
|
logger.debug(f"{server_name}: Creating stdio client with custom error handler")
|
|
372
|
-
|
|
425
|
+
|
|
426
|
+
channel_hook = transport_metrics.record_event if transport_metrics else None
|
|
427
|
+
return _add_none_to_context(
|
|
428
|
+
tracking_stdio_client(server_params, channel_hook=channel_hook, errlog=error_handler)
|
|
429
|
+
)
|
|
373
430
|
elif config.transport == "sse":
|
|
374
431
|
if not config.url:
|
|
375
432
|
raise ValueError(
|
|
@@ -401,7 +458,23 @@ class MCPConnectionManager(ContextDependent):
|
|
|
401
458
|
if oauth_auth is not None:
|
|
402
459
|
headers.pop("Authorization", None)
|
|
403
460
|
headers.pop("X-HF-Authorization", None)
|
|
404
|
-
|
|
461
|
+
channel_hook = None
|
|
462
|
+
if transport_metrics is not None:
|
|
463
|
+
def channel_hook(event):
|
|
464
|
+
try:
|
|
465
|
+
transport_metrics.record_event(event)
|
|
466
|
+
except Exception: # pragma: no cover - defensive guard
|
|
467
|
+
logger.debug(
|
|
468
|
+
"%s: transport metrics hook failed", server_name,
|
|
469
|
+
exc_info=True,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
return tracking_streamablehttp_client(
|
|
473
|
+
config.url,
|
|
474
|
+
headers,
|
|
475
|
+
auth=oauth_auth,
|
|
476
|
+
channel_hook=channel_hook,
|
|
477
|
+
)
|
|
405
478
|
else:
|
|
406
479
|
raise ValueError(f"Unsupported transport: {config.transport}")
|
|
407
480
|
|
|
@@ -412,6 +485,9 @@ class MCPConnectionManager(ContextDependent):
|
|
|
412
485
|
client_session_factory=client_session_factory,
|
|
413
486
|
)
|
|
414
487
|
|
|
488
|
+
if transport_metrics is not None:
|
|
489
|
+
server_conn.transport_metrics = transport_metrics
|
|
490
|
+
|
|
415
491
|
async with self._lock:
|
|
416
492
|
# Check if already running
|
|
417
493
|
if server_name in self.running_servers:
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from typing import TYPE_CHECKING, AsyncGenerator, Callable
|
|
6
|
+
|
|
7
|
+
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
8
|
+
|
|
9
|
+
from fast_agent.mcp.transport_tracking import ChannelEvent
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from anyio.abc import ObjectReceiveStream, ObjectSendStream
|
|
13
|
+
from mcp.shared.message import SessionMessage
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
ChannelHook = Callable[[ChannelEvent], None]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@asynccontextmanager
|
|
21
|
+
async def tracking_stdio_client(
|
|
22
|
+
server_params: StdioServerParameters,
|
|
23
|
+
*,
|
|
24
|
+
channel_hook: ChannelHook | None = None,
|
|
25
|
+
errlog: Callable[[str], None] | None = None,
|
|
26
|
+
) -> AsyncGenerator[
|
|
27
|
+
tuple[ObjectReceiveStream[SessionMessage | Exception], ObjectSendStream[SessionMessage]], None
|
|
28
|
+
]:
|
|
29
|
+
"""Context manager for stdio client with basic connection tracking."""
|
|
30
|
+
|
|
31
|
+
def emit_channel_event(event_type: str, detail: str | None = None) -> None:
|
|
32
|
+
if channel_hook is None:
|
|
33
|
+
return
|
|
34
|
+
try:
|
|
35
|
+
channel_hook(
|
|
36
|
+
ChannelEvent(
|
|
37
|
+
channel="stdio",
|
|
38
|
+
event_type=event_type, # type: ignore[arg-type]
|
|
39
|
+
detail=detail,
|
|
40
|
+
)
|
|
41
|
+
)
|
|
42
|
+
except Exception: # pragma: no cover - hook errors must not break transport
|
|
43
|
+
logger.exception("Channel hook raised an exception")
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
# Emit connection event
|
|
47
|
+
emit_channel_event("connect")
|
|
48
|
+
|
|
49
|
+
# Use the original stdio_client without stream interception
|
|
50
|
+
async with stdio_client(server_params, errlog=errlog) as (read_stream, write_stream):
|
|
51
|
+
yield read_stream, write_stream
|
|
52
|
+
|
|
53
|
+
except Exception as exc:
|
|
54
|
+
# Emit error event
|
|
55
|
+
emit_channel_event("error", detail=str(exc))
|
|
56
|
+
raise
|
|
57
|
+
finally:
|
|
58
|
+
# Emit disconnection event
|
|
59
|
+
emit_channel_event("disconnect")
|
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from typing import TYPE_CHECKING, AsyncGenerator, Awaitable, Callable
|
|
6
|
+
|
|
7
|
+
import anyio
|
|
8
|
+
import httpx
|
|
9
|
+
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
|
|
10
|
+
from mcp.client.streamable_http import (
|
|
11
|
+
RequestContext,
|
|
12
|
+
RequestId,
|
|
13
|
+
StreamableHTTPTransport,
|
|
14
|
+
StreamWriter,
|
|
15
|
+
)
|
|
16
|
+
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
|
|
17
|
+
from mcp.shared.message import SessionMessage
|
|
18
|
+
from mcp.types import JSONRPCError, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
|
|
19
|
+
|
|
20
|
+
from fast_agent.mcp.transport_tracking import ChannelEvent, ChannelName
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from datetime import timedelta
|
|
24
|
+
|
|
25
|
+
from anyio.abc import ObjectReceiveStream, ObjectSendStream
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
ChannelHook = Callable[[ChannelEvent], None]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ChannelTrackingStreamableHTTPTransport(StreamableHTTPTransport):
|
|
33
|
+
"""Streamable HTTP transport that emits channel events before dispatching."""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
url: str,
|
|
38
|
+
*,
|
|
39
|
+
headers: dict[str, str] | None = None,
|
|
40
|
+
timeout: float | timedelta = 30,
|
|
41
|
+
sse_read_timeout: float | timedelta = 60 * 5,
|
|
42
|
+
auth: httpx.Auth | None = None,
|
|
43
|
+
channel_hook: ChannelHook | None = None,
|
|
44
|
+
) -> None:
|
|
45
|
+
super().__init__(
|
|
46
|
+
url,
|
|
47
|
+
headers=headers,
|
|
48
|
+
timeout=timeout,
|
|
49
|
+
sse_read_timeout=sse_read_timeout,
|
|
50
|
+
auth=auth,
|
|
51
|
+
)
|
|
52
|
+
self._channel_hook = channel_hook
|
|
53
|
+
|
|
54
|
+
def _emit_channel_event(
|
|
55
|
+
self,
|
|
56
|
+
channel: ChannelName,
|
|
57
|
+
event_type: str,
|
|
58
|
+
*,
|
|
59
|
+
message: JSONRPCMessage | None = None,
|
|
60
|
+
raw_event: str | None = None,
|
|
61
|
+
detail: str | None = None,
|
|
62
|
+
status_code: int | None = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
if self._channel_hook is None:
|
|
65
|
+
return
|
|
66
|
+
try:
|
|
67
|
+
self._channel_hook(
|
|
68
|
+
ChannelEvent(
|
|
69
|
+
channel=channel,
|
|
70
|
+
event_type=event_type, # type: ignore[arg-type]
|
|
71
|
+
message=message,
|
|
72
|
+
raw_event=raw_event,
|
|
73
|
+
detail=detail,
|
|
74
|
+
status_code=status_code,
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
except Exception: # pragma: no cover - hook errors must not break transport
|
|
78
|
+
logger.exception("Channel hook raised an exception")
|
|
79
|
+
|
|
80
|
+
async def _handle_json_response( # type: ignore[override]
|
|
81
|
+
self,
|
|
82
|
+
response: httpx.Response,
|
|
83
|
+
read_stream_writer: StreamWriter,
|
|
84
|
+
is_initialization: bool = False,
|
|
85
|
+
) -> None:
|
|
86
|
+
try:
|
|
87
|
+
content = await response.aread()
|
|
88
|
+
message = JSONRPCMessage.model_validate_json(content)
|
|
89
|
+
|
|
90
|
+
if is_initialization:
|
|
91
|
+
self._maybe_extract_protocol_version_from_message(message)
|
|
92
|
+
|
|
93
|
+
self._emit_channel_event("post-json", "message", message=message)
|
|
94
|
+
await read_stream_writer.send(SessionMessage(message))
|
|
95
|
+
except Exception as exc: # pragma: no cover - propagate to session
|
|
96
|
+
logger.exception("Error parsing JSON response")
|
|
97
|
+
await read_stream_writer.send(exc)
|
|
98
|
+
|
|
99
|
+
async def _handle_sse_event_with_channel(
|
|
100
|
+
self,
|
|
101
|
+
channel: ChannelName,
|
|
102
|
+
sse: ServerSentEvent,
|
|
103
|
+
read_stream_writer: StreamWriter,
|
|
104
|
+
original_request_id: RequestId | None = None,
|
|
105
|
+
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
|
|
106
|
+
is_initialization: bool = False,
|
|
107
|
+
) -> bool:
|
|
108
|
+
if sse.event != "message":
|
|
109
|
+
# Treat non-message events (e.g. ping) as keepalive notifications
|
|
110
|
+
self._emit_channel_event(channel, "keepalive", raw_event=sse.event or "keepalive")
|
|
111
|
+
return False
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
message = JSONRPCMessage.model_validate_json(sse.data)
|
|
115
|
+
if is_initialization:
|
|
116
|
+
self._maybe_extract_protocol_version_from_message(message)
|
|
117
|
+
|
|
118
|
+
if original_request_id is not None and isinstance(
|
|
119
|
+
message.root, (JSONRPCResponse, JSONRPCError)
|
|
120
|
+
):
|
|
121
|
+
message.root.id = original_request_id
|
|
122
|
+
|
|
123
|
+
self._emit_channel_event(channel, "message", message=message)
|
|
124
|
+
await read_stream_writer.send(SessionMessage(message))
|
|
125
|
+
|
|
126
|
+
if sse.id and resumption_callback:
|
|
127
|
+
await resumption_callback(sse.id)
|
|
128
|
+
|
|
129
|
+
return isinstance(message.root, (JSONRPCResponse, JSONRPCError))
|
|
130
|
+
except Exception as exc: # pragma: no cover - propagate to session
|
|
131
|
+
logger.exception("Error parsing SSE message")
|
|
132
|
+
await read_stream_writer.send(exc)
|
|
133
|
+
return False
|
|
134
|
+
|
|
135
|
+
async def handle_get_stream( # type: ignore[override]
|
|
136
|
+
self,
|
|
137
|
+
client: httpx.AsyncClient,
|
|
138
|
+
read_stream_writer: StreamWriter,
|
|
139
|
+
) -> None:
|
|
140
|
+
if not self.session_id:
|
|
141
|
+
return
|
|
142
|
+
|
|
143
|
+
headers = self._prepare_request_headers(self.request_headers)
|
|
144
|
+
connected = False
|
|
145
|
+
try:
|
|
146
|
+
async with aconnect_sse(
|
|
147
|
+
client,
|
|
148
|
+
"GET",
|
|
149
|
+
self.url,
|
|
150
|
+
headers=headers,
|
|
151
|
+
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
|
|
152
|
+
) as event_source:
|
|
153
|
+
event_source.response.raise_for_status()
|
|
154
|
+
self._emit_channel_event("get", "connect")
|
|
155
|
+
connected = True
|
|
156
|
+
async for sse in event_source.aiter_sse():
|
|
157
|
+
await self._handle_sse_event_with_channel(
|
|
158
|
+
"get",
|
|
159
|
+
sse,
|
|
160
|
+
read_stream_writer,
|
|
161
|
+
)
|
|
162
|
+
except Exception as exc: # pragma: no cover - non fatal stream errors
|
|
163
|
+
logger.debug("GET stream error (non-fatal): %s", exc)
|
|
164
|
+
status_code = None
|
|
165
|
+
detail = str(exc)
|
|
166
|
+
if isinstance(exc, httpx.HTTPStatusError):
|
|
167
|
+
if exc.response is not None:
|
|
168
|
+
status_code = exc.response.status_code
|
|
169
|
+
reason = exc.response.reason_phrase or ""
|
|
170
|
+
if not reason:
|
|
171
|
+
try:
|
|
172
|
+
reason = (exc.response.text or "").strip()
|
|
173
|
+
except Exception:
|
|
174
|
+
reason = ""
|
|
175
|
+
detail = f"HTTP {status_code}: {reason or 'response'}"
|
|
176
|
+
else:
|
|
177
|
+
status_code = exc.response.status_code if hasattr(exc, "response") else None
|
|
178
|
+
self._emit_channel_event("get", "error", detail=detail, status_code=status_code)
|
|
179
|
+
finally:
|
|
180
|
+
if connected:
|
|
181
|
+
self._emit_channel_event("get", "disconnect")
|
|
182
|
+
|
|
183
|
+
async def _handle_resumption_request( # type: ignore[override]
|
|
184
|
+
self,
|
|
185
|
+
ctx: RequestContext,
|
|
186
|
+
) -> None:
|
|
187
|
+
headers = self._prepare_request_headers(ctx.headers)
|
|
188
|
+
if ctx.metadata and ctx.metadata.resumption_token:
|
|
189
|
+
headers["last-event-id"] = ctx.metadata.resumption_token
|
|
190
|
+
else: # pragma: no cover - defensive
|
|
191
|
+
raise ValueError("Resumption request requires a resumption token")
|
|
192
|
+
|
|
193
|
+
original_request_id: RequestId | None = None
|
|
194
|
+
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
|
|
195
|
+
original_request_id = ctx.session_message.message.root.id
|
|
196
|
+
|
|
197
|
+
async with aconnect_sse(
|
|
198
|
+
ctx.client,
|
|
199
|
+
"GET",
|
|
200
|
+
self.url,
|
|
201
|
+
headers=headers,
|
|
202
|
+
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
|
|
203
|
+
) as event_source:
|
|
204
|
+
event_source.response.raise_for_status()
|
|
205
|
+
async for sse in event_source.aiter_sse():
|
|
206
|
+
is_complete = await self._handle_sse_event_with_channel(
|
|
207
|
+
"resumption",
|
|
208
|
+
sse,
|
|
209
|
+
ctx.read_stream_writer,
|
|
210
|
+
original_request_id,
|
|
211
|
+
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
|
212
|
+
)
|
|
213
|
+
if is_complete:
|
|
214
|
+
await event_source.response.aclose()
|
|
215
|
+
break
|
|
216
|
+
|
|
217
|
+
async def _handle_sse_response( # type: ignore[override]
|
|
218
|
+
self,
|
|
219
|
+
response: httpx.Response,
|
|
220
|
+
ctx: RequestContext,
|
|
221
|
+
is_initialization: bool = False,
|
|
222
|
+
) -> None:
|
|
223
|
+
try:
|
|
224
|
+
event_source = EventSource(response)
|
|
225
|
+
async for sse in event_source.aiter_sse():
|
|
226
|
+
is_complete = await self._handle_sse_event_with_channel(
|
|
227
|
+
"post-sse",
|
|
228
|
+
sse,
|
|
229
|
+
ctx.read_stream_writer,
|
|
230
|
+
resumption_callback=(
|
|
231
|
+
ctx.metadata.on_resumption_token_update if ctx.metadata else None
|
|
232
|
+
),
|
|
233
|
+
is_initialization=is_initialization,
|
|
234
|
+
)
|
|
235
|
+
if is_complete:
|
|
236
|
+
await response.aclose()
|
|
237
|
+
break
|
|
238
|
+
except Exception as exc: # pragma: no cover - propagate to session
|
|
239
|
+
logger.exception("Error reading SSE stream")
|
|
240
|
+
await ctx.read_stream_writer.send(exc)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@asynccontextmanager
|
|
244
|
+
async def tracking_streamablehttp_client(
|
|
245
|
+
url: str,
|
|
246
|
+
headers: dict[str, str] | None = None,
|
|
247
|
+
*,
|
|
248
|
+
timeout: float | timedelta = 30,
|
|
249
|
+
sse_read_timeout: float | timedelta = 60 * 5,
|
|
250
|
+
terminate_on_close: bool = True,
|
|
251
|
+
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
|
|
252
|
+
auth: httpx.Auth | None = None,
|
|
253
|
+
channel_hook: ChannelHook | None = None,
|
|
254
|
+
) -> AsyncGenerator[
|
|
255
|
+
tuple[
|
|
256
|
+
ObjectReceiveStream[SessionMessage | Exception],
|
|
257
|
+
ObjectSendStream[SessionMessage],
|
|
258
|
+
Callable[[], str | None],
|
|
259
|
+
],
|
|
260
|
+
None,
|
|
261
|
+
]:
|
|
262
|
+
"""Context manager mirroring streamablehttp_client with channel tracking."""
|
|
263
|
+
|
|
264
|
+
transport = ChannelTrackingStreamableHTTPTransport(
|
|
265
|
+
url,
|
|
266
|
+
headers=headers,
|
|
267
|
+
timeout=timeout,
|
|
268
|
+
sse_read_timeout=sse_read_timeout,
|
|
269
|
+
auth=auth,
|
|
270
|
+
channel_hook=channel_hook,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](
|
|
274
|
+
0
|
|
275
|
+
)
|
|
276
|
+
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
|
|
277
|
+
|
|
278
|
+
async with anyio.create_task_group() as tg:
|
|
279
|
+
try:
|
|
280
|
+
async with httpx_client_factory(
|
|
281
|
+
headers=transport.request_headers,
|
|
282
|
+
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
|
283
|
+
auth=transport.auth,
|
|
284
|
+
) as client:
|
|
285
|
+
|
|
286
|
+
def start_get_stream() -> None:
|
|
287
|
+
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
|
|
288
|
+
|
|
289
|
+
tg.start_soon(
|
|
290
|
+
transport.post_writer,
|
|
291
|
+
client,
|
|
292
|
+
write_stream_reader,
|
|
293
|
+
read_stream_writer,
|
|
294
|
+
write_stream,
|
|
295
|
+
start_get_stream,
|
|
296
|
+
tg,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
try:
|
|
300
|
+
yield read_stream, write_stream, transport.get_session_id
|
|
301
|
+
finally:
|
|
302
|
+
if transport.session_id and terminate_on_close:
|
|
303
|
+
await transport.terminate_session(client)
|
|
304
|
+
tg.cancel_scope.cancel()
|
|
305
|
+
finally:
|
|
306
|
+
await read_stream_writer.aclose()
|
|
307
|
+
await read_stream.aclose()
|
|
308
|
+
await write_stream_reader.aclose()
|
|
309
|
+
await write_stream.aclose()
|