nvidia-nat 1.3.0a20250828__py3-none-any.whl → 1.3.0a20250830__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +6 -1
- nat/agent/react_agent/agent.py +46 -38
- nat/agent/react_agent/register.py +7 -2
- nat/agent/rewoo_agent/agent.py +16 -30
- nat/agent/rewoo_agent/register.py +3 -3
- nat/agent/tool_calling_agent/agent.py +9 -19
- nat/agent/tool_calling_agent/register.py +2 -2
- nat/builder/eval_builder.py +2 -2
- nat/builder/function.py +8 -8
- nat/builder/workflow.py +6 -2
- nat/builder/workflow_builder.py +21 -24
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_mcp.py +183 -47
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +1 -1
- nat/cli/commands/start.py +15 -3
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/workflow_commands.py +4 -4
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/thinking_mixin.py +27 -8
- nat/eval/evaluate.py +6 -6
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +1 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/swe_bench_evaluator/evaluate.py +5 -5
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +2 -2
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +3 -3
- nat/front_ends/fastapi/message_handler.py +2 -2
- nat/front_ends/fastapi/message_validator.py +8 -10
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +1 -1
- nat/front_ends/mcp/mcp_front_end_config.py +5 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +8 -2
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +2 -2
- nat/front_ends/mcp/tool_converter.py +40 -13
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +8 -9
- nat/observability/exporter_manager.py +5 -5
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/processor/batching_processor.py +4 -6
- nat/observability/register.py +3 -1
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +5 -5
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +2 -2
- nat/profiler/profile_runner.py +1 -1
- nat/profiler/utils.py +1 -1
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -1
- nat/registry_handlers/pypi/pypi_handler.py +3 -3
- nat/registry_handlers/rest/rest_handler.py +4 -4
- nat/retriever/milvus/retriever.py +1 -1
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/runtime/loader.py +1 -1
- nat/runtime/runner.py +2 -2
- nat/settings/global_settings.py +1 -1
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/mcp/{mcp_client.py → mcp_client_base.py} +197 -46
- nat/tool/mcp/mcp_client_impl.py +229 -0
- nat/tool/mcp/mcp_tool.py +79 -42
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +1 -0
- nat/tool/retriever.py +3 -2
- nat/utils/io/yaml_tools.py +1 -1
- nat/utils/reactive/observer.py +2 -2
- nat/utils/settings/global_settings.py +2 -2
- {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/METADATA +3 -3
- {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/RECORD +82 -81
- {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/top_level.txt +0 -0
|
@@ -16,12 +16,18 @@
|
|
|
16
16
|
from __future__ import annotations
|
|
17
17
|
|
|
18
18
|
import logging
|
|
19
|
+
from abc import ABC
|
|
20
|
+
from abc import abstractmethod
|
|
21
|
+
from contextlib import AsyncExitStack
|
|
19
22
|
from contextlib import asynccontextmanager
|
|
20
23
|
from enum import Enum
|
|
21
24
|
from typing import Any
|
|
22
25
|
|
|
23
26
|
from mcp import ClientSession
|
|
24
27
|
from mcp.client.sse import sse_client
|
|
28
|
+
from mcp.client.stdio import StdioServerParameters
|
|
29
|
+
from mcp.client.stdio import stdio_client
|
|
30
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
25
31
|
from mcp.types import TextContent
|
|
26
32
|
from pydantic import BaseModel
|
|
27
33
|
from pydantic import Field
|
|
@@ -29,6 +35,7 @@ from pydantic import create_model
|
|
|
29
35
|
|
|
30
36
|
from nat.tool.mcp.exceptions import MCPToolNotFoundError
|
|
31
37
|
from nat.utils.exception_handlers.mcp import mcp_exception_handler
|
|
38
|
+
from nat.utils.type_utils import override
|
|
32
39
|
|
|
33
40
|
logger = logging.getLogger(__name__)
|
|
34
41
|
|
|
@@ -107,56 +114,78 @@ def model_from_mcp_schema(name: str, mcp_input_schema: dict) -> type[BaseModel]:
|
|
|
107
114
|
return create_model(f"{_generate_valid_classname(name)}InputSchema", **schema_dict)
|
|
108
115
|
|
|
109
116
|
|
|
110
|
-
class
|
|
117
|
+
class MCPBaseClient(ABC):
|
|
111
118
|
"""
|
|
112
|
-
|
|
119
|
+
Base client for creating a session and connecting to an MCP server
|
|
113
120
|
|
|
114
121
|
Args:
|
|
115
|
-
|
|
122
|
+
transport (str): The type of client to use ('sse', 'stdio', or 'streamable-http')
|
|
116
123
|
"""
|
|
117
124
|
|
|
118
|
-
def __init__(self,
|
|
119
|
-
self.
|
|
125
|
+
def __init__(self, transport: str = 'streamable-http'):
|
|
126
|
+
self._tools = None
|
|
127
|
+
self._transport = transport.lower()
|
|
128
|
+
if self._transport not in ['sse', 'stdio', 'streamable-http']:
|
|
129
|
+
raise ValueError("transport must be either 'sse', 'stdio' or 'streamable-http'")
|
|
120
130
|
|
|
121
|
-
|
|
122
|
-
async def connect_to_sse_server(self):
|
|
123
|
-
"""
|
|
124
|
-
Establish a session with an MCP SSE server within an aync context
|
|
125
|
-
"""
|
|
126
|
-
async with sse_client(url=self.url) as (read, write):
|
|
127
|
-
async with ClientSession(read, write) as session:
|
|
128
|
-
await session.initialize()
|
|
129
|
-
yield session
|
|
131
|
+
self._exit_stack: AsyncExitStack | None = None
|
|
130
132
|
|
|
133
|
+
self._session: ClientSession | None = None
|
|
131
134
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
+
@property
|
|
136
|
+
def transport(self) -> str:
|
|
137
|
+
return self._transport
|
|
135
138
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
+
async def __aenter__(self):
|
|
140
|
+
if self._exit_stack:
|
|
141
|
+
raise RuntimeError("MCPBaseClient already initialized. Use async with to initialize.")
|
|
139
142
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
self.
|
|
143
|
+
self._exit_stack = AsyncExitStack()
|
|
144
|
+
|
|
145
|
+
self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
|
|
146
|
+
|
|
147
|
+
return self
|
|
148
|
+
|
|
149
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
150
|
+
|
|
151
|
+
if not self._exit_stack:
|
|
152
|
+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
|
153
|
+
|
|
154
|
+
await self._exit_stack.aclose()
|
|
155
|
+
self._session = None
|
|
156
|
+
self._exit_stack = None
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def server_name(self):
|
|
160
|
+
"""
|
|
161
|
+
Provide server name for logging
|
|
162
|
+
"""
|
|
163
|
+
return self._transport
|
|
164
|
+
|
|
165
|
+
@abstractmethod
|
|
166
|
+
@asynccontextmanager
|
|
167
|
+
async def connect_to_server(self):
|
|
168
|
+
"""
|
|
169
|
+
Establish a session with an MCP server within an async context
|
|
170
|
+
"""
|
|
171
|
+
pass
|
|
143
172
|
|
|
144
|
-
@mcp_exception_handler
|
|
145
173
|
async def get_tools(self):
|
|
146
174
|
"""
|
|
147
175
|
Retrieve a dictionary of all tools served by the MCP server.
|
|
176
|
+
"""
|
|
148
177
|
|
|
149
|
-
|
|
150
|
-
|
|
178
|
+
if not self._session:
|
|
179
|
+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
|
151
180
|
|
|
152
|
-
|
|
153
|
-
MCPError: If connection or tool retrieval fails
|
|
154
|
-
"""
|
|
155
|
-
async with self.connect_to_sse_server() as session:
|
|
156
|
-
response = await session.list_tools()
|
|
181
|
+
response = await self._session.list_tools()
|
|
157
182
|
|
|
158
183
|
return {
|
|
159
|
-
tool.name:
|
|
184
|
+
tool.name:
|
|
185
|
+
MCPToolClient(session=self._session,
|
|
186
|
+
tool_name=tool.name,
|
|
187
|
+
tool_description=tool.description,
|
|
188
|
+
tool_input_schema=tool.inputSchema)
|
|
160
189
|
for tool in response.tools
|
|
161
190
|
}
|
|
162
191
|
|
|
@@ -172,9 +201,11 @@ class MCPBuilder(MCPSSEClient):
|
|
|
172
201
|
MCPToolClient for the configured tool.
|
|
173
202
|
|
|
174
203
|
Raises:
|
|
175
|
-
MCPToolNotFoundError: If no tool is available with that name
|
|
176
|
-
MCPError: If connection fails
|
|
204
|
+
MCPToolNotFoundError: If no tool is available with that name.
|
|
177
205
|
"""
|
|
206
|
+
if not self._exit_stack:
|
|
207
|
+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
|
208
|
+
|
|
178
209
|
if not self._tools:
|
|
179
210
|
self._tools = await self.get_tools()
|
|
180
211
|
|
|
@@ -185,27 +216,143 @@ class MCPBuilder(MCPSSEClient):
|
|
|
185
216
|
|
|
186
217
|
@mcp_exception_handler
|
|
187
218
|
async def call_tool(self, tool_name: str, tool_args: dict | None):
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
219
|
+
if not self._session:
|
|
220
|
+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
|
|
221
|
+
|
|
222
|
+
result = await self._session.call_tool(tool_name, tool_args)
|
|
223
|
+
return result
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class MCPSSEClient(MCPBaseClient):
|
|
227
|
+
"""
|
|
228
|
+
Client for creating a session and connecting to an MCP server using SSE
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
url (str): The url of the MCP server
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
def __init__(self, url: str):
|
|
235
|
+
super().__init__("sse")
|
|
236
|
+
self._url = url
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def url(self) -> str:
|
|
240
|
+
return self._url
|
|
241
|
+
|
|
242
|
+
@property
|
|
243
|
+
def server_name(self):
|
|
244
|
+
return f"sse:{self._url}"
|
|
245
|
+
|
|
246
|
+
@asynccontextmanager
|
|
247
|
+
@override
|
|
248
|
+
async def connect_to_server(self):
|
|
249
|
+
"""
|
|
250
|
+
Establish a session with an MCP SSE server within an async context
|
|
251
|
+
"""
|
|
252
|
+
async with sse_client(url=self._url) as (read, write):
|
|
253
|
+
async with ClientSession(read, write) as session:
|
|
254
|
+
await session.initialize()
|
|
255
|
+
yield session
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class MCPStdioClient(MCPBaseClient):
|
|
259
|
+
"""
|
|
260
|
+
Client for creating a session and connecting to an MCP server using stdio
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
command (str): The command to run
|
|
264
|
+
args (list[str] | None): Additional arguments for the command
|
|
265
|
+
env (dict[str, str] | None): Environment variables to set for the process
|
|
266
|
+
"""
|
|
267
|
+
|
|
268
|
+
def __init__(self, command: str, args: list[str] | None = None, env: dict[str, str] | None = None):
|
|
269
|
+
super().__init__("stdio")
|
|
270
|
+
self._command = command
|
|
271
|
+
self._args = args
|
|
272
|
+
self._env = env
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def command(self) -> str:
|
|
276
|
+
return self._command
|
|
277
|
+
|
|
278
|
+
@property
|
|
279
|
+
def server_name(self):
|
|
280
|
+
return f"stdio:{self._command}"
|
|
281
|
+
|
|
282
|
+
@property
|
|
283
|
+
def args(self) -> list[str] | None:
|
|
284
|
+
return self._args
|
|
285
|
+
|
|
286
|
+
@property
|
|
287
|
+
def env(self) -> dict[str, str] | None:
|
|
288
|
+
return self._env
|
|
289
|
+
|
|
290
|
+
@asynccontextmanager
|
|
291
|
+
@override
|
|
292
|
+
async def connect_to_server(self):
|
|
293
|
+
"""
|
|
294
|
+
Establish a session with an MCP server via stdio within an async context
|
|
295
|
+
"""
|
|
296
|
+
|
|
297
|
+
server_params = StdioServerParameters(command=self._command, args=self._args or [], env=self._env)
|
|
298
|
+
async with stdio_client(server_params) as (read, write):
|
|
299
|
+
async with ClientSession(read, write) as session:
|
|
300
|
+
await session.initialize()
|
|
301
|
+
yield session
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class MCPStreamableHTTPClient(MCPBaseClient):
|
|
305
|
+
"""
|
|
306
|
+
Client for creating a session and connecting to an MCP server using streamable-http
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
url (str): The url of the MCP server
|
|
310
|
+
"""
|
|
311
|
+
|
|
312
|
+
def __init__(self, url: str):
|
|
313
|
+
super().__init__("streamable-http")
|
|
314
|
+
|
|
315
|
+
self._url = url
|
|
316
|
+
|
|
317
|
+
@property
|
|
318
|
+
def url(self) -> str:
|
|
319
|
+
return self._url
|
|
191
320
|
|
|
321
|
+
@property
|
|
322
|
+
def server_name(self):
|
|
323
|
+
return f"streamable-http:{self._url}"
|
|
192
324
|
|
|
193
|
-
|
|
325
|
+
@asynccontextmanager
|
|
326
|
+
async def connect_to_server(self):
|
|
327
|
+
"""
|
|
328
|
+
Establish a session with an MCP server via streamable-http within an async context
|
|
329
|
+
"""
|
|
330
|
+
async with streamablehttp_client(url=self._url) as (read, write, get_session_id):
|
|
331
|
+
async with ClientSession(read, write) as session:
|
|
332
|
+
await session.initialize()
|
|
333
|
+
yield session
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class MCPToolClient:
|
|
194
337
|
"""
|
|
195
338
|
Client wrapper used to call an MCP tool.
|
|
196
339
|
|
|
197
340
|
Args:
|
|
198
|
-
|
|
341
|
+
connect_fn (callable): Function that returns an async context manager for connecting to the server
|
|
199
342
|
tool_name (str): The name of the tool to wrap
|
|
200
343
|
tool_description (str): The description of the tool provided by the MCP server.
|
|
201
344
|
tool_input_schema (dict): The input schema for the tool.
|
|
202
345
|
"""
|
|
203
346
|
|
|
204
|
-
def __init__(self,
|
|
205
|
-
|
|
347
|
+
def __init__(self,
|
|
348
|
+
session: ClientSession,
|
|
349
|
+
tool_name: str,
|
|
350
|
+
tool_description: str | None,
|
|
351
|
+
tool_input_schema: dict | None = None):
|
|
352
|
+
self._session = session
|
|
206
353
|
self._tool_name = tool_name
|
|
207
354
|
self._tool_description = tool_description
|
|
208
|
-
self._input_schema = model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None
|
|
355
|
+
self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
|
|
209
356
|
|
|
210
357
|
@property
|
|
211
358
|
def name(self):
|
|
@@ -234,7 +381,6 @@ class MCPToolClient(MCPSSEClient):
|
|
|
234
381
|
"""
|
|
235
382
|
self._tool_description = description
|
|
236
383
|
|
|
237
|
-
@mcp_exception_handler
|
|
238
384
|
async def acall(self, tool_args: dict) -> str:
|
|
239
385
|
"""
|
|
240
386
|
Call the MCP tool with the provided arguments.
|
|
@@ -242,14 +388,19 @@ class MCPToolClient(MCPSSEClient):
|
|
|
242
388
|
Args:
|
|
243
389
|
tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
|
|
244
390
|
"""
|
|
245
|
-
|
|
246
|
-
result = await session.call_tool(self._tool_name, tool_args)
|
|
391
|
+
result = await self._session.call_tool(self._tool_name, tool_args)
|
|
247
392
|
|
|
248
393
|
output = []
|
|
394
|
+
|
|
249
395
|
for res in result.content:
|
|
250
396
|
if isinstance(res, TextContent):
|
|
251
397
|
output.append(res.text)
|
|
252
398
|
else:
|
|
253
399
|
# Log non-text content for now
|
|
254
400
|
logger.warning("Got not-text output from %s of type %s", self.name, type(res))
|
|
255
|
-
|
|
401
|
+
result_str = "\n".join(output)
|
|
402
|
+
|
|
403
|
+
if result.isError:
|
|
404
|
+
raise RuntimeError(result_str)
|
|
405
|
+
|
|
406
|
+
return result_str
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from typing import Literal
|
|
18
|
+
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
from pydantic import Field
|
|
21
|
+
from pydantic import HttpUrl
|
|
22
|
+
from pydantic import model_validator
|
|
23
|
+
|
|
24
|
+
from nat.builder.builder import Builder
|
|
25
|
+
from nat.builder.function_info import FunctionInfo
|
|
26
|
+
from nat.cli.register_workflow import register_function
|
|
27
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
28
|
+
from nat.experimental.decorators.experimental_warning_decorator import experimental
|
|
29
|
+
from nat.tool.mcp.mcp_client_base import MCPBaseClient
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
# All functions in this file are experimental
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ToolOverrideConfig(BaseModel):
|
|
37
|
+
"""
|
|
38
|
+
Configuration for overriding tool properties when exposing from MCP server.
|
|
39
|
+
"""
|
|
40
|
+
alias: str | None = Field(default=None, description="Override the tool name (function name in the workflow)")
|
|
41
|
+
description: str | None = Field(default=None, description="Override the tool description")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class MCPServerConfig(BaseModel):
|
|
45
|
+
"""
|
|
46
|
+
Server connection details for MCP client.
|
|
47
|
+
Supports stdio, sse, and streamable-http transports.
|
|
48
|
+
streamable-http is the recommended default for HTTP-based connections.
|
|
49
|
+
"""
|
|
50
|
+
transport: Literal["stdio", "sse", "streamable-http"] = Field(
|
|
51
|
+
..., description="Transport type to connect to the MCP server (stdio, sse, or streamable-http)")
|
|
52
|
+
url: HttpUrl | None = Field(default=None,
|
|
53
|
+
description="URL of the MCP server (for sse or streamable-http transport)")
|
|
54
|
+
command: str | None = Field(default=None,
|
|
55
|
+
description="Command to run for stdio transport (e.g. 'python' or 'docker')")
|
|
56
|
+
args: list[str] | None = Field(default=None, description="Arguments for the stdio command")
|
|
57
|
+
env: dict[str, str] | None = Field(default=None, description="Environment variables for the stdio process")
|
|
58
|
+
|
|
59
|
+
@model_validator(mode="after")
|
|
60
|
+
def validate_model(self):
|
|
61
|
+
"""Validate that stdio and SSE/Streamable HTTP properties are mutually exclusive."""
|
|
62
|
+
if self.transport == "stdio":
|
|
63
|
+
if self.url is not None:
|
|
64
|
+
raise ValueError("url should not be set when using stdio transport")
|
|
65
|
+
if not self.command:
|
|
66
|
+
raise ValueError("command is required when using stdio transport")
|
|
67
|
+
elif self.transport in ("sse", "streamable-http"):
|
|
68
|
+
if self.command is not None or self.args is not None or self.env is not None:
|
|
69
|
+
raise ValueError("command, args, and env should not be set when using sse or streamable-http transport")
|
|
70
|
+
if not self.url:
|
|
71
|
+
raise ValueError("url is required when using sse or streamable-http transport")
|
|
72
|
+
return self
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class MCPClientConfig(FunctionBaseConfig, name="mcp_client"):
|
|
76
|
+
"""
|
|
77
|
+
Configuration for connecting to an MCP server as a client and exposing selected tools.
|
|
78
|
+
"""
|
|
79
|
+
server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)")
|
|
80
|
+
tool_filter: dict[str, ToolOverrideConfig] | list[str] | None = Field(
|
|
81
|
+
default=None,
|
|
82
|
+
description="""Filter or map tools to expose from the server (list or dict).
|
|
83
|
+
Can be:
|
|
84
|
+
- A list of tool names to expose: ['tool1', 'tool2']
|
|
85
|
+
- A dict mapping tool names to override configs:
|
|
86
|
+
{'tool1': {'alias': 'new_name', 'description': 'New desc'}}
|
|
87
|
+
{'tool2': {'description': 'Override description only'}} # alias defaults to 'tool2'
|
|
88
|
+
""")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class MCPSingleToolConfig(FunctionBaseConfig, name="mcp_single_tool"):
|
|
92
|
+
"""
|
|
93
|
+
Configuration for wrapping a single tool from an MCP server as a NeMo Agent toolkit function.
|
|
94
|
+
"""
|
|
95
|
+
client: MCPBaseClient = Field(..., description="MCP client to use for the tool")
|
|
96
|
+
tool_name: str = Field(..., description="Name of the tool to use")
|
|
97
|
+
tool_description: str | None = Field(default=None, description="Description of the tool")
|
|
98
|
+
|
|
99
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _get_server_name_safe(client: MCPBaseClient) -> str:
|
|
103
|
+
|
|
104
|
+
# Avoid leaking env secrets from stdio client in logs.
|
|
105
|
+
if client.transport == "stdio":
|
|
106
|
+
safe_server = f"stdio: {client.command}"
|
|
107
|
+
else:
|
|
108
|
+
safe_server = f"{client.transport}: {client.url}"
|
|
109
|
+
|
|
110
|
+
return safe_server
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@register_function(config_type=MCPSingleToolConfig)
|
|
114
|
+
async def mcp_single_tool(config: MCPSingleToolConfig, builder: Builder):
|
|
115
|
+
"""
|
|
116
|
+
Wrap a single tool from an MCP server as a NeMo Agent toolkit function.
|
|
117
|
+
"""
|
|
118
|
+
tool = await config.client.get_tool(config.tool_name)
|
|
119
|
+
if config.tool_description:
|
|
120
|
+
tool.set_description(description=config.tool_description)
|
|
121
|
+
input_schema = tool.input_schema
|
|
122
|
+
|
|
123
|
+
logger.info("Configured to use tool: %s from MCP server at %s", tool.name, _get_server_name_safe(config.client))
|
|
124
|
+
|
|
125
|
+
def _convert_from_str(input_str: str) -> BaseModel:
|
|
126
|
+
return input_schema.model_validate_json(input_str)
|
|
127
|
+
|
|
128
|
+
@experimental(feature_name="mcp_client")
|
|
129
|
+
async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
|
|
130
|
+
try:
|
|
131
|
+
if tool_input:
|
|
132
|
+
return await tool.acall(tool_input.model_dump())
|
|
133
|
+
_ = input_schema.model_validate(kwargs)
|
|
134
|
+
return await tool.acall(kwargs)
|
|
135
|
+
except Exception as e:
|
|
136
|
+
return str(e)
|
|
137
|
+
|
|
138
|
+
fn = FunctionInfo.create(single_fn=_response_fn,
|
|
139
|
+
description=tool.description,
|
|
140
|
+
input_schema=input_schema,
|
|
141
|
+
converters=[_convert_from_str])
|
|
142
|
+
yield fn
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@register_function(MCPClientConfig)
|
|
146
|
+
async def mcp_client_function_handler(config: MCPClientConfig, builder: Builder):
|
|
147
|
+
"""
|
|
148
|
+
Connect to an MCP server, discover tools, and register them as functions in the workflow.
|
|
149
|
+
|
|
150
|
+
Note:
|
|
151
|
+
- Uses builder's exit stack to manage client lifecycle
|
|
152
|
+
- Applies tool filters if provided
|
|
153
|
+
"""
|
|
154
|
+
from nat.tool.mcp.mcp_client_base import MCPSSEClient
|
|
155
|
+
from nat.tool.mcp.mcp_client_base import MCPStdioClient
|
|
156
|
+
from nat.tool.mcp.mcp_client_base import MCPStreamableHTTPClient
|
|
157
|
+
|
|
158
|
+
# Build the appropriate client
|
|
159
|
+
client_cls = {
|
|
160
|
+
"stdio": lambda: MCPStdioClient(config.server.command, config.server.args, config.server.env),
|
|
161
|
+
"sse": lambda: MCPSSEClient(str(config.server.url)),
|
|
162
|
+
"streamable-http": lambda: MCPStreamableHTTPClient(str(config.server.url)),
|
|
163
|
+
}.get(config.server.transport)
|
|
164
|
+
|
|
165
|
+
if not client_cls:
|
|
166
|
+
raise ValueError(f"Unsupported transport: {config.server.transport}")
|
|
167
|
+
|
|
168
|
+
client = client_cls()
|
|
169
|
+
logger.info("Configured to use MCP server at %s", _get_server_name_safe(client))
|
|
170
|
+
|
|
171
|
+
# client aenter connects to the server and stores the client in the exit stack
|
|
172
|
+
# so it's cleaned up when the workflow is done
|
|
173
|
+
async with client:
|
|
174
|
+
all_tools = await client.get_tools()
|
|
175
|
+
tool_configs = _filter_and_configure_tools(all_tools, config.tool_filter)
|
|
176
|
+
|
|
177
|
+
for tool_name, tool_cfg in tool_configs.items():
|
|
178
|
+
await builder.add_function(
|
|
179
|
+
tool_cfg["function_name"],
|
|
180
|
+
MCPSingleToolConfig(
|
|
181
|
+
client=client,
|
|
182
|
+
tool_name=tool_name,
|
|
183
|
+
tool_description=tool_cfg["description"],
|
|
184
|
+
))
|
|
185
|
+
|
|
186
|
+
@experimental(feature_name="mcp_client")
|
|
187
|
+
async def idle_fn(text: str) -> str:
|
|
188
|
+
# This function is a placeholder and will be removed when function groups are used
|
|
189
|
+
return f"MCP client connected: {text}"
|
|
190
|
+
|
|
191
|
+
yield FunctionInfo.create(single_fn=idle_fn, description="MCP client")
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _filter_and_configure_tools(all_tools: dict, tool_filter) -> dict[str, dict]:
|
|
195
|
+
"""
|
|
196
|
+
Apply tool filtering and optional aliasing/description overrides.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
Dict[str, dict] where each value has:
|
|
200
|
+
- function_name
|
|
201
|
+
- description
|
|
202
|
+
"""
|
|
203
|
+
if tool_filter is None:
|
|
204
|
+
return {name: {"function_name": name, "description": tool.description} for name, tool in all_tools.items()}
|
|
205
|
+
|
|
206
|
+
if isinstance(tool_filter, list):
|
|
207
|
+
return {
|
|
208
|
+
name: {
|
|
209
|
+
"function_name": name, "description": all_tools[name].description
|
|
210
|
+
}
|
|
211
|
+
for name in tool_filter if name in all_tools
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
if isinstance(tool_filter, dict):
|
|
215
|
+
result = {}
|
|
216
|
+
for name, override in tool_filter.items():
|
|
217
|
+
tool = all_tools.get(name)
|
|
218
|
+
if not tool:
|
|
219
|
+
logger.warning("Tool '%s' specified in tool_filter not found in MCP server", name)
|
|
220
|
+
continue
|
|
221
|
+
|
|
222
|
+
if isinstance(override, ToolOverrideConfig):
|
|
223
|
+
result[name] = {
|
|
224
|
+
"function_name": override.alias or name, "description": override.description or tool.description
|
|
225
|
+
}
|
|
226
|
+
else:
|
|
227
|
+
logger.warning("Unsupported override type for '%s': %s", name, type(override))
|
|
228
|
+
result[name] = {"function_name": name, "description": tool.description}
|
|
229
|
+
return result
|