fast-agent-mcp 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.2.21.dist-info → fast_agent_mcp-0.2.23.dist-info}/METADATA +10 -8
- {fast_agent_mcp-0.2.21.dist-info → fast_agent_mcp-0.2.23.dist-info}/RECORD +25 -22
- mcp_agent/agents/workflow/orchestrator_agent.py +2 -2
- mcp_agent/cli/commands/go.py +136 -33
- mcp_agent/cli/commands/url_parser.py +185 -0
- mcp_agent/config.py +16 -1
- mcp_agent/core/fastagent.py +2 -2
- mcp_agent/core/request_params.py +11 -7
- mcp_agent/event_progress.py +1 -1
- mcp_agent/llm/augmented_llm.py +3 -9
- mcp_agent/llm/model_factory.py +8 -0
- mcp_agent/llm/provider_types.py +1 -0
- mcp_agent/llm/providers/augmented_llm_anthropic.py +1 -0
- mcp_agent/llm/providers/augmented_llm_openai.py +14 -1
- mcp_agent/llm/providers/augmented_llm_tensorzero.py +442 -0
- mcp_agent/llm/providers/multipart_converter_tensorzero.py +200 -0
- mcp_agent/mcp/mcp_connection_manager.py +78 -10
- mcp_agent/mcp/prompts/prompt_server.py +12 -4
- mcp_agent/mcp_server/agent_server.py +13 -10
- mcp_agent/mcp_server_registry.py +51 -9
- mcp_agent/resources/examples/mcp/state-transfer/fastagent.config.yaml +2 -2
- mcp_agent/ui/console_display.py +7 -6
- {fast_agent_mcp-0.2.21.dist-info → fast_agent_mcp-0.2.23.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.2.21.dist-info → fast_agent_mcp-0.2.23.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.2.21.dist-info → fast_agent_mcp-0.2.23.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,200 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Any, Dict, List, Optional, Union
|
3
|
+
|
4
|
+
from mcp.types import (
|
5
|
+
CallToolResult,
|
6
|
+
EmbeddedResource,
|
7
|
+
ImageContent,
|
8
|
+
TextContent,
|
9
|
+
)
|
10
|
+
|
11
|
+
from mcp_agent.logging.logger import get_logger
|
12
|
+
from mcp_agent.mcp.helpers.content_helpers import (
|
13
|
+
get_text,
|
14
|
+
)
|
15
|
+
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
16
|
+
|
17
|
+
_logger = get_logger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
class TensorZeroConverter:
|
21
|
+
"""Converts MCP message types to/from TensorZero API format."""
|
22
|
+
|
23
|
+
@staticmethod
|
24
|
+
def _convert_content_part(
|
25
|
+
part: Union[TextContent, ImageContent, EmbeddedResource],
|
26
|
+
) -> Optional[Dict[str, Any]]:
|
27
|
+
"""Converts a single MCP content part to a T0 content block dictionary."""
|
28
|
+
if isinstance(part, TextContent):
|
29
|
+
text = get_text(part)
|
30
|
+
if text is not None:
|
31
|
+
return {"type": "text", "text": text}
|
32
|
+
elif isinstance(part, ImageContent):
|
33
|
+
# Handle Base64: needs data, mimeType (and mimeType must not be empty)
|
34
|
+
if hasattr(part, "data") and part.data and hasattr(part, "mimeType") and part.mimeType:
|
35
|
+
_logger.debug(
|
36
|
+
f"Converting ImageContent as base64 for T0 native: mime={part.mimeType}, data_len={len(part.data) if isinstance(part.data, str) else 'N/A'}"
|
37
|
+
)
|
38
|
+
supported_mime_types = ["image/jpeg", "image/png", "image/webp"]
|
39
|
+
mime_type = getattr(part, "mimeType", "")
|
40
|
+
|
41
|
+
# Use the provided mime_type if supported, otherwise default to png
|
42
|
+
if mime_type not in supported_mime_types:
|
43
|
+
_logger.warning(
|
44
|
+
f"Unsupported mimeType '{mime_type}' for T0 base64 image, defaulting to image/png."
|
45
|
+
)
|
46
|
+
mime_type = "image/png"
|
47
|
+
|
48
|
+
return {
|
49
|
+
"type": "image",
|
50
|
+
"mime_type": mime_type, # Note: T0 uses mime_type, not media_type
|
51
|
+
"data": getattr(part, "data", ""), # Data is direct property
|
52
|
+
}
|
53
|
+
else:
|
54
|
+
# Log cases where it's an ImageContent but doesn't fit Base64 criteria
|
55
|
+
_logger.warning(
|
56
|
+
f"Skipping ImageContent: Missing required base64 fields (mimeType/data), or mimeType is empty: {part}"
|
57
|
+
)
|
58
|
+
|
59
|
+
elif isinstance(part, EmbeddedResource):
|
60
|
+
_logger.warning(f"Skipping EmbeddedResource, T0 conversion not implemented: {part}")
|
61
|
+
else:
|
62
|
+
_logger.error(
|
63
|
+
f"Unsupported content part type for T0 conversion: {type(part)}"
|
64
|
+
) # Changed to error
|
65
|
+
|
66
|
+
return None # Return None if no block was successfully created
|
67
|
+
|
68
|
+
@staticmethod
|
69
|
+
def _get_text_from_call_tool_result(result: CallToolResult) -> str:
|
70
|
+
"""Helper to extract combined text from a CallToolResult's content list."""
|
71
|
+
texts = []
|
72
|
+
if result.content:
|
73
|
+
for part in result.content:
|
74
|
+
text = get_text(part)
|
75
|
+
if text:
|
76
|
+
texts.append(text)
|
77
|
+
return "\n".join(texts)
|
78
|
+
|
79
|
+
@staticmethod
|
80
|
+
def convert_tool_results_to_t0_user_message(
|
81
|
+
results: List[CallToolResult],
|
82
|
+
) -> Optional[Dict[str, Any]]:
|
83
|
+
"""Formats CallToolResult list into T0's tool_result blocks within a user message dict."""
|
84
|
+
t0_tool_result_blocks = []
|
85
|
+
for result in results:
|
86
|
+
tool_use_id = getattr(result, "_t0_tool_use_id_temp", None)
|
87
|
+
tool_name = getattr(result, "_t0_tool_name_temp", None)
|
88
|
+
|
89
|
+
if tool_use_id and tool_name:
|
90
|
+
result_content_str = TensorZeroConverter._get_text_from_call_tool_result(result)
|
91
|
+
try:
|
92
|
+
# Attempt to treat result as JSON if possible, else use raw string
|
93
|
+
try:
|
94
|
+
json_result = json.loads(result_content_str)
|
95
|
+
except json.JSONDecodeError:
|
96
|
+
json_result = result_content_str # Fallback to string if not valid JSON
|
97
|
+
except Exception as e:
|
98
|
+
_logger.error(f"Unexpected error processing tool result content: {e}")
|
99
|
+
json_result = str(result_content_str) # Safest fallback
|
100
|
+
|
101
|
+
t0_block = {
|
102
|
+
"type": "tool_result",
|
103
|
+
"id": tool_use_id,
|
104
|
+
"name": tool_name,
|
105
|
+
"result": json_result, # T0 expects the result directly
|
106
|
+
}
|
107
|
+
t0_tool_result_blocks.append(t0_block)
|
108
|
+
|
109
|
+
# Clean up temporary attributes
|
110
|
+
try:
|
111
|
+
delattr(result, "_t0_tool_use_id_temp")
|
112
|
+
delattr(result, "_t0_tool_name_temp")
|
113
|
+
if hasattr(result, "_t0_is_error_temp"):
|
114
|
+
delattr(result, "_t0_is_error_temp")
|
115
|
+
except AttributeError:
|
116
|
+
pass
|
117
|
+
else:
|
118
|
+
_logger.warning(
|
119
|
+
f"Could not find id/name temp attributes for CallToolResult: {result}"
|
120
|
+
)
|
121
|
+
|
122
|
+
if not t0_tool_result_blocks:
|
123
|
+
return None
|
124
|
+
|
125
|
+
return {"role": "user", "content": t0_tool_result_blocks}
|
126
|
+
|
127
|
+
@staticmethod
|
128
|
+
def convert_mcp_to_t0_message(msg: PromptMessageMultipart) -> Optional[Dict[str, Any]]:
|
129
|
+
"""
|
130
|
+
Converts a single PromptMessageMultipart to a T0 API message dictionary.
|
131
|
+
Handles Text, Image, and embedded CallToolResult content.
|
132
|
+
Skips system messages.
|
133
|
+
"""
|
134
|
+
if msg.role == "system":
|
135
|
+
return None
|
136
|
+
|
137
|
+
t0_content_blocks = []
|
138
|
+
contains_tool_result = False
|
139
|
+
|
140
|
+
for part in msg.content:
|
141
|
+
# Use the corrected _convert_content_part
|
142
|
+
converted_block = TensorZeroConverter._convert_content_part(part)
|
143
|
+
if converted_block:
|
144
|
+
t0_content_blocks.append(converted_block)
|
145
|
+
elif isinstance(part, CallToolResult):
|
146
|
+
# Existing logic for handling embedded CallToolResult (seems compatible with T0 tool_result spec)
|
147
|
+
contains_tool_result = True
|
148
|
+
tool_use_id = getattr(part, "_t0_tool_use_id_temp", None)
|
149
|
+
tool_name = getattr(part, "_t0_tool_name_temp", None)
|
150
|
+
if tool_use_id and tool_name:
|
151
|
+
result_content_str = TensorZeroConverter._get_text_from_call_tool_result(part)
|
152
|
+
# Try to format result as JSON object/string
|
153
|
+
try:
|
154
|
+
json_result = json.loads(result_content_str)
|
155
|
+
except json.JSONDecodeError:
|
156
|
+
json_result = result_content_str # Fallback
|
157
|
+
except Exception as e:
|
158
|
+
_logger.error(f"Error processing embedded tool result: {e}")
|
159
|
+
json_result = str(result_content_str)
|
160
|
+
|
161
|
+
t0_content_blocks.append(
|
162
|
+
{
|
163
|
+
"type": "tool_result",
|
164
|
+
"id": tool_use_id,
|
165
|
+
"name": tool_name,
|
166
|
+
"result": json_result,
|
167
|
+
}
|
168
|
+
)
|
169
|
+
# Clean up temp attributes
|
170
|
+
try:
|
171
|
+
delattr(part, "_t0_tool_use_id_temp")
|
172
|
+
delattr(part, "_t0_tool_name_temp")
|
173
|
+
except AttributeError:
|
174
|
+
pass
|
175
|
+
else:
|
176
|
+
_logger.warning(
|
177
|
+
f"Found embedded CallToolResult without required temp attributes: {part}"
|
178
|
+
)
|
179
|
+
# Note: The _convert_content_part handles logging for other skipped/unsupported types
|
180
|
+
|
181
|
+
if not t0_content_blocks:
|
182
|
+
return None
|
183
|
+
|
184
|
+
# Determine role - logic remains the same
|
185
|
+
valid_role = msg.role if msg.role in ["user", "assistant"] else "user"
|
186
|
+
if contains_tool_result and all(
|
187
|
+
block.get("type") == "tool_result" for block in t0_content_blocks
|
188
|
+
):
|
189
|
+
final_role = "user"
|
190
|
+
if valid_role != final_role:
|
191
|
+
_logger.debug(f"Overriding role to '{final_role}' for tool result message.")
|
192
|
+
else:
|
193
|
+
final_role = valid_role
|
194
|
+
if valid_role != msg.role:
|
195
|
+
_logger.warning(f"Mapping message role '{msg.role}' to '{valid_role}' for T0.")
|
196
|
+
|
197
|
+
return {"role": final_role, "content": t0_content_blocks}
|
198
|
+
|
199
|
+
# Add methods here if needed to convert *from* T0 format back to MCP types
|
200
|
+
# e.g., adapt_t0_response_to_mcp(...) - this logic stays in the LLM class for now
|
@@ -15,6 +15,7 @@ from typing import (
|
|
15
15
|
|
16
16
|
from anyio import Event, Lock, create_task_group
|
17
17
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
18
|
+
from httpx import HTTPStatusError
|
18
19
|
from mcp import ClientSession
|
19
20
|
from mcp.client.sse import sse_client
|
20
21
|
from mcp.client.stdio import (
|
@@ -22,6 +23,7 @@ from mcp.client.stdio import (
|
|
22
23
|
get_default_environment,
|
23
24
|
stdio_client,
|
24
25
|
)
|
26
|
+
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
|
25
27
|
from mcp.types import JSONRPCMessage, ServerCapabilities
|
26
28
|
|
27
29
|
from mcp_agent.config import MCPServerSettings
|
@@ -39,6 +41,27 @@ if TYPE_CHECKING:
|
|
39
41
|
logger = get_logger(__name__)
|
40
42
|
|
41
43
|
|
44
|
+
class StreamingContextAdapter:
|
45
|
+
"""Adapter to provide a 3-value context from a 2-value context manager"""
|
46
|
+
|
47
|
+
def __init__(self, context_manager):
|
48
|
+
self.context_manager = context_manager
|
49
|
+
self.cm_instance = None
|
50
|
+
|
51
|
+
async def __aenter__(self):
|
52
|
+
self.cm_instance = await self.context_manager.__aenter__()
|
53
|
+
read_stream, write_stream = self.cm_instance
|
54
|
+
return read_stream, write_stream, None
|
55
|
+
|
56
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
57
|
+
return await self.context_manager.__aexit__(exc_type, exc_val, exc_tb)
|
58
|
+
|
59
|
+
|
60
|
+
def _add_none_to_context(context_manager):
|
61
|
+
"""Helper to add a None value to context managers that return 2 values instead of 3"""
|
62
|
+
return StreamingContextAdapter(context_manager)
|
63
|
+
|
64
|
+
|
42
65
|
class ServerConnection:
|
43
66
|
"""
|
44
67
|
Represents a long-lived MCP server connection, including:
|
@@ -56,6 +79,7 @@ class ServerConnection:
|
|
56
79
|
tuple[
|
57
80
|
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
58
81
|
MemoryObjectSendStream[JSONRPCMessage],
|
82
|
+
GetSessionIdCallback | None,
|
59
83
|
],
|
60
84
|
None,
|
61
85
|
],
|
@@ -161,15 +185,27 @@ async def _server_lifecycle_task(server_conn: ServerConnection) -> None:
|
|
161
185
|
try:
|
162
186
|
transport_context = server_conn._transport_context_factory()
|
163
187
|
|
164
|
-
async with transport_context as (read_stream, write_stream):
|
165
|
-
# try:
|
188
|
+
async with transport_context as (read_stream, write_stream, _):
|
166
189
|
server_conn.create_session(read_stream, write_stream)
|
167
190
|
|
168
191
|
async with server_conn.session:
|
169
192
|
await server_conn.initialize_session()
|
170
|
-
|
171
193
|
await server_conn.wait_for_shutdown_request()
|
172
194
|
|
195
|
+
except HTTPStatusError as http_exc:
|
196
|
+
logger.error(
|
197
|
+
f"{server_name}: Lifecycle task encountered HTTP error: {http_exc}",
|
198
|
+
exc_info=True,
|
199
|
+
data={
|
200
|
+
"progress_action": ProgressAction.FATAL_ERROR,
|
201
|
+
"server_name": server_name,
|
202
|
+
},
|
203
|
+
)
|
204
|
+
server_conn._error_occurred = True
|
205
|
+
server_conn._error_message = f"HTTP Error: {http_exc.response.status_code} {http_exc.response.reason_phrase} for URL: {http_exc.request.url}"
|
206
|
+
server_conn._initialized_event.set()
|
207
|
+
# No raise - let get_server handle it with a friendly message
|
208
|
+
|
173
209
|
except Exception as exc:
|
174
210
|
logger.error(
|
175
211
|
f"{server_name}: Lifecycle task encountered an error: {exc}",
|
@@ -180,7 +216,27 @@ async def _server_lifecycle_task(server_conn: ServerConnection) -> None:
|
|
180
216
|
},
|
181
217
|
)
|
182
218
|
server_conn._error_occurred = True
|
183
|
-
|
219
|
+
|
220
|
+
if "ExceptionGroup" in type(exc).__name__ and hasattr(exc, "exceptions"):
|
221
|
+
# Handle ExceptionGroup better by extracting the actual errors
|
222
|
+
error_messages = []
|
223
|
+
for subexc in exc.exceptions:
|
224
|
+
if isinstance(subexc, HTTPStatusError):
|
225
|
+
# Special handling for HTTP errors to make them more user-friendly
|
226
|
+
error_messages.append(
|
227
|
+
f"HTTP Error: {subexc.response.status_code} {subexc.response.reason_phrase} for URL: {subexc.request.url}"
|
228
|
+
)
|
229
|
+
else:
|
230
|
+
error_messages.append(f"Error: {type(subexc).__name__}: {subexc}")
|
231
|
+
if hasattr(subexc, "__cause__") and subexc.__cause__:
|
232
|
+
error_messages.append(
|
233
|
+
f"Caused by: {type(subexc.__cause__).__name__}: {subexc.__cause__}"
|
234
|
+
)
|
235
|
+
server_conn._error_message = error_messages
|
236
|
+
else:
|
237
|
+
# For regular exceptions, keep the traceback but format it more cleanly
|
238
|
+
server_conn._error_message = traceback.format_exception(exc)
|
239
|
+
|
184
240
|
# If there's an error, we should also set the event so that
|
185
241
|
# 'get_server' won't hang
|
186
242
|
server_conn._initialized_event.set()
|
@@ -270,13 +326,17 @@ class MCPConnectionManager(ContextDependent):
|
|
270
326
|
error_handler = get_stderr_handler(server_name)
|
271
327
|
# Explicitly ensure we're using our custom logger for stderr
|
272
328
|
logger.debug(f"{server_name}: Creating stdio client with custom error handler")
|
273
|
-
return stdio_client(server_params, errlog=error_handler)
|
329
|
+
return _add_none_to_context(stdio_client(server_params, errlog=error_handler))
|
274
330
|
elif config.transport == "sse":
|
275
|
-
return
|
276
|
-
|
277
|
-
|
278
|
-
|
331
|
+
return _add_none_to_context(
|
332
|
+
sse_client(
|
333
|
+
config.url,
|
334
|
+
config.headers,
|
335
|
+
sse_read_timeout=config.read_transport_sse_timeout_seconds,
|
336
|
+
)
|
279
337
|
)
|
338
|
+
elif config.transport == "http":
|
339
|
+
return streamablehttp_client(config.url, config.headers)
|
280
340
|
else:
|
281
341
|
raise ValueError(f"Unsupported transport: {config.transport}")
|
282
342
|
|
@@ -333,9 +393,17 @@ class MCPConnectionManager(ContextDependent):
|
|
333
393
|
# Check if the server is healthy after initialization
|
334
394
|
if not server_conn.is_healthy():
|
335
395
|
error_msg = server_conn._error_message or "Unknown error"
|
396
|
+
|
397
|
+
# Format the error message for better display
|
398
|
+
if isinstance(error_msg, list):
|
399
|
+
# Join the list with newlines for better readability
|
400
|
+
formatted_error = "\n".join(error_msg)
|
401
|
+
else:
|
402
|
+
formatted_error = str(error_msg)
|
403
|
+
|
336
404
|
raise ServerInitializationError(
|
337
405
|
f"MCP Server: '{server_name}': Failed to initialize - see details. Check fastagent.config.yaml?",
|
338
|
-
|
406
|
+
formatted_error,
|
339
407
|
)
|
340
408
|
|
341
409
|
return server_conn
|
@@ -335,7 +335,7 @@ def parse_args():
|
|
335
335
|
parser.add_argument(
|
336
336
|
"--transport",
|
337
337
|
type=str,
|
338
|
-
choices=["stdio", "sse"],
|
338
|
+
choices=["stdio", "sse", "http"],
|
339
339
|
default="stdio",
|
340
340
|
help="Transport to use (default: stdio)",
|
341
341
|
)
|
@@ -502,14 +502,22 @@ async def async_main() -> int:
|
|
502
502
|
return await test_prompt(args.test, config)
|
503
503
|
|
504
504
|
# Start the server with the specified transport
|
505
|
-
if config.transport == "
|
506
|
-
await mcp.run_stdio_async()
|
507
|
-
else: # sse
|
505
|
+
if config.transport == "sse": # sse
|
508
506
|
# Set the host and port in settings before running the server
|
509
507
|
mcp.settings.host = config.host
|
510
508
|
mcp.settings.port = config.port
|
511
509
|
logger.info(f"Starting SSE server on {config.host}:{config.port}")
|
512
510
|
await mcp.run_sse_async()
|
511
|
+
elif config.transport == "http":
|
512
|
+
mcp.settings.host = config.host
|
513
|
+
mcp.settings.port = config.port
|
514
|
+
logger.info(f"Starting SSE server on {config.host}:{config.port}")
|
515
|
+
await mcp.run_streamable_http_async()
|
516
|
+
elif config.transport == "stdio":
|
517
|
+
await mcp.run_stdio_async()
|
518
|
+
else:
|
519
|
+
logger.error(f"Unknown transport: {config.transport}")
|
520
|
+
return 1
|
513
521
|
return 0
|
514
522
|
|
515
523
|
|
@@ -140,9 +140,9 @@ class AgentMCPServer:
|
|
140
140
|
print("Press Ctrl+C again to force exit.")
|
141
141
|
self._graceful_shutdown_event.set()
|
142
142
|
|
143
|
-
def run(self, transport: str = "
|
143
|
+
def run(self, transport: str = "http", host: str = "0.0.0.0", port: int = 8000) -> None:
|
144
144
|
"""Run the MCP server synchronously."""
|
145
|
-
if transport
|
145
|
+
if transport in ["sse", "http"]:
|
146
146
|
self.mcp_server.settings.host = host
|
147
147
|
self.mcp_server.settings.port = port
|
148
148
|
|
@@ -180,12 +180,12 @@ class AgentMCPServer:
|
|
180
180
|
asyncio.run(self._cleanup_stdio())
|
181
181
|
|
182
182
|
async def run_async(
|
183
|
-
self, transport: str = "
|
183
|
+
self, transport: str = "http", host: str = "0.0.0.0", port: int = 8000
|
184
184
|
) -> None:
|
185
185
|
"""Run the MCP server asynchronously with improved shutdown handling."""
|
186
186
|
# Use different handling strategies based on transport type
|
187
|
-
if transport
|
188
|
-
# For SSE, use our enhanced shutdown handling
|
187
|
+
if transport in ["sse", "http"]:
|
188
|
+
# For SSE/HTTP, use our enhanced shutdown handling
|
189
189
|
self._setup_signal_handlers()
|
190
190
|
|
191
191
|
self.mcp_server.settings.host = host
|
@@ -236,9 +236,9 @@ class AgentMCPServer:
|
|
236
236
|
|
237
237
|
async def _run_server_with_shutdown(self, transport: str):
|
238
238
|
"""Run the server with proper shutdown handling."""
|
239
|
-
# This method is
|
240
|
-
if transport
|
241
|
-
raise ValueError("This method should only be used with SSE transport")
|
239
|
+
# This method is used for SSE/HTTP transport
|
240
|
+
if transport not in ["sse", "http"]:
|
241
|
+
raise ValueError("This method should only be used with SSE or HTTP transport")
|
242
242
|
|
243
243
|
# Start a monitor task for shutdown
|
244
244
|
shutdown_monitor = asyncio.create_task(self._monitor_shutdown())
|
@@ -262,8 +262,11 @@ class AgentMCPServer:
|
|
262
262
|
# Replace with our tracking version
|
263
263
|
self.mcp_server._sse_transport.connect_sse = tracked_connect_sse
|
264
264
|
|
265
|
-
# Run the server
|
266
|
-
|
265
|
+
# Run the server based on transport type
|
266
|
+
if transport == "sse":
|
267
|
+
await self.mcp_server.run_sse_async()
|
268
|
+
elif transport == "http":
|
269
|
+
await self.mcp_server.run_streamable_http_async()
|
267
270
|
finally:
|
268
271
|
# Cancel the monitor when the server exits
|
269
272
|
shutdown_monitor.cancel()
|
mcp_agent/mcp_server_registry.py
CHANGED
@@ -18,6 +18,7 @@ from mcp.client.stdio import (
|
|
18
18
|
StdioServerParameters,
|
19
19
|
get_default_environment,
|
20
20
|
)
|
21
|
+
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
|
21
22
|
|
22
23
|
from mcp_agent.config import (
|
23
24
|
MCPServerAuthSettings,
|
@@ -27,7 +28,10 @@ from mcp_agent.config import (
|
|
27
28
|
)
|
28
29
|
from mcp_agent.logging.logger import get_logger
|
29
30
|
from mcp_agent.mcp.logger_textio import get_stderr_handler
|
30
|
-
from mcp_agent.mcp.mcp_connection_manager import
|
31
|
+
from mcp_agent.mcp.mcp_connection_manager import (
|
32
|
+
MCPConnectionManager,
|
33
|
+
_add_none_to_context,
|
34
|
+
)
|
31
35
|
|
32
36
|
logger = get_logger(__name__)
|
33
37
|
|
@@ -93,7 +97,12 @@ class ServerRegistry:
|
|
93
97
|
self,
|
94
98
|
server_name: str,
|
95
99
|
client_session_factory: Callable[
|
96
|
-
[
|
100
|
+
[
|
101
|
+
MemoryObjectReceiveStream,
|
102
|
+
MemoryObjectSendStream,
|
103
|
+
timedelta | None,
|
104
|
+
GetSessionIdCallback | None,
|
105
|
+
],
|
97
106
|
ClientSession,
|
98
107
|
] = ClientSession,
|
99
108
|
) -> AsyncGenerator[ClientSession, None]:
|
@@ -132,14 +141,18 @@ class ServerRegistry:
|
|
132
141
|
)
|
133
142
|
|
134
143
|
# Create a stderr handler that logs to our application logger
|
135
|
-
async with
|
144
|
+
async with _add_none_to_context(
|
145
|
+
stdio_client(server_params, errlog=get_stderr_handler(server_name))
|
146
|
+
) as (
|
136
147
|
read_stream,
|
137
148
|
write_stream,
|
149
|
+
_,
|
138
150
|
):
|
139
151
|
session = client_session_factory(
|
140
152
|
read_stream,
|
141
153
|
write_stream,
|
142
154
|
read_timeout_seconds,
|
155
|
+
None, # No callback for stdio
|
143
156
|
)
|
144
157
|
async with session:
|
145
158
|
logger.info(f"{server_name}: Connected to server using stdio transport.")
|
@@ -153,15 +166,18 @@ class ServerRegistry:
|
|
153
166
|
raise ValueError(f"URL is required for SSE transport: {server_name}")
|
154
167
|
|
155
168
|
# Use sse_client to get the read and write streams
|
156
|
-
async with
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
169
|
+
async with _add_none_to_context(
|
170
|
+
sse_client(
|
171
|
+
config.url,
|
172
|
+
config.headers,
|
173
|
+
sse_read_timeout=config.read_transport_sse_timeout_seconds,
|
174
|
+
)
|
175
|
+
) as (read_stream, write_stream, _):
|
161
176
|
session = client_session_factory(
|
162
177
|
read_stream,
|
163
178
|
write_stream,
|
164
179
|
read_timeout_seconds,
|
180
|
+
None, # No callback for stdio
|
165
181
|
)
|
166
182
|
async with session:
|
167
183
|
logger.info(f"{server_name}: Connected to server using SSE transport.")
|
@@ -169,6 +185,27 @@ class ServerRegistry:
|
|
169
185
|
yield session
|
170
186
|
finally:
|
171
187
|
logger.debug(f"{server_name}: Closed session to server")
|
188
|
+
elif config.transport == "http":
|
189
|
+
if not config.url:
|
190
|
+
raise ValueError(f"URL is required for SSE transport: {server_name}")
|
191
|
+
|
192
|
+
async with streamablehttp_client(config.url, config.headers) as (
|
193
|
+
read_stream,
|
194
|
+
write_stream,
|
195
|
+
_,
|
196
|
+
):
|
197
|
+
session = client_session_factory(
|
198
|
+
read_stream,
|
199
|
+
write_stream,
|
200
|
+
read_timeout_seconds,
|
201
|
+
None, # No callback for stdio
|
202
|
+
)
|
203
|
+
async with session:
|
204
|
+
logger.info(f"{server_name}: Connected to server using HTTP transport.")
|
205
|
+
try:
|
206
|
+
yield session
|
207
|
+
finally:
|
208
|
+
logger.debug(f"{server_name}: Closed session to server")
|
172
209
|
|
173
210
|
# Unsupported transport
|
174
211
|
else:
|
@@ -179,7 +216,12 @@ class ServerRegistry:
|
|
179
216
|
self,
|
180
217
|
server_name: str,
|
181
218
|
client_session_factory: Callable[
|
182
|
-
[
|
219
|
+
[
|
220
|
+
MemoryObjectReceiveStream,
|
221
|
+
MemoryObjectSendStream,
|
222
|
+
timedelta | None,
|
223
|
+
GetSessionIdCallback,
|
224
|
+
],
|
183
225
|
ClientSession,
|
184
226
|
] = ClientSession,
|
185
227
|
init_hook: InitHookCallable = None,
|
mcp_agent/ui/console_display.py
CHANGED
@@ -25,6 +25,7 @@ class ConsoleDisplay:
|
|
25
25
|
config: Configuration object containing display preferences
|
26
26
|
"""
|
27
27
|
self.config = config
|
28
|
+
self._markup = config.logger.enable_markup if config else True
|
28
29
|
|
29
30
|
def show_tool_result(self, result: CallToolResult) -> None:
|
30
31
|
"""Display a tool result in a formatted panel."""
|
@@ -46,7 +47,7 @@ class ConsoleDisplay:
|
|
46
47
|
if len(str(result.content)) > 360:
|
47
48
|
panel.height = 8
|
48
49
|
|
49
|
-
console.console.print(panel)
|
50
|
+
console.console.print(panel, markup=self._markup)
|
50
51
|
console.console.print("\n")
|
51
52
|
|
52
53
|
def show_oai_tool_result(self, result) -> None:
|
@@ -67,7 +68,7 @@ class ConsoleDisplay:
|
|
67
68
|
if len(str(result)) > 360:
|
68
69
|
panel.height = 8
|
69
70
|
|
70
|
-
console.console.print(panel)
|
71
|
+
console.console.print(panel, markup=self._markup)
|
71
72
|
console.console.print("\n")
|
72
73
|
|
73
74
|
def show_tool_call(self, available_tools, tool_name, tool_args) -> None:
|
@@ -92,7 +93,7 @@ class ConsoleDisplay:
|
|
92
93
|
if len(str(tool_args)) > 360:
|
93
94
|
panel.height = 8
|
94
95
|
|
95
|
-
console.console.print(panel)
|
96
|
+
console.console.print(panel, markup=self._markup)
|
96
97
|
console.console.print("\n")
|
97
98
|
|
98
99
|
def _format_tool_list(self, available_tools, selected_tool_name):
|
@@ -172,7 +173,7 @@ class ConsoleDisplay:
|
|
172
173
|
subtitle=display_server_list,
|
173
174
|
subtitle_align="left",
|
174
175
|
)
|
175
|
-
console.console.print(panel)
|
176
|
+
console.console.print(panel, markup=self._markup)
|
176
177
|
console.console.print("\n")
|
177
178
|
|
178
179
|
def show_user_message(
|
@@ -196,7 +197,7 @@ class ConsoleDisplay:
|
|
196
197
|
subtitle=subtitle_text,
|
197
198
|
subtitle_align="left",
|
198
199
|
)
|
199
|
-
console.console.print(panel)
|
200
|
+
console.console.print(panel, markup=self._markup)
|
200
201
|
console.console.print("\n")
|
201
202
|
|
202
203
|
async def show_prompt_loaded(
|
@@ -270,5 +271,5 @@ class ConsoleDisplay:
|
|
270
271
|
subtitle_align="left",
|
271
272
|
)
|
272
273
|
|
273
|
-
console.console.print(panel)
|
274
|
+
console.console.print(panel, markup=self._markup)
|
274
275
|
console.console.print("\n")
|
File without changes
|
File without changes
|
File without changes
|