atlas-chat 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. atlas/__init__.py +40 -0
  2. atlas/application/__init__.py +7 -0
  3. atlas/application/chat/__init__.py +7 -0
  4. atlas/application/chat/agent/__init__.py +10 -0
  5. atlas/application/chat/agent/act_loop.py +179 -0
  6. atlas/application/chat/agent/factory.py +142 -0
  7. atlas/application/chat/agent/protocols.py +46 -0
  8. atlas/application/chat/agent/react_loop.py +338 -0
  9. atlas/application/chat/agent/think_act_loop.py +171 -0
  10. atlas/application/chat/approval_manager.py +151 -0
  11. atlas/application/chat/elicitation_manager.py +191 -0
  12. atlas/application/chat/events/__init__.py +1 -0
  13. atlas/application/chat/events/agent_event_relay.py +112 -0
  14. atlas/application/chat/modes/__init__.py +1 -0
  15. atlas/application/chat/modes/agent.py +125 -0
  16. atlas/application/chat/modes/plain.py +74 -0
  17. atlas/application/chat/modes/rag.py +81 -0
  18. atlas/application/chat/modes/tools.py +179 -0
  19. atlas/application/chat/orchestrator.py +213 -0
  20. atlas/application/chat/policies/__init__.py +1 -0
  21. atlas/application/chat/policies/tool_authorization.py +99 -0
  22. atlas/application/chat/preprocessors/__init__.py +1 -0
  23. atlas/application/chat/preprocessors/message_builder.py +92 -0
  24. atlas/application/chat/preprocessors/prompt_override_service.py +104 -0
  25. atlas/application/chat/service.py +454 -0
  26. atlas/application/chat/utilities/__init__.py +6 -0
  27. atlas/application/chat/utilities/error_handler.py +367 -0
  28. atlas/application/chat/utilities/event_notifier.py +546 -0
  29. atlas/application/chat/utilities/file_processor.py +613 -0
  30. atlas/application/chat/utilities/tool_executor.py +789 -0
  31. atlas/atlas_chat_cli.py +347 -0
  32. atlas/atlas_client.py +238 -0
  33. atlas/core/__init__.py +0 -0
  34. atlas/core/auth.py +205 -0
  35. atlas/core/authorization_manager.py +27 -0
  36. atlas/core/capabilities.py +123 -0
  37. atlas/core/compliance.py +215 -0
  38. atlas/core/domain_whitelist.py +147 -0
  39. atlas/core/domain_whitelist_middleware.py +82 -0
  40. atlas/core/http_client.py +28 -0
  41. atlas/core/log_sanitizer.py +102 -0
  42. atlas/core/metrics_logger.py +59 -0
  43. atlas/core/middleware.py +131 -0
  44. atlas/core/otel_config.py +242 -0
  45. atlas/core/prompt_risk.py +200 -0
  46. atlas/core/rate_limit.py +0 -0
  47. atlas/core/rate_limit_middleware.py +64 -0
  48. atlas/core/security_headers_middleware.py +51 -0
  49. atlas/domain/__init__.py +37 -0
  50. atlas/domain/chat/__init__.py +1 -0
  51. atlas/domain/chat/dtos.py +85 -0
  52. atlas/domain/errors.py +96 -0
  53. atlas/domain/messages/__init__.py +12 -0
  54. atlas/domain/messages/models.py +160 -0
  55. atlas/domain/rag_mcp_service.py +664 -0
  56. atlas/domain/sessions/__init__.py +7 -0
  57. atlas/domain/sessions/models.py +36 -0
  58. atlas/domain/unified_rag_service.py +371 -0
  59. atlas/infrastructure/__init__.py +10 -0
  60. atlas/infrastructure/app_factory.py +135 -0
  61. atlas/infrastructure/events/__init__.py +1 -0
  62. atlas/infrastructure/events/cli_event_publisher.py +140 -0
  63. atlas/infrastructure/events/websocket_publisher.py +140 -0
  64. atlas/infrastructure/sessions/in_memory_repository.py +56 -0
  65. atlas/infrastructure/transport/__init__.py +7 -0
  66. atlas/infrastructure/transport/websocket_connection_adapter.py +33 -0
  67. atlas/init_cli.py +226 -0
  68. atlas/interfaces/__init__.py +15 -0
  69. atlas/interfaces/events.py +134 -0
  70. atlas/interfaces/llm.py +54 -0
  71. atlas/interfaces/rag.py +40 -0
  72. atlas/interfaces/sessions.py +75 -0
  73. atlas/interfaces/tools.py +57 -0
  74. atlas/interfaces/transport.py +24 -0
  75. atlas/main.py +564 -0
  76. atlas/mcp/api_key_demo/README.md +76 -0
  77. atlas/mcp/api_key_demo/main.py +172 -0
  78. atlas/mcp/api_key_demo/run.sh +56 -0
  79. atlas/mcp/basictable/main.py +147 -0
  80. atlas/mcp/calculator/main.py +149 -0
  81. atlas/mcp/code-executor/execution_engine.py +98 -0
  82. atlas/mcp/code-executor/execution_environment.py +95 -0
  83. atlas/mcp/code-executor/main.py +528 -0
  84. atlas/mcp/code-executor/result_processing.py +276 -0
  85. atlas/mcp/code-executor/script_generation.py +195 -0
  86. atlas/mcp/code-executor/security_checker.py +140 -0
  87. atlas/mcp/corporate_cars/main.py +437 -0
  88. atlas/mcp/csv_reporter/main.py +545 -0
  89. atlas/mcp/duckduckgo/main.py +182 -0
  90. atlas/mcp/elicitation_demo/README.md +171 -0
  91. atlas/mcp/elicitation_demo/main.py +262 -0
  92. atlas/mcp/env-demo/README.md +158 -0
  93. atlas/mcp/env-demo/main.py +199 -0
  94. atlas/mcp/file_size_test/main.py +284 -0
  95. atlas/mcp/filesystem/main.py +348 -0
  96. atlas/mcp/image_demo/main.py +113 -0
  97. atlas/mcp/image_demo/requirements.txt +4 -0
  98. atlas/mcp/logging_demo/README.md +72 -0
  99. atlas/mcp/logging_demo/main.py +103 -0
  100. atlas/mcp/many_tools_demo/main.py +50 -0
  101. atlas/mcp/order_database/__init__.py +0 -0
  102. atlas/mcp/order_database/main.py +369 -0
  103. atlas/mcp/order_database/signal_data.csv +1001 -0
  104. atlas/mcp/pdfbasic/main.py +394 -0
  105. atlas/mcp/pptx_generator/main.py +760 -0
  106. atlas/mcp/pptx_generator/requirements.txt +13 -0
  107. atlas/mcp/pptx_generator/run_test.sh +1 -0
  108. atlas/mcp/pptx_generator/test_pptx_generator_security.py +169 -0
  109. atlas/mcp/progress_demo/main.py +167 -0
  110. atlas/mcp/progress_updates_demo/QUICKSTART.md +273 -0
  111. atlas/mcp/progress_updates_demo/README.md +120 -0
  112. atlas/mcp/progress_updates_demo/main.py +497 -0
  113. atlas/mcp/prompts/main.py +222 -0
  114. atlas/mcp/public_demo/main.py +189 -0
  115. atlas/mcp/sampling_demo/README.md +169 -0
  116. atlas/mcp/sampling_demo/main.py +234 -0
  117. atlas/mcp/thinking/main.py +77 -0
  118. atlas/mcp/tool_planner/main.py +240 -0
  119. atlas/mcp/ui-demo/badmesh.png +0 -0
  120. atlas/mcp/ui-demo/main.py +383 -0
  121. atlas/mcp/ui-demo/templates/button_demo.html +32 -0
  122. atlas/mcp/ui-demo/templates/data_visualization.html +32 -0
  123. atlas/mcp/ui-demo/templates/form_demo.html +28 -0
  124. atlas/mcp/username-override-demo/README.md +320 -0
  125. atlas/mcp/username-override-demo/main.py +308 -0
  126. atlas/modules/__init__.py +0 -0
  127. atlas/modules/config/__init__.py +34 -0
  128. atlas/modules/config/cli.py +231 -0
  129. atlas/modules/config/config_manager.py +1096 -0
  130. atlas/modules/file_storage/__init__.py +22 -0
  131. atlas/modules/file_storage/cli.py +330 -0
  132. atlas/modules/file_storage/content_extractor.py +290 -0
  133. atlas/modules/file_storage/manager.py +295 -0
  134. atlas/modules/file_storage/mock_s3_client.py +402 -0
  135. atlas/modules/file_storage/s3_client.py +417 -0
  136. atlas/modules/llm/__init__.py +19 -0
  137. atlas/modules/llm/caller.py +287 -0
  138. atlas/modules/llm/litellm_caller.py +675 -0
  139. atlas/modules/llm/models.py +19 -0
  140. atlas/modules/mcp_tools/__init__.py +17 -0
  141. atlas/modules/mcp_tools/client.py +2123 -0
  142. atlas/modules/mcp_tools/token_storage.py +556 -0
  143. atlas/modules/prompts/prompt_provider.py +130 -0
  144. atlas/modules/rag/__init__.py +24 -0
  145. atlas/modules/rag/atlas_rag_client.py +336 -0
  146. atlas/modules/rag/client.py +129 -0
  147. atlas/routes/admin_routes.py +865 -0
  148. atlas/routes/config_routes.py +484 -0
  149. atlas/routes/feedback_routes.py +361 -0
  150. atlas/routes/files_routes.py +274 -0
  151. atlas/routes/health_routes.py +40 -0
  152. atlas/routes/mcp_auth_routes.py +223 -0
  153. atlas/server_cli.py +164 -0
  154. atlas/tests/conftest.py +20 -0
  155. atlas/tests/integration/test_mcp_auth_integration.py +152 -0
  156. atlas/tests/manual_test_sampling.py +87 -0
  157. atlas/tests/modules/mcp_tools/test_client_auth.py +226 -0
  158. atlas/tests/modules/mcp_tools/test_client_env.py +191 -0
  159. atlas/tests/test_admin_mcp_server_management_routes.py +141 -0
  160. atlas/tests/test_agent_roa.py +135 -0
  161. atlas/tests/test_app_factory_smoke.py +47 -0
  162. atlas/tests/test_approval_manager.py +439 -0
  163. atlas/tests/test_atlas_client.py +188 -0
  164. atlas/tests/test_atlas_rag_client.py +447 -0
  165. atlas/tests/test_atlas_rag_integration.py +224 -0
  166. atlas/tests/test_attach_file_flow.py +287 -0
  167. atlas/tests/test_auth_utils.py +165 -0
  168. atlas/tests/test_backend_public_url.py +185 -0
  169. atlas/tests/test_banner_logging.py +287 -0
  170. atlas/tests/test_capability_tokens_and_injection.py +203 -0
  171. atlas/tests/test_compliance_level.py +54 -0
  172. atlas/tests/test_compliance_manager.py +253 -0
  173. atlas/tests/test_config_manager.py +617 -0
  174. atlas/tests/test_config_manager_paths.py +12 -0
  175. atlas/tests/test_core_auth.py +18 -0
  176. atlas/tests/test_core_utils.py +190 -0
  177. atlas/tests/test_docker_env_sync.py +202 -0
  178. atlas/tests/test_domain_errors.py +329 -0
  179. atlas/tests/test_domain_whitelist.py +359 -0
  180. atlas/tests/test_elicitation_manager.py +408 -0
  181. atlas/tests/test_elicitation_routing.py +296 -0
  182. atlas/tests/test_env_demo_server.py +88 -0
  183. atlas/tests/test_error_classification.py +113 -0
  184. atlas/tests/test_error_flow_integration.py +116 -0
  185. atlas/tests/test_feedback_routes.py +333 -0
  186. atlas/tests/test_file_content_extraction.py +1134 -0
  187. atlas/tests/test_file_extraction_routes.py +158 -0
  188. atlas/tests/test_file_library.py +107 -0
  189. atlas/tests/test_file_manager_unit.py +18 -0
  190. atlas/tests/test_health_route.py +49 -0
  191. atlas/tests/test_http_client_stub.py +8 -0
  192. atlas/tests/test_imports_smoke.py +30 -0
  193. atlas/tests/test_interfaces_llm_response.py +9 -0
  194. atlas/tests/test_issue_access_denied_fix.py +136 -0
  195. atlas/tests/test_llm_env_expansion.py +836 -0
  196. atlas/tests/test_log_level_sensitive_data.py +285 -0
  197. atlas/tests/test_mcp_auth_routes.py +341 -0
  198. atlas/tests/test_mcp_client_auth.py +331 -0
  199. atlas/tests/test_mcp_data_injection.py +270 -0
  200. atlas/tests/test_mcp_get_authorized_servers.py +95 -0
  201. atlas/tests/test_mcp_hot_reload.py +512 -0
  202. atlas/tests/test_mcp_image_content.py +424 -0
  203. atlas/tests/test_mcp_logging.py +172 -0
  204. atlas/tests/test_mcp_progress_updates.py +313 -0
  205. atlas/tests/test_mcp_prompt_override_system_prompt.py +102 -0
  206. atlas/tests/test_mcp_prompts_server.py +39 -0
  207. atlas/tests/test_mcp_tool_result_parsing.py +296 -0
  208. atlas/tests/test_metrics_logger.py +56 -0
  209. atlas/tests/test_middleware_auth.py +379 -0
  210. atlas/tests/test_prompt_risk_and_acl.py +141 -0
  211. atlas/tests/test_rag_mcp_aggregator.py +204 -0
  212. atlas/tests/test_rag_mcp_service.py +224 -0
  213. atlas/tests/test_rate_limit_middleware.py +45 -0
  214. atlas/tests/test_routes_config_smoke.py +60 -0
  215. atlas/tests/test_routes_files_download_token.py +41 -0
  216. atlas/tests/test_routes_files_health.py +18 -0
  217. atlas/tests/test_runtime_imports.py +53 -0
  218. atlas/tests/test_sampling_integration.py +482 -0
  219. atlas/tests/test_security_admin_routes.py +61 -0
  220. atlas/tests/test_security_capability_tokens.py +65 -0
  221. atlas/tests/test_security_file_stats_scope.py +21 -0
  222. atlas/tests/test_security_header_injection.py +191 -0
  223. atlas/tests/test_security_headers_and_filename.py +63 -0
  224. atlas/tests/test_shared_session_repository.py +101 -0
  225. atlas/tests/test_system_prompt_loading.py +181 -0
  226. atlas/tests/test_token_storage.py +505 -0
  227. atlas/tests/test_tool_approval_config.py +93 -0
  228. atlas/tests/test_tool_approval_utils.py +356 -0
  229. atlas/tests/test_tool_authorization_group_filtering.py +223 -0
  230. atlas/tests/test_tool_details_in_config.py +108 -0
  231. atlas/tests/test_tool_planner.py +300 -0
  232. atlas/tests/test_unified_rag_service.py +398 -0
  233. atlas/tests/test_username_override_in_approval.py +258 -0
  234. atlas/tests/test_websocket_auth_header.py +168 -0
  235. atlas/version.py +6 -0
  236. atlas_chat-0.1.0.data/data/.env.example +253 -0
  237. atlas_chat-0.1.0.data/data/config/defaults/compliance-levels.json +44 -0
  238. atlas_chat-0.1.0.data/data/config/defaults/domain-whitelist.json +123 -0
  239. atlas_chat-0.1.0.data/data/config/defaults/file-extractors.json +74 -0
  240. atlas_chat-0.1.0.data/data/config/defaults/help-config.json +198 -0
  241. atlas_chat-0.1.0.data/data/config/defaults/llmconfig-buggy.yml +11 -0
  242. atlas_chat-0.1.0.data/data/config/defaults/llmconfig.yml +19 -0
  243. atlas_chat-0.1.0.data/data/config/defaults/mcp.json +138 -0
  244. atlas_chat-0.1.0.data/data/config/defaults/rag-sources.json +17 -0
  245. atlas_chat-0.1.0.data/data/config/defaults/splash-config.json +16 -0
  246. atlas_chat-0.1.0.dist-info/METADATA +236 -0
  247. atlas_chat-0.1.0.dist-info/RECORD +250 -0
  248. atlas_chat-0.1.0.dist-info/WHEEL +5 -0
  249. atlas_chat-0.1.0.dist-info/entry_points.txt +4 -0
  250. atlas_chat-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,664 @@
