nvidia-nat-mcp 1.3.0a20250926__py3-none-any.whl → 1.3.0a20251111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nvidia-nat-mcp might be problematic. Click here for more details.

nat/meta/pypi.md CHANGED
@@ -19,9 +19,9 @@ limitations under the License.
19
19
 
20
20
 
21
21
  # NVIDIA NeMo Agent Toolkit MCP Subpackage
22
- Subpackage for MCP client integration in NeMo Agent toolkit.
22
+ Subpackage for MCP integration in NeMo Agent toolkit.
23
23
 
24
- This package provides MCP (Model Context Protocol) client functionality, allowing NeMo Agent toolkit workflows to connect to external MCP servers and use their tools as functions.
24
+ This package provides MCP (Model Context Protocol) functionality, allowing NeMo Agent toolkit workflows to connect to external MCP servers and use their tools as functions.
25
25
 
26
26
  ## Features
27
27
 
@@ -0,0 +1,208 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import logging
18
+ import secrets
19
+ import webbrowser
20
+
21
+ import pkce
22
+ from authlib.integrations.httpx_client import AsyncOAuth2Client
23
+ from fastapi import FastAPI
24
+
25
+ from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
26
+ from nat.data_models.authentication import AuthenticatedContext
27
+ from nat.data_models.authentication import AuthFlowType
28
+ from nat.data_models.authentication import AuthProviderBaseConfig
29
+ from nat.front_ends.console.authentication_flow_handler import ConsoleAuthenticationFlowHandler
30
+ from nat.front_ends.console.authentication_flow_handler import _FlowState
31
+ from nat.front_ends.fastapi.fastapi_front_end_controller import _FastApiFrontEndController
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class MCPAuthenticationFlowHandler(ConsoleAuthenticationFlowHandler):
37
+ """
38
+ Authentication helper for MCP environments.
39
+
40
+ This handler is specifically designed for MCP tool discovery scenarios where
41
+ authentication needs to happen before the default auth_callback is available
42
+ in the Context. It handles OAuth2 authorization code flow during MCP client
43
+ startup and tool discovery phases.
44
+
45
+ Key differences from console handler:
46
+ - Only supports OAuth2 Authorization Code flow (no HTTP Basic)
47
+ - Optimized for MCP tool discovery workflows
48
+ - Designed for single-use authentication during startup
49
+ """
50
+
51
+ def __init__(self):
52
+ super().__init__()
53
+ self._server_controller: _FastApiFrontEndController | None = None
54
+ self._redirect_app: FastAPI | None = None
55
+ self._server_lock = asyncio.Lock()
56
+ self._oauth_client: AsyncOAuth2Client | None = None
57
+ self._redirect_host: str = "localhost" # Default host, will be overridden from config
58
+ self._redirect_port: int = 8000 # Default port, will be overridden from config
59
+ self._server_task: asyncio.Task | None = None
60
+
61
+ async def authenticate(self, config: AuthProviderBaseConfig, method: AuthFlowType) -> AuthenticatedContext:
62
+ """
63
+ Handle the OAuth2 authorization code flow for MCP environments.
64
+
65
+ Args:
66
+ config: OAuth2 configuration for MCP server
67
+ method: Authentication method (only OAUTH2_AUTHORIZATION_CODE supported)
68
+
69
+ Returns:
70
+ AuthenticatedContext with Bearer token for MCP server access
71
+
72
+ Raises:
73
+ ValueError: If config is invalid for MCP use case
74
+ NotImplementedError: If method is not OAuth2 Authorization Code
75
+ """
76
+ logger.info("Starting MCP authentication flow")
77
+
78
+ if method == AuthFlowType.OAUTH2_AUTHORIZATION_CODE:
79
+ if not isinstance(config, OAuth2AuthCodeFlowProviderConfig):
80
+ raise ValueError("Requested OAuth2 Authorization Code Flow but passed invalid config")
81
+
82
+ # MCP-specific validation
83
+ if not config.redirect_uri:
84
+ raise ValueError("MCP authentication requires redirect_uri to be configured")
85
+
86
+ logger.info("MCP authentication configured for server: %s", getattr(config, 'server_url', 'unknown'))
87
+ return await self._handle_oauth2_auth_code_flow(config)
88
+
89
+ raise NotImplementedError(f'Auth method "{method}" not supported for MCP environments')
90
+
91
+ async def _handle_oauth2_auth_code_flow(self, cfg: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext:
92
+ logger.info("Starting MCP OAuth2 authorization code flow")
93
+
94
+ # Extract and validate host and port from redirect_uri for callback server
95
+ from urllib.parse import urlparse
96
+ parsed_uri = urlparse(str(cfg.redirect_uri))
97
+
98
+ # Validate scheme/host and choose a safe non-privileged bind port
99
+ scheme = (parsed_uri.scheme or "http").lower()
100
+ if scheme not in ("http", "https"):
101
+ raise ValueError(f"redirect_uri must use http or https scheme, got '{scheme}'")
102
+
103
+ host = parsed_uri.hostname
104
+ if not host:
105
+ raise ValueError("redirect_uri must include a hostname, for example http://localhost:8000/auth/redirect")
106
+
107
+ # Never auto-bind to 80/443; default to 8000 when port is not specified
108
+ port = parsed_uri.port or 8000
109
+ if not (1 <= port <= 65535):
110
+ raise ValueError(f"Invalid redirect port: {port}. Expected 1-65535.")
111
+
112
+ if scheme == "https" and parsed_uri.port is None:
113
+ logger.warning(
114
+ "redirect_uri uses https without an explicit port; binding to %d (plain HTTP). "
115
+ "Terminate TLS at a reverse proxy and forward to this port.",
116
+ port)
117
+
118
+ self._redirect_host = host
119
+ self._redirect_port = port
120
+ logger.info("MCP redirect server will use %s:%d", self._redirect_host, self._redirect_port)
121
+
122
+ state = secrets.token_urlsafe(16)
123
+ flow_state = _FlowState()
124
+ client = self.construct_oauth_client(cfg)
125
+
126
+ flow_state.token_url = cfg.token_url
127
+ flow_state.use_pkce = cfg.use_pkce
128
+
129
+ # PKCE bits
130
+ if cfg.use_pkce:
131
+ verifier, challenge = pkce.generate_pkce_pair()
132
+ flow_state.verifier = verifier
133
+ flow_state.challenge = challenge
134
+ logger.debug("PKCE enabled for MCP authentication")
135
+
136
+ auth_url, _ = client.create_authorization_url(
137
+ cfg.authorization_url,
138
+ state=state,
139
+ code_verifier=flow_state.verifier if cfg.use_pkce else None,
140
+ code_challenge=flow_state.challenge if cfg.use_pkce else None,
141
+ **(cfg.authorization_kwargs or {})
142
+ )
143
+
144
+ async with self._server_lock:
145
+ if self._redirect_app is None:
146
+ self._redirect_app = await self._build_redirect_app()
147
+
148
+ await self._start_redirect_server()
149
+ self._flows[state] = flow_state
150
+
151
+ logger.info("MCP authentication: Your browser has been opened for authentication.")
152
+ logger.info("This will authenticate you with the MCP server for tool discovery.")
153
+ webbrowser.open(auth_url)
154
+
155
+ # Use default timeout for MCP tool discovery
156
+ timeout = 300
157
+
158
+ try:
159
+ token = await asyncio.wait_for(flow_state.future, timeout=timeout)
160
+ logger.info("MCP authentication successful, token obtained")
161
+ except TimeoutError as exc:
162
+ logger.error("MCP authentication timed out")
163
+ raise RuntimeError(f"MCP authentication timed out ({timeout} seconds). Please try again.") from exc
164
+ finally:
165
+ async with self._server_lock:
166
+ self._flows.pop(state, None)
167
+ await self._stop_redirect_server()
168
+
169
+ return AuthenticatedContext(
170
+ headers={"Authorization": f"Bearer {token['access_token']}"},
171
+ metadata={
172
+ "expires_at": token.get("expires_at"),
173
+ "raw_token": token,
174
+ },
175
+ )
176
+
177
+ async def _start_redirect_server(self) -> None:
178
+ """
179
+ Override to use the host and port from redirect_uri config instead of hardcoded localhost:8000.
180
+
181
+ This allows MCP authentication to work with custom redirect hosts and ports
182
+ specified in the configuration.
183
+ """
184
+ # If the server is already running, do nothing
185
+ if self._server_controller:
186
+ return
187
+ try:
188
+ if not self._redirect_app:
189
+ raise RuntimeError("Redirect app not built.")
190
+
191
+ self._server_controller = _FastApiFrontEndController(self._redirect_app)
192
+
193
+ self._server_task = asyncio.create_task(
194
+ self._server_controller.start_server(host=self._redirect_host, port=self._redirect_port))
195
+ logger.debug("MCP redirect server starting on %s:%d", self._redirect_host, self._redirect_port)
196
+
197
+ # Wait for the server to bind (max ~10s)
198
+ start = asyncio.get_running_loop().time()
199
+ while True:
200
+ server = getattr(self._server_controller, "_server", None)
201
+ if server and getattr(server, "started", False):
202
+ break
203
+ if asyncio.get_running_loop().time() - start > 10:
204
+ raise RuntimeError("Redirect server did not report ready within 10s")
205
+ await asyncio.sleep(0.1)
206
+ except Exception as exc:
207
+ raise RuntimeError(
208
+ f"Failed to start MCP redirect server on {self._redirect_host}:{self._redirect_port}: {exc}") from exc
@@ -14,6 +14,8 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import logging
17
+ from collections.abc import Awaitable
18
+ from collections.abc import Callable
17
19
  from urllib.parse import urljoin
18
20
  from urllib.parse import urlparse
19
21
 
@@ -21,15 +23,18 @@ import httpx
21
23
  from pydantic import BaseModel
22
24
  from pydantic import Field
23
25
  from pydantic import HttpUrl
26
+ from pydantic import TypeAdapter
24
27
 
25
28
  from mcp.shared.auth import OAuthClientInformationFull
26
29
  from mcp.shared.auth import OAuthClientMetadata
27
30
  from mcp.shared.auth import OAuthMetadata
28
31
  from mcp.shared.auth import ProtectedResourceMetadata
32
+ from nat.authentication.interfaces import AuthenticatedContext
33
+ from nat.authentication.interfaces import AuthFlowType
29
34
  from nat.authentication.interfaces import AuthProviderBase
30
- from nat.data_models.authentication import AuthReason
31
- from nat.data_models.authentication import AuthRequest
35
+ from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
32
36
  from nat.data_models.authentication import AuthResult
37
+ from nat.plugins.mcp.auth.auth_flow_handler import MCPAuthenticationFlowHandler
33
38
  from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig
34
39
 
35
40
  logger = logging.getLogger(__name__)
@@ -40,6 +45,7 @@ class OAuth2Endpoints(BaseModel):
40
45
  authorization_url: HttpUrl = Field(..., description="OAuth2 authorization endpoint URL")
41
46
  token_url: HttpUrl = Field(..., description="OAuth2 token endpoint URL")
42
47
  registration_url: HttpUrl | None = Field(default=None, description="OAuth2 client registration endpoint URL")
48
+ scopes: list[str] | None = Field(default=None, description="OAuth2 scopes to be used for the authentication")
43
49
 
44
50
 
45
51
  class OAuth2Credentials(BaseModel):
@@ -60,9 +66,10 @@ class DiscoverOAuth2Endpoints:
60
66
  def __init__(self, config: MCPOAuth2ProviderConfig):
61
67
  self.config = config
62
68
  self._cached_endpoints: OAuth2Endpoints | None = None
63
- self._last_oauth_scopes: list[str] | None = None
64
69
 
65
- async def discover(self, reason: AuthReason, www_authenticate: str | None) -> tuple[OAuth2Endpoints, bool]:
70
+ self._flow_handler: MCPAuthenticationFlowHandler = MCPAuthenticationFlowHandler()
71
+
72
+ async def discover(self, response: httpx.Response | None = None) -> tuple[OAuth2Endpoints, bool]:
66
73
  """
67
74
  Discover OAuth2 endpoints from MCP server.
68
75
 
@@ -73,21 +80,24 @@ class DiscoverOAuth2Endpoints:
73
80
  Returns:
74
81
  A tuple of OAuth2Endpoints and a boolean indicating if the endpoints have changed.
75
82
  """
83
+ is_401_retry = response is not None and response.status_code == 401
76
84
  # Fast path: reuse cache when not a 401 retry
77
- if reason != AuthReason.RETRY_AFTER_401 and self._cached_endpoints is not None:
85
+ if not is_401_retry and self._cached_endpoints is not None:
78
86
  return self._cached_endpoints, False
79
87
 
80
88
  issuer: str = str(self.config.server_url) # default to server URL
81
89
  endpoints: OAuth2Endpoints | None = None
82
90
 
83
91
  # 1) 401 hint (RFC 9728) if present
84
- if reason == AuthReason.RETRY_AFTER_401 and www_authenticate:
85
- hint_url = self._extract_from_www_authenticate_header(www_authenticate)
86
- if hint_url:
87
- logger.info("Using RFC 9728 resource_metadata hint: %s", hint_url)
88
- issuer_hint = await self._fetch_pr_issuer(hint_url)
89
- if issuer_hint:
90
- issuer = issuer_hint
92
+ if is_401_retry and response:
93
+ www_authenticate = response.headers.get("WWW-Authenticate")
94
+ if www_authenticate:
95
+ hint_url = self._extract_from_www_authenticate_header(www_authenticate)
96
+ if hint_url:
97
+ logger.info("Using RFC 9728 resource_metadata hint: %s", hint_url)
98
+ issuer_hint = await self._fetch_pr_issuer(hint_url)
99
+ if issuer_hint:
100
+ issuer = issuer_hint
91
101
 
92
102
  # 2) Try RS protected resource well-known if we still only have default issuer
93
103
  if issuer == str(self.config.server_url):
@@ -105,10 +115,7 @@ class DiscoverOAuth2Endpoints:
105
115
  if endpoints is None:
106
116
  raise RuntimeError("Could not discover OAuth2 endpoints from MCP server")
107
117
 
108
- changed = (self._cached_endpoints is None
109
- or endpoints.authorization_url != self._cached_endpoints.authorization_url
110
- or endpoints.token_url != self._cached_endpoints.token_url
111
- or endpoints.registration_url != self._cached_endpoints.registration_url)
118
+ changed = (self._cached_endpoints is None or endpoints.model_dump() != self._cached_endpoints.model_dump())
112
119
  self._cached_endpoints = endpoints
113
120
  logger.info("OAuth2 endpoints selected: %s", self._cached_endpoints)
114
121
  return self._cached_endpoints, changed
@@ -155,10 +162,29 @@ class DiscoverOAuth2Endpoints:
155
162
  async with httpx.AsyncClient(timeout=10.0) as client:
156
163
  for url in urls:
157
164
  try:
158
- resp = await client.get(url, headers={"Accept": "application/json"})
165
+ resp = await client.get(url, follow_redirects=True, headers={"Accept": "application/json"})
159
166
  if resp.status_code != 200:
160
167
  continue
168
+
169
+ # Check content type before attempting JSON parsing
170
+ content_type = resp.headers.get("content-type", "").lower()
171
+ if "application/json" not in content_type:
172
+ logger.info(
173
+ "Discovery endpoint %s returned non-JSON content type: %s. "
174
+ "This may indicate the endpoint doesn't support discovery or requires authentication.",
175
+ url,
176
+ content_type)
177
+ # If it's HTML, log a more helpful message
178
+ if "text/html" in content_type:
179
+ logger.info("The endpoint appears to be returning an HTML page instead of OAuth metadata. "
180
+ "This often means:")
181
+ logger.info("1. The OAuth discovery endpoint doesn't exist at this URL")
182
+ logger.info("2. The server requires authentication before providing discovery metadata")
183
+ logger.info("3. The URL is pointing to a web application instead of an OAuth server")
184
+ continue
185
+
161
186
  body = await resp.aread()
187
+
162
188
  try:
163
189
  meta = OAuthMetadata.model_validate_json(body)
164
190
  except Exception as e:
@@ -166,15 +192,21 @@ class DiscoverOAuth2Endpoints:
166
192
  continue
167
193
  if meta.authorization_endpoint and meta.token_endpoint:
168
194
  logger.info("Discovered OAuth2 endpoints from %s", url)
169
- # this is bit of a hack to get the scopes supported by the auth server
170
- self._last_oauth_scopes = meta.scopes_supported
195
+ # Convert AnyHttpUrl to HttpUrl using TypeAdapter
196
+ http_url_adapter = TypeAdapter(HttpUrl)
171
197
  return OAuth2Endpoints(
172
- authorization_url=str(meta.authorization_endpoint),
173
- token_url=str(meta.token_endpoint),
174
- registration_url=str(meta.registration_endpoint) if meta.registration_endpoint else None,
198
+ authorization_url=http_url_adapter.validate_python(str(meta.authorization_endpoint)),
199
+ token_url=http_url_adapter.validate_python(str(meta.token_endpoint)),
200
+ registration_url=http_url_adapter.validate_python(str(meta.registration_endpoint))
201
+ if meta.registration_endpoint else None,
202
+ scopes=meta.scopes_supported,
175
203
  )
176
204
  except Exception as e:
177
205
  logger.debug("Discovery failed at %s: %s", url, e)
206
+
207
+ # If we get here, all discovery URLs failed
208
+ logger.info("OAuth discovery failed for all attempted URLs.")
209
+ logger.info("Attempted URLs: %s", urls)
178
210
  return None
179
211
 
180
212
  def _build_path_aware_discovery_urls(self, base_or_issuer: str) -> list[str]:
@@ -184,17 +216,19 @@ class DiscoverOAuth2Endpoints:
184
216
  path = (p.path or "").rstrip("/")
185
217
  urls: list[str] = []
186
218
  if path:
187
- urls.append(urljoin(base, f"/.well-known/oauth-authorization-server{path}"))
219
+ # this is the specified by the MCP spec
220
+ urls.append(urljoin(base, f".well-known/oauth-protected-resource{path}"))
221
+ # this is fallback for backward compatibility
222
+ urls.append(urljoin(base, f"{path}/.well-known/oauth-authorization-server"))
188
223
  urls.append(urljoin(base, "/.well-known/oauth-authorization-server"))
189
224
  if path:
190
- urls.append(urljoin(base, f"/.well-known/openid-configuration{path}"))
225
+ # this is the specified by the MCP spec
226
+ urls.append(urljoin(base, f".well-known/openid-configuration{path}"))
227
+ # this is fallback for backward compatibility
228
+ urls.append(urljoin(base, f"{path}/.well-known/openid-configuration"))
191
229
  urls.append(base_or_issuer.rstrip("/") + "/.well-known/openid-configuration")
192
230
  return urls
193
231
 
194
- def scopes_supported(self) -> list[str] | None:
195
- """Get the last OAuth scopes discovered from the AS."""
196
- return self._last_oauth_scopes
197
-
198
232
 
199
233
  class DynamicClientRegistration:
200
234
  """Dynamic client registration utility."""
@@ -251,8 +285,9 @@ class DynamicClientRegistration:
251
285
  class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
252
286
  """MCP OAuth2 authentication provider that delegates to NAT framework."""
253
287
 
254
- def __init__(self, config: MCPOAuth2ProviderConfig):
288
+ def __init__(self, config: MCPOAuth2ProviderConfig, builder=None):
255
289
  super().__init__(config)
290
+ self._builder = builder
256
291
 
257
292
  # Discovery
258
293
  self._discoverer = DiscoverOAuth2Endpoints(config)
@@ -264,51 +299,71 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
264
299
 
265
300
  # For the OAuth2 flow
266
301
  self._auth_code_provider = None
267
-
268
- async def authenticate(self, user_id: str | None = None) -> AuthResult:
302
+ self._flow_handler = MCPAuthenticationFlowHandler()
303
+
304
+ self._auth_callback = None
305
+
306
+ # Initialize token storage
307
+ self._token_storage = None
308
+ self._token_storage_object_store_name = None
309
+
310
+ if self.config.token_storage_object_store:
311
+ # Store object store name, will be resolved later when builder context is available
312
+ self._token_storage_object_store_name = self.config.token_storage_object_store
313
+ logger.info(f"Configured to use object store '{self._token_storage_object_store_name}' for token storage")
314
+ else:
315
+ # Default: use in-memory token storage
316
+ from .token_storage import InMemoryTokenStorage
317
+ self._token_storage = InMemoryTokenStorage()
318
+
319
+ def _set_custom_auth_callback(self,
320
+ auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType],
321
+ Awaitable[AuthenticatedContext]]):
322
+ """Set the custom authentication callback."""
323
+ if not self._auth_callback:
324
+ logger.info("Using custom authentication callback")
325
+ self._auth_callback = auth_callback
326
+ if self._auth_code_provider:
327
+ self._auth_code_provider._set_custom_auth_callback(self._auth_callback) # type: ignore[arg-type]
328
+
329
+ async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
269
330
  """
270
331
  Authenticate using MCP OAuth2 flow via NAT framework.
332
+
333
+ If response is provided in kwargs (typically from a 401), performs:
271
334
  1. Dynamic endpoints discovery (RFC9728 + RFC 8414 + OIDC)
272
335
  2. Client registration (RFC7591)
273
- 3. Use NAT's standard OAuth2 flow (OAuth2AuthCodeFlowProvider)
336
+ 3. Authentication
337
+
338
+ Otherwise, performs standard authentication flow.
274
339
  """
