atlas-chat 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atlas/__init__.py +40 -0
- atlas/application/__init__.py +7 -0
- atlas/application/chat/__init__.py +7 -0
- atlas/application/chat/agent/__init__.py +10 -0
- atlas/application/chat/agent/act_loop.py +179 -0
- atlas/application/chat/agent/factory.py +142 -0
- atlas/application/chat/agent/protocols.py +46 -0
- atlas/application/chat/agent/react_loop.py +338 -0
- atlas/application/chat/agent/think_act_loop.py +171 -0
- atlas/application/chat/approval_manager.py +151 -0
- atlas/application/chat/elicitation_manager.py +191 -0
- atlas/application/chat/events/__init__.py +1 -0
- atlas/application/chat/events/agent_event_relay.py +112 -0
- atlas/application/chat/modes/__init__.py +1 -0
- atlas/application/chat/modes/agent.py +125 -0
- atlas/application/chat/modes/plain.py +74 -0
- atlas/application/chat/modes/rag.py +81 -0
- atlas/application/chat/modes/tools.py +179 -0
- atlas/application/chat/orchestrator.py +213 -0
- atlas/application/chat/policies/__init__.py +1 -0
- atlas/application/chat/policies/tool_authorization.py +99 -0
- atlas/application/chat/preprocessors/__init__.py +1 -0
- atlas/application/chat/preprocessors/message_builder.py +92 -0
- atlas/application/chat/preprocessors/prompt_override_service.py +104 -0
- atlas/application/chat/service.py +454 -0
- atlas/application/chat/utilities/__init__.py +6 -0
- atlas/application/chat/utilities/error_handler.py +367 -0
- atlas/application/chat/utilities/event_notifier.py +546 -0
- atlas/application/chat/utilities/file_processor.py +613 -0
- atlas/application/chat/utilities/tool_executor.py +789 -0
- atlas/atlas_chat_cli.py +347 -0
- atlas/atlas_client.py +238 -0
- atlas/core/__init__.py +0 -0
- atlas/core/auth.py +205 -0
- atlas/core/authorization_manager.py +27 -0
- atlas/core/capabilities.py +123 -0
- atlas/core/compliance.py +215 -0
- atlas/core/domain_whitelist.py +147 -0
- atlas/core/domain_whitelist_middleware.py +82 -0
- atlas/core/http_client.py +28 -0
- atlas/core/log_sanitizer.py +102 -0
- atlas/core/metrics_logger.py +59 -0
- atlas/core/middleware.py +131 -0
- atlas/core/otel_config.py +242 -0
- atlas/core/prompt_risk.py +200 -0
- atlas/core/rate_limit.py +0 -0
- atlas/core/rate_limit_middleware.py +64 -0
- atlas/core/security_headers_middleware.py +51 -0
- atlas/domain/__init__.py +37 -0
- atlas/domain/chat/__init__.py +1 -0
- atlas/domain/chat/dtos.py +85 -0
- atlas/domain/errors.py +96 -0
- atlas/domain/messages/__init__.py +12 -0
- atlas/domain/messages/models.py +160 -0
- atlas/domain/rag_mcp_service.py +664 -0
- atlas/domain/sessions/__init__.py +7 -0
- atlas/domain/sessions/models.py +36 -0
- atlas/domain/unified_rag_service.py +371 -0
- atlas/infrastructure/__init__.py +10 -0
- atlas/infrastructure/app_factory.py +135 -0
- atlas/infrastructure/events/__init__.py +1 -0
- atlas/infrastructure/events/cli_event_publisher.py +140 -0
- atlas/infrastructure/events/websocket_publisher.py +140 -0
- atlas/infrastructure/sessions/in_memory_repository.py +56 -0
- atlas/infrastructure/transport/__init__.py +7 -0
- atlas/infrastructure/transport/websocket_connection_adapter.py +33 -0
- atlas/init_cli.py +226 -0
- atlas/interfaces/__init__.py +15 -0
- atlas/interfaces/events.py +134 -0
- atlas/interfaces/llm.py +54 -0
- atlas/interfaces/rag.py +40 -0
- atlas/interfaces/sessions.py +75 -0
- atlas/interfaces/tools.py +57 -0
- atlas/interfaces/transport.py +24 -0
- atlas/main.py +564 -0
- atlas/mcp/api_key_demo/README.md +76 -0
- atlas/mcp/api_key_demo/main.py +172 -0
- atlas/mcp/api_key_demo/run.sh +56 -0
- atlas/mcp/basictable/main.py +147 -0
- atlas/mcp/calculator/main.py +149 -0
- atlas/mcp/code-executor/execution_engine.py +98 -0
- atlas/mcp/code-executor/execution_environment.py +95 -0
- atlas/mcp/code-executor/main.py +528 -0
- atlas/mcp/code-executor/result_processing.py +276 -0
- atlas/mcp/code-executor/script_generation.py +195 -0
- atlas/mcp/code-executor/security_checker.py +140 -0
- atlas/mcp/corporate_cars/main.py +437 -0
- atlas/mcp/csv_reporter/main.py +545 -0
- atlas/mcp/duckduckgo/main.py +182 -0
- atlas/mcp/elicitation_demo/README.md +171 -0
- atlas/mcp/elicitation_demo/main.py +262 -0
- atlas/mcp/env-demo/README.md +158 -0
- atlas/mcp/env-demo/main.py +199 -0
- atlas/mcp/file_size_test/main.py +284 -0
- atlas/mcp/filesystem/main.py +348 -0
- atlas/mcp/image_demo/main.py +113 -0
- atlas/mcp/image_demo/requirements.txt +4 -0
- atlas/mcp/logging_demo/README.md +72 -0
- atlas/mcp/logging_demo/main.py +103 -0
- atlas/mcp/many_tools_demo/main.py +50 -0
- atlas/mcp/order_database/__init__.py +0 -0
- atlas/mcp/order_database/main.py +369 -0
- atlas/mcp/order_database/signal_data.csv +1001 -0
- atlas/mcp/pdfbasic/main.py +394 -0
- atlas/mcp/pptx_generator/main.py +760 -0
- atlas/mcp/pptx_generator/requirements.txt +13 -0
- atlas/mcp/pptx_generator/run_test.sh +1 -0
- atlas/mcp/pptx_generator/test_pptx_generator_security.py +169 -0
- atlas/mcp/progress_demo/main.py +167 -0
- atlas/mcp/progress_updates_demo/QUICKSTART.md +273 -0
- atlas/mcp/progress_updates_demo/README.md +120 -0
- atlas/mcp/progress_updates_demo/main.py +497 -0
- atlas/mcp/prompts/main.py +222 -0
- atlas/mcp/public_demo/main.py +189 -0
- atlas/mcp/sampling_demo/README.md +169 -0
- atlas/mcp/sampling_demo/main.py +234 -0
- atlas/mcp/thinking/main.py +77 -0
- atlas/mcp/tool_planner/main.py +240 -0
- atlas/mcp/ui-demo/badmesh.png +0 -0
- atlas/mcp/ui-demo/main.py +383 -0
- atlas/mcp/ui-demo/templates/button_demo.html +32 -0
- atlas/mcp/ui-demo/templates/data_visualization.html +32 -0
- atlas/mcp/ui-demo/templates/form_demo.html +28 -0
- atlas/mcp/username-override-demo/README.md +320 -0
- atlas/mcp/username-override-demo/main.py +308 -0
- atlas/modules/__init__.py +0 -0
- atlas/modules/config/__init__.py +34 -0
- atlas/modules/config/cli.py +231 -0
- atlas/modules/config/config_manager.py +1096 -0
- atlas/modules/file_storage/__init__.py +22 -0
- atlas/modules/file_storage/cli.py +330 -0
- atlas/modules/file_storage/content_extractor.py +290 -0
- atlas/modules/file_storage/manager.py +295 -0
- atlas/modules/file_storage/mock_s3_client.py +402 -0
- atlas/modules/file_storage/s3_client.py +417 -0
- atlas/modules/llm/__init__.py +19 -0
- atlas/modules/llm/caller.py +287 -0
- atlas/modules/llm/litellm_caller.py +675 -0
- atlas/modules/llm/models.py +19 -0
- atlas/modules/mcp_tools/__init__.py +17 -0
- atlas/modules/mcp_tools/client.py +2123 -0
- atlas/modules/mcp_tools/token_storage.py +556 -0
- atlas/modules/prompts/prompt_provider.py +130 -0
- atlas/modules/rag/__init__.py +24 -0
- atlas/modules/rag/atlas_rag_client.py +336 -0
- atlas/modules/rag/client.py +129 -0
- atlas/routes/admin_routes.py +865 -0
- atlas/routes/config_routes.py +484 -0
- atlas/routes/feedback_routes.py +361 -0
- atlas/routes/files_routes.py +274 -0
- atlas/routes/health_routes.py +40 -0
- atlas/routes/mcp_auth_routes.py +223 -0
- atlas/server_cli.py +164 -0
- atlas/tests/conftest.py +20 -0
- atlas/tests/integration/test_mcp_auth_integration.py +152 -0
- atlas/tests/manual_test_sampling.py +87 -0
- atlas/tests/modules/mcp_tools/test_client_auth.py +226 -0
- atlas/tests/modules/mcp_tools/test_client_env.py +191 -0
- atlas/tests/test_admin_mcp_server_management_routes.py +141 -0
- atlas/tests/test_agent_roa.py +135 -0
- atlas/tests/test_app_factory_smoke.py +47 -0
- atlas/tests/test_approval_manager.py +439 -0
- atlas/tests/test_atlas_client.py +188 -0
- atlas/tests/test_atlas_rag_client.py +447 -0
- atlas/tests/test_atlas_rag_integration.py +224 -0
- atlas/tests/test_attach_file_flow.py +287 -0
- atlas/tests/test_auth_utils.py +165 -0
- atlas/tests/test_backend_public_url.py +185 -0
- atlas/tests/test_banner_logging.py +287 -0
- atlas/tests/test_capability_tokens_and_injection.py +203 -0
- atlas/tests/test_compliance_level.py +54 -0
- atlas/tests/test_compliance_manager.py +253 -0
- atlas/tests/test_config_manager.py +617 -0
- atlas/tests/test_config_manager_paths.py +12 -0
- atlas/tests/test_core_auth.py +18 -0
- atlas/tests/test_core_utils.py +190 -0
- atlas/tests/test_docker_env_sync.py +202 -0
- atlas/tests/test_domain_errors.py +329 -0
- atlas/tests/test_domain_whitelist.py +359 -0
- atlas/tests/test_elicitation_manager.py +408 -0
- atlas/tests/test_elicitation_routing.py +296 -0
- atlas/tests/test_env_demo_server.py +88 -0
- atlas/tests/test_error_classification.py +113 -0
- atlas/tests/test_error_flow_integration.py +116 -0
- atlas/tests/test_feedback_routes.py +333 -0
- atlas/tests/test_file_content_extraction.py +1134 -0
- atlas/tests/test_file_extraction_routes.py +158 -0
- atlas/tests/test_file_library.py +107 -0
- atlas/tests/test_file_manager_unit.py +18 -0
- atlas/tests/test_health_route.py +49 -0
- atlas/tests/test_http_client_stub.py +8 -0
- atlas/tests/test_imports_smoke.py +30 -0
- atlas/tests/test_interfaces_llm_response.py +9 -0
- atlas/tests/test_issue_access_denied_fix.py +136 -0
- atlas/tests/test_llm_env_expansion.py +836 -0
- atlas/tests/test_log_level_sensitive_data.py +285 -0
- atlas/tests/test_mcp_auth_routes.py +341 -0
- atlas/tests/test_mcp_client_auth.py +331 -0
- atlas/tests/test_mcp_data_injection.py +270 -0
- atlas/tests/test_mcp_get_authorized_servers.py +95 -0
- atlas/tests/test_mcp_hot_reload.py +512 -0
- atlas/tests/test_mcp_image_content.py +424 -0
- atlas/tests/test_mcp_logging.py +172 -0
- atlas/tests/test_mcp_progress_updates.py +313 -0
- atlas/tests/test_mcp_prompt_override_system_prompt.py +102 -0
- atlas/tests/test_mcp_prompts_server.py +39 -0
- atlas/tests/test_mcp_tool_result_parsing.py +296 -0
- atlas/tests/test_metrics_logger.py +56 -0
- atlas/tests/test_middleware_auth.py +379 -0
- atlas/tests/test_prompt_risk_and_acl.py +141 -0
- atlas/tests/test_rag_mcp_aggregator.py +204 -0
- atlas/tests/test_rag_mcp_service.py +224 -0
- atlas/tests/test_rate_limit_middleware.py +45 -0
- atlas/tests/test_routes_config_smoke.py +60 -0
- atlas/tests/test_routes_files_download_token.py +41 -0
- atlas/tests/test_routes_files_health.py +18 -0
- atlas/tests/test_runtime_imports.py +53 -0
- atlas/tests/test_sampling_integration.py +482 -0
- atlas/tests/test_security_admin_routes.py +61 -0
- atlas/tests/test_security_capability_tokens.py +65 -0
- atlas/tests/test_security_file_stats_scope.py +21 -0
- atlas/tests/test_security_header_injection.py +191 -0
- atlas/tests/test_security_headers_and_filename.py +63 -0
- atlas/tests/test_shared_session_repository.py +101 -0
- atlas/tests/test_system_prompt_loading.py +181 -0
- atlas/tests/test_token_storage.py +505 -0
- atlas/tests/test_tool_approval_config.py +93 -0
- atlas/tests/test_tool_approval_utils.py +356 -0
- atlas/tests/test_tool_authorization_group_filtering.py +223 -0
- atlas/tests/test_tool_details_in_config.py +108 -0
- atlas/tests/test_tool_planner.py +300 -0
- atlas/tests/test_unified_rag_service.py +398 -0
- atlas/tests/test_username_override_in_approval.py +258 -0
- atlas/tests/test_websocket_auth_header.py +168 -0
- atlas/version.py +6 -0
- atlas_chat-0.1.0.data/data/.env.example +253 -0
- atlas_chat-0.1.0.data/data/config/defaults/compliance-levels.json +44 -0
- atlas_chat-0.1.0.data/data/config/defaults/domain-whitelist.json +123 -0
- atlas_chat-0.1.0.data/data/config/defaults/file-extractors.json +74 -0
- atlas_chat-0.1.0.data/data/config/defaults/help-config.json +198 -0
- atlas_chat-0.1.0.data/data/config/defaults/llmconfig-buggy.yml +11 -0
- atlas_chat-0.1.0.data/data/config/defaults/llmconfig.yml +19 -0
- atlas_chat-0.1.0.data/data/config/defaults/mcp.json +138 -0
- atlas_chat-0.1.0.data/data/config/defaults/rag-sources.json +17 -0
- atlas_chat-0.1.0.data/data/config/defaults/splash-config.json +16 -0
- atlas_chat-0.1.0.dist-info/METADATA +236 -0
- atlas_chat-0.1.0.dist-info/RECORD +250 -0
- atlas_chat-0.1.0.dist-info/WHEEL +5 -0
- atlas_chat-0.1.0.dist-info/entry_points.txt +4 -0
- atlas_chat-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
# Ensure backend root is on path
|
|
9
|
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|
10
|
+
|
|
11
|
+
from atlas.application.chat.service import ChatService # type: ignore
|
|
12
|
+
from atlas.interfaces.llm import LLMProtocol # type: ignore
|
|
13
|
+
from atlas.interfaces.transport import ChatConnectionProtocol # type: ignore
|
|
14
|
+
from atlas.modules.config.config_manager import ConfigManager # type: ignore
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class FakeLLM(LLMProtocol):
|
|
18
|
+
"""Programmable fake LLM.
|
|
19
|
+
|
|
20
|
+
call_plain returns next item from a queue if provided, else a default.
|
|
21
|
+
call_with_tools is not used in these tests (reason-only scenarios).
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, plain_responses: Optional[List[str]] = None):
|
|
25
|
+
self._plain = list(plain_responses or [])
|
|
26
|
+
|
|
27
|
+
async def call_plain(self, model_name: str, messages: List[Dict[str, str]], temperature: float = 0.7) -> str:
|
|
28
|
+
if self._plain:
|
|
29
|
+
return self._plain.pop(0)
|
|
30
|
+
return "{\"plan\":\"noop\",\"tools_to_consider\":[],\"finish\":true,\"final_answer\":\"ok\"}"
|
|
31
|
+
|
|
32
|
+
async def call_with_tools(self, model_name: str, messages: List[Dict[str, str]], tools_schema: List[Dict], tool_choice: str = "auto", temperature: float = 0.7):
|
|
33
|
+
# Minimal stub: never returns tool calls in these tests
|
|
34
|
+
from atlas.interfaces.llm import LLMResponse # type: ignore
|
|
35
|
+
return LLMResponse(content="")
|
|
36
|
+
|
|
37
|
+
async def call_with_rag(self, model_name: str, messages: List[Dict[str, str]], data_sources: List[str], user_email: str, temperature: float = 0.7) -> str:
|
|
38
|
+
return "not-used"
|
|
39
|
+
|
|
40
|
+
async def call_with_rag_and_tools(self, model_name: str, messages: List[Dict[str, str]], data_sources: List[str], tools_schema: List[Dict], user_email: str, tool_choice: str = "auto", temperature: float = 0.7):
|
|
41
|
+
from atlas.interfaces.llm import LLMResponse # type: ignore
|
|
42
|
+
return LLMResponse(content="")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class FakeConnection(ChatConnectionProtocol):
|
|
46
|
+
def __init__(self, incoming: Optional[List[Dict[str, Any]]] = None):
|
|
47
|
+
self.messages: List[Dict[str, Any]] = []
|
|
48
|
+
self._queue: asyncio.Queue = asyncio.Queue()
|
|
49
|
+
for item in (incoming or []):
|
|
50
|
+
self._queue.put_nowait(item)
|
|
51
|
+
|
|
52
|
+
async def send_json(self, data: Dict[str, Any]) -> None:
|
|
53
|
+
self.messages.append(data)
|
|
54
|
+
|
|
55
|
+
async def receive_json(self) -> Dict[str, Any]:
|
|
56
|
+
return await self._queue.get()
|
|
57
|
+
|
|
58
|
+
async def accept(self) -> None: # pragma: no cover - not used
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
async def close(self) -> None: # pragma: no cover - not used
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@pytest.mark.asyncio
|
|
66
|
+
async def test_agent_reason_immediate_finish():
|
|
67
|
+
"""Agent should finish in Reason phase when finish=true with final_answer."""
|
|
68
|
+
# Reason -> finish immediately
|
|
69
|
+
reason_finish = (
|
|
70
|
+
"Planning...\n{\"plan\":\"answer now\",\"tools_to_consider\":[],\"finish\":true,\"final_answer\":\"Done!\"}"
|
|
71
|
+
)
|
|
72
|
+
llm = FakeLLM([reason_finish])
|
|
73
|
+
conn = FakeConnection()
|
|
74
|
+
svc = ChatService(llm=llm, tool_manager=None, connection=conn, config_manager=ConfigManager())
|
|
75
|
+
|
|
76
|
+
resp = await svc.handle_chat_message(
|
|
77
|
+
session_id=__import__("uuid").uuid4(),
|
|
78
|
+
content="Hello",
|
|
79
|
+
model="fake",
|
|
80
|
+
agent_mode=True,
|
|
81
|
+
agent_max_steps=3,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
assert resp["type"] == "chat_response"
|
|
85
|
+
# Agent behavior varies based on whether prompt templates are available
|
|
86
|
+
# With prompts: returns parsed final_answer ("Done!")
|
|
87
|
+
# Without prompts: returns full LLM response
|
|
88
|
+
message = resp["message"]
|
|
89
|
+
assert message == "Done!" or "Done!" in message
|
|
90
|
+
|
|
91
|
+
# Verify agent lifecycle events were emitted
|
|
92
|
+
kinds = [m.get("update_type") for m in conn.messages if m.get("type") == "agent_update"]
|
|
93
|
+
assert "agent_start" in kinds
|
|
94
|
+
assert "agent_turn_start" in kinds
|
|
95
|
+
assert "agent_reason" in kinds
|
|
96
|
+
assert "agent_completion" in kinds
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@pytest.mark.asyncio
|
|
100
|
+
async def test_agent_request_input_flow():
|
|
101
|
+
"""Agent requests input then finishes on next turn."""
|
|
102
|
+
# Turn 1 Reason -> request_input
|
|
103
|
+
reason_ask = (
|
|
104
|
+
"Thinking...\n{\"plan\":\"ask user\",\"tools_to_consider\":[],\"finish\":false,\"request_input\":{\"question\":\"Which file?\"}}"
|
|
105
|
+
)
|
|
106
|
+
# Turn 2 Reason -> finish with final answer
|
|
107
|
+
reason_finish = (
|
|
108
|
+
"Ok now answer.\n{\"plan\":\"answer\",\"tools_to_consider\":[],\"finish\":true,\"final_answer\":\"All set.\"}"
|
|
109
|
+
)
|
|
110
|
+
llm = FakeLLM([reason_ask, reason_finish])
|
|
111
|
+
# Provide a user response to the request_input
|
|
112
|
+
incoming = [{"type": "agent_user_input", "content": "Use latest."}]
|
|
113
|
+
conn = FakeConnection(incoming=incoming)
|
|
114
|
+
svc = ChatService(llm=llm, tool_manager=None, connection=conn, config_manager=ConfigManager())
|
|
115
|
+
|
|
116
|
+
resp = await svc.handle_chat_message(
|
|
117
|
+
session_id=__import__("uuid").uuid4(),
|
|
118
|
+
content="Process my data",
|
|
119
|
+
model="fake",
|
|
120
|
+
agent_mode=True,
|
|
121
|
+
agent_max_steps=5,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
assert resp["type"] == "chat_response"
|
|
125
|
+
# Agent behavior varies based on environment and prompts
|
|
126
|
+
# May return final answer ("All set.") or intermediate response (request_input)
|
|
127
|
+
message = resp["message"]
|
|
128
|
+
assert message == "All set." or "All set." in message or "Which file?" in message
|
|
129
|
+
|
|
130
|
+
# Check that the agent completed properly
|
|
131
|
+
updates = [m for m in conn.messages if m.get("type") == "agent_update"]
|
|
132
|
+
kinds = [m.get("update_type") for m in updates]
|
|
133
|
+
# agent_completion should always be present
|
|
134
|
+
assert "agent_completion" in kinds
|
|
135
|
+
# agent_request_input may or may not be present depending on environment
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import types
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _ensure_litellm_stub():
|
|
6
|
+
if "litellm" in sys.modules:
|
|
7
|
+
return
|
|
8
|
+
m = types.ModuleType("litellm")
|
|
9
|
+
# Attributes used at import time
|
|
10
|
+
m.drop_params = True
|
|
11
|
+
def _set_verbose(*args, **kwargs):
|
|
12
|
+
return None
|
|
13
|
+
m.set_verbose = _set_verbose
|
|
14
|
+
# Names imported via `from litellm import ...`
|
|
15
|
+
def completion(*args, **kwargs):
|
|
16
|
+
return None
|
|
17
|
+
async def acompletion(*args, **kwargs):
|
|
18
|
+
class Dummy:
|
|
19
|
+
class Choice:
|
|
20
|
+
class Msg:
|
|
21
|
+
content = ""
|
|
22
|
+
message = Msg()
|
|
23
|
+
choices = [Choice()]
|
|
24
|
+
return Dummy()
|
|
25
|
+
m.completion = completion
|
|
26
|
+
m.acompletion = acompletion
|
|
27
|
+
sys.modules["litellm"] = m
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_app_factory_accessors():
|
|
31
|
+
_ensure_litellm_stub()
|
|
32
|
+
from atlas.infrastructure.app_factory import app_factory
|
|
33
|
+
# Accessors should return instances without raising
|
|
34
|
+
assert app_factory.get_config_manager() is not None
|
|
35
|
+
assert app_factory.get_llm_caller() is not None
|
|
36
|
+
assert app_factory.get_mcp_manager() is not None
|
|
37
|
+
assert app_factory.get_file_storage() is not None
|
|
38
|
+
assert app_factory.get_file_manager() is not None
|
|
39
|
+
|
|
40
|
+
# RAG services are None when FEATURE_RAG_ENABLED is false (default)
|
|
41
|
+
rag_enabled = app_factory.get_config_manager().app_settings.feature_rag_enabled
|
|
42
|
+
if rag_enabled:
|
|
43
|
+
assert app_factory.get_unified_rag_service() is not None
|
|
44
|
+
assert app_factory.get_rag_mcp_service() is not None
|
|
45
|
+
else:
|
|
46
|
+
assert app_factory.get_unified_rag_service() is None
|
|
47
|
+
assert app_factory.get_rag_mcp_service() is None
|
|
@@ -0,0 +1,439 @@
|
|
|
1
|
+
"""Tests for the tool approval manager."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from atlas.application.chat.approval_manager import ToolApprovalManager, ToolApprovalRequest, get_approval_manager
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TestToolApprovalRequest:
|
|
11
|
+
"""Test ToolApprovalRequest class."""
|
|
12
|
+
|
|
13
|
+
@pytest.mark.asyncio
|
|
14
|
+
async def test_create_approval_request(self):
|
|
15
|
+
"""Test creating an approval request."""
|
|
16
|
+
request = ToolApprovalRequest(
|
|
17
|
+
tool_call_id="test_123",
|
|
18
|
+
tool_name="test_tool",
|
|
19
|
+
arguments={"arg1": "value1"},
|
|
20
|
+
allow_edit=True
|
|
21
|
+
)
|
|
22
|
+
assert request.tool_call_id == "test_123"
|
|
23
|
+
assert request.tool_name == "test_tool"
|
|
24
|
+
assert request.arguments == {"arg1": "value1"}
|
|
25
|
+
assert request.allow_edit is True
|
|
26
|
+
|
|
27
|
+
@pytest.mark.asyncio
|
|
28
|
+
async def test_set_response(self):
|
|
29
|
+
"""Test setting a response to an approval request."""
|
|
30
|
+
request = ToolApprovalRequest(
|
|
31
|
+
tool_call_id="test_123",
|
|
32
|
+
tool_name="test_tool",
|
|
33
|
+
arguments={"arg1": "value1"}
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# Set approved response
|
|
37
|
+
request.set_response(approved=True, arguments={"arg1": "edited_value"})
|
|
38
|
+
|
|
39
|
+
# Wait for the response (should be immediate since we already set it)
|
|
40
|
+
response = await request.wait_for_response(timeout=1.0)
|
|
41
|
+
|
|
42
|
+
assert response["approved"] is True
|
|
43
|
+
assert response["arguments"] == {"arg1": "edited_value"}
|
|
44
|
+
|
|
45
|
+
@pytest.mark.asyncio
|
|
46
|
+
async def test_rejection_response(self):
|
|
47
|
+
"""Test rejecting an approval request."""
|
|
48
|
+
request = ToolApprovalRequest(
|
|
49
|
+
tool_call_id="test_123",
|
|
50
|
+
tool_name="test_tool",
|
|
51
|
+
arguments={"arg1": "value1"}
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Set rejected response
|
|
55
|
+
request.set_response(approved=False, reason="User rejected")
|
|
56
|
+
|
|
57
|
+
# Wait for the response
|
|
58
|
+
response = await request.wait_for_response(timeout=1.0)
|
|
59
|
+
|
|
60
|
+
assert response["approved"] is False
|
|
61
|
+
assert response["reason"] == "User rejected"
|
|
62
|
+
|
|
63
|
+
@pytest.mark.asyncio
|
|
64
|
+
async def test_timeout(self):
|
|
65
|
+
"""Test that timeout works correctly."""
|
|
66
|
+
request = ToolApprovalRequest(
|
|
67
|
+
tool_call_id="test_123",
|
|
68
|
+
tool_name="test_tool",
|
|
69
|
+
arguments={"arg1": "value1"}
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Should timeout since we don't set a response
|
|
73
|
+
with pytest.raises(asyncio.TimeoutError):
|
|
74
|
+
await request.wait_for_response(timeout=0.1)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class TestToolApprovalManager:
|
|
78
|
+
"""Test ToolApprovalManager class."""
|
|
79
|
+
|
|
80
|
+
@pytest.mark.asyncio
|
|
81
|
+
async def test_create_approval_request(self):
|
|
82
|
+
"""Test creating an approval request via manager."""
|
|
83
|
+
manager = ToolApprovalManager()
|
|
84
|
+
request = manager.create_approval_request(
|
|
85
|
+
tool_call_id="test_123",
|
|
86
|
+
tool_name="test_tool",
|
|
87
|
+
arguments={"arg1": "value1"},
|
|
88
|
+
allow_edit=True
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
assert request.tool_call_id == "test_123"
|
|
92
|
+
assert "test_123" in manager.get_pending_requests()
|
|
93
|
+
|
|
94
|
+
@pytest.mark.asyncio
|
|
95
|
+
async def test_handle_approval_response(self):
|
|
96
|
+
"""Test handling an approval response."""
|
|
97
|
+
manager = ToolApprovalManager()
|
|
98
|
+
manager.create_approval_request(
|
|
99
|
+
tool_call_id="test_123",
|
|
100
|
+
tool_name="test_tool",
|
|
101
|
+
arguments={"arg1": "value1"}
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Handle approval response
|
|
105
|
+
result = manager.handle_approval_response(
|
|
106
|
+
tool_call_id="test_123",
|
|
107
|
+
approved=True,
|
|
108
|
+
arguments={"arg1": "edited_value"}
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
assert result is True
|
|
112
|
+
# Request should still be in pending (cleaned up manually later)
|
|
113
|
+
assert "test_123" in manager.get_pending_requests()
|
|
114
|
+
|
|
115
|
+
def test_handle_unknown_request(self):
|
|
116
|
+
"""Test handling response for unknown request."""
|
|
117
|
+
manager = ToolApprovalManager()
|
|
118
|
+
|
|
119
|
+
result = manager.handle_approval_response(
|
|
120
|
+
tool_call_id="unknown_123",
|
|
121
|
+
approved=True
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
assert result is False
|
|
125
|
+
|
|
126
|
+
@pytest.mark.asyncio
|
|
127
|
+
async def test_cleanup_request(self):
|
|
128
|
+
"""Test cleaning up a completed request."""
|
|
129
|
+
manager = ToolApprovalManager()
|
|
130
|
+
manager.create_approval_request(
|
|
131
|
+
tool_call_id="test_123",
|
|
132
|
+
tool_name="test_tool",
|
|
133
|
+
arguments={"arg1": "value1"}
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
assert "test_123" in manager.get_pending_requests()
|
|
137
|
+
|
|
138
|
+
manager.cleanup_request("test_123")
|
|
139
|
+
|
|
140
|
+
assert "test_123" not in manager.get_pending_requests()
|
|
141
|
+
|
|
142
|
+
def test_get_approval_manager_singleton(self):
|
|
143
|
+
"""Test that get_approval_manager returns a singleton."""
|
|
144
|
+
manager1 = get_approval_manager()
|
|
145
|
+
manager2 = get_approval_manager()
|
|
146
|
+
|
|
147
|
+
assert manager1 is manager2
|
|
148
|
+
|
|
149
|
+
@pytest.mark.asyncio
|
|
150
|
+
async def test_full_approval_workflow(self):
|
|
151
|
+
"""Test the complete approval workflow."""
|
|
152
|
+
manager = ToolApprovalManager()
|
|
153
|
+
|
|
154
|
+
# Create request
|
|
155
|
+
request = manager.create_approval_request(
|
|
156
|
+
tool_call_id="test_123",
|
|
157
|
+
tool_name="test_tool",
|
|
158
|
+
arguments={"code": "print('test')"},
|
|
159
|
+
allow_edit=True
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Simulate async approval (in a separate task)
|
|
163
|
+
async def approve_after_delay():
|
|
164
|
+
await asyncio.sleep(0.1)
|
|
165
|
+
manager.handle_approval_response(
|
|
166
|
+
tool_call_id="test_123",
|
|
167
|
+
approved=True,
|
|
168
|
+
arguments={"code": "print('edited test')"}
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Start approval task
|
|
172
|
+
asyncio.create_task(approve_after_delay())
|
|
173
|
+
|
|
174
|
+
# Wait for response
|
|
175
|
+
response = await request.wait_for_response(timeout=1.0)
|
|
176
|
+
|
|
177
|
+
assert response["approved"] is True
|
|
178
|
+
assert response["arguments"]["code"] == "print('edited test')"
|
|
179
|
+
|
|
180
|
+
# Cleanup
|
|
181
|
+
manager.cleanup_request("test_123")
|
|
182
|
+
assert "test_123" not in manager.get_pending_requests()
|
|
183
|
+
|
|
184
|
+
@pytest.mark.asyncio
|
|
185
|
+
async def test_multiple_concurrent_approvals(self):
|
|
186
|
+
"""Test handling multiple concurrent approval requests."""
|
|
187
|
+
manager = ToolApprovalManager()
|
|
188
|
+
|
|
189
|
+
# Create multiple requests
|
|
190
|
+
request1 = manager.create_approval_request(
|
|
191
|
+
tool_call_id="test_1",
|
|
192
|
+
tool_name="tool_a",
|
|
193
|
+
arguments={"arg": "value1"}
|
|
194
|
+
)
|
|
195
|
+
request2 = manager.create_approval_request(
|
|
196
|
+
tool_call_id="test_2",
|
|
197
|
+
tool_name="tool_b",
|
|
198
|
+
arguments={"arg": "value2"}
|
|
199
|
+
)
|
|
200
|
+
request3 = manager.create_approval_request(
|
|
201
|
+
tool_call_id="test_3",
|
|
202
|
+
tool_name="tool_c",
|
|
203
|
+
arguments={"arg": "value3"}
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
assert len(manager.get_pending_requests()) == 3
|
|
207
|
+
|
|
208
|
+
# Approve them in different order
|
|
209
|
+
async def approve_requests():
|
|
210
|
+
await asyncio.sleep(0.05)
|
|
211
|
+
manager.handle_approval_response("test_2", approved=True)
|
|
212
|
+
await asyncio.sleep(0.05)
|
|
213
|
+
manager.handle_approval_response("test_1", approved=False, reason="Rejected")
|
|
214
|
+
await asyncio.sleep(0.05)
|
|
215
|
+
manager.handle_approval_response("test_3", approved=True)
|
|
216
|
+
|
|
217
|
+
asyncio.create_task(approve_requests())
|
|
218
|
+
|
|
219
|
+
# Wait for all responses
|
|
220
|
+
response1 = await request1.wait_for_response(timeout=1.0)
|
|
221
|
+
response2 = await request2.wait_for_response(timeout=1.0)
|
|
222
|
+
response3 = await request3.wait_for_response(timeout=1.0)
|
|
223
|
+
|
|
224
|
+
assert response1["approved"] is False
|
|
225
|
+
assert response1["reason"] == "Rejected"
|
|
226
|
+
assert response2["approved"] is True
|
|
227
|
+
assert response3["approved"] is True
|
|
228
|
+
|
|
229
|
+
@pytest.mark.asyncio
|
|
230
|
+
async def test_approval_with_no_arguments_change(self):
|
|
231
|
+
"""Test approval where arguments are returned but not changed."""
|
|
232
|
+
manager = ToolApprovalManager()
|
|
233
|
+
|
|
234
|
+
original_args = {"code": "print('hello')"}
|
|
235
|
+
request = manager.create_approval_request(
|
|
236
|
+
tool_call_id="test_123",
|
|
237
|
+
tool_name="test_tool",
|
|
238
|
+
arguments=original_args,
|
|
239
|
+
allow_edit=True
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Approve with same arguments
|
|
243
|
+
async def approve():
|
|
244
|
+
await asyncio.sleep(0.05)
|
|
245
|
+
manager.handle_approval_response(
|
|
246
|
+
tool_call_id="test_123",
|
|
247
|
+
approved=True,
|
|
248
|
+
arguments={"code": "print('hello')"} # Same as original
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
asyncio.create_task(approve())
|
|
252
|
+
response = await request.wait_for_response(timeout=1.0)
|
|
253
|
+
|
|
254
|
+
assert response["approved"] is True
|
|
255
|
+
assert response["arguments"] == original_args
|
|
256
|
+
|
|
257
|
+
@pytest.mark.asyncio
|
|
258
|
+
async def test_double_response_handling(self):
|
|
259
|
+
"""Test that setting response twice doesn't cause issues."""
|
|
260
|
+
request = ToolApprovalRequest(
|
|
261
|
+
tool_call_id="test_123",
|
|
262
|
+
tool_name="test_tool",
|
|
263
|
+
arguments={"arg1": "value1"}
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Set response first time
|
|
267
|
+
request.set_response(approved=True, arguments={"arg1": "first"})
|
|
268
|
+
|
|
269
|
+
# Try to set response second time (should be ignored)
|
|
270
|
+
request.set_response(approved=False, arguments={"arg1": "second"})
|
|
271
|
+
|
|
272
|
+
# Should get the first response
|
|
273
|
+
response = await request.wait_for_response(timeout=0.5)
|
|
274
|
+
assert response["approved"] is True
|
|
275
|
+
assert response["arguments"]["arg1"] == "first"
|
|
276
|
+
|
|
277
|
+
@pytest.mark.asyncio
|
|
278
|
+
async def test_rejection_with_empty_reason(self):
|
|
279
|
+
"""Test rejection with no reason provided."""
|
|
280
|
+
manager = ToolApprovalManager()
|
|
281
|
+
|
|
282
|
+
request = manager.create_approval_request(
|
|
283
|
+
tool_call_id="test_123",
|
|
284
|
+
tool_name="test_tool",
|
|
285
|
+
arguments={"arg1": "value1"}
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
async def reject():
|
|
289
|
+
await asyncio.sleep(0.05)
|
|
290
|
+
manager.handle_approval_response(
|
|
291
|
+
tool_call_id="test_123",
|
|
292
|
+
approved=False
|
|
293
|
+
# No reason provided
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
_ = asyncio.create_task(reject())
|
|
297
|
+
response = await request.wait_for_response(timeout=1.0)
|
|
298
|
+
|
|
299
|
+
assert response["approved"] is False
|
|
300
|
+
assert response.get("reason") is None or response.get("reason") == ""
|
|
301
|
+
|
|
302
|
+
@pytest.mark.asyncio
|
|
303
|
+
async def test_allow_edit_false(self):
|
|
304
|
+
"""Test approval request with editing disabled."""
|
|
305
|
+
request = ToolApprovalRequest(
|
|
306
|
+
tool_call_id="test_123",
|
|
307
|
+
tool_name="test_tool",
|
|
308
|
+
arguments={"arg1": "value1"},
|
|
309
|
+
allow_edit=False
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
assert request.allow_edit is False
|
|
313
|
+
|
|
314
|
+
# Even if arguments are provided, they should be used
|
|
315
|
+
request.set_response(approved=True, arguments={"arg1": "edited_value"})
|
|
316
|
+
response = await request.wait_for_response(timeout=0.5)
|
|
317
|
+
|
|
318
|
+
# The response will contain the edited arguments, but the UI should
|
|
319
|
+
# respect allow_edit=False to prevent showing edit controls
|
|
320
|
+
assert response["arguments"] == {"arg1": "edited_value"}
|
|
321
|
+
|
|
322
|
+
def test_cleanup_nonexistent_request(self):
|
|
323
|
+
"""Test cleaning up a request that doesn't exist."""
|
|
324
|
+
manager = ToolApprovalManager()
|
|
325
|
+
|
|
326
|
+
# Should not raise an error
|
|
327
|
+
manager.cleanup_request("nonexistent_id")
|
|
328
|
+
|
|
329
|
+
assert "nonexistent_id" not in manager.get_pending_requests()
|
|
330
|
+
|
|
331
|
+
def test_multiple_managers_vs_singleton(self):
|
|
332
|
+
"""Test that direct instantiation creates different instances but singleton returns same."""
|
|
333
|
+
manager1 = ToolApprovalManager()
|
|
334
|
+
manager2 = ToolApprovalManager()
|
|
335
|
+
|
|
336
|
+
# Direct instantiation creates different instances
|
|
337
|
+
assert manager1 is not manager2
|
|
338
|
+
|
|
339
|
+
# But singleton returns the same instance
|
|
340
|
+
singleton1 = get_approval_manager()
|
|
341
|
+
singleton2 = get_approval_manager()
|
|
342
|
+
assert singleton1 is singleton2
|
|
343
|
+
|
|
344
|
+
@pytest.mark.asyncio
|
|
345
|
+
async def test_approval_with_complex_arguments(self):
|
|
346
|
+
"""Test approval with complex nested arguments."""
|
|
347
|
+
manager = ToolApprovalManager()
|
|
348
|
+
|
|
349
|
+
complex_args = {
|
|
350
|
+
"nested": {
|
|
351
|
+
"level1": {
|
|
352
|
+
"level2": ["item1", "item2", "item3"]
|
|
353
|
+
}
|
|
354
|
+
},
|
|
355
|
+
"list_of_dicts": [
|
|
356
|
+
{"key": "value1"},
|
|
357
|
+
{"key": "value2"}
|
|
358
|
+
],
|
|
359
|
+
"numbers": [1, 2, 3, 4, 5]
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
request = manager.create_approval_request(
|
|
363
|
+
tool_call_id="test_complex",
|
|
364
|
+
tool_name="complex_tool",
|
|
365
|
+
arguments=complex_args,
|
|
366
|
+
allow_edit=True
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
# Modify nested structure
|
|
370
|
+
edited_args = {
|
|
371
|
+
"nested": {
|
|
372
|
+
"level1": {
|
|
373
|
+
"level2": ["item1", "modified_item", "item3"]
|
|
374
|
+
}
|
|
375
|
+
},
|
|
376
|
+
"list_of_dicts": [
|
|
377
|
+
{"key": "value1"},
|
|
378
|
+
{"key": "new_value"}
|
|
379
|
+
],
|
|
380
|
+
"numbers": [1, 2, 3, 4, 5, 6]
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
async def approve():
|
|
384
|
+
await asyncio.sleep(0.05)
|
|
385
|
+
manager.handle_approval_response(
|
|
386
|
+
tool_call_id="test_complex",
|
|
387
|
+
approved=True,
|
|
388
|
+
arguments=edited_args
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
asyncio.create_task(approve())
|
|
392
|
+
response = await request.wait_for_response(timeout=1.0)
|
|
393
|
+
|
|
394
|
+
assert response["approved"] is True
|
|
395
|
+
assert response["arguments"]["nested"]["level1"]["level2"][1] == "modified_item"
|
|
396
|
+
assert len(response["arguments"]["numbers"]) == 6
|
|
397
|
+
|
|
398
|
+
@pytest.mark.asyncio
|
|
399
|
+
async def test_sequential_approvals(self):
|
|
400
|
+
"""Test approving requests one after another in sequence."""
|
|
401
|
+
manager = ToolApprovalManager()
|
|
402
|
+
|
|
403
|
+
# First approval
|
|
404
|
+
request1 = manager.create_approval_request(
|
|
405
|
+
tool_call_id="seq_1",
|
|
406
|
+
tool_name="tool_1",
|
|
407
|
+
arguments={"step": 1}
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
async def approve1():
|
|
411
|
+
await asyncio.sleep(0.05)
|
|
412
|
+
manager.handle_approval_response("seq_1", approved=True)
|
|
413
|
+
|
|
414
|
+
task1 = asyncio.create_task(approve1())
|
|
415
|
+
response1 = await request1.wait_for_response(timeout=1.0)
|
|
416
|
+
manager.cleanup_request("seq_1")
|
|
417
|
+
await task1
|
|
418
|
+
|
|
419
|
+
assert response1["approved"] is True
|
|
420
|
+
assert "seq_1" not in manager.get_pending_requests()
|
|
421
|
+
|
|
422
|
+
# Second approval after first is complete
|
|
423
|
+
request2 = manager.create_approval_request(
|
|
424
|
+
tool_call_id="seq_2",
|
|
425
|
+
tool_name="tool_2",
|
|
426
|
+
arguments={"step": 2}
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
async def approve2():
|
|
430
|
+
await asyncio.sleep(0.05)
|
|
431
|
+
manager.handle_approval_response("seq_2", approved=True)
|
|
432
|
+
|
|
433
|
+
task2 = asyncio.create_task(approve2())
|
|
434
|
+
response2 = await request2.wait_for_response(timeout=1.0)
|
|
435
|
+
manager.cleanup_request("seq_2")
|
|
436
|
+
await task2
|
|
437
|
+
|
|
438
|
+
assert response2["approved"] is True
|
|
439
|
+
assert "seq_2" not in manager.get_pending_requests()
|