llama-stack 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. llama_stack/cli/stack/list_deps.py +4 -0
  2. llama_stack/core/routers/inference.py +66 -40
  3. llama_stack/distributions/starter/build.yaml +1 -0
  4. llama_stack/distributions/starter/run-with-postgres-store.yaml +285 -0
  5. llama_stack/distributions/starter/starter.py +86 -68
  6. llama_stack/distributions/starter-gpu/build.yaml +1 -0
  7. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +288 -0
  8. llama_stack/providers/inline/vector_io/faiss/faiss.py +25 -2
  9. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +15 -4
  10. llama_stack/providers/remote/inference/vertexai/vertexai.py +10 -0
  11. llama_stack/providers/remote/vector_io/chroma/chroma.py +9 -3
  12. llama_stack/providers/remote/vector_io/milvus/milvus.py +7 -4
  13. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +32 -6
  14. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +11 -6
  15. llama_stack/providers/remote/vector_io/weaviate/weaviate.py +7 -4
  16. llama_stack/providers/utils/inference/embedding_mixin.py +1 -2
  17. llama_stack/providers/utils/inference/inference_store.py +30 -10
  18. llama_stack/providers/utils/inference/model_registry.py +1 -1
  19. llama_stack/providers/utils/inference/openai_mixin.py +33 -10
  20. llama_stack/providers/utils/responses/responses_store.py +12 -58
  21. llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +25 -9
  22. llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +31 -1
  23. llama_stack/ui/node_modules/flatted/python/flatted.py +149 -0
  24. {llama_stack-0.3.1.dist-info → llama_stack-0.3.3.dist-info}/METADATA +3 -3
  25. {llama_stack-0.3.1.dist-info → llama_stack-0.3.3.dist-info}/RECORD +29 -26
  26. {llama_stack-0.3.1.dist-info → llama_stack-0.3.3.dist-info}/WHEEL +0 -0
  27. {llama_stack-0.3.1.dist-info → llama_stack-0.3.3.dist-info}/entry_points.txt +0 -0
  28. {llama_stack-0.3.1.dist-info → llama_stack-0.3.3.dist-info}/licenses/LICENSE +0 -0
  29. {llama_stack-0.3.1.dist-info → llama_stack-0.3.3.dist-info}/top_level.txt +0 -0
@@ -368,6 +368,22 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
368
368
  log.exception("Could not connect to PGVector database server")
369
369
  raise RuntimeError("Could not connect to PGVector database server") from e
370
370
 
371
+ # Load existing vector stores from KV store into cache
372
+ start_key = VECTOR_DBS_PREFIX
373
+ end_key = f"{VECTOR_DBS_PREFIX}\xff"
374
+ stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
375
+ for vector_store_data in stored_vector_stores:
376
+ vector_store = VectorStore.model_validate_json(vector_store_data)
377
+ pgvector_index = PGVectorIndex(
378
+ vector_store=vector_store,
379
+ dimension=vector_store.embedding_dimension,
380
+ conn=self.conn,
381
+ kvstore=self.kvstore,
382
+ )
383
+ await pgvector_index.initialize()
384
+ index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api)
385
+ self.cache[vector_store.identifier] = index
386
+
371
387
  async def shutdown(self) -> None:
372
388
  if self.conn is not None:
373
389
  self.conn.close()
@@ -377,7 +393,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
377
393
 
378
394
  async def register_vector_store(self, vector_store: VectorStore) -> None:
379
395
  # Persist vector DB metadata in the KV store
380
- assert self.kvstore is not None
396
+ if self.kvstore is None:
397
+ raise RuntimeError("KVStore not initialized. Call initialize() before registering vector stores.")
398
+
399
+ # Save to kvstore for persistence
400
+ key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
401
+ await self.kvstore.set(key=key, value=vector_store.model_dump_json())
402
+
381
403
  # Upsert model metadata in Postgres
382
404
  upsert_models(self.conn, [(vector_store.identifier, vector_store)])
383
405
 
@@ -396,7 +418,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
396
418
  del self.cache[vector_store_id]
397
419
 
398
420
  # Delete vector DB metadata from KV store
