fast-agent-mcp 0.3.8__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fast-agent-mcp might be problematic. Click here for more details.

@@ -1,4 +1,7 @@
1
1
  from asyncio import Lock, gather
2
+ from collections import Counter
3
+ from dataclasses import dataclass, field
4
+ from datetime import datetime, timezone
2
5
  from typing import (
3
6
  TYPE_CHECKING,
4
7
  Any,
@@ -17,11 +20,12 @@ from mcp.types import (
17
20
  CallToolResult,
18
21
  ListToolsResult,
19
22
  Prompt,
23
+ ServerCapabilities,
20
24
  TextContent,
21
25
  Tool,
22
26
  )
23
27
  from opentelemetry import trace
24
- from pydantic import AnyUrl, BaseModel, ConfigDict
28
+ from pydantic import AnyUrl, BaseModel, ConfigDict, Field
25
29
 
26
30
  from fast_agent.context_dependent import ContextDependent
27
31
  from fast_agent.core.logging.logger import get_logger
@@ -30,6 +34,7 @@ from fast_agent.mcp.common import SEP, create_namespaced_name, is_namespaced_nam
30
34
  from fast_agent.mcp.gen_client import gen_client
31
35
  from fast_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
32
36
  from fast_agent.mcp.mcp_connection_manager import MCPConnectionManager
37
+ from fast_agent.mcp.transport_tracking import TransportSnapshot
33
38
 
34
39
  if TYPE_CHECKING:
35
40
  from fast_agent.context import Context
@@ -52,6 +57,49 @@ class NamespacedTool(BaseModel):
52
57
  namespaced_tool_name: str
53
58
 
54
59
 
60
+ @dataclass
61
+ class ServerStats:
62
+ call_counts: Counter = field(default_factory=Counter)
63
+ last_call_at: datetime | None = None
64
+ last_error_at: datetime | None = None
65
+
66
+ def record(self, operation_type: str, success: bool) -> None:
67
+ self.call_counts[operation_type] += 1
68
+ now = datetime.now(timezone.utc)
69
+ self.last_call_at = now
70
+ if not success:
71
+ self.last_error_at = now
72
+
73
+
74
+ class ServerStatus(BaseModel):
75
+ server_name: str
76
+ implementation_name: str | None = None
77
+ implementation_version: str | None = None
78
+ server_capabilities: ServerCapabilities | None = None
79
+ client_capabilities: Mapping[str, Any] | None = None
80
+ client_info_name: str | None = None
81
+ client_info_version: str | None = None
82
+ transport: str | None = None
83
+ is_connected: bool | None = None
84
+ last_call_at: datetime | None = None
85
+ last_error_at: datetime | None = None
86
+ staleness_seconds: float | None = None
87
+ call_counts: Dict[str, int] = Field(default_factory=dict)
88
+ error_message: str | None = None
89
+ instructions_available: bool | None = None
90
+ instructions_enabled: bool | None = None
91
+ instructions_included: bool | None = None
92
+ roots_configured: bool | None = None
93
+ roots_count: int | None = None
94
+ elicitation_mode: str | None = None
95
+ sampling_mode: str | None = None
96
+ spoofing_enabled: bool | None = None
97
+ session_id: str | None = None
98
+ transport_channels: TransportSnapshot | None = None
99
+
100
+ model_config = ConfigDict(arbitrary_types_allowed=True)
101
+
102
+
55
103
  class MCPAggregator(ContextDependent):
56
104
  """
57
105
  Aggregates multiple MCP servers. When a developer calls, e.g. call_tool(...),
@@ -140,6 +188,10 @@ class MCPAggregator(ContextDependent):
140
188
  # Lock for refreshing tools from a server
141
189
  self._refresh_lock = Lock()
142
190
 
191
+ # Track runtime stats per server
192
+ self._server_stats: Dict[str, ServerStats] = {}
193
+ self._stats_lock = Lock()
194
+
143
195
  def _create_progress_callback(self, server_name: str, tool_name: str) -> "ProgressFnT":
144
196
  """Create a progress callback function for tool execution."""
145
197
 
@@ -461,6 +513,50 @@ class MCPAggregator(ContextDependent):
461
513
  for server_name in self.server_names:
462
514
  await self._refresh_server_tools(server_name)
463
515
 
516
+ async def _record_server_call(
517
+ self, server_name: str, operation_type: str, success: bool
518
+ ) -> None:
519
+ async with self._stats_lock:
520
+ stats = self._server_stats.setdefault(server_name, ServerStats())
521
+ stats.record(operation_type, success)
522
+
523
+ # For stdio servers, also emit synthetic transport events to create activity timeline
524
+ await self._notify_stdio_transport_activity(server_name, operation_type, success)
525
+
526
+ async def _notify_stdio_transport_activity(
527
+ self, server_name: str, operation_type: str, success: bool
528
+ ) -> None:
529
+ """Notify transport metrics of activity for stdio servers to create activity timeline."""
530
+ if not self._persistent_connection_manager:
531
+ return
532
+
533
+ try:
534
+ # Get the server connection and check if it's stdio transport
535
+ server_conn = self._persistent_connection_manager.running_servers.get(server_name)
536
+ if not server_conn:
537
+ return
538
+
539
+ server_config = getattr(server_conn, "server_config", None)
540
+ if not server_config or server_config.transport != "stdio":
541
+ return
542
+
543
+ # Get transport metrics and emit synthetic message event
544
+ transport_metrics = getattr(server_conn, "transport_metrics", None)
545
+ if transport_metrics:
546
+ # Import here to avoid circular imports
547
+ from fast_agent.mcp.transport_tracking import ChannelEvent
548
+
549
+ # Create a synthetic message event to represent the MCP operation
550
+ event = ChannelEvent(
551
+ channel="stdio",
552
+ event_type="message",
553
+ detail=f"{operation_type} ({'success' if success else 'error'})"
554
+ )
555
+ transport_metrics.record_event(event)
556
+ except Exception:
557
+ # Don't let transport tracking errors break normal operation
558
+ logger.debug("Failed to notify stdio transport activity for %s", server_name, exc_info=True)
559
+
464
560
  async def get_server_instructions(self) -> Dict[str, tuple[str, List[str]]]:
465
561
  """
466
562
  Get instructions from all connected servers along with their tool names.
@@ -492,6 +588,174 @@ class MCPAggregator(ContextDependent):
492
588
 
493
589
  return instructions
494
590
 
591
+ async def collect_server_status(self) -> Dict[str, ServerStatus]:
592
+ """Return aggregated status information for each configured server."""
593
+ if not self.initialized:
594
+ await self.load_servers()
595
+
596
+ now = datetime.now(timezone.utc)
597
+ status_map: Dict[str, ServerStatus] = {}
598
+
599
+ for server_name in self.server_names:
600
+ stats = self._server_stats.get(server_name)
601
+ last_call = stats.last_call_at if stats else None
602
+ last_error = stats.last_error_at if stats else None
603
+ staleness = (now - last_call).total_seconds() if last_call else None
604
+ call_counts = dict(stats.call_counts) if stats else {}
605
+
606
+ implementation_name = None
607
+ implementation_version = None
608
+ capabilities: ServerCapabilities | None = None
609
+ client_capabilities: Mapping[str, Any] | None = None
610
+ client_info_name = None
611
+ client_info_version = None
612
+ is_connected = None
613
+ error_message = None
614
+ instructions_available = None
615
+ instructions_enabled = None
616
+ instructions_included = None
617
+ roots_configured = None
618
+ roots_count = None
619
+ elicitation_mode = None
620
+ sampling_mode = None
621
+ spoofing_enabled = None
622
+ server_cfg = None
623
+ session_id = None
624
+ server_conn = None
625
+ transport: str | None = None
626
+ transport_snapshot: TransportSnapshot | None = None
627
+
628
+ manager = getattr(self, "_persistent_connection_manager", None)
629
+ if self.connection_persistence and manager is not None:
630
+ try:
631
+ server_conn = await manager.get_server(
632
+ server_name,
633
+ client_session_factory=self._create_session_factory(server_name),
634
+ )
635
+ implementation = getattr(server_conn, "server_implementation", None)
636
+ if implementation:
637
+ implementation_name = getattr(implementation, "name", None)
638
+ implementation_version = getattr(implementation, "version", None)
639
+ capabilities = getattr(server_conn, "server_capabilities", None)
640
+ client_capabilities = getattr(server_conn, "client_capabilities", None)
641
+ session = server_conn.session
642
+ client_info = getattr(session, "client_info", None) if session else None
643
+ if client_info:
644
+ client_info_name = getattr(client_info, "name", None)
645
+ client_info_version = getattr(client_info, "version", None)
646
+ is_connected = server_conn.is_healthy()
647
+ error_message = getattr(server_conn, "_error_message", None)
648
+ instructions_available = getattr(
649
+ server_conn, "server_instructions_available", None
650
+ )
651
+ instructions_enabled = getattr(
652
+ server_conn, "server_instructions_enabled", None
653
+ )
654
+ instructions_included = bool(getattr(server_conn, "server_instructions", None))
655
+ server_cfg = getattr(server_conn, "server_config", None)
656
+ if session:
657
+ elicitation_mode = getattr(session, "effective_elicitation_mode", elicitation_mode)
658
+ session_id = getattr(server_conn, "session_id", None)
659
+ if not session_id and getattr(server_conn, "_get_session_id_cb", None):
660
+ try:
661
+ session_id = server_conn._get_session_id_cb() # type: ignore[attr-defined]
662
+ except Exception:
663
+ session_id = None
664
+ metrics = getattr(server_conn, "transport_metrics", None)
665
+ if metrics is not None:
666
+ try:
667
+ transport_snapshot = metrics.snapshot()
668
+ except Exception:
669
+ logger.debug(
670
+ "Failed to snapshot transport metrics for server '%s'",
671
+ server_name,
672
+ exc_info=True,
673
+ )
674
+ except Exception as exc:
675
+ logger.debug(
676
+ f"Failed to collect status for server '{server_name}'",
677
+ data={"error": str(exc)},
678
+ )
679
+
680
+ if server_cfg is None and self.context and getattr(self.context, "server_registry", None):
681
+ try:
682
+ server_cfg = self.context.server_registry.get_server_config(server_name)
683
+ except Exception:
684
+ server_cfg = None
685
+
686
+ if server_cfg is not None:
687
+ instructions_enabled = (
688
+ instructions_enabled
689
+ if instructions_enabled is not None
690
+ else server_cfg.include_instructions
691
+ )
692
+ roots = getattr(server_cfg, "roots", None)
693
+ roots_configured = bool(roots)
694
+ roots_count = len(roots) if roots else 0
695
+ transport = getattr(server_cfg, "transport", transport)
696
+ elicitation = getattr(server_cfg, "elicitation", None)
697
+ elicitation_mode = (
698
+ getattr(elicitation, "mode", None)
699
+ if elicitation
700
+ else elicitation_mode
701
+ )
702
+ sampling_cfg = getattr(server_cfg, "sampling", None)
703
+ spoofing_enabled = bool(getattr(server_cfg, "implementation", None))
704
+ if implementation_name is None and getattr(server_cfg, "implementation", None):
705
+ implementation_name = server_cfg.implementation.name
706
+ implementation_version = getattr(server_cfg.implementation, "version", None)
707
+ if session_id is None:
708
+ if server_cfg.transport == "stdio":
709
+ session_id = "local"
710
+ elif server_conn and getattr(server_conn, "_get_session_id_cb", None):
711
+ try:
712
+ session_id = server_conn._get_session_id_cb() # type: ignore[attr-defined]
713
+ except Exception:
714
+ session_id = None
715
+
716
+ if sampling_cfg is not None:
717
+ sampling_mode = "configured"
718
+ else:
719
+ auto_sampling = True
720
+ if self.context and getattr(self.context, "config", None):
721
+ auto_sampling = getattr(self.context.config, "auto_sampling", True)
722
+ sampling_mode = "auto" if auto_sampling else "off"
723
+ else:
724
+ # Fall back to defaults when config missing
725
+ auto_sampling = True
726
+ if self.context and getattr(self.context, "config", None):
727
+ auto_sampling = getattr(self.context.config, "auto_sampling", True)
728
+ sampling_mode = sampling_mode or ("auto" if auto_sampling else "off")
729
+
730
+ status_map[server_name] = ServerStatus(
731
+ server_name=server_name,
732
+ implementation_name=implementation_name,
733
+ implementation_version=implementation_version,
734
+ server_capabilities=capabilities,
735
+ client_capabilities=client_capabilities,
736
+ client_info_name=client_info_name,
737
+ client_info_version=client_info_version,
738
+ transport=transport,
739
+ is_connected=is_connected,
740
+ last_call_at=last_call,
741
+ last_error_at=last_error,
742
+ staleness_seconds=staleness,
743
+ call_counts=call_counts,
744
+ error_message=error_message,
745
+ instructions_available=instructions_available,
746
+ instructions_enabled=instructions_enabled,
747
+ instructions_included=instructions_included,
748
+ roots_configured=roots_configured,
749
+ roots_count=roots_count,
750
+ elicitation_mode=elicitation_mode,
751
+ sampling_mode=sampling_mode,
752
+ spoofing_enabled=spoofing_enabled,
753
+ session_id=session_id,
754
+ transport_channels=transport_snapshot,
755
+ )
756
+
757
+ return status_map
758
+
495
759
  async def _execute_on_server(
496
760
  self,
497
761
  server_name: str,
@@ -554,13 +818,17 @@ class MCPAggregator(ContextDependent):
554
818
  # Re-raise the original exception to propagate it
555
819
  raise e
556
820
 
821
+ success_flag: bool | None = None
822
+ result: R | None = None
823
+
557
824
  # Try initial execution
558
825
  try:
559
826
  if self.connection_persistence:
560
827
  server_connection = await self._persistent_connection_manager.get_server(
561
828
  server_name, client_session_factory=self._create_session_factory(server_name)
562
829
  )
563
- return await try_execute(server_connection.session)
830
+ result = await try_execute(server_connection.session)
831
+ success_flag = True
564
832
  else:
565
833
  logger.debug(
566
834
  f"Creating temporary connection to server: {server_name}",
@@ -582,7 +850,7 @@ class MCPAggregator(ContextDependent):
582
850
  "agent_name": self.agent_name,
583
851
  },
584
852
  )
585
- return result
853
+ success_flag = True
586
854
  except ConnectionError:
587
855
  # Server offline - attempt reconnection
588
856
  from fast_agent.ui import console
@@ -613,7 +881,7 @@ class MCPAggregator(ContextDependent):
613
881
 
614
882
  # Success!
615
883
  console.console.print(f"[dim green]MCP server {server_name} online[/dim green]")
616
- return result
884
+ success_flag = True
617
885
 
618
886
  except Exception:
619
887
  # Reconnection failed
@@ -621,10 +889,19 @@ class MCPAggregator(ContextDependent):
621
889
  f"[dim red]MCP server {server_name} offline - failed to reconnect[/dim red]"
622
890
  )
