nvidia-nat-mcp 1.3.0a20250925__py3-none-any.whl → 1.3.0a20250928__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,144 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import logging
18
+ import secrets
19
+ import webbrowser
20
+
21
+ import pkce
22
+ from authlib.integrations.httpx_client import AsyncOAuth2Client
23
+ from fastapi import FastAPI
24
+
25
+ from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
26
+ from nat.data_models.authentication import AuthenticatedContext
27
+ from nat.data_models.authentication import AuthFlowType
28
+ from nat.data_models.authentication import AuthProviderBaseConfig
29
+ from nat.front_ends.console.authentication_flow_handler import ConsoleAuthenticationFlowHandler
30
+ from nat.front_ends.console.authentication_flow_handler import _FlowState
31
+ from nat.front_ends.fastapi.fastapi_front_end_controller import _FastApiFrontEndController
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class MCPAuthenticationFlowHandler(ConsoleAuthenticationFlowHandler):
37
+ """
38
+ Authentication helper for MCP environments.
39
+
40
+ This handler is specifically designed for MCP tool discovery scenarios where
41
+ authentication needs to happen before the default auth_callback is available
42
+ in the Context. It handles OAuth2 authorization code flow during MCP client
43
+ startup and tool discovery phases.
44
+
45
+ Key differences from console handler:
46
+ - Only supports OAuth2 Authorization Code flow (no HTTP Basic)
47
+ - Optimized for MCP tool discovery workflows
48
+ - Designed for single-use authentication during startup
49
+ """
50
+
51
+ def __init__(self):
52
+ super().__init__()
53
+ self._server_controller: _FastApiFrontEndController | None = None
54
+ self._redirect_app: FastAPI | None = None
55
+ self._server_lock = asyncio.Lock()
56
+ self._oauth_client: AsyncOAuth2Client | None = None
57
+
58
+ async def authenticate(self, config: AuthProviderBaseConfig, method: AuthFlowType) -> AuthenticatedContext:
59
+ """
60
+ Handle the OAuth2 authorization code flow for MCP environments.
61
+
62
+ Args:
63
+ config: OAuth2 configuration for MCP server
64
+ method: Authentication method (only OAUTH2_AUTHORIZATION_CODE supported)
65
+
66
+ Returns:
67
+ AuthenticatedContext with Bearer token for MCP server access
68
+
69
+ Raises:
70
+ ValueError: If config is invalid for MCP use case
71
+ NotImplementedError: If method is not OAuth2 Authorization Code
72
+ """
73
+ logger.info("Starting MCP authentication flow")
74
+
75
+ if method == AuthFlowType.OAUTH2_AUTHORIZATION_CODE:
76
+ if not isinstance(config, OAuth2AuthCodeFlowProviderConfig):
77
+ raise ValueError("Requested OAuth2 Authorization Code Flow but passed invalid config")
78
+
79
+ # MCP-specific validation
80
+ if not config.redirect_uri:
81
+ raise ValueError("MCP authentication requires redirect_uri to be configured")
82
+
83
+ logger.info("MCP authentication configured for server: %s", getattr(config, 'server_url', 'unknown'))
84
+ return await self._handle_oauth2_auth_code_flow(config)
85
+
86
+ raise NotImplementedError(f'Auth method "{method}" not supported for MCP environments')
87
+
88
+ async def _handle_oauth2_auth_code_flow(self, cfg: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext:
89
+ logger.info("Starting MCP OAuth2 authorization code flow")
90
+
91
+ state = secrets.token_urlsafe(16)
92
+ flow_state = _FlowState()
93
+ client = self.construct_oauth_client(cfg)
94
+
95
+ flow_state.token_url = cfg.token_url
96
+ flow_state.use_pkce = cfg.use_pkce
97
+
98
+ # PKCE bits
99
+ if cfg.use_pkce:
100
+ verifier, challenge = pkce.generate_pkce_pair()
101
+ flow_state.verifier = verifier
102
+ flow_state.challenge = challenge
103
+ logger.debug("PKCE enabled for MCP authentication")
104
+
105
+ auth_url, _ = client.create_authorization_url(
106
+ cfg.authorization_url,
107
+ state=state,
108
+ code_verifier=flow_state.verifier if cfg.use_pkce else None,
109
+ code_challenge=flow_state.challenge if cfg.use_pkce else None,
110
+ **(cfg.authorization_kwargs or {})
111
+ )
112
+
113
+ async with self._server_lock:
114
+ if self._redirect_app is None:
115
+ self._redirect_app = await self._build_redirect_app()
116
+
117
+ await self._start_redirect_server()
118
+ self._flows[state] = flow_state
119
+
120
+ logger.info("MCP authentication: Your browser has been opened for authentication.")
121
+ logger.info("This will authenticate you with the MCP server for tool discovery.")
122
+ webbrowser.open(auth_url)
123
+
124
+ # Use default timeout for MCP tool discovery
125
+ timeout = 300
126
+
127
+ try:
128
+ token = await asyncio.wait_for(flow_state.future, timeout=timeout)
129
+ logger.info("MCP authentication successful, token obtained")
130
+ except asyncio.TimeoutError as exc:
131
+ logger.error("MCP authentication timed out")
132
+ raise RuntimeError(f"MCP authentication timed out ({timeout} seconds). Please try again.") from exc
133
+ finally:
134
+ async with self._server_lock:
135
+ self._flows.pop(state, None)
136
+ await self._stop_redirect_server()
137
+
138
+ return AuthenticatedContext(
139
+ headers={"Authorization": f"Bearer {token['access_token']}"},
140
+ metadata={
141
+ "expires_at": token.get("expires_at"),
142
+ "raw_token": token,
143
+ },
144
+ )
@@ -14,6 +14,8 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import logging
17
+ from collections.abc import Awaitable
18
+ from typing import Callable
17
19
  from urllib.parse import urljoin
18
20
  from urllib.parse import urlparse
19
21
 
@@ -26,10 +28,12 @@ from mcp.shared.auth import OAuthClientInformationFull
26
28
  from mcp.shared.auth import OAuthClientMetadata
27
29
  from mcp.shared.auth import OAuthMetadata
28
30
  from mcp.shared.auth import ProtectedResourceMetadata
31
+ from nat.authentication.interfaces import AuthenticatedContext
32
+ from nat.authentication.interfaces import AuthFlowType
29
33
  from nat.authentication.interfaces import AuthProviderBase
30
- from nat.data_models.authentication import AuthReason
31
- from nat.data_models.authentication import AuthRequest
34
+ from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
32
35
  from nat.data_models.authentication import AuthResult
36
+ from nat.plugins.mcp.auth.auth_flow_handler import MCPAuthenticationFlowHandler
33
37
  from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig
34
38
 
35
39
  logger = logging.getLogger(__name__)
@@ -40,6 +44,7 @@ class OAuth2Endpoints(BaseModel):
40
44
  authorization_url: HttpUrl = Field(..., description="OAuth2 authorization endpoint URL")
41
45
  token_url: HttpUrl = Field(..., description="OAuth2 token endpoint URL")
42
46
  registration_url: HttpUrl | None = Field(default=None, description="OAuth2 client registration endpoint URL")
47
+ scopes: list[str] | None = Field(default=None, description="OAuth2 scopes to be used for the authentication")
43
48
 
44
49
 
45
50
  class OAuth2Credentials(BaseModel):
@@ -60,9 +65,11 @@ class DiscoverOAuth2Endpoints:
60
65
  def __init__(self, config: MCPOAuth2ProviderConfig):
61
66
  self.config = config
62
67
  self._cached_endpoints: OAuth2Endpoints | None = None
63
- self._last_oauth_scopes: list[str] | None = None
68
+ self._authenticated_servers: dict[str, AuthResult] = {}
64
69
 
65
- async def discover(self, reason: AuthReason, www_authenticate: str | None) -> tuple[OAuth2Endpoints, bool]:
70
+ self._flow_handler: MCPAuthenticationFlowHandler = MCPAuthenticationFlowHandler()
71
+
72
+ async def discover(self, response: httpx.Response | None = None) -> tuple[OAuth2Endpoints, bool]:
66
73
  """
67
74
  Discover OAuth2 endpoints from MCP server.
68
75
 
@@ -73,21 +80,24 @@ class DiscoverOAuth2Endpoints:
73
80
  Returns:
74
81
  A tuple of OAuth2Endpoints and a boolean indicating if the endpoints have changed.
75
82
  """
83
+ is_401_retry = response is not None and response.status_code == 401
76
84
  # Fast path: reuse cache when not a 401 retry
77
- if reason != AuthReason.RETRY_AFTER_401 and self._cached_endpoints is not None:
85
+ if not is_401_retry and self._cached_endpoints is not None:
78
86
  return self._cached_endpoints, False
79
87
 
80
88
  issuer: str = str(self.config.server_url) # default to server URL
81
89
  endpoints: OAuth2Endpoints | None = None
82
90
 
83
91
  # 1) 401 hint (RFC 9728) if present
84
- if reason == AuthReason.RETRY_AFTER_401 and www_authenticate:
85
- hint_url = self._extract_from_www_authenticate_header(www_authenticate)
86
- if hint_url:
87
- logger.info("Using RFC 9728 resource_metadata hint: %s", hint_url)
88
- issuer_hint = await self._fetch_pr_issuer(hint_url)
89
- if issuer_hint:
90
- issuer = issuer_hint
92
+ if is_401_retry and response:
93
+ www_authenticate = response.headers.get("WWW-Authenticate")
94
+ if www_authenticate:
95
+ hint_url = self._extract_from_www_authenticate_header(www_authenticate)
96
+ if hint_url:
97
+ logger.info("Using RFC 9728 resource_metadata hint: %s", hint_url)
98
+ issuer_hint = await self._fetch_pr_issuer(hint_url)
99
+ if issuer_hint:
100
+ issuer = issuer_hint
91
101
 
92
102
  # 2) Try RS protected resource well-known if we still only have default issuer
93
103
  if issuer == str(self.config.server_url):
@@ -105,10 +115,7 @@ class DiscoverOAuth2Endpoints:
105
115
  if endpoints is None:
106
116
  raise RuntimeError("Could not discover OAuth2 endpoints from MCP server")
107
117
 
108
- changed = (self._cached_endpoints is None
109
- or endpoints.authorization_url != self._cached_endpoints.authorization_url
110
- or endpoints.token_url != self._cached_endpoints.token_url
111
- or endpoints.registration_url != self._cached_endpoints.registration_url)
118
+ changed = (self._cached_endpoints is None or endpoints.model_dump() != self._cached_endpoints.model_dump())
112
119
  self._cached_endpoints = endpoints
113
120
  logger.info("OAuth2 endpoints selected: %s", self._cached_endpoints)
114
121
  return self._cached_endpoints, changed
@@ -155,10 +162,29 @@ class DiscoverOAuth2Endpoints:
155
162
  async with httpx.AsyncClient(timeout=10.0) as client:
156
163
  for url in urls:
157
164
  try:
158
- resp = await client.get(url, headers={"Accept": "application/json"})
165
+ resp = await client.get(url, follow_redirects=True, headers={"Accept": "application/json"})
159
166
  if resp.status_code != 200:
160
167
  continue
168
+
169
+ # Check content type before attempting JSON parsing
170
+ content_type = resp.headers.get("content-type", "").lower()
171
+ if "application/json" not in content_type:
172
+ logger.info(
173
+ "Discovery endpoint %s returned non-JSON content type: %s. "
174
+ "This may indicate the endpoint doesn't support discovery or requires authentication.",
175
+ url,
176
+ content_type)
177
+ # If it's HTML, log a more helpful message
178
+ if "text/html" in content_type:
179
+ logger.info("The endpoint appears to be returning an HTML page instead of OAuth metadata. "
180
+ "This often means:")
181
+ logger.info("1. The OAuth discovery endpoint doesn't exist at this URL")
182
+ logger.info("2. The server requires authentication before providing discovery metadata")
183
+ logger.info("3. The URL is pointing to a web application instead of an OAuth server")
184
+ continue
185
+
161
186
  body = await resp.aread()
187
+
162
188
  try:
163
189
  meta = OAuthMetadata.model_validate_json(body)
164
190
  except Exception as e:
@@ -167,14 +193,18 @@ class DiscoverOAuth2Endpoints:
167
193
  if meta.authorization_endpoint and meta.token_endpoint:
168
194
  logger.info("Discovered OAuth2 endpoints from %s", url)
169
195
  # this is bit of a hack to get the scopes supported by the auth server
170
- self._last_oauth_scopes = meta.scopes_supported
171
196
  return OAuth2Endpoints(
172
197
  authorization_url=str(meta.authorization_endpoint),
173
198
  token_url=str(meta.token_endpoint),
174
199
  registration_url=str(meta.registration_endpoint) if meta.registration_endpoint else None,
200
+ scopes=meta.scopes_supported,
175
201
  )
176
202
  except Exception as e:
177
203
  logger.debug("Discovery failed at %s: %s", url, e)
204
+
205
+ # If we get here, all discovery URLs failed
206
+ logger.info("OAuth discovery failed for all attempted URLs.")
207
+ logger.info("Attempted URLs: %s", urls)
178
208
  return None
179
209
 
180
210
  def _build_path_aware_discovery_urls(self, base_or_issuer: str) -> list[str]:
@@ -184,17 +214,19 @@ class DiscoverOAuth2Endpoints:
184
214
  path = (p.path or "").rstrip("/")
185
215
  urls: list[str] = []
186
216
  if path:
187
- urls.append(urljoin(base, f"/.well-known/oauth-authorization-server{path}"))
217
+ # this is the specified by the MCP spec
218
+ urls.append(urljoin(base, f".well-known/oauth-protected-resource{path}"))
219
+ # this is fallback for backward compatibility
220
+ urls.append(urljoin(base, f"{path}/.well-known/oauth-authorization-server"))
188
221
  urls.append(urljoin(base, "/.well-known/oauth-authorization-server"))
189
222
  if path:
190
- urls.append(urljoin(base, f"/.well-known/openid-configuration{path}"))
223
+ # this is the specified by the MCP spec
224
+ urls.append(urljoin(base, f".well-known/openid-configuration{path}"))
225
+ # this is fallback for backward compatibility
226
+ urls.append(urljoin(base, f"{path}/.well-known/openid-configuration"))
191
227
  urls.append(base_or_issuer.rstrip("/") + "/.well-known/openid-configuration")
192
228
  return urls
193
229
 
194
- def scopes_supported(self) -> list[str] | None:
195
- """Get the last OAuth scopes discovered from the AS."""
196
- return self._last_oauth_scopes
197
-
198
230
 
199
231
  class DynamicClientRegistration:
200
232
  """Dynamic client registration utility."""
@@ -264,51 +296,54 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
264
296
 
265
297
  # For the OAuth2 flow
266
298
  self._auth_code_provider = None
299
+ self._flow_handler = MCPAuthenticationFlowHandler()
267
300
 
268
- async def authenticate(self, user_id: str | None = None) -> AuthResult:
301
+ self._auth_callback = None
302
+
303
+ def _set_custom_auth_callback(self,
304
+ auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType],
305
+ Awaitable[AuthenticatedContext]]):
306
+ """Set the custom authentication callback."""
307
+ if not self._auth_callback:
308
+ logger.info("Using custom authentication callback")
309
+ self._auth_callback = auth_callback
310
+ if self._auth_code_provider:
311
+ self._auth_code_provider._set_custom_auth_callback(self._auth_callback)
312
+
313
+ async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
269
314
  """
270
315
  Authenticate using MCP OAuth2 flow via NAT framework.
316
+
317
+ If response is provided in kwargs (typically from a 401), performs:
271
318
  1. Dynamic endpoints discovery (RFC9728 + RFC 8414 + OIDC)
272
319
  2. Client registration (RFC7591)
273
- 3. Use NAT's standard OAuth2 flow (OAuth2AuthCodeFlowProvider)
320
+ 3. Authentication
321
+
322
+ Otherwise, performs standard authentication flow.
274
323
  """
