nvidia-nat 1.3.0a20250828__py3-none-any.whl → 1.3.0a20250830__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. nat/agent/base.py +6 -1
  2. nat/agent/react_agent/agent.py +46 -38
  3. nat/agent/react_agent/register.py +7 -2
  4. nat/agent/rewoo_agent/agent.py +16 -30
  5. nat/agent/rewoo_agent/register.py +3 -3
  6. nat/agent/tool_calling_agent/agent.py +9 -19
  7. nat/agent/tool_calling_agent/register.py +2 -2
  8. nat/builder/eval_builder.py +2 -2
  9. nat/builder/function.py +8 -8
  10. nat/builder/workflow.py +6 -2
  11. nat/builder/workflow_builder.py +21 -24
  12. nat/cli/cli_utils/config_override.py +1 -1
  13. nat/cli/commands/info/list_channels.py +1 -1
  14. nat/cli/commands/info/list_mcp.py +183 -47
  15. nat/cli/commands/registry/publish.py +2 -2
  16. nat/cli/commands/registry/pull.py +2 -2
  17. nat/cli/commands/registry/remove.py +2 -2
  18. nat/cli/commands/registry/search.py +1 -1
  19. nat/cli/commands/start.py +15 -3
  20. nat/cli/commands/uninstall.py +1 -1
  21. nat/cli/commands/workflow/workflow_commands.py +4 -4
  22. nat/data_models/discovery_metadata.py +4 -4
  23. nat/data_models/thinking_mixin.py +27 -8
  24. nat/eval/evaluate.py +6 -6
  25. nat/eval/intermediate_step_adapter.py +1 -1
  26. nat/eval/rag_evaluator/evaluate.py +2 -2
  27. nat/eval/rag_evaluator/register.py +1 -1
  28. nat/eval/remote_workflow.py +3 -3
  29. nat/eval/swe_bench_evaluator/evaluate.py +5 -5
  30. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  31. nat/eval/tunable_rag_evaluator/evaluate.py +3 -3
  32. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +2 -2
  33. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  34. nat/front_ends/fastapi/fastapi_front_end_plugin.py +1 -1
  35. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +3 -3
  36. nat/front_ends/fastapi/message_handler.py +2 -2
  37. nat/front_ends/fastapi/message_validator.py +8 -10
  38. nat/front_ends/fastapi/response_helpers.py +4 -4
  39. nat/front_ends/fastapi/step_adaptor.py +1 -1
  40. nat/front_ends/mcp/mcp_front_end_config.py +5 -0
  41. nat/front_ends/mcp/mcp_front_end_plugin.py +8 -2
  42. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +2 -2
  43. nat/front_ends/mcp/tool_converter.py +40 -13
  44. nat/observability/exporter/base_exporter.py +1 -1
  45. nat/observability/exporter/processing_exporter.py +8 -9
  46. nat/observability/exporter_manager.py +5 -5
  47. nat/observability/mixin/file_mixin.py +7 -7
  48. nat/observability/processor/batching_processor.py +4 -6
  49. nat/observability/register.py +3 -1
  50. nat/profiler/calc/calc_runner.py +3 -4
  51. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  52. nat/profiler/callbacks/langchain_callback_handler.py +5 -5
  53. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  54. nat/profiler/callbacks/semantic_kernel_callback_handler.py +2 -2
  55. nat/profiler/profile_runner.py +1 -1
  56. nat/profiler/utils.py +1 -1
  57. nat/registry_handlers/local/local_handler.py +2 -2
  58. nat/registry_handlers/package_utils.py +1 -1
  59. nat/registry_handlers/pypi/pypi_handler.py +3 -3
  60. nat/registry_handlers/rest/rest_handler.py +4 -4
  61. nat/retriever/milvus/retriever.py +1 -1
  62. nat/retriever/nemo_retriever/retriever.py +1 -1
  63. nat/runtime/loader.py +1 -1
  64. nat/runtime/runner.py +2 -2
  65. nat/settings/global_settings.py +1 -1
  66. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  67. nat/tool/mcp/{mcp_client.py → mcp_client_base.py} +197 -46
  68. nat/tool/mcp/mcp_client_impl.py +229 -0
  69. nat/tool/mcp/mcp_tool.py +79 -42
  70. nat/tool/nvidia_rag.py +1 -1
  71. nat/tool/register.py +1 -0
  72. nat/tool/retriever.py +3 -2
  73. nat/utils/io/yaml_tools.py +1 -1
  74. nat/utils/reactive/observer.py +2 -2
  75. nat/utils/settings/global_settings.py +2 -2
  76. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/METADATA +3 -3
  77. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/RECORD +82 -81
  78. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/WHEEL +0 -0
  79. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/entry_points.txt +0 -0
  80. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  81. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/licenses/LICENSE.md +0 -0
  82. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/top_level.txt +0 -0
