nvidia-nat 1.3.0a20250926__py3-none-any.whl → 1.3.0a20250929__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,9 +13,14 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import asyncio
16
17
  import json
17
18
  import logging
19
+ import re
18
20
  from json import JSONDecodeError
21
+ from typing import Dict
22
+ from typing import List
23
+ from typing import Tuple
19
24
 
20
25
  from langchain_core.callbacks.base import AsyncCallbackHandler
21
26
  from langchain_core.language_models import BaseChatModel
@@ -50,15 +55,21 @@ class ReWOOGraphState(BaseModel):
50
55
  default_factory=lambda: AIMessage(content="")) # the plan generated by the planner to solve the task
51
56
  steps: AIMessage = Field(
52
57
  default_factory=lambda: AIMessage(content="")) # the steps to solve the task, parsed from the plan
58
+ # New fields for parallel execution support
59
+ evidence_map: Dict[str, Dict] = Field(default_factory=dict) # mapping from placeholders to step info
60
+ execution_levels: List[List[str]] = Field(default_factory=list) # levels for parallel execution
61
+ current_level: int = Field(default=0) # current execution level
53
62
  intermediate_results: dict[str, ToolMessage] = Field(default_factory=dict) # the intermediate results of each step
54
63
  result: AIMessage = Field(
55
64
  default_factory=lambda: AIMessage(content="")) # the final result of the task, generated by the solver
56
65
 
57
66
 
58
67
  class ReWOOAgentGraph(BaseAgent):
59
- """Configurable LangGraph ReWOO Agent. A ReWOO Agent performs reasoning by interacting with other objects or tools
60
- and utilizes their outputs to make decisions. Supports retrying on output parsing errors. Argument
61
- "detailed_logs" toggles logging of inputs, outputs, and intermediate steps."""
68
+ """Configurable ReWOO Agent.
69
+
70
+ Args:
71
+ detailed_logs: Toggles logging of inputs, outputs, and intermediate steps.
72
+ """
62
73
 