275
- auth_request = self.config.auth_request
276
- if not auth_request:
277
- auth_request = AuthRequest(reason=AuthReason.NORMAL)
278
-
279
- if auth_request.reason != AuthReason.RETRY_AFTER_401:
280
- # auth provider is expected to be setup via 401, till that time we return empty auth result
281
- if not self._auth_code_provider:
282
- return AuthResult(credentials=[], token_expires_at=None, raw={})
283
-
284
- await self._discover_and_register(auth_request)
285
- # Use NAT's standard OAuth2 flow
286
- if auth_request.reason == AuthReason.RETRY_AFTER_401:
287
- # force fresh delegate (clears in-mem token cache)
288
- self._auth_code_provider = None
289
- # preserve other fields, just normalize reason & inject user_id
290
- auth_request = auth_request.model_copy(update={
291
- "reason": AuthReason.NORMAL, "user_id": user_id, "www_authenticate": None
292
- })
293
- # back-compat: propagate user_id if provided but not set in the request
294
- elif user_id is not None and auth_request.user_id is None:
295
- auth_request = auth_request.model_copy(update={"user_id": user_id})
296
-
297
- # Perform the OAuth2 flow without lock
298
- return await self._perform_oauth2_flow(auth_request=auth_request)
299
-
300
- async def _discover_and_register(self, auth_request: AuthRequest):
340
+ if not user_id:
341
+ # MCP tool calls cannot be made without an authorized user
342
+ raise RuntimeError("User is not authorized to call the tool")
343
+
344
+ response = kwargs.get('response')
345
+ if response and response.status_code == 401:
346
+ await self._discover_and_register(response=response)
347
+
348
+ return await self._nat_oauth2_authenticate(user_id=user_id)
349
+
350
+ @property
351
+ def _effective_scopes(self) -> list[str]:
352
+ """Get the effective scopes to be used for the authentication."""
353
+ return self.config.scopes or (self._cached_endpoints.scopes if self._cached_endpoints else []) or []
354
+
355
+ async def _discover_and_register(self, response: httpx.Response | None = None):
301
356
  """
302
357
  Discover OAuth2 endpoints and register an OAuth2 client with the Authorization Server
303
358
  using OIDC client registration.
304
359
  """
