lightspeed-stack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,321 @@
1
+ """Handler for REST API call to provide answer to streaming query."""
2
+
3
+ import json
4
+ import logging
5
+ import re
6
+ from typing import Any, AsyncIterator
7
+
8
+ from cachetools import TTLCache # type: ignore
9
+
10
+ from llama_stack_client import APIConnectionError
11
+ from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
12
+ from llama_stack_client import AsyncLlamaStackClient # type: ignore
13
+ from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
14
+ from llama_stack_client.types import UserMessage # type: ignore
15
+
16
+ from fastapi import APIRouter, HTTPException, Request, Depends, status
17
+ from fastapi.responses import StreamingResponse
18
+
19
+ from auth import get_auth_dependency
20
+ from client import get_async_llama_stack_client
21
+ from configuration import configuration
22
+ from models.requests import QueryRequest
23
+ from utils.endpoints import check_configuration_loaded, get_system_prompt
24
+ from utils.common import retrieve_user_id
25
+ from utils.mcp_headers import mcp_headers_dependency
26
+ from utils.suid import get_suid
27
+
28
+
29
+ from app.endpoints.query import (
30
+ get_rag_toolgroups,
31
+ is_transcripts_enabled,
32
+ store_transcript,
33
+ select_model_id,
34
+ validate_attachments_metadata,
35
+ )
36
+
37
+ logger = logging.getLogger("app.endpoints.handlers")
38
+ router = APIRouter(tags=["streaming_query"])
39
+ auth_dependency = get_auth_dependency()
40
+
41
+ # Global agent registry to persist agents across requests
42
+ _agent_cache: TTLCache[str, AsyncAgent] = TTLCache(maxsize=1000, ttl=3600)
43
+
44
+
45
+ async def get_agent(
46
+ client: AsyncLlamaStackClient,
47
+ model_id: str,
48
+ system_prompt: str,
49
+ available_shields: list[str],
50
+ conversation_id: str | None,
51
+ ) -> tuple[AsyncAgent, str]:
52
+ """Get existing agent or create a new one with session persistence."""
53
+ if conversation_id is not None:
54
+ agent = _agent_cache.get(conversation_id)
55
+ if agent:
56
+ logger.debug("Reusing existing agent with key: %s", conversation_id)
57
+ return agent, conversation_id
58
+
59
+ logger.debug("Creating new agent")
60
+ agent = AsyncAgent(
61
+ client, # type: ignore[arg-type]
62
+ model=model_id,
63
+ instructions=system_prompt,
64
+ input_shields=available_shields if available_shields else [],
65
+ enable_session_persistence=True,
66
+ )
67
+ conversation_id = await agent.create_session(get_suid())
68
+ _agent_cache[conversation_id] = agent
69
+ return agent, conversation_id
70
+
71
+
72
+ METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n")
73
+
74
+
75
+ def format_stream_data(d: dict) -> str:
76
+ """Format outbound data in the Event Stream Format."""
77
+ data = json.dumps(d)
78
+ return f"data: {data}\n\n"
79
+
80
+
81
+ def stream_start_event(conversation_id: str) -> str:
82
+ """Yield the start of the data stream.
83
+
84
+ Args:
85
+ conversation_id: The conversation ID (UUID).
86
+ """
87
+ return format_stream_data(
88
+ {
89
+ "event": "start",
90
+ "data": {
91
+ "conversation_id": conversation_id,
92
+ },
93
+ }
94
+ )
95
+
96
+
97
+ def stream_end_event(metadata_map: dict) -> str:
98
+ """Yield the end of the data stream."""
99
+ return format_stream_data(
100
+ {
101
+ "event": "end",
102
+ "data": {
103
+ "referenced_documents": [
104
+ {
105
+ "doc_url": v["docs_url"],
106
+ "doc_title": v["title"],
107
+ }
108
+ for v in filter(
109
+ lambda v: ("docs_url" in v) and ("title" in v),
110
+ metadata_map.values(),
111
+ )
112
+ ],
113
+ "truncated": None, # TODO(jboos): implement truncated
114
+ "input_tokens": 0, # TODO(jboos): implement input tokens
115
+ "output_tokens": 0, # TODO(jboos): implement output tokens
116
+ },
117
+ "available_quotas": {}, # TODO(jboos): implement available quotas
118
+ }
119
+ )
120
+
121
+
122
+ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | None:
123
+ """Build a streaming event from a chunk response.
124
+
125
+ This function processes chunks from the LLama Stack streaming response and formats
126
+ them into Server-Sent Events (SSE) format for the client. It handles two main
127
+ event types:
128
+
129
+ 1. step_progress: Contains text deltas from the model inference process
130
+ 2. step_complete: Contains information about completed tool execution steps
131
+
132
+ Args:
133
+ chunk: The streaming chunk from LLama Stack containing event data
134
+ chunk_id: The current chunk ID counter (gets incremented for each token)
135
+
136
+ Returns:
137
+ str | None: A formatted SSE data string with event information, or None if
138
+ the chunk doesn't contain processable event data
139
+ """
140
+ # pylint: disable=R1702
141
+ if hasattr(chunk.event, "payload"):
142
+ if chunk.event.payload.event_type == "step_progress":
143
+ if hasattr(chunk.event.payload.delta, "text"):
144
+ text = chunk.event.payload.delta.text
145
+ return format_stream_data(
146
+ {
147
+ "event": "token",
148
+ "data": {
149
+ "id": chunk_id,
150
+ "role": chunk.event.payload.step_type,
151
+ "token": text,
152
+ },
153
+ }
154
+ )
155
+ if (
156
+ chunk.event.payload.event_type == "step_complete"
157
+ and chunk.event.payload.step_details.step_type == "tool_execution"
158
+ ):
159
+ for r in chunk.event.payload.step_details.tool_responses:
160
+ if r.tool_name == "knowledge_search" and r.content:
161
+ for text_content_item in r.content:
162
+ if isinstance(text_content_item, TextContentItem):
163
+ for match in METADATA_PATTERN.findall(
164
+ text_content_item.text
165
+ ):
166
+ meta = json.loads(match.replace("'", '"'))
167
+ metadata_map[meta["document_id"]] = meta
168
+ if chunk.event.payload.step_details.tool_calls:
169
+ tool_name = str(
170
+ chunk.event.payload.step_details.tool_calls[0].tool_name
171
+ )
172
+ return format_stream_data(
173
+ {
174
+ "event": "token",
175
+ "data": {
176
+ "id": chunk_id,
177
+ "role": chunk.event.payload.step_type,
178
+ "token": tool_name,
179
+ },
180
+ }
181
+ )
182
+ return None
183
+
184
+
185
+ @router.post("/streaming_query")
186
+ async def streaming_query_endpoint_handler(
187
+ _request: Request,
188
+ query_request: QueryRequest,
189
+ auth: Any = Depends(auth_dependency),
190
+ mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
191
+ ) -> StreamingResponse:
192
+ """Handle request to the /streaming_query endpoint."""
193
+ check_configuration_loaded(configuration)
194
+
195
+ llama_stack_config = configuration.llama_stack_configuration
196
+ logger.info("LLama stack config: %s", llama_stack_config)
197
+
198
+ try:
199
+ # try to get Llama Stack client
200
+ client = await get_async_llama_stack_client(llama_stack_config)
201
+ model_id = select_model_id(await client.models.list(), query_request)
202
+ response, conversation_id = await retrieve_response(
203
+ client,
204
+ model_id,
205
+ query_request,
206
+ auth,
207
+ mcp_headers=mcp_headers,
208
+ )
209
+ metadata_map: dict[str, dict[str, Any]] = {}
210
+
211
+ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
212
+ """Generate SSE formatted streaming response."""
213
+ chunk_id = 0
214
+ complete_response = ""
215
+
216
+ # Send start event
217
+ yield stream_start_event(conversation_id)
218
+
219
+ async for chunk in turn_response:
220
+ if event := stream_build_event(chunk, chunk_id, metadata_map):
221
+ complete_response += json.loads(event.replace("data: ", ""))[
222
+ "data"
223
+ ]["token"]
224
+ chunk_id += 1
225
+ yield event
226
+
227
+ yield stream_end_event(metadata_map)
228
+
229
+ if not is_transcripts_enabled():
230
+ logger.debug("Transcript collection is disabled in the configuration")
231
+ else:
232
+ store_transcript(
233
+ user_id=retrieve_user_id(auth),
234
+ conversation_id=conversation_id,
235
+ query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
236
+ query=query_request.query,
237
+ query_request=query_request,
238
+ response=complete_response,
239
+ rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
240
+ truncated=False, # TODO(lucasagomes): implement truncation as part
241
+ # of quota work
242
+ attachments=query_request.attachments or [],
243
+ )
244
+
245
+ return StreamingResponse(response_generator(response))
246
+ # connection to Llama Stack server
247
+ except APIConnectionError as e:
248
+ logger.error("Unable to connect to Llama Stack: %s", e)
249
+ raise HTTPException(
250
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
251
+ detail={
252
+ "response": "Unable to connect to Llama Stack",
253
+ "cause": str(e),
254
+ },
255
+ ) from e
256
+
257
+
258
+ async def retrieve_response(
259
+ client: AsyncLlamaStackClient,
260
+ model_id: str,
261
+ query_request: QueryRequest,
262
+ token: str,
263
+ mcp_headers: dict[str, dict[str, str]] | None = None,
264
+ ) -> tuple[Any, str]:
265
+ """Retrieve response from LLMs and agents."""
266
+ available_shields = [shield.identifier for shield in await client.shields.list()]
267
+ if not available_shields:
268
+ logger.info("No available shields. Disabling safety")
269
+ else:
270
+ logger.info("Available shields found: %s", available_shields)
271
+
272
+ # use system prompt from request or default one
273
+ system_prompt = get_system_prompt(query_request, configuration)
274
+ logger.debug("Using system prompt: %s", system_prompt)
275
+
276
+ # TODO(lucasagomes): redact attachments content before sending to LLM
277
+ # if attachments are provided, validate them
278
+ if query_request.attachments:
279
+ validate_attachments_metadata(query_request.attachments)
280
+
281
+ agent, conversation_id = await get_agent(
282
+ client,
283
+ model_id,
284
+ system_prompt,
285
+ available_shields,
286
+ query_request.conversation_id,
287
+ )
288
+
289
+ # preserve compatibility when mcp_headers is not provided
290
+ if mcp_headers is None:
291
+ mcp_headers = {}
292
+ if not mcp_headers and token:
293
+ for mcp_server in configuration.mcp_servers:
294
+ mcp_headers[mcp_server.url] = {
295
+ "Authorization": f"Bearer {token}",
296
+ }
297
+
298
+ agent.extra_headers = {
299
+ "X-LlamaStack-Provider-Data": json.dumps(
300
+ {
301
+ "mcp_headers": mcp_headers,
302
+ }
303
+ ),
304
+ }
305
+
306
+ logger.debug("Session ID: %s", conversation_id)
307
+ vector_db_ids = [
308
+ vector_db.identifier for vector_db in await client.vector_dbs.list()
309
+ ]
310
+ toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
311
+ mcp_server.name for mcp_server in configuration.mcp_servers
312
+ ]
313
+ response = await agent.create_turn(
314
+ messages=[UserMessage(role="user", content=query_request.query)],
315
+ session_id=conversation_id,
316
+ documents=query_request.get_documents(),
317
+ stream=True,
318
+ toolgroups=toolgroups or None,
319
+ )
320
+
321
+ return response, conversation_id
app/main.py ADDED
@@ -0,0 +1,38 @@
1
+ """Definition of FastAPI based web service."""
2
+
3
+ from fastapi import FastAPI
4
+ from app import routers
5
+
6
+ import version
7
+ from log import get_logger
8
+ from configuration import configuration
9
+ from utils.common import register_mcp_servers_async
10
+
11
+ logger = get_logger(__name__)
12
+
13
+ logger.info("Initializing app")
14
+
15
+ service_name = configuration.configuration.name
16
+
17
+
18
+ app = FastAPI(
19
+ title=f"{service_name} service - OpenAPI",
20
+ description=f"{service_name} service API specification.",
21
+ version=version.__version__,
22
+ license_info={
23
+ "name": "Apache 2.0",
24
+ "url": "https://www.apache.org/licenses/LICENSE-2.0.html",
25
+ },
26
+ )
27
+
28
+ logger.info("Including routers")
29
+ routers.include_routers(app)
30
+
31
+
32
+ @app.on_event("startup")
33
+ async def startup_event() -> None:
34
+ """Perform logger setup on service startup."""
35
+ logger.info("Registering MCP servers")
36
+ await register_mcp_servers_async(logger, configuration.configuration)
37
+ get_logger("app.endpoints.handlers")
38
+ logger.info("App startup complete")
app/routers.py ADDED
@@ -0,0 +1,30 @@
1
+ """REST API routers."""
2
+
3
+ from fastapi import FastAPI
4
+
5
+ from app.endpoints import (
6
+ info,
7
+ models,
8
+ root,
9
+ query,
10
+ health,
11
+ config,
12
+ feedback,
13
+ streaming_query,
14
+ )
15
+
16
+
17
+ def include_routers(app: FastAPI) -> None:
18
+ """Include FastAPI routers for different endpoints.
19
+
20
+ Args:
21
+ app: The `FastAPI` app instance.
22
+ """
23
+ app.include_router(root.router)
24
+ app.include_router(info.router, prefix="/v1")
25
+ app.include_router(models.router, prefix="/v1")
26
+ app.include_router(query.router, prefix="/v1")
27
+ app.include_router(health.router, prefix="/v1")
28
+ app.include_router(config.router, prefix="/v1")
29
+ app.include_router(feedback.router, prefix="/v1")
30
+ app.include_router(streaming_query.router, prefix="/v1")
auth/__init__.py ADDED
@@ -0,0 +1,38 @@
1
+ """This package contains authentication code and modules."""
2
+
3
+ import logging
4
+
5
+ from auth.interface import AuthInterface
6
+ from auth import noop, noop_with_token, k8s
7
+ from configuration import configuration
8
+ import constants
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def get_auth_dependency(
15
+ virtual_path: str = constants.DEFAULT_VIRTUAL_PATH,
16
+ ) -> AuthInterface:
17
+ """Select the configured authentication dependency interface."""
18
+ module = configuration.authentication_configuration.module # pyright: ignore
19
+
20
+ logger.debug(
21
+ "Initializing authentication dependency: module='%s', virtual_path='%s'",
22
+ module,
23
+ virtual_path,
24
+ )
25
+
26
+ match module:
27
+ case constants.AUTH_MOD_NOOP:
28
+ return noop.NoopAuthDependency(virtual_path=virtual_path)
29
+ case constants.AUTH_MOD_NOOP_WITH_TOKEN:
30
+ return noop_with_token.NoopWithTokenAuthDependency(
31
+ virtual_path=virtual_path
32
+ )
33
+ case constants.AUTH_MOD_K8S:
34
+ return k8s.K8SAuthDependency(virtual_path=virtual_path)
35
+ case _:
36
+ err_msg = f"Unsupported authentication module '{module}'"
37
+ logger.error(err_msg)
38
+ raise ValueError(err_msg)
auth/interface.py ADDED
@@ -0,0 +1,13 @@
1
+ """Abstract base class for authentication methods."""
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ from fastapi import Request
6
+
7
+
8
+ class AuthInterface(ABC): # pylint: disable=too-few-public-methods
9
+ """Base class for all authentication method implementations."""
10
+
11
+ @abstractmethod
12
+ async def __call__(self, request: Request) -> tuple[str, str, str]:
13
+ """Validate FastAPI Requests for authentication and authorization."""