atlas-chat 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. atlas/__init__.py +40 -0
  2. atlas/application/__init__.py +7 -0
  3. atlas/application/chat/__init__.py +7 -0
  4. atlas/application/chat/agent/__init__.py +10 -0
  5. atlas/application/chat/agent/act_loop.py +179 -0
  6. atlas/application/chat/agent/factory.py +142 -0
  7. atlas/application/chat/agent/protocols.py +46 -0
  8. atlas/application/chat/agent/react_loop.py +338 -0
  9. atlas/application/chat/agent/think_act_loop.py +171 -0
  10. atlas/application/chat/approval_manager.py +151 -0
  11. atlas/application/chat/elicitation_manager.py +191 -0
  12. atlas/application/chat/events/__init__.py +1 -0
  13. atlas/application/chat/events/agent_event_relay.py +112 -0
  14. atlas/application/chat/modes/__init__.py +1 -0
  15. atlas/application/chat/modes/agent.py +125 -0
  16. atlas/application/chat/modes/plain.py +74 -0
  17. atlas/application/chat/modes/rag.py +81 -0
  18. atlas/application/chat/modes/tools.py +179 -0
  19. atlas/application/chat/orchestrator.py +213 -0
  20. atlas/application/chat/policies/__init__.py +1 -0
  21. atlas/application/chat/policies/tool_authorization.py +99 -0
  22. atlas/application/chat/preprocessors/__init__.py +1 -0
  23. atlas/application/chat/preprocessors/message_builder.py +92 -0
  24. atlas/application/chat/preprocessors/prompt_override_service.py +104 -0
  25. atlas/application/chat/service.py +454 -0
  26. atlas/application/chat/utilities/__init__.py +6 -0
  27. atlas/application/chat/utilities/error_handler.py +367 -0
  28. atlas/application/chat/utilities/event_notifier.py +546 -0
  29. atlas/application/chat/utilities/file_processor.py +613 -0
  30. atlas/application/chat/utilities/tool_executor.py +789 -0
  31. atlas/atlas_chat_cli.py +347 -0
  32. atlas/atlas_client.py +238 -0
  33. atlas/core/__init__.py +0 -0
  34. atlas/core/auth.py +205 -0
  35. atlas/core/authorization_manager.py +27 -0
  36. atlas/core/capabilities.py +123 -0
  37. atlas/core/compliance.py +215 -0
  38. atlas/core/domain_whitelist.py +147 -0
  39. atlas/core/domain_whitelist_middleware.py +82 -0
  40. atlas/core/http_client.py +28 -0
  41. atlas/core/log_sanitizer.py +102 -0
  42. atlas/core/metrics_logger.py +59 -0
  43. atlas/core/middleware.py +131 -0
  44. atlas/core/otel_config.py +242 -0
  45. atlas/core/prompt_risk.py +200 -0
  46. atlas/core/rate_limit.py +0 -0
  47. atlas/core/rate_limit_middleware.py +64 -0
  48. atlas/core/security_headers_middleware.py +51 -0
  49. atlas/domain/__init__.py +37 -0
  50. atlas/domain/chat/__init__.py +1 -0
  51. atlas/domain/chat/dtos.py +85 -0
  52. atlas/domain/errors.py +96 -0
  53. atlas/domain/messages/__init__.py +12 -0
  54. atlas/domain/messages/models.py +160 -0
  55. atlas/domain/rag_mcp_service.py +664 -0
  56. atlas/domain/sessions/__init__.py +7 -0
  57. atlas/domain/sessions/models.py +36 -0
  58. atlas/domain/unified_rag_service.py +371 -0
  59. atlas/infrastructure/__init__.py +10 -0
  60. atlas/infrastructure/app_factory.py +135 -0
  61. atlas/infrastructure/events/__init__.py +1 -0
  62. atlas/infrastructure/events/cli_event_publisher.py +140 -0
  63. atlas/infrastructure/events/websocket_publisher.py +140 -0
  64. atlas/infrastructure/sessions/in_memory_repository.py +56 -0
  65. atlas/infrastructure/transport/__init__.py +7 -0
  66. atlas/infrastructure/transport/websocket_connection_adapter.py +33 -0
  67. atlas/init_cli.py +226 -0
  68. atlas/interfaces/__init__.py +15 -0
  69. atlas/interfaces/events.py +134 -0
  70. atlas/interfaces/llm.py +54 -0
  71. atlas/interfaces/rag.py +40 -0
  72. atlas/interfaces/sessions.py +75 -0
  73. atlas/interfaces/tools.py +57 -0
  74. atlas/interfaces/transport.py +24 -0
  75. atlas/main.py +564 -0
  76. atlas/mcp/api_key_demo/README.md +76 -0
  77. atlas/mcp/api_key_demo/main.py +172 -0
  78. atlas/mcp/api_key_demo/run.sh +56 -0
  79. atlas/mcp/basictable/main.py +147 -0
  80. atlas/mcp/calculator/main.py +149 -0
  81. atlas/mcp/code-executor/execution_engine.py +98 -0
  82. atlas/mcp/code-executor/execution_environment.py +95 -0
  83. atlas/mcp/code-executor/main.py +528 -0
  84. atlas/mcp/code-executor/result_processing.py +276 -0
  85. atlas/mcp/code-executor/script_generation.py +195 -0
  86. atlas/mcp/code-executor/security_checker.py +140 -0
  87. atlas/mcp/corporate_cars/main.py +437 -0
  88. atlas/mcp/csv_reporter/main.py +545 -0
  89. atlas/mcp/duckduckgo/main.py +182 -0
  90. atlas/mcp/elicitation_demo/README.md +171 -0
  91. atlas/mcp/elicitation_demo/main.py +262 -0
  92. atlas/mcp/env-demo/README.md +158 -0
  93. atlas/mcp/env-demo/main.py +199 -0
  94. atlas/mcp/file_size_test/main.py +284 -0
  95. atlas/mcp/filesystem/main.py +348 -0
  96. atlas/mcp/image_demo/main.py +113 -0
  97. atlas/mcp/image_demo/requirements.txt +4 -0
  98. atlas/mcp/logging_demo/README.md +72 -0
  99. atlas/mcp/logging_demo/main.py +103 -0
  100. atlas/mcp/many_tools_demo/main.py +50 -0
  101. atlas/mcp/order_database/__init__.py +0 -0
  102. atlas/mcp/order_database/main.py +369 -0
  103. atlas/mcp/order_database/signal_data.csv +1001 -0
  104. atlas/mcp/pdfbasic/main.py +394 -0
  105. atlas/mcp/pptx_generator/main.py +760 -0
  106. atlas/mcp/pptx_generator/requirements.txt +13 -0
  107. atlas/mcp/pptx_generator/run_test.sh +1 -0
  108. atlas/mcp/pptx_generator/test_pptx_generator_security.py +169 -0
  109. atlas/mcp/progress_demo/main.py +167 -0
  110. atlas/mcp/progress_updates_demo/QUICKSTART.md +273 -0
  111. atlas/mcp/progress_updates_demo/README.md +120 -0
  112. atlas/mcp/progress_updates_demo/main.py +497 -0
  113. atlas/mcp/prompts/main.py +222 -0
  114. atlas/mcp/public_demo/main.py +189 -0
  115. atlas/mcp/sampling_demo/README.md +169 -0
  116. atlas/mcp/sampling_demo/main.py +234 -0
  117. atlas/mcp/thinking/main.py +77 -0
  118. atlas/mcp/tool_planner/main.py +240 -0
  119. atlas/mcp/ui-demo/badmesh.png +0 -0
  120. atlas/mcp/ui-demo/main.py +383 -0
  121. atlas/mcp/ui-demo/templates/button_demo.html +32 -0
  122. atlas/mcp/ui-demo/templates/data_visualization.html +32 -0
  123. atlas/mcp/ui-demo/templates/form_demo.html +28 -0
  124. atlas/mcp/username-override-demo/README.md +320 -0
  125. atlas/mcp/username-override-demo/main.py +308 -0
  126. atlas/modules/__init__.py +0 -0
  127. atlas/modules/config/__init__.py +34 -0
  128. atlas/modules/config/cli.py +231 -0
  129. atlas/modules/config/config_manager.py +1096 -0
  130. atlas/modules/file_storage/__init__.py +22 -0
  131. atlas/modules/file_storage/cli.py +330 -0
  132. atlas/modules/file_storage/content_extractor.py +290 -0
  133. atlas/modules/file_storage/manager.py +295 -0
  134. atlas/modules/file_storage/mock_s3_client.py +402 -0
  135. atlas/modules/file_storage/s3_client.py +417 -0
  136. atlas/modules/llm/__init__.py +19 -0
  137. atlas/modules/llm/caller.py +287 -0
  138. atlas/modules/llm/litellm_caller.py +675 -0
  139. atlas/modules/llm/models.py +19 -0
  140. atlas/modules/mcp_tools/__init__.py +17 -0
  141. atlas/modules/mcp_tools/client.py +2123 -0
  142. atlas/modules/mcp_tools/token_storage.py +556 -0
  143. atlas/modules/prompts/prompt_provider.py +130 -0
  144. atlas/modules/rag/__init__.py +24 -0
  145. atlas/modules/rag/atlas_rag_client.py +336 -0
  146. atlas/modules/rag/client.py +129 -0
  147. atlas/routes/admin_routes.py +865 -0
  148. atlas/routes/config_routes.py +484 -0
  149. atlas/routes/feedback_routes.py +361 -0
  150. atlas/routes/files_routes.py +274 -0
  151. atlas/routes/health_routes.py +40 -0
  152. atlas/routes/mcp_auth_routes.py +223 -0
  153. atlas/server_cli.py +164 -0
  154. atlas/tests/conftest.py +20 -0
  155. atlas/tests/integration/test_mcp_auth_integration.py +152 -0
  156. atlas/tests/manual_test_sampling.py +87 -0
  157. atlas/tests/modules/mcp_tools/test_client_auth.py +226 -0
  158. atlas/tests/modules/mcp_tools/test_client_env.py +191 -0
  159. atlas/tests/test_admin_mcp_server_management_routes.py +141 -0
  160. atlas/tests/test_agent_roa.py +135 -0
  161. atlas/tests/test_app_factory_smoke.py +47 -0
  162. atlas/tests/test_approval_manager.py +439 -0
  163. atlas/tests/test_atlas_client.py +188 -0
  164. atlas/tests/test_atlas_rag_client.py +447 -0
  165. atlas/tests/test_atlas_rag_integration.py +224 -0
  166. atlas/tests/test_attach_file_flow.py +287 -0
  167. atlas/tests/test_auth_utils.py +165 -0
  168. atlas/tests/test_backend_public_url.py +185 -0
  169. atlas/tests/test_banner_logging.py +287 -0
  170. atlas/tests/test_capability_tokens_and_injection.py +203 -0
  171. atlas/tests/test_compliance_level.py +54 -0
  172. atlas/tests/test_compliance_manager.py +253 -0
  173. atlas/tests/test_config_manager.py +617 -0
  174. atlas/tests/test_config_manager_paths.py +12 -0
  175. atlas/tests/test_core_auth.py +18 -0
  176. atlas/tests/test_core_utils.py +190 -0
  177. atlas/tests/test_docker_env_sync.py +202 -0
  178. atlas/tests/test_domain_errors.py +329 -0
  179. atlas/tests/test_domain_whitelist.py +359 -0
  180. atlas/tests/test_elicitation_manager.py +408 -0
  181. atlas/tests/test_elicitation_routing.py +296 -0
  182. atlas/tests/test_env_demo_server.py +88 -0
  183. atlas/tests/test_error_classification.py +113 -0
  184. atlas/tests/test_error_flow_integration.py +116 -0
  185. atlas/tests/test_feedback_routes.py +333 -0
  186. atlas/tests/test_file_content_extraction.py +1134 -0
  187. atlas/tests/test_file_extraction_routes.py +158 -0
  188. atlas/tests/test_file_library.py +107 -0
  189. atlas/tests/test_file_manager_unit.py +18 -0
  190. atlas/tests/test_health_route.py +49 -0
  191. atlas/tests/test_http_client_stub.py +8 -0
  192. atlas/tests/test_imports_smoke.py +30 -0
  193. atlas/tests/test_interfaces_llm_response.py +9 -0
  194. atlas/tests/test_issue_access_denied_fix.py +136 -0
  195. atlas/tests/test_llm_env_expansion.py +836 -0
  196. atlas/tests/test_log_level_sensitive_data.py +285 -0
  197. atlas/tests/test_mcp_auth_routes.py +341 -0
  198. atlas/tests/test_mcp_client_auth.py +331 -0
  199. atlas/tests/test_mcp_data_injection.py +270 -0
  200. atlas/tests/test_mcp_get_authorized_servers.py +95 -0
  201. atlas/tests/test_mcp_hot_reload.py +512 -0
  202. atlas/tests/test_mcp_image_content.py +424 -0
  203. atlas/tests/test_mcp_logging.py +172 -0
  204. atlas/tests/test_mcp_progress_updates.py +313 -0
  205. atlas/tests/test_mcp_prompt_override_system_prompt.py +102 -0
  206. atlas/tests/test_mcp_prompts_server.py +39 -0
  207. atlas/tests/test_mcp_tool_result_parsing.py +296 -0
  208. atlas/tests/test_metrics_logger.py +56 -0
  209. atlas/tests/test_middleware_auth.py +379 -0
  210. atlas/tests/test_prompt_risk_and_acl.py +141 -0
  211. atlas/tests/test_rag_mcp_aggregator.py +204 -0
  212. atlas/tests/test_rag_mcp_service.py +224 -0
  213. atlas/tests/test_rate_limit_middleware.py +45 -0
  214. atlas/tests/test_routes_config_smoke.py +60 -0
  215. atlas/tests/test_routes_files_download_token.py +41 -0
  216. atlas/tests/test_routes_files_health.py +18 -0
  217. atlas/tests/test_runtime_imports.py +53 -0
  218. atlas/tests/test_sampling_integration.py +482 -0
  219. atlas/tests/test_security_admin_routes.py +61 -0
  220. atlas/tests/test_security_capability_tokens.py +65 -0
  221. atlas/tests/test_security_file_stats_scope.py +21 -0
  222. atlas/tests/test_security_header_injection.py +191 -0
  223. atlas/tests/test_security_headers_and_filename.py +63 -0
  224. atlas/tests/test_shared_session_repository.py +101 -0
  225. atlas/tests/test_system_prompt_loading.py +181 -0
  226. atlas/tests/test_token_storage.py +505 -0
  227. atlas/tests/test_tool_approval_config.py +93 -0
  228. atlas/tests/test_tool_approval_utils.py +356 -0
  229. atlas/tests/test_tool_authorization_group_filtering.py +223 -0
  230. atlas/tests/test_tool_details_in_config.py +108 -0
  231. atlas/tests/test_tool_planner.py +300 -0
  232. atlas/tests/test_unified_rag_service.py +398 -0
  233. atlas/tests/test_username_override_in_approval.py +258 -0
  234. atlas/tests/test_websocket_auth_header.py +168 -0
  235. atlas/version.py +6 -0
  236. atlas_chat-0.1.0.data/data/.env.example +253 -0
  237. atlas_chat-0.1.0.data/data/config/defaults/compliance-levels.json +44 -0
  238. atlas_chat-0.1.0.data/data/config/defaults/domain-whitelist.json +123 -0
  239. atlas_chat-0.1.0.data/data/config/defaults/file-extractors.json +74 -0
  240. atlas_chat-0.1.0.data/data/config/defaults/help-config.json +198 -0
  241. atlas_chat-0.1.0.data/data/config/defaults/llmconfig-buggy.yml +11 -0
  242. atlas_chat-0.1.0.data/data/config/defaults/llmconfig.yml +19 -0
  243. atlas_chat-0.1.0.data/data/config/defaults/mcp.json +138 -0
  244. atlas_chat-0.1.0.data/data/config/defaults/rag-sources.json +17 -0
  245. atlas_chat-0.1.0.data/data/config/defaults/splash-config.json +16 -0
  246. atlas_chat-0.1.0.dist-info/METADATA +236 -0
  247. atlas_chat-0.1.0.dist-info/RECORD +250 -0
  248. atlas_chat-0.1.0.dist-info/WHEEL +5 -0
  249. atlas_chat-0.1.0.dist-info/entry_points.txt +4 -0
  250. atlas_chat-0.1.0.dist-info/top_level.txt +1 -0