@@ -16,12 +16,18 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import logging
19
+ from abc import ABC
20
+ from abc import abstractmethod
21
+ from contextlib import AsyncExitStack
19
22
  from contextlib import asynccontextmanager
20
23
  from enum import Enum
21
24
  from typing import Any
22
25
 
23
26
  from mcp import ClientSession
24
27
  from mcp.client.sse import sse_client
28
+ from mcp.client.stdio import StdioServerParameters
29
+ from mcp.client.stdio import stdio_client
30
+ from mcp.client.streamable_http import streamablehttp_client
25
31
  from mcp.types import TextContent
26
32
  from pydantic import BaseModel
27
33
  from pydantic import Field
@@ -29,6 +35,7 @@ from pydantic import create_model
29
35
 
30
36
  from nat.tool.mcp.exceptions import MCPToolNotFoundError
31
37
  from nat.utils.exception_handlers.mcp import mcp_exception_handler
38
+ from nat.utils.type_utils import override
32
39
 
33
40
  logger = logging.getLogger(__name__)
34
41
 
@@ -107,56 +114,78 @@ def model_from_mcp_schema(name: str, mcp_input_schema: dict) -> type[BaseModel]:
107
114
  return create_model(f"{_generate_valid_classname(name)}InputSchema", **schema_dict)
108
115
 
109
116
 
110
- class MCPSSEClient:
117
+ class MCPBaseClient(ABC):
111
118
  """
112
- Client for creating a session and connecting to an MCP server using SSE
119
+ Base client for creating a session and connecting to an MCP server
113
120
 
114
121
  Args:
115
- url (str): The url of the MCP server
122
+ transport (str): The type of client to use ('sse', 'stdio', or 'streamable-http')
116
123
  """
117
124
 
118
- def __init__(self, url: str):
119
- self.url = url
125
+ def __init__(self, transport: str = 'streamable-http'):
126
+ self._tools = None
127
+ self._transport = transport.lower()
128
+ if self._transport not in ['sse', 'stdio', 'streamable-http']:
129
+ raise ValueError("transport must be either 'sse', 'stdio' or 'streamable-http'")
120
130
 
121
- @asynccontextmanager
122
- async def connect_to_sse_server(self):
123
- """
124
- Establish a session with an MCP SSE server within an aync context
125
- """
126
- async with sse_client(url=self.url) as (read, write):
127
- async with ClientSession(read, write) as session:
128
- await session.initialize()
129
- yield session
131
+ self._exit_stack: AsyncExitStack | None = None
130
132
 
133
+ self._session: ClientSession | None = None
131
134
 
132
- class MCPBuilder(MCPSSEClient):
133
- """
134
- Builder class used to connect to an MCP Server and generate ToolClients
135
+ @property
136
+ def transport(self) -> str:
137
+ return self._transport
135
138
 
136
- Args:
137
- url (str): The url of the MCP server
138
- """
139
+ async def __aenter__(self):
140
+ if self._exit_stack:
141
+ raise RuntimeError("MCPBaseClient already initialized. Use async with to initialize.")
139
142
 
140
- def __init__(self, url):
141
- super().__init__(url)
142
- self._tools = None
143
+ self._exit_stack = AsyncExitStack()
144
+
145
+ self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
146
+
147
+ return self
148
+
149
+ async def __aexit__(self, exc_type, exc_value, traceback):
150
+
151
+ if not self._exit_stack:
152
+ raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
153
+
154
+ await self._exit_stack.aclose()
155
+ self._session = None
156
+ self._exit_stack = None
157
+
158
+ @property
159
+ def server_name(self):
160
+ """
161
+ Provide server name for logging
162
+ """
163
+ return self._transport
164
+
165
+ @abstractmethod
166
+ @asynccontextmanager
167
+ async def connect_to_server(self):
168
+ """
169
+ Establish a session with an MCP server within an async context
170
+ """
171
+ pass
143
172
 
