nvidia-nat-mcp 1.3.0a20250925__py3-none-any.whl → 1.3.0a20250928__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/plugins/mcp/auth/auth_flow_handler.py +144 -0
- nat/plugins/mcp/auth/auth_provider.py +106 -79
- nat/plugins/mcp/auth/auth_provider_config.py +5 -2
- nat/plugins/mcp/client_base.py +278 -79
- nat/plugins/mcp/client_impl.py +45 -4
- {nvidia_nat_mcp-1.3.0a20250925.dist-info → nvidia_nat_mcp-1.3.0a20250928.dist-info}/METADATA +2 -2
- {nvidia_nat_mcp-1.3.0a20250925.dist-info → nvidia_nat_mcp-1.3.0a20250928.dist-info}/RECORD +10 -9
- {nvidia_nat_mcp-1.3.0a20250925.dist-info → nvidia_nat_mcp-1.3.0a20250928.dist-info}/WHEEL +0 -0
- {nvidia_nat_mcp-1.3.0a20250925.dist-info → nvidia_nat_mcp-1.3.0a20250928.dist-info}/entry_points.txt +0 -0
- {nvidia_nat_mcp-1.3.0a20250925.dist-info → nvidia_nat_mcp-1.3.0a20250928.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,144 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import asyncio
|
17
|
+
import logging
|
18
|
+
import secrets
|
19
|
+
import webbrowser
|
20
|
+
|
21
|
+
import pkce
|
22
|
+
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
23
|
+
from fastapi import FastAPI
|
24
|
+
|
25
|
+
from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
|
26
|
+
from nat.data_models.authentication import AuthenticatedContext
|
27
|
+
from nat.data_models.authentication import AuthFlowType
|
28
|
+
from nat.data_models.authentication import AuthProviderBaseConfig
|
29
|
+
from nat.front_ends.console.authentication_flow_handler import ConsoleAuthenticationFlowHandler
|
30
|
+
from nat.front_ends.console.authentication_flow_handler import _FlowState
|
31
|
+
from nat.front_ends.fastapi.fastapi_front_end_controller import _FastApiFrontEndController
|
32
|
+
|
33
|
+
logger = logging.getLogger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
class MCPAuthenticationFlowHandler(ConsoleAuthenticationFlowHandler):
|
37
|
+
"""
|
38
|
+
Authentication helper for MCP environments.
|
39
|
+
|
40
|
+
This handler is specifically designed for MCP tool discovery scenarios where
|
41
|
+
authentication needs to happen before the default auth_callback is available
|
42
|
+
in the Context. It handles OAuth2 authorization code flow during MCP client
|
43
|
+
startup and tool discovery phases.
|
44
|
+
|
45
|
+
Key differences from console handler:
|
46
|
+
- Only supports OAuth2 Authorization Code flow (no HTTP Basic)
|
47
|
+
- Optimized for MCP tool discovery workflows
|
48
|
+
- Designed for single-use authentication during startup
|
49
|
+
"""
|
50
|
+
|
51
|
+
def __init__(self):
|
52
|
+
super().__init__()
|
53
|
+
self._server_controller: _FastApiFrontEndController | None = None
|
54
|
+
self._redirect_app: FastAPI | None = None
|
55
|
+
self._server_lock = asyncio.Lock()
|
56
|
+
self._oauth_client: AsyncOAuth2Client | None = None
|
57
|
+
|
58
|
+
async def authenticate(self, config: AuthProviderBaseConfig, method: AuthFlowType) -> AuthenticatedContext:
|
59
|
+
"""
|
60
|
+
Handle the OAuth2 authorization code flow for MCP environments.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
config: OAuth2 configuration for MCP server
|
64
|
+
method: Authentication method (only OAUTH2_AUTHORIZATION_CODE supported)
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
AuthenticatedContext with Bearer token for MCP server access
|
68
|
+
|
69
|
+
Raises:
|
70
|
+
ValueError: If config is invalid for MCP use case
|
71
|
+
NotImplementedError: If method is not OAuth2 Authorization Code
|
72
|
+
"""
|
73
|
+
logger.info("Starting MCP authentication flow")
|
74
|
+
|
75
|
+
if method == AuthFlowType.OAUTH2_AUTHORIZATION_CODE:
|
76
|
+
if not isinstance(config, OAuth2AuthCodeFlowProviderConfig):
|
77
|
+
raise ValueError("Requested OAuth2 Authorization Code Flow but passed invalid config")
|
78
|
+
|
79
|
+
# MCP-specific validation
|
80
|
+
if not config.redirect_uri:
|
81
|
+
raise ValueError("MCP authentication requires redirect_uri to be configured")
|
82
|
+
|
83
|
+
logger.info("MCP authentication configured for server: %s", getattr(config, 'server_url', 'unknown'))
|
84
|
+
return await self._handle_oauth2_auth_code_flow(config)
|
85
|
+
|
86
|
+
raise NotImplementedError(f'Auth method "{method}" not supported for MCP environments')
|
87
|
+
|
88
|
+
async def _handle_oauth2_auth_code_flow(self, cfg: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext:
|
89
|
+
logger.info("Starting MCP OAuth2 authorization code flow")
|
90
|
+
|
91
|
+
state = secrets.token_urlsafe(16)
|
92
|
+
flow_state = _FlowState()
|
93
|
+
client = self.construct_oauth_client(cfg)
|
94
|
+
|
95
|
+
flow_state.token_url = cfg.token_url
|
96
|
+
flow_state.use_pkce = cfg.use_pkce
|
97
|
+
|
98
|
+
# PKCE bits
|
99
|
+
if cfg.use_pkce:
|
100
|
+
verifier, challenge = pkce.generate_pkce_pair()
|
101
|
+
flow_state.verifier = verifier
|
102
|
+
flow_state.challenge = challenge
|
103
|
+
logger.debug("PKCE enabled for MCP authentication")
|
104
|
+
|
105
|
+
auth_url, _ = client.create_authorization_url(
|
106
|
+
cfg.authorization_url,
|
107
|
+
state=state,
|
108
|
+
code_verifier=flow_state.verifier if cfg.use_pkce else None,
|
109
|
+
code_challenge=flow_state.challenge if cfg.use_pkce else None,
|
110
|
+
**(cfg.authorization_kwargs or {})
|
111
|
+
)
|
112
|
+
|
113
|
+
async with self._server_lock:
|
114
|
+
if self._redirect_app is None:
|
115
|
+
self._redirect_app = await self._build_redirect_app()
|
116
|
+
|
117
|
+
await self._start_redirect_server()
|
118
|
+
self._flows[state] = flow_state
|
119
|
+
|
120
|
+
logger.info("MCP authentication: Your browser has been opened for authentication.")
|
121
|
+
logger.info("This will authenticate you with the MCP server for tool discovery.")
|
122
|
+
webbrowser.open(auth_url)
|
123
|
+
|
124
|
+
# Use default timeout for MCP tool discovery
|
125
|
+
timeout = 300
|
126
|
+
|
127
|
+
try:
|
128
|
+
token = await asyncio.wait_for(flow_state.future, timeout=timeout)
|
129
|
+
logger.info("MCP authentication successful, token obtained")
|
130
|
+
except asyncio.TimeoutError as exc:
|
131
|
+
logger.error("MCP authentication timed out")
|
132
|
+
raise RuntimeError(f"MCP authentication timed out ({timeout} seconds). Please try again.") from exc
|
133
|
+
finally:
|
134
|
+
async with self._server_lock:
|
135
|
+
self._flows.pop(state, None)
|
136
|
+
await self._stop_redirect_server()
|
137
|
+
|
138
|
+
return AuthenticatedContext(
|
139
|
+
headers={"Authorization": f"Bearer {token['access_token']}"},
|
140
|
+
metadata={
|
141
|
+
"expires_at": token.get("expires_at"),
|
142
|
+
"raw_token": token,
|
143
|
+
},
|
144
|
+
)
|
@@ -14,6 +14,8 @@
|
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
16
|
import logging
|
17
|
+
from collections.abc import Awaitable
|
18
|
+
from typing import Callable
|
17
19
|
from urllib.parse import urljoin
|
18
20
|
from urllib.parse import urlparse
|
19
21
|
|
@@ -26,10 +28,12 @@ from mcp.shared.auth import OAuthClientInformationFull
|
|
26
28
|
from mcp.shared.auth import OAuthClientMetadata
|
27
29
|
from mcp.shared.auth import OAuthMetadata
|
28
30
|
from mcp.shared.auth import ProtectedResourceMetadata
|
31
|
+
from nat.authentication.interfaces import AuthenticatedContext
|
32
|
+
from nat.authentication.interfaces import AuthFlowType
|
29
33
|
from nat.authentication.interfaces import AuthProviderBase
|
30
|
-
from nat.
|
31
|
-
from nat.data_models.authentication import AuthRequest
|
34
|
+
from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
|
32
35
|
from nat.data_models.authentication import AuthResult
|
36
|
+
from nat.plugins.mcp.auth.auth_flow_handler import MCPAuthenticationFlowHandler
|
33
37
|
from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig
|
34
38
|
|
35
39
|
logger = logging.getLogger(__name__)
|
@@ -40,6 +44,7 @@ class OAuth2Endpoints(BaseModel):
|
|
40
44
|
authorization_url: HttpUrl = Field(..., description="OAuth2 authorization endpoint URL")
|
41
45
|
token_url: HttpUrl = Field(..., description="OAuth2 token endpoint URL")
|
42
46
|
registration_url: HttpUrl | None = Field(default=None, description="OAuth2 client registration endpoint URL")
|
47
|
+
scopes: list[str] | None = Field(default=None, description="OAuth2 scopes to be used for the authentication")
|
43
48
|
|
44
49
|
|
45
50
|
class OAuth2Credentials(BaseModel):
|
@@ -60,9 +65,11 @@ class DiscoverOAuth2Endpoints:
|
|
60
65
|
def __init__(self, config: MCPOAuth2ProviderConfig):
|
61
66
|
self.config = config
|
62
67
|
self._cached_endpoints: OAuth2Endpoints | None = None
|
63
|
-
self.
|
68
|
+
self._authenticated_servers: dict[str, AuthResult] = {}
|
64
69
|
|
65
|
-
|
70
|
+
self._flow_handler: MCPAuthenticationFlowHandler = MCPAuthenticationFlowHandler()
|
71
|
+
|
72
|
+
async def discover(self, response: httpx.Response | None = None) -> tuple[OAuth2Endpoints, bool]:
|
66
73
|
"""
|
67
74
|
Discover OAuth2 endpoints from MCP server.
|
68
75
|
|
@@ -73,21 +80,24 @@ class DiscoverOAuth2Endpoints:
|
|
73
80
|
Returns:
|
74
81
|
A tuple of OAuth2Endpoints and a boolean indicating if the endpoints have changed.
|
75
82
|
"""
|
83
|
+
is_401_retry = response is not None and response.status_code == 401
|
76
84
|
# Fast path: reuse cache when not a 401 retry
|
77
|
-
if
|
85
|
+
if not is_401_retry and self._cached_endpoints is not None:
|
78
86
|
return self._cached_endpoints, False
|
79
87
|
|
80
88
|
issuer: str = str(self.config.server_url) # default to server URL
|
81
89
|
endpoints: OAuth2Endpoints | None = None
|
82
90
|
|
83
91
|
# 1) 401 hint (RFC 9728) if present
|
84
|
-
if
|
85
|
-
|
86
|
-
if
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
92
|
+
if is_401_retry and response:
|
93
|
+
www_authenticate = response.headers.get("WWW-Authenticate")
|
94
|
+
if www_authenticate:
|
95
|
+
hint_url = self._extract_from_www_authenticate_header(www_authenticate)
|
96
|
+
if hint_url:
|
97
|
+
logger.info("Using RFC 9728 resource_metadata hint: %s", hint_url)
|
98
|
+
issuer_hint = await self._fetch_pr_issuer(hint_url)
|
99
|
+
if issuer_hint:
|
100
|
+
issuer = issuer_hint
|
91
101
|
|
92
102
|
# 2) Try RS protected resource well-known if we still only have default issuer
|
93
103
|
if issuer == str(self.config.server_url):
|
@@ -105,10 +115,7 @@ class DiscoverOAuth2Endpoints:
|
|
105
115
|
if endpoints is None:
|
106
116
|
raise RuntimeError("Could not discover OAuth2 endpoints from MCP server")
|
107
117
|
|
108
|
-
changed = (self._cached_endpoints is None
|
109
|
-
or endpoints.authorization_url != self._cached_endpoints.authorization_url
|
110
|
-
or endpoints.token_url != self._cached_endpoints.token_url
|
111
|
-
or endpoints.registration_url != self._cached_endpoints.registration_url)
|
118
|
+
changed = (self._cached_endpoints is None or endpoints.model_dump() != self._cached_endpoints.model_dump())
|
112
119
|
self._cached_endpoints = endpoints
|
113
120
|
logger.info("OAuth2 endpoints selected: %s", self._cached_endpoints)
|
114
121
|
return self._cached_endpoints, changed
|
@@ -155,10 +162,29 @@ class DiscoverOAuth2Endpoints:
|
|
155
162
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
156
163
|
for url in urls:
|
157
164
|
try:
|
158
|
-
resp = await client.get(url, headers={"Accept": "application/json"})
|
165
|
+
resp = await client.get(url, follow_redirects=True, headers={"Accept": "application/json"})
|
159
166
|
if resp.status_code != 200:
|
160
167
|
continue
|
168
|
+
|
169
|
+
# Check content type before attempting JSON parsing
|
170
|
+
content_type = resp.headers.get("content-type", "").lower()
|
171
|
+
if "application/json" not in content_type:
|
172
|
+
logger.info(
|
173
|
+
"Discovery endpoint %s returned non-JSON content type: %s. "
|
174
|
+
"This may indicate the endpoint doesn't support discovery or requires authentication.",
|
175
|
+
url,
|
176
|
+
content_type)
|
177
|
+
# If it's HTML, log a more helpful message
|
178
|
+
if "text/html" in content_type:
|
179
|
+
logger.info("The endpoint appears to be returning an HTML page instead of OAuth metadata. "
|
180
|
+
"This often means:")
|
181
|
+
logger.info("1. The OAuth discovery endpoint doesn't exist at this URL")
|
182
|
+
logger.info("2. The server requires authentication before providing discovery metadata")
|
183
|
+
logger.info("3. The URL is pointing to a web application instead of an OAuth server")
|
184
|
+
continue
|
185
|
+
|
161
186
|
body = await resp.aread()
|
187
|
+
|
162
188
|
try:
|
163
189
|
meta = OAuthMetadata.model_validate_json(body)
|
164
190
|
except Exception as e:
|
@@ -167,14 +193,18 @@ class DiscoverOAuth2Endpoints:
|
|
167
193
|
if meta.authorization_endpoint and meta.token_endpoint:
|
168
194
|
logger.info("Discovered OAuth2 endpoints from %s", url)
|
169
195
|
# this is bit of a hack to get the scopes supported by the auth server
|
170
|
-
self._last_oauth_scopes = meta.scopes_supported
|
171
196
|
return OAuth2Endpoints(
|
172
197
|
authorization_url=str(meta.authorization_endpoint),
|
173
198
|
token_url=str(meta.token_endpoint),
|
174
199
|
registration_url=str(meta.registration_endpoint) if meta.registration_endpoint else None,
|
200
|
+
scopes=meta.scopes_supported,
|
175
201
|
)
|
176
202
|
except Exception as e:
|
177
203
|
logger.debug("Discovery failed at %s: %s", url, e)
|
204
|
+
|
205
|
+
# If we get here, all discovery URLs failed
|
206
|
+
logger.info("OAuth discovery failed for all attempted URLs.")
|
207
|
+
logger.info("Attempted URLs: %s", urls)
|
178
208
|
return None
|
179
209
|
|
180
210
|
def _build_path_aware_discovery_urls(self, base_or_issuer: str) -> list[str]:
|
@@ -184,17 +214,19 @@ class DiscoverOAuth2Endpoints:
|
|
184
214
|
path = (p.path or "").rstrip("/")
|
185
215
|
urls: list[str] = []
|
186
216
|
if path:
|
187
|
-
|
217
|
+
# this is the specified by the MCP spec
|
218
|
+
urls.append(urljoin(base, f".well-known/oauth-protected-resource{path}"))
|
219
|
+
# this is fallback for backward compatibility
|
220
|
+
urls.append(urljoin(base, f"{path}/.well-known/oauth-authorization-server"))
|
188
221
|
urls.append(urljoin(base, "/.well-known/oauth-authorization-server"))
|
189
222
|
if path:
|
190
|
-
|
223
|
+
# this is the specified by the MCP spec
|
224
|
+
urls.append(urljoin(base, f".well-known/openid-configuration{path}"))
|
225
|
+
# this is fallback for backward compatibility
|
226
|
+
urls.append(urljoin(base, f"{path}/.well-known/openid-configuration"))
|
191
227
|
urls.append(base_or_issuer.rstrip("/") + "/.well-known/openid-configuration")
|
192
228
|
return urls
|
193
229
|
|
194
|
-
def scopes_supported(self) -> list[str] | None:
|
195
|
-
"""Get the last OAuth scopes discovered from the AS."""
|
196
|
-
return self._last_oauth_scopes
|
197
|
-
|
198
230
|
|
199
231
|
class DynamicClientRegistration:
|
200
232
|
"""Dynamic client registration utility."""
|
@@ -264,51 +296,54 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
|
264
296
|
|
265
297
|
# For the OAuth2 flow
|
266
298
|
self._auth_code_provider = None
|
299
|
+
self._flow_handler = MCPAuthenticationFlowHandler()
|
267
300
|
|
268
|
-
|
301
|
+
self._auth_callback = None
|
302
|
+
|
303
|
+
def _set_custom_auth_callback(self,
|
304
|
+
auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType],
|
305
|
+
Awaitable[AuthenticatedContext]]):
|
306
|
+
"""Set the custom authentication callback."""
|
307
|
+
if not self._auth_callback:
|
308
|
+
logger.info("Using custom authentication callback")
|
309
|
+
self._auth_callback = auth_callback
|
310
|
+
if self._auth_code_provider:
|
311
|
+
self._auth_code_provider._set_custom_auth_callback(self._auth_callback)
|
312
|
+
|
313
|
+
async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
|
269
314
|
"""
|
270
315
|
Authenticate using MCP OAuth2 flow via NAT framework.
|
316
|
+
|
317
|
+
If response is provided in kwargs (typically from a 401), performs:
|
271
318
|
1. Dynamic endpoints discovery (RFC9728 + RFC 8414 + OIDC)
|
272
319
|
2. Client registration (RFC7591)
|
273
|
-
3.
|
320
|
+
3. Authentication
|
321
|
+
|
322
|
+
Otherwise, performs standard authentication flow.
|
274
323
|
"""
|
275
|
-
|
276
|
-
if
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
# force fresh delegate (clears in-mem token cache)
|
288
|
-
self._auth_code_provider = None
|
289
|
-
# preserve other fields, just normalize reason & inject user_id
|
290
|
-
auth_request = auth_request.model_copy(update={
|
291
|
-
"reason": AuthReason.NORMAL, "user_id": user_id, "www_authenticate": None
|
292
|
-
})
|
293
|
-
# back-compat: propagate user_id if provided but not set in the request
|
294
|
-
elif user_id is not None and auth_request.user_id is None:
|
295
|
-
auth_request = auth_request.model_copy(update={"user_id": user_id})
|
296
|
-
|
297
|
-
# Perform the OAuth2 flow without lock
|
298
|
-
return await self._perform_oauth2_flow(auth_request=auth_request)
|
299
|
-
|
300
|
-
async def _discover_and_register(self, auth_request: AuthRequest):
|
324
|
+
response = kwargs.get('response')
|
325
|
+
if response and response.status_code == 401:
|
326
|
+
await self._discover_and_register(response=response)
|
327
|
+
|
328
|
+
return await self._nat_oauth2_authenticate(user_id=user_id)
|
329
|
+
|
330
|
+
@property
|
331
|
+
def _effective_scopes(self) -> list[str]:
|
332
|
+
"""Get the effective scopes to be used for the authentication."""
|
333
|
+
return self.config.scopes or (self._cached_endpoints.scopes if self._cached_endpoints else []) or []
|
334
|
+
|
335
|
+
async def _discover_and_register(self, response: httpx.Response | None = None):
|
301
336
|
"""
|
302
337
|
Discover OAuth2 endpoints and register an OAuth2 client with the Authorization Server
|
303
338
|
using OIDC client registration.
|
304
339
|
"""
|
305
340
|
# Discover OAuth2 endpoints
|
306
|
-
self._cached_endpoints, endpoints_changed = await self._discoverer.discover(
|
307
|
-
www_authenticate=auth_request.www_authenticate)
|
341
|
+
self._cached_endpoints, endpoints_changed = await self._discoverer.discover(response=response)
|
308
342
|
if endpoints_changed:
|
309
343
|
logger.info("OAuth2 endpoints: %s", self._cached_endpoints)
|
310
344
|
self._cached_credentials = None # invalidate credentials tied to old AS
|
311
|
-
|
345
|
+
self._auth_code_provider = None
|
346
|
+
effective_scopes = self._effective_scopes
|
312
347
|
|
313
348
|
# Client registration
|
314
349
|
if not self._cached_credentials:
|
@@ -324,21 +359,20 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
|
324
359
|
self._cached_credentials = await self._registrar.register(self._cached_endpoints, effective_scopes)
|
325
360
|
logger.info("Registered OAuth2 client: %s", self._cached_credentials.client_id)
|
326
361
|
|
327
|
-
def
|
328
|
-
"""
|
329
|
-
Prefer caller-provided scopes; otherwise fall back to AS-advertised scopes_supported.
|
330
|
-
"""
|
331
|
-
return self.config.scopes or self._discoverer.scopes_supported()
|
332
|
-
|
333
|
-
async def _build_oauth2_delegate(self):
|
334
|
-
"""Build NAT OAuth2 provider and delegate auth token acquisition and refresh to it"""
|
362
|
+
async def _nat_oauth2_authenticate(self, user_id: str | None = None) -> AuthResult:
|
363
|
+
"""Perform the OAuth2 flow using MCP-specific authentication flow handler."""
|
335
364
|
from nat.authentication.oauth2.oauth2_auth_code_flow_provider import OAuth2AuthCodeFlowProvider
|
336
|
-
|
365
|
+
|
366
|
+
if not self._cached_endpoints or not self._cached_credentials:
|
367
|
+
# if discovery is yet to to be done return empty auth result
|
368
|
+
return AuthResult(credentials=[], token_expires_at=None, raw={})
|
337
369
|
|
338
370
|
endpoints = self._cached_endpoints
|
339
371
|
credentials = self._cached_credentials
|
340
372
|
|
373
|
+
# Build the OAuth2 provider if not already built
|
341
374
|
if self._auth_code_provider is None:
|
375
|
+
scopes = self._effective_scopes
|
342
376
|
oauth2_config = OAuth2AuthCodeFlowProviderConfig(
|
343
377
|
client_id=credentials.client_id,
|
344
378
|
client_secret=credentials.client_secret or "",
|
@@ -346,22 +380,15 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
|
346
380
|
token_url=str(endpoints.token_url),
|
347
381
|
token_endpoint_auth_method=getattr(self.config, "token_endpoint_auth_method", None),
|
348
382
|
redirect_uri=str(self.config.redirect_uri) if self.config.redirect_uri else "",
|
349
|
-
scopes=
|
383
|
+
scopes=scopes,
|
350
384
|
use_pkce=bool(self.config.use_pkce),
|
351
|
-
|
352
|
-
|
385
|
+
authorization_kwargs={"resource": str(self.config.server_url)})
|
353
386
|
self._auth_code_provider = OAuth2AuthCodeFlowProvider(oauth2_config)
|
354
387
|
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
raise RuntimeError("_perform_oauth2_flow should not be called for RETRY_AFTER_401")
|
360
|
-
|
361
|
-
if not self._cached_endpoints or not self._cached_credentials:
|
362
|
-
raise RuntimeError("OAuth2 flow called before discovery/registration")
|
388
|
+
# Use MCP-specific authentication method if available
|
389
|
+
if hasattr(self._auth_code_provider, "_set_custom_auth_callback"):
|
390
|
+
self._auth_code_provider._set_custom_auth_callback(self._auth_callback
|
391
|
+
or self._flow_handler.authenticate)
|
363
392
|
|
364
|
-
#
|
365
|
-
await self.
|
366
|
-
# Let the delegate handle per-user cache + refresh
|
367
|
-
return await self._auth_code_provider.authenticate()
|
393
|
+
# Auth code provider is responsible for per-user cache + refresh
|
394
|
+
return await self._auth_code_provider.authenticate(user_id=user_id)
|
@@ -18,7 +18,6 @@ from pydantic import HttpUrl
|
|
18
18
|
from pydantic import model_validator
|
19
19
|
|
20
20
|
from nat.authentication.interfaces import AuthProviderBaseConfig
|
21
|
-
from nat.data_models.authentication import AuthRequest
|
22
21
|
|
23
22
|
|
24
23
|
class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
|
@@ -51,12 +50,16 @@ class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
|
|
51
50
|
# Advanced options
|
52
51
|
use_pkce: bool = Field(default=True, description="Use PKCE for authorization code flow")
|
53
52
|
|
54
|
-
|
53
|
+
default_user_id: str | None = Field(default=None, description="Default user ID for authentication")
|
54
|
+
allow_default_user_id_for_tool_calls: bool = Field(default=True, description="Allow default user ID for tool calls")
|
55
55
|
|
56
56
|
@model_validator(mode="after")
|
57
57
|
def validate_auth_config(self):
|
58
58
|
"""Validate authentication configuration for MCP-specific options."""
|
59
59
|
|
60
|
+
# if default_user_id is not provided, use the server_url as the default user id
|
61
|
+
if not self.default_user_id:
|
62
|
+
self.default_user_id = str(self.server_url)
|
60
63
|
# Dynamic registration + MCP discovery
|
61
64
|
if self.enable_dynamic_registration and not self.client_id:
|
62
65
|
# Pure dynamic registration - no explicit credentials needed
|
nat/plugins/mcp/client_base.py
CHANGED
@@ -15,13 +15,18 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
import asyncio
|
19
|
+
import json
|
18
20
|
import logging
|
19
21
|
from abc import ABC
|
20
22
|
from abc import abstractmethod
|
23
|
+
from collections.abc import AsyncGenerator
|
24
|
+
from collections.abc import Callable
|
21
25
|
from contextlib import AsyncExitStack
|
22
26
|
from contextlib import asynccontextmanager
|
23
|
-
from
|
27
|
+
from datetime import timedelta
|
24
28
|
|
29
|
+
import anyio
|
25
30
|
import httpx
|
26
31
|
|
27
32
|
from mcp import ClientSession
|
@@ -30,10 +35,13 @@ from mcp.client.stdio import StdioServerParameters
|
|
30
35
|
from mcp.client.stdio import stdio_client
|
31
36
|
from mcp.client.streamable_http import streamablehttp_client
|
32
37
|
from mcp.types import TextContent
|
38
|
+
from nat.authentication.interfaces import AuthenticatedContext
|
39
|
+
from nat.authentication.interfaces import AuthFlowType
|
33
40
|
from nat.authentication.interfaces import AuthProviderBase
|
34
|
-
from nat.
|
35
|
-
from nat.
|
41
|
+
from nat.plugins.mcp.exception_handler import convert_to_mcp_error
|
42
|
+
from nat.plugins.mcp.exception_handler import format_mcp_error
|
36
43
|
from nat.plugins.mcp.exception_handler import mcp_exception_handler
|
44
|
+
from nat.plugins.mcp.exceptions import MCPError
|
37
45
|
from nat.plugins.mcp.exceptions import MCPToolNotFoundError
|
38
46
|
from nat.plugins.mcp.utils import model_from_mcp_schema
|
39
47
|
from nat.utils.type_utils import override
|
@@ -47,75 +55,89 @@ class AuthAdapter(httpx.Auth):
|
|
47
55
|
Converts AuthProviderBase to httpx.Auth interface for dynamic token management.
|
48
56
|
"""
|
49
57
|
|
50
|
-
def __init__(self, auth_provider: AuthProviderBase
|
58
|
+
def __init__(self, auth_provider: AuthProviderBase):
|
51
59
|
self.auth_provider = auth_provider
|
52
|
-
|
60
|
+
# each adapter instance has its own lock to avoid unnecessary delays for multiple clients
|
61
|
+
self._lock = anyio.Lock()
|
53
62
|
|
54
63
|
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
55
64
|
"""Add authentication headers to the request using NAT auth provider."""
|
56
|
-
|
57
|
-
if self.auth_for_tool_calls_only and not self._is_tool_call_request(request):
|
58
|
-
# Skip auth for non-tool calls
|
59
|
-
yield request
|
60
|
-
return
|
61
|
-
|
62
|
-
try:
|
63
|
-
# Get fresh auth headers from the NAT auth provider
|
64
|
-
auth_headers = await self._get_auth_headers(reason=AuthReason.NORMAL)
|
65
|
-
request.headers.update(auth_headers)
|
66
|
-
except Exception as e:
|
67
|
-
logger.info("Failed to get auth headers: %s", e)
|
68
|
-
# Continue without auth headers if auth fails
|
69
|
-
|
70
|
-
response = yield request
|
71
|
-
|
72
|
-
# Handle 401 responses by retrying with fresh auth
|
73
|
-
if response.status_code == 401:
|
65
|
+
async with self._lock:
|
74
66
|
try:
|
75
|
-
# Get
|
76
|
-
|
67
|
+
# Get auth headers from the NAT auth provider:
|
68
|
+
# 1. If discovery is yet to done this will return None and request will be sent without auth header.
|
69
|
+
# 2. If discovery is done, this will return the auth header from cache if the token is still valid
|
70
|
+
auth_headers = await self._get_auth_headers(request=request, response=None)
|
77
71
|
request.headers.update(auth_headers)
|
78
|
-
yield request # Retry the request
|
79
72
|
except Exception as e:
|
80
|
-
logger.info("Failed to
|
73
|
+
logger.info("Failed to get auth headers: %s", e)
|
74
|
+
# Continue without auth headers if auth fails
|
75
|
+
|
76
|
+
response = yield request
|
77
|
+
|
78
|
+
# Handle 401 responses by retrying with fresh auth
|
79
|
+
if response.status_code == 401:
|
80
|
+
try:
|
81
|
+
# 401 can happen if:
|
82
|
+
# 1. The request was sent without auth header
|
83
|
+
# 2. The auth headers are invalid
|
84
|
+
# 3. The auth headers are expired
|
85
|
+
# 4. The auth headers are revoked
|
86
|
+
# 5. Auth config on the MCP server has changed
|
87
|
+
# In this case we attempt to re-run discovery and authentication
|
88
|
+
auth_headers = await self._get_auth_headers(request=request, response=response)
|
89
|
+
request.headers.update(auth_headers)
|
90
|
+
yield request # Retry the request
|
91
|
+
except Exception as e:
|
92
|
+
logger.info("Failed to refresh auth after 401: %s", e)
|
81
93
|
return
|
82
94
|
|
83
|
-
def
|
84
|
-
"""Check if this is a tool call request based on the request body.
|
95
|
+
def _get_session_id_from_tool_call_request(self, request: httpx.Request) -> tuple[str | None, bool]:
|
96
|
+
"""Check if this is a tool call request based on the request body.
|
97
|
+
Return the session id if it exists and a boolean indicating if it is a tool call request
|
98
|
+
"""
|
85
99
|
try:
|
86
100
|
# Check if the request body contains a tool call
|
87
101
|
if request.content:
|
88
|
-
import json
|
89
102
|
body = json.loads(request.content.decode('utf-8'))
|
90
103
|
# Check if it's a JSON-RPC request with method "tools/call"
|
91
104
|
if (isinstance(body, dict) and body.get("method") == "tools/call"):
|
92
|
-
|
105
|
+
session_id = body.get("params").get("_meta").get("session_id")
|
106
|
+
return session_id, True
|
93
107
|
except (json.JSONDecodeError, UnicodeDecodeError, AttributeError):
|
94
108
|
# If we can't parse the body, assume it's not a tool call
|
95
109
|
pass
|
96
|
-
return False
|
110
|
+
return None, False
|
97
111
|
|
98
|
-
async def _get_auth_headers(self,
|
112
|
+
async def _get_auth_headers(self,
|
113
|
+
request: httpx.Request | None = None,
|
114
|
+
response: httpx.Response | None = None) -> dict[str, str]:
|
99
115
|
"""Get authentication headers from the NAT auth provider."""
|
100
|
-
# Build auth request
|
101
|
-
www_authenticate = response.headers.get("WWW-Authenticate", None) if response else None
|
102
|
-
auth_request = AuthRequest(
|
103
|
-
reason=reason,
|
104
|
-
www_authenticate=www_authenticate,
|
105
|
-
)
|
106
116
|
try:
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
117
|
+
session_id = None
|
118
|
+
is_tool_call = False
|
119
|
+
if request:
|
120
|
+
session_id, is_tool_call = self._get_session_id_from_tool_call_request(request)
|
121
|
+
|
122
|
+
if is_tool_call:
|
123
|
+
# Tool call requests should use the session id if it exists, default user id can be used if allowed
|
124
|
+
if self.auth_provider.config.allow_default_user_id_for_tool_calls:
|
125
|
+
user_id = session_id or self.auth_provider.config.default_user_id
|
126
|
+
else:
|
127
|
+
user_id = session_id
|
128
|
+
else:
|
129
|
+
# Non-tool call requests should use the session id if it exists and fallback to default user id
|
130
|
+
user_id = session_id or self.auth_provider.config.default_user_id
|
131
|
+
|
132
|
+
auth_result = await self.auth_provider.authenticate(user_id=user_id, response=response)
|
133
|
+
|
112
134
|
# Check if we have BearerTokenCred
|
113
135
|
from nat.data_models.authentication import BearerTokenCred
|
114
136
|
if auth_result.credentials and isinstance(auth_result.credentials[0], BearerTokenCred):
|
115
137
|
token = auth_result.credentials[0].token.get_secret_value()
|
116
138
|
return {"Authorization": f"Bearer {token}"}
|
117
139
|
else:
|
118
|
-
logger.
|
140
|
+
logger.info("Auth provider did not return BearerTokenCred")
|
119
141
|
return {}
|
120
142
|
except Exception as e:
|
121
143
|
logger.warning("Failed to get auth token: %s", e)
|
@@ -131,7 +153,14 @@ class MCPBaseClient(ABC):
|
|
131
153
|
auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
|
132
154
|
"""
|
133
155
|
|
134
|
-
def __init__(self,
|
156
|
+
def __init__(self,
|
157
|
+
transport: str = 'streamable-http',
|
158
|
+
auth_provider: AuthProviderBase | None = None,
|
159
|
+
tool_call_timeout: timedelta = timedelta(seconds=5),
|
160
|
+
reconnect_enabled: bool = True,
|
161
|
+
reconnect_max_attempts: int = 2,
|
162
|
+
reconnect_initial_backoff: float = 0.5,
|
163
|
+
reconnect_max_backoff: float = 50.0):
|
135
164
|
self._tools = None
|
136
165
|
self._transport = transport.lower()
|
137
166
|
if self._transport not in ['sse', 'stdio', 'streamable-http']:
|
@@ -143,8 +172,18 @@ class MCPBaseClient(ABC):
|
|
143
172
|
self._initial_connection = False
|
144
173
|
|
145
174
|
# Convert auth provider to AuthAdapter
|
175
|
+
self._auth_provider = auth_provider
|
146
176
|
self._httpx_auth = AuthAdapter(auth_provider) if auth_provider else None
|
147
177
|
|
178
|
+
self._tool_call_timeout = tool_call_timeout
|
179
|
+
|
180
|
+
# Reconnect configuration
|
181
|
+
self._reconnect_enabled = reconnect_enabled
|
182
|
+
self._reconnect_max_attempts = reconnect_max_attempts
|
183
|
+
self._reconnect_initial_backoff = reconnect_initial_backoff
|
184
|
+
self._reconnect_max_backoff = reconnect_max_backoff
|
185
|
+
self._reconnect_lock: asyncio.Lock = asyncio.Lock()
|
186
|
+
|
148
187
|
@property
|
149
188
|
def transport(self) -> str:
|
150
189
|
return self._transport
|
@@ -164,13 +203,14 @@ class MCPBaseClient(ABC):
|
|
164
203
|
return self
|
165
204
|
|
166
205
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
167
|
-
if
|
168
|
-
|
206
|
+
if self._exit_stack:
|
207
|
+
# Close session
|
208
|
+
await self._exit_stack.aclose()
|
209
|
+
self._session = None
|
210
|
+
self._exit_stack = None
|
169
211
|
|
170
|
-
|
171
|
-
|
172
|
-
self._session = None
|
173
|
-
self._exit_stack = None
|
212
|
+
self._connection_established = False
|
213
|
+
self._tools = None
|
174
214
|
|
175
215
|
@property
|
176
216
|
def server_name(self):
|
@@ -181,22 +221,80 @@ class MCPBaseClient(ABC):
|
|
181
221
|
|
182
222
|
@abstractmethod
|
183
223
|
@asynccontextmanager
|
184
|
-
async def connect_to_server(self):
|
224
|
+
async def connect_to_server(self) -> AsyncGenerator[ClientSession, None]:
|
185
225
|
"""
|
186
226
|
Establish a session with an MCP server within an async context
|
187
227
|
"""
|
188
228
|
yield
|
189
229
|
|
230
|
+
async def _reconnect(self):
|
231
|
+
"""
|
232
|
+
Attempt to reconnect by tearing down and re-establishing the session.
|
233
|
+
"""
|
234
|
+
async with self._reconnect_lock:
|
235
|
+
backoff = self._reconnect_initial_backoff
|
236
|
+
attempt = 0
|
237
|
+
last_error: Exception | None = None
|
238
|
+
|
239
|
+
while attempt in range(0, self._reconnect_max_attempts):
|
240
|
+
attempt += 1
|
241
|
+
try:
|
242
|
+
# Close the existing stack and ClientSession
|
243
|
+
if self._exit_stack:
|
244
|
+
await self._exit_stack.aclose()
|
245
|
+
# Create a fresh stack and session
|
246
|
+
self._exit_stack = AsyncExitStack()
|
247
|
+
self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
|
248
|
+
|
249
|
+
self._connection_established = True
|
250
|
+
self._tools = None
|
251
|
+
|
252
|
+
logger.info("Reconnected to MCP server (%s) on attempt %d", self.server_name, attempt)
|
253
|
+
return
|
254
|
+
|
255
|
+
except Exception as e:
|
256
|
+
last_error = e
|
257
|
+
logger.warning("Reconnect attempt %d failed for %s: %s", attempt, self.server_name, e)
|
258
|
+
await asyncio.sleep(min(backoff, self._reconnect_max_backoff))
|
259
|
+
backoff = min(backoff * 2, self._reconnect_max_backoff)
|
260
|
+
|
261
|
+
# All attempts failed
|
262
|
+
self._connection_established = False
|
263
|
+
if last_error:
|
264
|
+
raise last_error
|
265
|
+
|
266
|
+
async def _with_reconnect(self, coro):
|
267
|
+
"""
|
268
|
+
Execute an awaited operation, reconnecting once on errors.
|
269
|
+
"""
|
270
|
+
try:
|
271
|
+
return await coro()
|
272
|
+
except Exception as e:
|
273
|
+
if self._reconnect_enabled:
|
274
|
+
logger.warning("MCP Client operation failed. Attempting reconnect: %s", e)
|
275
|
+
try:
|
276
|
+
await self._reconnect()
|
277
|
+
except Exception as reconnect_err:
|
278
|
+
logger.error("MCP Client reconnect attempt failed: %s", reconnect_err)
|
279
|
+
raise
|
280
|
+
return await coro()
|
281
|
+
raise
|
282
|
+
|
190
283
|
async def get_tools(self):
|
191
284
|
"""
|
192
285
|
Retrieve a dictionary of all tools served by the MCP server.
|
193
286
|
Uses unauthenticated session for discovery.
|
194
287
|
"""
|
195
288
|
|
196
|
-
|
197
|
-
|
289
|
+
async def _get_tools():
|
290
|
+
session = self._session
|
291
|
+
return await session.list_tools()
|
198
292
|
|
199
|
-
|
293
|
+
try:
|
294
|
+
response = await self._with_reconnect(_get_tools)
|
295
|
+
except Exception as e:
|
296
|
+
logger.warning("Failed to get tools: %s", e)
|
297
|
+
raise
|
200
298
|
|
201
299
|
return {
|
202
300
|
tool.name:
|
@@ -204,7 +302,8 @@ class MCPBaseClient(ABC):
|
|
204
302
|
tool_name=tool.name,
|
205
303
|
tool_description=tool.description,
|
206
304
|
tool_input_schema=tool.inputSchema,
|
207
|
-
parent_client=self
|
305
|
+
parent_client=self,
|
306
|
+
tool_call_timeout=self._tool_call_timeout)
|
208
307
|
for tool in response.tools
|
209
308
|
}
|
210
309
|
|
@@ -233,13 +332,42 @@ class MCPBaseClient(ABC):
|
|
233
332
|
raise MCPToolNotFoundError(tool_name, self.server_name)
|
234
333
|
return tool
|
235
334
|
|
335
|
+
def set_user_auth_callback(self, auth_callback: Callable[[AuthFlowType], AuthenticatedContext]):
|
336
|
+
"""Set the user authentication callback."""
|
337
|
+
if self._auth_provider and hasattr(self._auth_provider, "_set_custom_auth_callback"):
|
338
|
+
self._auth_provider._set_custom_auth_callback(auth_callback)
|
339
|
+
|
236
340
|
@mcp_exception_handler
|
237
|
-
async def
|
341
|
+
async def call_tool_with_meta(self, tool_name: str, args: dict, session_id: str):
|
342
|
+
from mcp.types import CallToolRequest
|
343
|
+
from mcp.types import CallToolRequestParams
|
344
|
+
from mcp.types import CallToolResult
|
345
|
+
from mcp.types import ClientRequest
|
346
|
+
|
238
347
|
if not self._session:
|
239
348
|
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
240
349
|
|
241
|
-
|
242
|
-
|
350
|
+
async def _call_tool_with_meta():
|
351
|
+
params = CallToolRequestParams(name=tool_name, arguments=args, **{"_meta": {"session_id": session_id}})
|
352
|
+
req = ClientRequest(CallToolRequest(params=params))
|
353
|
+
# We will increase the timeout to 5 minutes if the tool call timeout is less than 5 min and
|
354
|
+
# auth is enabled.
|
355
|
+
if self._auth_provider and self._tool_call_timeout.total_seconds() < 300:
|
356
|
+
timeout = timedelta(seconds=300)
|
357
|
+
else:
|
358
|
+
timeout = self._tool_call_timeout
|
359
|
+
return await self._session.send_request(req, CallToolResult, request_read_timeout_seconds=timeout)
|
360
|
+
|
361
|
+
return await self._with_reconnect(_call_tool_with_meta)
|
362
|
+
|
363
|
+
@mcp_exception_handler
|
364
|
+
async def call_tool(self, tool_name: str, tool_args: dict | None):
|
365
|
+
|
366
|
+
async def _call_tool():
|
367
|
+
session = self._session
|
368
|
+
return await session.call_tool(tool_name, tool_args, read_timeout_seconds=self._tool_call_timeout)
|
369
|
+
|
370
|
+
return await self._with_reconnect(_call_tool)
|
243
371
|
|
244
372
|
|
245
373
|
class MCPSSEClient(MCPBaseClient):
|
@@ -250,8 +378,19 @@ class MCPSSEClient(MCPBaseClient):
|
|
250
378
|
url (str): The url of the MCP server
|
251
379
|
"""
|
252
380
|
|
253
|
-
def __init__(self,
|
254
|
-
|
381
|
+
def __init__(self,
|
382
|
+
url: str,
|
383
|
+
tool_call_timeout: timedelta = timedelta(seconds=5),
|
384
|
+
reconnect_enabled: bool = True,
|
385
|
+
reconnect_max_attempts: int = 2,
|
386
|
+
reconnect_initial_backoff: float = 0.5,
|
387
|
+
reconnect_max_backoff: float = 50.0):
|
388
|
+
super().__init__("sse",
|
389
|
+
tool_call_timeout=tool_call_timeout,
|
390
|
+
reconnect_enabled=reconnect_enabled,
|
391
|
+
reconnect_max_attempts=reconnect_max_attempts,
|
392
|
+
reconnect_initial_backoff=reconnect_initial_backoff,
|
393
|
+
reconnect_max_backoff=reconnect_max_backoff)
|
255
394
|
self._url = url
|
256
395
|
|
257
396
|
@property
|
@@ -286,8 +425,21 @@ class MCPStdioClient(MCPBaseClient):
|
|
286
425
|
env (dict[str, str] | None): Environment variables to set for the process
|
287
426
|
"""
|
288
427
|
|
289
|
-
def __init__(self,
|
290
|
-
|
428
|
+
def __init__(self,
|
429
|
+
command: str,
|
430
|
+
args: list[str] | None = None,
|
431
|
+
env: dict[str, str] | None = None,
|
432
|
+
tool_call_timeout: timedelta = timedelta(seconds=5),
|
433
|
+
reconnect_enabled: bool = True,
|
434
|
+
reconnect_max_attempts: int = 2,
|
435
|
+
reconnect_initial_backoff: float = 0.5,
|
436
|
+
reconnect_max_backoff: float = 50.0):
|
437
|
+
super().__init__("stdio",
|
438
|
+
tool_call_timeout=tool_call_timeout,
|
439
|
+
reconnect_enabled=reconnect_enabled,
|
440
|
+
reconnect_max_attempts=reconnect_max_attempts,
|
441
|
+
reconnect_initial_backoff=reconnect_initial_backoff,
|
442
|
+
reconnect_max_backoff=reconnect_max_backoff)
|
291
443
|
self._command = command
|
292
444
|
self._args = args
|
293
445
|
self._env = env
|
@@ -331,8 +483,21 @@ class MCPStreamableHTTPClient(MCPBaseClient):
|
|
331
483
|
auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
|
332
484
|
"""
|
333
485
|
|
334
|
-
def __init__(self,
|
335
|
-
|
486
|
+
def __init__(self,
|
487
|
+
url: str,
|
488
|
+
auth_provider: AuthProviderBase | None = None,
|
489
|
+
tool_call_timeout: timedelta = timedelta(seconds=5),
|
490
|
+
reconnect_enabled: bool = True,
|
491
|
+
reconnect_max_attempts: int = 2,
|
492
|
+
reconnect_initial_backoff: float = 0.5,
|
493
|
+
reconnect_max_backoff: float = 50.0):
|
494
|
+
super().__init__("streamable-http",
|
495
|
+
auth_provider=auth_provider,
|
496
|
+
tool_call_timeout=tool_call_timeout,
|
497
|
+
reconnect_enabled=reconnect_enabled,
|
498
|
+
reconnect_max_attempts=reconnect_max_attempts,
|
499
|
+
reconnect_initial_backoff=reconnect_initial_backoff,
|
500
|
+
reconnect_max_backoff=reconnect_max_backoff)
|
336
501
|
self._url = url
|
337
502
|
|
338
503
|
@property
|
@@ -371,15 +536,20 @@ class MCPToolClient:
|
|
371
536
|
|
372
537
|
def __init__(self,
|
373
538
|
session: ClientSession,
|
539
|
+
parent_client: "MCPBaseClient",
|
374
540
|
tool_name: str,
|
375
541
|
tool_description: str | None,
|
376
542
|
tool_input_schema: dict | None = None,
|
377
|
-
|
543
|
+
tool_call_timeout: timedelta = timedelta(seconds=5)):
|
378
544
|
self._session = session
|
379
545
|
self._tool_name = tool_name
|
380
546
|
self._tool_description = tool_description
|
381
547
|
self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
|
382
548
|
self._parent_client = parent_client
|
549
|
+
self._tool_call_timeout = tool_call_timeout
|
550
|
+
|
551
|
+
if self._parent_client is None:
|
552
|
+
raise RuntimeError("MCPToolClient initialized without a parent client.")
|
383
553
|
|
384
554
|
@property
|
385
555
|
def name(self):
|
@@ -417,20 +587,49 @@ class MCPToolClient:
|
|
417
587
|
"""
|
418
588
|
if self._session is None:
|
419
589
|
raise RuntimeError("No session available for tool call")
|
420
|
-
logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
|
421
|
-
result = await self._session.call_tool(self._tool_name, tool_args)
|
422
590
|
|
423
|
-
|
591
|
+
# Extract context information
|
592
|
+
session_id = None
|
593
|
+
try:
|
594
|
+
from nat.builder.context import Context as _Ctx
|
595
|
+
|
596
|
+
# get auth callback (for example: WebSocketAuthenticationFlowHandler). this is lazily set in the client
|
597
|
+
# on first tool call
|
598
|
+
auth_callback = _Ctx.get().user_auth_callback
|
599
|
+
if auth_callback and self._parent_client:
|
600
|
+
# set custom auth callback
|
601
|
+
self._parent_client.set_user_auth_callback(auth_callback)
|
602
|
+
|
603
|
+
# get session id from context, authentication is done per-websocket session for tool calls
|
604
|
+
cookies = getattr(_Ctx.get().metadata, "cookies", None)
|
605
|
+
if cookies:
|
606
|
+
session_id = cookies.get("nat-session")
|
607
|
+
except Exception:
|
608
|
+
pass
|
424
609
|
|
425
|
-
|
426
|
-
if
|
427
|
-
|
610
|
+
try:
|
611
|
+
if session_id:
|
612
|
+
logger.info("Calling tool %s with arguments %s for a user session", self._tool_name, tool_args)
|
613
|
+
result = await self._parent_client.call_tool_with_meta(self._tool_name, tool_args, session_id)
|
428
614
|
else:
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
615
|
+
logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
|
616
|
+
result = await self._session.call_tool(self._tool_name, tool_args)
|
617
|
+
|
618
|
+
output = []
|
619
|
+
for res in result.content:
|
620
|
+
if isinstance(res, TextContent):
|
621
|
+
output.append(res.text)
|
622
|
+
else:
|
623
|
+
# Log non-text content for now
|
624
|
+
logger.warning("Got not-text output from %s of type %s", self.name, type(res))
|
625
|
+
result_str = "\n".join(output)
|
626
|
+
|
627
|
+
if result.isError:
|
628
|
+
mcp_error: MCPError = convert_to_mcp_error(RuntimeError(result_str), self._parent_client.server_name)
|
629
|
+
raise mcp_error
|
630
|
+
|
631
|
+
except MCPError as e:
|
632
|
+
format_mcp_error(e, include_traceback=False)
|
633
|
+
result_str = "MCPToolClient tool call failed: %s" % e.original_exception
|
435
634
|
|
436
635
|
return result_str
|
nat/plugins/mcp/client_impl.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
16
|
import logging
|
17
|
+
from datetime import timedelta
|
17
18
|
from typing import Literal
|
18
19
|
|
19
20
|
from pydantic import BaseModel
|
@@ -55,7 +56,8 @@ class MCPServerConfig(BaseModel):
|
|
55
56
|
env: dict[str, str] | None = Field(default=None, description="Environment variables for the stdio process")
|
56
57
|
|
57
58
|
# Authentication configuration
|
58
|
-
auth_provider: AuthenticationRef | None = Field(default=None,
|
59
|
+
auth_provider: str | AuthenticationRef | None = Field(default=None,
|
60
|
+
description="Reference to authentication provider")
|
59
61
|
|
60
62
|
@model_validator(mode="after")
|
61
63
|
def validate_model(self):
|
@@ -90,6 +92,20 @@ class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
|
|
90
92
|
Configuration for connecting to an MCP server as a client and exposing selected tools.
|
91
93
|
"""
|
92
94
|
server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)")
|
95
|
+
tool_call_timeout: timedelta = Field(
|
96
|
+
default=timedelta(seconds=60),
|
97
|
+
description="Timeout (in seconds) for the MCP tool call. Defaults to 60 seconds.")
|
98
|
+
reconnect_enabled: bool = Field(
|
99
|
+
default=True,
|
100
|
+
description="Whether to enable reconnecting to the MCP server if the connection is lost. \
|
101
|
+
Defaults to True.")
|
102
|
+
reconnect_max_attempts: int = Field(default=2,
|
103
|
+
ge=0,
|
104
|
+
description="Maximum number of reconnect attempts. Defaults to 2.")
|
105
|
+
reconnect_initial_backoff: float = Field(
|
106
|
+
default=0.5, ge=0.0, description="Initial backoff time for reconnect attempts. Defaults to 0.5 seconds.")
|
107
|
+
reconnect_max_backoff: float = Field(
|
108
|
+
default=50.0, ge=0.0, description="Maximum backoff time for reconnect attempts. Defaults to 50 seconds.")
|
93
109
|
tool_overrides: dict[str, MCPToolOverrideConfig] | None = Field(
|
94
110
|
default=None,
|
95
111
|
description="""Optional tool name overrides and description changes.
|
@@ -102,6 +118,13 @@ class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
|
|
102
118
|
description: "Multiply two numbers" # alias defaults to original name
|
103
119
|
""")
|
104
120
|
|
121
|
+
@model_validator(mode="after")
|
122
|
+
def _validate_reconnect_backoff(self) -> "MCPClientConfig":
|
123
|
+
"""Validate reconnect backoff values."""
|
124
|
+
if self.reconnect_max_backoff < self.reconnect_initial_backoff:
|
125
|
+
raise ValueError("reconnect_max_backoff must be greater than or equal to reconnect_initial_backoff")
|
126
|
+
return self
|
127
|
+
|
105
128
|
|
106
129
|
@register_function_group(config_type=MCPClientConfig)
|
107
130
|
async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
|
@@ -126,11 +149,29 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
|
|
126
149
|
if config.server.transport == "stdio":
|
127
150
|
if not config.server.command:
|
128
151
|
raise ValueError("command is required for stdio transport")
|
129
|
-
client = MCPStdioClient(config.server.command,
|
152
|
+
client = MCPStdioClient(config.server.command,
|
153
|
+
config.server.args,
|
154
|
+
config.server.env,
|
155
|
+
config.tool_call_timeout,
|
156
|
+
reconnect_enabled=config.reconnect_enabled,
|
157
|
+
reconnect_max_attempts=config.reconnect_max_attempts,
|
158
|
+
reconnect_initial_backoff=config.reconnect_initial_backoff,
|
159
|
+
reconnect_max_backoff=config.reconnect_max_backoff)
|
130
160
|
elif config.server.transport == "sse":
|
131
|
-
client = MCPSSEClient(str(config.server.url)
|
161
|
+
client = MCPSSEClient(str(config.server.url),
|
162
|
+
tool_call_timeout=config.tool_call_timeout,
|
163
|
+
reconnect_enabled=config.reconnect_enabled,
|
164
|
+
reconnect_max_attempts=config.reconnect_max_attempts,
|
165
|
+
reconnect_initial_backoff=config.reconnect_initial_backoff,
|
166
|
+
reconnect_max_backoff=config.reconnect_max_backoff)
|
132
167
|
elif config.server.transport == "streamable-http":
|
133
|
-
client = MCPStreamableHTTPClient(str(config.server.url),
|
168
|
+
client = MCPStreamableHTTPClient(str(config.server.url),
|
169
|
+
auth_provider=auth_provider,
|
170
|
+
tool_call_timeout=config.tool_call_timeout,
|
171
|
+
reconnect_enabled=config.reconnect_enabled,
|
172
|
+
reconnect_max_attempts=config.reconnect_max_attempts,
|
173
|
+
reconnect_initial_backoff=config.reconnect_initial_backoff,
|
174
|
+
reconnect_max_backoff=config.reconnect_max_backoff)
|
134
175
|
else:
|
135
176
|
raise ValueError(f"Unsupported transport: {config.server.transport}")
|
136
177
|
|
{nvidia_nat_mcp-1.3.0a20250925.dist-info → nvidia_nat_mcp-1.3.0a20250928.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: nvidia-nat-mcp
|
3
|
-
Version: 1.3.
|
3
|
+
Version: 1.3.0a20250928
|
4
4
|
Summary: Subpackage for MCP client integration in NeMo Agent toolkit
|
5
5
|
Keywords: ai,rag,agents,mcp
|
6
6
|
Classifier: Programming Language :: Python
|
@@ -9,7 +9,7 @@ Classifier: Programming Language :: Python :: 3.12
|
|
9
9
|
Classifier: Programming Language :: Python :: 3.13
|
10
10
|
Requires-Python: <3.14,>=3.11
|
11
11
|
Description-Content-Type: text/markdown
|
12
|
-
Requires-Dist: nvidia-nat==v1.3.
|
12
|
+
Requires-Dist: nvidia-nat==v1.3.0a20250928
|
13
13
|
Requires-Dist: mcp~=1.14
|
14
14
|
|
15
15
|
<!--
|
@@ -1,18 +1,19 @@
|
|
1
1
|
nat/meta/pypi.md,sha256=GyV4DI1d9ThgEhnYTQ0vh40Q9hPC8jN-goLnRiFDmZ8,1498
|
2
2
|
nat/plugins/mcp/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
3
|
-
nat/plugins/mcp/client_base.py,sha256=
|
4
|
-
nat/plugins/mcp/client_impl.py,sha256=
|
3
|
+
nat/plugins/mcp/client_base.py,sha256=x6NHs3X2alyB6Wo4pjWZrGyuU7lLqGC_cLO9fuL4Zgw,25194
|
4
|
+
nat/plugins/mcp/client_impl.py,sha256=M1gTMlp3RLhFaAHOvwkk38boFy05MixV_glrIEcMjvo,10759
|
5
5
|
nat/plugins/mcp/exception_handler.py,sha256=JdPdZG1NgWpdRnIz7JTGHiJASS5wot9nJiD3SRWV4Kw,7649
|
6
6
|
nat/plugins/mcp/exceptions.py,sha256=EGVOnYlui8xufm8dhJyPL1SUqBLnCGOTvRoeyNcmcWE,5980
|
7
7
|
nat/plugins/mcp/register.py,sha256=HOT2Wl2isGuyFc7BUTi58-BbjI5-EtZMZo7stsv5pN4,831
|
8
8
|
nat/plugins/mcp/tool.py,sha256=v3MFsiaLJy8Ourcfqa6ohtAE2Nn-vqpC6Q6gsCdJ28Q,6165
|
9
9
|
nat/plugins/mcp/utils.py,sha256=3fuzYpC14wrfMOTOGvY2KHWcxZvBWqrxdDZD17lhmC8,4055
|
10
10
|
nat/plugins/mcp/auth/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
11
|
-
nat/plugins/mcp/auth/
|
12
|
-
nat/plugins/mcp/auth/
|
11
|
+
nat/plugins/mcp/auth/auth_flow_handler.py,sha256=eRqRS2t-YzJ3kEFaG0PEC8DzctYzaJzr9XLZGkvuxq0,6018
|
12
|
+
nat/plugins/mcp/auth/auth_provider.py,sha256=5TTaPlIXMgwrE4YcZ_HO9-GNBBFaDTdUKAa5fuALayI,19142
|
13
|
+
nat/plugins/mcp/auth/auth_provider_config.py,sha256=vhU47Vcp_30M8tWu0FumbJ6pdUnFbBZm-ABdNlup__U,3821
|
13
14
|
nat/plugins/mcp/auth/register.py,sha256=yzphsn1I4a5G39_IacbuX0ZQqGM8fevvTUM_B94UXKE,1211
|
14
|
-
nvidia_nat_mcp-1.3.
|
15
|
-
nvidia_nat_mcp-1.3.
|
16
|
-
nvidia_nat_mcp-1.3.
|
17
|
-
nvidia_nat_mcp-1.3.
|
18
|
-
nvidia_nat_mcp-1.3.
|
15
|
+
nvidia_nat_mcp-1.3.0a20250928.dist-info/METADATA,sha256=VnuxYtll39Hir1_hDQvhqRWPM00M6A9JD3NCd7JISaw,1997
|
16
|
+
nvidia_nat_mcp-1.3.0a20250928.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
17
|
+
nvidia_nat_mcp-1.3.0a20250928.dist-info/entry_points.txt,sha256=rYvUp4i-klBr3bVNh7zYOPXret704vTjvCk1qd7FooI,97
|
18
|
+
nvidia_nat_mcp-1.3.0a20250928.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
|
19
|
+
nvidia_nat_mcp-1.3.0a20250928.dist-info/RECORD,,
|
File without changes
|
{nvidia_nat_mcp-1.3.0a20250925.dist-info → nvidia_nat_mcp-1.3.0a20250928.dist-info}/entry_points.txt
RENAMED
File without changes
|
{nvidia_nat_mcp-1.3.0a20250925.dist-info → nvidia_nat_mcp-1.3.0a20250928.dist-info}/top_level.txt
RENAMED
File without changes
|