atlas-chat 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atlas/__init__.py +40 -0
- atlas/application/__init__.py +7 -0
- atlas/application/chat/__init__.py +7 -0
- atlas/application/chat/agent/__init__.py +10 -0
- atlas/application/chat/agent/act_loop.py +179 -0
- atlas/application/chat/agent/factory.py +142 -0
- atlas/application/chat/agent/protocols.py +46 -0
- atlas/application/chat/agent/react_loop.py +338 -0
- atlas/application/chat/agent/think_act_loop.py +171 -0
- atlas/application/chat/approval_manager.py +151 -0
- atlas/application/chat/elicitation_manager.py +191 -0
- atlas/application/chat/events/__init__.py +1 -0
- atlas/application/chat/events/agent_event_relay.py +112 -0
- atlas/application/chat/modes/__init__.py +1 -0
- atlas/application/chat/modes/agent.py +125 -0
- atlas/application/chat/modes/plain.py +74 -0
- atlas/application/chat/modes/rag.py +81 -0
- atlas/application/chat/modes/tools.py +179 -0
- atlas/application/chat/orchestrator.py +213 -0
- atlas/application/chat/policies/__init__.py +1 -0
- atlas/application/chat/policies/tool_authorization.py +99 -0
- atlas/application/chat/preprocessors/__init__.py +1 -0
- atlas/application/chat/preprocessors/message_builder.py +92 -0
- atlas/application/chat/preprocessors/prompt_override_service.py +104 -0
- atlas/application/chat/service.py +454 -0
- atlas/application/chat/utilities/__init__.py +6 -0
- atlas/application/chat/utilities/error_handler.py +367 -0
- atlas/application/chat/utilities/event_notifier.py +546 -0
- atlas/application/chat/utilities/file_processor.py +613 -0
- atlas/application/chat/utilities/tool_executor.py +789 -0
- atlas/atlas_chat_cli.py +347 -0
- atlas/atlas_client.py +238 -0
- atlas/core/__init__.py +0 -0
- atlas/core/auth.py +205 -0
- atlas/core/authorization_manager.py +27 -0
- atlas/core/capabilities.py +123 -0
- atlas/core/compliance.py +215 -0
- atlas/core/domain_whitelist.py +147 -0
- atlas/core/domain_whitelist_middleware.py +82 -0
- atlas/core/http_client.py +28 -0
- atlas/core/log_sanitizer.py +102 -0
- atlas/core/metrics_logger.py +59 -0
- atlas/core/middleware.py +131 -0
- atlas/core/otel_config.py +242 -0
- atlas/core/prompt_risk.py +200 -0
- atlas/core/rate_limit.py +0 -0
- atlas/core/rate_limit_middleware.py +64 -0
- atlas/core/security_headers_middleware.py +51 -0
- atlas/domain/__init__.py +37 -0
- atlas/domain/chat/__init__.py +1 -0
- atlas/domain/chat/dtos.py +85 -0
- atlas/domain/errors.py +96 -0
- atlas/domain/messages/__init__.py +12 -0
- atlas/domain/messages/models.py +160 -0
- atlas/domain/rag_mcp_service.py +664 -0
- atlas/domain/sessions/__init__.py +7 -0
- atlas/domain/sessions/models.py +36 -0
- atlas/domain/unified_rag_service.py +371 -0
- atlas/infrastructure/__init__.py +10 -0
- atlas/infrastructure/app_factory.py +135 -0
- atlas/infrastructure/events/__init__.py +1 -0
- atlas/infrastructure/events/cli_event_publisher.py +140 -0
- atlas/infrastructure/events/websocket_publisher.py +140 -0
- atlas/infrastructure/sessions/in_memory_repository.py +56 -0
- atlas/infrastructure/transport/__init__.py +7 -0
- atlas/infrastructure/transport/websocket_connection_adapter.py +33 -0
- atlas/init_cli.py +226 -0
- atlas/interfaces/__init__.py +15 -0
- atlas/interfaces/events.py +134 -0
- atlas/interfaces/llm.py +54 -0
- atlas/interfaces/rag.py +40 -0
- atlas/interfaces/sessions.py +75 -0
- atlas/interfaces/tools.py +57 -0
- atlas/interfaces/transport.py +24 -0
- atlas/main.py +564 -0
- atlas/mcp/api_key_demo/README.md +76 -0
- atlas/mcp/api_key_demo/main.py +172 -0
- atlas/mcp/api_key_demo/run.sh +56 -0
- atlas/mcp/basictable/main.py +147 -0
- atlas/mcp/calculator/main.py +149 -0
- atlas/mcp/code-executor/execution_engine.py +98 -0
- atlas/mcp/code-executor/execution_environment.py +95 -0
- atlas/mcp/code-executor/main.py +528 -0
- atlas/mcp/code-executor/result_processing.py +276 -0
- atlas/mcp/code-executor/script_generation.py +195 -0
- atlas/mcp/code-executor/security_checker.py +140 -0
- atlas/mcp/corporate_cars/main.py +437 -0
- atlas/mcp/csv_reporter/main.py +545 -0
- atlas/mcp/duckduckgo/main.py +182 -0
- atlas/mcp/elicitation_demo/README.md +171 -0
- atlas/mcp/elicitation_demo/main.py +262 -0
- atlas/mcp/env-demo/README.md +158 -0
- atlas/mcp/env-demo/main.py +199 -0
- atlas/mcp/file_size_test/main.py +284 -0
- atlas/mcp/filesystem/main.py +348 -0
- atlas/mcp/image_demo/main.py +113 -0
- atlas/mcp/image_demo/requirements.txt +4 -0
- atlas/mcp/logging_demo/README.md +72 -0
- atlas/mcp/logging_demo/main.py +103 -0
- atlas/mcp/many_tools_demo/main.py +50 -0
- atlas/mcp/order_database/__init__.py +0 -0
- atlas/mcp/order_database/main.py +369 -0
- atlas/mcp/order_database/signal_data.csv +1001 -0
- atlas/mcp/pdfbasic/main.py +394 -0
- atlas/mcp/pptx_generator/main.py +760 -0
- atlas/mcp/pptx_generator/requirements.txt +13 -0
- atlas/mcp/pptx_generator/run_test.sh +1 -0
- atlas/mcp/pptx_generator/test_pptx_generator_security.py +169 -0
- atlas/mcp/progress_demo/main.py +167 -0
- atlas/mcp/progress_updates_demo/QUICKSTART.md +273 -0
- atlas/mcp/progress_updates_demo/README.md +120 -0
- atlas/mcp/progress_updates_demo/main.py +497 -0
- atlas/mcp/prompts/main.py +222 -0
- atlas/mcp/public_demo/main.py +189 -0
- atlas/mcp/sampling_demo/README.md +169 -0
- atlas/mcp/sampling_demo/main.py +234 -0
- atlas/mcp/thinking/main.py +77 -0
- atlas/mcp/tool_planner/main.py +240 -0
- atlas/mcp/ui-demo/badmesh.png +0 -0
- atlas/mcp/ui-demo/main.py +383 -0
- atlas/mcp/ui-demo/templates/button_demo.html +32 -0
- atlas/mcp/ui-demo/templates/data_visualization.html +32 -0
- atlas/mcp/ui-demo/templates/form_demo.html +28 -0
- atlas/mcp/username-override-demo/README.md +320 -0
- atlas/mcp/username-override-demo/main.py +308 -0
- atlas/modules/__init__.py +0 -0
- atlas/modules/config/__init__.py +34 -0
- atlas/modules/config/cli.py +231 -0
- atlas/modules/config/config_manager.py +1096 -0
- atlas/modules/file_storage/__init__.py +22 -0
- atlas/modules/file_storage/cli.py +330 -0
- atlas/modules/file_storage/content_extractor.py +290 -0
- atlas/modules/file_storage/manager.py +295 -0
- atlas/modules/file_storage/mock_s3_client.py +402 -0
- atlas/modules/file_storage/s3_client.py +417 -0
- atlas/modules/llm/__init__.py +19 -0
- atlas/modules/llm/caller.py +287 -0
- atlas/modules/llm/litellm_caller.py +675 -0
- atlas/modules/llm/models.py +19 -0
- atlas/modules/mcp_tools/__init__.py +17 -0
- atlas/modules/mcp_tools/client.py +2123 -0
- atlas/modules/mcp_tools/token_storage.py +556 -0
- atlas/modules/prompts/prompt_provider.py +130 -0
- atlas/modules/rag/__init__.py +24 -0
- atlas/modules/rag/atlas_rag_client.py +336 -0
- atlas/modules/rag/client.py +129 -0
- atlas/routes/admin_routes.py +865 -0
- atlas/routes/config_routes.py +484 -0
- atlas/routes/feedback_routes.py +361 -0
- atlas/routes/files_routes.py +274 -0
- atlas/routes/health_routes.py +40 -0
- atlas/routes/mcp_auth_routes.py +223 -0
- atlas/server_cli.py +164 -0
- atlas/tests/conftest.py +20 -0
- atlas/tests/integration/test_mcp_auth_integration.py +152 -0
- atlas/tests/manual_test_sampling.py +87 -0
- atlas/tests/modules/mcp_tools/test_client_auth.py +226 -0
- atlas/tests/modules/mcp_tools/test_client_env.py +191 -0
- atlas/tests/test_admin_mcp_server_management_routes.py +141 -0
- atlas/tests/test_agent_roa.py +135 -0
- atlas/tests/test_app_factory_smoke.py +47 -0
- atlas/tests/test_approval_manager.py +439 -0
- atlas/tests/test_atlas_client.py +188 -0
- atlas/tests/test_atlas_rag_client.py +447 -0
- atlas/tests/test_atlas_rag_integration.py +224 -0
- atlas/tests/test_attach_file_flow.py +287 -0
- atlas/tests/test_auth_utils.py +165 -0
- atlas/tests/test_backend_public_url.py +185 -0
- atlas/tests/test_banner_logging.py +287 -0
- atlas/tests/test_capability_tokens_and_injection.py +203 -0
- atlas/tests/test_compliance_level.py +54 -0
- atlas/tests/test_compliance_manager.py +253 -0
- atlas/tests/test_config_manager.py +617 -0
- atlas/tests/test_config_manager_paths.py +12 -0
- atlas/tests/test_core_auth.py +18 -0
- atlas/tests/test_core_utils.py +190 -0
- atlas/tests/test_docker_env_sync.py +202 -0
- atlas/tests/test_domain_errors.py +329 -0
- atlas/tests/test_domain_whitelist.py +359 -0
- atlas/tests/test_elicitation_manager.py +408 -0
- atlas/tests/test_elicitation_routing.py +296 -0
- atlas/tests/test_env_demo_server.py +88 -0
- atlas/tests/test_error_classification.py +113 -0
- atlas/tests/test_error_flow_integration.py +116 -0
- atlas/tests/test_feedback_routes.py +333 -0
- atlas/tests/test_file_content_extraction.py +1134 -0
- atlas/tests/test_file_extraction_routes.py +158 -0
- atlas/tests/test_file_library.py +107 -0
- atlas/tests/test_file_manager_unit.py +18 -0
- atlas/tests/test_health_route.py +49 -0
- atlas/tests/test_http_client_stub.py +8 -0
- atlas/tests/test_imports_smoke.py +30 -0
- atlas/tests/test_interfaces_llm_response.py +9 -0
- atlas/tests/test_issue_access_denied_fix.py +136 -0
- atlas/tests/test_llm_env_expansion.py +836 -0
- atlas/tests/test_log_level_sensitive_data.py +285 -0
- atlas/tests/test_mcp_auth_routes.py +341 -0
- atlas/tests/test_mcp_client_auth.py +331 -0
- atlas/tests/test_mcp_data_injection.py +270 -0
- atlas/tests/test_mcp_get_authorized_servers.py +95 -0
- atlas/tests/test_mcp_hot_reload.py +512 -0
- atlas/tests/test_mcp_image_content.py +424 -0
- atlas/tests/test_mcp_logging.py +172 -0
- atlas/tests/test_mcp_progress_updates.py +313 -0
- atlas/tests/test_mcp_prompt_override_system_prompt.py +102 -0
- atlas/tests/test_mcp_prompts_server.py +39 -0
- atlas/tests/test_mcp_tool_result_parsing.py +296 -0
- atlas/tests/test_metrics_logger.py +56 -0
- atlas/tests/test_middleware_auth.py +379 -0
- atlas/tests/test_prompt_risk_and_acl.py +141 -0
- atlas/tests/test_rag_mcp_aggregator.py +204 -0
- atlas/tests/test_rag_mcp_service.py +224 -0
- atlas/tests/test_rate_limit_middleware.py +45 -0
- atlas/tests/test_routes_config_smoke.py +60 -0
- atlas/tests/test_routes_files_download_token.py +41 -0
- atlas/tests/test_routes_files_health.py +18 -0
- atlas/tests/test_runtime_imports.py +53 -0
- atlas/tests/test_sampling_integration.py +482 -0
- atlas/tests/test_security_admin_routes.py +61 -0
- atlas/tests/test_security_capability_tokens.py +65 -0
- atlas/tests/test_security_file_stats_scope.py +21 -0
- atlas/tests/test_security_header_injection.py +191 -0
- atlas/tests/test_security_headers_and_filename.py +63 -0
- atlas/tests/test_shared_session_repository.py +101 -0
- atlas/tests/test_system_prompt_loading.py +181 -0
- atlas/tests/test_token_storage.py +505 -0
- atlas/tests/test_tool_approval_config.py +93 -0
- atlas/tests/test_tool_approval_utils.py +356 -0
- atlas/tests/test_tool_authorization_group_filtering.py +223 -0
- atlas/tests/test_tool_details_in_config.py +108 -0
- atlas/tests/test_tool_planner.py +300 -0
- atlas/tests/test_unified_rag_service.py +398 -0
- atlas/tests/test_username_override_in_approval.py +258 -0
- atlas/tests/test_websocket_auth_header.py +168 -0
- atlas/version.py +6 -0
- atlas_chat-0.1.0.data/data/.env.example +253 -0
- atlas_chat-0.1.0.data/data/config/defaults/compliance-levels.json +44 -0
- atlas_chat-0.1.0.data/data/config/defaults/domain-whitelist.json +123 -0
- atlas_chat-0.1.0.data/data/config/defaults/file-extractors.json +74 -0
- atlas_chat-0.1.0.data/data/config/defaults/help-config.json +198 -0
- atlas_chat-0.1.0.data/data/config/defaults/llmconfig-buggy.yml +11 -0
- atlas_chat-0.1.0.data/data/config/defaults/llmconfig.yml +19 -0
- atlas_chat-0.1.0.data/data/config/defaults/mcp.json +138 -0
- atlas_chat-0.1.0.data/data/config/defaults/rag-sources.json +17 -0
- atlas_chat-0.1.0.data/data/config/defaults/splash-config.json +16 -0
- atlas_chat-0.1.0.dist-info/METADATA +236 -0
- atlas_chat-0.1.0.dist-info/RECORD +250 -0
- atlas_chat-0.1.0.dist-info/WHEEL +5 -0
- atlas_chat-0.1.0.dist-info/entry_points.txt +4 -0
- atlas_chat-0.1.0.dist-info/top_level.txt +1 -0
atlas/core/auth.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
"""Authentication and authorization module."""
|
|
2
|
+
|
|
3
|
+
import hmac
|
|
4
|
+
import logging
|
|
5
|
+
import re
|
|
6
|
+
from datetime import datetime, timedelta
|
|
7
|
+
from typing import Dict, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
import jwt
|
|
11
|
+
|
|
12
|
+
from atlas.modules.config.config_manager import config_manager
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
# Cache with TTL for ALB public keys: {(kid, region): (key, expiry_time)}
|
|
17
|
+
_alb_key_cache: Dict[Tuple[str, str], Tuple[str, datetime]] = {}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
async def is_user_in_group(user_id: str, group_id: str) -> bool:
|
|
21
|
+
"""
|
|
22
|
+
Check if a user is in a specified group.
|
|
23
|
+
|
|
24
|
+
This function first checks for a configured external authorization endpoint.
|
|
25
|
+
If available, it makes an HTTP request to check group membership.
|
|
26
|
+
If not configured, it falls back to a mock implementation for local development.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
user_id: User email/identifier.
|
|
30
|
+
group_id: Group identifier.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
True if the user is in the group, False otherwise.
|
|
34
|
+
"""
|
|
35
|
+
app_settings = config_manager.app_settings
|
|
36
|
+
auth_url = app_settings.auth_group_check_url
|
|
37
|
+
api_key = app_settings.auth_group_check_api_key
|
|
38
|
+
|
|
39
|
+
if auth_url and api_key:
|
|
40
|
+
# Use the external HTTP endpoint for authorization
|
|
41
|
+
try:
|
|
42
|
+
async with httpx.AsyncClient() as client:
|
|
43
|
+
headers = {"Authorization": f"Bearer {api_key}"}
|
|
44
|
+
payload = {"user_id": user_id, "group_id": group_id}
|
|
45
|
+
response = await client.post(auth_url, json=payload, headers=headers, timeout=5.0)
|
|
46
|
+
response.raise_for_status()
|
|
47
|
+
# Assuming the endpoint returns a simple JSON like {"is_member": true}
|
|
48
|
+
return response.json().get("is_member", False)
|
|
49
|
+
except httpx.RequestError as e:
|
|
50
|
+
logger.error(f"HTTP request to auth endpoint failed: {e}", exc_info=True)
|
|
51
|
+
return False
|
|
52
|
+
except Exception as e:
|
|
53
|
+
logger.error(f"Error during external auth check: {e}", exc_info=True)
|
|
54
|
+
return False
|
|
55
|
+
else:
|
|
56
|
+
# Everybody is in the users group by default
|
|
57
|
+
if (group_id == "users"):
|
|
58
|
+
return True
|
|
59
|
+
# Fallback to mock implementation if no external endpoint is configured
|
|
60
|
+
if (app_settings.debug_mode and
|
|
61
|
+
user_id == app_settings.test_user and
|
|
62
|
+
group_id == app_settings.admin_group):
|
|
63
|
+
return True
|
|
64
|
+
|
|
65
|
+
mock_groups = {
|
|
66
|
+
"test@test.com": ["users", "mcp_basic", "admin"],
|
|
67
|
+
"user@example.com": ["users", "mcp_basic"],
|
|
68
|
+
"admin@example.com": ["admin", "users", "mcp_basic", "mcp_advanced"]
|
|
69
|
+
}
|
|
70
|
+
user_groups = mock_groups.get(user_id, [])
|
|
71
|
+
return group_id in user_groups
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _get_alb_public_key(kid: str, aws_region: str) -> Optional[str]:
|
|
75
|
+
"""
|
|
76
|
+
Fetch and cache AWS ALB public key by key ID.
|
|
77
|
+
|
|
78
|
+
Caching reduces latency and API calls since AWS ALB rotates keys infrequently.
|
|
79
|
+
Cache has a 1-hour TTL to handle key rotation.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
kid: Key ID from JWT header
|
|
83
|
+
aws_region: AWS region (e.g., 'us-east-1')
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Public key string, or None if fetch fails
|
|
87
|
+
"""
|
|
88
|
+
# Security: Validate inputs to prevent URL injection and cache poisoning attacks
|
|
89
|
+
# kid and region are used in URL construction, so strict validation is critical
|
|
90
|
+
if not re.match(r'^[a-zA-Z0-9\-]+$', kid):
|
|
91
|
+
logger.error(f"Invalid kid format: {kid}")
|
|
92
|
+
return None
|
|
93
|
+
if not re.match(r'^[a-z]{2}-[a-z]+-\d+$', aws_region):
|
|
94
|
+
logger.error(f"Invalid AWS region format: {aws_region}")
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
# Security: TTL-based cache (1 hour) allows key rotation and prevents stale keys
|
|
98
|
+
# if AWS rotates keys or a key is compromised
|
|
99
|
+
cache_key = (kid, aws_region)
|
|
100
|
+
now = datetime.utcnow()
|
|
101
|
+
if cache_key in _alb_key_cache:
|
|
102
|
+
cached_key, expiry = _alb_key_cache[cache_key]
|
|
103
|
+
if now < expiry:
|
|
104
|
+
return cached_key
|
|
105
|
+
else:
|
|
106
|
+
# Expired, remove from cache
|
|
107
|
+
del _alb_key_cache[cache_key]
|
|
108
|
+
|
|
109
|
+
url = f'https://public-keys.auth.elb.{aws_region}.amazonaws.com/{kid}'
|
|
110
|
+
try:
|
|
111
|
+
response = httpx.get(url, timeout=5.0)
|
|
112
|
+
response.raise_for_status()
|
|
113
|
+
pub_key = response.text
|
|
114
|
+
|
|
115
|
+
# Cache with 1-hour TTL
|
|
116
|
+
expiry = now + timedelta(hours=1)
|
|
117
|
+
_alb_key_cache[cache_key] = (pub_key, expiry)
|
|
118
|
+
|
|
119
|
+
return pub_key
|
|
120
|
+
except httpx.HTTPStatusError as e:
|
|
121
|
+
logger.error(f"HTTP error fetching ALB public key from {url}: {e.response.status_code}")
|
|
122
|
+
return None
|
|
123
|
+
except httpx.RequestError as e:
|
|
124
|
+
logger.error(f"Error fetching ALB public key from {url}: {e}")
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def get_user_from_aws_alb_jwt(encoded_jwt, expected_alb_arn, aws_region):
|
|
129
|
+
"""
|
|
130
|
+
Validates the AWS ALB JWT and parses the email address from the payload.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
encoded_jwt (str): The JWT from the x-amzn-oidc-data header.
|
|
134
|
+
expected_alb_arn (str): The ARN of your Application Load Balancer.
|
|
135
|
+
aws_region (str): The AWS region where your ALB is located (e.g., 'us-east-1').
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
str: The user's email address, or None if validation fails.
|
|
139
|
+
"""
|
|
140
|
+
if not encoded_jwt:
|
|
141
|
+
return None
|
|
142
|
+
try:
|
|
143
|
+
# Step 1: Decode the JWT header to get the key ID (kid) and signer using PyJWT
|
|
144
|
+
header = jwt.get_unverified_header(encoded_jwt)
|
|
145
|
+
kid = header.get('kid')
|
|
146
|
+
received_alb_arn = header.get('signer')
|
|
147
|
+
|
|
148
|
+
if not kid:
|
|
149
|
+
logger.error("Error: 'kid' not found in JWT header")
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
# Step 2: Validate the signer matches the expected ALB ARN
|
|
153
|
+
# Security: hmac.compare_digest prevents timing attacks that could reveal the ARN
|
|
154
|
+
if not received_alb_arn or not hmac.compare_digest(received_alb_arn, expected_alb_arn):
|
|
155
|
+
logger.error(f"Error: Invalid signer ARN. Expected {expected_alb_arn}, got {received_alb_arn}")
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
# Step 3: Get the public key from the regional endpoint (with caching)
|
|
159
|
+
pub_key = _get_alb_public_key(kid, aws_region)
|
|
160
|
+
if not pub_key:
|
|
161
|
+
logger.error("Error: Failed to fetch ALB public key")
|
|
162
|
+
return None
|
|
163
|
+
|
|
164
|
+
# Step 4: Validate the signature and claims using PyJWT
|
|
165
|
+
# The decode method handles signature verification and standard claims (like expiration)
|
|
166
|
+
# The ALB uses ES256 algorithm
|
|
167
|
+
payload = jwt.decode(
|
|
168
|
+
encoded_jwt,
|
|
169
|
+
pub_key,
|
|
170
|
+
algorithms=['ES256'],
|
|
171
|
+
# Optional: Add audience or issuer validation if needed, though ALB handles most standard claims validation
|
|
172
|
+
options={"verify_aud": False, "verify_iss": False}
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Step 5: Extract the email address from the payload
|
|
176
|
+
email_address = payload.get('email')
|
|
177
|
+
if email_address:
|
|
178
|
+
# Security: Validate email format to prevent injection attacks and ensure
|
|
179
|
+
# the email claim contains a properly formatted email address
|
|
180
|
+
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
|
181
|
+
if not isinstance(email_address, str) or not re.match(email_pattern, email_address):
|
|
182
|
+
logger.error(f"Error: Invalid email format in JWT payload: {email_address}")
|
|
183
|
+
return None
|
|
184
|
+
logger.debug("Successfully authenticated user via AWS ALB JWT")
|
|
185
|
+
return email_address
|
|
186
|
+
else:
|
|
187
|
+
logger.error("Error: 'email' claim not found in JWT payload")
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
except jwt.ExpiredSignatureError:
|
|
191
|
+
logger.error("Error: Token has expired")
|
|
192
|
+
return None
|
|
193
|
+
except jwt.InvalidTokenError as e:
|
|
194
|
+
logger.error(f"Error: Invalid token - {e}")
|
|
195
|
+
return None
|
|
196
|
+
except Exception as e:
|
|
197
|
+
logger.error(f"An unexpected error occurred: {e}")
|
|
198
|
+
return None
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def get_user_from_header(x_email_header: Optional[str]) -> Optional[str]:
|
|
202
|
+
"""Extract user email from authentication header value."""
|
|
203
|
+
if not x_email_header:
|
|
204
|
+
return None
|
|
205
|
+
return x_email_header.strip()
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Authorization utilities for managing access to resources."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Awaitable, Callable
|
|
5
|
+
|
|
6
|
+
from atlas.modules.config.config_manager import get_app_settings
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
AuthCheckFunc = Callable[[str, str], Awaitable[bool]]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AuthorizationManager:
|
|
14
|
+
"""Manages authorization logic for admin access."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, auth_check_func: AuthCheckFunc):
|
|
17
|
+
self.auth_check_func = auth_check_func
|
|
18
|
+
self.app_settings = get_app_settings()
|
|
19
|
+
|
|
20
|
+
async def is_admin(self, user_email: str) -> bool:
|
|
21
|
+
"""Check if a user has admin privileges."""
|
|
22
|
+
return await self.auth_check_func(user_email, self.app_settings.admin_group)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def create_authorization_manager(auth_check_func: AuthCheckFunc) -> AuthorizationManager:
|
|
26
|
+
"""Factory function to create an AuthorizationManager."""
|
|
27
|
+
return AuthorizationManager(auth_check_func)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Capability-token utilities for secure, headless access to resources.
|
|
3
|
+
|
|
4
|
+
Provides short-lived HMAC-signed tokens suitable for embedding in URLs,
|
|
5
|
+
primarily for file downloads by tools that don't carry session cookies.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import base64
|
|
9
|
+
import hmac
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import time
|
|
13
|
+
from hashlib import sha256
|
|
14
|
+
from typing import Any, Dict, Optional
|
|
15
|
+
|
|
16
|
+
from atlas.modules.config import config_manager
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _b64url_encode(data: bytes) -> str:
|
|
22
|
+
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _b64url_decode(data: str) -> bytes:
|
|
26
|
+
padding = "=" * (-len(data) % 4)
|
|
27
|
+
return base64.urlsafe_b64decode((data + padding).encode("ascii"))
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _get_secret() -> bytes:
|
|
31
|
+
"""Get the capability token secret as bytes.
|
|
32
|
+
|
|
33
|
+
Order of precedence:
|
|
34
|
+
- App settings (config manager)
|
|
35
|
+
- Fallback development secret (unsafe for production)
|
|
36
|
+
"""
|
|
37
|
+
try:
|
|
38
|
+
settings = config_manager.app_settings
|
|
39
|
+
if getattr(settings, "capability_token_secret", None):
|
|
40
|
+
return settings.capability_token_secret.encode("utf-8")
|
|
41
|
+
except Exception:
|
|
42
|
+
# Config not ready; continue to fallback with a dev secret.
|
|
43
|
+
logger.debug("Capability token secret not available; using fallback dev secret.")
|
|
44
|
+
|
|
45
|
+
logger.warning("Using fallback dev capability token secret. Set CAPABILITY_TOKEN_SECRET for security.")
|
|
46
|
+
return b"dev-capability-secret"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _get_default_ttl_seconds() -> int:
|
|
50
|
+
try:
|
|
51
|
+
settings = config_manager.app_settings
|
|
52
|
+
ttl = getattr(settings, "capability_token_ttl_seconds", None)
|
|
53
|
+
if isinstance(ttl, int) and ttl > 0:
|
|
54
|
+
return ttl
|
|
55
|
+
except Exception:
|
|
56
|
+
logger.debug("Capability token TTL not available; using default TTL.")
|
|
57
|
+
return 3600
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def generate_file_token(user_email: str, file_key: str, ttl_seconds: Optional[int] = None) -> str:
|
|
61
|
+
"""Generate a short-lived token authorizing access to a file key for a user."""
|
|
62
|
+
exp = int(time.time()) + (ttl_seconds or _get_default_ttl_seconds())
|
|
63
|
+
payload = {"u": user_email, "k": file_key, "e": exp}
|
|
64
|
+
payload_bytes = json.dumps(payload, separators=(",", ":")).encode("utf-8")
|
|
65
|
+
body = _b64url_encode(payload_bytes)
|
|
66
|
+
sig = hmac.new(_get_secret(), body.encode("ascii"), sha256).digest()
|
|
67
|
+
return f"{body}.{_b64url_encode(sig)}"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def verify_file_token(token: str) -> Optional[Dict[str, Any]]:
|
|
71
|
+
"""Verify a file token and return claims if valid, else None."""
|
|
72
|
+
try:
|
|
73
|
+
body, sig_b64 = token.split(".", 1)
|
|
74
|
+
expected_sig = hmac.new(_get_secret(), body.encode("ascii"), sha256).digest()
|
|
75
|
+
given_sig = _b64url_decode(sig_b64)
|
|
76
|
+
if not hmac.compare_digest(expected_sig, given_sig):
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
claims = json.loads(_b64url_decode(body).decode("utf-8"))
|
|
80
|
+
if int(claims.get("e", 0)) < int(time.time()):
|
|
81
|
+
return None
|
|
82
|
+
# Ensure required claims exist
|
|
83
|
+
if not claims.get("u") or not claims.get("k"):
|
|
84
|
+
return None
|
|
85
|
+
return claims
|
|
86
|
+
except Exception:
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def create_download_url(file_key: str, user_email: Optional[str]) -> str:
|
|
91
|
+
"""Create a download URL for a given file key, optionally with a token.
|
|
92
|
+
|
|
93
|
+
If BACKEND_PUBLIC_URL is configured, returns an absolute URL that remote MCP servers
|
|
94
|
+
can access. Otherwise, returns a relative URL (only works for local/stdio servers).
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
file_key: S3 key of the file to download
|
|
98
|
+
user_email: User email for token generation
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Download URL (absolute if BACKEND_PUBLIC_URL configured, relative otherwise)
|
|
102
|
+
"""
|
|
103
|
+
# Build relative path with token
|
|
104
|
+
if user_email:
|
|
105
|
+
token = generate_file_token(user_email, file_key)
|
|
106
|
+
relative_path = f"/api/files/download/{file_key}?token={token}"
|
|
107
|
+
else:
|
|
108
|
+
# Fallback: no user context available
|
|
109
|
+
relative_path = f"/api/files/download/{file_key}"
|
|
110
|
+
|
|
111
|
+
# Check if we should use absolute URLs for remote MCP server access
|
|
112
|
+
try:
|
|
113
|
+
settings = config_manager.app_settings
|
|
114
|
+
backend_public_url = getattr(settings, "backend_public_url", None)
|
|
115
|
+
if backend_public_url:
|
|
116
|
+
# Strip trailing slash from base URL and combine with relative path
|
|
117
|
+
base = backend_public_url.rstrip("/")
|
|
118
|
+
return f"{base}{relative_path}"
|
|
119
|
+
except Exception as e:
|
|
120
|
+
logger.debug(f"Could not check backend_public_url config: {e}")
|
|
121
|
+
|
|
122
|
+
# Return relative URL as default
|
|
123
|
+
return relative_path
|
atlas/core/compliance.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Compliance level management and validation.
|
|
3
|
+
|
|
4
|
+
Loads compliance level definitions from compliance-levels.json and provides
|
|
5
|
+
validation and allowlist checking.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Dict, List, Optional, Set
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ComplianceLevel:
|
|
19
|
+
"""Represents a single compliance level definition."""
|
|
20
|
+
name: str
|
|
21
|
+
description: str
|
|
22
|
+
aliases: List[str]
|
|
23
|
+
allowed_with: List[str] # List of compliance levels that can be used together
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ComplianceLevelManager:
|
|
27
|
+
"""Manages compliance level definitions and validation."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, config_path: Optional[Path] = None):
|
|
30
|
+
"""Initialize the compliance level manager.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
config_path: Path to compliance-levels.json. If None, uses default location.
|
|
34
|
+
"""
|
|
35
|
+
self.levels: Dict[str, ComplianceLevel] = {}
|
|
36
|
+
self.mode: str = "explicit_allowlist"
|
|
37
|
+
self._name_to_canonical: Dict[str, str] = {} # Maps aliases to canonical names
|
|
38
|
+
|
|
39
|
+
if config_path is None:
|
|
40
|
+
# Try to find config in standard locations
|
|
41
|
+
backend_root = Path(__file__).parent.parent
|
|
42
|
+
project_root = backend_root.parent
|
|
43
|
+
|
|
44
|
+
search_paths = [
|
|
45
|
+
project_root / "config" / "overrides" / "compliance-levels.json",
|
|
46
|
+
project_root / "config" / "defaults" / "compliance-levels.json",
|
|
47
|
+
backend_root / "configfilesadmin" / "compliance-levels.json",
|
|
48
|
+
backend_root / "configfiles" / "compliance-levels.json",
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
for path in search_paths:
|
|
52
|
+
if path.exists():
|
|
53
|
+
config_path = path
|
|
54
|
+
break
|
|
55
|
+
|
|
56
|
+
if config_path and config_path.exists():
|
|
57
|
+
self._load_config(config_path)
|
|
58
|
+
else:
|
|
59
|
+
logger.warning("No compliance-levels.json found, using permissive validation")
|
|
60
|
+
|
|
61
|
+
def _load_config(self, config_path: Path):
|
|
62
|
+
"""Load compliance level configuration from JSON file."""
|
|
63
|
+
try:
|
|
64
|
+
with open(config_path, 'r', encoding='utf-8') as f:
|
|
65
|
+
config = json.load(f)
|
|
66
|
+
|
|
67
|
+
self.mode = config.get('mode', 'explicit_allowlist')
|
|
68
|
+
|
|
69
|
+
for level_data in config.get('levels', []):
|
|
70
|
+
level = ComplianceLevel(
|
|
71
|
+
name=level_data['name'],
|
|
72
|
+
description=level_data.get('description', ''),
|
|
73
|
+
aliases=level_data.get('aliases', []),
|
|
74
|
+
allowed_with=level_data.get('allowed_with', [level_data['name']])
|
|
75
|
+
)
|
|
76
|
+
self.levels[level.name] = level
|
|
77
|
+
|
|
78
|
+
# Map canonical name to itself
|
|
79
|
+
self._name_to_canonical[level.name] = level.name
|
|
80
|
+
|
|
81
|
+
# Map aliases to canonical name
|
|
82
|
+
for alias in level.aliases:
|
|
83
|
+
self._name_to_canonical[alias] = level.name
|
|
84
|
+
|
|
85
|
+
logger.info(f"Loaded {len(self.levels)} compliance levels from {config_path}")
|
|
86
|
+
logger.debug(f"Compliance levels: {list(self.levels.keys())}")
|
|
87
|
+
|
|
88
|
+
except Exception as e:
|
|
89
|
+
logger.error(f"Error loading compliance-levels.json: {e}")
|
|
90
|
+
# Continue with empty config for permissive validation
|
|
91
|
+
|
|
92
|
+
def get_canonical_name(self, name: Optional[str]) -> Optional[str]:
|
|
93
|
+
"""Get the canonical name for a compliance level (resolves aliases).
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
name: Compliance level name or alias
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Canonical name, or None if not found
|
|
100
|
+
"""
|
|
101
|
+
if not name:
|
|
102
|
+
return None
|
|
103
|
+
return self._name_to_canonical.get(name)
|
|
104
|
+
|
|
105
|
+
def validate_compliance_level(self, level_name: Optional[str], context: str = "") -> Optional[str]:
|
|
106
|
+
"""Validate a compliance level name.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
level_name: The compliance level to validate
|
|
110
|
+
context: Context for logging (e.g., "MCP server 'calculator'")
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Canonical name if valid, None if invalid (with warning logged)
|
|
114
|
+
"""
|
|
115
|
+
if not level_name:
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
canonical = self.get_canonical_name(level_name)
|
|
119
|
+
|
|
120
|
+
if canonical is None:
|
|
121
|
+
# No compliance config loaded - permissive mode
|
|
122
|
+
if not self.levels:
|
|
123
|
+
return level_name
|
|
124
|
+
|
|
125
|
+
# Unknown compliance level
|
|
126
|
+
valid_levels = list(self.levels.keys())
|
|
127
|
+
logger.warning(
|
|
128
|
+
f"Invalid compliance level '{level_name}' {context}. "
|
|
129
|
+
f"Valid levels: {', '.join(valid_levels)}. "
|
|
130
|
+
f"Setting to None."
|
|
131
|
+
)
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
if canonical != level_name:
|
|
135
|
+
logger.debug(f"Resolved alias '{level_name}' to '{canonical}' {context}")
|
|
136
|
+
|
|
137
|
+
return canonical
|
|
138
|
+
|
|
139
|
+
def is_accessible(self, user_level: Optional[str], resource_level: Optional[str]) -> bool:
|
|
140
|
+
"""Check if a resource at resource_level is accessible given user_level.
|
|
141
|
+
|
|
142
|
+
In explicit allowlist mode:
|
|
143
|
+
- Each level defines which other levels can be used together
|
|
144
|
+
- For example, HIPAA might allow HIPAA and SOC2, but not Public
|
|
145
|
+
- None (unset) is accessible by all and can access all
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
user_level: User's selected compliance level
|
|
149
|
+
resource_level: Resource's compliance level
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
True if resource is accessible, False otherwise
|
|
153
|
+
"""
|
|
154
|
+
# If either is None/unset, resource is accessible (backward compatibility)
|
|
155
|
+
if not user_level or not resource_level:
|
|
156
|
+
return True
|
|
157
|
+
|
|
158
|
+
# Get canonical names
|
|
159
|
+
user_canonical = self.get_canonical_name(user_level)
|
|
160
|
+
resource_canonical = self.get_canonical_name(resource_level)
|
|
161
|
+
|
|
162
|
+
# If we don't have level info, be permissive
|
|
163
|
+
if not user_canonical or not resource_canonical:
|
|
164
|
+
return True
|
|
165
|
+
|
|
166
|
+
# Get level object for user
|
|
167
|
+
user_level_obj = self.levels.get(user_canonical)
|
|
168
|
+
|
|
169
|
+
if not user_level_obj:
|
|
170
|
+
return True
|
|
171
|
+
|
|
172
|
+
# Check if resource_level is in the user's allowed_with list
|
|
173
|
+
return resource_canonical in user_level_obj.allowed_with
|
|
174
|
+
|
|
175
|
+
def get_accessible_levels(self, user_level: Optional[str]) -> Set[str]:
|
|
176
|
+
"""Get all compliance levels accessible to a user.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
user_level: User's selected compliance level
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Set of accessible compliance level names (canonical)
|
|
183
|
+
"""
|
|
184
|
+
if not user_level or not self.levels:
|
|
185
|
+
# Return all levels if no user level or no config
|
|
186
|
+
return set(self.levels.keys()) if self.levels else set()
|
|
187
|
+
|
|
188
|
+
user_canonical = self.get_canonical_name(user_level)
|
|
189
|
+
if not user_canonical or user_canonical not in self.levels:
|
|
190
|
+
return set(self.levels.keys())
|
|
191
|
+
|
|
192
|
+
user_level_obj = self.levels[user_canonical]
|
|
193
|
+
|
|
194
|
+
# Return the allowed_with list for this level
|
|
195
|
+
return set(user_level_obj.allowed_with)
|
|
196
|
+
|
|
197
|
+
def get_all_levels(self) -> List[str]:
|
|
198
|
+
"""Get all defined compliance level names (canonical).
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
List of compliance level names in definition order
|
|
202
|
+
"""
|
|
203
|
+
return list(self.levels.keys())
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
# Global instance
|
|
207
|
+
_compliance_manager: Optional[ComplianceLevelManager] = None
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def get_compliance_manager() -> ComplianceLevelManager:
|
|
211
|
+
"""Get the global compliance level manager instance."""
|
|
212
|
+
global _compliance_manager
|
|
213
|
+
if _compliance_manager is None:
|
|
214
|
+
_compliance_manager = ComplianceLevelManager()
|
|
215
|
+
return _compliance_manager
|