atlas-chat 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atlas/__init__.py +40 -0
- atlas/application/__init__.py +7 -0
- atlas/application/chat/__init__.py +7 -0
- atlas/application/chat/agent/__init__.py +10 -0
- atlas/application/chat/agent/act_loop.py +179 -0
- atlas/application/chat/agent/factory.py +142 -0
- atlas/application/chat/agent/protocols.py +46 -0
- atlas/application/chat/agent/react_loop.py +338 -0
- atlas/application/chat/agent/think_act_loop.py +171 -0
- atlas/application/chat/approval_manager.py +151 -0
- atlas/application/chat/elicitation_manager.py +191 -0
- atlas/application/chat/events/__init__.py +1 -0
- atlas/application/chat/events/agent_event_relay.py +112 -0
- atlas/application/chat/modes/__init__.py +1 -0
- atlas/application/chat/modes/agent.py +125 -0
- atlas/application/chat/modes/plain.py +74 -0
- atlas/application/chat/modes/rag.py +81 -0
- atlas/application/chat/modes/tools.py +179 -0
- atlas/application/chat/orchestrator.py +213 -0
- atlas/application/chat/policies/__init__.py +1 -0
- atlas/application/chat/policies/tool_authorization.py +99 -0
- atlas/application/chat/preprocessors/__init__.py +1 -0
- atlas/application/chat/preprocessors/message_builder.py +92 -0
- atlas/application/chat/preprocessors/prompt_override_service.py +104 -0
- atlas/application/chat/service.py +454 -0
- atlas/application/chat/utilities/__init__.py +6 -0
- atlas/application/chat/utilities/error_handler.py +367 -0
- atlas/application/chat/utilities/event_notifier.py +546 -0
- atlas/application/chat/utilities/file_processor.py +613 -0
- atlas/application/chat/utilities/tool_executor.py +789 -0
- atlas/atlas_chat_cli.py +347 -0
- atlas/atlas_client.py +238 -0
- atlas/core/__init__.py +0 -0
- atlas/core/auth.py +205 -0
- atlas/core/authorization_manager.py +27 -0
- atlas/core/capabilities.py +123 -0
- atlas/core/compliance.py +215 -0
- atlas/core/domain_whitelist.py +147 -0
- atlas/core/domain_whitelist_middleware.py +82 -0
- atlas/core/http_client.py +28 -0
- atlas/core/log_sanitizer.py +102 -0
- atlas/core/metrics_logger.py +59 -0
- atlas/core/middleware.py +131 -0
- atlas/core/otel_config.py +242 -0
- atlas/core/prompt_risk.py +200 -0
- atlas/core/rate_limit.py +0 -0
- atlas/core/rate_limit_middleware.py +64 -0
- atlas/core/security_headers_middleware.py +51 -0
- atlas/domain/__init__.py +37 -0
- atlas/domain/chat/__init__.py +1 -0
- atlas/domain/chat/dtos.py +85 -0
- atlas/domain/errors.py +96 -0
- atlas/domain/messages/__init__.py +12 -0
- atlas/domain/messages/models.py +160 -0
- atlas/domain/rag_mcp_service.py +664 -0
- atlas/domain/sessions/__init__.py +7 -0
- atlas/domain/sessions/models.py +36 -0
- atlas/domain/unified_rag_service.py +371 -0
- atlas/infrastructure/__init__.py +10 -0
- atlas/infrastructure/app_factory.py +135 -0
- atlas/infrastructure/events/__init__.py +1 -0
- atlas/infrastructure/events/cli_event_publisher.py +140 -0
- atlas/infrastructure/events/websocket_publisher.py +140 -0
- atlas/infrastructure/sessions/in_memory_repository.py +56 -0
- atlas/infrastructure/transport/__init__.py +7 -0
- atlas/infrastructure/transport/websocket_connection_adapter.py +33 -0
- atlas/init_cli.py +226 -0
- atlas/interfaces/__init__.py +15 -0
- atlas/interfaces/events.py +134 -0
- atlas/interfaces/llm.py +54 -0
- atlas/interfaces/rag.py +40 -0
- atlas/interfaces/sessions.py +75 -0
- atlas/interfaces/tools.py +57 -0
- atlas/interfaces/transport.py +24 -0
- atlas/main.py +564 -0
- atlas/mcp/api_key_demo/README.md +76 -0
- atlas/mcp/api_key_demo/main.py +172 -0
- atlas/mcp/api_key_demo/run.sh +56 -0
- atlas/mcp/basictable/main.py +147 -0
- atlas/mcp/calculator/main.py +149 -0
- atlas/mcp/code-executor/execution_engine.py +98 -0
- atlas/mcp/code-executor/execution_environment.py +95 -0
- atlas/mcp/code-executor/main.py +528 -0
- atlas/mcp/code-executor/result_processing.py +276 -0
- atlas/mcp/code-executor/script_generation.py +195 -0
- atlas/mcp/code-executor/security_checker.py +140 -0
- atlas/mcp/corporate_cars/main.py +437 -0
- atlas/mcp/csv_reporter/main.py +545 -0
- atlas/mcp/duckduckgo/main.py +182 -0
- atlas/mcp/elicitation_demo/README.md +171 -0
- atlas/mcp/elicitation_demo/main.py +262 -0
- atlas/mcp/env-demo/README.md +158 -0
- atlas/mcp/env-demo/main.py +199 -0
- atlas/mcp/file_size_test/main.py +284 -0
- atlas/mcp/filesystem/main.py +348 -0
- atlas/mcp/image_demo/main.py +113 -0
- atlas/mcp/image_demo/requirements.txt +4 -0
- atlas/mcp/logging_demo/README.md +72 -0
- atlas/mcp/logging_demo/main.py +103 -0
- atlas/mcp/many_tools_demo/main.py +50 -0
- atlas/mcp/order_database/__init__.py +0 -0
- atlas/mcp/order_database/main.py +369 -0
- atlas/mcp/order_database/signal_data.csv +1001 -0
- atlas/mcp/pdfbasic/main.py +394 -0
- atlas/mcp/pptx_generator/main.py +760 -0
- atlas/mcp/pptx_generator/requirements.txt +13 -0
- atlas/mcp/pptx_generator/run_test.sh +1 -0
- atlas/mcp/pptx_generator/test_pptx_generator_security.py +169 -0
- atlas/mcp/progress_demo/main.py +167 -0
- atlas/mcp/progress_updates_demo/QUICKSTART.md +273 -0
- atlas/mcp/progress_updates_demo/README.md +120 -0
- atlas/mcp/progress_updates_demo/main.py +497 -0
- atlas/mcp/prompts/main.py +222 -0
- atlas/mcp/public_demo/main.py +189 -0
- atlas/mcp/sampling_demo/README.md +169 -0
- atlas/mcp/sampling_demo/main.py +234 -0
- atlas/mcp/thinking/main.py +77 -0
- atlas/mcp/tool_planner/main.py +240 -0
- atlas/mcp/ui-demo/badmesh.png +0 -0
- atlas/mcp/ui-demo/main.py +383 -0
- atlas/mcp/ui-demo/templates/button_demo.html +32 -0
- atlas/mcp/ui-demo/templates/data_visualization.html +32 -0
- atlas/mcp/ui-demo/templates/form_demo.html +28 -0
- atlas/mcp/username-override-demo/README.md +320 -0
- atlas/mcp/username-override-demo/main.py +308 -0
- atlas/modules/__init__.py +0 -0
- atlas/modules/config/__init__.py +34 -0
- atlas/modules/config/cli.py +231 -0
- atlas/modules/config/config_manager.py +1096 -0
- atlas/modules/file_storage/__init__.py +22 -0
- atlas/modules/file_storage/cli.py +330 -0
- atlas/modules/file_storage/content_extractor.py +290 -0
- atlas/modules/file_storage/manager.py +295 -0
- atlas/modules/file_storage/mock_s3_client.py +402 -0
- atlas/modules/file_storage/s3_client.py +417 -0
- atlas/modules/llm/__init__.py +19 -0
- atlas/modules/llm/caller.py +287 -0
- atlas/modules/llm/litellm_caller.py +675 -0
- atlas/modules/llm/models.py +19 -0
- atlas/modules/mcp_tools/__init__.py +17 -0
- atlas/modules/mcp_tools/client.py +2123 -0
- atlas/modules/mcp_tools/token_storage.py +556 -0
- atlas/modules/prompts/prompt_provider.py +130 -0
- atlas/modules/rag/__init__.py +24 -0
- atlas/modules/rag/atlas_rag_client.py +336 -0
- atlas/modules/rag/client.py +129 -0
- atlas/routes/admin_routes.py +865 -0
- atlas/routes/config_routes.py +484 -0
- atlas/routes/feedback_routes.py +361 -0
- atlas/routes/files_routes.py +274 -0
- atlas/routes/health_routes.py +40 -0
- atlas/routes/mcp_auth_routes.py +223 -0
- atlas/server_cli.py +164 -0
- atlas/tests/conftest.py +20 -0
- atlas/tests/integration/test_mcp_auth_integration.py +152 -0
- atlas/tests/manual_test_sampling.py +87 -0
- atlas/tests/modules/mcp_tools/test_client_auth.py +226 -0
- atlas/tests/modules/mcp_tools/test_client_env.py +191 -0
- atlas/tests/test_admin_mcp_server_management_routes.py +141 -0
- atlas/tests/test_agent_roa.py +135 -0
- atlas/tests/test_app_factory_smoke.py +47 -0
- atlas/tests/test_approval_manager.py +439 -0
- atlas/tests/test_atlas_client.py +188 -0
- atlas/tests/test_atlas_rag_client.py +447 -0
- atlas/tests/test_atlas_rag_integration.py +224 -0
- atlas/tests/test_attach_file_flow.py +287 -0
- atlas/tests/test_auth_utils.py +165 -0
- atlas/tests/test_backend_public_url.py +185 -0
- atlas/tests/test_banner_logging.py +287 -0
- atlas/tests/test_capability_tokens_and_injection.py +203 -0
- atlas/tests/test_compliance_level.py +54 -0
- atlas/tests/test_compliance_manager.py +253 -0
- atlas/tests/test_config_manager.py +617 -0
- atlas/tests/test_config_manager_paths.py +12 -0
- atlas/tests/test_core_auth.py +18 -0
- atlas/tests/test_core_utils.py +190 -0
- atlas/tests/test_docker_env_sync.py +202 -0
- atlas/tests/test_domain_errors.py +329 -0
- atlas/tests/test_domain_whitelist.py +359 -0
- atlas/tests/test_elicitation_manager.py +408 -0
- atlas/tests/test_elicitation_routing.py +296 -0
- atlas/tests/test_env_demo_server.py +88 -0
- atlas/tests/test_error_classification.py +113 -0
- atlas/tests/test_error_flow_integration.py +116 -0
- atlas/tests/test_feedback_routes.py +333 -0
- atlas/tests/test_file_content_extraction.py +1134 -0
- atlas/tests/test_file_extraction_routes.py +158 -0
- atlas/tests/test_file_library.py +107 -0
- atlas/tests/test_file_manager_unit.py +18 -0
- atlas/tests/test_health_route.py +49 -0
- atlas/tests/test_http_client_stub.py +8 -0
- atlas/tests/test_imports_smoke.py +30 -0
- atlas/tests/test_interfaces_llm_response.py +9 -0
- atlas/tests/test_issue_access_denied_fix.py +136 -0
- atlas/tests/test_llm_env_expansion.py +836 -0
- atlas/tests/test_log_level_sensitive_data.py +285 -0
- atlas/tests/test_mcp_auth_routes.py +341 -0
- atlas/tests/test_mcp_client_auth.py +331 -0
- atlas/tests/test_mcp_data_injection.py +270 -0
- atlas/tests/test_mcp_get_authorized_servers.py +95 -0
- atlas/tests/test_mcp_hot_reload.py +512 -0
- atlas/tests/test_mcp_image_content.py +424 -0
- atlas/tests/test_mcp_logging.py +172 -0
- atlas/tests/test_mcp_progress_updates.py +313 -0
- atlas/tests/test_mcp_prompt_override_system_prompt.py +102 -0
- atlas/tests/test_mcp_prompts_server.py +39 -0
- atlas/tests/test_mcp_tool_result_parsing.py +296 -0
- atlas/tests/test_metrics_logger.py +56 -0
- atlas/tests/test_middleware_auth.py +379 -0
- atlas/tests/test_prompt_risk_and_acl.py +141 -0
- atlas/tests/test_rag_mcp_aggregator.py +204 -0
- atlas/tests/test_rag_mcp_service.py +224 -0
- atlas/tests/test_rate_limit_middleware.py +45 -0
- atlas/tests/test_routes_config_smoke.py +60 -0
- atlas/tests/test_routes_files_download_token.py +41 -0
- atlas/tests/test_routes_files_health.py +18 -0
- atlas/tests/test_runtime_imports.py +53 -0
- atlas/tests/test_sampling_integration.py +482 -0
- atlas/tests/test_security_admin_routes.py +61 -0
- atlas/tests/test_security_capability_tokens.py +65 -0
- atlas/tests/test_security_file_stats_scope.py +21 -0
- atlas/tests/test_security_header_injection.py +191 -0
- atlas/tests/test_security_headers_and_filename.py +63 -0
- atlas/tests/test_shared_session_repository.py +101 -0
- atlas/tests/test_system_prompt_loading.py +181 -0
- atlas/tests/test_token_storage.py +505 -0
- atlas/tests/test_tool_approval_config.py +93 -0
- atlas/tests/test_tool_approval_utils.py +356 -0
- atlas/tests/test_tool_authorization_group_filtering.py +223 -0
- atlas/tests/test_tool_details_in_config.py +108 -0
- atlas/tests/test_tool_planner.py +300 -0
- atlas/tests/test_unified_rag_service.py +398 -0
- atlas/tests/test_username_override_in_approval.py +258 -0
- atlas/tests/test_websocket_auth_header.py +168 -0
- atlas/version.py +6 -0
- atlas_chat-0.1.0.data/data/.env.example +253 -0
- atlas_chat-0.1.0.data/data/config/defaults/compliance-levels.json +44 -0
- atlas_chat-0.1.0.data/data/config/defaults/domain-whitelist.json +123 -0
- atlas_chat-0.1.0.data/data/config/defaults/file-extractors.json +74 -0
- atlas_chat-0.1.0.data/data/config/defaults/help-config.json +198 -0
- atlas_chat-0.1.0.data/data/config/defaults/llmconfig-buggy.yml +11 -0
- atlas_chat-0.1.0.data/data/config/defaults/llmconfig.yml +19 -0
- atlas_chat-0.1.0.data/data/config/defaults/mcp.json +138 -0
- atlas_chat-0.1.0.data/data/config/defaults/rag-sources.json +17 -0
- atlas_chat-0.1.0.data/data/config/defaults/splash-config.json +16 -0
- atlas_chat-0.1.0.dist-info/METADATA +236 -0
- atlas_chat-0.1.0.dist-info/RECORD +250 -0
- atlas_chat-0.1.0.dist-info/WHEEL +5 -0
- atlas_chat-0.1.0.dist-info/entry_points.txt +4 -0
- atlas_chat-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Domain models for sessions."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
from typing import Any, Dict, Optional
|
|
6
|
+
from uuid import UUID, uuid4
|
|
7
|
+
|
|
8
|
+
from ..messages.models import ConversationHistory
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class Session:
|
|
13
|
+
"""Domain model for a chat session."""
|
|
14
|
+
id: UUID = field(default_factory=uuid4)
|
|
15
|
+
user_email: Optional[str] = None
|
|
16
|
+
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
17
|
+
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
18
|
+
history: ConversationHistory = field(default_factory=ConversationHistory)
|
|
19
|
+
context: Dict[str, Any] = field(default_factory=dict)
|
|
20
|
+
active: bool = True
|
|
21
|
+
|
|
22
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
23
|
+
"""Convert to dictionary."""
|
|
24
|
+
return {
|
|
25
|
+
"id": str(self.id),
|
|
26
|
+
"user_email": self.user_email,
|
|
27
|
+
"created_at": self.created_at.isoformat(),
|
|
28
|
+
"updated_at": self.updated_at.isoformat(),
|
|
29
|
+
"history": self.history.to_dict(),
|
|
30
|
+
"context": self.context,
|
|
31
|
+
"active": self.active
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
def update_timestamp(self) -> None:
|
|
35
|
+
"""Update the last modified timestamp."""
|
|
36
|
+
self.updated_at = datetime.now(timezone.utc)
|
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
"""Unified RAG Service that aggregates HTTP and MCP RAG sources.
|
|
2
|
+
|
|
3
|
+
This service provides a single interface for:
|
|
4
|
+
- Discovering data sources across all configured RAG backends
|
|
5
|
+
- Querying RAG sources with automatic routing based on source type
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
12
|
+
|
|
13
|
+
from atlas.core.compliance import get_compliance_manager
|
|
14
|
+
from atlas.core.log_sanitizer import sanitize_for_logging
|
|
15
|
+
from atlas.modules.config.config_manager import ConfigManager, RAGSourceConfig, resolve_env_var
|
|
16
|
+
from atlas.modules.rag.atlas_rag_client import AtlasRAGClient
|
|
17
|
+
from atlas.modules.rag.client import RAGResponse
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class UnifiedRAGService:
|
|
23
|
+
"""Aggregates RAG discovery and querying across HTTP and MCP sources."""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
config_manager: ConfigManager,
|
|
28
|
+
mcp_manager: Optional[Any] = None,
|
|
29
|
+
auth_check_func: Optional[Callable] = None,
|
|
30
|
+
rag_mcp_service: Optional[Any] = None,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""Initialize the unified RAG service.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config_manager: Configuration manager for loading RAG sources config.
|
|
36
|
+
mcp_manager: MCP tool manager for MCP-based RAG sources.
|
|
37
|
+
auth_check_func: Function to check user authorization for groups.
|
|
38
|
+
rag_mcp_service: Optional RAGMCPService instance for MCP RAG queries.
|
|
39
|
+
"""
|
|
40
|
+
self.config_manager = config_manager
|
|
41
|
+
self.mcp_manager = mcp_manager
|
|
42
|
+
self.auth_check_func = auth_check_func
|
|
43
|
+
self.rag_mcp_service = rag_mcp_service
|
|
44
|
+
|
|
45
|
+
# Cache of HTTP RAG clients by source name
|
|
46
|
+
self._http_clients: Dict[str, AtlasRAGClient] = {}
|
|
47
|
+
|
|
48
|
+
def _get_http_client(self, source_name: str, config: RAGSourceConfig) -> AtlasRAGClient:
|
|
49
|
+
"""Get or create an HTTP RAG client for a source."""
|
|
50
|
+
if source_name not in self._http_clients:
|
|
51
|
+
# Resolve environment variables in config
|
|
52
|
+
url = resolve_env_var(config.url, required=True)
|
|
53
|
+
bearer_token = resolve_env_var(config.bearer_token, required=False)
|
|
54
|
+
|
|
55
|
+
self._http_clients[source_name] = AtlasRAGClient(
|
|
56
|
+
base_url=url,
|
|
57
|
+
bearer_token=bearer_token,
|
|
58
|
+
default_model=config.default_model or "openai/gpt-oss-120b",
|
|
59
|
+
top_k=config.top_k,
|
|
60
|
+
timeout=config.timeout,
|
|
61
|
+
)
|
|
62
|
+
logger.info("Created HTTP RAG client for source: %s", source_name)
|
|
63
|
+
|
|
64
|
+
return self._http_clients[source_name]
|
|
65
|
+
|
|
66
|
+
async def _is_user_authorized(self, username: str, groups: List[str]) -> bool:
|
|
67
|
+
"""Check if user is authorized for a RAG source based on groups."""
|
|
68
|
+
if not groups:
|
|
69
|
+
return True # No groups restriction
|
|
70
|
+
if not self.auth_check_func:
|
|
71
|
+
return True # No auth check function provided
|
|
72
|
+
|
|
73
|
+
for group in groups:
|
|
74
|
+
if await self.auth_check_func(username, group):
|
|
75
|
+
return True
|
|
76
|
+
return False
|
|
77
|
+
|
|
78
|
+
async def discover_data_sources(
|
|
79
|
+
self,
|
|
80
|
+
username: str,
|
|
81
|
+
user_compliance_level: Optional[str] = None,
|
|
82
|
+
) -> List[Dict[str, Any]]:
|
|
83
|
+
"""Discover data sources across all configured RAG backends.
|
|
84
|
+
|
|
85
|
+
Returns a list of RAG servers with their sources in the format expected by the UI:
|
|
86
|
+
[
|
|
87
|
+
{
|
|
88
|
+
"server": "atlas_rag",
|
|
89
|
+
"displayName": "ATLAS RAG",
|
|
90
|
+
"icon": "database",
|
|
91
|
+
"complianceLevel": "Internal",
|
|
92
|
+
"sources": [
|
|
93
|
+
{"id": "technical-docs", "name": "technical-docs", ...}
|
|
94
|
+
]
|
|
95
|
+
}
|
|
96
|
+
]
|
|
97
|
+
"""
|
|
98
|
+
rag_servers: List[Dict[str, Any]] = []
|
|
99
|
+
rag_config = self.config_manager.rag_sources_config
|
|
100
|
+
|
|
101
|
+
for source_name, source_config in rag_config.sources.items():
|
|
102
|
+
try:
|
|
103
|
+
if not source_config.enabled:
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
# Check group authorization
|
|
107
|
+
if not await self._is_user_authorized(username, source_config.groups):
|
|
108
|
+
logger.debug(
|
|
109
|
+
"User %s not authorized for RAG source %s (groups: %s)",
|
|
110
|
+
sanitize_for_logging(username),
|
|
111
|
+
sanitize_for_logging(source_name),
|
|
112
|
+
source_config.groups,
|
|
113
|
+
)
|
|
114
|
+
continue
|
|
115
|
+
|
|
116
|
+
# Check compliance level filtering
|
|
117
|
+
if user_compliance_level and source_config.compliance_level:
|
|
118
|
+
compliance_mgr = get_compliance_manager()
|
|
119
|
+
if not compliance_mgr.is_accessible(
|
|
120
|
+
user_level=user_compliance_level,
|
|
121
|
+
resource_level=source_config.compliance_level,
|
|
122
|
+
):
|
|
123
|
+
logger.info(
|
|
124
|
+
"Skipping RAG source %s due to compliance level mismatch (user: %s, source: %s)",
|
|
125
|
+
sanitize_for_logging(source_name),
|
|
126
|
+
sanitize_for_logging(user_compliance_level),
|
|
127
|
+
sanitize_for_logging(source_config.compliance_level),
|
|
128
|
+
)
|
|
129
|
+
continue
|
|
130
|
+
|
|
131
|
+
if source_config.type == "http":
|
|
132
|
+
# Discover from HTTP RAG API
|
|
133
|
+
server_info = await self._discover_http_source(
|
|
134
|
+
source_name, source_config, username
|
|
135
|
+
)
|
|
136
|
+
if server_info:
|
|
137
|
+
rag_servers.append(server_info)
|
|
138
|
+
|
|
139
|
+
elif source_config.type == "mcp":
|
|
140
|
+
# MCP sources from rag-sources.json are handled by RAGMCPService
|
|
141
|
+
# which reads them via config_manager.rag_mcp_config
|
|
142
|
+
logger.debug("Skipping MCP source %s (handled by RAGMCPService)", source_name)
|
|
143
|
+
|
|
144
|
+
except Exception as e:
|
|
145
|
+
logger.error(
|
|
146
|
+
"Error discovering RAG source %s, continuing with remaining sources: %s",
|
|
147
|
+
sanitize_for_logging(source_name),
|
|
148
|
+
e,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
return rag_servers
|
|
152
|
+
|
|
153
|
+
async def _discover_http_source(
|
|
154
|
+
self,
|
|
155
|
+
source_name: str,
|
|
156
|
+
config: RAGSourceConfig,
|
|
157
|
+
username: str,
|
|
158
|
+
) -> Optional[Dict[str, Any]]:
|
|
159
|
+
"""Discover data sources from an HTTP RAG API."""
|
|
160
|
+
try:
|
|
161
|
+
client = self._get_http_client(source_name, config)
|
|
162
|
+
data_sources = await client.discover_data_sources(username)
|
|
163
|
+
|
|
164
|
+
if not data_sources:
|
|
165
|
+
logger.debug("No data sources found for HTTP source %s", source_name)
|
|
166
|
+
return None
|
|
167
|
+
|
|
168
|
+
# Build UI sources array
|
|
169
|
+
ui_sources = [
|
|
170
|
+
{
|
|
171
|
+
"id": ds.name,
|
|
172
|
+
"name": ds.name,
|
|
173
|
+
"authRequired": True,
|
|
174
|
+
"selected": False,
|
|
175
|
+
"complianceLevel": ds.compliance_level,
|
|
176
|
+
}
|
|
177
|
+
for ds in data_sources
|
|
178
|
+
]
|
|
179
|
+
|
|
180
|
+
return {
|
|
181
|
+
"server": source_name,
|
|
182
|
+
"displayName": config.display_name or source_name,
|
|
183
|
+
"icon": config.icon or "database",
|
|
184
|
+
"complianceLevel": config.compliance_level,
|
|
185
|
+
"sources": ui_sources,
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
except Exception as e:
|
|
189
|
+
logger.error("Failed to discover HTTP source %s: %s", source_name, e)
|
|
190
|
+
return None
|
|
191
|
+
|
|
192
|
+
async def query_rag(
|
|
193
|
+
self,
|
|
194
|
+
username: str,
|
|
195
|
+
qualified_data_source: str,
|
|
196
|
+
messages: List[Dict],
|
|
197
|
+
) -> RAGResponse:
|
|
198
|
+
"""Query a RAG source.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
username: The user making the query.
|
|
202
|
+
qualified_data_source: Data source in format "server:source_id" (e.g., "atlas_rag:technical-docs").
|
|
203
|
+
messages: List of message dictionaries.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
RAGResponse with content and metadata.
|
|
207
|
+
"""
|
|
208
|
+
logger.debug(
|
|
209
|
+
"[RAG] query_rag called: qualified_source=%s, user=%s, message_count=%d",
|
|
210
|
+
sanitize_for_logging(qualified_data_source),
|
|
211
|
+
sanitize_for_logging(username),
|
|
212
|
+
len(messages),
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Parse the qualified data source
|
|
216
|
+
if ":" in qualified_data_source:
|
|
217
|
+
server_name, source_id = qualified_data_source.split(":", 1)
|
|
218
|
+
else:
|
|
219
|
+
# No prefix - assume it's the source ID and try to find the server
|
|
220
|
+
source_id = qualified_data_source
|
|
221
|
+
server_name = self._find_server_for_source(source_id)
|
|
222
|
+
if not server_name:
|
|
223
|
+
logger.error("[RAG] Could not find server for source: %s", source_id)
|
|
224
|
+
raise ValueError(f"Could not find server for source: {source_id}")
|
|
225
|
+
|
|
226
|
+
logger.info(
|
|
227
|
+
"[RAG] Routing query: server=%s, source=%s, user=%s",
|
|
228
|
+
server_name, source_id, sanitize_for_logging(username)
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
rag_config = self.config_manager.rag_sources_config
|
|
232
|
+
source_config = rag_config.sources.get(server_name)
|
|
233
|
+
|
|
234
|
+
if not source_config:
|
|
235
|
+
logger.error("[RAG] Source not found in config: %s", server_name)
|
|
236
|
+
raise ValueError(f"RAG source not found: {server_name}")
|
|
237
|
+
|
|
238
|
+
logger.debug(
|
|
239
|
+
"[RAG] Source config: type=%s, enabled=%s, compliance_level=%s",
|
|
240
|
+
source_config.type,
|
|
241
|
+
source_config.enabled,
|
|
242
|
+
source_config.compliance_level,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
if source_config.type == "http":
|
|
246
|
+
logger.debug("[RAG] Routing to HTTP RAG client for server: %s", server_name)
|
|
247
|
+
client = self._get_http_client(server_name, source_config)
|
|
248
|
+
# Pass the unqualified source_id to the HTTP API
|
|
249
|
+
response = await client.query_rag(username, source_id, messages)
|
|
250
|
+
logger.debug(
|
|
251
|
+
"[RAG] HTTP RAG response received: content_length=%d, has_metadata=%s",
|
|
252
|
+
len(response.content) if response.content else 0,
|
|
253
|
+
response.metadata is not None,
|
|
254
|
+
)
|
|
255
|
+
return response
|
|
256
|
+
|
|
257
|
+
elif source_config.type == "mcp":
|
|
258
|
+
logger.debug("[RAG] Routing to MCP RAG service for server: %s", server_name)
|
|
259
|
+
# Route MCP queries to RAGMCPService
|
|
260
|
+
if not self.rag_mcp_service:
|
|
261
|
+
logger.error("[RAG] RAGMCPService not configured for MCP RAG queries")
|
|
262
|
+
raise ValueError("RAGMCPService not configured for MCP RAG queries")
|
|
263
|
+
|
|
264
|
+
# Extract the query from messages (last user message)
|
|
265
|
+
query = ""
|
|
266
|
+
for msg in reversed(messages):
|
|
267
|
+
if msg.get("role") == "user":
|
|
268
|
+
query = msg.get("content", "")
|
|
269
|
+
break
|
|
270
|
+
|
|
271
|
+
logger.debug(
|
|
272
|
+
"[RAG] MCP RAG query: server=%s, source=%s, query_preview=%s...",
|
|
273
|
+
server_name,
|
|
274
|
+
source_id,
|
|
275
|
+
sanitize_for_logging(query[:100]) if query else "(empty)",
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# Call RAGMCPService.synthesize() for MCP sources
|
|
279
|
+
qualified_sources = [qualified_data_source] # Format: "server:source_id"
|
|
280
|
+
mcp_response = await self.rag_mcp_service.synthesize(
|
|
281
|
+
username=username,
|
|
282
|
+
query=query,
|
|
283
|
+
sources=qualified_sources,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
logger.debug(
|
|
287
|
+
"[RAG] MCP RAG response received: has_results=%s, meta_data_keys=%s",
|
|
288
|
+
"results" in mcp_response,
|
|
289
|
+
list(mcp_response.get("meta_data", {}).keys()),
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# Convert MCP response to RAGResponse format
|
|
293
|
+
results = mcp_response.get("results", {})
|
|
294
|
+
answer = results.get("answer", "No response from MCP RAG.")
|
|
295
|
+
meta_data = mcp_response.get("meta_data", {})
|
|
296
|
+
|
|
297
|
+
logger.debug(
|
|
298
|
+
"[RAG] MCP RAG answer: length=%d, preview=%s...",
|
|
299
|
+
len(answer) if answer else 0,
|
|
300
|
+
sanitize_for_logging(answer[:200]) if answer else "(empty)",
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# Build metadata if available
|
|
304
|
+
metadata = None
|
|
305
|
+
if meta_data.get("providers"):
|
|
306
|
+
# Create basic metadata from MCP response
|
|
307
|
+
from atlas.modules.rag.client import DocumentMetadata, RAGMetadata
|
|
308
|
+
providers_info = meta_data.get("providers", {})
|
|
309
|
+
docs_found = []
|
|
310
|
+
for provider_name, provider_info in providers_info.items():
|
|
311
|
+
if provider_info.get("used_synth"):
|
|
312
|
+
docs_found.append(DocumentMetadata(
|
|
313
|
+
source=provider_name,
|
|
314
|
+
content_type="mcp_synthesis",
|
|
315
|
+
confidence_score=1.0,
|
|
316
|
+
))
|
|
317
|
+
metadata = RAGMetadata(
|
|
318
|
+
query_processing_time_ms=0,
|
|
319
|
+
total_documents_searched=len(providers_info),
|
|
320
|
+
documents_found=docs_found,
|
|
321
|
+
data_source_name=server_name,
|
|
322
|
+
retrieval_method="mcp_synthesis",
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
return RAGResponse(content=answer, metadata=metadata)
|
|
326
|
+
|
|
327
|
+
else:
|
|
328
|
+
raise ValueError(f"Unknown RAG source type: {source_config.type}")
|
|
329
|
+
|
|
330
|
+
def _find_server_for_source(self, source_id: str) -> Optional[str]:
|
|
331
|
+
"""Try to find which server a source belongs to (best effort)."""
|
|
332
|
+
# For now, just return None - caller should provide qualified source
|
|
333
|
+
return None
|
|
334
|
+
|
|
335
|
+
def get_http_sources(self) -> Dict[str, RAGSourceConfig]:
|
|
336
|
+
"""Get all HTTP-type RAG sources from config."""
|
|
337
|
+
rag_config = self.config_manager.rag_sources_config
|
|
338
|
+
return {
|
|
339
|
+
name: config
|
|
340
|
+
for name, config in rag_config.sources.items()
|
|
341
|
+
if config.type == "http" and config.enabled
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
def get_mcp_sources(self) -> Dict[str, RAGSourceConfig]:
|
|
345
|
+
"""Get all MCP-type RAG sources from config."""
|
|
346
|
+
rag_config = self.config_manager.rag_sources_config
|
|
347
|
+
return {
|
|
348
|
+
name: config
|
|
349
|
+
for name, config in rag_config.sources.items()
|
|
350
|
+
if config.type == "mcp" and config.enabled
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
def invalidate_cache(self, source_name: Optional[str] = None) -> None:
|
|
354
|
+
"""Invalidate cached HTTP clients.
|
|
355
|
+
|
|
356
|
+
Call this when configuration changes to ensure clients are recreated
|
|
357
|
+
with updated settings (URLs, tokens, etc.).
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
source_name: Specific source to invalidate, or None to invalidate all.
|
|
361
|
+
"""
|
|
362
|
+
if source_name:
|
|
363
|
+
if source_name in self._http_clients:
|
|
364
|
+
del self._http_clients[source_name]
|
|
365
|
+
logger.info("Invalidated HTTP client cache for source: %s", source_name)
|
|
366
|
+
else:
|
|
367
|
+
self._http_clients.clear()
|
|
368
|
+
logger.info("Invalidated all HTTP client caches")
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
__all__ = ["UnifiedRAGService"]
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Infrastructure layer - external adapters and wiring."""
|
|
2
|
+
|
|
3
|
+
from .app_factory import AppFactory, app_factory
|
|
4
|
+
from .transport.websocket_connection_adapter import WebSocketConnectionAdapter
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"AppFactory",
|
|
8
|
+
"app_factory",
|
|
9
|
+
"WebSocketConnectionAdapter",
|
|
10
|
+
]
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""Application factory for dependency injection and wiring."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from atlas.application.chat.service import ChatService
|
|
7
|
+
from atlas.core.auth import is_user_in_group
|
|
8
|
+
from atlas.domain.rag_mcp_service import RAGMCPService
|
|
9
|
+
from atlas.domain.unified_rag_service import UnifiedRAGService
|
|
10
|
+
from atlas.infrastructure.sessions.in_memory_repository import InMemorySessionRepository
|
|
11
|
+
from atlas.interfaces.transport import ChatConnectionProtocol
|
|
12
|
+
from atlas.modules.config import ConfigManager
|
|
13
|
+
from atlas.modules.file_storage import FileManager, S3StorageClient
|
|
14
|
+
from atlas.modules.file_storage.mock_s3_client import MockS3StorageClient
|
|
15
|
+
from atlas.modules.llm.litellm_caller import LiteLLMCaller
|
|
16
|
+
from atlas.modules.mcp_tools import MCPToolManager
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AppFactory:
|
|
22
|
+
"""Application factory that wires dependencies (simple in-memory DI)."""
|
|
23
|
+
|
|
24
|
+
def __init__(self) -> None:
|
|
25
|
+
# Configuration
|
|
26
|
+
self.config_manager = ConfigManager()
|
|
27
|
+
|
|
28
|
+
# MCP tools manager
|
|
29
|
+
self.mcp_tools = MCPToolManager()
|
|
30
|
+
|
|
31
|
+
# Only initialize RAG services when the RAG feature flag is enabled
|
|
32
|
+
if self.config_manager.app_settings.feature_rag_enabled:
|
|
33
|
+
# RAG MCP service for MCP-based RAG servers (create first for dependency injection)
|
|
34
|
+
self.rag_mcp_service = RAGMCPService(
|
|
35
|
+
mcp_manager=self.mcp_tools,
|
|
36
|
+
config_manager=self.config_manager,
|
|
37
|
+
auth_check_func=is_user_in_group,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Unified RAG service for HTTP and MCP RAG sources (configured via rag-sources.json)
|
|
41
|
+
# Includes rag_mcp_service for routing MCP queries
|
|
42
|
+
self.unified_rag_service = UnifiedRAGService(
|
|
43
|
+
config_manager=self.config_manager,
|
|
44
|
+
mcp_manager=self.mcp_tools,
|
|
45
|
+
auth_check_func=is_user_in_group,
|
|
46
|
+
rag_mcp_service=self.rag_mcp_service,
|
|
47
|
+
)
|
|
48
|
+
logger.info("RAG services initialized (FEATURE_RAG_ENABLED=true)")
|
|
49
|
+
else:
|
|
50
|
+
self.rag_mcp_service = None
|
|
51
|
+
self.unified_rag_service = None
|
|
52
|
+
logger.info("RAG services disabled (FEATURE_RAG_ENABLED=false)")
|
|
53
|
+
|
|
54
|
+
# LLM caller with unified RAG service for RAG queries (None when RAG disabled)
|
|
55
|
+
self.llm_caller = LiteLLMCaller(
|
|
56
|
+
self.config_manager.llm_config,
|
|
57
|
+
debug_mode=self.config_manager.app_settings.debug_mode,
|
|
58
|
+
rag_service=self.unified_rag_service,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# File storage & manager
|
|
62
|
+
if self.config_manager.app_settings.use_mock_s3:
|
|
63
|
+
logger.info("Using MockS3StorageClient (in-process, no Docker required)")
|
|
64
|
+
self.file_storage = MockS3StorageClient()
|
|
65
|
+
else:
|
|
66
|
+
logger.info("Using S3StorageClient (MinIO/AWS S3)")
|
|
67
|
+
self.file_storage = S3StorageClient()
|
|
68
|
+
self.file_manager = FileManager(self.file_storage)
|
|
69
|
+
|
|
70
|
+
# Shared session repository for all ChatService instances
|
|
71
|
+
self.session_repository = InMemorySessionRepository()
|
|
72
|
+
|
|
73
|
+
logger.info("AppFactory initialized")
|
|
74
|
+
|
|
75
|
+
async def initialize(self) -> None:
|
|
76
|
+
"""Initialize async resources (MCP clients, tool discovery) for headless use."""
|
|
77
|
+
try:
|
|
78
|
+
await self.mcp_tools.initialize_clients()
|
|
79
|
+
await self.mcp_tools.discover_tools()
|
|
80
|
+
await self.mcp_tools.discover_prompts()
|
|
81
|
+
logger.info("AppFactory async initialization complete")
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.warning("MCP initialization failed; continuing without tools: %s", e)
|
|
84
|
+
|
|
85
|
+
def create_chat_service(
|
|
86
|
+
self, connection: Optional[ChatConnectionProtocol] = None
|
|
87
|
+
) -> ChatService:
|
|
88
|
+
return ChatService(
|
|
89
|
+
llm=self.llm_caller,
|
|
90
|
+
tool_manager=self.mcp_tools,
|
|
91
|
+
connection=connection,
|
|
92
|
+
config_manager=self.config_manager,
|
|
93
|
+
file_manager=self.file_manager,
|
|
94
|
+
session_repository=self.session_repository,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def create_headless_chat_service(
|
|
98
|
+
self, event_publisher=None
|
|
99
|
+
) -> ChatService:
|
|
100
|
+
"""Create a ChatService for headless/CLI use with a custom event publisher."""
|
|
101
|
+
return ChatService(
|
|
102
|
+
llm=self.llm_caller,
|
|
103
|
+
tool_manager=self.mcp_tools,
|
|
104
|
+
connection=None,
|
|
105
|
+
config_manager=self.config_manager,
|
|
106
|
+
file_manager=self.file_manager,
|
|
107
|
+
session_repository=self.session_repository,
|
|
108
|
+
event_publisher=event_publisher,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Accessors
|
|
112
|
+
def get_config_manager(self) -> ConfigManager: # noqa: D401
|
|
113
|
+
return self.config_manager
|
|
114
|
+
|
|
115
|
+
def get_llm_caller(self) -> LiteLLMCaller: # noqa: D401
|
|
116
|
+
return self.llm_caller
|
|
117
|
+
|
|
118
|
+
def get_mcp_manager(self) -> MCPToolManager: # noqa: D401
|
|
119
|
+
return self.mcp_tools
|
|
120
|
+
|
|
121
|
+
def get_rag_mcp_service(self) -> Optional[RAGMCPService]: # noqa: D401
|
|
122
|
+
return self.rag_mcp_service
|
|
123
|
+
|
|
124
|
+
def get_unified_rag_service(self) -> Optional[UnifiedRAGService]: # noqa: D401
|
|
125
|
+
return self.unified_rag_service
|
|
126
|
+
|
|
127
|
+
def get_file_storage(self) -> S3StorageClient: # noqa: D401
|
|
128
|
+
return self.file_storage
|
|
129
|
+
|
|
130
|
+
def get_file_manager(self) -> FileManager: # noqa: D401
|
|
131
|
+
return self.file_manager
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# Temporary global instance during migration away from singletons
|
|
135
|
+
app_factory = AppFactory()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Infrastructure event implementations."""
|