nvidia-nat 1.3.0a20250829__py3-none-any.whl → 1.3.0a20250830__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -77,7 +77,8 @@ class ReActAgentGraph(DualNodeAgent):
77
77
  retry_agent_response_parsing_errors: bool = True,
78
78
  parse_agent_response_max_retries: int = 1,
79
79
  tool_call_max_retries: int = 1,
80
- pass_tool_call_errors_to_agent: bool = True):
80
+ pass_tool_call_errors_to_agent: bool = True,
81
+ normalize_tool_input_quotes: bool = True):
81
82
  super().__init__(llm=llm,
82
83
  tools=tools,
83
84
  callbacks=callbacks,
@@ -87,6 +88,7 @@ class ReActAgentGraph(DualNodeAgent):
87
88
  if retry_agent_response_parsing_errors else 1)
88
89
  self.tool_call_max_retries = tool_call_max_retries
89
90
  self.pass_tool_call_errors_to_agent = pass_tool_call_errors_to_agent
91
+ self.normalize_tool_input_quotes = normalize_tool_input_quotes
90
92
  logger.debug(
91
93
  "%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
92
94
  AGENT_LOG_PREFIX)
@@ -286,35 +288,45 @@ class ReActAgentGraph(DualNodeAgent):
286
288
  agent_thoughts.tool_input)
287
289
 
288
290
  # Run the tool. Try to use structured input, if possible.
291
+ tool_input_str = agent_thoughts.tool_input.strip()
292
+
289
293
  try:
290
- tool_input_str = str(agent_thoughts.tool_input).strip().replace("'", '"')
291
- tool_input_dict = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
294
+ tool_input = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
292
295
  logger.debug("%s Successfully parsed structured tool input from Action Input", AGENT_LOG_PREFIX)
293
296
 
294
- tool_response = await self._call_tool(requested_tool,
295
- tool_input_dict,
296
- RunnableConfig(callbacks=self.callbacks),
297
- max_retries=self.tool_call_max_retries)
298
-
299
- if self.detailed_logs:
300
- self._log_tool_response(requested_tool.name, tool_input_dict, str(tool_response.content))
301
-
302
- except JSONDecodeError as ex:
303
- logger.debug(
304
- "%s Unable to parse structured tool input from Action Input. Using Action Input as is."
305
- "\nParsing error: %s",
306
- AGENT_LOG_PREFIX,
307
- ex,
308
- exc_info=True)
309
- tool_input_str = str(agent_thoughts.tool_input)
310
-
311
- tool_response = await self._call_tool(requested_tool,
312
- tool_input_str,
313
- RunnableConfig(callbacks=self.callbacks),
314
- max_retries=self.tool_call_max_retries)
297
+ except JSONDecodeError as original_ex:
298
+ if self.normalize_tool_input_quotes:
299
+ # If initial JSON parsing fails, try with quote normalization as a fallback
300
+ normalized_str = tool_input_str.replace("'", '"')
301
+ try:
302
+ tool_input = json.loads(normalized_str)
303
+ logger.debug("%s Successfully parsed structured tool input after quote normalization",
304
+ AGENT_LOG_PREFIX)
305
+ except JSONDecodeError:
306
+ # the quote normalization failed, use raw string input
307
+ logger.debug(
308
+ "%s Unable to parse structured tool input after quote normalization. Using Action Input as is."
309
+ "\nParsing error: %s",
310
+ AGENT_LOG_PREFIX,
311
+ original_ex)
312
+ tool_input = tool_input_str
313
+ else:
314
+ # use raw string input
315
+ logger.debug(
316
+ "%s Unable to parse structured tool input from Action Input. Using Action Input as is."
317
+ "\nParsing error: %s",
318
+ AGENT_LOG_PREFIX,
319
+ original_ex)
320
+ tool_input = tool_input_str
321
+
322
+ # Call tool once with the determined input (either parsed dict or raw string)
323
+ tool_response = await self._call_tool(requested_tool,
324
+ tool_input,
325
+ RunnableConfig(callbacks=self.callbacks),
326
+ max_retries=self.tool_call_max_retries)
315
327
 
316
328
  if self.detailed_logs:
317
- self._log_tool_response(requested_tool.name, tool_input_str, str(tool_response.content))
329
+ self._log_tool_response(requested_tool.name, tool_input, str(tool_response.content))
318
330
 
319
331
  if not self.pass_tool_call_errors_to_agent:
