atlas-chat 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. atlas/__init__.py +40 -0
  2. atlas/application/__init__.py +7 -0
  3. atlas/application/chat/__init__.py +7 -0
  4. atlas/application/chat/agent/__init__.py +10 -0
  5. atlas/application/chat/agent/act_loop.py +179 -0
  6. atlas/application/chat/agent/factory.py +142 -0
  7. atlas/application/chat/agent/protocols.py +46 -0
  8. atlas/application/chat/agent/react_loop.py +338 -0
  9. atlas/application/chat/agent/think_act_loop.py +171 -0
  10. atlas/application/chat/approval_manager.py +151 -0
  11. atlas/application/chat/elicitation_manager.py +191 -0
  12. atlas/application/chat/events/__init__.py +1 -0
  13. atlas/application/chat/events/agent_event_relay.py +112 -0
  14. atlas/application/chat/modes/__init__.py +1 -0
  15. atlas/application/chat/modes/agent.py +125 -0
  16. atlas/application/chat/modes/plain.py +74 -0
  17. atlas/application/chat/modes/rag.py +81 -0
  18. atlas/application/chat/modes/tools.py +179 -0
  19. atlas/application/chat/orchestrator.py +213 -0
  20. atlas/application/chat/policies/__init__.py +1 -0
  21. atlas/application/chat/policies/tool_authorization.py +99 -0
  22. atlas/application/chat/preprocessors/__init__.py +1 -0
  23. atlas/application/chat/preprocessors/message_builder.py +92 -0
  24. atlas/application/chat/preprocessors/prompt_override_service.py +104 -0
  25. atlas/application/chat/service.py +454 -0
  26. atlas/application/chat/utilities/__init__.py +6 -0
  27. atlas/application/chat/utilities/error_handler.py +367 -0
  28. atlas/application/chat/utilities/event_notifier.py +546 -0
  29. atlas/application/chat/utilities/file_processor.py +613 -0
  30. atlas/application/chat/utilities/tool_executor.py +789 -0
  31. atlas/atlas_chat_cli.py +347 -0
  32. atlas/atlas_client.py +238 -0
  33. atlas/core/__init__.py +0 -0
  34. atlas/core/auth.py +205 -0
  35. atlas/core/authorization_manager.py +27 -0
  36. atlas/core/capabilities.py +123 -0
  37. atlas/core/compliance.py +215 -0
  38. atlas/core/domain_whitelist.py +147 -0
  39. atlas/core/domain_whitelist_middleware.py +82 -0
  40. atlas/core/http_client.py +28 -0
  41. atlas/core/log_sanitizer.py +102 -0
  42. atlas/core/metrics_logger.py +59 -0
  43. atlas/core/middleware.py +131 -0
  44. atlas/core/otel_config.py +242 -0
  45. atlas/core/prompt_risk.py +200 -0
  46. atlas/core/rate_limit.py +0 -0
  47. atlas/core/rate_limit_middleware.py +64 -0
  48. atlas/core/security_headers_middleware.py +51 -0
  49. atlas/domain/__init__.py +37 -0
  50. atlas/domain/chat/__init__.py +1 -0
  51. atlas/domain/chat/dtos.py +85 -0
  52. atlas/domain/errors.py +96 -0
  53. atlas/domain/messages/__init__.py +12 -0
  54. atlas/domain/messages/models.py +160 -0
  55. atlas/domain/rag_mcp_service.py +664 -0
  56. atlas/domain/sessions/__init__.py +7 -0
  57. atlas/domain/sessions/models.py +36 -0
  58. atlas/domain/unified_rag_service.py +371 -0
  59. atlas/infrastructure/__init__.py +10 -0
  60. atlas/infrastructure/app_factory.py +135 -0
  61. atlas/infrastructure/events/__init__.py +1 -0
  62. atlas/infrastructure/events/cli_event_publisher.py +140 -0
  63. atlas/infrastructure/events/websocket_publisher.py +140 -0
  64. atlas/infrastructure/sessions/in_memory_repository.py +56 -0
  65. atlas/infrastructure/transport/__init__.py +7 -0
  66. atlas/infrastructure/transport/websocket_connection_adapter.py +33 -0
  67. atlas/init_cli.py +226 -0
  68. atlas/interfaces/__init__.py +15 -0
  69. atlas/interfaces/events.py +134 -0
  70. atlas/interfaces/llm.py +54 -0
  71. atlas/interfaces/rag.py +40 -0
  72. atlas/interfaces/sessions.py +75 -0
  73. atlas/interfaces/tools.py +57 -0
  74. atlas/interfaces/transport.py +24 -0
  75. atlas/main.py +564 -0
  76. atlas/mcp/api_key_demo/README.md +76 -0
  77. atlas/mcp/api_key_demo/main.py +172 -0
  78. atlas/mcp/api_key_demo/run.sh +56 -0
  79. atlas/mcp/basictable/main.py +147 -0
  80. atlas/mcp/calculator/main.py +149 -0
  81. atlas/mcp/code-executor/execution_engine.py +98 -0
  82. atlas/mcp/code-executor/execution_environment.py +95 -0
  83. atlas/mcp/code-executor/main.py +528 -0
  84. atlas/mcp/code-executor/result_processing.py +276 -0
  85. atlas/mcp/code-executor/script_generation.py +195 -0
  86. atlas/mcp/code-executor/security_checker.py +140 -0
  87. atlas/mcp/corporate_cars/main.py +437 -0
  88. atlas/mcp/csv_reporter/main.py +545 -0
  89. atlas/mcp/duckduckgo/main.py +182 -0
  90. atlas/mcp/elicitation_demo/README.md +171 -0
  91. atlas/mcp/elicitation_demo/main.py +262 -0
  92. atlas/mcp/env-demo/README.md +158 -0
  93. atlas/mcp/env-demo/main.py +199 -0
  94. atlas/mcp/file_size_test/main.py +284 -0
  95. atlas/mcp/filesystem/main.py +348 -0
  96. atlas/mcp/image_demo/main.py +113 -0
  97. atlas/mcp/image_demo/requirements.txt +4 -0
  98. atlas/mcp/logging_demo/README.md +72 -0
  99. atlas/mcp/logging_demo/main.py +103 -0
  100. atlas/mcp/many_tools_demo/main.py +50 -0
  101. atlas/mcp/order_database/__init__.py +0 -0
  102. atlas/mcp/order_database/main.py +369 -0
  103. atlas/mcp/order_database/signal_data.csv +1001 -0
  104. atlas/mcp/pdfbasic/main.py +394 -0
  105. atlas/mcp/pptx_generator/main.py +760 -0
  106. atlas/mcp/pptx_generator/requirements.txt +13 -0
  107. atlas/mcp/pptx_generator/run_test.sh +1 -0
  108. atlas/mcp/pptx_generator/test_pptx_generator_security.py +169 -0
  109. atlas/mcp/progress_demo/main.py +167 -0
  110. atlas/mcp/progress_updates_demo/QUICKSTART.md +273 -0
  111. atlas/mcp/progress_updates_demo/README.md +120 -0
  112. atlas/mcp/progress_updates_demo/main.py +497 -0
  113. atlas/mcp/prompts/main.py +222 -0
  114. atlas/mcp/public_demo/main.py +189 -0
  115. atlas/mcp/sampling_demo/README.md +169 -0
  116. atlas/mcp/sampling_demo/main.py +234 -0
  117. atlas/mcp/thinking/main.py +77 -0
  118. atlas/mcp/tool_planner/main.py +240 -0
  119. atlas/mcp/ui-demo/badmesh.png +0 -0
  120. atlas/mcp/ui-demo/main.py +383 -0
  121. atlas/mcp/ui-demo/templates/button_demo.html +32 -0
  122. atlas/mcp/ui-demo/templates/data_visualization.html +32 -0
  123. atlas/mcp/ui-demo/templates/form_demo.html +28 -0
  124. atlas/mcp/username-override-demo/README.md +320 -0
  125. atlas/mcp/username-override-demo/main.py +308 -0
  126. atlas/modules/__init__.py +0 -0
  127. atlas/modules/config/__init__.py +34 -0
  128. atlas/modules/config/cli.py +231 -0
  129. atlas/modules/config/config_manager.py +1096 -0
  130. atlas/modules/file_storage/__init__.py +22 -0
  131. atlas/modules/file_storage/cli.py +330 -0
  132. atlas/modules/file_storage/content_extractor.py +290 -0
  133. atlas/modules/file_storage/manager.py +295 -0
  134. atlas/modules/file_storage/mock_s3_client.py +402 -0
  135. atlas/modules/file_storage/s3_client.py +417 -0
  136. atlas/modules/llm/__init__.py +19 -0
  137. atlas/modules/llm/caller.py +287 -0
  138. atlas/modules/llm/litellm_caller.py +675 -0
  139. atlas/modules/llm/models.py +19 -0
  140. atlas/modules/mcp_tools/__init__.py +17 -0
  141. atlas/modules/mcp_tools/client.py +2123 -0
  142. atlas/modules/mcp_tools/token_storage.py +556 -0
  143. atlas/modules/prompts/prompt_provider.py +130 -0
  144. atlas/modules/rag/__init__.py +24 -0
  145. atlas/modules/rag/atlas_rag_client.py +336 -0
  146. atlas/modules/rag/client.py +129 -0
  147. atlas/routes/admin_routes.py +865 -0
  148. atlas/routes/config_routes.py +484 -0
  149. atlas/routes/feedback_routes.py +361 -0
  150. atlas/routes/files_routes.py +274 -0
  151. atlas/routes/health_routes.py +40 -0
  152. atlas/routes/mcp_auth_routes.py +223 -0
  153. atlas/server_cli.py +164 -0
  154. atlas/tests/conftest.py +20 -0
  155. atlas/tests/integration/test_mcp_auth_integration.py +152 -0
  156. atlas/tests/manual_test_sampling.py +87 -0
  157. atlas/tests/modules/mcp_tools/test_client_auth.py +226 -0
  158. atlas/tests/modules/mcp_tools/test_client_env.py +191 -0
  159. atlas/tests/test_admin_mcp_server_management_routes.py +141 -0
  160. atlas/tests/test_agent_roa.py +135 -0
  161. atlas/tests/test_app_factory_smoke.py +47 -0
  162. atlas/tests/test_approval_manager.py +439 -0
  163. atlas/tests/test_atlas_client.py +188 -0
  164. atlas/tests/test_atlas_rag_client.py +447 -0
  165. atlas/tests/test_atlas_rag_integration.py +224 -0
  166. atlas/tests/test_attach_file_flow.py +287 -0
  167. atlas/tests/test_auth_utils.py +165 -0
  168. atlas/tests/test_backend_public_url.py +185 -0
  169. atlas/tests/test_banner_logging.py +287 -0
  170. atlas/tests/test_capability_tokens_and_injection.py +203 -0
  171. atlas/tests/test_compliance_level.py +54 -0
  172. atlas/tests/test_compliance_manager.py +253 -0
  173. atlas/tests/test_config_manager.py +617 -0
  174. atlas/tests/test_config_manager_paths.py +12 -0
  175. atlas/tests/test_core_auth.py +18 -0
  176. atlas/tests/test_core_utils.py +190 -0
  177. atlas/tests/test_docker_env_sync.py +202 -0
  178. atlas/tests/test_domain_errors.py +329 -0
  179. atlas/tests/test_domain_whitelist.py +359 -0
  180. atlas/tests/test_elicitation_manager.py +408 -0
  181. atlas/tests/test_elicitation_routing.py +296 -0
  182. atlas/tests/test_env_demo_server.py +88 -0
  183. atlas/tests/test_error_classification.py +113 -0
  184. atlas/tests/test_error_flow_integration.py +116 -0
  185. atlas/tests/test_feedback_routes.py +333 -0
  186. atlas/tests/test_file_content_extraction.py +1134 -0
  187. atlas/tests/test_file_extraction_routes.py +158 -0
  188. atlas/tests/test_file_library.py +107 -0
  189. atlas/tests/test_file_manager_unit.py +18 -0
  190. atlas/tests/test_health_route.py +49 -0
  191. atlas/tests/test_http_client_stub.py +8 -0
  192. atlas/tests/test_imports_smoke.py +30 -0
  193. atlas/tests/test_interfaces_llm_response.py +9 -0
  194. atlas/tests/test_issue_access_denied_fix.py +136 -0
  195. atlas/tests/test_llm_env_expansion.py +836 -0
  196. atlas/tests/test_log_level_sensitive_data.py +285 -0
  197. atlas/tests/test_mcp_auth_routes.py +341 -0
  198. atlas/tests/test_mcp_client_auth.py +331 -0
  199. atlas/tests/test_mcp_data_injection.py +270 -0
  200. atlas/tests/test_mcp_get_authorized_servers.py +95 -0
  201. atlas/tests/test_mcp_hot_reload.py +512 -0
  202. atlas/tests/test_mcp_image_content.py +424 -0
  203. atlas/tests/test_mcp_logging.py +172 -0
  204. atlas/tests/test_mcp_progress_updates.py +313 -0
  205. atlas/tests/test_mcp_prompt_override_system_prompt.py +102 -0
  206. atlas/tests/test_mcp_prompts_server.py +39 -0
  207. atlas/tests/test_mcp_tool_result_parsing.py +296 -0
  208. atlas/tests/test_metrics_logger.py +56 -0
  209. atlas/tests/test_middleware_auth.py +379 -0
  210. atlas/tests/test_prompt_risk_and_acl.py +141 -0
  211. atlas/tests/test_rag_mcp_aggregator.py +204 -0
  212. atlas/tests/test_rag_mcp_service.py +224 -0
  213. atlas/tests/test_rate_limit_middleware.py +45 -0
  214. atlas/tests/test_routes_config_smoke.py +60 -0
  215. atlas/tests/test_routes_files_download_token.py +41 -0
  216. atlas/tests/test_routes_files_health.py +18 -0
  217. atlas/tests/test_runtime_imports.py +53 -0
  218. atlas/tests/test_sampling_integration.py +482 -0
  219. atlas/tests/test_security_admin_routes.py +61 -0
  220. atlas/tests/test_security_capability_tokens.py +65 -0
  221. atlas/tests/test_security_file_stats_scope.py +21 -0
  222. atlas/tests/test_security_header_injection.py +191 -0
  223. atlas/tests/test_security_headers_and_filename.py +63 -0
  224. atlas/tests/test_shared_session_repository.py +101 -0
  225. atlas/tests/test_system_prompt_loading.py +181 -0
  226. atlas/tests/test_token_storage.py +505 -0
  227. atlas/tests/test_tool_approval_config.py +93 -0
  228. atlas/tests/test_tool_approval_utils.py +356 -0
  229. atlas/tests/test_tool_authorization_group_filtering.py +223 -0
  230. atlas/tests/test_tool_details_in_config.py +108 -0
  231. atlas/tests/test_tool_planner.py +300 -0
  232. atlas/tests/test_unified_rag_service.py +398 -0
  233. atlas/tests/test_username_override_in_approval.py +258 -0
  234. atlas/tests/test_websocket_auth_header.py +168 -0
  235. atlas/version.py +6 -0
  236. atlas_chat-0.1.0.data/data/.env.example +253 -0
  237. atlas_chat-0.1.0.data/data/config/defaults/compliance-levels.json +44 -0
  238. atlas_chat-0.1.0.data/data/config/defaults/domain-whitelist.json +123 -0
  239. atlas_chat-0.1.0.data/data/config/defaults/file-extractors.json +74 -0
  240. atlas_chat-0.1.0.data/data/config/defaults/help-config.json +198 -0
  241. atlas_chat-0.1.0.data/data/config/defaults/llmconfig-buggy.yml +11 -0
  242. atlas_chat-0.1.0.data/data/config/defaults/llmconfig.yml +19 -0
  243. atlas_chat-0.1.0.data/data/config/defaults/mcp.json +138 -0
  244. atlas_chat-0.1.0.data/data/config/defaults/rag-sources.json +17 -0
  245. atlas_chat-0.1.0.data/data/config/defaults/splash-config.json +16 -0
  246. atlas_chat-0.1.0.dist-info/METADATA +236 -0
  247. atlas_chat-0.1.0.dist-info/RECORD +250 -0
  248. atlas_chat-0.1.0.dist-info/WHEEL +5 -0
  249. atlas_chat-0.1.0.dist-info/entry_points.txt +4 -0
  250. atlas_chat-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,135 @@