275
- auth_request = self.config.auth_request
276
- if not auth_request:
277
- auth_request = AuthRequest(reason=AuthReason.NORMAL)
278
-
279
- if auth_request.reason != AuthReason.RETRY_AFTER_401:
280
- # auth provider is expected to be setup via 401, till that time we return empty auth result
281
- if not self._auth_code_provider:
282
- return AuthResult(credentials=[], token_expires_at=None, raw={})
283
-
284
- await self._discover_and_register(auth_request)
285
- # Use NAT's standard OAuth2 flow
286
- if auth_request.reason == AuthReason.RETRY_AFTER_401:
287
- # force fresh delegate (clears in-mem token cache)
288
- self._auth_code_provider = None
289
- # preserve other fields, just normalize reason & inject user_id
290
- auth_request = auth_request.model_copy(update={
291
- "reason": AuthReason.NORMAL, "user_id": user_id, "www_authenticate": None
292
- })
293
- # back-compat: propagate user_id if provided but not set in the request
294
- elif user_id is not None and auth_request.user_id is None:
295
- auth_request = auth_request.model_copy(update={"user_id": user_id})
296
-
297
- # Perform the OAuth2 flow without lock
298
- return await self._perform_oauth2_flow(auth_request=auth_request)
299
-
300
- async def _discover_and_register(self, auth_request: AuthRequest):
324
+ response = kwargs.get('response')
325
+ if response and response.status_code == 401:
326
+ await self._discover_and_register(response=response)
327
+
328
+ return await self._nat_oauth2_authenticate(user_id=user_id)
329
+
330
+ @property
331
+ def _effective_scopes(self) -> list[str]:
332
+ """Get the effective scopes to be used for the authentication."""
333
+ return self.config.scopes or (self._cached_endpoints.scopes if self._cached_endpoints else []) or []
334
+
335
+ async def _discover_and_register(self, response: httpx.Response | None = None):
301
336
  """
302
337
  Discover OAuth2 endpoints and register an OAuth2 client with the Authorization Server
303
338
  using OIDC client registration.
304
339
  """
