nvidia-nat-mcp 1.3.0a20250909__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +32 -0
- nat/plugins/mcp/__init__.py +14 -0
- nat/plugins/mcp/client_base.py +406 -0
- nat/plugins/mcp/client_impl.py +229 -0
- nat/plugins/mcp/exception_handler.py +211 -0
- nat/plugins/mcp/exceptions.py +142 -0
- nat/plugins/mcp/register.py +22 -0
- nat/plugins/mcp/tool.py +133 -0
- nvidia_nat_mcp-1.3.0a20250909.dist-info/METADATA +46 -0
- nvidia_nat_mcp-1.3.0a20250909.dist-info/RECORD +13 -0
- nvidia_nat_mcp-1.3.0a20250909.dist-info/WHEEL +5 -0
- nvidia_nat_mcp-1.3.0a20250909.dist-info/entry_points.txt +2 -0
- nvidia_nat_mcp-1.3.0a20250909.dist-info/top_level.txt +1 -0
nat/meta/pypi.md
ADDED
@@ -0,0 +1,32 @@
|
|
1
|
+
<!--
|
2
|
+
SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3
|
+
SPDX-License-Identifier: Apache-2.0
|
4
|
+
|
5
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
6
|
+
you may not use this file except in compliance with the License.
|
7
|
+
You may obtain a copy of the License at
|
8
|
+
|
9
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
|
11
|
+
Unless required by applicable law or agreed to in writing, software
|
12
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
13
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14
|
+
See the License for the specific language governing permissions and
|
15
|
+
limitations under the License.
|
16
|
+
-->
|
17
|
+
|
18
|
+

