nvidia-nat 1.3.0a20250828__py3-none-any.whl → 1.3.0a20250830__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. nat/agent/base.py +6 -1
  2. nat/agent/react_agent/agent.py +46 -38
  3. nat/agent/react_agent/register.py +7 -2
  4. nat/agent/rewoo_agent/agent.py +16 -30
  5. nat/agent/rewoo_agent/register.py +3 -3
  6. nat/agent/tool_calling_agent/agent.py +9 -19
  7. nat/agent/tool_calling_agent/register.py +2 -2
  8. nat/builder/eval_builder.py +2 -2
  9. nat/builder/function.py +8 -8
  10. nat/builder/workflow.py +6 -2
  11. nat/builder/workflow_builder.py +21 -24
  12. nat/cli/cli_utils/config_override.py +1 -1
  13. nat/cli/commands/info/list_channels.py +1 -1
  14. nat/cli/commands/info/list_mcp.py +183 -47
  15. nat/cli/commands/registry/publish.py +2 -2
  16. nat/cli/commands/registry/pull.py +2 -2
  17. nat/cli/commands/registry/remove.py +2 -2
  18. nat/cli/commands/registry/search.py +1 -1
  19. nat/cli/commands/start.py +15 -3
  20. nat/cli/commands/uninstall.py +1 -1
  21. nat/cli/commands/workflow/workflow_commands.py +4 -4
  22. nat/data_models/discovery_metadata.py +4 -4
  23. nat/data_models/thinking_mixin.py +27 -8
  24. nat/eval/evaluate.py +6 -6
  25. nat/eval/intermediate_step_adapter.py +1 -1
  26. nat/eval/rag_evaluator/evaluate.py +2 -2
  27. nat/eval/rag_evaluator/register.py +1 -1
  28. nat/eval/remote_workflow.py +3 -3
  29. nat/eval/swe_bench_evaluator/evaluate.py +5 -5
  30. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  31. nat/eval/tunable_rag_evaluator/evaluate.py +3 -3
  32. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +2 -2
  33. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  34. nat/front_ends/fastapi/fastapi_front_end_plugin.py +1 -1
  35. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +3 -3
  36. nat/front_ends/fastapi/message_handler.py +2 -2
  37. nat/front_ends/fastapi/message_validator.py +8 -10
  38. nat/front_ends/fastapi/response_helpers.py +4 -4
  39. nat/front_ends/fastapi/step_adaptor.py +1 -1
  40. nat/front_ends/mcp/mcp_front_end_config.py +5 -0
  41. nat/front_ends/mcp/mcp_front_end_plugin.py +8 -2
  42. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +2 -2
  43. nat/front_ends/mcp/tool_converter.py +40 -13
  44. nat/observability/exporter/base_exporter.py +1 -1
  45. nat/observability/exporter/processing_exporter.py +8 -9
  46. nat/observability/exporter_manager.py +5 -5
  47. nat/observability/mixin/file_mixin.py +7 -7
  48. nat/observability/processor/batching_processor.py +4 -6
  49. nat/observability/register.py +3 -1
  50. nat/profiler/calc/calc_runner.py +3 -4
  51. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  52. nat/profiler/callbacks/langchain_callback_handler.py +5 -5
  53. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  54. nat/profiler/callbacks/semantic_kernel_callback_handler.py +2 -2
  55. nat/profiler/profile_runner.py +1 -1
  56. nat/profiler/utils.py +1 -1
  57. nat/registry_handlers/local/local_handler.py +2 -2
  58. nat/registry_handlers/package_utils.py +1 -1
  59. nat/registry_handlers/pypi/pypi_handler.py +3 -3
  60. nat/registry_handlers/rest/rest_handler.py +4 -4
  61. nat/retriever/milvus/retriever.py +1 -1
  62. nat/retriever/nemo_retriever/retriever.py +1 -1
  63. nat/runtime/loader.py +1 -1
  64. nat/runtime/runner.py +2 -2
  65. nat/settings/global_settings.py +1 -1
  66. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  67. nat/tool/mcp/{mcp_client.py → mcp_client_base.py} +197 -46
  68. nat/tool/mcp/mcp_client_impl.py +229 -0
  69. nat/tool/mcp/mcp_tool.py +79 -42
  70. nat/tool/nvidia_rag.py +1 -1
  71. nat/tool/register.py +1 -0
  72. nat/tool/retriever.py +3 -2
  73. nat/utils/io/yaml_tools.py +1 -1
  74. nat/utils/reactive/observer.py +2 -2
  75. nat/utils/settings/global_settings.py +2 -2
  76. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/METADATA +3 -3
  77. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/RECORD +82 -81
  78. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/WHEEL +0 -0
  79. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/entry_points.txt +0 -0
  80. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  81. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/licenses/LICENSE.md +0 -0
  82. {nvidia_nat-1.3.0a20250828.dist-info → nvidia_nat-1.3.0a20250830.dist-info}/top_level.txt +0 -0