1
+ """RAG MCP Aggregator Service (Phase 1: Discovery)
2
+
3
+ Aggregates discovery of RAG resources from authorized MCP servers that expose
4
+ the `rag_discover_resources` tool. Returns a flat list of data source IDs for
5
+ backward-compatible UI, with server-qualified IDs to avoid collisions.
6
+
7
+ Future phases will add search/synthesis and richer shapes.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import logging
13
+ from typing import Any, Dict, List, Optional
14
+
15
+ from atlas.core.compliance import get_compliance_manager
16
+ from atlas.core.log_sanitizer import sanitize_for_logging
17
+ from atlas.core.prompt_risk import calculate_prompt_injection_risk, log_high_risk_event
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class RAGMCPService:
23
+ """Aggregator for RAG over MCP servers."""
24
+
25
+ def __init__(self, mcp_manager, config_manager, auth_check_func) -> None:
26
+ self.mcp_manager = mcp_manager
27
+ self.config_manager = config_manager
28
+ self.auth_check_func = auth_check_func
29
+
30
+ async def _get_authorized_rag_servers(self, username: str, rag_servers: dict) -> List[str]:
31
+ """Get list of RAG servers the user is authorized to access.
32
+
33
+ This checks authorization directly against rag_mcp_config servers,
34
+ independent of mcp_manager.servers_config (which excludes RAG servers
35
+ to keep them separate from the tools panel).
36
+ """
37
+ authorized = []
38
+ for server_name, server_config in rag_servers.items():
39
+ if not server_config.enabled:
40
+ continue
41
+
42
+ required_groups = server_config.groups or []
43
+ if not required_groups:
44
+ # No group restriction - available to all
45
+ authorized.append(server_name)
46
+ continue
47
+
48
+ # Check if user is in any of the required groups
49
+ group_checks = [
50
+ await self.auth_check_func(username, group)
51
+ for group in required_groups
52
+ ]
53
+ if any(group_checks):
54
+ authorized.append(server_name)
55
+
56
+ return authorized
57
+
58
+ async def discover_data_sources(self, username: str, user_compliance_level: Optional[str] = None) -> List[str]:
59
+ """Discover data sources across authorized MCP RAG servers.
60
+
61
+ Phase 1 returns a flat list of strings for backward compatibility.
62
+ Uses server-qualified IDs: "{server}:{resource_id}" to avoid collisions.
63
+ """
64
+ # Ensure RAG servers are initialized from rag_mcp_config, without polluting tool inventory
65
+ try:
66
+ rag_servers = self.config_manager.rag_mcp_config.servers
67
+ # If these servers aren't in mcp_manager.clients, initialize just these
68
+ missing = [name for name in rag_servers.keys() if name not in getattr(self.mcp_manager, "clients", {})]
69
+ if missing:
70
+ # Temporarily extend servers_config with rag servers and initialize them
71
+ original = dict(getattr(self.mcp_manager, "servers_config", {}))
72
+ try:
73
+ self.mcp_manager.servers_config.update({name: cfg.model_dump() for name, cfg in rag_servers.items()})
74
+ await self.mcp_manager.initialize_clients()
75
+ await self.mcp_manager.discover_tools()
76
+ finally:
77
+ # Restore original list for general tools panel separation
78
+ self.mcp_manager.servers_config = original
79
+ except Exception:
80
+ # If anything goes wrong, fallback silently to existing clients
81
+ pass
82
+ try:
83
+ # Determine RAG servers current user can see
84
+ # Use rag_mcp_config directly since servers_config was restored above
85
+ rag_servers = self.config_manager.rag_mcp_config.servers
86
+ authorized_servers: List[str] = await self._get_authorized_rag_servers(
87
+ username, rag_servers
88
+ )
89
+
90
+ if not authorized_servers:
91
+ logger.info("No authorized MCP servers for user %s", sanitize_for_logging(username))
92
+ return []
93
+
94
+ # --- Compliance Filtering (Step 2) ---
95
+ if user_compliance_level:
96
+ compliance_mgr = get_compliance_manager()
97
+ filtered_servers = []
98
+ for server in authorized_servers:
99
+ cfg = (self.mcp_manager.available_tools.get(server) or {}).get("config", {})
100
+ server_compliance_level = cfg.get("compliance_level")
101
+ if compliance_mgr.is_accessible(
102
+ user_level=user_compliance_level, resource_level=server_compliance_level
103
+ ):
104
+ filtered_servers.append(server)
105
+ else:
106
+ logger.info(
107
+ "Skipping RAG server %s due to compliance level mismatch (user: %s, server: %s)",
108
+ sanitize_for_logging(server),
109
+ sanitize_for_logging(user_compliance_level),
110
+ sanitize_for_logging(server_compliance_level),
111
+ )
112
+ authorized_servers = filtered_servers
113
+ if not authorized_servers:
114
+ logger.info("No authorized MCP servers remain after compliance filtering for user %s", sanitize_for_logging(username))
115
+ return []
116
+ # -------------------------------------
117
+
118
+ # Filter to servers that advertise the discovery tool
119
+ servers_with_discovery: List[str] = []
120
+ for server in authorized_servers:
121
+ server_data = self.mcp_manager.available_tools.get(server)
122
+ tool_list = (server_data or {}).get("tools", [])
123
+ if any(getattr(t, "name", None) == "rag_discover_resources" for t in tool_list):
124
+ servers_with_discovery.append(server)
125
+
126
+ if not servers_with_discovery:
127
+ logger.info("No servers implement rag_discover_resources for user %s", sanitize_for_logging(username))
128
+ return []
129
+
130
+ # Fan out discovery calls
131
+ sources: List[str] = []
132
+ for server in servers_with_discovery:
133
+ try:
134
+ raw = await self.mcp_manager.call_tool(
135
+ server_name=server,
136
+ tool_name="rag_discover_resources",
137
+ arguments={"username": username},
138
+ )
139
+
140
+ structured = self._extract_structured_result(raw)
141
+ resources = self._extract_resources(structured)
142
+ for r in resources:
143
+ rid = r.get("id") or r.get("name")
144
+ if not isinstance(rid, str):
145
+ continue
146
+ # Qualify with server to avoid collisions across providers
147
+ sources.append(f"{server}:{rid}")
148
+ except Exception as e:
149
+ logger.warning(
150
+ "Discovery failed on server %s for user %s: %s",
151
+ sanitize_for_logging(server),
152
+ sanitize_for_logging(username),
153
+ e,
154
+ )
155
+
156
+ # De-dupe while preserving order
157
+ seen = set()
158
+ deduped = []
159
+ for s in sources:
160
+ if s not in seen:
161
+ seen.add(s)
162
+ deduped.append(s)
163
+ return deduped
164
+
165
+ except Exception as e:
166
+ logger.error("Error during RAG MCP discovery: %s", e, exc_info=True)
167
+ return []
168
+
169
+ async def discover_servers(self, username: str, user_compliance_level: Optional[str] = None) -> List[Dict[str, Any]]:
170
+ """Return richer per-server discovery structure for UI (rag_servers).
171
+
172
+ Shape:
173
+ [
174
+ {
175
+ "server": "docsRag",
176
+ "displayName": "docsRag",
177
+ "icon": <optional>,
178
+ "sources": [
179
+ {"id": "handbook", "name": "Employee Handbook", "authRequired": False, "selected": False}
180
+ ]
181
+ }
182
+ ]
183
+ """
184
+ # Ensure RAG servers are initialized from rag_mcp_config, without polluting tool inventory
185
+ try:
186
+ rag_cfg_servers = self.config_manager.rag_mcp_config.servers
187
+ missing = [name for name in rag_cfg_servers.keys() if name not in getattr(self.mcp_manager, "clients", {})]
188
+ if missing:
189
+ original = dict(getattr(self.mcp_manager, "servers_config", {}))
190
+ try:
191
+ self.mcp_manager.servers_config.update({name: cfg.model_dump() for name, cfg in rag_cfg_servers.items()})
192
+ await self.mcp_manager.initialize_clients()
193
+ await self.mcp_manager.discover_tools()
194
+ finally:
195
+ self.mcp_manager.servers_config = original
196
+ except Exception:
197
+ # Fallback silently if RAG config init fails; we'll just return empty set
198
+ pass
199
+
200
+ rag_servers: List[Dict[str, Any]] = []
201
+ try:
202
+ compliance_mgr = get_compliance_manager() if user_compliance_level else None
203
+
204
+ # Use rag_mcp_config directly since servers_config was restored above
205
+ rag_cfg_servers = self.config_manager.rag_mcp_config.servers
206
+ authorized_servers: List[str] = await self._get_authorized_rag_servers(
207
+ username, rag_cfg_servers
208
+ )
209
+
210
+ # --- Compliance Filtering (Step 2) ---
211
+ if compliance_mgr:
212
+ filtered_servers = []
213
+ for server in authorized_servers:
214
+ cfg = (self.mcp_manager.available_tools.get(server) or {}).get("config", {})
215
+ server_compliance_level = cfg.get("compliance_level")
216
+ if compliance_mgr.is_accessible(
217
+ user_level=user_compliance_level, resource_level=server_compliance_level
218
+ ):
219
+ filtered_servers.append(server)
220
+ else:
221
+ logger.info(
222
+ "Skipping RAG server %s due to compliance level mismatch (user: %s, server: %s)",
223
+ sanitize_for_logging(server),
224
+ sanitize_for_logging(user_compliance_level),
225
+ sanitize_for_logging(server_compliance_level),
226
+ )
227
+ authorized_servers = filtered_servers
228
+ # -------------------------------------
229
+
230
+ for server in authorized_servers:
231
+ server_data = self.mcp_manager.available_tools.get(server)
232
+ tools = (server_data or {}).get("tools", [])
233
+ if not any(getattr(t, "name", None) == "rag_discover_resources" for t in tools):
234
+ continue
235
+
236
+ # Call discovery
237
+ try:
238
+ raw = await self.mcp_manager.call_tool(
239
+ server_name=server,
240
+ tool_name="rag_discover_resources",
241
+ arguments={"username": username},
242
+ )
243
+ structured = self._extract_structured_result(raw)
244
+ resources = self._extract_resources(structured)
245
+ except Exception as e:
246
+ logger.warning("Discovery failed for server %s: %s", server, e)
247
+ resources = []
248
+
249
+ # Build UI sources array
250
+ ui_sources: List[Dict[str, Any]] = []
251
+ for r in resources:
252
+ rid = r.get("id") or r.get("name")
253
+ if not isinstance(rid, str):
254
+ continue
255
+
256
+ # --- Compliance Filtering (Step 3) ---
257
+ # Check for both camelCase (MCP standard) and snake_case (RAG mock standard)
258
+ resource_compliance_level = r.get("complianceLevel") or r.get("compliance_level")
259
+ if compliance_mgr and not compliance_mgr.is_accessible(
260
+ user_level=user_compliance_level, resource_level=resource_compliance_level
261
+ ):
262
+ logger.info(
263
+ "Skipping RAG resource %s:%s due to compliance level mismatch (user: %s, resource: %s)",
264
+ sanitize_for_logging(server),
265
+ sanitize_for_logging(rid),
266
+ sanitize_for_logging(user_compliance_level),
267
+ sanitize_for_logging(resource_compliance_level),
268
+ )
269
+ continue
270
+ # -------------------------------------
271
+
272
+ ui_sources.append({
273
+ "id": rid,
274
+ "name": r.get("name") or rid,
275
+ # New contract: authRequired expected true; pass-through in case of legacy servers
276
+ "authRequired": bool(r.get("authRequired", True)),
277
+ # New: include per-resource groups when provided
278
+ "groups": list(r.get("groups", [])) if isinstance(r.get("groups"), list) else None,
279
+ "selected": bool(r.get("defaultSelected", False)),
280
+ # Include compliance_level from resource or inherit from server
281
+ "complianceLevel": resource_compliance_level if resource_compliance_level else None,
282
+ })
283
+
284
+ # Optional config-driven icon/name and compliance level
285
+ cfg = (self.mcp_manager.available_tools.get(server) or {}).get("config", {})
286
+ display_name = cfg.get("displayName") or server
287
+ icon = (cfg.get("ui") or {}).get("icon") if isinstance(cfg.get("ui"), dict) else None
288
+ compliance_level = cfg.get("compliance_level")
289
+
290
+ rag_servers.append({
291
+ "server": server,
292
+ "displayName": display_name,
293
+ "icon": icon,
294
+ "complianceLevel": compliance_level,
295
+ "sources": ui_sources,
296
+ })
297
+ except Exception as e:
298
+ logger.error("discover_servers error: %s", e, exc_info=True)
299
+
300
+ return rag_servers
301
+
302
+ async def search_raw(
303
+ self,
304
+ username: str,
305
+ query: str,
306
+ sources: List[str],
307
+ top_k: int = 8,
308
+ filters: Optional[Dict[str, Any]] = None,
309
+ ranking: Optional[Dict[str, Any]] = None,
310
+ ) -> Dict[str, Any]:
311
+ """Call rag_get_raw_results across servers and merge results.
312
+
313
+ sources are server-qualified (server:id). We group by server.
314
+ """
315
+ logger.debug(
316
+ "[MCP-RAG] search_raw called: user=%s, query_preview=%s..., sources=%s, top_k=%d",
317
+ sanitize_for_logging(username),
318
+ sanitize_for_logging(query[:100]) if query else "(empty)",
319
+ sources,
320
+ top_k,
321
+ )
322
+
323
+ filters = filters or {}
324
+ ranking = ranking or {}
325
+ by_server: Dict[str, List[str]] = {}
326
+ for s in sources or []:
327
+ if isinstance(s, str) and ":" in s:
328
+ srv, rid = s.split(":", 1)
329
+ by_server.setdefault(srv, []).append(rid)
330
+
331
+ logger.debug("[MCP-RAG] search_raw sources grouped by server: %s", by_server)
332
+
333
+ all_hits: List[Dict[str, Any]] = []
334
+ meta: Dict[str, Any] = {"providers": {}, "top_k": top_k}
335
+
336
+ for server, rids in by_server.items():
337
+ logger.debug("[MCP-RAG] search_raw processing server=%s, resource_ids=%s", server, rids)
338
+ try:
339
+ # Check tool availability
340
+ server_data = self.mcp_manager.available_tools.get(server) or {}
341
+ tool_list = server_data.get("tools", [])
342
+ if not any(getattr(t, "name", None) == "rag_get_raw_results" for t in tool_list):
343
+ logger.debug("[MCP-RAG] Server %s lacks rag_get_raw_results tool, skipping", server)
344
+ continue
345
+
346
+ logger.debug("[MCP-RAG] Calling rag_get_raw_results on server %s", server)
347
+ raw = await self.mcp_manager.call_tool(
348
+ server_name=server,
349
+ tool_name="rag_get_raw_results",
350
+ arguments={
351
+ "username": username,
352
+ "query": query,
353
+ "sources": rids,
354
+ "top_k": top_k,
355
+ "filters": filters,
356
+ "ranking": ranking,
357
+ },
358
+ )
359
+ logger.debug("[MCP-RAG] Server %s raw response type: %s", server, type(raw).__name__)
360
+
361
+ payload = self._extract_structured_result(raw) or {}
362
+ results = payload.get("results") or {}
363
+ hits = results.get("hits") or []
364
+ logger.debug("[MCP-RAG] Server %s returned %d hits", server, len(hits))
365
+
366
+ # Annotate with server for provenance
367
+ for h in hits:
368
+ if isinstance(h, dict):
369
+ h.setdefault("server", server)
370
+ all_hits.extend([h for h in hits if isinstance(h, dict)])
371
+ meta["providers"][server] = {
372
+ "returned": len(hits),
373
+ "error": None,
374
+ }
375
+ except Exception as e:
376
+ logger.error("[MCP-RAG] Server %s search_raw error: %s", server, e, exc_info=True)
377
+ meta["providers"][server] = {"returned": 0, "error": str(e)}
378
+
379
+ # Merge + rerank (simple): sort by score desc if present
380
+ def score_of(h: Dict[str, Any]) -> float:
381
+ try:
382
+ return float(h.get("score", 0.0))
383
+ except Exception:
384
+ return 0.0
385
+
386
+ all_hits.sort(key=score_of, reverse=True)
387
+ merged = all_hits[: top_k or len(all_hits)]
388
+
389
+ logger.info(
390
+ "[MCP-RAG] search_raw complete: total_hits=%d, merged_count=%d, providers=%s",
391
+ len(all_hits),
392
+ len(merged),
393
+ list(meta["providers"].keys()),
394
+ )
395
+
396
+ # Prompt-injection risk check on retrieved snippets (observe + log)
397
+ try:
398
+ for h in merged:
399
+ if not isinstance(h, dict):
400
+ continue
401
+ text = h.get("snippet") or h.get("chunk") or h.get("text") or ""
402
+ if not isinstance(text, str) or not text.strip():
403
+ continue
404
+ pi = calculate_prompt_injection_risk(text, mode="general")
405
+ if pi.get("risk_level") in ("medium", "high"):
406
+ log_high_risk_event(
407
+ source="rag_chunk",
408
+ user=username,
409
+ content=text,
410
+ score=int(pi.get("score", 0)),
411
+ risk_level=str(pi.get("risk_level")),
412
+ triggers=list(pi.get("triggers", [])),
413
+ extra={
414
+ "server": h.get("server"),
415
+ "resourceId": h.get("resourceId"),
416
+ },
417
+ )
418
+ except Exception:
419
+ logger.debug("Prompt risk check failed (RAG results)", exc_info=True)
420
+
421
+ return {
422
+ "results": {
423
+ "hits": merged,
424
+ "stats": {
425
+ "total_found": len(all_hits),
426
+ "top_k": top_k,
427
+ },
428
+ },
429
+ "meta_data": meta,
430
+ }
431
+
432
+ async def synthesize(
433
+ self,
434
+ username: str,
435
+ query: str,
436
+ sources: List[str],
437
+ top_k: Optional[int] = None,
438
+ synthesis_params: Optional[Dict[str, Any]] = None,
439
+ provided_context: Optional[Dict[str, Any]] = None,
440
+ ) -> Dict[str, Any]:
441
+ """Call rag_get_synthesized_results across servers when available.
442
+
443
+ If not available, fall back to raw search and concatenate snippets.
444
+ """
445
+ logger.debug(
446
+ "[MCP-RAG] synthesize called: user=%s, query_preview=%s..., sources=%s, top_k=%s",
447
+ sanitize_for_logging(username),
448
+ sanitize_for_logging(query[:100]) if query else "(empty)",
449
+ sources,
450
+ top_k,
451
+ )
452
+
453
+ synthesis_params = synthesis_params or {}
454
+ provided_context = provided_context or {}
455
+
456
+ by_server: Dict[str, List[str]] = {}
457
+ for s in sources or []:
458
+ if isinstance(s, str) and ":" in s:
459
+ srv, rid = s.split(":", 1)
460
+ by_server.setdefault(srv, []).append(rid)
461
+
462
+ logger.debug("[MCP-RAG] Sources grouped by server: %s", by_server)
463
+
464
+ answers: List[str] = []
465
+ citations: List[Dict[str, Any]] = []
466
+ meta: Dict[str, Any] = {"providers": {}}
467
+ used_fallback = False
468
+
469
+ for server, rids in by_server.items():
470
+ logger.debug(
471
+ "[MCP-RAG] Processing server=%s, resource_ids=%s",
472
+ server,
473
+ rids,
474
+ )
475
+ try:
476
+ server_data = self.mcp_manager.available_tools.get(server) or {}
477
+ tool_list = server_data.get("tools", [])
478
+ tool_names = [getattr(t, "name", None) for t in tool_list]
479
+ logger.debug("[MCP-RAG] Server %s available tools: %s", server, tool_names)
480
+
481
+ has_synth = any(getattr(t, "name", None) == "rag_get_synthesized_results" for t in tool_list)
482
+ if has_synth:
483
+ logger.debug("[MCP-RAG] Server %s has rag_get_synthesized_results, calling...", server)
484
+ raw = await self.mcp_manager.call_tool(
485
+ server_name=server,
486
+ tool_name="rag_get_synthesized_results",
487
+ arguments={
488
+ "username": username,
489
+ "query": query,
490
+ "sources": rids,
491
+ **({"top_k": top_k} if top_k is not None else {}),
492
+ "synthesis_params": synthesis_params,
493
+ "provided_context": provided_context,
494
+ },
495
+ )
496
+ logger.debug("[MCP-RAG] Server %s raw response type: %s", server, type(raw).__name__)
497
+
498
+ payload = self._extract_structured_result(raw) or {}
499
+ logger.debug("[MCP-RAG] Server %s extracted payload keys: %s", server, list(payload.keys()))
500
+ logger.debug("[MCP-RAG] Server %s full payload: %s", server, sanitize_for_logging(str(payload)[:1000]))
501
+
502
+ results = payload.get("results") or {}
503
+ logger.debug("[MCP-RAG] Server %s results type: %s, results keys: %s", server, type(results).__name__, list(results.keys()) if isinstance(results, dict) else "N/A")
504
+ logger.debug("[MCP-RAG] Server %s results content: %s", server, sanitize_for_logging(str(results)[:500]))
505
+
506
+ ans = results.get("answer")
507
+ if isinstance(ans, str) and ans:
508
+ logger.debug(
509
+ "[MCP-RAG] Server %s answer length=%d, preview=%s...",
510
+ server,
511
+ len(ans),
512
+ sanitize_for_logging(ans[:200]),
513
+ )
514
+ answers.append(ans)
515
+ cits = results.get("citations") or []
516
+ if isinstance(cits, list):
517
+ for c in cits:
518
+ if isinstance(c, dict):
519
+ c.setdefault("server", server)
520
+ citations.extend([c for c in cits if isinstance(c, dict)])
521
+ meta["providers"][server] = {"used_synth": True, "error": None}
522
+ logger.info("[MCP-RAG] Server %s synthesis complete: answer_length=%d", server, len(ans) if ans else 0)
523
+ else:
524
+ logger.debug("[MCP-RAG] Server %s lacks rag_get_synthesized_results, using fallback search_raw", server)
525
+ used_fallback = True
526
+ raw_payload = await self.search_raw(username, query, [f"{server}:{rid}" for rid in rids], top_k=top_k or 8)
527
+ # Build a rudimentary answer from snippets
528
+ hits = ((raw_payload.get("results") or {}).get("hits") or [])
529
+ logger.debug("[MCP-RAG] Server %s fallback search returned %d hits", server, len(hits))
530
+ snippet_texts = [h.get("snippet") or h.get("chunk") or "" for h in hits if isinstance(h, dict)]
531
+ if snippet_texts:
532
+ answers.append("\n\n".join(snippet_texts[:3]))
533
+ meta["providers"][server] = {"used_synth": False, "error": None}
534
+ except Exception as e:
535
+ logger.error("[MCP-RAG] Server %s synthesis error: %s", server, e, exc_info=True)
536
+ meta["providers"][server] = {"used_synth": False, "error": str(e)}
537
+
538
+ final_answer = "\n\n---\n\n".join([a for a in answers if a]) if answers else ""
539
+ logger.info(
540
+ "[MCP-RAG] synthesize complete: total_answers=%d, final_answer_length=%d, used_fallback=%s",
541
+ len(answers),
542
+ len(final_answer),
543
+ used_fallback,
544
+ )
545
+
546
+ return {
547
+ "results": {
548
+ "answer": final_answer,
549
+ "citations": citations or None,
550
+ "limits": {"truncated": False} if final_answer else None,
551
+ },
552
+ "meta_data": {**meta, "fallback_used": used_fallback},
553
+ }
554
+
555
+ # --- helpers ---------------------------------------------------------
556
+ def _extract_structured_result(self, raw: Any) -> Dict[str, Any]:
557
+ """Best-effort extraction of a structured payload from FastMCP result."""
558
+ import json
559
+
560
+ logger.debug("[MCP-RAG] _extract_structured_result: raw type=%s", type(raw).__name__)
561
+
562
+ try:
563
+ # Log available attributes for debugging
564
+ if hasattr(raw, "__dict__"):
565
+ logger.debug("[MCP-RAG] _extract_structured_result: raw attributes=%s", list(raw.__dict__.keys()) if hasattr(raw, "__dict__") else "N/A")
566
+
567
+ # If raw is already a dict, return it directly
568
+ if isinstance(raw, dict):
569
+ logger.debug("[MCP-RAG] _extract_structured_result: raw is already a dict with keys=%s", list(raw.keys()))
570
+ return raw
571
+
572
+ # Preferred attributes from fastmcp
573
+ if hasattr(raw, "structured_content") and raw.structured_content:
574
+ logger.debug("[MCP-RAG] _extract_structured_result: found structured_content")
575
+ if isinstance(raw.structured_content, dict):
576
+ return raw.structured_content
577
+ if hasattr(raw, "data") and raw.data:
578
+ logger.debug("[MCP-RAG] _extract_structured_result: found data")
579
+ if isinstance(raw.data, dict):
580
+ return raw.data
581
+
582
+ if hasattr(raw, "content") and raw.content:
583
+ contents = getattr(raw, "content")
584
+ logger.debug("[MCP-RAG] _extract_structured_result: found content, type=%s, len=%s", type(contents).__name__, len(contents) if hasattr(contents, '__len__') else "N/A")
585
+
586
+ # content is typically a list of segments with .text
587
+ if isinstance(contents, list) and contents:
588
+ # Try all content items, not just the first
589
+ for idx, item in enumerate(contents):
590
+ logger.debug("[MCP-RAG] _extract_structured_result: content[%d] type=%s", idx, type(item).__name__)
591
+
592
+ # Try .text attribute
593
+ text = getattr(item, "text", None)
594
+ if text is None and isinstance(item, dict):
595
+ text = item.get("text")
596
+
597
+ if text:
598
+ logger.debug("[MCP-RAG] _extract_structured_result: text type=%s, preview=%s", type(text).__name__, sanitize_for_logging(str(text)[:300]))
599
+ if isinstance(text, str) and text.strip():
600
+ try:
601
+ obj = json.loads(text)
602
+ if isinstance(obj, dict):
603
+ logger.debug("[MCP-RAG] _extract_structured_result: parsed JSON with keys=%s", list(obj.keys()))
604
+ return obj
605
+ except Exception as json_err:
606
+ logger.debug("[MCP-RAG] _extract_structured_result: JSON parse failed for content[%d]: %s", idx, json_err)
607
+
608
+ # Try if item is itself a dict with results/meta_data
609
+ if isinstance(item, dict):
610
+ if "results" in item or "meta_data" in item:
611
+ logger.debug("[MCP-RAG] _extract_structured_result: content[%d] is a dict with results/meta_data", idx)
612
+ return item
613
+
614
+ # If content is a single string (not list), try to parse as JSON
615
+ elif isinstance(contents, str) and contents.strip():
616
+ logger.debug("[MCP-RAG] _extract_structured_result: content is a string, trying JSON parse")
617
+ try:
618
+ obj = json.loads(contents)
619
+ if isinstance(obj, dict):
620
+ logger.debug("[MCP-RAG] _extract_structured_result: parsed content string as JSON with keys=%s", list(obj.keys()))
621
+ return obj
622
+ except Exception as json_err:
623
+ logger.debug("[MCP-RAG] _extract_structured_result: JSON parse of content string failed: %s", json_err)
624
+
625
+ # Try to convert raw to string and parse as JSON (last resort)
626
+ if hasattr(raw, "__str__"):
627
+ raw_str = str(raw)
628
+ # Only try if it looks like JSON
629
+ if raw_str.strip().startswith("{"):
630
+ logger.debug("[MCP-RAG] _extract_structured_result: trying __str__ as JSON: %s...", sanitize_for_logging(raw_str[:200]))
631
+ try:
632
+ obj = json.loads(raw_str)
633
+ if isinstance(obj, dict):
634
+ logger.debug("[MCP-RAG] _extract_structured_result: parsed __str__ as JSON with keys=%s", list(obj.keys()))
635
+ return obj
636
+ except Exception:
637
+ pass
638
+
639
+ except Exception as parse_err: # pragma: no cover - defensive
640
+ logger.debug("Non-fatal: failed to parse structured result: %s", parse_err)
641
+
642
+ logger.debug("[MCP-RAG] _extract_structured_result: returning empty dict")
643
+ return {}
644
+
645
+ def _extract_resources(self, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
646
+ """Extract list of resource dicts from a normalized tool result."""
647
+ if not isinstance(payload, dict):
648
+ return []
649
+ results = payload.get("results") if isinstance(payload.get("results"), dict) else payload
650
+ # Support both {results: {resources: [...]}} and {results: [...]}
651
+ # Also support the RAG mock format: {accessible_data_sources: [...]}
652
+ resources = (
653
+ (results.get("resources") if isinstance(results, dict) else None)
654
+ or payload.get("resources")
655
+ or payload.get("accessible_data_sources") # Added support for RAG mock format
656
+ or []
657
+ )
658
+ if isinstance(resources, list):
659
+ # ensure each entry is a dict
660
+ return [r for r in resources if isinstance(r, dict)]
661
+ return []
662
+
663
+
664
+ __all__ = ["RAGMCPService"]
@@ -0,0 +1,7 @@
1
+ """Domain models for sessions."""
2
+
3
+ from .models import Session
4
+
5
+ __all__ = [
6
+ "Session",
7
+ ]