144
- @mcp_exception_handler
145
173
  async def get_tools(self):
146
174
  """
147
175
  Retrieve a dictionary of all tools served by the MCP server.
176
+ """
148
177
 
149
- Returns:
150
- Dict of tool name to MCPToolClient
178
+ if not self._session:
179
+ raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
151
180
 
152
- Raises:
153
- MCPError: If connection or tool retrieval fails
154
- """
155
- async with self.connect_to_sse_server() as session:
156
- response = await session.list_tools()
181
+ response = await self._session.list_tools()
157
182
 
158
183
  return {
159
- tool.name: MCPToolClient(self.url, tool.name, tool.description, tool_input_schema=tool.inputSchema)
184
+ tool.name:
185
+ MCPToolClient(session=self._session,
186
+ tool_name=tool.name,
187
+ tool_description=tool.description,
188
+ tool_input_schema=tool.inputSchema)
160
189
  for tool in response.tools
161
190
  }
162
191
 
@@ -172,9 +201,11 @@ class MCPBuilder(MCPSSEClient):
172
201
  MCPToolClient for the configured tool.
173
202
 
174
203
  Raises:
175
- MCPToolNotFoundError: If no tool is available with that name
176
- MCPError: If connection fails
204
+ MCPToolNotFoundError: If no tool is available with that name.
177
205
  """
206
+ if not self._exit_stack:
207
+ raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
208
+
178
209
  if not self._tools:
179
210
  self._tools = await self.get_tools()
180
211
 
@@ -185,27 +216,143 @@ class MCPBuilder(MCPSSEClient):
185
216
 
186
217
  @mcp_exception_handler
187
218
  async def call_tool(self, tool_name: str, tool_args: dict | None):
188
- async with self.connect_to_sse_server() as session:
189
- result = await session.call_tool(tool_name, tool_args)
190
- return result
219
+ if not self._session:
220
+ raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
221
+
222
+ result = await self._session.call_tool(tool_name, tool_args)
223
+ return result
224
+
225
+
226
+ class MCPSSEClient(MCPBaseClient):
227
+ """
228
+ Client for creating a session and connecting to an MCP server using SSE
229
+
230
+ Args:
231
+ url (str): The url of the MCP server
232
+ """
233
+
234
+ def __init__(self, url: str):
235
+ super().__init__("sse")
236
+ self._url = url
237
+
238
+ @property
239
+ def url(self) -> str:
240
+ return self._url
241
+
242
+ @property
243
+ def server_name(self):
244
+ return f"sse:{self._url}"
245
+
246
+ @asynccontextmanager
247
+ @override
248
+ async def connect_to_server(self):
249
+ """
250
+ Establish a session with an MCP SSE server within an async context
251
+ """
252
+ async with sse_client(url=self._url) as (read, write):
253
+ async with ClientSession(read, write) as session:
254
+ await session.initialize()
255
+ yield session
256
+
257
+
258
+ class MCPStdioClient(MCPBaseClient):
259
+ """
260
+ Client for creating a session and connecting to an MCP server using stdio
261
+
262
+ Args:
263
+ command (str): The command to run
264
+ args (list[str] | None): Additional arguments for the command
265
+ env (dict[str, str] | None): Environment variables to set for the process
266
+ """
267
+
268
+ def __init__(self, command: str, args: list[str] | None = None, env: dict[str, str] | None = None):
269
+ super().__init__("stdio")
270
+ self._command = command
271
+ self._args = args
272
+ self._env = env
273
+
274
+ @property
275
+ def command(self) -> str:
276
+ return self._command
277
+
278
+ @property
279
+ def server_name(self):
280
+ return f"stdio:{self._command}"
281
+
282
+ @property
283
+ def args(self) -> list[str] | None:
284
+ return self._args
285
+
286
+ @property
287
+ def env(self) -> dict[str, str] | None:
288
+ return self._env
289
+
290
+ @asynccontextmanager
291
+ @override
292
+ async def connect_to_server(self):
293
+ """
294
+ Establish a session with an MCP server via stdio within an async context
295
+ """
296
+
297
+ server_params = StdioServerParameters(command=self._command, args=self._args or [], env=self._env)
298
+ async with stdio_client(server_params) as (read, write):
299
+ async with ClientSession(read, write) as session:
300
+ await session.initialize()
301
+ yield session
302
+
303
+
304
+ class MCPStreamableHTTPClient(MCPBaseClient):
305
+ """
306
+ Client for creating a session and connecting to an MCP server using streamable-http
307
+
308
+ Args:
309
+ url (str): The url of the MCP server
310
+ """
311
+
312
+ def __init__(self, url: str):
313
+ super().__init__("streamable-http")
314
+
315
+ self._url = url
316
+
317
+ @property
318
+ def url(self) -> str:
319
+ return self._url
191
320
 