@@ -23,14 +23,36 @@ import click
23
23
  from pydantic import BaseModel
24
24
 
25
25
  from nat.tool.mcp.exceptions import MCPError
26
- from nat.tool.mcp.mcp_client import MCPBuilder
27
26
  from nat.utils.exception_handlers.mcp import format_mcp_error
28
27
 
29
28
  # Suppress verbose logs from mcp.client.sse and httpx
30
29
  logging.getLogger("mcp.client.sse").setLevel(logging.WARNING)
31
30
  logging.getLogger("httpx").setLevel(logging.WARNING)
32
31
 
33
- logger = logging.getLogger(__name__)
32
+
33
+ def validate_transport_cli_args(transport: str, command: str | None, args: str | None, env: str | None) -> bool:
34
+ """
35
+ Validate transport and parameter combinations, returning False if invalid.
36
+
37
+ Args:
38
+ transport: The transport type ('sse', 'stdio', or 'streamable-http')
39
+ command: Command for stdio transport
40
+ args: Arguments for stdio transport
41
+ env: Environment variables for stdio transport
42
+
43
+ Returns:
44
+ bool: True if valid, False if invalid (error message already displayed)
45
+ """
46
+ if transport == 'stdio':
47
+ if not command:
48
+ click.echo("--command is required when using stdio client type", err=True)
49
+ return False
50
+ elif transport in ['sse', 'streamable-http']:
51
+ if command or args or env:
52
+ click.echo("--command, --args, and --env are not allowed when using sse or streamable-http client type",
53
+ err=True)
54
+ return False
55
+ return True
34
56
 
35
57
 
36
58
  class MCPPingResult(BaseModel):
@@ -64,12 +86,20 @@ def format_tool(tool: Any) -> dict[str, str | None]:
64
86
  description = getattr(tool, 'description', '')
65
87
  input_schema = getattr(tool, 'input_schema', None) or getattr(tool, 'inputSchema', None)
66
88
 
67
- schema_str = None
68
- if input_schema:
69
- if hasattr(input_schema, "schema_json"):
70
- schema_str = input_schema.schema_json(indent=2)
71
- else:
72
- schema_str = str(input_schema)
89
+ # Normalize schema to JSON string
90
+ if input_schema is None:
91
+ return {
92
+ "name": name,
93
+ "description": description,
94
+ "input_schema": None,
95
+ }
96
+ elif hasattr(input_schema, "schema_json"):
97
+ schema_str = input_schema.schema_json(indent=2)
98
+ elif isinstance(input_schema, dict):
99
+ schema_str = json.dumps(input_schema, indent=2)
100
+ else:
101
+ # Final fallback: attempt to dump stringified version wrapped as JSON string
102
+ schema_str = json.dumps({"raw": str(input_schema)}, indent=2)
73
103
 