305
340
  # Discover OAuth2 endpoints
306
- self._cached_endpoints, endpoints_changed = await self._discoverer.discover(reason=auth_request.reason,
307
- www_authenticate=auth_request.www_authenticate)
341
+ self._cached_endpoints, endpoints_changed = await self._discoverer.discover(response=response)
308
342
  if endpoints_changed:
309
343
  logger.info("OAuth2 endpoints: %s", self._cached_endpoints)
310
344
  self._cached_credentials = None # invalidate credentials tied to old AS
311
- effective_scopes = self._effective_scopes()
345
+ self._auth_code_provider = None
346
+ effective_scopes = self._effective_scopes
312
347
 
313
348
  # Client registration
314
349
  if not self._cached_credentials:
@@ -324,21 +359,20 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
324
359
  self._cached_credentials = await self._registrar.register(self._cached_endpoints, effective_scopes)
325
360
  logger.info("Registered OAuth2 client: %s", self._cached_credentials.client_id)
326
361
 
327
- def _effective_scopes(self) -> list[str] | None:
328
- """
329
- Prefer caller-provided scopes; otherwise fall back to AS-advertised scopes_supported.
330
- """
331
- return self.config.scopes or self._discoverer.scopes_supported()
332
-
333
- async def _build_oauth2_delegate(self):
334
- """Build NAT OAuth2 provider and delegate auth token acquisition and refresh to it"""
362
+ async def _nat_oauth2_authenticate(self, user_id: str | None = None) -> AuthResult:
363
+ """Perform the OAuth2 flow using MCP-specific authentication flow handler."""
335
364
  from nat.authentication.oauth2.oauth2_auth_code_flow_provider import OAuth2AuthCodeFlowProvider
336
- from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
365
+
366
+ if not self._cached_endpoints or not self._cached_credentials:
367
+ # if discovery is yet to to be done return empty auth result
368
+ return AuthResult(credentials=[], token_expires_at=None, raw={})
337
369
 
338
370
  endpoints = self._cached_endpoints
339
371
  credentials = self._cached_credentials
340
372
 
373
+ # Build the OAuth2 provider if not already built
341
374
  if self._auth_code_provider is None:
375
+ scopes = self._effective_scopes
342
376
  oauth2_config = OAuth2AuthCodeFlowProviderConfig(
343
377
  client_id=credentials.client_id,
344
378
  client_secret=credentials.client_secret or "",
@@ -346,22 +380,15 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
346
380
  token_url=str(endpoints.token_url),
347
381
  token_endpoint_auth_method=getattr(self.config, "token_endpoint_auth_method", None),
348
382
  redirect_uri=str(self.config.redirect_uri) if self.config.redirect_uri else "",
349
- scopes=self._effective_scopes() or [],
383
+ scopes=scopes,
350
384
  use_pkce=bool(self.config.use_pkce),
351
- )
352
-
385
+ authorization_kwargs={"resource": str(self.config.server_url)})
353
386
  self._auth_code_provider = OAuth2AuthCodeFlowProvider(oauth2_config)