399
- assert self.kvstore is not None
421
+ if self.kvstore is None:
422
+ raise RuntimeError("KVStore not initialized. Call initialize() before unregistering vector stores.")
400
423
  await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_store_id}")
401
424
 
402
425
  async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
@@ -413,13 +436,16 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
413
436
  if vector_store_id in self.cache:
414
437
  return self.cache[vector_store_id]
415
438
 
416
- if self.vector_store_table is None:
417
- raise VectorStoreNotFoundError(vector_store_id)
439
+ # Try to load from kvstore
440
+ if self.kvstore is None:
441
+ raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
418
442
 
419
- vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
420
- if not vector_store:
443
+ key = f"{VECTOR_DBS_PREFIX}{vector_store_id}"
444
+ vector_store_data = await self.kvstore.get(key)
445
+ if not vector_store_data:
421
446
  raise VectorStoreNotFoundError(vector_store_id)
422
447
 
448
+ vector_store = VectorStore.model_validate_json(vector_store_data)
423
449
  index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn)
424
450
  await index.initialize()
425
451
  self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api)
@@ -183,7 +183,8 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
183
183
  await super().shutdown()
184
184
 
185
185
  async def register_vector_store(self, vector_store: VectorStore) -> None:
186
- assert self.kvstore is not None
186
+ if self.kvstore is None:
187
+ raise RuntimeError("KVStore not initialized. Call initialize() before registering vector stores.")
187
188
  key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
188
189
  await self.kvstore.set(key=key, value=vector_store.model_dump_json())
189
190
 
@@ -200,20 +201,24 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
200
201
  await self.cache[vector_store_id].index.delete()
201
202
  del self.cache[vector_store_id]
202
203
 
203
- assert self.kvstore is not None
204
+ if self.kvstore is None:
205
+ raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
204
206
  await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
205
207
 
206
208
  async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
207
209
  if vector_store_id in self.cache:
208
210
  return self.cache[vector_store_id]
209
211
 
210
- if self.vector_store_table is None:
211
- raise ValueError(f"Vector DB not found {vector_store_id}")
212
+ # Try to load from kvstore
213
+ if self.kvstore is None:
214
+ raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
212
215
 
213
- vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
214
- if not vector_store:
216
+ key = f"{VECTOR_DBS_PREFIX}{vector_store_id}"
217
+ vector_store_data = await self.kvstore.get(key)
218
+ if not vector_store_data:
215
219
  raise VectorStoreNotFoundError(vector_store_id)
216
220
 
221
+ vector_store = VectorStore.model_validate_json(vector_store_data)
217
222
  index = VectorStoreWithIndex(
218
223
  vector_store=vector_store,
219
224
  index=QdrantIndex(client=self.client, collection_name=vector_store.identifier),
@@ -346,13 +346,16 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
346
346
  if vector_store_id in self.cache:
347
347
  return self.cache[vector_store_id]
348
348
 
349
- if self.vector_store_table is None:
350
- raise VectorStoreNotFoundError(vector_store_id)
349
+ # Try to load from kvstore
350
+ if self.kvstore is None:
351
+ raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
351
352
 
352
- vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
353
- if not vector_store:
353
+ key = f"{VECTOR_DBS_PREFIX}{vector_store_id}"
354
+ vector_store_data = await self.kvstore.get(key)
355
+ if not vector_store_data:
354
356
  raise VectorStoreNotFoundError(vector_store_id)
355
357
 
358
+ vector_store = VectorStore.model_validate_json(vector_store_data)
356
359
  client = self._get_client()
357
360
  sanitized_collection_name = sanitize_collection_name(vector_store.identifier, weaviate_format=True)
358
361
  if not client.collections.exists(sanitized_collection_name):
@@ -46,8 +46,7 @@ class SentenceTransformerEmbeddingMixin:
46
46
  raise ValueError("Empty list not supported")
47
47
 
48
48
  # Get the model and generate embeddings
49
- model_obj = await self.model_store.get_model(params.model)
50
- embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id)
49
+ embedding_model = await self._load_sentence_transformer_model(params.model)
51
50
  embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False)
52
51
 
53
52
  # Convert embeddings to the requested format