623
891
  error_msg = f"MCP server {server_name} offline - failed to reconnect"
892
+ success_flag = False
624
893
  if error_factory:
625
- return error_factory(error_msg)
894
+ result = error_factory(error_msg)
626
895
  else:
627
896
  raise Exception(error_msg)
897
+ except Exception:
898
+ success_flag = False
899
+ raise
900
+ finally:
901
+ if success_flag is not None:
902
+ await self._record_server_call(server_name, operation_type, success_flag)
903
+
904
+ return result
628
905
 
629
906
  async def _parse_resource_name(self, name: str, resource_type: str) -> tuple[str, str]:
630
907
  """
@@ -21,10 +21,9 @@ from mcp.client.sse import sse_client
21
21
  from mcp.client.stdio import (
22
22
  StdioServerParameters,
23
23
  get_default_environment,
24
- stdio_client,
25
24
  )
26
- from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
27
- from mcp.types import JSONRPCMessage, ServerCapabilities
25
+ from mcp.client.streamable_http import GetSessionIdCallback
26
+ from mcp.types import Implementation, JSONRPCMessage, ServerCapabilities
28
27
 
29
28
  from fast_agent.config import MCPServerSettings
30
29
  from fast_agent.context_dependent import ContextDependent
@@ -34,6 +33,9 @@ from fast_agent.event_progress import ProgressAction
34
33
  from fast_agent.mcp.logger_textio import get_stderr_handler
35
34
  from fast_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
36
35
  from fast_agent.mcp.oauth_client import build_oauth_provider
36
+ from fast_agent.mcp.stdio_tracking_simple import tracking_stdio_client
37
+ from fast_agent.mcp.streamable_http_tracking import tracking_streamablehttp_client
38
+ from fast_agent.mcp.transport_tracking import TransportChannelMetrics
37
39
 
38
40
  if TYPE_CHECKING:
39
41
  from fast_agent.context import Context
@@ -107,6 +109,14 @@ class ServerConnection:
107
109
 
108
110
  # Server instructions from initialization
109
111
  self.server_instructions: str | None = None
112
+ self.server_capabilities: ServerCapabilities | None = None
113
+ self.server_implementation: Implementation | None = None
114
+ self.client_capabilities: dict | None = None
115
+ self.server_instructions_available: bool = False
116
+ self.server_instructions_enabled: bool = server_config.include_instructions if server_config else True
117
+ self.session_id: str | None = None
118
+ self._get_session_id_cb: GetSessionIdCallback | None = None
119
+ self.transport_metrics: TransportChannelMetrics | None = None
110
120
 
111
121
  def is_healthy(self) -> bool:
112
122
  """Check if the server connection is healthy and ready to use."""
@@ -138,15 +148,32 @@ class ServerConnection:
138
148
  result = await self.session.initialize()
139
149
 
140
150
  self.server_capabilities = result.capabilities
151
+ # InitializeResult exposes server info via `serverInfo`; keep fallback for older fields
152
+ implementation = getattr(result, "serverInfo", None)
153
+ if implementation is None:
154
+ implementation = getattr(result, "implementation", None)
155
+ self.server_implementation = implementation
156
+
157
+ raw_instructions = getattr(result, "instructions", None)
158
+ self.server_instructions_available = bool(raw_instructions)
141
159
 
142
160
  # Store instructions if provided by the server and enabled in config
143
161
  if self.server_config.include_instructions:
144
- self.server_instructions = getattr(result, 'instructions', None)
162
+ self.server_instructions = raw_instructions
145
163
  if self.server_instructions:
146
- logger.debug(f"{self.server_name}: Received server instructions", data={"instructions": self.server_instructions})
164
+ logger.debug(
165
+ f"{self.server_name}: Received server instructions",
166
+ data={"instructions": self.server_instructions},
167
+ )
147
168
  else:
148
169
  self.server_instructions = None
149
- logger.debug(f"{self.server_name}: Server instructions disabled by configuration")
170
+ if self.server_instructions_available:
171
+ logger.debug(
172
+ f"{self.server_name}: Server instructions disabled by configuration",
173
+ data={"instructions": raw_instructions},
174
+ )
175
+ else:
176
+ logger.debug(f"{self.server_name}: No server instructions provided")
150
177
 
151
178
  # If there's an init hook, run it
152
179
 
@@ -175,10 +202,15 @@ class ServerConnection:
175
202
  )
176
203
 
177
204
  session = self._client_session_factory(
178
- read_stream, send_stream, read_timeout, server_config=self.server_config
205
+ read_stream,
206
+ send_stream,
207
+ read_timeout,
208
+ server_config=self.server_config,
209
+ transport_metrics=self.transport_metrics,
179
210
  )
180
211
 
181
212
  self.session = session
213
+ self.client_capabilities = getattr(session, "client_capabilities", None)
182
214
 
183
215
  return session
184
216
 
@@ -192,11 +224,30 @@ async def _server_lifecycle_task(server_conn: ServerConnection) -> None:
192
224
  try:
193
225
  transport_context = server_conn._transport_context_factory()
194
226
 
195
- async with transport_context as (read_stream, write_stream, _):
227
+ async with transport_context as (read_stream, write_stream, get_session_id_cb):
228
+ server_conn._get_session_id_cb = get_session_id_cb
229
+
230
+ if get_session_id_cb is not None:
231
+ try:
232
+ server_conn.session_id = get_session_id_cb()
233
+ except Exception:
234
+ logger.debug(f"{server_name}: Unable to retrieve session id from transport")
235
+ elif server_conn.server_config.transport == "stdio":
236
+ server_conn.session_id = "local"
237
+
196
238
  server_conn.create_session(read_stream, write_stream)
197
239
 
198
240
  async with server_conn.session:
199
241
  await server_conn.initialize_session()
242
+
243
+ if get_session_id_cb is not None:
244
+ try:
245
+ server_conn.session_id = get_session_id_cb() or server_conn.session_id
246
+ except Exception:
247
+ logger.debug(f"{server_name}: Unable to refresh session id after init")
248
+ elif server_conn.server_config.transport == "stdio":
249
+ server_conn.session_id = "local"
250
+
200
251
  await server_conn.wait_for_shutdown_request()
201
252
 
202
253
  except HTTPStatusError as http_exc:
@@ -353,6 +404,8 @@ class MCPConnectionManager(ContextDependent):
353
404
 
354
405
  logger.debug(f"{server_name}: Found server configuration=", data=config.model_dump())
355
406
 
407
+ transport_metrics = TransportChannelMetrics() if config.transport in ("http", "stdio") else None
408
+
356
409
  def transport_context_factory():
357
410
  if config.transport == "stdio":
358
411
  if not config.command:
@@ -369,7 +422,11 @@ class MCPConnectionManager(ContextDependent):
369
422
  error_handler = get_stderr_handler(server_name)
370
423
  # Explicitly ensure we're using our custom logger for stderr
371
424
  logger.debug(f"{server_name}: Creating stdio client with custom error handler")
372
- return _add_none_to_context(stdio_client(server_params, errlog=error_handler))
425
+
426
+ channel_hook = transport_metrics.record_event if transport_metrics else None
427
+ return _add_none_to_context(
428
+ tracking_stdio_client(server_params, channel_hook=channel_hook, errlog=error_handler)
429
+ )
373
430
  elif config.transport == "sse":
374
431
  if not config.url:
375
432
  raise ValueError(
@@ -401,7 +458,23 @@ class MCPConnectionManager(ContextDependent):
401
458
  if oauth_auth is not None:
402
459
  headers.pop("Authorization", None)
403
460
  headers.pop("X-HF-Authorization", None)
404
- return streamablehttp_client(config.url, headers, auth=oauth_auth)
461
+ channel_hook = None
462
+ if transport_metrics is not None:
463
+ def channel_hook(event):
464
+ try:
465
+ transport_metrics.record_event(event)
466
+ except Exception: # pragma: no cover - defensive guard
467
+ logger.debug(
468
+ "%s: transport metrics hook failed", server_name,
469
+ exc_info=True,
470
+ )
471
+
472
+ return tracking_streamablehttp_client(
473
+ config.url,
474
+ headers,
475
+ auth=oauth_auth,
476
+ channel_hook=channel_hook,
477
+ )
405
478
  else:
406
479
  raise ValueError(f"Unsupported transport: {config.transport}")
407
480
 
@@ -412,6 +485,9 @@ class MCPConnectionManager(ContextDependent):
412
485
  client_session_factory=client_session_factory,
413
486
  )
414
487
 
488
+ if transport_metrics is not None:
489
+ server_conn.transport_metrics = transport_metrics
490
+
415
491
  async with self._lock:
416
492
  # Check if already running
417
493
  if server_name in self.running_servers:
@@ -0,0 +1,59 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from contextlib import asynccontextmanager
5
+ from typing import TYPE_CHECKING, AsyncGenerator, Callable
6
+
7
+ from mcp.client.stdio import StdioServerParameters, stdio_client
8
+
9
+ from fast_agent.mcp.transport_tracking import ChannelEvent
10
+
11
+ if TYPE_CHECKING:
12
+ from anyio.abc import ObjectReceiveStream, ObjectSendStream
13
+ from mcp.shared.message import SessionMessage
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ ChannelHook = Callable[[ChannelEvent], None]
18
+
19
+
20
+ @asynccontextmanager
21
+ async def tracking_stdio_client(
22
+ server_params: StdioServerParameters,
23
+ *,
24
+ channel_hook: ChannelHook | None = None,
25
+ errlog: Callable[[str], None] | None = None,
26
+ ) -> AsyncGenerator[
27
+ tuple[ObjectReceiveStream[SessionMessage | Exception], ObjectSendStream[SessionMessage]], None
28
+ ]:
29
+ """Context manager for stdio client with basic connection tracking."""
30
+
31
+ def emit_channel_event(event_type: str, detail: str | None = None) -> None:
32
+ if channel_hook is None:
33
+ return
34
+ try:
35
+ channel_hook(
36
+ ChannelEvent(
37
+ channel="stdio",
38
+ event_type=event_type, # type: ignore[arg-type]
39
+ detail=detail,
40
+ )
41
+ )
42
+ except Exception: # pragma: no cover - hook errors must not break transport
43
+ logger.exception("Channel hook raised an exception")
44
+
45
+ try:
46
+ # Emit connection event
47
+ emit_channel_event("connect")
48
+
49
+ # Use the original stdio_client without stream interception
50
+ async with stdio_client(server_params, errlog=errlog) as (read_stream, write_stream):
51
+ yield read_stream, write_stream
52
+
53
+ except Exception as exc:
54
+ # Emit error event
55
+ emit_channel_event("error", detail=str(exc))
56
+ raise
57
+ finally:
58
+ # Emit disconnection event
59
+ emit_channel_event("disconnect")