nvidia-nat-mcp 1.3.0a20250925__py3-none-any.whl → 1.3.0a20250926__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,12 +15,15 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import asyncio
19
+ import json
18
20
  import logging
19
21
  from abc import ABC
20
22
  from abc import abstractmethod
23
+ from collections.abc import AsyncGenerator
21
24
  from contextlib import AsyncExitStack
22
25
  from contextlib import asynccontextmanager
23
- from typing import AsyncGenerator
26
+ from datetime import timedelta
24
27
 
25
28
  import httpx
26
29
 
@@ -33,7 +36,10 @@ from mcp.types import TextContent
33
36
  from nat.authentication.interfaces import AuthProviderBase
34
37
  from nat.data_models.authentication import AuthReason
35
38
  from nat.data_models.authentication import AuthRequest
39
+ from nat.plugins.mcp.exception_handler import convert_to_mcp_error
40
+ from nat.plugins.mcp.exception_handler import format_mcp_error
36
41
  from nat.plugins.mcp.exception_handler import mcp_exception_handler
42
+ from nat.plugins.mcp.exceptions import MCPError
37
43
  from nat.plugins.mcp.exceptions import MCPToolNotFoundError
38
44
  from nat.plugins.mcp.utils import model_from_mcp_schema
39
45
  from nat.utils.type_utils import override
@@ -85,7 +91,6 @@ class AuthAdapter(httpx.Auth):
85
91
  try:
86
92
  # Check if the request body contains a tool call
87
93
  if request.content:
88
- import json
89
94
  body = json.loads(request.content.decode('utf-8'))
90
95
  # Check if it's a JSON-RPC request with method "tools/call"
91
96
  if (isinstance(body, dict) and body.get("method") == "tools/call"):
@@ -131,7 +136,14 @@ class MCPBaseClient(ABC):
131
136
  auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
132
137
  """
133
138
 
134
- def __init__(self, transport: str = 'streamable-http', auth_provider: AuthProviderBase | None = None):
139
+ def __init__(self,
140
+ transport: str = 'streamable-http',
141
+ auth_provider: AuthProviderBase | None = None,
142
+ tool_call_timeout: timedelta = timedelta(seconds=5),
143
+ reconnect_enabled: bool = True,
144
+ reconnect_max_attempts: int = 2,
145
+ reconnect_initial_backoff: float = 0.5,
146
+ reconnect_max_backoff: float = 50.0):
135
147
  self._tools = None
136
148
  self._transport = transport.lower()
137
149
  if self._transport not in ['sse', 'stdio', 'streamable-http']:
@@ -145,6 +157,15 @@ class MCPBaseClient(ABC):
145
157
  # Convert auth provider to AuthAdapter
146
158
  self._httpx_auth = AuthAdapter(auth_provider) if auth_provider else None
147
159
 
160
+ self._tool_call_timeout = tool_call_timeout
161
+
162
+ # Reconnect configuration
163
+ self._reconnect_enabled = reconnect_enabled
164
+ self._reconnect_max_attempts = reconnect_max_attempts
165
+ self._reconnect_initial_backoff = reconnect_initial_backoff
166
+ self._reconnect_max_backoff = reconnect_max_backoff
167
+ self._reconnect_lock: asyncio.Lock = asyncio.Lock()
168
+
148
169
  @property
149
170
  def transport(self) -> str:
150
171
  return self._transport
@@ -164,13 +185,14 @@ class MCPBaseClient(ABC):
164
185
  return self
165
186
 
166
187
  async def __aexit__(self, exc_type, exc_value, traceback):
167
- if not self._exit_stack:
168
- raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
188
+ if self._exit_stack:
189
+ # Close session
190
+ await self._exit_stack.aclose()
191
+ self._session = None
192
+ self._exit_stack = None
169
193
 
170
- # Close session
171
- await self._exit_stack.aclose()
172
- self._session = None
173
- self._exit_stack = None
194
+ self._connection_established = False
195
+ self._tools = None
174
196
 
175
197
  @property
176
198
  def server_name(self):
@@ -181,22 +203,80 @@ class MCPBaseClient(ABC):
181
203
 
182
204
  @abstractmethod
183
205
  @asynccontextmanager
184
- async def connect_to_server(self):
206
+ async def connect_to_server(self) -> AsyncGenerator[ClientSession, None]:
185
207
  """
186
208
  Establish a session with an MCP server within an async context