@@ -35,6 +35,7 @@ class InferenceStore:
35
35
  self.reference = reference
36
36
  self.sql_store = None
37
37
  self.policy = policy
38
+ self.enable_write_queue = True
38
39
 
39
40
  # Async write queue and worker control
40
41
  self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
@@ -47,14 +48,13 @@ class InferenceStore:
47
48
  base_store = sqlstore_impl(self.reference)
48
49
  self.sql_store = AuthorizedSqlStore(base_store, self.policy)
49
50
 
50
- # Disable write queue for SQLite to avoid concurrency issues
51
- backend_name = self.reference.backend
52
- backend_config = _SQLSTORE_BACKENDS.get(backend_name)
53
- if backend_config is None:
54
- raise ValueError(
55
- f"Unregistered SQL backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
56
- )
57
- self.enable_write_queue = backend_config.type != StorageBackendType.SQL_SQLITE
51
+ # Disable write queue for SQLite since WAL mode handles concurrency
52
+ # Keep it enabled for other backends (like Postgres) for performance
53
+ backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
54
+ if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
55
+ self.enable_write_queue = False
56
+ logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
57
+
58
58
  await self.sql_store.create_table(
59
59
  "chat_completions",
60
60
  {
@@ -70,8 +70,9 @@ class InferenceStore:
70
70
  self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
71
71
  for _ in range(self._num_writers):
72
72
  self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
73
- else:
74
- logger.info("Write queue disabled for SQLite to avoid concurrency issues")
73
+ logger.debug(
74
+ f"Inference store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
75
+ )
75
76
 
76
77
  async def shutdown(self) -> None:
77
78
  if not self._worker_tasks:
@@ -93,10 +94,29 @@ class InferenceStore:
93
94
  if self.enable_write_queue and self._queue is not None:
94
95
  await self._queue.join()
95
96
 
97
+ async def _ensure_workers_started(self) -> None:
98
+ """Ensure the async write queue workers run on the current loop."""
99
+ if not self.enable_write_queue:
100
+ return
101
+
102
+ if self._queue is None:
103
+ self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
104
+ logger.debug(
105
+ f"Inference store write queue created with max size {self._max_write_queue_size} "
106
+ f"and {self._num_writers} writers"
107
+ )
108
+
109
+ if not self._worker_tasks:
110
+ loop = asyncio.get_running_loop()
111
+ for _ in range(self._num_writers):
112
+ task = loop.create_task(self._worker_loop())
113
+ self._worker_tasks.append(task)
114
+
96
115
  async def store_chat_completion(
97
116
  self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
98
117
  ) -> None:
99
118
  if self.enable_write_queue:
119
+ await self._ensure_workers_started()
100
120
  if self._queue is None:
101
121
  raise ValueError("Inference store is not initialized")
102
122
  try:
@@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils")
20
20
 
21
21
 
22
22
  class RemoteInferenceProviderConfig(BaseModel):
23
- allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default
23
+ allowed_models: list[str] | None = Field(
24
24
  default=None,
25
25
  description="List of models that should be registered with the model registry. If None, all models are allowed.",
26
26
  )
@@ -82,9 +82,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
82
82
  # This is set in list_models() and used in check_model_availability()
83
83
  _model_cache: dict[str, Model] = {}
84
84
 
85
- # List of allowed models for this provider, if empty all models allowed
86
- allowed_models: list[str] = []
87
-
88
85
  # Optional field name in provider data to look for API key, which takes precedence
89
86
  provider_data_api_key_field: str | None = None
90
87
 
@@ -191,6 +188,19 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
191
188
 
192
189
  return api_key
193
190
 
191
+ def _validate_model_allowed(self, provider_model_id: str) -> None:
192
+ """
193
+ Validate that the model is in the allowed_models list if configured.
194
+
195
+ :param provider_model_id: The provider-specific model ID to validate
196
+ :raises ValueError: If the model is not in the allowed_models list
197
+ """
198
+ if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
199
+ raise ValueError(
200
+ f"Model '{provider_model_id}' is not in the allowed models list. "
201
+ f"Allowed models: {self.config.allowed_models}"
202
+ )
203
+
194
204
  async def _get_provider_model_id(self, model: str) -> str:
195
205
  """
196
206
  Get the provider-specific model ID from the model store.
@@ -201,8 +211,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
201
211
  :param model: The registered model name/identifier
202
212
  :return: The provider-specific model ID (e.g., "gpt-4")
203
213
  """
204
- # Look up the registered model to get the provider-specific model ID
205
214
  # self.model_store is injected by the distribution system at runtime
215
+ if not await self.model_store.has_model(model): # type: ignore[attr-defined]
216
+ return model
217
+
218
+ # Look up the registered model to get the provider-specific model ID
206
219
  model_obj: Model = await self.model_store.get_model(model) # type: ignore[attr-defined]
207
220
  # provider_resource_id is str | None, but we expect it to be str for OpenAI calls
208
221
  if model_obj.provider_resource_id is None:
@@ -234,8 +247,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
234
247
  Direct OpenAI completion API call.
235
248
  """
236
249
  # TODO: fix openai_completion to return type compatible with OpenAI's API response
250
+ provider_model_id = await self._get_provider_model_id(params.model)
251
+ self._validate_model_allowed(provider_model_id)
252
+
237
253
  completion_kwargs = await prepare_openai_completion_params(
238
- model=await self._get_provider_model_id(params.model),
254
+ model=provider_model_id,
239
255
  prompt=params.prompt,
240
256
  best_of=params.best_of,
241
257
  echo=params.echo,
@@ -267,6 +283,9 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
267
283
  """
268
284
  Direct OpenAI chat completion API call.
269
285
  """
286
+ provider_model_id = await self._get_provider_model_id(params.model)
287
+ self._validate_model_allowed(provider_model_id)
288
+
270
289
  messages = params.messages
271
290
 
272
291
  if self.download_images:
@@ -288,7 +307,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
288
307
  messages = [await _localize_image_url(m) for m in messages]
289
308
 
290
309
  request_params = await prepare_openai_completion_params(
291
- model=await self._get_provider_model_id(params.model),
310
+ model=provider_model_id,
292
311
  messages=messages,
293
312
  frequency_penalty=params.frequency_penalty,
294
313
  function_call=params.function_call,
@@ -326,9 +345,13 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
326
345
  """
327
346
  Direct OpenAI embeddings API call.
328
347
  """
329
- # Prepare request parameters
330
- request_params = {
331
- "model": await self._get_provider_model_id(params.model),
348
+ provider_model_id = await self._get_provider_model_id(params.model)
349
+ self._validate_model_allowed(provider_model_id)
350
+
351
+ # Build request params conditionally to avoid NotGiven/Omit type mismatch
352
+ # The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven
353
+ request_params: dict[str, Any] = {
354
+ "model": provider_model_id,
332
355
  "input": params.input,
333
356
  "encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
334
357
  "dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,
@@ -413,7 +436,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
413
436
  for provider_model_id in provider_models_ids:
414
437
  if not isinstance(provider_model_id, str):
415
438
  raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
416
- if self.allowed_models and provider_model_id not in self.allowed_models:
439
+ if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
417
440
  logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
418
441
  continue
419
442
  if metadata := self.embedding_model_metadata.get(provider_model_id):
@@ -4,7 +4,6 @@
4
4
  # This source code is licensed under the terms described in the LICENSE file in
5
5
  # the root directory of this source tree.
6
6
  import asyncio
7
- from typing import Any
8
7
 
9
8
  from llama_stack.apis.agents import (
10
9
  Order,
@@ -55,28 +54,19 @@ class ResponsesStore:
55
54
 
56
55
  self.policy = policy
57
56
  self.sql_store = None
58
- self.enable_write_queue = True
59
-
60
- # Async write queue and worker control
61
- self._queue: (
62
- asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput], list[OpenAIMessageParam]]] | None
63
- ) = None
64
- self._worker_tasks: list[asyncio.Task[Any]] = []
65
- self._max_write_queue_size: int = self.reference.max_write_queue_size
66
- self._num_writers: int = max(1, self.reference.num_writers)
67
57
 
68
58
  async def initialize(self):
69
59
  """Create the necessary tables if they don't exist."""
70
60
  base_store = sqlstore_impl(self.reference)
71
61
  self.sql_store = AuthorizedSqlStore(base_store, self.policy)
72
62
 
63
+ # Disable write queue for SQLite since WAL mode handles concurrency
64
+ # Keep it enabled for other backends (like Postgres) for performance
73
65
  backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
74
- if backend_config is None:
75
- raise ValueError(
76
- f"Unregistered SQL backend '{self.reference.backend}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
77
- )
78
- if backend_config.type == StorageBackendType.SQL_SQLITE:
66
+ if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
79
67
  self.enable_write_queue = False
68
+ logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
69
+
80
70
  await self.sql_store.create_table(
81
71
  "openai_responses",
82
72
  {
@@ -99,28 +89,16 @@ class ResponsesStore:
99
89
  self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
100
90
  for _ in range(self._num_writers):
101
91
  self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
102
- else:
103
- logger.debug("Write queue disabled for SQLite to avoid concurrency issues")
92
+ logger.debug(
93
+ f"Responses store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
94
+ )
104
95
 
105
96
  async def shutdown(self) -> None:
106
- if not self._worker_tasks:
107
- return
108
- if self._queue is not None:
109
- await self._queue.join()
110
- for t in self._worker_tasks:
111
- if not t.done():
112
- t.cancel()
113
- for t in self._worker_tasks:
114
- try:
115
- await t
116
- except asyncio.CancelledError:
117
- pass
118
- self._worker_tasks.clear()
97
+ return
119
98
 
120
99
  async def flush(self) -> None:
121
- """Wait for all queued writes to complete. Useful for testing."""
122
- if self.enable_write_queue and self._queue is not None:
123
- await self._queue.join()
100
+ """Maintained for compatibility; no-op now that writes are synchronous."""
101
+ return
124
102
 
125
103
  async def store_response_object(
126
104
  self,
@@ -128,31 +106,7 @@ class ResponsesStore:
128
106
  input: list[OpenAIResponseInput],
129
107
  messages: list[OpenAIMessageParam],
130
108
  ) -> None:
131
- if self.enable_write_queue:
132
- if self._queue is None:
133
- raise ValueError("Responses store is not initialized")
134
- try:
135
- self._queue.put_nowait((response_object, input, messages))
136
- except asyncio.QueueFull:
137
- logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
138
- await self._queue.put((response_object, input, messages))
139
- else:
140
- await self._write_response_object(response_object, input, messages)
141
-
142
- async def _worker_loop(self) -> None:
143
- assert self._queue is not None
144
- while True:
145
- try:
146
- item = await self._queue.get()
147
- except asyncio.CancelledError:
148
- break
149
- response_object, input, messages = item
150
- try:
151
- await self._write_response_object(response_object, input, messages)
152
- except Exception as e: # noqa: BLE001
153
- logger.error(f"Error writing response object: {e}")
154
- finally:
155
- self._queue.task_done()
109
+ await self._write_response_object(response_object, input, messages)
156
110
 
157
111
  async def _write_response_object(
158
112
  self,
@@ -45,8 +45,13 @@ def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: Use
45
45
  enhanced["owner_principal"] = current_user.principal
46
46
  enhanced["access_attributes"] = current_user.attributes
47
47
  else:
48
- enhanced["owner_principal"] = None
49
- enhanced["access_attributes"] = None
48
+ # IMPORTANT: Use empty string and null value (not None) to match public access filter
49
+ # The public access filter in _get_public_access_conditions() expects:
50
+ # - owner_principal = '' (empty string)
51
+ # - access_attributes = null (JSON null, which serializes to the string 'null')
52
+ # Setting them to None (SQL NULL) will cause rows to be filtered out on read.
53
+ enhanced["owner_principal"] = ""
54
+ enhanced["access_attributes"] = None # Pydantic/JSON will serialize this as JSON null
50
55
  return enhanced
51
56
 
52
57
 
@@ -188,8 +193,9 @@ class AuthorizedSqlStore:
188
193
  enhanced_data["owner_principal"] = current_user.principal
189
194
  enhanced_data["access_attributes"] = current_user.attributes
190
195
  else:
191
- enhanced_data["owner_principal"] = None
192
- enhanced_data["access_attributes"] = None
196
+ # IMPORTANT: Use empty string for owner_principal to match public access filter
197
+ enhanced_data["owner_principal"] = ""
198
+ enhanced_data["access_attributes"] = None # Will serialize as JSON null
193
199
 
194
200
  await self.sql_store.update(table, enhanced_data, where)
195
201
 
@@ -245,14 +251,24 @@ class AuthorizedSqlStore:
245
251
  raise ValueError(f"Unsupported database type: {self.database_type}")
246
252
 
247
253
  def _get_public_access_conditions(self) -> list[str]:
248
- """Get the SQL conditions for public access."""
249
- # Public records are records that have no owner_principal or access_attributes
254
+ """Get the SQL conditions for public access.
255
+
256
+ Public records are those with:
257
+ - owner_principal = '' (empty string)
258
+ - access_attributes is either SQL NULL or JSON null
259
+
260
+ Note: Different databases serialize None differently:
261
+ - SQLite: None → JSON null (text = 'null')
262
+ - Postgres: None → SQL NULL (IS NULL)
263
+ """
250
264
  conditions = ["owner_principal = ''"]
251
265
  if self.database_type == StorageBackendType.SQL_POSTGRES.value:
252
- # Postgres stores JSON null as 'null'
253
- conditions.append("access_attributes::text = 'null'")
266
+ # Accept both SQL NULL and JSON null for Postgres compatibility
267
+ # This handles both old rows (SQL NULL) and new rows (JSON null)
268
+ conditions.append("(access_attributes IS NULL OR access_attributes::text = 'null')")
254
269
  elif self.database_type == StorageBackendType.SQL_SQLITE.value:
255
- conditions.append("access_attributes = 'null'")
270
+ # SQLite serializes None as JSON null
271
+ conditions.append("(access_attributes IS NULL OR access_attributes = 'null')")
256
272
  else:
257
273
  raise ValueError(f"Unsupported database type: {self.database_type}")
258
274
  return conditions
@@ -17,6 +17,7 @@ from sqlalchemy import (
17
17
  String,
18
18
  Table,
19
19
  Text,
20
+ event,
20
21
  inspect,
21
22
  select,
22
23
  text,
@@ -75,7 +76,36 @@ class SqlAlchemySqlStoreImpl(SqlStore):
75
76
  self.metadata = MetaData()
76
77
 
77
78
  def create_engine(self) -> AsyncEngine:
78
- return create_async_engine(self.config.engine_str, pool_pre_ping=True)
79
+ # Configure connection args for better concurrency support
80
+ connect_args = {}
81
+ if "sqlite" in self.config.engine_str:
82
+ # SQLite-specific optimizations for concurrent access
83
+ # With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
84
+ connect_args["timeout"] = 5.0
85
+ connect_args["check_same_thread"] = False # Allow usage across asyncio tasks
86
+
87
+ engine = create_async_engine(
88
+ self.config.engine_str,
89
+ pool_pre_ping=True,
90
+ connect_args=connect_args,
91
+ )
92
+
93
+ # Enable WAL mode for SQLite to support concurrent readers and writers
94
+ if "sqlite" in self.config.engine_str:
95
+
96
+ @event.listens_for(engine.sync_engine, "connect")
97
+ def set_sqlite_pragma(dbapi_conn, connection_record):
98
+ cursor = dbapi_conn.cursor()
99
+ # Enable Write-Ahead Logging for better concurrency
100
+ cursor.execute("PRAGMA journal_mode=WAL")
101
+ # Set busy timeout to 5 seconds (retry instead of immediate failure)
102
+ # With WAL mode, locks should be brief; if we hit 5s there's a bigger issue
103
+ cursor.execute("PRAGMA busy_timeout=5000")
104
+ # Use NORMAL synchronous mode for better performance (still safe with WAL)
105
+ cursor.execute("PRAGMA synchronous=NORMAL")
106
+ cursor.close()
107
+
108
+ return engine
79
109
 
80
110
  async def create_table(
81
111
  self,