atlas-chat 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. atlas/__init__.py +40 -0
  2. atlas/application/__init__.py +7 -0
  3. atlas/application/chat/__init__.py +7 -0
  4. atlas/application/chat/agent/__init__.py +10 -0
  5. atlas/application/chat/agent/act_loop.py +179 -0
  6. atlas/application/chat/agent/factory.py +142 -0
  7. atlas/application/chat/agent/protocols.py +46 -0
  8. atlas/application/chat/agent/react_loop.py +338 -0
  9. atlas/application/chat/agent/think_act_loop.py +171 -0
  10. atlas/application/chat/approval_manager.py +151 -0
  11. atlas/application/chat/elicitation_manager.py +191 -0
  12. atlas/application/chat/events/__init__.py +1 -0
  13. atlas/application/chat/events/agent_event_relay.py +112 -0
  14. atlas/application/chat/modes/__init__.py +1 -0
  15. atlas/application/chat/modes/agent.py +125 -0
  16. atlas/application/chat/modes/plain.py +74 -0
  17. atlas/application/chat/modes/rag.py +81 -0
  18. atlas/application/chat/modes/tools.py +179 -0
  19. atlas/application/chat/orchestrator.py +213 -0
  20. atlas/application/chat/policies/__init__.py +1 -0
  21. atlas/application/chat/policies/tool_authorization.py +99 -0
  22. atlas/application/chat/preprocessors/__init__.py +1 -0
  23. atlas/application/chat/preprocessors/message_builder.py +92 -0
  24. atlas/application/chat/preprocessors/prompt_override_service.py +104 -0
  25. atlas/application/chat/service.py +454 -0
  26. atlas/application/chat/utilities/__init__.py +6 -0
  27. atlas/application/chat/utilities/error_handler.py +367 -0
  28. atlas/application/chat/utilities/event_notifier.py +546 -0
  29. atlas/application/chat/utilities/file_processor.py +613 -0
  30. atlas/application/chat/utilities/tool_executor.py +789 -0
  31. atlas/atlas_chat_cli.py +347 -0
  32. atlas/atlas_client.py +238 -0
  33. atlas/core/__init__.py +0 -0
  34. atlas/core/auth.py +205 -0
  35. atlas/core/authorization_manager.py +27 -0
  36. atlas/core/capabilities.py +123 -0
  37. atlas/core/compliance.py +215 -0
  38. atlas/core/domain_whitelist.py +147 -0
  39. atlas/core/domain_whitelist_middleware.py +82 -0
  40. atlas/core/http_client.py +28 -0
  41. atlas/core/log_sanitizer.py +102 -0
  42. atlas/core/metrics_logger.py +59 -0
  43. atlas/core/middleware.py +131 -0
  44. atlas/core/otel_config.py +242 -0
  45. atlas/core/prompt_risk.py +200 -0
  46. atlas/core/rate_limit.py +0 -0
  47. atlas/core/rate_limit_middleware.py +64 -0
  48. atlas/core/security_headers_middleware.py +51 -0
  49. atlas/domain/__init__.py +37 -0
  50. atlas/domain/chat/__init__.py +1 -0
  51. atlas/domain/chat/dtos.py +85 -0
  52. atlas/domain/errors.py +96 -0
  53. atlas/domain/messages/__init__.py +12 -0
  54. atlas/domain/messages/models.py +160 -0
  55. atlas/domain/rag_mcp_service.py +664 -0
  56. atlas/domain/sessions/__init__.py +7 -0
  57. atlas/domain/sessions/models.py +36 -0
  58. atlas/domain/unified_rag_service.py +371 -0
  59. atlas/infrastructure/__init__.py +10 -0
  60. atlas/infrastructure/app_factory.py +135 -0
  61. atlas/infrastructure/events/__init__.py +1 -0
  62. atlas/infrastructure/events/cli_event_publisher.py +140 -0
  63. atlas/infrastructure/events/websocket_publisher.py +140 -0
  64. atlas/infrastructure/sessions/in_memory_repository.py +56 -0
  65. atlas/infrastructure/transport/__init__.py +7 -0
  66. atlas/infrastructure/transport/websocket_connection_adapter.py +33 -0
  67. atlas/init_cli.py +226 -0
  68. atlas/interfaces/__init__.py +15 -0
  69. atlas/interfaces/events.py +134 -0
  70. atlas/interfaces/llm.py +54 -0
  71. atlas/interfaces/rag.py +40 -0
  72. atlas/interfaces/sessions.py +75 -0
  73. atlas/interfaces/tools.py +57 -0
  74. atlas/interfaces/transport.py +24 -0
  75. atlas/main.py +564 -0
  76. atlas/mcp/api_key_demo/README.md +76 -0
  77. atlas/mcp/api_key_demo/main.py +172 -0
  78. atlas/mcp/api_key_demo/run.sh +56 -0
  79. atlas/mcp/basictable/main.py +147 -0
  80. atlas/mcp/calculator/main.py +149 -0
  81. atlas/mcp/code-executor/execution_engine.py +98 -0
  82. atlas/mcp/code-executor/execution_environment.py +95 -0
  83. atlas/mcp/code-executor/main.py +528 -0
  84. atlas/mcp/code-executor/result_processing.py +276 -0
  85. atlas/mcp/code-executor/script_generation.py +195 -0
  86. atlas/mcp/code-executor/security_checker.py +140 -0
  87. atlas/mcp/corporate_cars/main.py +437 -0
  88. atlas/mcp/csv_reporter/main.py +545 -0
  89. atlas/mcp/duckduckgo/main.py +182 -0
  90. atlas/mcp/elicitation_demo/README.md +171 -0
  91. atlas/mcp/elicitation_demo/main.py +262 -0
  92. atlas/mcp/env-demo/README.md +158 -0
  93. atlas/mcp/env-demo/main.py +199 -0
  94. atlas/mcp/file_size_test/main.py +284 -0
  95. atlas/mcp/filesystem/main.py +348 -0
  96. atlas/mcp/image_demo/main.py +113 -0
  97. atlas/mcp/image_demo/requirements.txt +4 -0
  98. atlas/mcp/logging_demo/README.md +72 -0
  99. atlas/mcp/logging_demo/main.py +103 -0
  100. atlas/mcp/many_tools_demo/main.py +50 -0
  101. atlas/mcp/order_database/__init__.py +0 -0
  102. atlas/mcp/order_database/main.py +369 -0
  103. atlas/mcp/order_database/signal_data.csv +1001 -0
  104. atlas/mcp/pdfbasic/main.py +394 -0
  105. atlas/mcp/pptx_generator/main.py +760 -0
  106. atlas/mcp/pptx_generator/requirements.txt +13 -0
  107. atlas/mcp/pptx_generator/run_test.sh +1 -0
  108. atlas/mcp/pptx_generator/test_pptx_generator_security.py +169 -0
  109. atlas/mcp/progress_demo/main.py +167 -0
  110. atlas/mcp/progress_updates_demo/QUICKSTART.md +273 -0
  111. atlas/mcp/progress_updates_demo/README.md +120 -0
  112. atlas/mcp/progress_updates_demo/main.py +497 -0
  113. atlas/mcp/prompts/main.py +222 -0
  114. atlas/mcp/public_demo/main.py +189 -0
  115. atlas/mcp/sampling_demo/README.md +169 -0
  116. atlas/mcp/sampling_demo/main.py +234 -0
  117. atlas/mcp/thinking/main.py +77 -0
  118. atlas/mcp/tool_planner/main.py +240 -0
  119. atlas/mcp/ui-demo/badmesh.png +0 -0
  120. atlas/mcp/ui-demo/main.py +383 -0
  121. atlas/mcp/ui-demo/templates/button_demo.html +32 -0
  122. atlas/mcp/ui-demo/templates/data_visualization.html +32 -0
  123. atlas/mcp/ui-demo/templates/form_demo.html +28 -0
  124. atlas/mcp/username-override-demo/README.md +320 -0
  125. atlas/mcp/username-override-demo/main.py +308 -0
  126. atlas/modules/__init__.py +0 -0
  127. atlas/modules/config/__init__.py +34 -0
  128. atlas/modules/config/cli.py +231 -0
  129. atlas/modules/config/config_manager.py +1096 -0
  130. atlas/modules/file_storage/__init__.py +22 -0
  131. atlas/modules/file_storage/cli.py +330 -0
  132. atlas/modules/file_storage/content_extractor.py +290 -0
  133. atlas/modules/file_storage/manager.py +295 -0
  134. atlas/modules/file_storage/mock_s3_client.py +402 -0
  135. atlas/modules/file_storage/s3_client.py +417 -0
  136. atlas/modules/llm/__init__.py +19 -0
  137. atlas/modules/llm/caller.py +287 -0
  138. atlas/modules/llm/litellm_caller.py +675 -0
  139. atlas/modules/llm/models.py +19 -0
  140. atlas/modules/mcp_tools/__init__.py +17 -0
  141. atlas/modules/mcp_tools/client.py +2123 -0
  142. atlas/modules/mcp_tools/token_storage.py +556 -0
  143. atlas/modules/prompts/prompt_provider.py +130 -0
  144. atlas/modules/rag/__init__.py +24 -0
  145. atlas/modules/rag/atlas_rag_client.py +336 -0
  146. atlas/modules/rag/client.py +129 -0
  147. atlas/routes/admin_routes.py +865 -0
  148. atlas/routes/config_routes.py +484 -0
  149. atlas/routes/feedback_routes.py +361 -0
  150. atlas/routes/files_routes.py +274 -0
  151. atlas/routes/health_routes.py +40 -0
  152. atlas/routes/mcp_auth_routes.py +223 -0
  153. atlas/server_cli.py +164 -0
  154. atlas/tests/conftest.py +20 -0
  155. atlas/tests/integration/test_mcp_auth_integration.py +152 -0
  156. atlas/tests/manual_test_sampling.py +87 -0
  157. atlas/tests/modules/mcp_tools/test_client_auth.py +226 -0
  158. atlas/tests/modules/mcp_tools/test_client_env.py +191 -0
  159. atlas/tests/test_admin_mcp_server_management_routes.py +141 -0
  160. atlas/tests/test_agent_roa.py +135 -0
  161. atlas/tests/test_app_factory_smoke.py +47 -0
  162. atlas/tests/test_approval_manager.py +439 -0
  163. atlas/tests/test_atlas_client.py +188 -0
  164. atlas/tests/test_atlas_rag_client.py +447 -0
  165. atlas/tests/test_atlas_rag_integration.py +224 -0
  166. atlas/tests/test_attach_file_flow.py +287 -0
  167. atlas/tests/test_auth_utils.py +165 -0
  168. atlas/tests/test_backend_public_url.py +185 -0
  169. atlas/tests/test_banner_logging.py +287 -0
  170. atlas/tests/test_capability_tokens_and_injection.py +203 -0
  171. atlas/tests/test_compliance_level.py +54 -0
  172. atlas/tests/test_compliance_manager.py +253 -0
  173. atlas/tests/test_config_manager.py +617 -0
  174. atlas/tests/test_config_manager_paths.py +12 -0
  175. atlas/tests/test_core_auth.py +18 -0
  176. atlas/tests/test_core_utils.py +190 -0
  177. atlas/tests/test_docker_env_sync.py +202 -0
  178. atlas/tests/test_domain_errors.py +329 -0
  179. atlas/tests/test_domain_whitelist.py +359 -0
  180. atlas/tests/test_elicitation_manager.py +408 -0
  181. atlas/tests/test_elicitation_routing.py +296 -0
  182. atlas/tests/test_env_demo_server.py +88 -0
  183. atlas/tests/test_error_classification.py +113 -0
  184. atlas/tests/test_error_flow_integration.py +116 -0
  185. atlas/tests/test_feedback_routes.py +333 -0
  186. atlas/tests/test_file_content_extraction.py +1134 -0
  187. atlas/tests/test_file_extraction_routes.py +158 -0
  188. atlas/tests/test_file_library.py +107 -0
  189. atlas/tests/test_file_manager_unit.py +18 -0
  190. atlas/tests/test_health_route.py +49 -0
  191. atlas/tests/test_http_client_stub.py +8 -0
  192. atlas/tests/test_imports_smoke.py +30 -0
  193. atlas/tests/test_interfaces_llm_response.py +9 -0
  194. atlas/tests/test_issue_access_denied_fix.py +136 -0
  195. atlas/tests/test_llm_env_expansion.py +836 -0
  196. atlas/tests/test_log_level_sensitive_data.py +285 -0
  197. atlas/tests/test_mcp_auth_routes.py +341 -0
  198. atlas/tests/test_mcp_client_auth.py +331 -0
  199. atlas/tests/test_mcp_data_injection.py +270 -0
  200. atlas/tests/test_mcp_get_authorized_servers.py +95 -0
  201. atlas/tests/test_mcp_hot_reload.py +512 -0
  202. atlas/tests/test_mcp_image_content.py +424 -0
  203. atlas/tests/test_mcp_logging.py +172 -0
  204. atlas/tests/test_mcp_progress_updates.py +313 -0
  205. atlas/tests/test_mcp_prompt_override_system_prompt.py +102 -0
  206. atlas/tests/test_mcp_prompts_server.py +39 -0
  207. atlas/tests/test_mcp_tool_result_parsing.py +296 -0
  208. atlas/tests/test_metrics_logger.py +56 -0
  209. atlas/tests/test_middleware_auth.py +379 -0
  210. atlas/tests/test_prompt_risk_and_acl.py +141 -0
  211. atlas/tests/test_rag_mcp_aggregator.py +204 -0
  212. atlas/tests/test_rag_mcp_service.py +224 -0
  213. atlas/tests/test_rate_limit_middleware.py +45 -0
  214. atlas/tests/test_routes_config_smoke.py +60 -0
  215. atlas/tests/test_routes_files_download_token.py +41 -0
  216. atlas/tests/test_routes_files_health.py +18 -0
  217. atlas/tests/test_runtime_imports.py +53 -0
  218. atlas/tests/test_sampling_integration.py +482 -0
  219. atlas/tests/test_security_admin_routes.py +61 -0
  220. atlas/tests/test_security_capability_tokens.py +65 -0
  221. atlas/tests/test_security_file_stats_scope.py +21 -0
  222. atlas/tests/test_security_header_injection.py +191 -0
  223. atlas/tests/test_security_headers_and_filename.py +63 -0
  224. atlas/tests/test_shared_session_repository.py +101 -0
  225. atlas/tests/test_system_prompt_loading.py +181 -0
  226. atlas/tests/test_token_storage.py +505 -0
  227. atlas/tests/test_tool_approval_config.py +93 -0
  228. atlas/tests/test_tool_approval_utils.py +356 -0
  229. atlas/tests/test_tool_authorization_group_filtering.py +223 -0
  230. atlas/tests/test_tool_details_in_config.py +108 -0
  231. atlas/tests/test_tool_planner.py +300 -0
  232. atlas/tests/test_unified_rag_service.py +398 -0
  233. atlas/tests/test_username_override_in_approval.py +258 -0
  234. atlas/tests/test_websocket_auth_header.py +168 -0
  235. atlas/version.py +6 -0
  236. atlas_chat-0.1.0.data/data/.env.example +253 -0
  237. atlas_chat-0.1.0.data/data/config/defaults/compliance-levels.json +44 -0
  238. atlas_chat-0.1.0.data/data/config/defaults/domain-whitelist.json +123 -0
  239. atlas_chat-0.1.0.data/data/config/defaults/file-extractors.json +74 -0
  240. atlas_chat-0.1.0.data/data/config/defaults/help-config.json +198 -0
  241. atlas_chat-0.1.0.data/data/config/defaults/llmconfig-buggy.yml +11 -0
  242. atlas_chat-0.1.0.data/data/config/defaults/llmconfig.yml +19 -0
  243. atlas_chat-0.1.0.data/data/config/defaults/mcp.json +138 -0
  244. atlas_chat-0.1.0.data/data/config/defaults/rag-sources.json +17 -0
  245. atlas_chat-0.1.0.data/data/config/defaults/splash-config.json +16 -0
  246. atlas_chat-0.1.0.dist-info/METADATA +236 -0
  247. atlas_chat-0.1.0.dist-info/RECORD +250 -0
  248. atlas_chat-0.1.0.dist-info/WHEEL +5 -0
  249. atlas_chat-0.1.0.dist-info/entry_points.txt +4 -0
  250. atlas_chat-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2123 @@
