nvidia-nat-mcp 1.4.0a20260107__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nvidia-nat-mcp might be problematic. Click here for more details.
- nat/meta/pypi.md +32 -0
- nat/plugins/mcp/__init__.py +14 -0
- nat/plugins/mcp/auth/__init__.py +14 -0
- nat/plugins/mcp/auth/auth_flow_handler.py +208 -0
- nat/plugins/mcp/auth/auth_provider.py +431 -0
- nat/plugins/mcp/auth/auth_provider_config.py +86 -0
- nat/plugins/mcp/auth/register.py +33 -0
- nat/plugins/mcp/auth/service_account/__init__.py +14 -0
- nat/plugins/mcp/auth/service_account/provider.py +136 -0
- nat/plugins/mcp/auth/service_account/provider_config.py +137 -0
- nat/plugins/mcp/auth/service_account/token_client.py +156 -0
- nat/plugins/mcp/auth/token_storage.py +265 -0
- nat/plugins/mcp/cli/__init__.py +15 -0
- nat/plugins/mcp/cli/commands.py +1051 -0
- nat/plugins/mcp/client/__init__.py +15 -0
- nat/plugins/mcp/client/client_base.py +665 -0
- nat/plugins/mcp/client/client_config.py +146 -0
- nat/plugins/mcp/client/client_impl.py +782 -0
- nat/plugins/mcp/exception_handler.py +211 -0
- nat/plugins/mcp/exceptions.py +142 -0
- nat/plugins/mcp/register.py +23 -0
- nat/plugins/mcp/server/__init__.py +15 -0
- nat/plugins/mcp/server/front_end_config.py +109 -0
- nat/plugins/mcp/server/front_end_plugin.py +155 -0
- nat/plugins/mcp/server/front_end_plugin_worker.py +411 -0
- nat/plugins/mcp/server/introspection_token_verifier.py +72 -0
- nat/plugins/mcp/server/memory_profiler.py +320 -0
- nat/plugins/mcp/server/register_frontend.py +27 -0
- nat/plugins/mcp/server/tool_converter.py +286 -0
- nat/plugins/mcp/utils.py +228 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/METADATA +55 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/RECORD +37 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/WHEEL +5 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/entry_points.txt +9 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,431 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from collections.abc import Awaitable
|
|
18
|
+
from collections.abc import Callable
|
|
19
|
+
from urllib.parse import urljoin
|
|
20
|
+
from urllib.parse import urlparse
|
|
21
|
+
|
|
22
|
+
import httpx
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
from pydantic import Field
|
|
25
|
+
from pydantic import HttpUrl
|
|
26
|
+
from pydantic import TypeAdapter
|
|
27
|
+
|
|
28
|
+
from mcp.shared.auth import OAuthClientInformationFull
|
|
29
|
+
from mcp.shared.auth import OAuthClientMetadata
|
|
30
|
+
from mcp.shared.auth import OAuthMetadata
|
|
31
|
+
from mcp.shared.auth import ProtectedResourceMetadata
|
|
32
|
+
from nat.authentication.interfaces import AuthenticatedContext
|
|
33
|
+
from nat.authentication.interfaces import AuthFlowType
|
|
34
|
+
from nat.authentication.interfaces import AuthProviderBase
|
|
35
|
+
from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
|
|
36
|
+
from nat.data_models.authentication import AuthResult
|
|
37
|
+
from nat.data_models.common import get_secret_value
|
|
38
|
+
from nat.plugins.mcp.auth.auth_flow_handler import MCPAuthenticationFlowHandler
|
|
39
|
+
from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig
|
|
40
|
+
|
|
41
|
+
logger = logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class OAuth2Endpoints(BaseModel):
|
|
45
|
+
"""OAuth2 endpoints discovered from MCP server."""
|
|
46
|
+
authorization_url: HttpUrl = Field(..., description="OAuth2 authorization endpoint URL")
|
|
47
|
+
token_url: HttpUrl = Field(..., description="OAuth2 token endpoint URL")
|
|
48
|
+
registration_url: HttpUrl | None = Field(default=None, description="OAuth2 client registration endpoint URL")
|
|
49
|
+
scopes: list[str] | None = Field(default=None, description="OAuth2 scopes to be used for the authentication")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class OAuth2Credentials(BaseModel):
|
|
53
|
+
"""OAuth2 client credentials from registration."""
|
|
54
|
+
client_id: str = Field(..., description="OAuth2 client identifier")
|
|
55
|
+
client_secret: str | None = Field(default=None, description="OAuth2 client secret")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class DiscoverOAuth2Endpoints:
|
|
59
|
+
"""
|
|
60
|
+
MCP-SDK parity discovery flow:
|
|
61
|
+
1) If 401 + WWW-Authenticate has resource_metadata (RFC 9728), fetch it.
|
|
62
|
+
2) Else fetch RS well-known /.well-known/oauth-protected-resource.
|
|
63
|
+
3) If PR metadata lists authorization_servers, pick first as issuer.
|
|
64
|
+
4) Do path-aware RFC 8414 / OIDC discovery against issuer (or server base).
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, config: MCPOAuth2ProviderConfig):
|
|
68
|
+
self.config = config
|
|
69
|
+
self._cached_endpoints: OAuth2Endpoints | None = None
|
|
70
|
+
|
|
71
|
+
self._flow_handler: MCPAuthenticationFlowHandler = MCPAuthenticationFlowHandler()
|
|
72
|
+
|
|
73
|
+
async def discover(self, response: httpx.Response | None = None) -> tuple[OAuth2Endpoints, bool]:
|
|
74
|
+
"""
|
|
75
|
+
Discover OAuth2 endpoints from MCP server.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
reason: The reason for the discovery.
|
|
79
|
+
www_authenticate: The WWW-Authenticate header from a 401 response.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
A tuple of OAuth2Endpoints and a boolean indicating if the endpoints have changed.
|
|
83
|
+
"""
|
|
84
|
+
is_401_retry = response is not None and response.status_code == 401
|
|
85
|
+
# Fast path: reuse cache when not a 401 retry
|
|
86
|
+
if not is_401_retry and self._cached_endpoints is not None:
|
|
87
|
+
return self._cached_endpoints, False
|
|
88
|
+
|
|
89
|
+
issuer: str = str(self.config.server_url) # default to server URL
|
|
90
|
+
endpoints: OAuth2Endpoints | None = None
|
|
91
|
+
|
|
92
|
+
# 1) 401 hint (RFC 9728) if present
|
|
93
|
+
if is_401_retry and response:
|
|
94
|
+
www_authenticate = response.headers.get("WWW-Authenticate")
|
|
95
|
+
if www_authenticate:
|
|
96
|
+
hint_url = self._extract_from_www_authenticate_header(www_authenticate)
|
|
97
|
+
if hint_url:
|
|
98
|
+
logger.info("Using RFC 9728 resource_metadata hint: %s", hint_url)
|
|
99
|
+
issuer_hint = await self._fetch_pr_issuer(hint_url)
|
|
100
|
+
if issuer_hint:
|
|
101
|
+
issuer = issuer_hint
|
|
102
|
+
|
|
103
|
+
# 2) Try RS protected resource well-known if we still only have default issuer
|
|
104
|
+
if issuer == str(self.config.server_url):
|
|
105
|
+
pr_url = urljoin(self._authorization_base_url(), "/.well-known/oauth-protected-resource")
|
|
106
|
+
try:
|
|
107
|
+
logger.debug("Fetching protected resource metadata: %s", pr_url)
|
|
108
|
+
issuer2 = await self._fetch_pr_issuer(pr_url)
|
|
109
|
+
if issuer2:
|
|
110
|
+
issuer = issuer2
|
|
111
|
+
except Exception as e:
|
|
112
|
+
logger.debug("Protected resource metadata not available: %s", e)
|
|
113
|
+
|
|
114
|
+
# 3) Path-aware RFC 8414 / OIDC discovery using issuer (or server base)
|
|
115
|
+
endpoints = await self._discover_via_issuer_or_base(issuer)
|
|
116
|
+
if endpoints is None:
|
|
117
|
+
raise RuntimeError("Could not discover OAuth2 endpoints from MCP server")
|
|
118
|
+
|
|
119
|
+
changed = (self._cached_endpoints is None or endpoints.model_dump() != self._cached_endpoints.model_dump())
|
|
120
|
+
self._cached_endpoints = endpoints
|
|
121
|
+
logger.info("OAuth2 endpoints selected: %s", self._cached_endpoints)
|
|
122
|
+
return self._cached_endpoints, changed
|
|
123
|
+
|
|
124
|
+
# --------------------------- helpers ---------------------------
|
|
125
|
+
def _authorization_base_url(self) -> str:
|
|
126
|
+
"""Get the authorization base URL from the MCP server URL."""
|
|
127
|
+
p = urlparse(str(self.config.server_url))
|
|
128
|
+
return f"{p.scheme}://{p.netloc}"
|
|
129
|
+
|
|
130
|
+
def _extract_from_www_authenticate_header(self, hdr: str) -> str | None:
|
|
131
|
+
"""Extract the resource_metadata URL from the WWW-Authenticate header."""
|
|
132
|
+
import re
|
|
133
|
+
|
|
134
|
+
if not hdr:
|
|
135
|
+
return None
|
|
136
|
+
# resource_metadata="url" | 'url' | url (case-insensitive; stop on space/comma/semicolon)
|
|
137
|
+
m = re.search(r'(?i)\bresource_metadata\s*=\s*(?:"([^"]+)"|\'([^\']+)\'|([^\s,;]+))', hdr)
|
|
138
|
+
if not m:
|
|
139
|
+
return None
|
|
140
|
+
url = next((g for g in m.groups() if g), None)
|
|
141
|
+
if url:
|
|
142
|
+
logger.debug("Extracted resource_metadata URL: %s", url)
|
|
143
|
+
return url
|
|
144
|
+
|
|
145
|
+
async def _fetch_pr_issuer(self, url: str) -> str | None:
|
|
146
|
+
"""Fetch RFC 9728 Protected Resource Metadata and return the first issuer (authorization_server)."""
|
|
147
|
+
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
148
|
+
resp = await client.get(url, headers={"Accept": "application/json"})
|
|
149
|
+
resp.raise_for_status()
|
|
150
|
+
body = await resp.aread()
|
|
151
|
+
try:
|
|
152
|
+
pr = ProtectedResourceMetadata.model_validate_json(body)
|
|
153
|
+
except Exception as e:
|
|
154
|
+
logger.debug("Invalid ProtectedResourceMetadata at %s: %s", url, e)
|
|
155
|
+
return None
|
|
156
|
+
if pr.authorization_servers:
|
|
157
|
+
return str(pr.authorization_servers[0])
|
|
158
|
+
return None
|
|
159
|
+
|
|
160
|
+
async def _discover_via_issuer_or_base(self, base_or_issuer: str) -> OAuth2Endpoints | None:
|
|
161
|
+
"""Perform path-aware RFC 8414 / OIDC discovery given an issuer or base URL."""
|
|
162
|
+
urls = self._build_path_aware_discovery_urls(base_or_issuer)
|
|
163
|
+
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
164
|
+
for url in urls:
|
|
165
|
+
try:
|
|
166
|
+
resp = await client.get(url, follow_redirects=True, headers={"Accept": "application/json"})
|
|
167
|
+
if resp.status_code != 200:
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
# Check content type before attempting JSON parsing
|
|
171
|
+
content_type = resp.headers.get("content-type", "").lower()
|
|
172
|
+
if "application/json" not in content_type:
|
|
173
|
+
logger.info(
|
|
174
|
+
"Discovery endpoint %s returned non-JSON content type: %s. "
|
|
175
|
+
"This may indicate the endpoint doesn't support discovery or requires authentication.",
|
|
176
|
+
url,
|
|
177
|
+
content_type)
|
|
178
|
+
# If it's HTML, log a more helpful message
|
|
179
|
+
if "text/html" in content_type:
|
|
180
|
+
logger.info("The endpoint appears to be returning an HTML page instead of OAuth metadata. "
|
|
181
|
+
"This often means:")
|
|
182
|
+
logger.info("1. The OAuth discovery endpoint doesn't exist at this URL")
|
|
183
|
+
logger.info("2. The server requires authentication before providing discovery metadata")
|
|
184
|
+
logger.info("3. The URL is pointing to a web application instead of an OAuth server")
|
|
185
|
+
continue
|
|
186
|
+
|
|
187
|
+
body = await resp.aread()
|
|
188
|
+
|
|
189
|
+
try:
|
|
190
|
+
meta = OAuthMetadata.model_validate_json(body)
|
|
191
|
+
except Exception as e:
|
|
192
|
+
logger.debug("Invalid OAuthMetadata at %s: %s", url, e)
|
|
193
|
+
continue
|
|
194
|
+
if meta.authorization_endpoint and meta.token_endpoint:
|
|
195
|
+
logger.info("Discovered OAuth2 endpoints from %s", url)
|
|
196
|
+
# Convert AnyHttpUrl to HttpUrl using TypeAdapter
|
|
197
|
+
http_url_adapter = TypeAdapter(HttpUrl)
|
|
198
|
+
return OAuth2Endpoints(
|
|
199
|
+
authorization_url=http_url_adapter.validate_python(str(meta.authorization_endpoint)),
|
|
200
|
+
token_url=http_url_adapter.validate_python(str(meta.token_endpoint)),
|
|
201
|
+
registration_url=http_url_adapter.validate_python(str(meta.registration_endpoint))
|
|
202
|
+
if meta.registration_endpoint else None,
|
|
203
|
+
scopes=meta.scopes_supported,
|
|
204
|
+
)
|
|
205
|
+
except Exception as e:
|
|
206
|
+
logger.debug("Discovery failed at %s: %s", url, e)
|
|
207
|
+
|
|
208
|
+
# If we get here, all discovery URLs failed
|
|
209
|
+
logger.info("OAuth discovery failed for all attempted URLs.")
|
|
210
|
+
logger.info("Attempted URLs: %s", urls)
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
def _build_path_aware_discovery_urls(self, base_or_issuer: str) -> list[str]:
|
|
214
|
+
"""Build path-aware discovery URLs."""
|
|
215
|
+
p = urlparse(base_or_issuer)
|
|
216
|
+
base = f"{p.scheme}://{p.netloc}"
|
|
217
|
+
path = (p.path or "").rstrip("/")
|
|
218
|
+
urls: list[str] = []
|
|
219
|
+
if path:
|
|
220
|
+
# this is the specified by the MCP spec
|
|
221
|
+
urls.append(urljoin(base, f".well-known/oauth-protected-resource{path}"))
|
|
222
|
+
# this is fallback for backward compatibility
|
|
223
|
+
urls.append(urljoin(base, f"{path}/.well-known/oauth-authorization-server"))
|
|
224
|
+
urls.append(urljoin(base, "/.well-known/oauth-authorization-server"))
|
|
225
|
+
if path:
|
|
226
|
+
# this is the specified by the MCP spec
|
|
227
|
+
urls.append(urljoin(base, f".well-known/openid-configuration{path}"))
|
|
228
|
+
# this is fallback for backward compatibility
|
|
229
|
+
urls.append(urljoin(base, f"{path}/.well-known/openid-configuration"))
|
|
230
|
+
urls.append(base_or_issuer.rstrip("/") + "/.well-known/openid-configuration")
|
|
231
|
+
return urls
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class DynamicClientRegistration:
|
|
235
|
+
"""Dynamic client registration utility."""
|
|
236
|
+
|
|
237
|
+
def __init__(self, config: MCPOAuth2ProviderConfig):
|
|
238
|
+
self.config = config
|
|
239
|
+
|
|
240
|
+
def _authorization_base_url(self) -> str:
|
|
241
|
+
"""Get the authorization base URL from the MCP server URL."""
|
|
242
|
+
p = urlparse(str(self.config.server_url))
|
|
243
|
+
return f"{p.scheme}://{p.netloc}"
|
|
244
|
+
|
|
245
|
+
async def register(self, endpoints: OAuth2Endpoints, scopes: list[str] | None) -> OAuth2Credentials:
|
|
246
|
+
"""Register an OAuth2 client with the Authorization Server using OIDC client registration."""
|
|
247
|
+
# Fallback to /register if metadata didn't provide an endpoint
|
|
248
|
+
registration_url = (str(endpoints.registration_url) if endpoints.registration_url else urljoin(
|
|
249
|
+
self._authorization_base_url(), "/register"))
|
|
250
|
+
|
|
251
|
+
metadata = OAuthClientMetadata(
|
|
252
|
+
redirect_uris=[self.config.redirect_uri],
|
|
253
|
+
token_endpoint_auth_method=(getattr(self.config, "token_endpoint_auth_method", None)
|
|
254
|
+
or "client_secret_post"),
|
|
255
|
+
grant_types=["authorization_code", "refresh_token"],
|
|
256
|
+
response_types=["code"],
|
|
257
|
+
scope=" ".join(scopes) if scopes else None,
|
|
258
|
+
client_name=self.config.client_name or None,
|
|
259
|
+
)
|
|
260
|
+
payload = metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
|
|
261
|
+
|
|
262
|
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
263
|
+
resp = await client.post(
|
|
264
|
+
registration_url,
|
|
265
|
+
json=payload,
|
|
266
|
+
headers={
|
|
267
|
+
"Content-Type": "application/json", "Accept": "application/json"
|
|
268
|
+
},
|
|
269
|
+
)
|
|
270
|
+
resp.raise_for_status()
|
|
271
|
+
body = await resp.aread()
|
|
272
|
+
|
|
273
|
+
try:
|
|
274
|
+
info = OAuthClientInformationFull.model_validate_json(body)
|
|
275
|
+
except Exception as e:
|
|
276
|
+
raise RuntimeError(
|
|
277
|
+
f"Registration response was not valid OAuthClientInformation from {registration_url}") from e
|
|
278
|
+
|
|
279
|
+
if not info.client_id:
|
|
280
|
+
raise RuntimeError("No client_id received from registration")
|
|
281
|
+
|
|
282
|
+
logger.info("Successfully registered OAuth2 client: %s", info.client_id)
|
|
283
|
+
return OAuth2Credentials(client_id=info.client_id, client_secret=info.client_secret)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
|
287
|
+
"""MCP OAuth2 authentication provider that delegates to NAT framework."""
|
|
288
|
+
|
|
289
|
+
def __init__(self, config: MCPOAuth2ProviderConfig, builder=None):
|
|
290
|
+
super().__init__(config)
|
|
291
|
+
self._builder = builder
|
|
292
|
+
|
|
293
|
+
# Discovery
|
|
294
|
+
self._discoverer = DiscoverOAuth2Endpoints(config)
|
|
295
|
+
self._cached_endpoints: OAuth2Endpoints | None = None
|
|
296
|
+
|
|
297
|
+
# Client registration
|
|
298
|
+
self._registrar = DynamicClientRegistration(config)
|
|
299
|
+
self._cached_credentials: OAuth2Credentials | None = None
|
|
300
|
+
|
|
301
|
+
# For the OAuth2 flow
|
|
302
|
+
self._auth_code_provider = None
|
|
303
|
+
self._flow_handler = MCPAuthenticationFlowHandler()
|
|
304
|
+
|
|
305
|
+
self._auth_callback = None
|
|
306
|
+
|
|
307
|
+
# Initialize token storage
|
|
308
|
+
self._token_storage = None
|
|
309
|
+
self._token_storage_object_store_name = None
|
|
310
|
+
|
|
311
|
+
if self.config.token_storage_object_store:
|
|
312
|
+
# Store object store name, will be resolved later when builder context is available
|
|
313
|
+
self._token_storage_object_store_name = self.config.token_storage_object_store
|
|
314
|
+
logger.info(f"Configured to use object store '{self._token_storage_object_store_name}' for token storage")
|
|
315
|
+
else:
|
|
316
|
+
# Default: use in-memory token storage
|
|
317
|
+
from .token_storage import InMemoryTokenStorage
|
|
318
|
+
self._token_storage = InMemoryTokenStorage()
|
|
319
|
+
|
|
320
|
+
def _set_custom_auth_callback(self,
|
|
321
|
+
auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType],
|
|
322
|
+
Awaitable[AuthenticatedContext]]):
|
|
323
|
+
"""Set the custom authentication callback."""
|
|
324
|
+
if not self._auth_callback:
|
|
325
|
+
logger.info("Using custom authentication callback")
|
|
326
|
+
self._auth_callback = auth_callback
|
|
327
|
+
if self._auth_code_provider:
|
|
328
|
+
self._auth_code_provider._set_custom_auth_callback(self._auth_callback) # type: ignore[arg-type]
|
|
329
|
+
|
|
330
|
+
async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
|
|
331
|
+
"""
|
|
332
|
+
Authenticate using MCP OAuth2 flow via NAT framework.
|
|
333
|
+
|
|
334
|
+
If response is provided in kwargs (typically from a 401), performs:
|
|
335
|
+
1. Dynamic endpoints discovery (RFC9728 + RFC 8414 + OIDC)
|
|
336
|
+
2. Client registration (RFC7591)
|
|
337
|
+
3. Authentication
|
|
338
|
+
|
|
339
|
+
Otherwise, performs standard authentication flow.
|
|
340
|
+
"""
|
|
341
|
+
if not user_id:
|
|
342
|
+
# MCP tool calls cannot be made without an authorized user
|
|
343
|
+
raise RuntimeError("User is not authorized to call the tool")
|
|
344
|
+
|
|
345
|
+
response = kwargs.get('response')
|
|
346
|
+
if response and response.status_code == 401:
|
|
347
|
+
await self._discover_and_register(response=response)
|
|
348
|
+
|
|
349
|
+
return await self._nat_oauth2_authenticate(user_id=user_id)
|
|
350
|
+
|
|
351
|
+
@property
|
|
352
|
+
def _effective_scopes(self) -> list[str]:
|
|
353
|
+
"""Get the effective scopes to be used for the authentication."""
|
|
354
|
+
return self.config.scopes or (self._cached_endpoints.scopes if self._cached_endpoints else []) or []
|
|
355
|
+
|
|
356
|
+
async def _discover_and_register(self, response: httpx.Response | None = None):
|
|
357
|
+
"""
|
|
358
|
+
Discover OAuth2 endpoints and register an OAuth2 client with the Authorization Server
|
|
359
|
+
using OIDC client registration.
|
|
360
|
+
"""
|
|
361
|
+
# Discover OAuth2 endpoints
|
|
362
|
+
self._cached_endpoints, endpoints_changed = await self._discoverer.discover(response=response)
|
|
363
|
+
if endpoints_changed:
|
|
364
|
+
logger.info("OAuth2 endpoints: %s", self._cached_endpoints)
|
|
365
|
+
self._cached_credentials = None # invalidate credentials tied to old AS
|
|
366
|
+
self._auth_code_provider = None
|
|
367
|
+
effective_scopes = self._effective_scopes
|
|
368
|
+
|
|
369
|
+
# Client registration
|
|
370
|
+
if not self._cached_credentials:
|
|
371
|
+
if self.config.client_id:
|
|
372
|
+
# Manual registration mode
|
|
373
|
+
self._cached_credentials = OAuth2Credentials(
|
|
374
|
+
client_id=self.config.client_id,
|
|
375
|
+
client_secret=get_secret_value(self.config.client_secret),
|
|
376
|
+
)
|
|
377
|
+
logger.info("Using manual client_id: %s", self._cached_credentials.client_id)
|
|
378
|
+
else:
|
|
379
|
+
# Dynamic registration mode requires registration endpoint
|
|
380
|
+
self._cached_credentials = await self._registrar.register(self._cached_endpoints, effective_scopes)
|
|
381
|
+
logger.info("Registered OAuth2 client: %s", self._cached_credentials.client_id)
|
|
382
|
+
|
|
383
|
+
async def _nat_oauth2_authenticate(self, user_id: str | None = None) -> AuthResult:
|
|
384
|
+
"""Perform the OAuth2 flow using MCP-specific authentication flow handler."""
|
|
385
|
+
from nat.authentication.oauth2.oauth2_auth_code_flow_provider import OAuth2AuthCodeFlowProvider
|
|
386
|
+
|
|
387
|
+
if not self._cached_endpoints or not self._cached_credentials:
|
|
388
|
+
# if discovery is yet to to be done return empty auth result
|
|
389
|
+
return AuthResult(credentials=[], token_expires_at=None, raw={})
|
|
390
|
+
|
|
391
|
+
endpoints = self._cached_endpoints
|
|
392
|
+
credentials = self._cached_credentials
|
|
393
|
+
|
|
394
|
+
# Resolve object store reference if needed
|
|
395
|
+
if self._token_storage_object_store_name and not self._token_storage:
|
|
396
|
+
try:
|
|
397
|
+
if not self._builder:
|
|
398
|
+
raise RuntimeError("Builder not available for resolving object store")
|
|
399
|
+
object_store = await self._builder.get_object_store_client(self._token_storage_object_store_name)
|
|
400
|
+
from .token_storage import ObjectStoreTokenStorage
|
|
401
|
+
self._token_storage = ObjectStoreTokenStorage(object_store)
|
|
402
|
+
logger.info(f"Initialized token storage with object store '{self._token_storage_object_store_name}'")
|
|
403
|
+
except Exception as e:
|
|
404
|
+
logger.warning(
|
|
405
|
+
f"Failed to resolve object store '{self._token_storage_object_store_name}' for token storage: {e}. "
|
|
406
|
+
"Falling back to in-memory storage.")
|
|
407
|
+
from .token_storage import InMemoryTokenStorage
|
|
408
|
+
self._token_storage = InMemoryTokenStorage()
|
|
409
|
+
|
|
410
|
+
# Build the OAuth2 provider if not already built
|
|
411
|
+
if self._auth_code_provider is None:
|
|
412
|
+
scopes = self._effective_scopes
|
|
413
|
+
oauth2_config = OAuth2AuthCodeFlowProviderConfig(
|
|
414
|
+
client_id=credentials.client_id,
|
|
415
|
+
client_secret=credentials.client_secret or "",
|
|
416
|
+
authorization_url=str(endpoints.authorization_url),
|
|
417
|
+
token_url=str(endpoints.token_url),
|
|
418
|
+
token_endpoint_auth_method=getattr(self.config, "token_endpoint_auth_method", None),
|
|
419
|
+
redirect_uri=str(self.config.redirect_uri) if self.config.redirect_uri else "",
|
|
420
|
+
scopes=scopes,
|
|
421
|
+
use_pkce=bool(self.config.use_pkce),
|
|
422
|
+
authorization_kwargs={"resource": str(self.config.server_url)})
|
|
423
|
+
self._auth_code_provider = OAuth2AuthCodeFlowProvider(oauth2_config, token_storage=self._token_storage)
|
|
424
|
+
|
|
425
|
+
# Use MCP-specific authentication method if available
|
|
426
|
+
if hasattr(self._auth_code_provider, "_set_custom_auth_callback"):
|
|
427
|
+
callback = self._auth_callback or self._flow_handler.authenticate
|
|
428
|
+
self._auth_code_provider._set_custom_auth_callback(callback) # type: ignore[arg-type]
|
|
429
|
+
|
|
430
|
+
# Auth code provider is responsible for per-user cache + refresh
|
|
431
|
+
return await self._auth_code_provider.authenticate(user_id=user_id)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import Field
|
|
17
|
+
from pydantic import HttpUrl
|
|
18
|
+
from pydantic import model_validator
|
|
19
|
+
|
|
20
|
+
from nat.authentication.interfaces import AuthProviderBaseConfig
|
|
21
|
+
from nat.data_models.common import OptionalSecretStr
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
|
|
25
|
+
"""
|
|
26
|
+
MCP OAuth2 provider with endpoints discovery, optional DCR, and authentication flow via the OAuth2AuthCodeFlow
|
|
27
|
+
provider.
|
|
28
|
+
|
|
29
|
+
Supported modes:
|
|
30
|
+
- Endpoints discovery + Dynamic Client Registration (DCR) (enable_dynamic_registration=True, no client_id)
|
|
31
|
+
- Endpoints discovery + Manual Client Registration (client_id + client_secret provided)
|
|
32
|
+
"""
|
|
33
|
+
server_url: HttpUrl = Field(
|
|
34
|
+
...,
|
|
35
|
+
description=
|
|
36
|
+
"URL of the MCP server. This is the MCP server that provides tools, NOT the OAuth2 authorization server.")
|
|
37
|
+
|
|
38
|
+
# Client registration (manual registration vs DCR)
|
|
39
|
+
client_id: str | None = Field(default=None, description="OAuth2 client ID for pre-registered clients")
|
|
40
|
+
client_secret: OptionalSecretStr = Field(default=None,
|
|
41
|
+
description="OAuth2 client secret for pre-registered clients")
|
|
42
|
+
enable_dynamic_registration: bool = Field(default=True,
|
|
43
|
+
description="Enable OAuth2 Dynamic Client Registration (RFC 7591)")
|
|
44
|
+
client_name: str = Field(default="NAT MCP Client", description="OAuth2 client name for dynamic registration")
|
|
45
|
+
|
|
46
|
+
# OAuth2 flow configuration
|
|
47
|
+
redirect_uri: HttpUrl = Field(..., description="OAuth2 redirect URI.")
|
|
48
|
+
token_endpoint_auth_method: str = Field(default="client_secret_post",
|
|
49
|
+
description="The authentication method for the token endpoint.")
|
|
50
|
+
scopes: list[str] = Field(default_factory=list,
|
|
51
|
+
description="OAuth2 scopes, discovered from MCP server if not provided")
|
|
52
|
+
# Advanced options
|
|
53
|
+
use_pkce: bool = Field(default=True, description="Use PKCE for authorization code flow")
|
|
54
|
+
|
|
55
|
+
default_user_id: str | None = Field(default=None, description="Default user ID for authentication")
|
|
56
|
+
allow_default_user_id_for_tool_calls: bool = Field(default=True, description="Allow default user ID for tool calls")
|
|
57
|
+
|
|
58
|
+
# Token storage configuration
|
|
59
|
+
token_storage_object_store: str | None = Field(
|
|
60
|
+
default=None,
|
|
61
|
+
description="Reference to object store for secure token storage. If None, uses in-memory storage.")
|
|
62
|
+
|
|
63
|
+
@model_validator(mode="after")
|
|
64
|
+
def validate_auth_config(self):
|
|
65
|
+
"""Validate authentication configuration for MCP-specific options."""
|
|
66
|
+
|
|
67
|
+
# if default_user_id is not provided, use the server_url as the default user id
|
|
68
|
+
if not self.default_user_id:
|
|
69
|
+
self.default_user_id = str(self.server_url)
|
|
70
|
+
# Dynamic registration + MCP discovery
|
|
71
|
+
if self.enable_dynamic_registration and not self.client_id:
|
|
72
|
+
# Pure dynamic registration - no explicit credentials needed
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
# Manual registration + MCP discovery
|
|
76
|
+
elif self.client_id and self.client_secret:
|
|
77
|
+
# Has credentials but will discover URLs from MCP server
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
# Invalid configuration
|
|
81
|
+
else:
|
|
82
|
+
raise ValueError("Must provide either: "
|
|
83
|
+
"1) enable_dynamic_registration=True (dynamic), or "
|
|
84
|
+
"2) client_id + client_secret (hybrid)")
|
|
85
|
+
|
|
86
|
+
return self
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from nat.builder.builder import Builder
|
|
17
|
+
from nat.cli.register_workflow import register_auth_provider
|
|
18
|
+
from nat.plugins.mcp.auth.auth_provider import MCPOAuth2Provider
|
|
19
|
+
from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig
|
|
20
|
+
from nat.plugins.mcp.auth.service_account.provider import MCPServiceAccountProvider
|
|
21
|
+
from nat.plugins.mcp.auth.service_account.provider_config import MCPServiceAccountProviderConfig
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@register_auth_provider(config_type=MCPOAuth2ProviderConfig)
|
|
25
|
+
async def mcp_oauth2_provider(authentication_provider: MCPOAuth2ProviderConfig, builder: Builder):
|
|
26
|
+
"""Register MCP OAuth2 authentication provider with NAT system."""
|
|
27
|
+
yield MCPOAuth2Provider(authentication_provider, builder=builder)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@register_auth_provider(config_type=MCPServiceAccountProviderConfig)
|
|
31
|
+
async def mcp_service_account_provider(authentication_provider: MCPServiceAccountProviderConfig, builder: Builder):
|
|
32
|
+
"""Register MCP Service Account authentication provider with NAT system."""
|
|
33
|
+
yield MCPServiceAccountProvider(authentication_provider, builder=builder)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|