atlas-chat 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. atlas/__init__.py +40 -0
  2. atlas/application/__init__.py +7 -0
  3. atlas/application/chat/__init__.py +7 -0
  4. atlas/application/chat/agent/__init__.py +10 -0
  5. atlas/application/chat/agent/act_loop.py +179 -0
  6. atlas/application/chat/agent/factory.py +142 -0
  7. atlas/application/chat/agent/protocols.py +46 -0
  8. atlas/application/chat/agent/react_loop.py +338 -0
  9. atlas/application/chat/agent/think_act_loop.py +171 -0
  10. atlas/application/chat/approval_manager.py +151 -0
  11. atlas/application/chat/elicitation_manager.py +191 -0
  12. atlas/application/chat/events/__init__.py +1 -0
  13. atlas/application/chat/events/agent_event_relay.py +112 -0
  14. atlas/application/chat/modes/__init__.py +1 -0
  15. atlas/application/chat/modes/agent.py +125 -0
  16. atlas/application/chat/modes/plain.py +74 -0
  17. atlas/application/chat/modes/rag.py +81 -0
  18. atlas/application/chat/modes/tools.py +179 -0
  19. atlas/application/chat/orchestrator.py +213 -0
  20. atlas/application/chat/policies/__init__.py +1 -0
  21. atlas/application/chat/policies/tool_authorization.py +99 -0
  22. atlas/application/chat/preprocessors/__init__.py +1 -0
  23. atlas/application/chat/preprocessors/message_builder.py +92 -0
  24. atlas/application/chat/preprocessors/prompt_override_service.py +104 -0
  25. atlas/application/chat/service.py +454 -0
  26. atlas/application/chat/utilities/__init__.py +6 -0
  27. atlas/application/chat/utilities/error_handler.py +367 -0
  28. atlas/application/chat/utilities/event_notifier.py +546 -0
  29. atlas/application/chat/utilities/file_processor.py +613 -0
  30. atlas/application/chat/utilities/tool_executor.py +789 -0
  31. atlas/atlas_chat_cli.py +347 -0
  32. atlas/atlas_client.py +238 -0
  33. atlas/core/__init__.py +0 -0
  34. atlas/core/auth.py +205 -0
  35. atlas/core/authorization_manager.py +27 -0
  36. atlas/core/capabilities.py +123 -0
  37. atlas/core/compliance.py +215 -0
  38. atlas/core/domain_whitelist.py +147 -0
  39. atlas/core/domain_whitelist_middleware.py +82 -0
  40. atlas/core/http_client.py +28 -0
  41. atlas/core/log_sanitizer.py +102 -0
  42. atlas/core/metrics_logger.py +59 -0
  43. atlas/core/middleware.py +131 -0
  44. atlas/core/otel_config.py +242 -0
  45. atlas/core/prompt_risk.py +200 -0
  46. atlas/core/rate_limit.py +0 -0
  47. atlas/core/rate_limit_middleware.py +64 -0
  48. atlas/core/security_headers_middleware.py +51 -0
  49. atlas/domain/__init__.py +37 -0
  50. atlas/domain/chat/__init__.py +1 -0
  51. atlas/domain/chat/dtos.py +85 -0
  52. atlas/domain/errors.py +96 -0
  53. atlas/domain/messages/__init__.py +12 -0
  54. atlas/domain/messages/models.py +160 -0
  55. atlas/domain/rag_mcp_service.py +664 -0
  56. atlas/domain/sessions/__init__.py +7 -0
  57. atlas/domain/sessions/models.py +36 -0
  58. atlas/domain/unified_rag_service.py +371 -0
  59. atlas/infrastructure/__init__.py +10 -0
  60. atlas/infrastructure/app_factory.py +135 -0
  61. atlas/infrastructure/events/__init__.py +1 -0
  62. atlas/infrastructure/events/cli_event_publisher.py +140 -0
  63. atlas/infrastructure/events/websocket_publisher.py +140 -0
  64. atlas/infrastructure/sessions/in_memory_repository.py +56 -0
  65. atlas/infrastructure/transport/__init__.py +7 -0
  66. atlas/infrastructure/transport/websocket_connection_adapter.py +33 -0
  67. atlas/init_cli.py +226 -0
  68. atlas/interfaces/__init__.py +15 -0
  69. atlas/interfaces/events.py +134 -0
  70. atlas/interfaces/llm.py +54 -0
  71. atlas/interfaces/rag.py +40 -0
  72. atlas/interfaces/sessions.py +75 -0
  73. atlas/interfaces/tools.py +57 -0
  74. atlas/interfaces/transport.py +24 -0
  75. atlas/main.py +564 -0
  76. atlas/mcp/api_key_demo/README.md +76 -0
  77. atlas/mcp/api_key_demo/main.py +172 -0
  78. atlas/mcp/api_key_demo/run.sh +56 -0
  79. atlas/mcp/basictable/main.py +147 -0
  80. atlas/mcp/calculator/main.py +149 -0
  81. atlas/mcp/code-executor/execution_engine.py +98 -0
  82. atlas/mcp/code-executor/execution_environment.py +95 -0
  83. atlas/mcp/code-executor/main.py +528 -0
  84. atlas/mcp/code-executor/result_processing.py +276 -0
  85. atlas/mcp/code-executor/script_generation.py +195 -0
  86. atlas/mcp/code-executor/security_checker.py +140 -0
  87. atlas/mcp/corporate_cars/main.py +437 -0
  88. atlas/mcp/csv_reporter/main.py +545 -0
  89. atlas/mcp/duckduckgo/main.py +182 -0
  90. atlas/mcp/elicitation_demo/README.md +171 -0
  91. atlas/mcp/elicitation_demo/main.py +262 -0
  92. atlas/mcp/env-demo/README.md +158 -0
  93. atlas/mcp/env-demo/main.py +199 -0
  94. atlas/mcp/file_size_test/main.py +284 -0
  95. atlas/mcp/filesystem/main.py +348 -0
  96. atlas/mcp/image_demo/main.py +113 -0
  97. atlas/mcp/image_demo/requirements.txt +4 -0
  98. atlas/mcp/logging_demo/README.md +72 -0
  99. atlas/mcp/logging_demo/main.py +103 -0
  100. atlas/mcp/many_tools_demo/main.py +50 -0
  101. atlas/mcp/order_database/__init__.py +0 -0
  102. atlas/mcp/order_database/main.py +369 -0
  103. atlas/mcp/order_database/signal_data.csv +1001 -0
  104. atlas/mcp/pdfbasic/main.py +394 -0
  105. atlas/mcp/pptx_generator/main.py +760 -0
  106. atlas/mcp/pptx_generator/requirements.txt +13 -0
  107. atlas/mcp/pptx_generator/run_test.sh +1 -0
  108. atlas/mcp/pptx_generator/test_pptx_generator_security.py +169 -0
  109. atlas/mcp/progress_demo/main.py +167 -0
  110. atlas/mcp/progress_updates_demo/QUICKSTART.md +273 -0
  111. atlas/mcp/progress_updates_demo/README.md +120 -0
  112. atlas/mcp/progress_updates_demo/main.py +497 -0
  113. atlas/mcp/prompts/main.py +222 -0
  114. atlas/mcp/public_demo/main.py +189 -0
  115. atlas/mcp/sampling_demo/README.md +169 -0
  116. atlas/mcp/sampling_demo/main.py +234 -0
  117. atlas/mcp/thinking/main.py +77 -0
  118. atlas/mcp/tool_planner/main.py +240 -0
  119. atlas/mcp/ui-demo/badmesh.png +0 -0
  120. atlas/mcp/ui-demo/main.py +383 -0
  121. atlas/mcp/ui-demo/templates/button_demo.html +32 -0
  122. atlas/mcp/ui-demo/templates/data_visualization.html +32 -0
  123. atlas/mcp/ui-demo/templates/form_demo.html +28 -0
  124. atlas/mcp/username-override-demo/README.md +320 -0
  125. atlas/mcp/username-override-demo/main.py +308 -0
  126. atlas/modules/__init__.py +0 -0
  127. atlas/modules/config/__init__.py +34 -0
  128. atlas/modules/config/cli.py +231 -0
  129. atlas/modules/config/config_manager.py +1096 -0
  130. atlas/modules/file_storage/__init__.py +22 -0
  131. atlas/modules/file_storage/cli.py +330 -0
  132. atlas/modules/file_storage/content_extractor.py +290 -0
  133. atlas/modules/file_storage/manager.py +295 -0
  134. atlas/modules/file_storage/mock_s3_client.py +402 -0
  135. atlas/modules/file_storage/s3_client.py +417 -0
  136. atlas/modules/llm/__init__.py +19 -0
  137. atlas/modules/llm/caller.py +287 -0
  138. atlas/modules/llm/litellm_caller.py +675 -0
  139. atlas/modules/llm/models.py +19 -0
  140. atlas/modules/mcp_tools/__init__.py +17 -0
  141. atlas/modules/mcp_tools/client.py +2123 -0
  142. atlas/modules/mcp_tools/token_storage.py +556 -0
  143. atlas/modules/prompts/prompt_provider.py +130 -0
  144. atlas/modules/rag/__init__.py +24 -0
  145. atlas/modules/rag/atlas_rag_client.py +336 -0
  146. atlas/modules/rag/client.py +129 -0
  147. atlas/routes/admin_routes.py +865 -0
  148. atlas/routes/config_routes.py +484 -0
  149. atlas/routes/feedback_routes.py +361 -0
  150. atlas/routes/files_routes.py +274 -0
  151. atlas/routes/health_routes.py +40 -0
  152. atlas/routes/mcp_auth_routes.py +223 -0
  153. atlas/server_cli.py +164 -0
  154. atlas/tests/conftest.py +20 -0
  155. atlas/tests/integration/test_mcp_auth_integration.py +152 -0
  156. atlas/tests/manual_test_sampling.py +87 -0
  157. atlas/tests/modules/mcp_tools/test_client_auth.py +226 -0
  158. atlas/tests/modules/mcp_tools/test_client_env.py +191 -0
  159. atlas/tests/test_admin_mcp_server_management_routes.py +141 -0
  160. atlas/tests/test_agent_roa.py +135 -0
  161. atlas/tests/test_app_factory_smoke.py +47 -0
  162. atlas/tests/test_approval_manager.py +439 -0
  163. atlas/tests/test_atlas_client.py +188 -0
  164. atlas/tests/test_atlas_rag_client.py +447 -0
  165. atlas/tests/test_atlas_rag_integration.py +224 -0
  166. atlas/tests/test_attach_file_flow.py +287 -0
  167. atlas/tests/test_auth_utils.py +165 -0
  168. atlas/tests/test_backend_public_url.py +185 -0
  169. atlas/tests/test_banner_logging.py +287 -0
  170. atlas/tests/test_capability_tokens_and_injection.py +203 -0
  171. atlas/tests/test_compliance_level.py +54 -0
  172. atlas/tests/test_compliance_manager.py +253 -0
  173. atlas/tests/test_config_manager.py +617 -0
  174. atlas/tests/test_config_manager_paths.py +12 -0
  175. atlas/tests/test_core_auth.py +18 -0
  176. atlas/tests/test_core_utils.py +190 -0
  177. atlas/tests/test_docker_env_sync.py +202 -0
  178. atlas/tests/test_domain_errors.py +329 -0
  179. atlas/tests/test_domain_whitelist.py +359 -0
  180. atlas/tests/test_elicitation_manager.py +408 -0
  181. atlas/tests/test_elicitation_routing.py +296 -0
  182. atlas/tests/test_env_demo_server.py +88 -0
  183. atlas/tests/test_error_classification.py +113 -0
  184. atlas/tests/test_error_flow_integration.py +116 -0
  185. atlas/tests/test_feedback_routes.py +333 -0
  186. atlas/tests/test_file_content_extraction.py +1134 -0
  187. atlas/tests/test_file_extraction_routes.py +158 -0
  188. atlas/tests/test_file_library.py +107 -0
  189. atlas/tests/test_file_manager_unit.py +18 -0
  190. atlas/tests/test_health_route.py +49 -0
  191. atlas/tests/test_http_client_stub.py +8 -0
  192. atlas/tests/test_imports_smoke.py +30 -0
  193. atlas/tests/test_interfaces_llm_response.py +9 -0
  194. atlas/tests/test_issue_access_denied_fix.py +136 -0
  195. atlas/tests/test_llm_env_expansion.py +836 -0
  196. atlas/tests/test_log_level_sensitive_data.py +285 -0
  197. atlas/tests/test_mcp_auth_routes.py +341 -0
  198. atlas/tests/test_mcp_client_auth.py +331 -0
  199. atlas/tests/test_mcp_data_injection.py +270 -0
  200. atlas/tests/test_mcp_get_authorized_servers.py +95 -0
  201. atlas/tests/test_mcp_hot_reload.py +512 -0
  202. atlas/tests/test_mcp_image_content.py +424 -0
  203. atlas/tests/test_mcp_logging.py +172 -0
  204. atlas/tests/test_mcp_progress_updates.py +313 -0
  205. atlas/tests/test_mcp_prompt_override_system_prompt.py +102 -0
  206. atlas/tests/test_mcp_prompts_server.py +39 -0
  207. atlas/tests/test_mcp_tool_result_parsing.py +296 -0
  208. atlas/tests/test_metrics_logger.py +56 -0
  209. atlas/tests/test_middleware_auth.py +379 -0
  210. atlas/tests/test_prompt_risk_and_acl.py +141 -0
  211. atlas/tests/test_rag_mcp_aggregator.py +204 -0
  212. atlas/tests/test_rag_mcp_service.py +224 -0
  213. atlas/tests/test_rate_limit_middleware.py +45 -0
  214. atlas/tests/test_routes_config_smoke.py +60 -0
  215. atlas/tests/test_routes_files_download_token.py +41 -0
  216. atlas/tests/test_routes_files_health.py +18 -0
  217. atlas/tests/test_runtime_imports.py +53 -0
  218. atlas/tests/test_sampling_integration.py +482 -0
  219. atlas/tests/test_security_admin_routes.py +61 -0
  220. atlas/tests/test_security_capability_tokens.py +65 -0
  221. atlas/tests/test_security_file_stats_scope.py +21 -0
  222. atlas/tests/test_security_header_injection.py +191 -0
  223. atlas/tests/test_security_headers_and_filename.py +63 -0
  224. atlas/tests/test_shared_session_repository.py +101 -0
  225. atlas/tests/test_system_prompt_loading.py +181 -0
  226. atlas/tests/test_token_storage.py +505 -0
  227. atlas/tests/test_tool_approval_config.py +93 -0
  228. atlas/tests/test_tool_approval_utils.py +356 -0
  229. atlas/tests/test_tool_authorization_group_filtering.py +223 -0
  230. atlas/tests/test_tool_details_in_config.py +108 -0
  231. atlas/tests/test_tool_planner.py +300 -0
  232. atlas/tests/test_unified_rag_service.py +398 -0
  233. atlas/tests/test_username_override_in_approval.py +258 -0
  234. atlas/tests/test_websocket_auth_header.py +168 -0
  235. atlas/version.py +6 -0
  236. atlas_chat-0.1.0.data/data/.env.example +253 -0
  237. atlas_chat-0.1.0.data/data/config/defaults/compliance-levels.json +44 -0
  238. atlas_chat-0.1.0.data/data/config/defaults/domain-whitelist.json +123 -0
  239. atlas_chat-0.1.0.data/data/config/defaults/file-extractors.json +74 -0
  240. atlas_chat-0.1.0.data/data/config/defaults/help-config.json +198 -0
  241. atlas_chat-0.1.0.data/data/config/defaults/llmconfig-buggy.yml +11 -0
  242. atlas_chat-0.1.0.data/data/config/defaults/llmconfig.yml +19 -0
  243. atlas_chat-0.1.0.data/data/config/defaults/mcp.json +138 -0
  244. atlas_chat-0.1.0.data/data/config/defaults/rag-sources.json +17 -0
  245. atlas_chat-0.1.0.data/data/config/defaults/splash-config.json +16 -0
  246. atlas_chat-0.1.0.dist-info/METADATA +236 -0
  247. atlas_chat-0.1.0.dist-info/RECORD +250 -0
  248. atlas_chat-0.1.0.dist-info/WHEEL +5 -0
  249. atlas_chat-0.1.0.dist-info/entry_points.txt +4 -0
  250. atlas_chat-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,331 @@