|
19
|
+
|
20
|
+
|
21
|
+
# NVIDIA NeMo Agent Toolkit MCP Subpackage
|
22
|
+
Subpackage for MCP client integration in NeMo Agent toolkit.
|
23
|
+
|
24
|
+
This package provides MCP (Model Context Protocol) client functionality, allowing NeMo Agent toolkit workflows to connect to external MCP servers and use their tools as functions.
|
25
|
+
|
26
|
+
## Features
|
27
|
+
|
28
|
+
- Connect to MCP servers via streamable-http, SSE, or stdio transports
|
29
|
+
- Wrap individual MCP tools as NeMo Agent toolkit functions
|
30
|
+
- Connect to MCP servers and dynamically discover available tools
|
31
|
+
|
32
|
+
For more information about the NVIDIA NeMo Agent toolkit, please visit the [NeMo Agent toolkit GitHub Repo](https://github.com/NVIDIA/NeMo-Agent-Toolkit).
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
@@ -0,0 +1,406 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import logging
|
19
|
+
from abc import ABC
|
20
|
+
from abc import abstractmethod
|
21
|
+
from contextlib import AsyncExitStack
|
22
|
+
from contextlib import asynccontextmanager
|
23
|
+
from enum import Enum
|
24
|
+
from typing import Any
|
25
|
+
|
26
|
+
from pydantic import BaseModel
|
27
|
+
from pydantic import Field
|
28
|
+
from pydantic import create_model
|
29
|
+
|
30
|
+
from mcp import ClientSession
|
31
|
+
from mcp.client.sse import sse_client
|
32
|
+
from mcp.client.stdio import StdioServerParameters
|
33
|
+
from mcp.client.stdio import stdio_client
|
34
|
+
from mcp.client.streamable_http import streamablehttp_client
|
35
|
+
from mcp.types import TextContent
|
36
|
+
from nat.plugins.mcp.exception_handler import mcp_exception_handler
|
37
|
+
from nat.plugins.mcp.exceptions import MCPToolNotFoundError
|
38
|
+
from nat.utils.type_utils import override
|
39
|
+
|
40
|
+
logger = logging.getLogger(__name__)
|
41
|
+
|
42
|
+
|
43
|
+
def model_from_mcp_schema(name: str, mcp_input_schema: dict) -> type[BaseModel]:
|
44
|
+
"""
|
45
|
+
Create a pydantic model from the input schema of the MCP tool
|
46
|
+
"""
|
47
|
+
_type_map = {
|
48
|
+
"string": str,
|
49
|
+
"number": float,
|
50
|
+
"integer": int,
|
51
|
+
"boolean": bool,
|
52
|
+
"array": list,
|
53
|
+
"null": None,
|
54
|
+
"object": dict,
|
55
|
+
}
|
56
|
+
|
57
|
+
properties = mcp_input_schema.get("properties", {})
|
58
|
+
required_fields = set(mcp_input_schema.get("required", []))
|
59
|
+
schema_dict = {}
|
60
|
+
|
61
|
+
def _generate_valid_classname(class_name: str):
|
62
|
+
return class_name.replace('_', ' ').replace('-', ' ').title().replace(' ', '')
|
63
|
+
|
64
|
+
def _generate_field(field_name: str, field_properties: dict[str, Any]) -> tuple:
|
65
|
+
json_type = field_properties.get("type", "string")
|
66
|
+
enum_vals = field_properties.get("enum")
|
67
|
+
|
68
|
+
if enum_vals:
|
69
|
+
enum_name = f"{field_name.capitalize()}Enum"
|
70
|
+
field_type = Enum(enum_name, {item: item for item in enum_vals})
|
71
|
+
|
72
|
+
elif json_type == "object" and "properties" in field_properties:
|
73
|
+
field_type = model_from_mcp_schema(name=field_name, mcp_input_schema=field_properties)
|
74
|
+
elif json_type == "array" and "items" in field_properties:
|
75
|
+
item_properties = field_properties.get("items", {})
|
76
|
+
if item_properties.get("type") == "object":
|
77
|
+
item_type = model_from_mcp_schema(name=field_name, mcp_input_schema=item_properties)
|
78
|
+
else:
|
79
|
+
item_type = _type_map.get(item_properties.get("type", "string"), Any)
|
80
|
+
field_type = list[item_type]
|
81
|
+
elif isinstance(json_type, list):
|
82
|
+
field_type = None
|
83
|
+
for t in json_type:
|
84
|
+
mapped = _type_map.get(t, Any)
|
85
|
+
field_type = mapped if field_type is None else field_type | mapped
|
86
|
+
|
87
|
+
return field_type, Field(
|
88
|
+
default=field_properties.get("default", None if "null" in json_type else ...),
|
89
|
+
description=field_properties.get("description", "")
|
90
|
+
)
|
91
|
+
else:
|
92
|
+
field_type = _type_map.get(json_type, Any)
|
93
|
+
|
94
|
+
# Determine the default value based on whether the field is required
|
95
|
+
if field_name in required_fields:
|
96
|
+
# Field is required - use explicit default if provided, otherwise make it required
|
97
|
+
default_value = field_properties.get("default", ...)
|
98
|
+
else:
|
99
|
+
# Field is optional - use explicit default if provided, otherwise None
|
100
|
+
default_value = field_properties.get("default", None)
|
101
|
+
# Make the type optional if no default was provided
|
102
|
+
if "default" not in field_properties:
|
103
|
+
field_type = field_type | None
|
104
|
+
|
105
|
+
nullable = field_properties.get("nullable", False)
|
106
|
+
description = field_properties.get("description", "")
|
107
|
+
|
108
|
+
field_type = field_type | None if nullable else field_type
|
109
|
+
|
110
|
+
return field_type, Field(default=default_value, description=description)
|
111
|
+
|
112
|
+
for field_name, field_props in properties.items():
|
113
|
+
schema_dict[field_name] = _generate_field(field_name=field_name, field_properties=field_props)
|
114
|
+
return create_model(f"{_generate_valid_classname(name)}InputSchema", **schema_dict)
|
115
|
+
|
116
|
+
|
117
|
+
class MCPBaseClient(ABC):
|
118
|
+
"""
|
119
|
+
Base client for creating a session and connecting to an MCP server
|
120
|
+
|
121
|
+
Args:
|
122
|
+
transport (str): The type of client to use ('sse', 'stdio', or 'streamable-http')
|
123
|
+
"""
|
124
|
+
|
125
|
+
def __init__(self, transport: str = 'streamable-http'):
|
126
|
+
self._tools = None
|
127
|
+
self._transport = transport.lower()
|
128
|
+
if self._transport not in ['sse', 'stdio', 'streamable-http']:
|
129
|
+
raise ValueError("transport must be either 'sse', 'stdio' or 'streamable-http'")
|
130
|
+
|
131
|
+
self._exit_stack: AsyncExitStack | None = None
|
132
|
+
|
133
|
+
self._session: ClientSession | None = None
|
134
|
+
|
135
|
+
@property
|
136
|
+
def transport(self) -> str:
|
137
|
+
return self._transport
|
138
|
+
|
139
|
+
async def __aenter__(self):
|
140
|
+
if self._exit_stack:
|
141
|
+
raise RuntimeError("MCPBaseClient already initialized. Use async with to initialize.")
|
142
|
+
|
143
|
+
self._exit_stack = AsyncExitStack()
|
144
|
+
|
145
|
+
self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
|
146
|
+
|
147
|
+
return self
|
148
|
+
|
149
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
150
|
+
|
151
|
+
if not self._exit_stack:
|
152
|
+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
153
|
+
|
154
|
+
await self._exit_stack.aclose()
|
155
|
+
self._session = None
|
156
|
+
self._exit_stack = None
|
157
|
+
|
158
|
+
@property
|
159
|
+
def server_name(self):
|
160
|
+
"""
|
161
|
+
Provide server name for logging
|
162
|
+
"""
|
163
|
+
return self._transport
|
164
|
+
|
165
|
+
@abstractmethod
|
166
|
+
@asynccontextmanager
|
167
|
+
async def connect_to_server(self):
|
168
|
+
"""
|
169
|
+
Establish a session with an MCP server within an async context
|
170
|
+
"""
|
171
|
+
pass
|
172
|
+
|
173
|
+
async def get_tools(self):
|
174
|
+
"""
|
175
|
+
Retrieve a dictionary of all tools served by the MCP server.
|
176
|
+
"""
|
177
|
+
|
178
|
+
if not self._session:
|
179
|
+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
180
|
+
|
181
|
+
response = await self._session.list_tools()
|
182
|
+
|
183
|
+
return {
|
184
|
+
tool.name:
|
185
|
+
MCPToolClient(session=self._session,
|
186
|
+
tool_name=tool.name,
|
187
|
+
tool_description=tool.description,
|
188
|
+
tool_input_schema=tool.inputSchema)
|
189
|
+
for tool in response.tools
|
190
|
+
}
|
191
|
+
|
192
|
+
@mcp_exception_handler
|
193
|
+
async def get_tool(self, tool_name: str) -> MCPToolClient:
|
194
|
+
"""
|
195
|
+
Get an MCP Tool by name.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
tool_name (str): Name of the tool to load.
|
199
|
+
|
200
|
+
Returns:
|
201
|
+
MCPToolClient for the configured tool.
|
202
|
+
|
203
|
+
Raises:
|
204
|
+
MCPToolNotFoundError: If no tool is available with that name.
|
205
|
+
"""
|
206
|
+
if not self._exit_stack:
|
207
|
+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
208
|
+
|
209
|
+
if not self._tools:
|
210
|
+
self._tools = await self.get_tools()
|
211
|
+
|
212
|
+
tool = self._tools.get(tool_name)
|
213
|
+
if not tool:
|
214
|
+
raise MCPToolNotFoundError(tool_name, self.server_name)
|
215
|
+
return tool
|
216
|
+
|
217
|
+
@mcp_exception_handler
|
218
|
+
async def call_tool(self, tool_name: str, tool_args: dict | None):
|
219
|
+
if not self._session:
|
220
|
+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
221
|
+
|
222
|
+
result = await self._session.call_tool(tool_name, tool_args)
|
223
|
+
return result
|
224
|
+
|
225
|
+
|
226
|
+
class MCPSSEClient(MCPBaseClient):
|
227
|
+
"""
|
228
|
+
Client for creating a session and connecting to an MCP server using SSE
|
229
|
+
|
230
|
+
Args:
|
231
|
+
url (str): The url of the MCP server
|
232
|
+
"""
|
233
|
+
|
234
|
+
def __init__(self, url: str):
|
235
|
+
super().__init__("sse")
|
236
|
+
self._url = url
|
237
|
+
|
238
|
+
@property
|
239
|
+
def url(self) -> str:
|
240
|
+
return self._url
|
241
|
+
|
242
|
+
@property
|
243
|
+
def server_name(self):
|
244
|
+
return f"sse:{self._url}"
|
245
|
+
|
246
|
+
@asynccontextmanager
|
247
|
+
@override
|
248
|
+
async def connect_to_server(self):
|
249
|
+
"""
|
250
|
+
Establish a session with an MCP SSE server within an async context
|
251
|
+
"""
|
252
|
+
async with sse_client(url=self._url) as (read, write):
|
253
|
+
async with ClientSession(read, write) as session:
|
254
|
+
await session.initialize()
|
255
|
+
yield session
|
256
|
+
|
257
|
+
|
258
|
+
class MCPStdioClient(MCPBaseClient):
|
259
|
+
"""
|
260
|
+
Client for creating a session and connecting to an MCP server using stdio
|
261
|
+
|
262
|
+
Args:
|
263
|
+
command (str): The command to run
|
264
|
+
args (list[str] | None): Additional arguments for the command
|
265
|
+
env (dict[str, str] | None): Environment variables to set for the process
|
266
|
+
"""
|
267
|
+
|
268
|
+
def __init__(self, command: str, args: list[str] | None = None, env: dict[str, str] | None = None):
|
269
|
+
super().__init__("stdio")
|
270
|
+
self._command = command
|
271
|
+
self._args = args
|
272
|
+
self._env = env
|
273
|
+
|
274
|
+
@property
|
275
|
+
def command(self) -> str:
|
276
|
+
return self._command
|
277
|
+
|
278
|
+
@property
|
279
|
+
def server_name(self):
|
280
|
+
return f"stdio:{self._command}"
|
281
|
+
|
282
|
+
@property
|
283
|
+
def args(self) -> list[str] | None:
|
284
|
+
return self._args
|
285
|
+
|
286
|
+
@property
|
287
|
+
def env(self) -> dict[str, str] | None:
|
288
|
+
return self._env
|
289
|
+
|
290
|
+
@asynccontextmanager
|
291
|
+
@override
|
292
|
+
async def connect_to_server(self):
|
293
|
+
"""
|
294
|
+
Establish a session with an MCP server via stdio within an async context
|
295
|
+
"""
|
296
|
+
|
297
|
+
server_params = StdioServerParameters(command=self._command, args=self._args or [], env=self._env)
|
298
|
+
async with stdio_client(server_params) as (read, write):
|
299
|
+
async with ClientSession(read, write) as session:
|
300
|
+
await session.initialize()
|
301
|
+
yield session
|
302
|
+
|
303
|
+
|
304
|
+
class MCPStreamableHTTPClient(MCPBaseClient):
|
305
|
+
"""
|
306
|
+
Client for creating a session and connecting to an MCP server using streamable-http
|
307
|
+
|
308
|
+
Args:
|
309
|
+
url (str): The url of the MCP server
|
310
|
+
"""
|
311
|
+
|
312
|
+
def __init__(self, url: str):
|
313
|
+
super().__init__("streamable-http")
|
314
|
+
|
315
|
+
self._url = url
|
316
|
+
|
317
|
+
@property
|
318
|
+
def url(self) -> str:
|
319
|
+
return self._url
|
320
|
+
|
321
|
+
@property
|
322
|
+
def server_name(self):
|
323
|
+
return f"streamable-http:{self._url}"
|
324
|
+
|
325
|
+
@asynccontextmanager
|
326
|
+
async def connect_to_server(self):
|
327
|
+
"""
|
328
|
+
Establish a session with an MCP server via streamable-http within an async context
|
329
|
+
"""
|
330
|
+
async with streamablehttp_client(url=self._url) as (read, write, get_session_id):
|
331
|
+
async with ClientSession(read, write) as session:
|
332
|
+
await session.initialize()
|
333
|
+
yield session
|
334
|
+
|
335
|
+
|
336
|
+
class MCPToolClient:
|
337
|
+
"""
|
338
|
+
Client wrapper used to call an MCP tool.
|
339
|
+
|
340
|
+
Args:
|
341
|
+
connect_fn (callable): Function that returns an async context manager for connecting to the server
|
342
|
+
tool_name (str): The name of the tool to wrap
|
343
|
+
tool_description (str): The description of the tool provided by the MCP server.
|
344
|
+
tool_input_schema (dict): The input schema for the tool.
|
345
|
+
"""
|
346
|
+
|
347
|
+
def __init__(self,
|
348
|
+
session: ClientSession,
|
349
|
+
tool_name: str,
|
350
|
+
tool_description: str | None,
|
351
|
+
tool_input_schema: dict | None = None):
|
352
|
+
self._session = session
|
353
|
+
self._tool_name = tool_name
|
354
|
+
self._tool_description = tool_description
|
355
|
+
self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
|
356
|
+
|
357
|
+
@property
|
358
|
+
def name(self):
|
359
|
+
"""Returns the name of the tool."""
|
360
|
+
return self._tool_name
|
361
|
+
|
362
|
+
@property
|
363
|
+
def description(self):
|
364
|
+
"""
|
365
|
+
Returns the tool's description. If none was provided. Provides a simple description using the tool's name
|
366
|
+
"""
|
367
|
+
if not self._tool_description:
|
368
|
+
return f"MCP Tool {self._tool_name}"
|
369
|
+
return self._tool_description
|
370
|
+
|
371
|
+
@property
|
372
|
+
def input_schema(self):
|
373
|
+
"""
|
374
|
+
Returns the tool's input_schema.
|
375
|
+
"""
|
376
|
+
return self._input_schema
|
377
|
+
|
378
|
+
def set_description(self, description: str):
|
379
|
+
"""
|
380
|
+
Manually define the tool's description using the provided string.
|
381
|
+
"""
|
382
|
+
self._tool_description = description
|
383
|
+
|
384
|
+
async def acall(self, tool_args: dict) -> str:
|
385
|
+
"""
|
386
|
+
Call the MCP tool with the provided arguments.
|
387
|
+
|
388
|
+
Args:
|
389
|
+
tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
|
390
|
+
"""
|
391
|
+
result = await self._session.call_tool(self._tool_name, tool_args)
|
392
|
+
|
393
|
+
output = []
|
394
|
+
|
395
|
+
for res in result.content:
|
396
|
+
if isinstance(res, TextContent):
|
397
|
+
output.append(res.text)
|
398
|
+
else:
|
399
|
+
# Log non-text content for now
|
400
|
+
logger.warning("Got not-text output from %s of type %s", self.name, type(res))
|
401
|
+
result_str = "\n".join(output)
|
402
|
+
|
403
|
+
if result.isError:
|
404
|
+
raise RuntimeError(result_str)
|
405
|
+
|
406
|
+
return result_str
|
@@ -0,0 +1,229 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
from typing import Literal
|
18
|
+
|
19
|
+
from pydantic import BaseModel
|
20
|
+
from pydantic import Field
|
21
|
+
from pydantic import HttpUrl
|
22
|
+
from pydantic import model_validator
|
23
|
+
|
24
|
+
from nat.builder.builder import Builder
|
25
|
+
from nat.builder.function_info import FunctionInfo
|
26
|
+
from nat.cli.register_workflow import register_function
|
27
|
+
from nat.data_models.function import FunctionBaseConfig
|
28
|
+
from nat.experimental.decorators.experimental_warning_decorator import experimental
|
29
|
+
from nat.plugins.mcp.client_base import MCPBaseClient
|
30
|
+
|
31
|
+
logger = logging.getLogger(__name__)
|
32
|
+
|
33
|
+
|
34
|
+
class MCPToolOverrideConfig(BaseModel):
|
35
|
+
"""
|
36
|
+
Configuration for overriding tool properties when exposing from MCP server.
|
37
|
+
"""
|
38
|
+
alias: str | None = Field(default=None, description="Override the tool name (function name in the workflow)")
|
39
|
+
description: str | None = Field(default=None, description="Override the tool description")
|
40
|
+
|
41
|
+
|
42
|
+
class MCPServerConfig(BaseModel):
|
43
|
+
"""
|
44
|
+
Server connection details for MCP client.
|
45
|
+
Supports stdio, sse, and streamable-http transports.
|
46
|
+
streamable-http is the recommended default for HTTP-based connections.
|
47
|
+
"""
|
48
|
+
transport: Literal["stdio", "sse", "streamable-http"] = Field(
|
49
|
+
..., description="Transport type to connect to the MCP server (stdio, sse, or streamable-http)")
|
50
|
+
url: HttpUrl | None = Field(default=None,
|
51
|
+
description="URL of the MCP server (for sse or streamable-http transport)")
|
52
|
+
command: str | None = Field(default=None,
|
53
|
+
description="Command to run for stdio transport (e.g. 'python' or 'docker')")
|
54
|
+
args: list[str] | None = Field(default=None, description="Arguments for the stdio command")
|
55
|
+
env: dict[str, str] | None = Field(default=None, description="Environment variables for the stdio process")
|
56
|
+
|
57
|
+
@model_validator(mode="after")
|
58
|
+
def validate_model(self):
|
59
|
+
"""Validate that stdio and SSE/Streamable HTTP properties are mutually exclusive."""
|
60
|
+
if self.transport == "stdio":
|
61
|
+
if self.url is not None:
|
62
|
+
raise ValueError("url should not be set when using stdio transport")
|
63
|
+
if not self.command:
|
64
|
+
raise ValueError("command is required when using stdio transport")
|
65
|
+
elif self.transport in ("sse", "streamable-http"):
|
66
|
+
if self.command is not None or self.args is not None or self.env is not None:
|
67
|
+
raise ValueError("command, args, and env should not be set when using sse or streamable-http transport")
|
68
|
+
if not self.url:
|
69
|
+
raise ValueError("url is required when using sse or streamable-http transport")
|
70
|
+
return self
|
71
|
+
|
72
|
+
|
73
|
+
class MCPClientConfig(FunctionBaseConfig, name="mcp_client"):
|
74
|
+
"""
|
75
|
+
Configuration for connecting to an MCP server as a client and exposing selected tools.
|
76
|
+
"""
|
77
|
+
server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)")
|
78
|
+
tool_filter: dict[str, MCPToolOverrideConfig] | list[str] | None = Field(
|
79
|
+
default=None,
|
80
|
+
description="""Filter or map tools to expose from the server (list or dict).
|
81
|
+
Can be:
|
82
|
+
- A list of tool names to expose: ['tool1', 'tool2']
|
83
|
+
- A dict mapping tool names to override configs:
|
84
|
+
{'tool1': {'alias': 'new_name', 'description': 'New desc'}}
|
85
|
+
{'tool2': {'description': 'Override description only'}} # alias defaults to 'tool2'
|
86
|
+
""")
|
87
|
+
|
88
|
+
|
89
|
+
class MCPSingleToolConfig(FunctionBaseConfig, name="mcp_single_tool"):
|
90
|
+
"""
|
91
|
+
Configuration for wrapping a single tool from an MCP server as a NeMo Agent toolkit function.
|
92
|
+
"""
|
93
|
+
client: MCPBaseClient = Field(..., description="MCP client to use for the tool")
|
94
|
+
tool_name: str = Field(..., description="Name of the tool to use")
|
95
|
+
tool_description: str | None = Field(default=None, description="Description of the tool")
|
96
|
+
|
97
|
+
model_config = {"arbitrary_types_allowed": True}
|
98
|
+
|
99
|
+
|
100
|
+
def _get_server_name_safe(client: MCPBaseClient) -> str:
|
101
|
+
# Avoid leaking env secrets from stdio client in logs.
|
102
|
+
if client.transport == "stdio":
|
103
|
+
safe_server = f"stdio: {client.command}"
|
104
|
+
else:
|
105
|
+
safe_server = f"{client.transport}: {client.url}"
|
106
|
+
|
107
|
+
return safe_server
|
108
|
+
|
109
|
+
|
110
|
+
@register_function(config_type=MCPSingleToolConfig)
|
111
|
+
async def mcp_single_tool(config: MCPSingleToolConfig, builder: Builder):
|
112
|
+
"""
|
113
|
+
Wrap a single tool from an MCP server as a NeMo Agent toolkit function.
|
114
|
+
"""
|
115
|
+
tool = await config.client.get_tool(config.tool_name)
|
116
|
+
if config.tool_description:
|
117
|
+
tool.set_description(description=config.tool_description)
|
118
|
+
input_schema = tool.input_schema
|
119
|
+
|
120
|
+
logger.info("Configured to use tool: %s from MCP server at %s", tool.name, _get_server_name_safe(config.client))
|
121
|
+
|
122
|
+
def _convert_from_str(input_str: str) -> BaseModel:
|
123
|
+
return input_schema.model_validate_json(input_str)
|
124
|
+
|
125
|
+
@experimental(feature_name="mcp_client")
|
126
|
+
async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
|
127
|
+
try:
|
128
|
+
if tool_input:
|
129
|
+
return await tool.acall(tool_input.model_dump())
|
130
|
+
_ = input_schema.model_validate(kwargs)
|
131
|
+
return await tool.acall(kwargs)
|
132
|
+
except Exception as e:
|
133
|
+
return str(e)
|
134
|
+
|
135
|
+
fn = FunctionInfo.create(single_fn=_response_fn,
|
136
|
+
description=tool.description,
|
137
|
+
input_schema=input_schema,
|
138
|
+
converters=[_convert_from_str])
|
139
|
+
yield fn
|
140
|
+
|
141
|
+
|
142
|
+
@register_function(MCPClientConfig)
|
143
|
+
async def mcp_client_function_handler(config: MCPClientConfig, builder: Builder):
|
144
|
+
"""
|
145
|
+
Connect to an MCP server, discover tools, and register them as functions in the workflow.
|
146
|
+
|
147
|
+
Note:
|
148
|
+
- Uses builder's exit stack to manage client lifecycle
|
149
|
+
- Applies tool filters if provided
|
150
|
+
"""
|
151
|
+
from nat.plugins.mcp.client_base import MCPSSEClient
|
152
|
+
from nat.plugins.mcp.client_base import MCPStdioClient
|
153
|
+
from nat.plugins.mcp.client_base import MCPStreamableHTTPClient
|
154
|
+
|
155
|
+
# Build the appropriate client
|
156
|
+
client_cls = {
|
157
|
+
"stdio": lambda: MCPStdioClient(config.server.command, config.server.args, config.server.env),
|
158
|
+
"sse": lambda: MCPSSEClient(str(config.server.url)),
|
159
|
+
"streamable-http": lambda: MCPStreamableHTTPClient(str(config.server.url)),
|
160
|
+
}.get(config.server.transport)
|
161
|
+
|
162
|
+
if not client_cls:
|
163
|
+
raise ValueError(f"Unsupported transport: {config.server.transport}")
|
164
|
+
|
165
|
+
client = client_cls()
|
166
|
+
logger.info("Configured to use MCP server at %s", _get_server_name_safe(client))
|
167
|
+
|
168
|
+
# client aenter connects to the server and stores the client in the exit stack
|
169
|
+
# so it's cleaned up when the workflow is done
|
170
|
+
async with client:
|
171
|
+
all_tools = await client.get_tools()
|
172
|
+
tool_configs = mcp_filter_tools(all_tools, config.tool_filter)
|
173
|
+
|
174
|
+
for tool_name, tool_cfg in tool_configs.items():
|
175
|
+
await builder.add_function(
|
176
|
+
tool_cfg["function_name"],
|
177
|
+
MCPSingleToolConfig(
|
178
|
+
client=client,
|
179
|
+
tool_name=tool_name,
|
180
|
+
tool_description=tool_cfg["description"],
|
181
|
+
))
|
182
|
+
|
183
|
+
@experimental(feature_name="mcp_client")
|
184
|
+
async def idle_fn(text: str) -> str:
|
185
|
+
# This function is a placeholder and will be removed when function groups are used
|
186
|
+
return f"MCP client connected: {text}"
|
187
|
+
|
188
|
+
yield FunctionInfo.create(single_fn=idle_fn, description="MCP client")
|
189
|
+
|
190
|
+
|
191
|
+
def mcp_filter_tools(all_tools: dict, tool_filter) -> dict[str, dict]:
|
192
|
+
"""
|
193
|
+
Apply tool filtering and optional aliasing/description overrides.
|
194
|
+
|
195
|
+
Returns:
|
196
|
+
Dict[str, dict] where each value has:
|
197
|
+
- function_name
|
198
|
+
- description
|
199
|
+
"""
|
200
|
+
if tool_filter is None:
|
201
|
+
return {name: {"function_name": name, "description": tool.description} for name, tool in all_tools.items()}
|
202
|
+
|
203
|
+
if isinstance(tool_filter, list):
|
204
|
+
return {
|
205
|
+
name: {
|
206
|
+
"function_name": name, "description": all_tools[name].description
|
207
|
+
}
|
208
|
+
for name in tool_filter if name in all_tools
|
209
|
+
}
|
210
|
+
|
211
|
+
if isinstance(tool_filter, dict):
|
212
|
+
result = {}
|
213
|
+
for name, override in tool_filter.items():
|
214
|
+
tool = all_tools.get(name)
|
215
|
+
if not tool:
|
216
|
+
logger.warning("Tool '%s' specified in tool_filter not found in MCP server", name)
|
217
|
+
continue
|
218
|
+
|
219
|
+
if isinstance(override, MCPToolOverrideConfig):
|
220
|
+
result[name] = {
|
221
|
+
"function_name": override.alias or name, "description": override.description or tool.description
|
222
|
+
}
|
223
|
+
else:
|
224
|
+
logger.warning("Unsupported override type for '%s': %s", name, type(override))
|
225
|
+
result[name] = {"function_name": name, "description": tool.description}
|
226
|
+
return result
|
227
|
+
|
228
|
+
# Fallback for unsupported tool_filter types
|
229
|
+
raise ValueError(f"Unsupported tool_filter type: {type(tool_filter)}")
|
@@ -0,0 +1,211 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
import ssl
|
18
|
+
import sys
|
19
|
+
from collections.abc import Callable
|
20
|
+
from functools import wraps
|
21
|
+
from typing import Any
|
22
|
+
|
23
|
+
import httpx
|
24
|
+
|
25
|
+
from nat.plugins.mcp.exceptions import MCPAuthenticationError
|
26
|
+
from nat.plugins.mcp.exceptions import MCPConnectionError
|
27
|
+
from nat.plugins.mcp.exceptions import MCPError
|
28
|
+
from nat.plugins.mcp.exceptions import MCPProtocolError
|
29
|
+
from nat.plugins.mcp.exceptions import MCPRequestError
|
30
|
+
from nat.plugins.mcp.exceptions import MCPSSLError
|
31
|
+
from nat.plugins.mcp.exceptions import MCPTimeoutError
|
32
|
+
from nat.plugins.mcp.exceptions import MCPToolNotFoundError
|
33
|
+
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
|
37
|
+
def format_mcp_error(error: MCPError, include_traceback: bool = False) -> None:
|
38
|
+
"""Format MCP errors for CLI display with structured logging and user guidance.
|
39
|
+
|
40
|
+
Logs structured error information for debugging and displays user-friendly
|
41
|
+
error messages with actionable suggestions to stderr.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
error (MCPError): MCPError instance containing message, url, category, suggestions, and original_exception
|
45
|
+
include_traceback (bool, optional): Whether to include the traceback in the error message. Defaults to False.
|
46
|
+
"""
|
47
|
+
# Log structured error information for debugging
|
48
|
+
logger.error("MCP operation failed: %s", error, exc_info=include_traceback)
|
49
|
+
|
50
|
+
# Display user-friendly suggestions
|
51
|
+
for suggestion in error.suggestions:
|
52
|
+
print(f" → {suggestion}", file=sys.stderr)
|
53
|
+
|
54
|
+
|
55
|
+
def _extract_url(args: tuple, kwargs: dict[str, Any], url_param: str, func_name: str) -> str:
|
56
|
+
"""Extract URL from function arguments using clean fallback chain.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
args: Function positional arguments
|
60
|
+
kwargs: Function keyword arguments
|
61
|
+
url_param (str): Parameter name containing the URL
|
62
|
+
func_name (str): Function name for logging
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
str: URL string or "unknown" if extraction fails
|
66
|
+
"""
|
67
|
+
# Try keyword arguments first
|
68
|
+
if url_param in kwargs:
|
69
|
+
return kwargs[url_param]
|
70
|
+
|
71
|
+
# Try self attribute (e.g., self.url)
|
72
|
+
if args and hasattr(args[0], url_param):
|
73
|
+
return getattr(args[0], url_param)
|
74
|
+
|
75
|
+
# Try common case: url as second parameter after self
|
76
|
+
if len(args) > 1 and url_param == "url":
|
77
|
+
return args[1]
|
78
|
+
|
79
|
+
# Fallback with warning
|
80
|
+
logger.warning("Could not extract URL for error handling in %s", func_name)
|
81
|
+
return "unknown"
|
82
|
+
|
83
|
+
|
84
|
+
def extract_primary_exception(exceptions: list[Exception]) -> Exception:
|
85
|
+
"""Extract the most relevant exception from a group.
|
86
|
+
|
87
|
+
Prioritizes connection errors over others for better user experience.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
exceptions (list[Exception]): List of exceptions from ExceptionGroup
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
Exception: Most relevant exception for user feedback
|
94
|
+
"""
|
95
|
+
# Prioritize connection errors
|
96
|
+
for exc in exceptions:
|
97
|
+
if isinstance(exc, (httpx.ConnectError, ConnectionError)):
|
98
|
+
return exc
|
99
|
+
|
100
|
+
# Then timeout errors
|
101
|
+
for exc in exceptions:
|
102
|
+
if isinstance(exc, httpx.TimeoutException):
|
103
|
+
return exc
|
104
|
+
|
105
|
+
# Then SSL errors
|
106
|
+
for exc in exceptions:
|
107
|
+
if isinstance(exc, ssl.SSLError):
|
108
|
+
return exc
|
109
|
+
|
110
|
+
# Fall back to first exception
|
111
|
+
return exceptions[0]
|
112
|
+
|
113
|
+
|
114
|
+
def convert_to_mcp_error(exception: Exception, url: str) -> MCPError:
|
115
|
+
"""Convert single exception to appropriate MCPError.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
exception (Exception): Single exception to convert
|
119
|
+
url (str): MCP server URL for context
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
MCPError: Appropriate MCPError subclass
|
123
|
+
"""
|
124
|
+
match exception:
|
125
|
+
case httpx.ConnectError() | ConnectionError():
|
126
|
+
return MCPConnectionError(url, exception)
|
127
|
+
case httpx.TimeoutException():
|
128
|
+
return MCPTimeoutError(url, exception)
|
129
|
+
case ssl.SSLError():
|
130
|
+
return MCPSSLError(url, exception)
|
131
|
+
case httpx.RequestError():
|
132
|
+
return MCPRequestError(url, exception)
|
133
|
+
case ValueError() if "Tool" in str(exception) and "not available" in str(exception):
|
134
|
+
# Extract tool name from error message if possible
|
135
|
+
tool_name = str(exception).split("Tool ")[1].split(" not available")[0] if "Tool " in str(
|
136
|
+
exception) else "unknown"
|
137
|
+
return MCPToolNotFoundError(tool_name, url, exception)
|
138
|
+
case _:
|
139
|
+
# Handle TaskGroup error message specifically
|
140
|
+
if "unhandled errors in a TaskGroup" in str(exception):
|
141
|
+
return MCPProtocolError(url, "Failed to connect to MCP server", exception)
|
142
|
+
if "unauthorized" in str(exception).lower() or "forbidden" in str(exception).lower():
|
143
|
+
return MCPAuthenticationError(url, exception)
|
144
|
+
return MCPError(f"Unexpected error: {exception}", url, original_exception=exception)
|
145
|
+
|
146
|
+
|
147
|
+
def handle_mcp_exceptions(url_param: str = "url") -> Callable[..., Any]:
|
148
|
+
"""Decorator that handles exceptions and converts them to MCPErrors.
|
149
|
+
|
150
|
+
This decorator wraps MCP client methods and converts low-level exceptions
|
151
|
+
to structured MCPError instances with helpful user guidance.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
url_param (str): Name of the parameter or attribute containing the MCP server URL
|
155
|
+
|
156
|
+
Returns:
|
157
|
+
Callable[..., Any]: Decorated function
|
158
|
+
|
159
|
+
Example:
|
160
|
+
.. code-block:: python
|
161
|
+
|
162
|
+
@handle_mcp_exceptions("url")
|
163
|
+
async def get_tools(self, url: str):
|
164
|
+
# Method implementation
|
165
|
+
pass
|
166
|
+
|
167
|
+
@handle_mcp_exceptions("url") # Uses self.url
|
168
|
+
async def get_tool(self):
|
169
|
+
# Method implementation
|
170
|
+
pass
|
171
|
+
"""
|
172
|
+
|
173
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
174
|
+
|
175
|
+
@wraps(func)
|
176
|
+
async def wrapper(*args, **kwargs):
|
177
|
+
try:
|
178
|
+
return await func(*args, **kwargs)
|
179
|
+
except MCPError:
|
180
|
+
# Re-raise MCPErrors as-is
|
181
|
+
raise
|
182
|
+
except Exception as e:
|
183
|
+
url = _extract_url(args, kwargs, url_param, func.__name__)
|
184
|
+
|
185
|
+
# Handle ExceptionGroup by extracting most relevant exception
|
186
|
+
if isinstance(e, ExceptionGroup): # noqa: F821
|
187
|
+
primary_exception = extract_primary_exception(list(e.exceptions))
|
188
|
+
mcp_error = convert_to_mcp_error(primary_exception, url)
|
189
|
+
else:
|
190
|
+
mcp_error = convert_to_mcp_error(e, url)
|
191
|
+
|
192
|
+
raise mcp_error from e
|
193
|
+
|
194
|
+
return wrapper
|
195
|
+
|
196
|
+
return decorator
|
197
|
+
|
198
|
+
|
199
|
+
def mcp_exception_handler(func: Callable[..., Any]) -> Callable[..., Any]:
|
200
|
+
"""Simplified decorator for methods that have self.url attribute.
|
201
|
+
|
202
|
+
This is a convenience decorator that assumes the URL is available as self.url.
|
203
|
+
Follows the same pattern as schema_exception_handler in this directory.
|
204
|
+
|
205
|
+
Args:
|
206
|
+
func (Callable[..., Any]): The function to decorate
|
207
|
+
|
208
|
+
Returns:
|
209
|
+
Callable[..., Any]: Decorated function
|
210
|
+
"""
|
211
|
+
return handle_mcp_exceptions("url")(func)
|
@@ -0,0 +1,142 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from enum import Enum
|
17
|
+
|
18
|
+
|
19
|
+
class MCPErrorCategory(str, Enum):
|
20
|
+
"""Categories of MCP errors for structured handling."""
|
21
|
+
CONNECTION = "connection"
|
22
|
+
TIMEOUT = "timeout"
|
23
|
+
SSL = "ssl"
|
24
|
+
AUTHENTICATION = "authentication"
|
25
|
+
TOOL_NOT_FOUND = "tool_not_found"
|
26
|
+
PROTOCOL = "protocol"
|
27
|
+
UNKNOWN = "unknown"
|
28
|
+
|
29
|
+
|
30
|
+
class MCPError(Exception):
|
31
|
+
"""Base exception for MCP-related errors."""
|
32
|
+
|
33
|
+
def __init__(self,
|
34
|
+
message: str,
|
35
|
+
url: str,
|
36
|
+
category: MCPErrorCategory = MCPErrorCategory.UNKNOWN,
|
37
|
+
suggestions: list[str] | None = None,
|
38
|
+
original_exception: Exception | None = None):
|
39
|
+
super().__init__(message)
|
40
|
+
self.url = url
|
41
|
+
self.category = category
|
42
|
+
self.suggestions = suggestions or []
|
43
|
+
self.original_exception = original_exception
|
44
|
+
|
45
|
+
|
46
|
+
class MCPConnectionError(MCPError):
|
47
|
+
"""Exception for MCP connection failures."""
|
48
|
+
|
49
|
+
def __init__(self, url: str, original_exception: Exception | None = None):
|
50
|
+
super().__init__(f"Unable to connect to MCP server at {url}",
|
51
|
+
url=url,
|
52
|
+
category=MCPErrorCategory.CONNECTION,
|
53
|
+
suggestions=[
|
54
|
+
"Please ensure the MCP server is running and accessible",
|
55
|
+
"Check if the URL and port are correct"
|
56
|
+
],
|
57
|
+
original_exception=original_exception)
|
58
|
+
|
59
|
+
|
60
|
+
class MCPTimeoutError(MCPError):
|
61
|
+
"""Exception for MCP timeout errors."""
|
62
|
+
|
63
|
+
def __init__(self, url: str, original_exception: Exception | None = None):
|
64
|
+
super().__init__(f"Connection timed out to MCP server at {url}",
|
65
|
+
url=url,
|
66
|
+
category=MCPErrorCategory.TIMEOUT,
|
67
|
+
suggestions=[
|
68
|
+
"The server may be overloaded or network is slow",
|
69
|
+
"Try again in a moment or check network connectivity"
|
70
|
+
],
|
71
|
+
original_exception=original_exception)
|
72
|
+
|
73
|
+
|
74
|
+
class MCPSSLError(MCPError):
|
75
|
+
"""Exception for MCP SSL/TLS errors."""
|
76
|
+
|
77
|
+
def __init__(self, url: str, original_exception: Exception | None = None):
|
78
|
+
super().__init__(f"SSL/TLS error connecting to {url}",
|
79
|
+
url=url,
|
80
|
+
category=MCPErrorCategory.SSL,
|
81
|
+
suggestions=[
|
82
|
+
"Check if the server requires HTTPS or has valid certificates",
|
83
|
+
"Try using HTTP instead of HTTPS if appropriate"
|
84
|
+
],
|
85
|
+
original_exception=original_exception)
|
86
|
+
|
87
|
+
|
88
|
+
class MCPRequestError(MCPError):
|
89
|
+
"""Exception for MCP request errors."""
|
90
|
+
|
91
|
+
def __init__(self, url: str, original_exception: Exception | None = None):
|
92
|
+
message = f"Request failed to MCP server at {url}"
|
93
|
+
if original_exception:
|
94
|
+
message += f": {original_exception}"
|
95
|
+
|
96
|
+
super().__init__(message,
|
97
|
+
url=url,
|
98
|
+
category=MCPErrorCategory.PROTOCOL,
|
99
|
+
suggestions=["Check the server URL format and network settings"],
|
100
|
+
original_exception=original_exception)
|
101
|
+
|
102
|
+
|
103
|
+
class MCPToolNotFoundError(MCPError):
|
104
|
+
"""Exception for when a specific MCP tool is not found."""
|
105
|
+
|
106
|
+
def __init__(self, tool_name: str, url: str, original_exception: Exception | None = None):
|
107
|
+
super().__init__(f"Tool '{tool_name}' not available at {url}",
|
108
|
+
url=url,
|
109
|
+
category=MCPErrorCategory.TOOL_NOT_FOUND,
|
110
|
+
suggestions=[
|
111
|
+
"Use 'nat info mcp --detail' to see available tools",
|
112
|
+
"Check that the tool name is spelled correctly"
|
113
|
+
],
|
114
|
+
original_exception=original_exception)
|
115
|
+
|
116
|
+
|
117
|
+
class MCPAuthenticationError(MCPError):
|
118
|
+
"""Exception for MCP authentication failures."""
|
119
|
+
|
120
|
+
def __init__(self, url: str, original_exception: Exception | None = None):
|
121
|
+
super().__init__(f"Authentication failed when connecting to MCP server at {url}",
|
122
|
+
url=url,
|
123
|
+
category=MCPErrorCategory.AUTHENTICATION,
|
124
|
+
suggestions=[
|
125
|
+
"Check if the server requires authentication credentials",
|
126
|
+
"Verify that your credentials are correct and not expired"
|
127
|
+
],
|
128
|
+
original_exception=original_exception)
|
129
|
+
|
130
|
+
|
131
|
+
class MCPProtocolError(MCPError):
|
132
|
+
"""Exception for MCP protocol-related errors."""
|
133
|
+
|
134
|
+
def __init__(self, url: str, message: str = "Protocol error", original_exception: Exception | None = None):
|
135
|
+
super().__init__(f"{message} (MCP server at {url})",
|
136
|
+
url=url,
|
137
|
+
category=MCPErrorCategory.PROTOCOL,
|
138
|
+
suggestions=[
|
139
|
+
"Check that the MCP server is running and accessible at this URL",
|
140
|
+
"Verify the server supports the expected MCP protocol version"
|
141
|
+
],
|
142
|
+
original_exception=original_exception)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
# flake8: noqa
|
17
|
+
# isort:skip_file
|
18
|
+
|
19
|
+
# Import any providers which need to be automatically registered here
|
20
|
+
|
21
|
+
from . import client_impl
|
22
|
+
from . import tool
|
nat/plugins/mcp/tool.py
ADDED
@@ -0,0 +1,133 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import logging
|
17
|
+
from typing import Literal
|
18
|
+
|
19
|
+
from pydantic import BaseModel
|
20
|
+
from pydantic import Field
|
21
|
+
from pydantic import HttpUrl
|
22
|
+
from pydantic import model_validator
|
23
|
+
|
24
|
+
from nat.builder.builder import Builder
|
25
|
+
from nat.builder.function_info import FunctionInfo
|
26
|
+
from nat.cli.register_workflow import register_function
|
27
|
+
from nat.data_models.function import FunctionBaseConfig
|
28
|
+
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
class MCPToolConfig(FunctionBaseConfig, name="mcp_tool_wrapper"):
|
33
|
+
"""
|
34
|
+
Function which connects to a Model Context Protocol (MCP) server and wraps the selected tool as a NeMo Agent toolkit
|
35
|
+
function.
|
36
|
+
"""
|
37
|
+
# Add your custom configuration parameters here
|
38
|
+
url: HttpUrl | None = Field(default=None,
|
39
|
+
description="The URL of the MCP server (for streamable-http or sse modes)")
|
40
|
+
mcp_tool_name: str = Field(description="The name of the tool served by the MCP Server that you want to use")
|
41
|
+
transport: Literal["sse", "stdio", "streamable-http"] = Field(
|
42
|
+
default="streamable-http",
|
43
|
+
description="The type of transport to use (default: streamable-http, backwards compatible with sse)")
|
44
|
+
command: str | None = Field(default=None,
|
45
|
+
description="The command to run for stdio mode (e.g. 'docker' or 'python')")
|
46
|
+
args: list[str] | None = Field(default=None, description="Additional arguments for the stdio command")
|
47
|
+
env: dict[str, str] | None = Field(default=None, description="Environment variables to set for the stdio process")
|
48
|
+
description: str | None = Field(default=None,
|
49
|
+
description="""
|
50
|
+
Description for the tool that will override the description provided by the MCP server. Should only be used if
|
51
|
+
the description provided by the server is poor or nonexistent
|
52
|
+
""")
|
53
|
+
return_exception: bool = Field(default=True,
|
54
|
+
description="""
|
55
|
+
If true, the tool will return the exception message if the tool call fails.
|
56
|
+
If false, raise the exception.
|
57
|
+
""")
|
58
|
+
|
59
|
+
@model_validator(mode="after")
|
60
|
+
def validate_model(self):
|
61
|
+
"""Validate that stdio and SSE/Streamable HTTP properties are mutually exclusive."""
|
62
|
+
if self.transport == 'stdio':
|
63
|
+
if self.url is not None:
|
64
|
+
raise ValueError("url should not be set when using stdio client type")
|
65
|
+
if not self.command:
|
66
|
+
raise ValueError("command is required when using stdio client type")
|
67
|
+
elif self.transport in ['streamable-http', 'sse']:
|
68
|
+
if self.command is not None or self.args is not None or self.env is not None:
|
69
|
+
raise ValueError(
|
70
|
+
"command, args, and env should not be set when using streamable-http or sse client type")
|
71
|
+
if not self.url:
|
72
|
+
raise ValueError("url is required when using streamable-http or sse client type")
|
73
|
+
return self
|
74
|
+
|
75
|
+
|
76
|
+
@register_function(config_type=MCPToolConfig)
|
77
|
+
async def mcp_tool(config: MCPToolConfig, builder: Builder):
|
78
|
+
"""
|
79
|
+
Generate a NeMo Agent Toolkit Function that wraps a tool provided by the MCP server.
|
80
|
+
"""
|
81
|
+
|
82
|
+
from nat.plugins.mcp.client_base import MCPSSEClient
|
83
|
+
from nat.plugins.mcp.client_base import MCPStdioClient
|
84
|
+
from nat.plugins.mcp.client_base import MCPStreamableHTTPClient
|
85
|
+
from nat.plugins.mcp.client_base import MCPToolClient
|
86
|
+
|
87
|
+
# Initialize the client
|
88
|
+
if config.transport == 'stdio':
|
89
|
+
client = MCPStdioClient(command=config.command, args=config.args, env=config.env)
|
90
|
+
elif config.transport == 'streamable-http':
|
91
|
+
client = MCPStreamableHTTPClient(url=str(config.url))
|
92
|
+
elif config.transport == 'sse':
|
93
|
+
client = MCPSSEClient(url=str(config.url))
|
94
|
+
else:
|
95
|
+
raise ValueError(f"Invalid transport type: {config.transport}")
|
96
|
+
|
97
|
+
async with client:
|
98
|
+
# If the tool is found create a MCPToolClient object and set the description if provided
|
99
|
+
tool: MCPToolClient = await client.get_tool(config.mcp_tool_name)
|
100
|
+
if config.description:
|
101
|
+
tool.set_description(description=config.description)
|
102
|
+
|
103
|
+
logger.info("Configured to use tool: %s from MCP server at %s", tool.name, client.server_name)
|
104
|
+
|
105
|
+
def _convert_from_str(input_str: str) -> tool.input_schema:
|
106
|
+
return tool.input_schema.model_validate_json(input_str)
|
107
|
+
|
108
|
+
async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
|
109
|
+
# Run the tool, catching any errors and sending to agent for correction
|
110
|
+
try:
|
111
|
+
if tool_input:
|
112
|
+
args = tool_input.model_dump()
|
113
|
+
return await tool.acall(args)
|
114
|
+
|
115
|
+
_ = tool.input_schema.model_validate(kwargs)
|
116
|
+
return await tool.acall(kwargs)
|
117
|
+
except Exception as e:
|
118
|
+
if config.return_exception:
|
119
|
+
if tool_input:
|
120
|
+
logger.warning("Error calling tool %s with serialized input: %s",
|
121
|
+
tool.name,
|
122
|
+
tool_input.model_dump(),
|
123
|
+
exc_info=True)
|
124
|
+
else:
|
125
|
+
logger.warning("Error calling tool %s with input: %s", tool.name, kwargs, exc_info=True)
|
126
|
+
return str(e)
|
127
|
+
# If the tool call fails, raise the exception.
|
128
|
+
raise
|
129
|
+
|
130
|
+
yield FunctionInfo.create(single_fn=_response_fn,
|
131
|
+
description=tool.description,
|
132
|
+
input_schema=tool.input_schema,
|
133
|
+
converters=[_convert_from_str])
|
@@ -0,0 +1,46 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: nvidia-nat-mcp
|
3
|
+
Version: 1.3.0a20250909
|
4
|
+
Summary: Subpackage for MCP client integration in NeMo Agent toolkit
|
5
|
+
Keywords: ai,rag,agents,mcp
|
6
|
+
Classifier: Programming Language :: Python
|
7
|
+
Classifier: Programming Language :: Python :: 3.11
|
8
|
+
Classifier: Programming Language :: Python :: 3.12
|
9
|
+
Classifier: Programming Language :: Python :: 3.13
|
10
|
+
Requires-Python: <3.14,>=3.11
|
11
|
+
Description-Content-Type: text/markdown
|
12
|
+
Requires-Dist: nvidia-nat==v1.3.0a20250909
|
13
|
+
Requires-Dist: mcp~=1.13
|
14
|
+
|
15
|
+
<!--
|
16
|
+
SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
17
|
+
SPDX-License-Identifier: Apache-2.0
|
18
|
+
|
19
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
20
|
+
you may not use this file except in compliance with the License.
|
21
|
+
You may obtain a copy of the License at
|
22
|
+
|
23
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
24
|
+
|
25
|
+
Unless required by applicable law or agreed to in writing, software
|
26
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
27
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
28
|
+
See the License for the specific language governing permissions and
|
29
|
+
limitations under the License.
|
30
|
+
-->
|
31
|
+
|
32
|
+

