nvidia-nat-mcp 1.3.0a20250926__py3-none-any.whl → 1.3.0a20250928__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/plugins/mcp/auth/auth_flow_handler.py +144 -0
- nat/plugins/mcp/auth/auth_provider.py +106 -79
- nat/plugins/mcp/auth/auth_provider_config.py +5 -2
- nat/plugins/mcp/client_base.py +117 -45
- nat/plugins/mcp/client_impl.py +4 -2
- {nvidia_nat_mcp-1.3.0a20250926.dist-info → nvidia_nat_mcp-1.3.0a20250928.dist-info}/METADATA +2 -2
- {nvidia_nat_mcp-1.3.0a20250926.dist-info → nvidia_nat_mcp-1.3.0a20250928.dist-info}/RECORD +10 -9
- {nvidia_nat_mcp-1.3.0a20250926.dist-info → nvidia_nat_mcp-1.3.0a20250928.dist-info}/WHEEL +0 -0
- {nvidia_nat_mcp-1.3.0a20250926.dist-info → nvidia_nat_mcp-1.3.0a20250928.dist-info}/entry_points.txt +0 -0
- {nvidia_nat_mcp-1.3.0a20250926.dist-info → nvidia_nat_mcp-1.3.0a20250928.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,144 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import asyncio
|
17
|
+
import logging
|
18
|
+
import secrets
|
19
|
+
import webbrowser
|
20
|
+
|
21
|
+
import pkce
|
22
|
+
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
23
|
+
from fastapi import FastAPI
|
24
|
+
|
25
|
+
from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
|
26
|
+
from nat.data_models.authentication import AuthenticatedContext
|
27
|
+
from nat.data_models.authentication import AuthFlowType
|
28
|
+
from nat.data_models.authentication import AuthProviderBaseConfig
|
29
|
+
from nat.front_ends.console.authentication_flow_handler import ConsoleAuthenticationFlowHandler
|
30
|
+
from nat.front_ends.console.authentication_flow_handler import _FlowState
|
31
|
+
from nat.front_ends.fastapi.fastapi_front_end_controller import _FastApiFrontEndController
|
32
|
+
|
33
|
+
logger = logging.getLogger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
class MCPAuthenticationFlowHandler(ConsoleAuthenticationFlowHandler):
|
37
|
+
"""
|
38
|
+
Authentication helper for MCP environments.
|
39
|
+
|
40
|
+
This handler is specifically designed for MCP tool discovery scenarios where
|
41
|
+
authentication needs to happen before the default auth_callback is available
|
42
|
+
in the Context. It handles OAuth2 authorization code flow during MCP client
|
43
|
+
startup and tool discovery phases.
|
44
|
+
|
45
|
+
Key differences from console handler:
|
46
|
+
- Only supports OAuth2 Authorization Code flow (no HTTP Basic)
|
47
|
+
- Optimized for MCP tool discovery workflows
|
48
|
+
- Designed for single-use authentication during startup
|
49
|
+
"""
|
50
|
+
|
51
|
+
def __init__(self):
|
52
|
+
super().__init__()
|
53
|
+
self._server_controller: _FastApiFrontEndController | None = None
|
54
|
+
self._redirect_app: FastAPI | None = None
|
55
|
+
self._server_lock = asyncio.Lock()
|
56
|
+
self._oauth_client: AsyncOAuth2Client | None = None
|
57
|
+
|
58
|
+
async def authenticate(self, config: AuthProviderBaseConfig, method: AuthFlowType) -> AuthenticatedContext:
|
59
|
+
"""
|
60
|
+
Handle the OAuth2 authorization code flow for MCP environments.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
config: OAuth2 configuration for MCP server
|
64
|
+
method: Authentication method (only OAUTH2_AUTHORIZATION_CODE supported)
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
AuthenticatedContext with Bearer token for MCP server access
|
68
|
+
|
69
|
+
Raises:
|
70
|
+
ValueError: If config is invalid for MCP use case
|
71
|
+
NotImplementedError: If method is not OAuth2 Authorization Code
|
72
|
+
"""
|
73
|
+
logger.info("Starting MCP authentication flow")
|
74
|
+
|
75
|
+
if method == AuthFlowType.OAUTH2_AUTHORIZATION_CODE:
|
76
|
+
if not isinstance(config, OAuth2AuthCodeFlowProviderConfig):
|
77
|
+
raise ValueError("Requested OAuth2 Authorization Code Flow but passed invalid config")
|
78
|
+
|
79
|
+
# MCP-specific validation
|
80
|
+
if not config.redirect_uri:
|
81
|
+
raise ValueError("MCP authentication requires redirect_uri to be configured")
|
82
|
+
|
83
|
+
logger.info("MCP authentication configured for server: %s", getattr(config, 'server_url', 'unknown'))
|
84
|
+
return await self._handle_oauth2_auth_code_flow(config)
|
85
|
+
|
86
|
+
raise NotImplementedError(f'Auth method "{method}" not supported for MCP environments')
|
87
|
+
|
88
|
+
async def _handle_oauth2_auth_code_flow(self, cfg: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext:
|
89
|
+
logger.info("Starting MCP OAuth2 authorization code flow")
|
90
|
+
|
91
|
+
state = secrets.token_urlsafe(16)
|
92
|
+
flow_state = _FlowState()
|
93
|
+
client = self.construct_oauth_client(cfg)
|
94
|
+
|
95
|
+
flow_state.token_url = cfg.token_url
|
96
|
+
flow_state.use_pkce = cfg.use_pkce
|
97
|
+
|
98
|
+
# PKCE bits
|
99
|
+
if cfg.use_pkce:
|
100
|
+
verifier, challenge = pkce.generate_pkce_pair()
|
101
|
+
flow_state.verifier = verifier
|
102
|
+
flow_state.challenge = challenge
|
103
|
+
logger.debug("PKCE enabled for MCP authentication")
|
104
|
+
|
105
|
+
auth_url, _ = client.create_authorization_url(
|
106
|
+
cfg.authorization_url,
|
107
|
+
state=state,
|
108
|
+
code_verifier=flow_state.verifier if cfg.use_pkce else None,
|
109
|
+
code_challenge=flow_state.challenge if cfg.use_pkce else None,
|
110
|
+
**(cfg.authorization_kwargs or {})
|
111
|
+
)
|
112
|
+
|
113
|
+
async with self._server_lock:
|
114
|
+
if self._redirect_app is None:
|
115
|
+
self._redirect_app = await self._build_redirect_app()
|
116
|
+
|
117
|
+
await self._start_redirect_server()
|
118
|
+
self._flows[state] = flow_state
|
119
|
+
|
120
|
+
logger.info("MCP authentication: Your browser has been opened for authentication.")
|
121
|
+
logger.info("This will authenticate you with the MCP server for tool discovery.")
|
122
|
+
webbrowser.open(auth_url)
|
123
|
+
|
124
|
+
# Use default timeout for MCP tool discovery
|
125
|
+
timeout = 300
|
126
|
+
|
127
|
+
try:
|
128
|
+
token = await asyncio.wait_for(flow_state.future, timeout=timeout)
|
129
|
+
logger.info("MCP authentication successful, token obtained")
|
130
|
+
except asyncio.TimeoutError as exc:
|
131
|
+
logger.error("MCP authentication timed out")
|
132
|
+
raise RuntimeError(f"MCP authentication timed out ({timeout} seconds). Please try again.") from exc
|
133
|
+
finally:
|
134
|
+
async with self._server_lock:
|
135
|
+
self._flows.pop(state, None)
|
136
|
+
await self._stop_redirect_server()
|
137
|
+
|
138
|
+
return AuthenticatedContext(
|
139
|
+
headers={"Authorization": f"Bearer {token['access_token']}"},
|
140
|
+
metadata={
|
141
|
+
"expires_at": token.get("expires_at"),
|
142
|
+
"raw_token": token,
|
143
|
+
},
|
144
|
+
)
|
@@ -14,6 +14,8 @@
|
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
16
|
import logging
|
17
|
+
from collections.abc import Awaitable
|
18
|
+
from typing import Callable
|
17
19
|
from urllib.parse import urljoin
|
18
20
|
from urllib.parse import urlparse
|
19
21
|
|
@@ -26,10 +28,12 @@ from mcp.shared.auth import OAuthClientInformationFull
|
|
26
28
|
from mcp.shared.auth import OAuthClientMetadata
|
27
29
|
from mcp.shared.auth import OAuthMetadata
|
28
30
|
from mcp.shared.auth import ProtectedResourceMetadata
|
31
|
+
from nat.authentication.interfaces import AuthenticatedContext
|
32
|
+
from nat.authentication.interfaces import AuthFlowType
|
29
33
|
from nat.authentication.interfaces import AuthProviderBase
|
30
|
-
from nat.
|
31
|
-
from nat.data_models.authentication import AuthRequest
|
34
|
+
from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
|
32
35
|
from nat.data_models.authentication import AuthResult
|
36
|
+
from nat.plugins.mcp.auth.auth_flow_handler import MCPAuthenticationFlowHandler
|
33
37
|
from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig
|
34
38
|
|
35
39
|
logger = logging.getLogger(__name__)
|
@@ -40,6 +44,7 @@ class OAuth2Endpoints(BaseModel):
|
|
40
44
|
authorization_url: HttpUrl = Field(..., description="OAuth2 authorization endpoint URL")
|
41
45
|
token_url: HttpUrl = Field(..., description="OAuth2 token endpoint URL")
|
42
46
|
registration_url: HttpUrl | None = Field(default=None, description="OAuth2 client registration endpoint URL")
|
47
|
+
scopes: list[str] | None = Field(default=None, description="OAuth2 scopes to be used for the authentication")
|
43
48
|
|
44
49
|
|
45
50
|
class OAuth2Credentials(BaseModel):
|
@@ -60,9 +65,11 @@ class DiscoverOAuth2Endpoints:
|
|
60
65
|
def __init__(self, config: MCPOAuth2ProviderConfig):
|
61
66
|
self.config = config
|
62
67
|
self._cached_endpoints: OAuth2Endpoints | None = None
|
63
|
-
self.
|
68
|
+
self._authenticated_servers: dict[str, AuthResult] = {}
|
64
69
|
|
65
|
-
|
70
|
+
self._flow_handler: MCPAuthenticationFlowHandler = MCPAuthenticationFlowHandler()
|
71
|
+
|
72
|
+
async def discover(self, response: httpx.Response | None = None) -> tuple[OAuth2Endpoints, bool]:
|
66
73
|
"""
|
67
74
|
Discover OAuth2 endpoints from MCP server.
|
68
75
|
|
@@ -73,21 +80,24 @@ class DiscoverOAuth2Endpoints:
|
|
73
80
|
Returns:
|
74
81
|
A tuple of OAuth2Endpoints and a boolean indicating if the endpoints have changed.
|
75
82
|
"""
|
83
|
+
is_401_retry = response is not None and response.status_code == 401
|
76
84
|
# Fast path: reuse cache when not a 401 retry
|
77
|
-
if
|
85
|
+
if not is_401_retry and self._cached_endpoints is not None:
|
78
86
|
return self._cached_endpoints, False
|
79
87
|
|
80
88
|
issuer: str = str(self.config.server_url) # default to server URL
|
81
89
|
endpoints: OAuth2Endpoints | None = None
|
82
90
|
|
83
91
|
# 1) 401 hint (RFC 9728) if present
|
84
|
-
if
|
85
|
-
|
86
|
-
if
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
92
|
+
if is_401_retry and response:
|
93
|
+
www_authenticate = response.headers.get("WWW-Authenticate")
|
94
|
+
if www_authenticate:
|
95
|
+
hint_url = self._extract_from_www_authenticate_header(www_authenticate)
|
96
|
+
if hint_url:
|
97
|
+
logger.info("Using RFC 9728 resource_metadata hint: %s", hint_url)
|
98
|
+
issuer_hint = await self._fetch_pr_issuer(hint_url)
|
99
|
+
if issuer_hint:
|
100
|
+
issuer = issuer_hint
|
91
101
|
|
92
102
|
# 2) Try RS protected resource well-known if we still only have default issuer
|
93
103
|
if issuer == str(self.config.server_url):
|
@@ -105,10 +115,7 @@ class DiscoverOAuth2Endpoints:
|
|
105
115
|
if endpoints is None:
|
106
116
|
raise RuntimeError("Could not discover OAuth2 endpoints from MCP server")
|
107
117
|
|
108
|
-
changed = (self._cached_endpoints is None
|
109
|
-
or endpoints.authorization_url != self._cached_endpoints.authorization_url
|
110
|
-
or endpoints.token_url != self._cached_endpoints.token_url
|
111
|
-
or endpoints.registration_url != self._cached_endpoints.registration_url)
|
118
|
+
changed = (self._cached_endpoints is None or endpoints.model_dump() != self._cached_endpoints.model_dump())
|
112
119
|
self._cached_endpoints = endpoints
|
113
120
|
logger.info("OAuth2 endpoints selected: %s", self._cached_endpoints)
|
114
121
|
return self._cached_endpoints, changed
|
@@ -155,10 +162,29 @@ class DiscoverOAuth2Endpoints:
|
|
155
162
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
156
163
|
for url in urls:
|
157
164
|
try:
|
158
|
-
resp = await client.get(url, headers={"Accept": "application/json"})
|
165
|
+
resp = await client.get(url, follow_redirects=True, headers={"Accept": "application/json"})
|
159
166
|
if resp.status_code != 200:
|
160
167
|
continue
|
168
|
+
|
169
|
+
# Check content type before attempting JSON parsing
|
170
|
+
content_type = resp.headers.get("content-type", "").lower()
|
171
|
+
if "application/json" not in content_type:
|
172
|
+
logger.info(
|
173
|
+
"Discovery endpoint %s returned non-JSON content type: %s. "
|
174
|
+
"This may indicate the endpoint doesn't support discovery or requires authentication.",
|
175
|
+
url,
|
176
|
+
content_type)
|
177
|
+
# If it's HTML, log a more helpful message
|
178
|
+
if "text/html" in content_type:
|
179
|
+
logger.info("The endpoint appears to be returning an HTML page instead of OAuth metadata. "
|
180
|
+
"This often means:")
|
181
|
+
logger.info("1. The OAuth discovery endpoint doesn't exist at this URL")
|
182
|
+
logger.info("2. The server requires authentication before providing discovery metadata")
|
183
|
+
logger.info("3. The URL is pointing to a web application instead of an OAuth server")
|
184
|
+
continue
|
185
|
+
|
161
186
|
body = await resp.aread()
|
187
|
+
|
162
188
|
try:
|
163
189
|
meta = OAuthMetadata.model_validate_json(body)
|
164
190
|
except Exception as e:
|
@@ -167,14 +193,18 @@ class DiscoverOAuth2Endpoints:
|
|
167
193
|
if meta.authorization_endpoint and meta.token_endpoint:
|
168
194
|
logger.info("Discovered OAuth2 endpoints from %s", url)
|
169
195
|
# this is bit of a hack to get the scopes supported by the auth server
|
170
|
-
self._last_oauth_scopes = meta.scopes_supported
|
171
196
|
return OAuth2Endpoints(
|
172
197
|
authorization_url=str(meta.authorization_endpoint),
|
173
198
|
token_url=str(meta.token_endpoint),
|
174
199
|
registration_url=str(meta.registration_endpoint) if meta.registration_endpoint else None,
|
200
|
+
scopes=meta.scopes_supported,
|
175
201
|
)
|
176
202
|
except Exception as e:
|
177
203
|
logger.debug("Discovery failed at %s: %s", url, e)
|
204
|
+
|
205
|
+
# If we get here, all discovery URLs failed
|
206
|
+
logger.info("OAuth discovery failed for all attempted URLs.")
|
207
|
+
logger.info("Attempted URLs: %s", urls)
|
178
208
|
return None
|
179
209
|
|
180
210
|
def _build_path_aware_discovery_urls(self, base_or_issuer: str) -> list[str]:
|
@@ -184,17 +214,19 @@ class DiscoverOAuth2Endpoints:
|
|
184
214
|
path = (p.path or "").rstrip("/")
|
185
215
|
urls: list[str] = []
|
186
216
|
if path:
|
187
|
-
|
217
|
+
# this is the specified by the MCP spec
|
218
|
+
urls.append(urljoin(base, f".well-known/oauth-protected-resource{path}"))
|
219
|
+
# this is fallback for backward compatibility
|
220
|
+
urls.append(urljoin(base, f"{path}/.well-known/oauth-authorization-server"))
|
188
221
|
urls.append(urljoin(base, "/.well-known/oauth-authorization-server"))
|
189
222
|
if path:
|
190
|
-
|
223
|
+
# this is the specified by the MCP spec
|
224
|
+
urls.append(urljoin(base, f".well-known/openid-configuration{path}"))
|
225
|
+
# this is fallback for backward compatibility
|
226
|
+
urls.append(urljoin(base, f"{path}/.well-known/openid-configuration"))
|
191
227
|
urls.append(base_or_issuer.rstrip("/") + "/.well-known/openid-configuration")
|
192
228
|
return urls
|
193
229
|
|
194
|
-
def scopes_supported(self) -> list[str] | None:
|
195
|
-
"""Get the last OAuth scopes discovered from the AS."""
|
196
|
-
return self._last_oauth_scopes
|
197
|
-
|
198
230
|
|
199
231
|
class DynamicClientRegistration:
|
200
232
|
"""Dynamic client registration utility."""
|
@@ -264,51 +296,54 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
|
264
296
|
|
265
297
|
# For the OAuth2 flow
|
266
298
|
self._auth_code_provider = None
|
299
|
+
self._flow_handler = MCPAuthenticationFlowHandler()
|
267
300
|
|
268
|
-
|
301
|
+
self._auth_callback = None
|
302
|
+
|
303
|
+
def _set_custom_auth_callback(self,
|
304
|
+
auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType],
|
305
|
+
Awaitable[AuthenticatedContext]]):
|
306
|
+
"""Set the custom authentication callback."""
|
307
|
+
if not self._auth_callback:
|
308
|
+
logger.info("Using custom authentication callback")
|
309
|
+
self._auth_callback = auth_callback
|
310
|
+
if self._auth_code_provider:
|
311
|
+
self._auth_code_provider._set_custom_auth_callback(self._auth_callback)
|
312
|
+
|
313
|
+
async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
|
269
314
|
"""
|
270
315
|
Authenticate using MCP OAuth2 flow via NAT framework.
|
316
|
+
|
317
|
+
If response is provided in kwargs (typically from a 401), performs:
|
271
318
|
1. Dynamic endpoints discovery (RFC9728 + RFC 8414 + OIDC)
|
272
319
|
2. Client registration (RFC7591)
|
273
|
-
3.
|
320
|
+
3. Authentication
|
321
|
+
|
322
|
+
Otherwise, performs standard authentication flow.
|
274
323
|
"""
|
275
|
-
|
276
|
-
if
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
# force fresh delegate (clears in-mem token cache)
|
288
|
-
self._auth_code_provider = None
|
289
|
-
# preserve other fields, just normalize reason & inject user_id
|
290
|
-
auth_request = auth_request.model_copy(update={
|
291
|
-
"reason": AuthReason.NORMAL, "user_id": user_id, "www_authenticate": None
|
292
|
-
})
|
293
|
-
# back-compat: propagate user_id if provided but not set in the request
|
294
|
-
elif user_id is not None and auth_request.user_id is None:
|
295
|
-
auth_request = auth_request.model_copy(update={"user_id": user_id})
|
296
|
-
|
297
|
-
# Perform the OAuth2 flow without lock
|
298
|
-
return await self._perform_oauth2_flow(auth_request=auth_request)
|
299
|
-
|
300
|
-
async def _discover_and_register(self, auth_request: AuthRequest):
|
324
|
+
response = kwargs.get('response')
|
325
|
+
if response and response.status_code == 401:
|
326
|
+
await self._discover_and_register(response=response)
|
327
|
+
|
328
|
+
return await self._nat_oauth2_authenticate(user_id=user_id)
|
329
|
+
|
330
|
+
@property
|
331
|
+
def _effective_scopes(self) -> list[str]:
|
332
|
+
"""Get the effective scopes to be used for the authentication."""
|
333
|
+
return self.config.scopes or (self._cached_endpoints.scopes if self._cached_endpoints else []) or []
|
334
|
+
|
335
|
+
async def _discover_and_register(self, response: httpx.Response | None = None):
|
301
336
|
"""
|
302
337
|
Discover OAuth2 endpoints and register an OAuth2 client with the Authorization Server
|
303
338
|
using OIDC client registration.
|
304
339
|
"""
|
305
340
|
# Discover OAuth2 endpoints
|
306
|
-
self._cached_endpoints, endpoints_changed = await self._discoverer.discover(
|
307
|
-
www_authenticate=auth_request.www_authenticate)
|
341
|
+
self._cached_endpoints, endpoints_changed = await self._discoverer.discover(response=response)
|
308
342
|
if endpoints_changed:
|
309
343
|
logger.info("OAuth2 endpoints: %s", self._cached_endpoints)
|
310
344
|
self._cached_credentials = None # invalidate credentials tied to old AS
|
311
|
-
|
345
|
+
self._auth_code_provider = None
|
346
|
+
effective_scopes = self._effective_scopes
|
312
347
|
|
313
348
|
# Client registration
|
314
349
|
if not self._cached_credentials:
|
@@ -324,21 +359,20 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
|
324
359
|
self._cached_credentials = await self._registrar.register(self._cached_endpoints, effective_scopes)
|
325
360
|
logger.info("Registered OAuth2 client: %s", self._cached_credentials.client_id)
|
326
361
|
|
327
|
-
def
|
328
|
-
"""
|
329
|
-
Prefer caller-provided scopes; otherwise fall back to AS-advertised scopes_supported.
|
330
|
-
"""
|
331
|
-
return self.config.scopes or self._discoverer.scopes_supported()
|
332
|
-
|
333
|
-
async def _build_oauth2_delegate(self):
|
334
|
-
"""Build NAT OAuth2 provider and delegate auth token acquisition and refresh to it"""
|
362
|
+
async def _nat_oauth2_authenticate(self, user_id: str | None = None) -> AuthResult:
|
363
|
+
"""Perform the OAuth2 flow using MCP-specific authentication flow handler."""
|
335
364
|
from nat.authentication.oauth2.oauth2_auth_code_flow_provider import OAuth2AuthCodeFlowProvider
|
336
|
-
|
365
|
+
|
366
|
+
if not self._cached_endpoints or not self._cached_credentials:
|
367
|
+
# if discovery is yet to to be done return empty auth result
|
368
|
+
return AuthResult(credentials=[], token_expires_at=None, raw={})
|
337
369
|
|
338
370
|
endpoints = self._cached_endpoints
|
339
371
|
credentials = self._cached_credentials
|
340
372
|
|
373
|
+
# Build the OAuth2 provider if not already built
|
341
374
|
if self._auth_code_provider is None:
|
375
|
+
scopes = self._effective_scopes
|
342
376
|
oauth2_config = OAuth2AuthCodeFlowProviderConfig(
|
343
377
|
client_id=credentials.client_id,
|
344
378
|
client_secret=credentials.client_secret or "",
|
@@ -346,22 +380,15 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
|
346
380
|
token_url=str(endpoints.token_url),
|
347
381
|
token_endpoint_auth_method=getattr(self.config, "token_endpoint_auth_method", None),
|
348
382
|
redirect_uri=str(self.config.redirect_uri) if self.config.redirect_uri else "",
|
349
|
-
scopes=
|
383
|
+
scopes=scopes,
|
350
384
|
use_pkce=bool(self.config.use_pkce),
|
351
|
-
|
352
|
-
|
385
|
+
authorization_kwargs={"resource": str(self.config.server_url)})
|
353
386
|
self._auth_code_provider = OAuth2AuthCodeFlowProvider(oauth2_config)
|
354
387
|
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
raise RuntimeError("_perform_oauth2_flow should not be called for RETRY_AFTER_401")
|
360
|
-
|
361
|
-
if not self._cached_endpoints or not self._cached_credentials:
|
362
|
-
raise RuntimeError("OAuth2 flow called before discovery/registration")
|
388
|
+
# Use MCP-specific authentication method if available
|
389
|
+
if hasattr(self._auth_code_provider, "_set_custom_auth_callback"):
|
390
|
+
self._auth_code_provider._set_custom_auth_callback(self._auth_callback
|
391
|
+
or self._flow_handler.authenticate)
|
363
392
|
|
364
|
-
#
|
365
|
-
await self.
|
366
|
-
# Let the delegate handle per-user cache + refresh
|
367
|
-
return await self._auth_code_provider.authenticate()
|
393
|
+
# Auth code provider is responsible for per-user cache + refresh
|
394
|
+
return await self._auth_code_provider.authenticate(user_id=user_id)
|
@@ -18,7 +18,6 @@ from pydantic import HttpUrl
|
|
18
18
|
from pydantic import model_validator
|
19
19
|
|
20
20
|
from nat.authentication.interfaces import AuthProviderBaseConfig
|
21
|
-
from nat.data_models.authentication import AuthRequest
|
22
21
|
|
23
22
|
|
24
23
|
class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
|
@@ -51,12 +50,16 @@ class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
|
|
51
50
|
# Advanced options
|
52
51
|
use_pkce: bool = Field(default=True, description="Use PKCE for authorization code flow")
|
53
52
|
|
54
|
-
|
53
|
+
default_user_id: str | None = Field(default=None, description="Default user ID for authentication")
|
54
|
+
allow_default_user_id_for_tool_calls: bool = Field(default=True, description="Allow default user ID for tool calls")
|
55
55
|
|
56
56
|
@model_validator(mode="after")
|
57
57
|
def validate_auth_config(self):
|
58
58
|
"""Validate authentication configuration for MCP-specific options."""
|
59
59
|
|
60
|
+
# if default_user_id is not provided, use the server_url as the default user id
|
61
|
+
if not self.default_user_id:
|
62
|
+
self.default_user_id = str(self.server_url)
|
60
63
|
# Dynamic registration + MCP discovery
|
61
64
|
if self.enable_dynamic_registration and not self.client_id:
|
62
65
|
# Pure dynamic registration - no explicit credentials needed
|
nat/plugins/mcp/client_base.py
CHANGED
@@ -21,10 +21,12 @@ import logging
|
|
21
21
|
from abc import ABC
|
22
22
|
from abc import abstractmethod
|
23
23
|
from collections.abc import AsyncGenerator
|
24
|
+
from collections.abc import Callable
|
24
25
|
from contextlib import AsyncExitStack
|
25
26
|
from contextlib import asynccontextmanager
|
26
27
|
from datetime import timedelta
|
27
28
|
|
29
|
+
import anyio
|
28
30
|
import httpx
|
29
31
|
|
30
32
|
from mcp import ClientSession
|
@@ -33,9 +35,9 @@ from mcp.client.stdio import StdioServerParameters
|
|
33
35
|
from mcp.client.stdio import stdio_client
|
34
36
|
from mcp.client.streamable_http import streamablehttp_client
|
35
37
|
from mcp.types import TextContent
|
38
|
+
from nat.authentication.interfaces import AuthenticatedContext
|
39
|
+
from nat.authentication.interfaces import AuthFlowType
|
36
40
|
from nat.authentication.interfaces import AuthProviderBase
|
37
|
-
from nat.data_models.authentication import AuthReason
|
38
|
-
from nat.data_models.authentication import AuthRequest
|
39
41
|
from nat.plugins.mcp.exception_handler import convert_to_mcp_error
|
40
42
|
from nat.plugins.mcp.exception_handler import format_mcp_error
|
41
43
|
from nat.plugins.mcp.exception_handler import mcp_exception_handler
|
@@ -53,74 +55,89 @@ class AuthAdapter(httpx.Auth):
|
|
53
55
|
Converts AuthProviderBase to httpx.Auth interface for dynamic token management.
|
54
56
|
"""
|
55
57
|
|
56
|
-
def __init__(self, auth_provider: AuthProviderBase
|
58
|
+
def __init__(self, auth_provider: AuthProviderBase):
|
57
59
|
self.auth_provider = auth_provider
|
58
|
-
|
60
|
+
# each adapter instance has its own lock to avoid unnecessary delays for multiple clients
|
61
|
+
self._lock = anyio.Lock()
|
59
62
|
|
60
63
|
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
61
64
|
"""Add authentication headers to the request using NAT auth provider."""
|
62
|
-
|
63
|
-
if self.auth_for_tool_calls_only and not self._is_tool_call_request(request):
|
64
|
-
# Skip auth for non-tool calls
|
65
|
-
yield request
|
66
|
-
return
|
67
|
-
|
68
|
-
try:
|
69
|
-
# Get fresh auth headers from the NAT auth provider
|
70
|
-
auth_headers = await self._get_auth_headers(reason=AuthReason.NORMAL)
|
71
|
-
request.headers.update(auth_headers)
|
72
|
-
except Exception as e:
|
73
|
-
logger.info("Failed to get auth headers: %s", e)
|
74
|
-
# Continue without auth headers if auth fails
|
75
|
-
|
76
|
-
response = yield request
|
77
|
-
|
78
|
-
# Handle 401 responses by retrying with fresh auth
|
79
|
-
if response.status_code == 401:
|
65
|
+
async with self._lock:
|
80
66
|
try:
|
81
|
-
# Get
|
82
|
-
|
67
|
+
# Get auth headers from the NAT auth provider:
|
68
|
+
# 1. If discovery is yet to done this will return None and request will be sent without auth header.
|
69
|
+
# 2. If discovery is done, this will return the auth header from cache if the token is still valid
|
70
|
+
auth_headers = await self._get_auth_headers(request=request, response=None)
|
83
71
|
request.headers.update(auth_headers)
|
84
|
-
yield request # Retry the request
|
85
72
|
except Exception as e:
|
86
|
-
logger.info("Failed to
|
73
|
+
logger.info("Failed to get auth headers: %s", e)
|
74
|
+
# Continue without auth headers if auth fails
|
75
|
+
|
76
|
+
response = yield request
|
77
|
+
|
78
|
+
# Handle 401 responses by retrying with fresh auth
|
79
|
+
if response.status_code == 401:
|
80
|
+
try:
|
81
|
+
# 401 can happen if:
|
82
|
+
# 1. The request was sent without auth header
|
83
|
+
# 2. The auth headers are invalid
|
84
|
+
# 3. The auth headers are expired
|
85
|
+
# 4. The auth headers are revoked
|
86
|
+
# 5. Auth config on the MCP server has changed
|
87
|
+
# In this case we attempt to re-run discovery and authentication
|
88
|
+
auth_headers = await self._get_auth_headers(request=request, response=response)
|
89
|
+
request.headers.update(auth_headers)
|
90
|
+
yield request # Retry the request
|
91
|
+
except Exception as e:
|
92
|
+
logger.info("Failed to refresh auth after 401: %s", e)
|
87
93
|
return
|
88
94
|
|
89
|
-
def
|
90
|
-
"""Check if this is a tool call request based on the request body.
|
95
|
+
def _get_session_id_from_tool_call_request(self, request: httpx.Request) -> tuple[str | None, bool]:
|
96
|
+
"""Check if this is a tool call request based on the request body.
|
97
|
+
Return the session id if it exists and a boolean indicating if it is a tool call request
|
98
|
+
"""
|
91
99
|
try:
|
92
100
|
# Check if the request body contains a tool call
|
93
101
|
if request.content:
|
94
102
|
body = json.loads(request.content.decode('utf-8'))
|
95
103
|
# Check if it's a JSON-RPC request with method "tools/call"
|
96
104
|
if (isinstance(body, dict) and body.get("method") == "tools/call"):
|
97
|
-
|
105
|
+
session_id = body.get("params").get("_meta").get("session_id")
|
106
|
+
return session_id, True
|
98
107
|
except (json.JSONDecodeError, UnicodeDecodeError, AttributeError):
|
99
108
|
# If we can't parse the body, assume it's not a tool call
|
100
109
|
pass
|
101
|
-
return False
|
110
|
+
return None, False
|
102
111
|
|
103
|
-
async def _get_auth_headers(self,
|
112
|
+
async def _get_auth_headers(self,
|
113
|
+
request: httpx.Request | None = None,
|
114
|
+
response: httpx.Response | None = None) -> dict[str, str]:
|
104
115
|
"""Get authentication headers from the NAT auth provider."""
|
105
|
-
# Build auth request
|
106
|
-
www_authenticate = response.headers.get("WWW-Authenticate", None) if response else None
|
107
|
-
auth_request = AuthRequest(
|
108
|
-
reason=reason,
|
109
|
-
www_authenticate=www_authenticate,
|
110
|
-
)
|
111
116
|
try:
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
+
session_id = None
|
118
|
+
is_tool_call = False
|
119
|
+
if request:
|
120
|
+
session_id, is_tool_call = self._get_session_id_from_tool_call_request(request)
|
121
|
+
|
122
|
+
if is_tool_call:
|
123
|
+
# Tool call requests should use the session id if it exists, default user id can be used if allowed
|
124
|
+
if self.auth_provider.config.allow_default_user_id_for_tool_calls:
|
125
|
+
user_id = session_id or self.auth_provider.config.default_user_id
|
126
|
+
else:
|
127
|
+
user_id = session_id
|
128
|
+
else:
|
129
|
+
# Non-tool call requests should use the session id if it exists and fallback to default user id
|
130
|
+
user_id = session_id or self.auth_provider.config.default_user_id
|
131
|
+
|
132
|
+
auth_result = await self.auth_provider.authenticate(user_id=user_id, response=response)
|
133
|
+
|
117
134
|
# Check if we have BearerTokenCred
|
118
135
|
from nat.data_models.authentication import BearerTokenCred
|
119
136
|
if auth_result.credentials and isinstance(auth_result.credentials[0], BearerTokenCred):
|
120
137
|
token = auth_result.credentials[0].token.get_secret_value()
|
121
138
|
return {"Authorization": f"Bearer {token}"}
|
122
139
|
else:
|
123
|
-
logger.
|
140
|
+
logger.info("Auth provider did not return BearerTokenCred")
|
124
141
|
return {}
|
125
142
|
except Exception as e:
|
126
143
|
logger.warning("Failed to get auth token: %s", e)
|
@@ -155,6 +172,7 @@ class MCPBaseClient(ABC):
|
|
155
172
|
self._initial_connection = False
|
156
173
|
|
157
174
|
# Convert auth provider to AuthAdapter
|
175
|
+
self._auth_provider = auth_provider
|
158
176
|
self._httpx_auth = AuthAdapter(auth_provider) if auth_provider else None
|
159
177
|
|
160
178
|
self._tool_call_timeout = tool_call_timeout
|
@@ -314,6 +332,34 @@ class MCPBaseClient(ABC):
|
|
314
332
|
raise MCPToolNotFoundError(tool_name, self.server_name)
|
315
333
|
return tool
|
316
334
|
|
335
|
+
def set_user_auth_callback(self, auth_callback: Callable[[AuthFlowType], AuthenticatedContext]):
|
336
|
+
"""Set the user authentication callback."""
|
337
|
+
if self._auth_provider and hasattr(self._auth_provider, "_set_custom_auth_callback"):
|
338
|
+
self._auth_provider._set_custom_auth_callback(auth_callback)
|
339
|
+
|
340
|
+
@mcp_exception_handler
|
341
|
+
async def call_tool_with_meta(self, tool_name: str, args: dict, session_id: str):
|
342
|
+
from mcp.types import CallToolRequest
|
343
|
+
from mcp.types import CallToolRequestParams
|
344
|
+
from mcp.types import CallToolResult
|
345
|
+
from mcp.types import ClientRequest
|
346
|
+
|
347
|
+
if not self._session:
|
348
|
+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
349
|
+
|
350
|
+
async def _call_tool_with_meta():
|
351
|
+
params = CallToolRequestParams(name=tool_name, arguments=args, **{"_meta": {"session_id": session_id}})
|
352
|
+
req = ClientRequest(CallToolRequest(params=params))
|
353
|
+
# We will increase the timeout to 5 minutes if the tool call timeout is less than 5 min and
|
354
|
+
# auth is enabled.
|
355
|
+
if self._auth_provider and self._tool_call_timeout.total_seconds() < 300:
|
356
|
+
timeout = timedelta(seconds=300)
|
357
|
+
else:
|
358
|
+
timeout = self._tool_call_timeout
|
359
|
+
return await self._session.send_request(req, CallToolResult, request_read_timeout_seconds=timeout)
|
360
|
+
|
361
|
+
return await self._with_reconnect(_call_tool_with_meta)
|
362
|
+
|
317
363
|
@mcp_exception_handler
|
318
364
|
async def call_tool(self, tool_name: str, tool_args: dict | None):
|
319
365
|
|
@@ -539,9 +585,35 @@ class MCPToolClient:
|
|
539
585
|
Args:
|
540
586
|
tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
|
541
587
|
"""
|
542
|
-
|
588
|
+
if self._session is None:
|
589
|
+
raise RuntimeError("No session available for tool call")
|
590
|
+
|
591
|
+
# Extract context information
|
592
|
+
session_id = None
|
593
|
+
try:
|
594
|
+
from nat.builder.context import Context as _Ctx
|
595
|
+
|
596
|
+
# get auth callback (for example: WebSocketAuthenticationFlowHandler). this is lazily set in the client
|
597
|
+
# on first tool call
|
598
|
+
auth_callback = _Ctx.get().user_auth_callback
|
599
|
+
if auth_callback and self._parent_client:
|
600
|
+
# set custom auth callback
|
601
|
+
self._parent_client.set_user_auth_callback(auth_callback)
|
602
|
+
|
603
|
+
# get session id from context, authentication is done per-websocket session for tool calls
|
604
|
+
cookies = getattr(_Ctx.get().metadata, "cookies", None)
|
605
|
+
if cookies:
|
606
|
+
session_id = cookies.get("nat-session")
|
607
|
+
except Exception:
|
608
|
+
pass
|
609
|
+
|
543
610
|
try:
|
544
|
-
|
611
|
+
if session_id:
|
612
|
+
logger.info("Calling tool %s with arguments %s for a user session", self._tool_name, tool_args)
|
613
|
+
result = await self._parent_client.call_tool_with_meta(self._tool_name, tool_args, session_id)
|
614
|
+
else:
|
615
|
+
logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
|
616
|
+
result = await self._session.call_tool(self._tool_name, tool_args)
|
545
617
|
|
546
618
|
output = []
|
547
619
|
for res in result.content:
|
nat/plugins/mcp/client_impl.py
CHANGED
@@ -56,7 +56,8 @@ class MCPServerConfig(BaseModel):
|
|
56
56
|
env: dict[str, str] | None = Field(default=None, description="Environment variables for the stdio process")
|
57
57
|
|
58
58
|
# Authentication configuration
|
59
|
-
auth_provider: AuthenticationRef | None = Field(default=None,
|
59
|
+
auth_provider: str | AuthenticationRef | None = Field(default=None,
|
60
|
+
description="Reference to authentication provider")
|
60
61
|
|
61
62
|
@model_validator(mode="after")
|
62
63
|
def validate_model(self):
|
@@ -92,7 +93,8 @@ class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
|
|
92
93
|
"""
|
93
94
|
server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)")
|
94
95
|
tool_call_timeout: timedelta = Field(
|
95
|
-
default=timedelta(seconds=
|
96
|
+
default=timedelta(seconds=60),
|
97
|
+
description="Timeout (in seconds) for the MCP tool call. Defaults to 60 seconds.")
|
96
98
|
reconnect_enabled: bool = Field(
|
97
99
|
default=True,
|
98
100
|
description="Whether to enable reconnecting to the MCP server if the connection is lost. \
|
{nvidia_nat_mcp-1.3.0a20250926.dist-info → nvidia_nat_mcp-1.3.0a20250928.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: nvidia-nat-mcp
|
3
|
-
Version: 1.3.
|
3
|
+
Version: 1.3.0a20250928
|
4
4
|
Summary: Subpackage for MCP client integration in NeMo Agent toolkit
|
5
5
|
Keywords: ai,rag,agents,mcp
|
6
6
|
Classifier: Programming Language :: Python
|
@@ -9,7 +9,7 @@ Classifier: Programming Language :: Python :: 3.12
|
|
9
9
|
Classifier: Programming Language :: Python :: 3.13
|
10
10
|
Requires-Python: <3.14,>=3.11
|
11
11
|
Description-Content-Type: text/markdown
|
12
|
-
Requires-Dist: nvidia-nat==v1.3.
|
12
|
+
Requires-Dist: nvidia-nat==v1.3.0a20250928
|
13
13
|
Requires-Dist: mcp~=1.14
|
14
14
|
|
15
15
|
<!--
|
@@ -1,18 +1,19 @@
|
|
1
1
|
nat/meta/pypi.md,sha256=GyV4DI1d9ThgEhnYTQ0vh40Q9hPC8jN-goLnRiFDmZ8,1498
|
2
2
|
nat/plugins/mcp/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
3
|
-
nat/plugins/mcp/client_base.py,sha256=
|
4
|
-
nat/plugins/mcp/client_impl.py,sha256=
|
3
|
+
nat/plugins/mcp/client_base.py,sha256=x6NHs3X2alyB6Wo4pjWZrGyuU7lLqGC_cLO9fuL4Zgw,25194
|
4
|
+
nat/plugins/mcp/client_impl.py,sha256=M1gTMlp3RLhFaAHOvwkk38boFy05MixV_glrIEcMjvo,10759
|
5
5
|
nat/plugins/mcp/exception_handler.py,sha256=JdPdZG1NgWpdRnIz7JTGHiJASS5wot9nJiD3SRWV4Kw,7649
|
6
6
|
nat/plugins/mcp/exceptions.py,sha256=EGVOnYlui8xufm8dhJyPL1SUqBLnCGOTvRoeyNcmcWE,5980
|
7
7
|
nat/plugins/mcp/register.py,sha256=HOT2Wl2isGuyFc7BUTi58-BbjI5-EtZMZo7stsv5pN4,831
|
8
8
|
nat/plugins/mcp/tool.py,sha256=v3MFsiaLJy8Ourcfqa6ohtAE2Nn-vqpC6Q6gsCdJ28Q,6165
|
9
9
|
nat/plugins/mcp/utils.py,sha256=3fuzYpC14wrfMOTOGvY2KHWcxZvBWqrxdDZD17lhmC8,4055
|
10
10
|
nat/plugins/mcp/auth/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
11
|
-
nat/plugins/mcp/auth/
|
12
|
-
nat/plugins/mcp/auth/
|
11
|
+
nat/plugins/mcp/auth/auth_flow_handler.py,sha256=eRqRS2t-YzJ3kEFaG0PEC8DzctYzaJzr9XLZGkvuxq0,6018
|
12
|
+
nat/plugins/mcp/auth/auth_provider.py,sha256=5TTaPlIXMgwrE4YcZ_HO9-GNBBFaDTdUKAa5fuALayI,19142
|
13
|
+
nat/plugins/mcp/auth/auth_provider_config.py,sha256=vhU47Vcp_30M8tWu0FumbJ6pdUnFbBZm-ABdNlup__U,3821
|
13
14
|
nat/plugins/mcp/auth/register.py,sha256=yzphsn1I4a5G39_IacbuX0ZQqGM8fevvTUM_B94UXKE,1211
|
14
|
-
nvidia_nat_mcp-1.3.
|
15
|
-
nvidia_nat_mcp-1.3.
|
16
|
-
nvidia_nat_mcp-1.3.
|
17
|
-
nvidia_nat_mcp-1.3.
|
18
|
-
nvidia_nat_mcp-1.3.
|
15
|
+
nvidia_nat_mcp-1.3.0a20250928.dist-info/METADATA,sha256=VnuxYtll39Hir1_hDQvhqRWPM00M6A9JD3NCd7JISaw,1997
|
16
|
+
nvidia_nat_mcp-1.3.0a20250928.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
17
|
+
nvidia_nat_mcp-1.3.0a20250928.dist-info/entry_points.txt,sha256=rYvUp4i-klBr3bVNh7zYOPXret704vTjvCk1qd7FooI,97
|
18
|
+
nvidia_nat_mcp-1.3.0a20250928.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
|
19
|
+
nvidia_nat_mcp-1.3.0a20250928.dist-info/RECORD,,
|
File without changes
|
{nvidia_nat_mcp-1.3.0a20250926.dist-info → nvidia_nat_mcp-1.3.0a20250928.dist-info}/entry_points.txt
RENAMED
File without changes
|
{nvidia_nat_mcp-1.3.0a20250926.dist-info → nvidia_nat_mcp-1.3.0a20250928.dist-info}/top_level.txt
RENAMED
File without changes
|