atlas/core/auth.py ADDED
@@ -0,0 +1,205 @@
1
+ """Authentication and authorization module."""
2
+
3
+ import hmac
4
+ import logging
5
+ import re
6
+ from datetime import datetime, timedelta
7
+ from typing import Dict, Optional, Tuple
8
+
9
+ import httpx
10
+ import jwt
11
+
12
+ from atlas.modules.config.config_manager import config_manager
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Cache with TTL for ALB public keys: {(kid, region): (key, expiry_time)}
17
+ _alb_key_cache: Dict[Tuple[str, str], Tuple[str, datetime]] = {}
18
+
19
+
20
+ async def is_user_in_group(user_id: str, group_id: str) -> bool:
21
+ """
22
+ Check if a user is in a specified group.
23
+
24
+ This function first checks for a configured external authorization endpoint.
25
+ If available, it makes an HTTP request to check group membership.
26
+ If not configured, it falls back to a mock implementation for local development.
27
+
28
+ Args:
29
+ user_id: User email/identifier.
30
+ group_id: Group identifier.
31
+
32
+ Returns:
33
+ True if the user is in the group, False otherwise.
34
+ """
35
+ app_settings = config_manager.app_settings
36
+ auth_url = app_settings.auth_group_check_url
37
+ api_key = app_settings.auth_group_check_api_key
38
+
39
+ if auth_url and api_key:
40
+ # Use the external HTTP endpoint for authorization
41
+ try:
42
+ async with httpx.AsyncClient() as client:
43
+ headers = {"Authorization": f"Bearer {api_key}"}
44
+ payload = {"user_id": user_id, "group_id": group_id}
45
+ response = await client.post(auth_url, json=payload, headers=headers, timeout=5.0)
46
+ response.raise_for_status()
47
+ # Assuming the endpoint returns a simple JSON like {"is_member": true}
48
+ return response.json().get("is_member", False)
49
+ except httpx.RequestError as e:
50
+ logger.error(f"HTTP request to auth endpoint failed: {e}", exc_info=True)
51
+ return False
52
+ except Exception as e:
53
+ logger.error(f"Error during external auth check: {e}", exc_info=True)
54
+ return False
55
+ else:
56
+ # Everybody is in the users group by default
57
+ if (group_id == "users"):
58
+ return True
59
+ # Fallback to mock implementation if no external endpoint is configured
60
+ if (app_settings.debug_mode and
61
+ user_id == app_settings.test_user and
62
+ group_id == app_settings.admin_group):
63
+ return True
64
+
65
+ mock_groups = {
66
+ "test@test.com": ["users", "mcp_basic", "admin"],
67
+ "user@example.com": ["users", "mcp_basic"],
68
+ "admin@example.com": ["admin", "users", "mcp_basic", "mcp_advanced"]
69
+ }
70
+ user_groups = mock_groups.get(user_id, [])
71
+ return group_id in user_groups
72
+
73
+
74
+ def _get_alb_public_key(kid: str, aws_region: str) -> Optional[str]:
75
+ """
76
+ Fetch and cache AWS ALB public key by key ID.
77
+
78
+ Caching reduces latency and API calls since AWS ALB rotates keys infrequently.
79
+ Cache has a 1-hour TTL to handle key rotation.
80
+
81
+ Args:
82
+ kid: Key ID from JWT header
83
+ aws_region: AWS region (e.g., 'us-east-1')
84
+
85
+ Returns:
86
+ Public key string, or None if fetch fails
87
+ """
88
+ # Security: Validate inputs to prevent URL injection and cache poisoning attacks
89
+ # kid and region are used in URL construction, so strict validation is critical
90
+ if not re.match(r'^[a-zA-Z0-9\-]+$', kid):
91
+ logger.error(f"Invalid kid format: {kid}")
92
+ return None
93
+ if not re.match(r'^[a-z]{2}-[a-z]+-\d+$', aws_region):
94
+ logger.error(f"Invalid AWS region format: {aws_region}")
95
+ return None
96
+
97
+ # Security: TTL-based cache (1 hour) allows key rotation and prevents stale keys
98
+ # if AWS rotates keys or a key is compromised
99
+ cache_key = (kid, aws_region)
100
+ now = datetime.utcnow()
101
+ if cache_key in _alb_key_cache:
102
+ cached_key, expiry = _alb_key_cache[cache_key]
103
+ if now < expiry:
104
+ return cached_key
105
+ else:
106
+ # Expired, remove from cache
107
+ del _alb_key_cache[cache_key]
108
+
109
+ url = f'https://public-keys.auth.elb.{aws_region}.amazonaws.com/{kid}'
110
+ try:
111
+ response = httpx.get(url, timeout=5.0)
112
+ response.raise_for_status()
113
+ pub_key = response.text
114
+
115
+ # Cache with 1-hour TTL
116
+ expiry = now + timedelta(hours=1)
117
+ _alb_key_cache[cache_key] = (pub_key, expiry)
118
+
119
+ return pub_key
120
+ except httpx.HTTPStatusError as e:
121
+ logger.error(f"HTTP error fetching ALB public key from {url}: {e.response.status_code}")
122
+ return None
123
+ except httpx.RequestError as e:
124
+ logger.error(f"Error fetching ALB public key from {url}: {e}")
125
+ return None
126
+
127
+
128
+ def get_user_from_aws_alb_jwt(encoded_jwt, expected_alb_arn, aws_region):
129
+ """
130
+ Validates the AWS ALB JWT and parses the email address from the payload.
131
+
132
+ Args:
133
+ encoded_jwt (str): The JWT from the x-amzn-oidc-data header.
134
+ expected_alb_arn (str): The ARN of your Application Load Balancer.
135
+ aws_region (str): The AWS region where your ALB is located (e.g., 'us-east-1').
136
+
137
+ Returns:
138
+ str: The user's email address, or None if validation fails.
139
+ """
140
+ if not encoded_jwt:
141
+ return None
142
+ try:
143
+ # Step 1: Decode the JWT header to get the key ID (kid) and signer using PyJWT
144
+ header = jwt.get_unverified_header(encoded_jwt)
145
+ kid = header.get('kid')
146
+ received_alb_arn = header.get('signer')
147
+
148
+ if not kid:
149
+ logger.error("Error: 'kid' not found in JWT header")
150
+ return None
151
+
152
+ # Step 2: Validate the signer matches the expected ALB ARN
153
+ # Security: hmac.compare_digest prevents timing attacks that could reveal the ARN
154
+ if not received_alb_arn or not hmac.compare_digest(received_alb_arn, expected_alb_arn):
155
+ logger.error(f"Error: Invalid signer ARN. Expected {expected_alb_arn}, got {received_alb_arn}")
156
+ return None
157
+
158
+ # Step 3: Get the public key from the regional endpoint (with caching)
159
+ pub_key = _get_alb_public_key(kid, aws_region)
160
+ if not pub_key:
161
+ logger.error("Error: Failed to fetch ALB public key")
162
+ return None
163
+
164
+ # Step 4: Validate the signature and claims using PyJWT
165
+ # The decode method handles signature verification and standard claims (like expiration)
166
+ # The ALB uses ES256 algorithm
167
+ payload = jwt.decode(
168
+ encoded_jwt,
169
+ pub_key,
170
+ algorithms=['ES256'],
171
+ # Optional: Add audience or issuer validation if needed, though ALB handles most standard claims validation
172
+ options={"verify_aud": False, "verify_iss": False}
173
+ )
174
+
175
+ # Step 5: Extract the email address from the payload
176
+ email_address = payload.get('email')
177
+ if email_address:
178
+ # Security: Validate email format to prevent injection attacks and ensure
179
+ # the email claim contains a properly formatted email address
180
+ email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
181
+ if not isinstance(email_address, str) or not re.match(email_pattern, email_address):
182
+ logger.error(f"Error: Invalid email format in JWT payload: {email_address}")
183
+ return None
184
+ logger.debug("Successfully authenticated user via AWS ALB JWT")
185
+ return email_address
186
+ else:
187
+ logger.error("Error: 'email' claim not found in JWT payload")
188
+ return None
189
+
190
+ except jwt.ExpiredSignatureError:
191
+ logger.error("Error: Token has expired")
192
+ return None
193
+ except jwt.InvalidTokenError as e:
194
+ logger.error(f"Error: Invalid token - {e}")
195
+ return None
196
+ except Exception as e:
197
+ logger.error(f"An unexpected error occurred: {e}")
198
+ return None
199
+
200
+
201
+ def get_user_from_header(x_email_header: Optional[str]) -> Optional[str]:
202
+ """Extract user email from authentication header value."""
203
+ if not x_email_header:
204
+ return None
205
+ return x_email_header.strip()
@@ -0,0 +1,27 @@
1
+ """Authorization utilities for managing access to resources."""
2
+
3
+ import logging
4
+ from typing import Awaitable, Callable
5
+
6
+ from atlas.modules.config.config_manager import get_app_settings
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ AuthCheckFunc = Callable[[str, str], Awaitable[bool]]
11
+
12
+
13
+ class AuthorizationManager:
14
+ """Manages authorization logic for admin access."""
15
+
16
+ def __init__(self, auth_check_func: AuthCheckFunc):
17
+ self.auth_check_func = auth_check_func
18
+ self.app_settings = get_app_settings()
19
+
20
+ async def is_admin(self, user_email: str) -> bool:
21
+ """Check if a user has admin privileges."""
22
+ return await self.auth_check_func(user_email, self.app_settings.admin_group)
23
+
24
+
25
+ def create_authorization_manager(auth_check_func: AuthCheckFunc) -> AuthorizationManager:
26
+ """Factory function to create an AuthorizationManager."""
27
+ return AuthorizationManager(auth_check_func)
@@ -0,0 +1,123 @@
1
+ """
2
+ Capability-token utilities for secure, headless access to resources.
3
+
4
+ Provides short-lived HMAC-signed tokens suitable for embedding in URLs,
5
+ primarily for file downloads by tools that don't carry session cookies.
6
+ """
7
+
8
+ import base64
9
+ import hmac
10
+ import json
11
+ import logging
12
+ import time
13
+ from hashlib import sha256
14
+ from typing import Any, Dict, Optional
15
+
16
+ from atlas.modules.config import config_manager
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def _b64url_encode(data: bytes) -> str:
22
+ return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
23
+
24
+
25
+ def _b64url_decode(data: str) -> bytes:
26
+ padding = "=" * (-len(data) % 4)
27
+ return base64.urlsafe_b64decode((data + padding).encode("ascii"))
28
+
29
+
30
+ def _get_secret() -> bytes:
31
+ """Get the capability token secret as bytes.
32
+
33
+ Order of precedence:
34
+ - App settings (config manager)
35
+ - Fallback development secret (unsafe for production)
36
+ """
37
+ try:
38
+ settings = config_manager.app_settings
39
+ if getattr(settings, "capability_token_secret", None):
40
+ return settings.capability_token_secret.encode("utf-8")
41
+ except Exception:
42
+ # Config not ready; continue to fallback with a dev secret.
43
+ logger.debug("Capability token secret not available; using fallback dev secret.")
44
+
45
+ logger.warning("Using fallback dev capability token secret. Set CAPABILITY_TOKEN_SECRET for security.")
46
+ return b"dev-capability-secret"
47
+
48
+
49
+ def _get_default_ttl_seconds() -> int:
50
+ try:
51
+ settings = config_manager.app_settings
52
+ ttl = getattr(settings, "capability_token_ttl_seconds", None)
53
+ if isinstance(ttl, int) and ttl > 0:
54
+ return ttl
55
+ except Exception:
56
+ logger.debug("Capability token TTL not available; using default TTL.")
57
+ return 3600
58
+
59
+
60
+ def generate_file_token(user_email: str, file_key: str, ttl_seconds: Optional[int] = None) -> str:
61
+ """Generate a short-lived token authorizing access to a file key for a user."""
62
+ exp = int(time.time()) + (ttl_seconds or _get_default_ttl_seconds())
63
+ payload = {"u": user_email, "k": file_key, "e": exp}
64
+ payload_bytes = json.dumps(payload, separators=(",", ":")).encode("utf-8")
65
+ body = _b64url_encode(payload_bytes)
66
+ sig = hmac.new(_get_secret(), body.encode("ascii"), sha256).digest()
67
+ return f"{body}.{_b64url_encode(sig)}"
68
+
69
+
70
+ def verify_file_token(token: str) -> Optional[Dict[str, Any]]:
71
+ """Verify a file token and return claims if valid, else None."""
72
+ try:
73
+ body, sig_b64 = token.split(".", 1)
74
+ expected_sig = hmac.new(_get_secret(), body.encode("ascii"), sha256).digest()
75
+ given_sig = _b64url_decode(sig_b64)
76
+ if not hmac.compare_digest(expected_sig, given_sig):
77
+ return None
78
+
79
+ claims = json.loads(_b64url_decode(body).decode("utf-8"))
80
+ if int(claims.get("e", 0)) < int(time.time()):
81
+ return None
82
+ # Ensure required claims exist
83
+ if not claims.get("u") or not claims.get("k"):
84
+ return None
85
+ return claims
86
+ except Exception:
87
+ return None
88
+
89
+
90
+ def create_download_url(file_key: str, user_email: Optional[str]) -> str:
91
+ """Create a download URL for a given file key, optionally with a token.
92
+
93
+ If BACKEND_PUBLIC_URL is configured, returns an absolute URL that remote MCP servers
94
+ can access. Otherwise, returns a relative URL (only works for local/stdio servers).
95
+
96
+ Args:
97
+ file_key: S3 key of the file to download
98
+ user_email: User email for token generation
99
+
100
+ Returns:
101
+ Download URL (absolute if BACKEND_PUBLIC_URL configured, relative otherwise)
102
+ """
103
+ # Build relative path with token
104
+ if user_email:
105
+ token = generate_file_token(user_email, file_key)
106
+ relative_path = f"/api/files/download/{file_key}?token={token}"
107
+ else:
108
+ # Fallback: no user context available
109
+ relative_path = f"/api/files/download/{file_key}"
110
+
111
+ # Check if we should use absolute URLs for remote MCP server access
112
+ try:
113
+ settings = config_manager.app_settings
114
+ backend_public_url = getattr(settings, "backend_public_url", None)
115
+ if backend_public_url:
116
+ # Strip trailing slash from base URL and combine with relative path
117
+ base = backend_public_url.rstrip("/")
118
+ return f"{base}{relative_path}"
119
+ except Exception as e:
120
+ logger.debug(f"Could not check backend_public_url config: {e}")
121
+
122
+ # Return relative URL as default
123
+ return relative_path
@@ -0,0 +1,215 @@
1
+ """
2
+ Compliance level management and validation.
3
+
4
+ Loads compliance level definitions from compliance-levels.json and provides
5
+ validation and allowlist checking.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+ from typing import Dict, List, Optional, Set
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class ComplianceLevel:
19
+ """Represents a single compliance level definition."""
20
+ name: str
21
+ description: str
22
+ aliases: List[str]
23
+ allowed_with: List[str] # List of compliance levels that can be used together
24
+
25
+
26
+ class ComplianceLevelManager:
27
+ """Manages compliance level definitions and validation."""
28
+
29
+ def __init__(self, config_path: Optional[Path] = None):
30
+ """Initialize the compliance level manager.
31
+
32
+ Args:
33
+ config_path: Path to compliance-levels.json. If None, uses default location.
34
+ """
35
+ self.levels: Dict[str, ComplianceLevel] = {}
36
+ self.mode: str = "explicit_allowlist"
37
+ self._name_to_canonical: Dict[str, str] = {} # Maps aliases to canonical names
38
+
39
+ if config_path is None:
40
+ # Try to find config in standard locations
41
+ backend_root = Path(__file__).parent.parent
42
+ project_root = backend_root.parent
43
+
44
+ search_paths = [
45
+ project_root / "config" / "overrides" / "compliance-levels.json",
46
+ project_root / "config" / "defaults" / "compliance-levels.json",
47
+ backend_root / "configfilesadmin" / "compliance-levels.json",
48
+ backend_root / "configfiles" / "compliance-levels.json",
49
+ ]
50
+
51
+ for path in search_paths:
52
+ if path.exists():
53
+ config_path = path
54
+ break
55
+
56
+ if config_path and config_path.exists():
57
+ self._load_config(config_path)
58
+ else:
59
+ logger.warning("No compliance-levels.json found, using permissive validation")
60
+
61
+ def _load_config(self, config_path: Path):
62
+ """Load compliance level configuration from JSON file."""
63
+ try:
64
+ with open(config_path, 'r', encoding='utf-8') as f:
65
+ config = json.load(f)
66
+
67
+ self.mode = config.get('mode', 'explicit_allowlist')
68
+
69
+ for level_data in config.get('levels', []):
70
+ level = ComplianceLevel(
71
+ name=level_data['name'],
72
+ description=level_data.get('description', ''),
73
+ aliases=level_data.get('aliases', []),
74
+ allowed_with=level_data.get('allowed_with', [level_data['name']])
75
+ )
76
+ self.levels[level.name] = level
77
+
78
+ # Map canonical name to itself
79
+ self._name_to_canonical[level.name] = level.name
80
+
81
+ # Map aliases to canonical name
82
+ for alias in level.aliases:
83
+ self._name_to_canonical[alias] = level.name
84
+
85
+ logger.info(f"Loaded {len(self.levels)} compliance levels from {config_path}")
86
+ logger.debug(f"Compliance levels: {list(self.levels.keys())}")
87
+
88
+ except Exception as e:
89
+ logger.error(f"Error loading compliance-levels.json: {e}")
90
+ # Continue with empty config for permissive validation
91
+
92
+ def get_canonical_name(self, name: Optional[str]) -> Optional[str]:
93
+ """Get the canonical name for a compliance level (resolves aliases).
94
+
95
+ Args:
96
+ name: Compliance level name or alias
97
+
98
+ Returns:
99
+ Canonical name, or None if not found
100
+ """
101
+ if not name:
102
+ return None
103
+ return self._name_to_canonical.get(name)
104
+
105
+ def validate_compliance_level(self, level_name: Optional[str], context: str = "") -> Optional[str]:
106
+ """Validate a compliance level name.
107
+
108
+ Args:
109
+ level_name: The compliance level to validate
110
+ context: Context for logging (e.g., "MCP server 'calculator'")
111
+
112
+ Returns:
113
+ Canonical name if valid, None if invalid (with warning logged)
114
+ """
115
+ if not level_name:
116
+ return None
117
+
118
+ canonical = self.get_canonical_name(level_name)
119
+
120
+ if canonical is None:
121
+ # No compliance config loaded - permissive mode
122
+ if not self.levels:
123
+ return level_name
124
+
125
+ # Unknown compliance level
126
+ valid_levels = list(self.levels.keys())
127
+ logger.warning(
128
+ f"Invalid compliance level '{level_name}' {context}. "
129
+ f"Valid levels: {', '.join(valid_levels)}. "
130
+ f"Setting to None."
131
+ )
132
+ return None
133
+
134
+ if canonical != level_name:
135
+ logger.debug(f"Resolved alias '{level_name}' to '{canonical}' {context}")
136
+
137
+ return canonical
138
+
139
+ def is_accessible(self, user_level: Optional[str], resource_level: Optional[str]) -> bool:
140
+ """Check if a resource at resource_level is accessible given user_level.
141
+
142
+ In explicit allowlist mode:
143
+ - Each level defines which other levels can be used together
144
+ - For example, HIPAA might allow HIPAA and SOC2, but not Public
145
+ - None (unset) is accessible by all and can access all
146
+
147
+ Args:
148
+ user_level: User's selected compliance level
149
+ resource_level: Resource's compliance level
150
+
151
+ Returns:
152
+ True if resource is accessible, False otherwise
153
+ """
154
+ # If either is None/unset, resource is accessible (backward compatibility)
155
+ if not user_level or not resource_level:
156
+ return True
157
+
158
+ # Get canonical names
159
+ user_canonical = self.get_canonical_name(user_level)
160
+ resource_canonical = self.get_canonical_name(resource_level)
161
+
162
+ # If we don't have level info, be permissive
163
+ if not user_canonical or not resource_canonical:
164
+ return True
165
+
166
+ # Get level object for user
167
+ user_level_obj = self.levels.get(user_canonical)
168
+
169
+ if not user_level_obj:
170
+ return True
171
+
172
+ # Check if resource_level is in the user's allowed_with list
173
+ return resource_canonical in user_level_obj.allowed_with
174
+
175
+ def get_accessible_levels(self, user_level: Optional[str]) -> Set[str]:
176
+ """Get all compliance levels accessible to a user.
177
+
178
+ Args:
179
+ user_level: User's selected compliance level
180
+
181
+ Returns:
182
+ Set of accessible compliance level names (canonical)
183
+ """
184
+ if not user_level or not self.levels:
185
+ # Return all levels if no user level or no config
186
+ return set(self.levels.keys()) if self.levels else set()
187
+
188
+ user_canonical = self.get_canonical_name(user_level)
189
+ if not user_canonical or user_canonical not in self.levels:
190
+ return set(self.levels.keys())
191
+
192
+ user_level_obj = self.levels[user_canonical]
193
+
194
+ # Return the allowed_with list for this level
195
+ return set(user_level_obj.allowed_with)
196
+
197
+ def get_all_levels(self) -> List[str]:
198
+ """Get all defined compliance level names (canonical).
199
+
200
+ Returns:
201
+ List of compliance level names in definition order
202
+ """
203
+ return list(self.levels.keys())
204
+
205
+
206
+ # Global instance
207
+ _compliance_manager: Optional[ComplianceLevelManager] = None
208
+
209
+
210
+ def get_compliance_manager() -> ComplianceLevelManager:
211
+ """Get the global compliance level manager instance."""
212
+ global _compliance_manager
213
+ if _compliance_manager is None:
214
+ _compliance_manager = ComplianceLevelManager()
215
+ return _compliance_manager