|
33
|
+
|
34
|
+
|
35
|
+
# NVIDIA NeMo Agent Toolkit MCP Subpackage
|
36
|
+
Subpackage for MCP client integration in NeMo Agent toolkit.
|
37
|
+
|
38
|
+
This package provides MCP (Model Context Protocol) client functionality, allowing NeMo Agent toolkit workflows to connect to external MCP servers and use their tools as functions.
|
39
|
+
|
40
|
+
## Features
|
41
|
+
|
42
|
+
- Connect to MCP servers via streamable-http, SSE, or stdio transports
|
43
|
+
- Wrap individual MCP tools as NeMo Agent toolkit functions
|
44
|
+
- Connect to MCP servers and dynamically discover available tools
|
45
|
+
|
46
|
+
For more information about the NVIDIA NeMo Agent toolkit, please visit the [NeMo Agent toolkit GitHub Repo](https://github.com/NVIDIA/NeMo-Agent-Toolkit).
|
@@ -0,0 +1,13 @@
|
|
1
|
+
nat/meta/pypi.md,sha256=GyV4DI1d9ThgEhnYTQ0vh40Q9hPC8jN-goLnRiFDmZ8,1498
|
2
|
+
nat/plugins/mcp/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
3
|
+
nat/plugins/mcp/client_base.py,sha256=4vFOBFoSpLpkq7r2iXDMbi6tDj02JBidZo5RiBR167w,13424
|
4
|
+
nat/plugins/mcp/client_impl.py,sha256=6rG3LcCX4TFsiST5O0_C8eOpY9LdnoSMarfAWeR76XA,9724
|
5
|
+
nat/plugins/mcp/exception_handler.py,sha256=JdPdZG1NgWpdRnIz7JTGHiJASS5wot9nJiD3SRWV4Kw,7649
|
6
|
+
nat/plugins/mcp/exceptions.py,sha256=EGVOnYlui8xufm8dhJyPL1SUqBLnCGOTvRoeyNcmcWE,5980
|
7
|
+
nat/plugins/mcp/register.py,sha256=HOT2Wl2isGuyFc7BUTi58-BbjI5-EtZMZo7stsv5pN4,831
|
8
|
+
nat/plugins/mcp/tool.py,sha256=MSRnnr1a6OjfqVkt2SYPkfLi9lU0JqovIZuTOL1cgHQ,6378
|
9
|
+
nvidia_nat_mcp-1.3.0a20250909.dist-info/METADATA,sha256=w2nXFHRlGLbHy6RPOhRmnA3AgLKzNdnC6DFXdNPdPYs,1997
|
10
|
+
nvidia_nat_mcp-1.3.0a20250909.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
11
|
+
nvidia_nat_mcp-1.3.0a20250909.dist-info/entry_points.txt,sha256=x7dQTqek3GEdU-y9GslnygxMu0BSbEeUiOOMa2gvaaQ,52
|
12
|
+
nvidia_nat_mcp-1.3.0a20250909.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
|
13
|
+
nvidia_nat_mcp-1.3.0a20250909.dist-info/RECORD,,
|
@@ -0,0 +1 @@
|
|
1
|
+
nat
|