llama-stack 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. llama_stack/cli/stack/run.py +3 -0
  2. llama_stack/core/library_client.py +80 -3
  3. llama_stack/core/routing_tables/common.py +11 -0
  4. llama_stack/core/routing_tables/vector_stores.py +4 -0
  5. llama_stack/core/stack.py +38 -11
  6. llama_stack/core/storage/kvstore/kvstore.py +11 -0
  7. llama_stack/core/storage/kvstore/mongodb/mongodb.py +5 -0
  8. llama_stack/core/storage/kvstore/postgres/postgres.py +8 -0
  9. llama_stack/core/storage/kvstore/redis/redis.py +5 -0
  10. llama_stack/core/storage/sqlstore/sqlalchemy_sqlstore.py +8 -0
  11. llama_stack/core/storage/sqlstore/sqlstore.py +8 -0
  12. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +60 -34
  13. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +4 -0
  14. llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +9 -1
  15. llama_stack/providers/inline/tool_runtime/rag/memory.py +8 -3
  16. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +13 -1
  17. llama_stack/providers/utils/inference/embedding_mixin.py +20 -16
  18. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +33 -0
  19. llama_stack/providers/utils/memory/vector_store.py +9 -4
  20. llama_stack/providers/utils/tools/mcp.py +258 -16
  21. {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/METADATA +2 -2
  22. {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/RECORD +96 -29
  23. {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/WHEEL +1 -1
  24. llama_stack_api/internal/kvstore.py +2 -0
  25. llama_stack_api/internal/sqlstore.py +2 -0
  26. llama_stack_api/llama_stack_api/__init__.py +945 -0
  27. llama_stack_api/llama_stack_api/admin/__init__.py +45 -0
  28. llama_stack_api/llama_stack_api/admin/api.py +72 -0
  29. llama_stack_api/llama_stack_api/admin/fastapi_routes.py +117 -0
  30. llama_stack_api/llama_stack_api/admin/models.py +113 -0
  31. llama_stack_api/llama_stack_api/agents.py +173 -0
  32. llama_stack_api/llama_stack_api/batches/__init__.py +40 -0
  33. llama_stack_api/llama_stack_api/batches/api.py +53 -0
  34. llama_stack_api/llama_stack_api/batches/fastapi_routes.py +113 -0
  35. llama_stack_api/llama_stack_api/batches/models.py +78 -0
  36. llama_stack_api/llama_stack_api/benchmarks/__init__.py +43 -0
  37. llama_stack_api/llama_stack_api/benchmarks/api.py +39 -0
  38. llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +109 -0
  39. llama_stack_api/llama_stack_api/benchmarks/models.py +109 -0
  40. llama_stack_api/llama_stack_api/common/__init__.py +5 -0
  41. llama_stack_api/llama_stack_api/common/content_types.py +101 -0
  42. llama_stack_api/llama_stack_api/common/errors.py +95 -0
  43. llama_stack_api/llama_stack_api/common/job_types.py +38 -0
  44. llama_stack_api/llama_stack_api/common/responses.py +77 -0
  45. llama_stack_api/llama_stack_api/common/training_types.py +47 -0
  46. llama_stack_api/llama_stack_api/common/type_system.py +146 -0
  47. llama_stack_api/llama_stack_api/connectors.py +146 -0
  48. llama_stack_api/llama_stack_api/conversations.py +270 -0
  49. llama_stack_api/llama_stack_api/datasetio.py +55 -0
  50. llama_stack_api/llama_stack_api/datasets/__init__.py +61 -0
  51. llama_stack_api/llama_stack_api/datasets/api.py +35 -0
  52. llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +104 -0
  53. llama_stack_api/llama_stack_api/datasets/models.py +152 -0
  54. llama_stack_api/llama_stack_api/datatypes.py +373 -0
  55. llama_stack_api/llama_stack_api/eval.py +137 -0
  56. llama_stack_api/llama_stack_api/file_processors/__init__.py +27 -0
  57. llama_stack_api/llama_stack_api/file_processors/api.py +64 -0
  58. llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +78 -0
  59. llama_stack_api/llama_stack_api/file_processors/models.py +42 -0
  60. llama_stack_api/llama_stack_api/files/__init__.py +35 -0
  61. llama_stack_api/llama_stack_api/files/api.py +51 -0
  62. llama_stack_api/llama_stack_api/files/fastapi_routes.py +124 -0
  63. llama_stack_api/llama_stack_api/files/models.py +107 -0
  64. llama_stack_api/llama_stack_api/inference.py +1169 -0
  65. llama_stack_api/llama_stack_api/inspect_api/__init__.py +37 -0
  66. llama_stack_api/llama_stack_api/inspect_api/api.py +25 -0
  67. llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +76 -0
  68. llama_stack_api/llama_stack_api/inspect_api/models.py +28 -0
  69. llama_stack_api/llama_stack_api/internal/__init__.py +9 -0
  70. llama_stack_api/llama_stack_api/internal/kvstore.py +28 -0
  71. llama_stack_api/llama_stack_api/internal/sqlstore.py +81 -0
  72. llama_stack_api/llama_stack_api/models.py +171 -0
  73. llama_stack_api/llama_stack_api/openai_responses.py +1468 -0
  74. llama_stack_api/llama_stack_api/post_training.py +370 -0
  75. llama_stack_api/llama_stack_api/prompts.py +203 -0
  76. llama_stack_api/llama_stack_api/providers/__init__.py +33 -0
  77. llama_stack_api/llama_stack_api/providers/api.py +16 -0
  78. llama_stack_api/llama_stack_api/providers/fastapi_routes.py +57 -0
  79. llama_stack_api/llama_stack_api/providers/models.py +24 -0
  80. llama_stack_api/llama_stack_api/py.typed +0 -0
  81. llama_stack_api/llama_stack_api/rag_tool.py +168 -0
  82. llama_stack_api/llama_stack_api/resource.py +37 -0
  83. llama_stack_api/llama_stack_api/router_utils.py +160 -0
  84. llama_stack_api/llama_stack_api/safety.py +132 -0
  85. llama_stack_api/llama_stack_api/schema_utils.py +208 -0
  86. llama_stack_api/llama_stack_api/scoring.py +93 -0
  87. llama_stack_api/llama_stack_api/scoring_functions.py +211 -0
  88. llama_stack_api/llama_stack_api/shields.py +93 -0
  89. llama_stack_api/llama_stack_api/tools.py +226 -0
  90. llama_stack_api/llama_stack_api/vector_io.py +941 -0
  91. llama_stack_api/llama_stack_api/vector_stores.py +53 -0
  92. llama_stack_api/llama_stack_api/version.py +9 -0
  93. llama_stack_api/vector_stores.py +2 -0
  94. {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/entry_points.txt +0 -0
  95. {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/licenses/LICENSE +0 -0
  96. {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/top_level.txt +0 -0
@@ -202,6 +202,9 @@ class StackRun(Subcommand):
202
202
  # Set the config file in environment so create_app can find it
203
203
  os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
204
204
 
205
+ # disable together banner that spams llama stack run every time
206
+ os.environ["TOGETHER_NO_BANNER"] = "1"
207
+
205
208
  uvicorn_config = {
206
209
  "factory": True,
207
210
  "host": host,
@@ -161,6 +161,45 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
161
161
  """
162
162
  pass
163
163
 
164
+ def shutdown(self) -> None:
165
+ """Shutdown the client and release all resources.
166
+
167
+ This method should be called when you're done using the client to properly
168
+ close database connections and release other resources. Failure to call this
169
+ method may result in the program hanging on exit while waiting for background
170
+ threads to complete.
171
+
172
+ This method is idempotent and can be called multiple times safely.
173
+
174
+ Example:
175
+ client = LlamaStackAsLibraryClient("starter")
176
+ # ... use the client ...
177
+ client.shutdown()
178
+ """
179
+ loop = self.loop
180
+ asyncio.set_event_loop(loop)
181
+ try:
182
+ loop.run_until_complete(self.async_client.shutdown())
183
+ finally:
184
+ loop.close()
185
+ asyncio.set_event_loop(None)
186
+
187
+ def __enter__(self) -> "LlamaStackAsLibraryClient":
188
+ """Enter the context manager.
189
+
190
+ The client is already initialized in __init__, so this just returns self.
191
+
192
+ Example:
193
+ with LlamaStackAsLibraryClient("starter") as client:
194
+ response = client.models.list()
195
+ # Client is automatically shut down here
196
+ """
197
+ return self
198
+
199
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
200
+ """Exit the context manager and shut down the client."""
201
+ self.shutdown()
202
+
164
203
  def request(self, *args, **kwargs):
165
204
  loop = self.loop
166
205
  asyncio.set_event_loop(loop)
@@ -224,6 +263,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
224
263
  self.custom_provider_registry = custom_provider_registry
225
264
  self.provider_data = provider_data
226
265
  self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
266
+ self.stack: Stack | None = None
227
267
 
228
268
  def _remove_root_logger_handlers(self):
229
269
  """
@@ -246,9 +286,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
246
286
  try:
247
287
  self.route_impls = None
248
288
 
249
- stack = Stack(self.config, self.custom_provider_registry)
250
- await stack.initialize()
251
- self.impls = stack.impls
289
+ self.stack = Stack(self.config, self.custom_provider_registry)
290
+ await self.stack.initialize()
291
+ self.impls = self.stack.impls
252
292
  except ModuleNotFoundError as _e:
253
293
  cprint(_e.msg, color="red", file=sys.stderr)
254
294
  cprint(
@@ -283,6 +323,43 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
283
323
  self.route_impls = initialize_route_impls(self.impls)
284
324
  return True
285
325
 
326
+ async def shutdown(self) -> None:
327
+ """Shutdown the client and release all resources.
328
+
329
+ This method should be called when you're done using the client to properly
330
+ close database connections and release other resources. Failure to call this
331
+ method may result in the program hanging on exit while waiting for background
332
+ threads to complete.
333
+
334
+ This method is idempotent and can be called multiple times safely.
335
+
336
+ Example:
337
+ client = AsyncLlamaStackAsLibraryClient("starter")
338
+ await client.initialize()
339
+ # ... use the client ...
340
+ await client.shutdown()
341
+ """
342
+ if self.stack:
343
+ await self.stack.shutdown()
344
+ self.stack = None
345
+
346
+ async def __aenter__(self) -> "AsyncLlamaStackAsLibraryClient":
347
+ """Enter the async context manager.
348
+
349
+ Initializes the client and returns it.
350
+
351
+ Example:
352
+ async with AsyncLlamaStackAsLibraryClient("starter") as client:
353
+ response = await client.models.list()
354
+ # Client is automatically shut down here
355
+ """
356
+ await self.initialize()
357
+ return self
358
+
359
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
360
+ """Exit the async context manager and shut down the client."""
361
+ await self.shutdown()
362
+
286
363
  async def request(
287
364
  self,
288
365
  cast_to: Any,
@@ -209,6 +209,17 @@ class CommonRoutingTableImpl(RoutingTable):
209
209
  logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
210
210
 
211
211
  registered_obj = await register_object_with_provider(obj, p)
212
+
213
+ # Ensure OpenAI metadata exists for vector stores
214
+ if obj.type == ResourceType.vector_store.value:
215
+ if hasattr(p, "_ensure_openai_metadata_exists"):
216
+ await p._ensure_openai_metadata_exists(obj)
217
+ else:
218
+ logger.warning(
219
+ f"Provider {obj.provider_id} does not support OpenAI metadata creation. "
220
+ f"Vector store {obj.identifier} may not work with OpenAI-compatible APIs."
221
+ )
222
+
212
223
  # TODO: This needs to be fixed for all APIs once they return the registered object
213
224
  if obj.type == ResourceType.model.value:
214
225
  await self.dist_registry.register(registered_obj)
@@ -55,6 +55,10 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
55
55
 
56
56
  # Internal methods only - no public API exposure
57
57
 
58
+ async def list_vector_stores(self) -> list[VectorStoreWithOwner]:
59
+ """List all registered vector stores."""
60
+ return await self.get_all_with_type(ResourceType.vector_store.value)
61
+
58
62
  async def register_vector_store(
59
63
  self,
60
64
  vector_store_id: str,
llama_stack/core/stack.py CHANGED
@@ -53,6 +53,7 @@ from llama_stack_api import (
53
53
  PostTraining,
54
54
  Prompts,
55
55
  Providers,
56
+ RegisterBenchmarkRequest,
56
57
  Safety,
57
58
  Scoring,
58
59
  ScoringFunctions,
@@ -61,6 +62,7 @@ from llama_stack_api import (
61
62
  ToolRuntime,
62
63
  VectorIO,
63
64
  )
65
+ from llama_stack_api.datasets import RegisterDatasetRequest
64
66
 
65
67
  logger = get_logger(name=__name__, category="core")
66
68
 
@@ -91,18 +93,22 @@ class LlamaStack(
91
93
  pass
92
94
 
93
95
 
96
+ # Resources to register based on configuration.
97
+ # If a request class is specified, the configuration object will be converted to this class before invoking the registration method.
94
98
  RESOURCES = [
95
- ("models", Api.models, "register_model", "list_models"),
96
- ("shields", Api.shields, "register_shield", "list_shields"),
97
- ("datasets", Api.datasets, "register_dataset", "list_datasets"),
99
+ ("models", Api.models, "register_model", "list_models", None),
100
+ ("shields", Api.shields, "register_shield", "list_shields", None),
101
+ ("datasets", Api.datasets, "register_dataset", "list_datasets", RegisterDatasetRequest),
98
102
  (
99
103
  "scoring_fns",
100
104
  Api.scoring_functions,
101
105
  "register_scoring_function",
102
106
  "list_scoring_functions",
107
+ None,
103
108
  ),
104
- ("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"),
105
- ("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
109
+ ("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks", RegisterBenchmarkRequest),
110
+ ("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups", None),
111
+ ("vector_stores", Api.vector_stores, "register_vector_store", "list_vector_stores", None),
106
112
  ]
107
113
 
108
114
 
@@ -199,7 +205,7 @@ async def invoke_with_optional_request(method: Any) -> Any:
199
205
 
200
206
 
201
207
  async def register_resources(run_config: StackConfig, impls: dict[Api, Any]):
202
- for rsrc, api, register_method, list_method in RESOURCES:
208
+ for rsrc, api, register_method, list_method, request_class in RESOURCES:
203
209
  objects = getattr(run_config.registered_resources, rsrc)
204
210
  if api not in impls:
205
211
  continue
@@ -213,10 +219,17 @@ async def register_resources(run_config: StackConfig, impls: dict[Api, Any]):
213
219
  continue
214
220
  logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
215
221
 
216
- # we want to maintain the type information in arguments to method.
217
- # instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
218
- # we use model_dump() to find all the attrs and then getattr to get the still typed value.
219
- await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
222
+ # TODO: Once all register methods are migrated to accept request objects,
223
+ # remove this conditional and always use the request_class pattern.
224
+ if request_class is not None:
225
+ request = request_class(**obj.model_dump())
226
+ await method(request)
227
+ else:
228
+ # we want to maintain the type information in arguments to method.
229
+ # instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
230
+ # we use model_dump() to find all the attrs and then getattr to get the still typed
231
+ # value.
232
+ await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
220
233
 
221
234
  method = getattr(impls[api], list_method)
222
235
  response = await invoke_with_optional_request(method)
@@ -608,7 +621,7 @@ class Stack:
608
621
  async def shutdown(self):
609
622
  for impl in self.impls.values():
610
623
  impl_name = impl.__class__.__name__
611
- logger.info(f"Shutting down {impl_name}")
624
+ logger.debug(f"Shutting down {impl_name}")
612
625
  try:
613
626
  if hasattr(impl, "shutdown"):
614
627
  await asyncio.wait_for(impl.shutdown(), timeout=5)
@@ -630,6 +643,20 @@ class Stack:
630
643
  if REGISTRY_REFRESH_TASK:
631
644
  REGISTRY_REFRESH_TASK.cancel()
632
645
 
646
+ # Shutdown storage backends
647
+ from llama_stack.core.storage.kvstore.kvstore import shutdown_kvstore_backends
648
+ from llama_stack.core.storage.sqlstore.sqlstore import shutdown_sqlstore_backends
649
+
650
+ try:
651
+ await shutdown_kvstore_backends()
652
+ except Exception as e:
653
+ logger.exception(f"Failed to shutdown KV store backends: {e}")
654
+
655
+ try:
656
+ await shutdown_sqlstore_backends()
657
+ except Exception as e:
658
+ logger.exception(f"Failed to shutdown SQL store backends: {e}")
659
+
633
660
 
634
661
  async def refresh_registry_once(impls: dict[Api, Any]):
635
662
  logger.debug("refreshing registry")
@@ -62,6 +62,9 @@ class InmemoryKVStoreImpl(KVStore):
62
62
  async def delete(self, key: str) -> None:
63
63
  del self._store[key]
64
64
 
65
+ async def shutdown(self) -> None:
66
+ self._store.clear()
67
+
65
68
 
66
69
  _KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
67
70
  _KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {}
@@ -126,3 +129,11 @@ async def kvstore_impl(reference: KVStoreReference) -> KVStore:
126
129
  await impl.initialize()
127
130
  _KVSTORE_INSTANCES[cache_key] = impl
128
131
  return impl
132
+
133
+
134
+ async def shutdown_kvstore_backends() -> None:
135
+ """Shutdown all cached KV store instances."""
136
+ global _KVSTORE_INSTANCES
137
+ for instance in _KVSTORE_INSTANCES.values():
138
+ await instance.shutdown()
139
+ _KVSTORE_INSTANCES.clear()
@@ -83,3 +83,8 @@ class MongoDBKVStoreImpl(KVStore):
83
83
  async for doc in cursor:
84
84
  result.append(doc["key"])
85
85
  return result
86
+
87
+ async def shutdown(self) -> None:
88
+ if self.conn:
89
+ await self.conn.close()
90
+ self.conn = None
@@ -123,3 +123,11 @@ class PostgresKVStoreImpl(KVStore):
123
123
  (start_key, end_key),
124
124
  )
125
125
  return [row[0] for row in cursor.fetchall()]
126
+
127
+ async def shutdown(self) -> None:
128
+ if self._cursor:
129
+ self._cursor.close()
130
+ self._cursor = None
131
+ if self._conn:
132
+ self._conn.close()
133
+ self._conn = None
@@ -99,3 +99,8 @@ class RedisKVStoreImpl(KVStore):
99
99
  if cursor == 0:
100
100
  break
101
101
  return result
102
+
103
+ async def shutdown(self) -> None:
104
+ if self._redis:
105
+ await self._redis.close()
106
+ self._redis = None
@@ -107,6 +107,14 @@ class SqlAlchemySqlStoreImpl(SqlStore):
107
107
 
108
108
  return engine
109
109
 
110
+ async def shutdown(self) -> None:
111
+ """Dispose the session maker's engine and close all connections."""
112
+ # The async_session holds a reference to the engine created in __init__
113
+ if self.async_session:
114
+ engine = self.async_session.kw.get("bind")
115
+ if engine:
116
+ await engine.dispose()
117
+
110
118
  async def create_table(
111
119
  self,
112
120
  table: str,
@@ -85,3 +85,11 @@ def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> Non
85
85
  _SQLSTORE_LOCKS.clear()
86
86
  for name, cfg in backends.items():
87
87
  _SQLSTORE_BACKENDS[name] = cfg
88
+
89
+
90
+ async def shutdown_sqlstore_backends() -> None:
91
+ """Shutdown all cached SQL store instances."""
92
+ global _SQLSTORE_INSTANCES
93
+ for instance in _SQLSTORE_INSTANCES.values():
94
+ await instance.shutdown()
95
+ _SQLSTORE_INSTANCES.clear()
@@ -4,6 +4,7 @@
4
4
  # This source code is licensed under the terms described in the LICENSE file in
5
5
  # the root directory of this source tree.
6
6
 
7
+ import asyncio
7
8
  import re
8
9
  import time
9
10
  import uuid
@@ -16,6 +17,7 @@ from llama_stack.providers.utils.responses.responses_store import (
16
17
  ResponsesStore,
17
18
  _OpenAIResponseObjectWithInputAndMessages,
18
19
  )
20
+ from llama_stack.providers.utils.tools.mcp import MCPSessionManager
19
21
  from llama_stack_api import (
20
22
  ConversationItem,
21
23
  Conversations,
@@ -489,6 +491,19 @@ class OpenAIResponsesImpl:
489
491
  response_id = f"resp_{uuid.uuid4()}"
490
492
  created_at = int(time.time())
491
493
 
494
+ # Create a per-request MCP session manager for session reuse (fix for #4452)
495
+ # This avoids redundant tools/list calls when making multiple MCP tool invocations
496
+ mcp_session_manager = MCPSessionManager()
497
+
498
+ # Create a per-request ToolExecutor with the session manager
499
+ request_tool_executor = ToolExecutor(
500
+ tool_groups_api=self.tool_groups_api,
501
+ tool_runtime_api=self.tool_runtime_api,
502
+ vector_io_api=self.vector_io_api,
503
+ vector_stores_config=self.tool_executor.vector_stores_config,
504
+ mcp_session_manager=mcp_session_manager,
505
+ )
506
+
492
507
  orchestrator = StreamingResponseOrchestrator(
493
508
  inference_api=self.inference_api,
494
509
  ctx=ctx,
@@ -498,7 +513,7 @@ class OpenAIResponsesImpl:
498
513
  text=text,
499
514
  max_infer_iters=max_infer_iters,
500
515
  parallel_tool_calls=parallel_tool_calls,
501
- tool_executor=self.tool_executor,
516
+ tool_executor=request_tool_executor,
502
517
  safety_api=self.safety_api,
503
518
  guardrail_ids=guardrail_ids,
504
519
  instructions=instructions,
@@ -513,41 +528,52 @@ class OpenAIResponsesImpl:
513
528
 
514
529
  # Type as ConversationItem to avoid list invariance issues
515
530
  output_items: list[ConversationItem] = []
516
- async for stream_chunk in orchestrator.create_response():
517
- match stream_chunk.type:
518
- case "response.completed" | "response.incomplete":
519
- final_response = stream_chunk.response
520
- case "response.failed":
521
- failed_response = stream_chunk.response
522
- case "response.output_item.done":
523
- item = stream_chunk.item
524
- output_items.append(item)
525
- case _:
526
- pass # Other event types
527
-
528
- # Store and sync before yielding terminal events
529
- # This ensures the storage/syncing happens even if the consumer breaks after receiving the event
530
- if (
531
- stream_chunk.type in {"response.completed", "response.incomplete"}
532
- and final_response
533
- and failed_response is None
534
- ):
535
- messages_to_store = list(
536
- filter(lambda x: not isinstance(x, OpenAISystemMessageParam), orchestrator.final_messages)
537
- )
538
- if store:
539
- # TODO: we really should work off of output_items instead of "final_messages"
540
- await self._store_response(
541
- response=final_response,
542
- input=all_input,
543
- messages=messages_to_store,
531
+ try:
532
+ async for stream_chunk in orchestrator.create_response():
533
+ match stream_chunk.type:
534
+ case "response.completed" | "response.incomplete":
535
+ final_response = stream_chunk.response
536
+ case "response.failed":
537
+ failed_response = stream_chunk.response
538
+ case "response.output_item.done":
539
+ item = stream_chunk.item
540
+ output_items.append(item)
541
+ case _:
542
+ pass # Other event types
543
+
544
+ # Store and sync before yielding terminal events
545
+ # This ensures the storage/syncing happens even if the consumer breaks after receiving the event
546
+ if (
547
+ stream_chunk.type in {"response.completed", "response.incomplete"}
548
+ and final_response
549
+ and failed_response is None
550
+ ):
551
+ messages_to_store = list(
552
+ filter(lambda x: not isinstance(x, OpenAISystemMessageParam), orchestrator.final_messages)
544
553
  )
554
+ if store:
555
+ # TODO: we really should work off of output_items instead of "final_messages"
556
+ await self._store_response(
557
+ response=final_response,
558
+ input=all_input,
559
+ messages=messages_to_store,
560
+ )
545
561
 
546
- if conversation:
547
- await self._sync_response_to_conversation(conversation, input, output_items)
548
- await self.responses_store.store_conversation_messages(conversation, messages_to_store)
549
-
550
- yield stream_chunk
562
+ if conversation:
563
+ await self._sync_response_to_conversation(conversation, input, output_items)
564
+ await self.responses_store.store_conversation_messages(conversation, messages_to_store)
565
+
566
+ yield stream_chunk
567
+ finally:
568
+ # Clean up MCP sessions at the end of the request (fix for #4452)
569
+ # Use shield() to prevent cancellation from interrupting cleanup and leaking resources
570
+ # Wrap in try/except as cleanup errors should not mask the original response
571
+ try:
572
+ await asyncio.shield(mcp_session_manager.close_all())
573
+ except BaseException as e:
574
+ # Debug level - cleanup errors are expected in streaming scenarios where
575
+ # anyio cancel scopes may be in a different task context
576
+ logger.debug(f"Error during MCP session cleanup: {e}")
551
577
 
552
578
  async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
553
579
  return await self.responses_store.delete_response_object(response_id)
@@ -1200,6 +1200,9 @@ class StreamingResponseOrchestrator:
1200
1200
  "mcp_list_tools_id": list_id,
1201
1201
  }
1202
1202
 
1203
+ # Get session manager from tool_executor if available (fix for #4452)
1204
+ session_manager = getattr(self.tool_executor, "mcp_session_manager", None)
1205
+
1203
1206
  # TODO: follow semantic conventions for Open Telemetry tool spans
1204
1207
  # https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span
1205
1208
  with tracer.start_as_current_span("list_mcp_tools", attributes=attributes):
@@ -1207,6 +1210,7 @@ class StreamingResponseOrchestrator:
1207
1210
  endpoint=mcp_tool.server_url,
1208
1211
  headers=mcp_tool.headers,
1209
1212
  authorization=mcp_tool.authorization,
1213
+ session_manager=session_manager,
1210
1214
  )
1211
1215
 
1212
1216
  # Create the MCP list tools message
@@ -54,11 +54,14 @@ class ToolExecutor:
54
54
  tool_runtime_api: ToolRuntime,
55
55
  vector_io_api: VectorIO,
56
56
  vector_stores_config=None,
57
+ mcp_session_manager=None,
57
58
  ):
58
59
  self.tool_groups_api = tool_groups_api
59
60
  self.tool_runtime_api = tool_runtime_api
60
61
  self.vector_io_api = vector_io_api
61
62
  self.vector_stores_config = vector_stores_config
63
+ # Optional MCPSessionManager for session reuse within a request (fix for #4452)
64
+ self.mcp_session_manager = mcp_session_manager
62
65
 
63
66
  async def execute_tool_call(
64
67
  self,
@@ -233,6 +236,7 @@ class ToolExecutor:
233
236
  "document_ids": [r.file_id for r in search_results],
234
237
  "chunks": [r.content[0].text if r.content else "" for r in search_results],
235
238
  "scores": [r.score for r in search_results],
239
+ "attributes": [r.attributes or {} for r in search_results],
236
240
  "citation_files": citation_files,
237
241
  },
238
242
  )
@@ -327,12 +331,14 @@ class ToolExecutor:
327
331
  # TODO: follow semantic conventions for Open Telemetry tool spans
328
332
  # https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span
329
333
  with tracer.start_as_current_span("invoke_mcp_tool", attributes=attributes):
334
+ # Pass session_manager for session reuse within request (fix for #4452)
330
335
  result = await invoke_mcp_tool(
331
336
  endpoint=mcp_tool.server_url,
332
337
  tool_name=function_name,
333
338
  kwargs=tool_kwargs,
334
339
  headers=mcp_tool.headers,
335
340
  authorization=mcp_tool.authorization,
341
+ session_manager=self.mcp_session_manager,
336
342
  )
337
343
  elif function_name == "knowledge_search":
338
344
  response_file_search_tool = (
@@ -464,16 +470,18 @@ class ToolExecutor:
464
470
  )
465
471
  if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata:
466
472
  message.results = []
473
+ attributes_list = metadata.get("attributes", [])
467
474
  for i, doc_id in enumerate(metadata["document_ids"]):
468
475
  text = metadata["chunks"][i] if "chunks" in metadata else None
469
476
  score = metadata["scores"][i] if "scores" in metadata else None
477
+ attrs = attributes_list[i] if i < len(attributes_list) else {}
470
478
  message.results.append(
471
479
  OpenAIResponseOutputMessageFileSearchToolCallResults(
472
480
  file_id=doc_id,
473
481
  filename=doc_id,
474
482
  text=text if text is not None else "",
475
483
  score=score if score is not None else 0.0,
476
- attributes={},
484
+ attributes=attrs,
477
485
  )
478
486
  )
479
487
  if has_error:
@@ -50,8 +50,11 @@ log = get_logger(name=__name__, category="tool_runtime")
50
50
  async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
51
51
  """Get raw binary data and mime type from a RAGDocument for file upload."""
52
52
  if isinstance(doc.content, URL):
53
- if doc.content.uri.startswith("data:"):
54
- parts = parse_data_url(doc.content.uri)
53
+ uri = doc.content.uri
54
+ if uri.startswith("file://"):
55
+ raise ValueError("file:// URIs are not supported. Please use the Files API (/v1/files) to upload files.")
56
+ if uri.startswith("data:"):
57
+ parts = parse_data_url(uri)
55
58
  mime_type = parts["mimetype"]
56
59
  data = parts["data"]
57
60
 
@@ -63,7 +66,7 @@ async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
63
66
  return file_data, mime_type
64
67
  else:
65
68
  async with httpx.AsyncClient() as client:
66
- r = await client.get(doc.content.uri)
69
+ r = await client.get(uri)
67
70
  r.raise_for_status()
68
71
  mime_type = r.headers.get("content-type", "application/octet-stream")
69
72
  return r.content, mime_type
@@ -73,6 +76,8 @@ async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
73
76
  else:
74
77
  content_str = interleaved_content_as_str(doc.content)
75
78
 
79
+ if content_str.startswith("file://"):
80
+ raise ValueError("file:// URIs are not supported. Please use the Files API (/v1/files) to upload files.")
76
81
  if content_str.startswith("data:"):
77
82
  parts = parse_data_url(content_str)
78
83
  mime_type = parts["mimetype"]
@@ -10,6 +10,7 @@ from typing import Any
10
10
  import psycopg2
11
11
  from numpy.typing import NDArray
12
12
  from psycopg2 import sql
13
+ from psycopg2.extensions import cursor
13
14
  from psycopg2.extras import Json, execute_values
14
15
  from pydantic import BaseModel, TypeAdapter
15
16
 
@@ -54,6 +55,17 @@ def check_extension_version(cur):
54
55
  return result[0] if result else None
55
56
 
56
57
 
58
+ def create_vector_extension(cur: cursor) -> None:
59
+ try:
60
+ log.info("Vector extension not found, creating...")
61
+ cur.execute("CREATE EXTENSION vector;")
62
+ log.info("Vector extension created successfully")
63
+ log.info(f"Vector extension version: {check_extension_version(cur)}")
64
+
65
+ except psycopg2.Error as e:
66
+ raise RuntimeError(f"Failed to create vector extension for PGVector: {e}") from e
67
+
68
+
57
69
  def upsert_models(conn, keys_models: list[tuple[str, BaseModel]]):
58
70
  with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
59
71
  query = sql.SQL(
@@ -364,7 +376,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
364
376
  if version:
365
377
  log.info(f"Vector extension version: {version}")
366
378
  else:
367
- raise RuntimeError("Vector extension is not installed.")
379
+ create_vector_extension(cur)
368
380
 
369
381
  cur.execute(
370
382
  """