atlas-chat 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. atlas/__init__.py +40 -0
  2. atlas/application/__init__.py +7 -0
  3. atlas/application/chat/__init__.py +7 -0
  4. atlas/application/chat/agent/__init__.py +10 -0
  5. atlas/application/chat/agent/act_loop.py +179 -0
  6. atlas/application/chat/agent/factory.py +142 -0
  7. atlas/application/chat/agent/protocols.py +46 -0
  8. atlas/application/chat/agent/react_loop.py +338 -0
  9. atlas/application/chat/agent/think_act_loop.py +171 -0
  10. atlas/application/chat/approval_manager.py +151 -0
  11. atlas/application/chat/elicitation_manager.py +191 -0
  12. atlas/application/chat/events/__init__.py +1 -0
  13. atlas/application/chat/events/agent_event_relay.py +112 -0
  14. atlas/application/chat/modes/__init__.py +1 -0
  15. atlas/application/chat/modes/agent.py +125 -0
  16. atlas/application/chat/modes/plain.py +74 -0
  17. atlas/application/chat/modes/rag.py +81 -0
  18. atlas/application/chat/modes/tools.py +179 -0
  19. atlas/application/chat/orchestrator.py +213 -0
  20. atlas/application/chat/policies/__init__.py +1 -0
  21. atlas/application/chat/policies/tool_authorization.py +99 -0
  22. atlas/application/chat/preprocessors/__init__.py +1 -0
  23. atlas/application/chat/preprocessors/message_builder.py +92 -0
  24. atlas/application/chat/preprocessors/prompt_override_service.py +104 -0
  25. atlas/application/chat/service.py +454 -0
  26. atlas/application/chat/utilities/__init__.py +6 -0
  27. atlas/application/chat/utilities/error_handler.py +367 -0
  28. atlas/application/chat/utilities/event_notifier.py +546 -0
  29. atlas/application/chat/utilities/file_processor.py +613 -0
  30. atlas/application/chat/utilities/tool_executor.py +789 -0
  31. atlas/atlas_chat_cli.py +347 -0
  32. atlas/atlas_client.py +238 -0
  33. atlas/core/__init__.py +0 -0
  34. atlas/core/auth.py +205 -0
  35. atlas/core/authorization_manager.py +27 -0
  36. atlas/core/capabilities.py +123 -0
  37. atlas/core/compliance.py +215 -0
  38. atlas/core/domain_whitelist.py +147 -0
  39. atlas/core/domain_whitelist_middleware.py +82 -0
  40. atlas/core/http_client.py +28 -0
  41. atlas/core/log_sanitizer.py +102 -0
  42. atlas/core/metrics_logger.py +59 -0
  43. atlas/core/middleware.py +131 -0
  44. atlas/core/otel_config.py +242 -0
  45. atlas/core/prompt_risk.py +200 -0
  46. atlas/core/rate_limit.py +0 -0
  47. atlas/core/rate_limit_middleware.py +64 -0
  48. atlas/core/security_headers_middleware.py +51 -0
  49. atlas/domain/__init__.py +37 -0
  50. atlas/domain/chat/__init__.py +1 -0
  51. atlas/domain/chat/dtos.py +85 -0
  52. atlas/domain/errors.py +96 -0
  53. atlas/domain/messages/__init__.py +12 -0
  54. atlas/domain/messages/models.py +160 -0
  55. atlas/domain/rag_mcp_service.py +664 -0
  56. atlas/domain/sessions/__init__.py +7 -0
  57. atlas/domain/sessions/models.py +36 -0
  58. atlas/domain/unified_rag_service.py +371 -0
  59. atlas/infrastructure/__init__.py +10 -0
  60. atlas/infrastructure/app_factory.py +135 -0
  61. atlas/infrastructure/events/__init__.py +1 -0
  62. atlas/infrastructure/events/cli_event_publisher.py +140 -0
  63. atlas/infrastructure/events/websocket_publisher.py +140 -0
  64. atlas/infrastructure/sessions/in_memory_repository.py +56 -0
  65. atlas/infrastructure/transport/__init__.py +7 -0
  66. atlas/infrastructure/transport/websocket_connection_adapter.py +33 -0
  67. atlas/init_cli.py +226 -0
  68. atlas/interfaces/__init__.py +15 -0
  69. atlas/interfaces/events.py +134 -0
  70. atlas/interfaces/llm.py +54 -0
  71. atlas/interfaces/rag.py +40 -0
  72. atlas/interfaces/sessions.py +75 -0
  73. atlas/interfaces/tools.py +57 -0
  74. atlas/interfaces/transport.py +24 -0
  75. atlas/main.py +564 -0
  76. atlas/mcp/api_key_demo/README.md +76 -0
  77. atlas/mcp/api_key_demo/main.py +172 -0
  78. atlas/mcp/api_key_demo/run.sh +56 -0
  79. atlas/mcp/basictable/main.py +147 -0
  80. atlas/mcp/calculator/main.py +149 -0
  81. atlas/mcp/code-executor/execution_engine.py +98 -0
  82. atlas/mcp/code-executor/execution_environment.py +95 -0
  83. atlas/mcp/code-executor/main.py +528 -0
  84. atlas/mcp/code-executor/result_processing.py +276 -0
  85. atlas/mcp/code-executor/script_generation.py +195 -0
  86. atlas/mcp/code-executor/security_checker.py +140 -0
  87. atlas/mcp/corporate_cars/main.py +437 -0
  88. atlas/mcp/csv_reporter/main.py +545 -0
  89. atlas/mcp/duckduckgo/main.py +182 -0
  90. atlas/mcp/elicitation_demo/README.md +171 -0
  91. atlas/mcp/elicitation_demo/main.py +262 -0
  92. atlas/mcp/env-demo/README.md +158 -0
  93. atlas/mcp/env-demo/main.py +199 -0
  94. atlas/mcp/file_size_test/main.py +284 -0
  95. atlas/mcp/filesystem/main.py +348 -0
  96. atlas/mcp/image_demo/main.py +113 -0
  97. atlas/mcp/image_demo/requirements.txt +4 -0
  98. atlas/mcp/logging_demo/README.md +72 -0
  99. atlas/mcp/logging_demo/main.py +103 -0
  100. atlas/mcp/many_tools_demo/main.py +50 -0
  101. atlas/mcp/order_database/__init__.py +0 -0
  102. atlas/mcp/order_database/main.py +369 -0
  103. atlas/mcp/order_database/signal_data.csv +1001 -0
  104. atlas/mcp/pdfbasic/main.py +394 -0
  105. atlas/mcp/pptx_generator/main.py +760 -0
  106. atlas/mcp/pptx_generator/requirements.txt +13 -0
  107. atlas/mcp/pptx_generator/run_test.sh +1 -0
  108. atlas/mcp/pptx_generator/test_pptx_generator_security.py +169 -0
  109. atlas/mcp/progress_demo/main.py +167 -0
  110. atlas/mcp/progress_updates_demo/QUICKSTART.md +273 -0
  111. atlas/mcp/progress_updates_demo/README.md +120 -0
  112. atlas/mcp/progress_updates_demo/main.py +497 -0
  113. atlas/mcp/prompts/main.py +222 -0
  114. atlas/mcp/public_demo/main.py +189 -0
  115. atlas/mcp/sampling_demo/README.md +169 -0
  116. atlas/mcp/sampling_demo/main.py +234 -0
  117. atlas/mcp/thinking/main.py +77 -0
  118. atlas/mcp/tool_planner/main.py +240 -0
  119. atlas/mcp/ui-demo/badmesh.png +0 -0
  120. atlas/mcp/ui-demo/main.py +383 -0
  121. atlas/mcp/ui-demo/templates/button_demo.html +32 -0
  122. atlas/mcp/ui-demo/templates/data_visualization.html +32 -0
  123. atlas/mcp/ui-demo/templates/form_demo.html +28 -0
  124. atlas/mcp/username-override-demo/README.md +320 -0
  125. atlas/mcp/username-override-demo/main.py +308 -0
  126. atlas/modules/__init__.py +0 -0
  127. atlas/modules/config/__init__.py +34 -0
  128. atlas/modules/config/cli.py +231 -0
  129. atlas/modules/config/config_manager.py +1096 -0
  130. atlas/modules/file_storage/__init__.py +22 -0
  131. atlas/modules/file_storage/cli.py +330 -0
  132. atlas/modules/file_storage/content_extractor.py +290 -0
  133. atlas/modules/file_storage/manager.py +295 -0
  134. atlas/modules/file_storage/mock_s3_client.py +402 -0
  135. atlas/modules/file_storage/s3_client.py +417 -0
  136. atlas/modules/llm/__init__.py +19 -0
  137. atlas/modules/llm/caller.py +287 -0
  138. atlas/modules/llm/litellm_caller.py +675 -0
  139. atlas/modules/llm/models.py +19 -0
  140. atlas/modules/mcp_tools/__init__.py +17 -0
  141. atlas/modules/mcp_tools/client.py +2123 -0
  142. atlas/modules/mcp_tools/token_storage.py +556 -0
  143. atlas/modules/prompts/prompt_provider.py +130 -0
  144. atlas/modules/rag/__init__.py +24 -0
  145. atlas/modules/rag/atlas_rag_client.py +336 -0
  146. atlas/modules/rag/client.py +129 -0
  147. atlas/routes/admin_routes.py +865 -0
  148. atlas/routes/config_routes.py +484 -0
  149. atlas/routes/feedback_routes.py +361 -0
  150. atlas/routes/files_routes.py +274 -0
  151. atlas/routes/health_routes.py +40 -0
  152. atlas/routes/mcp_auth_routes.py +223 -0
  153. atlas/server_cli.py +164 -0
  154. atlas/tests/conftest.py +20 -0
  155. atlas/tests/integration/test_mcp_auth_integration.py +152 -0
  156. atlas/tests/manual_test_sampling.py +87 -0
  157. atlas/tests/modules/mcp_tools/test_client_auth.py +226 -0
  158. atlas/tests/modules/mcp_tools/test_client_env.py +191 -0
  159. atlas/tests/test_admin_mcp_server_management_routes.py +141 -0
  160. atlas/tests/test_agent_roa.py +135 -0
  161. atlas/tests/test_app_factory_smoke.py +47 -0
  162. atlas/tests/test_approval_manager.py +439 -0
  163. atlas/tests/test_atlas_client.py +188 -0
  164. atlas/tests/test_atlas_rag_client.py +447 -0
  165. atlas/tests/test_atlas_rag_integration.py +224 -0
  166. atlas/tests/test_attach_file_flow.py +287 -0
  167. atlas/tests/test_auth_utils.py +165 -0
  168. atlas/tests/test_backend_public_url.py +185 -0
  169. atlas/tests/test_banner_logging.py +287 -0
  170. atlas/tests/test_capability_tokens_and_injection.py +203 -0
  171. atlas/tests/test_compliance_level.py +54 -0
  172. atlas/tests/test_compliance_manager.py +253 -0
  173. atlas/tests/test_config_manager.py +617 -0
  174. atlas/tests/test_config_manager_paths.py +12 -0
  175. atlas/tests/test_core_auth.py +18 -0
  176. atlas/tests/test_core_utils.py +190 -0
  177. atlas/tests/test_docker_env_sync.py +202 -0
  178. atlas/tests/test_domain_errors.py +329 -0
  179. atlas/tests/test_domain_whitelist.py +359 -0
  180. atlas/tests/test_elicitation_manager.py +408 -0
  181. atlas/tests/test_elicitation_routing.py +296 -0
  182. atlas/tests/test_env_demo_server.py +88 -0
  183. atlas/tests/test_error_classification.py +113 -0
  184. atlas/tests/test_error_flow_integration.py +116 -0
  185. atlas/tests/test_feedback_routes.py +333 -0
  186. atlas/tests/test_file_content_extraction.py +1134 -0
  187. atlas/tests/test_file_extraction_routes.py +158 -0
  188. atlas/tests/test_file_library.py +107 -0
  189. atlas/tests/test_file_manager_unit.py +18 -0
  190. atlas/tests/test_health_route.py +49 -0
  191. atlas/tests/test_http_client_stub.py +8 -0
  192. atlas/tests/test_imports_smoke.py +30 -0
  193. atlas/tests/test_interfaces_llm_response.py +9 -0
  194. atlas/tests/test_issue_access_denied_fix.py +136 -0
  195. atlas/tests/test_llm_env_expansion.py +836 -0
  196. atlas/tests/test_log_level_sensitive_data.py +285 -0
  197. atlas/tests/test_mcp_auth_routes.py +341 -0
  198. atlas/tests/test_mcp_client_auth.py +331 -0
  199. atlas/tests/test_mcp_data_injection.py +270 -0
  200. atlas/tests/test_mcp_get_authorized_servers.py +95 -0
  201. atlas/tests/test_mcp_hot_reload.py +512 -0
  202. atlas/tests/test_mcp_image_content.py +424 -0
  203. atlas/tests/test_mcp_logging.py +172 -0
  204. atlas/tests/test_mcp_progress_updates.py +313 -0
  205. atlas/tests/test_mcp_prompt_override_system_prompt.py +102 -0
  206. atlas/tests/test_mcp_prompts_server.py +39 -0
  207. atlas/tests/test_mcp_tool_result_parsing.py +296 -0
  208. atlas/tests/test_metrics_logger.py +56 -0
  209. atlas/tests/test_middleware_auth.py +379 -0
  210. atlas/tests/test_prompt_risk_and_acl.py +141 -0
  211. atlas/tests/test_rag_mcp_aggregator.py +204 -0
  212. atlas/tests/test_rag_mcp_service.py +224 -0
  213. atlas/tests/test_rate_limit_middleware.py +45 -0
  214. atlas/tests/test_routes_config_smoke.py +60 -0
  215. atlas/tests/test_routes_files_download_token.py +41 -0
  216. atlas/tests/test_routes_files_health.py +18 -0
  217. atlas/tests/test_runtime_imports.py +53 -0
  218. atlas/tests/test_sampling_integration.py +482 -0
  219. atlas/tests/test_security_admin_routes.py +61 -0
  220. atlas/tests/test_security_capability_tokens.py +65 -0
  221. atlas/tests/test_security_file_stats_scope.py +21 -0
  222. atlas/tests/test_security_header_injection.py +191 -0
  223. atlas/tests/test_security_headers_and_filename.py +63 -0
  224. atlas/tests/test_shared_session_repository.py +101 -0
  225. atlas/tests/test_system_prompt_loading.py +181 -0
  226. atlas/tests/test_token_storage.py +505 -0
  227. atlas/tests/test_tool_approval_config.py +93 -0
  228. atlas/tests/test_tool_approval_utils.py +356 -0
  229. atlas/tests/test_tool_authorization_group_filtering.py +223 -0
  230. atlas/tests/test_tool_details_in_config.py +108 -0
  231. atlas/tests/test_tool_planner.py +300 -0
  232. atlas/tests/test_unified_rag_service.py +398 -0
  233. atlas/tests/test_username_override_in_approval.py +258 -0
  234. atlas/tests/test_websocket_auth_header.py +168 -0
  235. atlas/version.py +6 -0
  236. atlas_chat-0.1.0.data/data/.env.example +253 -0
  237. atlas_chat-0.1.0.data/data/config/defaults/compliance-levels.json +44 -0
  238. atlas_chat-0.1.0.data/data/config/defaults/domain-whitelist.json +123 -0
  239. atlas_chat-0.1.0.data/data/config/defaults/file-extractors.json +74 -0
  240. atlas_chat-0.1.0.data/data/config/defaults/help-config.json +198 -0
  241. atlas_chat-0.1.0.data/data/config/defaults/llmconfig-buggy.yml +11 -0
  242. atlas_chat-0.1.0.data/data/config/defaults/llmconfig.yml +19 -0
  243. atlas_chat-0.1.0.data/data/config/defaults/mcp.json +138 -0
  244. atlas_chat-0.1.0.data/data/config/defaults/rag-sources.json +17 -0
  245. atlas_chat-0.1.0.data/data/config/defaults/splash-config.json +16 -0
  246. atlas_chat-0.1.0.dist-info/METADATA +236 -0
  247. atlas_chat-0.1.0.dist-info/RECORD +250 -0
  248. atlas_chat-0.1.0.dist-info/WHEEL +5 -0
  249. atlas_chat-0.1.0.dist-info/entry_points.txt +4 -0
  250. atlas_chat-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,675 @@