305
360
  # Discover OAuth2 endpoints
306
- self._cached_endpoints, endpoints_changed = await self._discoverer.discover(reason=auth_request.reason,
307
- www_authenticate=auth_request.www_authenticate)
361
+ self._cached_endpoints, endpoints_changed = await self._discoverer.discover(response=response)
308
362
  if endpoints_changed:
309
363
  logger.info("OAuth2 endpoints: %s", self._cached_endpoints)
310
364
  self._cached_credentials = None # invalidate credentials tied to old AS
311
- effective_scopes = self._effective_scopes()
365
+ self._auth_code_provider = None
366
+ effective_scopes = self._effective_scopes
312
367
 
313
368
  # Client registration
314
369
  if not self._cached_credentials:
@@ -324,21 +379,36 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
324
379
  self._cached_credentials = await self._registrar.register(self._cached_endpoints, effective_scopes)
325
380
  logger.info("Registered OAuth2 client: %s", self._cached_credentials.client_id)
326
381
 
327
- def _effective_scopes(self) -> list[str] | None:
328
- """
329
- Prefer caller-provided scopes; otherwise fall back to AS-advertised scopes_supported.
330
- """
331
- return self.config.scopes or self._discoverer.scopes_supported()
332
-
333
- async def _build_oauth2_delegate(self):
334
- """Build NAT OAuth2 provider and delegate auth token acquisition and refresh to it"""
382
+ async def _nat_oauth2_authenticate(self, user_id: str | None = None) -> AuthResult:
383
+ """Perform the OAuth2 flow using MCP-specific authentication flow handler."""
335
384
  from nat.authentication.oauth2.oauth2_auth_code_flow_provider import OAuth2AuthCodeFlowProvider
