nvidia-nat-mcp 1.3.0a20250926__py3-none-any.whl → 1.3.0a20251111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +2 -2
- nat/plugins/mcp/auth/auth_flow_handler.py +208 -0
- nat/plugins/mcp/auth/auth_provider.py +149 -86
- nat/plugins/mcp/auth/auth_provider_config.py +10 -2
- nat/plugins/mcp/auth/register.py +1 -1
- nat/plugins/mcp/auth/token_storage.py +265 -0
- nat/plugins/mcp/client_base.py +165 -71
- nat/plugins/mcp/client_config.py +131 -0
- nat/plugins/mcp/client_impl.py +469 -99
- nat/plugins/mcp/exception_handler.py +1 -1
- nat/plugins/mcp/tool.py +6 -7
- nat/plugins/mcp/utils.py +167 -34
- {nvidia_nat_mcp-1.3.0a20250926.dist-info → nvidia_nat_mcp-1.3.0a20251111.dist-info}/METADATA +13 -4
- nvidia_nat_mcp-1.3.0a20251111.dist-info/RECORD +23 -0
- nvidia_nat_mcp-1.3.0a20251111.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_mcp-1.3.0a20251111.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_mcp-1.3.0a20250926.dist-info/RECORD +0 -18
- {nvidia_nat_mcp-1.3.0a20250926.dist-info → nvidia_nat_mcp-1.3.0a20251111.dist-info}/WHEEL +0 -0
- {nvidia_nat_mcp-1.3.0a20250926.dist-info → nvidia_nat_mcp-1.3.0a20251111.dist-info}/entry_points.txt +0 -0
- {nvidia_nat_mcp-1.3.0a20250926.dist-info → nvidia_nat_mcp-1.3.0a20251111.dist-info}/top_level.txt +0 -0
nat/meta/pypi.md
CHANGED
|
@@ -19,9 +19,9 @@ limitations under the License.
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
# NVIDIA NeMo Agent Toolkit MCP Subpackage
|
|
22
|
-
Subpackage for MCP
|
|
22
|
+
Subpackage for MCP integration in NeMo Agent toolkit.
|
|
23
23
|
|
|
24
|
-
This package provides MCP (Model Context Protocol)
|
|
24
|
+
This package provides MCP (Model Context Protocol) functionality, allowing NeMo Agent toolkit workflows to connect to external MCP servers and use their tools as functions.
|
|
25
25
|
|
|
26
26
|
## Features
|
|
27
27
|
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
import secrets
|
|
19
|
+
import webbrowser
|
|
20
|
+
|
|
21
|
+
import pkce
|
|
22
|
+
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
23
|
+
from fastapi import FastAPI
|
|
24
|
+
|
|
25
|
+
from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
|
|
26
|
+
from nat.data_models.authentication import AuthenticatedContext
|
|
27
|
+
from nat.data_models.authentication import AuthFlowType
|
|
28
|
+
from nat.data_models.authentication import AuthProviderBaseConfig
|
|
29
|
+
from nat.front_ends.console.authentication_flow_handler import ConsoleAuthenticationFlowHandler
|
|
30
|
+
from nat.front_ends.console.authentication_flow_handler import _FlowState
|
|
31
|
+
from nat.front_ends.fastapi.fastapi_front_end_controller import _FastApiFrontEndController
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class MCPAuthenticationFlowHandler(ConsoleAuthenticationFlowHandler):
|
|
37
|
+
"""
|
|
38
|
+
Authentication helper for MCP environments.
|
|
39
|
+
|
|
40
|
+
This handler is specifically designed for MCP tool discovery scenarios where
|
|
41
|
+
authentication needs to happen before the default auth_callback is available
|
|
42
|
+
in the Context. It handles OAuth2 authorization code flow during MCP client
|
|
43
|
+
startup and tool discovery phases.
|
|
44
|
+
|
|
45
|
+
Key differences from console handler:
|
|
46
|
+
- Only supports OAuth2 Authorization Code flow (no HTTP Basic)
|
|
47
|
+
- Optimized for MCP tool discovery workflows
|
|
48
|
+
- Designed for single-use authentication during startup
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self):
|
|
52
|
+
super().__init__()
|
|
53
|
+
self._server_controller: _FastApiFrontEndController | None = None
|
|
54
|
+
self._redirect_app: FastAPI | None = None
|
|
55
|
+
self._server_lock = asyncio.Lock()
|
|
56
|
+
self._oauth_client: AsyncOAuth2Client | None = None
|
|
57
|
+
self._redirect_host: str = "localhost" # Default host, will be overridden from config
|
|
58
|
+
self._redirect_port: int = 8000 # Default port, will be overridden from config
|
|
59
|
+
self._server_task: asyncio.Task | None = None
|
|
60
|
+
|
|
61
|
+
async def authenticate(self, config: AuthProviderBaseConfig, method: AuthFlowType) -> AuthenticatedContext:
|
|
62
|
+
"""
|
|
63
|
+
Handle the OAuth2 authorization code flow for MCP environments.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
config: OAuth2 configuration for MCP server
|
|
67
|
+
method: Authentication method (only OAUTH2_AUTHORIZATION_CODE supported)
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
AuthenticatedContext with Bearer token for MCP server access
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
ValueError: If config is invalid for MCP use case
|
|
74
|
+
NotImplementedError: If method is not OAuth2 Authorization Code
|
|
75
|
+
"""
|
|
76
|
+
logger.info("Starting MCP authentication flow")
|
|
77
|
+
|
|
78
|
+
if method == AuthFlowType.OAUTH2_AUTHORIZATION_CODE:
|
|
79
|
+
if not isinstance(config, OAuth2AuthCodeFlowProviderConfig):
|
|
80
|
+
raise ValueError("Requested OAuth2 Authorization Code Flow but passed invalid config")
|
|
81
|
+
|
|
82
|
+
# MCP-specific validation
|
|
83
|
+
if not config.redirect_uri:
|
|
84
|
+
raise ValueError("MCP authentication requires redirect_uri to be configured")
|
|
85
|
+
|
|
86
|
+
logger.info("MCP authentication configured for server: %s", getattr(config, 'server_url', 'unknown'))
|
|
87
|
+
return await self._handle_oauth2_auth_code_flow(config)
|
|
88
|
+
|
|
89
|
+
raise NotImplementedError(f'Auth method "{method}" not supported for MCP environments')
|
|
90
|
+
|
|
91
|
+
async def _handle_oauth2_auth_code_flow(self, cfg: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext:
|
|
92
|
+
logger.info("Starting MCP OAuth2 authorization code flow")
|
|
93
|
+
|
|
94
|
+
# Extract and validate host and port from redirect_uri for callback server
|
|
95
|
+
from urllib.parse import urlparse
|
|
96
|
+
parsed_uri = urlparse(str(cfg.redirect_uri))
|
|
97
|
+
|
|
98
|
+
# Validate scheme/host and choose a safe non-privileged bind port
|
|
99
|
+
scheme = (parsed_uri.scheme or "http").lower()
|
|
100
|
+
if scheme not in ("http", "https"):
|
|
101
|
+
raise ValueError(f"redirect_uri must use http or https scheme, got '{scheme}'")
|
|
102
|
+
|
|
103
|
+
host = parsed_uri.hostname
|
|
104
|
+
if not host:
|
|
105
|
+
raise ValueError("redirect_uri must include a hostname, for example http://localhost:8000/auth/redirect")
|
|
106
|
+
|
|
107
|
+
# Never auto-bind to 80/443; default to 8000 when port is not specified
|
|
108
|
+
port = parsed_uri.port or 8000
|
|
109
|
+
if not (1 <= port <= 65535):
|
|
110
|
+
raise ValueError(f"Invalid redirect port: {port}. Expected 1-65535.")
|
|
111
|
+
|
|
112
|
+
if scheme == "https" and parsed_uri.port is None:
|
|
113
|
+
logger.warning(
|
|
114
|
+
"redirect_uri uses https without an explicit port; binding to %d (plain HTTP). "
|
|
115
|
+
"Terminate TLS at a reverse proxy and forward to this port.",
|
|
116
|
+
port)
|
|
117
|
+
|
|
118
|
+
self._redirect_host = host
|
|
119
|
+
self._redirect_port = port
|
|
120
|
+
logger.info("MCP redirect server will use %s:%d", self._redirect_host, self._redirect_port)
|
|
121
|
+
|
|
122
|
+
state = secrets.token_urlsafe(16)
|
|
123
|
+
flow_state = _FlowState()
|
|
124
|
+
client = self.construct_oauth_client(cfg)
|
|
125
|
+
|
|
126
|
+
flow_state.token_url = cfg.token_url
|
|
127
|
+
flow_state.use_pkce = cfg.use_pkce
|
|
128
|
+
|
|
129
|
+
# PKCE bits
|
|
130
|
+
if cfg.use_pkce:
|
|
131
|
+
verifier, challenge = pkce.generate_pkce_pair()
|
|
132
|
+
flow_state.verifier = verifier
|
|
133
|
+
flow_state.challenge = challenge
|
|
134
|
+
logger.debug("PKCE enabled for MCP authentication")
|
|
135
|
+
|
|
136
|
+
auth_url, _ = client.create_authorization_url(
|
|
137
|
+
cfg.authorization_url,
|
|
138
|
+
state=state,
|
|
139
|
+
code_verifier=flow_state.verifier if cfg.use_pkce else None,
|
|
140
|
+
code_challenge=flow_state.challenge if cfg.use_pkce else None,
|
|
141
|
+
**(cfg.authorization_kwargs or {})
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
async with self._server_lock:
|
|
145
|
+
if self._redirect_app is None:
|
|
146
|
+
self._redirect_app = await self._build_redirect_app()
|
|
147
|
+
|
|
148
|
+
await self._start_redirect_server()
|
|
149
|
+
self._flows[state] = flow_state
|
|
150
|
+
|
|
151
|
+
logger.info("MCP authentication: Your browser has been opened for authentication.")
|
|
152
|
+
logger.info("This will authenticate you with the MCP server for tool discovery.")
|
|
153
|
+
webbrowser.open(auth_url)
|
|
154
|
+
|
|
155
|
+
# Use default timeout for MCP tool discovery
|
|
156
|
+
timeout = 300
|
|
157
|
+
|
|
158
|
+
try:
|
|
159
|
+
token = await asyncio.wait_for(flow_state.future, timeout=timeout)
|
|
160
|
+
logger.info("MCP authentication successful, token obtained")
|
|
161
|
+
except TimeoutError as exc:
|
|
162
|
+
logger.error("MCP authentication timed out")
|
|
163
|
+
raise RuntimeError(f"MCP authentication timed out ({timeout} seconds). Please try again.") from exc
|
|
164
|
+
finally:
|
|
165
|
+
async with self._server_lock:
|
|
166
|
+
self._flows.pop(state, None)
|
|
167
|
+
await self._stop_redirect_server()
|
|
168
|
+
|
|
169
|
+
return AuthenticatedContext(
|
|
170
|
+
headers={"Authorization": f"Bearer {token['access_token']}"},
|
|
171
|
+
metadata={
|
|
172
|
+
"expires_at": token.get("expires_at"),
|
|
173
|
+
"raw_token": token,
|
|
174
|
+
},
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
async def _start_redirect_server(self) -> None:
|
|
178
|
+
"""
|
|
179
|
+
Override to use the host and port from redirect_uri config instead of hardcoded localhost:8000.
|
|
180
|
+
|
|
181
|
+
This allows MCP authentication to work with custom redirect hosts and ports
|
|
182
|
+
specified in the configuration.
|
|
183
|
+
"""
|
|
184
|
+
# If the server is already running, do nothing
|
|
185
|
+
if self._server_controller:
|
|
186
|
+
return
|
|
187
|
+
try:
|
|
188
|
+
if not self._redirect_app:
|
|
189
|
+
raise RuntimeError("Redirect app not built.")
|
|
190
|
+
|
|
191
|
+
self._server_controller = _FastApiFrontEndController(self._redirect_app)
|
|
192
|
+
|
|
193
|
+
self._server_task = asyncio.create_task(
|
|
194
|
+
self._server_controller.start_server(host=self._redirect_host, port=self._redirect_port))
|
|
195
|
+
logger.debug("MCP redirect server starting on %s:%d", self._redirect_host, self._redirect_port)
|
|
196
|
+
|
|
197
|
+
# Wait for the server to bind (max ~10s)
|
|
198
|
+
start = asyncio.get_running_loop().time()
|
|
199
|
+
while True:
|
|
200
|
+
server = getattr(self._server_controller, "_server", None)
|
|
201
|
+
if server and getattr(server, "started", False):
|
|
202
|
+
break
|
|
203
|
+
if asyncio.get_running_loop().time() - start > 10:
|
|
204
|
+
raise RuntimeError("Redirect server did not report ready within 10s")
|
|
205
|
+
await asyncio.sleep(0.1)
|
|
206
|
+
except Exception as exc:
|
|
207
|
+
raise RuntimeError(
|
|
208
|
+
f"Failed to start MCP redirect server on {self._redirect_host}:{self._redirect_port}: {exc}") from exc
|
|
@@ -14,6 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import logging
|
|
17
|
+
from collections.abc import Awaitable
|
|
18
|
+
from collections.abc import Callable
|
|
17
19
|
from urllib.parse import urljoin
|
|
18
20
|
from urllib.parse import urlparse
|
|
19
21
|
|
|
@@ -21,15 +23,18 @@ import httpx
|
|
|
21
23
|
from pydantic import BaseModel
|
|
22
24
|
from pydantic import Field
|
|
23
25
|
from pydantic import HttpUrl
|
|
26
|
+
from pydantic import TypeAdapter
|
|
24
27
|
|
|
25
28
|
from mcp.shared.auth import OAuthClientInformationFull
|
|
26
29
|
from mcp.shared.auth import OAuthClientMetadata
|
|
27
30
|
from mcp.shared.auth import OAuthMetadata
|
|
28
31
|
from mcp.shared.auth import ProtectedResourceMetadata
|
|
32
|
+
from nat.authentication.interfaces import AuthenticatedContext
|
|
33
|
+
from nat.authentication.interfaces import AuthFlowType
|
|
29
34
|
from nat.authentication.interfaces import AuthProviderBase
|
|
30
|
-
from nat.
|
|
31
|
-
from nat.data_models.authentication import AuthRequest
|
|
35
|
+
from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
|
|
32
36
|
from nat.data_models.authentication import AuthResult
|
|
37
|
+
from nat.plugins.mcp.auth.auth_flow_handler import MCPAuthenticationFlowHandler
|
|
33
38
|
from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig
|
|
34
39
|
|
|
35
40
|
logger = logging.getLogger(__name__)
|
|
@@ -40,6 +45,7 @@ class OAuth2Endpoints(BaseModel):
|
|
|
40
45
|
authorization_url: HttpUrl = Field(..., description="OAuth2 authorization endpoint URL")
|
|
41
46
|
token_url: HttpUrl = Field(..., description="OAuth2 token endpoint URL")
|
|
42
47
|
registration_url: HttpUrl | None = Field(default=None, description="OAuth2 client registration endpoint URL")
|
|
48
|
+
scopes: list[str] | None = Field(default=None, description="OAuth2 scopes to be used for the authentication")
|
|
43
49
|
|
|
44
50
|
|
|
45
51
|
class OAuth2Credentials(BaseModel):
|
|
@@ -60,9 +66,10 @@ class DiscoverOAuth2Endpoints:
|
|
|
60
66
|
def __init__(self, config: MCPOAuth2ProviderConfig):
|
|
61
67
|
self.config = config
|
|
62
68
|
self._cached_endpoints: OAuth2Endpoints | None = None
|
|
63
|
-
self._last_oauth_scopes: list[str] | None = None
|
|
64
69
|
|
|
65
|
-
|
|
70
|
+
self._flow_handler: MCPAuthenticationFlowHandler = MCPAuthenticationFlowHandler()
|
|
71
|
+
|
|
72
|
+
async def discover(self, response: httpx.Response | None = None) -> tuple[OAuth2Endpoints, bool]:
|
|
66
73
|
"""
|
|
67
74
|
Discover OAuth2 endpoints from MCP server.
|
|
68
75
|
|
|
@@ -73,21 +80,24 @@ class DiscoverOAuth2Endpoints:
|
|
|
73
80
|
Returns:
|
|
74
81
|
A tuple of OAuth2Endpoints and a boolean indicating if the endpoints have changed.
|
|
75
82
|
"""
|
|
83
|
+
is_401_retry = response is not None and response.status_code == 401
|
|
76
84
|
# Fast path: reuse cache when not a 401 retry
|
|
77
|
-
if
|
|
85
|
+
if not is_401_retry and self._cached_endpoints is not None:
|
|
78
86
|
return self._cached_endpoints, False
|
|
79
87
|
|
|
80
88
|
issuer: str = str(self.config.server_url) # default to server URL
|
|
81
89
|
endpoints: OAuth2Endpoints | None = None
|
|
82
90
|
|
|
83
91
|
# 1) 401 hint (RFC 9728) if present
|
|
84
|
-
if
|
|
85
|
-
|
|
86
|
-
if
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
92
|
+
if is_401_retry and response:
|
|
93
|
+
www_authenticate = response.headers.get("WWW-Authenticate")
|
|
94
|
+
if www_authenticate:
|
|
95
|
+
hint_url = self._extract_from_www_authenticate_header(www_authenticate)
|
|
96
|
+
if hint_url:
|
|
97
|
+
logger.info("Using RFC 9728 resource_metadata hint: %s", hint_url)
|
|
98
|
+
issuer_hint = await self._fetch_pr_issuer(hint_url)
|
|
99
|
+
if issuer_hint:
|
|
100
|
+
issuer = issuer_hint
|
|
91
101
|
|
|
92
102
|
# 2) Try RS protected resource well-known if we still only have default issuer
|
|
93
103
|
if issuer == str(self.config.server_url):
|
|
@@ -105,10 +115,7 @@ class DiscoverOAuth2Endpoints:
|
|
|
105
115
|
if endpoints is None:
|
|
106
116
|
raise RuntimeError("Could not discover OAuth2 endpoints from MCP server")
|
|
107
117
|
|
|
108
|
-
changed = (self._cached_endpoints is None
|
|
109
|
-
or endpoints.authorization_url != self._cached_endpoints.authorization_url
|
|
110
|
-
or endpoints.token_url != self._cached_endpoints.token_url
|
|
111
|
-
or endpoints.registration_url != self._cached_endpoints.registration_url)
|
|
118
|
+
changed = (self._cached_endpoints is None or endpoints.model_dump() != self._cached_endpoints.model_dump())
|
|
112
119
|
self._cached_endpoints = endpoints
|
|
113
120
|
logger.info("OAuth2 endpoints selected: %s", self._cached_endpoints)
|
|
114
121
|
return self._cached_endpoints, changed
|
|
@@ -155,10 +162,29 @@ class DiscoverOAuth2Endpoints:
|
|
|
155
162
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
156
163
|
for url in urls:
|
|
157
164
|
try:
|
|
158
|
-
resp = await client.get(url, headers={"Accept": "application/json"})
|
|
165
|
+
resp = await client.get(url, follow_redirects=True, headers={"Accept": "application/json"})
|
|
159
166
|
if resp.status_code != 200:
|
|
160
167
|
continue
|
|
168
|
+
|
|
169
|
+
# Check content type before attempting JSON parsing
|
|
170
|
+
content_type = resp.headers.get("content-type", "").lower()
|
|
171
|
+
if "application/json" not in content_type:
|
|
172
|
+
logger.info(
|
|
173
|
+
"Discovery endpoint %s returned non-JSON content type: %s. "
|
|
174
|
+
"This may indicate the endpoint doesn't support discovery or requires authentication.",
|
|
175
|
+
url,
|
|
176
|
+
content_type)
|
|
177
|
+
# If it's HTML, log a more helpful message
|
|
178
|
+
if "text/html" in content_type:
|
|
179
|
+
logger.info("The endpoint appears to be returning an HTML page instead of OAuth metadata. "
|
|
180
|
+
"This often means:")
|
|
181
|
+
logger.info("1. The OAuth discovery endpoint doesn't exist at this URL")
|
|
182
|
+
logger.info("2. The server requires authentication before providing discovery metadata")
|
|
183
|
+
logger.info("3. The URL is pointing to a web application instead of an OAuth server")
|
|
184
|
+
continue
|
|
185
|
+
|
|
161
186
|
body = await resp.aread()
|
|
187
|
+
|
|
162
188
|
try:
|
|
163
189
|
meta = OAuthMetadata.model_validate_json(body)
|
|
164
190
|
except Exception as e:
|
|
@@ -166,15 +192,21 @@ class DiscoverOAuth2Endpoints:
|
|
|
166
192
|
continue
|
|
167
193
|
if meta.authorization_endpoint and meta.token_endpoint:
|
|
168
194
|
logger.info("Discovered OAuth2 endpoints from %s", url)
|
|
169
|
-
#
|
|
170
|
-
|
|
195
|
+
# Convert AnyHttpUrl to HttpUrl using TypeAdapter
|
|
196
|
+
http_url_adapter = TypeAdapter(HttpUrl)
|
|
171
197
|
return OAuth2Endpoints(
|
|
172
|
-
authorization_url=str(meta.authorization_endpoint),
|
|
173
|
-
token_url=str(meta.token_endpoint),
|
|
174
|
-
registration_url=str(meta.registration_endpoint)
|
|
198
|
+
authorization_url=http_url_adapter.validate_python(str(meta.authorization_endpoint)),
|
|
199
|
+
token_url=http_url_adapter.validate_python(str(meta.token_endpoint)),
|
|
200
|
+
registration_url=http_url_adapter.validate_python(str(meta.registration_endpoint))
|
|
201
|
+
if meta.registration_endpoint else None,
|
|
202
|
+
scopes=meta.scopes_supported,
|
|
175
203
|
)
|
|
176
204
|
except Exception as e:
|
|
177
205
|
logger.debug("Discovery failed at %s: %s", url, e)
|
|
206
|
+
|
|
207
|
+
# If we get here, all discovery URLs failed
|
|
208
|
+
logger.info("OAuth discovery failed for all attempted URLs.")
|
|
209
|
+
logger.info("Attempted URLs: %s", urls)
|
|
178
210
|
return None
|
|
179
211
|
|
|
180
212
|
def _build_path_aware_discovery_urls(self, base_or_issuer: str) -> list[str]:
|
|
@@ -184,17 +216,19 @@ class DiscoverOAuth2Endpoints:
|
|
|
184
216
|
path = (p.path or "").rstrip("/")
|
|
185
217
|
urls: list[str] = []
|
|
186
218
|
if path:
|
|
187
|
-
|
|
219
|
+
# this is the specified by the MCP spec
|
|
220
|
+
urls.append(urljoin(base, f".well-known/oauth-protected-resource{path}"))
|
|
221
|
+
# this is fallback for backward compatibility
|
|
222
|
+
urls.append(urljoin(base, f"{path}/.well-known/oauth-authorization-server"))
|
|
188
223
|
urls.append(urljoin(base, "/.well-known/oauth-authorization-server"))
|
|
189
224
|
if path:
|
|
190
|
-
|
|
225
|
+
# this is the specified by the MCP spec
|
|
226
|
+
urls.append(urljoin(base, f".well-known/openid-configuration{path}"))
|
|
227
|
+
# this is fallback for backward compatibility
|
|
228
|
+
urls.append(urljoin(base, f"{path}/.well-known/openid-configuration"))
|
|
191
229
|
urls.append(base_or_issuer.rstrip("/") + "/.well-known/openid-configuration")
|
|
192
230
|
return urls
|
|
193
231
|
|
|
194
|
-
def scopes_supported(self) -> list[str] | None:
|
|
195
|
-
"""Get the last OAuth scopes discovered from the AS."""
|
|
196
|
-
return self._last_oauth_scopes
|
|
197
|
-
|
|
198
232
|
|
|
199
233
|
class DynamicClientRegistration:
|
|
200
234
|
"""Dynamic client registration utility."""
|
|
@@ -251,8 +285,9 @@ class DynamicClientRegistration:
|
|
|
251
285
|
class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
|
252
286
|
"""MCP OAuth2 authentication provider that delegates to NAT framework."""
|
|
253
287
|
|
|
254
|
-
def __init__(self, config: MCPOAuth2ProviderConfig):
|
|
288
|
+
def __init__(self, config: MCPOAuth2ProviderConfig, builder=None):
|
|
255
289
|
super().__init__(config)
|
|
290
|
+
self._builder = builder
|
|
256
291
|
|
|
257
292
|
# Discovery
|
|
258
293
|
self._discoverer = DiscoverOAuth2Endpoints(config)
|
|
@@ -264,51 +299,71 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
|
|
264
299
|
|
|
265
300
|
# For the OAuth2 flow
|
|
266
301
|
self._auth_code_provider = None
|
|
267
|
-
|
|
268
|
-
|
|
302
|
+
self._flow_handler = MCPAuthenticationFlowHandler()
|
|
303
|
+
|
|
304
|
+
self._auth_callback = None
|
|
305
|
+
|
|
306
|
+
# Initialize token storage
|
|
307
|
+
self._token_storage = None
|
|
308
|
+
self._token_storage_object_store_name = None
|
|
309
|
+
|
|
310
|
+
if self.config.token_storage_object_store:
|
|
311
|
+
# Store object store name, will be resolved later when builder context is available
|
|
312
|
+
self._token_storage_object_store_name = self.config.token_storage_object_store
|
|
313
|
+
logger.info(f"Configured to use object store '{self._token_storage_object_store_name}' for token storage")
|
|
314
|
+
else:
|
|
315
|
+
# Default: use in-memory token storage
|
|
316
|
+
from .token_storage import InMemoryTokenStorage
|
|
317
|
+
self._token_storage = InMemoryTokenStorage()
|
|
318
|
+
|
|
319
|
+
def _set_custom_auth_callback(self,
|
|
320
|
+
auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType],
|
|
321
|
+
Awaitable[AuthenticatedContext]]):
|
|
322
|
+
"""Set the custom authentication callback."""
|
|
323
|
+
if not self._auth_callback:
|
|
324
|
+
logger.info("Using custom authentication callback")
|
|
325
|
+
self._auth_callback = auth_callback
|
|
326
|
+
if self._auth_code_provider:
|
|
327
|
+
self._auth_code_provider._set_custom_auth_callback(self._auth_callback) # type: ignore[arg-type]
|
|
328
|
+
|
|
329
|
+
async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
|
|
269
330
|
"""
|
|
270
331
|
Authenticate using MCP OAuth2 flow via NAT framework.
|
|
332
|
+
|
|
333
|
+
If response is provided in kwargs (typically from a 401), performs:
|
|
271
334
|
1. Dynamic endpoints discovery (RFC9728 + RFC 8414 + OIDC)
|
|
272
335
|
2. Client registration (RFC7591)
|
|
273
|
-
3.
|
|
336
|
+
3. Authentication
|
|
337
|
+
|
|
338
|
+
Otherwise, performs standard authentication flow.
|
|
274
339
|
"""
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
"reason": AuthReason.NORMAL, "user_id": user_id, "www_authenticate": None
|
|
292
|
-
})
|
|
293
|
-
# back-compat: propagate user_id if provided but not set in the request
|
|
294
|
-
elif user_id is not None and auth_request.user_id is None:
|
|
295
|
-
auth_request = auth_request.model_copy(update={"user_id": user_id})
|
|
296
|
-
|
|
297
|
-
# Perform the OAuth2 flow without lock
|
|
298
|
-
return await self._perform_oauth2_flow(auth_request=auth_request)
|
|
299
|
-
|
|
300
|
-
async def _discover_and_register(self, auth_request: AuthRequest):
|
|
340
|
+
if not user_id:
|
|
341
|
+
# MCP tool calls cannot be made without an authorized user
|
|
342
|
+
raise RuntimeError("User is not authorized to call the tool")
|
|
343
|
+
|
|
344
|
+
response = kwargs.get('response')
|
|
345
|
+
if response and response.status_code == 401:
|
|
346
|
+
await self._discover_and_register(response=response)
|
|
347
|
+
|
|
348
|
+
return await self._nat_oauth2_authenticate(user_id=user_id)
|
|
349
|
+
|
|
350
|
+
@property
|
|
351
|
+
def _effective_scopes(self) -> list[str]:
|
|
352
|
+
"""Get the effective scopes to be used for the authentication."""
|
|
353
|
+
return self.config.scopes or (self._cached_endpoints.scopes if self._cached_endpoints else []) or []
|
|
354
|
+
|
|
355
|
+
async def _discover_and_register(self, response: httpx.Response | None = None):
|
|
301
356
|
"""
|
|
302
357
|
Discover OAuth2 endpoints and register an OAuth2 client with the Authorization Server
|
|
303
358
|
using OIDC client registration.
|
|
304
359
|
"""
|
|
305
360
|
# Discover OAuth2 endpoints
|
|
306
|
-
self._cached_endpoints, endpoints_changed = await self._discoverer.discover(
|
|
307
|
-
www_authenticate=auth_request.www_authenticate)
|
|
361
|
+
self._cached_endpoints, endpoints_changed = await self._discoverer.discover(response=response)
|
|
308
362
|
if endpoints_changed:
|
|
309
363
|
logger.info("OAuth2 endpoints: %s", self._cached_endpoints)
|
|
310
364
|
self._cached_credentials = None # invalidate credentials tied to old AS
|
|
311
|
-
|
|
365
|
+
self._auth_code_provider = None
|
|
366
|
+
effective_scopes = self._effective_scopes
|
|
312
367
|
|
|
313
368
|
# Client registration
|
|
314
369
|
if not self._cached_credentials:
|
|
@@ -324,21 +379,36 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
|
|
324
379
|
self._cached_credentials = await self._registrar.register(self._cached_endpoints, effective_scopes)
|
|
325
380
|
logger.info("Registered OAuth2 client: %s", self._cached_credentials.client_id)
|
|
326
381
|
|
|
327
|
-
def
|
|
328
|
-
"""
|
|
329
|
-
Prefer caller-provided scopes; otherwise fall back to AS-advertised scopes_supported.
|
|
330
|
-
"""
|
|
331
|
-
return self.config.scopes or self._discoverer.scopes_supported()
|
|
332
|
-
|
|
333
|
-
async def _build_oauth2_delegate(self):
|
|
334
|
-
"""Build NAT OAuth2 provider and delegate auth token acquisition and refresh to it"""
|
|
382
|
+
async def _nat_oauth2_authenticate(self, user_id: str | None = None) -> AuthResult:
|
|
383
|
+
"""Perform the OAuth2 flow using MCP-specific authentication flow handler."""
|
|
335
384
|
from nat.authentication.oauth2.oauth2_auth_code_flow_provider import OAuth2AuthCodeFlowProvider
|
|
336
|
-
|
|
385
|
+
|
|
386
|
+
if not self._cached_endpoints or not self._cached_credentials:
|
|
387
|
+
# if discovery is yet to to be done return empty auth result
|
|
388
|
+
return AuthResult(credentials=[], token_expires_at=None, raw={})
|
|
337
389
|
|
|
338
390
|
endpoints = self._cached_endpoints
|
|
339
391
|
credentials = self._cached_credentials
|
|
340
392
|
|
|
393
|
+
# Resolve object store reference if needed
|
|
394
|
+
if self._token_storage_object_store_name and not self._token_storage:
|
|
395
|
+
try:
|
|
396
|
+
if not self._builder:
|
|
397
|
+
raise RuntimeError("Builder not available for resolving object store")
|
|
398
|
+
object_store = await self._builder.get_object_store_client(self._token_storage_object_store_name)
|
|
399
|
+
from .token_storage import ObjectStoreTokenStorage
|
|
400
|
+
self._token_storage = ObjectStoreTokenStorage(object_store)
|
|
401
|
+
logger.info(f"Initialized token storage with object store '{self._token_storage_object_store_name}'")
|
|
402
|
+
except Exception as e:
|
|
403
|
+
logger.warning(
|
|
404
|
+
f"Failed to resolve object store '{self._token_storage_object_store_name}' for token storage: {e}. "
|
|
405
|
+
"Falling back to in-memory storage.")
|
|
406
|
+
from .token_storage import InMemoryTokenStorage
|
|
407
|
+
self._token_storage = InMemoryTokenStorage()
|
|
408
|
+
|
|
409
|
+
# Build the OAuth2 provider if not already built
|
|
341
410
|
if self._auth_code_provider is None:
|
|
411
|
+
scopes = self._effective_scopes
|
|
342
412
|
oauth2_config = OAuth2AuthCodeFlowProviderConfig(
|
|
343
413
|
client_id=credentials.client_id,
|
|
344
414
|
client_secret=credentials.client_secret or "",
|
|
@@ -346,22 +416,15 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
|
|
346
416
|
token_url=str(endpoints.token_url),
|
|
347
417
|
token_endpoint_auth_method=getattr(self.config, "token_endpoint_auth_method", None),
|
|
348
418
|
redirect_uri=str(self.config.redirect_uri) if self.config.redirect_uri else "",
|
|
349
|
-
scopes=
|
|
419
|
+
scopes=scopes,
|
|
350
420
|
use_pkce=bool(self.config.use_pkce),
|
|
351
|
-
|
|
421
|
+
authorization_kwargs={"resource": str(self.config.server_url)})
|
|
422
|
+
self._auth_code_provider = OAuth2AuthCodeFlowProvider(oauth2_config, token_storage=self._token_storage)
|
|
352
423
|
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
# This helper is only for non-401 flows
|
|
358
|
-
if auth_request and auth_request.reason == AuthReason.RETRY_AFTER_401:
|
|
359
|
-
raise RuntimeError("_perform_oauth2_flow should not be called for RETRY_AFTER_401")
|
|
360
|
-
|
|
361
|
-
if not self._cached_endpoints or not self._cached_credentials:
|
|
362
|
-
raise RuntimeError("OAuth2 flow called before discovery/registration")
|
|
424
|
+
# Use MCP-specific authentication method if available
|
|
425
|
+
if hasattr(self._auth_code_provider, "_set_custom_auth_callback"):
|
|
426
|
+
callback = self._auth_callback or self._flow_handler.authenticate
|
|
427
|
+
self._auth_code_provider._set_custom_auth_callback(callback) # type: ignore[arg-type]
|
|
363
428
|
|
|
364
|
-
#
|
|
365
|
-
await self.
|
|
366
|
-
# Let the delegate handle per-user cache + refresh
|
|
367
|
-
return await self._auth_code_provider.authenticate()
|
|
429
|
+
# Auth code provider is responsible for per-user cache + refresh
|
|
430
|
+
return await self._auth_code_provider.authenticate(user_id=user_id)
|
|
@@ -18,7 +18,6 @@ from pydantic import HttpUrl
|
|
|
18
18
|
from pydantic import model_validator
|
|
19
19
|
|
|
20
20
|
from nat.authentication.interfaces import AuthProviderBaseConfig
|
|
21
|
-
from nat.data_models.authentication import AuthRequest
|
|
22
21
|
|
|
23
22
|
|
|
24
23
|
class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
|
|
@@ -51,12 +50,21 @@ class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
|
|
|
51
50
|
# Advanced options
|
|
52
51
|
use_pkce: bool = Field(default=True, description="Use PKCE for authorization code flow")
|
|
53
52
|
|
|
54
|
-
|
|
53
|
+
default_user_id: str | None = Field(default=None, description="Default user ID for authentication")
|
|
54
|
+
allow_default_user_id_for_tool_calls: bool = Field(default=True, description="Allow default user ID for tool calls")
|
|
55
|
+
|
|
56
|
+
# Token storage configuration
|
|
57
|
+
token_storage_object_store: str | None = Field(
|
|
58
|
+
default=None,
|
|
59
|
+
description="Reference to object store for secure token storage. If None, uses in-memory storage.")
|
|
55
60
|
|
|
56
61
|
@model_validator(mode="after")
|
|
57
62
|
def validate_auth_config(self):
|
|
58
63
|
"""Validate authentication configuration for MCP-specific options."""
|
|
59
64
|
|
|
65
|
+
# if default_user_id is not provided, use the server_url as the default user id
|
|
66
|
+
if not self.default_user_id:
|
|
67
|
+
self.default_user_id = str(self.server_url)
|
|
60
68
|
# Dynamic registration + MCP discovery
|
|
61
69
|
if self.enable_dynamic_registration and not self.client_id:
|
|
62
70
|
# Pure dynamic registration - no explicit credentials needed
|
nat/plugins/mcp/auth/register.py
CHANGED
|
@@ -22,4 +22,4 @@ from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig
|
|
|
22
22
|
@register_auth_provider(config_type=MCPOAuth2ProviderConfig)
|
|
23
23
|
async def mcp_oauth2_provider(authentication_provider: MCPOAuth2ProviderConfig, builder: Builder):
|
|
24
24
|
"""Register MCP OAuth2 authentication provider with NAT system."""
|
|
25
|
-
yield MCPOAuth2Provider(authentication_provider)
|
|
25
|
+
yield MCPOAuth2Provider(authentication_provider, builder=builder)
|