320
332
  if tool_response.status == "error":
@@ -62,6 +62,10 @@ class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
62
62
  include_tool_input_schema_in_tool_description: bool = Field(
63
63
  default=True, description="Specify inclusion of tool input schemas in the prompt.")
64
64
  description: str = Field(default="ReAct Agent Workflow", description="The description of this functions use.")
65
+ normalize_tool_input_quotes: bool = Field(
66
+ default=True,
67
+ description="Whether to replace single quotes with double quotes in the tool input. "
68
+ "This is useful for tools that expect structured json input.")
65
69
  system_prompt: str | None = Field(
66
70
  default=None,
67
71
  description="Provides the SYSTEM_PROMPT to use with the agent") # defaults to SYSTEM_PROMPT in prompt.py
@@ -107,7 +111,8 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
107
111
  retry_agent_response_parsing_errors=config.retry_agent_response_parsing_errors,
108
112
  parse_agent_response_max_retries=config.parse_agent_response_max_retries,
109
113
  tool_call_max_retries=config.tool_call_max_retries,
110
- pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent).build_graph()
114
+ pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent,
115
+ normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()
111
116
 
112
117
  async def _response_fn(input_message: ChatRequest) -> ChatResponse:
113
118
  try:
nat/builder/workflow.py CHANGED
@@ -84,7 +84,11 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
84
84
  return self._entry_fn.has_single_output
85
85
 
86
86
  async def get_all_exporters(self) -> dict[str, BaseExporter]:
87
- return await self._exporter_manager.get_all_exporters()
87
+ return await self.exporter_manager.get_all_exporters()
88
+
89
+ @property
90
+ def exporter_manager(self) -> ExporterManager:
91
+ return self._exporter_manager.get()
88
92
 
89
93
  @asynccontextmanager
90
94
  async def run(self, message: InputT):
@@ -96,7 +100,7 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
96
100
  async with Runner(input_message=message,
97
101
  entry_fn=self._entry_fn,
98
102
  context_state=self._context_state,
99
- exporter_manager=self._exporter_manager.get()) as runner:
103
+ exporter_manager=self.exporter_manager) as runner:
100
104
 
101
105
  # The caller can `yield runner` so they can do `runner.result()` or `runner.result_stream()`
102
106
  yield runner
@@ -23,14 +23,36 @@ import click
23
23
  from pydantic import BaseModel
24
24
 
25
25
  from nat.tool.mcp.exceptions import MCPError
26
- from nat.tool.mcp.mcp_client import MCPBuilder
27
26
  from nat.utils.exception_handlers.mcp import format_mcp_error
28
27
 
29
28
  # Suppress verbose logs from mcp.client.sse and httpx
30
29
  logging.getLogger("mcp.client.sse").setLevel(logging.WARNING)
31
30
  logging.getLogger("httpx").setLevel(logging.WARNING)
32
31
 
33
- logger = logging.getLogger(__name__)
32
+
33
+ def validate_transport_cli_args(transport: str, command: str | None, args: str | None, env: str | None) -> bool:
34
+ """
35
+ Validate transport and parameter combinations, returning False if invalid.
36
+
37
+ Args:
38
+ transport: The transport type ('sse', 'stdio', or 'streamable-http')
39
+ command: Command for stdio transport
40
+ args: Arguments for stdio transport
41
+ env: Environment variables for stdio transport
42
+
43
+ Returns:
44
+ bool: True if valid, False if invalid (error message already displayed)
45
+ """
46
+ if transport == 'stdio':
47
+ if not command:
48
+ click.echo("--command is required when using stdio client type", err=True)
49
+ return False
50
+ elif transport in ['sse', 'streamable-http']:
51
+ if command or args or env:
52
+ click.echo("--command, --args, and --env are not allowed when using sse or streamable-http client type",
53
+ err=True)
54
+ return False
55
+ return True
34
56
 
35
57
 
36
58
  class MCPPingResult(BaseModel):
@@ -64,12 +86,20 @@ def format_tool(tool: Any) -> dict[str, str | None]:
64
86
  description = getattr(tool, 'description', '')
65
87
  input_schema = getattr(tool, 'input_schema', None) or getattr(tool, 'inputSchema', None)
66
88
 