321
+ @property
322
+ def server_name(self):
323
+ return f"streamable-http:{self._url}"
192
324
 
193
- class MCPToolClient(MCPSSEClient):
325
+ @asynccontextmanager
326
+ async def connect_to_server(self):
327
+ """
328
+ Establish a session with an MCP server via streamable-http within an async context
329
+ """
330
+ async with streamablehttp_client(url=self._url) as (read, write, get_session_id):
331
+ async with ClientSession(read, write) as session:
332
+ await session.initialize()
333
+ yield session
334
+
335
+
336
+ class MCPToolClient:
194
337
  """
195
338
  Client wrapper used to call an MCP tool.
196
339
 
197
340
  Args:
198
- url (str): The url of the MCP server
341
+ connect_fn (callable): Function that returns an async context manager for connecting to the server
199
342
  tool_name (str): The name of the tool to wrap
200
343
  tool_description (str): The description of the tool provided by the MCP server.
201
344
  tool_input_schema (dict): The input schema for the tool.
202
345
  """
203
346
 
204
- def __init__(self, url: str, tool_name: str, tool_description: str | None, tool_input_schema: dict | None = None):
205
- super().__init__(url)
347
+ def __init__(self,
348
+ session: ClientSession,
349
+ tool_name: str,
350
+ tool_description: str | None,
351
+ tool_input_schema: dict | None = None):
352
+ self._session = session
206
353
  self._tool_name = tool_name
207
354
  self._tool_description = tool_description
208
- self._input_schema = model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None
355
+ self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
209
356
 
210
357
  @property
211
358
  def name(self):
@@ -234,7 +381,6 @@ class MCPToolClient(MCPSSEClient):
234
381
  """
235
382
  self._tool_description = description
236
383
 
237
- @mcp_exception_handler
238
384
  async def acall(self, tool_args: dict) -> str:
239
385
  """
240
386
  Call the MCP tool with the provided arguments.
@@ -242,14 +388,19 @@ class MCPToolClient(MCPSSEClient):
242
388
  Args:
243
389
  tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
