fast-agent-mcp 0.3.8__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fast-agent-mcp might be problematic. Click here for more details.

@@ -0,0 +1,309 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from contextlib import asynccontextmanager
5
+ from typing import TYPE_CHECKING, AsyncGenerator, Awaitable, Callable
6
+
7
+ import anyio
8
+ import httpx
9
+ from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
10
+ from mcp.client.streamable_http import (
11
+ RequestContext,
12
+ RequestId,
13
+ StreamableHTTPTransport,
14
+ StreamWriter,
15
+ )
16
+ from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
17
+ from mcp.shared.message import SessionMessage
18
+ from mcp.types import JSONRPCError, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
19
+
20
+ from fast_agent.mcp.transport_tracking import ChannelEvent, ChannelName
21
+
22
+ if TYPE_CHECKING:
23
+ from datetime import timedelta
24
+
25
+ from anyio.abc import ObjectReceiveStream, ObjectSendStream
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ ChannelHook = Callable[[ChannelEvent], None]
30
+
31
+
32
+ class ChannelTrackingStreamableHTTPTransport(StreamableHTTPTransport):
33
+ """Streamable HTTP transport that emits channel events before dispatching."""
34
+
35
+ def __init__(
36
+ self,
37
+ url: str,
38
+ *,
39
+ headers: dict[str, str] | None = None,
40
+ timeout: float | timedelta = 30,
41
+ sse_read_timeout: float | timedelta = 60 * 5,
42
+ auth: httpx.Auth | None = None,
43
+ channel_hook: ChannelHook | None = None,
44
+ ) -> None:
45
+ super().__init__(
46
+ url,
47
+ headers=headers,
48
+ timeout=timeout,
49
+ sse_read_timeout=sse_read_timeout,
50
+ auth=auth,
51
+ )
52
+ self._channel_hook = channel_hook
53
+
54
+ def _emit_channel_event(
55
+ self,
56
+ channel: ChannelName,
57
+ event_type: str,
58
+ *,
59
+ message: JSONRPCMessage | None = None,
60
+ raw_event: str | None = None,
61
+ detail: str | None = None,
62
+ status_code: int | None = None,
63
+ ) -> None:
64
+ if self._channel_hook is None:
65
+ return
66
+ try:
67
+ self._channel_hook(
68
+ ChannelEvent(
69
+ channel=channel,
70
+ event_type=event_type, # type: ignore[arg-type]
71
+ message=message,
72
+ raw_event=raw_event,
73
+ detail=detail,
74
+ status_code=status_code,
75
+ )
76
+ )
77
+ except Exception: # pragma: no cover - hook errors must not break transport
78
+ logger.exception("Channel hook raised an exception")
79
+
80
+ async def _handle_json_response( # type: ignore[override]
81
+ self,
82
+ response: httpx.Response,
83
+ read_stream_writer: StreamWriter,
84
+ is_initialization: bool = False,
85
+ ) -> None:
86
+ try:
87
+ content = await response.aread()
88
+ message = JSONRPCMessage.model_validate_json(content)
89
+
90
+ if is_initialization:
91
+ self._maybe_extract_protocol_version_from_message(message)
92
+
93
+ self._emit_channel_event("post-json", "message", message=message)
94
+ await read_stream_writer.send(SessionMessage(message))
95
+ except Exception as exc: # pragma: no cover - propagate to session
96
+ logger.exception("Error parsing JSON response")
97
+ await read_stream_writer.send(exc)
98
+
99
+ async def _handle_sse_event_with_channel(
100
+ self,
101
+ channel: ChannelName,
102
+ sse: ServerSentEvent,
103
+ read_stream_writer: StreamWriter,
104
+ original_request_id: RequestId | None = None,
105
+ resumption_callback: Callable[[str], Awaitable[None]] | None = None,
106
+ is_initialization: bool = False,
107
+ ) -> bool:
108
+ if sse.event != "message":
109
+ # Treat non-message events (e.g. ping) as keepalive notifications
110
+ self._emit_channel_event(channel, "keepalive", raw_event=sse.event or "keepalive")
111
+ return False
112
+
113
+ try:
114
+ message = JSONRPCMessage.model_validate_json(sse.data)
115
+ if is_initialization:
116
+ self._maybe_extract_protocol_version_from_message(message)
117
+
118
+ if original_request_id is not None and isinstance(
119
+ message.root, (JSONRPCResponse, JSONRPCError)
120
+ ):
121
+ message.root.id = original_request_id
122
+
123
+ self._emit_channel_event(channel, "message", message=message)
124
+ await read_stream_writer.send(SessionMessage(message))
125
+
126
+ if sse.id and resumption_callback:
127
+ await resumption_callback(sse.id)
128
+
129
+ return isinstance(message.root, (JSONRPCResponse, JSONRPCError))
130
+ except Exception as exc: # pragma: no cover - propagate to session
131
+ logger.exception("Error parsing SSE message")
132
+ await read_stream_writer.send(exc)
133
+ return False
134
+
135
+ async def handle_get_stream( # type: ignore[override]
136
+ self,
137
+ client: httpx.AsyncClient,
138
+ read_stream_writer: StreamWriter,
139
+ ) -> None:
140
+ if not self.session_id:
141
+ return
142
+
143
+ headers = self._prepare_request_headers(self.request_headers)
144
+ connected = False
145
+ try:
146
+ async with aconnect_sse(
147
+ client,
148
+ "GET",
149
+ self.url,
150
+ headers=headers,
151
+ timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
152
+ ) as event_source:
153
+ event_source.response.raise_for_status()
154
+ self._emit_channel_event("get", "connect")
155
+ connected = True
156
+ async for sse in event_source.aiter_sse():
157
+ await self._handle_sse_event_with_channel(
158
+ "get",
159
+ sse,
160
+ read_stream_writer,
161
+ )
162
+ except Exception as exc: # pragma: no cover - non fatal stream errors
163
+ logger.debug("GET stream error (non-fatal): %s", exc)
164
+ status_code = None
165
+ detail = str(exc)
166
+ if isinstance(exc, httpx.HTTPStatusError):
167
+ if exc.response is not None:
168
+ status_code = exc.response.status_code
169
+ reason = exc.response.reason_phrase or ""
170
+ if not reason:
171
+ try:
172
+ reason = (exc.response.text or "").strip()
173
+ except Exception:
174
+ reason = ""
175
+ detail = f"HTTP {status_code}: {reason or 'response'}"
176
+ else:
177
+ status_code = exc.response.status_code if hasattr(exc, "response") else None
178
+ self._emit_channel_event("get", "error", detail=detail, status_code=status_code)
179
+ finally:
180
+ if connected:
181
+ self._emit_channel_event("get", "disconnect")
182
+
183
+ async def _handle_resumption_request( # type: ignore[override]
184
+ self,
185
+ ctx: RequestContext,
186
+ ) -> None:
187
+ headers = self._prepare_request_headers(ctx.headers)
188
+ if ctx.metadata and ctx.metadata.resumption_token:
189
+ headers["last-event-id"] = ctx.metadata.resumption_token
190
+ else: # pragma: no cover - defensive
191
+ raise ValueError("Resumption request requires a resumption token")
192
+
193
+ original_request_id: RequestId | None = None
194
+ if isinstance(ctx.session_message.message.root, JSONRPCRequest):
195
+ original_request_id = ctx.session_message.message.root.id
196
+
197
+ async with aconnect_sse(
198
+ ctx.client,
199
+ "GET",
200
+ self.url,
201
+ headers=headers,
202
+ timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
203
+ ) as event_source:
204
+ event_source.response.raise_for_status()
205
+ async for sse in event_source.aiter_sse():
206
+ is_complete = await self._handle_sse_event_with_channel(
207
+ "resumption",
208
+ sse,
209
+ ctx.read_stream_writer,
210
+ original_request_id,
211
+ ctx.metadata.on_resumption_token_update if ctx.metadata else None,
212
+ )
213
+ if is_complete:
214
+ await event_source.response.aclose()
215
+ break
216
+
217
+ async def _handle_sse_response( # type: ignore[override]
218
+ self,
219
+ response: httpx.Response,
220
+ ctx: RequestContext,
221
+ is_initialization: bool = False,
222
+ ) -> None:
223
+ try:
224
+ event_source = EventSource(response)
225
+ async for sse in event_source.aiter_sse():
226
+ is_complete = await self._handle_sse_event_with_channel(
227
+ "post-sse",
228
+ sse,
229
+ ctx.read_stream_writer,
230
+ resumption_callback=(
231
+ ctx.metadata.on_resumption_token_update if ctx.metadata else None
232
+ ),
233
+ is_initialization=is_initialization,
234
+ )
235
+ if is_complete:
236
+ await response.aclose()
237
+ break
238
+ except Exception as exc: # pragma: no cover - propagate to session
239
+ logger.exception("Error reading SSE stream")
240
+ await ctx.read_stream_writer.send(exc)
241
+
242
+
243
+ @asynccontextmanager
244
+ async def tracking_streamablehttp_client(
245
+ url: str,
246
+ headers: dict[str, str] | None = None,
247
+ *,
248
+ timeout: float | timedelta = 30,
249
+ sse_read_timeout: float | timedelta = 60 * 5,
250
+ terminate_on_close: bool = True,
251
+ httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
252
+ auth: httpx.Auth | None = None,
253
+ channel_hook: ChannelHook | None = None,
254
+ ) -> AsyncGenerator[
255
+ tuple[
256
+ ObjectReceiveStream[SessionMessage | Exception],
257
+ ObjectSendStream[SessionMessage],
258
+ Callable[[], str | None],
259
+ ],
260
+ None,
261
+ ]:
262
+ """Context manager mirroring streamablehttp_client with channel tracking."""
263
+
264
+ transport = ChannelTrackingStreamableHTTPTransport(
265
+ url,
266
+ headers=headers,
267
+ timeout=timeout,
268
+ sse_read_timeout=sse_read_timeout,
269
+ auth=auth,
270
+ channel_hook=channel_hook,
271
+ )
272
+
273
+ read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](
274
+ 0
275
+ )
276
+ write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
277
+
278
+ async with anyio.create_task_group() as tg:
279
+ try:
280
+ async with httpx_client_factory(
281
+ headers=transport.request_headers,
282
+ timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
283
+ auth=transport.auth,
284
+ ) as client:
285
+
286
+ def start_get_stream() -> None:
287
+ tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
288
+
289
+ tg.start_soon(
290
+ transport.post_writer,
291
+ client,
292
+ write_stream_reader,
293
+ read_stream_writer,
294
+ write_stream,
295
+ start_get_stream,
296
+ tg,
297
+ )
298
+
299
+ try:
300
+ yield read_stream, write_stream, transport.get_session_id
301
+ finally:
302
+ if transport.session_id and terminate_on_close:
303
+ await transport.terminate_session(client)
304
+ tg.cancel_scope.cancel()
305
+ finally:
306
+ await read_stream_writer.aclose()
307
+ await read_stream.aclose()
308
+ await write_stream_reader.aclose()
309
+ await write_stream.aclose()