llama-stack 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/run.py +3 -0
- llama_stack/core/library_client.py +80 -3
- llama_stack/core/routing_tables/common.py +11 -0
- llama_stack/core/routing_tables/vector_stores.py +4 -0
- llama_stack/core/stack.py +38 -11
- llama_stack/core/storage/kvstore/kvstore.py +11 -0
- llama_stack/core/storage/kvstore/mongodb/mongodb.py +5 -0
- llama_stack/core/storage/kvstore/postgres/postgres.py +8 -0
- llama_stack/core/storage/kvstore/redis/redis.py +5 -0
- llama_stack/core/storage/sqlstore/sqlalchemy_sqlstore.py +8 -0
- llama_stack/core/storage/sqlstore/sqlstore.py +8 -0
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +60 -34
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +4 -0
- llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +9 -1
- llama_stack/providers/inline/tool_runtime/rag/memory.py +8 -3
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +13 -1
- llama_stack/providers/utils/inference/embedding_mixin.py +20 -16
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +33 -0
- llama_stack/providers/utils/memory/vector_store.py +9 -4
- llama_stack/providers/utils/tools/mcp.py +258 -16
- {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/METADATA +2 -2
- {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/RECORD +96 -29
- {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/WHEEL +1 -1
- llama_stack_api/internal/kvstore.py +2 -0
- llama_stack_api/internal/sqlstore.py +2 -0
- llama_stack_api/llama_stack_api/__init__.py +945 -0
- llama_stack_api/llama_stack_api/admin/__init__.py +45 -0
- llama_stack_api/llama_stack_api/admin/api.py +72 -0
- llama_stack_api/llama_stack_api/admin/fastapi_routes.py +117 -0
- llama_stack_api/llama_stack_api/admin/models.py +113 -0
- llama_stack_api/llama_stack_api/agents.py +173 -0
- llama_stack_api/llama_stack_api/batches/__init__.py +40 -0
- llama_stack_api/llama_stack_api/batches/api.py +53 -0
- llama_stack_api/llama_stack_api/batches/fastapi_routes.py +113 -0
- llama_stack_api/llama_stack_api/batches/models.py +78 -0
- llama_stack_api/llama_stack_api/benchmarks/__init__.py +43 -0
- llama_stack_api/llama_stack_api/benchmarks/api.py +39 -0
- llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +109 -0
- llama_stack_api/llama_stack_api/benchmarks/models.py +109 -0
- llama_stack_api/llama_stack_api/common/__init__.py +5 -0
- llama_stack_api/llama_stack_api/common/content_types.py +101 -0
- llama_stack_api/llama_stack_api/common/errors.py +95 -0
- llama_stack_api/llama_stack_api/common/job_types.py +38 -0
- llama_stack_api/llama_stack_api/common/responses.py +77 -0
- llama_stack_api/llama_stack_api/common/training_types.py +47 -0
- llama_stack_api/llama_stack_api/common/type_system.py +146 -0
- llama_stack_api/llama_stack_api/connectors.py +146 -0
- llama_stack_api/llama_stack_api/conversations.py +270 -0
- llama_stack_api/llama_stack_api/datasetio.py +55 -0
- llama_stack_api/llama_stack_api/datasets/__init__.py +61 -0
- llama_stack_api/llama_stack_api/datasets/api.py +35 -0
- llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +104 -0
- llama_stack_api/llama_stack_api/datasets/models.py +152 -0
- llama_stack_api/llama_stack_api/datatypes.py +373 -0
- llama_stack_api/llama_stack_api/eval.py +137 -0
- llama_stack_api/llama_stack_api/file_processors/__init__.py +27 -0
- llama_stack_api/llama_stack_api/file_processors/api.py +64 -0
- llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +78 -0
- llama_stack_api/llama_stack_api/file_processors/models.py +42 -0
- llama_stack_api/llama_stack_api/files/__init__.py +35 -0
- llama_stack_api/llama_stack_api/files/api.py +51 -0
- llama_stack_api/llama_stack_api/files/fastapi_routes.py +124 -0
- llama_stack_api/llama_stack_api/files/models.py +107 -0
- llama_stack_api/llama_stack_api/inference.py +1169 -0
- llama_stack_api/llama_stack_api/inspect_api/__init__.py +37 -0
- llama_stack_api/llama_stack_api/inspect_api/api.py +25 -0
- llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +76 -0
- llama_stack_api/llama_stack_api/inspect_api/models.py +28 -0
- llama_stack_api/llama_stack_api/internal/__init__.py +9 -0
- llama_stack_api/llama_stack_api/internal/kvstore.py +28 -0
- llama_stack_api/llama_stack_api/internal/sqlstore.py +81 -0
- llama_stack_api/llama_stack_api/models.py +171 -0
- llama_stack_api/llama_stack_api/openai_responses.py +1468 -0
- llama_stack_api/llama_stack_api/post_training.py +370 -0
- llama_stack_api/llama_stack_api/prompts.py +203 -0
- llama_stack_api/llama_stack_api/providers/__init__.py +33 -0
- llama_stack_api/llama_stack_api/providers/api.py +16 -0
- llama_stack_api/llama_stack_api/providers/fastapi_routes.py +57 -0
- llama_stack_api/llama_stack_api/providers/models.py +24 -0
- llama_stack_api/llama_stack_api/py.typed +0 -0
- llama_stack_api/llama_stack_api/rag_tool.py +168 -0
- llama_stack_api/llama_stack_api/resource.py +37 -0
- llama_stack_api/llama_stack_api/router_utils.py +160 -0
- llama_stack_api/llama_stack_api/safety.py +132 -0
- llama_stack_api/llama_stack_api/schema_utils.py +208 -0
- llama_stack_api/llama_stack_api/scoring.py +93 -0
- llama_stack_api/llama_stack_api/scoring_functions.py +211 -0
- llama_stack_api/llama_stack_api/shields.py +93 -0
- llama_stack_api/llama_stack_api/tools.py +226 -0
- llama_stack_api/llama_stack_api/vector_io.py +941 -0
- llama_stack_api/llama_stack_api/vector_stores.py +53 -0
- llama_stack_api/llama_stack_api/version.py +9 -0
- llama_stack_api/vector_stores.py +2 -0
- {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/licenses/LICENSE +0 -0
- {llama_stack-0.4.1.dist-info → llama_stack-0.4.3.dist-info}/top_level.txt +0 -0
|
@@ -25,7 +25,8 @@ from llama_stack_api import (
|
|
|
25
25
|
OpenAIEmbeddingUsage,
|
|
26
26
|
)
|
|
27
27
|
|
|
28
|
-
EMBEDDING_MODELS = {}
|
|
28
|
+
EMBEDDING_MODELS: dict[str, "SentenceTransformer"] = {}
|
|
29
|
+
EMBEDDING_MODELS_LOCK = asyncio.Lock()
|
|
29
30
|
|
|
30
31
|
DARWIN = "Darwin"
|
|
31
32
|
|
|
@@ -76,26 +77,29 @@ class SentenceTransformerEmbeddingMixin:
|
|
|
76
77
|
)
|
|
77
78
|
|
|
78
79
|
async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
|
79
|
-
global EMBEDDING_MODELS
|
|
80
|
-
|
|
81
80
|
loaded_model = EMBEDDING_MODELS.get(model)
|
|
82
81
|
if loaded_model is not None:
|
|
83
82
|
return loaded_model
|
|
84
83
|
|
|
85
|
-
|
|
84
|
+
async with EMBEDDING_MODELS_LOCK:
|
|
85
|
+
loaded_model = EMBEDDING_MODELS.get(model)
|
|
86
|
+
if loaded_model is not None:
|
|
87
|
+
return loaded_model
|
|
88
|
+
|
|
89
|
+
log.info(f"Loading sentence transformer for {model}...")
|
|
86
90
|
|
|
87
|
-
|
|
88
|
-
|
|
91
|
+
def _load_model():
|
|
92
|
+
from sentence_transformers import SentenceTransformer
|
|
89
93
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
94
|
+
platform_name = platform.system()
|
|
95
|
+
if platform_name == DARWIN:
|
|
96
|
+
# PyTorch's OpenMP kernels can segfault on macOS when spawned from background
|
|
97
|
+
# threads with the default parallel settings, so force a single-threaded CPU run.
|
|
98
|
+
log.debug(f"Constraining torch threads on {platform_name} to a single worker")
|
|
99
|
+
torch.set_num_threads(1)
|
|
96
100
|
|
|
97
|
-
|
|
101
|
+
return SentenceTransformer(model, trust_remote_code=True)
|
|
98
102
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
103
|
+
loaded_model = await asyncio.to_thread(_load_model)
|
|
104
|
+
EMBEDDING_MODELS[model] = loaded_model
|
|
105
|
+
return loaded_model
|
|
@@ -122,6 +122,39 @@ class OpenAIVectorStoreMixin(ABC):
|
|
|
122
122
|
# update in-memory cache
|
|
123
123
|
self.openai_vector_stores[store_id] = store_info
|
|
124
124
|
|
|
125
|
+
async def _ensure_openai_metadata_exists(self, vector_store: VectorStore, name: str | None = None) -> None:
|
|
126
|
+
"""
|
|
127
|
+
Ensure OpenAI-compatible metadata exists for a vector store.
|
|
128
|
+
"""
|
|
129
|
+
if vector_store.identifier not in self.openai_vector_stores:
|
|
130
|
+
store_info = {
|
|
131
|
+
"id": vector_store.identifier,
|
|
132
|
+
"object": "vector_store",
|
|
133
|
+
"created_at": int(time.time()),
|
|
134
|
+
"name": name or vector_store.vector_store_name or vector_store.identifier,
|
|
135
|
+
"usage_bytes": 0,
|
|
136
|
+
"file_counts": VectorStoreFileCounts(
|
|
137
|
+
cancelled=0,
|
|
138
|
+
completed=0,
|
|
139
|
+
failed=0,
|
|
140
|
+
in_progress=0,
|
|
141
|
+
total=0,
|
|
142
|
+
).model_dump(),
|
|
143
|
+
"status": "completed",
|
|
144
|
+
"expires_after": None,
|
|
145
|
+
"expires_at": None,
|
|
146
|
+
"last_active_at": int(time.time()),
|
|
147
|
+
"file_ids": [],
|
|
148
|
+
"chunking_strategy": None,
|
|
149
|
+
"metadata": {
|
|
150
|
+
"provider_id": vector_store.provider_id,
|
|
151
|
+
"provider_vector_store_id": vector_store.provider_resource_id,
|
|
152
|
+
"embedding_model": vector_store.embedding_model,
|
|
153
|
+
"embedding_dimension": str(vector_store.embedding_dimension),
|
|
154
|
+
},
|
|
155
|
+
}
|
|
156
|
+
await self._save_openai_vector_store(vector_store.identifier, store_info)
|
|
157
|
+
|
|
125
158
|
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
|
126
159
|
"""Load all vector store metadata from persistent storage."""
|
|
127
160
|
assert self.kvstore
|
|
@@ -135,15 +135,20 @@ def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, en
|
|
|
135
135
|
|
|
136
136
|
async def content_from_doc(doc: RAGDocument) -> str:
|
|
137
137
|
if isinstance(doc.content, URL):
|
|
138
|
-
|
|
139
|
-
|
|
138
|
+
uri = doc.content.uri
|
|
139
|
+
if uri.startswith("file://"):
|
|
140
|
+
raise ValueError("file:// URIs are not supported. Please use the Files API (/v1/files) to upload files.")
|
|
141
|
+
if uri.startswith("data:"):
|
|
142
|
+
return content_from_data(uri)
|
|
140
143
|
async with httpx.AsyncClient() as client:
|
|
141
|
-
r = await client.get(
|
|
144
|
+
r = await client.get(uri)
|
|
142
145
|
if doc.mime_type == "application/pdf":
|
|
143
146
|
return parse_pdf(r.content)
|
|
144
147
|
return r.text
|
|
145
148
|
elif isinstance(doc.content, str):
|
|
146
|
-
|
|
149
|
+
if doc.content.startswith("file://"):
|
|
150
|
+
raise ValueError("file:// URIs are not supported. Please use the Files API (/v1/files) to upload files.")
|
|
151
|
+
pattern = re.compile("^(https?://|data:)")
|
|
147
152
|
if pattern.match(doc.content):
|
|
148
153
|
if doc.content.startswith("data:"):
|
|
149
154
|
return content_from_data(doc.content)
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
import asyncio
|
|
8
|
+
import hashlib
|
|
7
9
|
from collections.abc import AsyncGenerator
|
|
8
10
|
from contextlib import asynccontextmanager
|
|
9
11
|
from enum import Enum
|
|
@@ -73,6 +75,207 @@ class MCPProtol(Enum):
|
|
|
73
75
|
SSE = 2
|
|
74
76
|
|
|
75
77
|
|
|
78
|
+
class MCPSessionManager:
|
|
79
|
+
"""Manages MCP session lifecycle within a request scope.
|
|
80
|
+
|
|
81
|
+
This class caches MCP sessions by (endpoint, headers_hash) to avoid redundant
|
|
82
|
+
connection establishment and tools/list calls when making multiple tool
|
|
83
|
+
invocations to the same MCP server within a single request.
|
|
84
|
+
|
|
85
|
+
Fix for GitHub issue #4452: MCP tools/list called redundantly before every
|
|
86
|
+
tool invocation.
|
|
87
|
+
|
|
88
|
+
Usage:
|
|
89
|
+
async with MCPSessionManager() as session_manager:
|
|
90
|
+
# Multiple tool calls will reuse the same session
|
|
91
|
+
result1 = await invoke_mcp_tool(..., session_manager=session_manager)
|
|
92
|
+
result2 = await invoke_mcp_tool(..., session_manager=session_manager)
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(self):
|
|
96
|
+
# Cache of active sessions: key -> (session, client_context, session_context)
|
|
97
|
+
self._sessions: dict[str, tuple[ClientSession, Any, Any]] = {}
|
|
98
|
+
# Locks to prevent concurrent session creation for the same key
|
|
99
|
+
self._locks: dict[str, asyncio.Lock] = {}
|
|
100
|
+
# Global lock for managing the locks dict
|
|
101
|
+
self._global_lock = asyncio.Lock()
|
|
102
|
+
|
|
103
|
+
def _make_key(self, endpoint: str, headers: dict[str, str]) -> str:
|
|
104
|
+
"""Create a cache key from endpoint and headers."""
|
|
105
|
+
# Sort headers for consistent hashing
|
|
106
|
+
headers_str = str(sorted(headers.items()))
|
|
107
|
+
headers_hash = hashlib.sha256(headers_str.encode()).hexdigest()[:16]
|
|
108
|
+
return f"{endpoint}:{headers_hash}"
|
|
109
|
+
|
|
110
|
+
async def _get_lock(self, key: str) -> asyncio.Lock:
|
|
111
|
+
"""Get or create a lock for a specific cache key."""
|
|
112
|
+
async with self._global_lock:
|
|
113
|
+
if key not in self._locks:
|
|
114
|
+
self._locks[key] = asyncio.Lock()
|
|
115
|
+
return self._locks[key]
|
|
116
|
+
|
|
117
|
+
async def get_session(self, endpoint: str, headers: dict[str, str]) -> ClientSession:
|
|
118
|
+
"""Get or create an MCP session for the given endpoint and headers.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
endpoint: MCP server endpoint URL
|
|
122
|
+
headers: Headers including authorization
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
An initialized ClientSession ready for tool calls
|
|
126
|
+
"""
|
|
127
|
+
key = self._make_key(endpoint, headers)
|
|
128
|
+
|
|
129
|
+
# Check if session already exists (fast path)
|
|
130
|
+
if key in self._sessions:
|
|
131
|
+
session, _, _ = self._sessions[key]
|
|
132
|
+
return session
|
|
133
|
+
|
|
134
|
+
# Acquire lock for this specific key to prevent concurrent creation
|
|
135
|
+
lock = await self._get_lock(key)
|
|
136
|
+
async with lock:
|
|
137
|
+
# Double-check after acquiring lock
|
|
138
|
+
if key in self._sessions:
|
|
139
|
+
session, _, _ = self._sessions[key]
|
|
140
|
+
return session
|
|
141
|
+
|
|
142
|
+
# Create new session
|
|
143
|
+
session, client_ctx, session_ctx = await self._create_session(endpoint, headers)
|
|
144
|
+
self._sessions[key] = (session, client_ctx, session_ctx)
|
|
145
|
+
logger.debug(f"Created new MCP session for {endpoint} (key: {key[:32]}...)")
|
|
146
|
+
return session
|
|
147
|
+
|
|
148
|
+
async def _create_session(self, endpoint: str, headers: dict[str, str]) -> tuple[ClientSession, Any, Any]:
|
|
149
|
+
"""Create a new MCP session.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Tuple of (session, client_context, session_context) for lifecycle management
|
|
153
|
+
"""
|
|
154
|
+
# Use the same protocol detection logic as client_wrapper
|
|
155
|
+
connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE]
|
|
156
|
+
mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN)
|
|
157
|
+
if mcp_protocol == MCPProtol.SSE:
|
|
158
|
+
connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP]
|
|
159
|
+
|
|
160
|
+
last_exception: BaseException | None = None
|
|
161
|
+
|
|
162
|
+
for i, strategy in enumerate(connection_strategies):
|
|
163
|
+
try:
|
|
164
|
+
client = streamablehttp_client
|
|
165
|
+
if strategy == MCPProtol.SSE:
|
|
166
|
+
client = cast(Any, sse_client)
|
|
167
|
+
|
|
168
|
+
# Enter the client context manager manually
|
|
169
|
+
client_ctx = client(endpoint, headers=headers)
|
|
170
|
+
client_streams = await client_ctx.__aenter__()
|
|
171
|
+
|
|
172
|
+
try:
|
|
173
|
+
# Enter the session context manager manually
|
|
174
|
+
session = ClientSession(read_stream=client_streams[0], write_stream=client_streams[1])
|
|
175
|
+
session_ctx = session
|
|
176
|
+
await session.__aenter__()
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
await session.initialize()
|
|
180
|
+
protocol_cache[endpoint] = strategy
|
|
181
|
+
return session, client_ctx, session_ctx
|
|
182
|
+
except BaseException:
|
|
183
|
+
await session.__aexit__(None, None, None)
|
|
184
|
+
raise
|
|
185
|
+
except BaseException:
|
|
186
|
+
await client_ctx.__aexit__(None, None, None)
|
|
187
|
+
raise
|
|
188
|
+
|
|
189
|
+
except* httpx.HTTPStatusError as eg:
|
|
190
|
+
for exc in eg.exceptions:
|
|
191
|
+
err = cast(httpx.HTTPStatusError, exc)
|
|
192
|
+
if err.response.status_code == 401:
|
|
193
|
+
raise AuthenticationRequiredError(exc) from exc
|
|
194
|
+
if i == len(connection_strategies) - 1:
|
|
195
|
+
raise
|
|
196
|
+
last_exception = eg
|
|
197
|
+
except* httpx.ConnectError as eg:
|
|
198
|
+
if i == len(connection_strategies) - 1:
|
|
199
|
+
error_msg = f"Failed to connect to MCP server at {endpoint}: Connection refused"
|
|
200
|
+
logger.error(f"MCP connection error: {error_msg}")
|
|
201
|
+
raise ConnectionError(error_msg) from eg
|
|
202
|
+
else:
|
|
203
|
+
logger.warning(
|
|
204
|
+
f"failed to connect to MCP server at {endpoint} via {strategy.name}, "
|
|
205
|
+
f"falling back to {connection_strategies[i + 1].name}"
|
|
206
|
+
)
|
|
207
|
+
last_exception = eg
|
|
208
|
+
except* httpx.TimeoutException as eg:
|
|
209
|
+
if i == len(connection_strategies) - 1:
|
|
210
|
+
error_msg = f"MCP server at {endpoint} timed out"
|
|
211
|
+
logger.error(f"MCP timeout error: {error_msg}")
|
|
212
|
+
raise TimeoutError(error_msg) from eg
|
|
213
|
+
else:
|
|
214
|
+
logger.warning(
|
|
215
|
+
f"MCP server at {endpoint} timed out via {strategy.name}, "
|
|
216
|
+
f"falling back to {connection_strategies[i + 1].name}"
|
|
217
|
+
)
|
|
218
|
+
last_exception = eg
|
|
219
|
+
except* httpx.RequestError as eg:
|
|
220
|
+
if i == len(connection_strategies) - 1:
|
|
221
|
+
exc_msg = str(eg.exceptions[0]) if eg.exceptions else "Unknown error"
|
|
222
|
+
error_msg = f"Network error connecting to MCP server at {endpoint}: {exc_msg}"
|
|
223
|
+
logger.error(f"MCP network error: {error_msg}")
|
|
224
|
+
raise ConnectionError(error_msg) from eg
|
|
225
|
+
else:
|
|
226
|
+
logger.warning(
|
|
227
|
+
f"network error connecting to MCP server at {endpoint} via {strategy.name}, "
|
|
228
|
+
f"falling back to {connection_strategies[i + 1].name}"
|
|
229
|
+
)
|
|
230
|
+
last_exception = eg
|
|
231
|
+
except* McpError:
|
|
232
|
+
if i < len(connection_strategies) - 1:
|
|
233
|
+
logger.warning(
|
|
234
|
+
f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
|
235
|
+
)
|
|
236
|
+
else:
|
|
237
|
+
raise
|
|
238
|
+
|
|
239
|
+
# Should not reach here, but just in case
|
|
240
|
+
if last_exception:
|
|
241
|
+
raise last_exception
|
|
242
|
+
raise RuntimeError(f"Failed to create MCP session for {endpoint}")
|
|
243
|
+
|
|
244
|
+
async def close_all(self) -> None:
|
|
245
|
+
"""Close all cached sessions.
|
|
246
|
+
|
|
247
|
+
Should be called at the end of a request to clean up resources.
|
|
248
|
+
|
|
249
|
+
Note: We catch BaseException (not just Exception) because:
|
|
250
|
+
1. CancelledError is a BaseException and can occur during cleanup
|
|
251
|
+
2. anyio cancel scope errors can occur if cleanup runs in a different
|
|
252
|
+
task context than where the session was created
|
|
253
|
+
These are expected in streaming response scenarios and are handled gracefully.
|
|
254
|
+
"""
|
|
255
|
+
errors = []
|
|
256
|
+
session_count = len(self._sessions)
|
|
257
|
+
for key, (session, client_ctx, _) in list(self._sessions.items()):
|
|
258
|
+
try:
|
|
259
|
+
await session.__aexit__(None, None, None)
|
|
260
|
+
except BaseException as e:
|
|
261
|
+
# Debug level since these errors are expected in streaming scenarios
|
|
262
|
+
# where cleanup runs in a different async context than session creation
|
|
263
|
+
logger.debug(f"Error closing MCP session {key}: {e}")
|
|
264
|
+
errors.append(e)
|
|
265
|
+
try:
|
|
266
|
+
await client_ctx.__aexit__(None, None, None)
|
|
267
|
+
except BaseException as e:
|
|
268
|
+
logger.debug(f"Error closing MCP client context {key}: {e}")
|
|
269
|
+
errors.append(e)
|
|
270
|
+
|
|
271
|
+
self._sessions.clear()
|
|
272
|
+
self._locks.clear()
|
|
273
|
+
logger.debug(f"Closed {session_count} MCP sessions")
|
|
274
|
+
|
|
275
|
+
if errors:
|
|
276
|
+
logger.debug(f"Encountered {len(errors)} errors while closing MCP sessions (expected in streaming)")
|
|
277
|
+
|
|
278
|
+
|
|
76
279
|
@asynccontextmanager
|
|
77
280
|
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
|
|
78
281
|
# we use a ttl'd dict to cache the happy path protocol for each endpoint
|
|
@@ -151,6 +354,7 @@ async def list_mcp_tools(
|
|
|
151
354
|
endpoint: str,
|
|
152
355
|
headers: dict[str, str] | None = None,
|
|
153
356
|
authorization: str | None = None,
|
|
357
|
+
session_manager: MCPSessionManager | None = None,
|
|
154
358
|
) -> ListToolDefsResponse:
|
|
155
359
|
"""List tools available from an MCP server.
|
|
156
360
|
|
|
@@ -158,6 +362,10 @@ async def list_mcp_tools(
|
|
|
158
362
|
endpoint: MCP server endpoint URL
|
|
159
363
|
headers: Optional base headers to include
|
|
160
364
|
authorization: Optional OAuth access token (just the token, not "Bearer <token>")
|
|
365
|
+
session_manager: Optional MCPSessionManager for session reuse within a request.
|
|
366
|
+
When provided, sessions are cached and reused, avoiding redundant session
|
|
367
|
+
creation when list_mcp_tools and invoke_mcp_tool are called for the same
|
|
368
|
+
server within a request. (Fix for #4452)
|
|
161
369
|
|
|
162
370
|
Returns:
|
|
163
371
|
List of tool definitions from the MCP server
|
|
@@ -169,7 +377,9 @@ async def list_mcp_tools(
|
|
|
169
377
|
final_headers = prepare_mcp_headers(headers, authorization)
|
|
170
378
|
|
|
171
379
|
tools = []
|
|
172
|
-
|
|
380
|
+
|
|
381
|
+
# Helper function to process session and list tools
|
|
382
|
+
async def _list_tools_from_session(session):
|
|
173
383
|
tools_result = await session.list_tools()
|
|
174
384
|
for tool in tools_result.tools:
|
|
175
385
|
tools.append(
|
|
@@ -183,15 +393,51 @@ async def list_mcp_tools(
|
|
|
183
393
|
},
|
|
184
394
|
)
|
|
185
395
|
)
|
|
396
|
+
|
|
397
|
+
# If a session manager is provided, use it for session reuse (fix for #4452)
|
|
398
|
+
if session_manager is not None:
|
|
399
|
+
session = await session_manager.get_session(endpoint, final_headers)
|
|
400
|
+
await _list_tools_from_session(session)
|
|
401
|
+
else:
|
|
402
|
+
# Fallback to original behavior: create a new session for this call
|
|
403
|
+
async with client_wrapper(endpoint, final_headers) as session:
|
|
404
|
+
await _list_tools_from_session(session)
|
|
405
|
+
|
|
186
406
|
return ListToolDefsResponse(data=tools)
|
|
187
407
|
|
|
188
408
|
|
|
409
|
+
def _parse_mcp_result(result) -> ToolInvocationResult:
|
|
410
|
+
"""Parse MCP tool call result into ToolInvocationResult.
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
result: The raw MCP tool call result
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
ToolInvocationResult with parsed content
|
|
417
|
+
"""
|
|
418
|
+
content: list[InterleavedContentItem] = []
|
|
419
|
+
for item in result.content:
|
|
420
|
+
if isinstance(item, mcp_types.TextContent):
|
|
421
|
+
content.append(TextContentItem(text=item.text))
|
|
422
|
+
elif isinstance(item, mcp_types.ImageContent):
|
|
423
|
+
content.append(ImageContentItem(image=_URLOrData(data=item.data)))
|
|
424
|
+
elif isinstance(item, mcp_types.EmbeddedResource):
|
|
425
|
+
logger.warning(f"EmbeddedResource is not supported: {item}")
|
|
426
|
+
else:
|
|
427
|
+
raise ValueError(f"Unknown content type: {type(item)}")
|
|
428
|
+
return ToolInvocationResult(
|
|
429
|
+
content=content,
|
|
430
|
+
error_code=1 if result.isError else 0,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
|
|
189
434
|
async def invoke_mcp_tool(
|
|
190
435
|
endpoint: str,
|
|
191
436
|
tool_name: str,
|
|
192
437
|
kwargs: dict[str, Any],
|
|
193
438
|
headers: dict[str, str] | None = None,
|
|
194
439
|
authorization: str | None = None,
|
|
440
|
+
session_manager: MCPSessionManager | None = None,
|
|
195
441
|
) -> ToolInvocationResult:
|
|
196
442
|
"""Invoke an MCP tool with the given arguments.
|
|
197
443
|
|
|
@@ -201,6 +447,9 @@ async def invoke_mcp_tool(
|
|
|
201
447
|
kwargs: Tool invocation arguments
|
|
202
448
|
headers: Optional base headers to include
|
|
203
449
|
authorization: Optional OAuth access token (just the token, not "Bearer <token>")
|
|
450
|
+
session_manager: Optional MCPSessionManager for session reuse within a request.
|
|
451
|
+
When provided, sessions are cached and reused for multiple tool calls to
|
|
452
|
+
the same endpoint, avoiding redundant tools/list calls. (Fix for #4452)
|
|
204
453
|
|
|
205
454
|
Returns:
|
|
206
455
|
Tool invocation result with content and error information
|
|
@@ -211,20 +460,13 @@ async def invoke_mcp_tool(
|
|
|
211
460
|
# Prepare headers with authorization handling
|
|
212
461
|
final_headers = prepare_mcp_headers(headers, authorization)
|
|
213
462
|
|
|
214
|
-
|
|
463
|
+
# If a session manager is provided, use it for session reuse (fix for #4452)
|
|
464
|
+
if session_manager is not None:
|
|
465
|
+
session = await session_manager.get_session(endpoint, final_headers)
|
|
215
466
|
result = await session.call_tool(tool_name, kwargs)
|
|
467
|
+
return _parse_mcp_result(result)
|
|
216
468
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
elif isinstance(item, mcp_types.ImageContent):
|
|
222
|
-
content.append(ImageContentItem(image=_URLOrData(data=item.data)))
|
|
223
|
-
elif isinstance(item, mcp_types.EmbeddedResource):
|
|
224
|
-
logger.warning(f"EmbeddedResource is not supported: {item}")
|
|
225
|
-
else:
|
|
226
|
-
raise ValueError(f"Unknown content type: {type(item)}")
|
|
227
|
-
return ToolInvocationResult(
|
|
228
|
-
content=content,
|
|
229
|
-
error_code=1 if result.isError else 0,
|
|
230
|
-
)
|
|
469
|
+
# Fallback to original behavior: create a new session for each call
|
|
470
|
+
async with client_wrapper(endpoint, final_headers) as session:
|
|
471
|
+
result = await session.call_tool(tool_name, kwargs)
|
|
472
|
+
return _parse_mcp_result(result)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: llama_stack
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.3
|
|
4
4
|
Summary: Llama Stack
|
|
5
5
|
Author-email: Meta Llama <llama-oss@meta.com>
|
|
6
6
|
License: MIT
|
|
@@ -46,7 +46,7 @@ Requires-Dist: psycopg2-binary
|
|
|
46
46
|
Requires-Dist: tornado>=6.5.3
|
|
47
47
|
Requires-Dist: urllib3>=2.6.3
|
|
48
48
|
Provides-Extra: client
|
|
49
|
-
Requires-Dist: llama-stack-client==0.4.
|
|
49
|
+
Requires-Dist: llama-stack-client==0.4.3; extra == "client"
|
|
50
50
|
Dynamic: license-file
|
|
51
51
|
|
|
52
52
|
# Llama Stack
|