nvidia-nat-mcp 1.4.0a20260107__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. nat/meta/pypi.md +32 -0
  2. nat/plugins/mcp/__init__.py +14 -0
  3. nat/plugins/mcp/auth/__init__.py +14 -0
  4. nat/plugins/mcp/auth/auth_flow_handler.py +208 -0
  5. nat/plugins/mcp/auth/auth_provider.py +431 -0
  6. nat/plugins/mcp/auth/auth_provider_config.py +86 -0
  7. nat/plugins/mcp/auth/register.py +33 -0
  8. nat/plugins/mcp/auth/service_account/__init__.py +14 -0
  9. nat/plugins/mcp/auth/service_account/provider.py +136 -0
  10. nat/plugins/mcp/auth/service_account/provider_config.py +137 -0
  11. nat/plugins/mcp/auth/service_account/token_client.py +156 -0
  12. nat/plugins/mcp/auth/token_storage.py +265 -0
  13. nat/plugins/mcp/cli/__init__.py +15 -0
  14. nat/plugins/mcp/cli/commands.py +1051 -0
  15. nat/plugins/mcp/client/__init__.py +15 -0
  16. nat/plugins/mcp/client/client_base.py +665 -0
  17. nat/plugins/mcp/client/client_config.py +146 -0
  18. nat/plugins/mcp/client/client_impl.py +782 -0
  19. nat/plugins/mcp/exception_handler.py +211 -0
  20. nat/plugins/mcp/exceptions.py +142 -0
  21. nat/plugins/mcp/register.py +23 -0
  22. nat/plugins/mcp/server/__init__.py +15 -0
  23. nat/plugins/mcp/server/front_end_config.py +109 -0
  24. nat/plugins/mcp/server/front_end_plugin.py +155 -0
  25. nat/plugins/mcp/server/front_end_plugin_worker.py +411 -0
  26. nat/plugins/mcp/server/introspection_token_verifier.py +72 -0
  27. nat/plugins/mcp/server/memory_profiler.py +320 -0
  28. nat/plugins/mcp/server/register_frontend.py +27 -0
  29. nat/plugins/mcp/server/tool_converter.py +286 -0
  30. nat/plugins/mcp/utils.py +228 -0
  31. nvidia_nat_mcp-1.4.0a20260107.dist-info/METADATA +55 -0
  32. nvidia_nat_mcp-1.4.0a20260107.dist-info/RECORD +37 -0
  33. nvidia_nat_mcp-1.4.0a20260107.dist-info/WHEEL +5 -0
  34. nvidia_nat_mcp-1.4.0a20260107.dist-info/entry_points.txt +9 -0
  35. nvidia_nat_mcp-1.4.0a20260107.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  36. nvidia_nat_mcp-1.4.0a20260107.dist-info/licenses/LICENSE.md +201 -0
  37. nvidia_nat_mcp-1.4.0a20260107.dist-info/top_level.txt +1 -0
