nvidia-nat 1.3.0a20250906__py3-none-any.whl → 1.3.0a20250909__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,406 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from __future__ import annotations
17
-
18
- import logging
19
- from abc import ABC
20
- from abc import abstractmethod
21
- from contextlib import AsyncExitStack
22
- from contextlib import asynccontextmanager
23
- from enum import Enum
24
- from typing import Any
25
-
26
- from mcp import ClientSession
27
- from mcp.client.sse import sse_client
28
- from mcp.client.stdio import StdioServerParameters
29
- from mcp.client.stdio import stdio_client
30
- from mcp.client.streamable_http import streamablehttp_client
31
- from mcp.types import TextContent
32
- from pydantic import BaseModel
33
- from pydantic import Field
34
- from pydantic import create_model
35
-
36
- from nat.tool.mcp.exceptions import MCPToolNotFoundError
37
- from nat.utils.exception_handlers.mcp import mcp_exception_handler
38
- from nat.utils.type_utils import override
39
-
40
- logger = logging.getLogger(__name__)
41
-
42
-
43
- def model_from_mcp_schema(name: str, mcp_input_schema: dict) -> type[BaseModel]:
44
- """
45
- Create a pydantic model from the input schema of the MCP tool
46
- """
47
- _type_map = {
48
- "string": str,
49
- "number": float,
50
- "integer": int,
51
- "boolean": bool,
52
- "array": list,
53
- "null": None,
54
- "object": dict,
55
- }
56
-
57
- properties = mcp_input_schema.get("properties", {})
58
- required_fields = set(mcp_input_schema.get("required", []))
59
- schema_dict = {}
60
-
61
- def _generate_valid_classname(class_name: str):
62
- return class_name.replace('_', ' ').replace('-', ' ').title().replace(' ', '')
63
-
64
- def _generate_field(field_name: str, field_properties: dict[str, Any]) -> tuple:
65
- json_type = field_properties.get("type", "string")
66
- enum_vals = field_properties.get("enum")
67
-
68
- if enum_vals:
69
- enum_name = f"{field_name.capitalize()}Enum"
70
- field_type = Enum(enum_name, {item: item for item in enum_vals})
71
-
72
- elif json_type == "object" and "properties" in field_properties:
73
- field_type = model_from_mcp_schema(name=field_name, mcp_input_schema=field_properties)
74
- elif json_type == "array" and "items" in field_properties:
75
- item_properties = field_properties.get("items", {})
76
- if item_properties.get("type") == "object":
77
- item_type = model_from_mcp_schema(name=field_name, mcp_input_schema=item_properties)
78
- else:
79
- item_type = _type_map.get(item_properties.get("type", "string"), Any)
80
- field_type = list[item_type]
81
- elif isinstance(json_type, list):
82
- field_type = None
83
- for t in json_type:
84
- mapped = _type_map.get(t, Any)
85
- field_type = mapped if field_type is None else field_type | mapped
86
-
87
- return field_type, Field(
88
- default=field_properties.get("default", None if "null" in json_type else ...),
89
- description=field_properties.get("description", "")
90
- )
91
- else:
92
- field_type = _type_map.get(json_type, Any)
93
-
94
- # Determine the default value based on whether the field is required
95
- if field_name in required_fields:
96
- # Field is required - use explicit default if provided, otherwise make it required
97
- default_value = field_properties.get("default", ...)
98
- else:
99
- # Field is optional - use explicit default if provided, otherwise None
100
- default_value = field_properties.get("default", None)
101
- # Make the type optional if no default was provided
102
- if "default" not in field_properties:
103
- field_type = field_type | None
104
-
105
- nullable = field_properties.get("nullable", False)
106
- description = field_properties.get("description", "")
107
-
108
- field_type = field_type | None if nullable else field_type
109
-
110
- return field_type, Field(default=default_value, description=description)
111
-
112
- for field_name, field_props in properties.items():
113
- schema_dict[field_name] = _generate_field(field_name=field_name, field_properties=field_props)
114
- return create_model(f"{_generate_valid_classname(name)}InputSchema", **schema_dict)
115
-
116
-
117
- class MCPBaseClient(ABC):
118
- """
119
- Base client for creating a session and connecting to an MCP server
120
-
121
- Args:
122
- transport (str): The type of client to use ('sse', 'stdio', or 'streamable-http')
123
- """
124
-
125
- def __init__(self, transport: str = 'streamable-http'):
126
- self._tools = None
127
- self._transport = transport.lower()
128
- if self._transport not in ['sse', 'stdio', 'streamable-http']:
129
- raise ValueError("transport must be either 'sse', 'stdio' or 'streamable-http'")
130
-
131
- self._exit_stack: AsyncExitStack | None = None
132
-
133
- self._session: ClientSession | None = None
134
-
135
- @property
136
- def transport(self) -> str:
137
- return self._transport
138
-
139
- async def __aenter__(self):
140
- if self._exit_stack:
141
- raise RuntimeError("MCPBaseClient already initialized. Use async with to initialize.")
142
-
143
- self._exit_stack = AsyncExitStack()
144
-
145
- self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
146
-
147
- return self
148
-
149
- async def __aexit__(self, exc_type, exc_value, traceback):
150
-
151
- if not self._exit_stack:
152
- raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
153
-
154
- await self._exit_stack.aclose()
155
- self._session = None
156
- self._exit_stack = None
157
-
158
- @property
159
- def server_name(self):
160
- """
161
- Provide server name for logging
162
- """
163
- return self._transport
164
-
165
- @abstractmethod
166
- @asynccontextmanager
167
- async def connect_to_server(self):
168
- """
169
- Establish a session with an MCP server within an async context
170
- """
171
- pass
172
-
173
- async def get_tools(self):
174
- """
175
- Retrieve a dictionary of all tools served by the MCP server.
176
- """
177
-
178
- if not self._session:
179
- raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
180
-
181
- response = await self._session.list_tools()
182
-
183
- return {
184
- tool.name:
185
- MCPToolClient(session=self._session,
186
- tool_name=tool.name,
187
- tool_description=tool.description,
188
- tool_input_schema=tool.inputSchema)
189
- for tool in response.tools
190
- }
191
-
192
- @mcp_exception_handler
193
- async def get_tool(self, tool_name: str) -> MCPToolClient:
194
- """
195
- Get an MCP Tool by name.
196
-
197
- Args:
198
- tool_name (str): Name of the tool to load.
199
-
200
- Returns:
201
- MCPToolClient for the configured tool.
202
-
203
- Raises:
204
- MCPToolNotFoundError: If no tool is available with that name.
205
- """
206
- if not self._exit_stack:
207
- raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
208
-
209
- if not self._tools:
210
- self._tools = await self.get_tools()
211
-
212
- tool = self._tools.get(tool_name)
213
- if not tool:
214
- raise MCPToolNotFoundError(tool_name, self.url)
215
- return tool
216
-
217
- @mcp_exception_handler
218
- async def call_tool(self, tool_name: str, tool_args: dict | None):
219
- if not self._session:
220
- raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
221
-
222
- result = await self._session.call_tool(tool_name, tool_args)
223
- return result
224
-
225
-
226
- class MCPSSEClient(MCPBaseClient):
227
- """
228
- Client for creating a session and connecting to an MCP server using SSE
229
-
230
- Args:
231
- url (str): The url of the MCP server
232
- """
233
-
234
- def __init__(self, url: str):
235
- super().__init__("sse")
236
- self._url = url
237
-
238
- @property
239
- def url(self) -> str:
240
- return self._url
241
-
242
- @property
243
- def server_name(self):
244
- return f"sse:{self._url}"
245
-
246
- @asynccontextmanager
247
- @override
248
- async def connect_to_server(self):
249
- """
250
- Establish a session with an MCP SSE server within an async context
251
- """
252
- async with sse_client(url=self._url) as (read, write):
253
- async with ClientSession(read, write) as session:
254
- await session.initialize()
255
- yield session
256
-
257
-
258
- class MCPStdioClient(MCPBaseClient):
259
- """
260
- Client for creating a session and connecting to an MCP server using stdio
261
-
262
- Args:
263
- command (str): The command to run
264
- args (list[str] | None): Additional arguments for the command
265
- env (dict[str, str] | None): Environment variables to set for the process
266
- """
267
-
268
- def __init__(self, command: str, args: list[str] | None = None, env: dict[str, str] | None = None):
269
- super().__init__("stdio")
270
- self._command = command
271
- self._args = args
272
- self._env = env
273
-
274
- @property
275
- def command(self) -> str:
276
- return self._command
277
-
278
- @property
279
- def server_name(self):
280
- return f"stdio:{self._command}"
281
-
282
- @property
283
- def args(self) -> list[str] | None:
284
- return self._args
285
-
286
- @property
287
- def env(self) -> dict[str, str] | None:
288
- return self._env
289
-
290
- @asynccontextmanager
291
- @override
292
- async def connect_to_server(self):
293
- """
294
- Establish a session with an MCP server via stdio within an async context
295
- """
296
-
297
- server_params = StdioServerParameters(command=self._command, args=self._args or [], env=self._env)
298
- async with stdio_client(server_params) as (read, write):
299
- async with ClientSession(read, write) as session:
300
- await session.initialize()
301
- yield session
302
-
303
-
304
- class MCPStreamableHTTPClient(MCPBaseClient):
305
- """
306
- Client for creating a session and connecting to an MCP server using streamable-http
307
-
308
- Args:
309
- url (str): The url of the MCP server
310
- """
311
-
312
- def __init__(self, url: str):
313
- super().__init__("streamable-http")
314
-
315
- self._url = url
316
-
317
- @property
318
- def url(self) -> str:
319
- return self._url
320
-
321
- @property
322
- def server_name(self):
323
- return f"streamable-http:{self._url}"
324
-
325
- @asynccontextmanager
326
- async def connect_to_server(self):
327
- """
328
- Establish a session with an MCP server via streamable-http within an async context
329
- """
330
- async with streamablehttp_client(url=self._url) as (read, write, get_session_id):
331
- async with ClientSession(read, write) as session:
332
- await session.initialize()
333
- yield session
334
-
335
-
336
- class MCPToolClient:
337
- """
338
- Client wrapper used to call an MCP tool.
339
-
340
- Args:
341
- connect_fn (callable): Function that returns an async context manager for connecting to the server
342
- tool_name (str): The name of the tool to wrap
343
- tool_description (str): The description of the tool provided by the MCP server.
344
- tool_input_schema (dict): The input schema for the tool.
345
- """
346
-
347
- def __init__(self,
348
- session: ClientSession,
349
- tool_name: str,
350
- tool_description: str | None,
351
- tool_input_schema: dict | None = None):
352
- self._session = session
353
- self._tool_name = tool_name
354
- self._tool_description = tool_description
355
- self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
356
-
357
- @property
358
- def name(self):
359
- """Returns the name of the tool."""
360
- return self._tool_name
361
-
362
- @property
363
- def description(self):
364
- """
365
- Returns the tool's description. If none was provided. Provides a simple description using the tool's name
366
- """
367
- if not self._tool_description:
368
- return f"MCP Tool {self._tool_name}"
369
- return self._tool_description
370
-
371
- @property
372
- def input_schema(self):
373
- """
374
- Returns the tool's input_schema.
375
- """
376
- return self._input_schema
377
-
378
- def set_description(self, description: str):
379
- """
380
- Manually define the tool's description using the provided string.
381
- """
382
- self._tool_description = description
383
-
384
- async def acall(self, tool_args: dict) -> str:
385
- """
386
- Call the MCP tool with the provided arguments.
387
-
388
- Args:
389
- tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
390
- """
391
- result = await self._session.call_tool(self._tool_name, tool_args)
392
-
393
- output = []
394
-
395
- for res in result.content:
396
- if isinstance(res, TextContent):
397
- output.append(res.text)
398
- else:
399
- # Log non-text content for now
400
- logger.warning("Got not-text output from %s of type %s", self.name, type(res))
401
- result_str = "\n".join(output)
402
-
403
- if result.isError:
404
- raise RuntimeError(result_str)
405
-
406
- return result_str
@@ -1,229 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import logging
17
- from typing import Literal
18
-
19
- from pydantic import BaseModel
20
- from pydantic import Field
21
- from pydantic import HttpUrl
22
- from pydantic import model_validator
23
-
24
- from nat.builder.builder import Builder
25
- from nat.builder.function_info import FunctionInfo
26
- from nat.cli.register_workflow import register_function
27
- from nat.data_models.function import FunctionBaseConfig
28
- from nat.experimental.decorators.experimental_warning_decorator import experimental
29
- from nat.tool.mcp.mcp_client_base import MCPBaseClient
30
-
31
- logger = logging.getLogger(__name__)
32
-
33
- # All functions in this file are experimental
34
-
35
-
36
- class ToolOverrideConfig(BaseModel):
37
- """
38
- Configuration for overriding tool properties when exposing from MCP server.
39
- """
40
- alias: str | None = Field(default=None, description="Override the tool name (function name in the workflow)")
41
- description: str | None = Field(default=None, description="Override the tool description")
42
-
43
-
44
- class MCPServerConfig(BaseModel):
45
- """
46
- Server connection details for MCP client.
47
- Supports stdio, sse, and streamable-http transports.
48
- streamable-http is the recommended default for HTTP-based connections.
49
- """
50
- transport: Literal["stdio", "sse", "streamable-http"] = Field(
51
- ..., description="Transport type to connect to the MCP server (stdio, sse, or streamable-http)")
52
- url: HttpUrl | None = Field(default=None,
53
- description="URL of the MCP server (for sse or streamable-http transport)")
54
- command: str | None = Field(default=None,
55
- description="Command to run for stdio transport (e.g. 'python' or 'docker')")
56
- args: list[str] | None = Field(default=None, description="Arguments for the stdio command")
57
- env: dict[str, str] | None = Field(default=None, description="Environment variables for the stdio process")
58
-
59
- @model_validator(mode="after")
60
- def validate_model(self):
61
- """Validate that stdio and SSE/Streamable HTTP properties are mutually exclusive."""
62
- if self.transport == "stdio":
63
- if self.url is not None:
64
- raise ValueError("url should not be set when using stdio transport")
65
- if not self.command:
66
- raise ValueError("command is required when using stdio transport")
67
- elif self.transport in ("sse", "streamable-http"):
68
- if self.command is not None or self.args is not None or self.env is not None:
69
- raise ValueError("command, args, and env should not be set when using sse or streamable-http transport")
70
- if not self.url:
71
- raise ValueError("url is required when using sse or streamable-http transport")
72
- return self
73
-
74
-
75
- class MCPClientConfig(FunctionBaseConfig, name="mcp_client"):
76
- """
77
- Configuration for connecting to an MCP server as a client and exposing selected tools.
78
- """
79
- server: MCPServerConfig = Field(..., description="Server connection details (transport, url/command, etc.)")
80
- tool_filter: dict[str, ToolOverrideConfig] | list[str] | None = Field(
81
- default=None,
82
- description="""Filter or map tools to expose from the server (list or dict).
83
- Can be:
84
- - A list of tool names to expose: ['tool1', 'tool2']
85
- - A dict mapping tool names to override configs:
86
- {'tool1': {'alias': 'new_name', 'description': 'New desc'}}
87
- {'tool2': {'description': 'Override description only'}} # alias defaults to 'tool2'
88
- """)
89
-
90
-
91
- class MCPSingleToolConfig(FunctionBaseConfig, name="mcp_single_tool"):
92
- """
93
- Configuration for wrapping a single tool from an MCP server as a NeMo Agent toolkit function.
94
- """
95
- client: MCPBaseClient = Field(..., description="MCP client to use for the tool")
96
- tool_name: str = Field(..., description="Name of the tool to use")
97
- tool_description: str | None = Field(default=None, description="Description of the tool")
98
-
99
- model_config = {"arbitrary_types_allowed": True}
100
-
101
-
102
- def _get_server_name_safe(client: MCPBaseClient) -> str:
103
-
104
- # Avoid leaking env secrets from stdio client in logs.
105
- if client.transport == "stdio":
106
- safe_server = f"stdio: {client.command}"
107
- else:
108
- safe_server = f"{client.transport}: {client.url}"
109
-
110
- return safe_server
111
-
112
-
113
- @register_function(config_type=MCPSingleToolConfig)
114
- async def mcp_single_tool(config: MCPSingleToolConfig, builder: Builder):
115
- """
116
- Wrap a single tool from an MCP server as a NeMo Agent toolkit function.
117
- """
118
- tool = await config.client.get_tool(config.tool_name)
119
- if config.tool_description:
120
- tool.set_description(description=config.tool_description)
121
- input_schema = tool.input_schema
122
-
123
- logger.info("Configured to use tool: %s from MCP server at %s", tool.name, _get_server_name_safe(config.client))
124
-
125
- def _convert_from_str(input_str: str) -> BaseModel:
126
- return input_schema.model_validate_json(input_str)
127
-
128
- @experimental(feature_name="mcp_client")
129
- async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
130
- try:
131
- if tool_input:
132
- return await tool.acall(tool_input.model_dump())
133
- _ = input_schema.model_validate(kwargs)
134
- return await tool.acall(kwargs)
135
- except Exception as e:
136
- return str(e)
137
-
138
- fn = FunctionInfo.create(single_fn=_response_fn,
139
- description=tool.description,
140
- input_schema=input_schema,
141
- converters=[_convert_from_str])
142
- yield fn
143
-
144
-
145
- @register_function(MCPClientConfig)
146
- async def mcp_client_function_handler(config: MCPClientConfig, builder: Builder):
147
- """
148
- Connect to an MCP server, discover tools, and register them as functions in the workflow.
149
-
150
- Note:
151
- - Uses builder's exit stack to manage client lifecycle
152
- - Applies tool filters if provided
153
- """
154
- from nat.tool.mcp.mcp_client_base import MCPSSEClient
155
- from nat.tool.mcp.mcp_client_base import MCPStdioClient
156
- from nat.tool.mcp.mcp_client_base import MCPStreamableHTTPClient
157
-
158
- # Build the appropriate client
159
- client_cls = {
160
- "stdio": lambda: MCPStdioClient(config.server.command, config.server.args, config.server.env),
161
- "sse": lambda: MCPSSEClient(str(config.server.url)),
162
- "streamable-http": lambda: MCPStreamableHTTPClient(str(config.server.url)),
163
- }.get(config.server.transport)
164
-
165
- if not client_cls:
166
- raise ValueError(f"Unsupported transport: {config.server.transport}")
167
-
168
- client = client_cls()
169
- logger.info("Configured to use MCP server at %s", _get_server_name_safe(client))
170
-
171
- # client aenter connects to the server and stores the client in the exit stack
172
- # so it's cleaned up when the workflow is done
173
- async with client:
174
- all_tools = await client.get_tools()
175
- tool_configs = _filter_and_configure_tools(all_tools, config.tool_filter)
176
-
177
- for tool_name, tool_cfg in tool_configs.items():
178
- await builder.add_function(
179
- tool_cfg["function_name"],
180
- MCPSingleToolConfig(
181
- client=client,
182
- tool_name=tool_name,
183
- tool_description=tool_cfg["description"],
184
- ))
185
-
186
- @experimental(feature_name="mcp_client")
187
- async def idle_fn(text: str) -> str:
188
- # This function is a placeholder and will be removed when function groups are used
189
- return f"MCP client connected: {text}"
190
-
191
- yield FunctionInfo.create(single_fn=idle_fn, description="MCP client")
192
-
193
-
194
- def _filter_and_configure_tools(all_tools: dict, tool_filter) -> dict[str, dict]:
195
- """
196
- Apply tool filtering and optional aliasing/description overrides.
197
-
198
- Returns:
199
- Dict[str, dict] where each value has:
200
- - function_name
201
- - description
202
- """
203
- if tool_filter is None:
204
- return {name: {"function_name": name, "description": tool.description} for name, tool in all_tools.items()}
205
-
206
- if isinstance(tool_filter, list):
207
- return {
208
- name: {
209
- "function_name": name, "description": all_tools[name].description
210
- }
211
- for name in tool_filter if name in all_tools
212
- }
213
-
214
- if isinstance(tool_filter, dict):
215
- result = {}
216
- for name, override in tool_filter.items():
217
- tool = all_tools.get(name)
218
- if not tool:
219
- logger.warning("Tool '%s' specified in tool_filter not found in MCP server", name)
220
- continue
221
-
222
- if isinstance(override, ToolOverrideConfig):
223
- result[name] = {
224
- "function_name": override.alias or name, "description": override.description or tool.description
225
- }
226
- else:
227
- logger.warning("Unsupported override type for '%s': %s", name, type(override))
228
- result[name] = {"function_name": name, "description": tool.description}
229
- return result