lightspeed-stack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
app/endpoints/query.py ADDED
@@ -0,0 +1,360 @@
1
+ """Handler for REST API call to provide answer to query."""
2
+
3
+ from datetime import datetime, UTC
4
+ import json
5
+ import logging
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from cachetools import TTLCache # type: ignore
11
+
12
+ from llama_stack_client.lib.agents.agent import Agent
13
+ from llama_stack_client import APIConnectionError
14
+ from llama_stack_client import LlamaStackClient # type: ignore
15
+ from llama_stack_client.types import UserMessage # type: ignore
16
+ from llama_stack_client.types.agents.turn_create_params import (
17
+ ToolgroupAgentToolGroupWithArgs,
18
+ Toolgroup,
19
+ )
20
+ from llama_stack_client.types.model_list_response import ModelListResponse
21
+
22
+ from fastapi import APIRouter, HTTPException, status, Depends
23
+
24
+ from client import get_llama_stack_client
25
+ from configuration import configuration
26
+ from models.responses import QueryResponse
27
+ from models.requests import QueryRequest, Attachment
28
+ import constants
29
+ from auth import get_auth_dependency
30
+ from utils.common import retrieve_user_id
31
+ from utils.endpoints import check_configuration_loaded, get_system_prompt
32
+ from utils.mcp_headers import mcp_headers_dependency
33
+ from utils.suid import get_suid
34
+
35
+ logger = logging.getLogger("app.endpoints.handlers")
36
+ router = APIRouter(tags=["query"])
37
+ auth_dependency = get_auth_dependency()
38
+
39
+ # Global agent registry to persist agents across requests
40
+ _agent_cache: TTLCache[str, Agent] = TTLCache(maxsize=1000, ttl=3600)
41
+
42
+ query_response: dict[int | str, dict[str, Any]] = {
43
+ 200: {
44
+ "conversation_id": "123e4567-e89b-12d3-a456-426614174000",
45
+ "response": "LLM ansert",
46
+ },
47
+ 503: {
48
+ "detail": {
49
+ "response": "Unable to connect to Llama Stack",
50
+ "cause": "Connection error.",
51
+ }
52
+ },
53
+ }
54
+
55
+
56
+ def is_transcripts_enabled() -> bool:
57
+ """Check if transcripts is enabled.
58
+
59
+ Returns:
60
+ bool: True if transcripts is enabled, False otherwise.
61
+ """
62
+ return not configuration.user_data_collection_configuration.transcripts_disabled
63
+
64
+
65
+ def get_agent(
66
+ client: LlamaStackClient,
67
+ model_id: str,
68
+ system_prompt: str,
69
+ available_shields: list[str],
70
+ conversation_id: str | None,
71
+ ) -> tuple[Agent, str]:
72
+ """Get existing agent or create a new one with session persistence."""
73
+ if conversation_id is not None:
74
+ agent = _agent_cache.get(conversation_id)
75
+ if agent:
76
+ logger.debug("Reusing existing agent with key: %s", conversation_id)
77
+ return agent, conversation_id
78
+
79
+ logger.debug("Creating new agent")
80
+ # TODO(lucasagomes): move to ReActAgent
81
+ agent = Agent(
82
+ client,
83
+ model=model_id,
84
+ instructions=system_prompt,
85
+ input_shields=available_shields if available_shields else [],
86
+ enable_session_persistence=True,
87
+ )
88
+ conversation_id = agent.create_session(get_suid())
89
+ _agent_cache[conversation_id] = agent
90
+ return agent, conversation_id
91
+
92
+
93
+ @router.post("/query", responses=query_response)
94
+ def query_endpoint_handler(
95
+ query_request: QueryRequest,
96
+ auth: Any = Depends(auth_dependency),
97
+ mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
98
+ ) -> QueryResponse:
99
+ """Handle request to the /query endpoint."""
100
+ check_configuration_loaded(configuration)
101
+
102
+ llama_stack_config = configuration.llama_stack_configuration
103
+ logger.info("LLama stack config: %s", llama_stack_config)
104
+
105
+ try:
106
+ # try to get Llama Stack client
107
+ client = get_llama_stack_client(llama_stack_config)
108
+ model_id = select_model_id(client.models.list(), query_request)
109
+ response, conversation_id = retrieve_response(
110
+ client,
111
+ model_id,
112
+ query_request,
113
+ auth,
114
+ mcp_headers=mcp_headers,
115
+ )
116
+
117
+ if not is_transcripts_enabled():
118
+ logger.debug("Transcript collection is disabled in the configuration")
119
+ else:
120
+ store_transcript(
121
+ user_id=retrieve_user_id(auth),
122
+ conversation_id=conversation_id,
123
+ query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
124
+ query=query_request.query,
125
+ query_request=query_request,
126
+ response=response,
127
+ rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
128
+ truncated=False, # TODO(lucasagomes): implement truncation as part of quota work
129
+ attachments=query_request.attachments or [],
130
+ )
131
+
132
+ return QueryResponse(conversation_id=conversation_id, response=response)
133
+ # connection to Llama Stack server
134
+ except APIConnectionError as e:
135
+ logger.error("Unable to connect to Llama Stack: %s", e)
136
+ raise HTTPException(
137
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
138
+ detail={
139
+ "response": "Unable to connect to Llama Stack",
140
+ "cause": str(e),
141
+ },
142
+ ) from e
143
+
144
+
145
+ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> str:
146
+ """Select the model ID based on the request or available models."""
147
+ model_id = query_request.model
148
+ provider_id = query_request.provider
149
+
150
+ # TODO(lucasagomes): support default model selection via configuration
151
+ if not model_id:
152
+ logger.info("No model specified in request, using the first available LLM")
153
+ try:
154
+ model = next(
155
+ m
156
+ for m in models
157
+ if m.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue]
158
+ ).identifier
159
+ logger.info("Selected model: %s", model)
160
+ return model
161
+ except (StopIteration, AttributeError) as e:
162
+ message = "No LLM model found in available models"
163
+ logger.error(message)
164
+ raise HTTPException(
165
+ status_code=status.HTTP_400_BAD_REQUEST,
166
+ detail={
167
+ "response": constants.UNABLE_TO_PROCESS_RESPONSE,
168
+ "cause": message,
169
+ },
170
+ ) from e
171
+
172
+ logger.info("Searching for model: %s, provider: %s", model_id, provider_id)
173
+ if not any(
174
+ m.identifier == model_id and m.provider_id == provider_id for m in models
175
+ ):
176
+ message = f"Model {model_id} from provider {provider_id} not found in available models"
177
+ logger.error(message)
178
+ raise HTTPException(
179
+ status_code=status.HTTP_400_BAD_REQUEST,
180
+ detail={
181
+ "response": constants.UNABLE_TO_PROCESS_RESPONSE,
182
+ "cause": message,
183
+ },
184
+ )
185
+
186
+ return model_id
187
+
188
+
189
+ def retrieve_response(
190
+ client: LlamaStackClient,
191
+ model_id: str,
192
+ query_request: QueryRequest,
193
+ token: str,
194
+ mcp_headers: dict[str, dict[str, str]] | None = None,
195
+ ) -> tuple[str, str]:
196
+ """Retrieve response from LLMs and agents."""
197
+ available_shields = [shield.identifier for shield in client.shields.list()]
198
+ if not available_shields:
199
+ logger.info("No available shields. Disabling safety")
200
+ else:
201
+ logger.info("Available shields found: %s", available_shields)
202
+
203
+ # use system prompt from request or default one
204
+ system_prompt = get_system_prompt(query_request, configuration)
205
+ logger.debug("Using system prompt: %s", system_prompt)
206
+
207
+ # TODO(lucasagomes): redact attachments content before sending to LLM
208
+ # if attachments are provided, validate them
209
+ if query_request.attachments:
210
+ validate_attachments_metadata(query_request.attachments)
211
+
212
+ agent, conversation_id = get_agent(
213
+ client,
214
+ model_id,
215
+ system_prompt,
216
+ available_shields,
217
+ query_request.conversation_id,
218
+ )
219
+
220
+ # preserve compatibility when mcp_headers is not provided
221
+ if mcp_headers is None:
222
+ mcp_headers = {}
223
+ if not mcp_headers and token:
224
+ for mcp_server in configuration.mcp_servers:
225
+ mcp_headers[mcp_server.url] = {
226
+ "Authorization": f"Bearer {token}",
227
+ }
228
+
229
+ agent.extra_headers = {
230
+ "X-LlamaStack-Provider-Data": json.dumps(
231
+ {
232
+ "mcp_headers": mcp_headers,
233
+ }
234
+ ),
235
+ }
236
+
237
+ vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
238
+ toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
239
+ mcp_server.name for mcp_server in configuration.mcp_servers
240
+ ]
241
+ response = agent.create_turn(
242
+ messages=[UserMessage(role="user", content=query_request.query)],
243
+ session_id=conversation_id,
244
+ documents=query_request.get_documents(),
245
+ stream=False,
246
+ toolgroups=toolgroups or None,
247
+ )
248
+
249
+ return str(response.output_message.content), conversation_id # type: ignore[union-attr]
250
+
251
+
252
+ def validate_attachments_metadata(attachments: list[Attachment]) -> None:
253
+ """Validate the attachments metadata provided in the request.
254
+
255
+ Raises HTTPException if any attachment has an improper type or content type.
256
+ """
257
+ for attachment in attachments:
258
+ if attachment.attachment_type not in constants.ATTACHMENT_TYPES:
259
+ message = (
260
+ f"Attachment with improper type {attachment.attachment_type} detected"
261
+ )
262
+ logger.error(message)
263
+ raise HTTPException(
264
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
265
+ detail={
266
+ "response": constants.UNABLE_TO_PROCESS_RESPONSE,
267
+ "cause": message,
268
+ },
269
+ )
270
+ if attachment.content_type not in constants.ATTACHMENT_CONTENT_TYPES:
271
+ message = f"Attachment with improper content type {attachment.content_type} detected"
272
+ logger.error(message)
273
+ raise HTTPException(
274
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
275
+ detail={
276
+ "response": constants.UNABLE_TO_PROCESS_RESPONSE,
277
+ "cause": message,
278
+ },
279
+ )
280
+
281
+
282
+ def construct_transcripts_path(user_id: str, conversation_id: str) -> Path:
283
+ """Construct path to transcripts."""
284
+ # these two normalizations are required by Snyk as it detects
285
+ # this Path sanitization pattern
286
+ uid = os.path.normpath("/" + user_id).lstrip("/")
287
+ cid = os.path.normpath("/" + conversation_id).lstrip("/")
288
+ file_path = (
289
+ configuration.user_data_collection_configuration.transcripts_storage or ""
290
+ )
291
+ return Path(file_path, uid, cid)
292
+
293
+
294
+ def store_transcript( # pylint: disable=too-many-arguments,too-many-positional-arguments
295
+ user_id: str,
296
+ conversation_id: str,
297
+ query_is_valid: bool,
298
+ query: str,
299
+ query_request: QueryRequest,
300
+ response: str,
301
+ rag_chunks: list[str],
302
+ truncated: bool,
303
+ attachments: list[Attachment],
304
+ ) -> None:
305
+ """Store transcript in the local filesystem.
306
+
307
+ Args:
308
+ user_id: The user ID (UUID).
309
+ conversation_id: The conversation ID (UUID).
310
+ query_is_valid: The result of the query validation.
311
+ query: The query (without attachments).
312
+ query_request: The request containing a query.
313
+ response: The response to store.
314
+ rag_chunks: The list of `RagChunk` objects.
315
+ truncated: The flag indicating if the history was truncated.
316
+ attachments: The list of `Attachment` objects.
317
+ """
318
+ transcripts_path = construct_transcripts_path(user_id, conversation_id)
319
+ transcripts_path.mkdir(parents=True, exist_ok=True)
320
+
321
+ data_to_store = {
322
+ "metadata": {
323
+ "provider": query_request.provider,
324
+ "model": query_request.model,
325
+ "user_id": user_id,
326
+ "conversation_id": conversation_id,
327
+ "timestamp": datetime.now(UTC).isoformat(),
328
+ },
329
+ "redacted_query": query,
330
+ "query_is_valid": query_is_valid,
331
+ "llm_response": response,
332
+ "rag_chunks": rag_chunks,
333
+ "truncated": truncated,
334
+ "attachments": [attachment.model_dump() for attachment in attachments],
335
+ }
336
+
337
+ # stores feedback in a file under unique uuid
338
+ transcript_file_path = transcripts_path / f"{get_suid()}.json"
339
+ with open(transcript_file_path, "w", encoding="utf-8") as transcript_file:
340
+ json.dump(data_to_store, transcript_file)
341
+
342
+ logger.info("Transcript successfully stored at: %s", transcript_file_path)
343
+
344
+
345
+ def get_rag_toolgroups(
346
+ vector_db_ids: list[str],
347
+ ) -> list[Toolgroup] | None:
348
+ """Return a list of RAG Tool groups if the given vector DB list is not empty."""
349
+ return (
350
+ [
351
+ ToolgroupAgentToolGroupWithArgs(
352
+ name="builtin::rag/knowledge_search",
353
+ args={
354
+ "vector_db_ids": vector_db_ids,
355
+ },
356
+ )
357
+ ]
358
+ if vector_db_ids
359
+ else None
360
+ )