187
209
  """
188
210
  yield
189
211
 
212
+ async def _reconnect(self):
213
+ """
214
+ Attempt to reconnect by tearing down and re-establishing the session.
215
+ """
216
+ async with self._reconnect_lock:
217
+ backoff = self._reconnect_initial_backoff
218
+ attempt = 0
219
+ last_error: Exception | None = None
220
+
221
+ while attempt in range(0, self._reconnect_max_attempts):
222
+ attempt += 1
223
+ try:
224
+ # Close the existing stack and ClientSession
225
+ if self._exit_stack:
226
+ await self._exit_stack.aclose()
227
+ # Create a fresh stack and session
228
+ self._exit_stack = AsyncExitStack()
229
+ self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
230
+
231
+ self._connection_established = True
232
+ self._tools = None
233
+
234
+ logger.info("Reconnected to MCP server (%s) on attempt %d", self.server_name, attempt)
235
+ return
236
+
237
+ except Exception as e:
238
+ last_error = e
239
+ logger.warning("Reconnect attempt %d failed for %s: %s", attempt, self.server_name, e)
240
+ await asyncio.sleep(min(backoff, self._reconnect_max_backoff))
241
+ backoff = min(backoff * 2, self._reconnect_max_backoff)
242
+
243
+ # All attempts failed
244
+ self._connection_established = False
245
+ if last_error:
246
+ raise last_error
247
+
248
+ async def _with_reconnect(self, coro):
249
+ """
250
+ Execute an awaited operation, reconnecting once on errors.
251
+ """
252
+ try:
253
+ return await coro()
254
+ except Exception as e:
255
+ if self._reconnect_enabled:
256
+ logger.warning("MCP Client operation failed. Attempting reconnect: %s", e)
257
+ try:
258
+ await self._reconnect()
259
+ except Exception as reconnect_err:
260
+ logger.error("MCP Client reconnect attempt failed: %s", reconnect_err)
261
+ raise
262
+ return await coro()
263
+ raise
264
+
190
265
  async def get_tools(self):
191
266
  """
192
267
  Retrieve a dictionary of all tools served by the MCP server.
193
268
  Uses unauthenticated session for discovery.
194
269
  """
195
270
 
196
- if not self._session:
197
- raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
271
+ async def _get_tools():
272
+ session = self._session
273
+ return await session.list_tools()
198
274
 
199
- response = await self._session.list_tools()
275
+ try:
276
+ response = await self._with_reconnect(_get_tools)
277
+ except Exception as e:
278
+ logger.warning("Failed to get tools: %s", e)
279
+ raise
200
280
 
201
281
  return {
202
282
  tool.name:
@@ -204,7 +284,8 @@ class MCPBaseClient(ABC):
204
284
  tool_name=tool.name,
205
285
  tool_description=tool.description,
206
286
  tool_input_schema=tool.inputSchema,
207
- parent_client=self)
287
+ parent_client=self,
288
+ tool_call_timeout=self._tool_call_timeout)
208
289
  for tool in response.tools
209
290
  }
210
291
 
@@ -235,11 +316,12 @@ class MCPBaseClient(ABC):
235
316
 
236
317
  @mcp_exception_handler
237
318
  async def call_tool(self, tool_name: str, tool_args: dict | None):
238
- if not self._session:
239
- raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
240
319
 
241
- result = await self._session.call_tool(tool_name, tool_args)
242
- return result
320
+ async def _call_tool():
321
+ session = self._session
322
+ return await session.call_tool(tool_name, tool_args, read_timeout_seconds=self._tool_call_timeout)
323
+
324
+ return await self._with_reconnect(_call_tool)
243
325
 
244
326
 
245
327
  class MCPSSEClient(MCPBaseClient):
@@ -250,8 +332,19 @@ class MCPSSEClient(MCPBaseClient):
250
332
  url (str): The url of the MCP server
251
333
  """
252
334
 
253
- def __init__(self, url: str):
254
- super().__init__("sse")
335
+ def __init__(self,
336
+ url: str,
337
+ tool_call_timeout: timedelta = timedelta(seconds=5),
338
+ reconnect_enabled: bool = True,
339
+ reconnect_max_attempts: int = 2,
340
+ reconnect_initial_backoff: float = 0.5,
341
+ reconnect_max_backoff: float = 50.0):
342
+ super().__init__("sse",
343
+ tool_call_timeout=tool_call_timeout,
344
+ reconnect_enabled=reconnect_enabled,
345
+ reconnect_max_attempts=reconnect_max_attempts,
346
+ reconnect_initial_backoff=reconnect_initial_backoff,
347
+ reconnect_max_backoff=reconnect_max_backoff)
255
348
  self._url = url