1
+ """Unit tests for MCP client authentication methods.
2
+
3
+ Tests the per-user authentication functionality in MCPToolManager:
4
+ - _requires_user_auth: Check if server requires user authentication
5
+ - _get_user_client: Get or create user-specific client with token
6
+ - Cache validation and invalidation
7
+
8
+ Updated: 2025-01-23
9
+ """
10
+
11
+ from unittest.mock import MagicMock, patch
12
+
13
+ import pytest
14
+
15
+
16
+ class TestRequiresUserAuth:
17
+ """Test _requires_user_auth method."""
18
+
19
+ def test_requires_user_auth_for_jwt(self):
20
+ """JWT auth_type should require user auth."""
21
+ from atlas.modules.mcp_tools.client import MCPToolManager
22
+
23
+ manager = MCPToolManager.__new__(MCPToolManager)
24
+ manager.servers_config = {"test-server": {"auth_type": "jwt"}}
25
+
26
+ assert manager._requires_user_auth("test-server") is True
27
+
28
+ def test_requires_user_auth_for_oauth(self):
29
+ """OAuth auth_type should require user auth."""
30
+ from atlas.modules.mcp_tools.client import MCPToolManager
31
+
32
+ manager = MCPToolManager.__new__(MCPToolManager)
33
+ manager.servers_config = {"test-server": {"auth_type": "oauth"}}
34
+
35
+ assert manager._requires_user_auth("test-server") is True
36
+
37
+ def test_requires_user_auth_for_bearer(self):
38
+ """Bearer auth_type should require user auth."""
39
+ from atlas.modules.mcp_tools.client import MCPToolManager
40
+
41
+ manager = MCPToolManager.__new__(MCPToolManager)
42
+ manager.servers_config = {"test-server": {"auth_type": "bearer"}}
43
+
44
+ assert manager._requires_user_auth("test-server") is True
45
+
46
+ def test_requires_user_auth_for_api_key(self):
47
+ """API key auth_type should require user auth."""
48
+ from atlas.modules.mcp_tools.client import MCPToolManager
49
+
50
+ manager = MCPToolManager.__new__(MCPToolManager)
51
+ manager.servers_config = {"test-server": {"auth_type": "api_key"}}
52
+
53
+ assert manager._requires_user_auth("test-server") is True
54
+
55
+ def test_no_user_auth_for_none(self):
56
+ """None auth_type should not require user auth."""
57
+ from atlas.modules.mcp_tools.client import MCPToolManager
58
+
59
+ manager = MCPToolManager.__new__(MCPToolManager)
60
+ manager.servers_config = {"test-server": {"auth_type": "none"}}
61
+
62
+ assert manager._requires_user_auth("test-server") is False
63
+
64
+ def test_no_user_auth_when_missing(self):
65
+ """Missing auth_type should not require user auth (defaults to none)."""
66
+ from atlas.modules.mcp_tools.client import MCPToolManager
67
+
68
+ manager = MCPToolManager.__new__(MCPToolManager)
69
+ manager.servers_config = {"test-server": {}}
70
+
71
+ assert manager._requires_user_auth("test-server") is False
72
+
73
+ def test_no_user_auth_for_unknown_server(self):
74
+ """Unknown server should not require user auth."""
75
+ from atlas.modules.mcp_tools.client import MCPToolManager
76
+
77
+ manager = MCPToolManager.__new__(MCPToolManager)
78
+ manager.servers_config = {}
79
+
80
+ assert manager._requires_user_auth("unknown-server") is False
81
+
82
+
83
+ class TestGetUserClient:
84
+ """Test _get_user_client method."""
85
+
86
+ @pytest.fixture
87
+ def manager(self):
88
+ """Create a mock MCPToolManager for testing."""
89
+ import asyncio
90
+
91
+ from atlas.modules.mcp_tools.client import MCPToolManager
92
+
93
+ manager = MCPToolManager.__new__(MCPToolManager)
94
+ manager.servers_config = {
95
+ "test-server": {
96
+ "auth_type": "api_key",
97
+ "url": "http://localhost:8080"
98
+ }
99
+ }
100
+ manager._user_clients = {}
101
+ manager._user_clients_lock = asyncio.Lock()
102
+ manager._create_log_handler = MagicMock(return_value=None)
103
+ manager._create_elicitation_handler = MagicMock(return_value=None)
104
+ manager._create_sampling_handler = MagicMock(return_value=None)
105
+ return manager
106
+
107
+ @pytest.mark.asyncio
108
+ async def test_returns_none_without_token(self, manager):
109
+ """Should return None when user has no token stored."""
110
+ with patch("atlas.modules.mcp_tools.token_storage.get_token_storage") as mock_storage:
111
+ mock_token_storage = MagicMock()
112
+ mock_token_storage.get_valid_token.return_value = None
113
+ mock_storage.return_value = mock_token_storage
114
+
115
+ result = await manager._get_user_client("test-server", "user@example.com")
116
+
117
+ assert result is None
118
+
119
+ @pytest.mark.asyncio
120
+ async def test_creates_client_with_api_key_header(self, manager):
121
+ """Should create client with custom header for API key auth type."""
122
+ with patch("atlas.modules.mcp_tools.token_storage.get_token_storage") as mock_storage, \
123
+ patch("atlas.modules.mcp_tools.client.Client") as mock_client_class, \
124
+ patch("atlas.modules.mcp_tools.client.StreamableHttpTransport") as mock_transport_class:
125
+
126
+ # Mock token storage
127
+ mock_token = MagicMock()
128
+ mock_token.token_value = "test-api-key-123"
129
+ mock_token_storage = MagicMock()
130
+ mock_token_storage.get_valid_token.return_value = mock_token
131
+ mock_storage.return_value = mock_token_storage
132
+
133
+ # Mock transport and Client constructor
134
+ mock_transport = MagicMock()
135
+ mock_transport_class.return_value = mock_transport
136
+ mock_client = MagicMock()
137
+ mock_client_class.return_value = mock_client
138
+
139
+ result = await manager._get_user_client("test-server", "user@example.com")
140
+
141
+ assert result is mock_client
142
+ # Verify StreamableHttpTransport was created with custom header
143
+ mock_transport_class.assert_called_once()
144
+ transport_call_kwargs = mock_transport_class.call_args
145
+ assert transport_call_kwargs[1]["headers"] == {"X-API-Key": "test-api-key-123"}
146
+ # Verify Client was created with transport
147
+ mock_client_class.assert_called_once()
148
+ client_call_kwargs = mock_client_class.call_args
149
+ assert client_call_kwargs[1]["transport"] is mock_transport
150
+
151
+ @pytest.mark.asyncio
152
+ async def test_creates_client_with_bearer_token(self):
153
+ """Should create client with auth parameter for bearer auth type."""
154
+ import asyncio
155
+
156
+ from atlas.modules.mcp_tools.client import MCPToolManager
157
+
158
+ manager = MCPToolManager.__new__(MCPToolManager)
159
+ manager.servers_config = {
160
+ "bearer-server": {
161
+ "auth_type": "bearer",
162
+ "url": "http://localhost:8080"
163
+ }
164
+ }
165
+ manager._user_clients = {}
166
+ manager._user_clients_lock = asyncio.Lock()
167
+ manager._create_log_handler = MagicMock(return_value=None)
168
+ manager._create_elicitation_handler = MagicMock(return_value=None)
169
+ manager._create_sampling_handler = MagicMock(return_value=None)
170
+
171
+ with patch("atlas.modules.mcp_tools.token_storage.get_token_storage") as mock_storage, \
172
+ patch("atlas.modules.mcp_tools.client.Client") as mock_client_class:
173
+
174
+ # Mock token storage
175
+ mock_token = MagicMock()
176
+ mock_token.token_value = "bearer-token-123"
177
+ mock_token_storage = MagicMock()
178
+ mock_token_storage.get_valid_token.return_value = mock_token
179
+ mock_storage.return_value = mock_token_storage
180
+
181
+ # Mock Client constructor
182
+ mock_client = MagicMock()
183
+ mock_client_class.return_value = mock_client
184
+
185
+ result = await manager._get_user_client("bearer-server", "user@example.com")
186
+
187
+ assert result is mock_client
188
+ mock_client_class.assert_called_once()
189
+ # Verify token was passed as auth parameter (not via transport)
190
+ call_args = mock_client_class.call_args
191
+ assert call_args[0][0] == "http://localhost:8080" # URL as first positional arg
192
+ assert call_args[1]["auth"] == "bearer-token-123"
193
+
194
+ @pytest.mark.asyncio
195
+ async def test_uses_custom_auth_header_name(self):
196
+ """Should use custom auth_header from config for API key auth."""
197
+ import asyncio
198
+
199
+ from atlas.modules.mcp_tools.client import MCPToolManager
200
+
201
+ manager = MCPToolManager.__new__(MCPToolManager)
202
+ manager.servers_config = {
203
+ "custom-header-server": {
204
+ "auth_type": "api_key",
205
+ "auth_header": "X-Custom-Auth",
206
+ "url": "http://localhost:8080"
207
+ }
208
+ }
209
+ manager._user_clients = {}
210
+ manager._user_clients_lock = asyncio.Lock()
211
+ manager._create_log_handler = MagicMock(return_value=None)
212
+ manager._create_elicitation_handler = MagicMock(return_value=None)
213
+ manager._create_sampling_handler = MagicMock(return_value=None)
214
+
215
+ with patch("atlas.modules.mcp_tools.token_storage.get_token_storage") as mock_storage, \
216
+ patch("atlas.modules.mcp_tools.client.Client") as mock_client_class, \
217
+ patch("atlas.modules.mcp_tools.client.StreamableHttpTransport") as mock_transport_class:
218
+
219
+ mock_token = MagicMock()
220
+ mock_token.token_value = "custom-key-456"
221
+ mock_token_storage = MagicMock()
222
+ mock_token_storage.get_valid_token.return_value = mock_token
223
+ mock_storage.return_value = mock_token_storage
224
+
225
+ mock_transport = MagicMock()
226
+ mock_transport_class.return_value = mock_transport
227
+ mock_client = MagicMock()
228
+ mock_client_class.return_value = mock_client
229
+
230
+ result = await manager._get_user_client("custom-header-server", "user@example.com")
231
+
232
+ assert result is mock_client
233
+ # Verify custom header name was used
234
+ transport_call_kwargs = mock_transport_class.call_args
235
+ assert transport_call_kwargs[1]["headers"] == {"X-Custom-Auth": "custom-key-456"}
236
+
237
+ @pytest.mark.asyncio
238
+ async def test_caches_client(self, manager):
239
+ """Should cache client for subsequent calls."""
240
+ with patch("atlas.modules.mcp_tools.token_storage.get_token_storage") as mock_storage, \
241
+ patch("atlas.modules.mcp_tools.client.Client") as mock_client_class, \
242
+ patch("atlas.modules.mcp_tools.client.StreamableHttpTransport") as mock_transport_class:
243
+
244
+ mock_token = MagicMock()
245
+ mock_token.token_value = "test-api-key"
246
+ mock_token_storage = MagicMock()
247
+ mock_token_storage.get_valid_token.return_value = mock_token
248
+ mock_storage.return_value = mock_token_storage
249
+
250
+ mock_transport = MagicMock()
251
+ mock_transport_class.return_value = mock_transport
252
+ mock_client = MagicMock()
253
+ mock_client_class.return_value = mock_client
254
+
255
+ # First call creates client
256
+ result1 = await manager._get_user_client("test-server", "user@example.com")
257
+ # Second call should use cache
258
+ result2 = await manager._get_user_client("test-server", "user@example.com")
259
+
260
+ assert result1 is result2
261
+ # Client constructor only called once
262
+ assert mock_client_class.call_count == 1
263
+
264
+ @pytest.mark.asyncio
265
+ async def test_invalidates_cache_on_expired_token(self, manager):
266
+ """Should invalidate cached client when token expires."""
267
+ with patch("atlas.modules.mcp_tools.token_storage.get_token_storage") as mock_storage, \
268
+ patch("atlas.modules.mcp_tools.client.Client") as mock_client_class, \
269
+ patch("atlas.modules.mcp_tools.client.StreamableHttpTransport") as mock_transport_class:
270
+
271
+ mock_token = MagicMock()
272
+ mock_token.token_value = "test-api-key"
273
+ mock_token_storage = MagicMock()
274
+ mock_storage.return_value = mock_token_storage
275
+
276
+ mock_transport = MagicMock()
277
+ mock_transport_class.return_value = mock_transport
278
+ mock_client = MagicMock()
279
+ mock_client_class.return_value = mock_client
280
+
281
+ # First call - token valid
282
+ mock_token_storage.get_valid_token.return_value = mock_token
283
+ result1 = await manager._get_user_client("test-server", "user@example.com")
284
+ assert result1 is mock_client
285
+
286
+ # Second call - token expired (returns None)
287
+ mock_token_storage.get_valid_token.return_value = None
288
+ result2 = await manager._get_user_client("test-server", "user@example.com")
289
+
290
+ # Should return None and cache should be invalidated
291
+ assert result2 is None
292
+ cache_key = ("user@example.com", "test-server")
293
+ assert cache_key not in manager._user_clients
294
+
295
+
296
+ class TestInvalidateUserClient:
297
+ """Test _invalidate_user_client method."""
298
+
299
+ @pytest.mark.asyncio
300
+ async def test_removes_cached_client(self):
301
+ """Should remove client from cache."""
302
+ import asyncio
303
+
304
+ from atlas.modules.mcp_tools.client import MCPToolManager
305
+
306
+ manager = MCPToolManager.__new__(MCPToolManager)
307
+ manager._user_clients = {
308
+ ("user@example.com", "test-server"): MagicMock(),
309
+ ("other@example.com", "test-server"): MagicMock(),
310
+ }
311
+ manager._user_clients_lock = asyncio.Lock()
312
+
313
+ await manager._invalidate_user_client("user@example.com", "test-server")
314
+
315
+ assert ("user@example.com", "test-server") not in manager._user_clients
316
+ # Other user's client should remain
317
+ assert ("other@example.com", "test-server") in manager._user_clients
318
+
319
+ @pytest.mark.asyncio
320
+ async def test_handles_missing_cache_entry(self):
321
+ """Should not error when cache entry doesn't exist."""
322
+ import asyncio
323
+
324
+ from atlas.modules.mcp_tools.client import MCPToolManager
325
+
326
+ manager = MCPToolManager.__new__(MCPToolManager)
327
+ manager._user_clients = {}
328
+ manager._user_clients_lock = asyncio.Lock()
329
+
330
+ # Should not raise
331
+ await manager._invalidate_user_client("user@example.com", "test-server")
@@ -0,0 +1,270 @@
1
+ """
2
+ Tests for _mcp_data injection into MCP tool arguments.
3
+
4
+ Validates that tools declaring an _mcp_data parameter in their schema
5
+ receive structured metadata about all available MCP tools, following
6
+ the same pattern as the username injection feature.
7
+ """
8
+
9
+ from unittest.mock import MagicMock
10
+
11
+ from atlas.application.chat.utilities.tool_executor import (
12
+ build_mcp_data,
13
+ inject_context_into_args,
14
+ tool_accepts_mcp_data,
15
+ )
16
+
17
+
18
+ class FakeTool:
19
+ """Minimal tool object matching the MCP tool interface."""
20
+
21
+ def __init__(self, name, description="", inputSchema=None):
22
+ self.name = name
23
+ self.description = description
24
+ self.inputSchema = inputSchema or {"type": "object", "properties": {}}
25
+
26
+
27
+ def _make_tool_manager(available_tools=None, schema_override=None):
28
+ """Create a mock tool manager with the given available_tools dict."""
29
+ manager = MagicMock()
30
+ manager.available_tools = available_tools or {}
31
+
32
+ def get_tools_schema(tool_names):
33
+ if schema_override is not None:
34
+ return schema_override
35
+ schemas = []
36
+ for server_name, server_data in manager.available_tools.items():
37
+ for tool in server_data.get("tools", []):
38
+ fq_name = f"{server_name}_{tool.name}"
39
+ if fq_name in tool_names:
40
+ schemas.append({
41
+ "type": "function",
42
+ "function": {
43
+ "name": fq_name,
44
+ "description": tool.description or "",
45
+ "parameters": tool.inputSchema or {},
46
+ },
47
+ })
48
+ return schemas
49
+
50
+ manager.get_tools_schema = MagicMock(side_effect=get_tools_schema)
51
+ return manager
52
+
53
+
54
+ # -- tool_accepts_mcp_data tests --
55
+
56
+
57
+ class TestToolAcceptsMcpData:
58
+ """Tests for tool_accepts_mcp_data detection function."""
59
+
60
+ def test_returns_true_when_schema_has_mcp_data(self):
61
+ tool = FakeTool(
62
+ "planner",
63
+ inputSchema={
64
+ "type": "object",
65
+ "properties": {
66
+ "task": {"type": "string"},
67
+ "_mcp_data": {"type": "object"},
68
+ },
69
+ },
70
+ )
71
+ manager = _make_tool_manager({"demo": {"tools": [tool], "config": {}}})
72
+ assert tool_accepts_mcp_data("demo_planner", manager) is True
73
+
74
+ def test_returns_false_when_schema_lacks_mcp_data(self):
75
+ tool = FakeTool(
76
+ "search",
77
+ inputSchema={
78
+ "type": "object",
79
+ "properties": {"query": {"type": "string"}},
80
+ },
81
+ )
82
+ manager = _make_tool_manager({"demo": {"tools": [tool], "config": {}}})
83
+ assert tool_accepts_mcp_data("demo_search", manager) is False
84
+
85
+ def test_returns_false_with_no_tool_name(self):
86
+ assert tool_accepts_mcp_data("", MagicMock()) is False
87
+
88
+ def test_returns_false_with_no_tool_manager(self):
89
+ assert tool_accepts_mcp_data("some_tool", None) is False
90
+
91
+ def test_returns_false_when_schema_lookup_fails(self):
92
+ manager = MagicMock()
93
+ manager.get_tools_schema = MagicMock(side_effect=RuntimeError("fail"))
94
+ assert tool_accepts_mcp_data("any_tool", manager) is False
95
+
96
+
97
+ # -- build_mcp_data tests --
98
+
99
+
100
+ class TestBuildMcpData:
101
+ """Tests for build_mcp_data output structure."""
102
+
103
+ def test_returns_empty_when_no_tools(self):
104
+ manager = _make_tool_manager({})
105
+ result = build_mcp_data(manager)
106
+ assert result == {"available_servers": []}
107
+
108
+ def test_returns_empty_when_no_manager(self):
109
+ result = build_mcp_data(None)
110
+ assert result == {"available_servers": []}
111
+
112
+ def test_skips_canvas_server(self):
113
+ tool = FakeTool("canvas")
114
+ manager = _make_tool_manager({
115
+ "canvas": {"tools": [tool], "config": {}},
116
+ })
117
+ result = build_mcp_data(manager)
118
+ assert result["available_servers"] == []
119
+
120
+ def test_includes_server_and_tool_metadata(self):
121
+ tool_a = FakeTool(
122
+ "search",
123
+ description="Search documents",
124
+ inputSchema={
125
+ "type": "object",
126
+ "properties": {"query": {"type": "string"}},
127
+ },
128
+ )
129
+ tool_b = FakeTool("list", description="List items")
130
+ manager = _make_tool_manager({
131
+ "myserver": {
132
+ "tools": [tool_a, tool_b],
133
+ "config": {"description": "My Server"},
134
+ },
135
+ })
136
+
137
+ result = build_mcp_data(manager)
138
+ assert len(result["available_servers"]) == 1
139
+ server = result["available_servers"][0]
140
+ assert server["server_name"] == "myserver"
141
+ assert server["description"] == "My Server"
142
+ assert len(server["tools"]) == 2
143
+
144
+ tool_entry = server["tools"][0]
145
+ assert tool_entry["name"] == "myserver_search"
146
+ assert tool_entry["description"] == "Search documents"
147
+ assert "properties" in tool_entry["parameters"]
148
+
149
+ def test_multiple_servers(self):
150
+ tool1 = FakeTool("t1")
151
+ tool2 = FakeTool("t2")
152
+ manager = _make_tool_manager({
153
+ "server_a": {"tools": [tool1], "config": {}},
154
+ "server_b": {"tools": [tool2], "config": {}},
155
+ })
156
+ result = build_mcp_data(manager)
157
+ names = [s["server_name"] for s in result["available_servers"]]
158
+ assert "server_a" in names
159
+ assert "server_b" in names
160
+
161
+ def test_handles_missing_description(self):
162
+ tool = FakeTool("t", description=None, inputSchema=None)
163
+ manager = _make_tool_manager({
164
+ "s": {"tools": [tool], "config": {}},
165
+ })
166
+ result = build_mcp_data(manager)
167
+ server = result["available_servers"][0]
168
+ assert server["description"] == ""
169
+ assert server["tools"][0]["description"] == ""
170
+ assert server["tools"][0]["parameters"] == {"type": "object", "properties": {}}
171
+
172
+ def test_uses_short_description_fallback(self):
173
+ tool = FakeTool("t")
174
+ manager = _make_tool_manager({
175
+ "s": {"tools": [tool], "config": {"short_description": "Short desc"}},
176
+ })
177
+ result = build_mcp_data(manager)
178
+ assert result["available_servers"][0]["description"] == "Short desc"
179
+
180
+
181
+ # -- inject_context_into_args with _mcp_data tests --
182
+
183
+
184
+ class TestInjectMcpData:
185
+ """Tests for _mcp_data injection in inject_context_into_args."""
186
+
187
+ def test_injects_mcp_data_when_tool_accepts_it(self):
188
+ tool = FakeTool(
189
+ "planner",
190
+ inputSchema={
191
+ "type": "object",
192
+ "properties": {
193
+ "task": {"type": "string"},
194
+ "_mcp_data": {"type": "object"},
195
+ },
196
+ },
197
+ )
198
+ manager = _make_tool_manager({
199
+ "demo": {"tools": [tool], "config": {"description": "Demo"}},
200
+ })
201
+
202
+ result = inject_context_into_args(
203
+ {"task": "do something"},
204
+ {"user_email": "user@test.com"},
205
+ "demo_planner",
206
+ manager,
207
+ )
208
+
209
+ assert "_mcp_data" in result
210
+ assert "available_servers" in result["_mcp_data"]
211
+ assert len(result["_mcp_data"]["available_servers"]) == 1
212
+
213
+ def test_does_not_inject_mcp_data_when_tool_lacks_param(self):
214
+ tool = FakeTool(
215
+ "search",
216
+ inputSchema={
217
+ "type": "object",
218
+ "properties": {"query": {"type": "string"}},
219
+ },
220
+ )
221
+ manager = _make_tool_manager({
222
+ "demo": {"tools": [tool], "config": {}},
223
+ })
224
+
225
+ result = inject_context_into_args(
226
+ {"query": "hello"},
227
+ {"user_email": "user@test.com"},
228
+ "demo_search",
229
+ manager,
230
+ )
231
+
232
+ assert "_mcp_data" not in result
233
+
234
+ def test_mcp_data_reinjected_after_edit(self):
235
+ """Simulates the re-injection path after user edits tool arguments."""
236
+ tool = FakeTool(
237
+ "planner",
238
+ inputSchema={
239
+ "type": "object",
240
+ "properties": {
241
+ "task": {"type": "string"},
242
+ "_mcp_data": {"type": "object"},
243
+ },
244
+ },
245
+ )
246
+ manager = _make_tool_manager({
247
+ "demo": {"tools": [tool], "config": {}},
248
+ })
249
+
250
+ # Simulate user editing args (removing _mcp_data)
251
+ edited_args = {"task": "edited task"}
252
+ result = inject_context_into_args(
253
+ edited_args,
254
+ {"user_email": "user@test.com"},
255
+ "demo_planner",
256
+ manager,
257
+ )
258
+
259
+ # _mcp_data should be re-injected
260
+ assert "_mcp_data" in result
261
+ assert result["task"] == "edited task"
262
+
263
+ def test_mcp_data_not_injected_without_tool_manager(self):
264
+ result = inject_context_into_args(
265
+ {"task": "test"},
266
+ {"user_email": "user@test.com"},
267
+ "some_tool",
268
+ None,
269
+ )
270
+ assert "_mcp_data" not in result
@@ -0,0 +1,95 @@
1
+ """Test get_authorized_servers with async auth function."""
2
+
3
+ import pytest
4
+
5
+ from atlas.modules.mcp_tools.client import MCPToolManager
6
+
7
+
8
+ @pytest.mark.asyncio
9
+ async def test_get_authorized_servers_with_async_auth():
10
+ """Test that get_authorized_servers properly handles async auth_check_func."""
11
+
12
+ # Create a mock MCPToolManager with test server config
13
+ mcp_manager = MCPToolManager(None)
14
+ mcp_manager.servers_config = {
15
+ "server1": {
16
+ "enabled": True,
17
+ "groups": ["admin", "users"]
18
+ },
19
+ "server2": {
20
+ "enabled": True,
21
+ "groups": ["admin"]
22
+ },
23
+ "server3": {
24
+ "enabled": True,
25
+ "groups": [] # No groups required
26
+ },
27
+ "server4": {
28
+ "enabled": False,
29
+ "groups": ["admin"]
30
+ }
31
+ }
32
+
33
+ # Mock async auth function
34
+ async def mock_auth_check(user_email: str, group: str) -> bool:
35
+ """Mock auth check that returns True for admin group."""
36
+ return group == "admin"
37
+
38
+ # Test with user who has admin access
39
+ authorized = await mcp_manager.get_authorized_servers("admin@test.com", mock_auth_check)
40
+
41
+ # Should include server1 (has admin), server2 (has admin), server3 (no groups required)
42
+ # Should NOT include server4 (disabled)
43
+ assert set(authorized) == {"server1", "server2", "server3"}
44
+
45
+
46
+ @pytest.mark.asyncio
47
+ async def test_get_authorized_servers_with_multiple_groups():
48
+ """Test authorization with multiple group checks."""
49
+
50
+ mcp_manager = MCPToolManager(None)
51
+ mcp_manager.servers_config = {
52
+ "server1": {
53
+ "enabled": True,
54
+ "groups": ["users", "developers"]
55
+ },
56
+ "server2": {
57
+ "enabled": True,
58
+ "groups": ["admin"]
59
+ }
60
+ }
61
+
62
+ # User is in 'users' group but not 'admin'
63
+ async def mock_auth_check(user_email: str, group: str) -> bool:
64
+ return group in ["users", "developers"]
65
+
66
+ authorized = await mcp_manager.get_authorized_servers("user@test.com", mock_auth_check)
67
+
68
+ # Should include server1 (user is in 'users' group)
69
+ # Should NOT include server2 (user not in 'admin' group)
70
+ assert authorized == ["server1"]
71
+
72
+
73
+ @pytest.mark.asyncio
74
+ async def test_get_authorized_servers_no_access():
75
+ """Test when user has no access to any servers."""
76
+
77
+ mcp_manager = MCPToolManager(None)
78
+ mcp_manager.servers_config = {
79
+ "server1": {
80
+ "enabled": True,
81
+ "groups": ["admin"]
82
+ },
83
+ "server2": {
84
+ "enabled": True,
85
+ "groups": ["superusers"]
86
+ }
87
+ }
88
+
89
+ # User has no group memberships
90
+ async def mock_auth_check(user_email: str, group: str) -> bool:
91
+ return False
92
+
93
+ authorized = await mcp_manager.get_authorized_servers("user@test.com", mock_auth_check)
94
+
95
+ assert authorized == []