atlas-chat 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. atlas/__init__.py +40 -0
  2. atlas/application/__init__.py +7 -0
  3. atlas/application/chat/__init__.py +7 -0
  4. atlas/application/chat/agent/__init__.py +10 -0
  5. atlas/application/chat/agent/act_loop.py +179 -0
  6. atlas/application/chat/agent/factory.py +142 -0
  7. atlas/application/chat/agent/protocols.py +46 -0
  8. atlas/application/chat/agent/react_loop.py +338 -0
  9. atlas/application/chat/agent/think_act_loop.py +171 -0
  10. atlas/application/chat/approval_manager.py +151 -0
  11. atlas/application/chat/elicitation_manager.py +191 -0
  12. atlas/application/chat/events/__init__.py +1 -0
  13. atlas/application/chat/events/agent_event_relay.py +112 -0
  14. atlas/application/chat/modes/__init__.py +1 -0
  15. atlas/application/chat/modes/agent.py +125 -0
  16. atlas/application/chat/modes/plain.py +74 -0
  17. atlas/application/chat/modes/rag.py +81 -0
  18. atlas/application/chat/modes/tools.py +179 -0
  19. atlas/application/chat/orchestrator.py +213 -0
  20. atlas/application/chat/policies/__init__.py +1 -0
  21. atlas/application/chat/policies/tool_authorization.py +99 -0
  22. atlas/application/chat/preprocessors/__init__.py +1 -0
  23. atlas/application/chat/preprocessors/message_builder.py +92 -0
  24. atlas/application/chat/preprocessors/prompt_override_service.py +104 -0
  25. atlas/application/chat/service.py +454 -0
  26. atlas/application/chat/utilities/__init__.py +6 -0
  27. atlas/application/chat/utilities/error_handler.py +367 -0
  28. atlas/application/chat/utilities/event_notifier.py +546 -0
  29. atlas/application/chat/utilities/file_processor.py +613 -0
  30. atlas/application/chat/utilities/tool_executor.py +789 -0
  31. atlas/atlas_chat_cli.py +347 -0
  32. atlas/atlas_client.py +238 -0
  33. atlas/core/__init__.py +0 -0
  34. atlas/core/auth.py +205 -0
  35. atlas/core/authorization_manager.py +27 -0
  36. atlas/core/capabilities.py +123 -0
  37. atlas/core/compliance.py +215 -0
  38. atlas/core/domain_whitelist.py +147 -0
  39. atlas/core/domain_whitelist_middleware.py +82 -0
  40. atlas/core/http_client.py +28 -0
  41. atlas/core/log_sanitizer.py +102 -0
  42. atlas/core/metrics_logger.py +59 -0
  43. atlas/core/middleware.py +131 -0
  44. atlas/core/otel_config.py +242 -0
  45. atlas/core/prompt_risk.py +200 -0
  46. atlas/core/rate_limit.py +0 -0
  47. atlas/core/rate_limit_middleware.py +64 -0
  48. atlas/core/security_headers_middleware.py +51 -0
  49. atlas/domain/__init__.py +37 -0
  50. atlas/domain/chat/__init__.py +1 -0
  51. atlas/domain/chat/dtos.py +85 -0
  52. atlas/domain/errors.py +96 -0
  53. atlas/domain/messages/__init__.py +12 -0
  54. atlas/domain/messages/models.py +160 -0
  55. atlas/domain/rag_mcp_service.py +664 -0
  56. atlas/domain/sessions/__init__.py +7 -0
  57. atlas/domain/sessions/models.py +36 -0
  58. atlas/domain/unified_rag_service.py +371 -0
  59. atlas/infrastructure/__init__.py +10 -0
  60. atlas/infrastructure/app_factory.py +135 -0
  61. atlas/infrastructure/events/__init__.py +1 -0
  62. atlas/infrastructure/events/cli_event_publisher.py +140 -0
  63. atlas/infrastructure/events/websocket_publisher.py +140 -0
  64. atlas/infrastructure/sessions/in_memory_repository.py +56 -0
  65. atlas/infrastructure/transport/__init__.py +7 -0
  66. atlas/infrastructure/transport/websocket_connection_adapter.py +33 -0
  67. atlas/init_cli.py +226 -0
  68. atlas/interfaces/__init__.py +15 -0
  69. atlas/interfaces/events.py +134 -0
  70. atlas/interfaces/llm.py +54 -0
  71. atlas/interfaces/rag.py +40 -0
  72. atlas/interfaces/sessions.py +75 -0
  73. atlas/interfaces/tools.py +57 -0
  74. atlas/interfaces/transport.py +24 -0
  75. atlas/main.py +564 -0
  76. atlas/mcp/api_key_demo/README.md +76 -0
  77. atlas/mcp/api_key_demo/main.py +172 -0
  78. atlas/mcp/api_key_demo/run.sh +56 -0
  79. atlas/mcp/basictable/main.py +147 -0
  80. atlas/mcp/calculator/main.py +149 -0
  81. atlas/mcp/code-executor/execution_engine.py +98 -0
  82. atlas/mcp/code-executor/execution_environment.py +95 -0
  83. atlas/mcp/code-executor/main.py +528 -0
  84. atlas/mcp/code-executor/result_processing.py +276 -0
  85. atlas/mcp/code-executor/script_generation.py +195 -0
  86. atlas/mcp/code-executor/security_checker.py +140 -0
  87. atlas/mcp/corporate_cars/main.py +437 -0
  88. atlas/mcp/csv_reporter/main.py +545 -0
  89. atlas/mcp/duckduckgo/main.py +182 -0
  90. atlas/mcp/elicitation_demo/README.md +171 -0
  91. atlas/mcp/elicitation_demo/main.py +262 -0
  92. atlas/mcp/env-demo/README.md +158 -0
  93. atlas/mcp/env-demo/main.py +199 -0
  94. atlas/mcp/file_size_test/main.py +284 -0
  95. atlas/mcp/filesystem/main.py +348 -0
  96. atlas/mcp/image_demo/main.py +113 -0
  97. atlas/mcp/image_demo/requirements.txt +4 -0
  98. atlas/mcp/logging_demo/README.md +72 -0
  99. atlas/mcp/logging_demo/main.py +103 -0
  100. atlas/mcp/many_tools_demo/main.py +50 -0
  101. atlas/mcp/order_database/__init__.py +0 -0
  102. atlas/mcp/order_database/main.py +369 -0
  103. atlas/mcp/order_database/signal_data.csv +1001 -0
  104. atlas/mcp/pdfbasic/main.py +394 -0
  105. atlas/mcp/pptx_generator/main.py +760 -0
  106. atlas/mcp/pptx_generator/requirements.txt +13 -0
  107. atlas/mcp/pptx_generator/run_test.sh +1 -0
  108. atlas/mcp/pptx_generator/test_pptx_generator_security.py +169 -0
  109. atlas/mcp/progress_demo/main.py +167 -0
  110. atlas/mcp/progress_updates_demo/QUICKSTART.md +273 -0
  111. atlas/mcp/progress_updates_demo/README.md +120 -0
  112. atlas/mcp/progress_updates_demo/main.py +497 -0
  113. atlas/mcp/prompts/main.py +222 -0
  114. atlas/mcp/public_demo/main.py +189 -0
  115. atlas/mcp/sampling_demo/README.md +169 -0
  116. atlas/mcp/sampling_demo/main.py +234 -0
  117. atlas/mcp/thinking/main.py +77 -0
  118. atlas/mcp/tool_planner/main.py +240 -0
  119. atlas/mcp/ui-demo/badmesh.png +0 -0
  120. atlas/mcp/ui-demo/main.py +383 -0
  121. atlas/mcp/ui-demo/templates/button_demo.html +32 -0
  122. atlas/mcp/ui-demo/templates/data_visualization.html +32 -0
  123. atlas/mcp/ui-demo/templates/form_demo.html +28 -0
  124. atlas/mcp/username-override-demo/README.md +320 -0
  125. atlas/mcp/username-override-demo/main.py +308 -0
  126. atlas/modules/__init__.py +0 -0
  127. atlas/modules/config/__init__.py +34 -0
  128. atlas/modules/config/cli.py +231 -0
  129. atlas/modules/config/config_manager.py +1096 -0
  130. atlas/modules/file_storage/__init__.py +22 -0
  131. atlas/modules/file_storage/cli.py +330 -0
  132. atlas/modules/file_storage/content_extractor.py +290 -0
  133. atlas/modules/file_storage/manager.py +295 -0
  134. atlas/modules/file_storage/mock_s3_client.py +402 -0
  135. atlas/modules/file_storage/s3_client.py +417 -0
  136. atlas/modules/llm/__init__.py +19 -0
  137. atlas/modules/llm/caller.py +287 -0
  138. atlas/modules/llm/litellm_caller.py +675 -0
  139. atlas/modules/llm/models.py +19 -0
  140. atlas/modules/mcp_tools/__init__.py +17 -0
  141. atlas/modules/mcp_tools/client.py +2123 -0
  142. atlas/modules/mcp_tools/token_storage.py +556 -0
  143. atlas/modules/prompts/prompt_provider.py +130 -0
  144. atlas/modules/rag/__init__.py +24 -0
  145. atlas/modules/rag/atlas_rag_client.py +336 -0
  146. atlas/modules/rag/client.py +129 -0
  147. atlas/routes/admin_routes.py +865 -0
  148. atlas/routes/config_routes.py +484 -0
  149. atlas/routes/feedback_routes.py +361 -0
  150. atlas/routes/files_routes.py +274 -0
  151. atlas/routes/health_routes.py +40 -0
  152. atlas/routes/mcp_auth_routes.py +223 -0
  153. atlas/server_cli.py +164 -0
  154. atlas/tests/conftest.py +20 -0
  155. atlas/tests/integration/test_mcp_auth_integration.py +152 -0
  156. atlas/tests/manual_test_sampling.py +87 -0
  157. atlas/tests/modules/mcp_tools/test_client_auth.py +226 -0
  158. atlas/tests/modules/mcp_tools/test_client_env.py +191 -0
  159. atlas/tests/test_admin_mcp_server_management_routes.py +141 -0
  160. atlas/tests/test_agent_roa.py +135 -0
  161. atlas/tests/test_app_factory_smoke.py +47 -0
  162. atlas/tests/test_approval_manager.py +439 -0
  163. atlas/tests/test_atlas_client.py +188 -0
  164. atlas/tests/test_atlas_rag_client.py +447 -0
  165. atlas/tests/test_atlas_rag_integration.py +224 -0
  166. atlas/tests/test_attach_file_flow.py +287 -0
  167. atlas/tests/test_auth_utils.py +165 -0
  168. atlas/tests/test_backend_public_url.py +185 -0
  169. atlas/tests/test_banner_logging.py +287 -0
  170. atlas/tests/test_capability_tokens_and_injection.py +203 -0
  171. atlas/tests/test_compliance_level.py +54 -0
  172. atlas/tests/test_compliance_manager.py +253 -0
  173. atlas/tests/test_config_manager.py +617 -0
  174. atlas/tests/test_config_manager_paths.py +12 -0
  175. atlas/tests/test_core_auth.py +18 -0
  176. atlas/tests/test_core_utils.py +190 -0
  177. atlas/tests/test_docker_env_sync.py +202 -0
  178. atlas/tests/test_domain_errors.py +329 -0
  179. atlas/tests/test_domain_whitelist.py +359 -0
  180. atlas/tests/test_elicitation_manager.py +408 -0
  181. atlas/tests/test_elicitation_routing.py +296 -0
  182. atlas/tests/test_env_demo_server.py +88 -0
  183. atlas/tests/test_error_classification.py +113 -0
  184. atlas/tests/test_error_flow_integration.py +116 -0
  185. atlas/tests/test_feedback_routes.py +333 -0
  186. atlas/tests/test_file_content_extraction.py +1134 -0
  187. atlas/tests/test_file_extraction_routes.py +158 -0
  188. atlas/tests/test_file_library.py +107 -0
  189. atlas/tests/test_file_manager_unit.py +18 -0
  190. atlas/tests/test_health_route.py +49 -0
  191. atlas/tests/test_http_client_stub.py +8 -0
  192. atlas/tests/test_imports_smoke.py +30 -0
  193. atlas/tests/test_interfaces_llm_response.py +9 -0
  194. atlas/tests/test_issue_access_denied_fix.py +136 -0
  195. atlas/tests/test_llm_env_expansion.py +836 -0
  196. atlas/tests/test_log_level_sensitive_data.py +285 -0
  197. atlas/tests/test_mcp_auth_routes.py +341 -0
  198. atlas/tests/test_mcp_client_auth.py +331 -0
  199. atlas/tests/test_mcp_data_injection.py +270 -0
  200. atlas/tests/test_mcp_get_authorized_servers.py +95 -0
  201. atlas/tests/test_mcp_hot_reload.py +512 -0
  202. atlas/tests/test_mcp_image_content.py +424 -0
  203. atlas/tests/test_mcp_logging.py +172 -0
  204. atlas/tests/test_mcp_progress_updates.py +313 -0
  205. atlas/tests/test_mcp_prompt_override_system_prompt.py +102 -0
  206. atlas/tests/test_mcp_prompts_server.py +39 -0
  207. atlas/tests/test_mcp_tool_result_parsing.py +296 -0
  208. atlas/tests/test_metrics_logger.py +56 -0
  209. atlas/tests/test_middleware_auth.py +379 -0
  210. atlas/tests/test_prompt_risk_and_acl.py +141 -0
  211. atlas/tests/test_rag_mcp_aggregator.py +204 -0
  212. atlas/tests/test_rag_mcp_service.py +224 -0
  213. atlas/tests/test_rate_limit_middleware.py +45 -0
  214. atlas/tests/test_routes_config_smoke.py +60 -0
  215. atlas/tests/test_routes_files_download_token.py +41 -0
  216. atlas/tests/test_routes_files_health.py +18 -0
  217. atlas/tests/test_runtime_imports.py +53 -0
  218. atlas/tests/test_sampling_integration.py +482 -0
  219. atlas/tests/test_security_admin_routes.py +61 -0
  220. atlas/tests/test_security_capability_tokens.py +65 -0
  221. atlas/tests/test_security_file_stats_scope.py +21 -0
  222. atlas/tests/test_security_header_injection.py +191 -0
  223. atlas/tests/test_security_headers_and_filename.py +63 -0
  224. atlas/tests/test_shared_session_repository.py +101 -0
  225. atlas/tests/test_system_prompt_loading.py +181 -0
  226. atlas/tests/test_token_storage.py +505 -0
  227. atlas/tests/test_tool_approval_config.py +93 -0
  228. atlas/tests/test_tool_approval_utils.py +356 -0
  229. atlas/tests/test_tool_authorization_group_filtering.py +223 -0
  230. atlas/tests/test_tool_details_in_config.py +108 -0
  231. atlas/tests/test_tool_planner.py +300 -0
  232. atlas/tests/test_unified_rag_service.py +398 -0
  233. atlas/tests/test_username_override_in_approval.py +258 -0
  234. atlas/tests/test_websocket_auth_header.py +168 -0
  235. atlas/version.py +6 -0
  236. atlas_chat-0.1.0.data/data/.env.example +253 -0
  237. atlas_chat-0.1.0.data/data/config/defaults/compliance-levels.json +44 -0
  238. atlas_chat-0.1.0.data/data/config/defaults/domain-whitelist.json +123 -0
  239. atlas_chat-0.1.0.data/data/config/defaults/file-extractors.json +74 -0
  240. atlas_chat-0.1.0.data/data/config/defaults/help-config.json +198 -0
  241. atlas_chat-0.1.0.data/data/config/defaults/llmconfig-buggy.yml +11 -0
  242. atlas_chat-0.1.0.data/data/config/defaults/llmconfig.yml +19 -0
  243. atlas_chat-0.1.0.data/data/config/defaults/mcp.json +138 -0
  244. atlas_chat-0.1.0.data/data/config/defaults/rag-sources.json +17 -0
  245. atlas_chat-0.1.0.data/data/config/defaults/splash-config.json +16 -0
  246. atlas_chat-0.1.0.dist-info/METADATA +236 -0
  247. atlas_chat-0.1.0.dist-info/RECORD +250 -0
  248. atlas_chat-0.1.0.dist-info/WHEEL +5 -0
  249. atlas_chat-0.1.0.dist-info/entry_points.txt +4 -0
  250. atlas_chat-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,398 @@
