llama-stack 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/list_deps.py +4 -0
- llama_stack/core/routers/inference.py +66 -40
- llama_stack/distributions/starter/build.yaml +1 -0
- llama_stack/distributions/starter/run-with-postgres-store.yaml +285 -0
- llama_stack/distributions/starter/starter.py +86 -68
- llama_stack/distributions/starter-gpu/build.yaml +1 -0
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +288 -0
- llama_stack/providers/inline/vector_io/faiss/faiss.py +25 -2
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +15 -4
- llama_stack/providers/remote/inference/vertexai/vertexai.py +10 -0
- llama_stack/providers/remote/vector_io/chroma/chroma.py +9 -3
- llama_stack/providers/remote/vector_io/milvus/milvus.py +7 -4
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +32 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +11 -6
- llama_stack/providers/remote/vector_io/weaviate/weaviate.py +7 -4
- llama_stack/providers/utils/inference/embedding_mixin.py +1 -2
- llama_stack/providers/utils/inference/inference_store.py +30 -10
- llama_stack/providers/utils/inference/model_registry.py +1 -1
- llama_stack/providers/utils/inference/openai_mixin.py +33 -10
- llama_stack/providers/utils/responses/responses_store.py +12 -58
- llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +25 -9
- llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +31 -1
- llama_stack/ui/node_modules/flatted/python/flatted.py +149 -0
- {llama_stack-0.3.1.dist-info → llama_stack-0.3.3.dist-info}/METADATA +3 -3
- {llama_stack-0.3.1.dist-info → llama_stack-0.3.3.dist-info}/RECORD +29 -26
- {llama_stack-0.3.1.dist-info → llama_stack-0.3.3.dist-info}/WHEEL +0 -0
- {llama_stack-0.3.1.dist-info → llama_stack-0.3.3.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.3.1.dist-info → llama_stack-0.3.3.dist-info}/licenses/LICENSE +0 -0
- {llama_stack-0.3.1.dist-info → llama_stack-0.3.3.dist-info}/top_level.txt +0 -0
|
@@ -368,6 +368,22 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
|
|
368
368
|
log.exception("Could not connect to PGVector database server")
|
|
369
369
|
raise RuntimeError("Could not connect to PGVector database server") from e
|
|
370
370
|
|
|
371
|
+
# Load existing vector stores from KV store into cache
|
|
372
|
+
start_key = VECTOR_DBS_PREFIX
|
|
373
|
+
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
|
374
|
+
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
|
|
375
|
+
for vector_store_data in stored_vector_stores:
|
|
376
|
+
vector_store = VectorStore.model_validate_json(vector_store_data)
|
|
377
|
+
pgvector_index = PGVectorIndex(
|
|
378
|
+
vector_store=vector_store,
|
|
379
|
+
dimension=vector_store.embedding_dimension,
|
|
380
|
+
conn=self.conn,
|
|
381
|
+
kvstore=self.kvstore,
|
|
382
|
+
)
|
|
383
|
+
await pgvector_index.initialize()
|
|
384
|
+
index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api)
|
|
385
|
+
self.cache[vector_store.identifier] = index
|
|
386
|
+
|
|
371
387
|
async def shutdown(self) -> None:
|
|
372
388
|
if self.conn is not None:
|
|
373
389
|
self.conn.close()
|
|
@@ -377,7 +393,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
|
|
377
393
|
|
|
378
394
|
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
|
379
395
|
# Persist vector DB metadata in the KV store
|
|
380
|
-
|
|
396
|
+
if self.kvstore is None:
|
|
397
|
+
raise RuntimeError("KVStore not initialized. Call initialize() before registering vector stores.")
|
|
398
|
+
|
|
399
|
+
# Save to kvstore for persistence
|
|
400
|
+
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
|
|
401
|
+
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
|
|
402
|
+
|
|
381
403
|
# Upsert model metadata in Postgres
|
|
382
404
|
upsert_models(self.conn, [(vector_store.identifier, vector_store)])
|
|
383
405
|
|
|
@@ -396,7 +418,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
|
|
396
418
|
del self.cache[vector_store_id]
|
|
397
419
|
|
|
398
420
|
# Delete vector DB metadata from KV store
|
|
399
|
-
|
|
421
|
+
if self.kvstore is None:
|
|
422
|
+
raise RuntimeError("KVStore not initialized. Call initialize() before unregistering vector stores.")
|
|
400
423
|
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
|
401
424
|
|
|
402
425
|
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
|
@@ -413,13 +436,16 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
|
|
413
436
|
if vector_store_id in self.cache:
|
|
414
437
|
return self.cache[vector_store_id]
|
|
415
438
|
|
|
416
|
-
|
|
417
|
-
|
|
439
|
+
# Try to load from kvstore
|
|
440
|
+
if self.kvstore is None:
|
|
441
|
+
raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
|
|
418
442
|
|
|
419
|
-
|
|
420
|
-
|
|
443
|
+
key = f"{VECTOR_DBS_PREFIX}{vector_store_id}"
|
|
444
|
+
vector_store_data = await self.kvstore.get(key)
|
|
445
|
+
if not vector_store_data:
|
|
421
446
|
raise VectorStoreNotFoundError(vector_store_id)
|
|
422
447
|
|
|
448
|
+
vector_store = VectorStore.model_validate_json(vector_store_data)
|
|
423
449
|
index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn)
|
|
424
450
|
await index.initialize()
|
|
425
451
|
self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
|
@@ -183,7 +183,8 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|
|
183
183
|
await super().shutdown()
|
|
184
184
|
|
|
185
185
|
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
|
186
|
-
|
|
186
|
+
if self.kvstore is None:
|
|
187
|
+
raise RuntimeError("KVStore not initialized. Call initialize() before registering vector stores.")
|
|
187
188
|
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
|
|
188
189
|
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
|
|
189
190
|
|
|
@@ -200,20 +201,24 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|
|
200
201
|
await self.cache[vector_store_id].index.delete()
|
|
201
202
|
del self.cache[vector_store_id]
|
|
202
203
|
|
|
203
|
-
|
|
204
|
+
if self.kvstore is None:
|
|
205
|
+
raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
|
|
204
206
|
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
|
205
207
|
|
|
206
208
|
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
|
|
207
209
|
if vector_store_id in self.cache:
|
|
208
210
|
return self.cache[vector_store_id]
|
|
209
211
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
+
# Try to load from kvstore
|
|
213
|
+
if self.kvstore is None:
|
|
214
|
+
raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
|
|
212
215
|
|
|
213
|
-
|
|
214
|
-
|
|
216
|
+
key = f"{VECTOR_DBS_PREFIX}{vector_store_id}"
|
|
217
|
+
vector_store_data = await self.kvstore.get(key)
|
|
218
|
+
if not vector_store_data:
|
|
215
219
|
raise VectorStoreNotFoundError(vector_store_id)
|
|
216
220
|
|
|
221
|
+
vector_store = VectorStore.model_validate_json(vector_store_data)
|
|
217
222
|
index = VectorStoreWithIndex(
|
|
218
223
|
vector_store=vector_store,
|
|
219
224
|
index=QdrantIndex(client=self.client, collection_name=vector_store.identifier),
|
|
@@ -346,13 +346,16 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
|
|
346
346
|
if vector_store_id in self.cache:
|
|
347
347
|
return self.cache[vector_store_id]
|
|
348
348
|
|
|
349
|
-
|
|
350
|
-
|
|
349
|
+
# Try to load from kvstore
|
|
350
|
+
if self.kvstore is None:
|
|
351
|
+
raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
|
|
351
352
|
|
|
352
|
-
|
|
353
|
-
|
|
353
|
+
key = f"{VECTOR_DBS_PREFIX}{vector_store_id}"
|
|
354
|
+
vector_store_data = await self.kvstore.get(key)
|
|
355
|
+
if not vector_store_data:
|
|
354
356
|
raise VectorStoreNotFoundError(vector_store_id)
|
|
355
357
|
|
|
358
|
+
vector_store = VectorStore.model_validate_json(vector_store_data)
|
|
356
359
|
client = self._get_client()
|
|
357
360
|
sanitized_collection_name = sanitize_collection_name(vector_store.identifier, weaviate_format=True)
|
|
358
361
|
if not client.collections.exists(sanitized_collection_name):
|
|
@@ -46,8 +46,7 @@ class SentenceTransformerEmbeddingMixin:
|
|
|
46
46
|
raise ValueError("Empty list not supported")
|
|
47
47
|
|
|
48
48
|
# Get the model and generate embeddings
|
|
49
|
-
|
|
50
|
-
embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
|
49
|
+
embedding_model = await self._load_sentence_transformer_model(params.model)
|
|
51
50
|
embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False)
|
|
52
51
|
|
|
53
52
|
# Convert embeddings to the requested format
|
|
@@ -35,6 +35,7 @@ class InferenceStore:
|
|
|
35
35
|
self.reference = reference
|
|
36
36
|
self.sql_store = None
|
|
37
37
|
self.policy = policy
|
|
38
|
+
self.enable_write_queue = True
|
|
38
39
|
|
|
39
40
|
# Async write queue and worker control
|
|
40
41
|
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
|
|
@@ -47,14 +48,13 @@ class InferenceStore:
|
|
|
47
48
|
base_store = sqlstore_impl(self.reference)
|
|
48
49
|
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
|
49
50
|
|
|
50
|
-
# Disable write queue for SQLite
|
|
51
|
-
|
|
52
|
-
backend_config = _SQLSTORE_BACKENDS.get(
|
|
53
|
-
if backend_config
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
self.enable_write_queue = backend_config.type != StorageBackendType.SQL_SQLITE
|
|
51
|
+
# Disable write queue for SQLite since WAL mode handles concurrency
|
|
52
|
+
# Keep it enabled for other backends (like Postgres) for performance
|
|
53
|
+
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
|
54
|
+
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
|
|
55
|
+
self.enable_write_queue = False
|
|
56
|
+
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
|
|
57
|
+
|
|
58
58
|
await self.sql_store.create_table(
|
|
59
59
|
"chat_completions",
|
|
60
60
|
{
|
|
@@ -70,8 +70,9 @@ class InferenceStore:
|
|
|
70
70
|
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
|
71
71
|
for _ in range(self._num_writers):
|
|
72
72
|
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
|
73
|
-
|
|
74
|
-
|
|
73
|
+
logger.debug(
|
|
74
|
+
f"Inference store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
|
|
75
|
+
)
|
|
75
76
|
|
|
76
77
|
async def shutdown(self) -> None:
|
|
77
78
|
if not self._worker_tasks:
|
|
@@ -93,10 +94,29 @@ class InferenceStore:
|
|
|
93
94
|
if self.enable_write_queue and self._queue is not None:
|
|
94
95
|
await self._queue.join()
|
|
95
96
|
|
|
97
|
+
async def _ensure_workers_started(self) -> None:
|
|
98
|
+
"""Ensure the async write queue workers run on the current loop."""
|
|
99
|
+
if not self.enable_write_queue:
|
|
100
|
+
return
|
|
101
|
+
|
|
102
|
+
if self._queue is None:
|
|
103
|
+
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
|
104
|
+
logger.debug(
|
|
105
|
+
f"Inference store write queue created with max size {self._max_write_queue_size} "
|
|
106
|
+
f"and {self._num_writers} writers"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
if not self._worker_tasks:
|
|
110
|
+
loop = asyncio.get_running_loop()
|
|
111
|
+
for _ in range(self._num_writers):
|
|
112
|
+
task = loop.create_task(self._worker_loop())
|
|
113
|
+
self._worker_tasks.append(task)
|
|
114
|
+
|
|
96
115
|
async def store_chat_completion(
|
|
97
116
|
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
|
98
117
|
) -> None:
|
|
99
118
|
if self.enable_write_queue:
|
|
119
|
+
await self._ensure_workers_started()
|
|
100
120
|
if self._queue is None:
|
|
101
121
|
raise ValueError("Inference store is not initialized")
|
|
102
122
|
try:
|
|
@@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils")
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class RemoteInferenceProviderConfig(BaseModel):
|
|
23
|
-
allowed_models: list[str] | None = Field(
|
|
23
|
+
allowed_models: list[str] | None = Field(
|
|
24
24
|
default=None,
|
|
25
25
|
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
|
26
26
|
)
|
|
@@ -82,9 +82,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|
|
82
82
|
# This is set in list_models() and used in check_model_availability()
|
|
83
83
|
_model_cache: dict[str, Model] = {}
|
|
84
84
|
|
|
85
|
-
# List of allowed models for this provider, if empty all models allowed
|
|
86
|
-
allowed_models: list[str] = []
|
|
87
|
-
|
|
88
85
|
# Optional field name in provider data to look for API key, which takes precedence
|
|
89
86
|
provider_data_api_key_field: str | None = None
|
|
90
87
|
|
|
@@ -191,6 +188,19 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|
|
191
188
|
|
|
192
189
|
return api_key
|
|
193
190
|
|
|
191
|
+
def _validate_model_allowed(self, provider_model_id: str) -> None:
|
|
192
|
+
"""
|
|
193
|
+
Validate that the model is in the allowed_models list if configured.
|
|
194
|
+
|
|
195
|
+
:param provider_model_id: The provider-specific model ID to validate
|
|
196
|
+
:raises ValueError: If the model is not in the allowed_models list
|
|
197
|
+
"""
|
|
198
|
+
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"Model '{provider_model_id}' is not in the allowed models list. "
|
|
201
|
+
f"Allowed models: {self.config.allowed_models}"
|
|
202
|
+
)
|
|
203
|
+
|
|
194
204
|
async def _get_provider_model_id(self, model: str) -> str:
|
|
195
205
|
"""
|
|
196
206
|
Get the provider-specific model ID from the model store.
|
|
@@ -201,8 +211,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|
|
201
211
|
:param model: The registered model name/identifier
|
|
202
212
|
:return: The provider-specific model ID (e.g., "gpt-4")
|
|
203
213
|
"""
|
|
204
|
-
# Look up the registered model to get the provider-specific model ID
|
|
205
214
|
# self.model_store is injected by the distribution system at runtime
|
|
215
|
+
if not await self.model_store.has_model(model): # type: ignore[attr-defined]
|
|
216
|
+
return model
|
|
217
|
+
|
|
218
|
+
# Look up the registered model to get the provider-specific model ID
|
|
206
219
|
model_obj: Model = await self.model_store.get_model(model) # type: ignore[attr-defined]
|
|
207
220
|
# provider_resource_id is str | None, but we expect it to be str for OpenAI calls
|
|
208
221
|
if model_obj.provider_resource_id is None:
|
|
@@ -234,8 +247,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|
|
234
247
|
Direct OpenAI completion API call.
|
|
235
248
|
"""
|
|
236
249
|
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
|
250
|
+
provider_model_id = await self._get_provider_model_id(params.model)
|
|
251
|
+
self._validate_model_allowed(provider_model_id)
|
|
252
|
+
|
|
237
253
|
completion_kwargs = await prepare_openai_completion_params(
|
|
238
|
-
model=
|
|
254
|
+
model=provider_model_id,
|
|
239
255
|
prompt=params.prompt,
|
|
240
256
|
best_of=params.best_of,
|
|
241
257
|
echo=params.echo,
|
|
@@ -267,6 +283,9 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|
|
267
283
|
"""
|
|
268
284
|
Direct OpenAI chat completion API call.
|
|
269
285
|
"""
|
|
286
|
+
provider_model_id = await self._get_provider_model_id(params.model)
|
|
287
|
+
self._validate_model_allowed(provider_model_id)
|
|
288
|
+
|
|
270
289
|
messages = params.messages
|
|
271
290
|
|
|
272
291
|
if self.download_images:
|
|
@@ -288,7 +307,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|
|
288
307
|
messages = [await _localize_image_url(m) for m in messages]
|
|
289
308
|
|
|
290
309
|
request_params = await prepare_openai_completion_params(
|
|
291
|
-
model=
|
|
310
|
+
model=provider_model_id,
|
|
292
311
|
messages=messages,
|
|
293
312
|
frequency_penalty=params.frequency_penalty,
|
|
294
313
|
function_call=params.function_call,
|
|
@@ -326,9 +345,13 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|
|
326
345
|
"""
|
|
327
346
|
Direct OpenAI embeddings API call.
|
|
328
347
|
"""
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
348
|
+
provider_model_id = await self._get_provider_model_id(params.model)
|
|
349
|
+
self._validate_model_allowed(provider_model_id)
|
|
350
|
+
|
|
351
|
+
# Build request params conditionally to avoid NotGiven/Omit type mismatch
|
|
352
|
+
# The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven
|
|
353
|
+
request_params: dict[str, Any] = {
|
|
354
|
+
"model": provider_model_id,
|
|
332
355
|
"input": params.input,
|
|
333
356
|
"encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
|
|
334
357
|
"dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,
|
|
@@ -413,7 +436,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|
|
413
436
|
for provider_model_id in provider_models_ids:
|
|
414
437
|
if not isinstance(provider_model_id, str):
|
|
415
438
|
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
|
|
416
|
-
if self.allowed_models and provider_model_id not in self.allowed_models:
|
|
439
|
+
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
|
|
417
440
|
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
|
|
418
441
|
continue
|
|
419
442
|
if metadata := self.embedding_model_metadata.get(provider_model_id):
|
|
@@ -4,7 +4,6 @@
|
|
|
4
4
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
import asyncio
|
|
7
|
-
from typing import Any
|
|
8
7
|
|
|
9
8
|
from llama_stack.apis.agents import (
|
|
10
9
|
Order,
|
|
@@ -55,28 +54,19 @@ class ResponsesStore:
|
|
|
55
54
|
|
|
56
55
|
self.policy = policy
|
|
57
56
|
self.sql_store = None
|
|
58
|
-
self.enable_write_queue = True
|
|
59
|
-
|
|
60
|
-
# Async write queue and worker control
|
|
61
|
-
self._queue: (
|
|
62
|
-
asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput], list[OpenAIMessageParam]]] | None
|
|
63
|
-
) = None
|
|
64
|
-
self._worker_tasks: list[asyncio.Task[Any]] = []
|
|
65
|
-
self._max_write_queue_size: int = self.reference.max_write_queue_size
|
|
66
|
-
self._num_writers: int = max(1, self.reference.num_writers)
|
|
67
57
|
|
|
68
58
|
async def initialize(self):
|
|
69
59
|
"""Create the necessary tables if they don't exist."""
|
|
70
60
|
base_store = sqlstore_impl(self.reference)
|
|
71
61
|
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
|
72
62
|
|
|
63
|
+
# Disable write queue for SQLite since WAL mode handles concurrency
|
|
64
|
+
# Keep it enabled for other backends (like Postgres) for performance
|
|
73
65
|
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
|
74
|
-
if backend_config
|
|
75
|
-
raise ValueError(
|
|
76
|
-
f"Unregistered SQL backend '{self.reference.backend}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
|
77
|
-
)
|
|
78
|
-
if backend_config.type == StorageBackendType.SQL_SQLITE:
|
|
66
|
+
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
|
|
79
67
|
self.enable_write_queue = False
|
|
68
|
+
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
|
|
69
|
+
|
|
80
70
|
await self.sql_store.create_table(
|
|
81
71
|
"openai_responses",
|
|
82
72
|
{
|
|
@@ -99,28 +89,16 @@ class ResponsesStore:
|
|
|
99
89
|
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
|
100
90
|
for _ in range(self._num_writers):
|
|
101
91
|
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
|
102
|
-
|
|
103
|
-
|
|
92
|
+
logger.debug(
|
|
93
|
+
f"Responses store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
|
|
94
|
+
)
|
|
104
95
|
|
|
105
96
|
async def shutdown(self) -> None:
|
|
106
|
-
|
|
107
|
-
return
|
|
108
|
-
if self._queue is not None:
|
|
109
|
-
await self._queue.join()
|
|
110
|
-
for t in self._worker_tasks:
|
|
111
|
-
if not t.done():
|
|
112
|
-
t.cancel()
|
|
113
|
-
for t in self._worker_tasks:
|
|
114
|
-
try:
|
|
115
|
-
await t
|
|
116
|
-
except asyncio.CancelledError:
|
|
117
|
-
pass
|
|
118
|
-
self._worker_tasks.clear()
|
|
97
|
+
return
|
|
119
98
|
|
|
120
99
|
async def flush(self) -> None:
|
|
121
|
-
"""
|
|
122
|
-
|
|
123
|
-
await self._queue.join()
|
|
100
|
+
"""Maintained for compatibility; no-op now that writes are synchronous."""
|
|
101
|
+
return
|
|
124
102
|
|
|
125
103
|
async def store_response_object(
|
|
126
104
|
self,
|
|
@@ -128,31 +106,7 @@ class ResponsesStore:
|
|
|
128
106
|
input: list[OpenAIResponseInput],
|
|
129
107
|
messages: list[OpenAIMessageParam],
|
|
130
108
|
) -> None:
|
|
131
|
-
|
|
132
|
-
if self._queue is None:
|
|
133
|
-
raise ValueError("Responses store is not initialized")
|
|
134
|
-
try:
|
|
135
|
-
self._queue.put_nowait((response_object, input, messages))
|
|
136
|
-
except asyncio.QueueFull:
|
|
137
|
-
logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
|
|
138
|
-
await self._queue.put((response_object, input, messages))
|
|
139
|
-
else:
|
|
140
|
-
await self._write_response_object(response_object, input, messages)
|
|
141
|
-
|
|
142
|
-
async def _worker_loop(self) -> None:
|
|
143
|
-
assert self._queue is not None
|
|
144
|
-
while True:
|
|
145
|
-
try:
|
|
146
|
-
item = await self._queue.get()
|
|
147
|
-
except asyncio.CancelledError:
|
|
148
|
-
break
|
|
149
|
-
response_object, input, messages = item
|
|
150
|
-
try:
|
|
151
|
-
await self._write_response_object(response_object, input, messages)
|
|
152
|
-
except Exception as e: # noqa: BLE001
|
|
153
|
-
logger.error(f"Error writing response object: {e}")
|
|
154
|
-
finally:
|
|
155
|
-
self._queue.task_done()
|
|
109
|
+
await self._write_response_object(response_object, input, messages)
|
|
156
110
|
|
|
157
111
|
async def _write_response_object(
|
|
158
112
|
self,
|
|
@@ -45,8 +45,13 @@ def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: Use
|
|
|
45
45
|
enhanced["owner_principal"] = current_user.principal
|
|
46
46
|
enhanced["access_attributes"] = current_user.attributes
|
|
47
47
|
else:
|
|
48
|
-
|
|
49
|
-
|
|
48
|
+
# IMPORTANT: Use empty string and null value (not None) to match public access filter
|
|
49
|
+
# The public access filter in _get_public_access_conditions() expects:
|
|
50
|
+
# - owner_principal = '' (empty string)
|
|
51
|
+
# - access_attributes = null (JSON null, which serializes to the string 'null')
|
|
52
|
+
# Setting them to None (SQL NULL) will cause rows to be filtered out on read.
|
|
53
|
+
enhanced["owner_principal"] = ""
|
|
54
|
+
enhanced["access_attributes"] = None # Pydantic/JSON will serialize this as JSON null
|
|
50
55
|
return enhanced
|
|
51
56
|
|
|
52
57
|
|
|
@@ -188,8 +193,9 @@ class AuthorizedSqlStore:
|
|
|
188
193
|
enhanced_data["owner_principal"] = current_user.principal
|
|
189
194
|
enhanced_data["access_attributes"] = current_user.attributes
|
|
190
195
|
else:
|
|
191
|
-
|
|
192
|
-
enhanced_data["
|
|
196
|
+
# IMPORTANT: Use empty string for owner_principal to match public access filter
|
|
197
|
+
enhanced_data["owner_principal"] = ""
|
|
198
|
+
enhanced_data["access_attributes"] = None # Will serialize as JSON null
|
|
193
199
|
|
|
194
200
|
await self.sql_store.update(table, enhanced_data, where)
|
|
195
201
|
|
|
@@ -245,14 +251,24 @@ class AuthorizedSqlStore:
|
|
|
245
251
|
raise ValueError(f"Unsupported database type: {self.database_type}")
|
|
246
252
|
|
|
247
253
|
def _get_public_access_conditions(self) -> list[str]:
|
|
248
|
-
"""Get the SQL conditions for public access.
|
|
249
|
-
|
|
254
|
+
"""Get the SQL conditions for public access.
|
|
255
|
+
|
|
256
|
+
Public records are those with:
|
|
257
|
+
- owner_principal = '' (empty string)
|
|
258
|
+
- access_attributes is either SQL NULL or JSON null
|
|
259
|
+
|
|
260
|
+
Note: Different databases serialize None differently:
|
|
261
|
+
- SQLite: None → JSON null (text = 'null')
|
|
262
|
+
- Postgres: None → SQL NULL (IS NULL)
|
|
263
|
+
"""
|
|
250
264
|
conditions = ["owner_principal = ''"]
|
|
251
265
|
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
|
252
|
-
#
|
|
253
|
-
|
|
266
|
+
# Accept both SQL NULL and JSON null for Postgres compatibility
|
|
267
|
+
# This handles both old rows (SQL NULL) and new rows (JSON null)
|
|
268
|
+
conditions.append("(access_attributes IS NULL OR access_attributes::text = 'null')")
|
|
254
269
|
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
|
255
|
-
|
|
270
|
+
# SQLite serializes None as JSON null
|
|
271
|
+
conditions.append("(access_attributes IS NULL OR access_attributes = 'null')")
|
|
256
272
|
else:
|
|
257
273
|
raise ValueError(f"Unsupported database type: {self.database_type}")
|
|
258
274
|
return conditions
|
|
@@ -17,6 +17,7 @@ from sqlalchemy import (
|
|
|
17
17
|
String,
|
|
18
18
|
Table,
|
|
19
19
|
Text,
|
|
20
|
+
event,
|
|
20
21
|
inspect,
|
|
21
22
|
select,
|
|
22
23
|
text,
|
|
@@ -75,7 +76,36 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|
|
75
76
|
self.metadata = MetaData()
|
|
76
77
|
|
|
77
78
|
def create_engine(self) -> AsyncEngine:
|
|
78
|
-
|
|
79
|
+
# Configure connection args for better concurrency support
|
|
80
|
+
connect_args = {}
|
|
81
|
+
if "sqlite" in self.config.engine_str:
|
|
82
|
+
# SQLite-specific optimizations for concurrent access
|
|
83
|
+
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
|
|
84
|
+
connect_args["timeout"] = 5.0
|
|
85
|
+
connect_args["check_same_thread"] = False # Allow usage across asyncio tasks
|
|
86
|
+
|
|
87
|
+
engine = create_async_engine(
|
|
88
|
+
self.config.engine_str,
|
|
89
|
+
pool_pre_ping=True,
|
|
90
|
+
connect_args=connect_args,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Enable WAL mode for SQLite to support concurrent readers and writers
|
|
94
|
+
if "sqlite" in self.config.engine_str:
|
|
95
|
+
|
|
96
|
+
@event.listens_for(engine.sync_engine, "connect")
|
|
97
|
+
def set_sqlite_pragma(dbapi_conn, connection_record):
|
|
98
|
+
cursor = dbapi_conn.cursor()
|
|
99
|
+
# Enable Write-Ahead Logging for better concurrency
|
|
100
|
+
cursor.execute("PRAGMA journal_mode=WAL")
|
|
101
|
+
# Set busy timeout to 5 seconds (retry instead of immediate failure)
|
|
102
|
+
# With WAL mode, locks should be brief; if we hit 5s there's a bigger issue
|
|
103
|
+
cursor.execute("PRAGMA busy_timeout=5000")
|
|
104
|
+
# Use NORMAL synchronous mode for better performance (still safe with WAL)
|
|
105
|
+
cursor.execute("PRAGMA synchronous=NORMAL")
|
|
106
|
+
cursor.close()
|
|
107
|
+
|
|
108
|
+
return engine
|
|
79
109
|
|
|
80
110
|
async def create_table(
|
|
81
111
|
self,
|