atlas-chat 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atlas/__init__.py +40 -0
- atlas/application/__init__.py +7 -0
- atlas/application/chat/__init__.py +7 -0
- atlas/application/chat/agent/__init__.py +10 -0
- atlas/application/chat/agent/act_loop.py +179 -0
- atlas/application/chat/agent/factory.py +142 -0
- atlas/application/chat/agent/protocols.py +46 -0
- atlas/application/chat/agent/react_loop.py +338 -0
- atlas/application/chat/agent/think_act_loop.py +171 -0
- atlas/application/chat/approval_manager.py +151 -0
- atlas/application/chat/elicitation_manager.py +191 -0
- atlas/application/chat/events/__init__.py +1 -0
- atlas/application/chat/events/agent_event_relay.py +112 -0
- atlas/application/chat/modes/__init__.py +1 -0
- atlas/application/chat/modes/agent.py +125 -0
- atlas/application/chat/modes/plain.py +74 -0
- atlas/application/chat/modes/rag.py +81 -0
- atlas/application/chat/modes/tools.py +179 -0
- atlas/application/chat/orchestrator.py +213 -0
- atlas/application/chat/policies/__init__.py +1 -0
- atlas/application/chat/policies/tool_authorization.py +99 -0
- atlas/application/chat/preprocessors/__init__.py +1 -0
- atlas/application/chat/preprocessors/message_builder.py +92 -0
- atlas/application/chat/preprocessors/prompt_override_service.py +104 -0
- atlas/application/chat/service.py +454 -0
- atlas/application/chat/utilities/__init__.py +6 -0
- atlas/application/chat/utilities/error_handler.py +367 -0
- atlas/application/chat/utilities/event_notifier.py +546 -0
- atlas/application/chat/utilities/file_processor.py +613 -0
- atlas/application/chat/utilities/tool_executor.py +789 -0
- atlas/atlas_chat_cli.py +347 -0
- atlas/atlas_client.py +238 -0
- atlas/core/__init__.py +0 -0
- atlas/core/auth.py +205 -0
- atlas/core/authorization_manager.py +27 -0
- atlas/core/capabilities.py +123 -0
- atlas/core/compliance.py +215 -0
- atlas/core/domain_whitelist.py +147 -0
- atlas/core/domain_whitelist_middleware.py +82 -0
- atlas/core/http_client.py +28 -0
- atlas/core/log_sanitizer.py +102 -0
- atlas/core/metrics_logger.py +59 -0
- atlas/core/middleware.py +131 -0
- atlas/core/otel_config.py +242 -0
- atlas/core/prompt_risk.py +200 -0
- atlas/core/rate_limit.py +0 -0
- atlas/core/rate_limit_middleware.py +64 -0
- atlas/core/security_headers_middleware.py +51 -0
- atlas/domain/__init__.py +37 -0
- atlas/domain/chat/__init__.py +1 -0
- atlas/domain/chat/dtos.py +85 -0
- atlas/domain/errors.py +96 -0
- atlas/domain/messages/__init__.py +12 -0
- atlas/domain/messages/models.py +160 -0
- atlas/domain/rag_mcp_service.py +664 -0
- atlas/domain/sessions/__init__.py +7 -0
- atlas/domain/sessions/models.py +36 -0
- atlas/domain/unified_rag_service.py +371 -0
- atlas/infrastructure/__init__.py +10 -0
- atlas/infrastructure/app_factory.py +135 -0
- atlas/infrastructure/events/__init__.py +1 -0
- atlas/infrastructure/events/cli_event_publisher.py +140 -0
- atlas/infrastructure/events/websocket_publisher.py +140 -0
- atlas/infrastructure/sessions/in_memory_repository.py +56 -0
- atlas/infrastructure/transport/__init__.py +7 -0
- atlas/infrastructure/transport/websocket_connection_adapter.py +33 -0
- atlas/init_cli.py +226 -0
- atlas/interfaces/__init__.py +15 -0
- atlas/interfaces/events.py +134 -0
- atlas/interfaces/llm.py +54 -0
- atlas/interfaces/rag.py +40 -0
- atlas/interfaces/sessions.py +75 -0
- atlas/interfaces/tools.py +57 -0
- atlas/interfaces/transport.py +24 -0
- atlas/main.py +564 -0
- atlas/mcp/api_key_demo/README.md +76 -0
- atlas/mcp/api_key_demo/main.py +172 -0
- atlas/mcp/api_key_demo/run.sh +56 -0
- atlas/mcp/basictable/main.py +147 -0
- atlas/mcp/calculator/main.py +149 -0
- atlas/mcp/code-executor/execution_engine.py +98 -0
- atlas/mcp/code-executor/execution_environment.py +95 -0
- atlas/mcp/code-executor/main.py +528 -0
- atlas/mcp/code-executor/result_processing.py +276 -0
- atlas/mcp/code-executor/script_generation.py +195 -0
- atlas/mcp/code-executor/security_checker.py +140 -0
- atlas/mcp/corporate_cars/main.py +437 -0
- atlas/mcp/csv_reporter/main.py +545 -0
- atlas/mcp/duckduckgo/main.py +182 -0
- atlas/mcp/elicitation_demo/README.md +171 -0
- atlas/mcp/elicitation_demo/main.py +262 -0
- atlas/mcp/env-demo/README.md +158 -0
- atlas/mcp/env-demo/main.py +199 -0
- atlas/mcp/file_size_test/main.py +284 -0
- atlas/mcp/filesystem/main.py +348 -0
- atlas/mcp/image_demo/main.py +113 -0
- atlas/mcp/image_demo/requirements.txt +4 -0
- atlas/mcp/logging_demo/README.md +72 -0
- atlas/mcp/logging_demo/main.py +103 -0
- atlas/mcp/many_tools_demo/main.py +50 -0
- atlas/mcp/order_database/__init__.py +0 -0
- atlas/mcp/order_database/main.py +369 -0
- atlas/mcp/order_database/signal_data.csv +1001 -0
- atlas/mcp/pdfbasic/main.py +394 -0
- atlas/mcp/pptx_generator/main.py +760 -0
- atlas/mcp/pptx_generator/requirements.txt +13 -0
- atlas/mcp/pptx_generator/run_test.sh +1 -0
- atlas/mcp/pptx_generator/test_pptx_generator_security.py +169 -0
- atlas/mcp/progress_demo/main.py +167 -0
- atlas/mcp/progress_updates_demo/QUICKSTART.md +273 -0
- atlas/mcp/progress_updates_demo/README.md +120 -0
- atlas/mcp/progress_updates_demo/main.py +497 -0
- atlas/mcp/prompts/main.py +222 -0
- atlas/mcp/public_demo/main.py +189 -0
- atlas/mcp/sampling_demo/README.md +169 -0
- atlas/mcp/sampling_demo/main.py +234 -0
- atlas/mcp/thinking/main.py +77 -0
- atlas/mcp/tool_planner/main.py +240 -0
- atlas/mcp/ui-demo/badmesh.png +0 -0
- atlas/mcp/ui-demo/main.py +383 -0
- atlas/mcp/ui-demo/templates/button_demo.html +32 -0
- atlas/mcp/ui-demo/templates/data_visualization.html +32 -0
- atlas/mcp/ui-demo/templates/form_demo.html +28 -0
- atlas/mcp/username-override-demo/README.md +320 -0
- atlas/mcp/username-override-demo/main.py +308 -0
- atlas/modules/__init__.py +0 -0
- atlas/modules/config/__init__.py +34 -0
- atlas/modules/config/cli.py +231 -0
- atlas/modules/config/config_manager.py +1096 -0
- atlas/modules/file_storage/__init__.py +22 -0
- atlas/modules/file_storage/cli.py +330 -0
- atlas/modules/file_storage/content_extractor.py +290 -0
- atlas/modules/file_storage/manager.py +295 -0
- atlas/modules/file_storage/mock_s3_client.py +402 -0
- atlas/modules/file_storage/s3_client.py +417 -0
- atlas/modules/llm/__init__.py +19 -0
- atlas/modules/llm/caller.py +287 -0
- atlas/modules/llm/litellm_caller.py +675 -0
- atlas/modules/llm/models.py +19 -0
- atlas/modules/mcp_tools/__init__.py +17 -0
- atlas/modules/mcp_tools/client.py +2123 -0
- atlas/modules/mcp_tools/token_storage.py +556 -0
- atlas/modules/prompts/prompt_provider.py +130 -0
- atlas/modules/rag/__init__.py +24 -0
- atlas/modules/rag/atlas_rag_client.py +336 -0
- atlas/modules/rag/client.py +129 -0
- atlas/routes/admin_routes.py +865 -0
- atlas/routes/config_routes.py +484 -0
- atlas/routes/feedback_routes.py +361 -0
- atlas/routes/files_routes.py +274 -0
- atlas/routes/health_routes.py +40 -0
- atlas/routes/mcp_auth_routes.py +223 -0
- atlas/server_cli.py +164 -0
- atlas/tests/conftest.py +20 -0
- atlas/tests/integration/test_mcp_auth_integration.py +152 -0
- atlas/tests/manual_test_sampling.py +87 -0
- atlas/tests/modules/mcp_tools/test_client_auth.py +226 -0
- atlas/tests/modules/mcp_tools/test_client_env.py +191 -0
- atlas/tests/test_admin_mcp_server_management_routes.py +141 -0
- atlas/tests/test_agent_roa.py +135 -0
- atlas/tests/test_app_factory_smoke.py +47 -0
- atlas/tests/test_approval_manager.py +439 -0
- atlas/tests/test_atlas_client.py +188 -0
- atlas/tests/test_atlas_rag_client.py +447 -0
- atlas/tests/test_atlas_rag_integration.py +224 -0
- atlas/tests/test_attach_file_flow.py +287 -0
- atlas/tests/test_auth_utils.py +165 -0
- atlas/tests/test_backend_public_url.py +185 -0
- atlas/tests/test_banner_logging.py +287 -0
- atlas/tests/test_capability_tokens_and_injection.py +203 -0
- atlas/tests/test_compliance_level.py +54 -0
- atlas/tests/test_compliance_manager.py +253 -0
- atlas/tests/test_config_manager.py +617 -0
- atlas/tests/test_config_manager_paths.py +12 -0
- atlas/tests/test_core_auth.py +18 -0
- atlas/tests/test_core_utils.py +190 -0
- atlas/tests/test_docker_env_sync.py +202 -0
- atlas/tests/test_domain_errors.py +329 -0
- atlas/tests/test_domain_whitelist.py +359 -0
- atlas/tests/test_elicitation_manager.py +408 -0
- atlas/tests/test_elicitation_routing.py +296 -0
- atlas/tests/test_env_demo_server.py +88 -0
- atlas/tests/test_error_classification.py +113 -0
- atlas/tests/test_error_flow_integration.py +116 -0
- atlas/tests/test_feedback_routes.py +333 -0
- atlas/tests/test_file_content_extraction.py +1134 -0
- atlas/tests/test_file_extraction_routes.py +158 -0
- atlas/tests/test_file_library.py +107 -0
- atlas/tests/test_file_manager_unit.py +18 -0
- atlas/tests/test_health_route.py +49 -0
- atlas/tests/test_http_client_stub.py +8 -0
- atlas/tests/test_imports_smoke.py +30 -0
- atlas/tests/test_interfaces_llm_response.py +9 -0
- atlas/tests/test_issue_access_denied_fix.py +136 -0
- atlas/tests/test_llm_env_expansion.py +836 -0
- atlas/tests/test_log_level_sensitive_data.py +285 -0
- atlas/tests/test_mcp_auth_routes.py +341 -0
- atlas/tests/test_mcp_client_auth.py +331 -0
- atlas/tests/test_mcp_data_injection.py +270 -0
- atlas/tests/test_mcp_get_authorized_servers.py +95 -0
- atlas/tests/test_mcp_hot_reload.py +512 -0
- atlas/tests/test_mcp_image_content.py +424 -0
- atlas/tests/test_mcp_logging.py +172 -0
- atlas/tests/test_mcp_progress_updates.py +313 -0
- atlas/tests/test_mcp_prompt_override_system_prompt.py +102 -0
- atlas/tests/test_mcp_prompts_server.py +39 -0
- atlas/tests/test_mcp_tool_result_parsing.py +296 -0
- atlas/tests/test_metrics_logger.py +56 -0
- atlas/tests/test_middleware_auth.py +379 -0
- atlas/tests/test_prompt_risk_and_acl.py +141 -0
- atlas/tests/test_rag_mcp_aggregator.py +204 -0
- atlas/tests/test_rag_mcp_service.py +224 -0
- atlas/tests/test_rate_limit_middleware.py +45 -0
- atlas/tests/test_routes_config_smoke.py +60 -0
- atlas/tests/test_routes_files_download_token.py +41 -0
- atlas/tests/test_routes_files_health.py +18 -0
- atlas/tests/test_runtime_imports.py +53 -0
- atlas/tests/test_sampling_integration.py +482 -0
- atlas/tests/test_security_admin_routes.py +61 -0
- atlas/tests/test_security_capability_tokens.py +65 -0
- atlas/tests/test_security_file_stats_scope.py +21 -0
- atlas/tests/test_security_header_injection.py +191 -0
- atlas/tests/test_security_headers_and_filename.py +63 -0
- atlas/tests/test_shared_session_repository.py +101 -0
- atlas/tests/test_system_prompt_loading.py +181 -0
- atlas/tests/test_token_storage.py +505 -0
- atlas/tests/test_tool_approval_config.py +93 -0
- atlas/tests/test_tool_approval_utils.py +356 -0
- atlas/tests/test_tool_authorization_group_filtering.py +223 -0
- atlas/tests/test_tool_details_in_config.py +108 -0
- atlas/tests/test_tool_planner.py +300 -0
- atlas/tests/test_unified_rag_service.py +398 -0
- atlas/tests/test_username_override_in_approval.py +258 -0
- atlas/tests/test_websocket_auth_header.py +168 -0
- atlas/version.py +6 -0
- atlas_chat-0.1.0.data/data/.env.example +253 -0
- atlas_chat-0.1.0.data/data/config/defaults/compliance-levels.json +44 -0
- atlas_chat-0.1.0.data/data/config/defaults/domain-whitelist.json +123 -0
- atlas_chat-0.1.0.data/data/config/defaults/file-extractors.json +74 -0
- atlas_chat-0.1.0.data/data/config/defaults/help-config.json +198 -0
- atlas_chat-0.1.0.data/data/config/defaults/llmconfig-buggy.yml +11 -0
- atlas_chat-0.1.0.data/data/config/defaults/llmconfig.yml +19 -0
- atlas_chat-0.1.0.data/data/config/defaults/mcp.json +138 -0
- atlas_chat-0.1.0.data/data/config/defaults/rag-sources.json +17 -0
- atlas_chat-0.1.0.data/data/config/defaults/splash-config.json +16 -0
- atlas_chat-0.1.0.dist-info/METADATA +236 -0
- atlas_chat-0.1.0.dist-info/RECORD +250 -0
- atlas_chat-0.1.0.dist-info/WHEEL +5 -0
- atlas_chat-0.1.0.dist-info/entry_points.txt +4 -0
- atlas_chat-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2123 @@
|
|
|
1
|
+
"""FastMCP client for connecting to MCP servers and managing tools."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import contextvars
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
import time
|
|
10
|
+
from contextlib import asynccontextmanager
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional
|
|
13
|
+
|
|
14
|
+
from fastmcp import Client
|
|
15
|
+
from fastmcp.client.transports import StreamableHttpTransport
|
|
16
|
+
|
|
17
|
+
from atlas.core.log_sanitizer import sanitize_for_logging
|
|
18
|
+
from atlas.core.metrics_logger import log_metric
|
|
19
|
+
from atlas.domain.messages.models import ToolCall, ToolResult
|
|
20
|
+
from atlas.modules.config import config_manager
|
|
21
|
+
from atlas.modules.config.config_manager import resolve_env_var
|
|
22
|
+
from atlas.modules.mcp_tools.token_storage import AuthenticationRequiredException
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
# Type alias for log callback function
|
|
27
|
+
LogCallback = Callable[[str, str, str, Dict[str, Any]], Awaitable[None]]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class _ElicitationRoutingContext:
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
server_name: str,
|
|
34
|
+
tool_call: ToolCall,
|
|
35
|
+
update_cb: Optional[Callable[[Dict[str, Any]], Awaitable[None]]],
|
|
36
|
+
):
|
|
37
|
+
self.server_name = server_name
|
|
38
|
+
self.tool_call = tool_call
|
|
39
|
+
self.update_cb = update_cb
|
|
40
|
+
|
|
41
|
+
# Context-local override used to route MCP logs to the *current* request/session.
|
|
42
|
+
# This prevents cross-user log leakage when MCPToolManager is shared across connections.
|
|
43
|
+
_ACTIVE_LOG_CALLBACK: contextvars.ContextVar[Optional[LogCallback]] = contextvars.ContextVar(
|
|
44
|
+
"mcp_active_log_callback",
|
|
45
|
+
default=None,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# Dictionary-based routing for elicitation so a shared Client can still deliver
|
|
49
|
+
# elicitation requests to the correct user's WebSocket.
|
|
50
|
+
# Key: (server_name, tool_call_id) tuple to avoid collisions with concurrent tool calls
|
|
51
|
+
# Note: Cannot use contextvars.ContextVar because MCP receive loop runs in a different task
|
|
52
|
+
_ELICITATION_ROUTING: Dict[tuple, _ElicitationRoutingContext] = {}
|
|
53
|
+
|
|
54
|
+
# Dictionary-based routing for sampling requests (similar to elicitation)
|
|
55
|
+
# Key: (server_name, tool_call_id) tuple to avoid collisions with concurrent tool calls
|
|
56
|
+
_SAMPLING_ROUTING: Dict[tuple, "_SamplingRoutingContext"] = {}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class _SamplingRoutingContext:
|
|
60
|
+
"""Context for routing sampling requests to the correct tool execution."""
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
server_name: str,
|
|
64
|
+
tool_call: ToolCall,
|
|
65
|
+
update_cb: Optional[Callable[[Dict[str, Any]], Awaitable[None]]],
|
|
66
|
+
):
|
|
67
|
+
self.server_name = server_name
|
|
68
|
+
self.tool_call = tool_call
|
|
69
|
+
self.update_cb = update_cb
|
|
70
|
+
|
|
71
|
+
# Mapping from MCP log levels to Python logging levels
|
|
72
|
+
MCP_TO_PYTHON_LOG_LEVEL = {
|
|
73
|
+
"debug": logging.DEBUG,
|
|
74
|
+
"info": logging.INFO,
|
|
75
|
+
"notice": logging.INFO,
|
|
76
|
+
"warning": logging.WARNING,
|
|
77
|
+
"warn": logging.WARNING,
|
|
78
|
+
"error": logging.ERROR,
|
|
79
|
+
"alert": logging.CRITICAL,
|
|
80
|
+
"critical": logging.CRITICAL,
|
|
81
|
+
"emergency": logging.CRITICAL,
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class MCPToolManager:
|
|
86
|
+
"""Manager for MCP servers and their tools.
|
|
87
|
+
|
|
88
|
+
Default config path now points to config/overrides (or env override) with legacy fallback.
|
|
89
|
+
|
|
90
|
+
Supports:
|
|
91
|
+
- Hot-reloading configuration from disk via reload_config()
|
|
92
|
+
- Tracking failed server connections for retry
|
|
93
|
+
- Auto-reconnect with exponential backoff (when feature flag is enabled)
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(self, config_path: Optional[str] = None, log_callback: Optional[LogCallback] = None):
|
|
97
|
+
if config_path is None:
|
|
98
|
+
# Use config manager to get config path
|
|
99
|
+
app_settings = config_manager.app_settings
|
|
100
|
+
overrides_root = Path(app_settings.app_config_overrides)
|
|
101
|
+
|
|
102
|
+
# If relative, resolve from project root
|
|
103
|
+
if not overrides_root.is_absolute():
|
|
104
|
+
# This file is in backend/modules/mcp_tools/client.py
|
|
105
|
+
backend_root = Path(__file__).parent.parent.parent
|
|
106
|
+
project_root = backend_root.parent
|
|
107
|
+
overrides_root = project_root / overrides_root
|
|
108
|
+
|
|
109
|
+
candidate = overrides_root / "mcp.json"
|
|
110
|
+
if not candidate.exists():
|
|
111
|
+
# Legacy fallback
|
|
112
|
+
candidate = Path("backend/configfilesadmin/mcp.json")
|
|
113
|
+
if not candidate.exists():
|
|
114
|
+
candidate = Path("backend/configfiles/mcp.json")
|
|
115
|
+
self.config_path = str(candidate)
|
|
116
|
+
# Use default config manager when no path specified
|
|
117
|
+
mcp_config = config_manager.mcp_config
|
|
118
|
+
self.servers_config = {name: server.model_dump() for name, server in mcp_config.servers.items()}
|
|
119
|
+
else:
|
|
120
|
+
# Load config from the specified path
|
|
121
|
+
self.config_path = config_path
|
|
122
|
+
config_file = Path(config_path)
|
|
123
|
+
if config_file.exists():
|
|
124
|
+
from atlas.modules.config.config_manager import MCPConfig
|
|
125
|
+
data = json.loads(config_file.read_text())
|
|
126
|
+
# Convert flat structure to nested structure for Pydantic
|
|
127
|
+
servers_data = {"servers": data}
|
|
128
|
+
mcp_config = MCPConfig(**servers_data)
|
|
129
|
+
self.servers_config = {name: server.model_dump() for name, server in mcp_config.servers.items()}
|
|
130
|
+
else:
|
|
131
|
+
logger.warning(f"Custom config path specified but file not found: {config_path}")
|
|
132
|
+
self.servers_config = {}
|
|
133
|
+
self.clients = {}
|
|
134
|
+
self.available_tools = {}
|
|
135
|
+
self.available_prompts = {}
|
|
136
|
+
|
|
137
|
+
# Track failed servers for reconnection with backoff
|
|
138
|
+
self._failed_servers: Dict[str, Dict[str, Any]] = {}
|
|
139
|
+
# {server_name: {"last_attempt": timestamp, "attempt_count": int, "error": str}}
|
|
140
|
+
|
|
141
|
+
# Reconnect task reference (used by auto-reconnect background task)
|
|
142
|
+
self._reconnect_task: Optional[asyncio.Task] = None
|
|
143
|
+
self._reconnect_running = False
|
|
144
|
+
|
|
145
|
+
# Default log callback (used when no request-scoped callback is active).
|
|
146
|
+
# Signature: (server_name, level, message, extra_data) -> None
|
|
147
|
+
self._default_log_callback = log_callback
|
|
148
|
+
|
|
149
|
+
# Get configured log level for filtering
|
|
150
|
+
self._min_log_level = self._get_min_log_level()
|
|
151
|
+
|
|
152
|
+
# Per-user client cache for servers requiring user-specific authentication
|
|
153
|
+
# Key: (user_email, server_name), Value: FastMCP Client instance
|
|
154
|
+
self._user_clients: Dict[tuple, Client] = {}
|
|
155
|
+
self._user_clients_lock = asyncio.Lock()
|
|
156
|
+
|
|
157
|
+
def _get_min_log_level(self) -> int:
|
|
158
|
+
"""Get the minimum log level from environment or config."""
|
|
159
|
+
try:
|
|
160
|
+
app_settings = config_manager.app_settings
|
|
161
|
+
raw_level_name = getattr(app_settings, "log_level", None)
|
|
162
|
+
if not isinstance(raw_level_name, str):
|
|
163
|
+
raise TypeError("log_level must be a string")
|
|
164
|
+
level_name = raw_level_name.upper()
|
|
165
|
+
except Exception:
|
|
166
|
+
level_name = os.getenv("LOG_LEVEL", "INFO").upper()
|
|
167
|
+
|
|
168
|
+
level = getattr(logging, level_name, None)
|
|
169
|
+
return level if isinstance(level, int) else logging.INFO
|
|
170
|
+
|
|
171
|
+
def _create_log_handler(self, server_name: str):
|
|
172
|
+
"""Create a log handler for an MCP server.
|
|
173
|
+
|
|
174
|
+
This handler forwards MCP server logs to the backend logger and optionally to the UI.
|
|
175
|
+
Logs are filtered based on the configured LOG_LEVEL.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
server_name: Name of the MCP server
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
An async function that handles LogMessage objects from fastmcp
|
|
182
|
+
"""
|
|
183
|
+
async def log_handler(message) -> None:
|
|
184
|
+
"""Handle log messages from MCP server."""
|
|
185
|
+
try:
|
|
186
|
+
# Import here to avoid circular dependency
|
|
187
|
+
|
|
188
|
+
# Handle both LogMessage objects and dict-like structures
|
|
189
|
+
if hasattr(message, 'level'):
|
|
190
|
+
log_level_str = message.level.lower()
|
|
191
|
+
log_data = message.data if hasattr(message, 'data') else {}
|
|
192
|
+
else:
|
|
193
|
+
# Fallback for dict-like messages
|
|
194
|
+
log_level_str = message.get('level', 'info').lower()
|
|
195
|
+
log_data = message.get('data', {})
|
|
196
|
+
|
|
197
|
+
msg = log_data.get('msg', '') if isinstance(log_data, dict) else str(log_data)
|
|
198
|
+
extra = log_data.get('extra', {}) if isinstance(log_data, dict) else {}
|
|
199
|
+
|
|
200
|
+
# Convert MCP log level to Python logging level
|
|
201
|
+
python_log_level = MCP_TO_PYTHON_LOG_LEVEL.get(log_level_str, logging.INFO)
|
|
202
|
+
|
|
203
|
+
# Filter based on configured minimum log level
|
|
204
|
+
if python_log_level < self._min_log_level:
|
|
205
|
+
return
|
|
206
|
+
|
|
207
|
+
# Backend log noise reduction: tool servers can be very chatty at INFO.
|
|
208
|
+
# Keep their INFO messages available at LOG_LEVEL=DEBUG, but avoid flooding
|
|
209
|
+
# app logs at LOG_LEVEL=INFO. Warnings/errors still surface at INFO.
|
|
210
|
+
backend_log_level = python_log_level if python_log_level >= logging.WARNING else logging.DEBUG
|
|
211
|
+
|
|
212
|
+
# Log to backend logger with server context
|
|
213
|
+
logger.log(
|
|
214
|
+
backend_log_level,
|
|
215
|
+
f"[MCP:{sanitize_for_logging(server_name)}] {sanitize_for_logging(msg)}",
|
|
216
|
+
extra={"mcp_server": server_name, "mcp_extra": extra}
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Forward to the active (request-scoped) callback when present,
|
|
220
|
+
# otherwise fall back to the default callback.
|
|
221
|
+
callback = _ACTIVE_LOG_CALLBACK.get() or self._default_log_callback
|
|
222
|
+
if callback is not None:
|
|
223
|
+
await callback(server_name, log_level_str, msg, extra)
|
|
224
|
+
|
|
225
|
+
except Exception as e:
|
|
226
|
+
logger.warning(f"Error handling log from MCP server {server_name}: {e}")
|
|
227
|
+
|
|
228
|
+
return log_handler
|
|
229
|
+
|
|
230
|
+
def set_log_callback(self, callback: Optional[LogCallback]) -> None:
|
|
231
|
+
"""Set or update the log callback for forwarding MCP server logs to UI.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
callback: Async function that receives (server_name, level, message, extra_data)
|
|
235
|
+
"""
|
|
236
|
+
self._default_log_callback = callback
|
|
237
|
+
|
|
238
|
+
@asynccontextmanager
|
|
239
|
+
async def _use_log_callback(self, callback: Optional[LogCallback]) -> AsyncIterator[None]:
|
|
240
|
+
"""Temporarily set a request-scoped log callback.
|
|
241
|
+
|
|
242
|
+
This is used to bind MCP server logs to the current tool execution so they
|
|
243
|
+
are forwarded only to the correct user's WebSocket connection.
|
|
244
|
+
"""
|
|
245
|
+
token = _ACTIVE_LOG_CALLBACK.set(callback)
|
|
246
|
+
try:
|
|
247
|
+
yield
|
|
248
|
+
finally:
|
|
249
|
+
_ACTIVE_LOG_CALLBACK.reset(token)
|
|
250
|
+
|
|
251
|
+
@asynccontextmanager
|
|
252
|
+
async def _use_elicitation_context(
|
|
253
|
+
self,
|
|
254
|
+
server_name: str,
|
|
255
|
+
tool_call: ToolCall,
|
|
256
|
+
update_cb: Optional[Callable[[Dict[str, Any]], Awaitable[None]]],
|
|
257
|
+
) -> AsyncIterator[None]:
|
|
258
|
+
"""
|
|
259
|
+
Set up elicitation routing for a tool call.
|
|
260
|
+
Uses dictionary-based routing (not contextvars) because MCP receive loop runs in a different task.
|
|
261
|
+
Key is (server_name, tool_call.id) to avoid collisions with concurrent tool calls.
|
|
262
|
+
"""
|
|
263
|
+
routing = _ElicitationRoutingContext(server_name, tool_call, update_cb)
|
|
264
|
+
routing_key = (server_name, tool_call.id)
|
|
265
|
+
_ELICITATION_ROUTING[routing_key] = routing
|
|
266
|
+
try:
|
|
267
|
+
yield
|
|
268
|
+
finally:
|
|
269
|
+
_ELICITATION_ROUTING.pop(routing_key, None)
|
|
270
|
+
|
|
271
|
+
def _create_elicitation_handler(self, server_name: str):
|
|
272
|
+
"""
|
|
273
|
+
Create an elicitation handler for a specific MCP server.
|
|
274
|
+
|
|
275
|
+
Returns a handler function that captures the server_name,
|
|
276
|
+
allowing dictionary-based routing that works across async tasks.
|
|
277
|
+
"""
|
|
278
|
+
async def handler(message, response_type, params, _context):
|
|
279
|
+
"""Per-server elicitation handler with captured server_name."""
|
|
280
|
+
from fastmcp.client.elicitation import ElicitResult
|
|
281
|
+
from mcp.types import ElicitRequestFormParams
|
|
282
|
+
|
|
283
|
+
# Find routing context for this server (keyed by (server_name, tool_call_id))
|
|
284
|
+
routing = None
|
|
285
|
+
for (srv, _tcid), ctx in _ELICITATION_ROUTING.items():
|
|
286
|
+
if srv == server_name:
|
|
287
|
+
routing = ctx
|
|
288
|
+
break
|
|
289
|
+
if routing is None:
|
|
290
|
+
logger.warning(
|
|
291
|
+
f"Elicitation request for server '{server_name}' but no routing context - "
|
|
292
|
+
f"elicitation cancelled. Message: {message[:50]}..."
|
|
293
|
+
)
|
|
294
|
+
return ElicitResult(action="cancel", content=None)
|
|
295
|
+
if routing.update_cb is None:
|
|
296
|
+
logger.warning(
|
|
297
|
+
f"Elicitation request for server '{server_name}', tool '{routing.tool_call.name}' "
|
|
298
|
+
f"but update_cb is None - elicitation cancelled. Message: {message[:50]}..."
|
|
299
|
+
)
|
|
300
|
+
return ElicitResult(action="cancel", content=None)
|
|
301
|
+
|
|
302
|
+
response_schema: Dict[str, Any] = {}
|
|
303
|
+
if isinstance(params, ElicitRequestFormParams):
|
|
304
|
+
response_schema = params.requestedSchema or {}
|
|
305
|
+
|
|
306
|
+
try:
|
|
307
|
+
import uuid
|
|
308
|
+
|
|
309
|
+
from atlas.application.chat.elicitation_manager import get_elicitation_manager
|
|
310
|
+
|
|
311
|
+
elicitation_id = str(uuid.uuid4())
|
|
312
|
+
elicitation_manager = get_elicitation_manager()
|
|
313
|
+
|
|
314
|
+
request = elicitation_manager.create_elicitation_request(
|
|
315
|
+
elicitation_id=elicitation_id,
|
|
316
|
+
tool_call_id=routing.tool_call.id,
|
|
317
|
+
tool_name=routing.tool_call.name,
|
|
318
|
+
message=message,
|
|
319
|
+
response_schema=response_schema,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
logger.debug(f"Sending elicitation_request to frontend for server '{server_name}'")
|
|
323
|
+
await routing.update_cb(
|
|
324
|
+
{
|
|
325
|
+
"type": "elicitation_request",
|
|
326
|
+
"elicitation_id": elicitation_id,
|
|
327
|
+
"tool_call_id": routing.tool_call.id,
|
|
328
|
+
"tool_name": routing.tool_call.name,
|
|
329
|
+
"message": message,
|
|
330
|
+
"response_schema": response_schema,
|
|
331
|
+
}
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
try:
|
|
335
|
+
response = await request.wait_for_response(timeout=300.0)
|
|
336
|
+
finally:
|
|
337
|
+
elicitation_manager.cleanup_request(elicitation_id)
|
|
338
|
+
|
|
339
|
+
action = response.get("action", "cancel")
|
|
340
|
+
data = response.get("data")
|
|
341
|
+
|
|
342
|
+
if action != "accept":
|
|
343
|
+
return ElicitResult(action=action, content=None)
|
|
344
|
+
|
|
345
|
+
# Approval-only elicitation (response_type=None) must return an empty object.
|
|
346
|
+
# Some UIs send placeholder payloads like {'none': ''}; don't forward them.
|
|
347
|
+
if response_type is None:
|
|
348
|
+
return ElicitResult(action="accept", content={})
|
|
349
|
+
|
|
350
|
+
if data is None:
|
|
351
|
+
return ElicitResult(action="accept", content=None)
|
|
352
|
+
|
|
353
|
+
# FastMCP requires elicitation response content to be a JSON object.
|
|
354
|
+
if not isinstance(data, dict):
|
|
355
|
+
props: Dict[str, Any] = {}
|
|
356
|
+
if isinstance(response_schema, dict):
|
|
357
|
+
props = response_schema.get("properties") or {}
|
|
358
|
+
if list(props.keys()) == ["value"]:
|
|
359
|
+
data = {"value": data}
|
|
360
|
+
else:
|
|
361
|
+
data = {"value": data}
|
|
362
|
+
|
|
363
|
+
return ElicitResult(action="accept", content=data)
|
|
364
|
+
|
|
365
|
+
except asyncio.TimeoutError:
|
|
366
|
+
logger.warning(f"Elicitation timeout for server '{server_name}'")
|
|
367
|
+
return ElicitResult(action="cancel", content=None)
|
|
368
|
+
except Exception as e:
|
|
369
|
+
logger.error(f"Error handling elicitation for server '{server_name}': {e}", exc_info=True)
|
|
370
|
+
return ElicitResult(action="cancel", content=None)
|
|
371
|
+
|
|
372
|
+
return handler
|
|
373
|
+
|
|
374
|
+
@asynccontextmanager
|
|
375
|
+
async def _use_sampling_context(
|
|
376
|
+
self,
|
|
377
|
+
server_name: str,
|
|
378
|
+
tool_call: ToolCall,
|
|
379
|
+
update_cb: Optional[Callable[[Dict[str, Any]], Awaitable[None]]],
|
|
380
|
+
) -> AsyncIterator[None]:
|
|
381
|
+
"""
|
|
382
|
+
Set up sampling routing for a tool call.
|
|
383
|
+
Uses dictionary-based routing (not contextvars) because MCP receive loop runs in a different task.
|
|
384
|
+
Key is (server_name, tool_call.id) to avoid collisions with concurrent tool calls.
|
|
385
|
+
"""
|
|
386
|
+
routing = _SamplingRoutingContext(server_name, tool_call, update_cb)
|
|
387
|
+
routing_key = (server_name, tool_call.id)
|
|
388
|
+
_SAMPLING_ROUTING[routing_key] = routing
|
|
389
|
+
try:
|
|
390
|
+
yield
|
|
391
|
+
finally:
|
|
392
|
+
_SAMPLING_ROUTING.pop(routing_key, None)
|
|
393
|
+
|
|
394
|
+
def _create_sampling_handler(self, server_name: str):
|
|
395
|
+
"""
|
|
396
|
+
Create a sampling handler for a specific MCP server.
|
|
397
|
+
|
|
398
|
+
This handler intercepts MCP sampling requests and routes them to the LLM.
|
|
399
|
+
Returns a handler function that captures the server_name for routing.
|
|
400
|
+
"""
|
|
401
|
+
async def handler(messages, params=None, context=None):
|
|
402
|
+
"""Per-server sampling handler with captured server_name."""
|
|
403
|
+
from mcp.types import CreateMessageResult, SamplingMessage, TextContent
|
|
404
|
+
|
|
405
|
+
# Find routing context for this server (keyed by (server_name, tool_call_id))
|
|
406
|
+
routing = None
|
|
407
|
+
for (srv, _tcid), ctx in _SAMPLING_ROUTING.items():
|
|
408
|
+
if srv == server_name:
|
|
409
|
+
routing = ctx
|
|
410
|
+
break
|
|
411
|
+
if routing is None:
|
|
412
|
+
logger.warning(
|
|
413
|
+
f"Sampling request for server '{server_name}' but no routing context - "
|
|
414
|
+
f"sampling cancelled."
|
|
415
|
+
)
|
|
416
|
+
raise Exception("No routing context for sampling request")
|
|
417
|
+
|
|
418
|
+
try:
|
|
419
|
+
message_dicts = []
|
|
420
|
+
for msg in messages:
|
|
421
|
+
if isinstance(msg, SamplingMessage):
|
|
422
|
+
text = ""
|
|
423
|
+
if isinstance(msg.content, TextContent):
|
|
424
|
+
text = msg.content.text
|
|
425
|
+
elif isinstance(msg.content, list):
|
|
426
|
+
for item in msg.content:
|
|
427
|
+
if isinstance(item, TextContent):
|
|
428
|
+
text += item.text
|
|
429
|
+
else:
|
|
430
|
+
text = str(msg.content)
|
|
431
|
+
message_dicts.append({
|
|
432
|
+
"role": msg.role,
|
|
433
|
+
"content": text
|
|
434
|
+
})
|
|
435
|
+
elif isinstance(msg, str):
|
|
436
|
+
message_dicts.append({
|
|
437
|
+
"role": "user",
|
|
438
|
+
"content": msg
|
|
439
|
+
})
|
|
440
|
+
else:
|
|
441
|
+
message_dicts.append(msg)
|
|
442
|
+
|
|
443
|
+
system_prompt = getattr(params, 'systemPrompt', None) if params else None
|
|
444
|
+
temperature = getattr(params, 'temperature', None) if params else None
|
|
445
|
+
max_tokens = getattr(params, 'maxTokens', 512) if params else 512
|
|
446
|
+
model_preferences_raw = getattr(params, 'modelPreferences', None) if params else None
|
|
447
|
+
|
|
448
|
+
model_preferences = None
|
|
449
|
+
if model_preferences_raw:
|
|
450
|
+
if isinstance(model_preferences_raw, str):
|
|
451
|
+
model_preferences = [model_preferences_raw]
|
|
452
|
+
elif isinstance(model_preferences_raw, list):
|
|
453
|
+
model_preferences = model_preferences_raw
|
|
454
|
+
|
|
455
|
+
if system_prompt:
|
|
456
|
+
message_dicts.insert(0, {
|
|
457
|
+
"role": "system",
|
|
458
|
+
"content": system_prompt
|
|
459
|
+
})
|
|
460
|
+
|
|
461
|
+
logger.info(
|
|
462
|
+
f"Sampling request from server '{server_name}' tool '{routing.tool_call.name}': "
|
|
463
|
+
f"{len(message_dicts)} messages, temperature={temperature}, max_tokens={max_tokens}"
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
from atlas.modules.config import config_manager
|
|
467
|
+
from atlas.modules.llm.litellm_caller import LiteLLMCaller
|
|
468
|
+
|
|
469
|
+
llm_caller = LiteLLMCaller()
|
|
470
|
+
|
|
471
|
+
llm_config = config_manager.llm_config
|
|
472
|
+
model_name = None
|
|
473
|
+
|
|
474
|
+
if model_preferences:
|
|
475
|
+
for pref in model_preferences:
|
|
476
|
+
if pref in llm_config.models:
|
|
477
|
+
model_name = pref
|
|
478
|
+
break
|
|
479
|
+
for name, model_config in llm_config.models.items():
|
|
480
|
+
if model_config.model_name == pref:
|
|
481
|
+
model_name = name
|
|
482
|
+
break
|
|
483
|
+
if model_name:
|
|
484
|
+
break
|
|
485
|
+
|
|
486
|
+
if not model_name:
|
|
487
|
+
model_name = next(iter(llm_config.models.keys()))
|
|
488
|
+
|
|
489
|
+
logger.debug(
|
|
490
|
+
f"Using model '{model_name}' for sampling "
|
|
491
|
+
f"(preferences: {model_preferences})"
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
response = await llm_caller.call_plain(
|
|
495
|
+
model_name=model_name,
|
|
496
|
+
messages=message_dicts,
|
|
497
|
+
temperature=temperature,
|
|
498
|
+
max_tokens=max_tokens
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
logger.info(
|
|
502
|
+
f"Sampling completed for server '{server_name}': "
|
|
503
|
+
f"response_length={len(response) if response else 0}"
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
return CreateMessageResult(
|
|
507
|
+
role="assistant",
|
|
508
|
+
content=TextContent(type="text", text=response),
|
|
509
|
+
model=model_name
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
except Exception as e:
|
|
513
|
+
logger.error(f"Error handling sampling for server '{server_name}': {e}", exc_info=True)
|
|
514
|
+
raise
|
|
515
|
+
|
|
516
|
+
return handler
|
|
517
|
+
|
|
518
|
+
def reload_config(self) -> Dict[str, Any]:
|
|
519
|
+
"""Reload MCP server configuration from disk.
|
|
520
|
+
|
|
521
|
+
This re-reads the mcp.json configuration file and updates servers_config.
|
|
522
|
+
Call initialize_clients() and discover_tools()/discover_prompts() afterward
|
|
523
|
+
to apply the changes.
|
|
524
|
+
|
|
525
|
+
Returns:
|
|
526
|
+
Dict with previous and new server lists for comparison
|
|
527
|
+
"""
|
|
528
|
+
previous_servers = set(self.servers_config.keys())
|
|
529
|
+
|
|
530
|
+
# Reload from config manager (which reads from disk)
|
|
531
|
+
new_mcp_config = config_manager.reload_mcp_config()
|
|
532
|
+
self.servers_config = {
|
|
533
|
+
name: server.model_dump()
|
|
534
|
+
for name, server in new_mcp_config.servers.items()
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
new_servers = set(self.servers_config.keys())
|
|
538
|
+
|
|
539
|
+
# Clear failed servers tracking for removed servers
|
|
540
|
+
removed_servers = previous_servers - new_servers
|
|
541
|
+
for server_name in removed_servers:
|
|
542
|
+
self._failed_servers.pop(server_name, None)
|
|
543
|
+
|
|
544
|
+
added_servers = new_servers - previous_servers
|
|
545
|
+
unchanged_servers = previous_servers & new_servers
|
|
546
|
+
|
|
547
|
+
logger.info(
|
|
548
|
+
f"MCP config reloaded: added={list(added_servers)}, "
|
|
549
|
+
f"removed={list(removed_servers)}, unchanged={list(unchanged_servers)}"
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
return {
|
|
553
|
+
"previous_servers": list(previous_servers),
|
|
554
|
+
"new_servers": list(new_servers),
|
|
555
|
+
"added": list(added_servers),
|
|
556
|
+
"removed": list(removed_servers),
|
|
557
|
+
"unchanged": list(unchanged_servers)
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
def get_failed_servers(self) -> Dict[str, Dict[str, Any]]:
|
|
561
|
+
"""Get information about servers that failed to connect.
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
Dict mapping server name to failure info including last_attempt time,
|
|
565
|
+
attempt_count, and error message.
|
|
566
|
+
"""
|
|
567
|
+
return dict(self._failed_servers)
|
|
568
|
+
|
|
569
|
+
def _record_server_failure(self, server_name: str, error: str) -> None:
|
|
570
|
+
"""Record a server connection failure for tracking."""
|
|
571
|
+
if server_name in self._failed_servers:
|
|
572
|
+
self._failed_servers[server_name]["attempt_count"] += 1
|
|
573
|
+
self._failed_servers[server_name]["last_attempt"] = time.time()
|
|
574
|
+
self._failed_servers[server_name]["error"] = error
|
|
575
|
+
else:
|
|
576
|
+
self._failed_servers[server_name] = {
|
|
577
|
+
"last_attempt": time.time(),
|
|
578
|
+
"attempt_count": 1,
|
|
579
|
+
"error": error
|
|
580
|
+
}
|
|
581
|
+
|
|
582
|
+
def _clear_server_failure(self, server_name: str) -> None:
|
|
583
|
+
"""Clear failure tracking for a server after successful connection."""
|
|
584
|
+
self._failed_servers.pop(server_name, None)
|
|
585
|
+
|
|
586
|
+
def _calculate_backoff_delay(self, attempt_count: int) -> float:
|
|
587
|
+
"""Calculate exponential backoff delay for reconnection attempts.
|
|
588
|
+
|
|
589
|
+
Uses settings from config_manager for base interval, max interval, and multiplier.
|
|
590
|
+
"""
|
|
591
|
+
app_settings = config_manager.app_settings
|
|
592
|
+
base_interval = app_settings.mcp_reconnect_interval
|
|
593
|
+
max_interval = app_settings.mcp_reconnect_max_interval
|
|
594
|
+
multiplier = app_settings.mcp_reconnect_backoff_multiplier
|
|
595
|
+
|
|
596
|
+
delay = base_interval * (multiplier ** (attempt_count - 1))
|
|
597
|
+
return min(delay, max_interval)
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
def _determine_transport_type(self, config: Dict[str, Any]) -> str:
|
|
601
|
+
"""Determine the transport type for an MCP server configuration.
|
|
602
|
+
|
|
603
|
+
Priority order:
|
|
604
|
+
1. Explicit 'transport' field (highest priority)
|
|
605
|
+
2. Auto-detection from command
|
|
606
|
+
3. Auto-detection from URL if it has protocol
|
|
607
|
+
4. Fallback to 'type' field (backward compatibility)
|
|
608
|
+
"""
|
|
609
|
+
# 1. Explicit transport field takes highest priority
|
|
610
|
+
if config.get("transport"):
|
|
611
|
+
logger.debug(f"Using explicit transport: {config['transport']}")
|
|
612
|
+
return config["transport"]
|
|
613
|
+
|
|
614
|
+
# 2. Auto-detect from command (takes priority over URL)
|
|
615
|
+
if config.get("command"):
|
|
616
|
+
logger.debug("Auto-detected STDIO transport from command")
|
|
617
|
+
return "stdio"
|
|
618
|
+
|
|
619
|
+
# 3. Auto-detect from URL if it has protocol
|
|
620
|
+
url = config.get("url")
|
|
621
|
+
if url:
|
|
622
|
+
if url.startswith(("http://", "https://")):
|
|
623
|
+
if url.endswith("/sse"):
|
|
624
|
+
logger.debug(f"Auto-detected SSE transport from URL: {url}")
|
|
625
|
+
return "sse"
|
|
626
|
+
else:
|
|
627
|
+
logger.debug(f"Auto-detected HTTP transport from URL: {url}")
|
|
628
|
+
return "http"
|
|
629
|
+
else:
|
|
630
|
+
# URL without protocol - check if type field specifies transport
|
|
631
|
+
transport_type = config.get("type", "stdio")
|
|
632
|
+
if transport_type in ["http", "sse"]:
|
|
633
|
+
logger.debug(f"Using type field '{transport_type}' for URL without protocol: {url}")
|
|
634
|
+
return transport_type
|
|
635
|
+
else:
|
|
636
|
+
logger.debug(f"URL without protocol, defaulting to HTTP: {url}")
|
|
637
|
+
return "http"
|
|
638
|
+
|
|
639
|
+
# 4. Fallback to type field (backward compatibility)
|
|
640
|
+
transport_type = config.get("type", "stdio")
|
|
641
|
+
logger.debug(f"Using fallback transport type: {transport_type}")
|
|
642
|
+
return transport_type
|
|
643
|
+
|
|
644
|
+
async def _initialize_single_client(self, server_name: str, config: Dict[str, Any]) -> Optional[Client]:
|
|
645
|
+
"""Initialize a single MCP client. Returns None if initialization fails."""
|
|
646
|
+
safe_server_name = sanitize_for_logging(server_name)
|
|
647
|
+
# Keep INFO logs concise; config/transport details can be very verbose.
|
|
648
|
+
logger.info("Initializing MCP client for server '%s'", safe_server_name)
|
|
649
|
+
logger.debug("Server config for '%s': %s", safe_server_name, sanitize_for_logging(str(config)))
|
|
650
|
+
try:
|
|
651
|
+
transport_type = self._determine_transport_type(config)
|
|
652
|
+
logger.debug("Determined transport type for %s: %s", safe_server_name, transport_type)
|
|
653
|
+
|
|
654
|
+
if transport_type in ["http", "sse"]:
|
|
655
|
+
# HTTP/SSE MCP server
|
|
656
|
+
url = config.get("url")
|
|
657
|
+
if not url:
|
|
658
|
+
logger.error(f"No URL provided for HTTP/SSE server: {server_name}")
|
|
659
|
+
return None
|
|
660
|
+
|
|
661
|
+
# Ensure URL has protocol for FastMCP client
|
|
662
|
+
if not url.startswith(("http://", "https://")):
|
|
663
|
+
url = f"http://{url}"
|
|
664
|
+
logger.debug(f"Added http:// protocol to URL: {url}")
|
|
665
|
+
|
|
666
|
+
raw_token = config.get("auth_token")
|
|
667
|
+
try:
|
|
668
|
+
token = resolve_env_var(raw_token) # Resolve ${ENV_VAR} if present
|
|
669
|
+
except ValueError as e:
|
|
670
|
+
logger.error(f"Failed to resolve auth_token for {server_name}: {e}")
|
|
671
|
+
return None # Skip this server
|
|
672
|
+
|
|
673
|
+
# Create log handler for this server
|
|
674
|
+
log_handler = self._create_log_handler(server_name)
|
|
675
|
+
|
|
676
|
+
if transport_type == "sse":
|
|
677
|
+
# Use explicit SSE transport
|
|
678
|
+
logger.debug(f"Creating SSE client for {server_name} at {url}")
|
|
679
|
+
client = Client(
|
|
680
|
+
url,
|
|
681
|
+
auth=token,
|
|
682
|
+
log_handler=log_handler,
|
|
683
|
+
elicitation_handler=self._create_elicitation_handler(server_name),
|
|
684
|
+
sampling_handler=self._create_sampling_handler(server_name),
|
|
685
|
+
)
|
|
686
|
+
else:
|
|
687
|
+
# Use HTTP transport (StreamableHttp)
|
|
688
|
+
logger.debug(f"Creating HTTP client for {server_name} at {url}")
|
|
689
|
+
client = Client(
|
|
690
|
+
url,
|
|
691
|
+
auth=token,
|
|
692
|
+
log_handler=log_handler,
|
|
693
|
+
elicitation_handler=self._create_elicitation_handler(server_name),
|
|
694
|
+
sampling_handler=self._create_sampling_handler(server_name),
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
logger.info(f"Created {transport_type.upper()} MCP client for {server_name}")
|
|
698
|
+
return client
|
|
699
|
+
|
|
700
|
+
elif transport_type == "stdio":
|
|
701
|
+
# STDIO MCP server
|
|
702
|
+
command = config.get("command")
|
|
703
|
+
logger.debug("STDIO transport command for %s: %s", safe_server_name, command)
|
|
704
|
+
if command:
|
|
705
|
+
# Ensure MCP stdio servers run under the same interpreter as the backend.
|
|
706
|
+
# In dev containers, PATH `python` may not have required deps.
|
|
707
|
+
if command[0] in {"python", "python3"}:
|
|
708
|
+
command = [sys.executable, *command[1:]]
|
|
709
|
+
|
|
710
|
+
# Custom command specified
|
|
711
|
+
cwd = config.get("cwd")
|
|
712
|
+
env = config.get("env")
|
|
713
|
+
logger.debug("Working directory specified for %s: %s", safe_server_name, cwd)
|
|
714
|
+
|
|
715
|
+
# Resolve environment variables in env dict
|
|
716
|
+
resolved_env = None
|
|
717
|
+
if env is not None:
|
|
718
|
+
resolved_env = {}
|
|
719
|
+
for key, value in env.items():
|
|
720
|
+
try:
|
|
721
|
+
resolved_env[key] = resolve_env_var(value)
|
|
722
|
+
logger.debug(f"Resolved env var {key} for {server_name}")
|
|
723
|
+
except ValueError as e:
|
|
724
|
+
logger.error(f"Failed to resolve env var {key} for {server_name}: {e}")
|
|
725
|
+
return None # Skip this server if env var resolution fails
|
|
726
|
+
logger.debug("Environment variables specified for %s: %s", safe_server_name, list(resolved_env.keys()))
|
|
727
|
+
|
|
728
|
+
# Create log handler for this server
|
|
729
|
+
log_handler = self._create_log_handler(server_name)
|
|
730
|
+
|
|
731
|
+
if cwd:
|
|
732
|
+
# Convert relative path to absolute path from project root
|
|
733
|
+
if not os.path.isabs(cwd):
|
|
734
|
+
# Get project root (3 levels up from client.py)
|
|
735
|
+
# client.py is at: /workspaces/atlas-ui-3-11/backend/modules/mcp_tools/client.py
|
|
736
|
+
# project root is: /workspaces/atlas-ui-3-11
|
|
737
|
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
738
|
+
cwd = os.path.join(project_root, cwd)
|
|
739
|
+
logger.debug("Converted relative cwd to absolute for %s: %s", safe_server_name, cwd)
|
|
740
|
+
|
|
741
|
+
if os.path.exists(cwd):
|
|
742
|
+
logger.debug("Working directory exists for %s: %s", safe_server_name, cwd)
|
|
743
|
+
logger.debug("Creating STDIO client for %s with command=%s cwd=%s", safe_server_name, command, cwd)
|
|
744
|
+
from fastmcp.client.transports import StdioTransport
|
|
745
|
+
transport = StdioTransport(command=command[0], args=command[1:], cwd=cwd, env=resolved_env)
|
|
746
|
+
client = Client(
|
|
747
|
+
transport,
|
|
748
|
+
log_handler=log_handler,
|
|
749
|
+
elicitation_handler=self._create_elicitation_handler(server_name),
|
|
750
|
+
sampling_handler=self._create_sampling_handler(server_name),
|
|
751
|
+
)
|
|
752
|
+
logger.info(f"Successfully created STDIO MCP client for {server_name} with custom command and cwd")
|
|
753
|
+
return client
|
|
754
|
+
else:
|
|
755
|
+
logger.error(f"Working directory does not exist: {cwd}")
|
|
756
|
+
return None
|
|
757
|
+
else:
|
|
758
|
+
logger.debug("No cwd specified for %s; creating STDIO client with command=%s", safe_server_name, command)
|
|
759
|
+
from fastmcp.client.transports import StdioTransport
|
|
760
|
+
transport = StdioTransport(command=command[0], args=command[1:], env=resolved_env)
|
|
761
|
+
client = Client(
|
|
762
|
+
transport,
|
|
763
|
+
log_handler=log_handler,
|
|
764
|
+
elicitation_handler=self._create_elicitation_handler(server_name),
|
|
765
|
+
sampling_handler=self._create_sampling_handler(server_name),
|
|
766
|
+
)
|
|
767
|
+
logger.info(f"Successfully created STDIO MCP client for {server_name} with custom command")
|
|
768
|
+
return client
|
|
769
|
+
else:
|
|
770
|
+
# Fallback to old behavior for backward compatibility
|
|
771
|
+
server_path = f"mcp/{server_name}/main.py"
|
|
772
|
+
logger.debug(f"Attempting to initialize {server_name} at path: {server_path}")
|
|
773
|
+
if os.path.exists(server_path):
|
|
774
|
+
logger.debug(f"Server script exists for {server_name}, creating client...")
|
|
775
|
+
log_handler = self._create_log_handler(server_name)
|
|
776
|
+
client = Client(
|
|
777
|
+
server_path,
|
|
778
|
+
log_handler=log_handler,
|
|
779
|
+
elicitation_handler=self._create_elicitation_handler(server_name),
|
|
780
|
+
sampling_handler=self._create_sampling_handler(server_name),
|
|
781
|
+
) # Client auto-detects STDIO transport from .py file
|
|
782
|
+
logger.info(f"Created MCP client for {server_name}")
|
|
783
|
+
logger.debug(f"Successfully created client for {server_name}")
|
|
784
|
+
return client
|
|
785
|
+
else:
|
|
786
|
+
logger.error(f"MCP server script not found: {server_path}", exc_info=True)
|
|
787
|
+
return None
|
|
788
|
+
else:
|
|
789
|
+
logger.error(f"Unsupported transport type '{transport_type}' for server: {server_name}")
|
|
790
|
+
return None
|
|
791
|
+
|
|
792
|
+
except Exception as e:
|
|
793
|
+
# Targeted debugging for MCP startup errors
|
|
794
|
+
error_type = type(e).__name__
|
|
795
|
+
logger.error(f"Error creating client for {server_name}: {error_type}: {e}")
|
|
796
|
+
|
|
797
|
+
# Provide specific debugging information based on error type and config
|
|
798
|
+
if "connection" in str(e).lower() or "refused" in str(e).lower():
|
|
799
|
+
if transport_type in ["http", "sse"]:
|
|
800
|
+
logger.error(f"DEBUG: Connection failed for HTTP/SSE server '{server_name}'")
|
|
801
|
+
logger.error(f" → URL: {config.get('url', 'Not specified')}")
|
|
802
|
+
logger.error(f" → Transport: {transport_type}")
|
|
803
|
+
logger.error(" → Check if server is running and accessible")
|
|
804
|
+
else:
|
|
805
|
+
logger.error(f"DEBUG: STDIO connection failed for server '{server_name}'")
|
|
806
|
+
logger.error(f" → Command: {config.get('command', 'Not specified')}")
|
|
807
|
+
logger.error(f" → CWD: {config.get('cwd', 'Not specified')}")
|
|
808
|
+
logger.error(" → Check if command exists and is executable")
|
|
809
|
+
|
|
810
|
+
elif "timeout" in str(e).lower():
|
|
811
|
+
logger.error(f"DEBUG: Timeout connecting to server '{server_name}'")
|
|
812
|
+
logger.error(" → Server may be slow to start or overloaded")
|
|
813
|
+
logger.error(" → Consider increasing timeout or checking server health")
|
|
814
|
+
|
|
815
|
+
elif "permission" in str(e).lower() or "access" in str(e).lower():
|
|
816
|
+
logger.error(f"DEBUG: Permission error for server '{server_name}'")
|
|
817
|
+
if config.get('cwd'):
|
|
818
|
+
logger.error(f" → Check directory permissions: {config.get('cwd')}")
|
|
819
|
+
if config.get('command'):
|
|
820
|
+
logger.error(f" → Check executable permissions: {config.get('command')}")
|
|
821
|
+
|
|
822
|
+
elif "module" in str(e).lower() or "import" in str(e).lower():
|
|
823
|
+
logger.error(f"DEBUG: Import/module error for server '{server_name}'")
|
|
824
|
+
logger.error(" → Check if required dependencies are installed")
|
|
825
|
+
logger.error(" → Check Python path and virtual environment")
|
|
826
|
+
|
|
827
|
+
elif "json" in str(e).lower() or "decode" in str(e).lower():
|
|
828
|
+
logger.error(f"DEBUG: JSON/protocol error for server '{server_name}'")
|
|
829
|
+
logger.error(" → Server may not be MCP-compatible")
|
|
830
|
+
logger.error(" → Check server output format")
|
|
831
|
+
|
|
832
|
+
else:
|
|
833
|
+
# Generic debugging info
|
|
834
|
+
logger.error(f"DEBUG: Generic error for server '{server_name}'")
|
|
835
|
+
logger.error(f" → Config: {config}")
|
|
836
|
+
logger.error(f" → Transport type: {transport_type}")
|
|
837
|
+
|
|
838
|
+
# Always show the full traceback in debug mode
|
|
839
|
+
logger.debug(f"Full traceback for {server_name}:", exc_info=True)
|
|
840
|
+
return None
|
|
841
|
+
|
|
842
|
+
async def initialize_clients(self):
|
|
843
|
+
"""Initialize FastMCP clients for all configured servers in parallel."""
|
|
844
|
+
logger.info("Starting MCP client initialization for %d servers", len(self.servers_config))
|
|
845
|
+
logger.debug("MCP servers to initialize: %s", list(self.servers_config.keys()))
|
|
846
|
+
|
|
847
|
+
# Create tasks for parallel initialization
|
|
848
|
+
tasks = [
|
|
849
|
+
self._initialize_single_client(server_name, config)
|
|
850
|
+
for server_name, config in self.servers_config.items()
|
|
851
|
+
]
|
|
852
|
+
server_names = list(self.servers_config.keys())
|
|
853
|
+
|
|
854
|
+
# Run all initialization tasks in parallel
|
|
855
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
856
|
+
|
|
857
|
+
# Process results and store successful clients
|
|
858
|
+
for server_name, result in zip(server_names, results):
|
|
859
|
+
if isinstance(result, Exception):
|
|
860
|
+
error_msg = f"{type(result).__name__}: {result}"
|
|
861
|
+
logger.error(f"Exception during client initialization for {server_name}: {error_msg}", exc_info=True)
|
|
862
|
+
self._record_server_failure(server_name, error_msg)
|
|
863
|
+
elif result is not None:
|
|
864
|
+
self.clients[server_name] = result
|
|
865
|
+
self._clear_server_failure(server_name)
|
|
866
|
+
logger.info(f"Successfully initialized client for {server_name}")
|
|
867
|
+
else:
|
|
868
|
+
self._record_server_failure(server_name, "Initialization returned None")
|
|
869
|
+
logger.warning(f"Failed to initialize client for {server_name}")
|
|
870
|
+
|
|
871
|
+
failed_servers = sorted(set(self.servers_config.keys()) - set(self.clients.keys()))
|
|
872
|
+
logger.info(
|
|
873
|
+
"MCP client initialization complete: %d/%d connected (%d failed)",
|
|
874
|
+
len(self.clients),
|
|
875
|
+
len(self.servers_config),
|
|
876
|
+
len(failed_servers),
|
|
877
|
+
)
|
|
878
|
+
logger.debug("MCP clients initialized: %s", list(self.clients.keys()))
|
|
879
|
+
logger.debug("MCP clients failed to initialize: %s", failed_servers)
|
|
880
|
+
|
|
881
|
+
async def reconnect_failed_servers(self, force: bool = False) -> Dict[str, Any]:
|
|
882
|
+
"""Attempt to reconnect to servers that previously failed.
|
|
883
|
+
|
|
884
|
+
When ``force`` is False (default), this respects exponential backoff and
|
|
885
|
+
only attempts servers whose backoff delay has elapsed. When ``force`` is
|
|
886
|
+
True, backoff delays are ignored and all currently failed servers are
|
|
887
|
+
attempted immediately. The admin `/admin/mcp/reconnect` endpoint uses
|
|
888
|
+
``force=True`` to provide an on-demand retry button.
|
|
889
|
+
|
|
890
|
+
Returns:
|
|
891
|
+
Dict with reconnection results including newly connected, still
|
|
892
|
+
failed, and skipped servers due to backoff.
|
|
893
|
+
"""
|
|
894
|
+
if not self._failed_servers:
|
|
895
|
+
return {
|
|
896
|
+
"attempted": [],
|
|
897
|
+
"reconnected": [],
|
|
898
|
+
"still_failed": [],
|
|
899
|
+
"skipped_backoff": []
|
|
900
|
+
}
|
|
901
|
+
|
|
902
|
+
current_time = time.time()
|
|
903
|
+
attempted = []
|
|
904
|
+
reconnected = []
|
|
905
|
+
still_failed = []
|
|
906
|
+
skipped_backoff = []
|
|
907
|
+
|
|
908
|
+
for server_name, failure_info in list(self._failed_servers.items()):
|
|
909
|
+
# Skip if server is no longer in config
|
|
910
|
+
if server_name not in self.servers_config:
|
|
911
|
+
self._clear_server_failure(server_name)
|
|
912
|
+
continue
|
|
913
|
+
|
|
914
|
+
# Skip if already connected
|
|
915
|
+
if server_name in self.clients:
|
|
916
|
+
self._clear_server_failure(server_name)
|
|
917
|
+
continue
|
|
918
|
+
|
|
919
|
+
# Check backoff delay unless this is a forced reconnect
|
|
920
|
+
backoff_delay = self._calculate_backoff_delay(failure_info["attempt_count"])
|
|
921
|
+
time_since_last = current_time - failure_info["last_attempt"]
|
|
922
|
+
|
|
923
|
+
if not force and time_since_last < backoff_delay:
|
|
924
|
+
skipped_backoff.append({
|
|
925
|
+
"server": server_name,
|
|
926
|
+
"wait_remaining": backoff_delay - time_since_last,
|
|
927
|
+
"attempt_count": failure_info["attempt_count"]
|
|
928
|
+
})
|
|
929
|
+
continue
|
|
930
|
+
|
|
931
|
+
# Attempt reconnection
|
|
932
|
+
attempted.append(server_name)
|
|
933
|
+
config = self.servers_config[server_name]
|
|
934
|
+
|
|
935
|
+
try:
|
|
936
|
+
client = await self._initialize_single_client(server_name, config)
|
|
937
|
+
if client is not None:
|
|
938
|
+
self.clients[server_name] = client
|
|
939
|
+
self._clear_server_failure(server_name)
|
|
940
|
+
reconnected.append(server_name)
|
|
941
|
+
logger.info(f"Successfully reconnected to MCP server: {server_name}")
|
|
942
|
+
|
|
943
|
+
# Discover tools and prompts for the reconnected server
|
|
944
|
+
await self._discover_and_register_server(server_name, client)
|
|
945
|
+
else:
|
|
946
|
+
self._record_server_failure(server_name, "Reconnection returned None")
|
|
947
|
+
still_failed.append(server_name)
|
|
948
|
+
except Exception as e:
|
|
949
|
+
error_msg = f"{type(e).__name__}: {e}"
|
|
950
|
+
self._record_server_failure(server_name, error_msg)
|
|
951
|
+
still_failed.append(server_name)
|
|
952
|
+
logger.warning(f"Failed to reconnect to MCP server {server_name}: {error_msg}")
|
|
953
|
+
|
|
954
|
+
return {
|
|
955
|
+
"attempted": attempted,
|
|
956
|
+
"reconnected": reconnected,
|
|
957
|
+
"still_failed": still_failed,
|
|
958
|
+
"skipped_backoff": skipped_backoff
|
|
959
|
+
}
|
|
960
|
+
|
|
961
|
+
async def _discover_and_register_server(self, server_name: str, client: Client) -> None:
|
|
962
|
+
"""Discover tools and prompts for a single server and register them."""
|
|
963
|
+
try:
|
|
964
|
+
# Discover tools
|
|
965
|
+
tool_data = await self._discover_tools_for_server(server_name, client)
|
|
966
|
+
self.available_tools[server_name] = tool_data
|
|
967
|
+
|
|
968
|
+
# Update tool index
|
|
969
|
+
if hasattr(self, "_tool_index"):
|
|
970
|
+
for tool in tool_data.get('tools', []):
|
|
971
|
+
full_name = f"{server_name}_{tool.name}"
|
|
972
|
+
self._tool_index[full_name] = {
|
|
973
|
+
'server': server_name,
|
|
974
|
+
'tool': tool
|
|
975
|
+
}
|
|
976
|
+
|
|
977
|
+
# Discover prompts
|
|
978
|
+
prompt_data = await self._discover_prompts_for_server(server_name, client)
|
|
979
|
+
self.available_prompts[server_name] = prompt_data
|
|
980
|
+
|
|
981
|
+
logger.info(
|
|
982
|
+
f"Registered server {server_name}: "
|
|
983
|
+
f"{len(tool_data.get('tools', []))} tools, "
|
|
984
|
+
f"{len(prompt_data.get('prompts', []))} prompts"
|
|
985
|
+
)
|
|
986
|
+
except Exception as e:
|
|
987
|
+
logger.error(f"Error discovering tools/prompts for {server_name}: {e}")
|
|
988
|
+
|
|
989
|
+
async def start_auto_reconnect(self) -> None:
|
|
990
|
+
"""Start the background auto-reconnect task.
|
|
991
|
+
|
|
992
|
+
This task periodically attempts to reconnect to failed MCP servers
|
|
993
|
+
using exponential backoff. Only runs if FEATURE_MCP_AUTO_RECONNECT_ENABLED is true.
|
|
994
|
+
"""
|
|
995
|
+
app_settings = config_manager.app_settings
|
|
996
|
+
if not app_settings.feature_mcp_auto_reconnect_enabled:
|
|
997
|
+
logger.info("MCP auto-reconnect is disabled (FEATURE_MCP_AUTO_RECONNECT_ENABLED=false)")
|
|
998
|
+
return
|
|
999
|
+
|
|
1000
|
+
if self._reconnect_running:
|
|
1001
|
+
logger.warning("Auto-reconnect task is already running")
|
|
1002
|
+
return
|
|
1003
|
+
|
|
1004
|
+
self._reconnect_running = True
|
|
1005
|
+
self._reconnect_task = asyncio.create_task(self._auto_reconnect_loop())
|
|
1006
|
+
logger.info("Started MCP auto-reconnect background task")
|
|
1007
|
+
|
|
1008
|
+
async def stop_auto_reconnect(self) -> None:
|
|
1009
|
+
"""Stop the background auto-reconnect task."""
|
|
1010
|
+
self._reconnect_running = False
|
|
1011
|
+
if self._reconnect_task:
|
|
1012
|
+
self._reconnect_task.cancel()
|
|
1013
|
+
try:
|
|
1014
|
+
await self._reconnect_task
|
|
1015
|
+
except asyncio.CancelledError:
|
|
1016
|
+
pass
|
|
1017
|
+
self._reconnect_task = None
|
|
1018
|
+
logger.info("Stopped MCP auto-reconnect background task")
|
|
1019
|
+
|
|
1020
|
+
async def _auto_reconnect_loop(self) -> None:
|
|
1021
|
+
"""Background loop that periodically attempts to reconnect failed servers."""
|
|
1022
|
+
app_settings = config_manager.app_settings
|
|
1023
|
+
base_interval = app_settings.mcp_reconnect_interval
|
|
1024
|
+
|
|
1025
|
+
while self._reconnect_running:
|
|
1026
|
+
try:
|
|
1027
|
+
await asyncio.sleep(base_interval)
|
|
1028
|
+
|
|
1029
|
+
if not self._failed_servers:
|
|
1030
|
+
continue
|
|
1031
|
+
|
|
1032
|
+
logger.debug(
|
|
1033
|
+
f"Auto-reconnect: checking {len(self._failed_servers)} failed servers"
|
|
1034
|
+
)
|
|
1035
|
+
result = await self.reconnect_failed_servers()
|
|
1036
|
+
|
|
1037
|
+
if result["reconnected"]:
|
|
1038
|
+
logger.info(
|
|
1039
|
+
f"Auto-reconnect: successfully reconnected {len(result['reconnected'])} servers: "
|
|
1040
|
+
f"{result['reconnected']}"
|
|
1041
|
+
)
|
|
1042
|
+
if result["still_failed"]:
|
|
1043
|
+
logger.debug(
|
|
1044
|
+
f"Auto-reconnect: {len(result['still_failed'])} servers still failed"
|
|
1045
|
+
)
|
|
1046
|
+
|
|
1047
|
+
except asyncio.CancelledError:
|
|
1048
|
+
break
|
|
1049
|
+
except Exception as e:
|
|
1050
|
+
logger.error(f"Error in auto-reconnect loop: {e}", exc_info=True)
|
|
1051
|
+
await asyncio.sleep(base_interval) # Wait before retrying
|
|
1052
|
+
|
|
1053
|
+
async def _discover_tools_for_server(self, server_name: str, client: Client) -> Dict[str, Any]:
|
|
1054
|
+
"""Discover tools for a single server. Returns server tools data."""
|
|
1055
|
+
safe_server_name = sanitize_for_logging(server_name)
|
|
1056
|
+
server_config = self.servers_config.get(server_name, {})
|
|
1057
|
+
safe_config = sanitize_for_logging(str(server_config))
|
|
1058
|
+
discovery_timeout = config_manager.app_settings.mcp_discovery_timeout
|
|
1059
|
+
logger.debug("Tool discovery: starting for server '%s'", safe_server_name)
|
|
1060
|
+
logger.debug("Server config (sanitized): %s", safe_config)
|
|
1061
|
+
try:
|
|
1062
|
+
logger.debug("Opening client connection for %s", safe_server_name)
|
|
1063
|
+
async with client:
|
|
1064
|
+
logger.debug("Client connected for %s; listing tools", safe_server_name)
|
|
1065
|
+
tools = await asyncio.wait_for(client.list_tools(), timeout=discovery_timeout)
|
|
1066
|
+
logger.debug("Got %d tools from %s: %s", len(tools), safe_server_name, [tool.name for tool in tools])
|
|
1067
|
+
|
|
1068
|
+
# Log detailed tool information
|
|
1069
|
+
for i, tool in enumerate(tools):
|
|
1070
|
+
logger.debug(
|
|
1071
|
+
" Tool %d: name='%s', description='%s'",
|
|
1072
|
+
i + 1,
|
|
1073
|
+
tool.name,
|
|
1074
|
+
getattr(tool, 'description', 'No description'),
|
|
1075
|
+
)
|
|
1076
|
+
|
|
1077
|
+
server_data = {
|
|
1078
|
+
'tools': tools,
|
|
1079
|
+
'config': self.servers_config[server_name]
|
|
1080
|
+
}
|
|
1081
|
+
logger.debug("Stored %d tools for %s", len(tools), safe_server_name)
|
|
1082
|
+
return server_data
|
|
1083
|
+
except Exception as e:
|
|
1084
|
+
error_type = type(e).__name__
|
|
1085
|
+
error_msg = sanitize_for_logging(str(e))
|
|
1086
|
+
logger.error(f"TOOL DISCOVERY FAILED for '{safe_server_name}': {error_type}: {error_msg}")
|
|
1087
|
+
|
|
1088
|
+
# Targeted debugging for tool discovery errors
|
|
1089
|
+
error_lower = str(e).lower()
|
|
1090
|
+
if "connection" in error_lower or "refused" in error_lower:
|
|
1091
|
+
logger.error(f"DEBUG: Connection lost during tool discovery for '{safe_server_name}'")
|
|
1092
|
+
logger.error(" → Server may have crashed or disconnected")
|
|
1093
|
+
logger.error(" → Check server logs for startup errors")
|
|
1094
|
+
# Check if this is an HTTPS/SSL issue
|
|
1095
|
+
if "ssl" in error_lower or "certificate" in error_lower or "https" in error_lower:
|
|
1096
|
+
logger.error(" → SSL/HTTPS error detected")
|
|
1097
|
+
logger.error(" → On Windows, ensure SSL certificates are properly configured")
|
|
1098
|
+
logger.error(" → Try setting REQUESTS_CA_BUNDLE or SSL_CERT_FILE environment variables")
|
|
1099
|
+
elif "timeout" in error_lower:
|
|
1100
|
+
logger.error(f"DEBUG: Timeout during tool discovery for '{safe_server_name}'")
|
|
1101
|
+
logger.error(" → Server is slow to respond to list_tools() request")
|
|
1102
|
+
logger.error(" → Server may be overloaded or hanging")
|
|
1103
|
+
elif "json" in error_lower or "decode" in error_lower:
|
|
1104
|
+
logger.error(f"DEBUG: Protocol error during tool discovery for '{safe_server_name}'")
|
|
1105
|
+
logger.error(" → Server returned invalid MCP response")
|
|
1106
|
+
logger.error(" → Check if server implements MCP protocol correctly")
|
|
1107
|
+
elif "ssl" in error_lower or "certificate" in error_lower:
|
|
1108
|
+
logger.error(f"DEBUG: SSL/Certificate error during tool discovery for '{safe_server_name}'")
|
|
1109
|
+
logger.error(f" → URL: {server_config.get('url', 'N/A')}")
|
|
1110
|
+
logger.error(" → SSL certificate verification failed")
|
|
1111
|
+
logger.error(" → On Windows, this may require installing/updating CA certificates")
|
|
1112
|
+
logger.error(" → Check if the server URL uses HTTPS with a self-signed or untrusted certificate")
|
|
1113
|
+
else:
|
|
1114
|
+
logger.error(f"DEBUG: Generic tool discovery error for '{safe_server_name}'")
|
|
1115
|
+
logger.error(f" → Client type: {type(client).__name__}")
|
|
1116
|
+
logger.error(f" → Server URL: {server_config.get('url', 'N/A')}")
|
|
1117
|
+
logger.error(f" → Transport type: {server_config.get('transport', server_config.get('type', 'N/A'))}")
|
|
1118
|
+
|
|
1119
|
+
# Record failure for status/reconnect purposes
|
|
1120
|
+
self._record_server_failure(server_name, f"{error_type}: {error_msg}")
|
|
1121
|
+
|
|
1122
|
+
logger.debug(f"Full tool discovery traceback for {safe_server_name}:", exc_info=True)
|
|
1123
|
+
|
|
1124
|
+
server_data = {
|
|
1125
|
+
'tools': [],
|
|
1126
|
+
'config': server_config,
|
|
1127
|
+
}
|
|
1128
|
+
logger.debug(
|
|
1129
|
+
"Set empty tools list for failed server '%s' (config_present=%s)",
|
|
1130
|
+
safe_server_name,
|
|
1131
|
+
server_config is not None,
|
|
1132
|
+
)
|
|
1133
|
+
return server_data
|
|
1134
|
+
|
|
1135
|
+
async def discover_tools(self):
|
|
1136
|
+
"""Discover tools from all MCP servers in parallel."""
|
|
1137
|
+
logger.info("Starting MCP tool discovery for %d connected servers", len(self.clients))
|
|
1138
|
+
logger.debug("Tool discovery servers: %s", list(self.clients.keys()))
|
|
1139
|
+
self.available_tools = {}
|
|
1140
|
+
|
|
1141
|
+
# Create tasks for parallel tool discovery
|
|
1142
|
+
tasks = [
|
|
1143
|
+
self._discover_tools_for_server(server_name, client)
|
|
1144
|
+
for server_name, client in self.clients.items()
|
|
1145
|
+
]
|
|
1146
|
+
server_names = list(self.clients.keys())
|
|
1147
|
+
|
|
1148
|
+
# Run all tool discovery tasks in parallel
|
|
1149
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
1150
|
+
|
|
1151
|
+
# Process results and store server tools data
|
|
1152
|
+
for server_name, result in zip(server_names, results):
|
|
1153
|
+
# Skip clients whose config was removed during reload
|
|
1154
|
+
if server_name not in self.servers_config:
|
|
1155
|
+
logger.warning(
|
|
1156
|
+
f"Skipping tool discovery result for '{server_name}' because it is no longer in servers_config"
|
|
1157
|
+
)
|
|
1158
|
+
continue
|
|
1159
|
+
|
|
1160
|
+
if isinstance(result, Exception):
|
|
1161
|
+
logger.error(f"Exception during tool discovery for {server_name}: {result}", exc_info=True)
|
|
1162
|
+
# Record failure and set empty tools list for failed server
|
|
1163
|
+
self._record_server_failure(server_name, f"Exception during tool discovery: {result}")
|
|
1164
|
+
self.available_tools[server_name] = {
|
|
1165
|
+
'tools': [],
|
|
1166
|
+
'config': self.servers_config.get(server_name),
|
|
1167
|
+
}
|
|
1168
|
+
else:
|
|
1169
|
+
# Clear any previous discovery failure on success
|
|
1170
|
+
self._clear_server_failure(server_name)
|
|
1171
|
+
self.available_tools[server_name] = result
|
|
1172
|
+
|
|
1173
|
+
total_tools = sum(len(server_data.get('tools', [])) for server_data in self.available_tools.values())
|
|
1174
|
+
logger.info(
|
|
1175
|
+
"MCP tool discovery complete: %d tools across %d servers",
|
|
1176
|
+
total_tools,
|
|
1177
|
+
len(self.available_tools),
|
|
1178
|
+
)
|
|
1179
|
+
for server_name, server_data in self.available_tools.items():
|
|
1180
|
+
tool_names = [tool.name for tool in server_data.get('tools', [])]
|
|
1181
|
+
logger.debug("Tool discovery summary: %s: %d tools %s", server_name, len(tool_names), tool_names)
|
|
1182
|
+
|
|
1183
|
+
# Build tool index for quick lookups
|
|
1184
|
+
self._tool_index = {}
|
|
1185
|
+
for server_name, server_data in self.available_tools.items():
|
|
1186
|
+
if server_name == "canvas":
|
|
1187
|
+
self._tool_index["canvas_canvas"] = {
|
|
1188
|
+
'server': 'canvas',
|
|
1189
|
+
'tool': None # pseudo tool
|
|
1190
|
+
}
|
|
1191
|
+
else:
|
|
1192
|
+
for tool in server_data.get('tools', []):
|
|
1193
|
+
full_name = f"{server_name}_{tool.name}"
|
|
1194
|
+
self._tool_index[full_name] = {
|
|
1195
|
+
'server': server_name,
|
|
1196
|
+
'tool': tool
|
|
1197
|
+
}
|
|
1198
|
+
|
|
1199
|
+
async def _discover_prompts_for_server(self, server_name: str, client: Client) -> Dict[str, Any]:
|
|
1200
|
+
"""Discover prompts for a single server. Returns server prompts data."""
|
|
1201
|
+
safe_server_name = sanitize_for_logging(server_name)
|
|
1202
|
+
server_config = self.servers_config.get(server_name, {})
|
|
1203
|
+
discovery_timeout = config_manager.app_settings.mcp_discovery_timeout
|
|
1204
|
+
logger.debug(f"Attempting to discover prompts from {safe_server_name}")
|
|
1205
|
+
try:
|
|
1206
|
+
logger.debug(f"Opening client connection for {safe_server_name}")
|
|
1207
|
+
async with client:
|
|
1208
|
+
logger.debug(f"Client connected for {safe_server_name}, listing prompts...")
|
|
1209
|
+
try:
|
|
1210
|
+
prompts = await asyncio.wait_for(client.list_prompts(), timeout=discovery_timeout)
|
|
1211
|
+
logger.debug(
|
|
1212
|
+
f"Got {len(prompts)} prompts from {safe_server_name}: {[prompt.name for prompt in prompts]}"
|
|
1213
|
+
)
|
|
1214
|
+
server_data = {
|
|
1215
|
+
'prompts': prompts,
|
|
1216
|
+
'config': server_config,
|
|
1217
|
+
}
|
|
1218
|
+
logger.info(f"Discovered {len(prompts)} prompts from {safe_server_name}")
|
|
1219
|
+
logger.debug(f"Successfully stored prompts for {safe_server_name}")
|
|
1220
|
+
return server_data
|
|
1221
|
+
except Exception as e:
|
|
1222
|
+
# Server might not support prompts or list_prompts() failed store empty list
|
|
1223
|
+
logger.debug(
|
|
1224
|
+
f"Server {safe_server_name} does not support prompts or list_prompts() failed: {e}"
|
|
1225
|
+
)
|
|
1226
|
+
return {
|
|
1227
|
+
'prompts': [],
|
|
1228
|
+
'config': server_config,
|
|
1229
|
+
}
|
|
1230
|
+
except Exception as e:
|
|
1231
|
+
error_type = type(e).__name__
|
|
1232
|
+
error_msg = sanitize_for_logging(str(e))
|
|
1233
|
+
logger.error(f"PROMPT DISCOVERY FAILED for '{safe_server_name}': {error_type}: {error_msg}")
|
|
1234
|
+
|
|
1235
|
+
# Targeted debugging for prompt discovery errors
|
|
1236
|
+
error_lower = str(e).lower()
|
|
1237
|
+
if "connection" in error_lower or "refused" in error_lower:
|
|
1238
|
+
logger.error(f"DEBUG: Connection lost during prompt discovery for '{safe_server_name}'")
|
|
1239
|
+
logger.error(" → Server may have crashed or disconnected")
|
|
1240
|
+
# Check if this is an HTTPS/SSL issue
|
|
1241
|
+
if "ssl" in error_lower or "certificate" in error_lower or "https" in error_lower:
|
|
1242
|
+
logger.error(" → SSL/HTTPS error detected")
|
|
1243
|
+
logger.error(" → On Windows, ensure SSL certificates are properly configured")
|
|
1244
|
+
elif "timeout" in error_lower:
|
|
1245
|
+
logger.error(f"DEBUG: Timeout during prompt discovery for '{safe_server_name}'")
|
|
1246
|
+
logger.error(" → Server is slow to respond to list_prompts() request")
|
|
1247
|
+
elif "json" in error_lower or "decode" in error_lower:
|
|
1248
|
+
logger.error(f"DEBUG: Protocol error during prompt discovery for '{safe_server_name}'")
|
|
1249
|
+
logger.error(" → Server returned invalid MCP response for prompts")
|
|
1250
|
+
elif "ssl" in error_lower or "certificate" in error_lower:
|
|
1251
|
+
logger.error(f"DEBUG: SSL/Certificate error during prompt discovery for '{safe_server_name}'")
|
|
1252
|
+
logger.error(f" → URL: {server_config.get('url', 'N/A')}")
|
|
1253
|
+
logger.error(" → SSL certificate verification failed")
|
|
1254
|
+
logger.error(" → On Windows, this may require installing/updating CA certificates")
|
|
1255
|
+
else:
|
|
1256
|
+
logger.error(f"DEBUG: Generic prompt discovery error for '{safe_server_name}'")
|
|
1257
|
+
|
|
1258
|
+
# Record failure for status/reconnect purposes
|
|
1259
|
+
self._record_server_failure(server_name, f"{error_type}: {error_msg}")
|
|
1260
|
+
|
|
1261
|
+
logger.debug(f"Full prompt discovery traceback for {safe_server_name}:", exc_info=True)
|
|
1262
|
+
logger.debug(f"Set empty prompts list for failed server {safe_server_name}")
|
|
1263
|
+
return {
|
|
1264
|
+
'prompts': [],
|
|
1265
|
+
'config': server_config,
|
|
1266
|
+
}
|
|
1267
|
+
|
|
1268
|
+
async def discover_prompts(self):
|
|
1269
|
+
"""Discover prompts from all MCP servers in parallel."""
|
|
1270
|
+
logger.info("Starting MCP prompt discovery for %d connected servers", len(self.clients))
|
|
1271
|
+
logger.debug("Prompt discovery servers: %s", list(self.clients.keys()))
|
|
1272
|
+
self.available_prompts = {}
|
|
1273
|
+
|
|
1274
|
+
# Create tasks for parallel prompt discovery
|
|
1275
|
+
tasks = [
|
|
1276
|
+
self._discover_prompts_for_server(server_name, client)
|
|
1277
|
+
for server_name, client in self.clients.items()
|
|
1278
|
+
]
|
|
1279
|
+
server_names = list(self.clients.keys())
|
|
1280
|
+
|
|
1281
|
+
# Run all prompt discovery tasks in parallel
|
|
1282
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
1283
|
+
|
|
1284
|
+
# Process results and store server prompts data
|
|
1285
|
+
for server_name, result in zip(server_names, results):
|
|
1286
|
+
# Skip clients whose config was removed during reload
|
|
1287
|
+
if server_name not in self.servers_config:
|
|
1288
|
+
logger.warning(
|
|
1289
|
+
f"Skipping prompt discovery result for '{server_name}' because it is no longer in servers_config"
|
|
1290
|
+
)
|
|
1291
|
+
continue
|
|
1292
|
+
|
|
1293
|
+
if isinstance(result, Exception):
|
|
1294
|
+
logger.error(f"Exception during prompt discovery for {server_name}: {result}", exc_info=True)
|
|
1295
|
+
# Record failure and set empty prompts list for failed server
|
|
1296
|
+
self._record_server_failure(server_name, f"Exception during prompt discovery: {result}")
|
|
1297
|
+
self.available_prompts[server_name] = {
|
|
1298
|
+
'prompts': [],
|
|
1299
|
+
'config': self.servers_config.get(server_name),
|
|
1300
|
+
}
|
|
1301
|
+
else:
|
|
1302
|
+
# Clear any previous discovery failure on success
|
|
1303
|
+
self._clear_server_failure(server_name)
|
|
1304
|
+
self.available_prompts[server_name] = result
|
|
1305
|
+
|
|
1306
|
+
total_prompts = sum(len(server_data.get('prompts', [])) for server_data in self.available_prompts.values())
|
|
1307
|
+
logger.info(
|
|
1308
|
+
"MCP prompt discovery complete: %d prompts across %d servers",
|
|
1309
|
+
total_prompts,
|
|
1310
|
+
len(self.available_prompts),
|
|
1311
|
+
)
|
|
1312
|
+
for server_name, server_data in self.available_prompts.items():
|
|
1313
|
+
prompt_names = [prompt.name for prompt in server_data.get('prompts', [])]
|
|
1314
|
+
logger.debug("Prompt discovery summary: %s: %d prompts %s", server_name, len(prompt_names), prompt_names)
|
|
1315
|
+
|
|
1316
|
+
def get_server_groups(self, server_name: str) -> List[str]:
|
|
1317
|
+
"""Get required groups for a server."""
|
|
1318
|
+
if server_name in self.servers_config:
|
|
1319
|
+
return self.servers_config[server_name].get("groups", [])
|
|
1320
|
+
return []
|
|
1321
|
+
|
|
1322
|
+
def get_available_servers(self) -> List[str]:
|
|
1323
|
+
"""Get list of configured servers."""
|
|
1324
|
+
return list(self.servers_config.keys())
|
|
1325
|
+
|
|
1326
|
+
def get_tools_for_servers(self, server_names: List[str]) -> Dict[str, Any]:
|
|
1327
|
+
"""Get tools and their schemas for selected servers."""
|
|
1328
|
+
tools_schema = []
|
|
1329
|
+
server_tool_mapping = {}
|
|
1330
|
+
|
|
1331
|
+
for server_name in server_names:
|
|
1332
|
+
# Handle canvas pseudo-tool
|
|
1333
|
+
if server_name == "canvas":
|
|
1334
|
+
canvas_tool_schema = {
|
|
1335
|
+
"type": "function",
|
|
1336
|
+
"function": {
|
|
1337
|
+
"name": "canvas_canvas",
|
|
1338
|
+
"description": "Display final rendered content in a visual canvas panel. Use this for: 1) Complete code (not code discussions), 2) Final reports/documents (not report discussions), 3) Data visualizations, 4) Any polished content that should be viewed separately from the conversation. Put the actual content in the canvas, keep discussions in chat.",
|
|
1339
|
+
"parameters": {
|
|
1340
|
+
"type": "object",
|
|
1341
|
+
"properties": {
|
|
1342
|
+
"content": {
|
|
1343
|
+
"type": "string",
|
|
1344
|
+
"description": "The content to display in the canvas. Can be markdown, code, or plain text."
|
|
1345
|
+
}
|
|
1346
|
+
},
|
|
1347
|
+
"required": ["content"]
|
|
1348
|
+
}
|
|
1349
|
+
}
|
|
1350
|
+
}
|
|
1351
|
+
tools_schema.append(canvas_tool_schema)
|
|
1352
|
+
server_tool_mapping["canvas_canvas"] = {
|
|
1353
|
+
'server': 'canvas',
|
|
1354
|
+
'tool_name': 'canvas'
|
|
1355
|
+
}
|
|
1356
|
+
elif server_name in self.available_tools:
|
|
1357
|
+
server_tools = self.available_tools[server_name]['tools']
|
|
1358
|
+
for tool in server_tools:
|
|
1359
|
+
# Convert MCP tool format to OpenAI function calling format
|
|
1360
|
+
tool_schema = {
|
|
1361
|
+
"type": "function",
|
|
1362
|
+
"function": {
|
|
1363
|
+
"name": f"{server_name}_{tool.name}",
|
|
1364
|
+
"description": tool.description or '',
|
|
1365
|
+
"parameters": tool.inputSchema or {}
|
|
1366
|
+
}
|
|
1367
|
+
}
|
|
1368
|
+
# log the server -> function name
|
|
1369
|
+
# logger.info(f"Adding tool {tool.name} for server {server_name} ")
|
|
1370
|
+
tools_schema.append(tool_schema)
|
|
1371
|
+
server_tool_mapping[f"{server_name}_{tool.name}"] = {
|
|
1372
|
+
'server': server_name,
|
|
1373
|
+
'tool_name': tool.name
|
|
1374
|
+
}
|
|
1375
|
+
|
|
1376
|
+
return {
|
|
1377
|
+
'tools': tools_schema,
|
|
1378
|
+
'mapping': server_tool_mapping
|
|
1379
|
+
}
|
|
1380
|
+
|
|
1381
|
+
def _requires_user_auth(self, server_name: str) -> bool:
|
|
1382
|
+
"""Check if a server requires per-user authentication.
|
|
1383
|
+
|
|
1384
|
+
Returns True for servers with auth_type 'oauth', 'jwt', 'bearer', or 'api_key'.
|
|
1385
|
+
These servers need user-specific tokens rather than shared/admin tokens.
|
|
1386
|
+
"""
|
|
1387
|
+
config = self.servers_config.get(server_name, {})
|
|
1388
|
+
auth_type = config.get("auth_type", "none")
|
|
1389
|
+
return auth_type in ("oauth", "jwt", "bearer", "api_key")
|
|
1390
|
+
|
|
1391
|
+
async def _get_user_client(
|
|
1392
|
+
self,
|
|
1393
|
+
server_name: str,
|
|
1394
|
+
user_email: str,
|
|
1395
|
+
) -> Optional[Client]:
|
|
1396
|
+
"""Get or create a user-specific client for servers requiring per-user auth.
|
|
1397
|
+
|
|
1398
|
+
Args:
|
|
1399
|
+
server_name: Name of the MCP server
|
|
1400
|
+
user_email: User's email address
|
|
1401
|
+
|
|
1402
|
+
Returns:
|
|
1403
|
+
FastMCP Client configured with user's token, or None if no token available
|
|
1404
|
+
"""
|
|
1405
|
+
from atlas.modules.mcp_tools.token_storage import get_token_storage
|
|
1406
|
+
|
|
1407
|
+
token_storage = get_token_storage()
|
|
1408
|
+
cache_key = (user_email.lower(), server_name)
|
|
1409
|
+
|
|
1410
|
+
# Check cache first, but validate token is still valid
|
|
1411
|
+
async with self._user_clients_lock:
|
|
1412
|
+
if cache_key in self._user_clients:
|
|
1413
|
+
# Verify the token is still valid before returning cached client
|
|
1414
|
+
stored_token = token_storage.get_valid_token(user_email, server_name)
|
|
1415
|
+
if stored_token is not None:
|
|
1416
|
+
return self._user_clients[cache_key]
|
|
1417
|
+
else:
|
|
1418
|
+
# Token expired or removed, invalidate cached client
|
|
1419
|
+
logger.debug(
|
|
1420
|
+
f"Token expired for user on server '{server_name}', "
|
|
1421
|
+
f"invalidating cached client"
|
|
1422
|
+
)
|
|
1423
|
+
del self._user_clients[cache_key]
|
|
1424
|
+
|
|
1425
|
+
# Get user's token from storage
|
|
1426
|
+
logger.debug(f"[AUTH] Looking up token for server='{server_name}'")
|
|
1427
|
+
stored_token = token_storage.get_valid_token(user_email, server_name)
|
|
1428
|
+
logger.debug(f"[AUTH] Token found: {stored_token is not None}")
|
|
1429
|
+
|
|
1430
|
+
if stored_token is None:
|
|
1431
|
+
logger.debug(
|
|
1432
|
+
f"[AUTH] No valid token for server '{server_name}' - user needs to authenticate"
|
|
1433
|
+
)
|
|
1434
|
+
return None
|
|
1435
|
+
|
|
1436
|
+
# Get server config
|
|
1437
|
+
config = self.servers_config.get(server_name, {})
|
|
1438
|
+
url = config.get("url")
|
|
1439
|
+
|
|
1440
|
+
if not url:
|
|
1441
|
+
logger.error(f"No URL configured for server '{server_name}'")
|
|
1442
|
+
return None
|
|
1443
|
+
|
|
1444
|
+
# Ensure URL has protocol
|
|
1445
|
+
if not url.startswith(("http://", "https://")):
|
|
1446
|
+
url = f"http://{url}"
|
|
1447
|
+
|
|
1448
|
+
# Create client with user's token
|
|
1449
|
+
try:
|
|
1450
|
+
log_handler = self._create_log_handler(server_name)
|
|
1451
|
+
auth_type = config.get("auth_type", "bearer")
|
|
1452
|
+
|
|
1453
|
+
# For API key auth, use custom header; for bearer/jwt/oauth, use auth parameter
|
|
1454
|
+
if auth_type == "api_key":
|
|
1455
|
+
# Use custom header for API key authentication
|
|
1456
|
+
auth_header = config.get("auth_header", "X-API-Key")
|
|
1457
|
+
logger.debug(
|
|
1458
|
+
f"Creating API key client for '{server_name}' with header '{auth_header}'"
|
|
1459
|
+
)
|
|
1460
|
+
transport = StreamableHttpTransport(
|
|
1461
|
+
url,
|
|
1462
|
+
headers={auth_header: stored_token.token_value},
|
|
1463
|
+
)
|
|
1464
|
+
client = Client(
|
|
1465
|
+
transport=transport,
|
|
1466
|
+
log_handler=log_handler,
|
|
1467
|
+
elicitation_handler=self._create_elicitation_handler(server_name),
|
|
1468
|
+
sampling_handler=self._create_sampling_handler(server_name),
|
|
1469
|
+
)
|
|
1470
|
+
else:
|
|
1471
|
+
# FastMCP Client accepts auth= as a string (bearer token)
|
|
1472
|
+
client = Client(
|
|
1473
|
+
url,
|
|
1474
|
+
auth=stored_token.token_value,
|
|
1475
|
+
log_handler=log_handler,
|
|
1476
|
+
elicitation_handler=self._create_elicitation_handler(server_name),
|
|
1477
|
+
sampling_handler=self._create_sampling_handler(server_name),
|
|
1478
|
+
)
|
|
1479
|
+
|
|
1480
|
+
# Cache the client
|
|
1481
|
+
async with self._user_clients_lock:
|
|
1482
|
+
self._user_clients[cache_key] = client
|
|
1483
|
+
|
|
1484
|
+
logger.info(
|
|
1485
|
+
f"Created user-specific client for server '{server_name}' (auth_type={auth_type})"
|
|
1486
|
+
)
|
|
1487
|
+
return client
|
|
1488
|
+
|
|
1489
|
+
except Exception as e:
|
|
1490
|
+
logger.error(
|
|
1491
|
+
f"Failed to create user client for server '{server_name}': {e}"
|
|
1492
|
+
)
|
|
1493
|
+
return None
|
|
1494
|
+
|
|
1495
|
+
async def _invalidate_user_client(self, user_email: str, server_name: str) -> None:
|
|
1496
|
+
"""Remove a user's cached client (e.g., when token is revoked)."""
|
|
1497
|
+
cache_key = (user_email.lower(), server_name)
|
|
1498
|
+
async with self._user_clients_lock:
|
|
1499
|
+
if cache_key in self._user_clients:
|
|
1500
|
+
del self._user_clients[cache_key]
|
|
1501
|
+
logger.debug(f"Invalidated user client cache for server '{server_name}'")
|
|
1502
|
+
|
|
1503
|
+
async def call_tool(
|
|
1504
|
+
self,
|
|
1505
|
+
server_name: str,
|
|
1506
|
+
tool_name: str,
|
|
1507
|
+
arguments: Dict[str, Any],
|
|
1508
|
+
*,
|
|
1509
|
+
progress_handler: Optional[Any] = None,
|
|
1510
|
+
elicitation_handler: Optional[Any] = None,
|
|
1511
|
+
user_email: Optional[str] = None,
|
|
1512
|
+
) -> Any:
|
|
1513
|
+
"""Call a specific tool on an MCP server.
|
|
1514
|
+
|
|
1515
|
+
Args:
|
|
1516
|
+
server_name: Name of the MCP server
|
|
1517
|
+
tool_name: Name of the tool to call
|
|
1518
|
+
arguments: Tool arguments
|
|
1519
|
+
progress_handler: Optional progress callback handler
|
|
1520
|
+
elicitation_handler: Optional elicitation callback handler. Prefer the built-in
|
|
1521
|
+
elicitation routing (registered at client creation time) for shared clients.
|
|
1522
|
+
user_email: User's email for per-user authentication (required for oauth/jwt servers)
|
|
1523
|
+
"""
|
|
1524
|
+
# Determine which client to use
|
|
1525
|
+
client = None
|
|
1526
|
+
|
|
1527
|
+
# Check if this server requires per-user authentication
|
|
1528
|
+
if self._requires_user_auth(server_name):
|
|
1529
|
+
logger.debug(f"Server '{server_name}' requires user auth, user_email={user_email}")
|
|
1530
|
+
if user_email:
|
|
1531
|
+
client = await self._get_user_client(server_name, user_email)
|
|
1532
|
+
logger.debug(f"_get_user_client for '{server_name}' returned client: {client is not None}")
|
|
1533
|
+
if client is None:
|
|
1534
|
+
# Get auth type and build OAuth URL if applicable
|
|
1535
|
+
server_config = self.servers_config.get(server_name, {})
|
|
1536
|
+
auth_type = server_config.get("auth_type", "oauth")
|
|
1537
|
+
oauth_start_url = None
|
|
1538
|
+
if auth_type == "oauth":
|
|
1539
|
+
# Build OAuth start URL for automatic redirect
|
|
1540
|
+
oauth_start_url = f"/api/mcp/auth/{server_name}/oauth/start"
|
|
1541
|
+
raise AuthenticationRequiredException(
|
|
1542
|
+
server_name=server_name,
|
|
1543
|
+
auth_type=auth_type,
|
|
1544
|
+
message=f"Server '{server_name}' requires authentication.",
|
|
1545
|
+
oauth_start_url=oauth_start_url,
|
|
1546
|
+
)
|
|
1547
|
+
else:
|
|
1548
|
+
server_config = self.servers_config.get(server_name, {})
|
|
1549
|
+
auth_type = server_config.get("auth_type", "oauth")
|
|
1550
|
+
raise AuthenticationRequiredException(
|
|
1551
|
+
server_name=server_name,
|
|
1552
|
+
auth_type=auth_type,
|
|
1553
|
+
message=f"Server '{server_name}' requires authentication but no user context.",
|
|
1554
|
+
oauth_start_url=f"/api/mcp/auth/{server_name}/oauth/start" if auth_type == "oauth" else None,
|
|
1555
|
+
)
|
|
1556
|
+
else:
|
|
1557
|
+
# Use shared client for servers without per-user auth
|
|
1558
|
+
if server_name not in self.clients:
|
|
1559
|
+
raise ValueError(f"No client available for server: {server_name}")
|
|
1560
|
+
client = self.clients[server_name]
|
|
1561
|
+
|
|
1562
|
+
call_timeout = config_manager.app_settings.mcp_call_timeout
|
|
1563
|
+
try:
|
|
1564
|
+
# Set elicitation callback before opening the client context.
|
|
1565
|
+
# FastMCP negotiates supported capabilities during session init.
|
|
1566
|
+
if elicitation_handler is not None:
|
|
1567
|
+
client.set_elicitation_callback(elicitation_handler)
|
|
1568
|
+
|
|
1569
|
+
async with client:
|
|
1570
|
+
# Pass progress handler if provided (fastmcp >= 2.3.5)
|
|
1571
|
+
kwargs = {}
|
|
1572
|
+
if progress_handler is not None:
|
|
1573
|
+
kwargs["progress_handler"] = progress_handler
|
|
1574
|
+
|
|
1575
|
+
result = await asyncio.wait_for(
|
|
1576
|
+
client.call_tool(tool_name, arguments, **kwargs),
|
|
1577
|
+
timeout=call_timeout,
|
|
1578
|
+
)
|
|
1579
|
+
logger.info(f"Successfully called {sanitize_for_logging(tool_name)} on {sanitize_for_logging(server_name)}")
|
|
1580
|
+
return result
|
|
1581
|
+
except asyncio.TimeoutError:
|
|
1582
|
+
error_msg = f"Tool call '{tool_name}' on server '{server_name}' timed out after {call_timeout}s"
|
|
1583
|
+
logger.error(error_msg)
|
|
1584
|
+
self._record_server_failure(server_name, error_msg)
|
|
1585
|
+
raise TimeoutError(error_msg)
|
|
1586
|
+
except Exception as e:
|
|
1587
|
+
logger.error(f"Error calling {tool_name} on {server_name}: {e}")
|
|
1588
|
+
raise
|
|
1589
|
+
|
|
1590
|
+
async def get_prompt(self, server_name: str, prompt_name: str, arguments: Dict[str, Any] = None) -> Any:
|
|
1591
|
+
"""Get a specific prompt from an MCP server."""
|
|
1592
|
+
if server_name not in self.clients:
|
|
1593
|
+
raise ValueError(f"No client available for server: {server_name}")
|
|
1594
|
+
|
|
1595
|
+
client = self.clients[server_name]
|
|
1596
|
+
try:
|
|
1597
|
+
async with client:
|
|
1598
|
+
if arguments:
|
|
1599
|
+
result = await client.get_prompt(prompt_name, arguments)
|
|
1600
|
+
else:
|
|
1601
|
+
result = await client.get_prompt(prompt_name)
|
|
1602
|
+
logger.info(f"Successfully retrieved prompt {prompt_name} from {server_name}")
|
|
1603
|
+
return result
|
|
1604
|
+
except Exception as e:
|
|
1605
|
+
logger.error(f"Error getting prompt {prompt_name} from {server_name}: {e}")
|
|
1606
|
+
raise
|
|
1607
|
+
|
|
1608
|
+
def get_available_prompts_for_servers(self, server_names: List[str]) -> Dict[str, Any]:
|
|
1609
|
+
"""Get available prompts for selected servers."""
|
|
1610
|
+
available_prompts = {}
|
|
1611
|
+
|
|
1612
|
+
for server_name in server_names:
|
|
1613
|
+
if server_name in self.available_prompts:
|
|
1614
|
+
server_prompts = self.available_prompts[server_name]['prompts']
|
|
1615
|
+
for prompt in server_prompts:
|
|
1616
|
+
prompt_key = f"{server_name}_{prompt.name}"
|
|
1617
|
+
available_prompts[prompt_key] = {
|
|
1618
|
+
'server': server_name,
|
|
1619
|
+
'name': prompt.name,
|
|
1620
|
+
'description': prompt.description or '',
|
|
1621
|
+
'arguments': prompt.arguments or {}
|
|
1622
|
+
}
|
|
1623
|
+
|
|
1624
|
+
return available_prompts
|
|
1625
|
+
|
|
1626
|
+
async def get_authorized_servers(self, user_email: str, auth_check_func) -> List[str]:
|
|
1627
|
+
"""Get list of servers the user is authorized to use."""
|
|
1628
|
+
authorized_servers = []
|
|
1629
|
+
for server_name, server_config in self.servers_config.items():
|
|
1630
|
+
if not server_config.get("enabled", True):
|
|
1631
|
+
continue
|
|
1632
|
+
|
|
1633
|
+
required_groups = server_config.get("groups", [])
|
|
1634
|
+
if not required_groups:
|
|
1635
|
+
authorized_servers.append(server_name)
|
|
1636
|
+
continue
|
|
1637
|
+
|
|
1638
|
+
# Check if user is in any of the required groups
|
|
1639
|
+
# We need to await each call and collect results before using any()
|
|
1640
|
+
group_checks = [await auth_check_func(user_email, group) for group in required_groups]
|
|
1641
|
+
if any(group_checks):
|
|
1642
|
+
authorized_servers.append(server_name)
|
|
1643
|
+
return authorized_servers
|
|
1644
|
+
|
|
1645
|
+
def get_available_tools(self) -> List[str]:
|
|
1646
|
+
"""Get list of available tool names."""
|
|
1647
|
+
available_tools = []
|
|
1648
|
+
for server_name, server_data in self.available_tools.items():
|
|
1649
|
+
if server_name == "canvas":
|
|
1650
|
+
available_tools.append("canvas_canvas")
|
|
1651
|
+
else:
|
|
1652
|
+
for tool in server_data.get('tools', []):
|
|
1653
|
+
available_tools.append(f"{server_name}_{tool.name}")
|
|
1654
|
+
return available_tools
|
|
1655
|
+
|
|
1656
|
+
def get_tools_schema(self, tool_names: List[str]) -> List[Dict[str, Any]]:
|
|
1657
|
+
"""Get schemas for specified tools.
|
|
1658
|
+
|
|
1659
|
+
Previous implementation attempted to derive the server name by stripping the last
|
|
1660
|
+
underscore-delimited segment from the fully-qualified tool name. This broke when
|
|
1661
|
+
the original (per-server) tool names themselves contained underscores (e.g.
|
|
1662
|
+
server 'ui-demo' with tool 'create_form_demo' produced full name
|
|
1663
|
+
'ui-demo_create_form_demo'; naive splitting yielded a *server* of
|
|
1664
|
+
'ui-demo_create_form' which does not exist, causing the schema lookup to fail and
|
|
1665
|
+
returning an empty set. This method now directly matches fully-qualified tool
|
|
1666
|
+
names against the discovered inventory instead of guessing via string surgery.
|
|
1667
|
+
"""
|
|
1668
|
+
|
|
1669
|
+
if not tool_names:
|
|
1670
|
+
return []
|
|
1671
|
+
|
|
1672
|
+
# Build (or reuse) an index of full tool name -> (server_name, tool_obj)
|
|
1673
|
+
# so we can do O(1) lookups without fragile string parsing.
|
|
1674
|
+
if not hasattr(self, "_tool_index") or not getattr(self, "_tool_index"):
|
|
1675
|
+
index = {}
|
|
1676
|
+
for server_name, server_data in self.available_tools.items():
|
|
1677
|
+
if server_name == "canvas":
|
|
1678
|
+
index["canvas_canvas"] = {
|
|
1679
|
+
'server': 'canvas',
|
|
1680
|
+
'tool': None # pseudo tool
|
|
1681
|
+
}
|
|
1682
|
+
else:
|
|
1683
|
+
for tool in server_data.get('tools', []):
|
|
1684
|
+
full_name = f"{server_name}_{tool.name}"
|
|
1685
|
+
index[full_name] = {
|
|
1686
|
+
'server': server_name,
|
|
1687
|
+
'tool': tool
|
|
1688
|
+
}
|
|
1689
|
+
self._tool_index = index
|
|
1690
|
+
else:
|
|
1691
|
+
index = self._tool_index
|
|
1692
|
+
|
|
1693
|
+
matched = []
|
|
1694
|
+
missing = []
|
|
1695
|
+
for requested in tool_names:
|
|
1696
|
+
entry = index.get(requested)
|
|
1697
|
+
if not entry:
|
|
1698
|
+
missing.append(requested)
|
|
1699
|
+
continue
|
|
1700
|
+
if requested == "canvas_canvas":
|
|
1701
|
+
# Recreate the canvas schema (kept in one place – duplicate logic intentional
|
|
1702
|
+
# to avoid coupling to get_tools_for_servers which returns superset data)
|
|
1703
|
+
matched.append({
|
|
1704
|
+
"type": "function",
|
|
1705
|
+
"function": {
|
|
1706
|
+
"name": "canvas_canvas",
|
|
1707
|
+
"description": "Display final rendered content in a visual canvas panel. Use this for: 1) Complete code (not code discussions), 2) Final reports/documents (not report discussions), 3) Data visualizations, 4) Any polished content that should be viewed separately from the conversation. Put the actual content in the canvas, keep discussions in chat.",
|
|
1708
|
+
"parameters": {
|
|
1709
|
+
"type": "object",
|
|
1710
|
+
"properties": {
|
|
1711
|
+
"content": {
|
|
1712
|
+
"type": "string",
|
|
1713
|
+
"description": "The content to display in the canvas. Can be markdown, code, or plain text."
|
|
1714
|
+
}
|
|
1715
|
+
},
|
|
1716
|
+
"required": ["content"]
|
|
1717
|
+
}
|
|
1718
|
+
}
|
|
1719
|
+
})
|
|
1720
|
+
else:
|
|
1721
|
+
tool = entry['tool']
|
|
1722
|
+
matched.append({
|
|
1723
|
+
"type": "function",
|
|
1724
|
+
"function": {
|
|
1725
|
+
"name": requested,
|
|
1726
|
+
"description": getattr(tool, 'description', '') or '',
|
|
1727
|
+
"parameters": getattr(tool, 'inputSchema', {}) or {}
|
|
1728
|
+
}
|
|
1729
|
+
})
|
|
1730
|
+
|
|
1731
|
+
|
|
1732
|
+
|
|
1733
|
+
return matched
|
|
1734
|
+
|
|
1735
|
+
# ------------------------------------------------------------
|
|
1736
|
+
# Internal helpers
|
|
1737
|
+
# ------------------------------------------------------------
|
|
1738
|
+
def _normalize_mcp_tool_result(self, raw_result: Any) -> Dict[str, Any]:
|
|
1739
|
+
"""Normalize a FastMCP CallToolResult (or similar object) into our contract.
|
|
1740
|
+
|
|
1741
|
+
Returns a dict shaped like:
|
|
1742
|
+
{
|
|
1743
|
+
"results": <payload or string>,
|
|
1744
|
+
"meta_data": {...optional...},
|
|
1745
|
+
"returned_file_names": [...optional...],
|
|
1746
|
+
"returned_file_count": N (if file contents present)
|
|
1747
|
+
}
|
|
1748
|
+
|
|
1749
|
+
Notes:
|
|
1750
|
+
- We never inline base64 file contents here to avoid prompt bloat.
|
|
1751
|
+
- Handles legacy key forms (result, meta-data, metadata).
|
|
1752
|
+
- Falls back to stringifying the raw result if structured extraction fails.
|
|
1753
|
+
"""
|
|
1754
|
+
normalized: Dict[str, Any] = {}
|
|
1755
|
+
structured: Dict[str, Any] = {}
|
|
1756
|
+
|
|
1757
|
+
# Attempt extraction in priority order
|
|
1758
|
+
try:
|
|
1759
|
+
if hasattr(raw_result, "structured_content") and raw_result.structured_content: # type: ignore[attr-defined]
|
|
1760
|
+
structured = raw_result.structured_content # type: ignore[attr-defined]
|
|
1761
|
+
elif hasattr(raw_result, "data") and raw_result.data: # type: ignore[attr-defined]
|
|
1762
|
+
structured = raw_result.data # type: ignore[attr-defined]
|
|
1763
|
+
else:
|
|
1764
|
+
# Fallback: extract text content from content array
|
|
1765
|
+
if hasattr(raw_result, "content"):
|
|
1766
|
+
contents = getattr(raw_result, "content")
|
|
1767
|
+
if contents:
|
|
1768
|
+
# Collect all text from TextContent items
|
|
1769
|
+
text_parts = []
|
|
1770
|
+
for item in contents:
|
|
1771
|
+
if hasattr(item, "type") and getattr(item, "type") == "text":
|
|
1772
|
+
text = getattr(item, "text", None)
|
|
1773
|
+
if text:
|
|
1774
|
+
text_parts.append(text)
|
|
1775
|
+
|
|
1776
|
+
if text_parts:
|
|
1777
|
+
combined_text = "\n".join(text_parts)
|
|
1778
|
+
# Try to parse as JSON if it looks like JSON
|
|
1779
|
+
if combined_text.strip().startswith(("{", "[")):
|
|
1780
|
+
try:
|
|
1781
|
+
logger.info("MCP tool result normalization: using content text JSON fallback for structured extraction")
|
|
1782
|
+
structured = json.loads(combined_text)
|
|
1783
|
+
except Exception: # pragma: no cover - defensive
|
|
1784
|
+
# Not valid JSON, use as plain text result
|
|
1785
|
+
structured = {"results": combined_text}
|
|
1786
|
+
else:
|
|
1787
|
+
# Plain text - use as results directly
|
|
1788
|
+
structured = {"results": combined_text}
|
|
1789
|
+
except Exception as parse_err: # pragma: no cover - defensive
|
|
1790
|
+
logger.debug(f"Non-fatal parse issue extracting structured tool result: {parse_err}")
|
|
1791
|
+
|
|
1792
|
+
if isinstance(structured, dict):
|
|
1793
|
+
# Support both correct and legacy key forms
|
|
1794
|
+
results_payload = structured.get("results") or structured.get("result")
|
|
1795
|
+
meta_payload = (
|
|
1796
|
+
structured.get("meta_data")
|
|
1797
|
+
or structured.get("meta-data")
|
|
1798
|
+
or structured.get("metadata")
|
|
1799
|
+
)
|
|
1800
|
+
returned_file_names = structured.get("returned_file_names")
|
|
1801
|
+
returned_file_contents = structured.get("returned_file_contents")
|
|
1802
|
+
|
|
1803
|
+
if results_payload is not None:
|
|
1804
|
+
normalized["results"] = results_payload
|
|
1805
|
+
if meta_payload is not None:
|
|
1806
|
+
try:
|
|
1807
|
+
# Heuristic to prevent very large meta blobs
|
|
1808
|
+
if len(json.dumps(meta_payload)) < 4000:
|
|
1809
|
+
normalized["meta_data"] = meta_payload
|
|
1810
|
+
else:
|
|
1811
|
+
normalized["meta_data_truncated"] = True
|
|
1812
|
+
except Exception: # pragma: no cover
|
|
1813
|
+
normalized["meta_data_parse_error"] = True
|
|
1814
|
+
if returned_file_names:
|
|
1815
|
+
normalized["returned_file_names"] = returned_file_names
|
|
1816
|
+
if returned_file_contents:
|
|
1817
|
+
normalized["returned_file_count"] = (
|
|
1818
|
+
len(returned_file_contents) if isinstance(returned_file_contents, (list, tuple)) else 1
|
|
1819
|
+
)
|
|
1820
|
+
|
|
1821
|
+
# Phase 5 fallback: if no explicit results key, treat *entire* structured dict (minus large/base64 fields) as results
|
|
1822
|
+
if "results" not in normalized:
|
|
1823
|
+
# Prune potentially huge / sensitive keys before fallback
|
|
1824
|
+
prune_keys = {"returned_file_contents"}
|
|
1825
|
+
pruned = {k: v for k, v in structured.items() if k not in prune_keys}
|
|
1826
|
+
try:
|
|
1827
|
+
serialized = json.dumps(pruned)
|
|
1828
|
+
if len(serialized) <= 8000: # size guard
|
|
1829
|
+
normalized["results"] = pruned
|
|
1830
|
+
else:
|
|
1831
|
+
normalized["results_summary"] = {
|
|
1832
|
+
"keys": list(pruned.keys()),
|
|
1833
|
+
"omitted_due_to_size": len(serialized)
|
|
1834
|
+
}
|
|
1835
|
+
except Exception: # pragma: no cover
|
|
1836
|
+
# Fallback to string repr if serialization fails
|
|
1837
|
+
normalized.setdefault("results", str(pruned))
|
|
1838
|
+
|
|
1839
|
+
if not normalized:
|
|
1840
|
+
normalized = {"results": str(raw_result)}
|
|
1841
|
+
return normalized
|
|
1842
|
+
|
|
1843
|
+
async def execute_tool(
|
|
1844
|
+
self,
|
|
1845
|
+
tool_call: ToolCall,
|
|
1846
|
+
context: Optional[Dict[str, Any]] = None
|
|
1847
|
+
) -> ToolResult:
|
|
1848
|
+
"""Execute a tool call."""
|
|
1849
|
+
logger.debug("ToolManager.execute_tool: tool=%s", tool_call.name)
|
|
1850
|
+
# Handle canvas pseudo-tool
|
|
1851
|
+
if tool_call.name == "canvas_canvas":
|
|
1852
|
+
# Canvas tool just returns the content - it's handled by frontend
|
|
1853
|
+
content = tool_call.arguments.get("content", "")
|
|
1854
|
+
return ToolResult(
|
|
1855
|
+
tool_call_id=tool_call.id,
|
|
1856
|
+
content=f"Canvas content displayed: {content[:100]}..." if len(content) > 100 else f"Canvas content displayed: {content}",
|
|
1857
|
+
success=True
|
|
1858
|
+
)
|
|
1859
|
+
|
|
1860
|
+
# Use the tool index to get server and tool name (avoids parsing issues with dashes/underscores)
|
|
1861
|
+
if not hasattr(self, "_tool_index") or not getattr(self, "_tool_index"):
|
|
1862
|
+
# Build tool index if not available (same logic as in get_tools_schema)
|
|
1863
|
+
index = {}
|
|
1864
|
+
for server_name, server_data in self.available_tools.items():
|
|
1865
|
+
if server_name == "canvas":
|
|
1866
|
+
index["canvas_canvas"] = {
|
|
1867
|
+
'server': 'canvas',
|
|
1868
|
+
'tool': None # pseudo tool
|
|
1869
|
+
}
|
|
1870
|
+
else:
|
|
1871
|
+
for tool in server_data.get('tools', []):
|
|
1872
|
+
full_name = f"{server_name}_{tool.name}"
|
|
1873
|
+
index[full_name] = {
|
|
1874
|
+
'server': server_name,
|
|
1875
|
+
'tool': tool
|
|
1876
|
+
}
|
|
1877
|
+
self._tool_index = index
|
|
1878
|
+
|
|
1879
|
+
# Look up the tool in our index
|
|
1880
|
+
tool_entry = self._tool_index.get(tool_call.name)
|
|
1881
|
+
if not tool_entry:
|
|
1882
|
+
return ToolResult(
|
|
1883
|
+
tool_call_id=tool_call.id,
|
|
1884
|
+
content=f"Tool not found: {tool_call.name}",
|
|
1885
|
+
success=False,
|
|
1886
|
+
error=f"Tool not found: {tool_call.name}"
|
|
1887
|
+
)
|
|
1888
|
+
|
|
1889
|
+
server_name = tool_entry['server']
|
|
1890
|
+
actual_tool_name = tool_entry['tool'].name if tool_entry['tool'] else tool_call.name
|
|
1891
|
+
|
|
1892
|
+
try:
|
|
1893
|
+
update_cb = None
|
|
1894
|
+
user_email = None
|
|
1895
|
+
if isinstance(context, dict):
|
|
1896
|
+
update_cb = context.get("update_callback")
|
|
1897
|
+
user_email = context.get("user_email")
|
|
1898
|
+
|
|
1899
|
+
if update_cb is None:
|
|
1900
|
+
logger.warning(
|
|
1901
|
+
f"Executing tool '{tool_call.name}' without update_callback - "
|
|
1902
|
+
f"elicitation will not work. Context type: {type(context)}"
|
|
1903
|
+
)
|
|
1904
|
+
else:
|
|
1905
|
+
logger.debug(f"Executing tool '{tool_call.name}' with update_callback present")
|
|
1906
|
+
|
|
1907
|
+
async def _tool_log_callback(
|
|
1908
|
+
log_server_name: str,
|
|
1909
|
+
level: str,
|
|
1910
|
+
message: str,
|
|
1911
|
+
extra: Dict[str, Any],
|
|
1912
|
+
) -> None:
|
|
1913
|
+
if update_cb is None:
|
|
1914
|
+
return
|
|
1915
|
+
try:
|
|
1916
|
+
# Deferred import to avoid cycles
|
|
1917
|
+
from atlas.application.chat.utilities.event_notifier import notify_tool_log
|
|
1918
|
+
await notify_tool_log(
|
|
1919
|
+
server_name=log_server_name,
|
|
1920
|
+
tool_name=tool_call.name,
|
|
1921
|
+
tool_call_id=tool_call.id,
|
|
1922
|
+
level=level,
|
|
1923
|
+
message=sanitize_for_logging(message),
|
|
1924
|
+
extra=extra,
|
|
1925
|
+
update_callback=update_cb,
|
|
1926
|
+
)
|
|
1927
|
+
except Exception:
|
|
1928
|
+
logger.debug("Tool log forwarding failed", exc_info=True)
|
|
1929
|
+
|
|
1930
|
+
# Build a progress handler that forwards to UI if provided via context
|
|
1931
|
+
async def _progress_handler(progress: float, total: Optional[float], message: Optional[str]) -> None:
|
|
1932
|
+
try:
|
|
1933
|
+
if update_cb is not None:
|
|
1934
|
+
# Deferred import to avoid cycles
|
|
1935
|
+
from atlas.application.chat.utilities.event_notifier import notify_tool_progress
|
|
1936
|
+
await notify_tool_progress(
|
|
1937
|
+
tool_call_id=tool_call.id,
|
|
1938
|
+
tool_name=tool_call.name,
|
|
1939
|
+
progress=progress,
|
|
1940
|
+
total=total,
|
|
1941
|
+
message=message,
|
|
1942
|
+
update_callback=update_cb,
|
|
1943
|
+
)
|
|
1944
|
+
except Exception:
|
|
1945
|
+
logger.debug("Progress handler forwarding failed", exc_info=True)
|
|
1946
|
+
|
|
1947
|
+
if update_cb is not None:
|
|
1948
|
+
async with self._use_log_callback(_tool_log_callback):
|
|
1949
|
+
async with self._use_elicitation_context(server_name, tool_call, update_cb):
|
|
1950
|
+
async with self._use_sampling_context(server_name, tool_call, update_cb):
|
|
1951
|
+
raw_result = await self.call_tool(
|
|
1952
|
+
server_name,
|
|
1953
|
+
actual_tool_name,
|
|
1954
|
+
tool_call.arguments,
|
|
1955
|
+
progress_handler=_progress_handler,
|
|
1956
|
+
user_email=user_email,
|
|
1957
|
+
)
|
|
1958
|
+
else:
|
|
1959
|
+
async with self._use_elicitation_context(server_name, tool_call, update_cb):
|
|
1960
|
+
async with self._use_sampling_context(server_name, tool_call, update_cb):
|
|
1961
|
+
raw_result = await self.call_tool(
|
|
1962
|
+
server_name,
|
|
1963
|
+
actual_tool_name,
|
|
1964
|
+
tool_call.arguments,
|
|
1965
|
+
progress_handler=_progress_handler,
|
|
1966
|
+
user_email=user_email,
|
|
1967
|
+
)
|
|
1968
|
+
normalized_content = self._normalize_mcp_tool_result(raw_result)
|
|
1969
|
+
content_str = json.dumps(normalized_content, ensure_ascii=False)
|
|
1970
|
+
|
|
1971
|
+
# Extract v2 MCP response components (supports dict or FastMCP result objects)
|
|
1972
|
+
artifacts: List[Dict[str, Any]] = []
|
|
1973
|
+
display_config: Optional[Dict[str, Any]] = None
|
|
1974
|
+
meta_data: Optional[Dict[str, Any]] = None
|
|
1975
|
+
|
|
1976
|
+
try:
|
|
1977
|
+
if isinstance(raw_result, dict):
|
|
1978
|
+
structured = raw_result
|
|
1979
|
+
else:
|
|
1980
|
+
structured = {}
|
|
1981
|
+
if hasattr(raw_result, "structured_content") and raw_result.structured_content: # type: ignore[attr-defined]
|
|
1982
|
+
sc = raw_result.structured_content # type: ignore[attr-defined]
|
|
1983
|
+
if isinstance(sc, dict):
|
|
1984
|
+
structured = sc
|
|
1985
|
+
elif hasattr(raw_result, "data") and raw_result.data: # type: ignore[attr-defined]
|
|
1986
|
+
dt = raw_result.data # type: ignore[attr-defined]
|
|
1987
|
+
if isinstance(dt, dict):
|
|
1988
|
+
structured = dt
|
|
1989
|
+
else:
|
|
1990
|
+
# Fallback: parse first textual content if JSON-like
|
|
1991
|
+
# This handles MCP responses that return data only in content[0].text
|
|
1992
|
+
if hasattr(raw_result, "content"):
|
|
1993
|
+
contents = getattr(raw_result, "content")
|
|
1994
|
+
if contents and len(contents) > 0 and hasattr(contents[0], "text"):
|
|
1995
|
+
first_text = getattr(contents[0], "text")
|
|
1996
|
+
if isinstance(first_text, str) and first_text.strip().startswith("{"):
|
|
1997
|
+
try:
|
|
1998
|
+
structured = json.loads(first_text)
|
|
1999
|
+
except Exception:
|
|
2000
|
+
pass
|
|
2001
|
+
|
|
2002
|
+
if isinstance(structured, dict) and structured:
|
|
2003
|
+
# Extract artifacts
|
|
2004
|
+
raw_artifacts = structured.get("artifacts")
|
|
2005
|
+
if isinstance(raw_artifacts, list):
|
|
2006
|
+
for art in raw_artifacts:
|
|
2007
|
+
if isinstance(art, dict):
|
|
2008
|
+
name = art.get("name")
|
|
2009
|
+
b64 = art.get("b64")
|
|
2010
|
+
if name and b64:
|
|
2011
|
+
artifacts.append(art)
|
|
2012
|
+
|
|
2013
|
+
# Extract display
|
|
2014
|
+
disp = structured.get("display")
|
|
2015
|
+
if isinstance(disp, dict):
|
|
2016
|
+
display_config = disp
|
|
2017
|
+
|
|
2018
|
+
# Extract metadata
|
|
2019
|
+
md = structured.get("meta_data")
|
|
2020
|
+
if isinstance(md, dict):
|
|
2021
|
+
meta_data = md
|
|
2022
|
+
|
|
2023
|
+
# Extract ImageContent from the content array
|
|
2024
|
+
# Allowlist of safe image MIME types
|
|
2025
|
+
ALLOWED_IMAGE_MIMES = {
|
|
2026
|
+
"image/png", "image/jpeg", "image/gif",
|
|
2027
|
+
"image/svg+xml", "image/webp", "image/bmp"
|
|
2028
|
+
}
|
|
2029
|
+
|
|
2030
|
+
if hasattr(raw_result, "content"):
|
|
2031
|
+
contents = getattr(raw_result, "content")
|
|
2032
|
+
if isinstance(contents, list):
|
|
2033
|
+
image_counter = 0
|
|
2034
|
+
for item in contents:
|
|
2035
|
+
# Check if this is an ImageContent object
|
|
2036
|
+
if hasattr(item, "type") and getattr(item, "type") == "image":
|
|
2037
|
+
data = getattr(item, "data", None)
|
|
2038
|
+
mime_type = getattr(item, "mimeType", None)
|
|
2039
|
+
|
|
2040
|
+
# Validate mime type against allowlist
|
|
2041
|
+
if mime_type and mime_type not in ALLOWED_IMAGE_MIMES:
|
|
2042
|
+
logger.warning(
|
|
2043
|
+
f"Skipping ImageContent with unsupported mime type: {mime_type}"
|
|
2044
|
+
)
|
|
2045
|
+
continue
|
|
2046
|
+
|
|
2047
|
+
# Validate base64 data
|
|
2048
|
+
if data:
|
|
2049
|
+
try:
|
|
2050
|
+
import base64
|
|
2051
|
+
base64.b64decode(data, validate=True)
|
|
2052
|
+
except Exception:
|
|
2053
|
+
logger.warning(
|
|
2054
|
+
"Skipping ImageContent with invalid base64 data"
|
|
2055
|
+
)
|
|
2056
|
+
continue
|
|
2057
|
+
|
|
2058
|
+
if data and mime_type:
|
|
2059
|
+
# Generate a filename based on image counter and mime type
|
|
2060
|
+
# Use mcp_image_ prefix to avoid collisions with structured artifacts
|
|
2061
|
+
ext = mime_type.split("/")[-1] if "/" in mime_type else "bin"
|
|
2062
|
+
filename = f"mcp_image_{image_counter}.{ext}"
|
|
2063
|
+
|
|
2064
|
+
# Create artifact in the expected format
|
|
2065
|
+
artifact = {
|
|
2066
|
+
"name": filename,
|
|
2067
|
+
"b64": data,
|
|
2068
|
+
"mime": mime_type,
|
|
2069
|
+
"viewer": "image",
|
|
2070
|
+
"description": f"Image returned by {tool_call.name}"
|
|
2071
|
+
}
|
|
2072
|
+
artifacts.append(artifact)
|
|
2073
|
+
logger.debug(f"Extracted ImageContent as artifact: {filename} ({mime_type})")
|
|
2074
|
+
|
|
2075
|
+
# If no display config exists and this is the first image, auto-open canvas
|
|
2076
|
+
if not display_config and image_counter == 0:
|
|
2077
|
+
display_config = {
|
|
2078
|
+
"primary_file": filename,
|
|
2079
|
+
"open_canvas": True
|
|
2080
|
+
}
|
|
2081
|
+
|
|
2082
|
+
image_counter += 1
|
|
2083
|
+
except Exception:
|
|
2084
|
+
logger.warning("Error extracting v2 MCP components from tool result", exc_info=True)
|
|
2085
|
+
|
|
2086
|
+
log_metric("tool_call", user_email, tool_name=actual_tool_name)
|
|
2087
|
+
|
|
2088
|
+
return ToolResult(
|
|
2089
|
+
tool_call_id=tool_call.id,
|
|
2090
|
+
content=content_str,
|
|
2091
|
+
success=True,
|
|
2092
|
+
artifacts=artifacts,
|
|
2093
|
+
display_config=display_config,
|
|
2094
|
+
meta_data=meta_data
|
|
2095
|
+
)
|
|
2096
|
+
except Exception as e:
|
|
2097
|
+
logger.error(f"Error executing tool {tool_call.name}: {e}")
|
|
2098
|
+
|
|
2099
|
+
log_metric("tool_error", user_email, tool_name=actual_tool_name)
|
|
2100
|
+
|
|
2101
|
+
return ToolResult(
|
|
2102
|
+
tool_call_id=tool_call.id,
|
|
2103
|
+
content=f"Error executing tool: {str(e)}",
|
|
2104
|
+
success=False,
|
|
2105
|
+
error=str(e)
|
|
2106
|
+
)
|
|
2107
|
+
|
|
2108
|
+
async def execute_tool_calls(
|
|
2109
|
+
self,
|
|
2110
|
+
tool_calls: List[ToolCall],
|
|
2111
|
+
context: Optional[Dict[str, Any]] = None
|
|
2112
|
+
) -> List[ToolResult]:
|
|
2113
|
+
"""Execute multiple tool calls."""
|
|
2114
|
+
results = []
|
|
2115
|
+
for tool_call in tool_calls:
|
|
2116
|
+
result = await self.execute_tool(tool_call, context)
|
|
2117
|
+
results.append(result)
|
|
2118
|
+
return results
|
|
2119
|
+
|
|
2120
|
+
async def cleanup(self):
|
|
2121
|
+
"""Cleanup all clients."""
|
|
2122
|
+
logger.info("Cleaning up MCP clients")
|
|
2123
|
+
# FastMCP clients handle cleanup automatically with context managers
|