nvidia-nat-mcp 1.3.0a20250917__py3-none-any.whl → 1.3.0a20250922__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,14 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,367 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ from urllib.parse import urljoin
18
+ from urllib.parse import urlparse
19
+
20
+ import httpx
21
+ from pydantic import BaseModel
22
+ from pydantic import Field
23
+ from pydantic import HttpUrl
24
+
25
+ from mcp.shared.auth import OAuthClientInformationFull
26
+ from mcp.shared.auth import OAuthClientMetadata
27
+ from mcp.shared.auth import OAuthMetadata
28
+ from mcp.shared.auth import ProtectedResourceMetadata
29
+ from nat.authentication.interfaces import AuthProviderBase
30
+ from nat.data_models.authentication import AuthReason
31
+ from nat.data_models.authentication import AuthRequest
32
+ from nat.data_models.authentication import AuthResult
33
+ from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class OAuth2Endpoints(BaseModel):
39
+ """OAuth2 endpoints discovered from MCP server."""
40
+ authorization_url: HttpUrl = Field(..., description="OAuth2 authorization endpoint URL")
41
+ token_url: HttpUrl = Field(..., description="OAuth2 token endpoint URL")
42
+ registration_url: HttpUrl | None = Field(default=None, description="OAuth2 client registration endpoint URL")
43
+
44
+
45
+ class OAuth2Credentials(BaseModel):
46
+ """OAuth2 client credentials from registration."""
47
+ client_id: str = Field(..., description="OAuth2 client identifier")
48
+ client_secret: str | None = Field(default=None, description="OAuth2 client secret")
49
+
50
+
51
+ class DiscoverOAuth2Endpoints:
52
+ """
53
+ MCP-SDK parity discovery flow:
54
+ 1) If 401 + WWW-Authenticate has resource_metadata (RFC 9728), fetch it.
55
+ 2) Else fetch RS well-known /.well-known/oauth-protected-resource.
56
+ 3) If PR metadata lists authorization_servers, pick first as issuer.
57
+ 4) Do path-aware RFC 8414 / OIDC discovery against issuer (or server base).
58
+ """
59
+
60
+ def __init__(self, config: MCPOAuth2ProviderConfig):
61
+ self.config = config
62
+ self._cached_endpoints: OAuth2Endpoints | None = None
63
+ self._last_oauth_scopes: list[str] | None = None
64
+
65
+ async def discover(self, reason: AuthReason, www_authenticate: str | None) -> tuple[OAuth2Endpoints, bool]:
66
+ """
67
+ Discover OAuth2 endpoints from MCP server.
68
+
69
+ Args:
70
+ reason: The reason for the discovery.
71
+ www_authenticate: The WWW-Authenticate header from a 401 response.
72
+
73
+ Returns:
74
+ A tuple of OAuth2Endpoints and a boolean indicating if the endpoints have changed.
75
+ """
76
+ # Fast path: reuse cache when not a 401 retry
77
+ if reason != AuthReason.RETRY_AFTER_401 and self._cached_endpoints is not None:
78
+ return self._cached_endpoints, False
79
+
80
+ issuer: str = str(self.config.server_url) # default to server URL
81
+ endpoints: OAuth2Endpoints | None = None
82
+
83
+ # 1) 401 hint (RFC 9728) if present
84
+ if reason == AuthReason.RETRY_AFTER_401 and www_authenticate:
85
+ hint_url = self._extract_from_www_authenticate_header(www_authenticate)
86
+ if hint_url:
87
+ logger.info("Using RFC 9728 resource_metadata hint: %s", hint_url)
88
+ issuer_hint = await self._fetch_pr_issuer(hint_url)
89
+ if issuer_hint:
90
+ issuer = issuer_hint
91
+
92
+ # 2) Try RS protected resource well-known if we still only have default issuer
93
+ if issuer == str(self.config.server_url):
94
+ pr_url = urljoin(self._authorization_base_url(), "/.well-known/oauth-protected-resource")
95
+ try:
96
+ logger.debug("Fetching protected resource metadata: %s", pr_url)
97
+ issuer2 = await self._fetch_pr_issuer(pr_url)
98
+ if issuer2:
99
+ issuer = issuer2
100
+ except Exception as e:
101
+ logger.debug("Protected resource metadata not available: %s", e)
102
+
103
+ # 3) Path-aware RFC 8414 / OIDC discovery using issuer (or server base)
104
+ endpoints = await self._discover_via_issuer_or_base(issuer)
105
+ if endpoints is None:
106
+ raise RuntimeError("Could not discover OAuth2 endpoints from MCP server")
107
+
108
+ changed = (self._cached_endpoints is None
109
+ or endpoints.authorization_url != self._cached_endpoints.authorization_url
110
+ or endpoints.token_url != self._cached_endpoints.token_url
111
+ or endpoints.registration_url != self._cached_endpoints.registration_url)
112
+ self._cached_endpoints = endpoints
113
+ logger.info("OAuth2 endpoints selected: %s", self._cached_endpoints)
114
+ return self._cached_endpoints, changed
115
+
116
+ # --------------------------- helpers ---------------------------
117
+ def _authorization_base_url(self) -> str:
118
+ """Get the authorization base URL from the MCP server URL."""
119
+ p = urlparse(str(self.config.server_url))
120
+ return f"{p.scheme}://{p.netloc}"
121
+
122
+ def _extract_from_www_authenticate_header(self, hdr: str) -> str | None:
123
+ """Extract the resource_metadata URL from the WWW-Authenticate header."""
124
+ import re
125
+
126
+ if not hdr:
127
+ return None
128
+ # resource_metadata="url" | 'url' | url (case-insensitive; stop on space/comma/semicolon)
129
+ m = re.search(r'(?i)\bresource_metadata\s*=\s*(?:"([^"]+)"|\'([^\']+)\'|([^\s,;]+))', hdr)
130
+ if not m:
131
+ return None
132
+ url = next((g for g in m.groups() if g), None)
133
+ if url:
134
+ logger.debug("Extracted resource_metadata URL: %s", url)
135
+ return url
136
+
137
+ async def _fetch_pr_issuer(self, url: str) -> str | None:
138
+ """Fetch RFC 9728 Protected Resource Metadata and return the first issuer (authorization_server)."""
139
+ async with httpx.AsyncClient(timeout=10.0) as client:
140
+ resp = await client.get(url, headers={"Accept": "application/json"})
141
+ resp.raise_for_status()
142
+ body = await resp.aread()
143
+ try:
144
+ pr = ProtectedResourceMetadata.model_validate_json(body)
145
+ except Exception as e:
146
+ logger.debug("Invalid ProtectedResourceMetadata at %s: %s", url, e)
147
+ return None
148
+ if pr.authorization_servers:
149
+ return str(pr.authorization_servers[0])
150
+ return None
151
+
152
+ async def _discover_via_issuer_or_base(self, base_or_issuer: str) -> OAuth2Endpoints | None:
153
+ """Perform path-aware RFC 8414 / OIDC discovery given an issuer or base URL."""
154
+ urls = self._build_path_aware_discovery_urls(base_or_issuer)
155
+ async with httpx.AsyncClient(timeout=10.0) as client:
156
+ for url in urls:
157
+ try:
158
+ resp = await client.get(url, headers={"Accept": "application/json"})
159
+ if resp.status_code != 200:
160
+ continue
161
+ body = await resp.aread()
162
+ try:
163
+ meta = OAuthMetadata.model_validate_json(body)
164
+ except Exception as e:
165
+ logger.debug("Invalid OAuthMetadata at %s: %s", url, e)
166
+ continue
167
+ if meta.authorization_endpoint and meta.token_endpoint:
168
+ logger.info("Discovered OAuth2 endpoints from %s", url)
169
+ # this is bit of a hack to get the scopes supported by the auth server
170
+ self._last_oauth_scopes = meta.scopes_supported
171
+ return OAuth2Endpoints(
172
+ authorization_url=str(meta.authorization_endpoint),
173
+ token_url=str(meta.token_endpoint),
174
+ registration_url=str(meta.registration_endpoint) if meta.registration_endpoint else None,
175
+ )
176
+ except Exception as e:
177
+ logger.debug("Discovery failed at %s: %s", url, e)
178
+ return None
179
+
180
+ def _build_path_aware_discovery_urls(self, base_or_issuer: str) -> list[str]:
181
+ """Build path-aware discovery URLs."""
182
+ p = urlparse(base_or_issuer)
183
+ base = f"{p.scheme}://{p.netloc}"
184
+ path = (p.path or "").rstrip("/")
185
+ urls: list[str] = []
186
+ if path:
187
+ urls.append(urljoin(base, f"/.well-known/oauth-authorization-server{path}"))
188
+ urls.append(urljoin(base, "/.well-known/oauth-authorization-server"))
189
+ if path:
190
+ urls.append(urljoin(base, f"/.well-known/openid-configuration{path}"))
191
+ urls.append(base_or_issuer.rstrip("/") + "/.well-known/openid-configuration")
192
+ return urls
193
+
194
+ def scopes_supported(self) -> list[str] | None:
195
+ """Get the last OAuth scopes discovered from the AS."""
196
+ return self._last_oauth_scopes
197
+
198
+
199
+ class DynamicClientRegistration:
200
+ """Dynamic client registration utility."""
201
+
202
+ def __init__(self, config: MCPOAuth2ProviderConfig):
203
+ self.config = config
204
+
205
+ def _authorization_base_url(self) -> str:
206
+ """Get the authorization base URL from the MCP server URL."""
207
+ p = urlparse(str(self.config.server_url))
208
+ return f"{p.scheme}://{p.netloc}"
209
+
210
+ async def register(self, endpoints: OAuth2Endpoints, scopes: list[str] | None) -> OAuth2Credentials:
211
+ """Register an OAuth2 client with the Authorization Server using OIDC client registration."""
212
+ # Fallback to /register if metadata didn't provide an endpoint
213
+ registration_url = (str(endpoints.registration_url) if endpoints.registration_url else urljoin(
214
+ self._authorization_base_url(), "/register"))
215
+
216
+ metadata = OAuthClientMetadata(
217
+ redirect_uris=[self.config.redirect_uri],
218
+ token_endpoint_auth_method=(getattr(self.config, "token_endpoint_auth_method", None)
219
+ or "client_secret_post"),
220
+ grant_types=["authorization_code", "refresh_token"],
221
+ response_types=["code"],
222
+ scope=" ".join(scopes) if scopes else None,
223
+ client_name=self.config.client_name or None,
224
+ )
225
+ payload = metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
226
+
227
+ async with httpx.AsyncClient(timeout=30.0) as client:
228
+ resp = await client.post(
229
+ registration_url,
230
+ json=payload,
231
+ headers={
232
+ "Content-Type": "application/json", "Accept": "application/json"
233
+ },
234
+ )
235
+ resp.raise_for_status()
236
+ body = await resp.aread()
237
+
238
+ try:
239
+ info = OAuthClientInformationFull.model_validate_json(body)
240
+ except Exception as e:
241
+ raise RuntimeError(
242
+ f"Registration response was not valid OAuthClientInformation from {registration_url}") from e
243
+
244
+ if not info.client_id:
245
+ raise RuntimeError("No client_id received from registration")
246
+
247
+ logger.info("Successfully registered OAuth2 client: %s", info.client_id)
248
+ return OAuth2Credentials(client_id=info.client_id, client_secret=info.client_secret)
249
+
250
+
251
+ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
252
+ """MCP OAuth2 authentication provider that delegates to NAT framework."""
253
+
254
+ def __init__(self, config: MCPOAuth2ProviderConfig):
255
+ super().__init__(config)
256
+
257
+ # Discovery
258
+ self._discoverer = DiscoverOAuth2Endpoints(config)
259
+ self._cached_endpoints: OAuth2Endpoints | None = None
260
+
261
+ # Client registration
262
+ self._registrar = DynamicClientRegistration(config)
263
+ self._cached_credentials: OAuth2Credentials | None = None
264
+
265
+ # For the OAuth2 flow
266
+ self._auth_code_provider = None
267
+
268
+ async def authenticate(self, user_id: str | None = None) -> AuthResult:
269
+ """
270
+ Authenticate using MCP OAuth2 flow via NAT framework.
271
+ 1. Dynamic endpoints discovery (RFC9728 + RFC 8414 + OIDC)
272
+ 2. Client registration (RFC7591)
273
+ 3. Use NAT's standard OAuth2 flow (OAuth2AuthCodeFlowProvider)
274
+ """
275
+ auth_request = self.config.auth_request
276
+ if not auth_request:
277
+ auth_request = AuthRequest(reason=AuthReason.NORMAL)
278
+
279
+ if auth_request.reason != AuthReason.RETRY_AFTER_401:
280
+ # auth provider is expected to be setup via 401, till that time we return empty auth result
281
+ if not self._auth_code_provider:
282
+ return AuthResult(credentials=[], token_expires_at=None, raw={})
283
+
284
+ await self._discover_and_register(auth_request)
285
+ # Use NAT's standard OAuth2 flow
286
+ if auth_request.reason == AuthReason.RETRY_AFTER_401:
287
+ # force fresh delegate (clears in-mem token cache)
288
+ self._auth_code_provider = None
289
+ # preserve other fields, just normalize reason & inject user_id
290
+ auth_request = auth_request.model_copy(update={
291
+ "reason": AuthReason.NORMAL, "user_id": user_id, "www_authenticate": None
292
+ })
293
+ # back-compat: propagate user_id if provided but not set in the request
294
+ elif user_id is not None and auth_request.user_id is None:
295
+ auth_request = auth_request.model_copy(update={"user_id": user_id})
296
+
297
+ # Perform the OAuth2 flow without lock
298
+ return await self._perform_oauth2_flow(auth_request=auth_request)
299
+
300
+ async def _discover_and_register(self, auth_request: AuthRequest):
301
+ """
302
+ Discover OAuth2 endpoints and register an OAuth2 client with the Authorization Server
303
+ using OIDC client registration.
304
+ """
305
+ # Discover OAuth2 endpoints
306
+ self._cached_endpoints, endpoints_changed = await self._discoverer.discover(reason=auth_request.reason,
307
+ www_authenticate=auth_request.www_authenticate)
308
+ if endpoints_changed:
309
+ logger.info("OAuth2 endpoints: %s", self._cached_endpoints)
310
+ self._cached_credentials = None # invalidate credentials tied to old AS
311
+ effective_scopes = self._effective_scopes()
312
+
313
+ # Client registration
314
+ if not self._cached_credentials:
315
+ if self.config.client_id:
316
+ # Manual registration mode
317
+ self._cached_credentials = OAuth2Credentials(
318
+ client_id=self.config.client_id,
319
+ client_secret=self.config.client_secret,
320
+ )
321
+ logger.info("Using manual client_id: %s", self._cached_credentials.client_id)
322
+ else:
323
+ # Dynamic registration mode requires registration endpoint
324
+ self._cached_credentials = await self._registrar.register(self._cached_endpoints, effective_scopes)
325
+ logger.info("Registered OAuth2 client: %s", self._cached_credentials.client_id)
326
+
327
+ def _effective_scopes(self) -> list[str] | None:
328
+ """
329
+ Prefer caller-provided scopes; otherwise fall back to AS-advertised scopes_supported.
330
+ """
331
+ return self.config.scopes or self._discoverer.scopes_supported()
332
+
333
+ async def _build_oauth2_delegate(self):
334
+ """Build NAT OAuth2 provider and delegate auth token acquisition and refresh to it"""
335
+ from nat.authentication.oauth2.oauth2_auth_code_flow_provider import OAuth2AuthCodeFlowProvider
336
+ from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
337
+
338
+ endpoints = self._cached_endpoints
339
+ credentials = self._cached_credentials
340
+
341
+ if self._auth_code_provider is None:
342
+ oauth2_config = OAuth2AuthCodeFlowProviderConfig(
343
+ client_id=credentials.client_id,
344
+ client_secret=credentials.client_secret or "",
345
+ authorization_url=str(endpoints.authorization_url),
346
+ token_url=str(endpoints.token_url),
347
+ token_endpoint_auth_method=getattr(self.config, "token_endpoint_auth_method", None),
348
+ redirect_uri=str(self.config.redirect_uri) if self.config.redirect_uri else "",
349
+ scopes=self._effective_scopes() or [],
350
+ use_pkce=bool(self.config.use_pkce),
351
+ )
352
+
353
+ self._auth_code_provider = OAuth2AuthCodeFlowProvider(oauth2_config)
354
+
355
+ async def _perform_oauth2_flow(self, auth_request: AuthRequest | None = None) -> AuthResult:
356
+ """Perform the OAuth2 flow using NAT OAuth2 provider."""
357
+ # This helper is only for non-401 flows
358
+ if auth_request and auth_request.reason == AuthReason.RETRY_AFTER_401:
359
+ raise RuntimeError("_perform_oauth2_flow should not be called for RETRY_AFTER_401")
360
+
361
+ if not self._cached_endpoints or not self._cached_credentials:
362
+ raise RuntimeError("OAuth2 flow called before discovery/registration")
363
+
364
+ # (Re)build the delegate if needed
365
+ await self._build_oauth2_delegate()
366
+ # Let the delegate handle per-user cache + refresh
367
+ return await self._auth_code_provider.authenticate()
@@ -0,0 +1,76 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pydantic import Field
17
+ from pydantic import HttpUrl
18
+ from pydantic import model_validator
19
+
20
+ from nat.authentication.interfaces import AuthProviderBaseConfig
21
+ from nat.data_models.authentication import AuthRequest
22
+
23
+
24
+ class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
25
+ """
26
+ MCP OAuth2 provider with endpoints discovery, optional DCR, and authentication flow via the OAuth2AuthCodeFlow
27
+ provider.
28
+
29
+ Supported modes:
30
+ - Endpoints discovery + Dynamic Client Registration (DCR) (enable_dynamic_registration=True, no client_id)
31
+ - Endpoints discovery + Manual Client Registration (client_id + client_secret provided)
32
+ """
33
+ server_url: HttpUrl = Field(
34
+ ...,
35
+ description=
36
+ "URL of the MCP server. This is the MCP server that provides tools, NOT the OAuth2 authorization server.")
37
+
38
+ # Client registration (manual registration vs DCR)
39
+ client_id: str | None = Field(default=None, description="OAuth2 client ID for pre-registered clients")
40
+ client_secret: str | None = Field(default=None, description="OAuth2 client secret for pre-registered clients")
41
+ enable_dynamic_registration: bool = Field(default=True,
42
+ description="Enable OAuth2 Dynamic Client Registration (RFC 7591)")
43
+ client_name: str = Field(default="NAT MCP Client", description="OAuth2 client name for dynamic registration")
44
+
45
+ # OAuth2 flow configuration
46
+ redirect_uri: HttpUrl = Field(..., description="OAuth2 redirect URI.")
47
+ token_endpoint_auth_method: str = Field(default="client_secret_post",
48
+ description="The authentication method for the token endpoint.")
49
+ scopes: list[str] = Field(default_factory=list,
50
+ description="OAuth2 scopes, discovered from MCP server if not provided")
51
+ # Advanced options
52
+ use_pkce: bool = Field(default=True, description="Use PKCE for authorization code flow")
53
+
54
+ auth_request: AuthRequest | None = Field(default=None, description="Auth request for authentication (metadata)")
55
+
56
+ @model_validator(mode="after")
57
+ def validate_auth_config(self):
58
+ """Validate authentication configuration for MCP-specific options."""
59
+
60
+ # Dynamic registration + MCP discovery
61
+ if self.enable_dynamic_registration and not self.client_id:
62
+ # Pure dynamic registration - no explicit credentials needed
63
+ pass
64
+
65
+ # Manual registration + MCP discovery
66
+ elif self.client_id and self.client_secret:
67
+ # Has credentials but will discover URLs from MCP server
68
+ pass
69
+
70
+ # Invalid configuration
71
+ else:
72
+ raise ValueError("Must provide either: "
73
+ "1) enable_dynamic_registration=True (dynamic), or "
74
+ "2) client_id + client_secret (hybrid)")
75
+
76
+ return self
@@ -0,0 +1,25 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from nat.builder.builder import Builder
17
+ from nat.cli.register_workflow import register_auth_provider
18
+ from nat.plugins.mcp.auth.auth_provider import MCPOAuth2Provider
19
+ from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig
20
+
21
+
22
+ @register_auth_provider(config_type=MCPOAuth2ProviderConfig)
23
+ async def mcp_oauth2_provider(authentication_provider: MCPOAuth2ProviderConfig, builder: Builder):
24
+ """Register MCP OAuth2 authentication provider with NAT system."""
25
+ yield MCPOAuth2Provider(authentication_provider)