lightspeed-stack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- app/__init__.py +1 -0
- app/endpoints/.ruff_cache/.gitignore +2 -0
- app/endpoints/.ruff_cache/0.9.1/5703048272820174433 +0 -0
- app/endpoints/.ruff_cache/0.9.1/9961612457335986079 +0 -0
- app/endpoints/.ruff_cache/CACHEDIR.TAG +1 -0
- app/endpoints/__init__.py +1 -0
- app/endpoints/config.py +64 -0
- app/endpoints/feedback.py +129 -0
- app/endpoints/health.py +111 -0
- app/endpoints/info.py +26 -0
- app/endpoints/models.py +79 -0
- app/endpoints/query.py +360 -0
- app/endpoints/root.py +777 -0
- app/endpoints/streaming_query.py +321 -0
- app/main.py +38 -0
- app/routers.py +30 -0
- auth/__init__.py +38 -0
- auth/interface.py +13 -0
- auth/k8s.py +270 -0
- auth/noop.py +42 -0
- auth/noop_with_token.py +46 -0
- auth/utils.py +26 -0
- lightspeed_stack-0.1.0.dist-info/METADATA +443 -0
- lightspeed_stack-0.1.0.dist-info/RECORD +43 -0
- lightspeed_stack-0.1.0.dist-info/WHEEL +4 -0
- lightspeed_stack-0.1.0.dist-info/entry_points.txt +4 -0
- lightspeed_stack-0.1.0.dist-info/licenses/LICENSE +201 -0
- models/__init__.py +1 -0
- models/config.py +161 -0
- models/requests.py +208 -0
- models/responses.py +244 -0
- runners/__init__.py +1 -0
- runners/uvicorn.py +31 -0
- utils/.ruff_cache/.gitignore +2 -0
- utils/.ruff_cache/0.9.1/18446581155718949728 +0 -0
- utils/.ruff_cache/0.9.1/4991844299736624256 +0 -0
- utils/.ruff_cache/CACHEDIR.TAG +1 -0
- utils/__init__.py +1 -0
- utils/checks.py +27 -0
- utils/common.py +111 -0
- utils/endpoints.py +34 -0
- utils/mcp_headers.py +48 -0
- utils/suid.py +28 -0
app/endpoints/query.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
"""Handler for REST API call to provide answer to query."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime, UTC
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from cachetools import TTLCache # type: ignore
|
|
11
|
+
|
|
12
|
+
from llama_stack_client.lib.agents.agent import Agent
|
|
13
|
+
from llama_stack_client import APIConnectionError
|
|
14
|
+
from llama_stack_client import LlamaStackClient # type: ignore
|
|
15
|
+
from llama_stack_client.types import UserMessage # type: ignore
|
|
16
|
+
from llama_stack_client.types.agents.turn_create_params import (
|
|
17
|
+
ToolgroupAgentToolGroupWithArgs,
|
|
18
|
+
Toolgroup,
|
|
19
|
+
)
|
|
20
|
+
from llama_stack_client.types.model_list_response import ModelListResponse
|
|
21
|
+
|
|
22
|
+
from fastapi import APIRouter, HTTPException, status, Depends
|
|
23
|
+
|
|
24
|
+
from client import get_llama_stack_client
|
|
25
|
+
from configuration import configuration
|
|
26
|
+
from models.responses import QueryResponse
|
|
27
|
+
from models.requests import QueryRequest, Attachment
|
|
28
|
+
import constants
|
|
29
|
+
from auth import get_auth_dependency
|
|
30
|
+
from utils.common import retrieve_user_id
|
|
31
|
+
from utils.endpoints import check_configuration_loaded, get_system_prompt
|
|
32
|
+
from utils.mcp_headers import mcp_headers_dependency
|
|
33
|
+
from utils.suid import get_suid
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger("app.endpoints.handlers")
|
|
36
|
+
router = APIRouter(tags=["query"])
|
|
37
|
+
auth_dependency = get_auth_dependency()
|
|
38
|
+
|
|
39
|
+
# Global agent registry to persist agents across requests
|
|
40
|
+
_agent_cache: TTLCache[str, Agent] = TTLCache(maxsize=1000, ttl=3600)
|
|
41
|
+
|
|
42
|
+
query_response: dict[int | str, dict[str, Any]] = {
|
|
43
|
+
200: {
|
|
44
|
+
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
|
|
45
|
+
"response": "LLM ansert",
|
|
46
|
+
},
|
|
47
|
+
503: {
|
|
48
|
+
"detail": {
|
|
49
|
+
"response": "Unable to connect to Llama Stack",
|
|
50
|
+
"cause": "Connection error.",
|
|
51
|
+
}
|
|
52
|
+
},
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def is_transcripts_enabled() -> bool:
|
|
57
|
+
"""Check if transcripts is enabled.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
bool: True if transcripts is enabled, False otherwise.
|
|
61
|
+
"""
|
|
62
|
+
return not configuration.user_data_collection_configuration.transcripts_disabled
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_agent(
|
|
66
|
+
client: LlamaStackClient,
|
|
67
|
+
model_id: str,
|
|
68
|
+
system_prompt: str,
|
|
69
|
+
available_shields: list[str],
|
|
70
|
+
conversation_id: str | None,
|
|
71
|
+
) -> tuple[Agent, str]:
|
|
72
|
+
"""Get existing agent or create a new one with session persistence."""
|
|
73
|
+
if conversation_id is not None:
|
|
74
|
+
agent = _agent_cache.get(conversation_id)
|
|
75
|
+
if agent:
|
|
76
|
+
logger.debug("Reusing existing agent with key: %s", conversation_id)
|
|
77
|
+
return agent, conversation_id
|
|
78
|
+
|
|
79
|
+
logger.debug("Creating new agent")
|
|
80
|
+
# TODO(lucasagomes): move to ReActAgent
|
|
81
|
+
agent = Agent(
|
|
82
|
+
client,
|
|
83
|
+
model=model_id,
|
|
84
|
+
instructions=system_prompt,
|
|
85
|
+
input_shields=available_shields if available_shields else [],
|
|
86
|
+
enable_session_persistence=True,
|
|
87
|
+
)
|
|
88
|
+
conversation_id = agent.create_session(get_suid())
|
|
89
|
+
_agent_cache[conversation_id] = agent
|
|
90
|
+
return agent, conversation_id
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@router.post("/query", responses=query_response)
|
|
94
|
+
def query_endpoint_handler(
|
|
95
|
+
query_request: QueryRequest,
|
|
96
|
+
auth: Any = Depends(auth_dependency),
|
|
97
|
+
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
|
|
98
|
+
) -> QueryResponse:
|
|
99
|
+
"""Handle request to the /query endpoint."""
|
|
100
|
+
check_configuration_loaded(configuration)
|
|
101
|
+
|
|
102
|
+
llama_stack_config = configuration.llama_stack_configuration
|
|
103
|
+
logger.info("LLama stack config: %s", llama_stack_config)
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
# try to get Llama Stack client
|
|
107
|
+
client = get_llama_stack_client(llama_stack_config)
|
|
108
|
+
model_id = select_model_id(client.models.list(), query_request)
|
|
109
|
+
response, conversation_id = retrieve_response(
|
|
110
|
+
client,
|
|
111
|
+
model_id,
|
|
112
|
+
query_request,
|
|
113
|
+
auth,
|
|
114
|
+
mcp_headers=mcp_headers,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if not is_transcripts_enabled():
|
|
118
|
+
logger.debug("Transcript collection is disabled in the configuration")
|
|
119
|
+
else:
|
|
120
|
+
store_transcript(
|
|
121
|
+
user_id=retrieve_user_id(auth),
|
|
122
|
+
conversation_id=conversation_id,
|
|
123
|
+
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
|
|
124
|
+
query=query_request.query,
|
|
125
|
+
query_request=query_request,
|
|
126
|
+
response=response,
|
|
127
|
+
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
|
|
128
|
+
truncated=False, # TODO(lucasagomes): implement truncation as part of quota work
|
|
129
|
+
attachments=query_request.attachments or [],
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
return QueryResponse(conversation_id=conversation_id, response=response)
|
|
133
|
+
# connection to Llama Stack server
|
|
134
|
+
except APIConnectionError as e:
|
|
135
|
+
logger.error("Unable to connect to Llama Stack: %s", e)
|
|
136
|
+
raise HTTPException(
|
|
137
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
138
|
+
detail={
|
|
139
|
+
"response": "Unable to connect to Llama Stack",
|
|
140
|
+
"cause": str(e),
|
|
141
|
+
},
|
|
142
|
+
) from e
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> str:
|
|
146
|
+
"""Select the model ID based on the request or available models."""
|
|
147
|
+
model_id = query_request.model
|
|
148
|
+
provider_id = query_request.provider
|
|
149
|
+
|
|
150
|
+
# TODO(lucasagomes): support default model selection via configuration
|
|
151
|
+
if not model_id:
|
|
152
|
+
logger.info("No model specified in request, using the first available LLM")
|
|
153
|
+
try:
|
|
154
|
+
model = next(
|
|
155
|
+
m
|
|
156
|
+
for m in models
|
|
157
|
+
if m.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue]
|
|
158
|
+
).identifier
|
|
159
|
+
logger.info("Selected model: %s", model)
|
|
160
|
+
return model
|
|
161
|
+
except (StopIteration, AttributeError) as e:
|
|
162
|
+
message = "No LLM model found in available models"
|
|
163
|
+
logger.error(message)
|
|
164
|
+
raise HTTPException(
|
|
165
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
166
|
+
detail={
|
|
167
|
+
"response": constants.UNABLE_TO_PROCESS_RESPONSE,
|
|
168
|
+
"cause": message,
|
|
169
|
+
},
|
|
170
|
+
) from e
|
|
171
|
+
|
|
172
|
+
logger.info("Searching for model: %s, provider: %s", model_id, provider_id)
|
|
173
|
+
if not any(
|
|
174
|
+
m.identifier == model_id and m.provider_id == provider_id for m in models
|
|
175
|
+
):
|
|
176
|
+
message = f"Model {model_id} from provider {provider_id} not found in available models"
|
|
177
|
+
logger.error(message)
|
|
178
|
+
raise HTTPException(
|
|
179
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
180
|
+
detail={
|
|
181
|
+
"response": constants.UNABLE_TO_PROCESS_RESPONSE,
|
|
182
|
+
"cause": message,
|
|
183
|
+
},
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
return model_id
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def retrieve_response(
|
|
190
|
+
client: LlamaStackClient,
|
|
191
|
+
model_id: str,
|
|
192
|
+
query_request: QueryRequest,
|
|
193
|
+
token: str,
|
|
194
|
+
mcp_headers: dict[str, dict[str, str]] | None = None,
|
|
195
|
+
) -> tuple[str, str]:
|
|
196
|
+
"""Retrieve response from LLMs and agents."""
|
|
197
|
+
available_shields = [shield.identifier for shield in client.shields.list()]
|
|
198
|
+
if not available_shields:
|
|
199
|
+
logger.info("No available shields. Disabling safety")
|
|
200
|
+
else:
|
|
201
|
+
logger.info("Available shields found: %s", available_shields)
|
|
202
|
+
|
|
203
|
+
# use system prompt from request or default one
|
|
204
|
+
system_prompt = get_system_prompt(query_request, configuration)
|
|
205
|
+
logger.debug("Using system prompt: %s", system_prompt)
|
|
206
|
+
|
|
207
|
+
# TODO(lucasagomes): redact attachments content before sending to LLM
|
|
208
|
+
# if attachments are provided, validate them
|
|
209
|
+
if query_request.attachments:
|
|
210
|
+
validate_attachments_metadata(query_request.attachments)
|
|
211
|
+
|
|
212
|
+
agent, conversation_id = get_agent(
|
|
213
|
+
client,
|
|
214
|
+
model_id,
|
|
215
|
+
system_prompt,
|
|
216
|
+
available_shields,
|
|
217
|
+
query_request.conversation_id,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# preserve compatibility when mcp_headers is not provided
|
|
221
|
+
if mcp_headers is None:
|
|
222
|
+
mcp_headers = {}
|
|
223
|
+
if not mcp_headers and token:
|
|
224
|
+
for mcp_server in configuration.mcp_servers:
|
|
225
|
+
mcp_headers[mcp_server.url] = {
|
|
226
|
+
"Authorization": f"Bearer {token}",
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
agent.extra_headers = {
|
|
230
|
+
"X-LlamaStack-Provider-Data": json.dumps(
|
|
231
|
+
{
|
|
232
|
+
"mcp_headers": mcp_headers,
|
|
233
|
+
}
|
|
234
|
+
),
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
|
|
238
|
+
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
|
|
239
|
+
mcp_server.name for mcp_server in configuration.mcp_servers
|
|
240
|
+
]
|
|
241
|
+
response = agent.create_turn(
|
|
242
|
+
messages=[UserMessage(role="user", content=query_request.query)],
|
|
243
|
+
session_id=conversation_id,
|
|
244
|
+
documents=query_request.get_documents(),
|
|
245
|
+
stream=False,
|
|
246
|
+
toolgroups=toolgroups or None,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
return str(response.output_message.content), conversation_id # type: ignore[union-attr]
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def validate_attachments_metadata(attachments: list[Attachment]) -> None:
|
|
253
|
+
"""Validate the attachments metadata provided in the request.
|
|
254
|
+
|
|
255
|
+
Raises HTTPException if any attachment has an improper type or content type.
|
|
256
|
+
"""
|
|
257
|
+
for attachment in attachments:
|
|
258
|
+
if attachment.attachment_type not in constants.ATTACHMENT_TYPES:
|
|
259
|
+
message = (
|
|
260
|
+
f"Attachment with improper type {attachment.attachment_type} detected"
|
|
261
|
+
)
|
|
262
|
+
logger.error(message)
|
|
263
|
+
raise HTTPException(
|
|
264
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
265
|
+
detail={
|
|
266
|
+
"response": constants.UNABLE_TO_PROCESS_RESPONSE,
|
|
267
|
+
"cause": message,
|
|
268
|
+
},
|
|
269
|
+
)
|
|
270
|
+
if attachment.content_type not in constants.ATTACHMENT_CONTENT_TYPES:
|
|
271
|
+
message = f"Attachment with improper content type {attachment.content_type} detected"
|
|
272
|
+
logger.error(message)
|
|
273
|
+
raise HTTPException(
|
|
274
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
275
|
+
detail={
|
|
276
|
+
"response": constants.UNABLE_TO_PROCESS_RESPONSE,
|
|
277
|
+
"cause": message,
|
|
278
|
+
},
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def construct_transcripts_path(user_id: str, conversation_id: str) -> Path:
|
|
283
|
+
"""Construct path to transcripts."""
|
|
284
|
+
# these two normalizations are required by Snyk as it detects
|
|
285
|
+
# this Path sanitization pattern
|
|
286
|
+
uid = os.path.normpath("/" + user_id).lstrip("/")
|
|
287
|
+
cid = os.path.normpath("/" + conversation_id).lstrip("/")
|
|
288
|
+
file_path = (
|
|
289
|
+
configuration.user_data_collection_configuration.transcripts_storage or ""
|
|
290
|
+
)
|
|
291
|
+
return Path(file_path, uid, cid)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def store_transcript( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
295
|
+
user_id: str,
|
|
296
|
+
conversation_id: str,
|
|
297
|
+
query_is_valid: bool,
|
|
298
|
+
query: str,
|
|
299
|
+
query_request: QueryRequest,
|
|
300
|
+
response: str,
|
|
301
|
+
rag_chunks: list[str],
|
|
302
|
+
truncated: bool,
|
|
303
|
+
attachments: list[Attachment],
|
|
304
|
+
) -> None:
|
|
305
|
+
"""Store transcript in the local filesystem.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
user_id: The user ID (UUID).
|
|
309
|
+
conversation_id: The conversation ID (UUID).
|
|
310
|
+
query_is_valid: The result of the query validation.
|
|
311
|
+
query: The query (without attachments).
|
|
312
|
+
query_request: The request containing a query.
|
|
313
|
+
response: The response to store.
|
|
314
|
+
rag_chunks: The list of `RagChunk` objects.
|
|
315
|
+
truncated: The flag indicating if the history was truncated.
|
|
316
|
+
attachments: The list of `Attachment` objects.
|
|
317
|
+
"""
|
|
318
|
+
transcripts_path = construct_transcripts_path(user_id, conversation_id)
|
|
319
|
+
transcripts_path.mkdir(parents=True, exist_ok=True)
|
|
320
|
+
|
|
321
|
+
data_to_store = {
|
|
322
|
+
"metadata": {
|
|
323
|
+
"provider": query_request.provider,
|
|
324
|
+
"model": query_request.model,
|
|
325
|
+
"user_id": user_id,
|
|
326
|
+
"conversation_id": conversation_id,
|
|
327
|
+
"timestamp": datetime.now(UTC).isoformat(),
|
|
328
|
+
},
|
|
329
|
+
"redacted_query": query,
|
|
330
|
+
"query_is_valid": query_is_valid,
|
|
331
|
+
"llm_response": response,
|
|
332
|
+
"rag_chunks": rag_chunks,
|
|
333
|
+
"truncated": truncated,
|
|
334
|
+
"attachments": [attachment.model_dump() for attachment in attachments],
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
# stores feedback in a file under unique uuid
|
|
338
|
+
transcript_file_path = transcripts_path / f"{get_suid()}.json"
|
|
339
|
+
with open(transcript_file_path, "w", encoding="utf-8") as transcript_file:
|
|
340
|
+
json.dump(data_to_store, transcript_file)
|
|
341
|
+
|
|
342
|
+
logger.info("Transcript successfully stored at: %s", transcript_file_path)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def get_rag_toolgroups(
|
|
346
|
+
vector_db_ids: list[str],
|
|
347
|
+
) -> list[Toolgroup] | None:
|
|
348
|
+
"""Return a list of RAG Tool groups if the given vector DB list is not empty."""
|
|
349
|
+
return (
|
|
350
|
+
[
|
|
351
|
+
ToolgroupAgentToolGroupWithArgs(
|
|
352
|
+
name="builtin::rag/knowledge_search",
|
|
353
|
+
args={
|
|
354
|
+
"vector_db_ids": vector_db_ids,
|
|
355
|
+
},
|
|
356
|
+
)
|
|
357
|
+
]
|
|
358
|
+
if vector_db_ids
|
|
359
|
+
else None
|
|
360
|
+
)
|