354
387
 
355
- async def _perform_oauth2_flow(self, auth_request: AuthRequest | None = None) -> AuthResult:
356
- """Perform the OAuth2 flow using NAT OAuth2 provider."""
357
- # This helper is only for non-401 flows
358
- if auth_request and auth_request.reason == AuthReason.RETRY_AFTER_401:
359
- raise RuntimeError("_perform_oauth2_flow should not be called for RETRY_AFTER_401")
360
-
361
- if not self._cached_endpoints or not self._cached_credentials:
362
- raise RuntimeError("OAuth2 flow called before discovery/registration")
388
+ # Use MCP-specific authentication method if available
389
+ if hasattr(self._auth_code_provider, "_set_custom_auth_callback"):
390
+ self._auth_code_provider._set_custom_auth_callback(self._auth_callback
391
+ or self._flow_handler.authenticate)
363
392
 
364
- # (Re)build the delegate if needed
365
- await self._build_oauth2_delegate()
366
- # Let the delegate handle per-user cache + refresh
367
- return await self._auth_code_provider.authenticate()
393
+ # Auth code provider is responsible for per-user cache + refresh
394
+ return await self._auth_code_provider.authenticate(user_id=user_id)
@@ -18,7 +18,6 @@ from pydantic import HttpUrl
18
18
  from pydantic import model_validator
19
19
 
20
20
  from nat.authentication.interfaces import AuthProviderBaseConfig
21
- from nat.data_models.authentication import AuthRequest
22
21
 
23
22
 
24
23
  class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
@@ -51,12 +50,16 @@ class MCPOAuth2ProviderConfig(AuthProviderBaseConfig, name="mcp_oauth2"):
51
50
  # Advanced options
52
51
  use_pkce: bool = Field(default=True, description="Use PKCE for authorization code flow")
53
52
 
54
- auth_request: AuthRequest | None = Field(default=None, description="Auth request for authentication (metadata)")
53
+ default_user_id: str | None = Field(default=None, description="Default user ID for authentication")
54
+ allow_default_user_id_for_tool_calls: bool = Field(default=True, description="Allow default user ID for tool calls")
55
55
 
56
56
  @model_validator(mode="after")
57
57
  def validate_auth_config(self):
58
58
  """Validate authentication configuration for MCP-specific options."""
59
59
 
60
+ # if default_user_id is not provided, use the server_url as the default user id
61
+ if not self.default_user_id:
62
+ self.default_user_id = str(self.server_url)
60
63
  # Dynamic registration + MCP discovery
61
64
  if self.enable_dynamic_registration and not self.client_id:
62
65
  # Pure dynamic registration - no explicit credentials needed
@@ -15,13 +15,18 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import asyncio
19
+ import json
18
20
  import logging
19
21
  from abc import ABC
20
22
  from abc import abstractmethod
23
+ from collections.abc import AsyncGenerator
24
+ from collections.abc import Callable
21
25
  from contextlib import AsyncExitStack
22
26
  from contextlib import asynccontextmanager
23
- from typing import AsyncGenerator
27
+ from datetime import timedelta
24
28
 
29
+ import anyio
25
30
  import httpx
26
31
 
27
32
  from mcp import ClientSession
@@ -30,10 +35,13 @@ from mcp.client.stdio import StdioServerParameters
30
35
  from mcp.client.stdio import stdio_client
31
36
  from mcp.client.streamable_http import streamablehttp_client
32
37
  from mcp.types import TextContent
38
+ from nat.authentication.interfaces import AuthenticatedContext
39
+ from nat.authentication.interfaces import AuthFlowType
33
40
  from nat.authentication.interfaces import AuthProviderBase
34
- from nat.data_models.authentication import AuthReason
35
- from nat.data_models.authentication import AuthRequest
41
+ from nat.plugins.mcp.exception_handler import convert_to_mcp_error
42
+ from nat.plugins.mcp.exception_handler import format_mcp_error
36
43
  from nat.plugins.mcp.exception_handler import mcp_exception_handler
44
+ from nat.plugins.mcp.exceptions import MCPError
37
45
  from nat.plugins.mcp.exceptions import MCPToolNotFoundError
38
46
  from nat.plugins.mcp.utils import model_from_mcp_schema
39
47
  from nat.utils.type_utils import override
@@ -47,75 +55,89 @@ class AuthAdapter(httpx.Auth):
47
55
  Converts AuthProviderBase to httpx.Auth interface for dynamic token management.