256
349
 
257
350
  @property
@@ -286,8 +379,21 @@ class MCPStdioClient(MCPBaseClient):
286
379
  env (dict[str, str] | None): Environment variables to set for the process
287
380
  """
288
381
 
289
- def __init__(self, command: str, args: list[str] | None = None, env: dict[str, str] | None = None):
290
- super().__init__("stdio")
382
+ def __init__(self,
383
+ command: str,
384
+ args: list[str] | None = None,
385
+ env: dict[str, str] | None = None,
386
+ tool_call_timeout: timedelta = timedelta(seconds=5),
387
+ reconnect_enabled: bool = True,
388
+ reconnect_max_attempts: int = 2,
389
+ reconnect_initial_backoff: float = 0.5,
390
+ reconnect_max_backoff: float = 50.0):
391
+ super().__init__("stdio",
392
+ tool_call_timeout=tool_call_timeout,
393
+ reconnect_enabled=reconnect_enabled,
394
+ reconnect_max_attempts=reconnect_max_attempts,
395
+ reconnect_initial_backoff=reconnect_initial_backoff,
396
+ reconnect_max_backoff=reconnect_max_backoff)
291
397
  self._command = command
292
398
  self._args = args
293
399
  self._env = env
@@ -331,8 +437,21 @@ class MCPStreamableHTTPClient(MCPBaseClient):
331
437
  auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
332
438
  """
333
439
 
334
- def __init__(self, url: str, auth_provider: AuthProviderBase | None = None):
335
- super().__init__("streamable-http", auth_provider=auth_provider)
440
+ def __init__(self,
441
+ url: str,
442
+ auth_provider: AuthProviderBase | None = None,
443
+ tool_call_timeout: timedelta = timedelta(seconds=5),
444
+ reconnect_enabled: bool = True,
445
+ reconnect_max_attempts: int = 2,
446
+ reconnect_initial_backoff: float = 0.5,
447
+ reconnect_max_backoff: float = 50.0):
448
+ super().__init__("streamable-http",
449
+ auth_provider=auth_provider,
450
+ tool_call_timeout=tool_call_timeout,
451
+ reconnect_enabled=reconnect_enabled,
452
+ reconnect_max_attempts=reconnect_max_attempts,
453
+ reconnect_initial_backoff=reconnect_initial_backoff,
454
+ reconnect_max_backoff=reconnect_max_backoff)
336
455
  self._url = url
337
456
 
338
457
  @property
@@ -371,15 +490,20 @@ class MCPToolClient:
371
490
 
372
491
  def __init__(self,
373
492
  session: ClientSession,
493
+ parent_client: "MCPBaseClient",
374
494
  tool_name: str,
375
495
  tool_description: str | None,
376
496
  tool_input_schema: dict | None = None,
377
- parent_client: "MCPBaseClient | None" = None):
497
+ tool_call_timeout: timedelta = timedelta(seconds=5)):
378
498
  self._session = session
379
499
  self._tool_name = tool_name
380
500
  self._tool_description = tool_description
381
501
  self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
382
502
  self._parent_client = parent_client
503
+ self._tool_call_timeout = tool_call_timeout
504
+
505
+ if self._parent_client is None:
506
+ raise RuntimeError("MCPToolClient initialized without a parent client.")
383
507
 
384
508
  @property
385
509
  def name(self):
@@ -415,22 +539,25 @@ class MCPToolClient:
415
539
  Args:
416
540
  tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
417
541
  """
418
- if self._session is None:
419
- raise RuntimeError("No session available for tool call")
420
542
  logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
421
- result = await self._session.call_tool(self._tool_name, tool_args)
422
-
423
- output = []
424
-
425
- for res in result.content:
426
- if isinstance(res, TextContent):
427
- output.append(res.text)
428
- else:
429
- # Log non-text content for now
430
- logger.warning("Got not-text output from %s of type %s", self.name, type(res))
431
- result_str = "\n".join(output)
432
-
433
- if result.isError:
434
- raise RuntimeError(result_str)
543
+ try:
544
+ result = await self._parent_client.call_tool(self._tool_name, tool_args)
545
+
546
+ output = []
547
+ for res in result.content:
548
+ if isinstance(res, TextContent):
549
+ output.append(res.text)
550
+ else:
551
+ # Log non-text content for now
552
+ logger.warning("Got not-text output from %s of type %s", self.name, type(res))
553
+ result_str = "\n".join(output)
554
+
555
+ if result.isError:
556
+ mcp_error: MCPError = convert_to_mcp_error(RuntimeError(result_str), self._parent_client.server_name)
557
+ raise mcp_error
558
+
559
+ except MCPError as e:
560
+ format_mcp_error(e, include_traceback=False)
561
+ result_str = "MCPToolClient tool call failed: %s" % e.original_exception
435
562
 
436
563
  return result_str
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import logging
17
+ from datetime import timedelta
17
18
  from typing import Literal
18
19
 
19
20
  from pydantic import BaseModel
@@ -90,6 +91,19 @@ class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
90
91
  Configuration for connecting to an MCP server as a client and exposing selected tools.
91
92
  """