1
+ import asyncio
2
+ import os
3
+ import sys
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import pytest
7
+
8
+ # Ensure backend root is on path
9
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
10
+
11
+ from atlas.application.chat.service import ChatService # type: ignore
12
+ from atlas.interfaces.llm import LLMProtocol # type: ignore
13
+ from atlas.interfaces.transport import ChatConnectionProtocol # type: ignore
14
+ from atlas.modules.config.config_manager import ConfigManager # type: ignore
15
+
16
+
17
+ class FakeLLM(LLMProtocol):
18
+ """Programmable fake LLM.
19
+
20
+ call_plain returns next item from a queue if provided, else a default.
21
+ call_with_tools is not used in these tests (reason-only scenarios).
22
+ """
23
+
24
+ def __init__(self, plain_responses: Optional[List[str]] = None):
25
+ self._plain = list(plain_responses or [])
26
+
27
+ async def call_plain(self, model_name: str, messages: List[Dict[str, str]], temperature: float = 0.7) -> str:
28
+ if self._plain:
29
+ return self._plain.pop(0)
30
+ return "{\"plan\":\"noop\",\"tools_to_consider\":[],\"finish\":true,\"final_answer\":\"ok\"}"
31
+
32
+ async def call_with_tools(self, model_name: str, messages: List[Dict[str, str]], tools_schema: List[Dict], tool_choice: str = "auto", temperature: float = 0.7):
33
+ # Minimal stub: never returns tool calls in these tests
34
+ from atlas.interfaces.llm import LLMResponse # type: ignore
35
+ return LLMResponse(content="")
36
+
37
+ async def call_with_rag(self, model_name: str, messages: List[Dict[str, str]], data_sources: List[str], user_email: str, temperature: float = 0.7) -> str:
38
+ return "not-used"
39
+
40
+ async def call_with_rag_and_tools(self, model_name: str, messages: List[Dict[str, str]], data_sources: List[str], tools_schema: List[Dict], user_email: str, tool_choice: str = "auto", temperature: float = 0.7):
41
+ from atlas.interfaces.llm import LLMResponse # type: ignore
42
+ return LLMResponse(content="")
43
+
44
+
45
+ class FakeConnection(ChatConnectionProtocol):
46
+ def __init__(self, incoming: Optional[List[Dict[str, Any]]] = None):
47
+ self.messages: List[Dict[str, Any]] = []
48
+ self._queue: asyncio.Queue = asyncio.Queue()
49
+ for item in (incoming or []):
50
+ self._queue.put_nowait(item)
51
+
52
+ async def send_json(self, data: Dict[str, Any]) -> None:
53
+ self.messages.append(data)
54
+
55
+ async def receive_json(self) -> Dict[str, Any]:
56
+ return await self._queue.get()
57
+
58
+ async def accept(self) -> None: # pragma: no cover - not used
59
+ pass
60
+
61
+ async def close(self) -> None: # pragma: no cover - not used
62
+ pass
63
+
64
+
65
+ @pytest.mark.asyncio
66
+ async def test_agent_reason_immediate_finish():
67
+ """Agent should finish in Reason phase when finish=true with final_answer."""
68
+ # Reason -> finish immediately
69
+ reason_finish = (
70
+ "Planning...\n{\"plan\":\"answer now\",\"tools_to_consider\":[],\"finish\":true,\"final_answer\":\"Done!\"}"
71
+ )
72
+ llm = FakeLLM([reason_finish])
73
+ conn = FakeConnection()
74
+ svc = ChatService(llm=llm, tool_manager=None, connection=conn, config_manager=ConfigManager())
75
+
76
+ resp = await svc.handle_chat_message(
77
+ session_id=__import__("uuid").uuid4(),
78
+ content="Hello",
79
+ model="fake",
80
+ agent_mode=True,
81
+ agent_max_steps=3,
82
+ )
83
+
84
+ assert resp["type"] == "chat_response"
85
+ # Agent behavior varies based on whether prompt templates are available
86
+ # With prompts: returns parsed final_answer ("Done!")
87
+ # Without prompts: returns full LLM response
88
+ message = resp["message"]
89
+ assert message == "Done!" or "Done!" in message
90
+
91
+ # Verify agent lifecycle events were emitted
92
+ kinds = [m.get("update_type") for m in conn.messages if m.get("type") == "agent_update"]
93
+ assert "agent_start" in kinds
94
+ assert "agent_turn_start" in kinds
95
+ assert "agent_reason" in kinds
96
+ assert "agent_completion" in kinds
97
+
98
+
99
+ @pytest.mark.asyncio
100
+ async def test_agent_request_input_flow():
101
+ """Agent requests input then finishes on next turn."""
102
+ # Turn 1 Reason -> request_input
103
+ reason_ask = (
104
+ "Thinking...\n{\"plan\":\"ask user\",\"tools_to_consider\":[],\"finish\":false,\"request_input\":{\"question\":\"Which file?\"}}"
105
+ )
106
+ # Turn 2 Reason -> finish with final answer
107
+ reason_finish = (
108
+ "Ok now answer.\n{\"plan\":\"answer\",\"tools_to_consider\":[],\"finish\":true,\"final_answer\":\"All set.\"}"
109
+ )
110
+ llm = FakeLLM([reason_ask, reason_finish])
111
+ # Provide a user response to the request_input
112
+ incoming = [{"type": "agent_user_input", "content": "Use latest."}]
113
+ conn = FakeConnection(incoming=incoming)
114
+ svc = ChatService(llm=llm, tool_manager=None, connection=conn, config_manager=ConfigManager())
115
+
116
+ resp = await svc.handle_chat_message(
117
+ session_id=__import__("uuid").uuid4(),
118
+ content="Process my data",
119
+ model="fake",
120
+ agent_mode=True,
121
+ agent_max_steps=5,
122
+ )
123
+
124
+ assert resp["type"] == "chat_response"
125
+ # Agent behavior varies based on environment and prompts
126
+ # May return final answer ("All set.") or intermediate response (request_input)
127
+ message = resp["message"]
128
+ assert message == "All set." or "All set." in message or "Which file?" in message
129
+
130
+ # Check that the agent completed properly
131
+ updates = [m for m in conn.messages if m.get("type") == "agent_update"]
132
+ kinds = [m.get("update_type") for m in updates]
133
+ # agent_completion should always be present
134
+ assert "agent_completion" in kinds
135
+ # agent_request_input may or may not be present depending on environment
@@ -0,0 +1,47 @@
1
+ import sys
2
+ import types
3
+
4
+
5
+ def _ensure_litellm_stub():
6
+ if "litellm" in sys.modules:
7
+ return
8
+ m = types.ModuleType("litellm")
9
+ # Attributes used at import time
10
+ m.drop_params = True
11
+ def _set_verbose(*args, **kwargs):
12
+ return None
13
+ m.set_verbose = _set_verbose
14
+ # Names imported via `from litellm import ...`
15
+ def completion(*args, **kwargs):
16
+ return None
17
+ async def acompletion(*args, **kwargs):
18
+ class Dummy:
19
+ class Choice:
20
+ class Msg:
21
+ content = ""
22
+ message = Msg()
23
+ choices = [Choice()]
24
+ return Dummy()
25
+ m.completion = completion
26
+ m.acompletion = acompletion
27
+ sys.modules["litellm"] = m
28
+
29
+
30
+ def test_app_factory_accessors():
31
+ _ensure_litellm_stub()
32
+ from atlas.infrastructure.app_factory import app_factory
33
+ # Accessors should return instances without raising
34
+ assert app_factory.get_config_manager() is not None
35
+ assert app_factory.get_llm_caller() is not None
36
+ assert app_factory.get_mcp_manager() is not None
37
+ assert app_factory.get_file_storage() is not None
38
+ assert app_factory.get_file_manager() is not None
39
+
40
+ # RAG services are None when FEATURE_RAG_ENABLED is false (default)
41
+ rag_enabled = app_factory.get_config_manager().app_settings.feature_rag_enabled
42
+ if rag_enabled:
43
+ assert app_factory.get_unified_rag_service() is not None
44
+ assert app_factory.get_rag_mcp_service() is not None
45
+ else:
46
+ assert app_factory.get_unified_rag_service() is None
47
+ assert app_factory.get_rag_mcp_service() is None
@@ -0,0 +1,439 @@
1
+ """Tests for the tool approval manager."""
2
+
3
+ import asyncio
4
+
5
+ import pytest
6
+
7
+ from atlas.application.chat.approval_manager import ToolApprovalManager, ToolApprovalRequest, get_approval_manager
8
+
9
+
10
+ class TestToolApprovalRequest:
11
+ """Test ToolApprovalRequest class."""
12
+
13
+ @pytest.mark.asyncio
14
+ async def test_create_approval_request(self):
15
+ """Test creating an approval request."""
16
+ request = ToolApprovalRequest(
17
+ tool_call_id="test_123",
18
+ tool_name="test_tool",
19
+ arguments={"arg1": "value1"},
20
+ allow_edit=True
21
+ )
22
+ assert request.tool_call_id == "test_123"
23
+ assert request.tool_name == "test_tool"
24
+ assert request.arguments == {"arg1": "value1"}
25
+ assert request.allow_edit is True
26
+
27
+ @pytest.mark.asyncio
28
+ async def test_set_response(self):
29
+ """Test setting a response to an approval request."""
30
+ request = ToolApprovalRequest(
31
+ tool_call_id="test_123",
32
+ tool_name="test_tool",
33
+ arguments={"arg1": "value1"}
34
+ )
35
+
36
+ # Set approved response
37
+ request.set_response(approved=True, arguments={"arg1": "edited_value"})
38
+
39
+ # Wait for the response (should be immediate since we already set it)
40
+ response = await request.wait_for_response(timeout=1.0)
41
+
42
+ assert response["approved"] is True
43
+ assert response["arguments"] == {"arg1": "edited_value"}
44
+
45
+ @pytest.mark.asyncio
46
+ async def test_rejection_response(self):
47
+ """Test rejecting an approval request."""
48
+ request = ToolApprovalRequest(
49
+ tool_call_id="test_123",
50
+ tool_name="test_tool",
51
+ arguments={"arg1": "value1"}
52
+ )
53
+
54
+ # Set rejected response
55
+ request.set_response(approved=False, reason="User rejected")
56
+
57
+ # Wait for the response
58
+ response = await request.wait_for_response(timeout=1.0)
59
+
60
+ assert response["approved"] is False
61
+ assert response["reason"] == "User rejected"
62
+
63
+ @pytest.mark.asyncio
64
+ async def test_timeout(self):
65
+ """Test that timeout works correctly."""
66
+ request = ToolApprovalRequest(
67
+ tool_call_id="test_123",
68
+ tool_name="test_tool",
69
+ arguments={"arg1": "value1"}
70
+ )
71
+
72
+ # Should timeout since we don't set a response
73
+ with pytest.raises(asyncio.TimeoutError):
74
+ await request.wait_for_response(timeout=0.1)
75
+
76
+
77
+ class TestToolApprovalManager:
78
+ """Test ToolApprovalManager class."""
79
+
80
+ @pytest.mark.asyncio
81
+ async def test_create_approval_request(self):
82
+ """Test creating an approval request via manager."""
83
+ manager = ToolApprovalManager()
84
+ request = manager.create_approval_request(
85
+ tool_call_id="test_123",
86
+ tool_name="test_tool",
87
+ arguments={"arg1": "value1"},
88
+ allow_edit=True
89
+ )
90
+
91
+ assert request.tool_call_id == "test_123"
92
+ assert "test_123" in manager.get_pending_requests()
93
+
94
+ @pytest.mark.asyncio
95
+ async def test_handle_approval_response(self):
96
+ """Test handling an approval response."""
97
+ manager = ToolApprovalManager()
98
+ manager.create_approval_request(
99
+ tool_call_id="test_123",
100
+ tool_name="test_tool",
101
+ arguments={"arg1": "value1"}
102
+ )
103
+
104
+ # Handle approval response
105
+ result = manager.handle_approval_response(
106
+ tool_call_id="test_123",
107
+ approved=True,
108
+ arguments={"arg1": "edited_value"}
109
+ )
110
+
111
+ assert result is True
112
+ # Request should still be in pending (cleaned up manually later)
113
+ assert "test_123" in manager.get_pending_requests()
114
+
115
+ def test_handle_unknown_request(self):
116
+ """Test handling response for unknown request."""
117
+ manager = ToolApprovalManager()
118
+
119
+ result = manager.handle_approval_response(
120
+ tool_call_id="unknown_123",
121
+ approved=True
122
+ )
123
+
124
+ assert result is False
125
+
126
+ @pytest.mark.asyncio
127
+ async def test_cleanup_request(self):
128
+ """Test cleaning up a completed request."""
129
+ manager = ToolApprovalManager()
130
+ manager.create_approval_request(
131
+ tool_call_id="test_123",
132
+ tool_name="test_tool",
133
+ arguments={"arg1": "value1"}
134
+ )
135
+
136
+ assert "test_123" in manager.get_pending_requests()
137
+
138
+ manager.cleanup_request("test_123")
139
+
140
+ assert "test_123" not in manager.get_pending_requests()
141
+
142
+ def test_get_approval_manager_singleton(self):
143
+ """Test that get_approval_manager returns a singleton."""
144
+ manager1 = get_approval_manager()
145
+ manager2 = get_approval_manager()
146
+
147
+ assert manager1 is manager2
148
+
149
+ @pytest.mark.asyncio
150
+ async def test_full_approval_workflow(self):
151
+ """Test the complete approval workflow."""
152
+ manager = ToolApprovalManager()
153
+
154
+ # Create request
155
+ request = manager.create_approval_request(
156
+ tool_call_id="test_123",
157
+ tool_name="test_tool",
158
+ arguments={"code": "print('test')"},
159
+ allow_edit=True
160
+ )
161
+
162
+ # Simulate async approval (in a separate task)
163
+ async def approve_after_delay():
164
+ await asyncio.sleep(0.1)
165
+ manager.handle_approval_response(
166
+ tool_call_id="test_123",
167
+ approved=True,
168
+ arguments={"code": "print('edited test')"}
169
+ )
170
+
171
+ # Start approval task
172
+ asyncio.create_task(approve_after_delay())
173
+
174
+ # Wait for response
175
+ response = await request.wait_for_response(timeout=1.0)
176
+
177
+ assert response["approved"] is True
178
+ assert response["arguments"]["code"] == "print('edited test')"
179
+
180
+ # Cleanup
181
+ manager.cleanup_request("test_123")
182
+ assert "test_123" not in manager.get_pending_requests()
183
+
184
+ @pytest.mark.asyncio
185
+ async def test_multiple_concurrent_approvals(self):
186
+ """Test handling multiple concurrent approval requests."""
187
+ manager = ToolApprovalManager()
188
+
189
+ # Create multiple requests
190
+ request1 = manager.create_approval_request(
191
+ tool_call_id="test_1",
192
+ tool_name="tool_a",
193
+ arguments={"arg": "value1"}
194
+ )
195
+ request2 = manager.create_approval_request(
196
+ tool_call_id="test_2",
197
+ tool_name="tool_b",
198
+ arguments={"arg": "value2"}
199
+ )
200
+ request3 = manager.create_approval_request(
201
+ tool_call_id="test_3",
202
+ tool_name="tool_c",
203
+ arguments={"arg": "value3"}
204
+ )
205
+
206
+ assert len(manager.get_pending_requests()) == 3
207
+
208
+ # Approve them in different order
209
+ async def approve_requests():
210
+ await asyncio.sleep(0.05)
211
+ manager.handle_approval_response("test_2", approved=True)
212
+ await asyncio.sleep(0.05)
213
+ manager.handle_approval_response("test_1", approved=False, reason="Rejected")
214
+ await asyncio.sleep(0.05)
215
+ manager.handle_approval_response("test_3", approved=True)
216
+
217
+ asyncio.create_task(approve_requests())
218
+
219
+ # Wait for all responses
220
+ response1 = await request1.wait_for_response(timeout=1.0)
221
+ response2 = await request2.wait_for_response(timeout=1.0)
222
+ response3 = await request3.wait_for_response(timeout=1.0)
223
+
224
+ assert response1["approved"] is False
225
+ assert response1["reason"] == "Rejected"
226
+ assert response2["approved"] is True
227
+ assert response3["approved"] is True
228
+
229
+ @pytest.mark.asyncio
230
+ async def test_approval_with_no_arguments_change(self):
231
+ """Test approval where arguments are returned but not changed."""
232
+ manager = ToolApprovalManager()
233
+
234
+ original_args = {"code": "print('hello')"}
235
+ request = manager.create_approval_request(
236
+ tool_call_id="test_123",
237
+ tool_name="test_tool",
238
+ arguments=original_args,
239
+ allow_edit=True
240
+ )
241
+
242
+ # Approve with same arguments
243
+ async def approve():
244
+ await asyncio.sleep(0.05)
245
+ manager.handle_approval_response(
246
+ tool_call_id="test_123",
247
+ approved=True,
248
+ arguments={"code": "print('hello')"} # Same as original
249
+ )
250
+
251
+ asyncio.create_task(approve())
252
+ response = await request.wait_for_response(timeout=1.0)
253
+
254
+ assert response["approved"] is True
255
+ assert response["arguments"] == original_args
256
+
257
+ @pytest.mark.asyncio
258
+ async def test_double_response_handling(self):
259
+ """Test that setting response twice doesn't cause issues."""
260
+ request = ToolApprovalRequest(
261
+ tool_call_id="test_123",
262
+ tool_name="test_tool",
263
+ arguments={"arg1": "value1"}
264
+ )
265
+
266
+ # Set response first time
267
+ request.set_response(approved=True, arguments={"arg1": "first"})
268
+
269
+ # Try to set response second time (should be ignored)
270
+ request.set_response(approved=False, arguments={"arg1": "second"})
271
+
272
+ # Should get the first response
273
+ response = await request.wait_for_response(timeout=0.5)
274
+ assert response["approved"] is True
275
+ assert response["arguments"]["arg1"] == "first"
276
+
277
+ @pytest.mark.asyncio
278
+ async def test_rejection_with_empty_reason(self):
279
+ """Test rejection with no reason provided."""
280
+ manager = ToolApprovalManager()
281
+
282
+ request = manager.create_approval_request(
283
+ tool_call_id="test_123",
284
+ tool_name="test_tool",
285
+ arguments={"arg1": "value1"}
286
+ )
287
+
288
+ async def reject():
289
+ await asyncio.sleep(0.05)
290
+ manager.handle_approval_response(
291
+ tool_call_id="test_123",
292
+ approved=False
293
+ # No reason provided
294
+ )
295
+
296
+ _ = asyncio.create_task(reject())
297
+ response = await request.wait_for_response(timeout=1.0)
298
+
299
+ assert response["approved"] is False
300
+ assert response.get("reason") is None or response.get("reason") == ""
301
+
302
+ @pytest.mark.asyncio
303
+ async def test_allow_edit_false(self):
304
+ """Test approval request with editing disabled."""
305
+ request = ToolApprovalRequest(
306
+ tool_call_id="test_123",
307
+ tool_name="test_tool",
308
+ arguments={"arg1": "value1"},
309
+ allow_edit=False
310
+ )
311
+
312
+ assert request.allow_edit is False
313
+
314
+ # Even if arguments are provided, they should be used
315
+ request.set_response(approved=True, arguments={"arg1": "edited_value"})
316
+ response = await request.wait_for_response(timeout=0.5)
317
+
318
+ # The response will contain the edited arguments, but the UI should
319
+ # respect allow_edit=False to prevent showing edit controls
320
+ assert response["arguments"] == {"arg1": "edited_value"}
321
+
322
+ def test_cleanup_nonexistent_request(self):
323
+ """Test cleaning up a request that doesn't exist."""
324
+ manager = ToolApprovalManager()
325
+
326
+ # Should not raise an error
327
+ manager.cleanup_request("nonexistent_id")
328
+
329
+ assert "nonexistent_id" not in manager.get_pending_requests()
330
+
331
+ def test_multiple_managers_vs_singleton(self):
332
+ """Test that direct instantiation creates different instances but singleton returns same."""
333
+ manager1 = ToolApprovalManager()
334
+ manager2 = ToolApprovalManager()
335
+
336
+ # Direct instantiation creates different instances
337
+ assert manager1 is not manager2
338
+
339
+ # But singleton returns the same instance
340
+ singleton1 = get_approval_manager()
341
+ singleton2 = get_approval_manager()
342
+ assert singleton1 is singleton2
343
+
344
+ @pytest.mark.asyncio
345
+ async def test_approval_with_complex_arguments(self):
346
+ """Test approval with complex nested arguments."""
347
+ manager = ToolApprovalManager()
348
+
349
+ complex_args = {
350
+ "nested": {
351
+ "level1": {
352
+ "level2": ["item1", "item2", "item3"]
353
+ }
354
+ },
355
+ "list_of_dicts": [
356
+ {"key": "value1"},
357
+ {"key": "value2"}
358
+ ],
359
+ "numbers": [1, 2, 3, 4, 5]
360
+ }
361
+
362
+ request = manager.create_approval_request(
363
+ tool_call_id="test_complex",
364
+ tool_name="complex_tool",
365
+ arguments=complex_args,
366
+ allow_edit=True
367
+ )
368
+
369
+ # Modify nested structure
370
+ edited_args = {
371
+ "nested": {
372
+ "level1": {
373
+ "level2": ["item1", "modified_item", "item3"]
374
+ }
375
+ },
376
+ "list_of_dicts": [
377
+ {"key": "value1"},
378
+ {"key": "new_value"}
379
+ ],
380
+ "numbers": [1, 2, 3, 4, 5, 6]
381
+ }
382
+
383
+ async def approve():
384
+ await asyncio.sleep(0.05)
385
+ manager.handle_approval_response(
386
+ tool_call_id="test_complex",
387
+ approved=True,
388
+ arguments=edited_args
389
+ )
390
+
391
+ asyncio.create_task(approve())
392
+ response = await request.wait_for_response(timeout=1.0)
393
+
394
+ assert response["approved"] is True
395
+ assert response["arguments"]["nested"]["level1"]["level2"][1] == "modified_item"
396
+ assert len(response["arguments"]["numbers"]) == 6
397
+
398
+ @pytest.mark.asyncio
399
+ async def test_sequential_approvals(self):
400
+ """Test approving requests one after another in sequence."""
401
+ manager = ToolApprovalManager()
402
+
403
+ # First approval
404
+ request1 = manager.create_approval_request(
405
+ tool_call_id="seq_1",
406
+ tool_name="tool_1",
407
+ arguments={"step": 1}
408
+ )
409
+
410
+ async def approve1():
411
+ await asyncio.sleep(0.05)
412
+ manager.handle_approval_response("seq_1", approved=True)
413
+
414
+ task1 = asyncio.create_task(approve1())
415
+ response1 = await request1.wait_for_response(timeout=1.0)
416
+ manager.cleanup_request("seq_1")
417
+ await task1
418
+
419
+ assert response1["approved"] is True
420
+ assert "seq_1" not in manager.get_pending_requests()
421
+
422
+ # Second approval after first is complete
423
+ request2 = manager.create_approval_request(
424
+ tool_call_id="seq_2",
425
+ tool_name="tool_2",
426
+ arguments={"step": 2}
427
+ )
428
+
429
+ async def approve2():
430
+ await asyncio.sleep(0.05)
431
+ manager.handle_approval_response("seq_2", approved=True)
432
+
433
+ task2 = asyncio.create_task(approve2())
434
+ response2 = await request2.wait_for_response(timeout=1.0)
435
+ manager.cleanup_request("seq_2")
436
+ await task2
437
+
438
+ assert response2["approved"] is True
439
+ assert "seq_2" not in manager.get_pending_requests()