atlas-chat 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. atlas/__init__.py +40 -0
  2. atlas/application/__init__.py +7 -0
  3. atlas/application/chat/__init__.py +7 -0
  4. atlas/application/chat/agent/__init__.py +10 -0
  5. atlas/application/chat/agent/act_loop.py +179 -0
  6. atlas/application/chat/agent/factory.py +142 -0
  7. atlas/application/chat/agent/protocols.py +46 -0
  8. atlas/application/chat/agent/react_loop.py +338 -0
  9. atlas/application/chat/agent/think_act_loop.py +171 -0
  10. atlas/application/chat/approval_manager.py +151 -0
  11. atlas/application/chat/elicitation_manager.py +191 -0
  12. atlas/application/chat/events/__init__.py +1 -0
  13. atlas/application/chat/events/agent_event_relay.py +112 -0
  14. atlas/application/chat/modes/__init__.py +1 -0
  15. atlas/application/chat/modes/agent.py +125 -0
  16. atlas/application/chat/modes/plain.py +74 -0
  17. atlas/application/chat/modes/rag.py +81 -0
  18. atlas/application/chat/modes/tools.py +179 -0
  19. atlas/application/chat/orchestrator.py +213 -0
  20. atlas/application/chat/policies/__init__.py +1 -0
  21. atlas/application/chat/policies/tool_authorization.py +99 -0
  22. atlas/application/chat/preprocessors/__init__.py +1 -0
  23. atlas/application/chat/preprocessors/message_builder.py +92 -0
  24. atlas/application/chat/preprocessors/prompt_override_service.py +104 -0
  25. atlas/application/chat/service.py +454 -0
  26. atlas/application/chat/utilities/__init__.py +6 -0
  27. atlas/application/chat/utilities/error_handler.py +367 -0
  28. atlas/application/chat/utilities/event_notifier.py +546 -0
  29. atlas/application/chat/utilities/file_processor.py +613 -0
  30. atlas/application/chat/utilities/tool_executor.py +789 -0
  31. atlas/atlas_chat_cli.py +347 -0
  32. atlas/atlas_client.py +238 -0
  33. atlas/core/__init__.py +0 -0
  34. atlas/core/auth.py +205 -0
  35. atlas/core/authorization_manager.py +27 -0
  36. atlas/core/capabilities.py +123 -0
  37. atlas/core/compliance.py +215 -0
  38. atlas/core/domain_whitelist.py +147 -0
  39. atlas/core/domain_whitelist_middleware.py +82 -0
  40. atlas/core/http_client.py +28 -0
  41. atlas/core/log_sanitizer.py +102 -0
  42. atlas/core/metrics_logger.py +59 -0
  43. atlas/core/middleware.py +131 -0
  44. atlas/core/otel_config.py +242 -0
  45. atlas/core/prompt_risk.py +200 -0
  46. atlas/core/rate_limit.py +0 -0
  47. atlas/core/rate_limit_middleware.py +64 -0
  48. atlas/core/security_headers_middleware.py +51 -0
  49. atlas/domain/__init__.py +37 -0
  50. atlas/domain/chat/__init__.py +1 -0
  51. atlas/domain/chat/dtos.py +85 -0
  52. atlas/domain/errors.py +96 -0
  53. atlas/domain/messages/__init__.py +12 -0
  54. atlas/domain/messages/models.py +160 -0
  55. atlas/domain/rag_mcp_service.py +664 -0
  56. atlas/domain/sessions/__init__.py +7 -0
  57. atlas/domain/sessions/models.py +36 -0
  58. atlas/domain/unified_rag_service.py +371 -0
  59. atlas/infrastructure/__init__.py +10 -0
  60. atlas/infrastructure/app_factory.py +135 -0
  61. atlas/infrastructure/events/__init__.py +1 -0
  62. atlas/infrastructure/events/cli_event_publisher.py +140 -0
  63. atlas/infrastructure/events/websocket_publisher.py +140 -0
  64. atlas/infrastructure/sessions/in_memory_repository.py +56 -0
  65. atlas/infrastructure/transport/__init__.py +7 -0
  66. atlas/infrastructure/transport/websocket_connection_adapter.py +33 -0
  67. atlas/init_cli.py +226 -0
  68. atlas/interfaces/__init__.py +15 -0
  69. atlas/interfaces/events.py +134 -0
  70. atlas/interfaces/llm.py +54 -0
  71. atlas/interfaces/rag.py +40 -0
  72. atlas/interfaces/sessions.py +75 -0
  73. atlas/interfaces/tools.py +57 -0
  74. atlas/interfaces/transport.py +24 -0
  75. atlas/main.py +564 -0
  76. atlas/mcp/api_key_demo/README.md +76 -0
  77. atlas/mcp/api_key_demo/main.py +172 -0
  78. atlas/mcp/api_key_demo/run.sh +56 -0
  79. atlas/mcp/basictable/main.py +147 -0
  80. atlas/mcp/calculator/main.py +149 -0
  81. atlas/mcp/code-executor/execution_engine.py +98 -0
  82. atlas/mcp/code-executor/execution_environment.py +95 -0
  83. atlas/mcp/code-executor/main.py +528 -0
  84. atlas/mcp/code-executor/result_processing.py +276 -0
  85. atlas/mcp/code-executor/script_generation.py +195 -0
  86. atlas/mcp/code-executor/security_checker.py +140 -0
  87. atlas/mcp/corporate_cars/main.py +437 -0
  88. atlas/mcp/csv_reporter/main.py +545 -0
  89. atlas/mcp/duckduckgo/main.py +182 -0
  90. atlas/mcp/elicitation_demo/README.md +171 -0
  91. atlas/mcp/elicitation_demo/main.py +262 -0
  92. atlas/mcp/env-demo/README.md +158 -0
  93. atlas/mcp/env-demo/main.py +199 -0
  94. atlas/mcp/file_size_test/main.py +284 -0
  95. atlas/mcp/filesystem/main.py +348 -0
  96. atlas/mcp/image_demo/main.py +113 -0
  97. atlas/mcp/image_demo/requirements.txt +4 -0
  98. atlas/mcp/logging_demo/README.md +72 -0
  99. atlas/mcp/logging_demo/main.py +103 -0
  100. atlas/mcp/many_tools_demo/main.py +50 -0
  101. atlas/mcp/order_database/__init__.py +0 -0
  102. atlas/mcp/order_database/main.py +369 -0
  103. atlas/mcp/order_database/signal_data.csv +1001 -0
  104. atlas/mcp/pdfbasic/main.py +394 -0
  105. atlas/mcp/pptx_generator/main.py +760 -0
  106. atlas/mcp/pptx_generator/requirements.txt +13 -0
  107. atlas/mcp/pptx_generator/run_test.sh +1 -0
  108. atlas/mcp/pptx_generator/test_pptx_generator_security.py +169 -0
  109. atlas/mcp/progress_demo/main.py +167 -0
  110. atlas/mcp/progress_updates_demo/QUICKSTART.md +273 -0
  111. atlas/mcp/progress_updates_demo/README.md +120 -0
  112. atlas/mcp/progress_updates_demo/main.py +497 -0
  113. atlas/mcp/prompts/main.py +222 -0
  114. atlas/mcp/public_demo/main.py +189 -0
  115. atlas/mcp/sampling_demo/README.md +169 -0
  116. atlas/mcp/sampling_demo/main.py +234 -0
  117. atlas/mcp/thinking/main.py +77 -0
  118. atlas/mcp/tool_planner/main.py +240 -0
  119. atlas/mcp/ui-demo/badmesh.png +0 -0
  120. atlas/mcp/ui-demo/main.py +383 -0
  121. atlas/mcp/ui-demo/templates/button_demo.html +32 -0
  122. atlas/mcp/ui-demo/templates/data_visualization.html +32 -0
  123. atlas/mcp/ui-demo/templates/form_demo.html +28 -0
  124. atlas/mcp/username-override-demo/README.md +320 -0
  125. atlas/mcp/username-override-demo/main.py +308 -0
  126. atlas/modules/__init__.py +0 -0
  127. atlas/modules/config/__init__.py +34 -0
  128. atlas/modules/config/cli.py +231 -0
  129. atlas/modules/config/config_manager.py +1096 -0
  130. atlas/modules/file_storage/__init__.py +22 -0
  131. atlas/modules/file_storage/cli.py +330 -0
  132. atlas/modules/file_storage/content_extractor.py +290 -0
  133. atlas/modules/file_storage/manager.py +295 -0
  134. atlas/modules/file_storage/mock_s3_client.py +402 -0
  135. atlas/modules/file_storage/s3_client.py +417 -0
  136. atlas/modules/llm/__init__.py +19 -0
  137. atlas/modules/llm/caller.py +287 -0
  138. atlas/modules/llm/litellm_caller.py +675 -0
  139. atlas/modules/llm/models.py +19 -0
  140. atlas/modules/mcp_tools/__init__.py +17 -0
  141. atlas/modules/mcp_tools/client.py +2123 -0
  142. atlas/modules/mcp_tools/token_storage.py +556 -0
  143. atlas/modules/prompts/prompt_provider.py +130 -0
  144. atlas/modules/rag/__init__.py +24 -0
  145. atlas/modules/rag/atlas_rag_client.py +336 -0
  146. atlas/modules/rag/client.py +129 -0
  147. atlas/routes/admin_routes.py +865 -0
  148. atlas/routes/config_routes.py +484 -0
  149. atlas/routes/feedback_routes.py +361 -0
  150. atlas/routes/files_routes.py +274 -0
  151. atlas/routes/health_routes.py +40 -0
  152. atlas/routes/mcp_auth_routes.py +223 -0
  153. atlas/server_cli.py +164 -0
  154. atlas/tests/conftest.py +20 -0
  155. atlas/tests/integration/test_mcp_auth_integration.py +152 -0
  156. atlas/tests/manual_test_sampling.py +87 -0
  157. atlas/tests/modules/mcp_tools/test_client_auth.py +226 -0
  158. atlas/tests/modules/mcp_tools/test_client_env.py +191 -0
  159. atlas/tests/test_admin_mcp_server_management_routes.py +141 -0
  160. atlas/tests/test_agent_roa.py +135 -0
  161. atlas/tests/test_app_factory_smoke.py +47 -0
  162. atlas/tests/test_approval_manager.py +439 -0
  163. atlas/tests/test_atlas_client.py +188 -0
  164. atlas/tests/test_atlas_rag_client.py +447 -0
  165. atlas/tests/test_atlas_rag_integration.py +224 -0
  166. atlas/tests/test_attach_file_flow.py +287 -0
  167. atlas/tests/test_auth_utils.py +165 -0
  168. atlas/tests/test_backend_public_url.py +185 -0
  169. atlas/tests/test_banner_logging.py +287 -0
  170. atlas/tests/test_capability_tokens_and_injection.py +203 -0
  171. atlas/tests/test_compliance_level.py +54 -0
  172. atlas/tests/test_compliance_manager.py +253 -0
  173. atlas/tests/test_config_manager.py +617 -0
  174. atlas/tests/test_config_manager_paths.py +12 -0
  175. atlas/tests/test_core_auth.py +18 -0
  176. atlas/tests/test_core_utils.py +190 -0
  177. atlas/tests/test_docker_env_sync.py +202 -0
  178. atlas/tests/test_domain_errors.py +329 -0
  179. atlas/tests/test_domain_whitelist.py +359 -0
  180. atlas/tests/test_elicitation_manager.py +408 -0
  181. atlas/tests/test_elicitation_routing.py +296 -0
  182. atlas/tests/test_env_demo_server.py +88 -0
  183. atlas/tests/test_error_classification.py +113 -0
  184. atlas/tests/test_error_flow_integration.py +116 -0
  185. atlas/tests/test_feedback_routes.py +333 -0
  186. atlas/tests/test_file_content_extraction.py +1134 -0
  187. atlas/tests/test_file_extraction_routes.py +158 -0
  188. atlas/tests/test_file_library.py +107 -0
  189. atlas/tests/test_file_manager_unit.py +18 -0
  190. atlas/tests/test_health_route.py +49 -0
  191. atlas/tests/test_http_client_stub.py +8 -0
  192. atlas/tests/test_imports_smoke.py +30 -0
  193. atlas/tests/test_interfaces_llm_response.py +9 -0
  194. atlas/tests/test_issue_access_denied_fix.py +136 -0
  195. atlas/tests/test_llm_env_expansion.py +836 -0
  196. atlas/tests/test_log_level_sensitive_data.py +285 -0
  197. atlas/tests/test_mcp_auth_routes.py +341 -0
  198. atlas/tests/test_mcp_client_auth.py +331 -0
  199. atlas/tests/test_mcp_data_injection.py +270 -0
  200. atlas/tests/test_mcp_get_authorized_servers.py +95 -0
  201. atlas/tests/test_mcp_hot_reload.py +512 -0
  202. atlas/tests/test_mcp_image_content.py +424 -0
  203. atlas/tests/test_mcp_logging.py +172 -0
  204. atlas/tests/test_mcp_progress_updates.py +313 -0
  205. atlas/tests/test_mcp_prompt_override_system_prompt.py +102 -0
  206. atlas/tests/test_mcp_prompts_server.py +39 -0
  207. atlas/tests/test_mcp_tool_result_parsing.py +296 -0
  208. atlas/tests/test_metrics_logger.py +56 -0
  209. atlas/tests/test_middleware_auth.py +379 -0
  210. atlas/tests/test_prompt_risk_and_acl.py +141 -0
  211. atlas/tests/test_rag_mcp_aggregator.py +204 -0
  212. atlas/tests/test_rag_mcp_service.py +224 -0
  213. atlas/tests/test_rate_limit_middleware.py +45 -0
  214. atlas/tests/test_routes_config_smoke.py +60 -0
  215. atlas/tests/test_routes_files_download_token.py +41 -0
  216. atlas/tests/test_routes_files_health.py +18 -0
  217. atlas/tests/test_runtime_imports.py +53 -0
  218. atlas/tests/test_sampling_integration.py +482 -0
  219. atlas/tests/test_security_admin_routes.py +61 -0
  220. atlas/tests/test_security_capability_tokens.py +65 -0
  221. atlas/tests/test_security_file_stats_scope.py +21 -0
  222. atlas/tests/test_security_header_injection.py +191 -0
  223. atlas/tests/test_security_headers_and_filename.py +63 -0
  224. atlas/tests/test_shared_session_repository.py +101 -0
  225. atlas/tests/test_system_prompt_loading.py +181 -0
  226. atlas/tests/test_token_storage.py +505 -0
  227. atlas/tests/test_tool_approval_config.py +93 -0
  228. atlas/tests/test_tool_approval_utils.py +356 -0
  229. atlas/tests/test_tool_authorization_group_filtering.py +223 -0
  230. atlas/tests/test_tool_details_in_config.py +108 -0
  231. atlas/tests/test_tool_planner.py +300 -0
  232. atlas/tests/test_unified_rag_service.py +398 -0
  233. atlas/tests/test_username_override_in_approval.py +258 -0
  234. atlas/tests/test_websocket_auth_header.py +168 -0
  235. atlas/version.py +6 -0
  236. atlas_chat-0.1.0.data/data/.env.example +253 -0
  237. atlas_chat-0.1.0.data/data/config/defaults/compliance-levels.json +44 -0
  238. atlas_chat-0.1.0.data/data/config/defaults/domain-whitelist.json +123 -0
  239. atlas_chat-0.1.0.data/data/config/defaults/file-extractors.json +74 -0
  240. atlas_chat-0.1.0.data/data/config/defaults/help-config.json +198 -0
  241. atlas_chat-0.1.0.data/data/config/defaults/llmconfig-buggy.yml +11 -0
  242. atlas_chat-0.1.0.data/data/config/defaults/llmconfig.yml +19 -0
  243. atlas_chat-0.1.0.data/data/config/defaults/mcp.json +138 -0
  244. atlas_chat-0.1.0.data/data/config/defaults/rag-sources.json +17 -0
  245. atlas_chat-0.1.0.data/data/config/defaults/splash-config.json +16 -0
  246. atlas_chat-0.1.0.dist-info/METADATA +236 -0
  247. atlas_chat-0.1.0.dist-info/RECORD +250 -0
  248. atlas_chat-0.1.0.dist-info/WHEEL +5 -0
  249. atlas_chat-0.1.0.dist-info/entry_points.txt +4 -0
  250. atlas_chat-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,188 @@