48
56
  """
49
57
 
50
- def __init__(self, auth_provider: AuthProviderBase, auth_for_tool_calls_only: bool = False):
58
+ def __init__(self, auth_provider: AuthProviderBase):
51
59
  self.auth_provider = auth_provider
52
- self.auth_for_tool_calls_only = auth_for_tool_calls_only
60
+ # each adapter instance has its own lock to avoid unnecessary delays for multiple clients
61
+ self._lock = anyio.Lock()
53
62
 
54
63
  async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
55
64
  """Add authentication headers to the request using NAT auth provider."""
56
- # Check if we should only auth tool calls, Is this needed?
57
- if self.auth_for_tool_calls_only and not self._is_tool_call_request(request):
58
- # Skip auth for non-tool calls
59
- yield request
60
- return
61
-
62
- try:
63
- # Get fresh auth headers from the NAT auth provider
64
- auth_headers = await self._get_auth_headers(reason=AuthReason.NORMAL)
65
- request.headers.update(auth_headers)
66
- except Exception as e:
67
- logger.info("Failed to get auth headers: %s", e)
68
- # Continue without auth headers if auth fails
69
-
70
- response = yield request
71
-
72
- # Handle 401 responses by retrying with fresh auth
73
- if response.status_code == 401:
65
+ async with self._lock:
74
66
  try:
75
- # Get fresh auth headers with 401 context
76
- auth_headers = await self._get_auth_headers(reason=AuthReason.RETRY_AFTER_401, response=response)
67
+ # Get auth headers from the NAT auth provider:
68
+ # 1. If discovery is yet to done this will return None and request will be sent without auth header.
69
+ # 2. If discovery is done, this will return the auth header from cache if the token is still valid
70
+ auth_headers = await self._get_auth_headers(request=request, response=None)
77
71
  request.headers.update(auth_headers)
78
- yield request # Retry the request
79
72
  except Exception as e:
80
- logger.info("Failed to refresh auth after 401: %s", e)
73
+ logger.info("Failed to get auth headers: %s", e)
74
+ # Continue without auth headers if auth fails
75
+
76
+ response = yield request
77
+
78
+ # Handle 401 responses by retrying with fresh auth
79
+ if response.status_code == 401:
80
+ try:
81
+ # 401 can happen if:
82
+ # 1. The request was sent without auth header
83
+ # 2. The auth headers are invalid
84
+ # 3. The auth headers are expired
85
+ # 4. The auth headers are revoked
86
+ # 5. Auth config on the MCP server has changed
87
+ # In this case we attempt to re-run discovery and authentication
88
+ auth_headers = await self._get_auth_headers(request=request, response=response)
89
+ request.headers.update(auth_headers)
90
+ yield request # Retry the request
91
+ except Exception as e:
92
+ logger.info("Failed to refresh auth after 401: %s", e)
81
93
  return
82
94
 
83
- def _is_tool_call_request(self, request: httpx.Request) -> bool:
84
- """Check if this is a tool call request based on the request body."""
95
+ def _get_session_id_from_tool_call_request(self, request: httpx.Request) -> tuple[str | None, bool]:
96
+ """Check if this is a tool call request based on the request body.
97
+ Return the session id if it exists and a boolean indicating if it is a tool call request
98
+ """
85
99
  try:
86
100
  # Check if the request body contains a tool call
87
101
  if request.content:
88
- import json
89
102
  body = json.loads(request.content.decode('utf-8'))
90
103
  # Check if it's a JSON-RPC request with method "tools/call"
91
104
  if (isinstance(body, dict) and body.get("method") == "tools/call"):
92
- return True
105
+ session_id = body.get("params").get("_meta").get("session_id")
106
+ return session_id, True
93
107
  except (json.JSONDecodeError, UnicodeDecodeError, AttributeError):
94
108
  # If we can't parse the body, assume it's not a tool call
95
109
  pass
96
- return False
110
+ return None, False
97
111
 
98
- async def _get_auth_headers(self, reason: AuthReason, response: httpx.Response | None = None) -> dict[str, str]:
112
+ async def _get_auth_headers(self,
113
+ request: httpx.Request | None = None,
114
+ response: httpx.Response | None = None) -> dict[str, str]:
99
115
  """Get authentication headers from the NAT auth provider."""
100
- # Build auth request
101
- www_authenticate = response.headers.get("WWW-Authenticate", None) if response else None
102
- auth_request = AuthRequest(
103
- reason=reason,
104
- www_authenticate=www_authenticate,
105
- )
106
116
  try:
107
- # Mutating the config is not thread-safe, so we need to lock here
108
- # Is mutating the config the only way to pass the auth request to the auth provider? This needs
109
- # to be re-visited.
110
- self.auth_provider.config.auth_request = auth_request
111
- auth_result = await self.auth_provider.authenticate()
117
+ session_id = None
118
+ is_tool_call = False
119
+ if request:
120
+ session_id, is_tool_call = self._get_session_id_from_tool_call_request(request)
121
+
122
+ if is_tool_call:
123
+ # Tool call requests should use the session id if it exists, default user id can be used if allowed
124
+ if self.auth_provider.config.allow_default_user_id_for_tool_calls:
125
+ user_id = session_id or self.auth_provider.config.default_user_id
126
+ else:
127
+ user_id = session_id
128
+ else:
129
+ # Non-tool call requests should use the session id if it exists and fallback to default user id
130
+ user_id = session_id or self.auth_provider.config.default_user_id
131
+
132
+ auth_result = await self.auth_provider.authenticate(user_id=user_id, response=response)
133
+
112
134
  # Check if we have BearerTokenCred
113
135
  from nat.data_models.authentication import BearerTokenCred
114
136
  if auth_result.credentials and isinstance(auth_result.credentials[0], BearerTokenCred):
115
137
  token = auth_result.credentials[0].token.get_secret_value()
116
138
  return {"Authorization": f"Bearer {token}"}
117
139
  else:
118
- logger.warning("Auth provider did not return BearerTokenCred")
140
+ logger.info("Auth provider did not return BearerTokenCred")
119
141
  return {}
120
142
  except Exception as e:
121
143
  logger.warning("Failed to get auth token: %s", e)
@@ -131,7 +153,14 @@ class MCPBaseClient(ABC):
131
153
  auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
132
154
  """
133
155
 
134
- def __init__(self, transport: str = 'streamable-http', auth_provider: AuthProviderBase | None = None):
156
+ def __init__(self,
157
+ transport: str = 'streamable-http',
158
+ auth_provider: AuthProviderBase | None = None,
159
+ tool_call_timeout: timedelta = timedelta(seconds=5),
160
+ reconnect_enabled: bool = True,
161
+ reconnect_max_attempts: int = 2,
162
+ reconnect_initial_backoff: float = 0.5,
163
+ reconnect_max_backoff: float = 50.0):
135
164
  self._tools = None
136
165
  self._transport = transport.lower()
137
166
  if self._transport not in ['sse', 'stdio', 'streamable-http']:
@@ -143,8 +172,18 @@ class MCPBaseClient(ABC):
143
172
  self._initial_connection = False
144
173
 
145
174
  # Convert auth provider to AuthAdapter
175
+ self._auth_provider = auth_provider
146
176
  self._httpx_auth = AuthAdapter(auth_provider) if auth_provider else None
147
177
 
178
+ self._tool_call_timeout = tool_call_timeout
179
+
180
+ # Reconnect configuration
181
+ self._reconnect_enabled = reconnect_enabled
182
+ self._reconnect_max_attempts = reconnect_max_attempts
183
+ self._reconnect_initial_backoff = reconnect_initial_backoff
184
+ self._reconnect_max_backoff = reconnect_max_backoff
185
+ self._reconnect_lock: asyncio.Lock = asyncio.Lock()
186
+
148
187
  @property
149
188
  def transport(self) -> str:
150
189
  return self._transport
@@ -164,13 +203,14 @@ class MCPBaseClient(ABC):
164
203
  return self
165
204
 
166
205
  async def __aexit__(self, exc_type, exc_value, traceback):
167
- if not self._exit_stack:
168
- raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
206
+ if self._exit_stack:
207
+ # Close session
208
+ await self._exit_stack.aclose()
209
+ self._session = None
210
+ self._exit_stack = None
169
211
 
170
- # Close session
171
- await self._exit_stack.aclose()
172
- self._session = None
173
- self._exit_stack = None
212
+ self._connection_established = False
213
+ self._tools = None
174
214
 
175
215
  @property
176
216
  def server_name(self):
@@ -181,22 +221,80 @@ class MCPBaseClient(ABC):
181
221
 
182
222
  @abstractmethod
183
223
  @asynccontextmanager
184
- async def connect_to_server(self):
224
+ async def connect_to_server(self) -> AsyncGenerator[ClientSession, None]:
185
225
  """
186
226
  Establish a session with an MCP server within an async context
187
227
  """
188
228
  yield
189
229
 