1
+ """FastMCP client for connecting to MCP servers and managing tools."""
2
+
3
+ import asyncio
4
+ import contextvars
5
+ import json
6
+ import logging
7
+ import os
8
+ import sys
9
+ import time
10
+ from contextlib import asynccontextmanager
11
+ from pathlib import Path
12
+ from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional
13
+
14
+ from fastmcp import Client
15
+ from fastmcp.client.transports import StreamableHttpTransport
16
+
17
+ from atlas.core.log_sanitizer import sanitize_for_logging
18
+ from atlas.core.metrics_logger import log_metric
19
+ from atlas.domain.messages.models import ToolCall, ToolResult
20
+ from atlas.modules.config import config_manager
21
+ from atlas.modules.config.config_manager import resolve_env_var
22
+ from atlas.modules.mcp_tools.token_storage import AuthenticationRequiredException
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Type alias for log callback function
27
+ LogCallback = Callable[[str, str, str, Dict[str, Any]], Awaitable[None]]
28
+
29
+
30
+ class _ElicitationRoutingContext:
31
+ def __init__(
32
+ self,
33
+ server_name: str,
34
+ tool_call: ToolCall,
35
+ update_cb: Optional[Callable[[Dict[str, Any]], Awaitable[None]]],
36
+ ):
37
+ self.server_name = server_name
38
+ self.tool_call = tool_call
39
+ self.update_cb = update_cb
40
+
41
+ # Context-local override used to route MCP logs to the *current* request/session.
42
+ # This prevents cross-user log leakage when MCPToolManager is shared across connections.
43
+ _ACTIVE_LOG_CALLBACK: contextvars.ContextVar[Optional[LogCallback]] = contextvars.ContextVar(
44
+ "mcp_active_log_callback",
45
+ default=None,
46
+ )
47
+
48
+ # Dictionary-based routing for elicitation so a shared Client can still deliver
49
+ # elicitation requests to the correct user's WebSocket.
50
+ # Key: (server_name, tool_call_id) tuple to avoid collisions with concurrent tool calls
51
+ # Note: Cannot use contextvars.ContextVar because MCP receive loop runs in a different task
52
+ _ELICITATION_ROUTING: Dict[tuple, _ElicitationRoutingContext] = {}
53
+
54
+ # Dictionary-based routing for sampling requests (similar to elicitation)
55
+ # Key: (server_name, tool_call_id) tuple to avoid collisions with concurrent tool calls
56
+ _SAMPLING_ROUTING: Dict[tuple, "_SamplingRoutingContext"] = {}
57
+
58
+
59
+ class _SamplingRoutingContext:
60
+ """Context for routing sampling requests to the correct tool execution."""
61
+ def __init__(
62
+ self,
63
+ server_name: str,
64
+ tool_call: ToolCall,
65
+ update_cb: Optional[Callable[[Dict[str, Any]], Awaitable[None]]],
66
+ ):
67
+ self.server_name = server_name
68
+ self.tool_call = tool_call
69
+ self.update_cb = update_cb
70
+
71
+ # Mapping from MCP log levels to Python logging levels
72
+ MCP_TO_PYTHON_LOG_LEVEL = {
73
+ "debug": logging.DEBUG,
74
+ "info": logging.INFO,
75
+ "notice": logging.INFO,
76
+ "warning": logging.WARNING,
77
+ "warn": logging.WARNING,
78
+ "error": logging.ERROR,
79
+ "alert": logging.CRITICAL,
80
+ "critical": logging.CRITICAL,
81
+ "emergency": logging.CRITICAL,
82
+ }
83
+
84
+
85
+ class MCPToolManager:
86
+ """Manager for MCP servers and their tools.
87
+
88
+ Default config path now points to config/overrides (or env override) with legacy fallback.
89
+
90
+ Supports:
91
+ - Hot-reloading configuration from disk via reload_config()
92
+ - Tracking failed server connections for retry
93
+ - Auto-reconnect with exponential backoff (when feature flag is enabled)
94
+ """
95
+
96
+ def __init__(self, config_path: Optional[str] = None, log_callback: Optional[LogCallback] = None):
97
+ if config_path is None:
98
+ # Use config manager to get config path
99
+ app_settings = config_manager.app_settings
100
+ overrides_root = Path(app_settings.app_config_overrides)
101
+
102
+ # If relative, resolve from project root
103
+ if not overrides_root.is_absolute():
104
+ # This file is in backend/modules/mcp_tools/client.py
105
+ backend_root = Path(__file__).parent.parent.parent
106
+ project_root = backend_root.parent
107
+ overrides_root = project_root / overrides_root
108
+
109
+ candidate = overrides_root / "mcp.json"
110
+ if not candidate.exists():
111
+ # Legacy fallback
112
+ candidate = Path("backend/configfilesadmin/mcp.json")
113
+ if not candidate.exists():
114
+ candidate = Path("backend/configfiles/mcp.json")
115
+ self.config_path = str(candidate)
116
+ # Use default config manager when no path specified
117
+ mcp_config = config_manager.mcp_config
118
+ self.servers_config = {name: server.model_dump() for name, server in mcp_config.servers.items()}
119
+ else:
120
+ # Load config from the specified path
121
+ self.config_path = config_path
122
+ config_file = Path(config_path)
123
+ if config_file.exists():
124
+ from atlas.modules.config.config_manager import MCPConfig
125
+ data = json.loads(config_file.read_text())
126
+ # Convert flat structure to nested structure for Pydantic
127
+ servers_data = {"servers": data}
128
+ mcp_config = MCPConfig(**servers_data)
129
+ self.servers_config = {name: server.model_dump() for name, server in mcp_config.servers.items()}
130
+ else:
131
+ logger.warning(f"Custom config path specified but file not found: {config_path}")
132
+ self.servers_config = {}
133
+ self.clients = {}
134
+ self.available_tools = {}
135
+ self.available_prompts = {}
136
+
137
+ # Track failed servers for reconnection with backoff
138
+ self._failed_servers: Dict[str, Dict[str, Any]] = {}
139
+ # {server_name: {"last_attempt": timestamp, "attempt_count": int, "error": str}}
140
+
141
+ # Reconnect task reference (used by auto-reconnect background task)
142
+ self._reconnect_task: Optional[asyncio.Task] = None
143
+ self._reconnect_running = False
144
+
145
+ # Default log callback (used when no request-scoped callback is active).
146
+ # Signature: (server_name, level, message, extra_data) -> None
147
+ self._default_log_callback = log_callback
148
+
149
+ # Get configured log level for filtering
150
+ self._min_log_level = self._get_min_log_level()
151
+
152
+ # Per-user client cache for servers requiring user-specific authentication
153
+ # Key: (user_email, server_name), Value: FastMCP Client instance
154
+ self._user_clients: Dict[tuple, Client] = {}
155
+ self._user_clients_lock = asyncio.Lock()
156
+
157
+ def _get_min_log_level(self) -> int:
158
+ """Get the minimum log level from environment or config."""
159
+ try:
160
+ app_settings = config_manager.app_settings
161
+ raw_level_name = getattr(app_settings, "log_level", None)
162
+ if not isinstance(raw_level_name, str):
163
+ raise TypeError("log_level must be a string")
164
+ level_name = raw_level_name.upper()
165
+ except Exception:
166
+ level_name = os.getenv("LOG_LEVEL", "INFO").upper()
167
+
168
+ level = getattr(logging, level_name, None)
169
+ return level if isinstance(level, int) else logging.INFO
170
+
171
+ def _create_log_handler(self, server_name: str):
172
+ """Create a log handler for an MCP server.
173
+
174
+ This handler forwards MCP server logs to the backend logger and optionally to the UI.
175
+ Logs are filtered based on the configured LOG_LEVEL.
176
+
177
+ Args:
178
+ server_name: Name of the MCP server
179
+
180
+ Returns:
181
+ An async function that handles LogMessage objects from fastmcp
182
+ """
183
+ async def log_handler(message) -> None:
184
+ """Handle log messages from MCP server."""
185
+ try:
186
+ # Import here to avoid circular dependency
187
+
188
+ # Handle both LogMessage objects and dict-like structures
189
+ if hasattr(message, 'level'):
190
+ log_level_str = message.level.lower()
191
+ log_data = message.data if hasattr(message, 'data') else {}
192
+ else:
193
+ # Fallback for dict-like messages
194
+ log_level_str = message.get('level', 'info').lower()
195
+ log_data = message.get('data', {})
196
+
197
+ msg = log_data.get('msg', '') if isinstance(log_data, dict) else str(log_data)
198
+ extra = log_data.get('extra', {}) if isinstance(log_data, dict) else {}
199
+
200
+ # Convert MCP log level to Python logging level
201
+ python_log_level = MCP_TO_PYTHON_LOG_LEVEL.get(log_level_str, logging.INFO)
202
+
203
+ # Filter based on configured minimum log level
204
+ if python_log_level < self._min_log_level:
205
+ return
206
+
207
+ # Backend log noise reduction: tool servers can be very chatty at INFO.
208
+ # Keep their INFO messages available at LOG_LEVEL=DEBUG, but avoid flooding
209
+ # app logs at LOG_LEVEL=INFO. Warnings/errors still surface at INFO.
210
+ backend_log_level = python_log_level if python_log_level >= logging.WARNING else logging.DEBUG
211
+
212
+ # Log to backend logger with server context
213
+ logger.log(
214
+ backend_log_level,
215
+ f"[MCP:{sanitize_for_logging(server_name)}] {sanitize_for_logging(msg)}",
216
+ extra={"mcp_server": server_name, "mcp_extra": extra}
217
+ )
218
+
219
+ # Forward to the active (request-scoped) callback when present,
220
+ # otherwise fall back to the default callback.
221
+ callback = _ACTIVE_LOG_CALLBACK.get() or self._default_log_callback
222
+ if callback is not None:
223
+ await callback(server_name, log_level_str, msg, extra)
224
+
225
+ except Exception as e:
226
+ logger.warning(f"Error handling log from MCP server {server_name}: {e}")
227
+
228
+ return log_handler
229
+
230
+ def set_log_callback(self, callback: Optional[LogCallback]) -> None:
231
+ """Set or update the log callback for forwarding MCP server logs to UI.
232
+
233
+ Args:
234
+ callback: Async function that receives (server_name, level, message, extra_data)
235
+ """
236
+ self._default_log_callback = callback
237
+
238
+ @asynccontextmanager
239
+ async def _use_log_callback(self, callback: Optional[LogCallback]) -> AsyncIterator[None]:
240
+ """Temporarily set a request-scoped log callback.
241
+
242
+ This is used to bind MCP server logs to the current tool execution so they
243
+ are forwarded only to the correct user's WebSocket connection.
244
+ """
245
+ token = _ACTIVE_LOG_CALLBACK.set(callback)
246
+ try:
247
+ yield
248
+ finally:
249
+ _ACTIVE_LOG_CALLBACK.reset(token)
250
+
251
+ @asynccontextmanager
252
+ async def _use_elicitation_context(
253
+ self,
254
+ server_name: str,
255
+ tool_call: ToolCall,
256
+ update_cb: Optional[Callable[[Dict[str, Any]], Awaitable[None]]],
257
+ ) -> AsyncIterator[None]:
258
+ """
259
+ Set up elicitation routing for a tool call.
260
+ Uses dictionary-based routing (not contextvars) because MCP receive loop runs in a different task.
261
+ Key is (server_name, tool_call.id) to avoid collisions with concurrent tool calls.
262
+ """
263
+ routing = _ElicitationRoutingContext(server_name, tool_call, update_cb)
264
+ routing_key = (server_name, tool_call.id)
265
+ _ELICITATION_ROUTING[routing_key] = routing
266
+ try:
267
+ yield
268
+ finally:
269
+ _ELICITATION_ROUTING.pop(routing_key, None)
270
+
271
+ def _create_elicitation_handler(self, server_name: str):
272
+ """
273
+ Create an elicitation handler for a specific MCP server.
274
+
275
+ Returns a handler function that captures the server_name,
276
+ allowing dictionary-based routing that works across async tasks.
277
+ """
278
+ async def handler(message, response_type, params, _context):
279
+ """Per-server elicitation handler with captured server_name."""
280
+ from fastmcp.client.elicitation import ElicitResult
281
+ from mcp.types import ElicitRequestFormParams
282
+
283
+ # Find routing context for this server (keyed by (server_name, tool_call_id))
284
+ routing = None
285
+ for (srv, _tcid), ctx in _ELICITATION_ROUTING.items():
286
+ if srv == server_name:
287
+ routing = ctx
288
+ break
289
+ if routing is None:
290
+ logger.warning(
291
+ f"Elicitation request for server '{server_name}' but no routing context - "
292
+ f"elicitation cancelled. Message: {message[:50]}..."
293
+ )
294
+ return ElicitResult(action="cancel", content=None)
295
+ if routing.update_cb is None:
296
+ logger.warning(
297
+ f"Elicitation request for server '{server_name}', tool '{routing.tool_call.name}' "
298
+ f"but update_cb is None - elicitation cancelled. Message: {message[:50]}..."
299
+ )
300
+ return ElicitResult(action="cancel", content=None)
301
+
302
+ response_schema: Dict[str, Any] = {}
303
+ if isinstance(params, ElicitRequestFormParams):
304
+ response_schema = params.requestedSchema or {}
305
+
306
+ try:
307
+ import uuid
308
+
309
+ from atlas.application.chat.elicitation_manager import get_elicitation_manager
310
+
311
+ elicitation_id = str(uuid.uuid4())
312
+ elicitation_manager = get_elicitation_manager()
313
+
314
+ request = elicitation_manager.create_elicitation_request(
315
+ elicitation_id=elicitation_id,
316
+ tool_call_id=routing.tool_call.id,
317
+ tool_name=routing.tool_call.name,
318
+ message=message,
319
+ response_schema=response_schema,
320
+ )
321
+
322
+ logger.debug(f"Sending elicitation_request to frontend for server '{server_name}'")
323
+ await routing.update_cb(
324
+ {
325
+ "type": "elicitation_request",
326
+ "elicitation_id": elicitation_id,
327
+ "tool_call_id": routing.tool_call.id,
328
+ "tool_name": routing.tool_call.name,
329
+ "message": message,
330
+ "response_schema": response_schema,
331
+ }
332
+ )
333
+
334
+ try:
335
+ response = await request.wait_for_response(timeout=300.0)
336
+ finally:
337
+ elicitation_manager.cleanup_request(elicitation_id)
338
+
339
+ action = response.get("action", "cancel")
340
+ data = response.get("data")
341
+
342
+ if action != "accept":
343
+ return ElicitResult(action=action, content=None)
344
+
345
+ # Approval-only elicitation (response_type=None) must return an empty object.
346
+ # Some UIs send placeholder payloads like {'none': ''}; don't forward them.
347
+ if response_type is None:
348
+ return ElicitResult(action="accept", content={})
349
+
350
+ if data is None:
351
+ return ElicitResult(action="accept", content=None)
352
+
353
+ # FastMCP requires elicitation response content to be a JSON object.
354
+ if not isinstance(data, dict):
355
+ props: Dict[str, Any] = {}
356
+ if isinstance(response_schema, dict):
357
+ props = response_schema.get("properties") or {}
358
+ if list(props.keys()) == ["value"]:
359
+ data = {"value": data}
360
+ else:
361
+ data = {"value": data}
362
+
363
+ return ElicitResult(action="accept", content=data)
364
+
365
+ except asyncio.TimeoutError:
366
+ logger.warning(f"Elicitation timeout for server '{server_name}'")
367
+ return ElicitResult(action="cancel", content=None)
368
+ except Exception as e:
369
+ logger.error(f"Error handling elicitation for server '{server_name}': {e}", exc_info=True)
370
+ return ElicitResult(action="cancel", content=None)
371
+
372
+ return handler
373
+
374
+ @asynccontextmanager
375
+ async def _use_sampling_context(
376
+ self,
377
+ server_name: str,
378
+ tool_call: ToolCall,
379
+ update_cb: Optional[Callable[[Dict[str, Any]], Awaitable[None]]],
380
+ ) -> AsyncIterator[None]:
381
+ """
382
+ Set up sampling routing for a tool call.
383
+ Uses dictionary-based routing (not contextvars) because MCP receive loop runs in a different task.
384
+ Key is (server_name, tool_call.id) to avoid collisions with concurrent tool calls.
385
+ """
386
+ routing = _SamplingRoutingContext(server_name, tool_call, update_cb)
387
+ routing_key = (server_name, tool_call.id)
388
+ _SAMPLING_ROUTING[routing_key] = routing
389
+ try:
390
+ yield
391
+ finally:
392
+ _SAMPLING_ROUTING.pop(routing_key, None)
393
+
394
+ def _create_sampling_handler(self, server_name: str):
395
+ """
396
+ Create a sampling handler for a specific MCP server.
397
+
398
+ This handler intercepts MCP sampling requests and routes them to the LLM.
399
+ Returns a handler function that captures the server_name for routing.
400
+ """
401
+ async def handler(messages, params=None, context=None):
402
+ """Per-server sampling handler with captured server_name."""
403
+ from mcp.types import CreateMessageResult, SamplingMessage, TextContent
404
+
405
+ # Find routing context for this server (keyed by (server_name, tool_call_id))
406
+ routing = None
407
+ for (srv, _tcid), ctx in _SAMPLING_ROUTING.items():
408
+ if srv == server_name:
409
+ routing = ctx
410
+ break
411
+ if routing is None:
412
+ logger.warning(
413
+ f"Sampling request for server '{server_name}' but no routing context - "
414
+ f"sampling cancelled."
415
+ )
416
+ raise Exception("No routing context for sampling request")
417
+
418
+ try:
419
+ message_dicts = []
420
+ for msg in messages:
421
+ if isinstance(msg, SamplingMessage):
422
+ text = ""
423
+ if isinstance(msg.content, TextContent):
424
+ text = msg.content.text
425
+ elif isinstance(msg.content, list):
426
+ for item in msg.content:
427
+ if isinstance(item, TextContent):
428
+ text += item.text
429
+ else:
430
+ text = str(msg.content)
431
+ message_dicts.append({
432
+ "role": msg.role,
433
+ "content": text
434
+ })
435
+ elif isinstance(msg, str):
436
+ message_dicts.append({
437
+ "role": "user",
438
+ "content": msg
439
+ })
440
+ else:
441
+ message_dicts.append(msg)
442
+
443
+ system_prompt = getattr(params, 'systemPrompt', None) if params else None
444
+ temperature = getattr(params, 'temperature', None) if params else None
445
+ max_tokens = getattr(params, 'maxTokens', 512) if params else 512
446
+ model_preferences_raw = getattr(params, 'modelPreferences', None) if params else None
447
+
448
+ model_preferences = None
449
+ if model_preferences_raw:
450
+ if isinstance(model_preferences_raw, str):
451
+ model_preferences = [model_preferences_raw]
452
+ elif isinstance(model_preferences_raw, list):
453
+ model_preferences = model_preferences_raw
454
+
455
+ if system_prompt:
456
+ message_dicts.insert(0, {
457
+ "role": "system",
458
+ "content": system_prompt
459
+ })
460
+
461
+ logger.info(
462
+ f"Sampling request from server '{server_name}' tool '{routing.tool_call.name}': "
463
+ f"{len(message_dicts)} messages, temperature={temperature}, max_tokens={max_tokens}"
464
+ )
465
+
466
+ from atlas.modules.config import config_manager
467
+ from atlas.modules.llm.litellm_caller import LiteLLMCaller
468
+
469
+ llm_caller = LiteLLMCaller()
470
+
471
+ llm_config = config_manager.llm_config
472
+ model_name = None
473
+
474
+ if model_preferences:
475
+ for pref in model_preferences:
476
+ if pref in llm_config.models:
477
+ model_name = pref
478
+ break
479
+ for name, model_config in llm_config.models.items():
480
+ if model_config.model_name == pref:
481
+ model_name = name
482
+ break
483
+ if model_name:
484
+ break
485
+
486
+ if not model_name:
487
+ model_name = next(iter(llm_config.models.keys()))
488
+
489
+ logger.debug(
490
+ f"Using model '{model_name}' for sampling "
491
+ f"(preferences: {model_preferences})"
492
+ )
493
+
494
+ response = await llm_caller.call_plain(
495
+ model_name=model_name,
496
+ messages=message_dicts,
497
+ temperature=temperature,
498
+ max_tokens=max_tokens
499
+ )
500
+
501
+ logger.info(
502
+ f"Sampling completed for server '{server_name}': "
503
+ f"response_length={len(response) if response else 0}"
504
+ )
505
+
506
+ return CreateMessageResult(
507
+ role="assistant",
508
+ content=TextContent(type="text", text=response),
509
+ model=model_name
510
+ )
511
+
512
+ except Exception as e:
513
+ logger.error(f"Error handling sampling for server '{server_name}': {e}", exc_info=True)
514
+ raise
515
+
516
+ return handler
517
+
518
+ def reload_config(self) -> Dict[str, Any]:
519
+ """Reload MCP server configuration from disk.
520
+
521
+ This re-reads the mcp.json configuration file and updates servers_config.
522
+ Call initialize_clients() and discover_tools()/discover_prompts() afterward
523
+ to apply the changes.
524
+
525
+ Returns:
526
+ Dict with previous and new server lists for comparison
527
+ """
528
+ previous_servers = set(self.servers_config.keys())
529
+
530
+ # Reload from config manager (which reads from disk)
531
+ new_mcp_config = config_manager.reload_mcp_config()
532
+ self.servers_config = {
533
+ name: server.model_dump()
534
+ for name, server in new_mcp_config.servers.items()
535
+ }
536
+
537
+ new_servers = set(self.servers_config.keys())
538
+
539
+ # Clear failed servers tracking for removed servers
540
+ removed_servers = previous_servers - new_servers
541
+ for server_name in removed_servers:
542
+ self._failed_servers.pop(server_name, None)
543
+
544
+ added_servers = new_servers - previous_servers
545
+ unchanged_servers = previous_servers & new_servers
546
+
547
+ logger.info(
548
+ f"MCP config reloaded: added={list(added_servers)}, "
549
+ f"removed={list(removed_servers)}, unchanged={list(unchanged_servers)}"
550
+ )
551
+
552
+ return {
553
+ "previous_servers": list(previous_servers),
554
+ "new_servers": list(new_servers),
555
+ "added": list(added_servers),
556
+ "removed": list(removed_servers),
557
+ "unchanged": list(unchanged_servers)
558
+ }
559
+
560
+ def get_failed_servers(self) -> Dict[str, Dict[str, Any]]:
561
+ """Get information about servers that failed to connect.
562
+
563
+ Returns:
564
+ Dict mapping server name to failure info including last_attempt time,
565
+ attempt_count, and error message.
566
+ """
567
+ return dict(self._failed_servers)
568
+
569
+ def _record_server_failure(self, server_name: str, error: str) -> None:
570
+ """Record a server connection failure for tracking."""
571
+ if server_name in self._failed_servers:
572
+ self._failed_servers[server_name]["attempt_count"] += 1
573
+ self._failed_servers[server_name]["last_attempt"] = time.time()
574
+ self._failed_servers[server_name]["error"] = error
575
+ else:
576
+ self._failed_servers[server_name] = {
577
+ "last_attempt": time.time(),
578
+ "attempt_count": 1,
579
+ "error": error
580
+ }
581
+
582
+ def _clear_server_failure(self, server_name: str) -> None:
583
+ """Clear failure tracking for a server after successful connection."""
584
+ self._failed_servers.pop(server_name, None)
585
+
586
+ def _calculate_backoff_delay(self, attempt_count: int) -> float:
587
+ """Calculate exponential backoff delay for reconnection attempts.
588
+
589
+ Uses settings from config_manager for base interval, max interval, and multiplier.
590
+ """
591
+ app_settings = config_manager.app_settings
592
+ base_interval = app_settings.mcp_reconnect_interval
593
+ max_interval = app_settings.mcp_reconnect_max_interval
594
+ multiplier = app_settings.mcp_reconnect_backoff_multiplier
595
+
596
+ delay = base_interval * (multiplier ** (attempt_count - 1))
597
+ return min(delay, max_interval)
598
+
599
+
600
+ def _determine_transport_type(self, config: Dict[str, Any]) -> str:
601
+ """Determine the transport type for an MCP server configuration.
602
+
603
+ Priority order:
604
+ 1. Explicit 'transport' field (highest priority)
605
+ 2. Auto-detection from command
606
+ 3. Auto-detection from URL if it has protocol
607
+ 4. Fallback to 'type' field (backward compatibility)
608
+ """
609
+ # 1. Explicit transport field takes highest priority
610
+ if config.get("transport"):
611
+ logger.debug(f"Using explicit transport: {config['transport']}")
612
+ return config["transport"]
613
+
614
+ # 2. Auto-detect from command (takes priority over URL)
615
+ if config.get("command"):
616
+ logger.debug("Auto-detected STDIO transport from command")
617
+ return "stdio"
618
+
619
+ # 3. Auto-detect from URL if it has protocol
620
+ url = config.get("url")
621
+ if url:
622
+ if url.startswith(("http://", "https://")):
623
+ if url.endswith("/sse"):
624
+ logger.debug(f"Auto-detected SSE transport from URL: {url}")
625
+ return "sse"
626
+ else:
627
+ logger.debug(f"Auto-detected HTTP transport from URL: {url}")
628
+ return "http"
629
+ else:
630
+ # URL without protocol - check if type field specifies transport
631
+ transport_type = config.get("type", "stdio")
632
+ if transport_type in ["http", "sse"]:
633
+ logger.debug(f"Using type field '{transport_type}' for URL without protocol: {url}")
634
+ return transport_type
635
+ else:
636
+ logger.debug(f"URL without protocol, defaulting to HTTP: {url}")
637
+ return "http"
638
+
639
+ # 4. Fallback to type field (backward compatibility)
640
+ transport_type = config.get("type", "stdio")
641
+ logger.debug(f"Using fallback transport type: {transport_type}")
642
+ return transport_type
643
+
644
+ async def _initialize_single_client(self, server_name: str, config: Dict[str, Any]) -> Optional[Client]:
645
+ """Initialize a single MCP client. Returns None if initialization fails."""
646
+ safe_server_name = sanitize_for_logging(server_name)
647
+ # Keep INFO logs concise; config/transport details can be very verbose.
648
+ logger.info("Initializing MCP client for server '%s'", safe_server_name)
649
+ logger.debug("Server config for '%s': %s", safe_server_name, sanitize_for_logging(str(config)))
650
+ try:
651
+ transport_type = self._determine_transport_type(config)
652
+ logger.debug("Determined transport type for %s: %s", safe_server_name, transport_type)
653
+
654
+ if transport_type in ["http", "sse"]:
655
+ # HTTP/SSE MCP server
656
+ url = config.get("url")
657
+ if not url:
658
+ logger.error(f"No URL provided for HTTP/SSE server: {server_name}")
659
+ return None
660
+
661
+ # Ensure URL has protocol for FastMCP client
662
+ if not url.startswith(("http://", "https://")):
663
+ url = f"http://{url}"
664
+ logger.debug(f"Added http:// protocol to URL: {url}")
665
+
666
+ raw_token = config.get("auth_token")
667
+ try:
668
+ token = resolve_env_var(raw_token) # Resolve ${ENV_VAR} if present
669
+ except ValueError as e:
670
+ logger.error(f"Failed to resolve auth_token for {server_name}: {e}")
671
+ return None # Skip this server
672
+
673
+ # Create log handler for this server
674
+ log_handler = self._create_log_handler(server_name)
675
+
676
+ if transport_type == "sse":
677
+ # Use explicit SSE transport
678
+ logger.debug(f"Creating SSE client for {server_name} at {url}")
679
+ client = Client(
680
+ url,
681
+ auth=token,
682
+ log_handler=log_handler,
683
+ elicitation_handler=self._create_elicitation_handler(server_name),
684
+ sampling_handler=self._create_sampling_handler(server_name),
685
+ )
686
+ else:
687
+ # Use HTTP transport (StreamableHttp)
688
+ logger.debug(f"Creating HTTP client for {server_name} at {url}")
689
+ client = Client(
690
+ url,
691
+ auth=token,
692
+ log_handler=log_handler,
693
+ elicitation_handler=self._create_elicitation_handler(server_name),
694
+ sampling_handler=self._create_sampling_handler(server_name),
695
+ )
696
+
697
+ logger.info(f"Created {transport_type.upper()} MCP client for {server_name}")
698
+ return client
699
+
700
+ elif transport_type == "stdio":
701
+ # STDIO MCP server
702
+ command = config.get("command")
703
+ logger.debug("STDIO transport command for %s: %s", safe_server_name, command)
704
+ if command:
705
+ # Ensure MCP stdio servers run under the same interpreter as the backend.
706
+ # In dev containers, PATH `python` may not have required deps.
707
+ if command[0] in {"python", "python3"}:
708
+ command = [sys.executable, *command[1:]]
709
+
710
+ # Custom command specified
711
+ cwd = config.get("cwd")
712
+ env = config.get("env")
713
+ logger.debug("Working directory specified for %s: %s", safe_server_name, cwd)
714
+
715
+ # Resolve environment variables in env dict
716
+ resolved_env = None
717
+ if env is not None:
718
+ resolved_env = {}
719
+ for key, value in env.items():
720
+ try:
721
+ resolved_env[key] = resolve_env_var(value)
722
+ logger.debug(f"Resolved env var {key} for {server_name}")
723
+ except ValueError as e:
724
+ logger.error(f"Failed to resolve env var {key} for {server_name}: {e}")
725
+ return None # Skip this server if env var resolution fails
726
+ logger.debug("Environment variables specified for %s: %s", safe_server_name, list(resolved_env.keys()))
727
+
728
+ # Create log handler for this server
729
+ log_handler = self._create_log_handler(server_name)
730
+
731
+ if cwd:
732
+ # Convert relative path to absolute path from project root
733
+ if not os.path.isabs(cwd):
734
+ # Get project root (3 levels up from client.py)
735
+ # client.py is at: /workspaces/atlas-ui-3-11/backend/modules/mcp_tools/client.py
736
+ # project root is: /workspaces/atlas-ui-3-11
737
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
738
+ cwd = os.path.join(project_root, cwd)
739
+ logger.debug("Converted relative cwd to absolute for %s: %s", safe_server_name, cwd)
740
+
741
+ if os.path.exists(cwd):
742
+ logger.debug("Working directory exists for %s: %s", safe_server_name, cwd)
743
+ logger.debug("Creating STDIO client for %s with command=%s cwd=%s", safe_server_name, command, cwd)
744
+ from fastmcp.client.transports import StdioTransport
745
+ transport = StdioTransport(command=command[0], args=command[1:], cwd=cwd, env=resolved_env)
746
+ client = Client(
747
+ transport,
748
+ log_handler=log_handler,
749
+ elicitation_handler=self._create_elicitation_handler(server_name),
750
+ sampling_handler=self._create_sampling_handler(server_name),
751
+ )
752
+ logger.info(f"Successfully created STDIO MCP client for {server_name} with custom command and cwd")
753
+ return client
754
+ else:
755
+ logger.error(f"Working directory does not exist: {cwd}")
756
+ return None
757
+ else:
758
+ logger.debug("No cwd specified for %s; creating STDIO client with command=%s", safe_server_name, command)
759
+ from fastmcp.client.transports import StdioTransport
760
+ transport = StdioTransport(command=command[0], args=command[1:], env=resolved_env)
761
+ client = Client(
762
+ transport,
763
+ log_handler=log_handler,
764
+ elicitation_handler=self._create_elicitation_handler(server_name),
765
+ sampling_handler=self._create_sampling_handler(server_name),
766
+ )
767
+ logger.info(f"Successfully created STDIO MCP client for {server_name} with custom command")
768
+ return client
769
+ else:
770
+ # Fallback to old behavior for backward compatibility
771
+ server_path = f"mcp/{server_name}/main.py"
772
+ logger.debug(f"Attempting to initialize {server_name} at path: {server_path}")
773
+ if os.path.exists(server_path):
774
+ logger.debug(f"Server script exists for {server_name}, creating client...")
775
+ log_handler = self._create_log_handler(server_name)
776
+ client = Client(
777
+ server_path,
778
+ log_handler=log_handler,
779
+ elicitation_handler=self._create_elicitation_handler(server_name),
780
+ sampling_handler=self._create_sampling_handler(server_name),
781
+ ) # Client auto-detects STDIO transport from .py file
782
+ logger.info(f"Created MCP client for {server_name}")
783
+ logger.debug(f"Successfully created client for {server_name}")
784
+ return client
785
+ else:
786
+ logger.error(f"MCP server script not found: {server_path}", exc_info=True)
787
+ return None
788
+ else:
789
+ logger.error(f"Unsupported transport type '{transport_type}' for server: {server_name}")
790
+ return None
791
+
792
+ except Exception as e:
793
+ # Targeted debugging for MCP startup errors
794
+ error_type = type(e).__name__
795
+ logger.error(f"Error creating client for {server_name}: {error_type}: {e}")
796
+
797
+ # Provide specific debugging information based on error type and config
798
+ if "connection" in str(e).lower() or "refused" in str(e).lower():
799
+ if transport_type in ["http", "sse"]:
800
+ logger.error(f"DEBUG: Connection failed for HTTP/SSE server '{server_name}'")
801
+ logger.error(f" → URL: {config.get('url', 'Not specified')}")
802
+ logger.error(f" → Transport: {transport_type}")
803
+ logger.error(" → Check if server is running and accessible")
804
+ else:
805
+ logger.error(f"DEBUG: STDIO connection failed for server '{server_name}'")
806
+ logger.error(f" → Command: {config.get('command', 'Not specified')}")
807
+ logger.error(f" → CWD: {config.get('cwd', 'Not specified')}")
808
+ logger.error(" → Check if command exists and is executable")
809
+
810
+ elif "timeout" in str(e).lower():
811
+ logger.error(f"DEBUG: Timeout connecting to server '{server_name}'")
812
+ logger.error(" → Server may be slow to start or overloaded")
813
+ logger.error(" → Consider increasing timeout or checking server health")
814
+
815
+ elif "permission" in str(e).lower() or "access" in str(e).lower():
816
+ logger.error(f"DEBUG: Permission error for server '{server_name}'")
817
+ if config.get('cwd'):
818
+ logger.error(f" → Check directory permissions: {config.get('cwd')}")
819
+ if config.get('command'):
820
+ logger.error(f" → Check executable permissions: {config.get('command')}")
821
+
822
+ elif "module" in str(e).lower() or "import" in str(e).lower():
823
+ logger.error(f"DEBUG: Import/module error for server '{server_name}'")
824
+ logger.error(" → Check if required dependencies are installed")
825
+ logger.error(" → Check Python path and virtual environment")
826
+
827
+ elif "json" in str(e).lower() or "decode" in str(e).lower():
828
+ logger.error(f"DEBUG: JSON/protocol error for server '{server_name}'")
829
+ logger.error(" → Server may not be MCP-compatible")
830
+ logger.error(" → Check server output format")
831
+
832
+ else:
833
+ # Generic debugging info
834
+ logger.error(f"DEBUG: Generic error for server '{server_name}'")
835
+ logger.error(f" → Config: {config}")
836
+ logger.error(f" → Transport type: {transport_type}")
837
+
838
+ # Always show the full traceback in debug mode
839
+ logger.debug(f"Full traceback for {server_name}:", exc_info=True)
840
+ return None
841
+
842
+ async def initialize_clients(self):
843
+ """Initialize FastMCP clients for all configured servers in parallel."""
844
+ logger.info("Starting MCP client initialization for %d servers", len(self.servers_config))
845
+ logger.debug("MCP servers to initialize: %s", list(self.servers_config.keys()))
846
+
847
+ # Create tasks for parallel initialization
848
+ tasks = [
849
+ self._initialize_single_client(server_name, config)
850
+ for server_name, config in self.servers_config.items()
851
+ ]
852
+ server_names = list(self.servers_config.keys())
853
+
854
+ # Run all initialization tasks in parallel
855
+ results = await asyncio.gather(*tasks, return_exceptions=True)
856
+
857
+ # Process results and store successful clients
858
+ for server_name, result in zip(server_names, results):
859
+ if isinstance(result, Exception):
860
+ error_msg = f"{type(result).__name__}: {result}"
861
+ logger.error(f"Exception during client initialization for {server_name}: {error_msg}", exc_info=True)
862
+ self._record_server_failure(server_name, error_msg)
863
+ elif result is not None:
864
+ self.clients[server_name] = result
865
+ self._clear_server_failure(server_name)
866
+ logger.info(f"Successfully initialized client for {server_name}")
867
+ else:
868
+ self._record_server_failure(server_name, "Initialization returned None")
869
+ logger.warning(f"Failed to initialize client for {server_name}")
870
+
871
+ failed_servers = sorted(set(self.servers_config.keys()) - set(self.clients.keys()))
872
+ logger.info(
873
+ "MCP client initialization complete: %d/%d connected (%d failed)",
874
+ len(self.clients),
875
+ len(self.servers_config),
876
+ len(failed_servers),
877
+ )
878
+ logger.debug("MCP clients initialized: %s", list(self.clients.keys()))
879
+ logger.debug("MCP clients failed to initialize: %s", failed_servers)
880
+
881
+ async def reconnect_failed_servers(self, force: bool = False) -> Dict[str, Any]:
882
+ """Attempt to reconnect to servers that previously failed.
883
+
884
+ When ``force`` is False (default), this respects exponential backoff and
885
+ only attempts servers whose backoff delay has elapsed. When ``force`` is
886
+ True, backoff delays are ignored and all currently failed servers are
887
+ attempted immediately. The admin `/admin/mcp/reconnect` endpoint uses
888
+ ``force=True`` to provide an on-demand retry button.
889
+
890
+ Returns:
891
+ Dict with reconnection results including newly connected, still
892
+ failed, and skipped servers due to backoff.
893
+ """
894
+ if not self._failed_servers:
895
+ return {
896
+ "attempted": [],
897
+ "reconnected": [],
898
+ "still_failed": [],
899
+ "skipped_backoff": []
900
+ }
901
+
902
+ current_time = time.time()
903
+ attempted = []
904
+ reconnected = []
905
+ still_failed = []
906
+ skipped_backoff = []
907
+
908
+ for server_name, failure_info in list(self._failed_servers.items()):
909
+ # Skip if server is no longer in config
910
+ if server_name not in self.servers_config:
911
+ self._clear_server_failure(server_name)
912
+ continue
913
+
914
+ # Skip if already connected
915
+ if server_name in self.clients:
916
+ self._clear_server_failure(server_name)
917
+ continue
918
+
919
+ # Check backoff delay unless this is a forced reconnect
920
+ backoff_delay = self._calculate_backoff_delay(failure_info["attempt_count"])
921
+ time_since_last = current_time - failure_info["last_attempt"]
922
+
923
+ if not force and time_since_last < backoff_delay:
924
+ skipped_backoff.append({
925
+ "server": server_name,
926
+ "wait_remaining": backoff_delay - time_since_last,
927
+ "attempt_count": failure_info["attempt_count"]
928
+ })
929
+ continue
930
+
931
+ # Attempt reconnection
932
+ attempted.append(server_name)
933
+ config = self.servers_config[server_name]
934
+
935
+ try:
936
+ client = await self._initialize_single_client(server_name, config)
937
+ if client is not None:
938
+ self.clients[server_name] = client
939
+ self._clear_server_failure(server_name)
940
+ reconnected.append(server_name)
941
+ logger.info(f"Successfully reconnected to MCP server: {server_name}")
942
+
943
+ # Discover tools and prompts for the reconnected server
944
+ await self._discover_and_register_server(server_name, client)
945
+ else:
946
+ self._record_server_failure(server_name, "Reconnection returned None")
947
+ still_failed.append(server_name)
948
+ except Exception as e:
949
+ error_msg = f"{type(e).__name__}: {e}"
950
+ self._record_server_failure(server_name, error_msg)
951
+ still_failed.append(server_name)
952
+ logger.warning(f"Failed to reconnect to MCP server {server_name}: {error_msg}")
953
+
954
+ return {
955
+ "attempted": attempted,
956
+ "reconnected": reconnected,
957
+ "still_failed": still_failed,
958
+ "skipped_backoff": skipped_backoff
959
+ }
960
+
961
+ async def _discover_and_register_server(self, server_name: str, client: Client) -> None:
962
+ """Discover tools and prompts for a single server and register them."""
963
+ try:
964
+ # Discover tools
965
+ tool_data = await self._discover_tools_for_server(server_name, client)
966
+ self.available_tools[server_name] = tool_data
967
+
968
+ # Update tool index
969
+ if hasattr(self, "_tool_index"):
970
+ for tool in tool_data.get('tools', []):
971
+ full_name = f"{server_name}_{tool.name}"
972
+ self._tool_index[full_name] = {
973
+ 'server': server_name,
974
+ 'tool': tool
975
+ }
976
+
977
+ # Discover prompts
978
+ prompt_data = await self._discover_prompts_for_server(server_name, client)
979
+ self.available_prompts[server_name] = prompt_data
980
+
981
+ logger.info(
982
+ f"Registered server {server_name}: "
983
+ f"{len(tool_data.get('tools', []))} tools, "
984
+ f"{len(prompt_data.get('prompts', []))} prompts"
985
+ )
986
+ except Exception as e:
987
+ logger.error(f"Error discovering tools/prompts for {server_name}: {e}")
988
+
989
+ async def start_auto_reconnect(self) -> None:
990
+ """Start the background auto-reconnect task.
991
+
992
+ This task periodically attempts to reconnect to failed MCP servers
993
+ using exponential backoff. Only runs if FEATURE_MCP_AUTO_RECONNECT_ENABLED is true.
994
+ """
995
+ app_settings = config_manager.app_settings
996
+ if not app_settings.feature_mcp_auto_reconnect_enabled:
997
+ logger.info("MCP auto-reconnect is disabled (FEATURE_MCP_AUTO_RECONNECT_ENABLED=false)")
998
+ return
999
+
1000
+ if self._reconnect_running:
1001
+ logger.warning("Auto-reconnect task is already running")
1002
+ return
1003
+
1004
+ self._reconnect_running = True
1005
+ self._reconnect_task = asyncio.create_task(self._auto_reconnect_loop())
1006
+ logger.info("Started MCP auto-reconnect background task")
1007
+
1008
+ async def stop_auto_reconnect(self) -> None:
1009
+ """Stop the background auto-reconnect task."""
1010
+ self._reconnect_running = False
1011
+ if self._reconnect_task:
1012
+ self._reconnect_task.cancel()
1013
+ try:
1014
+ await self._reconnect_task
1015
+ except asyncio.CancelledError:
1016
+ pass
1017
+ self._reconnect_task = None
1018
+ logger.info("Stopped MCP auto-reconnect background task")
1019
+
1020
+ async def _auto_reconnect_loop(self) -> None:
1021
+ """Background loop that periodically attempts to reconnect failed servers."""
1022
+ app_settings = config_manager.app_settings
1023
+ base_interval = app_settings.mcp_reconnect_interval
1024
+
1025
+ while self._reconnect_running:
1026
+ try:
1027
+ await asyncio.sleep(base_interval)
1028
+
1029
+ if not self._failed_servers:
1030
+ continue
1031
+
1032
+ logger.debug(
1033
+ f"Auto-reconnect: checking {len(self._failed_servers)} failed servers"
1034
+ )
1035
+ result = await self.reconnect_failed_servers()
1036
+
1037
+ if result["reconnected"]:
1038
+ logger.info(
1039
+ f"Auto-reconnect: successfully reconnected {len(result['reconnected'])} servers: "
1040
+ f"{result['reconnected']}"
1041
+ )
1042
+ if result["still_failed"]:
1043
+ logger.debug(
1044
+ f"Auto-reconnect: {len(result['still_failed'])} servers still failed"
1045
+ )
1046
+
1047
+ except asyncio.CancelledError:
1048
+ break
1049
+ except Exception as e:
1050
+ logger.error(f"Error in auto-reconnect loop: {e}", exc_info=True)
1051
+ await asyncio.sleep(base_interval) # Wait before retrying
1052
+
1053
+ async def _discover_tools_for_server(self, server_name: str, client: Client) -> Dict[str, Any]:
1054
+ """Discover tools for a single server. Returns server tools data."""
1055
+ safe_server_name = sanitize_for_logging(server_name)
1056
+ server_config = self.servers_config.get(server_name, {})
1057
+ safe_config = sanitize_for_logging(str(server_config))
1058
+ discovery_timeout = config_manager.app_settings.mcp_discovery_timeout
1059
+ logger.debug("Tool discovery: starting for server '%s'", safe_server_name)
1060
+ logger.debug("Server config (sanitized): %s", safe_config)
1061
+ try:
1062
+ logger.debug("Opening client connection for %s", safe_server_name)
1063
+ async with client:
1064
+ logger.debug("Client connected for %s; listing tools", safe_server_name)
1065
+ tools = await asyncio.wait_for(client.list_tools(), timeout=discovery_timeout)
1066
+ logger.debug("Got %d tools from %s: %s", len(tools), safe_server_name, [tool.name for tool in tools])
1067
+
1068
+ # Log detailed tool information
1069
+ for i, tool in enumerate(tools):
1070
+ logger.debug(
1071
+ " Tool %d: name='%s', description='%s'",
1072
+ i + 1,
1073
+ tool.name,
1074
+ getattr(tool, 'description', 'No description'),
1075
+ )
1076
+
1077
+ server_data = {
1078
+ 'tools': tools,
1079
+ 'config': self.servers_config[server_name]
1080
+ }
1081
+ logger.debug("Stored %d tools for %s", len(tools), safe_server_name)
1082
+ return server_data
1083
+ except Exception as e:
1084
+ error_type = type(e).__name__
1085
+ error_msg = sanitize_for_logging(str(e))
1086
+ logger.error(f"TOOL DISCOVERY FAILED for '{safe_server_name}': {error_type}: {error_msg}")
1087
+
1088
+ # Targeted debugging for tool discovery errors
1089
+ error_lower = str(e).lower()
1090
+ if "connection" in error_lower or "refused" in error_lower:
1091
+ logger.error(f"DEBUG: Connection lost during tool discovery for '{safe_server_name}'")
1092
+ logger.error(" → Server may have crashed or disconnected")
1093
+ logger.error(" → Check server logs for startup errors")
1094
+ # Check if this is an HTTPS/SSL issue
1095
+ if "ssl" in error_lower or "certificate" in error_lower or "https" in error_lower:
1096
+ logger.error(" → SSL/HTTPS error detected")
1097
+ logger.error(" → On Windows, ensure SSL certificates are properly configured")
1098
+ logger.error(" → Try setting REQUESTS_CA_BUNDLE or SSL_CERT_FILE environment variables")
1099
+ elif "timeout" in error_lower:
1100
+ logger.error(f"DEBUG: Timeout during tool discovery for '{safe_server_name}'")
1101
+ logger.error(" → Server is slow to respond to list_tools() request")
1102
+ logger.error(" → Server may be overloaded or hanging")
1103
+ elif "json" in error_lower or "decode" in error_lower:
1104
+ logger.error(f"DEBUG: Protocol error during tool discovery for '{safe_server_name}'")
1105
+ logger.error(" → Server returned invalid MCP response")
1106
+ logger.error(" → Check if server implements MCP protocol correctly")
1107
+ elif "ssl" in error_lower or "certificate" in error_lower:
1108
+ logger.error(f"DEBUG: SSL/Certificate error during tool discovery for '{safe_server_name}'")
1109
+ logger.error(f" → URL: {server_config.get('url', 'N/A')}")
1110
+ logger.error(" → SSL certificate verification failed")
1111
+ logger.error(" → On Windows, this may require installing/updating CA certificates")
1112
+ logger.error(" → Check if the server URL uses HTTPS with a self-signed or untrusted certificate")
1113
+ else:
1114
+ logger.error(f"DEBUG: Generic tool discovery error for '{safe_server_name}'")
1115
+ logger.error(f" → Client type: {type(client).__name__}")
1116
+ logger.error(f" → Server URL: {server_config.get('url', 'N/A')}")
1117
+ logger.error(f" → Transport type: {server_config.get('transport', server_config.get('type', 'N/A'))}")
1118
+
1119
+ # Record failure for status/reconnect purposes
1120
+ self._record_server_failure(server_name, f"{error_type}: {error_msg}")
1121
+
1122
+ logger.debug(f"Full tool discovery traceback for {safe_server_name}:", exc_info=True)
1123
+
1124
+ server_data = {
1125
+ 'tools': [],
1126
+ 'config': server_config,
1127
+ }
1128
+ logger.debug(
1129
+ "Set empty tools list for failed server '%s' (config_present=%s)",
1130
+ safe_server_name,
1131
+ server_config is not None,
1132
+ )
1133
+ return server_data
1134
+
1135
+ async def discover_tools(self):
1136
+ """Discover tools from all MCP servers in parallel."""
1137
+ logger.info("Starting MCP tool discovery for %d connected servers", len(self.clients))
1138
+ logger.debug("Tool discovery servers: %s", list(self.clients.keys()))
1139
+ self.available_tools = {}
1140
+
1141
+ # Create tasks for parallel tool discovery
1142
+ tasks = [
1143
+ self._discover_tools_for_server(server_name, client)
1144
+ for server_name, client in self.clients.items()
1145
+ ]
1146
+ server_names = list(self.clients.keys())
1147
+
1148
+ # Run all tool discovery tasks in parallel
1149
+ results = await asyncio.gather(*tasks, return_exceptions=True)
1150
+
1151
+ # Process results and store server tools data
1152
+ for server_name, result in zip(server_names, results):
1153
+ # Skip clients whose config was removed during reload
1154
+ if server_name not in self.servers_config:
1155
+ logger.warning(
1156
+ f"Skipping tool discovery result for '{server_name}' because it is no longer in servers_config"
1157
+ )
1158
+ continue
1159
+
1160
+ if isinstance(result, Exception):
1161
+ logger.error(f"Exception during tool discovery for {server_name}: {result}", exc_info=True)
1162
+ # Record failure and set empty tools list for failed server
1163
+ self._record_server_failure(server_name, f"Exception during tool discovery: {result}")
1164
+ self.available_tools[server_name] = {
1165
+ 'tools': [],
1166
+ 'config': self.servers_config.get(server_name),
1167
+ }
1168
+ else:
1169
+ # Clear any previous discovery failure on success
1170
+ self._clear_server_failure(server_name)
1171
+ self.available_tools[server_name] = result
1172
+
1173
+ total_tools = sum(len(server_data.get('tools', [])) for server_data in self.available_tools.values())
1174
+ logger.info(
1175
+ "MCP tool discovery complete: %d tools across %d servers",
1176
+ total_tools,
1177
+ len(self.available_tools),
1178
+ )
1179
+ for server_name, server_data in self.available_tools.items():
1180
+ tool_names = [tool.name for tool in server_data.get('tools', [])]
1181
+ logger.debug("Tool discovery summary: %s: %d tools %s", server_name, len(tool_names), tool_names)
1182
+
1183
+ # Build tool index for quick lookups
1184
+ self._tool_index = {}
1185
+ for server_name, server_data in self.available_tools.items():
1186
+ if server_name == "canvas":
1187
+ self._tool_index["canvas_canvas"] = {
1188
+ 'server': 'canvas',
1189
+ 'tool': None # pseudo tool
1190
+ }
1191
+ else:
1192
+ for tool in server_data.get('tools', []):
1193
+ full_name = f"{server_name}_{tool.name}"
1194
+ self._tool_index[full_name] = {
1195
+ 'server': server_name,
1196
+ 'tool': tool
1197
+ }
1198
+
1199
+ async def _discover_prompts_for_server(self, server_name: str, client: Client) -> Dict[str, Any]:
1200
+ """Discover prompts for a single server. Returns server prompts data."""
1201
+ safe_server_name = sanitize_for_logging(server_name)
1202
+ server_config = self.servers_config.get(server_name, {})
1203
+ discovery_timeout = config_manager.app_settings.mcp_discovery_timeout
1204
+ logger.debug(f"Attempting to discover prompts from {safe_server_name}")
1205
+ try:
1206
+ logger.debug(f"Opening client connection for {safe_server_name}")
1207
+ async with client:
1208
+ logger.debug(f"Client connected for {safe_server_name}, listing prompts...")
1209
+ try:
1210
+ prompts = await asyncio.wait_for(client.list_prompts(), timeout=discovery_timeout)
1211
+ logger.debug(
1212
+ f"Got {len(prompts)} prompts from {safe_server_name}: {[prompt.name for prompt in prompts]}"
1213
+ )
1214
+ server_data = {
1215
+ 'prompts': prompts,
1216
+ 'config': server_config,
1217
+ }
1218
+ logger.info(f"Discovered {len(prompts)} prompts from {safe_server_name}")
1219
+ logger.debug(f"Successfully stored prompts for {safe_server_name}")
1220
+ return server_data
1221
+ except Exception as e:
1222
+ # Server might not support prompts or list_prompts() failed  store empty list
1223
+ logger.debug(
1224
+ f"Server {safe_server_name} does not support prompts or list_prompts() failed: {e}"
1225
+ )
1226
+ return {
1227
+ 'prompts': [],
1228
+ 'config': server_config,
1229
+ }
1230
+ except Exception as e:
1231
+ error_type = type(e).__name__
1232
+ error_msg = sanitize_for_logging(str(e))
1233
+ logger.error(f"PROMPT DISCOVERY FAILED for '{safe_server_name}': {error_type}: {error_msg}")
1234
+
1235
+ # Targeted debugging for prompt discovery errors
1236
+ error_lower = str(e).lower()
1237
+ if "connection" in error_lower or "refused" in error_lower:
1238
+ logger.error(f"DEBUG: Connection lost during prompt discovery for '{safe_server_name}'")
1239
+ logger.error(" → Server may have crashed or disconnected")
1240
+ # Check if this is an HTTPS/SSL issue
1241
+ if "ssl" in error_lower or "certificate" in error_lower or "https" in error_lower:
1242
+ logger.error(" → SSL/HTTPS error detected")
1243
+ logger.error(" → On Windows, ensure SSL certificates are properly configured")
1244
+ elif "timeout" in error_lower:
1245
+ logger.error(f"DEBUG: Timeout during prompt discovery for '{safe_server_name}'")
1246
+ logger.error(" → Server is slow to respond to list_prompts() request")
1247
+ elif "json" in error_lower or "decode" in error_lower:
1248
+ logger.error(f"DEBUG: Protocol error during prompt discovery for '{safe_server_name}'")
1249
+ logger.error(" → Server returned invalid MCP response for prompts")
1250
+ elif "ssl" in error_lower or "certificate" in error_lower:
1251
+ logger.error(f"DEBUG: SSL/Certificate error during prompt discovery for '{safe_server_name}'")
1252
+ logger.error(f" → URL: {server_config.get('url', 'N/A')}")
1253
+ logger.error(" → SSL certificate verification failed")
1254
+ logger.error(" → On Windows, this may require installing/updating CA certificates")
1255
+ else:
1256
+ logger.error(f"DEBUG: Generic prompt discovery error for '{safe_server_name}'")
1257
+
1258
+ # Record failure for status/reconnect purposes
1259
+ self._record_server_failure(server_name, f"{error_type}: {error_msg}")
1260
+
1261
+ logger.debug(f"Full prompt discovery traceback for {safe_server_name}:", exc_info=True)
1262
+ logger.debug(f"Set empty prompts list for failed server {safe_server_name}")
1263
+ return {
1264
+ 'prompts': [],
1265
+ 'config': server_config,
1266
+ }
1267
+
1268
+ async def discover_prompts(self):
1269
+ """Discover prompts from all MCP servers in parallel."""
1270
+ logger.info("Starting MCP prompt discovery for %d connected servers", len(self.clients))
1271
+ logger.debug("Prompt discovery servers: %s", list(self.clients.keys()))
1272
+ self.available_prompts = {}
1273
+
1274
+ # Create tasks for parallel prompt discovery
1275
+ tasks = [
1276
+ self._discover_prompts_for_server(server_name, client)
1277
+ for server_name, client in self.clients.items()
1278
+ ]
1279
+ server_names = list(self.clients.keys())
1280
+
1281
+ # Run all prompt discovery tasks in parallel
1282
+ results = await asyncio.gather(*tasks, return_exceptions=True)
1283
+
1284
+ # Process results and store server prompts data
1285
+ for server_name, result in zip(server_names, results):
1286
+ # Skip clients whose config was removed during reload
1287
+ if server_name not in self.servers_config:
1288
+ logger.warning(
1289
+ f"Skipping prompt discovery result for '{server_name}' because it is no longer in servers_config"
1290
+ )
1291
+ continue
1292
+
1293
+ if isinstance(result, Exception):
1294
+ logger.error(f"Exception during prompt discovery for {server_name}: {result}", exc_info=True)
1295
+ # Record failure and set empty prompts list for failed server
1296
+ self._record_server_failure(server_name, f"Exception during prompt discovery: {result}")
1297
+ self.available_prompts[server_name] = {
1298
+ 'prompts': [],
1299
+ 'config': self.servers_config.get(server_name),
1300
+ }
1301
+ else:
1302
+ # Clear any previous discovery failure on success
1303
+ self._clear_server_failure(server_name)
1304
+ self.available_prompts[server_name] = result
1305
+
1306
+ total_prompts = sum(len(server_data.get('prompts', [])) for server_data in self.available_prompts.values())
1307
+ logger.info(
1308
+ "MCP prompt discovery complete: %d prompts across %d servers",
1309
+ total_prompts,
1310
+ len(self.available_prompts),
1311
+ )
1312
+ for server_name, server_data in self.available_prompts.items():
1313
+ prompt_names = [prompt.name for prompt in server_data.get('prompts', [])]
1314
+ logger.debug("Prompt discovery summary: %s: %d prompts %s", server_name, len(prompt_names), prompt_names)
1315
+
1316
+ def get_server_groups(self, server_name: str) -> List[str]:
1317
+ """Get required groups for a server."""
1318
+ if server_name in self.servers_config:
1319
+ return self.servers_config[server_name].get("groups", [])
1320
+ return []
1321
+
1322
+ def get_available_servers(self) -> List[str]:
1323
+ """Get list of configured servers."""
1324
+ return list(self.servers_config.keys())
1325
+
1326
+ def get_tools_for_servers(self, server_names: List[str]) -> Dict[str, Any]:
1327
+ """Get tools and their schemas for selected servers."""
1328
+ tools_schema = []
1329
+ server_tool_mapping = {}
1330
+
1331
+ for server_name in server_names:
1332
+ # Handle canvas pseudo-tool
1333
+ if server_name == "canvas":
1334
+ canvas_tool_schema = {
1335
+ "type": "function",
1336
+ "function": {
1337
+ "name": "canvas_canvas",
1338
+ "description": "Display final rendered content in a visual canvas panel. Use this for: 1) Complete code (not code discussions), 2) Final reports/documents (not report discussions), 3) Data visualizations, 4) Any polished content that should be viewed separately from the conversation. Put the actual content in the canvas, keep discussions in chat.",
1339
+ "parameters": {
1340
+ "type": "object",
1341
+ "properties": {
1342
+ "content": {
1343
+ "type": "string",
1344
+ "description": "The content to display in the canvas. Can be markdown, code, or plain text."
1345
+ }
1346
+ },
1347
+ "required": ["content"]
1348
+ }
1349
+ }
1350
+ }
1351
+ tools_schema.append(canvas_tool_schema)
1352
+ server_tool_mapping["canvas_canvas"] = {
1353
+ 'server': 'canvas',
1354
+ 'tool_name': 'canvas'
1355
+ }
1356
+ elif server_name in self.available_tools:
1357
+ server_tools = self.available_tools[server_name]['tools']
1358
+ for tool in server_tools:
1359
+ # Convert MCP tool format to OpenAI function calling format
1360
+ tool_schema = {
1361
+ "type": "function",
1362
+ "function": {
1363
+ "name": f"{server_name}_{tool.name}",
1364
+ "description": tool.description or '',
1365
+ "parameters": tool.inputSchema or {}
1366
+ }
1367
+ }
1368
+ # log the server -> function name
1369
+ # logger.info(f"Adding tool {tool.name} for server {server_name} ")
1370
+ tools_schema.append(tool_schema)
1371
+ server_tool_mapping[f"{server_name}_{tool.name}"] = {
1372
+ 'server': server_name,
1373
+ 'tool_name': tool.name
1374
+ }
1375
+
1376
+ return {
1377
+ 'tools': tools_schema,
1378
+ 'mapping': server_tool_mapping
1379
+ }
1380
+
1381
+ def _requires_user_auth(self, server_name: str) -> bool:
1382
+ """Check if a server requires per-user authentication.
1383
+
1384
+ Returns True for servers with auth_type 'oauth', 'jwt', 'bearer', or 'api_key'.
1385
+ These servers need user-specific tokens rather than shared/admin tokens.
1386
+ """
1387
+ config = self.servers_config.get(server_name, {})
1388
+ auth_type = config.get("auth_type", "none")
1389
+ return auth_type in ("oauth", "jwt", "bearer", "api_key")
1390
+
1391
+ async def _get_user_client(
1392
+ self,
1393
+ server_name: str,
1394
+ user_email: str,
1395
+ ) -> Optional[Client]:
1396
+ """Get or create a user-specific client for servers requiring per-user auth.
1397
+
1398
+ Args:
1399
+ server_name: Name of the MCP server
1400
+ user_email: User's email address
1401
+
1402
+ Returns:
1403
+ FastMCP Client configured with user's token, or None if no token available
1404
+ """
1405
+ from atlas.modules.mcp_tools.token_storage import get_token_storage
1406
+
1407
+ token_storage = get_token_storage()
1408
+ cache_key = (user_email.lower(), server_name)
1409
+
1410
+ # Check cache first, but validate token is still valid
1411
+ async with self._user_clients_lock:
1412
+ if cache_key in self._user_clients:
1413
+ # Verify the token is still valid before returning cached client
1414
+ stored_token = token_storage.get_valid_token(user_email, server_name)
1415
+ if stored_token is not None:
1416
+ return self._user_clients[cache_key]
1417
+ else:
1418
+ # Token expired or removed, invalidate cached client
1419
+ logger.debug(
1420
+ f"Token expired for user on server '{server_name}', "
1421
+ f"invalidating cached client"
1422
+ )
1423
+ del self._user_clients[cache_key]
1424
+
1425
+ # Get user's token from storage
1426
+ logger.debug(f"[AUTH] Looking up token for server='{server_name}'")
1427
+ stored_token = token_storage.get_valid_token(user_email, server_name)
1428
+ logger.debug(f"[AUTH] Token found: {stored_token is not None}")
1429
+
1430
+ if stored_token is None:
1431
+ logger.debug(
1432
+ f"[AUTH] No valid token for server '{server_name}' - user needs to authenticate"
1433
+ )
1434
+ return None
1435
+
1436
+ # Get server config
1437
+ config = self.servers_config.get(server_name, {})
1438
+ url = config.get("url")
1439
+
1440
+ if not url:
1441
+ logger.error(f"No URL configured for server '{server_name}'")
1442
+ return None
1443
+
1444
+ # Ensure URL has protocol
1445
+ if not url.startswith(("http://", "https://")):
1446
+ url = f"http://{url}"
1447
+
1448
+ # Create client with user's token
1449
+ try:
1450
+ log_handler = self._create_log_handler(server_name)
1451
+ auth_type = config.get("auth_type", "bearer")
1452
+
1453
+ # For API key auth, use custom header; for bearer/jwt/oauth, use auth parameter
1454
+ if auth_type == "api_key":
1455
+ # Use custom header for API key authentication
1456
+ auth_header = config.get("auth_header", "X-API-Key")
1457
+ logger.debug(
1458
+ f"Creating API key client for '{server_name}' with header '{auth_header}'"
1459
+ )
1460
+ transport = StreamableHttpTransport(
1461
+ url,
1462
+ headers={auth_header: stored_token.token_value},
1463
+ )
1464
+ client = Client(
1465
+ transport=transport,
1466
+ log_handler=log_handler,
1467
+ elicitation_handler=self._create_elicitation_handler(server_name),
1468
+ sampling_handler=self._create_sampling_handler(server_name),
1469
+ )
1470
+ else:
1471
+ # FastMCP Client accepts auth= as a string (bearer token)
1472
+ client = Client(
1473
+ url,
1474
+ auth=stored_token.token_value,
1475
+ log_handler=log_handler,
1476
+ elicitation_handler=self._create_elicitation_handler(server_name),
1477
+ sampling_handler=self._create_sampling_handler(server_name),
1478
+ )
1479
+
1480
+ # Cache the client
1481
+ async with self._user_clients_lock:
1482
+ self._user_clients[cache_key] = client
1483
+
1484
+ logger.info(
1485
+ f"Created user-specific client for server '{server_name}' (auth_type={auth_type})"
1486
+ )
1487
+ return client
1488
+
1489
+ except Exception as e:
1490
+ logger.error(
1491
+ f"Failed to create user client for server '{server_name}': {e}"
1492
+ )
1493
+ return None
1494
+
1495
+ async def _invalidate_user_client(self, user_email: str, server_name: str) -> None:
1496
+ """Remove a user's cached client (e.g., when token is revoked)."""
1497
+ cache_key = (user_email.lower(), server_name)
1498
+ async with self._user_clients_lock:
1499
+ if cache_key in self._user_clients:
1500
+ del self._user_clients[cache_key]
1501
+ logger.debug(f"Invalidated user client cache for server '{server_name}'")
1502
+
1503
+ async def call_tool(
1504
+ self,
1505
+ server_name: str,
1506
+ tool_name: str,
1507
+ arguments: Dict[str, Any],
1508
+ *,
1509
+ progress_handler: Optional[Any] = None,
1510
+ elicitation_handler: Optional[Any] = None,
1511
+ user_email: Optional[str] = None,
1512
+ ) -> Any:
1513
+ """Call a specific tool on an MCP server.
1514
+
1515
+ Args:
1516
+ server_name: Name of the MCP server
1517
+ tool_name: Name of the tool to call
1518
+ arguments: Tool arguments
1519
+ progress_handler: Optional progress callback handler
1520
+ elicitation_handler: Optional elicitation callback handler. Prefer the built-in
1521
+ elicitation routing (registered at client creation time) for shared clients.
1522
+ user_email: User's email for per-user authentication (required for oauth/jwt servers)
1523
+ """
1524
+ # Determine which client to use
1525
+ client = None
1526
+
1527
+ # Check if this server requires per-user authentication
1528
+ if self._requires_user_auth(server_name):
1529
+ logger.debug(f"Server '{server_name}' requires user auth, user_email={user_email}")
1530
+ if user_email:
1531
+ client = await self._get_user_client(server_name, user_email)
1532
+ logger.debug(f"_get_user_client for '{server_name}' returned client: {client is not None}")
1533
+ if client is None:
1534
+ # Get auth type and build OAuth URL if applicable
1535
+ server_config = self.servers_config.get(server_name, {})
1536
+ auth_type = server_config.get("auth_type", "oauth")
1537
+ oauth_start_url = None
1538
+ if auth_type == "oauth":
1539
+ # Build OAuth start URL for automatic redirect
1540
+ oauth_start_url = f"/api/mcp/auth/{server_name}/oauth/start"
1541
+ raise AuthenticationRequiredException(
1542
+ server_name=server_name,
1543
+ auth_type=auth_type,
1544
+ message=f"Server '{server_name}' requires authentication.",
1545
+ oauth_start_url=oauth_start_url,
1546
+ )
1547
+ else:
1548
+ server_config = self.servers_config.get(server_name, {})
1549
+ auth_type = server_config.get("auth_type", "oauth")
1550
+ raise AuthenticationRequiredException(
1551
+ server_name=server_name,
1552
+ auth_type=auth_type,
1553
+ message=f"Server '{server_name}' requires authentication but no user context.",
1554
+ oauth_start_url=f"/api/mcp/auth/{server_name}/oauth/start" if auth_type == "oauth" else None,
1555
+ )
1556
+ else:
1557
+ # Use shared client for servers without per-user auth
1558
+ if server_name not in self.clients:
1559
+ raise ValueError(f"No client available for server: {server_name}")
1560
+ client = self.clients[server_name]
1561
+
1562
+ call_timeout = config_manager.app_settings.mcp_call_timeout
1563
+ try:
1564
+ # Set elicitation callback before opening the client context.
1565
+ # FastMCP negotiates supported capabilities during session init.
1566
+ if elicitation_handler is not None:
1567
+ client.set_elicitation_callback(elicitation_handler)
1568
+
1569
+ async with client:
1570
+ # Pass progress handler if provided (fastmcp >= 2.3.5)
1571
+ kwargs = {}
1572
+ if progress_handler is not None:
1573
+ kwargs["progress_handler"] = progress_handler
1574
+
1575
+ result = await asyncio.wait_for(
1576
+ client.call_tool(tool_name, arguments, **kwargs),
1577
+ timeout=call_timeout,
1578
+ )
1579
+ logger.info(f"Successfully called {sanitize_for_logging(tool_name)} on {sanitize_for_logging(server_name)}")
1580
+ return result
1581
+ except asyncio.TimeoutError:
1582
+ error_msg = f"Tool call '{tool_name}' on server '{server_name}' timed out after {call_timeout}s"
1583
+ logger.error(error_msg)
1584
+ self._record_server_failure(server_name, error_msg)
1585
+ raise TimeoutError(error_msg)
1586
+ except Exception as e:
1587
+ logger.error(f"Error calling {tool_name} on {server_name}: {e}")
1588
+ raise
1589
+
1590
+ async def get_prompt(self, server_name: str, prompt_name: str, arguments: Dict[str, Any] = None) -> Any:
1591
+ """Get a specific prompt from an MCP server."""
1592
+ if server_name not in self.clients:
1593
+ raise ValueError(f"No client available for server: {server_name}")
1594
+
1595
+ client = self.clients[server_name]
1596
+ try:
1597
+ async with client:
1598
+ if arguments:
1599
+ result = await client.get_prompt(prompt_name, arguments)
1600
+ else:
1601
+ result = await client.get_prompt(prompt_name)
1602
+ logger.info(f"Successfully retrieved prompt {prompt_name} from {server_name}")
1603
+ return result
1604
+ except Exception as e:
1605
+ logger.error(f"Error getting prompt {prompt_name} from {server_name}: {e}")
1606
+ raise
1607
+
1608
+ def get_available_prompts_for_servers(self, server_names: List[str]) -> Dict[str, Any]:
1609
+ """Get available prompts for selected servers."""
1610
+ available_prompts = {}
1611
+
1612
+ for server_name in server_names:
1613
+ if server_name in self.available_prompts:
1614
+ server_prompts = self.available_prompts[server_name]['prompts']
1615
+ for prompt in server_prompts:
1616
+ prompt_key = f"{server_name}_{prompt.name}"
1617
+ available_prompts[prompt_key] = {
1618
+ 'server': server_name,
1619
+ 'name': prompt.name,
1620
+ 'description': prompt.description or '',
1621
+ 'arguments': prompt.arguments or {}
1622
+ }
1623
+
1624
+ return available_prompts
1625
+
1626
+ async def get_authorized_servers(self, user_email: str, auth_check_func) -> List[str]:
1627
+ """Get list of servers the user is authorized to use."""
1628
+ authorized_servers = []
1629
+ for server_name, server_config in self.servers_config.items():
1630
+ if not server_config.get("enabled", True):
1631
+ continue
1632
+
1633
+ required_groups = server_config.get("groups", [])
1634
+ if not required_groups:
1635
+ authorized_servers.append(server_name)
1636
+ continue
1637
+
1638
+ # Check if user is in any of the required groups
1639
+ # We need to await each call and collect results before using any()
1640
+ group_checks = [await auth_check_func(user_email, group) for group in required_groups]
1641
+ if any(group_checks):
1642
+ authorized_servers.append(server_name)
1643
+ return authorized_servers
1644
+
1645
+ def get_available_tools(self) -> List[str]:
1646
+ """Get list of available tool names."""
1647
+ available_tools = []
1648
+ for server_name, server_data in self.available_tools.items():
1649
+ if server_name == "canvas":
1650
+ available_tools.append("canvas_canvas")
1651
+ else:
1652
+ for tool in server_data.get('tools', []):
1653
+ available_tools.append(f"{server_name}_{tool.name}")
1654
+ return available_tools
1655
+
1656
+ def get_tools_schema(self, tool_names: List[str]) -> List[Dict[str, Any]]:
1657
+ """Get schemas for specified tools.
1658
+
1659
+ Previous implementation attempted to derive the server name by stripping the last
1660
+ underscore-delimited segment from the fully-qualified tool name. This broke when
1661
+ the original (per-server) tool names themselves contained underscores (e.g.
1662
+ server 'ui-demo' with tool 'create_form_demo' produced full name
1663
+ 'ui-demo_create_form_demo'; naive splitting yielded a *server* of
1664
+ 'ui-demo_create_form' which does not exist, causing the schema lookup to fail and
1665
+ returning an empty set. This method now directly matches fully-qualified tool
1666
+ names against the discovered inventory instead of guessing via string surgery.
1667
+ """
1668
+
1669
+ if not tool_names:
1670
+ return []
1671
+
1672
+ # Build (or reuse) an index of full tool name -> (server_name, tool_obj)
1673
+ # so we can do O(1) lookups without fragile string parsing.
1674
+ if not hasattr(self, "_tool_index") or not getattr(self, "_tool_index"):
1675
+ index = {}
1676
+ for server_name, server_data in self.available_tools.items():
1677
+ if server_name == "canvas":
1678
+ index["canvas_canvas"] = {
1679
+ 'server': 'canvas',
1680
+ 'tool': None # pseudo tool
1681
+ }
1682
+ else:
1683
+ for tool in server_data.get('tools', []):
1684
+ full_name = f"{server_name}_{tool.name}"
1685
+ index[full_name] = {
1686
+ 'server': server_name,
1687
+ 'tool': tool
1688
+ }
1689
+ self._tool_index = index
1690
+ else:
1691
+ index = self._tool_index
1692
+
1693
+ matched = []
1694
+ missing = []
1695
+ for requested in tool_names:
1696
+ entry = index.get(requested)
1697
+ if not entry:
1698
+ missing.append(requested)
1699
+ continue
1700
+ if requested == "canvas_canvas":
1701
+ # Recreate the canvas schema (kept in one place – duplicate logic intentional
1702
+ # to avoid coupling to get_tools_for_servers which returns superset data)
1703
+ matched.append({
1704
+ "type": "function",
1705
+ "function": {
1706
+ "name": "canvas_canvas",
1707
+ "description": "Display final rendered content in a visual canvas panel. Use this for: 1) Complete code (not code discussions), 2) Final reports/documents (not report discussions), 3) Data visualizations, 4) Any polished content that should be viewed separately from the conversation. Put the actual content in the canvas, keep discussions in chat.",
1708
+ "parameters": {
1709
+ "type": "object",
1710
+ "properties": {
1711
+ "content": {
1712
+ "type": "string",
1713
+ "description": "The content to display in the canvas. Can be markdown, code, or plain text."
1714
+ }
1715
+ },
1716
+ "required": ["content"]
1717
+ }
1718
+ }
1719
+ })
1720
+ else:
1721
+ tool = entry['tool']
1722
+ matched.append({
1723
+ "type": "function",
1724
+ "function": {
1725
+ "name": requested,
1726
+ "description": getattr(tool, 'description', '') or '',
1727
+ "parameters": getattr(tool, 'inputSchema', {}) or {}
1728
+ }
1729
+ })
1730
+
1731
+
1732
+
1733
+ return matched
1734
+
1735
+ # ------------------------------------------------------------
1736
+ # Internal helpers
1737
+ # ------------------------------------------------------------
1738
+ def _normalize_mcp_tool_result(self, raw_result: Any) -> Dict[str, Any]:
1739
+ """Normalize a FastMCP CallToolResult (or similar object) into our contract.
1740
+
1741
+ Returns a dict shaped like:
1742
+ {
1743
+ "results": <payload or string>,
1744
+ "meta_data": {...optional...},
1745
+ "returned_file_names": [...optional...],
1746
+ "returned_file_count": N (if file contents present)
1747
+ }
1748
+
1749
+ Notes:
1750
+ - We never inline base64 file contents here to avoid prompt bloat.
1751
+ - Handles legacy key forms (result, meta-data, metadata).
1752
+ - Falls back to stringifying the raw result if structured extraction fails.
1753
+ """
1754
+ normalized: Dict[str, Any] = {}
1755
+ structured: Dict[str, Any] = {}
1756
+
1757
+ # Attempt extraction in priority order
1758
+ try:
1759
+ if hasattr(raw_result, "structured_content") and raw_result.structured_content: # type: ignore[attr-defined]
1760
+ structured = raw_result.structured_content # type: ignore[attr-defined]
1761
+ elif hasattr(raw_result, "data") and raw_result.data: # type: ignore[attr-defined]
1762
+ structured = raw_result.data # type: ignore[attr-defined]
1763
+ else:
1764
+ # Fallback: extract text content from content array
1765
+ if hasattr(raw_result, "content"):
1766
+ contents = getattr(raw_result, "content")
1767
+ if contents:
1768
+ # Collect all text from TextContent items
1769
+ text_parts = []
1770
+ for item in contents:
1771
+ if hasattr(item, "type") and getattr(item, "type") == "text":
1772
+ text = getattr(item, "text", None)
1773
+ if text:
1774
+ text_parts.append(text)
1775
+
1776
+ if text_parts:
1777
+ combined_text = "\n".join(text_parts)
1778
+ # Try to parse as JSON if it looks like JSON
1779
+ if combined_text.strip().startswith(("{", "[")):
1780
+ try:
1781
+ logger.info("MCP tool result normalization: using content text JSON fallback for structured extraction")
1782
+ structured = json.loads(combined_text)
1783
+ except Exception: # pragma: no cover - defensive
1784
+ # Not valid JSON, use as plain text result
1785
+ structured = {"results": combined_text}
1786
+ else:
1787
+ # Plain text - use as results directly
1788
+ structured = {"results": combined_text}
1789
+ except Exception as parse_err: # pragma: no cover - defensive
1790
+ logger.debug(f"Non-fatal parse issue extracting structured tool result: {parse_err}")
1791
+
1792
+ if isinstance(structured, dict):
1793
+ # Support both correct and legacy key forms
1794
+ results_payload = structured.get("results") or structured.get("result")
1795
+ meta_payload = (
1796
+ structured.get("meta_data")
1797
+ or structured.get("meta-data")
1798
+ or structured.get("metadata")
1799
+ )
1800
+ returned_file_names = structured.get("returned_file_names")
1801
+ returned_file_contents = structured.get("returned_file_contents")
1802
+
1803
+ if results_payload is not None:
1804
+ normalized["results"] = results_payload
1805
+ if meta_payload is not None:
1806
+ try:
1807
+ # Heuristic to prevent very large meta blobs
1808
+ if len(json.dumps(meta_payload)) < 4000:
1809
+ normalized["meta_data"] = meta_payload
1810
+ else:
1811
+ normalized["meta_data_truncated"] = True
1812
+ except Exception: # pragma: no cover
1813
+ normalized["meta_data_parse_error"] = True
1814
+ if returned_file_names:
1815
+ normalized["returned_file_names"] = returned_file_names
1816
+ if returned_file_contents:
1817
+ normalized["returned_file_count"] = (
1818
+ len(returned_file_contents) if isinstance(returned_file_contents, (list, tuple)) else 1
1819
+ )
1820
+
1821
+ # Phase 5 fallback: if no explicit results key, treat *entire* structured dict (minus large/base64 fields) as results
1822
+ if "results" not in normalized:
1823
+ # Prune potentially huge / sensitive keys before fallback
1824
+ prune_keys = {"returned_file_contents"}
1825
+ pruned = {k: v for k, v in structured.items() if k not in prune_keys}
1826
+ try:
1827
+ serialized = json.dumps(pruned)
1828
+ if len(serialized) <= 8000: # size guard
1829
+ normalized["results"] = pruned
1830
+ else:
1831
+ normalized["results_summary"] = {
1832
+ "keys": list(pruned.keys()),
1833
+ "omitted_due_to_size": len(serialized)
1834
+ }
1835
+ except Exception: # pragma: no cover
1836
+ # Fallback to string repr if serialization fails
1837
+ normalized.setdefault("results", str(pruned))
1838
+
1839
+ if not normalized:
1840
+ normalized = {"results": str(raw_result)}
1841
+ return normalized
1842
+
1843
+ async def execute_tool(
1844
+ self,
1845
+ tool_call: ToolCall,
1846
+ context: Optional[Dict[str, Any]] = None
1847
+ ) -> ToolResult:
1848
+ """Execute a tool call."""
1849
+ logger.debug("ToolManager.execute_tool: tool=%s", tool_call.name)
1850
+ # Handle canvas pseudo-tool
1851
+ if tool_call.name == "canvas_canvas":
1852
+ # Canvas tool just returns the content - it's handled by frontend
1853
+ content = tool_call.arguments.get("content", "")
1854
+ return ToolResult(
1855
+ tool_call_id=tool_call.id,
1856
+ content=f"Canvas content displayed: {content[:100]}..." if len(content) > 100 else f"Canvas content displayed: {content}",
1857
+ success=True
1858
+ )
1859
+
1860
+ # Use the tool index to get server and tool name (avoids parsing issues with dashes/underscores)
1861
+ if not hasattr(self, "_tool_index") or not getattr(self, "_tool_index"):
1862
+ # Build tool index if not available (same logic as in get_tools_schema)
1863
+ index = {}
1864
+ for server_name, server_data in self.available_tools.items():
1865
+ if server_name == "canvas":
1866
+ index["canvas_canvas"] = {
1867
+ 'server': 'canvas',
1868
+ 'tool': None # pseudo tool
1869
+ }
1870
+ else:
1871
+ for tool in server_data.get('tools', []):
1872
+ full_name = f"{server_name}_{tool.name}"
1873
+ index[full_name] = {
1874
+ 'server': server_name,
1875
+ 'tool': tool
1876
+ }
1877
+ self._tool_index = index
1878
+
1879
+ # Look up the tool in our index
1880
+ tool_entry = self._tool_index.get(tool_call.name)
1881
+ if not tool_entry:
1882
+ return ToolResult(
1883
+ tool_call_id=tool_call.id,
1884
+ content=f"Tool not found: {tool_call.name}",
1885
+ success=False,
1886
+ error=f"Tool not found: {tool_call.name}"
1887
+ )
1888
+
1889
+ server_name = tool_entry['server']
1890
+ actual_tool_name = tool_entry['tool'].name if tool_entry['tool'] else tool_call.name
1891
+
1892
+ try:
1893
+ update_cb = None
1894
+ user_email = None
1895
+ if isinstance(context, dict):
1896
+ update_cb = context.get("update_callback")
1897
+ user_email = context.get("user_email")
1898
+
1899
+ if update_cb is None:
1900
+ logger.warning(
1901
+ f"Executing tool '{tool_call.name}' without update_callback - "
1902
+ f"elicitation will not work. Context type: {type(context)}"
1903
+ )
1904
+ else:
1905
+ logger.debug(f"Executing tool '{tool_call.name}' with update_callback present")
1906
+
1907
+ async def _tool_log_callback(
1908
+ log_server_name: str,
1909
+ level: str,
1910
+ message: str,
1911
+ extra: Dict[str, Any],
1912
+ ) -> None:
1913
+ if update_cb is None:
1914
+ return
1915
+ try:
1916
+ # Deferred import to avoid cycles
1917
+ from atlas.application.chat.utilities.event_notifier import notify_tool_log
1918
+ await notify_tool_log(
1919
+ server_name=log_server_name,
1920
+ tool_name=tool_call.name,
1921
+ tool_call_id=tool_call.id,
1922
+ level=level,
1923
+ message=sanitize_for_logging(message),
1924
+ extra=extra,
1925
+ update_callback=update_cb,
1926
+ )
1927
+ except Exception:
1928
+ logger.debug("Tool log forwarding failed", exc_info=True)
1929
+
1930
+ # Build a progress handler that forwards to UI if provided via context
1931
+ async def _progress_handler(progress: float, total: Optional[float], message: Optional[str]) -> None:
1932
+ try:
1933
+ if update_cb is not None:
1934
+ # Deferred import to avoid cycles
1935
+ from atlas.application.chat.utilities.event_notifier import notify_tool_progress
1936
+ await notify_tool_progress(
1937
+ tool_call_id=tool_call.id,
1938
+ tool_name=tool_call.name,
1939
+ progress=progress,
1940
+ total=total,
1941
+ message=message,
1942
+ update_callback=update_cb,
1943
+ )
1944
+ except Exception:
1945
+ logger.debug("Progress handler forwarding failed", exc_info=True)
1946
+
1947
+ if update_cb is not None:
1948
+ async with self._use_log_callback(_tool_log_callback):
1949
+ async with self._use_elicitation_context(server_name, tool_call, update_cb):
1950
+ async with self._use_sampling_context(server_name, tool_call, update_cb):
1951
+ raw_result = await self.call_tool(
1952
+ server_name,
1953
+ actual_tool_name,
1954
+ tool_call.arguments,
1955
+ progress_handler=_progress_handler,
1956
+ user_email=user_email,
1957
+ )
1958
+ else:
1959
+ async with self._use_elicitation_context(server_name, tool_call, update_cb):
1960
+ async with self._use_sampling_context(server_name, tool_call, update_cb):
1961
+ raw_result = await self.call_tool(
1962
+ server_name,
1963
+ actual_tool_name,
1964
+ tool_call.arguments,
1965
+ progress_handler=_progress_handler,
1966
+ user_email=user_email,
1967
+ )
1968
+ normalized_content = self._normalize_mcp_tool_result(raw_result)
1969
+ content_str = json.dumps(normalized_content, ensure_ascii=False)
1970
+
1971
+ # Extract v2 MCP response components (supports dict or FastMCP result objects)
1972
+ artifacts: List[Dict[str, Any]] = []
1973
+ display_config: Optional[Dict[str, Any]] = None
1974
+ meta_data: Optional[Dict[str, Any]] = None
1975
+
1976
+ try:
1977
+ if isinstance(raw_result, dict):
1978
+ structured = raw_result
1979
+ else:
1980
+ structured = {}
1981
+ if hasattr(raw_result, "structured_content") and raw_result.structured_content: # type: ignore[attr-defined]
1982
+ sc = raw_result.structured_content # type: ignore[attr-defined]
1983
+ if isinstance(sc, dict):
1984
+ structured = sc
1985
+ elif hasattr(raw_result, "data") and raw_result.data: # type: ignore[attr-defined]
1986
+ dt = raw_result.data # type: ignore[attr-defined]
1987
+ if isinstance(dt, dict):
1988
+ structured = dt
1989
+ else:
1990
+ # Fallback: parse first textual content if JSON-like
1991
+ # This handles MCP responses that return data only in content[0].text
1992
+ if hasattr(raw_result, "content"):
1993
+ contents = getattr(raw_result, "content")
1994
+ if contents and len(contents) > 0 and hasattr(contents[0], "text"):
1995
+ first_text = getattr(contents[0], "text")
1996
+ if isinstance(first_text, str) and first_text.strip().startswith("{"):
1997
+ try:
1998
+ structured = json.loads(first_text)
1999
+ except Exception:
2000
+ pass
2001
+
2002
+ if isinstance(structured, dict) and structured:
2003
+ # Extract artifacts
2004
+ raw_artifacts = structured.get("artifacts")
2005
+ if isinstance(raw_artifacts, list):
2006
+ for art in raw_artifacts:
2007
+ if isinstance(art, dict):
2008
+ name = art.get("name")
2009
+ b64 = art.get("b64")
2010
+ if name and b64:
2011
+ artifacts.append(art)
2012
+
2013
+ # Extract display
2014
+ disp = structured.get("display")
2015
+ if isinstance(disp, dict):
2016
+ display_config = disp
2017
+
2018
+ # Extract metadata
2019
+ md = structured.get("meta_data")
2020
+ if isinstance(md, dict):
2021
+ meta_data = md
2022
+
2023
+ # Extract ImageContent from the content array
2024
+ # Allowlist of safe image MIME types
2025
+ ALLOWED_IMAGE_MIMES = {
2026
+ "image/png", "image/jpeg", "image/gif",
2027
+ "image/svg+xml", "image/webp", "image/bmp"
2028
+ }
2029
+
2030
+ if hasattr(raw_result, "content"):
2031
+ contents = getattr(raw_result, "content")
2032
+ if isinstance(contents, list):
2033
+ image_counter = 0
2034
+ for item in contents:
2035
+ # Check if this is an ImageContent object
2036
+ if hasattr(item, "type") and getattr(item, "type") == "image":
2037
+ data = getattr(item, "data", None)
2038
+ mime_type = getattr(item, "mimeType", None)
2039
+
2040
+ # Validate mime type against allowlist
2041
+ if mime_type and mime_type not in ALLOWED_IMAGE_MIMES:
2042
+ logger.warning(
2043
+ f"Skipping ImageContent with unsupported mime type: {mime_type}"
2044
+ )
2045
+ continue
2046
+
2047
+ # Validate base64 data
2048
+ if data:
2049
+ try:
2050
+ import base64
2051
+ base64.b64decode(data, validate=True)
2052
+ except Exception:
2053
+ logger.warning(
2054
+ "Skipping ImageContent with invalid base64 data"
2055
+ )
2056
+ continue
2057
+
2058
+ if data and mime_type:
2059
+ # Generate a filename based on image counter and mime type
2060
+ # Use mcp_image_ prefix to avoid collisions with structured artifacts
2061
+ ext = mime_type.split("/")[-1] if "/" in mime_type else "bin"
2062
+ filename = f"mcp_image_{image_counter}.{ext}"
2063
+
2064
+ # Create artifact in the expected format
2065
+ artifact = {
2066
+ "name": filename,
2067
+ "b64": data,
2068
+ "mime": mime_type,
2069
+ "viewer": "image",
2070
+ "description": f"Image returned by {tool_call.name}"
2071
+ }
2072
+ artifacts.append(artifact)
2073
+ logger.debug(f"Extracted ImageContent as artifact: {filename} ({mime_type})")
2074
+
2075
+ # If no display config exists and this is the first image, auto-open canvas
2076
+ if not display_config and image_counter == 0:
2077
+ display_config = {
2078
+ "primary_file": filename,
2079
+ "open_canvas": True
2080
+ }
2081
+
2082
+ image_counter += 1
2083
+ except Exception:
2084
+ logger.warning("Error extracting v2 MCP components from tool result", exc_info=True)
2085
+
2086
+ log_metric("tool_call", user_email, tool_name=actual_tool_name)
2087
+
2088
+ return ToolResult(
2089
+ tool_call_id=tool_call.id,
2090
+ content=content_str,
2091
+ success=True,
2092
+ artifacts=artifacts,
2093
+ display_config=display_config,
2094
+ meta_data=meta_data
2095
+ )
2096
+ except Exception as e:
2097
+ logger.error(f"Error executing tool {tool_call.name}: {e}")
2098
+
2099
+ log_metric("tool_error", user_email, tool_name=actual_tool_name)
2100
+
2101
+ return ToolResult(
2102
+ tool_call_id=tool_call.id,
2103
+ content=f"Error executing tool: {str(e)}",
2104
+ success=False,
2105
+ error=str(e)
2106
+ )
2107
+
2108
+ async def execute_tool_calls(
2109
+ self,
2110
+ tool_calls: List[ToolCall],
2111
+ context: Optional[Dict[str, Any]] = None
2112
+ ) -> List[ToolResult]:
2113
+ """Execute multiple tool calls."""
2114
+ results = []
2115
+ for tool_call in tool_calls:
2116
+ result = await self.execute_tool(tool_call, context)
2117
+ results.append(result)
2118
+ return results
2119
+
2120
+ async def cleanup(self):
2121
+ """Cleanup all clients."""
2122
+ logger.info("Cleaning up MCP clients")
2123
+ # FastMCP clients handle cleanup automatically with context managers