244
390
  """
245
- async with self.connect_to_sse_server() as session:
246
- result = await session.call_tool(self._tool_name, tool_args)
391
+ result = await self._session.call_tool(self._tool_name, tool_args)
247
392
 
248
393
  output = []
394
+
249
395
  for res in result.content:
250
396
  if isinstance(res, TextContent):
251
397
  output.append(res.text)
252
398
  else:
253
399
  # Log non-text content for now
254
400
  logger.warning("Got not-text output from %s of type %s", self.name, type(res))
255
- return "\n".join(output)
401
+ result_str = "\n".join(output)
402
+
403
+ if result.isError:
404
+ raise RuntimeError(result_str)
405
+
406
+ return result_str
@@ -0,0 +1,229 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ from typing import Literal
18
+
19
+ from pydantic import BaseModel
20
+ from pydantic import Field
21
+ from pydantic import HttpUrl
22
+ from pydantic import model_validator
23
+
24
+ from nat.builder.builder import Builder
25
+ from nat.builder.function_info import FunctionInfo
26
+ from nat.cli.register_workflow import register_function
27
+ from nat.data_models.function import FunctionBaseConfig
28
+ from nat.experimental.decorators.experimental_warning_decorator import experimental
29
+ from nat.tool.mcp.mcp_client_base import MCPBaseClient
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # All functions in this file are experimental
34
+
35
+
36
+ class ToolOverrideConfig(BaseModel):
37
+ """
38
+ Configuration for overriding tool properties when exposing from MCP server.
39
+ """
40
+ alias: str | None = Field(default=None, description="Override the tool name (function name in the workflow)")
41
+ description: str | None = Field(default=None, description="Override the tool description")
42
+
43
+
44
+ class MCPServerConfig(BaseModel):
45
+ """
46
+ Server connection details for MCP client.
47
+ Supports stdio, sse, and streamable-http transports.
48
+ streamable-http is the recommended default for HTTP-based connections.
49
+ """
50
+ transport: Literal["stdio", "sse", "streamable-http"] = Field(
51
+ ..., description="Transport type to connect to the MCP server (stdio, sse, or streamable-http)")
52
+ url: HttpUrl | None = Field(default=None,
53
+ description="URL of the MCP server (for sse or streamable-http transport)")
54
+ command: str | None = Field(default=None,
55
+ description="Command to run for stdio transport (e.g. 'python' or 'docker')")
56
+ args: list[str] | None = Field(default=None, description="Arguments for the stdio command")
57
+ env: dict[str, str] | None = Field(default=None, description="Environment variables for the stdio process")
58
+
59
+ @model_validator(mode="after")
60
+ def validate_model(self):
61
+ """Validate that stdio and SSE/Streamable HTTP properties are mutually exclusive."""
62
+ if self.transport == "stdio":
63
+ if self.url is not None:
64
+ raise ValueError("url should not be set when using stdio transport")
65
+ if not self.command:
66
+ raise ValueError("command is required when using stdio transport")
67
+ elif self.transport in ("sse", "streamable-http"):
68
+ if self.command is not None or self.args is not None or self.env is not None:
69
+ raise ValueError("command, args, and env should not be set when using sse or streamable-http transport")
70
+ if not self.url:
71
+ raise ValueError("url is required when using sse or streamable-http transport")
72
+ return self
73
+
74
+
75
+ class MCPClientConfig(FunctionBaseConfig, name="mcp_client"):
76
+ """
77
+ Configuration for connecting to an MCP server as a client and exposing selected tools.
78
+ """
79
+ server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)")
80
+ tool_filter: dict[str, ToolOverrideConfig] | list[str] | None = Field(
81
+ default=None,
82
+ description="""Filter or map tools to expose from the server (list or dict).
83
+ Can be:
84
+ - A list of tool names to expose: ['tool1', 'tool2']
85
+ - A dict mapping tool names to override configs:
86
+ {'tool1': {'alias': 'new_name', 'description': 'New desc'}}
87
+ {'tool2': {'description': 'Override description only'}} # alias defaults to 'tool2'
88
+ """)
89
+
90
+
91
+ class MCPSingleToolConfig(FunctionBaseConfig, name="mcp_single_tool"):
92
+ """
93
+ Configuration for wrapping a single tool from an MCP server as a NeMo Agent toolkit function.
94
+ """
95
+ client: MCPBaseClient = Field(..., description="MCP client to use for the tool")
96
+ tool_name: str = Field(..., description="Name of the tool to use")
97
+ tool_description: str | None = Field(default=None, description="Description of the tool")
98
+
99
+ model_config = {"arbitrary_types_allowed": True}
100
+
101
+
102
+ def _get_server_name_safe(client: MCPBaseClient) -> str:
103
+
104
+ # Avoid leaking env secrets from stdio client in logs.
105
+ if client.transport == "stdio":
106
+ safe_server = f"stdio: {client.command}"
107
+ else:
108
+ safe_server = f"{client.transport}: {client.url}"
109
+
110
+ return safe_server
111
+
112
+
113
+ @register_function(config_type=MCPSingleToolConfig)
114
+ async def mcp_single_tool(config: MCPSingleToolConfig, builder: Builder):
115
+ """
116
+ Wrap a single tool from an MCP server as a NeMo Agent toolkit function.
117
+ """
118
+ tool = await config.client.get_tool(config.tool_name)
119
+ if config.tool_description:
120
+ tool.set_description(description=config.tool_description)
121
+ input_schema = tool.input_schema
122
+
123
+ logger.info("Configured to use tool: %s from MCP server at %s", tool.name, _get_server_name_safe(config.client))
124
+
125
+ def _convert_from_str(input_str: str) -> BaseModel:
126
+ return input_schema.model_validate_json(input_str)
127
+
128
+ @experimental(feature_name="mcp_client")
129
+ async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
130
+ try:
131
+ if tool_input:
132
+ return await tool.acall(tool_input.model_dump())
133
+ _ = input_schema.model_validate(kwargs)
134
+ return await tool.acall(kwargs)
135
+ except Exception as e:
136
+ return str(e)
137
+
138
+ fn = FunctionInfo.create(single_fn=_response_fn,
139
+ description=tool.description,
140
+ input_schema=input_schema,
141
+ converters=[_convert_from_str])
142
+ yield fn
143
+
144
+
145
+ @register_function(MCPClientConfig)
146
+ async def mcp_client_function_handler(config: MCPClientConfig, builder: Builder):
147
+ """
148
+ Connect to an MCP server, discover tools, and register them as functions in the workflow.
149
+
150
+ Note:
151
+ - Uses builder's exit stack to manage client lifecycle
152
+ - Applies tool filters if provided
153
+ """
154
+ from nat.tool.mcp.mcp_client_base import MCPSSEClient
155
+ from nat.tool.mcp.mcp_client_base import MCPStdioClient
156
+ from nat.tool.mcp.mcp_client_base import MCPStreamableHTTPClient
157
+
158
+ # Build the appropriate client
159
+ client_cls = {
160
+ "stdio": lambda: MCPStdioClient(config.server.command, config.server.args, config.server.env),
161
+ "sse": lambda: MCPSSEClient(str(config.server.url)),
162
+ "streamable-http": lambda: MCPStreamableHTTPClient(str(config.server.url)),
163
+ }.get(config.server.transport)
164
+
165
+ if not client_cls:
166
+ raise ValueError(f"Unsupported transport: {config.server.transport}")
167
+
168
+ client = client_cls()
169
+ logger.info("Configured to use MCP server at %s", _get_server_name_safe(client))
170
+
171
+ # client aenter connects to the server and stores the client in the exit stack
172
+ # so it's cleaned up when the workflow is done
173
+ async with client:
174
+ all_tools = await client.get_tools()
175
+ tool_configs = _filter_and_configure_tools(all_tools, config.tool_filter)
176
+
177
+ for tool_name, tool_cfg in tool_configs.items():
178
+ await builder.add_function(
179
+ tool_cfg["function_name"],
180
+ MCPSingleToolConfig(
181
+ client=client,
182
+ tool_name=tool_name,
183
+ tool_description=tool_cfg["description"],
184
+ ))
185
+
186
+ @experimental(feature_name="mcp_client")
187
+ async def idle_fn(text: str) -> str:
188
+ # This function is a placeholder and will be removed when function groups are used
189
+ return f"MCP client connected: {text}"
190
+
191
+ yield FunctionInfo.create(single_fn=idle_fn, description="MCP client")
192
+
193
+
194
+ def _filter_and_configure_tools(all_tools: dict, tool_filter) -> dict[str, dict]:
195
+ """
196
+ Apply tool filtering and optional aliasing/description overrides.
197
+
198
+ Returns:
199
+ Dict[str, dict] where each value has:
200
+ - function_name
201
+ - description
202
+ """
203
+ if tool_filter is None:
204
+ return {name: {"function_name": name, "description": tool.description} for name, tool in all_tools.items()}
205
+
206
+ if isinstance(tool_filter, list):
207
+ return {
208
+ name: {
209
+ "function_name": name, "description": all_tools[name].description
210
+ }
211
+ for name in tool_filter if name in all_tools
212
+ }
213
+
214
+ if isinstance(tool_filter, dict):
215
+ result = {}
216
+ for name, override in tool_filter.items():
217
+ tool = all_tools.get(name)
218
+ if not tool:
219
+ logger.warning("Tool '%s' specified in tool_filter not found in MCP server", name)
220
+ continue
221
+
222
+ if isinstance(override, ToolOverrideConfig):
223
+ result[name] = {
224
+ "function_name": override.alias or name, "description": override.description or tool.description
225
+ }
226
+ else:
227
+ logger.warning("Unsupported override type for '%s': %s", name, type(override))
228
+ result[name] = {"function_name": name, "description": tool.description}
229
+ return result