63
74
  def __init__(self,
64
75
  llm: BaseChatModel,
@@ -109,16 +120,27 @@ class ReWOOAgentGraph(BaseAgent):
109
120
  raise
110
121
 
111
122
  @staticmethod
112
- def _get_current_step(state: ReWOOGraphState) -> int:
113
- steps = state.steps.content
114
- if len(steps) == 0:
115
- raise RuntimeError('No steps received in ReWOOGraphState')
123
+ def _get_current_level_status(state: ReWOOGraphState) -> tuple[int, bool]:
124
+ """
125
+ Get the current execution level and whether it's complete.
126
+ :param state: The ReWOO graph state.
127
+ :return: Tuple of (current_level, is_complete). Level -1 means all execution is complete.
128
+ :rtype: tuple[int, bool]
129
+ """
130
+ if not state.execution_levels:
131
+ return -1, True
132
+
133
+ current_level = state.current_level
116
134
 
117
- if len(state.intermediate_results) == len(steps):
118
- # all steps are done
119
- return -1
135
+ # Check if we've completed all levels
136
+ if current_level >= len(state.execution_levels):
137
+ return -1, True
120
138
 
121
- return len(state.intermediate_results)
139
+ # Check if current level is complete
140
+ current_level_placeholders = state.execution_levels[current_level]
141
+ level_complete = all(placeholder in state.intermediate_results for placeholder in current_level_placeholders)
142
+
143
+ return current_level, level_complete
122
144
 
123
145
  @staticmethod
124
146
  def _parse_planner_output(planner_output: str) -> AIMessage:
@@ -130,6 +152,75 @@ class ReWOOAgentGraph(BaseAgent):
130
152
 
131
153
  return AIMessage(content=steps)
132
154
 
155
+ @staticmethod
156
+ def _parse_planner_dependencies(steps: List[Dict]) -> Tuple[Dict[str, Dict], List[List[str]]]:
157
+ """
158
+ Parse planner steps to identify dependencies and create execution levels for parallel processing.
159
+ This creates a dependency map and identifies which evidence placeholders can be executed in parallel.
160
+
161
+ :param steps: List of plan steps from the planner.
162
+ :type steps: List[Dict]
163
+ :return: A mapping from evidence placeholders to step info and execution levels for parallel processing.
164
+ :rtype: Tuple[Dict[str, Dict], List[List[str]]]
165
+ """
166
+ evidences = {}
167
+ dependence = {}
168
+
169
+ # First pass: collect all evidence placeholders and their info
170
+ for step in steps:
171
+ if "evidence" not in step:
172
+ continue
173
+
174
+ evidence_info = step["evidence"]
175
+ placeholder = evidence_info.get("placeholder", "")
176
+
177
+ if placeholder:
178
+ # Store the complete step info for this evidence
179
+ evidences[placeholder] = {"plan": step.get("plan", ""), "evidence": evidence_info}
180
+
181
+ # Second pass: find dependencies now that we have all placeholders
182
+ for step in steps:
183
+ if "evidence" not in step:
184
+ continue
185
+
186
+ evidence_info = step["evidence"]
187
+ placeholder = evidence_info.get("placeholder", "")
188
+ tool_input = evidence_info.get("tool_input", "")
189
+
190
+ if placeholder:
191
+ # Find dependencies by looking for other placeholders in tool_input
192
+ dependence[placeholder] = []
193
+
194
+ # Convert tool_input to string to search for placeholders
195
+ tool_input_str = str(tool_input)
196
+ for var in re.findall(r"#E\d+", tool_input_str):
197
+ if var in evidences and var != placeholder:
198
+ dependence[placeholder].append(var)
199
+
200
+ # Create execution levels using topological sort
201
+ levels = []
202
+ remaining = dict(dependence)
203
+
204
+ while remaining:
205
+ # Find items with no dependencies (can be executed in parallel)
206
+ ready = [placeholder for placeholder, deps in remaining.items() if not deps]
207
+
208
+ if not ready:
209
+ raise ValueError("Circular dependency detected in planner output")
210
+
211
+ levels.append(ready)
212
+
213
+ # Remove completed items from remaining
214
+ for placeholder in ready:
215
+ remaining.pop(placeholder)
216
+
217
+ # Remove completed items from other dependencies
218
+ # for placeholder in remaining.items():
219
+ # remaining[placeholder] = [dep for dep in remaining[placeholder] if dep not in ready]
220
+ for ph, deps in list(remaining.items()):
221
+ remaining[ph] = [dep for dep in deps if dep not in ready]
222
+ return evidences, levels
223
+
133
224
  @staticmethod
134
225
  def _replace_placeholder(placeholder: str, tool_input: str | dict, tool_output: str | dict) -> str | dict:
135
226
 
@@ -201,119 +292,204 @@ class ReWOOAgentGraph(BaseAgent):
201
292
 
202
293
  steps = self._parse_planner_output(str(plan.content))
203
294
 
295
+ # Parse dependencies and create execution levels for parallel processing
296
+ evidence_map, execution_levels = self._parse_planner_dependencies(steps.content)
297
+
204
298
  if self.detailed_logs:
205
299
  agent_response_log_message = AGENT_CALL_LOG_MESSAGE % (task, str(plan.content))
206
300
  logger.info("ReWOO agent planner output: %s", agent_response_log_message)
301
+ logger.info("ReWOO agent execution levels: %s", execution_levels)
207
302
 
208
- return {"plan": plan, "steps": steps}
303
+ return {
304
+ "plan": plan,
305
+ "steps": steps,
306
+ "evidence_map": evidence_map,
307
+ "execution_levels": execution_levels,
308
+ "current_level": 0
309
+ }
209
310
 
210
311
  except Exception as ex:
211
312
  logger.error("%s Failed to call planner_node: %s", AGENT_LOG_PREFIX, ex)
212
313
  raise
213
314
 
214
315
  async def executor_node(self, state: ReWOOGraphState):
316
+ """
317
+ Execute tools in parallel for the current dependency level.
318
+
319
+ This replaces the sequential execution with parallel execution of tools
320
+ that have no dependencies between them.
321
+ """
215
322
  try:
216
323
  logger.debug("%s Starting the ReWOO Executor Node", AGENT_LOG_PREFIX)
217
324
 
218
- current_step = self._get_current_step(state)
219
- # The executor node should not be invoked after all steps are finished
220
- if current_step < 0:
221
- logger.error("%s ReWOO Executor is invoked with an invalid step number: %s",
222
- AGENT_LOG_PREFIX,
223
- current_step)
224
- raise RuntimeError(f"ReWOO Executor is invoked with an invalid step number: {current_step}")
225
-
226
- steps_content = state.steps.content
227
- if isinstance(steps_content, list) and current_step < len(steps_content):
228
- step = steps_content[current_step]
229
- if isinstance(step, dict) and "evidence" in step:
230
- step_info = step["evidence"]
231
- placeholder = step_info.get("placeholder", "")
232
- tool = step_info.get("tool", "")
233
- tool_input = step_info.get("tool_input", "")
325
+ current_level, level_complete = self._get_current_level_status(state)
326
+
327
+ # Should not be invoked if all levels are complete
328
+ if current_level < 0:
329
+ logger.error("%s ReWOO Executor invoked after all levels complete", AGENT_LOG_PREFIX)
330
+ raise RuntimeError("ReWOO Executor invoked after all levels complete")
331
+
332
+ # If current level is already complete, move to next level
333
+ if level_complete:
334
+ new_level = current_level + 1
335
+ logger.debug("%s Level %s complete, moving to level %s", AGENT_LOG_PREFIX, current_level, new_level)
336
+ return {"current_level": new_level}
337
+
338
+ # Get placeholders for current level
339
+ current_level_placeholders = state.execution_levels[current_level]
340
+
341
+ # Filter to only placeholders not yet completed
342
+ pending_placeholders = [p for p in current_level_placeholders if p not in state.intermediate_results]
343
+
344
+ if not pending_placeholders:
345
+ # All placeholders in this level are done, move to next level
346
+ new_level = current_level + 1
347
+ return {"current_level": new_level}
348
+
349
+ logger.debug("%s Executing level %s with %s tools in parallel: %s",
350
+ AGENT_LOG_PREFIX,
351
+ current_level,
352
+ len(pending_placeholders),
353
+ pending_placeholders)
354
+
355
+ # Execute all tools in current level in parallel
356
+ tasks = []
357
+ for placeholder in pending_placeholders:
358
+ step_info = state.evidence_map[placeholder]
359
+ task = self._execute_single_tool(placeholder, step_info, state.intermediate_results)
360
+ tasks.append(task)
361
+
362
+ # Wait for all tasks in current level to complete
363
+ results = await asyncio.gather(*tasks, return_exceptions=True)
364
+
365
+ # Process results and update intermediate_results
366
+ updated_intermediate_results = dict(state.intermediate_results)
367
+
368
+ for i, result in enumerate(results):
369
+ placeholder = pending_placeholders[i]
370
+
371
+ if isinstance(result, Exception):
372
+ logger.error("%s Tool execution failed for %s: %s", AGENT_LOG_PREFIX, placeholder, result)
373
+ # Create error tool message
374
+ error_message = f"Tool execution failed: {str(result)}"
375
+ updated_intermediate_results[placeholder] = ToolMessage(content=error_message,
376
+ tool_call_id=placeholder)
377
+ if self.raise_tool_call_error:
378
+ raise result
234
379
  else:
235
- logger.error("%s Invalid step format at index %s", AGENT_LOG_PREFIX, current_step)
236
- return {"intermediate_results": state.intermediate_results}
237
- else:
238
- logger.error("%s Invalid steps content or index %s", AGENT_LOG_PREFIX, current_step)
239
- return {"intermediate_results": state.intermediate_results}
240
-
241
- intermediate_results = state.intermediate_results
242
-
243
- # Replace the placeholder in the tool input with the previous tool output
244
- for _placeholder, _tool_output in intermediate_results.items():
245
- _tool_output = _tool_output.content
246
- # If the content is a list, get the first element which should be a dict
247
- if isinstance(_tool_output, list):
248
- _tool_output = _tool_output[0]
249
- assert isinstance(_tool_output, dict)
250
-
251
- tool_input = self._replace_placeholder(_placeholder, tool_input, _tool_output)
252
-
253
- requested_tool = self._get_tool(tool)
254
- if not requested_tool:
255
- configured_tool_names = list(self.tools_dict.keys())
256
- logger.warning(
257
- "%s ReWOO Agent wants to call tool %s. In the ReWOO Agent's configuration within the config file,"
258
- "there is no tool with that name: %s",
259
- AGENT_LOG_PREFIX,
260
- tool,
261
- configured_tool_names)
262
-
263
- intermediate_results[placeholder] = ToolMessage(content=TOOL_NOT_FOUND_ERROR_MESSAGE.format(
264
- tool_name=tool, tools=configured_tool_names),
265
- tool_call_id=tool)
266
- return {"intermediate_results": intermediate_results}
267
-
268
- if self.detailed_logs:
269
- logger.debug("%s Calling tool %s with input: %s", AGENT_LOG_PREFIX, requested_tool.name, tool_input)
270
-
271
- # Run the tool. Try to use structured input, if possible
272
- tool_input_parsed = self._parse_tool_input(tool_input)
273
- tool_response = await self._call_tool(requested_tool,
274
- tool_input_parsed,
275
- RunnableConfig(callbacks=self.callbacks),
276
- max_retries=self.tool_call_max_retries)
380
+ updated_intermediate_results[placeholder] = result
381
+ # Check if the ToolMessage has error status and raise_tool_call_error is True
382
+ if (isinstance(result, ToolMessage) and hasattr(result, 'status') and result.status == "error"
383
+ and self.raise_tool_call_error):
384
+ logger.error("%s Tool call failed for %s: %s", AGENT_LOG_PREFIX, placeholder, result.content)
385
+ raise RuntimeError(f"Tool call failed: {result.content}")
277
386
 
278
387
  if self.detailed_logs:
279
- self._log_tool_response(requested_tool.name, tool_input_parsed, str(tool_response))
388
+ logger.info("%s Completed level %s with %s tools",
389
+ AGENT_LOG_PREFIX,
390
+ current_level,
391
+ len(pending_placeholders))
280
392
 
281
- if self.raise_tool_call_error and tool_response.status == "error":
282
- raise RuntimeError(f"Tool call failed: {tool_response.content}")
283
-
284
- intermediate_results[placeholder] = tool_response
285
- return {"intermediate_results": intermediate_results}
393
+ return {"intermediate_results": updated_intermediate_results}
286
394
 
287
395
  except Exception as ex:
288
396
  logger.error("%s Failed to call executor_node: %s", AGENT_LOG_PREFIX, ex)
289
397
  raise
290
398
 
399
+ async def _execute_single_tool(self,
400
+ placeholder: str,
401
+ step_info: Dict,
402
+ intermediate_results: Dict[str, ToolMessage]) -> ToolMessage:
403
+ """
404
+ Execute a single tool with proper placeholder replacement.
405
+
406
+ :param placeholder: The evidence placeholder (e.g., "#E1").
407
+ :param step_info: Step information containing tool and tool_input.
408
+ :param intermediate_results: Current intermediate results for placeholder replacement.
409
+ :return: ToolMessage with the tool execution result.
410
+ """
411
+ evidence_info = step_info["evidence"]
412
+ tool_name = evidence_info.get("tool", "")
413
+ tool_input = evidence_info.get("tool_input", "")
414
+
415
+ # Replace placeholders in tool input with previous results
416
+ for _placeholder, _tool_output in intermediate_results.items():
417
+ _tool_output_content = _tool_output.content
418
+ # If the content is a list, get the first element which should be a dict
419
+ if isinstance(_tool_output_content, list):
420
+ _tool_output_content = _tool_output_content[0]
421
+ assert isinstance(_tool_output_content, dict)
422
+
423
+ tool_input = self._replace_placeholder(_placeholder, tool_input, _tool_output_content)
424
+
425
+ # Get the requested tool
426
+ requested_tool = self._get_tool(tool_name)
427
+ if not requested_tool:
428
+ configured_tool_names = list(self.tools_dict.keys())
429
+ logger.warning(
430
+ "%s ReWOO Agent wants to call tool %s. In the ReWOO Agent's configuration within the config file,"
431
+ "there is no tool with that name: %s",
432
+ AGENT_LOG_PREFIX,
433
+ tool_name,
434
+ configured_tool_names)
435
+
436
+ return ToolMessage(content=TOOL_NOT_FOUND_ERROR_MESSAGE.format(tool_name=tool_name,
437
+ tools=configured_tool_names),
438
+ tool_call_id=placeholder)
439
+
440
+ if self.detailed_logs:
441
+ logger.debug("%s Calling tool %s with input: %s", AGENT_LOG_PREFIX, requested_tool.name, tool_input)
442
+
443
+ # Parse and execute the tool
444
+ tool_input_parsed = self._parse_tool_input(tool_input)
445
+ tool_response = await self._call_tool(requested_tool,
446
+ tool_input_parsed,
447
+ RunnableConfig(callbacks=self.callbacks),
448
+ max_retries=self.tool_call_max_retries)
449
+
450
+ if self.detailed_logs:
451
+ self._log_tool_response(requested_tool.name, tool_input_parsed, str(tool_response))
452
+
453
+ return tool_response
454
+
291
455
  async def solver_node(self, state: ReWOOGraphState):
292
456
  try:
293
457
  logger.debug("%s Starting the ReWOO Solver Node", AGENT_LOG_PREFIX)
294
458
 
295
459
  plan = ""
296
- # Add the tool outputs of each step to the plan
297
- for step in state.steps.content:
298
- step_info = step["evidence"]
299
- placeholder = step_info.get("placeholder", "")
300
- tool_input = step_info.get("tool_input", "")
301
-
302
- intermediate_results = state.intermediate_results
303
- for _placeholder, _tool_output in intermediate_results.items():
304
- _tool_output = _tool_output.content
460
+ # Add the tool outputs of each step to the plan using evidence_map
461
+ for placeholder, step_info in state.evidence_map.items():
462
+ evidence_info = step_info["evidence"]
463
+ original_tool_input = evidence_info.get("tool_input", "")
464
+ tool_name = evidence_info.get("tool", "")
465
+
466
+ # Replace placeholders in tool input with actual results
467
+ final_tool_input = original_tool_input
468
+ for _placeholder, _tool_output in state.intermediate_results.items():
469
+ _tool_output_content = _tool_output.content
305
470
  # If the content is a list, get the first element which should be a dict
306
- if isinstance(_tool_output, list):
307
- _tool_output = _tool_output[0]
308
- assert isinstance(_tool_output, dict)
309
-
310
- tool_input = self._replace_placeholder(_placeholder, tool_input, _tool_output)
311
-
312
- placeholder = placeholder.replace(_placeholder, str(_tool_output))
313
-
314
- _plan = step.get("plan")
315
- tool = step_info.get("tool")
316
- plan += f"Plan: {_plan}\n{placeholder} = {tool}[{tool_input}]"
471
+ if isinstance(_tool_output_content, list):
472
+ _tool_output_content = _tool_output_content[0]
473
+ assert isinstance(_tool_output_content, dict)
474
+
475
+ final_tool_input = self._replace_placeholder(_placeholder, final_tool_input, _tool_output_content)
476
+
477
+ # Get the final result for this placeholder
478
+ final_result = ""
479
+ if placeholder in state.intermediate_results:
480
+ result_content = state.intermediate_results[placeholder].content
481
+ if isinstance(result_content, list):
482
+ result_content = result_content[0]
483
+ if isinstance(result_content, dict):
484
+ final_result = str(result_content)
485
+ else:
486
+ final_result = str(result_content)
487
+ else:
488
+ final_result = str(result_content)
489
+
490
+ step_plan = step_info.get("plan", "")
491
+ plan += f"Plan: {step_plan}\n{placeholder} = \
492
+ {tool_name}[{final_tool_input}]\nResult: {final_result}\n\n"
317
493
 
318
494
  task = str(state.task.content)
319
495
  solver_prompt = self.solver_prompt.partial(plan=plan)
@@ -336,12 +512,24 @@ class ReWOOAgentGraph(BaseAgent):
336
512
  try:
337
513
  logger.debug("%s Starting the ReWOO Conditional Edge", AGENT_LOG_PREFIX)
338
514
 
339
- current_step = self._get_current_step(state)
340
- if current_step == -1:
341
- logger.debug("%s The ReWOO Executor has finished its task", AGENT_LOG_PREFIX)
515
+ current_level, level_complete = self._get_current_level_status(state)
516
+
517
+ # If all levels are complete, move to solver
518
+ if current_level == -1:
519
+ logger.debug("%s All execution levels complete, moving to solver", AGENT_LOG_PREFIX)
342
520
  return AgentDecision.END
343
521
 
344
- logger.debug("%s The ReWOO Executor is still working on the task", AGENT_LOG_PREFIX)
522
+ # If current level is complete, check if there are more levels
523
+ if level_complete:
524
+ next_level = current_level + 1
525
+ if next_level >= len(state.execution_levels):
526
+ logger.debug("%s All execution levels complete, moving to solver", AGENT_LOG_PREFIX)
527
+ return AgentDecision.END
528
+
529
+ logger.debug("%s Continuing with executor (level %s, complete: %s)",
530
+ AGENT_LOG_PREFIX,
531
+ current_level,
532
+ level_complete)
345
533
  return AgentDecision.TOOL
346
534
 
347
535
  except Exception as ex:
@@ -18,33 +18,29 @@ For the following task, make plans that can solve the problem step by step. For
18
18
  which external tool together with tool input to retrieve evidence. You can store the evidence into a \
19
19
  placeholder #E that can be called by later tools. (Plan, #E1, Plan, #E2, Plan, ...)
20
20
 
21
- You may ask the human to the following tools:
21
+ The following tools and respective requirements are available to you:
22
22
 
23
23
  {tools}
24
24
 
25
- The tools should be one of the following: [{tool_names}]
25
+ The tool calls you make should be one of the following: [{tool_names}]
26
26
 
27
27
  You are not required to use all the tools listed. Choose only the ones that best fit the needs of each plan step.
28
28
 
29
- Your output must be a JSON array where each element represents one planning step. Each step must be an object with
30
-
29
+ Your output must be a JSON array where each element represents one planning step. Each step must be an object with \
31
30
  exactly two keys:
32
31
 
33
32
  1. "plan": A string that describes in detail the action or reasoning for that step.
34
33
 
35
- 2. "evidence": An object representing the external tool call associated with that plan step. This object must have the
34
+ 2. "evidence": An object representing the external tool call associated with that plan step. This object must have the \
36
35
  following keys:
37
36
 
38
- -"placeholder": A string that identifies the evidence placeholder (e.g., "#E1", "#E2", etc.). The numbering should
39
- be sequential based on the order of steps.
37
+ -"placeholder": A string that identifies the evidence placeholder ("#E1", "#E2", ...). The numbering should \
38
+ be sequential based on the order of steps.
40
39
 
41
40
  -"tool": A string specifying the name of the external tool used.
42
41
 
43
- -"tool_input": The input to the tool. This can be a string, array, or object, depending on the requirements of the
44
- tool.
45
-
46
- Do not include any additional keys or characters in your output, and do not wrap your response with markdown formatting.
47
- Your output must be strictly valid JSON.
42
+ -"tool_input": The input to the tool. This can be a string, array, or object, depending on the requirements of the \
43
+ tool. Be careful about type assumptions because the output of former tools might contain noise.
48
44
 
49
45
  Important instructions:
50
46
 
@@ -58,27 +54,28 @@ Here is an example of how a valid JSON output should look:
58
54
 
59
55
  [
60
56
  \'{{
61
- "plan": "Calculate the result of 2023 minus 25.",
57
+ "plan": "Find Alex's schedule on Sep 25, 2025",
62
58
  "evidence": \'{{
63
59
  "placeholder": "#E1",
64
- "tool": "calculator_subtract",
65
- "tool_input": [2023, 25]
60
+ "tool": "search_calendar",
61
+ "tool_input": ("Alex", "09/25/2025")
66
62
  }}\'