230
+ async def _reconnect(self):
231
+ """
232
+ Attempt to reconnect by tearing down and re-establishing the session.
233
+ """
234
+ async with self._reconnect_lock:
235
+ backoff = self._reconnect_initial_backoff
236
+ attempt = 0
237
+ last_error: Exception | None = None
238
+
239
+ while attempt in range(0, self._reconnect_max_attempts):
240
+ attempt += 1
241
+ try:
242
+ # Close the existing stack and ClientSession
243
+ if self._exit_stack:
244
+ await self._exit_stack.aclose()
245
+ # Create a fresh stack and session
246
+ self._exit_stack = AsyncExitStack()
247
+ self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
248
+
249
+ self._connection_established = True
250
+ self._tools = None
251
+
252
+ logger.info("Reconnected to MCP server (%s) on attempt %d", self.server_name, attempt)
253
+ return
254
+
255
+ except Exception as e:
256
+ last_error = e
257
+ logger.warning("Reconnect attempt %d failed for %s: %s", attempt, self.server_name, e)
258
+ await asyncio.sleep(min(backoff, self._reconnect_max_backoff))
259
+ backoff = min(backoff * 2, self._reconnect_max_backoff)
260
+
261
+ # All attempts failed
262
+ self._connection_established = False
263
+ if last_error:
264
+ raise last_error
265
+
266
+ async def _with_reconnect(self, coro):
267
+ """
268
+ Execute an awaited operation, reconnecting once on errors.
269
+ """
270
+ try:
271
+ return await coro()
272
+ except Exception as e:
273
+ if self._reconnect_enabled:
274
+ logger.warning("MCP Client operation failed. Attempting reconnect: %s", e)
275
+ try:
276
+ await self._reconnect()
277
+ except Exception as reconnect_err:
278
+ logger.error("MCP Client reconnect attempt failed: %s", reconnect_err)
279
+ raise
280
+ return await coro()
281
+ raise
282
+
190
283
  async def get_tools(self):
191
284
  """
192
285
  Retrieve a dictionary of all tools served by the MCP server.
193
286
  Uses unauthenticated session for discovery.
194
287
  """
195
288
 
196
- if not self._session:
197
- raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
289
+ async def _get_tools():
290
+ session = self._session
291
+ return await session.list_tools()
198
292
 
199
- response = await self._session.list_tools()
293
+ try:
294
+ response = await self._with_reconnect(_get_tools)
295
+ except Exception as e:
296
+ logger.warning("Failed to get tools: %s", e)
297
+ raise
200
298
 
201
299
  return {
202
300
  tool.name:
@@ -204,7 +302,8 @@ class MCPBaseClient(ABC):
204
302
  tool_name=tool.name,
205
303
  tool_description=tool.description,
206
304
  tool_input_schema=tool.inputSchema,
207
- parent_client=self)
305
+ parent_client=self,
306
+ tool_call_timeout=self._tool_call_timeout)
208
307
  for tool in response.tools
209
308
  }
210
309
 
@@ -233,13 +332,42 @@ class MCPBaseClient(ABC):
233
332
  raise MCPToolNotFoundError(tool_name, self.server_name)
234
333
  return tool
235
334
 
335
+ def set_user_auth_callback(self, auth_callback: Callable[[AuthFlowType], AuthenticatedContext]):
336
+ """Set the user authentication callback."""
337
+ if self._auth_provider and hasattr(self._auth_provider, "_set_custom_auth_callback"):
338
+ self._auth_provider._set_custom_auth_callback(auth_callback)
339
+
236
340
  @mcp_exception_handler
237
- async def call_tool(self, tool_name: str, tool_args: dict | None):
341
+ async def call_tool_with_meta(self, tool_name: str, args: dict, session_id: str):
342
+ from mcp.types import CallToolRequest
343
+ from mcp.types import CallToolRequestParams
344
+ from mcp.types import CallToolResult
345
+ from mcp.types import ClientRequest
346
+
238
347
  if not self._session:
239
348
  raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
240
349
 
241
- result = await self._session.call_tool(tool_name, tool_args)
242
- return result
350
+ async def _call_tool_with_meta():
351
+ params = CallToolRequestParams(name=tool_name, arguments=args, **{"_meta": {"session_id": session_id}})
352
+ req = ClientRequest(CallToolRequest(params=params))
353
+ # We will increase the timeout to 5 minutes if the tool call timeout is less than 5 min and
354
+ # auth is enabled.
355
+ if self._auth_provider and self._tool_call_timeout.total_seconds() < 300:
356
+ timeout = timedelta(seconds=300)
357
+ else:
358
+ timeout = self._tool_call_timeout
359
+ return await self._session.send_request(req, CallToolResult, request_read_timeout_seconds=timeout)
360
+
361
+ return await self._with_reconnect(_call_tool_with_meta)
362
+
363
+ @mcp_exception_handler
364
+ async def call_tool(self, tool_name: str, tool_args: dict | None):
365
+
366
+ async def _call_tool():
367
+ session = self._session
368
+ return await session.call_tool(tool_name, tool_args, read_timeout_seconds=self._tool_call_timeout)
369
+
370
+ return await self._with_reconnect(_call_tool)
243
371
 
244
372
 
245
373
  class MCPSSEClient(MCPBaseClient):
@@ -250,8 +378,19 @@ class MCPSSEClient(MCPBaseClient):
250
378
  url (str): The url of the MCP server
251
379
  """
252
380
 
253
- def __init__(self, url: str):
254
- super().__init__("sse")
381
+ def __init__(self,
382
+ url: str,
383
+ tool_call_timeout: timedelta = timedelta(seconds=5),
384
+ reconnect_enabled: bool = True,
385
+ reconnect_max_attempts: int = 2,
386
+ reconnect_initial_backoff: float = 0.5,
387
+ reconnect_max_backoff: float = 50.0):
388
+ super().__init__("sse",
389
+ tool_call_timeout=tool_call_timeout,
390
+ reconnect_enabled=reconnect_enabled,
391
+ reconnect_max_attempts=reconnect_max_attempts,
392
+ reconnect_initial_backoff=reconnect_initial_backoff,
393
+ reconnect_max_backoff=reconnect_max_backoff)
255
394
  self._url = url
256
395
 
257
396
  @property
@@ -286,8 +425,21 @@ class MCPStdioClient(MCPBaseClient):
286
425
  env (dict[str, str] | None): Environment variables to set for the process
287
426
  """
288
427
 
289
- def __init__(self, command: str, args: list[str] | None = None, env: dict[str, str] | None = None):
290
- super().__init__("stdio")
428
+ def __init__(self,
429
+ command: str,
430
+ args: list[str] | None = None,
431
+ env: dict[str, str] | None = None,
432
+ tool_call_timeout: timedelta = timedelta(seconds=5),
433
+ reconnect_enabled: bool = True,
434
+ reconnect_max_attempts: int = 2,
435
+ reconnect_initial_backoff: float = 0.5,
436
+ reconnect_max_backoff: float = 50.0):
437
+ super().__init__("stdio",
438
+ tool_call_timeout=tool_call_timeout,
439
+ reconnect_enabled=reconnect_enabled,
440
+ reconnect_max_attempts=reconnect_max_attempts,
441
+ reconnect_initial_backoff=reconnect_initial_backoff,
442
+ reconnect_max_backoff=reconnect_max_backoff)
291
443
  self._command = command
292
444
  self._args = args
293
445
  self._env = env
@@ -331,8 +483,21 @@ class MCPStreamableHTTPClient(MCPBaseClient):
331
483
  auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