92
93
  server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)")
94
+ tool_call_timeout: timedelta = Field(
95
+ default=timedelta(seconds=5), description="Timeout (in seconds) for the MCP tool call. Defaults to 5 seconds.")
96
+ reconnect_enabled: bool = Field(
97
+ default=True,
98
+ description="Whether to enable reconnecting to the MCP server if the connection is lost. \
99
+ Defaults to True.")
100
+ reconnect_max_attempts: int = Field(default=2,
101
+ ge=0,
102
+ description="Maximum number of reconnect attempts. Defaults to 2.")
103
+ reconnect_initial_backoff: float = Field(
104
+ default=0.5, ge=0.0, description="Initial backoff time for reconnect attempts. Defaults to 0.5 seconds.")
105
+ reconnect_max_backoff: float = Field(
106
+ default=50.0, ge=0.0, description="Maximum backoff time for reconnect attempts. Defaults to 50 seconds.")
93
107
  tool_overrides: dict[str, MCPToolOverrideConfig] | None = Field(
94
108
  default=None,
95
109
  description="""Optional tool name overrides and description changes.
@@ -102,6 +116,13 @@ class MCPClientConfig(FunctionGroupBaseConfig, name="mcp_client"):
102
116
  description: "Multiply two numbers" # alias defaults to original name
103
117
  """)
104
118
 
119
+ @model_validator(mode="after")
120
+ def _validate_reconnect_backoff(self) -> "MCPClientConfig":
121
+ """Validate reconnect backoff values."""
122
+ if self.reconnect_max_backoff < self.reconnect_initial_backoff:
123
+ raise ValueError("reconnect_max_backoff must be greater than or equal to reconnect_initial_backoff")
124
+ return self
125
+
105
126
 
106
127
  @register_function_group(config_type=MCPClientConfig)
107
128
  async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
@@ -126,11 +147,29 @@ async def mcp_client_function_group(config: MCPClientConfig, _builder: Builder):
126
147
  if config.server.transport == "stdio":
127
148
  if not config.server.command:
128
149
  raise ValueError("command is required for stdio transport")
129
- client = MCPStdioClient(config.server.command, config.server.args, config.server.env)
150
+ client = MCPStdioClient(config.server.command,
151
+ config.server.args,
152
+ config.server.env,
153
+ config.tool_call_timeout,
154
+ reconnect_enabled=config.reconnect_enabled,
155
+ reconnect_max_attempts=config.reconnect_max_attempts,
156
+ reconnect_initial_backoff=config.reconnect_initial_backoff,
157
+ reconnect_max_backoff=config.reconnect_max_backoff)
130
158
  elif config.server.transport == "sse":
131
- client = MCPSSEClient(str(config.server.url))
159
+ client = MCPSSEClient(str(config.server.url),
160
+ tool_call_timeout=config.tool_call_timeout,
161
+ reconnect_enabled=config.reconnect_enabled,
162
+ reconnect_max_attempts=config.reconnect_max_attempts,
163
+ reconnect_initial_backoff=config.reconnect_initial_backoff,
164
+ reconnect_max_backoff=config.reconnect_max_backoff)
132
165
  elif config.server.transport == "streamable-http":
133
- client = MCPStreamableHTTPClient(str(config.server.url), auth_provider=auth_provider)
166
+ client = MCPStreamableHTTPClient(str(config.server.url),
167
+ auth_provider=auth_provider,
168
+ tool_call_timeout=config.tool_call_timeout,
169
+ reconnect_enabled=config.reconnect_enabled,
170
+ reconnect_max_attempts=config.reconnect_max_attempts,
171
+ reconnect_initial_backoff=config.reconnect_initial_backoff,
172
+ reconnect_max_backoff=config.reconnect_max_backoff)
134
173
  else:
135
174
  raise ValueError(f"Unsupported transport: {config.server.transport}")
136
175
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-nat-mcp
3
- Version: 1.3.0a20250925
3
+ Version: 1.3.0a20250926
4
4
  Summary: Subpackage for MCP client integration in NeMo Agent toolkit
5
5
  Keywords: ai,rag,agents,mcp
6
6
  Classifier: Programming Language :: Python
@@ -9,7 +9,7 @@ Classifier: Programming Language :: Python :: 3.12
9
9
  Classifier: Programming Language :: Python :: 3.13
10
10
  Requires-Python: <3.14,>=3.11
11
11
  Description-Content-Type: text/markdown
12
- Requires-Dist: nvidia-nat==v1.3.0a20250925
12
+ Requires-Dist: nvidia-nat==v1.3.0a20250926
13
13
  Requires-Dist: mcp~=1.14
14
14
 
15
15
  <!--
@@ -1,7 +1,7 @@
1
1
  nat/meta/pypi.md,sha256=GyV4DI1d9ThgEhnYTQ0vh40Q9hPC8jN-goLnRiFDmZ8,1498
2
2
  nat/plugins/mcp/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
3
- nat/plugins/mcp/client_base.py,sha256=eWIHhZIaX8KGJMS7BTQ2szU2Nm4h57zCZ7uR5CmbYiY,15585
4
- nat/plugins/mcp/client_impl.py,sha256=A1rSxVz1K29ZlqY-7BRrMNRAkCVZyUg7MS6vU0stYZc,8067
3
+ nat/plugins/mcp/client_base.py,sha256=uOPn1EE7iL6yvHeBFAdQlNoNmtCaBplX5oyu7AeO2lQ,21356
4
+ nat/plugins/mcp/client_impl.py,sha256=5QdA6nt3xQcA_g-YsGZ5BwBsBVoIlk9ObI4mdJNiuYU,10685
5
5
  nat/plugins/mcp/exception_handler.py,sha256=JdPdZG1NgWpdRnIz7JTGHiJASS5wot9nJiD3SRWV4Kw,7649
6
6
  nat/plugins/mcp/exceptions.py,sha256=EGVOnYlui8xufm8dhJyPL1SUqBLnCGOTvRoeyNcmcWE,5980
7
7
  nat/plugins/mcp/register.py,sha256=HOT2Wl2isGuyFc7BUTi58-BbjI5-EtZMZo7stsv5pN4,831
@@ -11,8 +11,8 @@ nat/plugins/mcp/auth/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aM
11
11
  nat/plugins/mcp/auth/auth_provider.py,sha256=GOmM9vCfVd0QiyD_hBj7zCfkiimHa1WDfTTWWfMsr_k,17466
12
12
  nat/plugins/mcp/auth/auth_provider_config.py,sha256=bE6IKV0yveo98KXr0xynrH5BMwPhRv8xbaMBwYu42YQ,3587
13
13
  nat/plugins/mcp/auth/register.py,sha256=yzphsn1I4a5G39_IacbuX0ZQqGM8fevvTUM_B94UXKE,1211
14
- nvidia_nat_mcp-1.3.0a20250925.dist-info/METADATA,sha256=cZEodxrhN0o9EhFmJASwebVXNsCH8GAZEe7AujtEpao,1997
15
- nvidia_nat_mcp-1.3.0a20250925.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
- nvidia_nat_mcp-1.3.0a20250925.dist-info/entry_points.txt,sha256=rYvUp4i-klBr3bVNh7zYOPXret704vTjvCk1qd7FooI,97
17
- nvidia_nat_mcp-1.3.0a20250925.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
18
- nvidia_nat_mcp-1.3.0a20250925.dist-info/RECORD,,
14
+ nvidia_nat_mcp-1.3.0a20250926.dist-info/METADATA,sha256=Ax4EGQZ3KdGWPm9-MyZ6x6DPm5-6SBPmuPYCFWXD2Sw,1997
15
+ nvidia_nat_mcp-1.3.0a20250926.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
+ nvidia_nat_mcp-1.3.0a20250926.dist-info/entry_points.txt,sha256=rYvUp4i-klBr3bVNh7zYOPXret704vTjvCk1qd7FooI,97
17
+ nvidia_nat_mcp-1.3.0a20250926.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
18
+ nvidia_nat_mcp-1.3.0a20250926.dist-info/RECORD,,