nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +24 -15
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +79 -47
  7. nat/agent/react_agent/register.py +50 -22
  8. nat/agent/reasoning_agent/reasoning_agent.py +11 -9
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +326 -148
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +54 -27
  13. nat/agent/tool_calling_agent/agent.py +84 -28
  14. nat/agent/tool_calling_agent/register.py +51 -28
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +9 -5
  24. nat/builder/context.py +68 -17
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/user_interaction_manager.py +2 -2
  32. nat/builder/workflow.py +13 -1
  33. nat/builder/workflow_builder.py +281 -76
  34. nat/cli/cli_utils/config_override.py +2 -2
  35. nat/cli/commands/evaluate.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/info/list_channels.py +1 -1
  38. nat/cli/commands/info/list_components.py +7 -8
  39. nat/cli/commands/mcp/__init__.py +14 -0
  40. nat/cli/commands/mcp/mcp.py +986 -0
  41. nat/cli/commands/object_store/__init__.py +14 -0
  42. nat/cli/commands/object_store/object_store.py +227 -0
  43. nat/cli/commands/optimize.py +90 -0
  44. nat/cli/commands/registry/publish.py +2 -2
  45. nat/cli/commands/registry/pull.py +2 -2
  46. nat/cli/commands/registry/remove.py +2 -2
  47. nat/cli/commands/registry/search.py +15 -17
  48. nat/cli/commands/start.py +16 -5
  49. nat/cli/commands/uninstall.py +1 -1
  50. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  51. nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
  52. nat/cli/commands/workflow/templates/register.py.j2 +2 -3
  53. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  54. nat/cli/commands/workflow/workflow_commands.py +62 -22
  55. nat/cli/entrypoint.py +8 -10
  56. nat/cli/main.py +3 -0
  57. nat/cli/register_workflow.py +38 -4
  58. nat/cli/type_registry.py +75 -6
  59. nat/control_flow/__init__.py +0 -0
  60. nat/control_flow/register.py +20 -0
  61. nat/control_flow/router_agent/__init__.py +0 -0
  62. nat/control_flow/router_agent/agent.py +329 -0
  63. nat/control_flow/router_agent/prompt.py +48 -0
  64. nat/control_flow/router_agent/register.py +91 -0
  65. nat/control_flow/sequential_executor.py +166 -0
  66. nat/data_models/agent.py +34 -0
  67. nat/data_models/api_server.py +74 -66
  68. nat/data_models/authentication.py +23 -9
  69. nat/data_models/common.py +1 -1
  70. nat/data_models/component.py +2 -0
  71. nat/data_models/component_ref.py +11 -0
  72. nat/data_models/config.py +41 -17
  73. nat/data_models/dataset_handler.py +1 -1
  74. nat/data_models/discovery_metadata.py +4 -4
  75. nat/data_models/evaluate.py +4 -1
  76. nat/data_models/function.py +34 -0
  77. nat/data_models/function_dependencies.py +14 -6
  78. nat/data_models/gated_field_mixin.py +242 -0
  79. nat/data_models/intermediate_step.py +3 -3
  80. nat/data_models/optimizable.py +119 -0
  81. nat/data_models/optimizer.py +149 -0
  82. nat/data_models/span.py +41 -3
  83. nat/data_models/swe_bench_model.py +1 -1
  84. nat/data_models/temperature_mixin.py +44 -0
  85. nat/data_models/thinking_mixin.py +86 -0
  86. nat/data_models/top_p_mixin.py +44 -0
  87. nat/embedder/nim_embedder.py +1 -1
  88. nat/embedder/openai_embedder.py +1 -1
  89. nat/embedder/register.py +0 -1
  90. nat/eval/config.py +3 -1
  91. nat/eval/dataset_handler/dataset_handler.py +71 -7
  92. nat/eval/evaluate.py +86 -31
  93. nat/eval/evaluator/base_evaluator.py +1 -1
  94. nat/eval/evaluator/evaluator_model.py +13 -0
  95. nat/eval/intermediate_step_adapter.py +1 -1
  96. nat/eval/rag_evaluator/evaluate.py +2 -2
  97. nat/eval/rag_evaluator/register.py +3 -3
  98. nat/eval/register.py +4 -1
  99. nat/eval/remote_workflow.py +3 -3
  100. nat/eval/runtime_evaluator/__init__.py +14 -0
  101. nat/eval/runtime_evaluator/evaluate.py +123 -0
  102. nat/eval/runtime_evaluator/register.py +100 -0
  103. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  104. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  105. nat/eval/trajectory_evaluator/register.py +1 -1
  106. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  107. nat/eval/utils/eval_trace_ctx.py +89 -0
  108. nat/eval/utils/weave_eval.py +18 -9
  109. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  110. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  111. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  112. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  113. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  114. nat/experimental/test_time_compute/register.py +0 -1
  115. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  116. nat/front_ends/console/authentication_flow_handler.py +82 -30
  117. nat/front_ends/console/console_front_end_plugin.py +8 -5
  118. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  119. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  120. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  121. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  122. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  123. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +452 -282
  124. nat/front_ends/fastapi/job_store.py +518 -99
  125. nat/front_ends/fastapi/main.py +11 -19
  126. nat/front_ends/fastapi/message_handler.py +13 -14
  127. nat/front_ends/fastapi/message_validator.py +19 -19
  128. nat/front_ends/fastapi/response_helpers.py +4 -4
  129. nat/front_ends/fastapi/step_adaptor.py +2 -2
  130. nat/front_ends/fastapi/utils.py +57 -0
  131. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  132. nat/front_ends/mcp/mcp_front_end_config.py +10 -1
  133. nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
  134. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
  135. nat/front_ends/mcp/tool_converter.py +44 -14
  136. nat/front_ends/register.py +0 -1
  137. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  138. nat/llm/aws_bedrock_llm.py +24 -12
  139. nat/llm/azure_openai_llm.py +13 -6
  140. nat/llm/litellm_llm.py +69 -0
  141. nat/llm/nim_llm.py +20 -8
  142. nat/llm/openai_llm.py +14 -6
  143. nat/llm/register.py +4 -1
  144. nat/llm/utils/env_config_value.py +2 -3
  145. nat/llm/utils/thinking.py +215 -0
  146. nat/meta/pypi.md +9 -9
  147. nat/object_store/register.py +0 -1
  148. nat/observability/exporter/base_exporter.py +3 -3
  149. nat/observability/exporter/file_exporter.py +1 -1
  150. nat/observability/exporter/processing_exporter.py +309 -81
  151. nat/observability/exporter/span_exporter.py +35 -15
  152. nat/observability/exporter_manager.py +7 -7
  153. nat/observability/mixin/file_mixin.py +7 -7
  154. nat/observability/mixin/redaction_config_mixin.py +42 -0
  155. nat/observability/mixin/tagging_config_mixin.py +62 -0
  156. nat/observability/mixin/type_introspection_mixin.py +420 -107
  157. nat/observability/processor/batching_processor.py +5 -7
  158. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  159. nat/observability/processor/processor.py +3 -0
  160. nat/observability/processor/processor_factory.py +70 -0
  161. nat/observability/processor/redaction/__init__.py +24 -0
  162. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  163. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  164. nat/observability/processor/redaction/redaction_processor.py +177 -0
  165. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  166. nat/observability/processor/span_tagging_processor.py +68 -0
  167. nat/observability/register.py +6 -4
  168. nat/profiler/calc/calc_runner.py +3 -4
  169. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  170. nat/profiler/callbacks/langchain_callback_handler.py +6 -6
  171. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  172. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  173. nat/profiler/data_frame_row.py +1 -1
  174. nat/profiler/decorators/framework_wrapper.py +62 -13
  175. nat/profiler/decorators/function_tracking.py +160 -3
  176. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  177. nat/profiler/forecasting/models/linear_model.py +1 -1
  178. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  179. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  180. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  181. nat/profiler/inference_optimization/data_models.py +3 -3
  182. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
  183. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  184. nat/profiler/parameter_optimization/__init__.py +0 -0
  185. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  186. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  187. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  188. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  189. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  190. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  191. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  192. nat/profiler/profile_runner.py +14 -9
  193. nat/profiler/utils.py +4 -2
  194. nat/registry_handlers/local/local_handler.py +2 -2
  195. nat/registry_handlers/package_utils.py +1 -2
  196. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  197. nat/registry_handlers/register.py +3 -4
  198. nat/registry_handlers/rest/rest_handler.py +12 -13
  199. nat/retriever/milvus/retriever.py +2 -2
  200. nat/retriever/nemo_retriever/retriever.py +1 -1
  201. nat/retriever/register.py +0 -1
  202. nat/runtime/loader.py +2 -2
  203. nat/runtime/runner.py +106 -8
  204. nat/runtime/session.py +69 -8
  205. nat/settings/global_settings.py +16 -5
  206. nat/tool/chat_completion.py +5 -2
  207. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  208. nat/tool/datetime_tools.py +49 -9
  209. nat/tool/document_search.py +2 -2
  210. nat/tool/github_tools.py +450 -0
  211. nat/tool/memory_tools/get_memory_tool.py +1 -1
  212. nat/tool/nvidia_rag.py +1 -1
  213. nat/tool/register.py +2 -9
  214. nat/tool/retriever.py +3 -2
  215. nat/utils/callable_utils.py +70 -0
  216. nat/utils/data_models/schema_validator.py +3 -3
  217. nat/utils/decorators.py +210 -0
  218. nat/utils/exception_handlers/automatic_retries.py +104 -51
  219. nat/utils/exception_handlers/schemas.py +1 -1
  220. nat/utils/io/yaml_tools.py +2 -2
  221. nat/utils/log_levels.py +25 -0
  222. nat/utils/reactive/base/observable_base.py +2 -2
  223. nat/utils/reactive/base/observer_base.py +1 -1
  224. nat/utils/reactive/observable.py +2 -2
  225. nat/utils/reactive/observer.py +4 -4
  226. nat/utils/reactive/subscription.py +1 -1
  227. nat/utils/settings/global_settings.py +6 -8
  228. nat/utils/type_converter.py +4 -3
  229. nat/utils/type_utils.py +9 -5
  230. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +42 -18
  231. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +238 -196
  232. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +1 -0
  233. nat/cli/commands/info/list_mcp.py +0 -304
  234. nat/tool/github_tools/create_github_commit.py +0 -133
  235. nat/tool/github_tools/create_github_issue.py +0 -87
  236. nat/tool/github_tools/create_github_pr.py +0 -106
  237. nat/tool/github_tools/get_github_file.py +0 -106
  238. nat/tool/github_tools/get_github_issue.py +0 -166
  239. nat/tool/github_tools/get_github_pr.py +0 -256
  240. nat/tool/github_tools/update_github_issue.py +0 -100
  241. nat/tool/mcp/exceptions.py +0 -142
  242. nat/tool/mcp/mcp_client.py +0 -255
  243. nat/tool/mcp/mcp_tool.py +0 -96
  244. nat/utils/exception_handlers/mcp.py +0 -211
  245. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  246. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  247. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
  248. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  249. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  250. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,329 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import typing