67
- schema_str = None
68
- if input_schema:
69
- if hasattr(input_schema, "schema_json"):
70
- schema_str = input_schema.schema_json(indent=2)
71
- else:
72
- schema_str = str(input_schema)
89
+ # Normalize schema to JSON string
90
+ if input_schema is None:
91
+ return {
92
+ "name": name,
93
+ "description": description,
94
+ "input_schema": None,
95
+ }
96
+ elif hasattr(input_schema, "schema_json"):
97
+ schema_str = input_schema.schema_json(indent=2)
98
+ elif isinstance(input_schema, dict):
99
+ schema_str = json.dumps(input_schema, indent=2)
100
+ else:
101
+ # Final fallback: attempt to dump stringified version wrapped as JSON string
102
+ schema_str = json.dumps({"raw": str(input_schema)}, indent=2)
73
103
 
74
104
  return {
75
105
  "name": name,
@@ -100,8 +130,8 @@ def print_tool(tool_dict: dict[str, str | None], detail: bool = False) -> None:
100
130
  click.echo("-" * 60)
101
131
 
102
132
 
103
- async def list_tools_and_schemas(url: str, tool_name: str | None = None) -> list[dict[str, str | None]]:
104
- """List MCP tools using MCPBuilder with structured exception handling.
133
+ async def list_tools_and_schemas(command, url, tool_name=None, transport='sse', args=None, env=None):
134
+ """List MCP tools using NAT MCPClient with structured exception handling.
105
135
 
106
136
  Args:
107
137
  url (str): MCP server URL to connect to
@@ -115,20 +145,35 @@ async def list_tools_and_schemas(url: str, tool_name: str | None = None) -> list
115
145
  Raises:
116
146
  MCPError: Caught internally and logged, returns empty list instead
117
147
  """
118
- builder = MCPBuilder(url=url)
148
+ from nat.tool.mcp.mcp_client_base import MCPSSEClient
149
+ from nat.tool.mcp.mcp_client_base import MCPStdioClient
150
+ from nat.tool.mcp.mcp_client_base import MCPStreamableHTTPClient
151
+
152
+ if args is None:
153
+ args = []
154
+
119
155
  try:
120
- if tool_name:
121
- tool = await builder.get_tool(tool_name)
122
- return [format_tool(tool)]
123
- tools = await builder.get_tools()
124
- return [format_tool(tool) for tool in tools.values()]
156
+ if transport == 'stdio':
157
+ client = MCPStdioClient(command=command, args=args, env=env)
158
+ elif transport == 'streamable-http':
159
+ client = MCPStreamableHTTPClient(url=url)
160
+ else: # sse
161
+ client = MCPSSEClient(url=url)
162
+
163
+ async with client:
164
+ if tool_name:
165
+ tool = await client.get_tool(tool_name)
166
+ return [format_tool(tool)]
167
+ else:
168
+ tools = await client.get_tools()
169
+ return [format_tool(tool) for tool in tools.values()]
125
170
  except MCPError as e:
126
171
  format_mcp_error(e, include_traceback=False)
127
172
  return []
128
173
 
129
174
 
130
- async def list_tools_direct(url: str, tool_name: str | None = None) -> list[dict[str, str | None]]:
131
- """List MCP tools using direct MCP protocol with exception conversion.
175
+ async def list_tools_direct(command, url, tool_name=None, transport='sse', args=None, env=None):
176
+ """List MCP tools using direct MCP protocol with structured exception handling.
132
177
 
133
178
  Bypasses MCPBuilder and uses raw MCP ClientSession and SSE client directly.
134
179
  Converts raw exceptions to structured MCPErrors for consistent user experience.
@@ -147,25 +192,51 @@ async def list_tools_direct(url: str, tool_name: str | None = None) -> list[dict
147
192
  This function handles ExceptionGroup by extracting the most relevant exception
148
193
  and converting it to MCPError for consistent error reporting.
149
194
  """
195
+ if args is None:
196
+ args = []
150
197
  from mcp import ClientSession
151
198
  from mcp.client.sse import sse_client
199
+ from mcp.client.stdio import StdioServerParameters
200
+ from mcp.client.stdio import stdio_client
201
+ from mcp.client.streamable_http import streamablehttp_client
152
202
 
153
203
  try:
154
- async with sse_client(url=url) as (read, write):
204
+ if transport == 'stdio':
205
+
206
+ def get_stdio_client():
207
+ return stdio_client(server=StdioServerParameters(command=command, args=args, env=env))
208
+
209
+ client = get_stdio_client
210
+ elif transport == 'streamable-http':
211
+
212
+ def get_streamable_http_client():
213
+ return streamablehttp_client(url=url)
214
+
215
+ client = get_streamable_http_client
216
+ else:
217
+
218
+ def get_sse_client():
219
+ return sse_client(url=url)
220
+
221
+ client = get_sse_client
222
+
223
+ async with client() as ctx:
224
+ read, write = (ctx[0], ctx[1]) if isinstance(ctx, tuple) else ctx
155
225
  async with ClientSession(read, write) as session:
156
226
  await session.initialize()
157
227
  response = await session.list_tools()
158
228
 
159
- tools = []
160
- for tool in response.tools:
161
- if tool_name:
162
- if tool.name == tool_name:
163
- return [format_tool(tool)]
164
- else:
165
- tools.append(format_tool(tool))
166
- if tool_name and not tools:
167
- click.echo(f"[INFO] Tool '{tool_name}' not found.")
168
- return tools
229
+ tools = []
230
+ for tool in response.tools:
231
+ if tool_name:
232
+ if tool.name == tool_name:
233
+ tools.append(format_tool(tool))
234
+ else:
235
+ tools.append(format_tool(tool))
236
+
237
+ if tool_name and not tools:
238
+ click.echo(f"[INFO] Tool '{tool_name}' not found.")
239
+ return tools
169
240
  except Exception as e:
170
241
  # Convert raw exceptions to structured MCPError for consistency
171
242
  from nat.utils.exception_handlers.mcp import convert_to_mcp_error
@@ -181,7 +252,12 @@ async def list_tools_direct(url: str, tool_name: str | None = None) -> list[dict
181
252
  return []
182
253
 
183
254
 
184
- async def ping_mcp_server(url: str, timeout: int) -> MCPPingResult:
255
+ async def ping_mcp_server(url: str,
256
+ timeout: int,
257
+ transport: str = 'streamable-http',
258
+ command: str | None = None,
259
+ args: list[str] | None = None,
260
+ env: dict[str, str] | None = None) -> MCPPingResult:
185
261
  """Ping an MCP server to check if it's responsive.
186
262
 
187
263
  Args:
@@ -193,18 +269,29 @@ async def ping_mcp_server(url: str, timeout: int) -> MCPPingResult:
193
269
  """
194
270
  from mcp.client.session import ClientSession
195
271
  from mcp.client.sse import sse_client
272
+ from mcp.client.stdio import StdioServerParameters
273
+ from mcp.client.stdio import stdio_client
274
+ from mcp.client.streamable_http import streamablehttp_client
196
275
 
197
276
  async def _ping_operation():
198
- async with sse_client(url) as (read, write):
277
+ # Select transport
278
+ if transport == 'stdio':
279
+ stdio_args_local: list[str] = args or []
280
+ if not command:
281
+ raise RuntimeError("--command is required for stdio transport")
282
+ client_ctx = stdio_client(server=StdioServerParameters(command=command, args=stdio_args_local, env=env))
283
+ elif transport == 'sse':
284
+ client_ctx = sse_client(url)
285
+ else: # streamable-http
286
+ client_ctx = streamablehttp_client(url=url)
287
+
288
+ async with client_ctx as ctx:
289
+ read, write = (ctx[0], ctx[1]) if isinstance(ctx, tuple) else ctx
199
290
  async with ClientSession(read, write) as session:
200
- # Initialize the session
201
291
  await session.initialize()
202
292
 
203
- # Record start time just before ping
204
293
  start_time = time.time()
205
- # Send ping request
206
294
  await session.send_ping()
207
-
208
295
  end_time = time.time()
209
296
  response_time_ms = round((end_time - start_time) * 1000, 2)
210
297
 
@@ -226,12 +313,24 @@ async def ping_mcp_server(url: str, timeout: int) -> MCPPingResult:
226
313
 
227
314
  @click.group(invoke_without_command=True, help="List tool names (default), or show details with --detail or --tool.")
228
315
  @click.option('--direct', is_flag=True, help='Bypass MCPBuilder and use direct MCP protocol')
229
- @click.option('--url', default='http://localhost:9901/sse', show_default=True, help='MCP server URL')
316
+ @click.option(
317
+ '--url',
318
+ default='http://localhost:9901/mcp',
319
+ show_default=True,
320
+ help='MCP server URL (e.g. http://localhost:8080/mcp for streamable-http, http://localhost:8080/sse for sse)')
321
+ @click.option('--transport',
322
+ type=click.Choice(['sse', 'stdio', 'streamable-http']),
323
+ default='streamable-http',
324
+ show_default=True,
325
+ help='Type of client to use (default: streamable-http, backwards compatible with sse)')
326
+ @click.option('--command', help='For stdio: The command to run (e.g. mcp-server)')
327
+ @click.option('--args', help='For stdio: Additional arguments for the command (space-separated)')
328
+ @click.option('--env', help='For stdio: Environment variables in KEY=VALUE format (space-separated)')
230
329
  @click.option('--tool', default=None, help='Get details for a specific tool by name')
231
330
  @click.option('--detail', is_flag=True, help='Show full details for all tools')
232
331
  @click.option('--json-output', is_flag=True, help='Output tool metadata in JSON format')
233
332
  @click.pass_context
234
- def list_mcp(ctx: click.Context, direct: bool, url: str, tool: str | None, detail: bool, json_output: bool) -> None:
333
+ def list_mcp(ctx, direct, url, transport, command, args, env, tool, detail, json_output):
235
334
  """List MCP tool names (default) or show detailed tool information.
236
335
 
237
336
  Use --detail for full output including descriptions and input schemas.
@@ -242,7 +341,7 @@ def list_mcp(ctx: click.Context, direct: bool, url: str, tool: str | None, detai
242
341
  Args:
243
342
  ctx (click.Context): Click context object for command invocation
244
343
  direct (bool): Whether to bypass MCPBuilder and use direct MCP protocol
245
- url (str): MCP server URL to connect to (default: http://localhost:9901/sse)
344
+ url (str): MCP server URL to connect to (default: http://localhost:9901/mcp)
246
345
  tool (str | None): Optional specific tool name to retrieve detailed info for
247
346
  detail (bool): Whether to show full details (description + schema) for all tools
248
347
  json_output (bool): Whether to output tool metadata in JSON format instead of text
@@ -256,44 +355,81 @@ def list_mcp(ctx: click.Context, direct: bool, url: str, tool: str | None, detai
256
355
  """
257
356
  if ctx.invoked_subcommand is not None:
258
357
  return
358
+
359
+ if not validate_transport_cli_args(transport, command, args, env):
360
+ return
361
+
362
+ if transport in ['sse', 'streamable-http']:
363
+ if not url:
364
+ click.echo("[ERROR] --url is required when using sse or streamable-http client type", err=True)
365
+ return
366
+
367
+ stdio_args = args.split() if args else []
368
+ stdio_env = dict(var.split('=', 1) for var in env.split()) if env else None
369
+
259
370
  fetcher = list_tools_direct if direct else list_tools_and_schemas
260
- tools = asyncio.run(fetcher(url, tool))
371
+ tools = asyncio.run(fetcher(command, url, tool, transport, stdio_args, stdio_env))
261
372
 
262
373
  if json_output:
263
374
  click.echo(json.dumps(tools, indent=2))
264
375
  elif tool:
265
- for tool_dict in tools:
376
+ for tool_dict in (tools or []):
266
377
  print_tool(tool_dict, detail=True)
267
378
  elif detail:
268
- for tool_dict in tools:
379
+ for tool_dict in (tools or []):
269
380
  print_tool(tool_dict, detail=True)
270
381
  else:
271
- for tool_dict in tools:
382
+ for tool_dict in (tools or []):
272
383
  click.echo(tool_dict.get('name', 'Unknown tool'))
273
384
 
274
385
 
275
386
  @list_mcp.command()
276
- @click.option('--url', default='http://localhost:9901/sse', show_default=True, help='MCP server URL')
387
+ @click.option(
388
+ '--url',
389
+ default='http://localhost:9901/mcp',
390
+ show_default=True,
391
+ help='MCP server URL (e.g. http://localhost:8080/mcp for streamable-http, http://localhost:8080/sse for sse)')
392
+ @click.option('--transport',
393
+ type=click.Choice(['sse', 'stdio', 'streamable-http']),
394
+ default='streamable-http',
395
+ show_default=True,
396
+ help='Type of client to use for ping')
397
+ @click.option('--command', help='For stdio: The command to run (e.g. mcp-server)')
398
+ @click.option('--args', help='For stdio: Additional arguments for the command (space-separated)')
399
+ @click.option('--env', help='For stdio: Environment variables in KEY=VALUE format (space-separated)')
277
400
  @click.option('--timeout', default=60, show_default=True, help='Timeout in seconds for ping request')
278
401
  @click.option('--json-output', is_flag=True, help='Output ping result in JSON format')
279
- def ping(url: str, timeout: int, json_output: bool) -> None:
402
+ def ping(url: str,
403
+ transport: str,
404
+ command: str | None,
405
+ args: str | None,
406
+ env: str | None,
407
+ timeout: int,
408
+ json_output: bool) -> None:
280
409
  """Ping an MCP server to check if it's responsive.
281
410
 
282
411
  This command sends a ping request to the MCP server and measures the response time.
283
412
  It's useful for health checks and monitoring server availability.
284
413
 
285
414
  Args:
286
- url (str): MCP server URL to ping (default: http://localhost:9901/sse)
415
+ url (str): MCP server URL to ping (default: http://localhost:9901/mcp)
287
416
  timeout (int): Timeout in seconds for the ping request (default: 60)
288
417
  json_output (bool): Whether to output the result in JSON format
289
418
 
290
419
  Examples:
291
420
  nat info mcp ping # Ping default server
292
- nat info mcp ping --url http://custom-server:9901/sse # Ping custom server
421
+ nat info mcp ping --url http://custom-server:9901/mcp # Ping custom server
293
422
  nat info mcp ping --timeout 10 # Use 10 second timeout
294
423
  nat info mcp ping --json-output # Get JSON format output
295
424
  """
296
- result = asyncio.run(ping_mcp_server(url, timeout))
425
+ # Validate combinations similar to parent command
426
+ if not validate_transport_cli_args(transport, command, args, env):
427
+ return
428
+
429
+ stdio_args = args.split() if args else []
430
+ stdio_env = dict(var.split('=', 1) for var in env.split()) if env else None
431
+
432
+ result = asyncio.run(ping_mcp_server(url, timeout, transport, command, stdio_args, stdio_env))
297
433
 
298
434
  if json_output:
299
435
  click.echo(result.model_dump_json(indent=2))
nat/cli/commands/start.py CHANGED
@@ -102,12 +102,24 @@ class StartCommandGroup(click.Group):
102
102
  raise ValueError(f"Invalid field '{name}'.Unions are only supported for optional parameters.")
103
103
 
104
104
  # Handle the types
105
- if (issubclass(decomposed_type.root, Path)):
105
+ # Literal[...] -> map to click.Choice([...])
106
+ if (decomposed_type.origin is typing.Literal):
107
+ # typing.get_args returns the literal values; ensure they are strings for Click
108
+ literal_values = [str(v) for v in decomposed_type.args]
109
+ param_type = click.Choice(literal_values)
110
+
111
+ elif (issubclass(decomposed_type.root, Path)):
106
112
  param_type = click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path)
107
113
 
108
114
  elif (issubclass(decomposed_type.root, (list, tuple, set))):
109
115
  if (len(decomposed_type.args) == 1):
110
- param_type = decomposed_type.args[0]
116
+ inner = DecomposedType(decomposed_type.args[0])
117
+ # Support containers of Literal values -> multiple Choice
118
+ if (inner.origin is typing.Literal):
119
+ literal_values = [str(v) for v in inner.args]
120
+ param_type = click.Choice(literal_values)
121
+ else:
122
+ param_type = inner.root
111
123
  else:
112
124
  param_type = None
113
125
 
@@ -22,8 +22,7 @@ from nat.data_models.gated_field_mixin import GatedFieldMixin
22
22
 
23
23
  # The system prompt format for thinking is different for these, so we need to distinguish them here with two separate
24
24
  # regex patterns
25
- _NVIDIA_NEMOTRON_REGEX = re.compile(r"^nvidia/nvidia.*nemotron", re.IGNORECASE)
26
- _LLAMA_NEMOTRON_REGEX = re.compile(r"^nvidia/llama.*nemotron", re.IGNORECASE)
25
+ _NEMOTRON_REGEX = re.compile(r"^nvidia/(llama|nvidia).*nemotron", re.IGNORECASE)
27
26
  _MODEL_KEYS = ("model_name", "model", "azure_deployment")
28
27
 
29
28
 
@@ -33,7 +32,7 @@ class ThinkingMixin(
33
32
  field_name="thinking",
34
33
  default_if_supported=None,
35
34
  keys=_MODEL_KEYS,
36
- supported=(_NVIDIA_NEMOTRON_REGEX, _LLAMA_NEMOTRON_REGEX),
35
+ supported=(_NEMOTRON_REGEX, ),
37
36
  ):
38
37
  """
39
38
  Mixin class for thinking configuration. Only supported on Nemotron models.
@@ -52,7 +51,8 @@ class ThinkingMixin(
52
51
  """
53
52
  Returns the system prompt to use for thinking.
54
53
  For NVIDIA Nemotron, returns "/think" if enabled, else "/no_think".
55
- For Llama Nemotron, returns "detailed thinking on" if enabled, else "detailed thinking off".
54
+ For Llama Nemotron v1.5, returns "/think" if enabled, else "/no_think".
55
+ For Llama Nemotron v1.0, returns "detailed thinking on" if enabled, else "detailed thinking off".
56
56
  If thinking is not supported on the model, returns None.
57
57
 
58
58
  Returns:
@@ -60,9 +60,28 @@ class ThinkingMixin(
60
60
  """
61
61
  if self.thinking is None:
62
62
  return None
63
+
63
64
  for key in _MODEL_KEYS:
64
- if hasattr(self, key):
65
- if _NVIDIA_NEMOTRON_REGEX.match(getattr(self, key)):
66
- return "/think" if self.thinking else "/no_think"
67
- elif _LLAMA_NEMOTRON_REGEX.match(getattr(self, key)):
65
+ model = getattr(self, key, None)
66
+ if not isinstance(model, str) or model is None:
67
+ continue
68
+
69
+ # Normalize name to reduce checks
70
+ model = model.lower().translate(str.maketrans("_.", "--"))
71
+
72
+ if model.startswith("nvidia/nvidia"):
73
+ return "/think" if self.thinking else "/no_think"
74
+
75
+ if model.startswith("nvidia/llama"):
76
+ if "v1-0" in model or "v1-1" in model:
68
77
  return f"detailed thinking {'on' if self.thinking else 'off'}"
78
+
79
+ if "v1-5" in model:
80
+ # v1.5 models are updated to use the /think and /no_think system prompts
81
+ return "/think" if self.thinking else "/no_think"
82
+
83
+ # Assume any other model is a newer model that uses the /think and /no_think system prompts
84
+ return "/think" if self.thinking else "/no_think"
85
+
86
+ # Unknown model
87
+ return None
@@ -13,6 +13,8 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from typing import Literal
17
+
16
18
  from pydantic import Field
17
19
 
18
20
  from nat.data_models.front_end import FrontEndBaseConfig
@@ -32,5 +34,8 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
32
34
  log_level: str = Field(default="INFO", description="Log level for the MCP server (default: INFO)")
33
35
  tool_names: list[str] = Field(default_factory=list,
34
36
  description="The list of tools MCP server will expose (default: all tools)")
37
+ transport: Literal["sse", "streamable-http"] = Field(
38
+ default="streamable-http",
39
+ description="Transport type for the MCP server (default: streamable-http, backwards compatible with sse)")
35
40
  runner_class: str | None = Field(
36
41
  default=None, description="Custom worker class for handling MCP routes (default: built-in worker)")
@@ -77,5 +77,11 @@ class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
77
77
  # Add routes through the worker (includes health endpoint and function registration)
78
78
  await worker.add_routes(mcp, builder)
79
79
 
80
- # Start the MCP server
81
- await mcp.run_sse_async()
80
+ # Start the MCP server with configurable transport
81
+ # streamable-http is the default, but users can choose sse if preferred
82
+ if self.front_end_config.transport == "sse":
83
+ logger.info("Starting MCP server with SSE endpoint at /sse")
84
+ await mcp.run_sse_async()
85
+ else: # streamable-http
86
+ logger.info("Starting MCP server with streamable-http endpoint at /mcp/")
87
+ await mcp.run_streamable_http_async()
@@ -134,9 +134,9 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
134
134
  logger.debug("Skipping function %s as it's not in tool_names", function_name)
135
135
  functions = filtered_functions
136
136
 
137
- # Register each function with MCP
137
+ # Register each function with MCP, passing workflow context for observability
138
138
  for function_name, function in functions.items():
139
- register_function_with_mcp(mcp, function_name, function)
139
+ register_function_with_mcp(mcp, function_name, function, workflow)
140
140
 
141
141
  # Add a simple fallback function if no functions were found
142
142
  if not functions: