atlas-chat 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. atlas/__init__.py +40 -0
  2. atlas/application/__init__.py +7 -0
  3. atlas/application/chat/__init__.py +7 -0
  4. atlas/application/chat/agent/__init__.py +10 -0
  5. atlas/application/chat/agent/act_loop.py +179 -0
  6. atlas/application/chat/agent/factory.py +142 -0
  7. atlas/application/chat/agent/protocols.py +46 -0
  8. atlas/application/chat/agent/react_loop.py +338 -0
  9. atlas/application/chat/agent/think_act_loop.py +171 -0
  10. atlas/application/chat/approval_manager.py +151 -0
  11. atlas/application/chat/elicitation_manager.py +191 -0
  12. atlas/application/chat/events/__init__.py +1 -0
  13. atlas/application/chat/events/agent_event_relay.py +112 -0
  14. atlas/application/chat/modes/__init__.py +1 -0
  15. atlas/application/chat/modes/agent.py +125 -0
  16. atlas/application/chat/modes/plain.py +74 -0
  17. atlas/application/chat/modes/rag.py +81 -0
  18. atlas/application/chat/modes/tools.py +179 -0
  19. atlas/application/chat/orchestrator.py +213 -0
  20. atlas/application/chat/policies/__init__.py +1 -0
  21. atlas/application/chat/policies/tool_authorization.py +99 -0
  22. atlas/application/chat/preprocessors/__init__.py +1 -0
  23. atlas/application/chat/preprocessors/message_builder.py +92 -0
  24. atlas/application/chat/preprocessors/prompt_override_service.py +104 -0
  25. atlas/application/chat/service.py +454 -0
  26. atlas/application/chat/utilities/__init__.py +6 -0
  27. atlas/application/chat/utilities/error_handler.py +367 -0
  28. atlas/application/chat/utilities/event_notifier.py +546 -0
  29. atlas/application/chat/utilities/file_processor.py +613 -0
  30. atlas/application/chat/utilities/tool_executor.py +789 -0
  31. atlas/atlas_chat_cli.py +347 -0
  32. atlas/atlas_client.py +238 -0
  33. atlas/core/__init__.py +0 -0
  34. atlas/core/auth.py +205 -0
  35. atlas/core/authorization_manager.py +27 -0
  36. atlas/core/capabilities.py +123 -0
  37. atlas/core/compliance.py +215 -0
  38. atlas/core/domain_whitelist.py +147 -0
  39. atlas/core/domain_whitelist_middleware.py +82 -0
  40. atlas/core/http_client.py +28 -0
  41. atlas/core/log_sanitizer.py +102 -0
  42. atlas/core/metrics_logger.py +59 -0
  43. atlas/core/middleware.py +131 -0
  44. atlas/core/otel_config.py +242 -0
  45. atlas/core/prompt_risk.py +200 -0
  46. atlas/core/rate_limit.py +0 -0
  47. atlas/core/rate_limit_middleware.py +64 -0
  48. atlas/core/security_headers_middleware.py +51 -0
  49. atlas/domain/__init__.py +37 -0
  50. atlas/domain/chat/__init__.py +1 -0
  51. atlas/domain/chat/dtos.py +85 -0
  52. atlas/domain/errors.py +96 -0
  53. atlas/domain/messages/__init__.py +12 -0
  54. atlas/domain/messages/models.py +160 -0
  55. atlas/domain/rag_mcp_service.py +664 -0
  56. atlas/domain/sessions/__init__.py +7 -0
  57. atlas/domain/sessions/models.py +36 -0
  58. atlas/domain/unified_rag_service.py +371 -0
  59. atlas/infrastructure/__init__.py +10 -0
  60. atlas/infrastructure/app_factory.py +135 -0
  61. atlas/infrastructure/events/__init__.py +1 -0
  62. atlas/infrastructure/events/cli_event_publisher.py +140 -0
  63. atlas/infrastructure/events/websocket_publisher.py +140 -0
  64. atlas/infrastructure/sessions/in_memory_repository.py +56 -0
  65. atlas/infrastructure/transport/__init__.py +7 -0
  66. atlas/infrastructure/transport/websocket_connection_adapter.py +33 -0
  67. atlas/init_cli.py +226 -0
  68. atlas/interfaces/__init__.py +15 -0
  69. atlas/interfaces/events.py +134 -0
  70. atlas/interfaces/llm.py +54 -0
  71. atlas/interfaces/rag.py +40 -0
  72. atlas/interfaces/sessions.py +75 -0
  73. atlas/interfaces/tools.py +57 -0
  74. atlas/interfaces/transport.py +24 -0
  75. atlas/main.py +564 -0
  76. atlas/mcp/api_key_demo/README.md +76 -0
  77. atlas/mcp/api_key_demo/main.py +172 -0
  78. atlas/mcp/api_key_demo/run.sh +56 -0
  79. atlas/mcp/basictable/main.py +147 -0
  80. atlas/mcp/calculator/main.py +149 -0
  81. atlas/mcp/code-executor/execution_engine.py +98 -0
  82. atlas/mcp/code-executor/execution_environment.py +95 -0
  83. atlas/mcp/code-executor/main.py +528 -0
  84. atlas/mcp/code-executor/result_processing.py +276 -0
  85. atlas/mcp/code-executor/script_generation.py +195 -0
  86. atlas/mcp/code-executor/security_checker.py +140 -0
  87. atlas/mcp/corporate_cars/main.py +437 -0
  88. atlas/mcp/csv_reporter/main.py +545 -0
  89. atlas/mcp/duckduckgo/main.py +182 -0
  90. atlas/mcp/elicitation_demo/README.md +171 -0
  91. atlas/mcp/elicitation_demo/main.py +262 -0
  92. atlas/mcp/env-demo/README.md +158 -0
  93. atlas/mcp/env-demo/main.py +199 -0
  94. atlas/mcp/file_size_test/main.py +284 -0
  95. atlas/mcp/filesystem/main.py +348 -0
  96. atlas/mcp/image_demo/main.py +113 -0
  97. atlas/mcp/image_demo/requirements.txt +4 -0
  98. atlas/mcp/logging_demo/README.md +72 -0
  99. atlas/mcp/logging_demo/main.py +103 -0
  100. atlas/mcp/many_tools_demo/main.py +50 -0
  101. atlas/mcp/order_database/__init__.py +0 -0
  102. atlas/mcp/order_database/main.py +369 -0
  103. atlas/mcp/order_database/signal_data.csv +1001 -0
  104. atlas/mcp/pdfbasic/main.py +394 -0
  105. atlas/mcp/pptx_generator/main.py +760 -0
  106. atlas/mcp/pptx_generator/requirements.txt +13 -0
  107. atlas/mcp/pptx_generator/run_test.sh +1 -0
  108. atlas/mcp/pptx_generator/test_pptx_generator_security.py +169 -0
  109. atlas/mcp/progress_demo/main.py +167 -0
  110. atlas/mcp/progress_updates_demo/QUICKSTART.md +273 -0
  111. atlas/mcp/progress_updates_demo/README.md +120 -0
  112. atlas/mcp/progress_updates_demo/main.py +497 -0
  113. atlas/mcp/prompts/main.py +222 -0
  114. atlas/mcp/public_demo/main.py +189 -0
  115. atlas/mcp/sampling_demo/README.md +169 -0
  116. atlas/mcp/sampling_demo/main.py +234 -0
  117. atlas/mcp/thinking/main.py +77 -0
  118. atlas/mcp/tool_planner/main.py +240 -0
  119. atlas/mcp/ui-demo/badmesh.png +0 -0
  120. atlas/mcp/ui-demo/main.py +383 -0
  121. atlas/mcp/ui-demo/templates/button_demo.html +32 -0
  122. atlas/mcp/ui-demo/templates/data_visualization.html +32 -0
  123. atlas/mcp/ui-demo/templates/form_demo.html +28 -0
  124. atlas/mcp/username-override-demo/README.md +320 -0
  125. atlas/mcp/username-override-demo/main.py +308 -0
  126. atlas/modules/__init__.py +0 -0
  127. atlas/modules/config/__init__.py +34 -0
  128. atlas/modules/config/cli.py +231 -0
  129. atlas/modules/config/config_manager.py +1096 -0
  130. atlas/modules/file_storage/__init__.py +22 -0
  131. atlas/modules/file_storage/cli.py +330 -0
  132. atlas/modules/file_storage/content_extractor.py +290 -0
  133. atlas/modules/file_storage/manager.py +295 -0
  134. atlas/modules/file_storage/mock_s3_client.py +402 -0
  135. atlas/modules/file_storage/s3_client.py +417 -0
  136. atlas/modules/llm/__init__.py +19 -0
  137. atlas/modules/llm/caller.py +287 -0
  138. atlas/modules/llm/litellm_caller.py +675 -0
  139. atlas/modules/llm/models.py +19 -0
  140. atlas/modules/mcp_tools/__init__.py +17 -0
  141. atlas/modules/mcp_tools/client.py +2123 -0
  142. atlas/modules/mcp_tools/token_storage.py +556 -0
  143. atlas/modules/prompts/prompt_provider.py +130 -0
  144. atlas/modules/rag/__init__.py +24 -0
  145. atlas/modules/rag/atlas_rag_client.py +336 -0
  146. atlas/modules/rag/client.py +129 -0
  147. atlas/routes/admin_routes.py +865 -0
  148. atlas/routes/config_routes.py +484 -0
  149. atlas/routes/feedback_routes.py +361 -0
  150. atlas/routes/files_routes.py +274 -0
  151. atlas/routes/health_routes.py +40 -0
  152. atlas/routes/mcp_auth_routes.py +223 -0
  153. atlas/server_cli.py +164 -0
  154. atlas/tests/conftest.py +20 -0
  155. atlas/tests/integration/test_mcp_auth_integration.py +152 -0
  156. atlas/tests/manual_test_sampling.py +87 -0
  157. atlas/tests/modules/mcp_tools/test_client_auth.py +226 -0
  158. atlas/tests/modules/mcp_tools/test_client_env.py +191 -0
  159. atlas/tests/test_admin_mcp_server_management_routes.py +141 -0
  160. atlas/tests/test_agent_roa.py +135 -0
  161. atlas/tests/test_app_factory_smoke.py +47 -0
  162. atlas/tests/test_approval_manager.py +439 -0
  163. atlas/tests/test_atlas_client.py +188 -0
  164. atlas/tests/test_atlas_rag_client.py +447 -0
  165. atlas/tests/test_atlas_rag_integration.py +224 -0
  166. atlas/tests/test_attach_file_flow.py +287 -0
  167. atlas/tests/test_auth_utils.py +165 -0
  168. atlas/tests/test_backend_public_url.py +185 -0
  169. atlas/tests/test_banner_logging.py +287 -0
  170. atlas/tests/test_capability_tokens_and_injection.py +203 -0
  171. atlas/tests/test_compliance_level.py +54 -0
  172. atlas/tests/test_compliance_manager.py +253 -0
  173. atlas/tests/test_config_manager.py +617 -0
  174. atlas/tests/test_config_manager_paths.py +12 -0
  175. atlas/tests/test_core_auth.py +18 -0
  176. atlas/tests/test_core_utils.py +190 -0
  177. atlas/tests/test_docker_env_sync.py +202 -0
  178. atlas/tests/test_domain_errors.py +329 -0
  179. atlas/tests/test_domain_whitelist.py +359 -0
  180. atlas/tests/test_elicitation_manager.py +408 -0
  181. atlas/tests/test_elicitation_routing.py +296 -0
  182. atlas/tests/test_env_demo_server.py +88 -0
  183. atlas/tests/test_error_classification.py +113 -0
  184. atlas/tests/test_error_flow_integration.py +116 -0
  185. atlas/tests/test_feedback_routes.py +333 -0
  186. atlas/tests/test_file_content_extraction.py +1134 -0
  187. atlas/tests/test_file_extraction_routes.py +158 -0
  188. atlas/tests/test_file_library.py +107 -0
  189. atlas/tests/test_file_manager_unit.py +18 -0
  190. atlas/tests/test_health_route.py +49 -0
  191. atlas/tests/test_http_client_stub.py +8 -0
  192. atlas/tests/test_imports_smoke.py +30 -0
  193. atlas/tests/test_interfaces_llm_response.py +9 -0
  194. atlas/tests/test_issue_access_denied_fix.py +136 -0
  195. atlas/tests/test_llm_env_expansion.py +836 -0
  196. atlas/tests/test_log_level_sensitive_data.py +285 -0
  197. atlas/tests/test_mcp_auth_routes.py +341 -0
  198. atlas/tests/test_mcp_client_auth.py +331 -0
  199. atlas/tests/test_mcp_data_injection.py +270 -0
  200. atlas/tests/test_mcp_get_authorized_servers.py +95 -0
  201. atlas/tests/test_mcp_hot_reload.py +512 -0
  202. atlas/tests/test_mcp_image_content.py +424 -0
  203. atlas/tests/test_mcp_logging.py +172 -0
  204. atlas/tests/test_mcp_progress_updates.py +313 -0
  205. atlas/tests/test_mcp_prompt_override_system_prompt.py +102 -0
  206. atlas/tests/test_mcp_prompts_server.py +39 -0
  207. atlas/tests/test_mcp_tool_result_parsing.py +296 -0
  208. atlas/tests/test_metrics_logger.py +56 -0
  209. atlas/tests/test_middleware_auth.py +379 -0
  210. atlas/tests/test_prompt_risk_and_acl.py +141 -0
  211. atlas/tests/test_rag_mcp_aggregator.py +204 -0
  212. atlas/tests/test_rag_mcp_service.py +224 -0
  213. atlas/tests/test_rate_limit_middleware.py +45 -0
  214. atlas/tests/test_routes_config_smoke.py +60 -0
  215. atlas/tests/test_routes_files_download_token.py +41 -0
  216. atlas/tests/test_routes_files_health.py +18 -0
  217. atlas/tests/test_runtime_imports.py +53 -0
  218. atlas/tests/test_sampling_integration.py +482 -0
  219. atlas/tests/test_security_admin_routes.py +61 -0
  220. atlas/tests/test_security_capability_tokens.py +65 -0
  221. atlas/tests/test_security_file_stats_scope.py +21 -0
  222. atlas/tests/test_security_header_injection.py +191 -0
  223. atlas/tests/test_security_headers_and_filename.py +63 -0
  224. atlas/tests/test_shared_session_repository.py +101 -0
  225. atlas/tests/test_system_prompt_loading.py +181 -0
  226. atlas/tests/test_token_storage.py +505 -0
  227. atlas/tests/test_tool_approval_config.py +93 -0
  228. atlas/tests/test_tool_approval_utils.py +356 -0
  229. atlas/tests/test_tool_authorization_group_filtering.py +223 -0
  230. atlas/tests/test_tool_details_in_config.py +108 -0
  231. atlas/tests/test_tool_planner.py +300 -0
  232. atlas/tests/test_unified_rag_service.py +398 -0
  233. atlas/tests/test_username_override_in_approval.py +258 -0
  234. atlas/tests/test_websocket_auth_header.py +168 -0
  235. atlas/version.py +6 -0
  236. atlas_chat-0.1.0.data/data/.env.example +253 -0
  237. atlas_chat-0.1.0.data/data/config/defaults/compliance-levels.json +44 -0
  238. atlas_chat-0.1.0.data/data/config/defaults/domain-whitelist.json +123 -0
  239. atlas_chat-0.1.0.data/data/config/defaults/file-extractors.json +74 -0
  240. atlas_chat-0.1.0.data/data/config/defaults/help-config.json +198 -0
  241. atlas_chat-0.1.0.data/data/config/defaults/llmconfig-buggy.yml +11 -0
  242. atlas_chat-0.1.0.data/data/config/defaults/llmconfig.yml +19 -0
  243. atlas_chat-0.1.0.data/data/config/defaults/mcp.json +138 -0
  244. atlas_chat-0.1.0.data/data/config/defaults/rag-sources.json +17 -0
  245. atlas_chat-0.1.0.data/data/config/defaults/splash-config.json +16 -0
  246. atlas_chat-0.1.0.dist-info/METADATA +236 -0
  247. atlas_chat-0.1.0.dist-info/RECORD +250 -0
  248. atlas_chat-0.1.0.dist-info/WHEEL +5 -0
  249. atlas_chat-0.1.0.dist-info/entry_points.txt +4 -0
  250. atlas_chat-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,329 @@