1
+ """
2
+ LiteLLM-based LLM calling interface that handles all modes of LLM interaction.
3
+
4
+ This module provides a clean interface for calling LLMs using LiteLLM in different modes:
5
+ - Plain LLM calls (no tools)
6
+ - LLM calls with RAG integration
7
+ - LLM calls with tool support
8
+ - LLM calls with both RAG and tools
9
+
10
+ LiteLLM provides unified access to multiple LLM providers with automatic
11
+ fallbacks, cost tracking, and provider-specific optimizations.
12
+ """
13
+
14
+ import asyncio
15
+ import logging
16
+ import os
17
+ from typing import Any, Dict, List, Optional, Tuple
18
+
19
+ import litellm
20
+ from litellm import acompletion
21
+
22
+ from atlas.core.metrics_logger import log_metric
23
+ from atlas.modules.config.config_manager import resolve_env_var
24
+
25
+ from .models import LLMResponse
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Configure LiteLLM settings
30
+ litellm.drop_params = True # Drop unsupported params instead of erroring
31
+
32
+
33
+ class LiteLLMCaller:
34
+ """Clean interface for all LLM calling patterns using LiteLLM.
35
+
36
+ Note: this class may set provider-specific LLM API key environment
37
+ variables (for example ``OPENAI_API_KEY``) to maintain compatibility
38
+ with LiteLLM's internal provider detection. These mutations are
39
+ best-effort only and are not intended to provide strong isolation
40
+ guarantees in multi-tenant or highly concurrent environments.
41
+ """
42
+
43
+ def __init__(self, llm_config=None, debug_mode: bool = False, rag_service=None):
44
+ """Initialize with optional config dependency injection.
45
+
46
+ Args:
47
+ llm_config: LLM configuration object
48
+ debug_mode: Enable verbose LiteLLM logging (overridden by feature flag)
49
+ rag_service: UnifiedRAGService for RAG-augmented calls
50
+ """
51
+ if llm_config is None:
52
+ from atlas.modules.config import config_manager
53
+ self.llm_config = config_manager.llm_config
54
+ else:
55
+ self.llm_config = llm_config
56
+
57
+ # Store RAG service for RAG queries
58
+ self._rag_service = rag_service
59
+
60
+ # Set litellm verbosity based on debug mode, but respect the suppress feature flag
61
+ # The feature flag takes precedence - if suppression is enabled, never set verbose
62
+ from atlas.modules.config.config_manager import get_app_settings
63
+ app_settings = get_app_settings()
64
+ if app_settings.feature_suppress_litellm_logging:
65
+ litellm.set_verbose = False
66
+ else:
67
+ litellm.set_verbose = debug_mode
68
+
69
+ @staticmethod
70
+ def _parse_qualified_data_source(qualified_data_source: str) -> str:
71
+ """Extract corpus name from a qualified data source identifier.
72
+
73
+ Qualified data sources have format "server:source_id" (e.g., "atlas_rag:technical-docs").
74
+ The prefix is used for routing in multi-RAG setups, but the RAG API expects just
75
+ the corpus name.
76
+
77
+ Args:
78
+ qualified_data_source: Data source ID, optionally prefixed with server name.
79
+
80
+ Returns:
81
+ The corpus/source name without the server prefix.
82
+ """
83
+ if ":" in qualified_data_source:
84
+ _, data_source = qualified_data_source.split(":", 1)
85
+ logger.debug("Stripped RAG server prefix: %s -> %s", qualified_data_source, data_source)
86
+ return data_source
87
+ return qualified_data_source
88
+
89
+ def _build_rag_completion_response(
90
+ self,
91
+ rag_response,
92
+ display_source: str
93
+ ) -> str:
94
+ """Build formatted response for direct RAG completions.
95
+
96
+ Args:
97
+ rag_response: RAGResponse object with is_completion=True
98
+ display_source: Display name of the data source
99
+
100
+ Returns:
101
+ Formatted response string with RAG completion note and metadata
102
+ """
103
+ response_parts = []
104
+ response_parts.append(f"*Response from {display_source} (RAG completions endpoint):*\n")
105
+ response_parts.append(rag_response.content)
106
+
107
+ # Append metadata if available
108
+ if rag_response.metadata:
109
+ metadata_summary = self._format_rag_metadata(rag_response.metadata)
110
+ if metadata_summary and metadata_summary != "Metadata unavailable":
111
+ response_parts.append(f"\n\n---\n**RAG Sources & Processing Info:**\n{metadata_summary}")
112
+
113
+ return "\n".join(response_parts)
114
+
115
+ async def _query_all_rag_sources(
116
+ self,
117
+ data_sources: List[str],
118
+ rag_service,
119
+ user_email: str,
120
+ messages: List[Dict[str, str]],
121
+ ) -> List[Tuple[str, Any]]:
122
+ """Query all RAG data sources in parallel.
123
+
124
+ Args:
125
+ data_sources: Qualified data source identifiers (server:source_id).
126
+ rag_service: UnifiedRAGService instance.
127
+ user_email: User email for access control.
128
+ messages: Conversation messages for RAG context.
129
+
130
+ Returns:
131
+ List of (display_source, rag_response) tuples, one per source.
132
+ """
133
+
134
+ async def _query_single(qualified_source: str):
135
+ display = self._parse_qualified_data_source(qualified_source)
136
+ response = await rag_service.query_rag(user_email, qualified_source, messages)
137
+ return (display, response)
138
+
139
+ results = await asyncio.gather(
140
+ *[_query_single(src) for src in data_sources],
141
+ return_exceptions=True,
142
+ )
143
+
144
+ successful: List[Tuple[str, Any]] = []
145
+ for src, result in zip(data_sources, results):
146
+ if isinstance(result, Exception):
147
+ logger.error("[RAG] Failed to query source %s: %s", src, result)
148
+ else:
149
+ successful.append(result)
150
+
151
+ return successful
152
+
153
+ @staticmethod
154
+ def _combine_rag_contexts(
155
+ source_responses: List[Tuple[str, Any]],
156
+ ) -> Tuple[str, Optional[Any]]:
157
+ """Combine RAG responses from multiple sources into a single context block.
158
+
159
+ Args:
160
+ source_responses: List of (display_source, rag_response) tuples.
161
+
162
+ Returns:
163
+ (combined_content, merged_metadata) -- merged_metadata is the metadata
164
+ from the first source that has it, or None.
165
+ """
166
+ parts: List[str] = []
167
+ merged_metadata = None
168
+
169
+ for display_source, rag_response in source_responses:
170
+ content = rag_response.content if rag_response.content else ""
171
+ parts.append(f"### Context from {display_source}:\n{content}")
172
+ if rag_response.metadata and merged_metadata is None:
173
+ merged_metadata = rag_response.metadata
174
+
175
+ combined = "\n\n".join(parts)
176
+ return combined, merged_metadata
177
+
178
+ def _get_litellm_model_name(self, model_name: str) -> str:
179
+ """Convert internal model name to LiteLLM compatible format."""
180
+ if model_name not in self.llm_config.models:
181
+ raise ValueError(f"Model {model_name} not found in configuration")
182
+
183
+ model_config = self.llm_config.models[model_name]
184
+ model_id = model_config.model_name
185
+
186
+ # Map common providers to LiteLLM format
187
+ if "openrouter" in model_config.model_url:
188
+ return f"openrouter/{model_id}"
189
+ elif "openai" in model_config.model_url:
190
+ return f"openai/{model_id}"
191
+ elif "anthropic" in model_config.model_url:
192
+ return f"anthropic/{model_id}"
193
+ elif "google" in model_config.model_url:
194
+ return f"google/{model_id}"
195
+ elif "cerebras" in model_config.model_url:
196
+ return f"cerebras/{model_id}"
197
+ else:
198
+ # For custom endpoints, use the model_id directly
199
+ return model_id
200
+
201
+ def _get_model_kwargs(self, model_name: str, temperature: Optional[float] = None) -> Dict[str, Any]:
202
+ """Get LiteLLM kwargs for a specific model."""
203
+ if model_name not in self.llm_config.models:
204
+ raise ValueError(f"Model {model_name} not found in configuration")
205
+
206
+ model_config = self.llm_config.models[model_name]
207
+ kwargs = {
208
+ "max_tokens": model_config.max_tokens or 1000,
209
+ }
210
+
211
+ # Use provided temperature or fall back to config temperature
212
+ if temperature is not None:
213
+ kwargs["temperature"] = temperature
214
+ else:
215
+ kwargs["temperature"] = model_config.temperature or 0.7
216
+
217
+ # Set API key - resolve environment variables
218
+ try:
219
+ api_key = resolve_env_var(model_config.api_key)
220
+ except ValueError as e:
221
+ logger.error(f"Failed to resolve API key for model {model_name}: {e}")
222
+ raise
223
+
224
+ if api_key:
225
+ # Always pass api_key to LiteLLM for all providers
226
+ kwargs["api_key"] = api_key
227
+
228
+ # Additionally set provider-specific env vars for LiteLLM's internal logic
229
+ def _set_env_var_if_needed(env_key: str, value: str) -> None:
230
+ existing = os.environ.get(env_key)
231
+ if existing is None:
232
+ os.environ[env_key] = value
233
+ elif existing != value:
234
+ logger.warning(
235
+ "Overwriting existing environment variable %s for model %s",
236
+ env_key,
237
+ model_name,
238
+ )
239
+ os.environ[env_key] = value
240
+
241
+ if "openrouter" in model_config.model_url:
242
+ _set_env_var_if_needed("OPENROUTER_API_KEY", api_key)
243
+ elif "openai" in model_config.model_url:
244
+ _set_env_var_if_needed("OPENAI_API_KEY", api_key)
245
+ elif "anthropic" in model_config.model_url:
246
+ _set_env_var_if_needed("ANTHROPIC_API_KEY", api_key)
247
+ elif "google" in model_config.model_url:
248
+ _set_env_var_if_needed("GOOGLE_API_KEY", api_key)
249
+ elif "cerebras" in model_config.model_url:
250
+ _set_env_var_if_needed("CEREBRAS_API_KEY", api_key)
251
+ else:
252
+ # Custom endpoint - set OPENAI_API_KEY as fallback for
253
+ # OpenAI-compatible endpoints. This is a heuristic and
254
+ # only updates the env var if it is unset or already
255
+ # matches the same value.
256
+ _set_env_var_if_needed("OPENAI_API_KEY", api_key)
257
+
258
+ # Set custom API base for non-standard endpoints
259
+ if hasattr(model_config, 'model_url') and model_config.model_url:
260
+ if not any(provider in model_config.model_url for provider in ["openrouter", "api.openai.com", "api.anthropic.com", "api.cerebras.ai"]):
261
+ kwargs["api_base"] = model_config.model_url
262
+
263
+ # Handle extra headers with environment variable expansion
264
+ if model_config.extra_headers:
265
+ extra_headers_resolved = {}
266
+ for header_key, header_value in model_config.extra_headers.items():
267
+ try:
268
+ resolved_value = resolve_env_var(header_value)
269
+ extra_headers_resolved[header_key] = resolved_value
270
+ except ValueError as e:
271
+ logger.error(f"Failed to resolve extra header '{header_key}' for model {model_name}: {e}")
272
+ raise
273
+ kwargs["extra_headers"] = extra_headers_resolved
274
+
275
+ return kwargs
276
+
277
+ async def call_plain(
278
+ self,
279
+ model_name: str,
280
+ messages: List[Dict[str, str]],
281
+ temperature: Optional[float] = None,
282
+ max_tokens: Optional[int] = None,
283
+ user_email: Optional[str] = None
284
+ ) -> str:
285
+ """Plain LLM call - no tools, no RAG.
286
+
287
+ Args:
288
+ model_name: Name of the model to use
289
+ messages: List of message dicts with 'role' and 'content'
290
+ temperature: Optional temperature override (uses config default if None)
291
+ max_tokens: Optional max_tokens override (uses config default if None)
292
+ user_email: Optional user email for metrics logging
293
+ """
294
+ litellm_model = self._get_litellm_model_name(model_name)
295
+ model_kwargs = self._get_model_kwargs(model_name, temperature)
296
+
297
+ # Override max_tokens if provided
298
+ if max_tokens is not None:
299
+ model_kwargs["max_tokens"] = max_tokens
300
+
301
+ try:
302
+ total_chars = sum(len(str(msg.get('content', ''))) for msg in messages)
303
+ logger.info(f"Plain LLM call: {len(messages)} messages, {total_chars} chars")
304
+
305
+ response = await acompletion(
306
+ model=litellm_model,
307
+ messages=messages,
308
+ **model_kwargs
309
+ )
310
+
311
+ content = response.choices[0].message.content or ""
312
+ # Log response preview only at DEBUG level to avoid logging sensitive data
313
+ if logger.isEnabledFor(logging.DEBUG):
314
+ logger.debug(f"LLM response preview: '{content[:200]}{'...' if len(content) > 200 else ''}'")
315
+ else:
316
+ logger.info(f"LLM response length: {len(content)} chars")
317
+
318
+ log_metric("llm_call", user_email, model=model_name, message_count=len(messages))
319
+
320
+ return content
321
+
322
+ except Exception as exc:
323
+ logger.error("Error calling LLM: %s", exc, exc_info=True)
324
+ raise Exception(f"Failed to call LLM: {exc}")
325
+
326
+ async def call_with_rag(
327
+ self,
328
+ model_name: str,
329
+ messages: List[Dict[str, str]],
330
+ data_sources: List[str],
331
+ user_email: str,
332
+ rag_service=None,
333
+ temperature: float = 0.7,
334
+ ) -> str:
335
+ """LLM call with RAG integration."""
336
+ logger.debug(
337
+ "[LLM+RAG] call_with_rag called: model=%s, data_sources=%s, user=%s, message_count=%d",
338
+ model_name,
339
+ data_sources,
340
+ user_email,
341
+ len(messages),
342
+ )
343
+
344
+ if not data_sources:
345
+ logger.debug("[LLM+RAG] No data sources provided, falling back to plain LLM call")
346
+ return await self.call_plain(model_name, messages, temperature=temperature, user_email=user_email)
347
+
348
+ # Use provided service or instance service
349
+ if rag_service is None:
350
+ rag_service = self._rag_service
351
+ if rag_service is None:
352
+ logger.error("[LLM+RAG] RAG service not configured")
353
+ raise ValueError("RAG service not configured")
354
+
355
+ multi_source = len(data_sources) > 1
356
+ if multi_source:
357
+ logger.warning(
358
+ "[LLM+RAG] Multiple RAG sources selected (%d). All results will be "
359
+ "treated as raw context and sent through LLM, even if some sources "
360
+ "return pre-interpreted completions.",
361
+ len(data_sources),
362
+ )
363
+
364
+ logger.info(
365
+ "[LLM+RAG] Querying RAG: sources=%s, user=%s",
366
+ data_sources,
367
+ user_email,
368
+ )
369
+
370
+ try:
371
+ # Query all RAG sources in parallel
372
+ source_responses = await self._query_all_rag_sources(
373
+ data_sources, rag_service, user_email, messages,
374
+ )
375
+
376
+ if not source_responses:
377
+ logger.warning("[LLM+RAG] All RAG sources failed, falling back to plain LLM call")
378
+ return await self.call_plain(model_name, messages, temperature=temperature, user_email=user_email)
379
+
380
+ # Single source: preserve existing is_completion shortcut
381
+ if not multi_source:
382
+ display_source, rag_response = source_responses[0]
383
+
384
+ logger.debug(
385
+ "[LLM+RAG] RAG response received: content_length=%d, has_metadata=%s, is_completion=%s",
386
+ len(rag_response.content) if rag_response.content else 0,
387
+ rag_response.metadata is not None,
388
+ rag_response.is_completion,
389
+ )
390
+
391
+ if rag_response.is_completion:
392
+ logger.info(
393
+ "[LLM+RAG] RAG returned chat completion - returning directly without LLM processing"
394
+ )
395
+ final_response = self._build_rag_completion_response(rag_response, display_source)
396
+ logger.info(
397
+ "[LLM+RAG] Returning RAG completion directly: response_length=%d",
398
+ len(final_response),
399
+ )
400
+ return final_response
401
+
402
+ rag_content = rag_response.content
403
+ rag_metadata = rag_response.metadata
404
+ context_label = f"Retrieved context from {display_source}"
405
+ else:
406
+ # Multiple sources: combine all as raw context
407
+ rag_content, rag_metadata = self._combine_rag_contexts(source_responses)
408
+ context_label = f"Retrieved context from {len(source_responses)} RAG sources"
409
+
410
+ # Integrate RAG context into messages
411
+ messages_with_rag = messages.copy()
412
+ rag_context_message = {
413
+ "role": "system",
414
+ "content": f"{context_label}:\n\n{rag_content}\n\nUse this context to inform your response."
415
+ }
416
+ messages_with_rag.insert(-1, rag_context_message)
417
+
418
+ logger.debug("[LLM+RAG] Calling LLM with RAG-enriched context...")
419
+ llm_response = await self.call_plain(model_name, messages_with_rag, temperature=temperature, user_email=user_email)
420
+
421
+ # Only append metadata if RAG actually provided useful content
422
+ rag_content_useful = bool(
423
+ rag_content
424
+ and rag_content.strip()
425
+ and rag_content not in (
426
+ "No response from RAG system.",
427
+ "No response from MCP RAG.",
428
+ "No matching vehicles found.",
429
+ )
430
+ )
431
+
432
+ if rag_content_useful and rag_metadata:
433
+ metadata_summary = self._format_rag_metadata(rag_metadata)
434
+ if metadata_summary and metadata_summary != "Metadata unavailable":
435
+ llm_response += f"\n\n---\n**RAG Sources & Processing Info:**\n{metadata_summary}"
436
+
437
+ logger.info(
438
+ "[LLM+RAG] RAG-integrated query complete: response_length=%d, rag_content_useful=%s",
439
+ len(llm_response),
440
+ rag_content_useful,
441
+ )
442
+ return llm_response
443
+
444
+ except Exception as exc:
445
+ logger.error("[LLM+RAG] Error in RAG-integrated query: %s", exc, exc_info=True)
446
+ logger.warning("[LLM+RAG] Falling back to plain LLM call due to RAG error")
447
+ return await self.call_plain(model_name, messages, temperature=temperature, user_email=user_email)
448
+
449
+ async def call_with_tools(
450
+ self,
451
+ model_name: str,
452
+ messages: List[Dict[str, str]],
453
+ tools_schema: List[Dict],
454
+ tool_choice: str = "auto",
455
+ temperature: float = 0.7,
456
+ user_email: Optional[str] = None
457
+ ) -> LLMResponse:
458
+ """LLM call with tool support using LiteLLM."""
459
+ if not tools_schema:
460
+ content = await self.call_plain(model_name, messages, temperature=temperature, user_email=user_email)
461
+ return LLMResponse(content=content, model_used=model_name)
462
+
463
+ litellm_model = self._get_litellm_model_name(model_name)
464
+ model_kwargs = self._get_model_kwargs(model_name, temperature)
465
+
466
+ # Handle tool_choice parameter - try "required" first, fallback to "auto" if unsupported
467
+ final_tool_choice = tool_choice
468
+
469
+ try:
470
+ total_chars = sum(len(str(msg.get('content', ''))) for msg in messages)
471
+ logger.info(f"LLM call with tools: {len(messages)} messages, {total_chars} chars, {len(tools_schema)} tools")
472
+
473
+ response = await acompletion(
474
+ model=litellm_model,
475
+ messages=messages,
476
+ tools=tools_schema,
477
+ tool_choice=final_tool_choice,
478
+ **model_kwargs
479
+ )
480
+
481
+ message = response.choices[0].message
482
+
483
+ if tool_choice == "required" and not getattr(message, 'tool_calls', None):
484
+ logger.error(f"LLM failed to return tool calls when tool_choice was 'required'. Full response: {response}")
485
+ raise ValueError("LLM failed to return tool calls when tool_choice was 'required'.")
486
+
487
+ tool_calls = getattr(message, 'tool_calls', None)
488
+ tool_count = len(tool_calls) if tool_calls else 0
489
+ log_metric("llm_call", user_email, model=model_name, message_count=len(messages), tool_count=tool_count)
490
+
491
+ return LLMResponse(
492
+ content=getattr(message, 'content', None) or "",
493
+ tool_calls=tool_calls,
494
+ model_used=model_name
495
+ )
496
+
497
+ except Exception as exc:
498
+ # If we used "required" and it failed, try again with "auto"
499
+ if tool_choice == "required" and final_tool_choice == "required":
500
+ logger.warning(f"Tool choice 'required' failed, retrying with 'auto': {exc}")
501
+ try:
502
+ response = await acompletion(
503
+ model=litellm_model,
504
+ messages=messages,
505
+ tools=tools_schema,
506
+ tool_choice="auto",
507
+ **model_kwargs
508
+ )
509
+
510
+ message = response.choices[0].message
511
+ return LLMResponse(
512
+ content=getattr(message, 'content', None) or "",
513
+ tool_calls=getattr(message, 'tool_calls', None),
514
+ model_used=model_name
515
+ )
516
+ except Exception as retry_exc:
517
+ logger.error("Retry with tool_choice='auto' also failed: %s", retry_exc, exc_info=True)
518
+ raise Exception(f"Failed to call LLM with tools: {retry_exc}")
519
+
520
+ logger.error("Error calling LLM with tools: %s", exc, exc_info=True)
521
+ raise Exception(f"Failed to call LLM with tools: {exc}")
522
+
523
+ async def call_with_rag_and_tools(
524
+ self,
525
+ model_name: str,
526
+ messages: List[Dict[str, str]],
527
+ data_sources: List[str],
528
+ tools_schema: List[Dict],
529
+ user_email: str,
530
+ tool_choice: str = "auto",
531
+ rag_service=None,
532
+ temperature: float = 0.7,
533
+ ) -> LLMResponse:
534
+ """Full integration: RAG + Tools."""
535
+ logger.debug(
536
+ "[LLM+RAG+Tools] call_with_rag_and_tools called: model=%s, data_sources=%s, user=%s, tools_count=%d",
537
+ model_name,
538
+ data_sources,
539
+ user_email,
540
+ len(tools_schema) if tools_schema else 0,
541
+ )
542
+
543
+ if not data_sources:
544
+ logger.debug("[LLM+RAG+Tools] No data sources provided, falling back to tools-only call")
545
+ return await self.call_with_tools(model_name, messages, tools_schema, tool_choice, temperature=temperature, user_email=user_email)
546
+
547
+ # Use provided service or instance service
548
+ if rag_service is None:
549
+ rag_service = self._rag_service
550
+ if rag_service is None:
551
+ logger.error("[LLM+RAG+Tools] RAG service not configured")
552
+ raise ValueError("RAG service not configured")
553
+
554
+ multi_source = len(data_sources) > 1
555
+ if multi_source:
556
+ logger.warning(
557
+ "[LLM+RAG+Tools] Multiple RAG sources selected (%d). All results will be "
558
+ "treated as raw context and sent through LLM, even if some sources "
559
+ "return pre-interpreted completions.",
560
+ len(data_sources),
561
+ )
562
+
563
+ logger.info(
564
+ "[LLM+RAG+Tools] Querying RAG: sources=%s, user=%s",
565
+ data_sources,
566
+ user_email,
567
+ )
568
+
569
+ try:
570
+ # Query all RAG sources in parallel
571
+ source_responses = await self._query_all_rag_sources(
572
+ data_sources, rag_service, user_email, messages,
573
+ )
574
+
575
+ if not source_responses:
576
+ logger.warning("[LLM+RAG+Tools] All RAG sources failed, falling back to tools-only call")
577
+ return await self.call_with_tools(model_name, messages, tools_schema, tool_choice, temperature=temperature, user_email=user_email)
578
+
579
+ # Single source: preserve existing is_completion shortcut
580
+ if not multi_source:
581
+ display_source, rag_response = source_responses[0]
582
+
583
+ logger.debug(
584
+ "[LLM+RAG+Tools] RAG response received: content_length=%d, has_metadata=%s, is_completion=%s",
585
+ len(rag_response.content) if rag_response.content else 0,
586
+ rag_response.metadata is not None,
587
+ rag_response.is_completion,
588
+ )
589
+
590
+ if rag_response.is_completion:
591
+ logger.info(
592
+ "[LLM+RAG+Tools] RAG returned chat completion - returning directly without LLM processing"
593
+ )
594
+ final_response = self._build_rag_completion_response(rag_response, display_source)
595
+ logger.info(
596
+ "[LLM+RAG+Tools] Returning RAG completion directly: response_length=%d",
597
+ len(final_response),
598
+ )
599
+ return LLMResponse(content=final_response)
600
+
601
+ rag_content = rag_response.content
602
+ rag_metadata = rag_response.metadata
603
+ context_label = f"Retrieved context from {display_source}"
604
+ else:
605
+ # Multiple sources: combine all as raw context
606
+ rag_content, rag_metadata = self._combine_rag_contexts(source_responses)
607
+ context_label = f"Retrieved context from {len(source_responses)} RAG sources"
608
+
609
+ # Integrate RAG context into messages
610
+ messages_with_rag = messages.copy()
611
+ rag_context_message = {
612
+ "role": "system",
613
+ "content": f"{context_label}:\n\n{rag_content}\n\nUse this context to inform your response."
614
+ }
615
+ messages_with_rag.insert(-1, rag_context_message)
616
+
617
+ logger.debug("[LLM+RAG+Tools] Calling LLM with RAG-enriched context and tools...")
618
+ llm_response = await self.call_with_tools(model_name, messages_with_rag, tools_schema, tool_choice, temperature=temperature, user_email=user_email)
619
+
620
+ # Only append metadata if RAG actually provided useful content
621
+ rag_content_useful = bool(
622
+ rag_content
623
+ and rag_content.strip()
624
+ and rag_content not in (
625
+ "No response from RAG system.",
626
+ "No response from MCP RAG.",
627
+ "No matching vehicles found.",
628
+ )
629
+ )
630
+
631
+ if rag_content_useful and rag_metadata and not llm_response.has_tool_calls():
632
+ metadata_summary = self._format_rag_metadata(rag_metadata)
633
+ if metadata_summary and metadata_summary != "Metadata unavailable":
634
+ llm_response.content += f"\n\n---\n**RAG Sources & Processing Info:**\n{metadata_summary}"
635
+
636
+ logger.info(
637
+ "[LLM+RAG+Tools] RAG+tools query complete: response_length=%d, has_tool_calls=%s, rag_content_useful=%s",
638
+ len(llm_response.content) if llm_response.content else 0,
639
+ llm_response.has_tool_calls(),
640
+ rag_content_useful,
641
+ )
642
+ return llm_response
643
+
644
+ except Exception as exc:
645
+ logger.error("[LLM+RAG+Tools] Error in RAG+tools integrated query: %s", exc, exc_info=True)
646
+ logger.warning("[LLM+RAG+Tools] Falling back to tools-only call due to RAG error")
647
+ return await self.call_with_tools(model_name, messages, tools_schema, tool_choice, temperature=temperature, user_email=user_email)
648
+
649
+ def _format_rag_metadata(self, metadata) -> str:
650
+ """Format RAG metadata into a user-friendly summary."""
651
+ # Import here to avoid circular imports
652
+ try:
653
+ from atlas.modules.rag.models import RAGMetadata
654
+ if not isinstance(metadata, RAGMetadata):
655
+ return "Metadata unavailable"
656
+ except ImportError:
657
+ return "Metadata unavailable"
658
+
659
+ summary_parts = []
660
+ summary_parts.append(f" **Data Source:** {metadata.data_source_name}")
661
+ summary_parts.append(f" **Processing Time:** {metadata.query_processing_time_ms}ms")
662
+
663
+ if metadata.documents_found:
664
+ summary_parts.append(f" **Documents Found:** {len(metadata.documents_found)} (searched {metadata.total_documents_searched})")
665
+
666
+ for i, doc in enumerate(metadata.documents_found[:3]):
667
+ confidence_percent = int(doc.confidence_score * 100)
668
+ summary_parts.append(f" • {doc.source} ({confidence_percent}% relevance, {doc.content_type})")
669
+
670
+ if len(metadata.documents_found) > 3:
671
+ remaining = len(metadata.documents_found) - 3
672
+ summary_parts.append(f" • ... and {remaining} more document(s)")
673
+
674
+ summary_parts.append(f" **Retrieval Method:** {metadata.retrieval_method}")
675
+ return "\n".join(summary_parts)
@@ -0,0 +1,19 @@
1
+ """
2
+ Data models for LLM responses and related structures.
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Optional
7
+
8
+
9
+ @dataclass
10
+ class LLMResponse:
11
+ """Response from LLM call with metadata."""
12
+ content: str
13
+ tool_calls: Optional[List[Dict]] = None
14
+ model_used: str = ""
15
+ tokens_used: int = 0
16
+
17
+ def has_tool_calls(self) -> bool:
18
+ """Check if response contains tool calls."""
19
+ return self.tool_calls is not None and len(self.tool_calls) > 0