nvidia-nat-mcp 1.3.0a20250926__py3-none-any.whl → 1.3.0a20251111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +2 -2
- nat/plugins/mcp/auth/auth_flow_handler.py +208 -0
- nat/plugins/mcp/auth/auth_provider.py +149 -86
- nat/plugins/mcp/auth/auth_provider_config.py +10 -2
- nat/plugins/mcp/auth/register.py +1 -1
- nat/plugins/mcp/auth/token_storage.py +265 -0
- nat/plugins/mcp/client_base.py +165 -71
- nat/plugins/mcp/client_config.py +131 -0
- nat/plugins/mcp/client_impl.py +469 -99
- nat/plugins/mcp/exception_handler.py +1 -1
- nat/plugins/mcp/tool.py +6 -7
- nat/plugins/mcp/utils.py +167 -34
- {nvidia_nat_mcp-1.3.0a20250926.dist-info → nvidia_nat_mcp-1.3.0a20251111.dist-info}/METADATA +13 -4
- nvidia_nat_mcp-1.3.0a20251111.dist-info/RECORD +23 -0
- nvidia_nat_mcp-1.3.0a20251111.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_mcp-1.3.0a20251111.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_mcp-1.3.0a20250926.dist-info/RECORD +0 -18
- {nvidia_nat_mcp-1.3.0a20250926.dist-info → nvidia_nat_mcp-1.3.0a20251111.dist-info}/WHEEL +0 -0
- {nvidia_nat_mcp-1.3.0a20250926.dist-info → nvidia_nat_mcp-1.3.0a20251111.dist-info}/entry_points.txt +0 -0
- {nvidia_nat_mcp-1.3.0a20250926.dist-info → nvidia_nat_mcp-1.3.0a20251111.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import hashlib
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
from abc import ABC
|
|
20
|
+
from abc import abstractmethod
|
|
21
|
+
|
|
22
|
+
from nat.data_models.authentication import AuthResult
|
|
23
|
+
from nat.data_models.authentication import BasicAuthCred
|
|
24
|
+
from nat.data_models.authentication import BearerTokenCred
|
|
25
|
+
from nat.data_models.authentication import CookieCred
|
|
26
|
+
from nat.data_models.authentication import HeaderCred
|
|
27
|
+
from nat.data_models.authentication import QueryCred
|
|
28
|
+
from nat.data_models.object_store import NoSuchKeyError
|
|
29
|
+
from nat.object_store.interfaces import ObjectStore
|
|
30
|
+
from nat.object_store.models import ObjectStoreItem
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TokenStorageBase(ABC):
|
|
36
|
+
"""
|
|
37
|
+
Abstract base class for token storage implementations.
|
|
38
|
+
|
|
39
|
+
Token storage implementations handle the secure persistence of authentication
|
|
40
|
+
tokens for MCP OAuth2 flows. Implementations can use various backends such as
|
|
41
|
+
object stores, databases, or in-memory storage.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
async def store(self, user_id: str, auth_result: AuthResult) -> None:
|
|
46
|
+
"""
|
|
47
|
+
Store an authentication result for a user.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
user_id: The unique identifier for the user
|
|
51
|
+
auth_result: The authentication result to store
|
|
52
|
+
"""
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
async def retrieve(self, user_id: str) -> AuthResult | None:
|
|
57
|
+
"""
|
|
58
|
+
Retrieve an authentication result for a user.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
user_id: The unique identifier for the user
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
The authentication result if found, None otherwise
|
|
65
|
+
"""
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
@abstractmethod
|
|
69
|
+
async def delete(self, user_id: str) -> None:
|
|
70
|
+
"""
|
|
71
|
+
Delete an authentication result for a user.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
user_id: The unique identifier for the user
|
|
75
|
+
"""
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
async def clear_all(self) -> None:
|
|
80
|
+
"""
|
|
81
|
+
Clear all stored authentication results.
|
|
82
|
+
"""
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ObjectStoreTokenStorage(TokenStorageBase):
|
|
87
|
+
"""
|
|
88
|
+
Token storage implementation backed by a NeMo Agent toolkit object store.
|
|
89
|
+
|
|
90
|
+
This implementation uses the object store infrastructure to persist tokens,
|
|
91
|
+
which provides encryption at rest, access controls, and persistence across
|
|
92
|
+
restarts when using backends like S3, MySQL, or Redis.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(self, object_store: ObjectStore):
|
|
96
|
+
"""
|
|
97
|
+
Initialize the object store token storage.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
object_store: The object store instance to use for token persistence
|
|
101
|
+
"""
|
|
102
|
+
self._object_store = object_store
|
|
103
|
+
|
|
104
|
+
def _get_key(self, user_id: str) -> str:
|
|
105
|
+
"""
|
|
106
|
+
Generate the object store key for a user's token.
|
|
107
|
+
|
|
108
|
+
Uses SHA256 hash to ensure the key is S3-compatible and doesn't
|
|
109
|
+
contain special characters like "://" that are invalid in object keys.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
user_id: The user identifier
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
The object store key
|
|
116
|
+
"""
|
|
117
|
+
# Hash the user_id to create an S3-safe key
|
|
118
|
+
user_hash = hashlib.sha256(user_id.encode('utf-8')).hexdigest()
|
|
119
|
+
return f"tokens/{user_hash}"
|
|
120
|
+
|
|
121
|
+
async def store(self, user_id: str, auth_result: AuthResult) -> None:
|
|
122
|
+
"""
|
|
123
|
+
Store an authentication result in the object store.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
user_id: The unique identifier for the user
|
|
127
|
+
auth_result: The authentication result to store
|
|
128
|
+
"""
|
|
129
|
+
key = self._get_key(user_id)
|
|
130
|
+
|
|
131
|
+
# Serialize the AuthResult to JSON with secrets exposed
|
|
132
|
+
# SecretStr values are masked by default, so we need to expose them manually
|
|
133
|
+
# Create a serializable dict with exposed secrets
|
|
134
|
+
auth_dict = auth_result.model_dump(mode='json')
|
|
135
|
+
# Manually expose SecretStr values in credentials
|
|
136
|
+
for i, cred_obj in enumerate(auth_result.credentials):
|
|
137
|
+
if isinstance(cred_obj, BearerTokenCred):
|
|
138
|
+
auth_dict['credentials'][i]['token'] = cred_obj.token.get_secret_value()
|
|
139
|
+
elif isinstance(cred_obj, BasicAuthCred):
|
|
140
|
+
auth_dict['credentials'][i]['username'] = cred_obj.username.get_secret_value()
|
|
141
|
+
auth_dict['credentials'][i]['password'] = cred_obj.password.get_secret_value()
|
|
142
|
+
elif isinstance(cred_obj, HeaderCred | QueryCred | CookieCred):
|
|
143
|
+
auth_dict['credentials'][i]['value'] = cred_obj.value.get_secret_value()
|
|
144
|
+
|
|
145
|
+
data = json.dumps(auth_dict).encode('utf-8')
|
|
146
|
+
|
|
147
|
+
# Prepare metadata
|
|
148
|
+
metadata = {}
|
|
149
|
+
if auth_result.token_expires_at:
|
|
150
|
+
metadata["expires_at"] = auth_result.token_expires_at.isoformat()
|
|
151
|
+
|
|
152
|
+
# Create the object store item
|
|
153
|
+
item = ObjectStoreItem(data=data, content_type="application/json", metadata=metadata if metadata else None)
|
|
154
|
+
|
|
155
|
+
# Store using upsert to handle both new and existing tokens
|
|
156
|
+
await self._object_store.upsert_object(key, item)
|
|
157
|
+
|
|
158
|
+
async def retrieve(self, user_id: str) -> AuthResult | None:
|
|
159
|
+
"""
|
|
160
|
+
Retrieve an authentication result from the object store.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
user_id: The unique identifier for the user
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
The authentication result if found, None otherwise
|
|
167
|
+
"""
|
|
168
|
+
key = self._get_key(user_id)
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
item = await self._object_store.get_object(key)
|
|
172
|
+
# Deserialize the AuthResult from JSON
|
|
173
|
+
auth_result = AuthResult.model_validate_json(item.data)
|
|
174
|
+
return auth_result
|
|
175
|
+
except NoSuchKeyError:
|
|
176
|
+
return None
|
|
177
|
+
except Exception as e:
|
|
178
|
+
logger.error(f"Error deserializing token for user {user_id}: {e}", exc_info=True)
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
async def delete(self, user_id: str) -> None:
|
|
182
|
+
"""
|
|
183
|
+
Delete an authentication result from the object store.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
user_id: The unique identifier for the user
|
|
187
|
+
"""
|
|
188
|
+
key = self._get_key(user_id)
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
await self._object_store.delete_object(key)
|
|
192
|
+
except NoSuchKeyError:
|
|
193
|
+
# Token doesn't exist, which is fine for delete operations
|
|
194
|
+
pass
|
|
195
|
+
|
|
196
|
+
async def clear_all(self) -> None:
|
|
197
|
+
"""
|
|
198
|
+
Clear all stored authentication results.
|
|
199
|
+
|
|
200
|
+
Note: This implementation does not support clearing all tokens as the
|
|
201
|
+
object store interface doesn't provide a list operation. Individual
|
|
202
|
+
tokens must be deleted explicitly.
|
|
203
|
+
"""
|
|
204
|
+
logger.warning("clear_all() is not supported for ObjectStoreTokenStorage")
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class InMemoryTokenStorage(TokenStorageBase):
|
|
208
|
+
"""
|
|
209
|
+
In-memory token storage using NeMo Agent toolkit's built-in object store.
|
|
210
|
+
|
|
211
|
+
This implementation uses the in-memory object store for token persistence,
|
|
212
|
+
which provides a secure default option that doesn't require external storage
|
|
213
|
+
configuration. Tokens are stored in memory and cleared when the process exits.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def __init__(self):
|
|
217
|
+
"""
|
|
218
|
+
Initialize the in-memory token storage.
|
|
219
|
+
"""
|
|
220
|
+
from nat.object_store.in_memory_object_store import InMemoryObjectStore
|
|
221
|
+
|
|
222
|
+
# Create a dedicated in-memory object store for tokens
|
|
223
|
+
self._object_store = InMemoryObjectStore()
|
|
224
|
+
|
|
225
|
+
# Wrap with ObjectStoreTokenStorage for the actual implementation
|
|
226
|
+
self._storage = ObjectStoreTokenStorage(self._object_store)
|
|
227
|
+
logger.debug("Initialized in-memory token storage")
|
|
228
|
+
|
|
229
|
+
async def store(self, user_id: str, auth_result: AuthResult) -> None:
|
|
230
|
+
"""
|
|
231
|
+
Store an authentication result in memory.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
user_id: The unique identifier for the user
|
|
235
|
+
auth_result: The authentication result to store
|
|
236
|
+
"""
|
|
237
|
+
await self._storage.store(user_id, auth_result)
|
|
238
|
+
|
|
239
|
+
async def retrieve(self, user_id: str) -> AuthResult | None:
|
|
240
|
+
"""
|
|
241
|
+
Retrieve an authentication result from memory.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
user_id: The unique identifier for the user
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
The authentication result if found, None otherwise
|
|
248
|
+
"""
|
|
249
|
+
return await self._storage.retrieve(user_id)
|
|
250
|
+
|
|
251
|
+
async def delete(self, user_id: str) -> None:
|
|
252
|
+
"""
|
|
253
|
+
Delete an authentication result from memory.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
user_id: The unique identifier for the user
|
|
257
|
+
"""
|
|
258
|
+
await self._storage.delete(user_id)
|
|
259
|
+
|
|
260
|
+
async def clear_all(self) -> None:
|
|
261
|
+
"""
|
|
262
|
+
Clear all stored authentication results from memory.
|
|
263
|
+
"""
|
|
264
|
+
# For in-memory storage, we can access the internal storage
|
|
265
|
+
self._object_store._store.clear()
|
nat/plugins/mcp/client_base.py
CHANGED
|
@@ -16,15 +16,16 @@
|
|
|
16
16
|
from __future__ import annotations
|
|
17
17
|
|
|
18
18
|
import asyncio
|
|
19
|
-
import json
|
|
20
19
|
import logging
|
|
21
20
|
from abc import ABC
|
|
22
21
|
from abc import abstractmethod
|
|
23
22
|
from collections.abc import AsyncGenerator
|
|
23
|
+
from collections.abc import Callable
|
|
24
24
|
from contextlib import AsyncExitStack
|
|
25
25
|
from contextlib import asynccontextmanager
|
|
26
26
|
from datetime import timedelta
|
|
27
27
|
|
|
28
|
+
import anyio
|
|
28
29
|
import httpx
|
|
29
30
|
|
|
30
31
|
from mcp import ClientSession
|
|
@@ -33,9 +34,9 @@ from mcp.client.stdio import StdioServerParameters
|
|
|
33
34
|
from mcp.client.stdio import stdio_client
|
|
34
35
|
from mcp.client.streamable_http import streamablehttp_client
|
|
35
36
|
from mcp.types import TextContent
|
|
37
|
+
from nat.authentication.interfaces import AuthenticatedContext
|
|
38
|
+
from nat.authentication.interfaces import AuthFlowType
|
|
36
39
|
from nat.authentication.interfaces import AuthProviderBase
|
|
37
|
-
from nat.data_models.authentication import AuthReason
|
|
38
|
-
from nat.data_models.authentication import AuthRequest
|
|
39
40
|
from nat.plugins.mcp.exception_handler import convert_to_mcp_error
|
|
40
41
|
from nat.plugins.mcp.exception_handler import format_mcp_error
|
|
41
42
|
from nat.plugins.mcp.exception_handler import mcp_exception_handler
|
|
@@ -53,74 +54,71 @@ class AuthAdapter(httpx.Auth):
|
|
|
53
54
|
Converts AuthProviderBase to httpx.Auth interface for dynamic token management.
|
|
54
55
|
"""
|
|
55
56
|
|
|
56
|
-
def __init__(self, auth_provider: AuthProviderBase,
|
|
57
|
+
def __init__(self, auth_provider: AuthProviderBase, user_id: str | None = None):
|
|
57
58
|
self.auth_provider = auth_provider
|
|
58
|
-
self.
|
|
59
|
+
self.user_id = user_id # Session-specific user ID for cache isolation
|
|
60
|
+
# each adapter instance has its own lock to avoid unnecessary delays for multiple clients
|
|
61
|
+
self._lock = anyio.Lock()
|
|
62
|
+
# Track whether we're currently in an interactive authentication flow
|
|
63
|
+
self.is_authenticating = False
|
|
59
64
|
|
|
60
65
|
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
|
61
66
|
"""Add authentication headers to the request using NAT auth provider."""
|
|
62
|
-
|
|
63
|
-
if self.auth_for_tool_calls_only and not self._is_tool_call_request(request):
|
|
64
|
-
# Skip auth for non-tool calls
|
|
65
|
-
yield request
|
|
66
|
-
return
|
|
67
|
-
|
|
68
|
-
try:
|
|
69
|
-
# Get fresh auth headers from the NAT auth provider
|
|
70
|
-
auth_headers = await self._get_auth_headers(reason=AuthReason.NORMAL)
|
|
71
|
-
request.headers.update(auth_headers)
|
|
72
|
-
except Exception as e:
|
|
73
|
-
logger.info("Failed to get auth headers: %s", e)
|
|
74
|
-
# Continue without auth headers if auth fails
|
|
75
|
-
|
|
76
|
-
response = yield request
|
|
77
|
-
|
|
78
|
-
# Handle 401 responses by retrying with fresh auth
|
|
79
|
-
if response.status_code == 401:
|
|
67
|
+
async with self._lock:
|
|
80
68
|
try:
|
|
81
|
-
# Get
|
|
82
|
-
|
|
69
|
+
# Get auth headers from the NAT auth provider:
|
|
70
|
+
# 1. If discovery is yet to done this will return None and request will be sent without auth header.
|
|
71
|
+
# 2. If discovery is done, this will return the auth header from cache if the token is still valid
|
|
72
|
+
auth_headers = await self._get_auth_headers(request=request, response=None)
|
|
83
73
|
request.headers.update(auth_headers)
|
|
84
|
-
yield request # Retry the request
|
|
85
74
|
except Exception as e:
|
|
86
|
-
logger.info("Failed to
|
|
75
|
+
logger.info("Failed to get auth headers: %s", e)
|
|
76
|
+
# Continue without auth headers if auth fails
|
|
77
|
+
|
|
78
|
+
response = yield request
|
|
79
|
+
|
|
80
|
+
# Handle 401 responses by retrying with fresh auth
|
|
81
|
+
if response.status_code == 401:
|
|
82
|
+
try:
|
|
83
|
+
# 401 can happen if:
|
|
84
|
+
# 1. The request was sent without auth header
|
|
85
|
+
# 2. The auth headers are invalid
|
|
86
|
+
# 3. The auth headers are expired
|
|
87
|
+
# 4. The auth headers are revoked
|
|
88
|
+
# 5. Auth config on the MCP server has changed
|
|
89
|
+
# In this case we attempt to re-run discovery and authentication
|
|
90
|
+
|
|
91
|
+
# Signal that we're entering interactive auth flow
|
|
92
|
+
self.is_authenticating = True
|
|
93
|
+
logger.debug("Starting authentication flow due to 401 response")
|
|
94
|
+
|
|
95
|
+
auth_headers = await self._get_auth_headers(request=request, response=response)
|
|
96
|
+
request.headers.update(auth_headers)
|
|
97
|
+
yield request # Retry the request
|
|
98
|
+
except Exception as e:
|
|
99
|
+
logger.info("Failed to refresh auth after 401: %s", e)
|
|
100
|
+
raise
|
|
101
|
+
finally:
|
|
102
|
+
# Signal that auth flow is complete
|
|
103
|
+
self.is_authenticating = False
|
|
104
|
+
logger.debug("Authentication flow completed")
|
|
87
105
|
return
|
|
88
106
|
|
|
89
|
-
def
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
# Check if the request body contains a tool call
|
|
93
|
-
if request.content:
|
|
94
|
-
body = json.loads(request.content.decode('utf-8'))
|
|
95
|
-
# Check if it's a JSON-RPC request with method "tools/call"
|
|
96
|
-
if (isinstance(body, dict) and body.get("method") == "tools/call"):
|
|
97
|
-
return True
|
|
98
|
-
except (json.JSONDecodeError, UnicodeDecodeError, AttributeError):
|
|
99
|
-
# If we can't parse the body, assume it's not a tool call
|
|
100
|
-
pass
|
|
101
|
-
return False
|
|
102
|
-
|
|
103
|
-
async def _get_auth_headers(self, reason: AuthReason, response: httpx.Response | None = None) -> dict[str, str]:
|
|
107
|
+
async def _get_auth_headers(self,
|
|
108
|
+
request: httpx.Request | None = None,
|
|
109
|
+
response: httpx.Response | None = None) -> dict[str, str]:
|
|
104
110
|
"""Get authentication headers from the NAT auth provider."""
|
|
105
|
-
# Build auth request
|
|
106
|
-
www_authenticate = response.headers.get("WWW-Authenticate", None) if response else None
|
|
107
|
-
auth_request = AuthRequest(
|
|
108
|
-
reason=reason,
|
|
109
|
-
www_authenticate=www_authenticate,
|
|
110
|
-
)
|
|
111
111
|
try:
|
|
112
|
-
#
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
self.auth_provider.config.auth_request = auth_request
|
|
116
|
-
auth_result = await self.auth_provider.authenticate()
|
|
112
|
+
# Use the user_id passed to this AuthAdapter instance
|
|
113
|
+
auth_result = await self.auth_provider.authenticate(user_id=self.user_id, response=response)
|
|
114
|
+
|
|
117
115
|
# Check if we have BearerTokenCred
|
|
118
116
|
from nat.data_models.authentication import BearerTokenCred
|
|
119
117
|
if auth_result.credentials and isinstance(auth_result.credentials[0], BearerTokenCred):
|
|
120
118
|
token = auth_result.credentials[0].token.get_secret_value()
|
|
121
119
|
return {"Authorization": f"Bearer {token}"}
|
|
122
120
|
else:
|
|
123
|
-
logger.
|
|
121
|
+
logger.info("Auth provider did not return BearerTokenCred")
|
|
124
122
|
return {}
|
|
125
123
|
except Exception as e:
|
|
126
124
|
logger.warning("Failed to get auth token: %s", e)
|
|
@@ -134,12 +132,20 @@ class MCPBaseClient(ABC):
|
|
|
134
132
|
Args:
|
|
135
133
|
transport (str): The type of client to use ('sse', 'stdio', or 'streamable-http')
|
|
136
134
|
auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
|
|
135
|
+
tool_call_timeout (timedelta): Timeout for tool calls when authentication is not required
|
|
136
|
+
auth_flow_timeout (timedelta): Extended timeout for tool calls that may require interactive authentication
|
|
137
|
+
reconnect_enabled (bool): Whether to automatically reconnect on connection failures
|
|
138
|
+
reconnect_max_attempts (int): Maximum number of reconnection attempts
|
|
139
|
+
reconnect_initial_backoff (float): Initial backoff delay in seconds for reconnection attempts
|
|
140
|
+
reconnect_max_backoff (float): Maximum backoff delay in seconds for reconnection attempts
|
|
137
141
|
"""
|
|
138
142
|
|
|
139
143
|
def __init__(self,
|
|
140
144
|
transport: str = 'streamable-http',
|
|
141
145
|
auth_provider: AuthProviderBase | None = None,
|
|
142
|
-
|
|
146
|
+
user_id: str | None = None,
|
|
147
|
+
tool_call_timeout: timedelta = timedelta(seconds=60),
|
|
148
|
+
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
|
143
149
|
reconnect_enabled: bool = True,
|
|
144
150
|
reconnect_max_attempts: int = 2,
|
|
145
151
|
reconnect_initial_backoff: float = 0.5,
|
|
@@ -155,9 +161,13 @@ class MCPBaseClient(ABC):
|
|
|
155
161
|
self._initial_connection = False
|
|
156
162
|
|
|
157
163
|
# Convert auth provider to AuthAdapter
|
|
158
|
-
self.
|
|
164
|
+
self._auth_provider = auth_provider
|
|
165
|
+
# Use provided user_id or fall back to auth provider's default_user_id
|
|
166
|
+
effective_user_id = user_id or (auth_provider.config.default_user_id if auth_provider else None)
|
|
167
|
+
self._httpx_auth = AuthAdapter(auth_provider, effective_user_id) if auth_provider else None
|
|
159
168
|
|
|
160
169
|
self._tool_call_timeout = tool_call_timeout
|
|
170
|
+
self._auth_flow_timeout = auth_flow_timeout
|
|
161
171
|
|
|
162
172
|
# Reconnect configuration
|
|
163
173
|
self._reconnect_enabled = reconnect_enabled
|
|
@@ -166,6 +176,10 @@ class MCPBaseClient(ABC):
|
|
|
166
176
|
self._reconnect_max_backoff = reconnect_max_backoff
|
|
167
177
|
self._reconnect_lock: asyncio.Lock = asyncio.Lock()
|
|
168
178
|
|
|
179
|
+
@property
|
|
180
|
+
def auth_provider(self) -> AuthProviderBase | None:
|
|
181
|
+
return self._auth_provider
|
|
182
|
+
|
|
169
183
|
@property
|
|
170
184
|
def transport(self) -> str:
|
|
171
185
|
return self._transport
|
|
@@ -248,12 +262,25 @@ class MCPBaseClient(ABC):
|
|
|
248
262
|
async def _with_reconnect(self, coro):
|
|
249
263
|
"""
|
|
250
264
|
Execute an awaited operation, reconnecting once on errors.
|
|
265
|
+
Does not reconnect if the error occurs during an active authentication flow.
|
|
251
266
|
"""
|
|
252
267
|
try:
|
|
253
268
|
return await coro()
|
|
254
269
|
except Exception as e:
|
|
270
|
+
# Check if error happened during active authentication flow
|
|
271
|
+
if self._httpx_auth and self._httpx_auth.is_authenticating:
|
|
272
|
+
# Provide specific error message for authentication timeouts
|
|
273
|
+
if isinstance(e, TimeoutError):
|
|
274
|
+
logger.error("Timeout during user authentication flow - user may have abandoned authentication")
|
|
275
|
+
raise RuntimeError(
|
|
276
|
+
"Authentication timed out. User did not complete authentication in browser within "
|
|
277
|
+
f"{self._auth_flow_timeout.total_seconds()} seconds.") from e
|
|
278
|
+
else:
|
|
279
|
+
logger.error("Error during authentication flow: %s", e)
|
|
280
|
+
raise
|
|
281
|
+
|
|
282
|
+
# Normal error - attempt reconnect if enabled
|
|
255
283
|
if self._reconnect_enabled:
|
|
256
|
-
logger.warning("MCP Client operation failed. Attempting reconnect: %s", e)
|
|
257
284
|
try:
|
|
258
285
|
await self._reconnect()
|
|
259
286
|
except Exception as reconnect_err:
|
|
@@ -262,7 +289,49 @@ class MCPBaseClient(ABC):
|
|
|
262
289
|
return await coro()
|
|
263
290
|
raise
|
|
264
291
|
|
|
265
|
-
async def
|
|
292
|
+
async def _has_cached_auth_token(self) -> bool:
|
|
293
|
+
"""
|
|
294
|
+
Check if we have a cached, non-expired authentication token.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
bool: True if we have a valid cached token, False if authentication may be needed
|
|
298
|
+
"""
|
|
299
|
+
if not self._auth_provider:
|
|
300
|
+
return True # No auth needed
|
|
301
|
+
|
|
302
|
+
try:
|
|
303
|
+
# Check if OAuth2 provider has tokens cached
|
|
304
|
+
if hasattr(self._auth_provider, '_auth_code_provider'):
|
|
305
|
+
provider = self._auth_provider._auth_code_provider
|
|
306
|
+
if provider and hasattr(provider, '_authenticated_tokens'):
|
|
307
|
+
# Check if we have at least one non-expired token
|
|
308
|
+
for auth_result in provider._authenticated_tokens.values():
|
|
309
|
+
if not auth_result.is_expired():
|
|
310
|
+
return True
|
|
311
|
+
|
|
312
|
+
return False
|
|
313
|
+
except Exception:
|
|
314
|
+
# If we can't check, assume we need auth to be safe
|
|
315
|
+
return False
|
|
316
|
+
|
|
317
|
+
async def _get_tool_call_timeout(self) -> timedelta:
|
|
318
|
+
"""
|
|
319
|
+
Determine the appropriate timeout for a tool call based on authentication state.
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
timedelta: auth_flow_timeout if authentication may be needed, tool_call_timeout otherwise
|
|
323
|
+
"""
|
|
324
|
+
if self._auth_provider:
|
|
325
|
+
has_token = await self._has_cached_auth_token()
|
|
326
|
+
timeout = self._tool_call_timeout if has_token else self._auth_flow_timeout
|
|
327
|
+
if not has_token:
|
|
328
|
+
logger.debug("Using extended timeout (%s) for potential interactive authentication", timeout)
|
|
329
|
+
return timeout
|
|
330
|
+
else:
|
|
331
|
+
return self._tool_call_timeout
|
|
332
|
+
|
|
333
|
+
@mcp_exception_handler
|
|
334
|
+
async def get_tools(self) -> dict[str, MCPToolClient]:
|
|
266
335
|
"""
|
|
267
336
|
Retrieve a dictionary of all tools served by the MCP server.
|
|
268
337
|
Uses unauthenticated session for discovery.
|
|
@@ -270,7 +339,16 @@ class MCPBaseClient(ABC):
|
|
|
270
339
|
|
|
271
340
|
async def _get_tools():
|
|
272
341
|
session = self._session
|
|
273
|
-
|
|
342
|
+
try:
|
|
343
|
+
# Add timeout to the list_tools call.
|
|
344
|
+
# This is needed because MCP SDK does not support timeout for list_tools()
|
|
345
|
+
with anyio.fail_after(self._tool_call_timeout.total_seconds()):
|
|
346
|
+
tools = await session.list_tools()
|
|
347
|
+
except TimeoutError as e:
|
|
348
|
+
from nat.plugins.mcp.exceptions import MCPTimeoutError
|
|
349
|
+
raise MCPTimeoutError(self.server_name, e)
|
|
350
|
+
|
|
351
|
+
return tools
|
|
274
352
|
|
|
275
353
|
try:
|
|
276
354
|
response = await self._with_reconnect(_get_tools)
|
|
@@ -284,8 +362,7 @@ class MCPBaseClient(ABC):
|
|
|
284
362
|
tool_name=tool.name,
|
|
285
363
|
tool_description=tool.description,
|
|
286
364
|
tool_input_schema=tool.inputSchema,
|
|
287
|
-
parent_client=self
|
|
288
|
-
tool_call_timeout=self._tool_call_timeout)
|
|
365
|
+
parent_client=self)
|
|
289
366
|
for tool in response.tools
|
|
290
367
|
}
|
|
291
368
|
|
|
@@ -314,12 +391,18 @@ class MCPBaseClient(ABC):
|
|
|
314
391
|
raise MCPToolNotFoundError(tool_name, self.server_name)
|
|
315
392
|
return tool
|
|
316
393
|
|
|
394
|
+
def set_user_auth_callback(self, auth_callback: Callable[[AuthFlowType], AuthenticatedContext]):
|
|
395
|
+
"""Set the user authentication callback."""
|
|
396
|
+
if self._auth_provider and hasattr(self._auth_provider, "_set_custom_auth_callback"):
|
|
397
|
+
self._auth_provider._set_custom_auth_callback(auth_callback)
|
|
398
|
+
|
|
317
399
|
@mcp_exception_handler
|
|
318
400
|
async def call_tool(self, tool_name: str, tool_args: dict | None):
|
|
319
401
|
|
|
320
402
|
async def _call_tool():
|
|
321
403
|
session = self._session
|
|
322
|
-
|
|
404
|
+
timeout = await self._get_tool_call_timeout()
|
|
405
|
+
return await session.call_tool(tool_name, tool_args, read_timeout_seconds=timeout)
|
|
323
406
|
|
|
324
407
|
return await self._with_reconnect(_call_tool)
|
|
325
408
|
|
|
@@ -334,13 +417,15 @@ class MCPSSEClient(MCPBaseClient):
|
|
|
334
417
|
|
|
335
418
|
def __init__(self,
|
|
336
419
|
url: str,
|
|
337
|
-
tool_call_timeout: timedelta = timedelta(seconds=
|
|
420
|
+
tool_call_timeout: timedelta = timedelta(seconds=60),
|
|
421
|
+
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
|
338
422
|
reconnect_enabled: bool = True,
|
|
339
423
|
reconnect_max_attempts: int = 2,
|
|
340
424
|
reconnect_initial_backoff: float = 0.5,
|
|
341
425
|
reconnect_max_backoff: float = 50.0):
|
|
342
426
|
super().__init__("sse",
|
|
343
427
|
tool_call_timeout=tool_call_timeout,
|
|
428
|
+
auth_flow_timeout=auth_flow_timeout,
|
|
344
429
|
reconnect_enabled=reconnect_enabled,
|
|
345
430
|
reconnect_max_attempts=reconnect_max_attempts,
|
|
346
431
|
reconnect_initial_backoff=reconnect_initial_backoff,
|
|
@@ -383,13 +468,15 @@ class MCPStdioClient(MCPBaseClient):
|
|
|
383
468
|
command: str,
|
|
384
469
|
args: list[str] | None = None,
|
|
385
470
|
env: dict[str, str] | None = None,
|
|
386
|
-
tool_call_timeout: timedelta = timedelta(seconds=
|
|
471
|
+
tool_call_timeout: timedelta = timedelta(seconds=60),
|
|
472
|
+
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
|
387
473
|
reconnect_enabled: bool = True,
|
|
388
474
|
reconnect_max_attempts: int = 2,
|
|
389
475
|
reconnect_initial_backoff: float = 0.5,
|
|
390
476
|
reconnect_max_backoff: float = 50.0):
|
|
391
477
|
super().__init__("stdio",
|
|
392
478
|
tool_call_timeout=tool_call_timeout,
|
|
479
|
+
auth_flow_timeout=auth_flow_timeout,
|
|
393
480
|
reconnect_enabled=reconnect_enabled,
|
|
394
481
|
reconnect_max_attempts=reconnect_max_attempts,
|
|
395
482
|
reconnect_initial_backoff=reconnect_initial_backoff,
|
|
@@ -440,14 +527,18 @@ class MCPStreamableHTTPClient(MCPBaseClient):
|
|
|
440
527
|
def __init__(self,
|
|
441
528
|
url: str,
|
|
442
529
|
auth_provider: AuthProviderBase | None = None,
|
|
443
|
-
|
|
530
|
+
user_id: str | None = None,
|
|
531
|
+
tool_call_timeout: timedelta = timedelta(seconds=60),
|
|
532
|
+
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
|
444
533
|
reconnect_enabled: bool = True,
|
|
445
534
|
reconnect_max_attempts: int = 2,
|
|
446
535
|
reconnect_initial_backoff: float = 0.5,
|
|
447
536
|
reconnect_max_backoff: float = 50.0):
|
|
448
537
|
super().__init__("streamable-http",
|
|
449
538
|
auth_provider=auth_provider,
|
|
539
|
+
user_id=user_id,
|
|
450
540
|
tool_call_timeout=tool_call_timeout,
|
|
541
|
+
auth_flow_timeout=auth_flow_timeout,
|
|
451
542
|
reconnect_enabled=reconnect_enabled,
|
|
452
543
|
reconnect_max_attempts=reconnect_max_attempts,
|
|
453
544
|
reconnect_initial_backoff=reconnect_initial_backoff,
|
|
@@ -490,17 +581,15 @@ class MCPToolClient:
|
|
|
490
581
|
|
|
491
582
|
def __init__(self,
|
|
492
583
|
session: ClientSession,
|
|
493
|
-
parent_client:
|
|
584
|
+
parent_client: MCPBaseClient,
|
|
494
585
|
tool_name: str,
|
|
495
586
|
tool_description: str | None,
|
|
496
|
-
tool_input_schema: dict | None = None
|
|
497
|
-
tool_call_timeout: timedelta = timedelta(seconds=5)):
|
|
587
|
+
tool_input_schema: dict | None = None):
|
|
498
588
|
self._session = session
|
|
499
589
|
self._tool_name = tool_name
|
|
500
590
|
self._tool_description = tool_description
|
|
501
591
|
self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
|
|
502
592
|
self._parent_client = parent_client
|
|
503
|
-
self._tool_call_timeout = tool_call_timeout
|
|
504
593
|
|
|
505
594
|
if self._parent_client is None:
|
|
506
595
|
raise RuntimeError("MCPToolClient initialized without a parent client.")
|
|
@@ -535,12 +624,17 @@ class MCPToolClient:
|
|
|
535
624
|
async def acall(self, tool_args: dict) -> str:
|
|
536
625
|
"""
|
|
537
626
|
Call the MCP tool with the provided arguments.
|
|
627
|
+
Session context is now handled at the client level, eliminating the need for metadata injection.
|
|
538
628
|
|
|
539
629
|
Args:
|
|
540
630
|
tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
|
|
541
631
|
"""
|
|
542
|
-
|
|
632
|
+
if self._session is None:
|
|
633
|
+
raise RuntimeError("No session available for tool call")
|
|
634
|
+
|
|
543
635
|
try:
|
|
636
|
+
# Simple tool call - session context is already in the client instance
|
|
637
|
+
logger.info("Calling tool %s", self._tool_name)
|
|
544
638
|
result = await self._parent_client.call_tool(self._tool_name, tool_args)
|
|
545
639
|
|
|
546
640
|
output = []
|
|
@@ -558,6 +652,6 @@ class MCPToolClient:
|
|
|
558
652
|
|
|
559
653
|
except MCPError as e:
|
|
560
654
|
format_mcp_error(e, include_traceback=False)
|
|
561
|
-
result_str = "MCPToolClient tool call failed:
|
|
655
|
+
result_str = f"MCPToolClient tool call failed: {e.original_exception}"
|
|
562
656
|
|
|
563
657
|
return result_str
|