74
104
  return {
75
105
  "name": name,
@@ -100,8 +130,8 @@ def print_tool(tool_dict: dict[str, str | None], detail: bool = False) -> None:
100
130
  click.echo("-" * 60)
101
131
 
102
132
 
103
- async def list_tools_and_schemas(url: str, tool_name: str | None = None) -> list[dict[str, str | None]]:
104
- """List MCP tools using MCPBuilder with structured exception handling.
133
+ async def list_tools_and_schemas(command, url, tool_name=None, transport='sse', args=None, env=None):
134
+ """List MCP tools using NAT MCPClient with structured exception handling.
105
135
 
106
136
  Args:
107
137
  url (str): MCP server URL to connect to
@@ -115,20 +145,35 @@ async def list_tools_and_schemas(url: str, tool_name: str | None = None) -> list
115
145
  Raises:
116
146
  MCPError: Caught internally and logged, returns empty list instead
117
147
  """
118
- builder = MCPBuilder(url=url)
148
+ from nat.tool.mcp.mcp_client_base import MCPSSEClient
149
+ from nat.tool.mcp.mcp_client_base import MCPStdioClient
150
+ from nat.tool.mcp.mcp_client_base import MCPStreamableHTTPClient
151
+
152
+ if args is None:
153
+ args = []
154
+
119
155
  try:
120
- if tool_name:
121
- tool = await builder.get_tool(tool_name)
122
- return [format_tool(tool)]
123
- tools = await builder.get_tools()
124
- return [format_tool(tool) for tool in tools.values()]
156
+ if transport == 'stdio':
157
+ client = MCPStdioClient(command=command, args=args, env=env)
158
+ elif transport == 'streamable-http':
159
+ client = MCPStreamableHTTPClient(url=url)
160
+ else: # sse
161
+ client = MCPSSEClient(url=url)
162
+
163
+ async with client:
164
+ if tool_name:
165
+ tool = await client.get_tool(tool_name)
166
+ return [format_tool(tool)]
167
+ else:
168
+ tools = await client.get_tools()
169
+ return [format_tool(tool) for tool in tools.values()]
125
170
  except MCPError as e:
126
171
  format_mcp_error(e, include_traceback=False)
127
172
  return []
128
173
 
129
174
 
130
- async def list_tools_direct(url: str, tool_name: str | None = None) -> list[dict[str, str | None]]:
131
- """List MCP tools using direct MCP protocol with exception conversion.
175
+ async def list_tools_direct(command, url, tool_name=None, transport='sse', args=None, env=None):
176
+ """List MCP tools using direct MCP protocol with structured exception handling.
132
177
 
133
178
  Bypasses MCPBuilder and uses raw MCP ClientSession and SSE client directly.
134
179
  Converts raw exceptions to structured MCPErrors for consistent user experience.
@@ -147,25 +192,51 @@ async def list_tools_direct(url: str, tool_name: str | None = None) -> list[dict
147
192
  This function handles ExceptionGroup by extracting the most relevant exception
148
193
  and converting it to MCPError for consistent error reporting.
149
194
  """
195
+ if args is None:
196
+ args = []
150
197
  from mcp import ClientSession
151
198
  from mcp.client.sse import sse_client
199
+ from mcp.client.stdio import StdioServerParameters
200
+ from mcp.client.stdio import stdio_client
201
+ from mcp.client.streamable_http import streamablehttp_client
152
202
 
153
203
  try:
154
- async with sse_client(url=url) as (read, write):
204
+ if transport == 'stdio':
205
+
206
+ def get_stdio_client():
207
+ return stdio_client(server=StdioServerParameters(command=command, args=args, env=env))
208
+
209
+ client = get_stdio_client
210
+ elif transport == 'streamable-http':
211
+
212
+ def get_streamable_http_client():
213
+ return streamablehttp_client(url=url)
214
+
215
+ client = get_streamable_http_client
216
+ else:
217
+
218
+ def get_sse_client():
219
+ return sse_client(url=url)
220
+
221
+ client = get_sse_client
222
+
223
+ async with client() as ctx:
224
+ read, write = (ctx[0], ctx[1]) if isinstance(ctx, tuple) else ctx
155
225
  async with ClientSession(read, write) as session:
156
226
  await session.initialize()
157
227
  response = await session.list_tools()
158
228
 
159
- tools = []
160
- for tool in response.tools:
161
- if tool_name:
162
- if tool.name == tool_name:
163
- return [format_tool(tool)]
164
- else:
165
- tools.append(format_tool(tool))
166
- if tool_name and not tools:
167
- click.echo(f"[INFO] Tool '{tool_name}' not found.")
168
- return tools
229
+ tools = []
230
+ for tool in response.tools:
231
+ if tool_name:
232
+ if tool.name == tool_name:
233
+ tools.append(format_tool(tool))
234
+ else:
235
+ tools.append(format_tool(tool))
236
+
237
+ if tool_name and not tools:
238
+ click.echo(f"[INFO] Tool '{tool_name}' not found.")
239
+ return tools
169
240
  except Exception as e:
170
241
  # Convert raw exceptions to structured MCPError for consistency
171
242
  from nat.utils.exception_handlers.mcp import convert_to_mcp_error
@@ -181,7 +252,12 @@ async def list_tools_direct(url: str, tool_name: str | None = None) -> list[dict
181
252
  return []
182
253
 
183
254
 
184
- async def ping_mcp_server(url: str, timeout: int) -> MCPPingResult:
255
+ async def ping_mcp_server(url: str,
256
+ timeout: int,
257
+ transport: str = 'streamable-http',
258
+ command: str | None = None,
259
+ args: list[str] | None = None,
260
+ env: dict[str, str] | None = None) -> MCPPingResult:
185
261
  """Ping an MCP server to check if it's responsive.
186
262
 
187
263
  Args:
@@ -193,18 +269,29 @@ async def ping_mcp_server(url: str, timeout: int) -> MCPPingResult:
193
269
  """
194
270
  from mcp.client.session import ClientSession
195
271
  from mcp.client.sse import sse_client
272
+ from mcp.client.stdio import StdioServerParameters
273
+ from mcp.client.stdio import stdio_client
274
+ from mcp.client.streamable_http import streamablehttp_client
196
275
 
197
276
  async def _ping_operation():
198
- async with sse_client(url) as (read, write):
277
+ # Select transport
278
+ if transport == 'stdio':
279
+ stdio_args_local: list[str] = args or []
280
+ if not command:
281
+ raise RuntimeError("--command is required for stdio transport")
282
+ client_ctx = stdio_client(server=StdioServerParameters(command=command, args=stdio_args_local, env=env))
283
+ elif transport == 'sse':
284
+ client_ctx = sse_client(url)
285
+ else: # streamable-http
286
+ client_ctx = streamablehttp_client(url=url)
287
+
288
+ async with client_ctx as ctx:
289
+ read, write = (ctx[0], ctx[1]) if isinstance(ctx, tuple) else ctx
199
290
  async with ClientSession(read, write) as session:
200
- # Initialize the session
201
291
  await session.initialize()
202
292
 
203
- # Record start time just before ping
204
293
  start_time = time.time()
205
- # Send ping request
206
294
  await session.send_ping()
207
-
208
295
  end_time = time.time()
209
296
  response_time_ms = round((end_time - start_time) * 1000, 2)
210
297
 
@@ -226,12 +313,24 @@ async def ping_mcp_server(url: str, timeout: int) -> MCPPingResult:
226
313
 
227
314
  @click.group(invoke_without_command=True, help="List tool names (default), or show details with --detail or --tool.")
228
315
  @click.option('--direct', is_flag=True, help='Bypass MCPBuilder and use direct MCP protocol')
229
- @click.option('--url', default='http://localhost:9901/sse', show_default=True, help='MCP server URL')
316
+ @click.option(
317
+ '--url',
318
+ default='http://localhost:9901/mcp',
319
+ show_default=True,
320
+ help='MCP server URL (e.g. http://localhost:8080/mcp for streamable-http, http://localhost:8080/sse for sse)')
321
+ @click.option('--transport',
322
+ type=click.Choice(['sse', 'stdio', 'streamable-http']),
323
+ default='streamable-http',
324
+ show_default=True,
325
+ help='Type of client to use (default: streamable-http, backwards compatible with sse)')
326
+ @click.option('--command', help='For stdio: The command to run (e.g. mcp-server)')
327
+ @click.option('--args', help='For stdio: Additional arguments for the command (space-separated)')
328
+ @click.option('--env', help='For stdio: Environment variables in KEY=VALUE format (space-separated)')
230
329
  @click.option('--tool', default=None, help='Get details for a specific tool by name')
231
330
  @click.option('--detail', is_flag=True, help='Show full details for all tools')
232
331
  @click.option('--json-output', is_flag=True, help='Output tool metadata in JSON format')
233
332
  @click.pass_context
234
- def list_mcp(ctx: click.Context, direct: bool, url: str, tool: str | None, detail: bool, json_output: bool) -> None:
333
+ def list_mcp(ctx, direct, url, transport, command, args, env, tool, detail, json_output):
235
334
  """List MCP tool names (default) or show detailed tool information.
236
335
 
237
336
  Use --detail for full output including descriptions and input schemas.
@@ -242,7 +341,7 @@ def list_mcp(ctx: click.Context, direct: bool, url: str, tool: str | None, detai
242
341
  Args:
243
342
  ctx (click.Context): Click context object for command invocation
244
343
  direct (bool): Whether to bypass MCPBuilder and use direct MCP protocol
245
- url (str): MCP server URL to connect to (default: http://localhost:9901/sse)
344
+ url (str): MCP server URL to connect to (default: http://localhost:9901/mcp)
246
345
  tool (str | None): Optional specific tool name to retrieve detailed info for
247
346
  detail (bool): Whether to show full details (description + schema) for all tools
248
347
  json_output (bool): Whether to output tool metadata in JSON format instead of text
@@ -256,44 +355,81 @@ def list_mcp(ctx: click.Context, direct: bool, url: str, tool: str | None, detai
256
355
  """
257
356
  if ctx.invoked_subcommand is not None:
258
357
  return
358
+
359
+ if not validate_transport_cli_args(transport, command, args, env):
360
+ return
361
+
362
+ if transport in ['sse', 'streamable-http']:
363
+ if not url:
364
+ click.echo("[ERROR] --url is required when using sse or streamable-http client type", err=True)
365
+ return
366
+
367
+ stdio_args = args.split() if args else []
368
+ stdio_env = dict(var.split('=', 1) for var in env.split()) if env else None
369
+
259
370
  fetcher = list_tools_direct if direct else list_tools_and_schemas
260
- tools = asyncio.run(fetcher(url, tool))
371
+ tools = asyncio.run(fetcher(command, url, tool, transport, stdio_args, stdio_env))
261
372
 
262
373
  if json_output:
263
374
  click.echo(json.dumps(tools, indent=2))
264
375
  elif tool:
265
- for tool_dict in tools:
376
+ for tool_dict in (tools or []):
266
377
  print_tool(tool_dict, detail=True)
267
378
  elif detail:
268
- for tool_dict in tools:
379
+ for tool_dict in (tools or []):
269
380
  print_tool(tool_dict, detail=True)
270
381
  else:
271
- for tool_dict in tools:
382
+ for tool_dict in (tools or []):
272
383
  click.echo(tool_dict.get('name', 'Unknown tool'))
273
384
 
274
385
 
275
386
  @list_mcp.command()
276
- @click.option('--url', default='http://localhost:9901/sse', show_default=True, help='MCP server URL')
387
+ @click.option(
388
+ '--url',
389
+ default='http://localhost:9901/mcp',
390
+ show_default=True,
391
+ help='MCP server URL (e.g. http://localhost:8080/mcp for streamable-http, http://localhost:8080/sse for sse)')
392
+ @click.option('--transport',
393
+ type=click.Choice(['sse', 'stdio', 'streamable-http']),
394
+ default='streamable-http',
395
+ show_default=True,
396
+ help='Type of client to use for ping')
397
+ @click.option('--command', help='For stdio: The command to run (e.g. mcp-server)')
398
+ @click.option('--args', help='For stdio: Additional arguments for the command (space-separated)')
399
+ @click.option('--env', help='For stdio: Environment variables in KEY=VALUE format (space-separated)')
277
400
  @click.option('--timeout', default=60, show_default=True, help='Timeout in seconds for ping request')
278
401
  @click.option('--json-output', is_flag=True, help='Output ping result in JSON format')
279
- def ping(url: str, timeout: int, json_output: bool) -> None:
402
+ def ping(url: str,
403
+ transport: str,
404
+ command: str | None,
405
+ args: str | None,
406
+ env: str | None,
407
+ timeout: int,
408
+ json_output: bool) -> None:
280
409
  """Ping an MCP server to check if it's responsive.
281
410
 
282
411
  This command sends a ping request to the MCP server and measures the response time.
283
412
  It's useful for health checks and monitoring server availability.
284
413
 
285
414
  Args:
286
- url (str): MCP server URL to ping (default: http://localhost:9901/sse)
415
+ url (str): MCP server URL to ping (default: http://localhost:9901/mcp)
287
416
  timeout (int): Timeout in seconds for the ping request (default: 60)
288
417
  json_output (bool): Whether to output the result in JSON format
289
418
 
290
419
  Examples:
291
420
  nat info mcp ping # Ping default server
292
- nat info mcp ping --url http://custom-server:9901/sse # Ping custom server
421
+ nat info mcp ping --url http://custom-server:9901/mcp # Ping custom server
293
422
  nat info mcp ping --timeout 10 # Use 10 second timeout
294
423
  nat info mcp ping --json-output # Get JSON format output
295
424
  """
296
- result = asyncio.run(ping_mcp_server(url, timeout))
425
+ # Validate combinations similar to parent command
426
+ if not validate_transport_cli_args(transport, command, args, env):
427
+ return
428
+
429
+ stdio_args = args.split() if args else []
430
+ stdio_env = dict(var.split('=', 1) for var in env.split()) if env else None
431
+
432
+ result = asyncio.run(ping_mcp_server(url, timeout, transport, command, stdio_args, stdio_env))
297
433
 
298
434
  if json_output:
299
435
  click.echo(result.model_dump_json(indent=2))
@@ -40,7 +40,7 @@ async def publish_artifact(registry_handler_config: RegistryHandlerBaseConfig, p
40
40
  try:
41
41
  artifact = build_artifact(package_root=package_root)
42
42
  except Exception as e:
43
- logger.exception("Error building artifact: %s", e, exc_info=True)
43
+ logger.exception("Error building artifact: %s", e)
44
44
  return
45
45
  await stack.enter_async_context(registry_handler.publish(artifact=artifact))
46
46
 
@@ -82,7 +82,7 @@ def publish(channel: str, config_file: str, package_root: str) -> None:
82
82
  logger.error("Publish channel '%s' has not been configured.", channel)
83
83
  return
84
84
  except Exception as e:
85
- logger.exception("Error loading user settings: %s", e, exc_info=True)
85
+ logger.exception("Error loading user settings: %s", e)
86
86
  return
87
87
 
88
88
  asyncio.run(publish_artifact(registry_handler_config=publish_channel_config, package_root=package_root))
@@ -66,7 +66,7 @@ async def pull_artifact(registry_handler_config: RegistryHandlerBaseConfig, pack
66
66
  validated_packages = PullRequestPackages(packages=package_list)
67
67
 
68
68
  except Exception as e:
69
- logger.exception("Error processing package names: %s", e, exc_info=True)
69
+ logger.exception("Error processing package names: %s", e)
70
70
  return
71
71
 
72
72
  await stack.enter_async_context(registry_handler.pull(packages=validated_packages))
@@ -112,7 +112,7 @@ def pull(channel: str, config_file: str, packages: str) -> None:
112
112
  logger.error("Pull channel '%s' has not been configured.", channel)
113
113
  return
114
114
  except Exception as e:
115
- logger.exception("Error loading user settings: %s", e, exc_info=True)
115
+ logger.exception("Error loading user settings: %s", e)
116
116
  return
117
117
 
118
118
  asyncio.run(pull_artifact(pull_channel_config, packages))
@@ -41,7 +41,7 @@ async def remove_artifact(registry_handler_config: RegistryHandlerBaseConfig, pa
41
41
  try:
42
42
  package_name_list = PackageNameVersionList(**{"packages": packages})
43
43
  except Exception as e:
44
- logger.exception("Invalid package format: '%s'", e, exc_info=True)
44
+ logger.exception("Invalid package format: '%s'", e)
45
45
 
46
46
  await stack.enter_async_context(registry_handler.remove(packages=package_name_list))
47
47
 
@@ -102,7 +102,7 @@ def remove(channel: str, config_file: str, packages: str) -> None:
102
102
  logger.error("Remove channel '%s' has not been configured.", channel)
103
103
  return
104
104
  except Exception as e:
105
- logger.exception("Error loading user settings: %s", e, exc_info=True)
105
+ logger.exception("Error loading user settings: %s", e)
106
106
  return
107
107
 
108
108
  asyncio.run(remove_artifact(registry_handler_config=remove_channel_config, packages=packages_versions))
@@ -140,7 +140,7 @@ def search(config_file: str,
140
140
  logger.error("Search channel '%s' has not been configured.", channel)
141
141
  return
142
142
  except Exception as e:
143
- logger.exception("Error loading user settings: %s", e, exc_info=True)
143
+ logger.exception("Error loading user settings: %s", e)
144
144
  return
145
145
 
146
146
  asyncio.run(
nat/cli/commands/start.py CHANGED
@@ -102,12 +102,24 @@ class StartCommandGroup(click.Group):
102
102
  raise ValueError(f"Invalid field '{name}'.Unions are only supported for optional parameters.")
103
103
 
104
104
  # Handle the types
105
- if (issubclass(decomposed_type.root, Path)):
105
+ # Literal[...] -> map to click.Choice([...])
106
+ if (decomposed_type.origin is typing.Literal):
107
+ # typing.get_args returns the literal values; ensure they are strings for Click
108
+ literal_values = [str(v) for v in decomposed_type.args]
109
+ param_type = click.Choice(literal_values)
110
+
111
+ elif (issubclass(decomposed_type.root, Path)):
106
112
  param_type = click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path)
107
113
 
108
114
  elif (issubclass(decomposed_type.root, (list, tuple, set))):
109
115
  if (len(decomposed_type.args) == 1):
110
- param_type = decomposed_type.args[0]
116
+ inner = DecomposedType(decomposed_type.args[0])
117
+ # Support containers of Literal values -> multiple Choice
118
+ if (inner.origin is typing.Literal):
119
+ literal_values = [str(v) for v in inner.args]
120
+ param_type = click.Choice(literal_values)
121
+ else:
122
+ param_type = inner.root
111
123
  else:
112
124
  param_type = None
113
125
 
@@ -224,7 +236,7 @@ class StartCommandGroup(click.Group):
224
236
  return asyncio.run(run_plugin())
225
237
 
226
238
  except Exception as e:
227
- logger.error("Failed to initialize workflow", exc_info=True)
239
+ logger.error("Failed to initialize workflow")
228
240
  raise click.ClickException(str(e)) from e
229
241
 
230
242
  def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
@@ -44,7 +44,7 @@ async def uninstall_packages(packages: list[dict[str, str]]) -> None:
44
44
  try:
45
45
  package_name_list = PackageNameVersionList(**{"packages": packages})
46
46
  except Exception as e:
47
- logger.exception("Error validating package format: %s", e, exc_info=True)
47
+ logger.exception("Error validating package format: %s", e)
48
48
  return
49
49
 
50
50
  async with AsyncExitStack() as stack:
@@ -97,7 +97,7 @@ def find_package_root(package_name: str) -> Path | None:
97
97
  try:
98
98
  info = json.loads(direct_url)
99
99
  except json.JSONDecodeError:
100
- logger.error("Malformed direct_url.json for package: %s", package_name)
100
+ logger.exception("Malformed direct_url.json for package: %s", package_name)
101
101
  return None
102
102
 
103
103
  if not info.get("dir_info", {}).get("editable"):
@@ -271,7 +271,7 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
271
271
 
272
272
  click.echo(f"Workflow '{workflow_name}' created successfully in '{new_workflow_dir}'.")
273
273
  except Exception as e:
274
- logger.exception("An error occurred while creating the workflow: %s", e, exc_info=True)
274
+ logger.exception("An error occurred while creating the workflow: %s", e)
275
275
  click.echo(f"An error occurred while creating the workflow: {e}")
276
276
 
277
277
 
@@ -307,7 +307,7 @@ def reinstall_command(workflow_name):
307
307
 
308
308
  click.echo(f"Workflow '{workflow_name}' reinstalled successfully.")
309
309
  except Exception as e:
310
- logger.exception("An error occurred while reinstalling the workflow: %s", e, exc_info=True)
310
+ logger.exception("An error occurred while reinstalling the workflow: %s", e)
311
311
  click.echo(f"An error occurred while reinstalling the workflow: {e}")
312
312
 
313
313
 
@@ -354,7 +354,7 @@ def delete_command(workflow_name: str):
354
354
 
355
355
  click.echo(f"Workflow '{workflow_name}' deleted successfully.")
356
356
  except Exception as e:
357
- logger.exception("An error occurred while deleting the workflow: %s", e, exc_info=True)
357
+ logger.exception("An error occurred while deleting the workflow: %s", e)
358
358
  click.echo(f"An error occurred while deleting the workflow: {e}")
359
359
 
360
360
 
@@ -177,7 +177,7 @@ class DiscoveryMetadata(BaseModel):
177
177
  logger.warning("Package metadata not found for %s", distro_name)
178
178
  version = ""
179
179
  except Exception as e:
180
- logger.exception("Encountered issue extracting module metadata for %s: %s", config_type, e, exc_info=True)
180
+ logger.exception("Encountered issue extracting module metadata for %s: %s", config_type, e)
181
181
  return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
182
182
 
183
183
  description = generate_config_type_docs(config_type=config_type)
@@ -217,7 +217,7 @@ class DiscoveryMetadata(BaseModel):
217
217
  logger.warning("Package metadata not found for %s", distro_name)
218
218
  version = ""
219
219
  except Exception as e:
220
- logger.exception("Encountered issue extracting module metadata for %s: %s", fn, e, exc_info=True)
220
+ logger.exception("Encountered issue extracting module metadata for %s: %s", fn, e)
221
221
  return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
222
222
 
223
223
  if isinstance(wrapper_type, LLMFrameworkEnum):
@@ -252,7 +252,7 @@ class DiscoveryMetadata(BaseModel):
252
252
  description = ""
253
253
  package_version = package_version or ""
254
254
  except Exception as e:
255
- logger.exception("Encountered issue extracting module metadata for %s: %s", package_name, e, exc_info=True)
255
+ logger.exception("Encountered issue extracting module metadata for %s: %s", package_name, e)
256
256
  return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
257
257
 
258
258
  return DiscoveryMetadata(package=package_name,
@@ -290,7 +290,7 @@ class DiscoveryMetadata(BaseModel):
290
290
  logger.warning("Package metadata not found for %s", distro_name)
291
291
  version = ""
292
292
  except Exception as e:
293
- logger.exception("Encountered issue extracting module metadata for %s: %s", config_type, e, exc_info=True)
293
+ logger.exception("Encountered issue extracting module metadata for %s: %s", config_type, e)
294
294
  return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
295
295
 
296
296
  wrapper_type = wrapper_type.value if isinstance(wrapper_type, LLMFrameworkEnum) else wrapper_type
@@ -22,8 +22,7 @@ from nat.data_models.gated_field_mixin import GatedFieldMixin
22
22
 
23
23
  # The system prompt format for thinking is different for these, so we need to distinguish them here with two separate
24
24
  # regex patterns
25
- _NVIDIA_NEMOTRON_REGEX = re.compile(r"^nvidia/nvidia.*nemotron", re.IGNORECASE)
26
- _LLAMA_NEMOTRON_REGEX = re.compile(r"^nvidia/llama.*nemotron", re.IGNORECASE)
25
+ _NEMOTRON_REGEX = re.compile(r"^nvidia/(llama|nvidia).*nemotron", re.IGNORECASE)
27
26
  _MODEL_KEYS = ("model_name", "model", "azure_deployment")
28
27
 
29
28
 
@@ -33,7 +32,7 @@ class ThinkingMixin(
33
32
  field_name="thinking",
34
33
  default_if_supported=None,
35
34
  keys=_MODEL_KEYS,
36
- supported=(_NVIDIA_NEMOTRON_REGEX, _LLAMA_NEMOTRON_REGEX),
35
+ supported=(_NEMOTRON_REGEX, ),
37
36
  ):
38
37
  """
39
38
  Mixin class for thinking configuration. Only supported on Nemotron models.
@@ -52,7 +51,8 @@ class ThinkingMixin(
52
51
  """
53
52
  Returns the system prompt to use for thinking.
54
53
  For NVIDIA Nemotron, returns "/think" if enabled, else "/no_think".
55
- For Llama Nemotron, returns "detailed thinking on" if enabled, else "detailed thinking off".
54
+ For Llama Nemotron v1.5, returns "/think" if enabled, else "/no_think".
55
+ For Llama Nemotron v1.0, returns "detailed thinking on" if enabled, else "detailed thinking off".
56
56
  If thinking is not supported on the model, returns None.
57
57
 
58
58
  Returns:
@@ -60,9 +60,28 @@ class ThinkingMixin(
60
60
  """
61
61
  if self.thinking is None:
62
62
  return None
63
+
63
64
  for key in _MODEL_KEYS:
64
- if hasattr(self, key):
65
- if _NVIDIA_NEMOTRON_REGEX.match(getattr(self, key)):
66
- return "/think" if self.thinking else "/no_think"
67
- elif _LLAMA_NEMOTRON_REGEX.match(getattr(self, key)):
65
+ model = getattr(self, key, None)
66
+ if not isinstance(model, str) or model is None:
67
+ continue
68
+
69
+ # Normalize name to reduce checks
70
+ model = model.lower().translate(str.maketrans("_.", "--"))
71
+
72
+ if model.startswith("nvidia/nvidia"):
73
+ return "/think" if self.thinking else "/no_think"
74
+
75
+ if model.startswith("nvidia/llama"):
76
+ if "v1-0" in model or "v1-1" in model:
68
77
  return f"detailed thinking {'on' if self.thinking else 'off'}"
78
+
79
+ if "v1-5" in model:
80
+ # v1.5 models are updated to use the /think and /no_think system prompts
81
+ return "/think" if self.thinking else "/no_think"
82
+
83
+ # Assume any other model is a newer model that uses the /think and /no_think system prompts
84
+ return "/think" if self.thinking else "/no_think"
85
+
86
+ # Unknown model
87
+ return None
nat/eval/evaluate.py CHANGED
@@ -168,17 +168,17 @@ class EvaluationRun:
168
168
  intermediate_future = None
169
169
 
170
170
  try:
171
-
172
171
  # Start usage stats and intermediate steps collection in parallel
173
172
  intermediate_future = pull_intermediate()
174
173
  runner_result = runner.result()
175
174
  base_output = await runner_result
176
175
  intermediate_steps = await intermediate_future
177
176
  except NotImplementedError as e:
177
+ logger.error("Failed to run the workflow: %s", e)
178
178
  # raise original error
179
- raise e
179
+ raise
180
180
  except Exception as e:
181
- logger.exception("Failed to run the workflow: %s", e, exc_info=True)
181
+ logger.exception("Failed to run the workflow: %s", e)
182
182
  # stop processing if a workflow error occurs
183
183
  self.workflow_interrupted = True
184
184
 
@@ -317,7 +317,7 @@ class EvaluationRun:
317
317
  logger.info("Deleting old job directory: %s", dir_to_delete)
318
318
  shutil.rmtree(dir_to_delete)
319
319
  except Exception as e:
320
- logger.exception("Failed to delete old job directory: %s: %s", dir_to_delete, e, exc_info=True)
320
+ logger.exception("Failed to delete old job directory: %s: %s", dir_to_delete, e)
321
321
 
322
322
  def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults):
323
323
  workflow_output_file = self.eval_config.general.output_dir / "workflow_output.json"
@@ -367,7 +367,7 @@ class EvaluationRun:
367
367
 
368
368
  await self.weave_eval.alog_score(eval_output, evaluator_name)
369
369
  except Exception as e:
370
- logger.exception("An error occurred while running evaluator %s: %s", evaluator_name, e, exc_info=True)
370
+ logger.exception("An error occurred while running evaluator %s: %s", evaluator_name, e)
371
371
 
372
372
  async def run_evaluators(self, evaluators: dict[str, Any]):
373
373
  """Run all configured evaluators asynchronously."""
@@ -380,7 +380,7 @@ class EvaluationRun:
380
380
  try:
381
381
  await asyncio.gather(*tasks)
382
382
  except Exception as e:
383
- logger.exception("An error occurred while running evaluators: %s", e, exc_info=True)
383
+ logger.error("An error occurred while running evaluators: %s", e)
384
384
  raise
385
385
  finally:
386
386
  # Finish prediction loggers in Weave