18
+
19
+ from langchain_core.callbacks.base import AsyncCallbackHandler
20
+ from langchain_core.language_models import BaseChatModel
21
+ from langchain_core.messages.base import BaseMessage
22
+ from langchain_core.messages.human import HumanMessage
23
+ from langchain_core.prompts.chat import ChatPromptTemplate
24
+ from langchain_core.tools import BaseTool
25
+ from langgraph.graph import StateGraph
26
+ from pydantic import BaseModel
27
+ from pydantic import Field
28
+
29
+ from nat.agent.base import AGENT_CALL_LOG_MESSAGE
30
+ from nat.agent.base import AGENT_LOG_PREFIX
31
+ from nat.agent.base import BaseAgent
32
+
33
+ if typing.TYPE_CHECKING:
34
+ from nat.control_flow.router_agent.register import RouterAgentWorkflowConfig
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class RouterAgentGraphState(BaseModel):
40
+ """State schema for the Router Agent Graph.
41
+
42
+ This class defines the state structure used throughout the Router Agent's
43
+ execution graph, containing messages, routing information, and branch selection.
44
+
45
+ Attributes:
46
+ messages: A list of messages representing the conversation history.
47
+ forward_message: The message to be forwarded to the chosen branch.
48
+ chosen_branch: The name of the branch selected by the router agent.
49
+ """
50
+ messages: list[BaseMessage] = Field(default_factory=list)
51
+ forward_message: BaseMessage = Field(default_factory=lambda: HumanMessage(content=""))
52
+ chosen_branch: str = Field(default="")
53
+
54
+
55
+ class RouterAgentGraph(BaseAgent):
56
+ """Configurable Router Agent for routing requests to different branches.
57
+
58
+ A Router Agent analyzes incoming requests and routes them to one of the
59
+ configured branches based on the conte nt and context. It makes a single
60
+ routing decision and executes only the selected branch before returning.
61
+
62
+ This agent is useful for creating multi-path workflows where different
63
+ types of requests need to be handled by specialized sub-agents or tools.
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ llm: BaseChatModel,
69
+ branches: list[BaseTool],
70
+ prompt: ChatPromptTemplate,
71
+ max_router_retries: int = 3,
72
+ callbacks: list[AsyncCallbackHandler] | None = None,
73
+ detailed_logs: bool = False,
74
+ log_response_max_chars: int = 1000,
75
+ ):
76
+ """Initialize the Router Agent.
77
+
78
+ Args:
79
+ llm: The language model to use for routing decisions.
80
+ branches: List of tools/branches that the agent can route to.
81
+ prompt: The chat prompt template for the routing agent.
82
+ max_router_retries: Maximum number of retries if branch selection fails.
83
+ callbacks: Optional list of async callback handlers.
84
+ detailed_logs: Whether to enable detailed logging.
85
+ log_response_max_chars: Maximum characters to log in responses.
86
+ """
87
+ super().__init__(llm=llm,
88
+ tools=branches,
89
+ callbacks=callbacks,
90
+ detailed_logs=detailed_logs,
91
+ log_response_max_chars=log_response_max_chars)
92
+
93
+ self._branches = branches
94
+ self._branches_dict = {branch.name: branch for branch in branches}
95
+ branch_names = ",".join([branch.name for branch in branches])
96
+ branch_names_and_descriptions = "\n".join([f"{branch.name}: {branch.description}" for branch in branches])
97
+
98
+ prompt = prompt.partial(branches=branch_names_and_descriptions, branch_names=branch_names)
99
+ self.agent = prompt | self.llm
100
+
101
+ self.max_router_retries = max_router_retries
102
+
103
+ def _get_branch(self, branch_name: str) -> BaseTool | None:
104
+ return self._branches_dict.get(branch_name, None)
105
+
106
+ async def agent_node(self, state: RouterAgentGraphState):
107
+ """Execute the agent node to select a branch for routing.
108
+
109
+ This method processes the incoming request and determines which branch
110
+ should handle it. It uses the configured LLM to analyze the request
111
+ and select the most appropriate branch.
112
+
113
+ Args:
114
+ state: The current state of the router agent graph.
115
+
116
+ Returns:
117
+ RouterAgentGraphState: Updated state with the chosen branch.
118
+
119
+ Raises:
120
+ RuntimeError: If the agent fails to choose a branch after max retries.
121
+ """
122
+ logger.debug("%s Starting the Router Agent Node", AGENT_LOG_PREFIX)
123
+ chat_history = self._get_chat_history(state.messages)
124
+ request = state.forward_message.content
125
+ for attempt in range(1, self.max_router_retries + 1):
126
+ try:
127
+ agent_response = await self._call_llm(self.agent, {"request": request, "chat_history": chat_history})
128
+ if self.detailed_logs:
129
+ logger.info(AGENT_CALL_LOG_MESSAGE, request, agent_response)
130
+
131
+ state.messages += [agent_response]
132
+
133
+ # Determine chosen branch based on agent response
134
+ if state.chosen_branch == "":
135
+ for branch in self._branches:
136
+ if branch.name.lower() in str(agent_response.content).lower():
137
+ state.chosen_branch = branch.name
138
+ if self.detailed_logs:
139
+ logger.debug("%s Router Agent has chosen branch: %s", AGENT_LOG_PREFIX, branch.name)
140
+ return state
141
+
142
+ # The agent failed to choose a branch
143
+ if state.chosen_branch == "":
144
+ if attempt == self.max_router_retries:
145
+ logger.error("%s Router Agent has empty chosen branch", AGENT_LOG_PREFIX)
146
+ raise RuntimeError("Router Agent failed to choose a branch")
147
+ logger.warning("%s Router Agent failed to choose a branch, retrying %d out of %d",
148
+ AGENT_LOG_PREFIX,
149
+ attempt,
150
+ self.max_router_retries)
151
+
152
+ except Exception as ex:
153
+ logger.error("%s Router Agent failed to call agent_node: %s", AGENT_LOG_PREFIX, ex)
154
+ raise
155
+
156
+ return state
157
+
158
+ async def branch_node(self, state: RouterAgentGraphState):
159
+ """Execute the selected branch with the forwarded message.
160
+
161
+ This method calls the tool/branch that was selected by the agent node
162
+ and processes the response.
163
+
164
+ Args:
165
+ state: The current state containing the chosen branch and message.
166
+
167
+ Returns:
168
+ RouterAgentGraphState: Updated state with the branch response.
169
+
170
+ Raises:
171
+ RuntimeError: If no branch was chosen or branch execution fails.
172
+ ValueError: If the requested tool is not found in the configuration.
173
+ """
174
+ logger.debug("%s Starting Router Agent Tool Node", AGENT_LOG_PREFIX)
175
+ try:
176
+ if state.chosen_branch == "":
177
+ logger.error("%s Router Agent has empty chosen branch", AGENT_LOG_PREFIX)
178
+ raise RuntimeError("Router Agent failed to choose a branch")
179
+ requested_branch = self._get_branch(state.chosen_branch)
180
+ if not requested_branch:
181
+ logger.error("%s Router Agent wants to call tool %s but it is not in the config file",
182
+ AGENT_LOG_PREFIX,
183
+ state.chosen_branch)
184
+ raise ValueError("Tool not found in config file")
185
+
186
+ branch_input = state.forward_message.content
187
+ branch_response = await self._call_tool(requested_branch, branch_input)
188
+ state.messages += [branch_response]
189
+ if self.detailed_logs:
190
+ self._log_tool_response(requested_branch.name, branch_input, branch_response.content)
191
+
192
+ return state
193
+
194
+ except Exception as ex:
195
+ logger.error("%s Router Agent throws exception during branch node execution: %s", AGENT_LOG_PREFIX, ex)
196
+ raise
197
+
198
+ async def _build_graph(self, state_schema):
199
+ logger.debug("%s Building and compiling the Router Agent Graph", AGENT_LOG_PREFIX)
200
+
201
+ graph = StateGraph(state_schema)
202
+ graph.add_node("agent", self.agent_node)
203
+ graph.add_node("branch", self.branch_node)
204
+ graph.add_edge("agent", "branch")
205
+ graph.set_entry_point("agent")
206
+
207
+ self.graph = graph.compile()
208
+ logger.debug("%s Router Agent Graph built and compiled successfully", AGENT_LOG_PREFIX)
209
+
210
+ return self.graph
211
+
212
+ async def build_graph(self):
213
+ """Build and compile the router agent execution graph.
214
+
215
+ Creates a state graph with agent and branch nodes, configures the
216
+ execution flow, and compiles the graph for execution.
217
+
218
+ Returns:
219
+ The compiled execution graph.
220
+
221
+ Raises:
222
+ Exception: If graph building or compilation fails.
223
+ """
224
+ try:
225
+ await self._build_graph(state_schema=RouterAgentGraphState)
226
+ return self.graph
227
+ except Exception as ex:
228
+ logger.error("%s Router Agent failed to build graph: %s", AGENT_LOG_PREFIX, ex)
229
+ raise
230
+
231
+ @staticmethod
232
+ def validate_system_prompt(system_prompt: str) -> bool:
233
+ """Validate that the system prompt contains required variables.
234
+
235
+ Checks that the system prompt includes necessary template variables
236
+ for branch information that the router agent needs.
237
+
238
+ Args:
239
+ system_prompt: The system prompt string to validate.
240
+
241
+ Returns:
242
+ True if the prompt is valid, False otherwise.
243
+ """
244
+ errors = []
245
+ required_prompt_variables = {
246
+ "{branches}": "The system prompt must contain {branches} so the agent knows about configured branches.",
247
+ "{branch_names}": "The system prompt must contain {branch_names} so the agent knows branch names."
248
+ }
249
+ for variable_name, error_message in required_prompt_variables.items():
250
+ if variable_name not in system_prompt:
251
+ errors.append(error_message)
252
+ if errors:
253
+ error_text = "\n".join(errors)
254
+ logger.error("%s %s", AGENT_LOG_PREFIX, error_text)
255
+ return False
256
+ return True
257
+
258
+ @staticmethod
259
+ def validate_user_prompt(user_prompt: str) -> bool:
260
+ """Validate that the user prompt contains required variables.
261
+
262
+ Checks that the user prompt includes necessary template variables
263
+ for chat history and other required information.
264
+
265
+ Args:
266
+ user_prompt: The user prompt string to validate.
267
+
268
+ Returns:
269
+ True if the prompt is valid, False otherwise.
270
+ """
271
+ errors = []
272
+ if not user_prompt:
273
+ errors.append("The user prompt cannot be empty.")
274
+ else:
275
+ required_prompt_variables = {
276
+ "{chat_history}":
277
+ "The user prompt must contain {chat_history} so the agent knows about the conversation history.",
278
+ "{request}":
279
+ "The user prompt must contain {request} so the agent sees the current request.",
280
+ }
281
+ for variable_name, error_message in required_prompt_variables.items():
282
+ if variable_name not in user_prompt:
283
+ errors.append(error_message)
284
+ if errors:
285
+ error_text = "\n".join(errors)
286
+ logger.error("%s %s", AGENT_LOG_PREFIX, error_text)
287
+ return False
288
+ return True
289
+
290
+
291
+ def create_router_agent_prompt(config: "RouterAgentWorkflowConfig") -> ChatPromptTemplate:
292
+ """Create a Router Agent prompt from the configuration.
293
+
294
+ Builds a ChatPromptTemplate using either custom prompts from the config
295
+ or default system and user prompts. Validates the prompts to ensure they
296
+ contain required template variables.
297
+
298
+ Args:
299
+ config: The router agent workflow configuration containing prompt settings.
300
+
301
+ Returns:
302
+ A configured ChatPromptTemplate for the router agent.
303
+
304
+ Raises:
305
+ ValueError: If the system_prompt or user_prompt validation fails.
306
+ """
307
+ from nat.control_flow.router_agent.prompt import SYSTEM_PROMPT
308
+ from nat.control_flow.router_agent.prompt import USER_PROMPT
309
+ # the Router Agent prompt can be customized via config option system_prompt and user_prompt.
310
+
311
+ if config.system_prompt:
312
+ system_prompt = config.system_prompt
313
+ else:
314
+ system_prompt = SYSTEM_PROMPT
315
+
316
+ if config.user_prompt:
317
+ user_prompt = config.user_prompt
318
+ else:
319
+ user_prompt = USER_PROMPT
320
+
321
+ if not RouterAgentGraph.validate_system_prompt(system_prompt):
322
+ logger.error("%s Invalid system_prompt", AGENT_LOG_PREFIX)
323
+ raise ValueError("Invalid system_prompt")
324
+
325
+ if not RouterAgentGraph.validate_user_prompt(user_prompt):
326
+ logger.error("%s Invalid user_prompt", AGENT_LOG_PREFIX)
327
+ raise ValueError("Invalid user_prompt")
328
+
329
+ return ChatPromptTemplate([("system", system_prompt), ("user", user_prompt)])
@@ -0,0 +1,48 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ SYSTEM_PROMPT = """
17
+ You are a Router Agent responsible for analyzing incoming requests and routing them to the most appropriate branch.
18
+
19
+ Available branches:
20
+ {branches}
21
+
22
+ CRITICAL INSTRUCTIONS:
23
+ - Analyze the user's request carefully
24
+ - Select exactly ONE branch that best handles the request from: [{branch_names}]
25
+ - Respond with ONLY the exact branch name, nothing else
26
+ - Be decisive - choose the single best match, if the request could fit multiple branches,
27
+ choose the most specific/specialized one
28
+ - If no branch perfectly fits, choose the closest match
29
+
30
+ Your response MUST contain ONLY the branch name. Do not include any explanations, reasoning, or additional text.
31
+
32
+ Examples:
33
+ User: "How do I calculate 15 + 25?"
34
+ Response: calculator_tool
35
+
36
+ User: "What's the weather like today?"
37
+ Response: weather_service
38
+
39
+ User: "Send an email to John"
40
+ Response: email_tool"""
41
+
42
+ USER_PROMPT = """
43
+ Previous conversation history:
44
+ {chat_history}
45
+
46
+ To respond to the request: {request}, which branch should be chosen?
47
+
48
+ Respond with only the branch name."""
@@ -0,0 +1,91 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ from pydantic import Field
19
+
20
+ from nat.builder.builder import Builder
21
+ from nat.builder.framework_enum import LLMFrameworkEnum
22
+ from nat.builder.function_info import FunctionInfo
23
+ from nat.cli.register_workflow import register_function
24
+ from nat.data_models.agent import AgentBaseConfig
25
+ from nat.data_models.component_ref import FunctionRef
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class RouterAgentWorkflowConfig(AgentBaseConfig, name="router_agent"):
31
+ """
32
+ A router agent takes in the incoming message, combines it with a prompt and the list of branches,
33
+ and ask a LLM about which branch to take.
34
+ """
35
+ description: str = Field(default="Router Agent Workflow", description="Description of this functions use.")
36
+ branches: list[FunctionRef] = Field(default_factory=list,
37
+ description="The list of branches to provide to the router agent.")
38
+ system_prompt: str | None = Field(default=None, description="Provides the system prompt to use with the agent.")
39
+ user_prompt: str | None = Field(default=None, description="Provides the prompt to use with the agent.")
40
+ max_router_retries: int = Field(
41
+ default=3, description="Maximum number of retries if the router agent fails to choose a branch.")
42
+
43
+
44
+ @register_function(config_type=RouterAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
45
+ async def router_agent_workflow(config: RouterAgentWorkflowConfig, builder: Builder):
46
+ from langchain_core.messages.human import HumanMessage
47
+ from langgraph.graph.state import CompiledStateGraph
48
+
49
+ from nat.agent.base import AGENT_LOG_PREFIX
50
+ from nat.control_flow.router_agent.agent import RouterAgentGraph
51
+ from nat.control_flow.router_agent.agent import RouterAgentGraphState
52
+ from nat.control_flow.router_agent.agent import create_router_agent_prompt
53
+
54
+ prompt = create_router_agent_prompt(config)
55
+ llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
56
+ branches = await builder.get_tools(tool_names=config.branches, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
57
+ if not branches:
58
+ raise ValueError(f"No branches specified for Router Agent '{config.llm_name}'")
59
+
60
+ graph: CompiledStateGraph = await RouterAgentGraph(
61
+ llm=llm,
62
+ branches=branches,
63
+ prompt=prompt,
64
+ max_router_retries=config.max_router_retries,
65
+ detailed_logs=config.verbose,
66
+ log_response_max_chars=config.log_response_max_chars,
67
+ ).build_graph()
68
+
69
+ async def _response_fn(input_message: str) -> str:
70
+ try:
71
+ message = HumanMessage(content=input_message)
72
+ state = RouterAgentGraphState(forward_message=message)
73
+
74
+ result_dict = await graph.ainvoke(state)
75
+ result_state = RouterAgentGraphState(**result_dict)
76
+
77
+ output_message = result_state.messages[-1]
78
+ return str(output_message.content)
79
+
80
+ except Exception as ex:
81
+ logger.exception("%s Router Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
82
+ if config.verbose:
83
+ return str(ex)
84
+ return f"Router agent failed with exception: {ex}"
85
+
86
+ try:
87
+ yield FunctionInfo.from_fn(_response_fn, description=config.description)
88
+ except GeneratorExit:
89
+ logger.exception("%s Workflow exited early!", AGENT_LOG_PREFIX)
90
+ finally:
91
+ logger.debug("%s Cleaning up router_agent workflow.", AGENT_LOG_PREFIX)
@@ -0,0 +1,166 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import typing
18
+
19
+ from langchain_core.tools.base import BaseTool
20
+ from pydantic import BaseModel
21
+ from pydantic import Field
22
+
23
+ from nat.builder.builder import Builder
24
+ from nat.builder.framework_enum import LLMFrameworkEnum
25
+ from nat.builder.function import Function
26
+ from nat.builder.function_info import FunctionInfo
27
+ from nat.cli.register_workflow import register_function
28
+ from nat.data_models.component_ref import FunctionRef
29
+ from nat.data_models.function import FunctionBaseConfig
30
+ from nat.utils.type_utils import DecomposedType
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class ToolExecutionConfig(BaseModel):
36
+ """Configuration for individual tool execution within sequential execution."""
37
+
38
+ use_streaming: bool = Field(default=False, description="Whether to use streaming output for the tool.")
39
+
40
+
41
+ class SequentialExecutorConfig(FunctionBaseConfig, name="sequential_executor"):
42
+ """Configuration for sequential execution of a list of functions."""
43
+
44
+ tool_list: list[FunctionRef] = Field(default_factory=list,
45
+ description="A list of functions to execute sequentially.")
46
+ tool_execution_config: dict[str, ToolExecutionConfig] = Field(default_factory=dict,
47
+ description="Optional configuration for each"
48
+ "tool in the sequential execution tool list."
49
+ "Keys must match the tool names from the"
50
+ "tool_list.")
51
+ raise_type_incompatibility: bool = Field(
52
+ default=False,
53
+ description="Default to False. Check if the adjacent tools are type compatible,"
54
+ "which means the output type of the previous function is compatible with the input type of the next function."
55
+ "If set to True, any incompatibility will raise an exception. If set to false, the incompatibility will only"
56
+ "generate a warning message and the sequential execution will continue.")
57
+
58
+
59
+ def _get_function_output_type(function: Function, tool_execution_config: dict[str, ToolExecutionConfig]) -> type:
60
+ function_config = tool_execution_config.get(function.instance_name, None)
61
+ if function_config:
62
+ return function.streaming_output_type if function_config.use_streaming else function.single_output_type
63
+ else:
64
+ return function.single_output_type
65
+
66
+
67
+ def _validate_function_type_compatibility(src_fn: Function,
68
+ target_fn: Function,
69
+ tool_execution_config: dict[str, ToolExecutionConfig]) -> None:
70
+ src_output_type = _get_function_output_type(src_fn, tool_execution_config)
71
+ target_input_type = target_fn.input_type
72
+ logger.debug(
73
+ f"The output type of the {src_fn.instance_name} function is {str(src_output_type)}, is not compatible with"
74
+ f"the input type of the {target_fn.instance_name} function, which is {str(target_input_type)}")
75
+
76
+ is_compatible = DecomposedType.is_type_compatible(src_output_type, target_input_type)
77
+ if not is_compatible:
78
+ raise ValueError(
79
+ f"The output type of the {src_fn.instance_name} function is {str(src_output_type)}, is not compatible with"
80
+ f"the input type of the {target_fn.instance_name} function, which is {str(target_input_type)}")
81
+
82
+
83
+ async def _validate_tool_list_type_compatibility(sequential_executor_config: SequentialExecutorConfig,
84
+ builder: Builder) -> tuple[type, type]:
85
+ tool_list = sequential_executor_config.tool_list
86
+ tool_execution_config = sequential_executor_config.tool_execution_config
87
+
88
+ function_list = await builder.get_functions(tool_list)
89
+ if not function_list:
90
+ raise RuntimeError("The function list is empty")
91
+ input_type = function_list[0].input_type
92
+
93
+ if len(function_list) > 1:
94
+ for src_fn, target_fn in zip(function_list[0:-1], function_list[1:]):
95
+ try:
96
+ _validate_function_type_compatibility(src_fn, target_fn, tool_execution_config)
97
+ except ValueError as e:
98
+ raise ValueError(f"The sequential tool list has incompatible types: {e}")
99
+
100
+ output_type = _get_function_output_type(function_list[-1], tool_execution_config)
101
+ logger.debug(f"The input type of the sequential executor tool list is {str(input_type)},"
102
+ f"the output type is {str(output_type)}")
103
+
104
+ return (input_type, output_type)
105
+
106
+
107
+ @register_function(config_type=SequentialExecutorConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
108
+ async def sequential_execution(config: SequentialExecutorConfig, builder: Builder):
109
+ logger.debug(f"Initializing sequential executor with tool list: {config.tool_list}")
110
+
111
+ tools: list[BaseTool] = await builder.get_tools(tool_names=config.tool_list,
112
+ wrapper_type=LLMFrameworkEnum.LANGCHAIN)
113
+ tools_dict: dict[str, BaseTool] = {tool.name: tool for tool in tools}
114
+
115
+ try:
116
+ input_type, output_type = await _validate_tool_list_type_compatibility(config, builder)
117
+ except ValueError as e:
118
+ if config.raise_type_incompatibility:
119
+ logger.error(f"The sequential executor tool list has incompatible types: {e}")
120
+ raise
121
+ else:
122
+ logger.warning(f"The sequential executor tool list has incompatible types: {e}")
123
+ input_type = typing.Any
124
+ output_type = typing.Any
125
+ except Exception as e:
126
+ raise ValueError(f"Error with the sequential executor tool list: {e}")
127
+
128
+ # The type annotation of _sequential_function_execution is dynamically set according to the tool list
129
+ async def _sequential_function_execution(initial_tool_input):
130
+ logger.debug(f"Executing sequential executor with tool list: {config.tool_list}")
131
+
132
+ tool_list: list[FunctionRef] = config.tool_list
133
+ tool_input = initial_tool_input
134
+ tool_response = None
135
+
136
+ for tool_name in tool_list:
137
+ tool = tools_dict[tool_name]
138
+ tool_execution_config = config.tool_execution_config.get(tool_name, None)
139
+ logger.debug(f"Executing tool {tool_name} with input: {tool_input}")
140
+ try:
141
+ if tool_execution_config:
142
+ if tool_execution_config.use_streaming:
143
+ output = ""
144
+ async for chunk in tool.astream(tool_input):
145
+ output += chunk.content
146
+ tool_response = output
147
+ else:
148
+ tool_response = await tool.ainvoke(tool_input)
149
+ else:
150
+ tool_response = await tool.ainvoke(tool_input)
151
+ except Exception as e:
152
+ logger.error(f"Error with tool {tool_name}: {e}")
153
+ raise
154
+
155
+ # The input of the next tool is the response of the previous tool
156
+ tool_input = tool_response
157
+
158
+ return tool_response
159
+
160
+ # Dynamically set the annotations for the function
161
+ _sequential_function_execution.__annotations__ = {"initial_tool_input": input_type, "return": output_type}
162
+ logger.debug(f"Sequential executor function annotations: {_sequential_function_execution.__annotations__}")
163
+
164
+ yield FunctionInfo.from_fn(_sequential_function_execution,
165
+ description="Executes a list of functions sequentially."
166
+ "The input of the next tool is the response of the previous tool.")
@@ -0,0 +1,34 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pydantic import Field
17
+ from pydantic import PositiveInt
18
+
19
+ from nat.data_models.component_ref import LLMRef
20
+ from nat.data_models.function import FunctionBaseConfig
21
+
22
+
23
+ class AgentBaseConfig(FunctionBaseConfig):
24
+ """Base configuration class for all NAT agents with common fields."""
25
+
26
+ workflow_alias: str | None = Field(
27
+ default=None,
28
+ description=("The alias of the workflow. Useful when the agent is configured as a workflow "
29
+ "and needs to expose a customized name as a tool."))
30
+ llm_name: LLMRef = Field(description="The LLM model to use with the agent.")
31
+ verbose: bool = Field(default=False, description="Set the verbosity of the agent's logging.")
32
+ description: str = Field(description="The description of this function's use.")
33
+ log_response_max_chars: PositiveInt = Field(
34
+ default=1000, description="Maximum number of characters to display in logs when logging responses.")