nvidia-nat 1.3.0a20250829__py3-none-any.whl → 1.3.0a20250831__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/react_agent/agent.py +37 -25
- nat/agent/react_agent/register.py +6 -1
- nat/builder/workflow.py +6 -2
- nat/cli/commands/info/list_mcp.py +183 -47
- nat/cli/commands/start.py +14 -2
- nat/data_models/thinking_mixin.py +27 -8
- nat/front_ends/mcp/mcp_front_end_config.py +5 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +8 -2
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +2 -2
- nat/front_ends/mcp/tool_converter.py +40 -13
- nat/observability/register.py +3 -1
- nat/tool/mcp/{mcp_client.py → mcp_client_base.py} +197 -46
- nat/tool/mcp/mcp_client_impl.py +229 -0
- nat/tool/mcp/mcp_tool.py +79 -42
- nat/tool/register.py +1 -0
- {nvidia_nat-1.3.0a20250829.dist-info → nvidia_nat-1.3.0a20250831.dist-info}/METADATA +3 -3
- {nvidia_nat-1.3.0a20250829.dist-info → nvidia_nat-1.3.0a20250831.dist-info}/RECORD +22 -21
- {nvidia_nat-1.3.0a20250829.dist-info → nvidia_nat-1.3.0a20250831.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250829.dist-info → nvidia_nat-1.3.0a20250831.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250829.dist-info → nvidia_nat-1.3.0a20250831.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250829.dist-info → nvidia_nat-1.3.0a20250831.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250829.dist-info → nvidia_nat-1.3.0a20250831.dist-info}/top_level.txt +0 -0
nat/agent/react_agent/agent.py
CHANGED
|
@@ -77,7 +77,8 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
77
77
|
retry_agent_response_parsing_errors: bool = True,
|
|
78
78
|
parse_agent_response_max_retries: int = 1,
|
|
79
79
|
tool_call_max_retries: int = 1,
|
|
80
|
-
pass_tool_call_errors_to_agent: bool = True
|
|
80
|
+
pass_tool_call_errors_to_agent: bool = True,
|
|
81
|
+
normalize_tool_input_quotes: bool = True):
|
|
81
82
|
super().__init__(llm=llm,
|
|
82
83
|
tools=tools,
|
|
83
84
|
callbacks=callbacks,
|
|
@@ -87,6 +88,7 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
87
88
|
if retry_agent_response_parsing_errors else 1)
|
|
88
89
|
self.tool_call_max_retries = tool_call_max_retries
|
|
89
90
|
self.pass_tool_call_errors_to_agent = pass_tool_call_errors_to_agent
|
|
91
|
+
self.normalize_tool_input_quotes = normalize_tool_input_quotes
|
|
90
92
|
logger.debug(
|
|
91
93
|
"%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
|
|
92
94
|
AGENT_LOG_PREFIX)
|
|
@@ -286,35 +288,45 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
286
288
|
agent_thoughts.tool_input)
|
|
287
289
|
|
|
288
290
|
# Run the tool. Try to use structured input, if possible.
|
|
291
|
+
tool_input_str = agent_thoughts.tool_input.strip()
|
|
292
|
+
|
|
289
293
|
try:
|
|
290
|
-
|
|
291
|
-
tool_input_dict = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
|
|
294
|
+
tool_input = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
|
|
292
295
|
logger.debug("%s Successfully parsed structured tool input from Action Input", AGENT_LOG_PREFIX)
|
|
293
296
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
297
|
+
except JSONDecodeError as original_ex:
|
|
298
|
+
if self.normalize_tool_input_quotes:
|
|
299
|
+
# If initial JSON parsing fails, try with quote normalization as a fallback
|
|
300
|
+
normalized_str = tool_input_str.replace("'", '"')
|
|
301
|
+
try:
|
|
302
|
+
tool_input = json.loads(normalized_str)
|
|
303
|
+
logger.debug("%s Successfully parsed structured tool input after quote normalization",
|
|
304
|
+
AGENT_LOG_PREFIX)
|
|
305
|
+
except JSONDecodeError:
|
|
306
|
+
# the quote normalization failed, use raw string input
|
|
307
|
+
logger.debug(
|
|
308
|
+
"%s Unable to parse structured tool input after quote normalization. Using Action Input as is."
|
|
309
|
+
"\nParsing error: %s",
|
|
310
|
+
AGENT_LOG_PREFIX,
|
|
311
|
+
original_ex)
|
|
312
|
+
tool_input = tool_input_str
|
|
313
|
+
else:
|
|
314
|
+
# use raw string input
|
|
315
|
+
logger.debug(
|
|
316
|
+
"%s Unable to parse structured tool input from Action Input. Using Action Input as is."
|
|
317
|
+
"\nParsing error: %s",
|
|
318
|
+
AGENT_LOG_PREFIX,
|
|
319
|
+
original_ex)
|
|
320
|
+
tool_input = tool_input_str
|
|
321
|
+
|
|
322
|
+
# Call tool once with the determined input (either parsed dict or raw string)
|
|
323
|
+
tool_response = await self._call_tool(requested_tool,
|
|
324
|
+
tool_input,
|
|
325
|
+
RunnableConfig(callbacks=self.callbacks),
|
|
326
|
+
max_retries=self.tool_call_max_retries)
|
|
315
327
|
|
|
316
328
|
if self.detailed_logs:
|
|
317
|
-
self._log_tool_response(requested_tool.name,
|
|
329
|
+
self._log_tool_response(requested_tool.name, tool_input, str(tool_response.content))
|
|
318
330
|
|
|
319
331
|
if not self.pass_tool_call_errors_to_agent:
|
|
320
332
|
if tool_response.status == "error":
|
|
@@ -62,6 +62,10 @@ class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
|
|
|
62
62
|
include_tool_input_schema_in_tool_description: bool = Field(
|
|
63
63
|
default=True, description="Specify inclusion of tool input schemas in the prompt.")
|
|
64
64
|
description: str = Field(default="ReAct Agent Workflow", description="The description of this functions use.")
|
|
65
|
+
normalize_tool_input_quotes: bool = Field(
|
|
66
|
+
default=True,
|
|
67
|
+
description="Whether to replace single quotes with double quotes in the tool input. "
|
|
68
|
+
"This is useful for tools that expect structured json input.")
|
|
65
69
|
system_prompt: str | None = Field(
|
|
66
70
|
default=None,
|
|
67
71
|
description="Provides the SYSTEM_PROMPT to use with the agent") # defaults to SYSTEM_PROMPT in prompt.py
|
|
@@ -107,7 +111,8 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
107
111
|
retry_agent_response_parsing_errors=config.retry_agent_response_parsing_errors,
|
|
108
112
|
parse_agent_response_max_retries=config.parse_agent_response_max_retries,
|
|
109
113
|
tool_call_max_retries=config.tool_call_max_retries,
|
|
110
|
-
pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent
|
|
114
|
+
pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent,
|
|
115
|
+
normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()
|
|
111
116
|
|
|
112
117
|
async def _response_fn(input_message: ChatRequest) -> ChatResponse:
|
|
113
118
|
try:
|
nat/builder/workflow.py
CHANGED
|
@@ -84,7 +84,11 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
|
|
|
84
84
|
return self._entry_fn.has_single_output
|
|
85
85
|
|
|
86
86
|
async def get_all_exporters(self) -> dict[str, BaseExporter]:
|
|
87
|
-
return await self.
|
|
87
|
+
return await self.exporter_manager.get_all_exporters()
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def exporter_manager(self) -> ExporterManager:
|
|
91
|
+
return self._exporter_manager.get()
|
|
88
92
|
|
|
89
93
|
@asynccontextmanager
|
|
90
94
|
async def run(self, message: InputT):
|
|
@@ -96,7 +100,7 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
|
|
|
96
100
|
async with Runner(input_message=message,
|
|
97
101
|
entry_fn=self._entry_fn,
|
|
98
102
|
context_state=self._context_state,
|
|
99
|
-
exporter_manager=self.
|
|
103
|
+
exporter_manager=self.exporter_manager) as runner:
|
|
100
104
|
|
|
101
105
|
# The caller can `yield runner` so they can do `runner.result()` or `runner.result_stream()`
|
|
102
106
|
yield runner
|
|
@@ -23,14 +23,36 @@ import click
|
|
|
23
23
|
from pydantic import BaseModel
|
|
24
24
|
|
|
25
25
|
from nat.tool.mcp.exceptions import MCPError
|
|
26
|
-
from nat.tool.mcp.mcp_client import MCPBuilder
|
|
27
26
|
from nat.utils.exception_handlers.mcp import format_mcp_error
|
|
28
27
|
|
|
29
28
|
# Suppress verbose logs from mcp.client.sse and httpx
|
|
30
29
|
logging.getLogger("mcp.client.sse").setLevel(logging.WARNING)
|
|
31
30
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
32
31
|
|
|
33
|
-
|
|
32
|
+
|
|
33
|
+
def validate_transport_cli_args(transport: str, command: str | None, args: str | None, env: str | None) -> bool:
|
|
34
|
+
"""
|
|
35
|
+
Validate transport and parameter combinations, returning False if invalid.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
transport: The transport type ('sse', 'stdio', or 'streamable-http')
|
|
39
|
+
command: Command for stdio transport
|
|
40
|
+
args: Arguments for stdio transport
|
|
41
|
+
env: Environment variables for stdio transport
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
bool: True if valid, False if invalid (error message already displayed)
|
|
45
|
+
"""
|
|
46
|
+
if transport == 'stdio':
|
|
47
|
+
if not command:
|
|
48
|
+
click.echo("--command is required when using stdio client type", err=True)
|
|
49
|
+
return False
|
|
50
|
+
elif transport in ['sse', 'streamable-http']:
|
|
51
|
+
if command or args or env:
|
|
52
|
+
click.echo("--command, --args, and --env are not allowed when using sse or streamable-http client type",
|
|
53
|
+
err=True)
|
|
54
|
+
return False
|
|
55
|
+
return True
|
|
34
56
|
|
|
35
57
|
|
|
36
58
|
class MCPPingResult(BaseModel):
|
|
@@ -64,12 +86,20 @@ def format_tool(tool: Any) -> dict[str, str | None]:
|
|
|
64
86
|
description = getattr(tool, 'description', '')
|
|
65
87
|
input_schema = getattr(tool, 'input_schema', None) or getattr(tool, 'inputSchema', None)
|
|
66
88
|
|
|
67
|
-
|
|
68
|
-
if input_schema:
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
89
|
+
# Normalize schema to JSON string
|
|
90
|
+
if input_schema is None:
|
|
91
|
+
return {
|
|
92
|
+
"name": name,
|
|
93
|
+
"description": description,
|
|
94
|
+
"input_schema": None,
|
|
95
|
+
}
|
|
96
|
+
elif hasattr(input_schema, "schema_json"):
|
|
97
|
+
schema_str = input_schema.schema_json(indent=2)
|
|
98
|
+
elif isinstance(input_schema, dict):
|
|
99
|
+
schema_str = json.dumps(input_schema, indent=2)
|
|
100
|
+
else:
|
|
101
|
+
# Final fallback: attempt to dump stringified version wrapped as JSON string
|
|
102
|
+
schema_str = json.dumps({"raw": str(input_schema)}, indent=2)
|
|
73
103
|
|
|
74
104
|
return {
|
|
75
105
|
"name": name,
|
|
@@ -100,8 +130,8 @@ def print_tool(tool_dict: dict[str, str | None], detail: bool = False) -> None:
|
|
|
100
130
|
click.echo("-" * 60)
|
|
101
131
|
|
|
102
132
|
|
|
103
|
-
async def list_tools_and_schemas(url
|
|
104
|
-
"""List MCP tools using
|
|
133
|
+
async def list_tools_and_schemas(command, url, tool_name=None, transport='sse', args=None, env=None):
|
|
134
|
+
"""List MCP tools using NAT MCPClient with structured exception handling.
|
|
105
135
|
|
|
106
136
|
Args:
|
|
107
137
|
url (str): MCP server URL to connect to
|
|
@@ -115,20 +145,35 @@ async def list_tools_and_schemas(url: str, tool_name: str | None = None) -> list
|
|
|
115
145
|
Raises:
|
|
116
146
|
MCPError: Caught internally and logged, returns empty list instead
|
|
117
147
|
"""
|
|
118
|
-
|
|
148
|
+
from nat.tool.mcp.mcp_client_base import MCPSSEClient
|
|
149
|
+
from nat.tool.mcp.mcp_client_base import MCPStdioClient
|
|
150
|
+
from nat.tool.mcp.mcp_client_base import MCPStreamableHTTPClient
|
|
151
|
+
|
|
152
|
+
if args is None:
|
|
153
|
+
args = []
|
|
154
|
+
|
|
119
155
|
try:
|
|
120
|
-
if
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
156
|
+
if transport == 'stdio':
|
|
157
|
+
client = MCPStdioClient(command=command, args=args, env=env)
|
|
158
|
+
elif transport == 'streamable-http':
|
|
159
|
+
client = MCPStreamableHTTPClient(url=url)
|
|
160
|
+
else: # sse
|
|
161
|
+
client = MCPSSEClient(url=url)
|
|
162
|
+
|
|
163
|
+
async with client:
|
|
164
|
+
if tool_name:
|
|
165
|
+
tool = await client.get_tool(tool_name)
|
|
166
|
+
return [format_tool(tool)]
|
|
167
|
+
else:
|
|
168
|
+
tools = await client.get_tools()
|
|
169
|
+
return [format_tool(tool) for tool in tools.values()]
|
|
125
170
|
except MCPError as e:
|
|
126
171
|
format_mcp_error(e, include_traceback=False)
|
|
127
172
|
return []
|
|
128
173
|
|
|
129
174
|
|
|
130
|
-
async def list_tools_direct(url
|
|
131
|
-
"""List MCP tools using direct MCP protocol with exception
|
|
175
|
+
async def list_tools_direct(command, url, tool_name=None, transport='sse', args=None, env=None):
|
|
176
|
+
"""List MCP tools using direct MCP protocol with structured exception handling.
|
|
132
177
|
|
|
133
178
|
Bypasses MCPBuilder and uses raw MCP ClientSession and SSE client directly.
|
|
134
179
|
Converts raw exceptions to structured MCPErrors for consistent user experience.
|
|
@@ -147,25 +192,51 @@ async def list_tools_direct(url: str, tool_name: str | None = None) -> list[dict
|
|
|
147
192
|
This function handles ExceptionGroup by extracting the most relevant exception
|
|
148
193
|
and converting it to MCPError for consistent error reporting.
|
|
149
194
|
"""
|
|
195
|
+
if args is None:
|
|
196
|
+
args = []
|
|
150
197
|
from mcp import ClientSession
|
|
151
198
|
from mcp.client.sse import sse_client
|
|
199
|
+
from mcp.client.stdio import StdioServerParameters
|
|
200
|
+
from mcp.client.stdio import stdio_client
|
|
201
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
152
202
|
|
|
153
203
|
try:
|
|
154
|
-
|
|
204
|
+
if transport == 'stdio':
|
|
205
|
+
|
|
206
|
+
def get_stdio_client():
|
|
207
|
+
return stdio_client(server=StdioServerParameters(command=command, args=args, env=env))
|
|
208
|
+
|
|
209
|
+
client = get_stdio_client
|
|
210
|
+
elif transport == 'streamable-http':
|
|
211
|
+
|
|
212
|
+
def get_streamable_http_client():
|
|
213
|
+
return streamablehttp_client(url=url)
|
|
214
|
+
|
|
215
|
+
client = get_streamable_http_client
|
|
216
|
+
else:
|
|
217
|
+
|
|
218
|
+
def get_sse_client():
|
|
219
|
+
return sse_client(url=url)
|
|
220
|
+
|
|
221
|
+
client = get_sse_client
|
|
222
|
+
|
|
223
|
+
async with client() as ctx:
|
|
224
|
+
read, write = (ctx[0], ctx[1]) if isinstance(ctx, tuple) else ctx
|
|
155
225
|
async with ClientSession(read, write) as session:
|
|
156
226
|
await session.initialize()
|
|
157
227
|
response = await session.list_tools()
|
|
158
228
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
229
|
+
tools = []
|
|
230
|
+
for tool in response.tools:
|
|
231
|
+
if tool_name:
|
|
232
|
+
if tool.name == tool_name:
|
|
233
|
+
tools.append(format_tool(tool))
|
|
234
|
+
else:
|
|
235
|
+
tools.append(format_tool(tool))
|
|
236
|
+
|
|
237
|
+
if tool_name and not tools:
|
|
238
|
+
click.echo(f"[INFO] Tool '{tool_name}' not found.")
|
|
239
|
+
return tools
|
|
169
240
|
except Exception as e:
|
|
170
241
|
# Convert raw exceptions to structured MCPError for consistency
|
|
171
242
|
from nat.utils.exception_handlers.mcp import convert_to_mcp_error
|
|
@@ -181,7 +252,12 @@ async def list_tools_direct(url: str, tool_name: str | None = None) -> list[dict
|
|
|
181
252
|
return []
|
|
182
253
|
|
|
183
254
|
|
|
184
|
-
async def ping_mcp_server(url: str,
|
|
255
|
+
async def ping_mcp_server(url: str,
|
|
256
|
+
timeout: int,
|
|
257
|
+
transport: str = 'streamable-http',
|
|
258
|
+
command: str | None = None,
|
|
259
|
+
args: list[str] | None = None,
|
|
260
|
+
env: dict[str, str] | None = None) -> MCPPingResult:
|
|
185
261
|
"""Ping an MCP server to check if it's responsive.
|
|
186
262
|
|
|
187
263
|
Args:
|
|
@@ -193,18 +269,29 @@ async def ping_mcp_server(url: str, timeout: int) -> MCPPingResult:
|
|
|
193
269
|
"""
|
|
194
270
|
from mcp.client.session import ClientSession
|
|
195
271
|
from mcp.client.sse import sse_client
|
|
272
|
+
from mcp.client.stdio import StdioServerParameters
|
|
273
|
+
from mcp.client.stdio import stdio_client
|
|
274
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
196
275
|
|
|
197
276
|
async def _ping_operation():
|
|
198
|
-
|
|
277
|
+
# Select transport
|
|
278
|
+
if transport == 'stdio':
|
|
279
|
+
stdio_args_local: list[str] = args or []
|
|
280
|
+
if not command:
|
|
281
|
+
raise RuntimeError("--command is required for stdio transport")
|
|
282
|
+
client_ctx = stdio_client(server=StdioServerParameters(command=command, args=stdio_args_local, env=env))
|
|
283
|
+
elif transport == 'sse':
|
|
284
|
+
client_ctx = sse_client(url)
|
|
285
|
+
else: # streamable-http
|
|
286
|
+
client_ctx = streamablehttp_client(url=url)
|
|
287
|
+
|
|
288
|
+
async with client_ctx as ctx:
|
|
289
|
+
read, write = (ctx[0], ctx[1]) if isinstance(ctx, tuple) else ctx
|
|
199
290
|
async with ClientSession(read, write) as session:
|
|
200
|
-
# Initialize the session
|
|
201
291
|
await session.initialize()
|
|
202
292
|
|
|
203
|
-
# Record start time just before ping
|
|
204
293
|
start_time = time.time()
|
|
205
|
-
# Send ping request
|
|
206
294
|
await session.send_ping()
|
|
207
|
-
|
|
208
295
|
end_time = time.time()
|
|
209
296
|
response_time_ms = round((end_time - start_time) * 1000, 2)
|
|
210
297
|
|
|
@@ -226,12 +313,24 @@ async def ping_mcp_server(url: str, timeout: int) -> MCPPingResult:
|
|
|
226
313
|
|
|
227
314
|
@click.group(invoke_without_command=True, help="List tool names (default), or show details with --detail or --tool.")
|
|
228
315
|
@click.option('--direct', is_flag=True, help='Bypass MCPBuilder and use direct MCP protocol')
|
|
229
|
-
@click.option(
|
|
316
|
+
@click.option(
|
|
317
|
+
'--url',
|
|
318
|
+
default='http://localhost:9901/mcp',
|
|
319
|
+
show_default=True,
|
|
320
|
+
help='MCP server URL (e.g. http://localhost:8080/mcp for streamable-http, http://localhost:8080/sse for sse)')
|
|
321
|
+
@click.option('--transport',
|
|
322
|
+
type=click.Choice(['sse', 'stdio', 'streamable-http']),
|
|
323
|
+
default='streamable-http',
|
|
324
|
+
show_default=True,
|
|
325
|
+
help='Type of client to use (default: streamable-http, backwards compatible with sse)')
|
|
326
|
+
@click.option('--command', help='For stdio: The command to run (e.g. mcp-server)')
|
|
327
|
+
@click.option('--args', help='For stdio: Additional arguments for the command (space-separated)')
|
|
328
|
+
@click.option('--env', help='For stdio: Environment variables in KEY=VALUE format (space-separated)')
|
|
230
329
|
@click.option('--tool', default=None, help='Get details for a specific tool by name')
|
|
231
330
|
@click.option('--detail', is_flag=True, help='Show full details for all tools')
|
|
232
331
|
@click.option('--json-output', is_flag=True, help='Output tool metadata in JSON format')
|
|
233
332
|
@click.pass_context
|
|
234
|
-
def list_mcp(ctx
|
|
333
|
+
def list_mcp(ctx, direct, url, transport, command, args, env, tool, detail, json_output):
|
|
235
334
|
"""List MCP tool names (default) or show detailed tool information.
|
|
236
335
|
|
|
237
336
|
Use --detail for full output including descriptions and input schemas.
|
|
@@ -242,7 +341,7 @@ def list_mcp(ctx: click.Context, direct: bool, url: str, tool: str | None, detai
|
|
|
242
341
|
Args:
|
|
243
342
|
ctx (click.Context): Click context object for command invocation
|
|
244
343
|
direct (bool): Whether to bypass MCPBuilder and use direct MCP protocol
|
|
245
|
-
url (str): MCP server URL to connect to (default: http://localhost:9901/
|
|
344
|
+
url (str): MCP server URL to connect to (default: http://localhost:9901/mcp)
|
|
246
345
|
tool (str | None): Optional specific tool name to retrieve detailed info for
|
|
247
346
|
detail (bool): Whether to show full details (description + schema) for all tools
|
|
248
347
|
json_output (bool): Whether to output tool metadata in JSON format instead of text
|
|
@@ -256,44 +355,81 @@ def list_mcp(ctx: click.Context, direct: bool, url: str, tool: str | None, detai
|
|
|
256
355
|
"""
|
|
257
356
|
if ctx.invoked_subcommand is not None:
|
|
258
357
|
return
|
|
358
|
+
|
|
359
|
+
if not validate_transport_cli_args(transport, command, args, env):
|
|
360
|
+
return
|
|
361
|
+
|
|
362
|
+
if transport in ['sse', 'streamable-http']:
|
|
363
|
+
if not url:
|
|
364
|
+
click.echo("[ERROR] --url is required when using sse or streamable-http client type", err=True)
|
|
365
|
+
return
|
|
366
|
+
|
|
367
|
+
stdio_args = args.split() if args else []
|
|
368
|
+
stdio_env = dict(var.split('=', 1) for var in env.split()) if env else None
|
|
369
|
+
|
|
259
370
|
fetcher = list_tools_direct if direct else list_tools_and_schemas
|
|
260
|
-
tools = asyncio.run(fetcher(url, tool))
|
|
371
|
+
tools = asyncio.run(fetcher(command, url, tool, transport, stdio_args, stdio_env))
|
|
261
372
|
|
|
262
373
|
if json_output:
|
|
263
374
|
click.echo(json.dumps(tools, indent=2))
|
|
264
375
|
elif tool:
|
|
265
|
-
for tool_dict in tools:
|
|
376
|
+
for tool_dict in (tools or []):
|
|
266
377
|
print_tool(tool_dict, detail=True)
|
|
267
378
|
elif detail:
|
|
268
|
-
for tool_dict in tools:
|
|
379
|
+
for tool_dict in (tools or []):
|
|
269
380
|
print_tool(tool_dict, detail=True)
|
|
270
381
|
else:
|
|
271
|
-
for tool_dict in tools:
|
|
382
|
+
for tool_dict in (tools or []):
|
|
272
383
|
click.echo(tool_dict.get('name', 'Unknown tool'))
|
|
273
384
|
|
|
274
385
|
|
|
275
386
|
@list_mcp.command()
|
|
276
|
-
@click.option(
|
|
387
|
+
@click.option(
|
|
388
|
+
'--url',
|
|
389
|
+
default='http://localhost:9901/mcp',
|
|
390
|
+
show_default=True,
|
|
391
|
+
help='MCP server URL (e.g. http://localhost:8080/mcp for streamable-http, http://localhost:8080/sse for sse)')
|
|
392
|
+
@click.option('--transport',
|
|
393
|
+
type=click.Choice(['sse', 'stdio', 'streamable-http']),
|
|
394
|
+
default='streamable-http',
|
|
395
|
+
show_default=True,
|
|
396
|
+
help='Type of client to use for ping')
|
|
397
|
+
@click.option('--command', help='For stdio: The command to run (e.g. mcp-server)')
|
|
398
|
+
@click.option('--args', help='For stdio: Additional arguments for the command (space-separated)')
|
|
399
|
+
@click.option('--env', help='For stdio: Environment variables in KEY=VALUE format (space-separated)')
|
|
277
400
|
@click.option('--timeout', default=60, show_default=True, help='Timeout in seconds for ping request')
|
|
278
401
|
@click.option('--json-output', is_flag=True, help='Output ping result in JSON format')
|
|
279
|
-
def ping(url: str,
|
|
402
|
+
def ping(url: str,
|
|
403
|
+
transport: str,
|
|
404
|
+
command: str | None,
|
|
405
|
+
args: str | None,
|
|
406
|
+
env: str | None,
|
|
407
|
+
timeout: int,
|
|
408
|
+
json_output: bool) -> None:
|
|
280
409
|
"""Ping an MCP server to check if it's responsive.
|
|
281
410
|
|
|
282
411
|
This command sends a ping request to the MCP server and measures the response time.
|
|
283
412
|
It's useful for health checks and monitoring server availability.
|
|
284
413
|
|
|
285
414
|
Args:
|
|
286
|
-
url (str): MCP server URL to ping (default: http://localhost:9901/
|
|
415
|
+
url (str): MCP server URL to ping (default: http://localhost:9901/mcp)
|
|
287
416
|
timeout (int): Timeout in seconds for the ping request (default: 60)
|
|
288
417
|
json_output (bool): Whether to output the result in JSON format
|
|
289
418
|
|
|
290
419
|
Examples:
|
|
291
420
|
nat info mcp ping # Ping default server
|
|
292
|
-
nat info mcp ping --url http://custom-server:9901/
|
|
421
|
+
nat info mcp ping --url http://custom-server:9901/mcp # Ping custom server
|
|
293
422
|
nat info mcp ping --timeout 10 # Use 10 second timeout
|
|
294
423
|
nat info mcp ping --json-output # Get JSON format output
|
|
295
424
|
"""
|
|
296
|
-
|
|
425
|
+
# Validate combinations similar to parent command
|
|
426
|
+
if not validate_transport_cli_args(transport, command, args, env):
|
|
427
|
+
return
|
|
428
|
+
|
|
429
|
+
stdio_args = args.split() if args else []
|
|
430
|
+
stdio_env = dict(var.split('=', 1) for var in env.split()) if env else None
|
|
431
|
+
|
|
432
|
+
result = asyncio.run(ping_mcp_server(url, timeout, transport, command, stdio_args, stdio_env))
|
|
297
433
|
|
|
298
434
|
if json_output:
|
|
299
435
|
click.echo(result.model_dump_json(indent=2))
|
nat/cli/commands/start.py
CHANGED
|
@@ -102,12 +102,24 @@ class StartCommandGroup(click.Group):
|
|
|
102
102
|
raise ValueError(f"Invalid field '{name}'.Unions are only supported for optional parameters.")
|
|
103
103
|
|
|
104
104
|
# Handle the types
|
|
105
|
-
|
|
105
|
+
# Literal[...] -> map to click.Choice([...])
|
|
106
|
+
if (decomposed_type.origin is typing.Literal):
|
|
107
|
+
# typing.get_args returns the literal values; ensure they are strings for Click
|
|
108
|
+
literal_values = [str(v) for v in decomposed_type.args]
|
|
109
|
+
param_type = click.Choice(literal_values)
|
|
110
|
+
|
|
111
|
+
elif (issubclass(decomposed_type.root, Path)):
|
|
106
112
|
param_type = click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path)
|
|
107
113
|
|
|
108
114
|
elif (issubclass(decomposed_type.root, (list, tuple, set))):
|
|
109
115
|
if (len(decomposed_type.args) == 1):
|
|
110
|
-
|
|
116
|
+
inner = DecomposedType(decomposed_type.args[0])
|
|
117
|
+
# Support containers of Literal values -> multiple Choice
|
|
118
|
+
if (inner.origin is typing.Literal):
|
|
119
|
+
literal_values = [str(v) for v in inner.args]
|
|
120
|
+
param_type = click.Choice(literal_values)
|
|
121
|
+
else:
|
|
122
|
+
param_type = inner.root
|
|
111
123
|
else:
|
|
112
124
|
param_type = None
|
|
113
125
|
|
|
@@ -22,8 +22,7 @@ from nat.data_models.gated_field_mixin import GatedFieldMixin
|
|
|
22
22
|
|
|
23
23
|
# The system prompt format for thinking is different for these, so we need to distinguish them here with two separate
|
|
24
24
|
# regex patterns
|
|
25
|
-
|
|
26
|
-
_LLAMA_NEMOTRON_REGEX = re.compile(r"^nvidia/llama.*nemotron", re.IGNORECASE)
|
|
25
|
+
_NEMOTRON_REGEX = re.compile(r"^nvidia/(llama|nvidia).*nemotron", re.IGNORECASE)
|
|
27
26
|
_MODEL_KEYS = ("model_name", "model", "azure_deployment")
|
|
28
27
|
|
|
29
28
|
|
|
@@ -33,7 +32,7 @@ class ThinkingMixin(
|
|
|
33
32
|
field_name="thinking",
|
|
34
33
|
default_if_supported=None,
|
|
35
34
|
keys=_MODEL_KEYS,
|
|
36
|
-
supported=(
|
|
35
|
+
supported=(_NEMOTRON_REGEX, ),
|
|
37
36
|
):
|
|
38
37
|
"""
|
|
39
38
|
Mixin class for thinking configuration. Only supported on Nemotron models.
|
|
@@ -52,7 +51,8 @@ class ThinkingMixin(
|
|
|
52
51
|
"""
|
|
53
52
|
Returns the system prompt to use for thinking.
|
|
54
53
|
For NVIDIA Nemotron, returns "/think" if enabled, else "/no_think".
|
|
55
|
-
For Llama Nemotron, returns "
|
|
54
|
+
For Llama Nemotron v1.5, returns "/think" if enabled, else "/no_think".
|
|
55
|
+
For Llama Nemotron v1.0, returns "detailed thinking on" if enabled, else "detailed thinking off".
|
|
56
56
|
If thinking is not supported on the model, returns None.
|
|
57
57
|
|
|
58
58
|
Returns:
|
|
@@ -60,9 +60,28 @@ class ThinkingMixin(
|
|
|
60
60
|
"""
|
|
61
61
|
if self.thinking is None:
|
|
62
62
|
return None
|
|
63
|
+
|
|
63
64
|
for key in _MODEL_KEYS:
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
65
|
+
model = getattr(self, key, None)
|
|
66
|
+
if not isinstance(model, str) or model is None:
|
|
67
|
+
continue
|
|
68
|
+
|
|
69
|
+
# Normalize name to reduce checks
|
|
70
|
+
model = model.lower().translate(str.maketrans("_.", "--"))
|
|
71
|
+
|
|
72
|
+
if model.startswith("nvidia/nvidia"):
|
|
73
|
+
return "/think" if self.thinking else "/no_think"
|
|
74
|
+
|
|
75
|
+
if model.startswith("nvidia/llama"):
|
|
76
|
+
if "v1-0" in model or "v1-1" in model:
|
|
68
77
|
return f"detailed thinking {'on' if self.thinking else 'off'}"
|
|
78
|
+
|
|
79
|
+
if "v1-5" in model:
|
|
80
|
+
# v1.5 models are updated to use the /think and /no_think system prompts
|
|
81
|
+
return "/think" if self.thinking else "/no_think"
|
|
82
|
+
|
|
83
|
+
# Assume any other model is a newer model that uses the /think and /no_think system prompts
|
|
84
|
+
return "/think" if self.thinking else "/no_think"
|
|
85
|
+
|
|
86
|
+
# Unknown model
|
|
87
|
+
return None
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
from typing import Literal
|
|
17
|
+
|
|
16
18
|
from pydantic import Field
|
|
17
19
|
|
|
18
20
|
from nat.data_models.front_end import FrontEndBaseConfig
|
|
@@ -32,5 +34,8 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
|
32
34
|
log_level: str = Field(default="INFO", description="Log level for the MCP server (default: INFO)")
|
|
33
35
|
tool_names: list[str] = Field(default_factory=list,
|
|
34
36
|
description="The list of tools MCP server will expose (default: all tools)")
|
|
37
|
+
transport: Literal["sse", "streamable-http"] = Field(
|
|
38
|
+
default="streamable-http",
|
|
39
|
+
description="Transport type for the MCP server (default: streamable-http, backwards compatible with sse)")
|
|
35
40
|
runner_class: str | None = Field(
|
|
36
41
|
default=None, description="Custom worker class for handling MCP routes (default: built-in worker)")
|
|
@@ -77,5 +77,11 @@ class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
|
|
|
77
77
|
# Add routes through the worker (includes health endpoint and function registration)
|
|
78
78
|
await worker.add_routes(mcp, builder)
|
|
79
79
|
|
|
80
|
-
# Start the MCP server
|
|
81
|
-
|
|
80
|
+
# Start the MCP server with configurable transport
|
|
81
|
+
# streamable-http is the default, but users can choose sse if preferred
|
|
82
|
+
if self.front_end_config.transport == "sse":
|
|
83
|
+
logger.info("Starting MCP server with SSE endpoint at /sse")
|
|
84
|
+
await mcp.run_sse_async()
|
|
85
|
+
else: # streamable-http
|
|
86
|
+
logger.info("Starting MCP server with streamable-http endpoint at /mcp/")
|
|
87
|
+
await mcp.run_streamable_http_async()
|
|
@@ -134,9 +134,9 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
|
|
|
134
134
|
logger.debug("Skipping function %s as it's not in tool_names", function_name)
|
|
135
135
|
functions = filtered_functions
|
|
136
136
|
|
|
137
|
-
# Register each function with MCP
|
|
137
|
+
# Register each function with MCP, passing workflow context for observability
|
|
138
138
|
for function_name, function in functions.items():
|
|
139
|
-
register_function_with_mcp(mcp, function_name, function)
|
|
139
|
+
register_function_with_mcp(mcp, function_name, function, workflow)
|
|
140
140
|
|
|
141
141
|
# Add a simple fallback function if no functions were found
|
|
142
142
|
if not functions:
|