67
63
  }}\',
68
64
  \'{{
69
- "plan": "Retrieve the year represented by the result stored in #E1.",
65
+ "plan": "Find Bill's schedule on sep 25, 2025",
70
66
  "evidence": \'{{
71
67
  "placeholder": "#E2",
72
- "tool": "haystack_chitchat_agent",
73
- "tool_input": "Response with the result number contained in #E1"
68
+ "tool": "search_calendar",
69
+ "tool_input": ("Bill", "09/25/2025")
74
70
  }}\'
75
71
  }}\',
76
72
  \'{{
77
- "plan": "Search for the CEO of Golden State Warriors in the year stored in #E2.",
73
+ "plan": "Suggest a time for 1-hour meeting given Alex's and Bill's schedule.",
78
74
  "evidence": \'{{
79
75
  "placeholder": "#E3",
80
- "tool": "internet_search",
81
- "tool_input": "Who was the CEO of Golden State Warriors in the year #E2?"
76
+ "tool": "llm_chat",
77
+ "tool_input": "Find a common 1-hour time slot for Alex and Bill given their schedules. \
78
+ Alex's schedule: #E1; Bill's schedule: #E2?"
82
79
  }}\'
83
80
  }}\'
84
81
  ]
@@ -94,7 +91,7 @@ task: {task}
94
91
  """
95
92
 
96
93
  SOLVER_SYSTEM_PROMPT = """
97
- Solve the following task or problem. To solve the problem, we have made step-by-step Plan and \
94
+ Solve the following task or problem. To solve the problem, we have made some Plans ahead and \
98
95
  retrieved corresponding Evidence to each Plan. Use them with caution since long evidence might \
99
96
  contain irrelevant information.
100
97
 
@@ -80,7 +80,7 @@ class APIKeyAuthProvider(AuthProviderBase[APIKeyAuthProviderConfig]):
80
80
 
81
81
  raise ValueError(f"Unsupported header auth scheme: {header_auth_scheme}")
82
82
 
83
- async def authenticate(self, user_id: str | None = None) -> AuthResult | None:
83
+ async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult | None:
84
84
  """
85
85
  Authenticate the user using the API key credentials.
86
86
 
@@ -38,7 +38,7 @@ class HTTPBasicAuthProvider(AuthProviderBase):
38
38
 
39
39
  self._authenticated_tokens: dict[str, AuthResult] = {}
40
40
 
41
- async def authenticate(self, user_id: str | None = None) -> AuthResult:
41
+ async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
42
42
  """
43
43
  Performs simple HTTP Authentication using the provided user ID.
44
44
  """
@@ -54,7 +54,7 @@ class AuthProviderBase(typing.Generic[AuthProviderBaseConfigT], ABC):
54
54
  return self._config
55
55
 
56
56
  @abstractmethod
57
- async def authenticate(self, user_id: str | None = None) -> AuthResult:
57
+ async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
58
58
  """
59
59
  Perform the authentication process for the client.
60
60
 
@@ -62,6 +62,9 @@ class AuthProviderBase(typing.Generic[AuthProviderBaseConfigT], ABC):
62
62
  target API service, which may include obtaining tokens, refreshing credentials,
63
63
  or completing multi-step authentication flows.
64
64
 
65
+ Args:
66
+ user_id: Optional user identifier for authentication
67
+ kwargs: Additional authentication parameters for example: http response (typically from a 401)
65
68
  Raises:
66
69
  NotImplementedError: Must be implemented by subclasses.
67
70
  """
@@ -71,7 +74,7 @@ class AuthProviderBase(typing.Generic[AuthProviderBaseConfigT], ABC):
71
74
 
72
75
  class FlowHandlerBase(ABC):
73
76
  """
74
- Handles front-end specifc flows for authentication clients.
77
+ Handles front-end specific flows for authentication clients.
75
78
 
76
79
  Each front end will define a FlowHandler that will implement the authenticate method.
77
80
 
@@ -15,6 +15,7 @@
15
15
 
16
16
  from datetime import datetime
17
17
  from datetime import timezone
18
+ from typing import Callable
18
19
 
19
20
  from authlib.integrations.httpx_client import OAuth2Client as AuthlibOAuth2Client
20
21
  from pydantic import SecretStr
@@ -22,6 +23,7 @@ from pydantic import SecretStr
22
23
  from nat.authentication.interfaces import AuthProviderBase
23
24
  from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
24
25
  from nat.builder.context import Context
26
+ from nat.data_models.authentication import AuthenticatedContext
25
27
  from nat.data_models.authentication import AuthFlowType
26
28
  from nat.data_models.authentication import AuthResult
27
29
  from nat.data_models.authentication import BearerTokenCred
@@ -32,7 +34,7 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
32
34
  def __init__(self, config: OAuth2AuthCodeFlowProviderConfig):
33
35
  super().__init__(config)
34
36
  self._authenticated_tokens: dict[str, AuthResult] = {}
35
- self._context = Context.get()
37
+ self._auth_callback = None
36
38
 
37
39
  async def _attempt_token_refresh(self, user_id: str, auth_result: AuthResult) -> AuthResult | None:
38
40
  refresh_token = auth_result.raw.get("refresh_token")
@@ -62,7 +64,12 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
62
64
 
63
65
  return new_auth_result
64
66
 
65
- async def authenticate(self, user_id: str | None = None) -> AuthResult:
67
+ def _set_custom_auth_callback(self,
68
+ auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType],
69
+ AuthenticatedContext]):
70
+ self._auth_callback = auth_callback
71
+
72
+ async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
66
73
  if user_id is None and hasattr(Context.get(), "metadata") and hasattr(
67
74
  Context.get().metadata, "cookies") and Context.get().metadata.cookies is not None:
68
75
  session_id = Context.get().metadata.cookies.get("nat-session", None)
@@ -80,7 +87,12 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
80
87
  if refreshed_auth_result:
81
88
  return refreshed_auth_result
82
89
 
83
- auth_callback = self._context.user_auth_callback
90
+ # Try getting callback from the context if that's not set, use the default callback
91
+ try:
92
+ auth_callback = Context.get().user_auth_callback
93
+ except RuntimeError:
94
+ auth_callback = self._auth_callback
95
+
84
96
  if not auth_callback:
85
97
  raise RuntimeError("Authentication callback not set on Context.")
86
98
 
@@ -157,6 +157,66 @@ def print_tool(tool_dict: dict[str, str | None], detail: bool = False) -> None:
157
157
  click.echo("-" * 60)
158
158
 
159
159
 
160
+ def _set_auth_defaults(auth: bool,
161
+ url: str | None,
162
+ auth_redirect_uri: str | None,
163
+ auth_user_id: str | None,
164
+ auth_scopes: str | None) -> tuple[str | None, str | None, list[str] | None]:
165
+ """Set default auth values when --auth flag is used.
166
+
167
+ Args:
168
+ auth: Whether --auth flag was used
169
+ url: MCP server URL
170
+ auth_redirect_uri: OAuth2 redirect URI
171
+ auth_user_id: User ID for authentication
172
+ auth_scopes: OAuth2 scopes (comma-separated string)
173
+
174
+ Returns:
175
+ Tuple of (auth_redirect_uri, auth_user_id, auth_scopes_list) with defaults applied
176
+ """
177
+ if auth:
178
+ auth_redirect_uri = auth_redirect_uri or "http://localhost:8000/auth/redirect"
179
+ auth_user_id = auth_user_id or url
180
+ auth_scopes = auth_scopes or ""
181
+
182
+ # Convert comma-separated string to list, stripping whitespace
183
+ auth_scopes_list = [scope.strip() for scope in auth_scopes.split(',')] if auth_scopes else None
184
+
185
+ return auth_redirect_uri, auth_user_id, auth_scopes_list
186
+
187
+
188
+ async def _create_mcp_client_config(
189
+ builder,
190
+ server_cfg,
191
+ url: str | None,
192
+ transport: str,
193
+ auth_redirect_uri: str | None,
194
+ auth_user_id: str | None,
195
+ auth_scopes: list[str] | None,
196
+ ):
197
+ from nat.plugins.mcp.client_impl import MCPClientConfig
198
+
199
+ if url and transport == "streamable-http" and auth_redirect_uri:
200
+ try:
201
+ from nat.plugins.mcp.auth.auth_provider_config import MCPOAuth2ProviderConfig
202
+ auth_config = MCPOAuth2ProviderConfig(
203
+ server_url=url,
204
+ redirect_uri=auth_redirect_uri,
205
+ default_user_id=auth_user_id or url,
206
+ scopes=auth_scopes or [],
207
+ )
208
+ auth_provider_name = "mcp_oauth2_cli"
209
+ await builder.add_auth_provider(auth_provider_name, auth_config)
210
+ server_cfg.auth_provider = auth_provider_name
211
+ except ImportError:
212
+ click.echo(
213
+ "[WARNING] MCP OAuth2 authentication requires nvidia-nat-mcp package.",
214
+ err=True,
215
+ )
216
+
217
+ return MCPClientConfig(server=server_cfg)
218
+
219
+
160
220
  async def list_tools_via_function_group(
161
221
  command: str | None,
162
222
  url: str | None,
@@ -164,6 +224,9 @@ async def list_tools_via_function_group(
164
224
  transport: str = 'sse',
165
225
  args: list[str] | None = None,
166
226
  env: dict[str, str] | None = None,
227
+ auth_redirect_uri: str | None = None,
228
+ auth_user_id: str | None = None,
229
+ auth_scopes: list[str] | None = None,
167
230
  ) -> list[dict[str, str | None]]:
168
231
  """List tools by constructing the mcp_client function group and introspecting functions.
169
232
 
@@ -192,15 +255,24 @@ async def list_tools_via_function_group(
192
255
  args=args if transport == 'stdio' else None,
193
256
  env=env if transport == 'stdio' else None,
194
257
  )
258
+
195
259
  group_cfg = MCPClientConfig(server=server_cfg)
196
260
 
197
261
  tools: list[dict[str, str | None]] = []
198
262
 
199
263
  async with WorkflowBuilder() as builder: # type: ignore
264
+ # Add auth provider if url is provided and auth_redirect_uri is given (only for streamable-http)
265
+ group_cfg = await _create_mcp_client_config(builder,
266
+ server_cfg,
267
+ url,
268
+ transport,
269
+ auth_redirect_uri,
270
+ auth_user_id,
271
+ auth_scopes)
200
272
  group = await builder.add_function_group("mcp_client", group_cfg)
201
273
 
202
274
  # Access functions exposed by the group
203
- fns = group.get_accessible_functions()
275
+ fns = await group.get_accessible_functions()
204
276
 
205
277
  def to_tool_entry(full_name: str, fn_obj) -> dict[str, str | None]:
206
278
  # full_name like "mcp_client.<tool>"
@@ -324,7 +396,10 @@ async def ping_mcp_server(url: str,
324
396
  transport: str = 'streamable-http',
325
397
  command: str | None = None,
326
398
  args: list[str] | None = None,
327
- env: dict[str, str] | None = None) -> MCPPingResult:
399
+ env: dict[str, str] | None = None,
400
+ auth_redirect_uri: str | None = None,
401
+ auth_user_id: str | None = None,
402
+ auth_scopes: list[str] | None = None) -> MCPPingResult:
328
403
  """Ping an MCP server to check if it's responsive.
329
404
 
330
405
  Args:
@@ -383,7 +458,13 @@ def mcp_client_command():
383
458
  """
384
459
  MCP client commands.
385
460
  """
386
- return None
461
+ try:
462
+ from nat.runtime.loader import PluginTypes
463
+ from nat.runtime.loader import discover_and_register_plugins
464
+ discover_and_register_plugins(PluginTypes.CONFIG_OBJECT)
465
+ except ImportError:
466
+ click.echo("[WARNING] MCP client functionality requires nvidia-nat-mcp package.", err=True)
467
+ pass
387
468
 
388
469
 
389
470
  @mcp_client_command.group(name="tool", invoke_without_command=False, help="Inspect and call MCP tools.")
@@ -412,14 +493,36 @@ def mcp_client_tool_group():
412
493
  @click.option('--tool', default=None, help='Get details for a specific tool by name')
413
494
  @click.option('--detail', is_flag=True, help='Show full details for all tools')
414
495
  @click.option('--json-output', is_flag=True, help='Output tool metadata in JSON format')
496
+ @click.option('--auth',
497
+ is_flag=True,
498
+ help='Enable OAuth2 authentication with default settings (streamable-http only, not with --direct)')
499
+ @click.option('--auth-redirect-uri',
500
+ help='OAuth2 redirect URI for authentication (streamable-http only, not with --direct)')
501
+ @click.option('--auth-user-id', help='User ID for authentication (streamable-http only, not with --direct)')
502
+ @click.option('--auth-scopes', help='OAuth2 scopes (comma-separated, streamable-http only, not with --direct)')
415
503
  @click.pass_context
416
- def mcp_client_tool_list(ctx, direct, url, transport, command, args, env, tool, detail, json_output):
504
+ def mcp_client_tool_list(ctx,
505
+ direct,
506
+ url,
507
+ transport,
508
+ command,
509
+ args,
510
+ env,
511
+ tool,
512
+ detail,
513
+ json_output,
514
+ auth,
515
+ auth_redirect_uri,
516
+ auth_user_id,
517
+ auth_scopes):
417
518
  """List MCP tool names (default) or show detailed tool information.
418
519
 
419
520
  Use --detail for full output including descriptions and input schemas.
420
521
  If --tool is provided, always shows full output for that specific tool.
421
522
  Use --direct to bypass MCPBuilder and use raw MCP protocol.
422
523
  Use --json-output to get structured JSON data instead of formatted text.
524
+ Use --auth to enable auth with default settings (streamable-http only, not with --direct).
525
+ Use --auth-redirect-uri to enable auth for protected MCP servers (streamable-http only, not with --direct).
423
526
 
424
527
  Args:
425
528
  ctx (click.Context): Click context object for command invocation
@@ -428,13 +531,22 @@ def mcp_client_tool_list(ctx, direct, url, transport, command, args, env, tool,
428
531
  tool (str | None): Optional specific tool name to retrieve detailed info for
429
532
  detail (bool): Whether to show full details (description + schema) for all tools
430
533
  json_output (bool): Whether to output tool metadata in JSON format instead of text
534
+ auth (bool): Whether to enable OAuth2 authentication (streamable-http only, not with --direct)
535
+ auth_redirect_uri (str | None): redirect URI for auth (streamable-http only, not with --direct)
536
+ auth_user_id (str | None): User ID for authentication (streamable-http only, not with --direct)
537
+ auth_scopes (str | None): OAuth2 scopes (comma-separated, streamable-http only, not with --direct)
431
538
 
432
539
  Examples:
433
540
  nat mcp client tool list # List tool names only
434
541
  nat mcp client tool list --detail # Show all tools with full details
435
542
  nat mcp client tool list --tool my_tool # Show details for specific tool
436
543
  nat mcp client tool list --json-output # Get JSON format output
437
- nat mcp client tool list --direct --url http://... # Use direct protocol with custom URL
544
+ nat mcp client tool list --direct --url http://... # Use direct protocol with custom URL (no auth)
545
+ nat mcp client tool list --url https://example.com/mcp/ --auth # With auth using defaults
546
+ nat mcp client tool list --url https://example.com/mcp/ --transport streamable-http \
547
+ --auth-redirect-uri http://localhost:8000/auth/redirect # With custom auth settings
548
+ nat mcp client tool list --url https://example.com/mcp/ --transport streamable-http \
549
+ --auth-redirect-uri http://localhost:8000/auth/redirect --auth-user-id myuser # With auth and user ID
438
550
  """
439
551
  if ctx.invoked_subcommand is not None:
440
552
  return
@@ -447,11 +559,28 @@ def mcp_client_tool_list(ctx, direct, url, transport, command, args, env, tool,
447
559
  click.echo("[ERROR] --url is required when using sse or streamable-http client type", err=True)
448
560
  return
449
561
 
562
+ # Set auth defaults if --auth flag is used
563
+ auth_redirect_uri, auth_user_id, auth_scopes_list = _set_auth_defaults(
564
+ auth, url, auth_redirect_uri, auth_user_id, auth_scopes
565
+ )
566
+
450
567
  stdio_args = args.split() if args else []
451
568
  stdio_env = dict(var.split('=', 1) for var in env.split()) if env else None
452
569
 
453
- fetcher = list_tools_direct if direct else list_tools_via_function_group
454
- tools = asyncio.run(fetcher(command, url, tool, transport, stdio_args, stdio_env))
570
+ if direct:
571
+ tools = asyncio.run(
572
+ list_tools_direct(command, url, tool_name=tool, transport=transport, args=stdio_args, env=stdio_env))
573
+ else:
574
+ tools = asyncio.run(
575
+ list_tools_via_function_group(command,
576
+ url,
577
+ tool_name=tool,
578
+ transport=transport,
579
+ args=stdio_args,
580
+ env=stdio_env,
581
+ auth_redirect_uri=auth_redirect_uri,
582
+ auth_user_id=auth_user_id,
583
+ auth_scopes=auth_scopes_list))
455
584
 
456
585
  if json_output:
457
586
  click.echo(json.dumps(tools, indent=2))
@@ -482,13 +611,20 @@ def mcp_client_tool_list(ctx, direct, url, transport, command, args, env, tool,
482
611
  @click.option('--env', help='For stdio: Environment variables in KEY=VALUE format (space-separated)')
483
612
  @click.option('--timeout', default=60, show_default=True, help='Timeout in seconds for ping request')
484
613
  @click.option('--json-output', is_flag=True, help='Output ping result in JSON format')
614
+ @click.option('--auth-redirect-uri',
615
+ help='OAuth2 redirect URI for authentication (streamable-http only, not with --direct)')
616
+ @click.option('--auth-user-id', help='User ID for authentication (streamable-http only, not with --direct)')
617
+ @click.option('--auth-scopes', help='OAuth2 scopes (comma-separated, streamable-http only, not with --direct)')
485
618
  def mcp_client_ping(url: str,
486
619
  transport: str,
487
620
  command: str | None,
488
621
  args: str | None,
489
622
  env: str | None,
490
623
  timeout: int,
491
- json_output: bool) -> None:
624
+ json_output: bool,
625
+ auth_redirect_uri: str | None,
626
+ auth_user_id: str | None,
627
+ auth_scopes: str | None) -> None:
492
628
  """Ping an MCP server to check if it's responsive.
493
629
 
494
630
  This command sends a ping request to the MCP server and measures the response time.
@@ -498,12 +634,16 @@ def mcp_client_ping(url: str,
498
634
  url (str): MCP server URL to ping (default: http://localhost:9901/mcp)
499
635
  timeout (int): Timeout in seconds for the ping request (default: 60)
500
636
  json_output (bool): Whether to output the result in JSON format
637
+ auth_redirect_uri (str | None): redirect URI for auth (streamable-http only, not with --direct)
638
+ auth_user_id (str | None): User ID for auth (streamable-http only, not with --direct)
639
+ auth_scopes (str | None): OAuth2 scopes (comma-separated, streamable-http only, not with --direct)
501
640
 
502
641
  Examples:
503
642
  nat mcp client ping # Ping default server
504
643
  nat mcp client ping --url http://custom-server:9901/mcp # Ping custom server
505
644
  nat mcp client ping --timeout 10 # Use 10 second timeout
506
645
  nat mcp client ping --json-output # Get JSON format output
646
+ nat mcp client ping --url https://example.com/mcp/ --transport streamable-http --auth # With auth
507
647
  """
508
648
  # Validate combinations similar to list command
509
649
  if not validate_transport_cli_args(transport, command, args, env):
@@ -512,7 +652,24 @@ def mcp_client_ping(url: str,
512
652
  stdio_args = args.split() if args else []
513
653
  stdio_env = dict(var.split('=', 1) for var in env.split()) if env else None
514
654
 
515
- result = asyncio.run(ping_mcp_server(url, timeout, transport, command, stdio_args, stdio_env))
655
+ # Auth validation: if user_id or scopes provided, require redirect_uri
656
+ if (auth_user_id or auth_scopes) and not auth_redirect_uri:
657
+ click.echo("[ERROR] --auth-redirect-uri is required when using --auth-user-id or --auth-scopes", err=True)
658
+ return
659
+
660
+ # Parse auth scopes, stripping whitespace
661
+ auth_scopes_list = [scope.strip() for scope in auth_scopes.split(',')] if auth_scopes else None
662
+
663
+ result = asyncio.run(
664
+ ping_mcp_server(url,
665
+ timeout,
666
+ transport,
667
+ command,
668
+ stdio_args,
669
+ stdio_env,
670
+ auth_redirect_uri,
671
+ auth_user_id,
672
+ auth_scopes_list))
516
673
 
517
674
  if json_output:
518
675
  click.echo(result.model_dump_json(indent=2))
@@ -635,7 +792,10 @@ async def call_tool_and_print(command: str | None,
635
792
  args: list[str] | None,
636
793
  env: dict[str, str] | None,
637
794
  tool_args: dict[str, Any] | None,
638
- direct: bool) -> str:
795
+ direct: bool,
796
+ auth_redirect_uri: str | None = None,
797
+ auth_user_id: str | None = None,
798
+ auth_scopes: list[str] | None = None) -> str:
639
799
  """Call an MCP tool either directly or via the function group and return output.
640
800
 
641
801
  When ``direct`` is True, uses the raw MCP protocol client (bypassing the
@@ -681,11 +841,25 @@ async def call_tool_and_print(command: str | None,
681
841
  args=args if transport == 'stdio' else None,
682
842
  env=env if transport == 'stdio' else None,
683
843
  )
844
+
684
845
  group_cfg = MCPClientConfig(server=server_cfg)
685
846
 
686
847
  async with WorkflowBuilder() as builder: # type: ignore
848
+ # Add auth provider if url is provided and auth_redirect_uri is given (only for streamable-http)
849
+ if url and transport == 'streamable-http' and auth_redirect_uri:
850
+ try:
851
+ group_cfg = await _create_mcp_client_config(builder,
852
+ server_cfg,
853
+ url,
854
+ transport,
855
+ auth_redirect_uri,
856
+ auth_user_id,
857
+ auth_scopes)
858
+ except ImportError:
859
+ click.echo("[WARNING] MCP OAuth2 authentication requires nvidia-nat-mcp package.", err=True)
860
+
687
861
  group = await builder.add_function_group("mcp_client", group_cfg)
688
- fns = group.get_accessible_functions()
862
+ fns = await group.get_accessible_functions()
689
863
  full = f"mcp_client.{tool_name}"
690
864
  fn = fns.get(full)
691
865
  if fn is None:
@@ -713,6 +887,13 @@ async def call_tool_and_print(command: str | None,
713
887
  @click.option('--args', help='For stdio: Additional arguments for the command (space-separated)')
714
888
  @click.option('--env', help='For stdio: Environment variables in KEY=VALUE format (space-separated)')
715
889
  @click.option('--json-args', default=None, help='Pass tool args as a JSON object string')
890
+ @click.option('--auth',
891
+ is_flag=True,
892
+ help='Enable OAuth2 authentication with default settings (streamable-http only, not with --direct)')
893
+ @click.option('--auth-redirect-uri',
894
+ help='OAuth2 redirect URI for authentication (streamable-http only, not with --direct)')
895
+ @click.option('--auth-user-id', help='User ID for authentication (streamable-http only, not with --direct)')
896
+ @click.option('--auth-scopes', help='OAuth2 scopes (comma-separated, streamable-http only, not with --direct)')
716
897
  def mcp_client_tool_call(tool_name: str,
717
898
  direct: bool,
718
899
  url: str | None,
@@ -720,7 +901,11 @@ def mcp_client_tool_call(tool_name: str,
720
901
  command: str | None,
721
902
  args: str | None,
722
903
  env: str | None,
723
- json_args: str | None) -> None:
904
+ json_args: str | None,
905
+ auth: bool,
906
+ auth_redirect_uri: str | None,
907
+ auth_user_id: str | None,
908
+ auth_scopes: str | None) -> None:
724
909
  """Call an MCP tool by name with optional JSON arguments.
725
910
 
726
911
  Validates transport parameters, parses ``--json-args`` into a dictionary,
@@ -737,13 +922,20 @@ def mcp_client_tool_call(tool_name: str,
737
922
  args (str | None): For ``stdio`` transport, space-separated command arguments.
738
923
  env (str | None): For ``stdio`` transport, space-separated ``KEY=VALUE`` pairs.
739
924
  json_args (str | None): JSON object string with tool arguments (e.g. '{"q": "hello"}').
925
+ auth_redirect_uri (str | None): redirect URI for auth (streamable-http only, not with --direct)
926
+ auth_user_id (str | None): User ID for authentication (streamable-http only, not with --direct)
927
+ auth_scopes (str | None): OAuth2 scopes (comma-separated, streamable-http only, not with --direct)
740
928
 
741
929
  Examples:
742
930
  nat mcp client tool call echo --json-args '{"text": "Hello"}'
743
931
  nat mcp client tool call search --direct --url http://localhost:9901/mcp \
744
- --json-args '{"query": "NVIDIA"}'
932
+ --json-args '{"query": "NVIDIA"}' # Direct mode (no auth)
745
933
  nat mcp client tool call run --transport stdio --command mcp-server \
746
934
  --args "--flag1 --flag2" --env "ENV1=V1 ENV2=V2" --json-args '{}'
935
+ nat mcp client tool call search --url https://example.com/mcp/ --auth \
936
+ --json-args '{"query": "test"}' # With auth using defaults
937
+ nat mcp client tool call search --url https://example.com/mcp/ \
938
+ --transport streamable-http --json-args '{"query": "test"}' --auth
747
939
  """
748
940
  # Validate transport args
749
941
  if not validate_transport_cli_args(transport, command, args, env):
@@ -753,6 +945,11 @@ def mcp_client_tool_call(tool_name: str,
753
945
  stdio_args = args.split() if args else []
754
946
  stdio_env = dict(var.split('=', 1) for var in env.split()) if env else None
755
947
 
948
+ # Set auth defaults if --auth flag is used
949
+ auth_redirect_uri, auth_user_id, auth_scopes_list = _set_auth_defaults(
950
+ auth, url, auth_redirect_uri, auth_user_id, auth_scopes
951
+ )
952
+
756
953
  # Parse tool args
757
954
  arg_obj: dict[str, Any] = {}
758
955
  if json_args:
@@ -777,6 +974,9 @@ def mcp_client_tool_call(tool_name: str,
777
974
  env=stdio_env,
778
975
  tool_args=arg_obj,
779
976
  direct=direct,
977
+ auth_redirect_uri=auth_redirect_uri,
978
+ auth_user_id=auth_user_id,
979
+ auth_scopes=auth_scopes_list,
780
980
  ))
781
981
  if output:
782
982
  click.echo(output)
@@ -249,21 +249,3 @@ class AuthResult(BaseModel):
249
249
  target_kwargs.setdefault(k, {}).update(v)
250
250
  else:
251
251
  target_kwargs[k] = v
252
-
253
-
254
- class AuthReason(str, Enum):
255
- """
256
- Why the caller is asking for auth now.
257
- """
258
- NORMAL = "normal"
259
- RETRY_AFTER_401 = "retry_after_401"
260
-
261
-
262
- class AuthRequest(BaseModel):
263
- """
264
- Authentication request payload for provider.authenticate(...).
265
- """
266
- model_config = ConfigDict(extra="forbid")
267
-
268
- reason: AuthReason = Field(default=AuthReason.NORMAL, description="Purpose of this auth attempt.")
269
- www_authenticate: str | None = Field(default=None, description="Raw WWW-Authenticate header from a 401 response.")
nat/runtime/session.py CHANGED
@@ -111,7 +111,7 @@ class SessionManager:
111
111
  token_user_authentication = self._context_state.user_auth_callback.set(user_authentication_callback)
112
112
 
113
113
  if isinstance(http_connection, WebSocket):
114
- self.set_metadata_from_websocket(user_message_id, conversation_id)
114
+ self.set_metadata_from_websocket(http_connection, user_message_id, conversation_id)
115
115
 
116
116
  if isinstance(http_connection, Request):
117
117
  self.set_metadata_from_http_request(http_connection)
@@ -161,11 +161,31 @@ class SessionManager:
161
161
  if request.headers.get("user-message-id"):
162
162
  self._context_state.user_message_id.set(request.headers["user-message-id"])
163
163
 
164
- def set_metadata_from_websocket(self, user_message_id: str | None, conversation_id: str | None) -> None:
164
+ def set_metadata_from_websocket(self,
165
+ websocket: WebSocket,
166
+ user_message_id: str | None,
167
+ conversation_id: str | None) -> None:
165
168
  """
166
169
  Extracts and sets user metadata for Websocket connections.
167
170
  """
168
171
 
172
+ # Extract cookies from WebSocket headers (similar to HTTP request)
173
+ if websocket and hasattr(websocket, 'scope') and 'headers' in websocket.scope:
174
+ cookies = {}
175
+ for header_name, header_value in websocket.scope.get('headers', []):
176
+ if header_name == b'cookie':
177
+ cookie_header = header_value.decode('utf-8')
178
+ # Parse cookie header: "name1=value1; name2=value2"
179
+ for cookie in cookie_header.split(';'):
180
+ cookie = cookie.strip()
181
+ if '=' in cookie:
182
+ name, value = cookie.split('=', 1)
183
+ cookies[name.strip()] = value.strip()
184
+
185
+ # Set cookies in metadata (same as HTTP request)
186
+ self._context.metadata._request.cookies = cookies
187
+ self._context_state.metadata.set(self._context.metadata)
188
+
169
189
  if conversation_id is not None:
170
190
  self._context_state.conversation_id.set(conversation_id)
171
191
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-nat
3
- Version: 1.3.0a20250926
3
+ Version: 1.3.0a20250929
4
4
  Summary: NVIDIA NeMo Agent toolkit
5
5
  Author: NVIDIA Corporation
6
6
  Maintainer: NVIDIA Corporation
@@ -305,6 +305,7 @@ Requires-Dist: nat_router_agent; extra == "examples"
305
305
  Requires-Dist: nat_semantic_kernel_demo; extra == "examples"
306
306
  Requires-Dist: nat_sequential_executor; extra == "examples"
307
307
  Requires-Dist: nat_simple_auth; extra == "examples"
308
+ Requires-Dist: nat_simple_auth_mcp; extra == "examples"
308
309
  Requires-Dist: nat_simple_web_query; extra == "examples"
309
310
  Requires-Dist: nat_simple_web_query_eval; extra == "examples"
310
311
  Requires-Dist: nat_simple_calculator; extra == "examples"
@@ -14,17 +14,17 @@ nat/agent/react_agent/register.py,sha256=b97dfNtA0I3bNBOGdr9_akQ89UDwPHPPb7LqpsZ
14
14
  nat/agent/reasoning_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
15
  nat/agent/reasoning_agent/reasoning_agent.py,sha256=k_0wEDqACQn1Rn1MAKxoXyqOKsthHCQ1gt990YYUqHU,9575
16
16
  nat/agent/rewoo_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- nat/agent/rewoo_agent/agent.py,sha256=ogTtUsNqbAjEprBKp2ZEoYfuHvgxnjV7NgygvSgMz7M,19264
18
- nat/agent/rewoo_agent/prompt.py,sha256=nFMav3Zl_vmKPLzAIhbQHlldWnurPJb1GlwnekUuxDs,3720
17
+ nat/agent/rewoo_agent/agent.py,sha256=920IYuVBq9Kg4h3pbe_p4Gpz2mBjpGJxdLivYYo3ce8,27594
18
+ nat/agent/rewoo_agent/prompt.py,sha256=B0JeL1xDX4VKcShlkkviEcAsOKAwzSlX8NcAQdmUUPw,3645
19
19
  nat/agent/rewoo_agent/register.py,sha256=668zAag6eqajX_PIfh6c-0I0UQN5D-lRiz_mNKHXXjM,8954
20
20
  nat/agent/tool_calling_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
21
  nat/agent/tool_calling_agent/agent.py,sha256=4SIp29I56oznPRQu7B3HCoX53Ri3_o3BRRYNJjeBkF8,11006
22
22
  nat/agent/tool_calling_agent/register.py,sha256=ijiRfgDVtt2p7_q1YbIQZmUVV8-jf3yT18HwtKyReUI,6822
23
23
  nat/authentication/__init__.py,sha256=Xs1JQ16L9btwreh4pdGKwskffAw1YFO48jKrU4ib_7c,685
24
- nat/authentication/interfaces.py,sha256=FAYM-QXVUn3a_8bmAZ7kP-lmN_BrLW8mo6krZJ3e0ME,3314
24
+ nat/authentication/interfaces.py,sha256=1J2CWEJ_n6CLA3_HD3XV28CSbyfxrPAHzr7Q4kKDFdc,3511
25
25
  nat/authentication/register.py,sha256=lFhswYUk9iZ53mq33fClR9UfjJPdjGIivGGNHQeWiYo,915
26
26
  nat/authentication/api_key/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
27
- nat/authentication/api_key/api_key_auth_provider.py,sha256=DoN1MoeB_uOWQMHUz2OiyhwbFxQkj0FOifN2TFr4sds,4022
27
+ nat/authentication/api_key/api_key_auth_provider.py,sha256=QGRZZilD2GryhRJoODfKqypH54IwQp0I-Cck40Dc7dM,4032
28
28
  nat/authentication/api_key/api_key_auth_provider_config.py,sha256=zfkxH3yvUSKKldRf1K4PPm0rJLXGH0GDH8xj7anPYGQ,5472
29
29
  nat/authentication/api_key/register.py,sha256=Mhv3WyZ9H7C2JN8VuPvwlsJEZrwXJCLXCIokkN9RrP0,1147
30
30
  nat/authentication/credential_validator/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
@@ -32,10 +32,10 @@ nat/authentication/credential_validator/bearer_token_validator.py,sha256=cwGENd_
32
32
  nat/authentication/exceptions/__init__.py,sha256=Xs1JQ16L9btwreh4pdGKwskffAw1YFO48jKrU4ib_7c,685
33
33
  nat/authentication/exceptions/api_key_exceptions.py,sha256=6wnz951BI77rFYuHxoHOthz-y5oE08uxsuM6G5EvOyM,1545
34
34
  nat/authentication/http_basic_auth/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
- nat/authentication/http_basic_auth/http_basic_auth_provider.py,sha256=vmukbeI32aK6Ulc-yoaFHszr6BqmCvZHIpKr4-arMhs,3442
35
+ nat/authentication/http_basic_auth/http_basic_auth_provider.py,sha256=OXr5TV87SiZtzSK9i_E6WXWyVhWq2MfqO_SS1aZ3p6U,3452
36
36
  nat/authentication/http_basic_auth/register.py,sha256=N2VD0vw7cYABsLxsGXl5yw0htc8adkrB0Y_EMxKwFfk,1235
37
37
  nat/authentication/oauth2/__init__.py,sha256=Xs1JQ16L9btwreh4pdGKwskffAw1YFO48jKrU4ib_7c,685
38
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py,sha256=iFp_kImFwDrweFHGTg4n9RGybH3-wRSJkyh-FPOxwDA,4538
38
+ nat/authentication/oauth2/oauth2_auth_code_flow_provider.py,sha256=TDZvGbpCLW6EDCI-a4Z9KnQyUEcRaJ5hBCPVyIuy-KY,5099
39
39
  nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py,sha256=e165ysd2pX2WTbV3_FQKEjEaa4TAXkJ7B98WUGbqnGE,2204
40
40
  nat/authentication/oauth2/oauth2_resource_server_config.py,sha256=ltcNp8Dwb2Q4tlwMN5Cl0B5pouTLtXRoV-QopfqV45M,5314
41
41
  nat/authentication/oauth2/register.py,sha256=7rXhf-ilgSS_bUJsd9pOOCotL1FM8dKUt3ke1TllKkQ,1228
@@ -83,7 +83,7 @@ nat/cli/commands/info/info.py,sha256=BGqshIEDpNRH9hM-06k-Gq-QX-qNddPICSWCN-ReC-g
83
83
  nat/cli/commands/info/list_channels.py,sha256=K97TE6wtikgImY-wAbFNi0HHUGtkvIFd2woaG06VkT0,1277
84
84
  nat/cli/commands/info/list_components.py,sha256=QlAJVONBA77xW8Lx6Autw5NTAZNy_VrJGr1GL9MfnHM,4532
85
85
  nat/cli/commands/mcp/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
86
- nat/cli/commands/mcp/mcp.py,sha256=CrpOOZZGq1nSsuFU_6h_VBL5ZE92Yr30WV9c5iirtzY,33332
86
+ nat/cli/commands/mcp/mcp.py,sha256=Phtxegf0Ww89NHrwj4dEZh3vaHxjm7RV7NZMByUGhpM,43860
87
87
  nat/cli/commands/object_store/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
88
88
  nat/cli/commands/object_store/object_store.py,sha256=_ivB-R30a-66fNy-fUzi58HQ0Ay0gYsGz7T1xXoRa3Y,8576
89
89
  nat/cli/commands/registry/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
@@ -113,7 +113,7 @@ nat/control_flow/router_agent/register.py,sha256=p15Jy05PjJhXRrjqWiEgy47zmc-CFCu
113
113
  nat/data_models/__init__.py,sha256=Xs1JQ16L9btwreh4pdGKwskffAw1YFO48jKrU4ib_7c,685
114
114
  nat/data_models/agent.py,sha256=IwDyb9Zc3R4Zd5rFeqt7q0EQswczAl5focxV9KozIzs,1625
115
115
  nat/data_models/api_server.py,sha256=Zl0eAd-yV9PD8vUH8eWRvXFcUBdY2tENKe73q-Uxxgg,25699
116
- nat/data_models/authentication.py,sha256=t9-2CQaIfk74DM3CLaJU5NAYGs7m6BjdFFm1hj8WwtY,9064
116
+ nat/data_models/authentication.py,sha256=wpH8tpy_wt1aA3eKIgbc8HfYdr_g6AfM7-NfWBZh-KA,8528
117
117
  nat/data_models/common.py,sha256=nXXfGrjpxebzBUa55mLdmzePLt7VFHvTAc6Znj3yEv0,5875
118
118
  nat/data_models/component.py,sha256=b_hXOA8Gm5UNvlFkAhsR6kEvf33ST50MKtr5kWf75Ao,1894
119
119
  nat/data_models/component_ref.py,sha256=KFDWFVCcvJCfBBcXTh9f3R802EVHBtHXh9OdbRqFmdM,4747
@@ -407,7 +407,7 @@ nat/retriever/nemo_retriever/retriever.py,sha256=gi3_qJFqE-iqRh3of_cmJg-SwzaQ3z2
407
407
  nat/runtime/__init__.py,sha256=Xs1JQ16L9btwreh4pdGKwskffAw1YFO48jKrU4ib_7c,685
408
408
  nat/runtime/loader.py,sha256=obUdAgZVYCPGC0R8u3wcoKFJzzSPQgJvrbU4OWygtog,7953
409
409
  nat/runtime/runner.py,sha256=Kzm5GRrGUFMQ_fbLOCJumYc4R-JXdTm5tUw2yMMDJpE,6450
410
- nat/runtime/session.py,sha256=U3UHQpdCBkCiJetsWdq9r6wUEVDBa2gv1VQedE64kY8,6959
410
+ nat/runtime/session.py,sha256=DG4cpVg6GCVFY0cGzZnz55eLj0LoK5Q9Vg3NgbTqOHM,8029
411
411
  nat/runtime/user_metadata.py,sha256=ce37NRYJWnMOWk6A7VAQ1GQztjMmkhMOq-uYf2gNCwo,3692
412
412
  nat/settings/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
413
413
  nat/settings/global_settings.py,sha256=JSlYnW3CeGJYGxlaXLz5F9xXf72I5TUz-w6qngG5di0,12476
@@ -469,10 +469,10 @@ nat/utils/reactive/base/observer_base.py,sha256=6BiQfx26EMumotJ3KoVcdmFBYR_fnAss
469
469
  nat/utils/reactive/base/subject_base.py,sha256=UQOxlkZTIeeyYmG5qLtDpNf_63Y7p-doEeUA08_R8ME,2521
470
470
  nat/utils/settings/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
471
471
  nat/utils/settings/global_settings.py,sha256=9JaO6pxKT_Pjw6rxJRsRlFCXdVKCl_xUKU2QHZQWWNM,7294
472
- nvidia_nat-1.3.0a20250926.dist-info/licenses/LICENSE-3rd-party.txt,sha256=fOk5jMmCX9YoKWyYzTtfgl-SUy477audFC5hNY4oP7Q,284609
473
- nvidia_nat-1.3.0a20250926.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
474
- nvidia_nat-1.3.0a20250926.dist-info/METADATA,sha256=AsBqos4EXJEwd-G-QSIF4iXItfZ7ULCpgwh0QCVd1lQ,22862
475
- nvidia_nat-1.3.0a20250926.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
476
- nvidia_nat-1.3.0a20250926.dist-info/entry_points.txt,sha256=4jCqjyETMpyoWbCBf4GalZU8I_wbstpzwQNezdAVbbo,698
477
- nvidia_nat-1.3.0a20250926.dist-info/top_level.txt,sha256=lgJWLkigiVZuZ_O1nxVnD_ziYBwgpE2OStdaCduMEGc,8
478
- nvidia_nat-1.3.0a20250926.dist-info/RECORD,,
472
+ nvidia_nat-1.3.0a20250929.dist-info/licenses/LICENSE-3rd-party.txt,sha256=fOk5jMmCX9YoKWyYzTtfgl-SUy477audFC5hNY4oP7Q,284609
473
+ nvidia_nat-1.3.0a20250929.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
474
+ nvidia_nat-1.3.0a20250929.dist-info/METADATA,sha256=3m2Kp0KWB9oXqVUBoOmGesvEgobUxMEenPMV6W1llDo,22918
475
+ nvidia_nat-1.3.0a20250929.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
476
+ nvidia_nat-1.3.0a20250929.dist-info/entry_points.txt,sha256=4jCqjyETMpyoWbCBf4GalZU8I_wbstpzwQNezdAVbbo,698
477
+ nvidia_nat-1.3.0a20250929.dist-info/top_level.txt,sha256=lgJWLkigiVZuZ_O1nxVnD_ziYBwgpE2OStdaCduMEGc,8
478
+ nvidia_nat-1.3.0a20250929.dist-info/RECORD,,