1
+ """Tests for AtlasClient and CLI arg parsing."""
2
+
3
+ import json
4
+ from uuid import uuid4
5
+
6
+ import pytest
7
+
8
+ from atlas.infrastructure.events.cli_event_publisher import CLIEventPublisher
9
+
10
+ # ---------------------------------------------------------------------------
11
+ # CLIEventPublisher unit tests
12
+ # ---------------------------------------------------------------------------
13
+
14
+ class TestCLIEventPublisher:
15
+ """Tests for CLIEventPublisher in collecting mode."""
16
+
17
+ @pytest.fixture
18
+ def publisher(self):
19
+ return CLIEventPublisher(streaming=False)
20
+
21
+ @pytest.mark.asyncio
22
+ async def test_publish_chat_response_accumulates(self, publisher):
23
+ await publisher.publish_chat_response("Hello ")
24
+ await publisher.publish_chat_response("world")
25
+ result = publisher.get_result()
26
+ assert result.message == "Hello world"
27
+
28
+ @pytest.mark.asyncio
29
+ async def test_publish_tool_start_and_complete(self, publisher):
30
+ await publisher.publish_tool_start("my_tool")
31
+ assert len(publisher.get_result().tool_calls) == 1
32
+ assert publisher.get_result().tool_calls[0]["status"] == "started"
33
+
34
+ await publisher.publish_tool_complete("my_tool", result="ok")
35
+ assert publisher.get_result().tool_calls[0]["status"] == "complete"
36
+ assert publisher.get_result().tool_calls[0]["result"] == "ok"
37
+
38
+ @pytest.mark.asyncio
39
+ async def test_publish_files_update(self, publisher):
40
+ await publisher.publish_files_update({"report.pdf": {"key": "s3/report.pdf"}})
41
+ assert "report.pdf" in publisher.get_result().files
42
+
43
+ @pytest.mark.asyncio
44
+ async def test_publish_canvas_content(self, publisher):
45
+ await publisher.publish_canvas_content("<h1>Hi</h1>")
46
+ assert publisher.get_result().canvas_content == "<h1>Hi</h1>"
47
+
48
+ @pytest.mark.asyncio
49
+ async def test_send_json_records_event(self, publisher):
50
+ await publisher.send_json({"type": "custom", "data": 1})
51
+ assert len(publisher.get_result().raw_events) == 1
52
+
53
+ @pytest.mark.asyncio
54
+ async def test_streaming_writes_to_stdout(self, capsys):
55
+ pub = CLIEventPublisher(streaming=True)
56
+ await pub.publish_chat_response("token1")
57
+ captured = capsys.readouterr()
58
+ assert "token1" in captured.out
59
+
60
+ @pytest.mark.asyncio
61
+ async def test_quiet_suppresses_status(self, capsys):
62
+ pub = CLIEventPublisher(streaming=True, quiet=True)
63
+ await pub.publish_tool_start("some_tool")
64
+ captured = capsys.readouterr()
65
+ # quiet mode: no stderr output for tool status
66
+ assert "some_tool" not in captured.err
67
+
68
+
69
+ # ---------------------------------------------------------------------------
70
+ # CLI arg parsing tests
71
+ # ---------------------------------------------------------------------------
72
+
73
+ class TestCLIArgParsing:
74
+ """Tests for atlas_chat_cli argument parsing."""
75
+
76
+ def test_basic_prompt(self):
77
+ from atlas_chat_cli import build_parser
78
+
79
+ parser = build_parser()
80
+ args = parser.parse_args(["Hello world"])
81
+ assert args.prompt == "Hello world"
82
+ assert args.json_output is False
83
+
84
+ def test_tools_flag(self):
85
+ from atlas_chat_cli import build_parser
86
+
87
+ parser = build_parser()
88
+ args = parser.parse_args([
89
+ "Do stuff", "--tools", "toolA,toolB"
90
+ ])
91
+ assert args.tools == "toolA,toolB"
92
+
93
+ def test_json_output_flag(self):
94
+ from atlas_chat_cli import build_parser
95
+
96
+ parser = build_parser()
97
+ args = parser.parse_args(["prompt", "--json"])
98
+ assert args.json_output is True
99
+
100
+ def test_output_file_flag(self):
101
+ from atlas_chat_cli import build_parser
102
+
103
+ parser = build_parser()
104
+ args = parser.parse_args(["prompt", "-o", "/tmp/out.txt"])
105
+ assert args.output == "/tmp/out.txt"
106
+
107
+ def test_list_tools_flag(self):
108
+ from atlas_chat_cli import build_parser
109
+
110
+ parser = build_parser()
111
+ args = parser.parse_args(["--list-tools"])
112
+ assert args.list_tools is True
113
+
114
+ def test_env_file_flag(self):
115
+ from atlas_chat_cli import build_parser
116
+
117
+ parser = build_parser()
118
+ args = parser.parse_args(["prompt", "--env-file", "/path/to/custom.env"])
119
+ assert args.env_file == "/path/to/custom.env"
120
+
121
+ def test_env_file_flag_equals_syntax(self):
122
+ from atlas_chat_cli import build_parser
123
+
124
+ parser = build_parser()
125
+ args = parser.parse_args(["prompt", "--env-file=/other/path.env"])
126
+ assert args.env_file == "/other/path.env"
127
+
128
+ def test_data_sources_flag(self):
129
+ from atlas_chat_cli import build_parser
130
+
131
+ parser = build_parser()
132
+ args = parser.parse_args(["prompt", "--data-sources", "source1,source2"])
133
+ assert args.data_sources == "source1,source2"
134
+
135
+ def test_only_rag_flag(self):
136
+ from atlas_chat_cli import build_parser
137
+
138
+ parser = build_parser()
139
+ args = parser.parse_args(["prompt", "--only-rag"])
140
+ assert args.only_rag is True
141
+
142
+ def test_list_data_sources_flag(self):
143
+ from atlas_chat_cli import build_parser
144
+
145
+ parser = build_parser()
146
+ args = parser.parse_args(["--list-data-sources"])
147
+ assert args.list_data_sources is True
148
+
149
+ def test_combined_tools_and_data_sources(self):
150
+ from atlas_chat_cli import build_parser
151
+
152
+ parser = build_parser()
153
+ args = parser.parse_args([
154
+ "prompt",
155
+ "--tools", "calculator_evaluate",
156
+ "--data-sources", "atlas_rag",
157
+ ])
158
+ assert args.tools == "calculator_evaluate"
159
+ assert args.data_sources == "atlas_rag"
160
+ assert args.only_rag is False
161
+
162
+
163
+ # ---------------------------------------------------------------------------
164
+ # ChatResult serialization tests
165
+ # ---------------------------------------------------------------------------
166
+
167
+ class TestChatResult:
168
+ def test_to_dict(self):
169
+ from atlas_client import ChatResult
170
+
171
+ sid = uuid4()
172
+ result = ChatResult(
173
+ message="Hi",
174
+ tool_calls=[{"tool": "x", "status": "complete"}],
175
+ files={"a.txt": {}},
176
+ canvas_content="<p>c</p>",
177
+ session_id=sid,
178
+ )
179
+ d = result.to_dict()
180
+ assert d["message"] == "Hi"
181
+ assert d["session_id"] == str(sid)
182
+ assert len(d["tool_calls"]) == 1
183
+
184
+ def test_to_dict_json_serializable(self):
185
+ from atlas_client import ChatResult
186
+
187
+ result = ChatResult(message="ok")
188
+ json.dumps(result.to_dict()) # should not raise
@@ -0,0 +1,447 @@
1
+ """Unit tests for AtlasRAGClient."""
2
+
3
+ from unittest.mock import AsyncMock, MagicMock, patch
4
+
5
+ import pytest
6
+ from fastapi import HTTPException
7
+
8
+ from atlas.modules.rag.atlas_rag_client import AtlasRAGClient
9
+ from atlas.modules.rag.client import DataSource, RAGResponse
10
+
11
+
12
+ @pytest.fixture
13
+ def client():
14
+ """Create an AtlasRAGClient instance for testing."""
15
+ return AtlasRAGClient(
16
+ base_url="https://rag-api.example.com",
17
+ bearer_token="test-token",
18
+ default_model="test-model",
19
+ top_k=4,
20
+ timeout=30.0,
21
+ )
22
+
23
+
24
+ @pytest.fixture
25
+ def client_no_auth():
26
+ """Create an AtlasRAGClient without authentication."""
27
+ return AtlasRAGClient(
28
+ base_url="https://rag-api.example.com",
29
+ bearer_token=None,
30
+ )
31
+
32
+
33
+ class TestAtlasRAGClientInit:
34
+ """Tests for AtlasRAGClient initialization."""
35
+
36
+ def test_init_with_all_params(self, client):
37
+ """Test initialization with all parameters."""
38
+ assert client.base_url == "https://rag-api.example.com"
39
+ assert client.bearer_token == "test-token"
40
+ assert client.default_model == "test-model"
41
+ assert client.top_k == 4
42
+ assert client.timeout == 30.0
43
+
44
+ def test_init_strips_trailing_slash(self):
45
+ """Test that trailing slash is stripped from base_url."""
46
+ client = AtlasRAGClient(base_url="https://rag-api.example.com/")
47
+ assert client.base_url == "https://rag-api.example.com"
48
+
49
+ def test_init_defaults(self):
50
+ """Test initialization with default values."""
51
+ client = AtlasRAGClient(base_url="https://rag-api.example.com")
52
+ assert client.bearer_token is None
53
+ assert client.default_model == "openai/gpt-oss-120b"
54
+ assert client.top_k == 4
55
+ assert client.timeout == 60.0
56
+
57
+
58
+ class TestGetHeaders:
59
+ """Tests for header generation."""
60
+
61
+ def test_headers_with_auth(self, client):
62
+ """Test headers include Bearer token when provided."""
63
+ headers = client._get_headers()
64
+ assert headers["Content-Type"] == "application/json"
65
+ assert headers["Authorization"] == "Bearer test-token"
66
+
67
+ def test_headers_without_auth(self, client_no_auth):
68
+ """Test headers without Bearer token when not provided."""
69
+ headers = client_no_auth._get_headers()
70
+ assert headers["Content-Type"] == "application/json"
71
+ assert "Authorization" not in headers
72
+
73
+
74
+ class TestDiscoverDataSources:
75
+ """Tests for discover_data_sources method."""
76
+
77
+ @pytest.mark.asyncio
78
+ async def test_discover_success(self, client):
79
+ """Test successful data source discovery."""
80
+ mock_response = MagicMock()
81
+ mock_response.json.return_value = {
82
+ "user_name": "test-user",
83
+ "accessible_data_sources": [
84
+ {"name": "corpus1", "compliance_level": "CUI"},
85
+ {"name": "corpus2", "compliance_level": "Public"},
86
+ ],
87
+ }
88
+ mock_response.raise_for_status = MagicMock()
89
+
90
+ with patch("httpx.AsyncClient") as mock_client:
91
+ mock_instance = AsyncMock()
92
+ mock_instance.get.return_value = mock_response
93
+ mock_instance.__aenter__.return_value = mock_instance
94
+ mock_instance.__aexit__.return_value = None
95
+ mock_client.return_value = mock_instance
96
+
97
+ result = await client.discover_data_sources("test-user")
98
+
99
+ assert len(result) == 2
100
+ assert isinstance(result[0], DataSource)
101
+ assert result[0].name == "corpus1"
102
+ assert result[0].compliance_level == "CUI"
103
+ assert result[1].name == "corpus2"
104
+ assert result[1].compliance_level == "Public"
105
+
106
+ # Verify correct URL and params
107
+ mock_instance.get.assert_called_once()
108
+ call_args = mock_instance.get.call_args
109
+ assert call_args[0][0] == "https://rag-api.example.com/discover/datasources"
110
+ assert call_args[1]["params"] == {"as_user": "test-user"}
111
+ assert call_args[1]["headers"]["Authorization"] == "Bearer test-token"
112
+
113
+ @pytest.mark.asyncio
114
+ async def test_discover_empty_response(self, client):
115
+ """Test discovery with no accessible data sources."""
116
+ mock_response = MagicMock()
117
+ mock_response.json.return_value = {
118
+ "user_name": "test-user",
119
+ "accessible_data_sources": [],
120
+ }
121
+ mock_response.raise_for_status = MagicMock()
122
+
123
+ with patch("httpx.AsyncClient") as mock_client:
124
+ mock_instance = AsyncMock()
125
+ mock_instance.get.return_value = mock_response
126
+ mock_instance.__aenter__.return_value = mock_instance
127
+ mock_instance.__aexit__.return_value = None
128
+ mock_client.return_value = mock_instance
129
+
130
+ result = await client.discover_data_sources("test-user")
131
+
132
+ assert result == []
133
+
134
+ @pytest.mark.asyncio
135
+ async def test_discover_http_error(self, client):
136
+ """Test discovery handles HTTP errors gracefully."""
137
+ import httpx
138
+
139
+ with patch("httpx.AsyncClient") as mock_client:
140
+ mock_instance = AsyncMock()
141
+ mock_response = MagicMock()
142
+ mock_response.status_code = 500
143
+ mock_response.text = "Internal Server Error"
144
+ mock_instance.get.side_effect = httpx.HTTPStatusError(
145
+ "Error", request=MagicMock(), response=mock_response
146
+ )
147
+ mock_instance.__aenter__.return_value = mock_instance
148
+ mock_instance.__aexit__.return_value = None
149
+ mock_client.return_value = mock_instance
150
+
151
+ result = await client.discover_data_sources("test-user")
152
+
153
+ # Should return empty list on error
154
+ assert result == []
155
+
156
+ @pytest.mark.asyncio
157
+ async def test_discover_request_error(self, client):
158
+ """Test discovery handles network/request errors gracefully."""
159
+ import httpx
160
+
161
+ with patch("httpx.AsyncClient") as mock_client:
162
+ mock_instance = AsyncMock()
163
+ mock_instance.get.side_effect = httpx.RequestError(
164
+ "Connection failed", request=MagicMock()
165
+ )
166
+ mock_instance.__aenter__.return_value = mock_instance
167
+ mock_instance.__aexit__.return_value = None
168
+ mock_client.return_value = mock_instance
169
+
170
+ result = await client.discover_data_sources("test-user")
171
+
172
+ # Should return empty list on error
173
+ assert result == []
174
+
175
+
176
+ class TestQueryRag:
177
+ """Tests for query_rag method."""
178
+
179
+ @pytest.mark.asyncio
180
+ async def test_query_success(self, client):
181
+ """Test successful RAG query."""
182
+ mock_response = MagicMock()
183
+ mock_response.json.return_value = {
184
+ "id": "chatcmpl-xxx",
185
+ "object": "chat.completion",
186
+ "model": "test-model",
187
+ "choices": [
188
+ {
189
+ "index": 0,
190
+ "message": {"role": "assistant", "content": "This is the answer."},
191
+ "finish_reason": "stop",
192
+ }
193
+ ],
194
+ "rag_metadata": {
195
+ "query_processing_time_ms": 150,
196
+ "documents_found": [
197
+ {
198
+ "corpus_id": "corpus1",
199
+ "text": "Some text",
200
+ "confidence_score": 0.95,
201
+ "content_type": "atlas-search",
202
+ "id": "doc-123",
203
+ }
204
+ ],
205
+ "data_sources": ["corpus1"],
206
+ "retrieval_method": "similarity",
207
+ },
208
+ }
209
+ mock_response.raise_for_status = MagicMock()
210
+
211
+ with patch("httpx.AsyncClient") as mock_client:
212
+ mock_instance = AsyncMock()
213
+ mock_instance.post.return_value = mock_response
214
+ mock_instance.__aenter__.return_value = mock_instance
215
+ mock_instance.__aexit__.return_value = None
216
+ mock_client.return_value = mock_instance
217
+
218
+ messages = [{"role": "user", "content": "What is the answer?"}]
219
+ result = await client.query_rag("test-user", "corpus1", messages)
220
+
221
+ assert isinstance(result, RAGResponse)
222
+ assert result.content == "This is the answer."
223
+ assert result.is_completion is True # Should detect chat.completion format
224
+ assert result.metadata is not None
225
+ assert result.metadata.query_processing_time_ms == 150
226
+ assert result.metadata.data_source_name == "corpus1"
227
+ assert result.metadata.retrieval_method == "similarity"
228
+ assert len(result.metadata.documents_found) == 1
229
+ assert result.metadata.documents_found[0].source == "corpus1"
230
+ assert result.metadata.documents_found[0].confidence_score == 0.95
231
+
232
+ # Verify correct URL, params, and payload
233
+ mock_instance.post.assert_called_once()
234
+ call_args = mock_instance.post.call_args
235
+ assert call_args[0][0] == "https://rag-api.example.com/rag/completions"
236
+ assert call_args[1]["params"] == {"as_user": "test-user"}
237
+ payload = call_args[1]["json"]
238
+ assert payload["messages"] == messages
239
+ assert payload["stream"] is False
240
+ assert payload["model"] == "test-model"
241
+ assert payload["top_k"] == 4
242
+ assert payload["corpora"] == ["corpus1"]
243
+
244
+ @pytest.mark.asyncio
245
+ async def test_query_without_metadata(self, client):
246
+ """Test RAG query without metadata in response."""
247
+ mock_response = MagicMock()
248
+ mock_response.json.return_value = {
249
+ "choices": [
250
+ {
251
+ "message": {"role": "assistant", "content": "Simple answer."},
252
+ }
253
+ ],
254
+ }
255
+ mock_response.raise_for_status = MagicMock()
256
+
257
+ with patch("httpx.AsyncClient") as mock_client:
258
+ mock_instance = AsyncMock()
259
+ mock_instance.post.return_value = mock_response
260
+ mock_instance.__aenter__.return_value = mock_instance
261
+ mock_instance.__aexit__.return_value = None
262
+ mock_client.return_value = mock_instance
263
+
264
+ messages = [{"role": "user", "content": "Question"}]
265
+ result = await client.query_rag("test-user", "corpus1", messages)
266
+
267
+ assert result.content == "Simple answer."
268
+ assert result.metadata is None
269
+ assert result.is_completion is False # No 'object' field in response
270
+
271
+ @pytest.mark.asyncio
272
+ async def test_query_403_forbidden(self, client):
273
+ """Test RAG query raises HTTPException on 403."""
274
+ import httpx
275
+
276
+ with patch("httpx.AsyncClient") as mock_client:
277
+ mock_instance = AsyncMock()
278
+ mock_response = MagicMock()
279
+ mock_response.status_code = 403
280
+ mock_response.text = "Access denied"
281
+ mock_instance.post.side_effect = httpx.HTTPStatusError(
282
+ "Error", request=MagicMock(), response=mock_response
283
+ )
284
+ mock_instance.__aenter__.return_value = mock_instance
285
+ mock_instance.__aexit__.return_value = None
286
+ mock_client.return_value = mock_instance
287
+
288
+ messages = [{"role": "user", "content": "Question"}]
289
+ with pytest.raises(HTTPException) as exc_info:
290
+ await client.query_rag("test-user", "corpus1", messages)
291
+
292
+ assert exc_info.value.status_code == 403
293
+ assert "Access denied" in exc_info.value.detail
294
+
295
+ @pytest.mark.asyncio
296
+ async def test_query_404_not_found(self, client):
297
+ """Test RAG query raises HTTPException on 404."""
298
+ import httpx
299
+
300
+ with patch("httpx.AsyncClient") as mock_client:
301
+ mock_instance = AsyncMock()
302
+ mock_response = MagicMock()
303
+ mock_response.status_code = 404
304
+ mock_response.text = "Not found"
305
+ mock_instance.post.side_effect = httpx.HTTPStatusError(
306
+ "Error", request=MagicMock(), response=mock_response
307
+ )
308
+ mock_instance.__aenter__.return_value = mock_instance
309
+ mock_instance.__aexit__.return_value = None
310
+ mock_client.return_value = mock_instance
311
+
312
+ messages = [{"role": "user", "content": "Question"}]
313
+ with pytest.raises(HTTPException) as exc_info:
314
+ await client.query_rag("test-user", "corpus1", messages)
315
+
316
+ assert exc_info.value.status_code == 404
317
+ assert "not found" in exc_info.value.detail
318
+
319
+ @pytest.mark.asyncio
320
+ async def test_query_500_error(self, client):
321
+ """Test RAG query raises HTTPException on 500."""
322
+ import httpx
323
+
324
+ with patch("httpx.AsyncClient") as mock_client:
325
+ mock_instance = AsyncMock()
326
+ mock_response = MagicMock()
327
+ mock_response.status_code = 500
328
+ mock_response.text = "Internal error"
329
+ mock_instance.post.side_effect = httpx.HTTPStatusError(
330
+ "Error", request=MagicMock(), response=mock_response
331
+ )
332
+ mock_instance.__aenter__.return_value = mock_instance
333
+ mock_instance.__aexit__.return_value = None
334
+ mock_client.return_value = mock_instance
335
+
336
+ messages = [{"role": "user", "content": "Question"}]
337
+ with pytest.raises(HTTPException) as exc_info:
338
+ await client.query_rag("test-user", "corpus1", messages)
339
+
340
+ assert exc_info.value.status_code == 500
341
+
342
+ @pytest.mark.asyncio
343
+ async def test_query_connection_error(self, client):
344
+ """Test RAG query raises HTTPException on connection error."""
345
+ import httpx
346
+
347
+ with patch("httpx.AsyncClient") as mock_client:
348
+ mock_instance = AsyncMock()
349
+ mock_instance.post.side_effect = httpx.RequestError(
350
+ "Connection failed", request=MagicMock()
351
+ )
352
+ mock_instance.__aenter__.return_value = mock_instance
353
+ mock_instance.__aexit__.return_value = None
354
+ mock_client.return_value = mock_instance
355
+
356
+ messages = [{"role": "user", "content": "Question"}]
357
+ with pytest.raises(HTTPException) as exc_info:
358
+ await client.query_rag("test-user", "corpus1", messages)
359
+
360
+ assert exc_info.value.status_code == 500
361
+ assert "connect" in exc_info.value.detail.lower()
362
+
363
+
364
+ class TestParseRagMetadata:
365
+ """Tests for _parse_rag_metadata method."""
366
+
367
+ def test_parse_full_metadata(self, client):
368
+ """Test parsing complete RAG metadata."""
369
+ data = {
370
+ "rag_metadata": {
371
+ "query_processing_time_ms": 200,
372
+ "documents_found": [
373
+ {
374
+ "corpus_id": "test-corpus",
375
+ "content_type": "document",
376
+ "confidence_score": 0.85,
377
+ "id": "doc-456",
378
+ "last_modified": "2025-01-01T00:00:00Z",
379
+ }
380
+ ],
381
+ "data_sources": ["test-corpus"],
382
+ "retrieval_method": "hybrid",
383
+ }
384
+ }
385
+
386
+ result = client._parse_rag_metadata(data, "fallback-corpus")
387
+
388
+ assert result is not None
389
+ assert result.query_processing_time_ms == 200
390
+ assert result.data_source_name == "test-corpus"
391
+ assert result.retrieval_method == "hybrid"
392
+ assert len(result.documents_found) == 1
393
+ assert result.documents_found[0].source == "test-corpus"
394
+ assert result.documents_found[0].chunk_id == "doc-456"
395
+ assert result.documents_found[0].last_modified == "2025-01-01T00:00:00Z"
396
+
397
+ def test_parse_metadata_with_fallback_datasource(self, client):
398
+ """Test metadata parsing uses fallback when data_sources empty."""
399
+ data = {
400
+ "rag_metadata": {
401
+ "query_processing_time_ms": 100,
402
+ "documents_found": [],
403
+ "data_sources": [],
404
+ "retrieval_method": "similarity",
405
+ }
406
+ }
407
+
408
+ result = client._parse_rag_metadata(data, "fallback-corpus")
409
+
410
+ assert result.data_source_name == "fallback-corpus"
411
+
412
+ def test_parse_no_metadata(self, client):
413
+ """Test parsing when no rag_metadata present."""
414
+ data = {"choices": []}
415
+ result = client._parse_rag_metadata(data, "corpus")
416
+ assert result is None
417
+
418
+ def test_parse_empty_metadata(self, client):
419
+ """Test parsing when rag_metadata is empty."""
420
+ data = {"rag_metadata": None}
421
+ result = client._parse_rag_metadata(data, "corpus")
422
+ assert result is None
423
+
424
+
425
+ class TestFactoryFunction:
426
+ """Tests for create_atlas_rag_client_from_config factory."""
427
+
428
+ def test_factory_creates_client_from_config(self):
429
+ """Test factory function creates properly configured client."""
430
+ from atlas.modules.rag.atlas_rag_client import create_atlas_rag_client_from_config
431
+
432
+ mock_settings = MagicMock()
433
+ mock_settings.external_rag_url = "https://test-api.example.com"
434
+ mock_settings.external_rag_bearer_token = "factory-token"
435
+ mock_settings.external_rag_default_model = "factory-model"
436
+ mock_settings.external_rag_top_k = 8
437
+
438
+ mock_config_manager = MagicMock()
439
+ mock_config_manager.app_settings = mock_settings
440
+
441
+ client = create_atlas_rag_client_from_config(mock_config_manager)
442
+
443
+ assert isinstance(client, AtlasRAGClient)
444
+ assert client.base_url == "https://test-api.example.com"
445
+ assert client.bearer_token == "factory-token"
446
+ assert client.default_model == "factory-model"
447
+ assert client.top_k == 8