nvidia-nat-mcp 1.4.0a20260107__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +32 -0
- nat/plugins/mcp/__init__.py +14 -0
- nat/plugins/mcp/auth/__init__.py +14 -0
- nat/plugins/mcp/auth/auth_flow_handler.py +208 -0
- nat/plugins/mcp/auth/auth_provider.py +431 -0
- nat/plugins/mcp/auth/auth_provider_config.py +86 -0
- nat/plugins/mcp/auth/register.py +33 -0
- nat/plugins/mcp/auth/service_account/__init__.py +14 -0
- nat/plugins/mcp/auth/service_account/provider.py +136 -0
- nat/plugins/mcp/auth/service_account/provider_config.py +137 -0
- nat/plugins/mcp/auth/service_account/token_client.py +156 -0
- nat/plugins/mcp/auth/token_storage.py +265 -0
- nat/plugins/mcp/cli/__init__.py +15 -0
- nat/plugins/mcp/cli/commands.py +1051 -0
- nat/plugins/mcp/client/__init__.py +15 -0
- nat/plugins/mcp/client/client_base.py +665 -0
- nat/plugins/mcp/client/client_config.py +146 -0
- nat/plugins/mcp/client/client_impl.py +782 -0
- nat/plugins/mcp/exception_handler.py +211 -0
- nat/plugins/mcp/exceptions.py +142 -0
- nat/plugins/mcp/register.py +23 -0
- nat/plugins/mcp/server/__init__.py +15 -0
- nat/plugins/mcp/server/front_end_config.py +109 -0
- nat/plugins/mcp/server/front_end_plugin.py +155 -0
- nat/plugins/mcp/server/front_end_plugin_worker.py +411 -0
- nat/plugins/mcp/server/introspection_token_verifier.py +72 -0
- nat/plugins/mcp/server/memory_profiler.py +320 -0
- nat/plugins/mcp/server/register_frontend.py +27 -0
- nat/plugins/mcp/server/tool_converter.py +286 -0
- nat/plugins/mcp/utils.py +228 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/METADATA +55 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/RECORD +37 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/WHEEL +5 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/entry_points.txt +9 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_mcp-1.4.0a20260107.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""MCP client components."""
|
|
@@ -0,0 +1,665 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import asyncio
|
|
19
|
+
import logging
|
|
20
|
+
from abc import ABC
|
|
21
|
+
from abc import abstractmethod
|
|
22
|
+
from collections.abc import AsyncGenerator
|
|
23
|
+
from collections.abc import Callable
|
|
24
|
+
from contextlib import AsyncExitStack
|
|
25
|
+
from contextlib import asynccontextmanager
|
|
26
|
+
from datetime import timedelta
|
|
27
|
+
|
|
28
|
+
import anyio
|
|
29
|
+
import httpx
|
|
30
|
+
|
|
31
|
+
from mcp import ClientSession
|
|
32
|
+
from mcp.client.sse import sse_client
|
|
33
|
+
from mcp.client.stdio import StdioServerParameters
|
|
34
|
+
from mcp.client.stdio import stdio_client
|
|
35
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
36
|
+
from mcp.types import TextContent
|
|
37
|
+
from nat.authentication.interfaces import AuthenticatedContext
|
|
38
|
+
from nat.authentication.interfaces import AuthFlowType
|
|
39
|
+
from nat.authentication.interfaces import AuthProviderBase
|
|
40
|
+
from nat.plugins.mcp.exception_handler import convert_to_mcp_error
|
|
41
|
+
from nat.plugins.mcp.exception_handler import format_mcp_error
|
|
42
|
+
from nat.plugins.mcp.exception_handler import mcp_exception_handler
|
|
43
|
+
from nat.plugins.mcp.exceptions import MCPError
|
|
44
|
+
from nat.plugins.mcp.exceptions import MCPToolNotFoundError
|
|
45
|
+
from nat.plugins.mcp.utils import model_from_mcp_schema
|
|
46
|
+
from nat.utils.type_utils import override
|
|
47
|
+
|
|
48
|
+
logger = logging.getLogger(__name__)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class AuthAdapter(httpx.Auth):
|
|
52
|
+
"""
|
|
53
|
+
httpx.Auth adapter for authentication providers.
|
|
54
|
+
Converts AuthProviderBase to httpx.Auth interface for dynamic token management.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, auth_provider: AuthProviderBase, user_id: str | None = None):
|
|
58
|
+
self.auth_provider = auth_provider
|
|
59
|
+
self.user_id = user_id # Session-specific user ID for cache isolation
|
|
60
|
+
# each adapter instance has its own lock to avoid unnecessary delays for multiple clients
|
|
61
|
+
self._lock = anyio.Lock()
|
|
62
|
+
# Track whether we're currently in an interactive authentication flow
|
|
63
|
+
self.is_authenticating = False
|
|
64
|
+
|
|
65
|
+
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
|
66
|
+
"""Add authentication headers to the request using NAT auth provider."""
|
|
67
|
+
async with self._lock:
|
|
68
|
+
try:
|
|
69
|
+
# Get auth headers from the NAT auth provider:
|
|
70
|
+
# 1. If discovery is yet to done this will return None and request will be sent without auth header.
|
|
71
|
+
# 2. If discovery is done, this will return the auth header from cache if the token is still valid
|
|
72
|
+
auth_headers = await self._get_auth_headers(request=request, response=None)
|
|
73
|
+
request.headers.update(auth_headers)
|
|
74
|
+
except Exception as e:
|
|
75
|
+
logger.info("Failed to get auth headers: %s", e)
|
|
76
|
+
# Continue without auth headers if auth fails
|
|
77
|
+
|
|
78
|
+
response = yield request
|
|
79
|
+
|
|
80
|
+
# Handle 401 responses by retrying with fresh auth
|
|
81
|
+
if response.status_code == 401:
|
|
82
|
+
try:
|
|
83
|
+
# 401 can happen if:
|
|
84
|
+
# 1. The request was sent without auth header
|
|
85
|
+
# 2. The auth headers are invalid
|
|
86
|
+
# 3. The auth headers are expired
|
|
87
|
+
# 4. The auth headers are revoked
|
|
88
|
+
# 5. Auth config on the MCP server has changed
|
|
89
|
+
# In this case we attempt to re-run discovery and authentication
|
|
90
|
+
|
|
91
|
+
# Signal that we're entering interactive auth flow
|
|
92
|
+
self.is_authenticating = True
|
|
93
|
+
logger.debug("Starting authentication flow due to 401 response")
|
|
94
|
+
|
|
95
|
+
auth_headers = await self._get_auth_headers(request=request, response=response)
|
|
96
|
+
request.headers.update(auth_headers)
|
|
97
|
+
yield request # Retry the request
|
|
98
|
+
except Exception as e:
|
|
99
|
+
logger.info("Failed to refresh auth after 401: %s", e)
|
|
100
|
+
raise
|
|
101
|
+
finally:
|
|
102
|
+
# Signal that auth flow is complete
|
|
103
|
+
self.is_authenticating = False
|
|
104
|
+
logger.debug("Authentication flow completed")
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
async def _get_auth_headers(self,
|
|
108
|
+
request: httpx.Request | None = None,
|
|
109
|
+
response: httpx.Response | None = None) -> dict[str, str]:
|
|
110
|
+
"""Get authentication headers from the NAT auth provider."""
|
|
111
|
+
try:
|
|
112
|
+
# Use the user_id passed to this AuthAdapter instance
|
|
113
|
+
auth_result = await self.auth_provider.authenticate(user_id=self.user_id, response=response)
|
|
114
|
+
|
|
115
|
+
# Build headers from credentials
|
|
116
|
+
from nat.data_models.authentication import BearerTokenCred
|
|
117
|
+
from nat.data_models.authentication import HeaderCred
|
|
118
|
+
headers = {}
|
|
119
|
+
|
|
120
|
+
for cred in auth_result.credentials:
|
|
121
|
+
if isinstance(cred, BearerTokenCred):
|
|
122
|
+
# Standard Bearer token
|
|
123
|
+
token = cred.token.get_secret_value()
|
|
124
|
+
headers["Authorization"] = f"Bearer {token}"
|
|
125
|
+
elif isinstance(cred, HeaderCred):
|
|
126
|
+
# Generic header credential (supports custom formats and service accounts)
|
|
127
|
+
headers[cred.name] = cred.value.get_secret_value()
|
|
128
|
+
|
|
129
|
+
return headers
|
|
130
|
+
except Exception as e:
|
|
131
|
+
logger.warning("Failed to get auth token: %s", e)
|
|
132
|
+
return {}
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class MCPBaseClient(ABC):
|
|
136
|
+
"""
|
|
137
|
+
Base client for creating a MCP transport session and connecting to an MCP server
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
transport (str): The type of client to use ('sse', 'stdio', or 'streamable-http')
|
|
141
|
+
auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
|
|
142
|
+
tool_call_timeout (timedelta): Timeout for tool calls when authentication is not required
|
|
143
|
+
auth_flow_timeout (timedelta): Extended timeout for tool calls that may require interactive authentication
|
|
144
|
+
reconnect_enabled (bool): Whether to automatically reconnect on connection failures
|
|
145
|
+
reconnect_max_attempts (int): Maximum number of reconnection attempts
|
|
146
|
+
reconnect_initial_backoff (float): Initial backoff delay in seconds for reconnection attempts
|
|
147
|
+
reconnect_max_backoff (float): Maximum backoff delay in seconds for reconnection attempts
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
def __init__(self,
|
|
151
|
+
transport: str = 'streamable-http',
|
|
152
|
+
auth_provider: AuthProviderBase | None = None,
|
|
153
|
+
user_id: str | None = None,
|
|
154
|
+
tool_call_timeout: timedelta = timedelta(seconds=60),
|
|
155
|
+
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
|
156
|
+
reconnect_enabled: bool = True,
|
|
157
|
+
reconnect_max_attempts: int = 2,
|
|
158
|
+
reconnect_initial_backoff: float = 0.5,
|
|
159
|
+
reconnect_max_backoff: float = 50.0):
|
|
160
|
+
self._tools = None
|
|
161
|
+
self._transport = transport.lower()
|
|
162
|
+
if self._transport not in ['sse', 'stdio', 'streamable-http']:
|
|
163
|
+
raise ValueError("transport must be either 'sse', 'stdio' or 'streamable-http'")
|
|
164
|
+
|
|
165
|
+
self._exit_stack: AsyncExitStack | None = None
|
|
166
|
+
self._session: ClientSession | None = None # Main session
|
|
167
|
+
self._connection_established = False
|
|
168
|
+
self._initial_connection = False
|
|
169
|
+
|
|
170
|
+
# Convert auth provider to AuthAdapter
|
|
171
|
+
self._auth_provider = auth_provider
|
|
172
|
+
# Use provided user_id or fall back to auth provider's default_user_id (if available)
|
|
173
|
+
effective_user_id = user_id or (getattr(auth_provider.config, 'default_user_id', None)
|
|
174
|
+
if auth_provider else None)
|
|
175
|
+
self._httpx_auth = AuthAdapter(auth_provider, effective_user_id) if auth_provider else None
|
|
176
|
+
|
|
177
|
+
self._tool_call_timeout = tool_call_timeout
|
|
178
|
+
self._auth_flow_timeout = auth_flow_timeout
|
|
179
|
+
|
|
180
|
+
# Reconnect configuration
|
|
181
|
+
self._reconnect_enabled = reconnect_enabled
|
|
182
|
+
self._reconnect_max_attempts = reconnect_max_attempts
|
|
183
|
+
self._reconnect_initial_backoff = reconnect_initial_backoff
|
|
184
|
+
self._reconnect_max_backoff = reconnect_max_backoff
|
|
185
|
+
self._reconnect_lock: asyncio.Lock = asyncio.Lock()
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def auth_provider(self) -> AuthProviderBase | None:
|
|
189
|
+
return self._auth_provider
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def transport(self) -> str:
|
|
193
|
+
return self._transport
|
|
194
|
+
|
|
195
|
+
async def __aenter__(self):
|
|
196
|
+
if self._exit_stack:
|
|
197
|
+
raise RuntimeError("MCPBaseClient already initialized. Use async with to initialize.")
|
|
198
|
+
|
|
199
|
+
self._exit_stack = AsyncExitStack()
|
|
200
|
+
|
|
201
|
+
# Establish connection with httpx.Auth
|
|
202
|
+
self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
|
|
203
|
+
|
|
204
|
+
self._initial_connection = True
|
|
205
|
+
self._connection_established = True
|
|
206
|
+
|
|
207
|
+
return self
|
|
208
|
+
|
|
209
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
210
|
+
if self._exit_stack:
|
|
211
|
+
# Close session
|
|
212
|
+
await self._exit_stack.aclose()
|
|
213
|
+
self._session = None
|
|
214
|
+
self._exit_stack = None
|
|
215
|
+
|
|
216
|
+
self._connection_established = False
|
|
217
|
+
self._tools = None
|
|
218
|
+
|
|
219
|
+
@property
|
|
220
|
+
def server_name(self):
|
|
221
|
+
"""
|
|
222
|
+
Provide server name for logging
|
|
223
|
+
"""
|
|
224
|
+
return self._transport
|
|
225
|
+
|
|
226
|
+
@abstractmethod
|
|
227
|
+
@asynccontextmanager
|
|
228
|
+
async def connect_to_server(self) -> AsyncGenerator[ClientSession, None]:
|
|
229
|
+
"""
|
|
230
|
+
Establish a session with an MCP server within an async context
|
|
231
|
+
"""
|
|
232
|
+
yield
|
|
233
|
+
|
|
234
|
+
async def _reconnect(self):
|
|
235
|
+
"""
|
|
236
|
+
Attempt to reconnect by tearing down and re-establishing the session.
|
|
237
|
+
"""
|
|
238
|
+
async with self._reconnect_lock:
|
|
239
|
+
backoff = self._reconnect_initial_backoff
|
|
240
|
+
attempt = 0
|
|
241
|
+
last_error: Exception | None = None
|
|
242
|
+
|
|
243
|
+
while attempt in range(0, self._reconnect_max_attempts):
|
|
244
|
+
attempt += 1
|
|
245
|
+
try:
|
|
246
|
+
# Close the existing stack and ClientSession
|
|
247
|
+
if self._exit_stack:
|
|
248
|
+
await self._exit_stack.aclose()
|
|
249
|
+
# Create a fresh stack and session
|
|
250
|
+
self._exit_stack = AsyncExitStack()
|
|
251
|
+
self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
|
|
252
|
+
|
|
253
|
+
self._connection_established = True
|
|
254
|
+
self._tools = None
|
|
255
|
+
|
|
256
|
+
logger.info("Reconnected to MCP server (%s) on attempt %d", self.server_name, attempt)
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
except Exception as e:
|
|
260
|
+
last_error = e
|
|
261
|
+
logger.warning("Reconnect attempt %d failed for %s: %s", attempt, self.server_name, e)
|
|
262
|
+
await asyncio.sleep(min(backoff, self._reconnect_max_backoff))
|
|
263
|
+
backoff = min(backoff * 2, self._reconnect_max_backoff)
|
|
264
|
+
|
|
265
|
+
# All attempts failed
|
|
266
|
+
self._connection_established = False
|
|
267
|
+
if last_error:
|
|
268
|
+
raise last_error
|
|
269
|
+
|
|
270
|
+
async def _with_reconnect(self, coro):
|
|
271
|
+
"""
|
|
272
|
+
Execute an awaited operation, reconnecting once on errors.
|
|
273
|
+
Does not reconnect if the error occurs during an active authentication flow.
|
|
274
|
+
"""
|
|
275
|
+
try:
|
|
276
|
+
return await coro()
|
|
277
|
+
except Exception as e:
|
|
278
|
+
# Check if error happened during active authentication flow
|
|
279
|
+
if self._httpx_auth and self._httpx_auth.is_authenticating:
|
|
280
|
+
# Provide specific error message for authentication timeouts
|
|
281
|
+
if isinstance(e, TimeoutError):
|
|
282
|
+
logger.error("Timeout during user authentication flow - user may have abandoned authentication")
|
|
283
|
+
raise RuntimeError(
|
|
284
|
+
"Authentication timed out. User did not complete authentication in browser within "
|
|
285
|
+
f"{self._auth_flow_timeout.total_seconds()} seconds.") from e
|
|
286
|
+
else:
|
|
287
|
+
logger.error("Error during authentication flow: %s", e)
|
|
288
|
+
raise
|
|
289
|
+
|
|
290
|
+
# Normal error - attempt reconnect if enabled
|
|
291
|
+
if self._reconnect_enabled:
|
|
292
|
+
try:
|
|
293
|
+
await self._reconnect()
|
|
294
|
+
except Exception as reconnect_err:
|
|
295
|
+
logger.error("MCP Client reconnect attempt failed: %s", reconnect_err)
|
|
296
|
+
raise
|
|
297
|
+
return await coro()
|
|
298
|
+
raise
|
|
299
|
+
|
|
300
|
+
async def _has_cached_auth_token(self) -> bool:
|
|
301
|
+
"""
|
|
302
|
+
Check if we have a cached, non-expired authentication token.
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
bool: True if we have a valid cached token, False if authentication may be needed
|
|
306
|
+
"""
|
|
307
|
+
if not self._auth_provider:
|
|
308
|
+
return True # No auth needed
|
|
309
|
+
|
|
310
|
+
try:
|
|
311
|
+
# Check if OAuth2 provider has tokens cached
|
|
312
|
+
if hasattr(self._auth_provider, '_auth_code_provider'):
|
|
313
|
+
provider = self._auth_provider._auth_code_provider
|
|
314
|
+
if provider and hasattr(provider, '_authenticated_tokens'):
|
|
315
|
+
# Check if we have at least one non-expired token
|
|
316
|
+
for auth_result in provider._authenticated_tokens.values():
|
|
317
|
+
if not auth_result.is_expired():
|
|
318
|
+
return True
|
|
319
|
+
|
|
320
|
+
return False
|
|
321
|
+
except Exception:
|
|
322
|
+
# If we can't check, assume we need auth to be safe
|
|
323
|
+
return False
|
|
324
|
+
|
|
325
|
+
async def _get_tool_call_timeout(self) -> timedelta:
|
|
326
|
+
"""
|
|
327
|
+
Determine the appropriate timeout for a tool call based on authentication state.
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
timedelta: auth_flow_timeout if authentication may be needed, tool_call_timeout otherwise
|
|
331
|
+
"""
|
|
332
|
+
if self._auth_provider:
|
|
333
|
+
has_token = await self._has_cached_auth_token()
|
|
334
|
+
timeout = self._tool_call_timeout if has_token else self._auth_flow_timeout
|
|
335
|
+
if not has_token:
|
|
336
|
+
logger.debug("Using extended timeout (%s) for potential interactive authentication", timeout)
|
|
337
|
+
return timeout
|
|
338
|
+
else:
|
|
339
|
+
return self._tool_call_timeout
|
|
340
|
+
|
|
341
|
+
@mcp_exception_handler
|
|
342
|
+
async def get_tools(self) -> dict[str, MCPToolClient]:
|
|
343
|
+
"""
|
|
344
|
+
Retrieve a dictionary of all tools served by the MCP server.
|
|
345
|
+
Uses unauthenticated session for discovery.
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
async def _get_tools():
|
|
349
|
+
session = self._session
|
|
350
|
+
try:
|
|
351
|
+
# Add timeout to the list_tools call.
|
|
352
|
+
# This is needed because MCP SDK does not support timeout for list_tools()
|
|
353
|
+
with anyio.fail_after(self._tool_call_timeout.total_seconds()):
|
|
354
|
+
tools = await session.list_tools()
|
|
355
|
+
except TimeoutError as e:
|
|
356
|
+
from nat.plugins.mcp.exceptions import MCPTimeoutError
|
|
357
|
+
raise MCPTimeoutError(self.server_name, e)
|
|
358
|
+
|
|
359
|
+
return tools
|
|
360
|
+
|
|
361
|
+
try:
|
|
362
|
+
response = await self._with_reconnect(_get_tools)
|
|
363
|
+
except Exception as e:
|
|
364
|
+
logger.warning("Failed to get tools: %s", e)
|
|
365
|
+
raise
|
|
366
|
+
|
|
367
|
+
return {
|
|
368
|
+
tool.name:
|
|
369
|
+
MCPToolClient(session=self._session,
|
|
370
|
+
tool_name=tool.name,
|
|
371
|
+
tool_description=tool.description,
|
|
372
|
+
tool_input_schema=tool.inputSchema,
|
|
373
|
+
parent_client=self)
|
|
374
|
+
for tool in response.tools
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
@mcp_exception_handler
|
|
378
|
+
async def get_tool(self, tool_name: str) -> MCPToolClient:
|
|
379
|
+
"""
|
|
380
|
+
Get an MCP Tool by name.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
tool_name (str): Name of the tool to load.
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
MCPToolClient for the configured tool.
|
|
387
|
+
|
|
388
|
+
Raises:
|
|
389
|
+
MCPToolNotFoundError: If no tool is available with that name.
|
|
390
|
+
"""
|
|
391
|
+
if not self._exit_stack:
|
|
392
|
+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
|
393
|
+
|
|
394
|
+
if not self._tools:
|
|
395
|
+
self._tools = await self.get_tools()
|
|
396
|
+
|
|
397
|
+
tool = self._tools.get(tool_name)
|
|
398
|
+
if not tool:
|
|
399
|
+
raise MCPToolNotFoundError(tool_name, self.server_name)
|
|
400
|
+
return tool
|
|
401
|
+
|
|
402
|
+
def set_user_auth_callback(self, auth_callback: Callable[[AuthFlowType], AuthenticatedContext]):
|
|
403
|
+
"""Set the user authentication callback."""
|
|
404
|
+
if self._auth_provider and hasattr(self._auth_provider, "_set_custom_auth_callback"):
|
|
405
|
+
self._auth_provider._set_custom_auth_callback(auth_callback)
|
|
406
|
+
|
|
407
|
+
@mcp_exception_handler
|
|
408
|
+
async def call_tool(self, tool_name: str, tool_args: dict | None):
|
|
409
|
+
|
|
410
|
+
async def _call_tool():
|
|
411
|
+
session = self._session
|
|
412
|
+
timeout = await self._get_tool_call_timeout()
|
|
413
|
+
return await session.call_tool(tool_name, tool_args, read_timeout_seconds=timeout)
|
|
414
|
+
|
|
415
|
+
return await self._with_reconnect(_call_tool)
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
class MCPSSEClient(MCPBaseClient):
|
|
419
|
+
"""
|
|
420
|
+
Client for creating a session and connecting to an MCP server using SSE
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
url (str): The url of the MCP server
|
|
424
|
+
"""
|
|
425
|
+
|
|
426
|
+
def __init__(self,
|
|
427
|
+
url: str,
|
|
428
|
+
tool_call_timeout: timedelta = timedelta(seconds=60),
|
|
429
|
+
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
|
430
|
+
reconnect_enabled: bool = True,
|
|
431
|
+
reconnect_max_attempts: int = 2,
|
|
432
|
+
reconnect_initial_backoff: float = 0.5,
|
|
433
|
+
reconnect_max_backoff: float = 50.0):
|
|
434
|
+
super().__init__("sse",
|
|
435
|
+
tool_call_timeout=tool_call_timeout,
|
|
436
|
+
auth_flow_timeout=auth_flow_timeout,
|
|
437
|
+
reconnect_enabled=reconnect_enabled,
|
|
438
|
+
reconnect_max_attempts=reconnect_max_attempts,
|
|
439
|
+
reconnect_initial_backoff=reconnect_initial_backoff,
|
|
440
|
+
reconnect_max_backoff=reconnect_max_backoff)
|
|
441
|
+
self._url = url
|
|
442
|
+
|
|
443
|
+
@property
|
|
444
|
+
def url(self) -> str:
|
|
445
|
+
return self._url
|
|
446
|
+
|
|
447
|
+
@property
|
|
448
|
+
def server_name(self):
|
|
449
|
+
return f"sse:{self._url}"
|
|
450
|
+
|
|
451
|
+
@asynccontextmanager
|
|
452
|
+
@override
|
|
453
|
+
async def connect_to_server(self):
|
|
454
|
+
"""
|
|
455
|
+
Establish a session with an MCP SSE server within an async context
|
|
456
|
+
"""
|
|
457
|
+
async with sse_client(url=self._url) as (read, write):
|
|
458
|
+
async with ClientSession(read, write) as session:
|
|
459
|
+
await session.initialize()
|
|
460
|
+
yield session
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
class MCPStdioClient(MCPBaseClient):
|
|
464
|
+
"""
|
|
465
|
+
Client for creating a session and connecting to an MCP server using stdio.
|
|
466
|
+
This is a local transport that spawns the MCP server process and communicates
|
|
467
|
+
with it over stdin/stdout.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
command (str): The command to run
|
|
471
|
+
args (list[str] | None): Additional arguments for the command
|
|
472
|
+
env (dict[str, str] | None): Environment variables to set for the process
|
|
473
|
+
"""
|
|
474
|
+
|
|
475
|
+
def __init__(self,
|
|
476
|
+
command: str,
|
|
477
|
+
args: list[str] | None = None,
|
|
478
|
+
env: dict[str, str] | None = None,
|
|
479
|
+
tool_call_timeout: timedelta = timedelta(seconds=60),
|
|
480
|
+
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
|
481
|
+
reconnect_enabled: bool = True,
|
|
482
|
+
reconnect_max_attempts: int = 2,
|
|
483
|
+
reconnect_initial_backoff: float = 0.5,
|
|
484
|
+
reconnect_max_backoff: float = 50.0):
|
|
485
|
+
super().__init__("stdio",
|
|
486
|
+
tool_call_timeout=tool_call_timeout,
|
|
487
|
+
auth_flow_timeout=auth_flow_timeout,
|
|
488
|
+
reconnect_enabled=reconnect_enabled,
|
|
489
|
+
reconnect_max_attempts=reconnect_max_attempts,
|
|
490
|
+
reconnect_initial_backoff=reconnect_initial_backoff,
|
|
491
|
+
reconnect_max_backoff=reconnect_max_backoff)
|
|
492
|
+
self._command = command
|
|
493
|
+
self._args = args
|
|
494
|
+
self._env = env
|
|
495
|
+
|
|
496
|
+
@property
|
|
497
|
+
def command(self) -> str:
|
|
498
|
+
return self._command
|
|
499
|
+
|
|
500
|
+
@property
|
|
501
|
+
def server_name(self):
|
|
502
|
+
return f"stdio:{self._command}"
|
|
503
|
+
|
|
504
|
+
@property
|
|
505
|
+
def args(self) -> list[str] | None:
|
|
506
|
+
return self._args
|
|
507
|
+
|
|
508
|
+
@property
|
|
509
|
+
def env(self) -> dict[str, str] | None:
|
|
510
|
+
return self._env
|
|
511
|
+
|
|
512
|
+
@asynccontextmanager
|
|
513
|
+
@override
|
|
514
|
+
async def connect_to_server(self):
|
|
515
|
+
"""
|
|
516
|
+
Establish a session with an MCP server via stdio within an async context
|
|
517
|
+
"""
|
|
518
|
+
|
|
519
|
+
server_params = StdioServerParameters(command=self._command, args=self._args or [], env=self._env)
|
|
520
|
+
async with stdio_client(server_params) as (read, write):
|
|
521
|
+
async with ClientSession(read, write) as session:
|
|
522
|
+
await session.initialize()
|
|
523
|
+
yield session
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
class MCPStreamableHTTPClient(MCPBaseClient):
|
|
527
|
+
"""
|
|
528
|
+
Client for creating a session and connecting to an MCP server using streamable-http
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
url (str): The url of the MCP server
|
|
532
|
+
auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
|
|
533
|
+
"""
|
|
534
|
+
|
|
535
|
+
def __init__(self,
|
|
536
|
+
url: str,
|
|
537
|
+
auth_provider: AuthProviderBase | None = None,
|
|
538
|
+
user_id: str | None = None,
|
|
539
|
+
tool_call_timeout: timedelta = timedelta(seconds=60),
|
|
540
|
+
auth_flow_timeout: timedelta = timedelta(seconds=300),
|
|
541
|
+
reconnect_enabled: bool = True,
|
|
542
|
+
reconnect_max_attempts: int = 2,
|
|
543
|
+
reconnect_initial_backoff: float = 0.5,
|
|
544
|
+
reconnect_max_backoff: float = 50.0):
|
|
545
|
+
super().__init__("streamable-http",
|
|
546
|
+
auth_provider=auth_provider,
|
|
547
|
+
user_id=user_id,
|
|
548
|
+
tool_call_timeout=tool_call_timeout,
|
|
549
|
+
auth_flow_timeout=auth_flow_timeout,
|
|
550
|
+
reconnect_enabled=reconnect_enabled,
|
|
551
|
+
reconnect_max_attempts=reconnect_max_attempts,
|
|
552
|
+
reconnect_initial_backoff=reconnect_initial_backoff,
|
|
553
|
+
reconnect_max_backoff=reconnect_max_backoff)
|
|
554
|
+
self._url = url
|
|
555
|
+
|
|
556
|
+
@property
|
|
557
|
+
def url(self) -> str:
|
|
558
|
+
return self._url
|
|
559
|
+
|
|
560
|
+
@property
|
|
561
|
+
def server_name(self):
|
|
562
|
+
return f"streamable-http:{self._url}"
|
|
563
|
+
|
|
564
|
+
@asynccontextmanager
|
|
565
|
+
@override
|
|
566
|
+
async def connect_to_server(self):
|
|
567
|
+
"""
|
|
568
|
+
Establish a session with an MCP server via streamable-http within an async context
|
|
569
|
+
"""
|
|
570
|
+
# Use httpx.Auth for authentication
|
|
571
|
+
async with streamablehttp_client(url=self._url, auth=self._httpx_auth) as (read, write, _):
|
|
572
|
+
async with ClientSession(read, write) as session:
|
|
573
|
+
await session.initialize()
|
|
574
|
+
yield session
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
class MCPToolClient:
|
|
578
|
+
"""
|
|
579
|
+
Client wrapper used to call an MCP tool. This assumes that the MCP transport session
|
|
580
|
+
has already been setup.
|
|
581
|
+
|
|
582
|
+
Args:
|
|
583
|
+
session (ClientSession): The MCP client session
|
|
584
|
+
tool_name (str): The name of the tool to wrap
|
|
585
|
+
tool_description (str): The description of the tool provided by the MCP server.
|
|
586
|
+
tool_input_schema (dict): The input schema for the tool.
|
|
587
|
+
parent_client (MCPBaseClient): The parent MCP client for auth management.
|
|
588
|
+
"""
|
|
589
|
+
|
|
590
|
+
def __init__(self,
|
|
591
|
+
session: ClientSession,
|
|
592
|
+
parent_client: MCPBaseClient,
|
|
593
|
+
tool_name: str,
|
|
594
|
+
tool_description: str | None,
|
|
595
|
+
tool_input_schema: dict | None = None):
|
|
596
|
+
self._session = session
|
|
597
|
+
self._tool_name = tool_name
|
|
598
|
+
self._tool_description = tool_description
|
|
599
|
+
self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
|
|
600
|
+
self._parent_client = parent_client
|
|
601
|
+
|
|
602
|
+
if self._parent_client is None:
|
|
603
|
+
raise RuntimeError("MCPToolClient initialized without a parent client.")
|
|
604
|
+
|
|
605
|
+
@property
|
|
606
|
+
def name(self):
|
|
607
|
+
"""Returns the name of the tool."""
|
|
608
|
+
return self._tool_name
|
|
609
|
+
|
|
610
|
+
@property
|
|
611
|
+
def description(self):
|
|
612
|
+
"""
|
|
613
|
+
Returns the tool's description. If none was provided. Provides a simple description using the tool's name
|
|
614
|
+
"""
|
|
615
|
+
if not self._tool_description:
|
|
616
|
+
return f"MCP Tool {self._tool_name}"
|
|
617
|
+
return self._tool_description
|
|
618
|
+
|
|
619
|
+
@property
|
|
620
|
+
def input_schema(self):
|
|
621
|
+
"""
|
|
622
|
+
Returns the tool's input_schema.
|
|
623
|
+
"""
|
|
624
|
+
return self._input_schema
|
|
625
|
+
|
|
626
|
+
def set_description(self, description: str):
|
|
627
|
+
"""
|
|
628
|
+
Manually define the tool's description using the provided string.
|
|
629
|
+
"""
|
|
630
|
+
self._tool_description = description
|
|
631
|
+
|
|
632
|
+
async def acall(self, tool_args: dict) -> str:
|
|
633
|
+
"""
|
|
634
|
+
Call the MCP tool with the provided arguments.
|
|
635
|
+
Session context is now handled at the client level, eliminating the need for metadata injection.
|
|
636
|
+
|
|
637
|
+
Args:
|
|
638
|
+
tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
|
|
639
|
+
"""
|
|
640
|
+
if self._session is None:
|
|
641
|
+
raise RuntimeError("No session available for tool call")
|
|
642
|
+
|
|
643
|
+
try:
|
|
644
|
+
# Simple tool call - session context is already in the client instance
|
|
645
|
+
logger.info("Calling tool %s", self._tool_name)
|
|
646
|
+
result = await self._parent_client.call_tool(self._tool_name, tool_args)
|
|
647
|
+
|
|
648
|
+
output = []
|
|
649
|
+
for res in result.content:
|
|
650
|
+
if isinstance(res, TextContent):
|
|
651
|
+
output.append(res.text)
|
|
652
|
+
else:
|
|
653
|
+
# Log non-text content for now
|
|
654
|
+
logger.warning("Got not-text output from %s of type %s", self.name, type(res))
|
|
655
|
+
result_str = "\n".join(output)
|
|
656
|
+
|
|
657
|
+
if result.isError:
|
|
658
|
+
mcp_error: MCPError = convert_to_mcp_error(RuntimeError(result_str), self._parent_client.server_name)
|
|
659
|
+
raise mcp_error
|
|
660
|
+
|
|
661
|
+
except MCPError as e:
|
|
662
|
+
format_mcp_error(e, include_traceback=False)
|
|
663
|
+
result_str = f"MCPToolClient tool call failed: {e.original_exception}"
|
|
664
|
+
|
|
665
|
+
return result_str
|