nvidia-nat-mcp 1.3.0a20250929__py3-none-any.whl → 1.3.0a20251001__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/plugins/mcp/auth/auth_flow_handler.py +1 -1
- nat/plugins/mcp/auth/auth_provider.py +5 -1
- nat/plugins/mcp/client_base.py +145 -42
- nat/plugins/mcp/client_impl.py +74 -5
- nat/plugins/mcp/exception_handler.py +1 -1
- {nvidia_nat_mcp-1.3.0a20250929.dist-info → nvidia_nat_mcp-1.3.0a20251001.dist-info}/METADATA +2 -2
- {nvidia_nat_mcp-1.3.0a20250929.dist-info → nvidia_nat_mcp-1.3.0a20251001.dist-info}/RECORD +10 -10
- {nvidia_nat_mcp-1.3.0a20250929.dist-info → nvidia_nat_mcp-1.3.0a20251001.dist-info}/WHEEL +0 -0
- {nvidia_nat_mcp-1.3.0a20250929.dist-info → nvidia_nat_mcp-1.3.0a20251001.dist-info}/entry_points.txt +0 -0
- {nvidia_nat_mcp-1.3.0a20250929.dist-info → nvidia_nat_mcp-1.3.0a20251001.dist-info}/top_level.txt +0 -0
@@ -127,7 +127,7 @@ class MCPAuthenticationFlowHandler(ConsoleAuthenticationFlowHandler):
|
|
127
127
|
try:
|
128
128
|
token = await asyncio.wait_for(flow_state.future, timeout=timeout)
|
129
129
|
logger.info("MCP authentication successful, token obtained")
|
130
|
-
except
|
130
|
+
except TimeoutError as exc:
|
131
131
|
logger.error("MCP authentication timed out")
|
132
132
|
raise RuntimeError(f"MCP authentication timed out ({timeout} seconds). Please try again.") from exc
|
133
133
|
finally:
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
import logging
|
17
17
|
from collections.abc import Awaitable
|
18
|
-
from
|
18
|
+
from collections.abc import Callable
|
19
19
|
from urllib.parse import urljoin
|
20
20
|
from urllib.parse import urlparse
|
21
21
|
|
@@ -321,6 +321,10 @@ class MCPOAuth2Provider(AuthProviderBase[MCPOAuth2ProviderConfig]):
|
|
321
321
|
|
322
322
|
Otherwise, performs standard authentication flow.
|
323
323
|
"""
|
324
|
+
if not user_id:
|
325
|
+
# MCP tool calls cannot be made without an authorized user
|
326
|
+
raise RuntimeError("User is not authorized to call the tool")
|
327
|
+
|
324
328
|
response = kwargs.get('response')
|
325
329
|
if response and response.status_code == 401:
|
326
330
|
await self._discover_and_register(response=response)
|
nat/plugins/mcp/client_base.py
CHANGED
@@ -59,6 +59,8 @@ class AuthAdapter(httpx.Auth):
|
|
59
59
|
self.auth_provider = auth_provider
|
60
60
|
# each adapter instance has its own lock to avoid unnecessary delays for multiple clients
|
61
61
|
self._lock = anyio.Lock()
|
62
|
+
# Track whether we're currently in an interactive authentication flow
|
63
|
+
self.is_authenticating = False
|
62
64
|
|
63
65
|
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
64
66
|
"""Add authentication headers to the request using NAT auth provider."""
|
@@ -85,11 +87,21 @@ class AuthAdapter(httpx.Auth):
|
|
85
87
|
# 4. The auth headers are revoked
|
86
88
|
# 5. Auth config on the MCP server has changed
|
87
89
|
# In this case we attempt to re-run discovery and authentication
|
90
|
+
|
91
|
+
# Signal that we're entering interactive auth flow
|
92
|
+
self.is_authenticating = True
|
93
|
+
logger.debug("Starting authentication flow due to 401 response")
|
94
|
+
|
88
95
|
auth_headers = await self._get_auth_headers(request=request, response=response)
|
89
96
|
request.headers.update(auth_headers)
|
90
97
|
yield request # Retry the request
|
91
98
|
except Exception as e:
|
92
99
|
logger.info("Failed to refresh auth after 401: %s", e)
|
100
|
+
raise
|
101
|
+
finally:
|
102
|
+
# Signal that auth flow is complete
|
103
|
+
self.is_authenticating = False
|
104
|
+
logger.debug("Authentication flow completed")
|
93
105
|
return
|
94
106
|
|
95
107
|
def _get_session_id_from_tool_call_request(self, request: httpx.Request) -> tuple[str | None, bool]:
|
@@ -120,11 +132,8 @@ class AuthAdapter(httpx.Auth):
|
|
120
132
|
session_id, is_tool_call = self._get_session_id_from_tool_call_request(request)
|
121
133
|
|
122
134
|
if is_tool_call:
|
123
|
-
# Tool call requests should use the session id
|
124
|
-
|
125
|
-
user_id = session_id or self.auth_provider.config.default_user_id
|
126
|
-
else:
|
127
|
-
user_id = session_id
|
135
|
+
# Tool call requests should use the session id
|
136
|
+
user_id = session_id
|
128
137
|
else:
|
129
138
|
# Non-tool call requests should use the session id if it exists and fallback to default user id
|
130
139
|
user_id = session_id or self.auth_provider.config.default_user_id
|
@@ -151,12 +160,19 @@ class MCPBaseClient(ABC):
|
|
151
160
|
Args:
|
152
161
|
transport (str): The type of client to use ('sse', 'stdio', or 'streamable-http')
|
153
162
|
auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
|
163
|
+
tool_call_timeout (timedelta): Timeout for tool calls when authentication is not required
|
164
|
+
auth_flow_timeout (timedelta): Extended timeout for tool calls that may require interactive authentication
|
165
|
+
reconnect_enabled (bool): Whether to automatically reconnect on connection failures
|
166
|
+
reconnect_max_attempts (int): Maximum number of reconnection attempts
|
167
|
+
reconnect_initial_backoff (float): Initial backoff delay in seconds for reconnection attempts
|
168
|
+
reconnect_max_backoff (float): Maximum backoff delay in seconds for reconnection attempts
|
154
169
|
"""
|
155
170
|
|
156
171
|
def __init__(self,
|
157
172
|
transport: str = 'streamable-http',
|
158
173
|
auth_provider: AuthProviderBase | None = None,
|
159
|
-
tool_call_timeout: timedelta = timedelta(seconds=
|
174
|
+
tool_call_timeout: timedelta = timedelta(seconds=60),
|
175
|
+
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
160
176
|
reconnect_enabled: bool = True,
|
161
177
|
reconnect_max_attempts: int = 2,
|
162
178
|
reconnect_initial_backoff: float = 0.5,
|
@@ -176,6 +192,7 @@ class MCPBaseClient(ABC):
|
|
176
192
|
self._httpx_auth = AuthAdapter(auth_provider) if auth_provider else None
|
177
193
|
|
178
194
|
self._tool_call_timeout = tool_call_timeout
|
195
|
+
self._auth_flow_timeout = auth_flow_timeout
|
179
196
|
|
180
197
|
# Reconnect configuration
|
181
198
|
self._reconnect_enabled = reconnect_enabled
|
@@ -184,6 +201,10 @@ class MCPBaseClient(ABC):
|
|
184
201
|
self._reconnect_max_backoff = reconnect_max_backoff
|
185
202
|
self._reconnect_lock: asyncio.Lock = asyncio.Lock()
|
186
203
|
|
204
|
+
@property
|
205
|
+
def auth_provider(self) -> AuthProviderBase | None:
|
206
|
+
return self._auth_provider
|
207
|
+
|
187
208
|
@property
|
188
209
|
def transport(self) -> str:
|
189
210
|
return self._transport
|
@@ -266,12 +287,25 @@ class MCPBaseClient(ABC):
|
|
266
287
|
async def _with_reconnect(self, coro):
|
267
288
|
"""
|
268
289
|
Execute an awaited operation, reconnecting once on errors.
|
290
|
+
Does not reconnect if the error occurs during an active authentication flow.
|
269
291
|
"""
|
270
292
|
try:
|
271
293
|
return await coro()
|
272
294
|
except Exception as e:
|
295
|
+
# Check if error happened during active authentication flow
|
296
|
+
if self._httpx_auth and self._httpx_auth.is_authenticating:
|
297
|
+
# Provide specific error message for authentication timeouts
|
298
|
+
if isinstance(e, TimeoutError):
|
299
|
+
logger.error("Timeout during user authentication flow - user may have abandoned authentication")
|
300
|
+
raise RuntimeError(
|
301
|
+
"Authentication timed out. User did not complete authentication in browser within "
|
302
|
+
f"{self._auth_flow_timeout.total_seconds()} seconds.") from e
|
303
|
+
else:
|
304
|
+
logger.error("Error during authentication flow: %s", e)
|
305
|
+
raise
|
306
|
+
|
307
|
+
# Normal error - attempt reconnect if enabled
|
273
308
|
if self._reconnect_enabled:
|
274
|
-
logger.warning("MCP Client operation failed. Attempting reconnect: %s", e)
|
275
309
|
try:
|
276
310
|
await self._reconnect()
|
277
311
|
except Exception as reconnect_err:
|
@@ -280,7 +314,49 @@ class MCPBaseClient(ABC):
|
|
280
314
|
return await coro()
|
281
315
|
raise
|
282
316
|
|
283
|
-
async def
|
317
|
+
async def _has_cached_auth_token(self) -> bool:
|
318
|
+
"""
|
319
|
+
Check if we have a cached, non-expired authentication token.
|
320
|
+
|
321
|
+
Returns:
|
322
|
+
bool: True if we have a valid cached token, False if authentication may be needed
|
323
|
+
"""
|
324
|
+
if not self._auth_provider:
|
325
|
+
return True # No auth needed
|
326
|
+
|
327
|
+
try:
|
328
|
+
# Check if OAuth2 provider has tokens cached
|
329
|
+
if hasattr(self._auth_provider, '_auth_code_provider'):
|
330
|
+
provider = self._auth_provider._auth_code_provider
|
331
|
+
if provider and hasattr(provider, '_authenticated_tokens'):
|
332
|
+
# Check if we have at least one non-expired token
|
333
|
+
for auth_result in provider._authenticated_tokens.values():
|
334
|
+
if not auth_result.is_expired():
|
335
|
+
return True
|
336
|
+
|
337
|
+
return False
|
338
|
+
except Exception:
|
339
|
+
# If we can't check, assume we need auth to be safe
|
340
|
+
return False
|
341
|
+
|
342
|
+
async def _get_tool_call_timeout(self) -> timedelta:
|
343
|
+
"""
|
344
|
+
Determine the appropriate timeout for a tool call based on authentication state.
|
345
|
+
|
346
|
+
Returns:
|
347
|
+
timedelta: auth_flow_timeout if authentication may be needed, tool_call_timeout otherwise
|
348
|
+
"""
|
349
|
+
if self._auth_provider:
|
350
|
+
has_token = await self._has_cached_auth_token()
|
351
|
+
timeout = self._tool_call_timeout if has_token else self._auth_flow_timeout
|
352
|
+
if not has_token:
|
353
|
+
logger.debug("Using extended timeout (%s) for potential interactive authentication", timeout)
|
354
|
+
return timeout
|
355
|
+
else:
|
356
|
+
return self._tool_call_timeout
|
357
|
+
|
358
|
+
@mcp_exception_handler
|
359
|
+
async def get_tools(self) -> dict[str, MCPToolClient]:
|
284
360
|
"""
|
285
361
|
Retrieve a dictionary of all tools served by the MCP server.
|
286
362
|
Uses unauthenticated session for discovery.
|
@@ -288,7 +364,16 @@ class MCPBaseClient(ABC):
|
|
288
364
|
|
289
365
|
async def _get_tools():
|
290
366
|
session = self._session
|
291
|
-
|
367
|
+
try:
|
368
|
+
# Add timeout to the list_tools call.
|
369
|
+
# This is needed because MCP SDK does not support timeout for list_tools()
|
370
|
+
with anyio.fail_after(self._tool_call_timeout.total_seconds()):
|
371
|
+
tools = await session.list_tools()
|
372
|
+
except TimeoutError as e:
|
373
|
+
from nat.plugins.mcp.exceptions import MCPTimeoutError
|
374
|
+
raise MCPTimeoutError(self.server_name, e)
|
375
|
+
|
376
|
+
return tools
|
292
377
|
|
293
378
|
try:
|
294
379
|
response = await self._with_reconnect(_get_tools)
|
@@ -302,8 +387,7 @@ class MCPBaseClient(ABC):
|
|
302
387
|
tool_name=tool.name,
|
303
388
|
tool_description=tool.description,
|
304
389
|
tool_input_schema=tool.inputSchema,
|
305
|
-
parent_client=self
|
306
|
-
tool_call_timeout=self._tool_call_timeout)
|
390
|
+
parent_client=self)
|
307
391
|
for tool in response.tools
|
308
392
|
}
|
309
393
|
|
@@ -350,12 +434,7 @@ class MCPBaseClient(ABC):
|
|
350
434
|
async def _call_tool_with_meta():
|
351
435
|
params = CallToolRequestParams(name=tool_name, arguments=args, **{"_meta": {"session_id": session_id}})
|
352
436
|
req = ClientRequest(CallToolRequest(params=params))
|
353
|
-
|
354
|
-
# auth is enabled.
|
355
|
-
if self._auth_provider and self._tool_call_timeout.total_seconds() < 300:
|
356
|
-
timeout = timedelta(seconds=300)
|
357
|
-
else:
|
358
|
-
timeout = self._tool_call_timeout
|
437
|
+
timeout = await self._get_tool_call_timeout()
|
359
438
|
return await self._session.send_request(req, CallToolResult, request_read_timeout_seconds=timeout)
|
360
439
|
|
361
440
|
return await self._with_reconnect(_call_tool_with_meta)
|
@@ -365,7 +444,8 @@ class MCPBaseClient(ABC):
|
|
365
444
|
|
366
445
|
async def _call_tool():
|
367
446
|
session = self._session
|
368
|
-
|
447
|
+
timeout = await self._get_tool_call_timeout()
|
448
|
+
return await session.call_tool(tool_name, tool_args, read_timeout_seconds=timeout)
|
369
449
|
|
370
450
|
return await self._with_reconnect(_call_tool)
|
371
451
|
|
@@ -380,13 +460,15 @@ class MCPSSEClient(MCPBaseClient):
|
|
380
460
|
|
381
461
|
def __init__(self,
|
382
462
|
url: str,
|
383
|
-
tool_call_timeout: timedelta = timedelta(seconds=
|
463
|
+
tool_call_timeout: timedelta = timedelta(seconds=60),
|
464
|
+
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
384
465
|
reconnect_enabled: bool = True,
|
385
466
|
reconnect_max_attempts: int = 2,
|
386
467
|
reconnect_initial_backoff: float = 0.5,
|
387
468
|
reconnect_max_backoff: float = 50.0):
|
388
469
|
super().__init__("sse",
|
389
470
|
tool_call_timeout=tool_call_timeout,
|
471
|
+
auth_flow_timeout=auth_flow_timeout,
|
390
472
|
reconnect_enabled=reconnect_enabled,
|
391
473
|
reconnect_max_attempts=reconnect_max_attempts,
|
392
474
|
reconnect_initial_backoff=reconnect_initial_backoff,
|
@@ -429,13 +511,15 @@ class MCPStdioClient(MCPBaseClient):
|
|
429
511
|
command: str,
|
430
512
|
args: list[str] | None = None,
|
431
513
|
env: dict[str, str] | None = None,
|
432
|
-
tool_call_timeout: timedelta = timedelta(seconds=
|
514
|
+
tool_call_timeout: timedelta = timedelta(seconds=60),
|
515
|
+
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
433
516
|
reconnect_enabled: bool = True,
|
434
517
|
reconnect_max_attempts: int = 2,
|
435
518
|
reconnect_initial_backoff: float = 0.5,
|
436
519
|
reconnect_max_backoff: float = 50.0):
|
437
520
|
super().__init__("stdio",
|
438
521
|
tool_call_timeout=tool_call_timeout,
|
522
|
+
auth_flow_timeout=auth_flow_timeout,
|
439
523
|
reconnect_enabled=reconnect_enabled,
|
440
524
|
reconnect_max_attempts=reconnect_max_attempts,
|
441
525
|
reconnect_initial_backoff=reconnect_initial_backoff,
|
@@ -486,7 +570,8 @@ class MCPStreamableHTTPClient(MCPBaseClient):
|
|
486
570
|
def __init__(self,
|
487
571
|
url: str,
|
488
572
|
auth_provider: AuthProviderBase | None = None,
|
489
|
-
tool_call_timeout: timedelta = timedelta(seconds=
|
573
|
+
tool_call_timeout: timedelta = timedelta(seconds=60),
|
574
|
+
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
490
575
|
reconnect_enabled: bool = True,
|
491
576
|
reconnect_max_attempts: int = 2,
|
492
577
|
reconnect_initial_backoff: float = 0.5,
|
@@ -494,6 +579,7 @@ class MCPStreamableHTTPClient(MCPBaseClient):
|
|
494
579
|
super().__init__("streamable-http",
|
495
580
|
auth_provider=auth_provider,
|
496
581
|
tool_call_timeout=tool_call_timeout,
|
582
|
+
auth_flow_timeout=auth_flow_timeout,
|
497
583
|
reconnect_enabled=reconnect_enabled,
|
498
584
|
reconnect_max_attempts=reconnect_max_attempts,
|
499
585
|
reconnect_initial_backoff=reconnect_initial_backoff,
|
@@ -536,17 +622,15 @@ class MCPToolClient:
|
|
536
622
|
|
537
623
|
def __init__(self,
|
538
624
|
session: ClientSession,
|
539
|
-
parent_client:
|
625
|
+
parent_client: MCPBaseClient,
|
540
626
|
tool_name: str,
|
541
627
|
tool_description: str | None,
|
542
|
-
tool_input_schema: dict | None = None
|
543
|
-
tool_call_timeout: timedelta = timedelta(seconds=5)):
|
628
|
+
tool_input_schema: dict | None = None):
|
544
629
|
self._session = session
|
545
630
|
self._tool_name = tool_name
|
546
631
|
self._tool_description = tool_description
|
547
632
|
self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
|
548
633
|
self._parent_client = parent_client
|
549
|
-
self._tool_call_timeout = tool_call_timeout
|
550
634
|
|
551
635
|
if self._parent_client is None:
|
552
636
|
raise RuntimeError("MCPToolClient initialized without a parent client.")
|
@@ -578,6 +662,32 @@ class MCPToolClient:
|
|
578
662
|
"""
|
579
663
|
self._tool_description = description
|
580
664
|
|
665
|
+
def _get_session_id(self) -> str | None:
|
666
|
+
"""
|
667
|
+
Get the session id from the context.
|
668
|
+
"""
|
669
|
+
from nat.builder.context import Context as _Ctx
|
670
|
+
|
671
|
+
# get auth callback (for example: WebSocketAuthenticationFlowHandler). this is lazily set in the client
|
672
|
+
# on first tool call
|
673
|
+
auth_callback = _Ctx.get().user_auth_callback
|
674
|
+
if auth_callback and self._parent_client:
|
675
|
+
# set custom auth callback
|
676
|
+
self._parent_client.set_user_auth_callback(auth_callback)
|
677
|
+
|
678
|
+
# get session id from context, authentication is done per-websocket session for tool calls
|
679
|
+
session_id = None
|
680
|
+
cookies = getattr(_Ctx.get().metadata, "cookies", None)
|
681
|
+
if cookies:
|
682
|
+
session_id = cookies.get("nat-session")
|
683
|
+
|
684
|
+
if not session_id:
|
685
|
+
# use default user id if allowed
|
686
|
+
if self._parent_client.auth_provider and \
|
687
|
+
self._parent_client.auth_provider.config.allow_default_user_id_for_tool_calls:
|
688
|
+
session_id = self._parent_client.auth_provider.config.default_user_id
|
689
|
+
return session_id
|
690
|
+
|
581
691
|
async def acall(self, tool_args: dict) -> str:
|
582
692
|
"""
|
583
693
|
Call the MCP tool with the provided arguments.
|
@@ -589,31 +699,24 @@ class MCPToolClient:
|
|
589
699
|
raise RuntimeError("No session available for tool call")
|
590
700
|
|
591
701
|
# Extract context information
|
592
|
-
session_id = None
|
593
702
|
try:
|
594
|
-
|
595
|
-
|
596
|
-
# get auth callback (for example: WebSocketAuthenticationFlowHandler). this is lazily set in the client
|
597
|
-
# on first tool call
|
598
|
-
auth_callback = _Ctx.get().user_auth_callback
|
599
|
-
if auth_callback and self._parent_client:
|
600
|
-
# set custom auth callback
|
601
|
-
self._parent_client.set_user_auth_callback(auth_callback)
|
602
|
-
|
603
|
-
# get session id from context, authentication is done per-websocket session for tool calls
|
604
|
-
cookies = getattr(_Ctx.get().metadata, "cookies", None)
|
605
|
-
if cookies:
|
606
|
-
session_id = cookies.get("nat-session")
|
703
|
+
session_id = self._get_session_id()
|
607
704
|
except Exception:
|
608
|
-
|
705
|
+
session_id = None
|
609
706
|
|
610
707
|
try:
|
708
|
+
# if auth is enabled and session id is not available return user is not authorized to call the tool
|
709
|
+
if self._parent_client.auth_provider and not session_id:
|
710
|
+
result_str = "User is not authorized to call the tool"
|
711
|
+
mcp_error: MCPError = convert_to_mcp_error(RuntimeError(result_str), self._parent_client.server_name)
|
712
|
+
raise mcp_error
|
713
|
+
|
611
714
|
if session_id:
|
612
715
|
logger.info("Calling tool %s with arguments %s for a user session", self._tool_name, tool_args)
|
613
716
|
result = await self._parent_client.call_tool_with_meta(self._tool_name, tool_args, session_id)
|
614
717
|
else:
|
615
718
|
logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
|
616
|
-
result = await self.
|
719
|
+
result = await self._parent_client.call_tool(self._tool_name, tool_args)
|
617
720
|
|
618
721
|
output = []
|
619
722
|
for res in result.content:
|
@@ -630,6 +733,6 @@ class MCPToolClient:
|
|
630
733
|
|
631
734
|
except MCPError as e:
|
632
735
|
format_mcp_error(e, include_traceback=False)
|
633
|
-
result_str = "MCPToolClient tool call failed:
|
736
|
+
result_str = f"MCPToolClient tool call failed: {e.original_exception}"
|
634
737
|
|
635
738
|
return result_str
|
nat/plugins/mcp/client_impl.py
CHANGED
@@ -32,6 +32,50 @@ from nat.plugins.mcp.tool import mcp_tool_function
|
|
32
32
|
logger = logging.getLogger(__name__)
|
33
33
|
|
34
34
|
|
35
|
+
class MCPFunctionGroup(FunctionGroup):
|
36
|
+
"""
|
37
|
+
A specialized FunctionGroup for MCP clients that includes MCP-specific attributes
|
38
|
+
with proper type safety.
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(self, *args, **kwargs):
|
42
|
+
super().__init__(*args, **kwargs)
|
43
|
+
# MCP client attributes with proper typing
|
44
|
+
self._mcp_client = None # Will be set to the actual MCP client instance
|
45
|
+
self._mcp_client_server_name: str | None = None
|
46
|
+
self._mcp_client_transport: str | None = None
|
47
|
+
|
48
|
+
@property
|
49
|
+
def mcp_client(self):
|
50
|
+
"""Get the MCP client instance."""
|
51
|
+
return self._mcp_client
|
52
|
+
|
53
|
+
@mcp_client.setter
|
54
|
+
def mcp_client(self, client):
|
55
|
+
"""Set the MCP client instance."""
|
56
|
+
self._mcp_client = client
|
57
|
+
|
58
|
+
@property
|
59
|
+
def mcp_client_server_name(self) -> str | None:
|
60
|
+
"""Get the MCP client server name."""
|
61
|
+
return self._mcp_client_server_name
|
62
|
+
|
63
|
+
@mcp_client_server_name.setter
|
64
|
+
def mcp_client_server_name(self, server_name: str | None):
|
65
|
+
"""Set the MCP client server name."""
|
66
|
+
self._mcp_client_server_name = server_name
|
67
|
+
|
68
|
+
@property
|
69
|
+
def mcp_client_transport(self) -> str | None:
|
70
|
+
"""Get the MCP client transport type."""
|
71
|
+
return self._mcp_client_transport
|
72
|
+
|
73
|
+
@mcp_client_transport.setter
|
74
|
+
def mcp_client_transport(self, transport: str | None):
|
75
|
+
"""Set the MCP client transport type."""
|
76
|
+
self._mcp_client_transport = transport
|
77
|
+
|
78
|
+
|
35
79
|
class MCPToolOverrideConfig(BaseModel):
|
36
80
|
"""
|
37
81
|
Configuration for overriding tool properties when exposing from MCP server.
|
@@ -95,6 +139,10 @@ class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
|
|
95
139
|
tool_call_timeout: timedelta = Field(
|
96
140
|
default=timedelta(seconds=60),
|
97
141
|
description="Timeout (in seconds) for the MCP tool call. Defaults to 60 seconds.")
|
142
|
+
auth_flow_timeout: timedelta = Field(
|
143
|
+
default=timedelta(seconds=300),
|
144
|
+
description="Timeout (in seconds) for the MCP auth flow. When the tool call requires interactive \
|
145
|
+
authentication, this timeout is used. Defaults to 300 seconds.")
|
98
146
|
reconnect_enabled: bool = Field(
|
99
147
|
default=True,
|
100
148
|
description="Whether to enable reconnecting to the MCP server if the connection is lost. \
|
@@ -152,7 +200,8 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
|
|
152
200
|
client = MCPStdioClient(config.server.command,
|
153
201
|
config.server.args,
|
154
202
|
config.server.env,
|
155
|
-
config.tool_call_timeout,
|
203
|
+
tool_call_timeout=config.tool_call_timeout,
|
204
|
+
auth_flow_timeout=config.auth_flow_timeout,
|
156
205
|
reconnect_enabled=config.reconnect_enabled,
|
157
206
|
reconnect_max_attempts=config.reconnect_max_attempts,
|
158
207
|
reconnect_initial_backoff=config.reconnect_initial_backoff,
|
@@ -160,6 +209,7 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
|
|
160
209
|
elif config.server.transport == "sse":
|
161
210
|
client = MCPSSEClient(str(config.server.url),
|
162
211
|
tool_call_timeout=config.tool_call_timeout,
|
212
|
+
auth_flow_timeout=config.auth_flow_timeout,
|
163
213
|
reconnect_enabled=config.reconnect_enabled,
|
164
214
|
reconnect_max_attempts=config.reconnect_max_attempts,
|
165
215
|
reconnect_initial_backoff=config.reconnect_initial_backoff,
|
@@ -168,6 +218,7 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
|
|
168
218
|
client = MCPStreamableHTTPClient(str(config.server.url),
|
169
219
|
auth_provider=auth_provider,
|
170
220
|
tool_call_timeout=config.tool_call_timeout,
|
221
|
+
auth_flow_timeout=config.auth_flow_timeout,
|
171
222
|
reconnect_enabled=config.reconnect_enabled,
|
172
223
|
reconnect_max_attempts=config.reconnect_max_attempts,
|
173
224
|
reconnect_initial_backoff=config.reconnect_initial_backoff,
|
@@ -177,10 +228,16 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
|
|
177
228
|
|
178
229
|
logger.info("Configured to use MCP server at %s", client.server_name)
|
179
230
|
|
180
|
-
# Create the function group
|
181
|
-
group =
|
231
|
+
# Create the MCP function group
|
232
|
+
group = MCPFunctionGroup(config=config)
|
182
233
|
|
183
234
|
async with client:
|
235
|
+
# Expose the live MCP client on the function group instance so other components (e.g., HTTP endpoints)
|
236
|
+
# can reuse the already-established session instead of creating a new client per request.
|
237
|
+
group.mcp_client = client
|
238
|
+
group.mcp_client_server_name = client.server_name
|
239
|
+
group.mcp_client_transport = client.transport
|
240
|
+
|
184
241
|
all_tools = await client.get_tools()
|
185
242
|
tool_overrides = mcp_apply_tool_alias_and_description(all_tools, config.tool_overrides)
|
186
243
|
|
@@ -196,12 +253,24 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
|
|
196
253
|
# Create the tool function
|
197
254
|
tool_fn = mcp_tool_function(tool)
|
198
255
|
|
256
|
+
# Normalize optional typing for linter/type-checker compatibility
|
257
|
+
single_fn = tool_fn.single_fn
|
258
|
+
if single_fn is None:
|
259
|
+
# Should not happen because mcp_tool_function always sets a single_fn
|
260
|
+
logger.warning("Skipping tool %s because single_fn is None", function_name)
|
261
|
+
continue
|
262
|
+
|
263
|
+
input_schema = tool_fn.input_schema
|
264
|
+
# Convert NoneType sentinel to None for FunctionGroup.add_function signature
|
265
|
+
if input_schema is type(None): # noqa: E721
|
266
|
+
input_schema = None
|
267
|
+
|
199
268
|
# Add to group
|
200
269
|
logger.info("Adding tool %s to group", function_name)
|
201
270
|
group.add_function(name=function_name,
|
202
271
|
description=description,
|
203
|
-
fn=
|
204
|
-
input_schema=
|
272
|
+
fn=single_fn,
|
273
|
+
input_schema=input_schema,
|
205
274
|
converters=tool_fn.converters)
|
206
275
|
|
207
276
|
yield group
|
@@ -94,7 +94,7 @@ def extract_primary_exception(exceptions: list[Exception]) -> Exception:
|
|
94
94
|
"""
|
95
95
|
# Prioritize connection errors
|
96
96
|
for exc in exceptions:
|
97
|
-
if isinstance(exc,
|
97
|
+
if isinstance(exc, httpx.ConnectError | ConnectionError):
|
98
98
|
return exc
|
99
99
|
|
100
100
|
# Then timeout errors
|
{nvidia_nat_mcp-1.3.0a20250929.dist-info → nvidia_nat_mcp-1.3.0a20251001.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: nvidia-nat-mcp
|
3
|
-
Version: 1.3.
|
3
|
+
Version: 1.3.0a20251001
|
4
4
|
Summary: Subpackage for MCP client integration in NeMo Agent toolkit
|
5
5
|
Keywords: ai,rag,agents,mcp
|
6
6
|
Classifier: Programming Language :: Python
|
@@ -9,7 +9,7 @@ Classifier: Programming Language :: Python :: 3.12
|
|
9
9
|
Classifier: Programming Language :: Python :: 3.13
|
10
10
|
Requires-Python: <3.14,>=3.11
|
11
11
|
Description-Content-Type: text/markdown
|
12
|
-
Requires-Dist: nvidia-nat==v1.3.
|
12
|
+
Requires-Dist: nvidia-nat==v1.3.0a20251001
|
13
13
|
Requires-Dist: mcp~=1.14
|
14
14
|
|
15
15
|
<!--
|
@@ -1,19 +1,19 @@
|
|
1
1
|
nat/meta/pypi.md,sha256=GyV4DI1d9ThgEhnYTQ0vh40Q9hPC8jN-goLnRiFDmZ8,1498
|
2
2
|
nat/plugins/mcp/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
3
|
-
nat/plugins/mcp/client_base.py,sha256=
|
4
|
-
nat/plugins/mcp/client_impl.py,sha256=
|
5
|
-
nat/plugins/mcp/exception_handler.py,sha256=
|
3
|
+
nat/plugins/mcp/client_base.py,sha256=x_mgrCORXqYfJcYZg8zd0wm1AN1uFl51A7bYye0X-rc,30151
|
4
|
+
nat/plugins/mcp/client_impl.py,sha256=FGWlpzyBDR2tNkV9Ek7brSX9EWh98kfAGCr49zMzDSU,13657
|
5
|
+
nat/plugins/mcp/exception_handler.py,sha256=4JVdZDJL4LyumZEcMIEBK2LYC6djuSMzqUhQDZZ6dUo,7648
|
6
6
|
nat/plugins/mcp/exceptions.py,sha256=EGVOnYlui8xufm8dhJyPL1SUqBLnCGOTvRoeyNcmcWE,5980
|
7
7
|
nat/plugins/mcp/register.py,sha256=HOT2Wl2isGuyFc7BUTi58-BbjI5-EtZMZo7stsv5pN4,831
|
8
8
|
nat/plugins/mcp/tool.py,sha256=v3MFsiaLJy8Ourcfqa6ohtAE2Nn-vqpC6Q6gsCdJ28Q,6165
|
9
9
|
nat/plugins/mcp/utils.py,sha256=3fuzYpC14wrfMOTOGvY2KHWcxZvBWqrxdDZD17lhmC8,4055
|
10
10
|
nat/plugins/mcp/auth/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
11
|
-
nat/plugins/mcp/auth/auth_flow_handler.py,sha256=
|
12
|
-
nat/plugins/mcp/auth/auth_provider.py,sha256=
|
11
|
+
nat/plugins/mcp/auth/auth_flow_handler.py,sha256=2JgK0aH-5ouQCd2ov0lDMJAD5ZWIQJ7SVcXaLArxn6Y,6010
|
12
|
+
nat/plugins/mcp/auth/auth_provider.py,sha256=OfxPCEaXuhP8anOdrTRH-_E78CrbJtzW6i81_kebpDk,19321
|
13
13
|
nat/plugins/mcp/auth/auth_provider_config.py,sha256=vhU47Vcp_30M8tWu0FumbJ6pdUnFbBZm-ABdNlup__U,3821
|
14
14
|
nat/plugins/mcp/auth/register.py,sha256=yzphsn1I4a5G39_IacbuX0ZQqGM8fevvTUM_B94UXKE,1211
|
15
|
-
nvidia_nat_mcp-1.3.
|
16
|
-
nvidia_nat_mcp-1.3.
|
17
|
-
nvidia_nat_mcp-1.3.
|
18
|
-
nvidia_nat_mcp-1.3.
|
19
|
-
nvidia_nat_mcp-1.3.
|
15
|
+
nvidia_nat_mcp-1.3.0a20251001.dist-info/METADATA,sha256=RFkgBHpMSGrticp7yB2GeO5jvQh_OVW0iZ_QZTR1nUQ,1997
|
16
|
+
nvidia_nat_mcp-1.3.0a20251001.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
17
|
+
nvidia_nat_mcp-1.3.0a20251001.dist-info/entry_points.txt,sha256=rYvUp4i-klBr3bVNh7zYOPXret704vTjvCk1qd7FooI,97
|
18
|
+
nvidia_nat_mcp-1.3.0a20251001.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
|
19
|
+
nvidia_nat_mcp-1.3.0a20251001.dist-info/RECORD,,
|
File without changes
|
{nvidia_nat_mcp-1.3.0a20250929.dist-info → nvidia_nat_mcp-1.3.0a20251001.dist-info}/entry_points.txt
RENAMED
File without changes
|
{nvidia_nat_mcp-1.3.0a20250929.dist-info → nvidia_nat_mcp-1.3.0a20251001.dist-info}/top_level.txt
RENAMED
File without changes
|