@@ -0,0 +1,15 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """MCP client components."""
@@ -0,0 +1,665 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import logging
20
+ from abc import ABC
21
+ from abc import abstractmethod
22
+ from collections.abc import AsyncGenerator
23
+ from collections.abc import Callable
24
+ from contextlib import AsyncExitStack
25
+ from contextlib import asynccontextmanager
26
+ from datetime import timedelta
27
+
28
+ import anyio
29
+ import httpx
30
+
31
+ from mcp import ClientSession
32
+ from mcp.client.sse import sse_client
33
+ from mcp.client.stdio import StdioServerParameters
34
+ from mcp.client.stdio import stdio_client
35
+ from mcp.client.streamable_http import streamablehttp_client
36
+ from mcp.types import TextContent
37
+ from nat.authentication.interfaces import AuthenticatedContext
38
+ from nat.authentication.interfaces import AuthFlowType
39
+ from nat.authentication.interfaces import AuthProviderBase
40
+ from nat.plugins.mcp.exception_handler import convert_to_mcp_error
41
+ from nat.plugins.mcp.exception_handler import format_mcp_error
42
+ from nat.plugins.mcp.exception_handler import mcp_exception_handler
43
+ from nat.plugins.mcp.exceptions import MCPError
44
+ from nat.plugins.mcp.exceptions import MCPToolNotFoundError
45
+ from nat.plugins.mcp.utils import model_from_mcp_schema
46
+ from nat.utils.type_utils import override
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ class AuthAdapter(httpx.Auth):
52
+ """
53
+ httpx.Auth adapter for authentication providers.
54
+ Converts AuthProviderBase to httpx.Auth interface for dynamic token management.
55
+ """
56
+
57
+ def __init__(self, auth_provider: AuthProviderBase, user_id: str | None = None):
58
+ self.auth_provider = auth_provider
59
+ self.user_id = user_id # Session-specific user ID for cache isolation
60
+ # each adapter instance has its own lock to avoid unnecessary delays for multiple clients
61
+ self._lock = anyio.Lock()
62
+ # Track whether we're currently in an interactive authentication flow
63
+ self.is_authenticating = False
64
+
65
+ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
66
+ """Add authentication headers to the request using NAT auth provider."""
67
+ async with self._lock:
68
+ try:
69
+ # Get auth headers from the NAT auth provider:
70
+ # 1. If discovery is yet to done this will return None and request will be sent without auth header.
71
+ # 2. If discovery is done, this will return the auth header from cache if the token is still valid
72
+ auth_headers = await self._get_auth_headers(request=request, response=None)
73
+ request.headers.update(auth_headers)
74
+ except Exception as e:
75
+ logger.info("Failed to get auth headers: %s", e)
76
+ # Continue without auth headers if auth fails
77
+
78
+ response = yield request
79
+
80
+ # Handle 401 responses by retrying with fresh auth
81
+ if response.status_code == 401:
82
+ try:
83
+ # 401 can happen if:
84
+ # 1. The request was sent without auth header
85
+ # 2. The auth headers are invalid
86
+ # 3. The auth headers are expired
87
+ # 4. The auth headers are revoked
88
+ # 5. Auth config on the MCP server has changed
89
+ # In this case we attempt to re-run discovery and authentication
90
+
91
+ # Signal that we're entering interactive auth flow
92
+ self.is_authenticating = True
93
+ logger.debug("Starting authentication flow due to 401 response")
94
+
95
+ auth_headers = await self._get_auth_headers(request=request, response=response)
96
+ request.headers.update(auth_headers)
97
+ yield request # Retry the request
98
+ except Exception as e:
99
+ logger.info("Failed to refresh auth after 401: %s", e)
100
+ raise
101
+ finally:
102
+ # Signal that auth flow is complete
103
+ self.is_authenticating = False
104
+ logger.debug("Authentication flow completed")
105
+ return
106
+
107
+ async def _get_auth_headers(self,
108
+ request: httpx.Request | None = None,
109
+ response: httpx.Response | None = None) -> dict[str, str]:
110
+ """Get authentication headers from the NAT auth provider."""
111
+ try:
112
+ # Use the user_id passed to this AuthAdapter instance
113
+ auth_result = await self.auth_provider.authenticate(user_id=self.user_id, response=response)
114
+
115
+ # Build headers from credentials
116
+ from nat.data_models.authentication import BearerTokenCred
117
+ from nat.data_models.authentication import HeaderCred
118
+ headers = {}
119
+
120
+ for cred in auth_result.credentials:
121
+ if isinstance(cred, BearerTokenCred):
122
+ # Standard Bearer token
123
+ token = cred.token.get_secret_value()
124
+ headers["Authorization"] = f"Bearer {token}"
125
+ elif isinstance(cred, HeaderCred):
126
+ # Generic header credential (supports custom formats and service accounts)
127
+ headers[cred.name] = cred.value.get_secret_value()
128
+
129
+ return headers
130
+ except Exception as e:
131
+ logger.warning("Failed to get auth token: %s", e)
132
+ return {}
133
+
134
+
135
+ class MCPBaseClient(ABC):
136
+ """
137
+ Base client for creating a MCP transport session and connecting to an MCP server
138
+
139
+ Args:
140
+ transport (str): The type of client to use ('sse', 'stdio', or 'streamable-http')
141
+ auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
142
+ tool_call_timeout (timedelta): Timeout for tool calls when authentication is not required
143
+ auth_flow_timeout (timedelta): Extended timeout for tool calls that may require interactive authentication
144
+ reconnect_enabled (bool): Whether to automatically reconnect on connection failures
145
+ reconnect_max_attempts (int): Maximum number of reconnection attempts
146
+ reconnect_initial_backoff (float): Initial backoff delay in seconds for reconnection attempts
147
+ reconnect_max_backoff (float): Maximum backoff delay in seconds for reconnection attempts
148
+ """
149
+
150
+ def __init__(self,
151
+ transport: str = 'streamable-http',
152
+ auth_provider: AuthProviderBase | None = None,
153
+ user_id: str | None = None,
154
+ tool_call_timeout: timedelta = timedelta(seconds=60),
155
+ auth_flow_timeout: timedelta = timedelta(seconds=300),
156
+ reconnect_enabled: bool = True,
157
+ reconnect_max_attempts: int = 2,
158
+ reconnect_initial_backoff: float = 0.5,
159
+ reconnect_max_backoff: float = 50.0):
160
+ self._tools = None
161
+ self._transport = transport.lower()
162
+ if self._transport not in ['sse', 'stdio', 'streamable-http']:
163
+ raise ValueError("transport must be either 'sse', 'stdio' or 'streamable-http'")
164
+
165
+ self._exit_stack: AsyncExitStack | None = None
166
+ self._session: ClientSession | None = None # Main session
167
+ self._connection_established = False
168
+ self._initial_connection = False
169
+
170
+ # Convert auth provider to AuthAdapter
171
+ self._auth_provider = auth_provider
172
+ # Use provided user_id or fall back to auth provider's default_user_id (if available)
173
+ effective_user_id = user_id or (getattr(auth_provider.config, 'default_user_id', None)
174
+ if auth_provider else None)
175
+ self._httpx_auth = AuthAdapter(auth_provider, effective_user_id) if auth_provider else None
176
+
177
+ self._tool_call_timeout = tool_call_timeout
178
+ self._auth_flow_timeout = auth_flow_timeout
179
+
180
+ # Reconnect configuration
181
+ self._reconnect_enabled = reconnect_enabled
182
+ self._reconnect_max_attempts = reconnect_max_attempts
183
+ self._reconnect_initial_backoff = reconnect_initial_backoff
184
+ self._reconnect_max_backoff = reconnect_max_backoff
185
+ self._reconnect_lock: asyncio.Lock = asyncio.Lock()
186
+
187
+ @property
188
+ def auth_provider(self) -> AuthProviderBase | None:
189
+ return self._auth_provider
190
+
191
+ @property
192
+ def transport(self) -> str:
193
+ return self._transport
194
+
195
+ async def __aenter__(self):
196
+ if self._exit_stack:
197
+ raise RuntimeError("MCPBaseClient already initialized. Use async with to initialize.")
198
+
199
+ self._exit_stack = AsyncExitStack()
200
+
201
+ # Establish connection with httpx.Auth
202
+ self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
203
+
204
+ self._initial_connection = True
205
+ self._connection_established = True
206
+
207
+ return self
208
+
209
+ async def __aexit__(self, exc_type, exc_value, traceback):
210
+ if self._exit_stack:
211
+ # Close session
212
+ await self._exit_stack.aclose()
213
+ self._session = None
214
+ self._exit_stack = None
215
+
216
+ self._connection_established = False
217
+ self._tools = None
218
+
219
+ @property
220
+ def server_name(self):
221
+ """
222
+ Provide server name for logging
223
+ """
224
+ return self._transport
225
+
226
+ @abstractmethod
227
+ @asynccontextmanager
228
+ async def connect_to_server(self) -> AsyncGenerator[ClientSession, None]:
229
+ """
230
+ Establish a session with an MCP server within an async context
231
+ """
232
+ yield
233
+
234
+ async def _reconnect(self):
235
+ """
236
+ Attempt to reconnect by tearing down and re-establishing the session.
237
+ """
238
+ async with self._reconnect_lock:
239
+ backoff = self._reconnect_initial_backoff
240
+ attempt = 0
241
+ last_error: Exception | None = None
242
+
243
+ while attempt in range(0, self._reconnect_max_attempts):
244
+ attempt += 1
245
+ try:
246
+ # Close the existing stack and ClientSession
247
+ if self._exit_stack:
248
+ await self._exit_stack.aclose()
249
+ # Create a fresh stack and session
250
+ self._exit_stack = AsyncExitStack()
251
+ self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
252
+
253
+ self._connection_established = True
254
+ self._tools = None
255
+
256
+ logger.info("Reconnected to MCP server (%s) on attempt %d", self.server_name, attempt)
257
+ return
258
+
259
+ except Exception as e:
260
+ last_error = e
261
+ logger.warning("Reconnect attempt %d failed for %s: %s", attempt, self.server_name, e)
262
+ await asyncio.sleep(min(backoff, self._reconnect_max_backoff))
263
+ backoff = min(backoff * 2, self._reconnect_max_backoff)
264
+
265
+ # All attempts failed
266
+ self._connection_established = False
267
+ if last_error:
268
+ raise last_error
269
+
270
+ async def _with_reconnect(self, coro):
271
+ """
272
+ Execute an awaited operation, reconnecting once on errors.
273
+ Does not reconnect if the error occurs during an active authentication flow.
274
+ """
275
+ try:
276
+ return await coro()
277
+ except Exception as e:
278
+ # Check if error happened during active authentication flow
279
+ if self._httpx_auth and self._httpx_auth.is_authenticating:
280
+ # Provide specific error message for authentication timeouts
281
+ if isinstance(e, TimeoutError):
282
+ logger.error("Timeout during user authentication flow - user may have abandoned authentication")
283
+ raise RuntimeError(
284
+ "Authentication timed out. User did not complete authentication in browser within "
285
+ f"{self._auth_flow_timeout.total_seconds()} seconds.") from e
286
+ else:
287
+ logger.error("Error during authentication flow: %s", e)
288
+ raise
289
+
290
+ # Normal error - attempt reconnect if enabled
291
+ if self._reconnect_enabled:
292
+ try:
293
+ await self._reconnect()
294
+ except Exception as reconnect_err:
295
+ logger.error("MCP Client reconnect attempt failed: %s", reconnect_err)
296
+ raise
297
+ return await coro()
298
+ raise
299
+
300
+ async def _has_cached_auth_token(self) -> bool:
301
+ """
302
+ Check if we have a cached, non-expired authentication token.
303
+
304
+ Returns:
305
+ bool: True if we have a valid cached token, False if authentication may be needed
306
+ """
307
+ if not self._auth_provider:
308
+ return True # No auth needed
309
+
310
+ try:
311
+ # Check if OAuth2 provider has tokens cached
312
+ if hasattr(self._auth_provider, '_auth_code_provider'):
313
+ provider = self._auth_provider._auth_code_provider
314
+ if provider and hasattr(provider, '_authenticated_tokens'):
315
+ # Check if we have at least one non-expired token
316
+ for auth_result in provider._authenticated_tokens.values():
317
+ if not auth_result.is_expired():
318
+ return True
319
+
320
+ return False
321
+ except Exception:
322
+ # If we can't check, assume we need auth to be safe
323
+ return False
324
+
325
+ async def _get_tool_call_timeout(self) -> timedelta:
326
+ """
327
+ Determine the appropriate timeout for a tool call based on authentication state.
328
+
329
+ Returns:
330
+ timedelta: auth_flow_timeout if authentication may be needed, tool_call_timeout otherwise
331
+ """
332
+ if self._auth_provider:
333
+ has_token = await self._has_cached_auth_token()
334
+ timeout = self._tool_call_timeout if has_token else self._auth_flow_timeout
335
+ if not has_token:
336
+ logger.debug("Using extended timeout (%s) for potential interactive authentication", timeout)
337
+ return timeout
338
+ else:
339
+ return self._tool_call_timeout
340
+
341
+ @mcp_exception_handler
342
+ async def get_tools(self) -> dict[str, MCPToolClient]:
343
+ """
344
+ Retrieve a dictionary of all tools served by the MCP server.
345
+ Uses unauthenticated session for discovery.
346
+ """
347
+
348
+ async def _get_tools():
349
+ session = self._session
350
+ try:
351
+ # Add timeout to the list_tools call.
352
+ # This is needed because MCP SDK does not support timeout for list_tools()
353
+ with anyio.fail_after(self._tool_call_timeout.total_seconds()):
354
+ tools = await session.list_tools()
355
+ except TimeoutError as e:
356
+ from nat.plugins.mcp.exceptions import MCPTimeoutError
357
+ raise MCPTimeoutError(self.server_name, e)
358
+
359
+ return tools
360
+
361
+ try:
362
+ response = await self._with_reconnect(_get_tools)
363
+ except Exception as e:
364
+ logger.warning("Failed to get tools: %s", e)
365
+ raise
366
+
367
+ return {
368
+ tool.name:
369
+ MCPToolClient(session=self._session,
370
+ tool_name=tool.name,
371
+ tool_description=tool.description,
372
+ tool_input_schema=tool.inputSchema,
373
+ parent_client=self)
374
+ for tool in response.tools
375
+ }
376
+
377
+ @mcp_exception_handler
378
+ async def get_tool(self, tool_name: str) -> MCPToolClient:
379
+ """
380
+ Get an MCP Tool by name.
381
+
382
+ Args:
383
+ tool_name (str): Name of the tool to load.
384
+
385
+ Returns:
386
+ MCPToolClient for the configured tool.
387
+
388
+ Raises:
389
+ MCPToolNotFoundError: If no tool is available with that name.
390
+ """
391
+ if not self._exit_stack:
392
+ raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
393
+
394
+ if not self._tools:
395
+ self._tools = await self.get_tools()
396
+
397
+ tool = self._tools.get(tool_name)
398
+ if not tool:
399
+ raise MCPToolNotFoundError(tool_name, self.server_name)
400
+ return tool
401
+
402
+ def set_user_auth_callback(self, auth_callback: Callable[[AuthFlowType], AuthenticatedContext]):
403
+ """Set the user authentication callback."""
404
+ if self._auth_provider and hasattr(self._auth_provider, "_set_custom_auth_callback"):
405
+ self._auth_provider._set_custom_auth_callback(auth_callback)
406
+
407
+ @mcp_exception_handler
408
+ async def call_tool(self, tool_name: str, tool_args: dict | None):
409
+
410
+ async def _call_tool():
411
+ session = self._session
412
+ timeout = await self._get_tool_call_timeout()
413
+ return await session.call_tool(tool_name, tool_args, read_timeout_seconds=timeout)
414
+
415
+ return await self._with_reconnect(_call_tool)
416
+
417
+
418
+ class MCPSSEClient(MCPBaseClient):
419
+ """
420
+ Client for creating a session and connecting to an MCP server using SSE
421
+
422
+ Args:
423
+ url (str): The url of the MCP server
424
+ """
425
+
426
+ def __init__(self,
427
+ url: str,
428
+ tool_call_timeout: timedelta = timedelta(seconds=60),
429
+ auth_flow_timeout: timedelta = timedelta(seconds=300),
430
+ reconnect_enabled: bool = True,
431
+ reconnect_max_attempts: int = 2,
432
+ reconnect_initial_backoff: float = 0.5,
433
+ reconnect_max_backoff: float = 50.0):
434
+ super().__init__("sse",
435
+ tool_call_timeout=tool_call_timeout,
436
+ auth_flow_timeout=auth_flow_timeout,
437
+ reconnect_enabled=reconnect_enabled,
438
+ reconnect_max_attempts=reconnect_max_attempts,
439
+ reconnect_initial_backoff=reconnect_initial_backoff,
440
+ reconnect_max_backoff=reconnect_max_backoff)
441
+ self._url = url
442
+
443
+ @property
444
+ def url(self) -> str:
445
+ return self._url
446
+
447
+ @property
448
+ def server_name(self):
449
+ return f"sse:{self._url}"
450
+
451
+ @asynccontextmanager
452
+ @override
453
+ async def connect_to_server(self):
454
+ """
455
+ Establish a session with an MCP SSE server within an async context
456
+ """
457
+ async with sse_client(url=self._url) as (read, write):
458
+ async with ClientSession(read, write) as session:
459
+ await session.initialize()
460
+ yield session
461
+
462
+
463
+ class MCPStdioClient(MCPBaseClient):
464
+ """
465
+ Client for creating a session and connecting to an MCP server using stdio.
466
+ This is a local transport that spawns the MCP server process and communicates
467
+ with it over stdin/stdout.
468
+
469
+ Args:
470
+ command (str): The command to run
471
+ args (list[str] | None): Additional arguments for the command
472
+ env (dict[str, str] | None): Environment variables to set for the process
473
+ """
474
+
475
+ def __init__(self,
476
+ command: str,
477
+ args: list[str] | None = None,
478
+ env: dict[str, str] | None = None,
479
+ tool_call_timeout: timedelta = timedelta(seconds=60),
480
+ auth_flow_timeout: timedelta = timedelta(seconds=300),
481
+ reconnect_enabled: bool = True,
482
+ reconnect_max_attempts: int = 2,
483
+ reconnect_initial_backoff: float = 0.5,
484
+ reconnect_max_backoff: float = 50.0):
485
+ super().__init__("stdio",
486
+ tool_call_timeout=tool_call_timeout,
487
+ auth_flow_timeout=auth_flow_timeout,
488
+ reconnect_enabled=reconnect_enabled,
489
+ reconnect_max_attempts=reconnect_max_attempts,
490
+ reconnect_initial_backoff=reconnect_initial_backoff,
491
+ reconnect_max_backoff=reconnect_max_backoff)
492
+ self._command = command
493
+ self._args = args
494
+ self._env = env
495
+
496
+ @property
497
+ def command(self) -> str:
498
+ return self._command
499
+
500
+ @property
501
+ def server_name(self):
502
+ return f"stdio:{self._command}"
503
+
504
+ @property
505
+ def args(self) -> list[str] | None:
506
+ return self._args
507
+
508
+ @property
509
+ def env(self) -> dict[str, str] | None:
510
+ return self._env
511
+
512
+ @asynccontextmanager
513
+ @override
514
+ async def connect_to_server(self):
515
+ """
516
+ Establish a session with an MCP server via stdio within an async context
517
+ """
518
+
519
+ server_params = StdioServerParameters(command=self._command, args=self._args or [], env=self._env)
520
+ async with stdio_client(server_params) as (read, write):
521
+ async with ClientSession(read, write) as session:
522
+ await session.initialize()
523
+ yield session
524
+
525
+
526
+ class MCPStreamableHTTPClient(MCPBaseClient):
527
+ """
528
+ Client for creating a session and connecting to an MCP server using streamable-http
529
+
530
+ Args:
531
+ url (str): The url of the MCP server
532
+ auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
533
+ """
534
+
535
+ def __init__(self,
536
+ url: str,
537
+ auth_provider: AuthProviderBase | None = None,
538
+ user_id: str | None = None,
539
+ tool_call_timeout: timedelta = timedelta(seconds=60),
540
+ auth_flow_timeout: timedelta = timedelta(seconds=300),
541
+ reconnect_enabled: bool = True,
542
+ reconnect_max_attempts: int = 2,
543
+ reconnect_initial_backoff: float = 0.5,
544
+ reconnect_max_backoff: float = 50.0):
545
+ super().__init__("streamable-http",
546
+ auth_provider=auth_provider,
547
+ user_id=user_id,
548
+ tool_call_timeout=tool_call_timeout,
549
+ auth_flow_timeout=auth_flow_timeout,
550
+ reconnect_enabled=reconnect_enabled,
551
+ reconnect_max_attempts=reconnect_max_attempts,
552
+ reconnect_initial_backoff=reconnect_initial_backoff,
553
+ reconnect_max_backoff=reconnect_max_backoff)
554
+ self._url = url
555
+
556
+ @property
557
+ def url(self) -> str:
558
+ return self._url
559
+
560
+ @property
561
+ def server_name(self):
562
+ return f"streamable-http:{self._url}"
563
+
564
+ @asynccontextmanager
565
+ @override
566
+ async def connect_to_server(self):
567
+ """
568
+ Establish a session with an MCP server via streamable-http within an async context
569
+ """
570
+ # Use httpx.Auth for authentication
571
+ async with streamablehttp_client(url=self._url, auth=self._httpx_auth) as (read, write, _):
572
+ async with ClientSession(read, write) as session:
573
+ await session.initialize()
574
+ yield session
575
+
576
+
577
+ class MCPToolClient:
578
+ """
579
+ Client wrapper used to call an MCP tool. This assumes that the MCP transport session
580
+ has already been setup.
581
+
582
+ Args:
583
+ session (ClientSession): The MCP client session
584
+ tool_name (str): The name of the tool to wrap
585
+ tool_description (str): The description of the tool provided by the MCP server.
586
+ tool_input_schema (dict): The input schema for the tool.
587
+ parent_client (MCPBaseClient): The parent MCP client for auth management.
588
+ """
589
+
590
+ def __init__(self,
591
+ session: ClientSession,
592
+ parent_client: MCPBaseClient,
593
+ tool_name: str,
594
+ tool_description: str | None,
595
+ tool_input_schema: dict | None = None):
596
+ self._session = session
597
+ self._tool_name = tool_name
598
+ self._tool_description = tool_description
599
+ self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
600
+ self._parent_client = parent_client
601
+
602
+ if self._parent_client is None:
603
+ raise RuntimeError("MCPToolClient initialized without a parent client.")
604
+
605
+ @property
606
+ def name(self):
607
+ """Returns the name of the tool."""
608
+ return self._tool_name
609
+
610
+ @property
611
+ def description(self):
612
+ """
613
+ Returns the tool's description. If none was provided. Provides a simple description using the tool's name
614
+ """
615
+ if not self._tool_description:
616
+ return f"MCP Tool {self._tool_name}"
617
+ return self._tool_description
618
+
619
+ @property
620
+ def input_schema(self):
621
+ """
622
+ Returns the tool's input_schema.
623
+ """
624
+ return self._input_schema
625
+
626
+ def set_description(self, description: str):
627
+ """
628
+ Manually define the tool's description using the provided string.
629
+ """
630
+ self._tool_description = description
631
+
632
+ async def acall(self, tool_args: dict) -> str:
633
+ """
634
+ Call the MCP tool with the provided arguments.
635
+ Session context is now handled at the client level, eliminating the need for metadata injection.
636
+
637
+ Args:
638
+ tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
639
+ """
640
+ if self._session is None:
641
+ raise RuntimeError("No session available for tool call")
642
+
643
+ try:
644
+ # Simple tool call - session context is already in the client instance
645
+ logger.info("Calling tool %s", self._tool_name)
646
+ result = await self._parent_client.call_tool(self._tool_name, tool_args)
647
+
648
+ output = []
649
+ for res in result.content:
650
+ if isinstance(res, TextContent):
651
+ output.append(res.text)
652
+ else:
653
+ # Log non-text content for now
654
+ logger.warning("Got not-text output from %s of type %s", self.name, type(res))
655
+ result_str = "\n".join(output)
656
+
657
+ if result.isError:
658
+ mcp_error: MCPError = convert_to_mcp_error(RuntimeError(result_str), self._parent_client.server_name)
659
+ raise mcp_error
660
+
661
+ except MCPError as e:
662
+ format_mcp_error(e, include_traceback=False)
663
+ result_str = f"MCPToolClient tool call failed: {e.original_exception}"
664
+
665
+ return result_str