nvidia-nat-mcp 1.3.0rc1__py3-none-any.whl → 1.4.0a20251008__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +2 -2
- nat/plugins/mcp/auth/auth_provider.py +42 -10
- nat/plugins/mcp/auth/auth_provider_config.py +5 -0
- nat/plugins/mcp/auth/register.py +1 -1
- nat/plugins/mcp/auth/token_storage.py +265 -0
- nat/plugins/mcp/client_base.py +103 -111
- nat/plugins/mcp/client_config.py +131 -0
- nat/plugins/mcp/client_impl.py +293 -104
- nat/plugins/mcp/tool.py +5 -0
- nat/plugins/mcp/utils.py +16 -0
- {nvidia_nat_mcp-1.3.0rc1.dist-info → nvidia_nat_mcp-1.4.0a20251008.dist-info}/METADATA +5 -4
- nvidia_nat_mcp-1.4.0a20251008.dist-info/RECORD +21 -0
- nvidia_nat_mcp-1.3.0rc1.dist-info/RECORD +0 -19
- {nvidia_nat_mcp-1.3.0rc1.dist-info → nvidia_nat_mcp-1.4.0a20251008.dist-info}/WHEEL +0 -0
- {nvidia_nat_mcp-1.3.0rc1.dist-info → nvidia_nat_mcp-1.4.0a20251008.dist-info}/entry_points.txt +0 -0
- {nvidia_nat_mcp-1.3.0rc1.dist-info → nvidia_nat_mcp-1.4.0a20251008.dist-info}/top_level.txt +0 -0
nat/plugins/mcp/client_base.py
CHANGED
@@ -16,7 +16,6 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import asyncio
|
19
|
-
import json
|
20
19
|
import logging
|
21
20
|
from abc import ABC
|
22
21
|
from abc import abstractmethod
|
@@ -55,10 +54,13 @@ class AuthAdapter(httpx.Auth):
|
|
55
54
|
Converts AuthProviderBase to httpx.Auth interface for dynamic token management.
|
56
55
|
"""
|
57
56
|
|
58
|
-
def __init__(self, auth_provider: AuthProviderBase):
|
57
|
+
def __init__(self, auth_provider: AuthProviderBase, user_id: str | None = None):
|
59
58
|
self.auth_provider = auth_provider
|
59
|
+
self.user_id = user_id # Session-specific user ID for cache isolation
|
60
60
|
# each adapter instance has its own lock to avoid unnecessary delays for multiple clients
|
61
61
|
self._lock = anyio.Lock()
|
62
|
+
# Track whether we're currently in an interactive authentication flow
|
63
|
+
self.is_authenticating = False
|
62
64
|
|
63
65
|
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
64
66
|
"""Add authentication headers to the request using NAT auth provider."""
|
@@ -85,48 +87,30 @@ class AuthAdapter(httpx.Auth):
|
|
85
87
|
# 4. The auth headers are revoked
|
86
88
|
# 5. Auth config on the MCP server has changed
|
87
89
|
# In this case we attempt to re-run discovery and authentication
|
90
|
+
|
91
|
+
# Signal that we're entering interactive auth flow
|
92
|
+
self.is_authenticating = True
|
93
|
+
logger.debug("Starting authentication flow due to 401 response")
|
94
|
+
|
88
95
|
auth_headers = await self._get_auth_headers(request=request, response=response)
|
89
96
|
request.headers.update(auth_headers)
|
90
97
|
yield request # Retry the request
|
91
98
|
except Exception as e:
|
92
99
|
logger.info("Failed to refresh auth after 401: %s", e)
|
100
|
+
raise
|
101
|
+
finally:
|
102
|
+
# Signal that auth flow is complete
|
103
|
+
self.is_authenticating = False
|
104
|
+
logger.debug("Authentication flow completed")
|
93
105
|
return
|
94
106
|
|
95
|
-
def _get_session_id_from_tool_call_request(self, request: httpx.Request) -> tuple[str | None, bool]:
|
96
|
-
"""Check if this is a tool call request based on the request body.
|
97
|
-
Return the session id if it exists and a boolean indicating if it is a tool call request
|
98
|
-
"""
|
99
|
-
try:
|
100
|
-
# Check if the request body contains a tool call
|
101
|
-
if request.content:
|
102
|
-
body = json.loads(request.content.decode('utf-8'))
|
103
|
-
# Check if it's a JSON-RPC request with method "tools/call"
|
104
|
-
if (isinstance(body, dict) and body.get("method") == "tools/call"):
|
105
|
-
session_id = body.get("params").get("_meta").get("session_id")
|
106
|
-
return session_id, True
|
107
|
-
except (json.JSONDecodeError, UnicodeDecodeError, AttributeError):
|
108
|
-
# If we can't parse the body, assume it's not a tool call
|
109
|
-
pass
|
110
|
-
return None, False
|
111
|
-
|
112
107
|
async def _get_auth_headers(self,
|
113
108
|
request: httpx.Request | None = None,
|
114
109
|
response: httpx.Response | None = None) -> dict[str, str]:
|
115
110
|
"""Get authentication headers from the NAT auth provider."""
|
116
111
|
try:
|
117
|
-
|
118
|
-
|
119
|
-
if request:
|
120
|
-
session_id, is_tool_call = self._get_session_id_from_tool_call_request(request)
|
121
|
-
|
122
|
-
if is_tool_call:
|
123
|
-
# Tool call requests should use the session id
|
124
|
-
user_id = session_id
|
125
|
-
else:
|
126
|
-
# Non-tool call requests should use the session id if it exists and fallback to default user id
|
127
|
-
user_id = session_id or self.auth_provider.config.default_user_id
|
128
|
-
|
129
|
-
auth_result = await self.auth_provider.authenticate(user_id=user_id, response=response)
|
112
|
+
# Use the user_id passed to this AuthAdapter instance
|
113
|
+
auth_result = await self.auth_provider.authenticate(user_id=self.user_id, response=response)
|
130
114
|
|
131
115
|
# Check if we have BearerTokenCred
|
132
116
|
from nat.data_models.authentication import BearerTokenCred
|
@@ -148,12 +132,20 @@ class MCPBaseClient(ABC):
|
|
148
132
|
Args:
|
149
133
|
transport (str): The type of client to use ('sse', 'stdio', or 'streamable-http')
|
150
134
|
auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
|
135
|
+
tool_call_timeout (timedelta): Timeout for tool calls when authentication is not required
|
136
|
+
auth_flow_timeout (timedelta): Extended timeout for tool calls that may require interactive authentication
|
137
|
+
reconnect_enabled (bool): Whether to automatically reconnect on connection failures
|
138
|
+
reconnect_max_attempts (int): Maximum number of reconnection attempts
|
139
|
+
reconnect_initial_backoff (float): Initial backoff delay in seconds for reconnection attempts
|
140
|
+
reconnect_max_backoff (float): Maximum backoff delay in seconds for reconnection attempts
|
151
141
|
"""
|
152
142
|
|
153
143
|
def __init__(self,
|
154
144
|
transport: str = 'streamable-http',
|
155
145
|
auth_provider: AuthProviderBase | None = None,
|
156
|
-
|
146
|
+
user_id: str | None = None,
|
147
|
+
tool_call_timeout: timedelta = timedelta(seconds=60),
|
148
|
+
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
157
149
|
reconnect_enabled: bool = True,
|
158
150
|
reconnect_max_attempts: int = 2,
|
159
151
|
reconnect_initial_backoff: float = 0.5,
|
@@ -170,9 +162,12 @@ class MCPBaseClient(ABC):
|
|
170
162
|
|
171
163
|
# Convert auth provider to AuthAdapter
|
172
164
|
self._auth_provider = auth_provider
|
173
|
-
|
165
|
+
# Use provided user_id or fall back to auth provider's default_user_id
|
166
|
+
effective_user_id = user_id or (auth_provider.config.default_user_id if auth_provider else None)
|
167
|
+
self._httpx_auth = AuthAdapter(auth_provider, effective_user_id) if auth_provider else None
|
174
168
|
|
175
169
|
self._tool_call_timeout = tool_call_timeout
|
170
|
+
self._auth_flow_timeout = auth_flow_timeout
|
176
171
|
|
177
172
|
# Reconnect configuration
|
178
173
|
self._reconnect_enabled = reconnect_enabled
|
@@ -267,12 +262,25 @@ class MCPBaseClient(ABC):
|
|
267
262
|
async def _with_reconnect(self, coro):
|
268
263
|
"""
|
269
264
|
Execute an awaited operation, reconnecting once on errors.
|
265
|
+
Does not reconnect if the error occurs during an active authentication flow.
|
270
266
|
"""
|
271
267
|
try:
|
272
268
|
return await coro()
|
273
269
|
except Exception as e:
|
270
|
+
# Check if error happened during active authentication flow
|
271
|
+
if self._httpx_auth and self._httpx_auth.is_authenticating:
|
272
|
+
# Provide specific error message for authentication timeouts
|
273
|
+
if isinstance(e, TimeoutError):
|
274
|
+
logger.error("Timeout during user authentication flow - user may have abandoned authentication")
|
275
|
+
raise RuntimeError(
|
276
|
+
"Authentication timed out. User did not complete authentication in browser within "
|
277
|
+
f"{self._auth_flow_timeout.total_seconds()} seconds.") from e
|
278
|
+
else:
|
279
|
+
logger.error("Error during authentication flow: %s", e)
|
280
|
+
raise
|
281
|
+
|
282
|
+
# Normal error - attempt reconnect if enabled
|
274
283
|
if self._reconnect_enabled:
|
275
|
-
logger.warning("MCP Client operation failed. Attempting reconnect: %s", e)
|
276
284
|
try:
|
277
285
|
await self._reconnect()
|
278
286
|
except Exception as reconnect_err:
|
@@ -281,6 +289,47 @@ class MCPBaseClient(ABC):
|
|
281
289
|
return await coro()
|
282
290
|
raise
|
283
291
|
|
292
|
+
async def _has_cached_auth_token(self) -> bool:
|
293
|
+
"""
|
294
|
+
Check if we have a cached, non-expired authentication token.
|
295
|
+
|
296
|
+
Returns:
|
297
|
+
bool: True if we have a valid cached token, False if authentication may be needed
|
298
|
+
"""
|
299
|
+
if not self._auth_provider:
|
300
|
+
return True # No auth needed
|
301
|
+
|
302
|
+
try:
|
303
|
+
# Check if OAuth2 provider has tokens cached
|
304
|
+
if hasattr(self._auth_provider, '_auth_code_provider'):
|
305
|
+
provider = self._auth_provider._auth_code_provider
|
306
|
+
if provider and hasattr(provider, '_authenticated_tokens'):
|
307
|
+
# Check if we have at least one non-expired token
|
308
|
+
for auth_result in provider._authenticated_tokens.values():
|
309
|
+
if not auth_result.is_expired():
|
310
|
+
return True
|
311
|
+
|
312
|
+
return False
|
313
|
+
except Exception:
|
314
|
+
# If we can't check, assume we need auth to be safe
|
315
|
+
return False
|
316
|
+
|
317
|
+
async def _get_tool_call_timeout(self) -> timedelta:
|
318
|
+
"""
|
319
|
+
Determine the appropriate timeout for a tool call based on authentication state.
|
320
|
+
|
321
|
+
Returns:
|
322
|
+
timedelta: auth_flow_timeout if authentication may be needed, tool_call_timeout otherwise
|
323
|
+
"""
|
324
|
+
if self._auth_provider:
|
325
|
+
has_token = await self._has_cached_auth_token()
|
326
|
+
timeout = self._tool_call_timeout if has_token else self._auth_flow_timeout
|
327
|
+
if not has_token:
|
328
|
+
logger.debug("Using extended timeout (%s) for potential interactive authentication", timeout)
|
329
|
+
return timeout
|
330
|
+
else:
|
331
|
+
return self._tool_call_timeout
|
332
|
+
|
284
333
|
@mcp_exception_handler
|
285
334
|
async def get_tools(self) -> dict[str, MCPToolClient]:
|
286
335
|
"""
|
@@ -313,8 +362,7 @@ class MCPBaseClient(ABC):
|
|
313
362
|
tool_name=tool.name,
|
314
363
|
tool_description=tool.description,
|
315
364
|
tool_input_schema=tool.inputSchema,
|
316
|
-
parent_client=self
|
317
|
-
tool_call_timeout=self._tool_call_timeout)
|
365
|
+
parent_client=self)
|
318
366
|
for tool in response.tools
|
319
367
|
}
|
320
368
|
|
@@ -348,35 +396,13 @@ class MCPBaseClient(ABC):
|
|
348
396
|
if self._auth_provider and hasattr(self._auth_provider, "_set_custom_auth_callback"):
|
349
397
|
self._auth_provider._set_custom_auth_callback(auth_callback)
|
350
398
|
|
351
|
-
@mcp_exception_handler
|
352
|
-
async def call_tool_with_meta(self, tool_name: str, args: dict, session_id: str):
|
353
|
-
from mcp.types import CallToolRequest
|
354
|
-
from mcp.types import CallToolRequestParams
|
355
|
-
from mcp.types import CallToolResult
|
356
|
-
from mcp.types import ClientRequest
|
357
|
-
|
358
|
-
if not self._session:
|
359
|
-
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
360
|
-
|
361
|
-
async def _call_tool_with_meta():
|
362
|
-
params = CallToolRequestParams(name=tool_name, arguments=args, **{"_meta": {"session_id": session_id}})
|
363
|
-
req = ClientRequest(CallToolRequest(params=params))
|
364
|
-
# We will increase the timeout to 5 minutes if the tool call timeout is less than 5 min and
|
365
|
-
# auth is enabled.
|
366
|
-
if self._auth_provider and self._tool_call_timeout.total_seconds() < 300:
|
367
|
-
timeout = timedelta(seconds=300)
|
368
|
-
else:
|
369
|
-
timeout = self._tool_call_timeout
|
370
|
-
return await self._session.send_request(req, CallToolResult, request_read_timeout_seconds=timeout)
|
371
|
-
|
372
|
-
return await self._with_reconnect(_call_tool_with_meta)
|
373
|
-
|
374
399
|
@mcp_exception_handler
|
375
400
|
async def call_tool(self, tool_name: str, tool_args: dict | None):
|
376
401
|
|
377
402
|
async def _call_tool():
|
378
403
|
session = self._session
|
379
|
-
|
404
|
+
timeout = await self._get_tool_call_timeout()
|
405
|
+
return await session.call_tool(tool_name, tool_args, read_timeout_seconds=timeout)
|
380
406
|
|
381
407
|
return await self._with_reconnect(_call_tool)
|
382
408
|
|
@@ -391,13 +417,15 @@ class MCPSSEClient(MCPBaseClient):
|
|
391
417
|
|
392
418
|
def __init__(self,
|
393
419
|
url: str,
|
394
|
-
tool_call_timeout: timedelta = timedelta(seconds=
|
420
|
+
tool_call_timeout: timedelta = timedelta(seconds=60),
|
421
|
+
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
395
422
|
reconnect_enabled: bool = True,
|
396
423
|
reconnect_max_attempts: int = 2,
|
397
424
|
reconnect_initial_backoff: float = 0.5,
|
398
425
|
reconnect_max_backoff: float = 50.0):
|
399
426
|
super().__init__("sse",
|
400
427
|
tool_call_timeout=tool_call_timeout,
|
428
|
+
auth_flow_timeout=auth_flow_timeout,
|
401
429
|
reconnect_enabled=reconnect_enabled,
|
402
430
|
reconnect_max_attempts=reconnect_max_attempts,
|
403
431
|
reconnect_initial_backoff=reconnect_initial_backoff,
|
@@ -440,13 +468,15 @@ class MCPStdioClient(MCPBaseClient):
|
|
440
468
|
command: str,
|
441
469
|
args: list[str] | None = None,
|
442
470
|
env: dict[str, str] | None = None,
|
443
|
-
tool_call_timeout: timedelta = timedelta(seconds=
|
471
|
+
tool_call_timeout: timedelta = timedelta(seconds=60),
|
472
|
+
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
444
473
|
reconnect_enabled: bool = True,
|
445
474
|
reconnect_max_attempts: int = 2,
|
446
475
|
reconnect_initial_backoff: float = 0.5,
|
447
476
|
reconnect_max_backoff: float = 50.0):
|
448
477
|
super().__init__("stdio",
|
449
478
|
tool_call_timeout=tool_call_timeout,
|
479
|
+
auth_flow_timeout=auth_flow_timeout,
|
450
480
|
reconnect_enabled=reconnect_enabled,
|
451
481
|
reconnect_max_attempts=reconnect_max_attempts,
|
452
482
|
reconnect_initial_backoff=reconnect_initial_backoff,
|
@@ -497,14 +527,18 @@ class MCPStreamableHTTPClient(MCPBaseClient):
|
|
497
527
|
def __init__(self,
|
498
528
|
url: str,
|
499
529
|
auth_provider: AuthProviderBase | None = None,
|
500
|
-
|
530
|
+
user_id: str | None = None,
|
531
|
+
tool_call_timeout: timedelta = timedelta(seconds=60),
|
532
|
+
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
501
533
|
reconnect_enabled: bool = True,
|
502
534
|
reconnect_max_attempts: int = 2,
|
503
535
|
reconnect_initial_backoff: float = 0.5,
|
504
536
|
reconnect_max_backoff: float = 50.0):
|
505
537
|
super().__init__("streamable-http",
|
506
538
|
auth_provider=auth_provider,
|
539
|
+
user_id=user_id,
|
507
540
|
tool_call_timeout=tool_call_timeout,
|
541
|
+
auth_flow_timeout=auth_flow_timeout,
|
508
542
|
reconnect_enabled=reconnect_enabled,
|
509
543
|
reconnect_max_attempts=reconnect_max_attempts,
|
510
544
|
reconnect_initial_backoff=reconnect_initial_backoff,
|
@@ -550,14 +584,12 @@ class MCPToolClient:
|
|
550
584
|
parent_client: MCPBaseClient,
|
551
585
|
tool_name: str,
|
552
586
|
tool_description: str | None,
|
553
|
-
tool_input_schema: dict | None = None
|
554
|
-
tool_call_timeout: timedelta = timedelta(seconds=5)):
|
587
|
+
tool_input_schema: dict | None = None):
|
555
588
|
self._session = session
|
556
589
|
self._tool_name = tool_name
|
557
590
|
self._tool_description = tool_description
|
558
591
|
self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
|
559
592
|
self._parent_client = parent_client
|
560
|
-
self._tool_call_timeout = tool_call_timeout
|
561
593
|
|
562
594
|
if self._parent_client is None:
|
563
595
|
raise RuntimeError("MCPToolClient initialized without a parent client.")
|
@@ -589,35 +621,10 @@ class MCPToolClient:
|
|
589
621
|
"""
|
590
622
|
self._tool_description = description
|
591
623
|
|
592
|
-
def _get_session_id(self) -> str | None:
|
593
|
-
"""
|
594
|
-
Get the session id from the context.
|
595
|
-
"""
|
596
|
-
from nat.builder.context import Context as _Ctx
|
597
|
-
|
598
|
-
# get auth callback (for example: WebSocketAuthenticationFlowHandler). this is lazily set in the client
|
599
|
-
# on first tool call
|
600
|
-
auth_callback = _Ctx.get().user_auth_callback
|
601
|
-
if auth_callback and self._parent_client:
|
602
|
-
# set custom auth callback
|
603
|
-
self._parent_client.set_user_auth_callback(auth_callback)
|
604
|
-
|
605
|
-
# get session id from context, authentication is done per-websocket session for tool calls
|
606
|
-
session_id = None
|
607
|
-
cookies = getattr(_Ctx.get().metadata, "cookies", None)
|
608
|
-
if cookies:
|
609
|
-
session_id = cookies.get("nat-session")
|
610
|
-
|
611
|
-
if not session_id:
|
612
|
-
# use default user id if allowed
|
613
|
-
if self._parent_client.auth_provider and \
|
614
|
-
self._parent_client.auth_provider.config.allow_default_user_id_for_tool_calls:
|
615
|
-
session_id = self._parent_client.auth_provider.config.default_user_id
|
616
|
-
return session_id
|
617
|
-
|
618
624
|
async def acall(self, tool_args: dict) -> str:
|
619
625
|
"""
|
620
626
|
Call the MCP tool with the provided arguments.
|
627
|
+
Session context is now handled at the client level, eliminating the need for metadata injection.
|
621
628
|
|
622
629
|
Args:
|
623
630
|
tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
|
@@ -625,25 +632,10 @@ class MCPToolClient:
|
|
625
632
|
if self._session is None:
|
626
633
|
raise RuntimeError("No session available for tool call")
|
627
634
|
|
628
|
-
# Extract context information
|
629
|
-
try:
|
630
|
-
session_id = self._get_session_id()
|
631
|
-
except Exception:
|
632
|
-
session_id = None
|
633
|
-
|
634
635
|
try:
|
635
|
-
#
|
636
|
-
|
637
|
-
|
638
|
-
mcp_error: MCPError = convert_to_mcp_error(RuntimeError(result_str), self._parent_client.server_name)
|
639
|
-
raise mcp_error
|
640
|
-
|
641
|
-
if session_id:
|
642
|
-
logger.info("Calling tool %s with arguments %s for a user session", self._tool_name, tool_args)
|
643
|
-
result = await self._parent_client.call_tool_with_meta(self._tool_name, tool_args, session_id)
|
644
|
-
else:
|
645
|
-
logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
|
646
|
-
result = await self._session.call_tool(self._tool_name, tool_args)
|
636
|
+
# Simple tool call - session context is already in the client instance
|
637
|
+
logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
|
638
|
+
result = await self._parent_client.call_tool(self._tool_name, tool_args)
|
647
639
|
|
648
640
|
output = []
|
649
641
|
for res in result.content:
|
@@ -0,0 +1,131 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from datetime import timedelta
|
17
|
+
from typing import Literal
|
18
|
+
|
19
|
+
from pydantic import BaseModel
|
20
|
+
from pydantic import Field
|
21
|
+
from pydantic import HttpUrl
|
22
|
+
from pydantic import model_validator
|
23
|
+
|
24
|
+
from nat.data_models.component_ref import AuthenticationRef
|
25
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
26
|
+
|
27
|
+
|
28
|
+
class MCPToolOverrideConfig(BaseModel):
|
29
|
+
"""
|
30
|
+
Configuration for overriding tool properties when exposing from MCP server.
|
31
|
+
"""
|
32
|
+
alias: str | None = Field(default=None, description="Override the tool name (function name in the workflow)")
|
33
|
+
description: str | None = Field(default=None, description="Override the tool description")
|
34
|
+
|
35
|
+
|
36
|
+
class MCPServerConfig(BaseModel):
|
37
|
+
"""
|
38
|
+
Server connection details for MCP client.
|
39
|
+
Supports stdio, sse, and streamable-http transports.
|
40
|
+
streamable-http is the recommended default for HTTP-based connections.
|
41
|
+
"""
|
42
|
+
transport: Literal["stdio", "sse", "streamable-http"] = Field(
|
43
|
+
..., description="Transport type to connect to the MCP server (stdio, sse, or streamable-http)")
|
44
|
+
url: HttpUrl | None = Field(default=None,
|
45
|
+
description="URL of the MCP server (for sse or streamable-http transport)")
|
46
|
+
command: str | None = Field(default=None,
|
47
|
+
description="Command to run for stdio transport (e.g. 'python' or 'docker')")
|
48
|
+
args: list[str] | None = Field(default=None, description="Arguments for the stdio command")
|
49
|
+
env: dict[str, str] | None = Field(default=None, description="Environment variables for the stdio process")
|
50
|
+
|
51
|
+
# Authentication configuration
|
52
|
+
auth_provider: str | AuthenticationRef | None = Field(default=None,
|
53
|
+
description="Reference to authentication provider")
|
54
|
+
|
55
|
+
@model_validator(mode="after")
|
56
|
+
def validate_model(self):
|
57
|
+
"""Validate that stdio and SSE/Streamable HTTP properties are mutually exclusive."""
|
58
|
+
if self.transport == "stdio":
|
59
|
+
if self.url is not None:
|
60
|
+
raise ValueError("url should not be set when using stdio transport")
|
61
|
+
if not self.command:
|
62
|
+
raise ValueError("command is required when using stdio transport")
|
63
|
+
# Auth is not supported for stdio transport
|
64
|
+
if self.auth_provider is not None:
|
65
|
+
raise ValueError("Authentication is not supported for stdio transport")
|
66
|
+
elif self.transport == "sse":
|
67
|
+
if self.command is not None or self.args is not None or self.env is not None:
|
68
|
+
raise ValueError("command, args, and env should not be set when using sse transport")
|
69
|
+
if not self.url:
|
70
|
+
raise ValueError("url is required when using sse transport")
|
71
|
+
# Auth is not supported for SSE transport
|
72
|
+
if self.auth_provider is not None:
|
73
|
+
raise ValueError("Authentication is not supported for SSE transport.")
|
74
|
+
elif self.transport == "streamable-http":
|
75
|
+
if self.command is not None or self.args is not None or self.env is not None:
|
76
|
+
raise ValueError("command, args, and env should not be set when using streamable-http transport")
|
77
|
+
if not self.url:
|
78
|
+
raise ValueError("url is required when using streamable-http transport")
|
79
|
+
|
80
|
+
return self
|
81
|
+
|
82
|
+
|
83
|
+
class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
|
84
|
+
"""
|
85
|
+
Configuration for connecting to an MCP server as a client and exposing selected tools.
|
86
|
+
"""
|
87
|
+
server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)")
|
88
|
+
tool_call_timeout: timedelta = Field(
|
89
|
+
default=timedelta(seconds=60),
|
90
|
+
description="Timeout (in seconds) for the MCP tool call. Defaults to 60 seconds.")
|
91
|
+
auth_flow_timeout: timedelta = Field(
|
92
|
+
default=timedelta(seconds=300),
|
93
|
+
description="Timeout (in seconds) for the MCP auth flow. When the tool call requires interactive \
|
94
|
+
authentication, this timeout is used. Defaults to 300 seconds.")
|
95
|
+
reconnect_enabled: bool = Field(
|
96
|
+
default=True,
|
97
|
+
description="Whether to enable reconnecting to the MCP server if the connection is lost. \
|
98
|
+
Defaults to True.")
|
99
|
+
reconnect_max_attempts: int = Field(default=2,
|
100
|
+
ge=0,
|
101
|
+
description="Maximum number of reconnect attempts. Defaults to 2.")
|
102
|
+
reconnect_initial_backoff: float = Field(
|
103
|
+
default=0.5, ge=0.0, description="Initial backoff time for reconnect attempts. Defaults to 0.5 seconds.")
|
104
|
+
reconnect_max_backoff: float = Field(
|
105
|
+
default=50.0, ge=0.0, description="Maximum backoff time for reconnect attempts. Defaults to 50 seconds.")
|
106
|
+
tool_overrides: dict[str, MCPToolOverrideConfig] | None = Field(
|
107
|
+
default=None,
|
108
|
+
description="""Optional tool name overrides and description changes.
|
109
|
+
Example:
|
110
|
+
tool_overrides:
|
111
|
+
calculator_add:
|
112
|
+
alias: "add_numbers"
|
113
|
+
description: "Add two numbers together"
|
114
|
+
calculator_multiply:
|
115
|
+
description: "Multiply two numbers" # alias defaults to original name
|
116
|
+
""")
|
117
|
+
session_aware_tools: bool = Field(default=True,
|
118
|
+
description="Session-aware tools are created if True. Defaults to True.")
|
119
|
+
max_sessions: int = Field(default=100,
|
120
|
+
ge=1,
|
121
|
+
description="Maximum number of concurrent session clients. Defaults to 100.")
|
122
|
+
session_idle_timeout: timedelta = Field(
|
123
|
+
default=timedelta(hours=1),
|
124
|
+
description="Time after which inactive sessions are cleaned up. Defaults to 1 hour.")
|
125
|
+
|
126
|
+
@model_validator(mode="after")
|
127
|
+
def _validate_reconnect_backoff(self) -> "MCPClientConfig":
|
128
|
+
"""Validate reconnect backoff values."""
|
129
|
+
if self.reconnect_max_backoff < self.reconnect_initial_backoff:
|
130
|
+
raise ValueError("reconnect_max_backoff must be greater than or equal to reconnect_initial_backoff")
|
131
|
+
return self
|