336
- from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
385
+
386
+ if not self._cached_endpoints or not self._cached_credentials:
387
+ # if discovery is yet to to be done return empty auth result
388
+ return AuthResult(credentials=[], token_expires_at=None, raw={})
337
389
 
338
390
  endpoints = self._cached_endpoints
339
391
  credentials = self._cached_credentials
340
392
 
393
+ # Resolve object store reference if needed
394
+ if self._token_storage_object_store_name and not self._token_storage:
395
+ try:
396
+ if not self._builder:
397
+ raise RuntimeError("Builder not available for resolving object store")
398
+ object_store = await self._builder.get_object_store_client(self._token_storage_object_store_name)
399
+ from .token_storage import ObjectStoreTokenStorage
400
+ self._token_storage = ObjectStoreTokenStorage(object_store)
401
+ logger.info(f"Initialized token storage with object store '{self._token_storage_object_store_name}'")
402
+ except Exception as e:
403
+ logger.warning(
404
+ f"Failed to resolve object store '{self._token_storage_object_store_name}' for token storage: {e}. "
405
+ "Falling back to in-memory storage.")
406
+ from .token_storage import InMemoryTokenStorage
407
+ self._token_storage = InMemoryTokenStorage()
408
+
409
+ # Build the OAuth2 provider if not already built
341
410
  if self._auth_code_provider is None:
411
+ scopes = self._effective_scopes
342
412
  oauth2_config = OAuth2AuthCodeFlowProviderConfig(
343
413
  client_id=credentials.client_id,
344
414
  client_secret=credentials.client_secret or "",
@@ -346,22 +416,15 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
346
416
  token_url=str(endpoints.token_url),
347
417
  token_endpoint_auth_method=getattr(self.config, "token_endpoint_auth_method", None),
348
418
  redirect_uri=str(self.config.redirect_uri) if self.config.redirect_uri else "",
349
- scopes=self._effective_scopes() or [],
419
+ scopes=scopes,
350
420
  use_pkce=bool(self.config.use_pkce),
351
- )
421
+ authorization_kwargs={"resource": str(self.config.server_url)})
422
+ self._auth_code_provider = OAuth2AuthCodeFlowProvider(oauth2_config, token_storage=self._token_storage)
352
423
 
353
- self._auth_code_provider = OAuth2AuthCodeFlowProvider(oauth2_config)
354
-
355
- async def _perform_oauth2_flow(self, auth_request: AuthRequest | None = None) -> AuthResult:
356
- """Perform the OAuth2 flow using NAT OAuth2 provider."""
357
- # This helper is only for non-401 flows
358
- if auth_request and auth_request.reason == AuthReason.RETRY_AFTER_401:
359
- raise RuntimeError("_perform_oauth2_flow should not be called for RETRY_AFTER_401")
360
-
361
- if not self._cached_endpoints or not self._cached_credentials:
362
- raise RuntimeError("OAuth2 flow called before discovery/registration")
424
+ # Use MCP-specific authentication method if available
425
+ if hasattr(self._auth_code_provider, "_set_custom_auth_callback"):
426
+ callback = self._auth_callback or self._flow_handler.authenticate
427
+ self._auth_code_provider._set_custom_auth_callback(callback) # type: ignore[arg-type]
363
428
 