332
484
  """
333
485
 
334
- def __init__(self, url: str, auth_provider: AuthProviderBase | None = None):
335
- super().__init__("streamable-http", auth_provider=auth_provider)
486
+ def __init__(self,
487
+ url: str,
488
+ auth_provider: AuthProviderBase | None = None,
489
+ tool_call_timeout: timedelta = timedelta(seconds=5),
490
+ reconnect_enabled: bool = True,
491
+ reconnect_max_attempts: int = 2,
492
+ reconnect_initial_backoff: float = 0.5,
493
+ reconnect_max_backoff: float = 50.0):
494
+ super().__init__("streamable-http",
495
+ auth_provider=auth_provider,
496
+ tool_call_timeout=tool_call_timeout,
497
+ reconnect_enabled=reconnect_enabled,
498
+ reconnect_max_attempts=reconnect_max_attempts,
499
+ reconnect_initial_backoff=reconnect_initial_backoff,
500
+ reconnect_max_backoff=reconnect_max_backoff)
336
501
  self._url = url
337
502
 
338
503
  @property
@@ -371,15 +536,20 @@ class MCPToolClient:
371
536
 
372
537
  def __init__(self,
373
538
  session: ClientSession,
539
+ parent_client: "MCPBaseClient",
374
540
  tool_name: str,
375
541
  tool_description: str | None,
376
542
  tool_input_schema: dict | None = None,
377
- parent_client: "MCPBaseClient | None" = None):
543
+ tool_call_timeout: timedelta = timedelta(seconds=5)):
378
544
  self._session = session
379
545
  self._tool_name = tool_name
380
546
  self._tool_description = tool_description
381
547
  self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
382
548
  self._parent_client = parent_client
549
+ self._tool_call_timeout = tool_call_timeout
550
+
551
+ if self._parent_client is None:
552
+ raise RuntimeError("MCPToolClient initialized without a parent client.")
383
553
 
384
554
  @property
385
555
  def name(self):
@@ -417,20 +587,49 @@ class MCPToolClient:
417
587
  """
418
588
  if self._session is None:
419
589
  raise RuntimeError("No session available for tool call")
420
- logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
421
- result = await self._session.call_tool(self._tool_name, tool_args)
422
590
 
423
- output = []
591
+ # Extract context information
592
+ session_id = None
593
+ try:
594
+ from nat.builder.context import Context as _Ctx
595
+
596
+ # get auth callback (for example: WebSocketAuthenticationFlowHandler). this is lazily set in the client
597
+ # on first tool call
598
+ auth_callback = _Ctx.get().user_auth_callback
599
+ if auth_callback and self._parent_client:
600
+ # set custom auth callback
601
+ self._parent_client.set_user_auth_callback(auth_callback)
602
+
603
+ # get session id from context, authentication is done per-websocket session for tool calls
604
+ cookies = getattr(_Ctx.get().metadata, "cookies", None)
605
+ if cookies:
606
+ session_id = cookies.get("nat-session")
607
+ except Exception:
608
+ pass
424
609
 
425
- for res in result.content:
426
- if isinstance(res, TextContent):
427
- output.append(res.text)
610
+ try:
611
+ if session_id:
612
+ logger.info("Calling tool %s with arguments %s for a user session", self._tool_name, tool_args)
613
+ result = await self._parent_client.call_tool_with_meta(self._tool_name, tool_args, session_id)
428
614
  else:
429
- # Log non-text content for now
430
- logger.warning("Got not-text output from %s of type %s", self.name, type(res))
431
- result_str = "\n".join(output)
432
-
433
- if result.isError:
434
- raise RuntimeError(result_str)
615
+ logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
616
+ result = await self._session.call_tool(self._tool_name, tool_args)
617
+
618
+ output = []
619
+ for res in result.content:
620
+ if isinstance(res, TextContent):
621
+ output.append(res.text)
622
+ else:
623
+ # Log non-text content for now
624
+ logger.warning("Got not-text output from %s of type %s", self.name, type(res))
625
+ result_str = "\n".join(output)
626
+
627
+ if result.isError:
628
+ mcp_error: MCPError = convert_to_mcp_error(RuntimeError(result_str), self._parent_client.server_name)
629
+ raise mcp_error
630
+
631
+ except MCPError as e:
632
+ format_mcp_error(e, include_traceback=False)
633
+ result_str = "MCPToolClient tool call failed: %s" % e.original_exception
435
634
 
436
635
  return result_str
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import logging
17
+ from datetime import timedelta
17
18
  from typing import Literal
18
19
 
19
20
  from pydantic import BaseModel
@@ -55,7 +56,8 @@ class MCPServerConfig(BaseModel):
55
56
  env: dict[str, str] | None = Field(default=None, description="Environment variables for the stdio process")
56
57
 
57
58
  # Authentication configuration
58
- auth_provider: AuthenticationRef | None = Field(default=None, description="Reference to authentication provider")
59
+ auth_provider: str | AuthenticationRef | None = Field(default=None,
60
+ description="Reference to authentication provider")
59
61
 
60
62
  @model_validator(mode="after")
61
63
  def validate_model(self):
@@ -90,6 +92,20 @@ class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
90
92
  Configuration for connecting to an MCP server as a client and exposing selected tools.
91
93
  """
92
94
  server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)")
95
+ tool_call_timeout: timedelta = Field(
96
+ default=timedelta(seconds=60),
97
+ description="Timeout (in seconds) for the MCP tool call. Defaults to 60 seconds.")
98
+ reconnect_enabled: bool = Field(
99
+ default=True,
100
+ description="Whether to enable reconnecting to the MCP server if the connection is lost. \
101
+ Defaults to True.")
102
+ reconnect_max_attempts: int = Field(default=2,
103
+ ge=0,
104
+ description="Maximum number of reconnect attempts. Defaults to 2.")
105
+ reconnect_initial_backoff: float = Field(
106
+ default=0.5, ge=0.0, description="Initial backoff time for reconnect attempts. Defaults to 0.5 seconds.")
107
+ reconnect_max_backoff: float = Field(
108
+ default=50.0, ge=0.0, description="Maximum backoff time for reconnect attempts. Defaults to 50 seconds.")
93
109
  tool_overrides: dict[str, MCPToolOverrideConfig] | None = Field(
94
110
  default=None,
95
111
  description="""Optional tool name overrides and description changes.
@@ -102,6 +118,13 @@ class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
102
118
  description: "Multiply two numbers" # alias defaults to original name
