llama-stack 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/run.py +3 -0
- llama_stack/core/library_client.py +80 -3
- llama_stack/core/routing_tables/common.py +11 -0
- llama_stack/core/routing_tables/vector_stores.py +4 -0
- llama_stack/core/stack.py +38 -11
- llama_stack/core/storage/kvstore/kvstore.py +11 -0
- llama_stack/core/storage/kvstore/mongodb/mongodb.py +5 -0
- llama_stack/core/storage/kvstore/postgres/postgres.py +8 -0
- llama_stack/core/storage/kvstore/redis/redis.py +5 -0
- llama_stack/core/storage/sqlstore/sqlalchemy_sqlstore.py +8 -0
- llama_stack/core/storage/sqlstore/sqlstore.py +8 -0
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +60 -34
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +4 -0
- llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +9 -1
- llama_stack/providers/inline/tool_runtime/rag/memory.py +8 -3
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +13 -1
- llama_stack/providers/utils/inference/embedding_mixin.py +20 -16
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +33 -0
- llama_stack/providers/utils/memory/vector_store.py +9 -4
- llama_stack/providers/utils/tools/mcp.py +258 -16
- {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/METADATA +2 -2
- {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/RECORD +96 -29
- {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/WHEEL +1 -1
- llama_stack_api/internal/kvstore.py +2 -0
- llama_stack_api/internal/sqlstore.py +2 -0
- llama_stack_api/llama_stack_api/__init__.py +945 -0
- llama_stack_api/llama_stack_api/admin/__init__.py +45 -0
- llama_stack_api/llama_stack_api/admin/api.py +72 -0
- llama_stack_api/llama_stack_api/admin/fastapi_routes.py +117 -0
- llama_stack_api/llama_stack_api/admin/models.py +113 -0
- llama_stack_api/llama_stack_api/agents.py +173 -0
- llama_stack_api/llama_stack_api/batches/__init__.py +40 -0
- llama_stack_api/llama_stack_api/batches/api.py +53 -0
- llama_stack_api/llama_stack_api/batches/fastapi_routes.py +113 -0
- llama_stack_api/llama_stack_api/batches/models.py +78 -0
- llama_stack_api/llama_stack_api/benchmarks/__init__.py +43 -0
- llama_stack_api/llama_stack_api/benchmarks/api.py +39 -0
- llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +109 -0
- llama_stack_api/llama_stack_api/benchmarks/models.py +109 -0
- llama_stack_api/llama_stack_api/common/__init__.py +5 -0
- llama_stack_api/llama_stack_api/common/content_types.py +101 -0
- llama_stack_api/llama_stack_api/common/errors.py +95 -0
- llama_stack_api/llama_stack_api/common/job_types.py +38 -0
- llama_stack_api/llama_stack_api/common/responses.py +77 -0
- llama_stack_api/llama_stack_api/common/training_types.py +47 -0
- llama_stack_api/llama_stack_api/common/type_system.py +146 -0
- llama_stack_api/llama_stack_api/connectors.py +146 -0
- llama_stack_api/llama_stack_api/conversations.py +270 -0
- llama_stack_api/llama_stack_api/datasetio.py +55 -0
- llama_stack_api/llama_stack_api/datasets/__init__.py +61 -0
- llama_stack_api/llama_stack_api/datasets/api.py +35 -0
- llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +104 -0
- llama_stack_api/llama_stack_api/datasets/models.py +152 -0
- llama_stack_api/llama_stack_api/datatypes.py +373 -0
- llama_stack_api/llama_stack_api/eval.py +137 -0
- llama_stack_api/llama_stack_api/file_processors/__init__.py +27 -0
- llama_stack_api/llama_stack_api/file_processors/api.py +64 -0
- llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +78 -0
- llama_stack_api/llama_stack_api/file_processors/models.py +42 -0
- llama_stack_api/llama_stack_api/files/__init__.py +35 -0
- llama_stack_api/llama_stack_api/files/api.py +51 -0
- llama_stack_api/llama_stack_api/files/fastapi_routes.py +124 -0
- llama_stack_api/llama_stack_api/files/models.py +107 -0
- llama_stack_api/llama_stack_api/inference.py +1169 -0
- llama_stack_api/llama_stack_api/inspect_api/__init__.py +37 -0
- llama_stack_api/llama_stack_api/inspect_api/api.py +25 -0
- llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +76 -0
- llama_stack_api/llama_stack_api/inspect_api/models.py +28 -0
- llama_stack_api/llama_stack_api/internal/__init__.py +9 -0
- llama_stack_api/llama_stack_api/internal/kvstore.py +28 -0
- llama_stack_api/llama_stack_api/internal/sqlstore.py +81 -0
- llama_stack_api/llama_stack_api/models.py +171 -0
- llama_stack_api/llama_stack_api/openai_responses.py +1468 -0
- llama_stack_api/llama_stack_api/post_training.py +370 -0
- llama_stack_api/llama_stack_api/prompts.py +203 -0
- llama_stack_api/llama_stack_api/providers/__init__.py +33 -0
- llama_stack_api/llama_stack_api/providers/api.py +16 -0
- llama_stack_api/llama_stack_api/providers/fastapi_routes.py +57 -0
- llama_stack_api/llama_stack_api/providers/models.py +24 -0
- llama_stack_api/llama_stack_api/py.typed +0 -0
- llama_stack_api/llama_stack_api/rag_tool.py +168 -0
- llama_stack_api/llama_stack_api/resource.py +37 -0
- llama_stack_api/llama_stack_api/router_utils.py +160 -0
- llama_stack_api/llama_stack_api/safety.py +132 -0
- llama_stack_api/llama_stack_api/schema_utils.py +208 -0
- llama_stack_api/llama_stack_api/scoring.py +93 -0
- llama_stack_api/llama_stack_api/scoring_functions.py +211 -0
- llama_stack_api/llama_stack_api/shields.py +93 -0
- llama_stack_api/llama_stack_api/tools.py +226 -0
- llama_stack_api/llama_stack_api/vector_io.py +941 -0
- llama_stack_api/llama_stack_api/vector_stores.py +53 -0
- llama_stack_api/llama_stack_api/version.py +9 -0
- llama_stack_api/vector_stores.py +2 -0
- {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/licenses/LICENSE +0 -0
- {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/top_level.txt +0 -0
llama_stack/cli/stack/run.py
CHANGED
|
@@ -202,6 +202,9 @@ class StackRun(Subcommand):
|
|
|
202
202
|
# Set the config file in environment so create_app can find it
|
|
203
203
|
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
|
|
204
204
|
|
|
205
|
+
# disable together banner that spams llama stack run every time
|
|
206
|
+
os.environ["TOGETHER_NO_BANNER"] = "1"
|
|
207
|
+
|
|
205
208
|
uvicorn_config = {
|
|
206
209
|
"factory": True,
|
|
207
210
|
"host": host,
|
|
@@ -161,6 +161,45 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|
|
161
161
|
"""
|
|
162
162
|
pass
|
|
163
163
|
|
|
164
|
+
def shutdown(self) -> None:
|
|
165
|
+
"""Shutdown the client and release all resources.
|
|
166
|
+
|
|
167
|
+
This method should be called when you're done using the client to properly
|
|
168
|
+
close database connections and release other resources. Failure to call this
|
|
169
|
+
method may result in the program hanging on exit while waiting for background
|
|
170
|
+
threads to complete.
|
|
171
|
+
|
|
172
|
+
This method is idempotent and can be called multiple times safely.
|
|
173
|
+
|
|
174
|
+
Example:
|
|
175
|
+
client = LlamaStackAsLibraryClient("starter")
|
|
176
|
+
# ... use the client ...
|
|
177
|
+
client.shutdown()
|
|
178
|
+
"""
|
|
179
|
+
loop = self.loop
|
|
180
|
+
asyncio.set_event_loop(loop)
|
|
181
|
+
try:
|
|
182
|
+
loop.run_until_complete(self.async_client.shutdown())
|
|
183
|
+
finally:
|
|
184
|
+
loop.close()
|
|
185
|
+
asyncio.set_event_loop(None)
|
|
186
|
+
|
|
187
|
+
def __enter__(self) -> "LlamaStackAsLibraryClient":
|
|
188
|
+
"""Enter the context manager.
|
|
189
|
+
|
|
190
|
+
The client is already initialized in __init__, so this just returns self.
|
|
191
|
+
|
|
192
|
+
Example:
|
|
193
|
+
with LlamaStackAsLibraryClient("starter") as client:
|
|
194
|
+
response = client.models.list()
|
|
195
|
+
# Client is automatically shut down here
|
|
196
|
+
"""
|
|
197
|
+
return self
|
|
198
|
+
|
|
199
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
200
|
+
"""Exit the context manager and shut down the client."""
|
|
201
|
+
self.shutdown()
|
|
202
|
+
|
|
164
203
|
def request(self, *args, **kwargs):
|
|
165
204
|
loop = self.loop
|
|
166
205
|
asyncio.set_event_loop(loop)
|
|
@@ -224,6 +263,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|
|
224
263
|
self.custom_provider_registry = custom_provider_registry
|
|
225
264
|
self.provider_data = provider_data
|
|
226
265
|
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
|
266
|
+
self.stack: Stack | None = None
|
|
227
267
|
|
|
228
268
|
def _remove_root_logger_handlers(self):
|
|
229
269
|
"""
|
|
@@ -246,9 +286,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|
|
246
286
|
try:
|
|
247
287
|
self.route_impls = None
|
|
248
288
|
|
|
249
|
-
stack = Stack(self.config, self.custom_provider_registry)
|
|
250
|
-
await stack.initialize()
|
|
251
|
-
self.impls = stack.impls
|
|
289
|
+
self.stack = Stack(self.config, self.custom_provider_registry)
|
|
290
|
+
await self.stack.initialize()
|
|
291
|
+
self.impls = self.stack.impls
|
|
252
292
|
except ModuleNotFoundError as _e:
|
|
253
293
|
cprint(_e.msg, color="red", file=sys.stderr)
|
|
254
294
|
cprint(
|
|
@@ -283,6 +323,43 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|
|
283
323
|
self.route_impls = initialize_route_impls(self.impls)
|
|
284
324
|
return True
|
|
285
325
|
|
|
326
|
+
async def shutdown(self) -> None:
|
|
327
|
+
"""Shutdown the client and release all resources.
|
|
328
|
+
|
|
329
|
+
This method should be called when you're done using the client to properly
|
|
330
|
+
close database connections and release other resources. Failure to call this
|
|
331
|
+
method may result in the program hanging on exit while waiting for background
|
|
332
|
+
threads to complete.
|
|
333
|
+
|
|
334
|
+
This method is idempotent and can be called multiple times safely.
|
|
335
|
+
|
|
336
|
+
Example:
|
|
337
|
+
client = AsyncLlamaStackAsLibraryClient("starter")
|
|
338
|
+
await client.initialize()
|
|
339
|
+
# ... use the client ...
|
|
340
|
+
await client.shutdown()
|
|
341
|
+
"""
|
|
342
|
+
if self.stack:
|
|
343
|
+
await self.stack.shutdown()
|
|
344
|
+
self.stack = None
|
|
345
|
+
|
|
346
|
+
async def __aenter__(self) -> "AsyncLlamaStackAsLibraryClient":
|
|
347
|
+
"""Enter the async context manager.
|
|
348
|
+
|
|
349
|
+
Initializes the client and returns it.
|
|
350
|
+
|
|
351
|
+
Example:
|
|
352
|
+
async with AsyncLlamaStackAsLibraryClient("starter") as client:
|
|
353
|
+
response = await client.models.list()
|
|
354
|
+
# Client is automatically shut down here
|
|
355
|
+
"""
|
|
356
|
+
await self.initialize()
|
|
357
|
+
return self
|
|
358
|
+
|
|
359
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
360
|
+
"""Exit the async context manager and shut down the client."""
|
|
361
|
+
await self.shutdown()
|
|
362
|
+
|
|
286
363
|
async def request(
|
|
287
364
|
self,
|
|
288
365
|
cast_to: Any,
|
|
@@ -209,6 +209,17 @@ class CommonRoutingTableImpl(RoutingTable):
|
|
|
209
209
|
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
|
|
210
210
|
|
|
211
211
|
registered_obj = await register_object_with_provider(obj, p)
|
|
212
|
+
|
|
213
|
+
# Ensure OpenAI metadata exists for vector stores
|
|
214
|
+
if obj.type == ResourceType.vector_store.value:
|
|
215
|
+
if hasattr(p, "_ensure_openai_metadata_exists"):
|
|
216
|
+
await p._ensure_openai_metadata_exists(obj)
|
|
217
|
+
else:
|
|
218
|
+
logger.warning(
|
|
219
|
+
f"Provider {obj.provider_id} does not support OpenAI metadata creation. "
|
|
220
|
+
f"Vector store {obj.identifier} may not work with OpenAI-compatible APIs."
|
|
221
|
+
)
|
|
222
|
+
|
|
212
223
|
# TODO: This needs to be fixed for all APIs once they return the registered object
|
|
213
224
|
if obj.type == ResourceType.model.value:
|
|
214
225
|
await self.dist_registry.register(registered_obj)
|
|
@@ -55,6 +55,10 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
|
|
|
55
55
|
|
|
56
56
|
# Internal methods only - no public API exposure
|
|
57
57
|
|
|
58
|
+
async def list_vector_stores(self) -> list[VectorStoreWithOwner]:
|
|
59
|
+
"""List all registered vector stores."""
|
|
60
|
+
return await self.get_all_with_type(ResourceType.vector_store.value)
|
|
61
|
+
|
|
58
62
|
async def register_vector_store(
|
|
59
63
|
self,
|
|
60
64
|
vector_store_id: str,
|
llama_stack/core/stack.py
CHANGED
|
@@ -53,6 +53,7 @@ from llama_stack_api import (
|
|
|
53
53
|
PostTraining,
|
|
54
54
|
Prompts,
|
|
55
55
|
Providers,
|
|
56
|
+
RegisterBenchmarkRequest,
|
|
56
57
|
Safety,
|
|
57
58
|
Scoring,
|
|
58
59
|
ScoringFunctions,
|
|
@@ -61,6 +62,7 @@ from llama_stack_api import (
|
|
|
61
62
|
ToolRuntime,
|
|
62
63
|
VectorIO,
|
|
63
64
|
)
|
|
65
|
+
from llama_stack_api.datasets import RegisterDatasetRequest
|
|
64
66
|
|
|
65
67
|
logger = get_logger(name=__name__, category="core")
|
|
66
68
|
|
|
@@ -91,18 +93,22 @@ class LlamaStack(
|
|
|
91
93
|
pass
|
|
92
94
|
|
|
93
95
|
|
|
96
|
+
# Resources to register based on configuration.
|
|
97
|
+
# If a request class is specified, the configuration object will be converted to this class before invoking the registration method.
|
|
94
98
|
RESOURCES = [
|
|
95
|
-
("models", Api.models, "register_model", "list_models"),
|
|
96
|
-
("shields", Api.shields, "register_shield", "list_shields"),
|
|
97
|
-
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
|
99
|
+
("models", Api.models, "register_model", "list_models", None),
|
|
100
|
+
("shields", Api.shields, "register_shield", "list_shields", None),
|
|
101
|
+
("datasets", Api.datasets, "register_dataset", "list_datasets", RegisterDatasetRequest),
|
|
98
102
|
(
|
|
99
103
|
"scoring_fns",
|
|
100
104
|
Api.scoring_functions,
|
|
101
105
|
"register_scoring_function",
|
|
102
106
|
"list_scoring_functions",
|
|
107
|
+
None,
|
|
103
108
|
),
|
|
104
|
-
("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"),
|
|
105
|
-
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
|
|
109
|
+
("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks", RegisterBenchmarkRequest),
|
|
110
|
+
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups", None),
|
|
111
|
+
("vector_stores", Api.vector_stores, "register_vector_store", "list_vector_stores", None),
|
|
106
112
|
]
|
|
107
113
|
|
|
108
114
|
|
|
@@ -199,7 +205,7 @@ async def invoke_with_optional_request(method: Any) -> Any:
|
|
|
199
205
|
|
|
200
206
|
|
|
201
207
|
async def register_resources(run_config: StackConfig, impls: dict[Api, Any]):
|
|
202
|
-
for rsrc, api, register_method, list_method in RESOURCES:
|
|
208
|
+
for rsrc, api, register_method, list_method, request_class in RESOURCES:
|
|
203
209
|
objects = getattr(run_config.registered_resources, rsrc)
|
|
204
210
|
if api not in impls:
|
|
205
211
|
continue
|
|
@@ -213,10 +219,17 @@ async def register_resources(run_config: StackConfig, impls: dict[Api, Any]):
|
|
|
213
219
|
continue
|
|
214
220
|
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
|
215
221
|
|
|
216
|
-
#
|
|
217
|
-
#
|
|
218
|
-
|
|
219
|
-
|
|
222
|
+
# TODO: Once all register methods are migrated to accept request objects,
|
|
223
|
+
# remove this conditional and always use the request_class pattern.
|
|
224
|
+
if request_class is not None:
|
|
225
|
+
request = request_class(**obj.model_dump())
|
|
226
|
+
await method(request)
|
|
227
|
+
else:
|
|
228
|
+
# we want to maintain the type information in arguments to method.
|
|
229
|
+
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
|
|
230
|
+
# we use model_dump() to find all the attrs and then getattr to get the still typed
|
|
231
|
+
# value.
|
|
232
|
+
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
|
|
220
233
|
|
|
221
234
|
method = getattr(impls[api], list_method)
|
|
222
235
|
response = await invoke_with_optional_request(method)
|
|
@@ -608,7 +621,7 @@ class Stack:
|
|
|
608
621
|
async def shutdown(self):
|
|
609
622
|
for impl in self.impls.values():
|
|
610
623
|
impl_name = impl.__class__.__name__
|
|
611
|
-
logger.
|
|
624
|
+
logger.debug(f"Shutting down {impl_name}")
|
|
612
625
|
try:
|
|
613
626
|
if hasattr(impl, "shutdown"):
|
|
614
627
|
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
|
@@ -630,6 +643,20 @@ class Stack:
|
|
|
630
643
|
if REGISTRY_REFRESH_TASK:
|
|
631
644
|
REGISTRY_REFRESH_TASK.cancel()
|
|
632
645
|
|
|
646
|
+
# Shutdown storage backends
|
|
647
|
+
from llama_stack.core.storage.kvstore.kvstore import shutdown_kvstore_backends
|
|
648
|
+
from llama_stack.core.storage.sqlstore.sqlstore import shutdown_sqlstore_backends
|
|
649
|
+
|
|
650
|
+
try:
|
|
651
|
+
await shutdown_kvstore_backends()
|
|
652
|
+
except Exception as e:
|
|
653
|
+
logger.exception(f"Failed to shutdown KV store backends: {e}")
|
|
654
|
+
|
|
655
|
+
try:
|
|
656
|
+
await shutdown_sqlstore_backends()
|
|
657
|
+
except Exception as e:
|
|
658
|
+
logger.exception(f"Failed to shutdown SQL store backends: {e}")
|
|
659
|
+
|
|
633
660
|
|
|
634
661
|
async def refresh_registry_once(impls: dict[Api, Any]):
|
|
635
662
|
logger.debug("refreshing registry")
|
|
@@ -62,6 +62,9 @@ class InmemoryKVStoreImpl(KVStore):
|
|
|
62
62
|
async def delete(self, key: str) -> None:
|
|
63
63
|
del self._store[key]
|
|
64
64
|
|
|
65
|
+
async def shutdown(self) -> None:
|
|
66
|
+
self._store.clear()
|
|
67
|
+
|
|
65
68
|
|
|
66
69
|
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
|
|
67
70
|
_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {}
|
|
@@ -126,3 +129,11 @@ async def kvstore_impl(reference: KVStoreReference) -> KVStore:
|
|
|
126
129
|
await impl.initialize()
|
|
127
130
|
_KVSTORE_INSTANCES[cache_key] = impl
|
|
128
131
|
return impl
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
async def shutdown_kvstore_backends() -> None:
|
|
135
|
+
"""Shutdown all cached KV store instances."""
|
|
136
|
+
global _KVSTORE_INSTANCES
|
|
137
|
+
for instance in _KVSTORE_INSTANCES.values():
|
|
138
|
+
await instance.shutdown()
|
|
139
|
+
_KVSTORE_INSTANCES.clear()
|
|
@@ -123,3 +123,11 @@ class PostgresKVStoreImpl(KVStore):
|
|
|
123
123
|
(start_key, end_key),
|
|
124
124
|
)
|
|
125
125
|
return [row[0] for row in cursor.fetchall()]
|
|
126
|
+
|
|
127
|
+
async def shutdown(self) -> None:
|
|
128
|
+
if self._cursor:
|
|
129
|
+
self._cursor.close()
|
|
130
|
+
self._cursor = None
|
|
131
|
+
if self._conn:
|
|
132
|
+
self._conn.close()
|
|
133
|
+
self._conn = None
|
|
@@ -107,6 +107,14 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|
|
107
107
|
|
|
108
108
|
return engine
|
|
109
109
|
|
|
110
|
+
async def shutdown(self) -> None:
|
|
111
|
+
"""Dispose the session maker's engine and close all connections."""
|
|
112
|
+
# The async_session holds a reference to the engine created in __init__
|
|
113
|
+
if self.async_session:
|
|
114
|
+
engine = self.async_session.kw.get("bind")
|
|
115
|
+
if engine:
|
|
116
|
+
await engine.dispose()
|
|
117
|
+
|
|
110
118
|
async def create_table(
|
|
111
119
|
self,
|
|
112
120
|
table: str,
|
|
@@ -85,3 +85,11 @@ def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> Non
|
|
|
85
85
|
_SQLSTORE_LOCKS.clear()
|
|
86
86
|
for name, cfg in backends.items():
|
|
87
87
|
_SQLSTORE_BACKENDS[name] = cfg
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
async def shutdown_sqlstore_backends() -> None:
|
|
91
|
+
"""Shutdown all cached SQL store instances."""
|
|
92
|
+
global _SQLSTORE_INSTANCES
|
|
93
|
+
for instance in _SQLSTORE_INSTANCES.values():
|
|
94
|
+
await instance.shutdown()
|
|
95
|
+
_SQLSTORE_INSTANCES.clear()
|
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
import asyncio
|
|
7
8
|
import re
|
|
8
9
|
import time
|
|
9
10
|
import uuid
|
|
@@ -16,6 +17,7 @@ from llama_stack.providers.utils.responses.responses_store import (
|
|
|
16
17
|
ResponsesStore,
|
|
17
18
|
_OpenAIResponseObjectWithInputAndMessages,
|
|
18
19
|
)
|
|
20
|
+
from llama_stack.providers.utils.tools.mcp import MCPSessionManager
|
|
19
21
|
from llama_stack_api import (
|
|
20
22
|
ConversationItem,
|
|
21
23
|
Conversations,
|
|
@@ -489,6 +491,19 @@ class OpenAIResponsesImpl:
|
|
|
489
491
|
response_id = f"resp_{uuid.uuid4()}"
|
|
490
492
|
created_at = int(time.time())
|
|
491
493
|
|
|
494
|
+
# Create a per-request MCP session manager for session reuse (fix for #4452)
|
|
495
|
+
# This avoids redundant tools/list calls when making multiple MCP tool invocations
|
|
496
|
+
mcp_session_manager = MCPSessionManager()
|
|
497
|
+
|
|
498
|
+
# Create a per-request ToolExecutor with the session manager
|
|
499
|
+
request_tool_executor = ToolExecutor(
|
|
500
|
+
tool_groups_api=self.tool_groups_api,
|
|
501
|
+
tool_runtime_api=self.tool_runtime_api,
|
|
502
|
+
vector_io_api=self.vector_io_api,
|
|
503
|
+
vector_stores_config=self.tool_executor.vector_stores_config,
|
|
504
|
+
mcp_session_manager=mcp_session_manager,
|
|
505
|
+
)
|
|
506
|
+
|
|
492
507
|
orchestrator = StreamingResponseOrchestrator(
|
|
493
508
|
inference_api=self.inference_api,
|
|
494
509
|
ctx=ctx,
|
|
@@ -498,7 +513,7 @@ class OpenAIResponsesImpl:
|
|
|
498
513
|
text=text,
|
|
499
514
|
max_infer_iters=max_infer_iters,
|
|
500
515
|
parallel_tool_calls=parallel_tool_calls,
|
|
501
|
-
tool_executor=
|
|
516
|
+
tool_executor=request_tool_executor,
|
|
502
517
|
safety_api=self.safety_api,
|
|
503
518
|
guardrail_ids=guardrail_ids,
|
|
504
519
|
instructions=instructions,
|
|
@@ -513,41 +528,52 @@ class OpenAIResponsesImpl:
|
|
|
513
528
|
|
|
514
529
|
# Type as ConversationItem to avoid list invariance issues
|
|
515
530
|
output_items: list[ConversationItem] = []
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
if store:
|
|
539
|
-
# TODO: we really should work off of output_items instead of "final_messages"
|
|
540
|
-
await self._store_response(
|
|
541
|
-
response=final_response,
|
|
542
|
-
input=all_input,
|
|
543
|
-
messages=messages_to_store,
|
|
531
|
+
try:
|
|
532
|
+
async for stream_chunk in orchestrator.create_response():
|
|
533
|
+
match stream_chunk.type:
|
|
534
|
+
case "response.completed" | "response.incomplete":
|
|
535
|
+
final_response = stream_chunk.response
|
|
536
|
+
case "response.failed":
|
|
537
|
+
failed_response = stream_chunk.response
|
|
538
|
+
case "response.output_item.done":
|
|
539
|
+
item = stream_chunk.item
|
|
540
|
+
output_items.append(item)
|
|
541
|
+
case _:
|
|
542
|
+
pass # Other event types
|
|
543
|
+
|
|
544
|
+
# Store and sync before yielding terminal events
|
|
545
|
+
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event
|
|
546
|
+
if (
|
|
547
|
+
stream_chunk.type in {"response.completed", "response.incomplete"}
|
|
548
|
+
and final_response
|
|
549
|
+
and failed_response is None
|
|
550
|
+
):
|
|
551
|
+
messages_to_store = list(
|
|
552
|
+
filter(lambda x: not isinstance(x, OpenAISystemMessageParam), orchestrator.final_messages)
|
|
544
553
|
)
|
|
554
|
+
if store:
|
|
555
|
+
# TODO: we really should work off of output_items instead of "final_messages"
|
|
556
|
+
await self._store_response(
|
|
557
|
+
response=final_response,
|
|
558
|
+
input=all_input,
|
|
559
|
+
messages=messages_to_store,
|
|
560
|
+
)
|
|
545
561
|
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
562
|
+
if conversation:
|
|
563
|
+
await self._sync_response_to_conversation(conversation, input, output_items)
|
|
564
|
+
await self.responses_store.store_conversation_messages(conversation, messages_to_store)
|
|
565
|
+
|
|
566
|
+
yield stream_chunk
|
|
567
|
+
finally:
|
|
568
|
+
# Clean up MCP sessions at the end of the request (fix for #4452)
|
|
569
|
+
# Use shield() to prevent cancellation from interrupting cleanup and leaking resources
|
|
570
|
+
# Wrap in try/except as cleanup errors should not mask the original response
|
|
571
|
+
try:
|
|
572
|
+
await asyncio.shield(mcp_session_manager.close_all())
|
|
573
|
+
except BaseException as e:
|
|
574
|
+
# Debug level - cleanup errors are expected in streaming scenarios where
|
|
575
|
+
# anyio cancel scopes may be in a different task context
|
|
576
|
+
logger.debug(f"Error during MCP session cleanup: {e}")
|
|
551
577
|
|
|
552
578
|
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
|
553
579
|
return await self.responses_store.delete_response_object(response_id)
|
|
@@ -1200,6 +1200,9 @@ class StreamingResponseOrchestrator:
|
|
|
1200
1200
|
"mcp_list_tools_id": list_id,
|
|
1201
1201
|
}
|
|
1202
1202
|
|
|
1203
|
+
# Get session manager from tool_executor if available (fix for #4452)
|
|
1204
|
+
session_manager = getattr(self.tool_executor, "mcp_session_manager", None)
|
|
1205
|
+
|
|
1203
1206
|
# TODO: follow semantic conventions for Open Telemetry tool spans
|
|
1204
1207
|
# https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span
|
|
1205
1208
|
with tracer.start_as_current_span("list_mcp_tools", attributes=attributes):
|
|
@@ -1207,6 +1210,7 @@ class StreamingResponseOrchestrator:
|
|
|
1207
1210
|
endpoint=mcp_tool.server_url,
|
|
1208
1211
|
headers=mcp_tool.headers,
|
|
1209
1212
|
authorization=mcp_tool.authorization,
|
|
1213
|
+
session_manager=session_manager,
|
|
1210
1214
|
)
|
|
1211
1215
|
|
|
1212
1216
|
# Create the MCP list tools message
|
|
@@ -54,11 +54,14 @@ class ToolExecutor:
|
|
|
54
54
|
tool_runtime_api: ToolRuntime,
|
|
55
55
|
vector_io_api: VectorIO,
|
|
56
56
|
vector_stores_config=None,
|
|
57
|
+
mcp_session_manager=None,
|
|
57
58
|
):
|
|
58
59
|
self.tool_groups_api = tool_groups_api
|
|
59
60
|
self.tool_runtime_api = tool_runtime_api
|
|
60
61
|
self.vector_io_api = vector_io_api
|
|
61
62
|
self.vector_stores_config = vector_stores_config
|
|
63
|
+
# Optional MCPSessionManager for session reuse within a request (fix for #4452)
|
|
64
|
+
self.mcp_session_manager = mcp_session_manager
|
|
62
65
|
|
|
63
66
|
async def execute_tool_call(
|
|
64
67
|
self,
|
|
@@ -233,6 +236,7 @@ class ToolExecutor:
|
|
|
233
236
|
"document_ids": [r.file_id for r in search_results],
|
|
234
237
|
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
|
235
238
|
"scores": [r.score for r in search_results],
|
|
239
|
+
"attributes": [r.attributes or {} for r in search_results],
|
|
236
240
|
"citation_files": citation_files,
|
|
237
241
|
},
|
|
238
242
|
)
|
|
@@ -327,12 +331,14 @@ class ToolExecutor:
|
|
|
327
331
|
# TODO: follow semantic conventions for Open Telemetry tool spans
|
|
328
332
|
# https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span
|
|
329
333
|
with tracer.start_as_current_span("invoke_mcp_tool", attributes=attributes):
|
|
334
|
+
# Pass session_manager for session reuse within request (fix for #4452)
|
|
330
335
|
result = await invoke_mcp_tool(
|
|
331
336
|
endpoint=mcp_tool.server_url,
|
|
332
337
|
tool_name=function_name,
|
|
333
338
|
kwargs=tool_kwargs,
|
|
334
339
|
headers=mcp_tool.headers,
|
|
335
340
|
authorization=mcp_tool.authorization,
|
|
341
|
+
session_manager=self.mcp_session_manager,
|
|
336
342
|
)
|
|
337
343
|
elif function_name == "knowledge_search":
|
|
338
344
|
response_file_search_tool = (
|
|
@@ -464,16 +470,18 @@ class ToolExecutor:
|
|
|
464
470
|
)
|
|
465
471
|
if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata:
|
|
466
472
|
message.results = []
|
|
473
|
+
attributes_list = metadata.get("attributes", [])
|
|
467
474
|
for i, doc_id in enumerate(metadata["document_ids"]):
|
|
468
475
|
text = metadata["chunks"][i] if "chunks" in metadata else None
|
|
469
476
|
score = metadata["scores"][i] if "scores" in metadata else None
|
|
477
|
+
attrs = attributes_list[i] if i < len(attributes_list) else {}
|
|
470
478
|
message.results.append(
|
|
471
479
|
OpenAIResponseOutputMessageFileSearchToolCallResults(
|
|
472
480
|
file_id=doc_id,
|
|
473
481
|
filename=doc_id,
|
|
474
482
|
text=text if text is not None else "",
|
|
475
483
|
score=score if score is not None else 0.0,
|
|
476
|
-
attributes=
|
|
484
|
+
attributes=attrs,
|
|
477
485
|
)
|
|
478
486
|
)
|
|
479
487
|
if has_error:
|
|
@@ -50,8 +50,11 @@ log = get_logger(name=__name__, category="tool_runtime")
|
|
|
50
50
|
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
|
51
51
|
"""Get raw binary data and mime type from a RAGDocument for file upload."""
|
|
52
52
|
if isinstance(doc.content, URL):
|
|
53
|
-
|
|
54
|
-
|
|
53
|
+
uri = doc.content.uri
|
|
54
|
+
if uri.startswith("file://"):
|
|
55
|
+
raise ValueError("file:// URIs are not supported. Please use the Files API (/v1/files) to upload files.")
|
|
56
|
+
if uri.startswith("data:"):
|
|
57
|
+
parts = parse_data_url(uri)
|
|
55
58
|
mime_type = parts["mimetype"]
|
|
56
59
|
data = parts["data"]
|
|
57
60
|
|
|
@@ -63,7 +66,7 @@ async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
|
|
63
66
|
return file_data, mime_type
|
|
64
67
|
else:
|
|
65
68
|
async with httpx.AsyncClient() as client:
|
|
66
|
-
r = await client.get(
|
|
69
|
+
r = await client.get(uri)
|
|
67
70
|
r.raise_for_status()
|
|
68
71
|
mime_type = r.headers.get("content-type", "application/octet-stream")
|
|
69
72
|
return r.content, mime_type
|
|
@@ -73,6 +76,8 @@ async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
|
|
73
76
|
else:
|
|
74
77
|
content_str = interleaved_content_as_str(doc.content)
|
|
75
78
|
|
|
79
|
+
if content_str.startswith("file://"):
|
|
80
|
+
raise ValueError("file:// URIs are not supported. Please use the Files API (/v1/files) to upload files.")
|
|
76
81
|
if content_str.startswith("data:"):
|
|
77
82
|
parts = parse_data_url(content_str)
|
|
78
83
|
mime_type = parts["mimetype"]
|
|
@@ -10,6 +10,7 @@ from typing import Any
|
|
|
10
10
|
import psycopg2
|
|
11
11
|
from numpy.typing import NDArray
|
|
12
12
|
from psycopg2 import sql
|
|
13
|
+
from psycopg2.extensions import cursor
|
|
13
14
|
from psycopg2.extras import Json, execute_values
|
|
14
15
|
from pydantic import BaseModel, TypeAdapter
|
|
15
16
|
|
|
@@ -54,6 +55,17 @@ def check_extension_version(cur):
|
|
|
54
55
|
return result[0] if result else None
|
|
55
56
|
|
|
56
57
|
|
|
58
|
+
def create_vector_extension(cur: cursor) -> None:
|
|
59
|
+
try:
|
|
60
|
+
log.info("Vector extension not found, creating...")
|
|
61
|
+
cur.execute("CREATE EXTENSION vector;")
|
|
62
|
+
log.info("Vector extension created successfully")
|
|
63
|
+
log.info(f"Vector extension version: {check_extension_version(cur)}")
|
|
64
|
+
|
|
65
|
+
except psycopg2.Error as e:
|
|
66
|
+
raise RuntimeError(f"Failed to create vector extension for PGVector: {e}") from e
|
|
67
|
+
|
|
68
|
+
|
|
57
69
|
def upsert_models(conn, keys_models: list[tuple[str, BaseModel]]):
|
|
58
70
|
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
|
59
71
|
query = sql.SQL(
|
|
@@ -364,7 +376,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
|
|
364
376
|
if version:
|
|
365
377
|
log.info(f"Vector extension version: {version}")
|
|
366
378
|
else:
|
|
367
|
-
|
|
379
|
+
create_vector_extension(cur)
|
|
368
380
|
|
|
369
381
|
cur.execute(
|
|
370
382
|
"""
|