1
+ """Unit tests for UnifiedRAGService.
2
+
3
+ Tests the unified RAG service that aggregates HTTP and MCP RAG sources.
4
+ """
5
+
6
+ from unittest.mock import AsyncMock, MagicMock, patch
7
+
8
+ import pytest
9
+
10
+ from atlas.domain.unified_rag_service import UnifiedRAGService
11
+ from atlas.modules.config.config_manager import RAGSourceConfig, RAGSourcesConfig
12
+ from atlas.modules.rag.client import DataSource, RAGResponse
13
+
14
+
15
+ @pytest.fixture
16
+ def mock_config_manager():
17
+ """Create a mock config manager with test RAG sources."""
18
+ config_manager = MagicMock()
19
+
20
+ # Create test RAG sources config
21
+ http_source = RAGSourceConfig(
22
+ type="http",
23
+ display_name="Test HTTP RAG",
24
+ description="Test HTTP RAG source",
25
+ url="http://test-rag.example.com",
26
+ bearer_token="test-token",
27
+ groups=["users"],
28
+ compliance_level="Internal",
29
+ enabled=True,
30
+ )
31
+
32
+ mcp_source = RAGSourceConfig(
33
+ type="mcp",
34
+ display_name="Test MCP RAG",
35
+ description="Test MCP RAG source",
36
+ command=["python", "test_mcp.py"],
37
+ groups=["admin"],
38
+ compliance_level="SOC2",
39
+ enabled=True,
40
+ )
41
+
42
+ disabled_source = RAGSourceConfig(
43
+ type="http",
44
+ display_name="Disabled RAG",
45
+ url="http://disabled.example.com",
46
+ enabled=False,
47
+ )
48
+
49
+ config_manager.rag_sources_config = RAGSourcesConfig(
50
+ sources={
51
+ "test_http": http_source,
52
+ "test_mcp": mcp_source,
53
+ "disabled": disabled_source,
54
+ }
55
+ )
56
+
57
+ return config_manager
58
+
59
+
60
+ @pytest.fixture
61
+ def mock_auth_check():
62
+ """Create a mock auth check function."""
63
+ async def auth_check(username: str, group: str) -> bool:
64
+ # test@test.com is in "users" group only
65
+ if username == "test@test.com":
66
+ return group == "users"
67
+ # admin@test.com is in both "users" and "admin" groups
68
+ if username == "admin@test.com":
69
+ return group in ["users", "admin"]
70
+ return False
71
+
72
+ return auth_check
73
+
74
+
75
+ @pytest.fixture
76
+ def unified_rag_service(mock_config_manager, mock_auth_check):
77
+ """Create a UnifiedRAGService instance for testing."""
78
+ return UnifiedRAGService(
79
+ config_manager=mock_config_manager,
80
+ mcp_manager=None,
81
+ auth_check_func=mock_auth_check,
82
+ )
83
+
84
+
85
+ class TestUnifiedRAGServiceInit:
86
+ """Tests for UnifiedRAGService initialization."""
87
+
88
+ def test_init_with_all_params(self, mock_config_manager, mock_auth_check):
89
+ """Test initialization with all parameters."""
90
+ service = UnifiedRAGService(
91
+ config_manager=mock_config_manager,
92
+ mcp_manager=MagicMock(),
93
+ auth_check_func=mock_auth_check,
94
+ )
95
+
96
+ assert service.config_manager == mock_config_manager
97
+ assert service.auth_check_func == mock_auth_check
98
+ assert service._http_clients == {}
99
+
100
+ def test_init_without_optional_params(self, mock_config_manager):
101
+ """Test initialization without optional parameters."""
102
+ service = UnifiedRAGService(config_manager=mock_config_manager)
103
+
104
+ assert service.mcp_manager is None
105
+ assert service.auth_check_func is None
106
+
107
+
108
+ class TestHTTPClientCaching:
109
+ """Tests for HTTP client caching logic."""
110
+
111
+ def test_get_http_client_creates_new_client(self, unified_rag_service, mock_config_manager):
112
+ """Test that _get_http_client creates a new client when not cached."""
113
+ source_config = mock_config_manager.rag_sources_config.sources["test_http"]
114
+
115
+ with patch("atlas.domain.unified_rag_service.resolve_env_var", side_effect=lambda v, **kw: v):
116
+ client = unified_rag_service._get_http_client("test_http", source_config)
117
+
118
+ assert client is not None
119
+ assert "test_http" in unified_rag_service._http_clients
120
+ assert unified_rag_service._http_clients["test_http"] == client
121
+
122
+ def test_get_http_client_returns_cached_client(self, unified_rag_service, mock_config_manager):
123
+ """Test that _get_http_client returns cached client on second call."""
124
+ source_config = mock_config_manager.rag_sources_config.sources["test_http"]
125
+
126
+ with patch("atlas.domain.unified_rag_service.resolve_env_var", side_effect=lambda v, **kw: v):
127
+ client1 = unified_rag_service._get_http_client("test_http", source_config)
128
+ client2 = unified_rag_service._get_http_client("test_http", source_config)
129
+
130
+ assert client1 is client2
131
+
132
+
133
+ class TestUserAuthorization:
134
+ """Tests for user authorization logic."""
135
+
136
+ @pytest.mark.asyncio
137
+ async def test_is_user_authorized_no_groups(self, unified_rag_service):
138
+ """Test authorization when no groups are required."""
139
+ result = await unified_rag_service._is_user_authorized("anyone@test.com", [])
140
+ assert result is True
141
+
142
+ @pytest.mark.asyncio
143
+ async def test_is_user_authorized_user_in_group(self, unified_rag_service):
144
+ """Test authorization when user is in required group."""
145
+ result = await unified_rag_service._is_user_authorized("test@test.com", ["users"])
146
+ assert result is True
147
+
148
+ @pytest.mark.asyncio
149
+ async def test_is_user_authorized_user_not_in_group(self, unified_rag_service):
150
+ """Test authorization when user is not in required group."""
151
+ result = await unified_rag_service._is_user_authorized("test@test.com", ["admin"])
152
+ assert result is False
153
+
154
+ @pytest.mark.asyncio
155
+ async def test_is_user_authorized_no_auth_func(self, mock_config_manager):
156
+ """Test authorization when no auth check function is provided."""
157
+ service = UnifiedRAGService(config_manager=mock_config_manager)
158
+ result = await service._is_user_authorized("anyone@test.com", ["admin"])
159
+ # Should return True when no auth function (permissive by default)
160
+ assert result is True
161
+
162
+
163
+ class TestDiscoverDataSources:
164
+ """Tests for data source discovery."""
165
+
166
+ @pytest.mark.asyncio
167
+ async def test_discover_skips_disabled_sources(self, unified_rag_service):
168
+ """Test that disabled sources are skipped during discovery."""
169
+ with patch.object(unified_rag_service, "_discover_http_source", new_callable=AsyncMock) as mock_discover:
170
+ mock_discover.return_value = {"server": "test", "sources": []}
171
+
172
+ await unified_rag_service.discover_data_sources("test@test.com")
173
+
174
+ # Should not be called for disabled source
175
+ call_args = [call[0][0] for call in mock_discover.call_args_list]
176
+ assert "disabled" not in call_args
177
+
178
+ @pytest.mark.asyncio
179
+ async def test_discover_filters_by_authorization(self, unified_rag_service):
180
+ """Test that sources are filtered by user authorization."""
181
+ with patch.object(unified_rag_service, "_discover_http_source", new_callable=AsyncMock) as mock_discover:
182
+ mock_discover.return_value = {"server": "test", "sources": []}
183
+
184
+ # test@test.com is only in "users" group, not "admin"
185
+ await unified_rag_service.discover_data_sources("test@test.com")
186
+
187
+ # Should only discover test_http (users group), not test_mcp (admin group)
188
+ call_args = [call[0][0] for call in mock_discover.call_args_list]
189
+ assert "test_http" in call_args
190
+ # test_mcp requires admin group, which test@test.com doesn't have
191
+
192
+ @pytest.mark.asyncio
193
+ async def test_discover_includes_admin_sources_for_admin(self, unified_rag_service):
194
+ """Test that admin user can see admin-only sources."""
195
+ with patch.object(unified_rag_service, "_discover_http_source", new_callable=AsyncMock) as mock_discover:
196
+ mock_discover.return_value = {"server": "test", "sources": []}
197
+
198
+ # admin@test.com is in both "users" and "admin" groups
199
+ await unified_rag_service.discover_data_sources("admin@test.com")
200
+
201
+ # Should discover test_http (users group)
202
+ call_args = [call[0][0] for call in mock_discover.call_args_list]
203
+ assert "test_http" in call_args
204
+
205
+
206
+ class TestDiscoverHTTPSource:
207
+ """Tests for HTTP source discovery."""
208
+
209
+ @pytest.mark.asyncio
210
+ async def test_discover_http_source_success(self, unified_rag_service, mock_config_manager):
211
+ """Test successful HTTP source discovery."""
212
+ source_config = mock_config_manager.rag_sources_config.sources["test_http"]
213
+
214
+ mock_client = AsyncMock()
215
+ mock_client.discover_data_sources.return_value = [
216
+ DataSource(name="corpus1", compliance_level="Internal"),
217
+ DataSource(name="corpus2", compliance_level="Public"),
218
+ ]
219
+
220
+ with patch.object(unified_rag_service, "_get_http_client", return_value=mock_client):
221
+ result = await unified_rag_service._discover_http_source(
222
+ "test_http", source_config, "test@test.com"
223
+ )
224
+
225
+ assert result is not None
226
+ assert result["server"] == "test_http"
227
+ assert result["displayName"] == "Test HTTP RAG"
228
+ assert len(result["sources"]) == 2
229
+ assert result["sources"][0]["id"] == "corpus1"
230
+ assert result["sources"][1]["id"] == "corpus2"
231
+
232
+ @pytest.mark.asyncio
233
+ async def test_discover_http_source_empty(self, unified_rag_service, mock_config_manager):
234
+ """Test HTTP source discovery with no data sources."""
235
+ source_config = mock_config_manager.rag_sources_config.sources["test_http"]
236
+
237
+ mock_client = AsyncMock()
238
+ mock_client.discover_data_sources.return_value = []
239
+
240
+ with patch.object(unified_rag_service, "_get_http_client", return_value=mock_client):
241
+ result = await unified_rag_service._discover_http_source(
242
+ "test_http", source_config, "test@test.com"
243
+ )
244
+
245
+ assert result is None
246
+
247
+ @pytest.mark.asyncio
248
+ async def test_discover_http_source_error(self, unified_rag_service, mock_config_manager):
249
+ """Test HTTP source discovery handles errors gracefully."""
250
+ source_config = mock_config_manager.rag_sources_config.sources["test_http"]
251
+
252
+ mock_client = AsyncMock()
253
+ mock_client.discover_data_sources.side_effect = Exception("Connection failed")
254
+
255
+ with patch.object(unified_rag_service, "_get_http_client", return_value=mock_client):
256
+ result = await unified_rag_service._discover_http_source(
257
+ "test_http", source_config, "test@test.com"
258
+ )
259
+
260
+ assert result is None
261
+
262
+
263
+ class TestQueryRAG:
264
+ """Tests for RAG query routing."""
265
+
266
+ @pytest.mark.asyncio
267
+ async def test_query_rag_with_qualified_source(self, unified_rag_service, mock_config_manager):
268
+ """Test querying RAG with qualified source (server:source_id)."""
269
+ mock_client = AsyncMock()
270
+ mock_client.query_rag.return_value = RAGResponse(
271
+ content="Test response",
272
+ metadata=None,
273
+ )
274
+
275
+ with patch.object(unified_rag_service, "_get_http_client", return_value=mock_client):
276
+ result = await unified_rag_service.query_rag(
277
+ username="test@test.com",
278
+ qualified_data_source="test_http:corpus1",
279
+ messages=[{"role": "user", "content": "test query"}],
280
+ )
281
+
282
+ assert result.content == "Test response"
283
+ mock_client.query_rag.assert_called_once_with(
284
+ "test@test.com",
285
+ "corpus1",
286
+ [{"role": "user", "content": "test query"}],
287
+ )
288
+
289
+ @pytest.mark.asyncio
290
+ async def test_query_rag_unknown_server(self, unified_rag_service):
291
+ """Test querying RAG with unknown server raises error."""
292
+ with pytest.raises(ValueError, match="RAG source not found"):
293
+ await unified_rag_service.query_rag(
294
+ username="test@test.com",
295
+ qualified_data_source="unknown_server:corpus1",
296
+ messages=[],
297
+ )
298
+
299
+ @pytest.mark.asyncio
300
+ async def test_query_rag_mcp_source_without_service_raises(self, unified_rag_service):
301
+ """Test querying MCP source without RAGMCPService raises ValueError."""
302
+ # The unified_rag_service fixture has no rag_mcp_service configured
303
+ with pytest.raises(ValueError, match="RAGMCPService not configured"):
304
+ await unified_rag_service.query_rag(
305
+ username="admin@test.com",
306
+ qualified_data_source="test_mcp:corpus1",
307
+ messages=[],
308
+ )
309
+
310
+ @pytest.mark.asyncio
311
+ async def test_query_rag_mcp_source_routes_to_mcp_service(self, mock_config_manager, mock_auth_check):
312
+ """Test that MCP source queries are routed to RAGMCPService."""
313
+ # Create a mock RAGMCPService
314
+ mock_rag_mcp_service = MagicMock()
315
+ mock_rag_mcp_service.synthesize = AsyncMock(return_value={
316
+ "results": {
317
+ "answer": "Test answer from MCP RAG",
318
+ "citations": [],
319
+ },
320
+ "meta_data": {
321
+ "providers": {
322
+ "test_mcp": {"used_synth": True, "error": None}
323
+ },
324
+ "fallback_used": False,
325
+ },
326
+ })
327
+
328
+ # Create service with rag_mcp_service
329
+ service = UnifiedRAGService(
330
+ config_manager=mock_config_manager,
331
+ mcp_manager=None,
332
+ auth_check_func=mock_auth_check,
333
+ rag_mcp_service=mock_rag_mcp_service,
334
+ )
335
+
336
+ messages = [{"role": "user", "content": "What is the fleet info?"}]
337
+ result = await service.query_rag(
338
+ username="admin@test.com",
339
+ qualified_data_source="test_mcp:corpus1",
340
+ messages=messages,
341
+ )
342
+
343
+ # Verify RAGMCPService.synthesize was called
344
+ mock_rag_mcp_service.synthesize.assert_called_once_with(
345
+ username="admin@test.com",
346
+ query="What is the fleet info?",
347
+ sources=["test_mcp:corpus1"],
348
+ )
349
+
350
+ # Verify response format
351
+ assert isinstance(result, RAGResponse)
352
+ assert result.content == "Test answer from MCP RAG"
353
+ assert result.metadata is not None
354
+ assert result.metadata.data_source_name == "test_mcp"
355
+ assert result.metadata.retrieval_method == "mcp_synthesis"
356
+
357
+
358
+ class TestSourceFiltering:
359
+ """Tests for source filtering methods."""
360
+
361
+ def test_get_http_sources(self, unified_rag_service):
362
+ """Test getting only HTTP sources."""
363
+ sources = unified_rag_service.get_http_sources()
364
+
365
+ assert "test_http" in sources
366
+ assert "test_mcp" not in sources
367
+ assert "disabled" not in sources # Disabled sources are excluded
368
+
369
+ def test_get_mcp_sources(self, unified_rag_service):
370
+ """Test getting only MCP sources."""
371
+ sources = unified_rag_service.get_mcp_sources()
372
+
373
+ assert "test_mcp" in sources
374
+ assert "test_http" not in sources
375
+ assert "disabled" not in sources
376
+
377
+
378
+ class TestFindServerForSource:
379
+ """Tests for server lookup by source ID."""
380
+
381
+ def test_find_server_returns_none(self, unified_rag_service):
382
+ """Test that _find_server_for_source returns None (unimplemented)."""
383
+ result = unified_rag_service._find_server_for_source("corpus1")
384
+ assert result is None
385
+
386
+
387
+ class TestQueryRAGWithoutQualification:
388
+ """Tests for querying RAG without server prefix."""
389
+
390
+ @pytest.mark.asyncio
391
+ async def test_query_rag_without_prefix_raises(self, unified_rag_service):
392
+ """Test querying without server prefix raises error."""
393
+ with pytest.raises(ValueError, match="Could not find server"):
394
+ await unified_rag_service.query_rag(
395
+ username="test@test.com",
396
+ qualified_data_source="corpus1", # No server prefix
397
+ messages=[],
398
+ )
@@ -0,0 +1,258 @@
1
+ """Tests for username override security in tool approval flow."""
2
+
3
+
4
+ from unittest.mock import Mock
5
+
6
+ from atlas.application.chat.utilities.tool_executor import _filter_args_to_schema, inject_context_into_args
7
+
8
+
9
+ class TestUsernameOverrideInApproval:
10
+ """Test that username override cannot be bypassed through approval argument editing."""
11
+
12
+ def test_username_override_after_user_edit(self):
13
+ """Test that username is re-injected even after user edits it during approval."""
14
+ # Setup session context with authenticated user
15
+ session_context = {
16
+ "user_email": "alice@example.com",
17
+ "files": {}
18
+ }
19
+
20
+ # Simulate user editing username to a different value during approval
21
+ user_edited_args = {
22
+ "username": "malicious@example.com", # User tried to change this
23
+ "data": "test data"
24
+ }
25
+
26
+ # Mock tool manager that indicates tool accepts username
27
+ mock_tool_manager = Mock()
28
+ mock_tool_manager.get_tools_schema.return_value = [{
29
+ "function": {
30
+ "name": "create_record",
31
+ "parameters": {
32
+ "properties": {
33
+ "username": {"type": "string"},
34
+ "data": {"type": "string"}
35
+ }
36
+ }
37
+ }
38
+ }]
39
+
40
+ # Re-inject context (simulating what should happen after user approval)
41
+ re_injected_args = inject_context_into_args(
42
+ user_edited_args,
43
+ session_context,
44
+ "create_record",
45
+ mock_tool_manager
46
+ )
47
+
48
+ # Verify username was overridden back to authenticated user
49
+ assert re_injected_args["username"] == "alice@example.com"
50
+ assert re_injected_args["data"] == "test data"
51
+
52
+ # Re-filter to schema to simulate complete flow
53
+ filtered_args = _filter_args_to_schema(
54
+ re_injected_args,
55
+ "create_record",
56
+ mock_tool_manager
57
+ )
58
+
59
+ # Final result should have correct username
60
+ assert filtered_args["username"] == "alice@example.com"
61
+
62
+ def test_username_override_with_tool_that_doesnt_accept_username(self):
63
+ """Test that username is not injected for tools that don't accept it."""
64
+ session_context = {
65
+ "user_email": "alice@example.com",
66
+ "files": {}
67
+ }
68
+
69
+ user_edited_args = {
70
+ "query": "test query"
71
+ }
72
+
73
+ # Mock tool manager that indicates tool does NOT accept username
74
+ mock_tool_manager = Mock()
75
+ mock_tool_manager.get_tools_schema.return_value = [{
76
+ "function": {
77
+ "name": "search",
78
+ "parameters": {
79
+ "properties": {
80
+ "query": {"type": "string"}
81
+ }
82
+ }
83
+ }
84
+ }]
85
+
86
+ # Inject context
87
+ re_injected_args = inject_context_into_args(
88
+ user_edited_args,
89
+ session_context,
90
+ "search",
91
+ mock_tool_manager
92
+ )
93
+
94
+ # Verify username was NOT injected
95
+ assert "username" not in re_injected_args
96
+ assert re_injected_args["query"] == "test query"
97
+
98
+ def test_username_override_with_no_tool_manager(self):
99
+ """Test username injection when no tool manager is available (fallback)."""
100
+ session_context = {
101
+ "user_email": "bob@example.com",
102
+ "files": {}
103
+ }
104
+
105
+ user_edited_args = {
106
+ "data": "some data"
107
+ }
108
+
109
+ # Inject context with no tool manager (fallback mode)
110
+ re_injected_args = inject_context_into_args(
111
+ user_edited_args,
112
+ session_context,
113
+ "some_tool",
114
+ None # No tool manager
115
+ )
116
+
117
+ # Should still inject username in fallback mode
118
+ assert re_injected_args["username"] == "bob@example.com"
119
+ assert re_injected_args["data"] == "some data"
120
+
121
+ def test_multiple_security_injections_after_edit(self):
122
+ """Test that multiple security-critical parameters are protected."""
123
+ session_context = {
124
+ "user_email": "secure_user@example.com",
125
+ "files": {
126
+ "test.pdf": {"key": "file_key_123"}
127
+ }
128
+ }
129
+
130
+ # User tries to edit both username and filename details
131
+ user_edited_args = {
132
+ "username": "hacked@example.com", # Should be overridden
133
+ "filename": "test.pdf", # Valid filename
134
+ "data": "edited data"
135
+ }
136
+
137
+ mock_tool_manager = Mock()
138
+ mock_tool_manager.get_tools_schema.return_value = [{
139
+ "function": {
140
+ "name": "process_file",
141
+ "parameters": {
142
+ "properties": {
143
+ "username": {"type": "string"},
144
+ "filename": {"type": "string"},
145
+ "data": {"type": "string"}
146
+ }
147
+ }
148
+ }
149
+ }]
150
+
151
+ re_injected_args = inject_context_into_args(
152
+ user_edited_args,
153
+ session_context,
154
+ "process_file",
155
+ mock_tool_manager
156
+ )
157
+
158
+ # Username should be corrected
159
+ assert re_injected_args["username"] == "secure_user@example.com"
160
+ # File handling should work normally
161
+ assert "original_filename" in re_injected_args
162
+ assert re_injected_args["original_filename"] == "test.pdf"
163
+ # Data remains as user edited
164
+ assert re_injected_args["data"] == "edited data"
165
+
166
+ def test_prevented_impersonation_attack(self):
167
+ """Test specific impersonation attack scenario from vulnerability."""
168
+ session_context = {
169
+ "user_email": "alice@example.com",
170
+ "files": {}
171
+ }
172
+
173
+ # User (alice) tries to impersonate admin via approval dialog
174
+ user_edited_args = {
175
+ "username": "admin@example.com", # Impersonation attempt
176
+ "action": "delete_all_data"
177
+ }
178
+
179
+ mock_tool_manager = Mock()
180
+ mock_tool_manager.get_tools_schema.return_value = [{
181
+ "function": {
182
+ "name": "admin_action",
183
+ "parameters": {
184
+ "properties": {
185
+ "username": {"type": "string"},
186
+ "action": {"type": "string"}
187
+ }
188
+ }
189
+ }
190
+ }]
191
+
192
+ # Re-inject context (the security fix)
193
+ re_injected_args = inject_context_into_args(
194
+ user_edited_args,
195
+ session_context,
196
+ "admin_action",
197
+ mock_tool_manager
198
+ )
199
+
200
+ # Re-filter for complete security
201
+ filtered_args = _filter_args_to_schema(
202
+ re_injected_args,
203
+ "admin_action",
204
+ mock_tool_manager
205
+ )
206
+
207
+ # Security enforced: attack prevented
208
+ assert filtered_args["username"] == "alice@example.com" # Not admin
209
+ assert filtered_args["action"] == "delete_all_data" # Non-security param unchanged
210
+
211
+ def test_schema_filtering_preserves_security_injection(self):
212
+ """Test that schema filtering works correctly with re-injected arguments."""
213
+ session_context = {
214
+ "user_email": "secure@example.com",
215
+ "files": {}
216
+ }
217
+
218
+ # User tries to add schema-violating parameters
219
+ user_edited_args = {
220
+ "username": "hacked@example.com",
221
+ "data": "legitimate data",
222
+ "extra_param": "should_be_removed" # Not in schema
223
+ }
224
+
225
+ mock_tool_manager = Mock()
226
+ mock_tool_manager.get_tools_schema.return_value = [{
227
+ "function": {
228
+ "name": "limited_tool",
229
+ "parameters": {
230
+ "properties": {
231
+ "username": {"type": "string"},
232
+ "data": {"type": "string"}
233
+ # extra_param is NOT in schema
234
+ }
235
+ }
236
+ }
237
+ }]
238
+
239
+ # Re-inject and re-filter (complete security flow)
240
+ re_injected_args = inject_context_into_args(
241
+ user_edited_args,
242
+ session_context,
243
+ "limited_tool",
244
+ mock_tool_manager
245
+ )
246
+
247
+ filtered_args = _filter_args_to_schema(
248
+ re_injected_args,
249
+ "limited_tool",
250
+ mock_tool_manager
251
+ )
252
+
253
+ # Correct username enforced
254
+ assert filtered_args["username"] == "secure@example.com"
255
+ # Legitimate data preserved
256
+ assert filtered_args["data"] == "legitimate data"
257
+ # Schema violation removed
258
+ assert "extra_param" not in filtered_args