103
119
  """)
104
120
 
121
+ @model_validator(mode="after")
122
+ def _validate_reconnect_backoff(self) -> "MCPClientConfig":
123
+ """Validate reconnect backoff values."""
124
+ if self.reconnect_max_backoff < self.reconnect_initial_backoff:
125
+ raise ValueError("reconnect_max_backoff must be greater than or equal to reconnect_initial_backoff")
126
+ return self
127
+
105
128
 
106
129
  @register_function_group(config_type=MCPClientConfig)
107
130
  async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
@@ -126,11 +149,29 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
126
149
  if config.server.transport == "stdio":
127
150
  if not config.server.command:
128
151
  raise ValueError("command is required for stdio transport")
129
- client = MCPStdioClient(config.server.command, config.server.args, config.server.env)
152
+ client = MCPStdioClient(config.server.command,
153
+ config.server.args,
154
+ config.server.env,
155
+ config.tool_call_timeout,
156
+ reconnect_enabled=config.reconnect_enabled,
157
+ reconnect_max_attempts=config.reconnect_max_attempts,
158
+ reconnect_initial_backoff=config.reconnect_initial_backoff,
159
+ reconnect_max_backoff=config.reconnect_max_backoff)
130
160
  elif config.server.transport == "sse":
131
- client = MCPSSEClient(str(config.server.url))
161
+ client = MCPSSEClient(str(config.server.url),
162
+ tool_call_timeout=config.tool_call_timeout,
163
+ reconnect_enabled=config.reconnect_enabled,
164
+ reconnect_max_attempts=config.reconnect_max_attempts,
165
+ reconnect_initial_backoff=config.reconnect_initial_backoff,
166
+ reconnect_max_backoff=config.reconnect_max_backoff)
132
167
  elif config.server.transport == "streamable-http":
133
- client = MCPStreamableHTTPClient(str(config.server.url), auth_provider=auth_provider)
168
+ client = MCPStreamableHTTPClient(str(config.server.url),
169
+ auth_provider=auth_provider,
170
+ tool_call_timeout=config.tool_call_timeout,
171
+ reconnect_enabled=config.reconnect_enabled,
172
+ reconnect_max_attempts=config.reconnect_max_attempts,
173
+ reconnect_initial_backoff=config.reconnect_initial_backoff,
174
+ reconnect_max_backoff=config.reconnect_max_backoff)
134
175
  else:
135
176
  raise ValueError(f"Unsupported transport: {config.server.transport}")
136
177
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-nat-mcp
3
- Version: 1.3.0a20250925
3
+ Version: 1.3.0a20250928
4
4
  Summary: Subpackage for MCP client integration in NeMo Agent toolkit
5
5
  Keywords: ai,rag,agents,mcp
6
6
  Classifier: Programming Language :: Python
@@ -9,7 +9,7 @@ Classifier: Programming Language :: Python :: 3.12
9
9
  Classifier: Programming Language :: Python :: 3.13
10
10
  Requires-Python: <3.14,>=3.11
11
11
  Description-Content-Type: text/markdown
12
- Requires-Dist: nvidia-nat==v1.3.0a20250925
12
+ Requires-Dist: nvidia-nat==v1.3.0a20250928
13
13
  Requires-Dist: mcp~=1.14
14
14
 
15
15
  <!--
@@ -1,18 +1,19 @@
1
1
  nat/meta/pypi.md,sha256=GyV4DI1d9ThgEhnYTQ0vh40Q9hPC8jN-goLnRiFDmZ8,1498
2
2
  nat/plugins/mcp/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
3
- nat/plugins/mcp/client_base.py,sha256=eWIHhZIaX8KGJMS7BTQ2szU2Nm4h57zCZ7uR5CmbYiY,15585
4
- nat/plugins/mcp/client_impl.py,sha256=A1rSxVz1K29ZlqY-7BRrMNRAkCVZyUg7MS6vU0stYZc,8067
3
+ nat/plugins/mcp/client_base.py,sha256=x6NHs3X2alyB6Wo4pjWZrGyuU7lLqGC_cLO9fuL4Zgw,25194
4
+ nat/plugins/mcp/client_impl.py,sha256=M1gTMlp3RLhFaAHOvwkk38boFy05MixV_glrIEcMjvo,10759
5
5
  nat/plugins/mcp/exception_handler.py,sha256=JdPdZG1NgWpdRnIz7JTGHiJASS5wot9nJiD3SRWV4Kw,7649
6
6
  nat/plugins/mcp/exceptions.py,sha256=EGVOnYlui8xufm8dhJyPL1SUqBLnCGOTvRoeyNcmcWE,5980
7
7
  nat/plugins/mcp/register.py,sha256=HOT2Wl2isGuyFc7BUTi58-BbjI5-EtZMZo7stsv5pN4,831
8
8
  nat/plugins/mcp/tool.py,sha256=v3MFsiaLJy8Ourcfqa6ohtAE2Nn-vqpC6Q6gsCdJ28Q,6165
9
9
  nat/plugins/mcp/utils.py,sha256=3fuzYpC14wrfMOTOGvY2KHWcxZvBWqrxdDZD17lhmC8,4055
10
10
  nat/plugins/mcp/auth/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
11
- nat/plugins/mcp/auth/auth_provider.py,sha256=GOmM9vCfVd0QiyD_hBj7zCfkiimHa1WDfTTWWfMsr_k,17466
12
- nat/plugins/mcp/auth/auth_provider_config.py,sha256=bE6IKV0yveo98KXr0xynrH5BMwPhRv8xbaMBwYu42YQ,3587
11
+ nat/plugins/mcp/auth/auth_flow_handler.py,sha256=eRqRS2t-YzJ3kEFaG0PEC8DzctYzaJzr9XLZGkvuxq0,6018
12
+ nat/plugins/mcp/auth/auth_provider.py,sha256=5TTaPlIXMgwrE4YcZ_HO9-GNBBFaDTdUKAa5fuALayI,19142
13
+ nat/plugins/mcp/auth/auth_provider_config.py,sha256=vhU47Vcp_30M8tWu0FumbJ6pdUnFbBZm-ABdNlup__U,3821
13
14
  nat/plugins/mcp/auth/register.py,sha256=yzphsn1I4a5G39_IacbuX0ZQqGM8fevvTUM_B94UXKE,1211
14
- nvidia_nat_mcp-1.3.0a20250925.dist-info/METADATA,sha256=cZEodxrhN0o9EhFmJASwebVXNsCH8GAZEe7AujtEpao,1997
15
- nvidia_nat_mcp-1.3.0a20250925.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
- nvidia_nat_mcp-1.3.0a20250925.dist-info/entry_points.txt,sha256=rYvUp4i-klBr3bVNh7zYOPXret704vTjvCk1qd7FooI,97
17
- nvidia_nat_mcp-1.3.0a20250925.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
18
- nvidia_nat_mcp-1.3.0a20250925.dist-info/RECORD,,
15
+ nvidia_nat_mcp-1.3.0a20250928.dist-info/METADATA,sha256=VnuxYtll39Hir1_hDQvhqRWPM00M6A9JD3NCd7JISaw,1997
16
+ nvidia_nat_mcp-1.3.0a20250928.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
17
+ nvidia_nat_mcp-1.3.0a20250928.dist-info/entry_points.txt,sha256=rYvUp4i-klBr3bVNh7zYOPXret704vTjvCk1qd7FooI,97
18
+ nvidia_nat_mcp-1.3.0a20250928.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
19
+ nvidia_nat_mcp-1.3.0a20250928.dist-info/RECORD,,