nvidia-nat-mcp 1.3.0a20251006__py3-none-any.whl → 1.3.0a20251008__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +2 -2
- nat/plugins/mcp/auth/auth_provider.py +42 -10
- nat/plugins/mcp/auth/auth_provider_config.py +5 -0
- nat/plugins/mcp/auth/register.py +1 -1
- nat/plugins/mcp/auth/token_storage.py +265 -0
- nat/plugins/mcp/client_base.py +14 -95
- nat/plugins/mcp/client_config.py +131 -0
- nat/plugins/mcp/client_impl.py +411 -106
- nat/plugins/mcp/tool.py +5 -0
- nat/plugins/mcp/utils.py +16 -0
- {nvidia_nat_mcp-1.3.0a20251006.dist-info → nvidia_nat_mcp-1.3.0a20251008.dist-info}/METADATA +13 -4
- nvidia_nat_mcp-1.3.0a20251008.dist-info/RECORD +23 -0
- nvidia_nat_mcp-1.3.0a20251008.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_mcp-1.3.0a20251008.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_mcp-1.3.0a20251006.dist-info/RECORD +0 -19
- {nvidia_nat_mcp-1.3.0a20251006.dist-info → nvidia_nat_mcp-1.3.0a20251008.dist-info}/WHEEL +0 -0
- {nvidia_nat_mcp-1.3.0a20251006.dist-info → nvidia_nat_mcp-1.3.0a20251008.dist-info}/entry_points.txt +0 -0
- {nvidia_nat_mcp-1.3.0a20251006.dist-info → nvidia_nat_mcp-1.3.0a20251008.dist-info}/top_level.txt +0 -0
nat/meta/pypi.md
CHANGED
@@ -19,9 +19,9 @@ limitations under the License.
|
|
19
19
|
|
20
20
|
|
21
21
|
# NVIDIA NeMo Agent Toolkit MCP Subpackage
|
22
|
-
Subpackage for MCP
|
22
|
+
Subpackage for MCP integration in NeMo Agent toolkit.
|
23
23
|
|
24
|
-
This package provides MCP (Model Context Protocol)
|
24
|
+
This package provides MCP (Model Context Protocol) functionality, allowing NeMo Agent toolkit workflows to connect to external MCP servers and use their tools as functions.
|
25
25
|
|
26
26
|
## Features
|
27
27
|
|
@@ -23,6 +23,7 @@ import httpx
|
|
23
23
|
from pydantic import BaseModel
|
24
24
|
from pydantic import Field
|
25
25
|
from pydantic import HttpUrl
|
26
|
+
from pydantic import TypeAdapter
|
26
27
|
|
27
28
|
from mcp.shared.auth import OAuthClientInformationFull
|
28
29
|
from mcp.shared.auth import OAuthClientMetadata
|
@@ -65,7 +66,6 @@ class DiscoverOAuth2Endpoints:
|
|
65
66
|
def __init__(self, config: MCPOAuth2ProviderConfig):
|
66
67
|
self.config = config
|
67
68
|
self._cached_endpoints: OAuth2Endpoints | None = None
|
68
|
-
self._authenticated_servers: dict[str, AuthResult] = {}
|
69
69
|
|
70
70
|
self._flow_handler: MCPAuthenticationFlowHandler = MCPAuthenticationFlowHandler()
|
71
71
|
|
@@ -192,11 +192,13 @@ class DiscoverOAuth2Endpoints:
|
|
192
192
|
continue
|
193
193
|
if meta.authorization_endpoint and meta.token_endpoint:
|
194
194
|
logger.info("Discovered OAuth2 endpoints from %s", url)
|
195
|
-
#
|
195
|
+
# Convert AnyHttpUrl to HttpUrl using TypeAdapter
|
196
|
+
http_url_adapter = TypeAdapter(HttpUrl)
|
196
197
|
return OAuth2Endpoints(
|
197
|
-
authorization_url=str(meta.authorization_endpoint),
|
198
|
-
token_url=str(meta.token_endpoint),
|
199
|
-
registration_url=str(meta.registration_endpoint)
|
198
|
+
authorization_url=http_url_adapter.validate_python(str(meta.authorization_endpoint)),
|
199
|
+
token_url=http_url_adapter.validate_python(str(meta.token_endpoint)),
|
200
|
+
registration_url=http_url_adapter.validate_python(str(meta.registration_endpoint))
|
201
|
+
if meta.registration_endpoint else None,
|
200
202
|
scopes=meta.scopes_supported,
|
201
203
|
)
|
202
204
|
except Exception as e:
|
@@ -283,8 +285,9 @@ class DynamicClientRegistration:
|
|
283
285
|
class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
284
286
|
"""MCP OAuth2 authentication provider that delegates to NAT framework."""
|
285
287
|
|
286
|
-
def __init__(self, config: MCPOAuth2ProviderConfig):
|
288
|
+
def __init__(self, config: MCPOAuth2ProviderConfig, builder=None):
|
287
289
|
super().__init__(config)
|
290
|
+
self._builder = builder
|
288
291
|
|
289
292
|
# Discovery
|
290
293
|
self._discoverer = DiscoverOAuth2Endpoints(config)
|
@@ -300,6 +303,19 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
|
300
303
|
|
301
304
|
self._auth_callback = None
|
302
305
|
|
306
|
+
# Initialize token storage
|
307
|
+
self._token_storage = None
|
308
|
+
self._token_storage_object_store_name = None
|
309
|
+
|
310
|
+
if self.config.token_storage_object_store:
|
311
|
+
# Store object store name, will be resolved later when builder context is available
|
312
|
+
self._token_storage_object_store_name = self.config.token_storage_object_store
|
313
|
+
logger.info(f"Configured to use object store '{self._token_storage_object_store_name}' for token storage")
|
314
|
+
else:
|
315
|
+
# Default: use in-memory token storage
|
316
|
+
from .token_storage import InMemoryTokenStorage
|
317
|
+
self._token_storage = InMemoryTokenStorage()
|
318
|
+
|
303
319
|
def _set_custom_auth_callback(self,
|
304
320
|
auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType],
|
305
321
|
Awaitable[AuthenticatedContext]]):
|
@@ -308,7 +324,7 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
|
308
324
|
logger.info("Using custom authentication callback")
|
309
325
|
self._auth_callback = auth_callback
|
310
326
|
if self._auth_code_provider:
|
311
|
-
self._auth_code_provider._set_custom_auth_callback(self._auth_callback)
|
327
|
+
self._auth_code_provider._set_custom_auth_callback(self._auth_callback) # type: ignore[arg-type]
|
312
328
|
|
313
329
|
async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
|
314
330
|
"""
|
@@ -374,6 +390,22 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
|
374
390
|
endpoints = self._cached_endpoints
|
375
391
|
credentials = self._cached_credentials
|
376
392
|
|
393
|
+
# Resolve object store reference if needed
|
394
|
+
if self._token_storage_object_store_name and not self._token_storage:
|
395
|
+
try:
|
396
|
+
if not self._builder:
|
397
|
+
raise RuntimeError("Builder not available for resolving object store")
|
398
|
+
object_store = await self._builder.get_object_store_client(self._token_storage_object_store_name)
|
399
|
+
from .token_storage import ObjectStoreTokenStorage
|
400
|
+
self._token_storage = ObjectStoreTokenStorage(object_store)
|
401
|
+
logger.info(f"Initialized token storage with object store '{self._token_storage_object_store_name}'")
|
402
|
+
except Exception as e:
|
403
|
+
logger.warning(
|
404
|
+
f"Failed to resolve object store '{self._token_storage_object_store_name}' for token storage: {e}. "
|
405
|
+
"Falling back to in-memory storage.")
|
406
|
+
from .token_storage import InMemoryTokenStorage
|
407
|
+
self._token_storage = InMemoryTokenStorage()
|
408
|
+
|
377
409
|
# Build the OAuth2 provider if not already built
|
378
410
|
if self._auth_code_provider is None:
|
379
411
|
scopes = self._effective_scopes
|
@@ -387,12 +419,12 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
|
387
419
|
scopes=scopes,
|
388
420
|
use_pkce=bool(self.config.use_pkce),
|
389
421
|
authorization_kwargs={"resource": str(self.config.server_url)})
|
390
|
-
self._auth_code_provider = OAuth2AuthCodeFlowProvider(oauth2_config)
|
422
|
+
self._auth_code_provider = OAuth2AuthCodeFlowProvider(oauth2_config, token_storage=self._token_storage)
|
391
423
|
|
392
424
|
# Use MCP-specific authentication method if available
|
393
425
|
if hasattr(self._auth_code_provider, "_set_custom_auth_callback"):
|
394
|
-
self.
|
395
|
-
|
426
|
+
callback = self._auth_callback or self._flow_handler.authenticate
|
427
|
+
self._auth_code_provider._set_custom_auth_callback(callback) # type: ignore[arg-type]
|
396
428
|
|
397
429
|
# Auth code provider is responsible for per-user cache + refresh
|
398
430
|
return await self._auth_code_provider.authenticate(user_id=user_id)
|
@@ -53,6 +53,11 @@ class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
|
|
53
53
|
default_user_id: str | None = Field(default=None, description="Default user ID for authentication")
|
54
54
|
allow_default_user_id_for_tool_calls: bool = Field(default=True, description="Allow default user ID for tool calls")
|
55
55
|
|
56
|
+
# Token storage configuration
|
57
|
+
token_storage_object_store: str | None = Field(
|
58
|
+
default=None,
|
59
|
+
description="Reference to object store for secure token storage. If None, uses in-memory storage.")
|
60
|
+
|
56
61
|
@model_validator(mode="after")
|
57
62
|
def validate_auth_config(self):
|
58
63
|
"""Validate authentication configuration for MCP-specific options."""
|
nat/plugins/mcp/auth/register.py
CHANGED
@@ -22,4 +22,4 @@ from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig
|
|
22
22
|
@register_auth_provider(config_type=MCPOAuth2ProviderConfig)
|
23
23
|
async def mcp_oauth2_provider(authentication_provider: MCPOAuth2ProviderConfig, builder: Builder):
|
24
24
|
"""Register MCP OAuth2 authentication provider with NAT system."""
|
25
|
-
yield MCPOAuth2Provider(authentication_provider)
|
25
|
+
yield MCPOAuth2Provider(authentication_provider, builder=builder)
|
@@ -0,0 +1,265 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import hashlib
|
17
|
+
import json
|
18
|
+
import logging
|
19
|
+
from abc import ABC
|
20
|
+
from abc import abstractmethod
|
21
|
+
|
22
|
+
from nat.data_models.authentication import AuthResult
|
23
|
+
from nat.data_models.authentication import BasicAuthCred
|
24
|
+
from nat.data_models.authentication import BearerTokenCred
|
25
|
+
from nat.data_models.authentication import CookieCred
|
26
|
+
from nat.data_models.authentication import HeaderCred
|
27
|
+
from nat.data_models.authentication import QueryCred
|
28
|
+
from nat.data_models.object_store import NoSuchKeyError
|
29
|
+
from nat.object_store.interfaces import ObjectStore
|
30
|
+
from nat.object_store.models import ObjectStoreItem
|
31
|
+
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
|
35
|
+
class TokenStorageBase(ABC):
|
36
|
+
"""
|
37
|
+
Abstract base class for token storage implementations.
|
38
|
+
|
39
|
+
Token storage implementations handle the secure persistence of authentication
|
40
|
+
tokens for MCP OAuth2 flows. Implementations can use various backends such as
|
41
|
+
object stores, databases, or in-memory storage.
|
42
|
+
"""
|
43
|
+
|
44
|
+
@abstractmethod
|
45
|
+
async def store(self, user_id: str, auth_result: AuthResult) -> None:
|
46
|
+
"""
|
47
|
+
Store an authentication result for a user.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
user_id: The unique identifier for the user
|
51
|
+
auth_result: The authentication result to store
|
52
|
+
"""
|
53
|
+
pass
|
54
|
+
|
55
|
+
@abstractmethod
|
56
|
+
async def retrieve(self, user_id: str) -> AuthResult | None:
|
57
|
+
"""
|
58
|
+
Retrieve an authentication result for a user.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
user_id: The unique identifier for the user
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
The authentication result if found, None otherwise
|
65
|
+
"""
|
66
|
+
pass
|
67
|
+
|
68
|
+
@abstractmethod
|
69
|
+
async def delete(self, user_id: str) -> None:
|
70
|
+
"""
|
71
|
+
Delete an authentication result for a user.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
user_id: The unique identifier for the user
|
75
|
+
"""
|
76
|
+
pass
|
77
|
+
|
78
|
+
@abstractmethod
|
79
|
+
async def clear_all(self) -> None:
|
80
|
+
"""
|
81
|
+
Clear all stored authentication results.
|
82
|
+
"""
|
83
|
+
pass
|
84
|
+
|
85
|
+
|
86
|
+
class ObjectStoreTokenStorage(TokenStorageBase):
|
87
|
+
"""
|
88
|
+
Token storage implementation backed by a NeMo Agent toolkit object store.
|
89
|
+
|
90
|
+
This implementation uses the object store infrastructure to persist tokens,
|
91
|
+
which provides encryption at rest, access controls, and persistence across
|
92
|
+
restarts when using backends like S3, MySQL, or Redis.
|
93
|
+
"""
|
94
|
+
|
95
|
+
def __init__(self, object_store: ObjectStore):
|
96
|
+
"""
|
97
|
+
Initialize the object store token storage.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
object_store: The object store instance to use for token persistence
|
101
|
+
"""
|
102
|
+
self._object_store = object_store
|
103
|
+
|
104
|
+
def _get_key(self, user_id: str) -> str:
|
105
|
+
"""
|
106
|
+
Generate the object store key for a user's token.
|
107
|
+
|
108
|
+
Uses SHA256 hash to ensure the key is S3-compatible and doesn't
|
109
|
+
contain special characters like "://" that are invalid in object keys.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
user_id: The user identifier
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
The object store key
|
116
|
+
"""
|
117
|
+
# Hash the user_id to create an S3-safe key
|
118
|
+
user_hash = hashlib.sha256(user_id.encode('utf-8')).hexdigest()
|
119
|
+
return f"tokens/{user_hash}"
|
120
|
+
|
121
|
+
async def store(self, user_id: str, auth_result: AuthResult) -> None:
|
122
|
+
"""
|
123
|
+
Store an authentication result in the object store.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
user_id: The unique identifier for the user
|
127
|
+
auth_result: The authentication result to store
|
128
|
+
"""
|
129
|
+
key = self._get_key(user_id)
|
130
|
+
|
131
|
+
# Serialize the AuthResult to JSON with secrets exposed
|
132
|
+
# SecretStr values are masked by default, so we need to expose them manually
|
133
|
+
# Create a serializable dict with exposed secrets
|
134
|
+
auth_dict = auth_result.model_dump(mode='json')
|
135
|
+
# Manually expose SecretStr values in credentials
|
136
|
+
for i, cred_obj in enumerate(auth_result.credentials):
|
137
|
+
if isinstance(cred_obj, BearerTokenCred):
|
138
|
+
auth_dict['credentials'][i]['token'] = cred_obj.token.get_secret_value()
|
139
|
+
elif isinstance(cred_obj, BasicAuthCred):
|
140
|
+
auth_dict['credentials'][i]['username'] = cred_obj.username.get_secret_value()
|
141
|
+
auth_dict['credentials'][i]['password'] = cred_obj.password.get_secret_value()
|
142
|
+
elif isinstance(cred_obj, HeaderCred | QueryCred | CookieCred):
|
143
|
+
auth_dict['credentials'][i]['value'] = cred_obj.value.get_secret_value()
|
144
|
+
|
145
|
+
data = json.dumps(auth_dict).encode('utf-8')
|
146
|
+
|
147
|
+
# Prepare metadata
|
148
|
+
metadata = {}
|
149
|
+
if auth_result.token_expires_at:
|
150
|
+
metadata["expires_at"] = auth_result.token_expires_at.isoformat()
|
151
|
+
|
152
|
+
# Create the object store item
|
153
|
+
item = ObjectStoreItem(data=data, content_type="application/json", metadata=metadata if metadata else None)
|
154
|
+
|
155
|
+
# Store using upsert to handle both new and existing tokens
|
156
|
+
await self._object_store.upsert_object(key, item)
|
157
|
+
|
158
|
+
async def retrieve(self, user_id: str) -> AuthResult | None:
|
159
|
+
"""
|
160
|
+
Retrieve an authentication result from the object store.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
user_id: The unique identifier for the user
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
The authentication result if found, None otherwise
|
167
|
+
"""
|
168
|
+
key = self._get_key(user_id)
|
169
|
+
|
170
|
+
try:
|
171
|
+
item = await self._object_store.get_object(key)
|
172
|
+
# Deserialize the AuthResult from JSON
|
173
|
+
auth_result = AuthResult.model_validate_json(item.data)
|
174
|
+
return auth_result
|
175
|
+
except NoSuchKeyError:
|
176
|
+
return None
|
177
|
+
except Exception as e:
|
178
|
+
logger.error(f"Error deserializing token for user {user_id}: {e}", exc_info=True)
|
179
|
+
return None
|
180
|
+
|
181
|
+
async def delete(self, user_id: str) -> None:
|
182
|
+
"""
|
183
|
+
Delete an authentication result from the object store.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
user_id: The unique identifier for the user
|
187
|
+
"""
|
188
|
+
key = self._get_key(user_id)
|
189
|
+
|
190
|
+
try:
|
191
|
+
await self._object_store.delete_object(key)
|
192
|
+
except NoSuchKeyError:
|
193
|
+
# Token doesn't exist, which is fine for delete operations
|
194
|
+
pass
|
195
|
+
|
196
|
+
async def clear_all(self) -> None:
|
197
|
+
"""
|
198
|
+
Clear all stored authentication results.
|
199
|
+
|
200
|
+
Note: This implementation does not support clearing all tokens as the
|
201
|
+
object store interface doesn't provide a list operation. Individual
|
202
|
+
tokens must be deleted explicitly.
|
203
|
+
"""
|
204
|
+
logger.warning("clear_all() is not supported for ObjectStoreTokenStorage")
|
205
|
+
|
206
|
+
|
207
|
+
class InMemoryTokenStorage(TokenStorageBase):
|
208
|
+
"""
|
209
|
+
In-memory token storage using NeMo Agent toolkit's built-in object store.
|
210
|
+
|
211
|
+
This implementation uses the in-memory object store for token persistence,
|
212
|
+
which provides a secure default option that doesn't require external storage
|
213
|
+
configuration. Tokens are stored in memory and cleared when the process exits.
|
214
|
+
"""
|
215
|
+
|
216
|
+
def __init__(self):
|
217
|
+
"""
|
218
|
+
Initialize the in-memory token storage.
|
219
|
+
"""
|
220
|
+
from nat.object_store.in_memory_object_store import InMemoryObjectStore
|
221
|
+
|
222
|
+
# Create a dedicated in-memory object store for tokens
|
223
|
+
self._object_store = InMemoryObjectStore()
|
224
|
+
|
225
|
+
# Wrap with ObjectStoreTokenStorage for the actual implementation
|
226
|
+
self._storage = ObjectStoreTokenStorage(self._object_store)
|
227
|
+
logger.debug("Initialized in-memory token storage")
|
228
|
+
|
229
|
+
async def store(self, user_id: str, auth_result: AuthResult) -> None:
|
230
|
+
"""
|
231
|
+
Store an authentication result in memory.
|
232
|
+
|
233
|
+
Args:
|
234
|
+
user_id: The unique identifier for the user
|
235
|
+
auth_result: The authentication result to store
|
236
|
+
"""
|
237
|
+
await self._storage.store(user_id, auth_result)
|
238
|
+
|
239
|
+
async def retrieve(self, user_id: str) -> AuthResult | None:
|
240
|
+
"""
|
241
|
+
Retrieve an authentication result from memory.
|
242
|
+
|
243
|
+
Args:
|
244
|
+
user_id: The unique identifier for the user
|
245
|
+
|
246
|
+
Returns:
|
247
|
+
The authentication result if found, None otherwise
|
248
|
+
"""
|
249
|
+
return await self._storage.retrieve(user_id)
|
250
|
+
|
251
|
+
async def delete(self, user_id: str) -> None:
|
252
|
+
"""
|
253
|
+
Delete an authentication result from memory.
|
254
|
+
|
255
|
+
Args:
|
256
|
+
user_id: The unique identifier for the user
|
257
|
+
"""
|
258
|
+
await self._storage.delete(user_id)
|
259
|
+
|
260
|
+
async def clear_all(self) -> None:
|
261
|
+
"""
|
262
|
+
Clear all stored authentication results from memory.
|
263
|
+
"""
|
264
|
+
# For in-memory storage, we can access the internal storage
|
265
|
+
self._object_store._store.clear()
|
nat/plugins/mcp/client_base.py
CHANGED
@@ -16,7 +16,6 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import asyncio
|
19
|
-
import json
|
20
19
|
import logging
|
21
20
|
from abc import ABC
|
22
21
|
from abc import abstractmethod
|
@@ -55,8 +54,9 @@ class AuthAdapter(httpx.Auth):
|
|
55
54
|
Converts AuthProviderBase to httpx.Auth interface for dynamic token management.
|
56
55
|
"""
|
57
56
|
|
58
|
-
def __init__(self, auth_provider: AuthProviderBase):
|
57
|
+
def __init__(self, auth_provider: AuthProviderBase, user_id: str | None = None):
|
59
58
|
self.auth_provider = auth_provider
|
59
|
+
self.user_id = user_id # Session-specific user ID for cache isolation
|
60
60
|
# each adapter instance has its own lock to avoid unnecessary delays for multiple clients
|
61
61
|
self._lock = anyio.Lock()
|
62
62
|
# Track whether we're currently in an interactive authentication flow
|
@@ -104,41 +104,13 @@ class AuthAdapter(httpx.Auth):
|
|
104
104
|
logger.debug("Authentication flow completed")
|
105
105
|
return
|
106
106
|
|
107
|
-
def _get_session_id_from_tool_call_request(self, request: httpx.Request) -> tuple[str | None, bool]:
|
108
|
-
"""Check if this is a tool call request based on the request body.
|
109
|
-
Return the session id if it exists and a boolean indicating if it is a tool call request
|
110
|
-
"""
|
111
|
-
try:
|
112
|
-
# Check if the request body contains a tool call
|
113
|
-
if request.content:
|
114
|
-
body = json.loads(request.content.decode('utf-8'))
|
115
|
-
# Check if it's a JSON-RPC request with method "tools/call"
|
116
|
-
if (isinstance(body, dict) and body.get("method") == "tools/call"):
|
117
|
-
session_id = body.get("params").get("_meta").get("session_id")
|
118
|
-
return session_id, True
|
119
|
-
except (json.JSONDecodeError, UnicodeDecodeError, AttributeError):
|
120
|
-
# If we can't parse the body, assume it's not a tool call
|
121
|
-
pass
|
122
|
-
return None, False
|
123
|
-
|
124
107
|
async def _get_auth_headers(self,
|
125
108
|
request: httpx.Request | None = None,
|
126
109
|
response: httpx.Response | None = None) -> dict[str, str]:
|
127
110
|
"""Get authentication headers from the NAT auth provider."""
|
128
111
|
try:
|
129
|
-
|
130
|
-
|
131
|
-
if request:
|
132
|
-
session_id, is_tool_call = self._get_session_id_from_tool_call_request(request)
|
133
|
-
|
134
|
-
if is_tool_call:
|
135
|
-
# Tool call requests should use the session id
|
136
|
-
user_id = session_id
|
137
|
-
else:
|
138
|
-
# Non-tool call requests should use the session id if it exists and fallback to default user id
|
139
|
-
user_id = session_id or self.auth_provider.config.default_user_id
|
140
|
-
|
141
|
-
auth_result = await self.auth_provider.authenticate(user_id=user_id, response=response)
|
112
|
+
# Use the user_id passed to this AuthAdapter instance
|
113
|
+
auth_result = await self.auth_provider.authenticate(user_id=self.user_id, response=response)
|
142
114
|
|
143
115
|
# Check if we have BearerTokenCred
|
144
116
|
from nat.data_models.authentication import BearerTokenCred
|
@@ -171,6 +143,7 @@ class MCPBaseClient(ABC):
|
|
171
143
|
def __init__(self,
|
172
144
|
transport: str = 'streamable-http',
|
173
145
|
auth_provider: AuthProviderBase | None = None,
|
146
|
+
user_id: str | None = None,
|
174
147
|
tool_call_timeout: timedelta = timedelta(seconds=60),
|
175
148
|
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
176
149
|
reconnect_enabled: bool = True,
|
@@ -189,7 +162,9 @@ class MCPBaseClient(ABC):
|
|
189
162
|
|
190
163
|
# Convert auth provider to AuthAdapter
|
191
164
|
self._auth_provider = auth_provider
|
192
|
-
|
165
|
+
# Use provided user_id or fall back to auth provider's default_user_id
|
166
|
+
effective_user_id = user_id or (auth_provider.config.default_user_id if auth_provider else None)
|
167
|
+
self._httpx_auth = AuthAdapter(auth_provider, effective_user_id) if auth_provider else None
|
193
168
|
|
194
169
|
self._tool_call_timeout = tool_call_timeout
|
195
170
|
self._auth_flow_timeout = auth_flow_timeout
|
@@ -421,24 +396,6 @@ class MCPBaseClient(ABC):
|
|
421
396
|
if self._auth_provider and hasattr(self._auth_provider, "_set_custom_auth_callback"):
|
422
397
|
self._auth_provider._set_custom_auth_callback(auth_callback)
|
423
398
|
|
424
|
-
@mcp_exception_handler
|
425
|
-
async def call_tool_with_meta(self, tool_name: str, args: dict, session_id: str):
|
426
|
-
from mcp.types import CallToolRequest
|
427
|
-
from mcp.types import CallToolRequestParams
|
428
|
-
from mcp.types import CallToolResult
|
429
|
-
from mcp.types import ClientRequest
|
430
|
-
|
431
|
-
if not self._session:
|
432
|
-
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
433
|
-
|
434
|
-
async def _call_tool_with_meta():
|
435
|
-
params = CallToolRequestParams(name=tool_name, arguments=args, **{"_meta": {"session_id": session_id}})
|
436
|
-
req = ClientRequest(CallToolRequest(params=params))
|
437
|
-
timeout = await self._get_tool_call_timeout()
|
438
|
-
return await self._session.send_request(req, CallToolResult, request_read_timeout_seconds=timeout)
|
439
|
-
|
440
|
-
return await self._with_reconnect(_call_tool_with_meta)
|
441
|
-
|
442
399
|
@mcp_exception_handler
|
443
400
|
async def call_tool(self, tool_name: str, tool_args: dict | None):
|
444
401
|
|
@@ -570,6 +527,7 @@ class MCPStreamableHTTPClient(MCPBaseClient):
|
|
570
527
|
def __init__(self,
|
571
528
|
url: str,
|
572
529
|
auth_provider: AuthProviderBase | None = None,
|
530
|
+
user_id: str | None = None,
|
573
531
|
tool_call_timeout: timedelta = timedelta(seconds=60),
|
574
532
|
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
575
533
|
reconnect_enabled: bool = True,
|
@@ -578,6 +536,7 @@ class MCPStreamableHTTPClient(MCPBaseClient):
|
|
578
536
|
reconnect_max_backoff: float = 50.0):
|
579
537
|
super().__init__("streamable-http",
|
580
538
|
auth_provider=auth_provider,
|
539
|
+
user_id=user_id,
|
581
540
|
tool_call_timeout=tool_call_timeout,
|
582
541
|
auth_flow_timeout=auth_flow_timeout,
|
583
542
|
reconnect_enabled=reconnect_enabled,
|
@@ -662,35 +621,10 @@ class MCPToolClient:
|
|
662
621
|
"""
|
663
622
|
self._tool_description = description
|
664
623
|
|
665
|
-
def _get_session_id(self) -> str | None:
|
666
|
-
"""
|
667
|
-
Get the session id from the context.
|
668
|
-
"""
|
669
|
-
from nat.builder.context import Context as _Ctx
|
670
|
-
|
671
|
-
# get auth callback (for example: WebSocketAuthenticationFlowHandler). this is lazily set in the client
|
672
|
-
# on first tool call
|
673
|
-
auth_callback = _Ctx.get().user_auth_callback
|
674
|
-
if auth_callback and self._parent_client:
|
675
|
-
# set custom auth callback
|
676
|
-
self._parent_client.set_user_auth_callback(auth_callback)
|
677
|
-
|
678
|
-
# get session id from context, authentication is done per-websocket session for tool calls
|
679
|
-
session_id = None
|
680
|
-
cookies = getattr(_Ctx.get().metadata, "cookies", None)
|
681
|
-
if cookies:
|
682
|
-
session_id = cookies.get("nat-session")
|
683
|
-
|
684
|
-
if not session_id:
|
685
|
-
# use default user id if allowed
|
686
|
-
if self._parent_client.auth_provider and \
|
687
|
-
self._parent_client.auth_provider.config.allow_default_user_id_for_tool_calls:
|
688
|
-
session_id = self._parent_client.auth_provider.config.default_user_id
|
689
|
-
return session_id
|
690
|
-
|
691
624
|
async def acall(self, tool_args: dict) -> str:
|
692
625
|
"""
|
693
626
|
Call the MCP tool with the provided arguments.
|
627
|
+
Session context is now handled at the client level, eliminating the need for metadata injection.
|
694
628
|
|
695
629
|
Args:
|
696
630
|
tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
|
@@ -698,25 +632,10 @@ class MCPToolClient:
|
|
698
632
|
if self._session is None:
|
699
633
|
raise RuntimeError("No session available for tool call")
|
700
634
|
|
701
|
-
# Extract context information
|
702
635
|
try:
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
try:
|
708
|
-
# if auth is enabled and session id is not available return user is not authorized to call the tool
|
709
|
-
if self._parent_client.auth_provider and not session_id:
|
710
|
-
result_str = "User is not authorized to call the tool"
|
711
|
-
mcp_error: MCPError = convert_to_mcp_error(RuntimeError(result_str), self._parent_client.server_name)
|
712
|
-
raise mcp_error
|
713
|
-
|
714
|
-
if session_id:
|
715
|
-
logger.info("Calling tool %s with arguments %s for a user session", self._tool_name, tool_args)
|
716
|
-
result = await self._parent_client.call_tool_with_meta(self._tool_name, tool_args, session_id)
|
717
|
-
else:
|
718
|
-
logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
|
719
|
-
result = await self._parent_client.call_tool(self._tool_name, tool_args)
|
636
|
+
# Simple tool call - session context is already in the client instance
|
637
|
+
logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
|
638
|
+
result = await self._parent_client.call_tool(self._tool_name, tool_args)
|
720
639
|
|
721
640
|
output = []
|
722
641
|
for res in result.content:
|