364
- # (Re)build the delegate if needed
365
- await self._build_oauth2_delegate()
366
- # Let the delegate handle per-user cache + refresh
367
- return await self._auth_code_provider.authenticate()
429
+ # Auth code provider is responsible for per-user cache + refresh
430
+ return await self._auth_code_provider.authenticate(user_id=user_id)
@@ -18,7 +18,6 @@ from pydantic import HttpUrl
18
18
  from pydantic import model_validator
19
19
 
20
20
  from nat.authentication.interfaces import AuthProviderBaseConfig
21
- from nat.data_models.authentication import AuthRequest
22
21
 
23
22
 
24
23
  class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
@@ -51,12 +50,21 @@ class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
51
50
  # Advanced options
52
51
  use_pkce: bool = Field(default=True, description="Use PKCE for authorization code flow")
53
52
 
54
- auth_request: AuthRequest | None = Field(default=None, description="Auth request for authentication (metadata)")
53
+ default_user_id: str | None = Field(default=None, description="Default user ID for authentication")
54
+ allow_default_user_id_for_tool_calls: bool = Field(default=True, description="Allow default user ID for tool calls")
55
+
56
+ # Token storage configuration
57
+ token_storage_object_store: str | None = Field(
58
+ default=None,
59
+ description="Reference to object store for secure token storage. If None, uses in-memory storage.")
55
60
 
56
61
  @model_validator(mode="after")
57
62
  def validate_auth_config(self):
58
63
  """Validate authentication configuration for MCP-specific options."""
59
64
 
65
+ # if default_user_id is not provided, use the server_url as the default user id
66
+ if not self.default_user_id:
67
+ self.default_user_id = str(self.server_url)
60
68
  # Dynamic registration + MCP discovery
61
69
  if self.enable_dynamic_registration and not self.client_id:
62
70
  # Pure dynamic registration - no explicit credentials needed
@@ -22,4 +22,4 @@ from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig
22
22
  @register_auth_provider(config_type=MCPOAuth2ProviderConfig)
23
23
  async def mcp_oauth2_provider(authentication_provider: MCPOAuth2ProviderConfig, builder: Builder):
24
24
  """Register MCP OAuth2 authentication provider with NAT system."""
25
- yield MCPOAuth2Provider(authentication_provider)
25
+ yield MCPOAuth2Provider(authentication_provider, builder=builder)