fast-agent-mcp 0.3.8__py3-none-any.whl → 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fast-agent-mcp might be problematic. Click here for more details.
- fast_agent/agents/llm_agent.py +24 -0
- fast_agent/agents/mcp_agent.py +7 -1
- fast_agent/core/direct_factory.py +20 -8
- fast_agent/llm/provider/anthropic/llm_anthropic.py +107 -62
- fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py +4 -3
- fast_agent/llm/provider/google/google_converter.py +8 -41
- fast_agent/llm/provider/openai/llm_openai.py +3 -3
- fast_agent/mcp/mcp_agent_client_session.py +45 -2
- fast_agent/mcp/mcp_aggregator.py +282 -5
- fast_agent/mcp/mcp_connection_manager.py +86 -10
- fast_agent/mcp/stdio_tracking_simple.py +59 -0
- fast_agent/mcp/streamable_http_tracking.py +309 -0
- fast_agent/mcp/transport_tracking.py +598 -0
- fast_agent/resources/examples/data-analysis/analysis.py +7 -3
- fast_agent/ui/console_display.py +22 -1
- fast_agent/ui/enhanced_prompt.py +21 -1
- fast_agent/ui/interactive_prompt.py +5 -0
- fast_agent/ui/mcp_display.py +636 -0
- {fast_agent_mcp-0.3.8.dist-info → fast_agent_mcp-0.3.9.dist-info}/METADATA +5 -5
- {fast_agent_mcp-0.3.8.dist-info → fast_agent_mcp-0.3.9.dist-info}/RECORD +23 -19
- {fast_agent_mcp-0.3.8.dist-info → fast_agent_mcp-0.3.9.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.3.8.dist-info → fast_agent_mcp-0.3.9.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.3.8.dist-info → fast_agent_mcp-0.3.9.dist-info}/licenses/LICENSE +0 -0
fast_agent/mcp/mcp_aggregator.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
1
|
from asyncio import Lock, gather
|
|
2
|
+
from collections import Counter
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from datetime import datetime, timezone
|
|
2
5
|
from typing import (
|
|
3
6
|
TYPE_CHECKING,
|
|
4
7
|
Any,
|
|
@@ -17,11 +20,12 @@ from mcp.types import (
|
|
|
17
20
|
CallToolResult,
|
|
18
21
|
ListToolsResult,
|
|
19
22
|
Prompt,
|
|
23
|
+
ServerCapabilities,
|
|
20
24
|
TextContent,
|
|
21
25
|
Tool,
|
|
22
26
|
)
|
|
23
27
|
from opentelemetry import trace
|
|
24
|
-
from pydantic import AnyUrl, BaseModel, ConfigDict
|
|
28
|
+
from pydantic import AnyUrl, BaseModel, ConfigDict, Field
|
|
25
29
|
|
|
26
30
|
from fast_agent.context_dependent import ContextDependent
|
|
27
31
|
from fast_agent.core.logging.logger import get_logger
|
|
@@ -30,6 +34,7 @@ from fast_agent.mcp.common import SEP, create_namespaced_name, is_namespaced_nam
|
|
|
30
34
|
from fast_agent.mcp.gen_client import gen_client
|
|
31
35
|
from fast_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
|
|
32
36
|
from fast_agent.mcp.mcp_connection_manager import MCPConnectionManager
|
|
37
|
+
from fast_agent.mcp.transport_tracking import TransportSnapshot
|
|
33
38
|
|
|
34
39
|
if TYPE_CHECKING:
|
|
35
40
|
from fast_agent.context import Context
|
|
@@ -52,6 +57,49 @@ class NamespacedTool(BaseModel):
|
|
|
52
57
|
namespaced_tool_name: str
|
|
53
58
|
|
|
54
59
|
|
|
60
|
+
@dataclass
|
|
61
|
+
class ServerStats:
|
|
62
|
+
call_counts: Counter = field(default_factory=Counter)
|
|
63
|
+
last_call_at: datetime | None = None
|
|
64
|
+
last_error_at: datetime | None = None
|
|
65
|
+
|
|
66
|
+
def record(self, operation_type: str, success: bool) -> None:
|
|
67
|
+
self.call_counts[operation_type] += 1
|
|
68
|
+
now = datetime.now(timezone.utc)
|
|
69
|
+
self.last_call_at = now
|
|
70
|
+
if not success:
|
|
71
|
+
self.last_error_at = now
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ServerStatus(BaseModel):
|
|
75
|
+
server_name: str
|
|
76
|
+
implementation_name: str | None = None
|
|
77
|
+
implementation_version: str | None = None
|
|
78
|
+
server_capabilities: ServerCapabilities | None = None
|
|
79
|
+
client_capabilities: Mapping[str, Any] | None = None
|
|
80
|
+
client_info_name: str | None = None
|
|
81
|
+
client_info_version: str | None = None
|
|
82
|
+
transport: str | None = None
|
|
83
|
+
is_connected: bool | None = None
|
|
84
|
+
last_call_at: datetime | None = None
|
|
85
|
+
last_error_at: datetime | None = None
|
|
86
|
+
staleness_seconds: float | None = None
|
|
87
|
+
call_counts: Dict[str, int] = Field(default_factory=dict)
|
|
88
|
+
error_message: str | None = None
|
|
89
|
+
instructions_available: bool | None = None
|
|
90
|
+
instructions_enabled: bool | None = None
|
|
91
|
+
instructions_included: bool | None = None
|
|
92
|
+
roots_configured: bool | None = None
|
|
93
|
+
roots_count: int | None = None
|
|
94
|
+
elicitation_mode: str | None = None
|
|
95
|
+
sampling_mode: str | None = None
|
|
96
|
+
spoofing_enabled: bool | None = None
|
|
97
|
+
session_id: str | None = None
|
|
98
|
+
transport_channels: TransportSnapshot | None = None
|
|
99
|
+
|
|
100
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
101
|
+
|
|
102
|
+
|
|
55
103
|
class MCPAggregator(ContextDependent):
|
|
56
104
|
"""
|
|
57
105
|
Aggregates multiple MCP servers. When a developer calls, e.g. call_tool(...),
|
|
@@ -140,6 +188,10 @@ class MCPAggregator(ContextDependent):
|
|
|
140
188
|
# Lock for refreshing tools from a server
|
|
141
189
|
self._refresh_lock = Lock()
|
|
142
190
|
|
|
191
|
+
# Track runtime stats per server
|
|
192
|
+
self._server_stats: Dict[str, ServerStats] = {}
|
|
193
|
+
self._stats_lock = Lock()
|
|
194
|
+
|
|
143
195
|
def _create_progress_callback(self, server_name: str, tool_name: str) -> "ProgressFnT":
|
|
144
196
|
"""Create a progress callback function for tool execution."""
|
|
145
197
|
|
|
@@ -461,6 +513,50 @@ class MCPAggregator(ContextDependent):
|
|
|
461
513
|
for server_name in self.server_names:
|
|
462
514
|
await self._refresh_server_tools(server_name)
|
|
463
515
|
|
|
516
|
+
async def _record_server_call(
|
|
517
|
+
self, server_name: str, operation_type: str, success: bool
|
|
518
|
+
) -> None:
|
|
519
|
+
async with self._stats_lock:
|
|
520
|
+
stats = self._server_stats.setdefault(server_name, ServerStats())
|
|
521
|
+
stats.record(operation_type, success)
|
|
522
|
+
|
|
523
|
+
# For stdio servers, also emit synthetic transport events to create activity timeline
|
|
524
|
+
await self._notify_stdio_transport_activity(server_name, operation_type, success)
|
|
525
|
+
|
|
526
|
+
async def _notify_stdio_transport_activity(
|
|
527
|
+
self, server_name: str, operation_type: str, success: bool
|
|
528
|
+
) -> None:
|
|
529
|
+
"""Notify transport metrics of activity for stdio servers to create activity timeline."""
|
|
530
|
+
if not self._persistent_connection_manager:
|
|
531
|
+
return
|
|
532
|
+
|
|
533
|
+
try:
|
|
534
|
+
# Get the server connection and check if it's stdio transport
|
|
535
|
+
server_conn = self._persistent_connection_manager.running_servers.get(server_name)
|
|
536
|
+
if not server_conn:
|
|
537
|
+
return
|
|
538
|
+
|
|
539
|
+
server_config = getattr(server_conn, "server_config", None)
|
|
540
|
+
if not server_config or server_config.transport != "stdio":
|
|
541
|
+
return
|
|
542
|
+
|
|
543
|
+
# Get transport metrics and emit synthetic message event
|
|
544
|
+
transport_metrics = getattr(server_conn, "transport_metrics", None)
|
|
545
|
+
if transport_metrics:
|
|
546
|
+
# Import here to avoid circular imports
|
|
547
|
+
from fast_agent.mcp.transport_tracking import ChannelEvent
|
|
548
|
+
|
|
549
|
+
# Create a synthetic message event to represent the MCP operation
|
|
550
|
+
event = ChannelEvent(
|
|
551
|
+
channel="stdio",
|
|
552
|
+
event_type="message",
|
|
553
|
+
detail=f"{operation_type} ({'success' if success else 'error'})"
|
|
554
|
+
)
|
|
555
|
+
transport_metrics.record_event(event)
|
|
556
|
+
except Exception:
|
|
557
|
+
# Don't let transport tracking errors break normal operation
|
|
558
|
+
logger.debug("Failed to notify stdio transport activity for %s", server_name, exc_info=True)
|
|
559
|
+
|
|
464
560
|
async def get_server_instructions(self) -> Dict[str, tuple[str, List[str]]]:
|
|
465
561
|
"""
|
|
466
562
|
Get instructions from all connected servers along with their tool names.
|
|
@@ -492,6 +588,174 @@ class MCPAggregator(ContextDependent):
|
|
|
492
588
|
|
|
493
589
|
return instructions
|
|
494
590
|
|
|
591
|
+
async def collect_server_status(self) -> Dict[str, ServerStatus]:
|
|
592
|
+
"""Return aggregated status information for each configured server."""
|
|
593
|
+
if not self.initialized:
|
|
594
|
+
await self.load_servers()
|
|
595
|
+
|
|
596
|
+
now = datetime.now(timezone.utc)
|
|
597
|
+
status_map: Dict[str, ServerStatus] = {}
|
|
598
|
+
|
|
599
|
+
for server_name in self.server_names:
|
|
600
|
+
stats = self._server_stats.get(server_name)
|
|
601
|
+
last_call = stats.last_call_at if stats else None
|
|
602
|
+
last_error = stats.last_error_at if stats else None
|
|
603
|
+
staleness = (now - last_call).total_seconds() if last_call else None
|
|
604
|
+
call_counts = dict(stats.call_counts) if stats else {}
|
|
605
|
+
|
|
606
|
+
implementation_name = None
|
|
607
|
+
implementation_version = None
|
|
608
|
+
capabilities: ServerCapabilities | None = None
|
|
609
|
+
client_capabilities: Mapping[str, Any] | None = None
|
|
610
|
+
client_info_name = None
|
|
611
|
+
client_info_version = None
|
|
612
|
+
is_connected = None
|
|
613
|
+
error_message = None
|
|
614
|
+
instructions_available = None
|
|
615
|
+
instructions_enabled = None
|
|
616
|
+
instructions_included = None
|
|
617
|
+
roots_configured = None
|
|
618
|
+
roots_count = None
|
|
619
|
+
elicitation_mode = None
|
|
620
|
+
sampling_mode = None
|
|
621
|
+
spoofing_enabled = None
|
|
622
|
+
server_cfg = None
|
|
623
|
+
session_id = None
|
|
624
|
+
server_conn = None
|
|
625
|
+
transport: str | None = None
|
|
626
|
+
transport_snapshot: TransportSnapshot | None = None
|
|
627
|
+
|
|
628
|
+
manager = getattr(self, "_persistent_connection_manager", None)
|
|
629
|
+
if self.connection_persistence and manager is not None:
|
|
630
|
+
try:
|
|
631
|
+
server_conn = await manager.get_server(
|
|
632
|
+
server_name,
|
|
633
|
+
client_session_factory=self._create_session_factory(server_name),
|
|
634
|
+
)
|
|
635
|
+
implementation = getattr(server_conn, "server_implementation", None)
|
|
636
|
+
if implementation:
|
|
637
|
+
implementation_name = getattr(implementation, "name", None)
|
|
638
|
+
implementation_version = getattr(implementation, "version", None)
|
|
639
|
+
capabilities = getattr(server_conn, "server_capabilities", None)
|
|
640
|
+
client_capabilities = getattr(server_conn, "client_capabilities", None)
|
|
641
|
+
session = server_conn.session
|
|
642
|
+
client_info = getattr(session, "client_info", None) if session else None
|
|
643
|
+
if client_info:
|
|
644
|
+
client_info_name = getattr(client_info, "name", None)
|
|
645
|
+
client_info_version = getattr(client_info, "version", None)
|
|
646
|
+
is_connected = server_conn.is_healthy()
|
|
647
|
+
error_message = getattr(server_conn, "_error_message", None)
|
|
648
|
+
instructions_available = getattr(
|
|
649
|
+
server_conn, "server_instructions_available", None
|
|
650
|
+
)
|
|
651
|
+
instructions_enabled = getattr(
|
|
652
|
+
server_conn, "server_instructions_enabled", None
|
|
653
|
+
)
|
|
654
|
+
instructions_included = bool(getattr(server_conn, "server_instructions", None))
|
|
655
|
+
server_cfg = getattr(server_conn, "server_config", None)
|
|
656
|
+
if session:
|
|
657
|
+
elicitation_mode = getattr(session, "effective_elicitation_mode", elicitation_mode)
|
|
658
|
+
session_id = getattr(server_conn, "session_id", None)
|
|
659
|
+
if not session_id and getattr(server_conn, "_get_session_id_cb", None):
|
|
660
|
+
try:
|
|
661
|
+
session_id = server_conn._get_session_id_cb() # type: ignore[attr-defined]
|
|
662
|
+
except Exception:
|
|
663
|
+
session_id = None
|
|
664
|
+
metrics = getattr(server_conn, "transport_metrics", None)
|
|
665
|
+
if metrics is not None:
|
|
666
|
+
try:
|
|
667
|
+
transport_snapshot = metrics.snapshot()
|
|
668
|
+
except Exception:
|
|
669
|
+
logger.debug(
|
|
670
|
+
"Failed to snapshot transport metrics for server '%s'",
|
|
671
|
+
server_name,
|
|
672
|
+
exc_info=True,
|
|
673
|
+
)
|
|
674
|
+
except Exception as exc:
|
|
675
|
+
logger.debug(
|
|
676
|
+
f"Failed to collect status for server '{server_name}'",
|
|
677
|
+
data={"error": str(exc)},
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
if server_cfg is None and self.context and getattr(self.context, "server_registry", None):
|
|
681
|
+
try:
|
|
682
|
+
server_cfg = self.context.server_registry.get_server_config(server_name)
|
|
683
|
+
except Exception:
|
|
684
|
+
server_cfg = None
|
|
685
|
+
|
|
686
|
+
if server_cfg is not None:
|
|
687
|
+
instructions_enabled = (
|
|
688
|
+
instructions_enabled
|
|
689
|
+
if instructions_enabled is not None
|
|
690
|
+
else server_cfg.include_instructions
|
|
691
|
+
)
|
|
692
|
+
roots = getattr(server_cfg, "roots", None)
|
|
693
|
+
roots_configured = bool(roots)
|
|
694
|
+
roots_count = len(roots) if roots else 0
|
|
695
|
+
transport = getattr(server_cfg, "transport", transport)
|
|
696
|
+
elicitation = getattr(server_cfg, "elicitation", None)
|
|
697
|
+
elicitation_mode = (
|
|
698
|
+
getattr(elicitation, "mode", None)
|
|
699
|
+
if elicitation
|
|
700
|
+
else elicitation_mode
|
|
701
|
+
)
|
|
702
|
+
sampling_cfg = getattr(server_cfg, "sampling", None)
|
|
703
|
+
spoofing_enabled = bool(getattr(server_cfg, "implementation", None))
|
|
704
|
+
if implementation_name is None and getattr(server_cfg, "implementation", None):
|
|
705
|
+
implementation_name = server_cfg.implementation.name
|
|
706
|
+
implementation_version = getattr(server_cfg.implementation, "version", None)
|
|
707
|
+
if session_id is None:
|
|
708
|
+
if server_cfg.transport == "stdio":
|
|
709
|
+
session_id = "local"
|
|
710
|
+
elif server_conn and getattr(server_conn, "_get_session_id_cb", None):
|
|
711
|
+
try:
|
|
712
|
+
session_id = server_conn._get_session_id_cb() # type: ignore[attr-defined]
|
|
713
|
+
except Exception:
|
|
714
|
+
session_id = None
|
|
715
|
+
|
|
716
|
+
if sampling_cfg is not None:
|
|
717
|
+
sampling_mode = "configured"
|
|
718
|
+
else:
|
|
719
|
+
auto_sampling = True
|
|
720
|
+
if self.context and getattr(self.context, "config", None):
|
|
721
|
+
auto_sampling = getattr(self.context.config, "auto_sampling", True)
|
|
722
|
+
sampling_mode = "auto" if auto_sampling else "off"
|
|
723
|
+
else:
|
|
724
|
+
# Fall back to defaults when config missing
|
|
725
|
+
auto_sampling = True
|
|
726
|
+
if self.context and getattr(self.context, "config", None):
|
|
727
|
+
auto_sampling = getattr(self.context.config, "auto_sampling", True)
|
|
728
|
+
sampling_mode = sampling_mode or ("auto" if auto_sampling else "off")
|
|
729
|
+
|
|
730
|
+
status_map[server_name] = ServerStatus(
|
|
731
|
+
server_name=server_name,
|
|
732
|
+
implementation_name=implementation_name,
|
|
733
|
+
implementation_version=implementation_version,
|
|
734
|
+
server_capabilities=capabilities,
|
|
735
|
+
client_capabilities=client_capabilities,
|
|
736
|
+
client_info_name=client_info_name,
|
|
737
|
+
client_info_version=client_info_version,
|
|
738
|
+
transport=transport,
|
|
739
|
+
is_connected=is_connected,
|
|
740
|
+
last_call_at=last_call,
|
|
741
|
+
last_error_at=last_error,
|
|
742
|
+
staleness_seconds=staleness,
|
|
743
|
+
call_counts=call_counts,
|
|
744
|
+
error_message=error_message,
|
|
745
|
+
instructions_available=instructions_available,
|
|
746
|
+
instructions_enabled=instructions_enabled,
|
|
747
|
+
instructions_included=instructions_included,
|
|
748
|
+
roots_configured=roots_configured,
|
|
749
|
+
roots_count=roots_count,
|
|
750
|
+
elicitation_mode=elicitation_mode,
|
|
751
|
+
sampling_mode=sampling_mode,
|
|
752
|
+
spoofing_enabled=spoofing_enabled,
|
|
753
|
+
session_id=session_id,
|
|
754
|
+
transport_channels=transport_snapshot,
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
return status_map
|
|
758
|
+
|
|
495
759
|
async def _execute_on_server(
|
|
496
760
|
self,
|
|
497
761
|
server_name: str,
|
|
@@ -554,13 +818,17 @@ class MCPAggregator(ContextDependent):
|
|
|
554
818
|
# Re-raise the original exception to propagate it
|
|
555
819
|
raise e
|
|
556
820
|
|
|
821
|
+
success_flag: bool | None = None
|
|
822
|
+
result: R | None = None
|
|
823
|
+
|
|
557
824
|
# Try initial execution
|
|
558
825
|
try:
|
|
559
826
|
if self.connection_persistence:
|
|
560
827
|
server_connection = await self._persistent_connection_manager.get_server(
|
|
561
828
|
server_name, client_session_factory=self._create_session_factory(server_name)
|
|
562
829
|
)
|
|
563
|
-
|
|
830
|
+
result = await try_execute(server_connection.session)
|
|
831
|
+
success_flag = True
|
|
564
832
|
else:
|
|
565
833
|
logger.debug(
|
|
566
834
|
f"Creating temporary connection to server: {server_name}",
|
|
@@ -582,7 +850,7 @@ class MCPAggregator(ContextDependent):
|
|
|
582
850
|
"agent_name": self.agent_name,
|
|
583
851
|
},
|
|
584
852
|
)
|
|
585
|
-
|
|
853
|
+
success_flag = True
|
|
586
854
|
except ConnectionError:
|
|
587
855
|
# Server offline - attempt reconnection
|
|
588
856
|
from fast_agent.ui import console
|
|
@@ -613,7 +881,7 @@ class MCPAggregator(ContextDependent):
|
|
|
613
881
|
|
|
614
882
|
# Success!
|
|
615
883
|
console.console.print(f"[dim green]MCP server {server_name} online[/dim green]")
|
|
616
|
-
|
|
884
|
+
success_flag = True
|
|
617
885
|
|
|
618
886
|
except Exception:
|
|
619
887
|
# Reconnection failed
|
|
@@ -621,10 +889,19 @@ class MCPAggregator(ContextDependent):
|
|
|
621
889
|
f"[dim red]MCP server {server_name} offline - failed to reconnect[/dim red]"
|
|
622
890
|
)
|
|
623
891
|
error_msg = f"MCP server {server_name} offline - failed to reconnect"
|
|
892
|
+
success_flag = False
|
|
624
893
|
if error_factory:
|
|
625
|
-
|
|
894
|
+
result = error_factory(error_msg)
|
|
626
895
|
else:
|
|
627
896
|
raise Exception(error_msg)
|
|
897
|
+
except Exception:
|
|
898
|
+
success_flag = False
|
|
899
|
+
raise
|
|
900
|
+
finally:
|
|
901
|
+
if success_flag is not None:
|
|
902
|
+
await self._record_server_call(server_name, operation_type, success_flag)
|
|
903
|
+
|
|
904
|
+
return result
|
|
628
905
|
|
|
629
906
|
async def _parse_resource_name(self, name: str, resource_type: str) -> tuple[str, str]:
|
|
630
907
|
"""
|
|
@@ -21,10 +21,9 @@ from mcp.client.sse import sse_client
|
|
|
21
21
|
from mcp.client.stdio import (
|
|
22
22
|
StdioServerParameters,
|
|
23
23
|
get_default_environment,
|
|
24
|
-
stdio_client,
|
|
25
24
|
)
|
|
26
|
-
from mcp.client.streamable_http import GetSessionIdCallback
|
|
27
|
-
from mcp.types import JSONRPCMessage, ServerCapabilities
|
|
25
|
+
from mcp.client.streamable_http import GetSessionIdCallback
|
|
26
|
+
from mcp.types import Implementation, JSONRPCMessage, ServerCapabilities
|
|
28
27
|
|
|
29
28
|
from fast_agent.config import MCPServerSettings
|
|
30
29
|
from fast_agent.context_dependent import ContextDependent
|
|
@@ -34,6 +33,9 @@ from fast_agent.event_progress import ProgressAction
|
|
|
34
33
|
from fast_agent.mcp.logger_textio import get_stderr_handler
|
|
35
34
|
from fast_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
|
|
36
35
|
from fast_agent.mcp.oauth_client import build_oauth_provider
|
|
36
|
+
from fast_agent.mcp.stdio_tracking_simple import tracking_stdio_client
|
|
37
|
+
from fast_agent.mcp.streamable_http_tracking import tracking_streamablehttp_client
|
|
38
|
+
from fast_agent.mcp.transport_tracking import TransportChannelMetrics
|
|
37
39
|
|
|
38
40
|
if TYPE_CHECKING:
|
|
39
41
|
from fast_agent.context import Context
|
|
@@ -107,6 +109,14 @@ class ServerConnection:
|
|
|
107
109
|
|
|
108
110
|
# Server instructions from initialization
|
|
109
111
|
self.server_instructions: str | None = None
|
|
112
|
+
self.server_capabilities: ServerCapabilities | None = None
|
|
113
|
+
self.server_implementation: Implementation | None = None
|
|
114
|
+
self.client_capabilities: dict | None = None
|
|
115
|
+
self.server_instructions_available: bool = False
|
|
116
|
+
self.server_instructions_enabled: bool = server_config.include_instructions if server_config else True
|
|
117
|
+
self.session_id: str | None = None
|
|
118
|
+
self._get_session_id_cb: GetSessionIdCallback | None = None
|
|
119
|
+
self.transport_metrics: TransportChannelMetrics | None = None
|
|
110
120
|
|
|
111
121
|
def is_healthy(self) -> bool:
|
|
112
122
|
"""Check if the server connection is healthy and ready to use."""
|
|
@@ -138,15 +148,32 @@ class ServerConnection:
|
|
|
138
148
|
result = await self.session.initialize()
|
|
139
149
|
|
|
140
150
|
self.server_capabilities = result.capabilities
|
|
151
|
+
# InitializeResult exposes server info via `serverInfo`; keep fallback for older fields
|
|
152
|
+
implementation = getattr(result, "serverInfo", None)
|
|
153
|
+
if implementation is None:
|
|
154
|
+
implementation = getattr(result, "implementation", None)
|
|
155
|
+
self.server_implementation = implementation
|
|
156
|
+
|
|
157
|
+
raw_instructions = getattr(result, "instructions", None)
|
|
158
|
+
self.server_instructions_available = bool(raw_instructions)
|
|
141
159
|
|
|
142
160
|
# Store instructions if provided by the server and enabled in config
|
|
143
161
|
if self.server_config.include_instructions:
|
|
144
|
-
self.server_instructions =
|
|
162
|
+
self.server_instructions = raw_instructions
|
|
145
163
|
if self.server_instructions:
|
|
146
|
-
logger.debug(
|
|
164
|
+
logger.debug(
|
|
165
|
+
f"{self.server_name}: Received server instructions",
|
|
166
|
+
data={"instructions": self.server_instructions},
|
|
167
|
+
)
|
|
147
168
|
else:
|
|
148
169
|
self.server_instructions = None
|
|
149
|
-
|
|
170
|
+
if self.server_instructions_available:
|
|
171
|
+
logger.debug(
|
|
172
|
+
f"{self.server_name}: Server instructions disabled by configuration",
|
|
173
|
+
data={"instructions": raw_instructions},
|
|
174
|
+
)
|
|
175
|
+
else:
|
|
176
|
+
logger.debug(f"{self.server_name}: No server instructions provided")
|
|
150
177
|
|
|
151
178
|
# If there's an init hook, run it
|
|
152
179
|
|
|
@@ -175,10 +202,15 @@ class ServerConnection:
|
|
|
175
202
|
)
|
|
176
203
|
|
|
177
204
|
session = self._client_session_factory(
|
|
178
|
-
read_stream,
|
|
205
|
+
read_stream,
|
|
206
|
+
send_stream,
|
|
207
|
+
read_timeout,
|
|
208
|
+
server_config=self.server_config,
|
|
209
|
+
transport_metrics=self.transport_metrics,
|
|
179
210
|
)
|
|
180
211
|
|
|
181
212
|
self.session = session
|
|
213
|
+
self.client_capabilities = getattr(session, "client_capabilities", None)
|
|
182
214
|
|
|
183
215
|
return session
|
|
184
216
|
|
|
@@ -192,11 +224,30 @@ async def _server_lifecycle_task(server_conn: ServerConnection) -> None:
|
|
|
192
224
|
try:
|
|
193
225
|
transport_context = server_conn._transport_context_factory()
|
|
194
226
|
|
|
195
|
-
async with transport_context as (read_stream, write_stream,
|
|
227
|
+
async with transport_context as (read_stream, write_stream, get_session_id_cb):
|
|
228
|
+
server_conn._get_session_id_cb = get_session_id_cb
|
|
229
|
+
|
|
230
|
+
if get_session_id_cb is not None:
|
|
231
|
+
try:
|
|
232
|
+
server_conn.session_id = get_session_id_cb()
|
|
233
|
+
except Exception:
|
|
234
|
+
logger.debug(f"{server_name}: Unable to retrieve session id from transport")
|
|
235
|
+
elif server_conn.server_config.transport == "stdio":
|
|
236
|
+
server_conn.session_id = "local"
|
|
237
|
+
|
|
196
238
|
server_conn.create_session(read_stream, write_stream)
|
|
197
239
|
|
|
198
240
|
async with server_conn.session:
|
|
199
241
|
await server_conn.initialize_session()
|
|
242
|
+
|
|
243
|
+
if get_session_id_cb is not None:
|
|
244
|
+
try:
|
|
245
|
+
server_conn.session_id = get_session_id_cb() or server_conn.session_id
|
|
246
|
+
except Exception:
|
|
247
|
+
logger.debug(f"{server_name}: Unable to refresh session id after init")
|
|
248
|
+
elif server_conn.server_config.transport == "stdio":
|
|
249
|
+
server_conn.session_id = "local"
|
|
250
|
+
|
|
200
251
|
await server_conn.wait_for_shutdown_request()
|
|
201
252
|
|
|
202
253
|
except HTTPStatusError as http_exc:
|
|
@@ -353,6 +404,8 @@ class MCPConnectionManager(ContextDependent):
|
|
|
353
404
|
|
|
354
405
|
logger.debug(f"{server_name}: Found server configuration=", data=config.model_dump())
|
|
355
406
|
|
|
407
|
+
transport_metrics = TransportChannelMetrics() if config.transport in ("http", "stdio") else None
|
|
408
|
+
|
|
356
409
|
def transport_context_factory():
|
|
357
410
|
if config.transport == "stdio":
|
|
358
411
|
if not config.command:
|
|
@@ -369,7 +422,11 @@ class MCPConnectionManager(ContextDependent):
|
|
|
369
422
|
error_handler = get_stderr_handler(server_name)
|
|
370
423
|
# Explicitly ensure we're using our custom logger for stderr
|
|
371
424
|
logger.debug(f"{server_name}: Creating stdio client with custom error handler")
|
|
372
|
-
|
|
425
|
+
|
|
426
|
+
channel_hook = transport_metrics.record_event if transport_metrics else None
|
|
427
|
+
return _add_none_to_context(
|
|
428
|
+
tracking_stdio_client(server_params, channel_hook=channel_hook, errlog=error_handler)
|
|
429
|
+
)
|
|
373
430
|
elif config.transport == "sse":
|
|
374
431
|
if not config.url:
|
|
375
432
|
raise ValueError(
|
|
@@ -401,7 +458,23 @@ class MCPConnectionManager(ContextDependent):
|
|
|
401
458
|
if oauth_auth is not None:
|
|
402
459
|
headers.pop("Authorization", None)
|
|
403
460
|
headers.pop("X-HF-Authorization", None)
|
|
404
|
-
|
|
461
|
+
channel_hook = None
|
|
462
|
+
if transport_metrics is not None:
|
|
463
|
+
def channel_hook(event):
|
|
464
|
+
try:
|
|
465
|
+
transport_metrics.record_event(event)
|
|
466
|
+
except Exception: # pragma: no cover - defensive guard
|
|
467
|
+
logger.debug(
|
|
468
|
+
"%s: transport metrics hook failed", server_name,
|
|
469
|
+
exc_info=True,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
return tracking_streamablehttp_client(
|
|
473
|
+
config.url,
|
|
474
|
+
headers,
|
|
475
|
+
auth=oauth_auth,
|
|
476
|
+
channel_hook=channel_hook,
|
|
477
|
+
)
|
|
405
478
|
else:
|
|
406
479
|
raise ValueError(f"Unsupported transport: {config.transport}")
|
|
407
480
|
|
|
@@ -412,6 +485,9 @@ class MCPConnectionManager(ContextDependent):
|
|
|
412
485
|
client_session_factory=client_session_factory,
|
|
413
486
|
)
|
|
414
487
|
|
|
488
|
+
if transport_metrics is not None:
|
|
489
|
+
server_conn.transport_metrics = transport_metrics
|
|
490
|
+
|
|
415
491
|
async with self._lock:
|
|
416
492
|
# Check if already running
|
|
417
493
|
if server_name in self.running_servers:
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from typing import TYPE_CHECKING, AsyncGenerator, Callable
|
|
6
|
+
|
|
7
|
+
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
8
|
+
|
|
9
|
+
from fast_agent.mcp.transport_tracking import ChannelEvent
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from anyio.abc import ObjectReceiveStream, ObjectSendStream
|
|
13
|
+
from mcp.shared.message import SessionMessage
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
ChannelHook = Callable[[ChannelEvent], None]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@asynccontextmanager
|
|
21
|
+
async def tracking_stdio_client(
|
|
22
|
+
server_params: StdioServerParameters,
|
|
23
|
+
*,
|
|
24
|
+
channel_hook: ChannelHook | None = None,
|
|
25
|
+
errlog: Callable[[str], None] | None = None,
|
|
26
|
+
) -> AsyncGenerator[
|
|
27
|
+
tuple[ObjectReceiveStream[SessionMessage | Exception], ObjectSendStream[SessionMessage]], None
|
|
28
|
+
]:
|
|
29
|
+
"""Context manager for stdio client with basic connection tracking."""
|
|
30
|
+
|
|
31
|
+
def emit_channel_event(event_type: str, detail: str | None = None) -> None:
|
|
32
|
+
if channel_hook is None:
|
|
33
|
+
return
|
|
34
|
+
try:
|
|
35
|
+
channel_hook(
|
|
36
|
+
ChannelEvent(
|
|
37
|
+
channel="stdio",
|
|
38
|
+
event_type=event_type, # type: ignore[arg-type]
|
|
39
|
+
detail=detail,
|
|
40
|
+
)
|
|
41
|
+
)
|
|
42
|
+
except Exception: # pragma: no cover - hook errors must not break transport
|
|
43
|
+
logger.exception("Channel hook raised an exception")
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
# Emit connection event
|
|
47
|
+
emit_channel_event("connect")
|
|
48
|
+
|
|
49
|
+
# Use the original stdio_client without stream interception
|
|
50
|
+
async with stdio_client(server_params, errlog=errlog) as (read_stream, write_stream):
|
|
51
|
+
yield read_stream, write_stream
|
|
52
|
+
|
|
53
|
+
except Exception as exc:
|
|
54
|
+
# Emit error event
|
|
55
|
+
emit_channel_event("error", detail=str(exc))
|
|
56
|
+
raise
|
|
57
|
+
finally:
|
|
58
|
+
# Emit disconnection event
|
|
59
|
+
emit_channel_event("disconnect")
|