1
+ """Tests for domain errors module."""
2
+
3
+
4
+ from atlas.domain.errors import (
5
+ AuthenticationError,
6
+ AuthorizationError,
7
+ ConfigurationError,
8
+ DataSourcePermissionError,
9
+ DomainError,
10
+ LLMAuthenticationError,
11
+ LLMConfigurationError,
12
+ LLMError,
13
+ LLMServiceError,
14
+ LLMTimeoutError,
15
+ MessageError,
16
+ PromptOverrideError,
17
+ RateLimitError,
18
+ SessionError,
19
+ SessionNotFoundError,
20
+ ToolAuthorizationError,
21
+ ToolError,
22
+ ValidationError,
23
+ )
24
+
25
+
26
+ class TestDomainError:
27
+ """Test suite for DomainError base class."""
28
+
29
+ def test_domain_error_with_message_only(self):
30
+ """Test DomainError creation with message only."""
31
+ message = "Something went wrong"
32
+ error = DomainError(message)
33
+
34
+ assert str(error) == message
35
+ assert error.message == message
36
+ assert error.code is None
37
+
38
+ def test_domain_error_with_message_and_code(self):
39
+ """Test DomainError creation with message and code."""
40
+ message = "Something went wrong"
41
+ code = "ERR_001"
42
+ error = DomainError(message, code)
43
+
44
+ assert str(error) == message
45
+ assert error.message == message
46
+ assert error.code == code
47
+
48
+ def test_domain_error_inheritance(self):
49
+ """Test that DomainError inherits from Exception."""
50
+ error = DomainError("test")
51
+ assert isinstance(error, Exception)
52
+
53
+ def test_domain_error_with_empty_message(self):
54
+ """Test DomainError with empty message."""
55
+ error = DomainError("")
56
+ assert error.message == ""
57
+ assert str(error) == ""
58
+
59
+ def test_domain_error_with_none_code(self):
60
+ """Test DomainError with explicitly None code."""
61
+ error = DomainError("test", None)
62
+ assert error.code is None
63
+
64
+
65
+ class TestValidationError:
66
+ """Test suite for ValidationError."""
67
+
68
+ def test_validation_error_inheritance(self):
69
+ """Test that ValidationError inherits from DomainError."""
70
+ error = ValidationError("Invalid input")
71
+ assert isinstance(error, DomainError)
72
+ assert isinstance(error, Exception)
73
+
74
+ def test_validation_error_with_code(self):
75
+ """Test ValidationError with error code."""
76
+ error = ValidationError("Invalid email format", "VALIDATION_001")
77
+ assert error.message == "Invalid email format"
78
+ assert error.code == "VALIDATION_001"
79
+
80
+
81
+ class TestSessionError:
82
+ """Test suite for SessionError."""
83
+
84
+ def test_session_error_inheritance(self):
85
+ """Test that SessionError inherits from DomainError."""
86
+ error = SessionError("Session expired")
87
+ assert isinstance(error, DomainError)
88
+ assert isinstance(error, Exception)
89
+
90
+
91
+ class TestMessageError:
92
+ """Test suite for MessageError."""
93
+
94
+ def test_message_error_inheritance(self):
95
+ """Test that MessageError inherits from DomainError."""
96
+ error = MessageError("Message processing failed")
97
+ assert isinstance(error, DomainError)
98
+ assert isinstance(error, Exception)
99
+
100
+
101
+ class TestAuthenticationError:
102
+ """Test suite for AuthenticationError."""
103
+
104
+ def test_authentication_error_inheritance(self):
105
+ """Test that AuthenticationError inherits from DomainError."""
106
+ error = AuthenticationError("Invalid credentials")
107
+ assert isinstance(error, DomainError)
108
+ assert isinstance(error, Exception)
109
+
110
+
111
+ class TestAuthorizationError:
112
+ """Test suite for AuthorizationError."""
113
+
114
+ def test_authorization_error_inheritance(self):
115
+ """Test that AuthorizationError inherits from DomainError."""
116
+ error = AuthorizationError("Access denied")
117
+ assert isinstance(error, DomainError)
118
+ assert isinstance(error, Exception)
119
+
120
+
121
+ class TestConfigurationError:
122
+ """Test suite for ConfigurationError."""
123
+
124
+ def test_configuration_error_inheritance(self):
125
+ """Test that ConfigurationError inherits from DomainError."""
126
+ error = ConfigurationError("Invalid configuration")
127
+ assert isinstance(error, DomainError)
128
+ assert isinstance(error, Exception)
129
+
130
+
131
+ class TestLLMError:
132
+ """Test suite for LLMError."""
133
+
134
+ def test_llm_error_inheritance(self):
135
+ """Test that LLMError inherits from DomainError."""
136
+ error = LLMError("LLM service failed")
137
+ assert isinstance(error, DomainError)
138
+ assert isinstance(error, Exception)
139
+
140
+
141
+ class TestLLMServiceError:
142
+ """Test suite for LLMServiceError."""
143
+
144
+ def test_llm_service_error_inheritance(self):
145
+ """Test that LLMServiceError inherits from LLMError."""
146
+ error = LLMServiceError("Service unavailable")
147
+ assert isinstance(error, LLMError)
148
+ assert isinstance(error, DomainError)
149
+ assert isinstance(error, Exception)
150
+
151
+
152
+ class TestToolError:
153
+ """Test suite for ToolError."""
154
+
155
+ def test_tool_error_inheritance(self):
156
+ """Test that ToolError inherits from DomainError."""
157
+ error = ToolError("Tool execution failed")
158
+ assert isinstance(error, DomainError)
159
+ assert isinstance(error, Exception)
160
+
161
+
162
+ class TestToolAuthorizationError:
163
+ """Test suite for ToolAuthorizationError."""
164
+
165
+ def test_tool_authorization_error_inheritance(self):
166
+ """Test that ToolAuthorizationError inherits from AuthorizationError."""
167
+ error = ToolAuthorizationError("Tool access denied")
168
+ assert isinstance(error, AuthorizationError)
169
+ assert isinstance(error, DomainError)
170
+ assert isinstance(error, Exception)
171
+
172
+
173
+ class TestDataSourcePermissionError:
174
+ """Test suite for DataSourcePermissionError."""
175
+
176
+ def test_data_source_permission_error_inheritance(self):
177
+ """Test that DataSourcePermissionError inherits from AuthorizationError."""
178
+ error = DataSourcePermissionError("Data source access denied")
179
+ assert isinstance(error, AuthorizationError)
180
+ assert isinstance(error, DomainError)
181
+ assert isinstance(error, Exception)
182
+
183
+
184
+ class TestLLMConfigurationError:
185
+ """Test suite for LLMConfigurationError."""
186
+
187
+ def test_llm_configuration_error_inheritance(self):
188
+ """Test that LLMConfigurationError inherits from ConfigurationError."""
189
+ error = LLMConfigurationError("Invalid LLM config")
190
+ assert isinstance(error, ConfigurationError)
191
+ assert isinstance(error, DomainError)
192
+ assert isinstance(error, Exception)
193
+
194
+
195
+ class TestSessionNotFoundError:
196
+ """Test suite for SessionNotFoundError."""
197
+
198
+ def test_session_not_found_error_inheritance(self):
199
+ """Test that SessionNotFoundError inherits from SessionError."""
200
+ error = SessionNotFoundError("Session not found")
201
+ assert isinstance(error, SessionError)
202
+ assert isinstance(error, DomainError)
203
+ assert isinstance(error, Exception)
204
+
205
+
206
+ class TestPromptOverrideError:
207
+ """Test suite for PromptOverrideError."""
208
+
209
+ def test_prompt_override_error_inheritance(self):
210
+ """Test that PromptOverrideError inherits from DomainError."""
211
+ error = PromptOverrideError("Prompt override failed")
212
+ assert isinstance(error, DomainError)
213
+ assert isinstance(error, Exception)
214
+
215
+
216
+ class TestRateLimitError:
217
+ """Test suite for RateLimitError."""
218
+
219
+ def test_rate_limit_error_inheritance(self):
220
+ """Test that RateLimitError inherits from LLMError."""
221
+ error = RateLimitError("Rate limit exceeded")
222
+ assert isinstance(error, LLMError)
223
+ assert isinstance(error, DomainError)
224
+ assert isinstance(error, Exception)
225
+
226
+
227
+ class TestLLMTimeoutError:
228
+ """Test suite for LLMTimeoutError."""
229
+
230
+ def test_llm_timeout_error_inheritance(self):
231
+ """Test that LLMTimeoutError inherits from LLMError."""
232
+ error = LLMTimeoutError("Request timed out")
233
+ assert isinstance(error, LLMError)
234
+ assert isinstance(error, DomainError)
235
+ assert isinstance(error, Exception)
236
+
237
+
238
+ class TestLLMAuthenticationError:
239
+ """Test suite for LLMAuthenticationError."""
240
+
241
+ def test_llm_authentication_error_inheritance(self):
242
+ """Test that LLMAuthenticationError inherits from AuthenticationError."""
243
+ error = LLMAuthenticationError("LLM authentication failed")
244
+ assert isinstance(error, AuthenticationError)
245
+ assert isinstance(error, DomainError)
246
+ assert isinstance(error, Exception)
247
+
248
+
249
+ class TestErrorHierarchy:
250
+ """Test suite for error hierarchy and relationships."""
251
+
252
+ def test_all_errors_inherit_from_domain_error(self):
253
+ """Test that all custom errors inherit from DomainError."""
254
+ error_classes = [
255
+ ValidationError,
256
+ SessionError,
257
+ MessageError,
258
+ AuthenticationError,
259
+ AuthorizationError,
260
+ ConfigurationError,
261
+ LLMError,
262
+ LLMServiceError,
263
+ ToolError,
264
+ ToolAuthorizationError,
265
+ DataSourcePermissionError,
266
+ LLMConfigurationError,
267
+ SessionNotFoundError,
268
+ PromptOverrideError,
269
+ RateLimitError,
270
+ LLMTimeoutError,
271
+ LLMAuthenticationError,
272
+ ]
273
+
274
+ for error_class in error_classes:
275
+ error = error_class("test message")
276
+ assert isinstance(error, DomainError)
277
+ assert isinstance(error, Exception)
278
+
279
+ def test_error_message_and_code_preservation(self):
280
+ """Test that all error types preserve message and code correctly."""
281
+ error_classes = [
282
+ DomainError,
283
+ ValidationError,
284
+ SessionError,
285
+ MessageError,
286
+ AuthenticationError,
287
+ AuthorizationError,
288
+ ConfigurationError,
289
+ LLMError,
290
+ LLMServiceError,
291
+ ToolError,
292
+ ToolAuthorizationError,
293
+ DataSourcePermissionError,
294
+ LLMConfigurationError,
295
+ SessionNotFoundError,
296
+ PromptOverrideError,
297
+ RateLimitError,
298
+ LLMTimeoutError,
299
+ LLMAuthenticationError,
300
+ ]
301
+
302
+ test_message = "Test error message"
303
+ test_code = "TEST_001"
304
+
305
+ for error_class in error_classes:
306
+ error = error_class(test_message, test_code)
307
+ assert error.message == test_message
308
+ assert error.code == test_code
309
+ assert str(error) == test_message
310
+
311
+ def test_specific_inheritance_relationships(self):
312
+ """Test specific inheritance relationships between error types."""
313
+ # Test LLM-related errors
314
+ assert issubclass(LLMServiceError, LLMError)
315
+ assert issubclass(RateLimitError, LLMError)
316
+ assert issubclass(LLMTimeoutError, LLMError)
317
+
318
+ # Test authorization-related errors
319
+ assert issubclass(ToolAuthorizationError, AuthorizationError)
320
+ assert issubclass(DataSourcePermissionError, AuthorizationError)
321
+
322
+ # Test authentication-related errors
323
+ assert issubclass(LLMAuthenticationError, AuthenticationError)
324
+
325
+ # Test configuration-related errors
326
+ assert issubclass(LLMConfigurationError, ConfigurationError)
327
+
328
+ # Test session-related errors
329
+ assert issubclass(SessionNotFoundError, SessionError)
@@ -0,0 +1,359 @@
1
+ """Tests for domain whitelist middleware."""
2
+
3
+ import json
4
+ import tempfile
5
+ from pathlib import Path
6
+
7
+ import pytest
8
+ from fastapi import FastAPI
9
+
10
+ from atlas.core.domain_whitelist import DomainWhitelistManager
11
+ from atlas.core.domain_whitelist_middleware import DomainWhitelistMiddleware
12
+
13
+
14
+ @pytest.fixture
15
+ def temp_config():
16
+ """Create a temporary config file for testing."""
17
+ config_data = {
18
+ "version": "1.0",
19
+ "description": "Test config",
20
+ "domains": [
21
+ {"domain": "sandia.gov", "description": "Sandia National Labs"},
22
+ {"domain": "doe.gov", "description": "DOE"},
23
+ {"domain": "example.org", "description": "Example"},
24
+ ],
25
+ "subdomain_matching": True
26
+ }
27
+
28
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
29
+ json.dump(config_data, f)
30
+ temp_path = Path(f.name)
31
+
32
+ yield temp_path
33
+
34
+ # Cleanup
35
+ if temp_path.exists():
36
+ temp_path.unlink()
37
+
38
+
39
+
40
+ class TestDomainWhitelistManager:
41
+ """Test the domain whitelist manager."""
42
+
43
+ def test_load_config(self, temp_config):
44
+ """Test loading configuration from file."""
45
+ manager = DomainWhitelistManager(config_path=temp_config)
46
+
47
+ assert "sandia.gov" in manager.get_domains()
48
+ assert "doe.gov" in manager.get_domains()
49
+ assert "example.org" in manager.get_domains()
50
+ assert len(manager.get_domains()) == 3
51
+
52
+ def test_missing_config_file(self):
53
+ """Test that missing config file allows all domains (fail open)."""
54
+ non_existent_path = Path("/tmp/nonexistent_whitelist_config_12345.json")
55
+ manager = DomainWhitelistManager(config_path=non_existent_path)
56
+
57
+ # Config should not be loaded
58
+ assert manager.config_loaded is False
59
+ assert len(manager.get_domains()) == 0
60
+
61
+ # But should allow all domains (fail open)
62
+ assert manager.is_domain_allowed("user@gmail.com") is True
63
+ assert manager.is_domain_allowed("user@any-domain.com") is True
64
+ assert manager.is_domain_allowed("test@example.org") is True
65
+
66
+ def test_invalid_json_config(self):
67
+ """Test that invalid JSON config allows all domains (fail open)."""
68
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
69
+ f.write("{ invalid json content ]}")
70
+ temp_path = Path(f.name)
71
+
72
+ try:
73
+ manager = DomainWhitelistManager(config_path=temp_path)
74
+
75
+ # Config should not be loaded
76
+ assert manager.config_loaded is False
77
+ assert len(manager.get_domains()) == 0
78
+
79
+ # Should allow all domains (fail open)
80
+ assert manager.is_domain_allowed("user@gmail.com") is True
81
+ assert manager.is_domain_allowed("test@example.org") is True
82
+ finally:
83
+ if temp_path.exists():
84
+ temp_path.unlink()
85
+
86
+ def test_empty_domains_list(self):
87
+ """Test config with empty domains list."""
88
+ config_data = {
89
+ "version": "1.0",
90
+ "description": "Empty config",
91
+ "domains": [],
92
+ "subdomain_matching": True
93
+ }
94
+
95
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
96
+ json.dump(config_data, f)
97
+ temp_path = Path(f.name)
98
+
99
+ try:
100
+ manager = DomainWhitelistManager(config_path=temp_path)
101
+
102
+ # Config should be loaded successfully even with empty domains
103
+ assert manager.config_loaded is True
104
+ assert len(manager.get_domains()) == 0
105
+
106
+ # Should block all domains when config is valid but empty
107
+ assert manager.is_domain_allowed("user@gmail.com") is False
108
+ assert manager.is_domain_allowed("user@sandia.gov") is False
109
+ finally:
110
+ if temp_path.exists():
111
+ temp_path.unlink()
112
+
113
+ def test_domain_matching(self, temp_config):
114
+ """Test domain matching logic."""
115
+ manager = DomainWhitelistManager(config_path=temp_config)
116
+
117
+ # Exact matches
118
+ assert manager.is_domain_allowed("user@sandia.gov") is True
119
+ assert manager.is_domain_allowed("user@doe.gov") is True
120
+
121
+ # Subdomain matches
122
+ assert manager.is_domain_allowed("user@mail.sandia.gov") is True
123
+ assert manager.is_domain_allowed("user@sub.doe.gov") is True
124
+
125
+ # Invalid domains
126
+ assert manager.is_domain_allowed("user@gmail.com") is False
127
+ assert manager.is_domain_allowed("user@sandia.com") is False # Wrong TLD
128
+
129
+ def test_invalid_email(self, temp_config):
130
+ """Test handling of invalid email addresses."""
131
+ manager = DomainWhitelistManager(config_path=temp_config)
132
+
133
+ assert manager.is_domain_allowed("notanemail") is False
134
+ assert manager.is_domain_allowed("") is False
135
+ assert manager.is_domain_allowed("no-at-sign.com") is False
136
+
137
+
138
+ @pytest.fixture
139
+ def create_middleware():
140
+ """Factory fixture to create middleware with custom config."""
141
+ from starlette.middleware.base import BaseHTTPMiddleware
142
+
143
+ def _create(config_path):
144
+ app = FastAPI()
145
+
146
+ # Monkey-patch to use custom config
147
+ original_init = DomainWhitelistMiddleware.__init__
148
+ def patched_init(self, app, auth_redirect_url="/auth"):
149
+ BaseHTTPMiddleware.__init__(self, app)
150
+ self.auth_redirect_url = auth_redirect_url
151
+ self.whitelist_manager = DomainWhitelistManager(config_path=config_path)
152
+
153
+ DomainWhitelistMiddleware.__init__ = patched_init
154
+ middleware = DomainWhitelistMiddleware(app)
155
+ DomainWhitelistMiddleware.__init__ = original_init
156
+
157
+ return middleware
158
+
159
+ return _create
160
+
161
+
162
+ class TestDomainWhitelistMiddleware:
163
+ """Test domain whitelist middleware."""
164
+
165
+ def test_middleware_with_allowed_domain(self, temp_config, create_middleware):
166
+ """Test that allowed domains pass through."""
167
+ from starlette.requests import Request
168
+ from starlette.responses import Response
169
+
170
+ middleware = create_middleware(temp_config)
171
+
172
+ async def call_next(request):
173
+ return Response("OK", status_code=200)
174
+
175
+ async def test_request():
176
+ scope = {
177
+ "type": "http",
178
+ "method": "GET",
179
+ "path": "/api/test",
180
+ "query_string": b"",
181
+ "headers": [],
182
+ "state": {},
183
+ }
184
+ request = Request(scope)
185
+ request.state.user_email = "test@sandia.gov"
186
+
187
+ response = await middleware.dispatch(request, call_next)
188
+ assert response.status_code == 200
189
+
190
+ import asyncio
191
+ asyncio.run(test_request())
192
+
193
+ def test_middleware_with_disallowed_domain(self, temp_config, create_middleware):
194
+ """Test that disallowed domains are blocked."""
195
+ from starlette.requests import Request
196
+ from starlette.responses import Response
197
+
198
+ middleware = create_middleware(temp_config)
199
+
200
+ async def call_next(request):
201
+ return Response("OK", status_code=200)
202
+
203
+ async def test_request():
204
+ scope = {
205
+ "type": "http",
206
+ "method": "GET",
207
+ "path": "/api/test",
208
+ "query_string": b"",
209
+ "headers": [],
210
+ "state": {},
211
+ }
212
+ request = Request(scope)
213
+ request.state.user_email = "test@gmail.com"
214
+
215
+ response = await middleware.dispatch(request, call_next)
216
+ assert response.status_code == 403
217
+
218
+ import asyncio
219
+ asyncio.run(test_request())
220
+
221
+
222
+ def test_health_endpoint_bypass(self, temp_config, create_middleware):
223
+ """Test that health endpoint bypasses whitelist check."""
224
+ from starlette.requests import Request
225
+ from starlette.responses import Response
226
+
227
+ middleware = create_middleware(temp_config)
228
+
229
+ async def call_next(request):
230
+ return Response("OK", status_code=200)
231
+
232
+ async def test_request():
233
+ scope = {
234
+ "type": "http",
235
+ "method": "GET",
236
+ "path": "/api/health",
237
+ "query_string": b"",
238
+ "headers": [],
239
+ "state": {},
240
+ }
241
+ request = Request(scope)
242
+ # No email - should still pass for health check
243
+
244
+ response = await middleware.dispatch(request, call_next)
245
+ assert response.status_code == 200
246
+
247
+ import asyncio
248
+ asyncio.run(test_request())
249
+
250
+ def test_middleware_with_missing_config(self, create_middleware):
251
+ """Test that middleware with missing config allows all domains."""
252
+ from starlette.requests import Request
253
+ from starlette.responses import Response
254
+
255
+ non_existent_path = Path("/tmp/nonexistent_whitelist_config_12345.json")
256
+ middleware = create_middleware(non_existent_path)
257
+
258
+ async def call_next(request):
259
+ return Response("OK", status_code=200)
260
+
261
+ async def test_request():
262
+ scope = {
263
+ "type": "http",
264
+ "method": "GET",
265
+ "path": "/api/test",
266
+ "query_string": b"",
267
+ "headers": [],
268
+ "state": {},
269
+ }
270
+ request = Request(scope)
271
+ request.state.user_email = "test@gmail.com"
272
+
273
+ # Should pass even though config is missing (fail open)
274
+ response = await middleware.dispatch(request, call_next)
275
+ assert response.status_code == 200
276
+
277
+ import asyncio
278
+ asyncio.run(test_request())
279
+
280
+ def test_middleware_with_invalid_config(self, create_middleware):
281
+ """Test that middleware with invalid config allows all domains."""
282
+ from starlette.requests import Request
283
+ from starlette.responses import Response
284
+
285
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
286
+ f.write("{ invalid json }")
287
+ temp_path = Path(f.name)
288
+
289
+ try:
290
+ middleware = create_middleware(temp_path)
291
+
292
+ async def call_next(request):
293
+ return Response("OK", status_code=200)
294
+
295
+ async def test_request():
296
+ scope = {
297
+ "type": "http",
298
+ "method": "GET",
299
+ "path": "/api/test",
300
+ "query_string": b"",
301
+ "headers": [],
302
+ "state": {},
303
+ }
304
+ request = Request(scope)
305
+ request.state.user_email = "test@anydomain.com"
306
+
307
+ # Should pass even though config is invalid (fail open)
308
+ response = await middleware.dispatch(request, call_next)
309
+ assert response.status_code == 200
310
+
311
+ import asyncio
312
+ asyncio.run(test_request())
313
+ finally:
314
+ if temp_path.exists():
315
+ temp_path.unlink()
316
+
317
+ def test_middleware_with_empty_domains(self, create_middleware):
318
+ """Test that middleware with empty domains list blocks all."""
319
+ from starlette.requests import Request
320
+ from starlette.responses import Response
321
+
322
+ config_data = {
323
+ "version": "1.0",
324
+ "description": "Empty config",
325
+ "domains": [],
326
+ "subdomain_matching": True
327
+ }
328
+
329
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
330
+ json.dump(config_data, f)
331
+ temp_path = Path(f.name)
332
+
333
+ try:
334
+ middleware = create_middleware(temp_path)
335
+
336
+ async def call_next(request):
337
+ return Response("OK", status_code=200)
338
+
339
+ async def test_request():
340
+ scope = {
341
+ "type": "http",
342
+ "method": "GET",
343
+ "path": "/api/test",
344
+ "query_string": b"",
345
+ "headers": [],
346
+ "state": {},
347
+ }
348
+ request = Request(scope)
349
+ request.state.user_email = "test@anydomain.com"
350
+
351
+ # Should block because empty domains is a valid config
352
+ response = await middleware.dispatch(request, call_next)
353
+ assert response.status_code == 403
354
+
355
+ import asyncio
356
+ asyncio.run(test_request())
357
+ finally:
358
+ if temp_path.exists():
359
+ temp_path.unlink()