nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +24 -15
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +79 -47
  7. nat/agent/react_agent/register.py +41 -21
  8. nat/agent/reasoning_agent/reasoning_agent.py +11 -9
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +326 -148
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +46 -26
  13. nat/agent/tool_calling_agent/agent.py +84 -28
  14. nat/agent/tool_calling_agent/register.py +51 -28
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +9 -5
  24. nat/builder/context.py +46 -11
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/user_interaction_manager.py +2 -2
  32. nat/builder/workflow.py +13 -1
  33. nat/builder/workflow_builder.py +281 -76
  34. nat/cli/cli_utils/config_override.py +2 -2
  35. nat/cli/commands/evaluate.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/info/list_channels.py +1 -1
  38. nat/cli/commands/info/list_components.py +7 -8
  39. nat/cli/commands/mcp/__init__.py +14 -0
  40. nat/cli/commands/mcp/mcp.py +986 -0
  41. nat/cli/commands/object_store/__init__.py +14 -0
  42. nat/cli/commands/object_store/object_store.py +227 -0
  43. nat/cli/commands/optimize.py +90 -0
  44. nat/cli/commands/registry/publish.py +2 -2
  45. nat/cli/commands/registry/pull.py +2 -2
  46. nat/cli/commands/registry/remove.py +2 -2
  47. nat/cli/commands/registry/search.py +15 -17
  48. nat/cli/commands/start.py +16 -5
  49. nat/cli/commands/uninstall.py +1 -1
  50. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  51. nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
  52. nat/cli/commands/workflow/templates/register.py.j2 +0 -1
  53. nat/cli/commands/workflow/workflow_commands.py +9 -13
  54. nat/cli/entrypoint.py +8 -10
  55. nat/cli/register_workflow.py +38 -4
  56. nat/cli/type_registry.py +75 -6
  57. nat/control_flow/__init__.py +0 -0
  58. nat/control_flow/register.py +20 -0
  59. nat/control_flow/router_agent/__init__.py +0 -0
  60. nat/control_flow/router_agent/agent.py +329 -0
  61. nat/control_flow/router_agent/prompt.py +48 -0
  62. nat/control_flow/router_agent/register.py +91 -0
  63. nat/control_flow/sequential_executor.py +166 -0
  64. nat/data_models/agent.py +34 -0
  65. nat/data_models/api_server.py +10 -10
  66. nat/data_models/authentication.py +23 -9
  67. nat/data_models/common.py +1 -1
  68. nat/data_models/component.py +2 -0
  69. nat/data_models/component_ref.py +11 -0
  70. nat/data_models/config.py +41 -17
  71. nat/data_models/dataset_handler.py +1 -1
  72. nat/data_models/discovery_metadata.py +4 -4
  73. nat/data_models/evaluate.py +4 -1
  74. nat/data_models/function.py +34 -0
  75. nat/data_models/function_dependencies.py +14 -6
  76. nat/data_models/gated_field_mixin.py +242 -0
  77. nat/data_models/intermediate_step.py +3 -3
  78. nat/data_models/optimizable.py +119 -0
  79. nat/data_models/optimizer.py +149 -0
  80. nat/data_models/swe_bench_model.py +1 -1
  81. nat/data_models/temperature_mixin.py +44 -0
  82. nat/data_models/thinking_mixin.py +86 -0
  83. nat/data_models/top_p_mixin.py +44 -0
  84. nat/embedder/nim_embedder.py +1 -1
  85. nat/embedder/openai_embedder.py +1 -1
  86. nat/embedder/register.py +0 -1
  87. nat/eval/config.py +3 -1
  88. nat/eval/dataset_handler/dataset_handler.py +71 -7
  89. nat/eval/evaluate.py +86 -31
  90. nat/eval/evaluator/base_evaluator.py +1 -1
  91. nat/eval/evaluator/evaluator_model.py +13 -0
  92. nat/eval/intermediate_step_adapter.py +1 -1
  93. nat/eval/rag_evaluator/evaluate.py +2 -2
  94. nat/eval/rag_evaluator/register.py +3 -3
  95. nat/eval/register.py +4 -1
  96. nat/eval/remote_workflow.py +3 -3
  97. nat/eval/runtime_evaluator/__init__.py +14 -0
  98. nat/eval/runtime_evaluator/evaluate.py +123 -0
  99. nat/eval/runtime_evaluator/register.py +100 -0
  100. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  101. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  102. nat/eval/trajectory_evaluator/register.py +1 -1
  103. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  104. nat/eval/utils/eval_trace_ctx.py +89 -0
  105. nat/eval/utils/weave_eval.py +18 -9
  106. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  107. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  108. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  109. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  110. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  111. nat/experimental/test_time_compute/register.py +0 -1
  112. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  113. nat/front_ends/console/authentication_flow_handler.py +82 -30
  114. nat/front_ends/console/console_front_end_plugin.py +8 -5
  115. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  116. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  117. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  118. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  119. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  120. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
  121. nat/front_ends/fastapi/job_store.py +518 -99
  122. nat/front_ends/fastapi/main.py +11 -19
  123. nat/front_ends/fastapi/message_handler.py +13 -14
  124. nat/front_ends/fastapi/message_validator.py +17 -19
  125. nat/front_ends/fastapi/response_helpers.py +4 -4
  126. nat/front_ends/fastapi/step_adaptor.py +2 -2
  127. nat/front_ends/fastapi/utils.py +57 -0
  128. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  129. nat/front_ends/mcp/mcp_front_end_config.py +10 -1
  130. nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
  131. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
  132. nat/front_ends/mcp/tool_converter.py +44 -14
  133. nat/front_ends/register.py +0 -1
  134. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  135. nat/llm/aws_bedrock_llm.py +24 -12
  136. nat/llm/azure_openai_llm.py +13 -6
  137. nat/llm/litellm_llm.py +69 -0
  138. nat/llm/nim_llm.py +20 -8
  139. nat/llm/openai_llm.py +14 -6
  140. nat/llm/register.py +4 -1
  141. nat/llm/utils/env_config_value.py +2 -3
  142. nat/llm/utils/thinking.py +215 -0
  143. nat/meta/pypi.md +9 -9
  144. nat/object_store/register.py +0 -1
  145. nat/observability/exporter/base_exporter.py +3 -3
  146. nat/observability/exporter/file_exporter.py +1 -1
  147. nat/observability/exporter/processing_exporter.py +309 -81
  148. nat/observability/exporter/span_exporter.py +1 -1
  149. nat/observability/exporter_manager.py +7 -7
  150. nat/observability/mixin/file_mixin.py +7 -7
  151. nat/observability/mixin/redaction_config_mixin.py +42 -0
  152. nat/observability/mixin/tagging_config_mixin.py +62 -0
  153. nat/observability/mixin/type_introspection_mixin.py +420 -107
  154. nat/observability/processor/batching_processor.py +5 -7
  155. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  156. nat/observability/processor/processor.py +3 -0
  157. nat/observability/processor/processor_factory.py +70 -0
  158. nat/observability/processor/redaction/__init__.py +24 -0
  159. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  160. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  161. nat/observability/processor/redaction/redaction_processor.py +177 -0
  162. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  163. nat/observability/processor/span_tagging_processor.py +68 -0
  164. nat/observability/register.py +6 -4
  165. nat/profiler/calc/calc_runner.py +3 -4
  166. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  167. nat/profiler/callbacks/langchain_callback_handler.py +6 -6
  168. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  169. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  170. nat/profiler/data_frame_row.py +1 -1
  171. nat/profiler/decorators/framework_wrapper.py +62 -13
  172. nat/profiler/decorators/function_tracking.py +160 -3
  173. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  174. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  175. nat/profiler/inference_optimization/data_models.py +3 -3
  176. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
  177. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  178. nat/profiler/parameter_optimization/__init__.py +0 -0
  179. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  180. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  181. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  182. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  183. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  184. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  185. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  186. nat/profiler/profile_runner.py +14 -9
  187. nat/profiler/utils.py +4 -2
  188. nat/registry_handlers/local/local_handler.py +2 -2
  189. nat/registry_handlers/package_utils.py +1 -2
  190. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  191. nat/registry_handlers/register.py +3 -4
  192. nat/registry_handlers/rest/rest_handler.py +12 -13
  193. nat/retriever/milvus/retriever.py +2 -2
  194. nat/retriever/nemo_retriever/retriever.py +1 -1
  195. nat/retriever/register.py +0 -1
  196. nat/runtime/loader.py +2 -2
  197. nat/runtime/runner.py +3 -2
  198. nat/runtime/session.py +43 -8
  199. nat/settings/global_settings.py +16 -5
  200. nat/tool/chat_completion.py +5 -2
  201. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  202. nat/tool/datetime_tools.py +49 -9
  203. nat/tool/document_search.py +2 -2
  204. nat/tool/github_tools.py +450 -0
  205. nat/tool/nvidia_rag.py +1 -1
  206. nat/tool/register.py +2 -9
  207. nat/tool/retriever.py +3 -2
  208. nat/utils/callable_utils.py +70 -0
  209. nat/utils/data_models/schema_validator.py +3 -3
  210. nat/utils/exception_handlers/automatic_retries.py +104 -51
  211. nat/utils/exception_handlers/schemas.py +1 -1
  212. nat/utils/io/yaml_tools.py +2 -2
  213. nat/utils/log_levels.py +25 -0
  214. nat/utils/reactive/base/observable_base.py +2 -2
  215. nat/utils/reactive/base/observer_base.py +1 -1
  216. nat/utils/reactive/observable.py +2 -2
  217. nat/utils/reactive/observer.py +4 -4
  218. nat/utils/reactive/subscription.py +1 -1
  219. nat/utils/settings/global_settings.py +6 -8
  220. nat/utils/type_converter.py +4 -3
  221. nat/utils/type_utils.py +9 -5
  222. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
  223. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
  224. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
  225. nat/cli/commands/info/list_mcp.py +0 -304
  226. nat/tool/github_tools/create_github_commit.py +0 -133
  227. nat/tool/github_tools/create_github_issue.py +0 -87
  228. nat/tool/github_tools/create_github_pr.py +0 -106
  229. nat/tool/github_tools/get_github_file.py +0 -106
  230. nat/tool/github_tools/get_github_issue.py +0 -166
  231. nat/tool/github_tools/get_github_pr.py +0 -256
  232. nat/tool/github_tools/update_github_issue.py +0 -100
  233. nat/tool/mcp/exceptions.py +0 -142
  234. nat/tool/mcp/mcp_client.py +0 -255
  235. nat/tool/mcp/mcp_tool.py +0 -96
  236. nat/utils/exception_handlers/mcp.py +0 -211
  237. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  238. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  239. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
  240. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  241. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
  242. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
aiq/__init__.py CHANGED
@@ -13,10 +13,10 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import sys
17
16
  import importlib
18
17
  import importlib.abc
19
18
  import importlib.util
19
+ import sys
20
20
  import warnings
21
21
 
22
22
 
@@ -26,7 +26,7 @@ class CompatFinder(importlib.abc.MetaPathFinder):
26
26
  self.alias_prefix = alias_prefix
27
27
  self.target_prefix = target_prefix
28
28
 
29
- def find_spec(self, fullname, path, target=None): # pylint: disable=unused-argument
29
+ def find_spec(self, fullname, path, target=None):
30
30
  if fullname == self.alias_prefix or fullname.startswith(self.alias_prefix + "."):
31
31
  # Map aiq.something -> nat.something
32
32
  target_name = self.target_prefix + fullname[len(self.alias_prefix):]
nat/agent/base.py CHANGED
@@ -27,9 +27,10 @@ from langchain_core.language_models import BaseChatModel
27
27
  from langchain_core.messages import AIMessage
28
28
  from langchain_core.messages import BaseMessage
29
29
  from langchain_core.messages import ToolMessage
30
+ from langchain_core.runnables import Runnable
30
31
  from langchain_core.runnables import RunnableConfig
31
32
  from langchain_core.tools import BaseTool
32
- from langgraph.graph.graph import CompiledGraph
33
+ from langgraph.graph.state import CompiledStateGraph
33
34
 
34
35
  logger = logging.getLogger(__name__)
35
36
 
@@ -70,12 +71,14 @@ class BaseAgent(ABC):
70
71
  llm: BaseChatModel,
71
72
  tools: list[BaseTool],
72
73
  callbacks: list[AsyncCallbackHandler] | None = None,
73
- detailed_logs: bool = False) -> None:
74
+ detailed_logs: bool = False,
75
+ log_response_max_chars: int = 1000) -> None:
74
76
  logger.debug("Initializing Agent Graph")
75
77
  self.llm = llm
76
78
  self.tools = tools
77
79
  self.callbacks = callbacks or []
78
80
  self.detailed_logs = detailed_logs
81
+ self.log_response_max_chars = log_response_max_chars
79
82
  self.graph = None
80
83
 
81
84
  async def _stream_llm(self,
@@ -105,21 +108,25 @@ class BaseAgent(ABC):
105
108
 
106
109
  return AIMessage(content=output_message)
107
110
 
108
- async def _call_llm(self, messages: list[BaseMessage]) -> AIMessage:
111
+ async def _call_llm(self, llm: Runnable, inputs: dict[str, Any], config: RunnableConfig | None = None) -> AIMessage:
109
112
  """
110
113
  Call the LLM directly. Retry logic is handled automatically by the underlying LLM client.
111
114
 
112
115
  Parameters
113
116
  ----------
114
- messages : list[BaseMessage]
115
- The messages to send to the LLM
117
+ llm : Runnable
118
+ The LLM runnable (prompt | llm or similar)
119
+ inputs : dict[str, Any]
120
+ The inputs to pass to the runnable
121
+ config : RunnableConfig | None
122
+ The config to pass to the runnable (should include callbacks)
116
123
 
117
124
  Returns
118
125
  -------
119
126
  AIMessage
120
127
  The LLM response
121
128
  """
122
- response = await self.llm.ainvoke(messages)
129
+ response = await llm.ainvoke(inputs, config=config)
123
130
  return AIMessage(content=str(response.content))
124
131
 
125
132
  async def _call_tool(self,
@@ -158,6 +165,11 @@ class BaseAgent(ABC):
158
165
  tool_call_id=tool.name,
159
166
  content=f"The tool {tool.name} provided an empty response.")
160
167
 
168
+ # ToolMessage only accepts str or list[str | dict] as content.
169
+ # Convert into list if the response is a dict.
170
+ if isinstance(response, dict):
171
+ response = [response]
172
+
161
173
  return ToolMessage(name=tool.name, tool_call_id=tool.name, content=response)
162
174
 
163
175
  except Exception as e:
@@ -179,13 +191,12 @@ class BaseAgent(ABC):
179
191
  logger.debug("%s Retrying tool call for %s in %d seconds...", AGENT_LOG_PREFIX, tool.name, sleep_time)
180
192
  await asyncio.sleep(sleep_time)
181
193
 
182
- # pylint: disable=C0209
183
194
  # All retries exhausted, return error message
184
- error_content = "Tool call failed after all retry attempts. Last error: %s" % str(last_exception)
185
- logger.error("%s %s", AGENT_LOG_PREFIX, error_content)
195
+ error_content = f"Tool call failed after all retry attempts. Last error: {str(last_exception)}"
196
+ logger.error("%s %s", AGENT_LOG_PREFIX, error_content, exc_info=True)
186
197
  return ToolMessage(name=tool.name, tool_call_id=tool.name, content=error_content, status="error")
187
198
 
188
- def _log_tool_response(self, tool_name: str, tool_input: Any, tool_response: str, max_chars: int = 1000) -> None:
199
+ def _log_tool_response(self, tool_name: str, tool_input: Any, tool_response: str) -> None:
189
200
  """
190
201
  Log tool response with consistent formatting and length limits.
191
202
 
@@ -197,13 +208,11 @@ class BaseAgent(ABC):
197
208
  The input that was passed to the tool
198
209
  tool_response : str
199
210
  The response from the tool
200
- max_chars : int
201
- Maximum number of characters to log (default: 1000)
202
211
  """
203
212
  if self.detailed_logs:
204
213
  # Truncate tool response if too long
205
- display_response = tool_response[:max_chars] + "...(rest of response truncated)" if len(
206
- tool_response) > max_chars else tool_response
214
+ display_response = tool_response[:self.log_response_max_chars] + "...(rest of response truncated)" if len(
215
+ tool_response) > self.log_response_max_chars else tool_response
207
216
 
208
217
  # Format the tool input for display
209
218
  tool_input_str = str(tool_input)
@@ -252,5 +261,5 @@ class BaseAgent(ABC):
252
261
  return "\n".join([f"{message.type}: {message.content}" for message in messages[:-1]])
253
262
 
254
263
  @abstractmethod
255
- async def _build_graph(self, state_schema: type) -> CompiledGraph:
264
+ async def _build_graph(self, state_schema: type) -> CompiledStateGraph:
256
265
  pass
nat/agent/dual_node.py CHANGED
@@ -20,7 +20,7 @@ from langchain_core.callbacks import AsyncCallbackHandler
20
20
  from langchain_core.language_models import BaseChatModel
21
21
  from langchain_core.tools import BaseTool
22
22
  from langgraph.graph import StateGraph
23
- from langgraph.graph.graph import CompiledGraph
23
+ from langgraph.graph.state import CompiledStateGraph
24
24
  from pydantic import BaseModel
25
25
 
26
26
  from .base import AgentDecision
@@ -35,8 +35,13 @@ class DualNodeAgent(BaseAgent):
35
35
  llm: BaseChatModel,
36
36
  tools: list[BaseTool],
37
37
  callbacks: list[AsyncCallbackHandler] | None = None,
38
- detailed_logs: bool = False):
39
- super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
38
+ detailed_logs: bool = False,
39
+ log_response_max_chars: int = 1000):
40
+ super().__init__(llm=llm,
41
+ tools=tools,
42
+ callbacks=callbacks,
43
+ detailed_logs=detailed_logs,
44
+ log_response_max_chars=log_response_max_chars)
40
45
 
41
46
  @abstractmethod
42
47
  async def agent_node(self, state: BaseModel) -> BaseModel:
@@ -50,7 +55,7 @@ class DualNodeAgent(BaseAgent):
50
55
  async def conditional_edge(self, state: BaseModel) -> str:
51
56
  pass
52
57
 
53
- async def _build_graph(self, state_schema) -> CompiledGraph:
58
+ async def _build_graph(self, state_schema: type) -> CompiledStateGraph:
54
59
  log.debug("Building and compiling the Agent Graph")
55
60
 
56
61
  graph = StateGraph(state_schema)
@@ -0,0 +1,68 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # flake8: noqa W291
16
+
17
+ mutator_prompt = """
18
+
19
+ ## CORE DIRECTIVES
20
+ - **Preserve the original objective and task.** Do not change what the prompt is meant to accomplish.
21
+ - **Keep the intent intact.** The improved prompt must solve the same problem as the original.
22
+ - **Do not invent new goals.** Only improve clarity, structure, constraints, and usability.
23
+ - **Do not drop critical instructions.** Everything essential from the original prompt must remain.
24
+ - **Return only the mutated prompt text.** No rationale, no diffs, no explanations.
25
+ - **Be Creative within bounds.** You may rephrase, reorganize, and enhance, but not alter meaning.
26
+ - **DO NOT use curly braces in your prompt** for anything other than existing variables in the prompt as the string
27
+ will be treated as an f-string.
28
+ - **Examples are a good idea** if the original prompt lacks them. They help clarify expected output.
29
+
30
+ ---
31
+
32
+ ## IMPROVEMENT HINTS
33
+ When modifying, apply these principles:
34
+ 1. **Clarity & Precision** – remove vague language, strengthen directives.
35
+ 2. **Structure & Flow** – order sections as: *Objective → Constraints → Tools → Steps → Output Schema → Examples*.
36
+ 3. **Schema Adherence** – enforce a single canonical output schema (JSON/XML) with `schema_version`.
37
+ 4. **Tool Governance** – clarify when/how tools are used, their inputs/outputs, and fallback behavior.
38
+ 5. **Error Handling** – specify behavior if tools fail or inputs are insufficient.
39
+ 6. **Budget Awareness** – minimize verbosity, respect token/latency limits.
40
+ 7. **Safety** – include refusals for unsafe requests, enforce compliance with rules.
41
+ 8. **Consistency** – avoid format drift; always maintain the same schema.
42
+ 9. **Integrity** – confirm the task, objective, and intent are preserved.
43
+
44
+ ---
45
+
46
+ ## MUTATION OPERATORS
47
+ You may:
48
+ - **Tighten** (remove fluff, redundancies)
49
+ - **Reorder** (improve logical flow)
50
+ - **Constrain** (add explicit rules/limits)
51
+ - **Harden** (improve error handling/fallbacks)
52
+ - **Defuse** (replace ambiguous verbs with measurable actions)
53
+ - **Format-lock** (wrap outputs in JSON/XML fenced blocks)
54
+ - **Example-ify** (add examples if missing or weak)
55
+
56
+ ---
57
+
58
+ ## INPUT
59
+ Here is the prompt to mutate:
60
+ {original_prompt}
61
+
62
+ ## OBJECTIVE
63
+ The prompt must acheive the following objective:
64
+ {objective}
65
+
66
+ The modified prompt is: \n
67
+
68
+ """
@@ -0,0 +1,149 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pydantic import Field
17
+
18
+ from nat.builder.builder import Builder
19
+ from nat.builder.framework_enum import LLMFrameworkEnum
20
+ from nat.builder.function_info import FunctionInfo
21
+ from nat.cli.register_workflow import register_function
22
+ from nat.data_models.component_ref import LLMRef
23
+ from nat.data_models.function import FunctionBaseConfig
24
+ from nat.profiler.parameter_optimization.prompt_optimizer import PromptOptimizerInputSchema
25
+
26
+
27
+ class PromptOptimizerConfig(FunctionBaseConfig, name="prompt_init"):
28
+
29
+ optimizer_llm: LLMRef = Field(description="LLM to use for prompt optimization")
30
+ optimizer_prompt: str = Field(
31
+ description="Prompt template for the optimizer",
32
+ default=(
33
+ "You are an expert at optimizing prompts for LLMs. "
34
+ "Your task is to take a given prompt and suggest an optimized version of it. "
35
+ "Note that the prompt might be a template with variables and curly braces. Remember to always keep the "
36
+ "variables and curly braces in the prompt the same. Only modify the instructions in the prompt that are"
37
+ "not variables. The system is meant to achieve the following objective\n"
38
+ "{system_objective}\n Of which, the prompt is one part. The details of the prompt and context as below.\n"))
39
+ system_objective: str = Field(description="Objective of the workflow")
40
+
41
+
42
+ @register_function(config_type=PromptOptimizerConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
43
+ async def prompt_optimizer_function(config: PromptOptimizerConfig, builder: Builder):
44
+ """
45
+ Function to optimize prompts for LLMs.
46
+ """
47
+
48
+ try:
49
+ from langchain_core.prompts import PromptTemplate
50
+
51
+ from .prompt import mutator_prompt
52
+ except ImportError as exc:
53
+ raise ImportError("langchain-core is not installed. Please install it to use MultiLLMPlanner.\n"
54
+ "This error can be resolve by installing nvidia-nat[langchain]") from exc
55
+
56
+ llm = await builder.get_llm(config.optimizer_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
57
+
58
+ template = PromptTemplate(template=config.optimizer_prompt,
59
+ input_variables=["system_objective"],
60
+ validate_template=True)
61
+
62
+ base_prompt: str = (await template.ainvoke(input={"system_objective": config.system_objective})).to_string()
63
+ prompt_extension_template = PromptTemplate(template=mutator_prompt,
64
+ input_variables=["original_prompt", "objective"],
65
+ validate_template=True)
66
+
67
+ async def _inner(input_message: PromptOptimizerInputSchema) -> str:
68
+ """
69
+ Optimize the prompt using the provided LLM.
70
+ """
71
+
72
+ original_prompt = input_message.original_prompt
73
+ prompt_objective = input_message.objective
74
+
75
+ prompt_extension = (await prompt_extension_template.ainvoke(input={
76
+ "original_prompt": original_prompt,
77
+ "objective": prompt_objective,
78
+ })).to_string()
79
+
80
+ prompt = f"{base_prompt}\n\n{prompt_extension}"
81
+
82
+ optimized_prompt = await llm.ainvoke(prompt)
83
+ return optimized_prompt.content
84
+
85
+ yield FunctionInfo.from_fn(
86
+ fn=_inner,
87
+ description="Optimize prompts for LLMs using a feedback LLM.",
88
+ )
89
+
90
+
91
+ class PromptRecombinerConfig(FunctionBaseConfig, name="prompt_recombiner"):
92
+
93
+ optimizer_llm: LLMRef = Field(description="LLM to use for prompt recombination")
94
+ optimizer_prompt: str = Field(
95
+ description="Prompt template for the recombiner",
96
+ default=("You are an expert at combining prompt instructions for LLMs. "
97
+ "Your task is to merge two prompts for the same objective into a single, stronger prompt. "
98
+ "Do not introduce new variables or modify existing placeholders."),
99
+ )
100
+ system_objective: str = Field(description="Objective of the workflow")
101
+
102
+
103
+ @register_function(config_type=PromptRecombinerConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
104
+ async def prompt_recombiner_function(config: PromptRecombinerConfig, builder: Builder):
105
+ """
106
+ Function to recombine two parent prompts into a child prompt using the optimizer LLM.
107
+ Uses the same base template and objective instructions.
108
+ """
109
+
110
+ try:
111
+ from langchain_core.prompts import PromptTemplate
112
+ except ImportError as exc:
113
+ raise ImportError("langchain-core is not installed. Please install it to use MultiLLMPlanner.\n"
114
+ "This error can be resolve by installing nvidia-nat[langchain].") from exc
115
+
116
+ llm = await builder.get_llm(config.optimizer_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
117
+
118
+ template = PromptTemplate(template=config.optimizer_prompt,
119
+ input_variables=["system_objective"],
120
+ validate_template=True)
121
+
122
+ base_prompt: str = (await template.ainvoke(input={"system_objective": config.system_objective})).to_string()
123
+
124
+ class RecombineSchema(PromptOptimizerInputSchema):
125
+ parent_b: str | None = None
126
+
127
+ async def _inner(input_message: RecombineSchema) -> str:
128
+ parent_a = input_message.original_prompt
129
+ parent_b = input_message.parent_b or ""
130
+ prompt_objective = input_message.objective
131
+
132
+ prompt = (
133
+ f"{base_prompt}\n\n"
134
+ "We are performing genetic recombination between two prompts that satisfy the same objective.\n"
135
+ f"Objective: {prompt_objective}\n\n"
136
+ f"Parent A:\n{parent_a}\n\n"
137
+ f"Parent B:\n{parent_b}\n\n"
138
+ "Combine the strongest instructions and phrasing from both parents to produce a single, coherent child "
139
+ "prompt.\n"
140
+ "Maintain variables and placeholders unchanged.\n"
141
+ "Return only the child prompt text, with no additional commentary.")
142
+
143
+ child_prompt = await llm.ainvoke(prompt)
144
+ return child_prompt.content
145
+
146
+ yield FunctionInfo.from_fn(
147
+ fn=_inner,
148
+ description="Recombine two prompts into a stronger child prompt.",
149
+ )
@@ -14,8 +14,8 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import json
17
- # pylint: disable=R0917
18
17
  import logging
18
+ import re
19
19
  import typing
20
20
  from json import JSONDecodeError
21
21
 
@@ -23,12 +23,14 @@ from langchain_core.agents import AgentAction
23
23
  from langchain_core.agents import AgentFinish
24
24
  from langchain_core.callbacks.base import AsyncCallbackHandler
25
25
  from langchain_core.language_models import BaseChatModel
26
+ from langchain_core.language_models import LanguageModelInput
26
27
  from langchain_core.messages.ai import AIMessage
27
28
  from langchain_core.messages.base import BaseMessage
28
29
  from langchain_core.messages.human import HumanMessage
29
30
  from langchain_core.messages.tool import ToolMessage
30
31
  from langchain_core.prompts import ChatPromptTemplate
31
32
  from langchain_core.prompts import MessagesPlaceholder
33
+ from langchain_core.runnables import Runnable
32
34
  from langchain_core.runnables.config import RunnableConfig
33
35
  from langchain_core.tools import BaseTool
34
36
  from pydantic import BaseModel
@@ -57,6 +59,7 @@ class ReActGraphState(BaseModel):
57
59
  messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReAct Agent
58
60
  agent_scratchpad: list[AgentAction] = Field(default_factory=list) # agent thoughts / intermediate steps
59
61
  tool_responses: list[BaseMessage] = Field(default_factory=list) # the responses from any tool calls
62
+ final_answer: str | None = Field(default=None) # the final answer from the ReAct Agent
60
63
 
61
64
 
62
65
  class ReActAgentGraph(DualNodeAgent):
@@ -71,15 +74,22 @@ class ReActAgentGraph(DualNodeAgent):
71
74
  use_tool_schema: bool = True,
72
75
  callbacks: list[AsyncCallbackHandler] | None = None,
73
76
  detailed_logs: bool = False,
77
+ log_response_max_chars: int = 1000,
74
78
  retry_agent_response_parsing_errors: bool = True,
75
79
  parse_agent_response_max_retries: int = 1,
76
80
  tool_call_max_retries: int = 1,
77
- pass_tool_call_errors_to_agent: bool = True):
78
- super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
81
+ pass_tool_call_errors_to_agent: bool = True,
82
+ normalize_tool_input_quotes: bool = True):
83
+ super().__init__(llm=llm,
84
+ tools=tools,
85
+ callbacks=callbacks,
86
+ detailed_logs=detailed_logs,
87
+ log_response_max_chars=log_response_max_chars)
79
88
  self.parse_agent_response_max_retries = (parse_agent_response_max_retries
80
89
  if retry_agent_response_parsing_errors else 1)
81
90
  self.tool_call_max_retries = tool_call_max_retries
82
91
  self.pass_tool_call_errors_to_agent = pass_tool_call_errors_to_agent
92
+ self.normalize_tool_input_quotes = normalize_tool_input_quotes
83
93
  logger.debug(
84
94
  "%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
85
95
  AGENT_LOG_PREFIX)
@@ -97,21 +107,33 @@ class ReActAgentGraph(DualNodeAgent):
97
107
  f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
98
108
  prompt = prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
99
109
  # construct the ReAct Agent
100
- bound_llm = llm.bind(stop=["Observation:"]) # type: ignore
101
- self.agent = prompt | bound_llm
110
+ self.agent = prompt | self._maybe_bind_llm_and_yield()
102
111
  self.tools_dict = {tool.name: tool for tool in tools}
103
112
  logger.debug("%s Initialized ReAct Agent Graph", AGENT_LOG_PREFIX)
104
113
 
114
+ def _maybe_bind_llm_and_yield(self) -> Runnable[LanguageModelInput, BaseMessage]:
115
+ """
116
+ Bind additional parameters to the LLM if needed
117
+ - if the LLM is a smart model, no need to bind any additional parameters
118
+ - if the LLM is a non-smart model, bind a stop sequence to the LLM
119
+
120
+ Returns:
121
+ Runnable[LanguageModelInput, BaseMessage]: The LLM with any additional parameters bound.
122
+ """
123
+ # models that don't need (or don't support)a stop sequence
124
+ smart_models = re.compile(r"gpt-?5", re.IGNORECASE)
125
+ if any(smart_models.search(getattr(self.llm, model, "")) for model in ["model", "model_name"]):
126
+ # no need to bind any additional parameters to the LLM
127
+ return self.llm
128
+ # add a stop sequence to the LLM
129
+ return self.llm.bind(stop=["Observation:"])
130
+
105
131
  def _get_tool(self, tool_name: str):
106
132
  try:
107
133
  return self.tools_dict.get(tool_name)
108
134
  except Exception as ex:
109
- logger.exception("%s Unable to find tool with the name %s\n%s",
110
- AGENT_LOG_PREFIX,
111
- tool_name,
112
- ex,
113
- exc_info=True)
114
- raise ex
135
+ logger.error("%s Unable to find tool with the name %s\n%s", AGENT_LOG_PREFIX, tool_name, ex)
136
+ raise
115
137
 
116
138
  async def agent_node(self, state: ReActGraphState):
117
139
  try:
@@ -183,6 +205,7 @@ class ReActAgentGraph(DualNodeAgent):
183
205
  # this is where we handle the final output of the Agent, we can clean-up/format/postprocess here
184
206
  # the final answer goes in the "messages" state channel
185
207
  state.messages += [AIMessage(content=final_answer)]
208
+ state.final_answer = final_answer
186
209
  else:
187
210
  # the agent wants to call a tool, ensure the thoughts are preserved for the next agentic cycle
188
211
  agent_output.log = output_message.content
@@ -215,16 +238,15 @@ class ReActAgentGraph(DualNodeAgent):
215
238
  working_state.append(output_message)
216
239
  working_state.append(HumanMessage(content=str(ex.observation)))
217
240
  except Exception as ex:
218
- logger.exception("%s Failed to call agent_node: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
219
- raise ex
241
+ logger.error("%s Failed to call agent_node: %s", AGENT_LOG_PREFIX, ex)
242
+ raise
220
243
 
221
244
  async def conditional_edge(self, state: ReActGraphState):
222
245
  try:
223
246
  logger.debug("%s Starting the ReAct Conditional Edge", AGENT_LOG_PREFIX)
224
- if len(state.messages) > 1:
225
- # the ReAct Agent has finished executing, the last agent output was AgentFinish
226
- last_message_content = str(state.messages[-1].content)
227
- logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, last_message_content)
247
+ if state.final_answer:
248
+ # the ReAct Agent has finished executing
249
+ logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, state.final_answer)
228
250
  return AgentDecision.END
229
251
  # else the agent wants to call a tool
230
252
  agent_output = state.agent_scratchpad[-1]
@@ -234,7 +256,7 @@ class ReActAgentGraph(DualNodeAgent):
234
256
  agent_output.tool_input)
235
257
  return AgentDecision.TOOL
236
258
  except Exception as ex:
237
- logger.exception("Failed to determine whether agent is calling a tool: %s", ex, exc_info=True)
259
+ logger.exception("Failed to determine whether agent is calling a tool: %s", ex)
238
260
  logger.warning("%s Ending graph traversal", AGENT_LOG_PREFIX)
239
261
  return AgentDecision.END
240
262
 
@@ -267,35 +289,45 @@ class ReActAgentGraph(DualNodeAgent):
267
289
  agent_thoughts.tool_input)
268
290
 
269
291
  # Run the tool. Try to use structured input, if possible.
292
+ tool_input_str = agent_thoughts.tool_input.strip()
293
+
270
294
  try:
271
- tool_input_str = str(agent_thoughts.tool_input).strip().replace("'", '"')
272
- tool_input_dict = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
295
+ tool_input = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
273
296
  logger.debug("%s Successfully parsed structured tool input from Action Input", AGENT_LOG_PREFIX)
274
297
 
275
- tool_response = await self._call_tool(requested_tool,
276
- tool_input_dict,
277
- RunnableConfig(callbacks=self.callbacks),
278
- max_retries=self.tool_call_max_retries)
279
-
280
- if self.detailed_logs:
281
- self._log_tool_response(requested_tool.name, tool_input_dict, str(tool_response.content))
282
-
283
- except JSONDecodeError as ex:
284
- logger.debug(
285
- "%s Unable to parse structured tool input from Action Input. Using Action Input as is."
286
- "\nParsing error: %s",
287
- AGENT_LOG_PREFIX,
288
- ex,
289
- exc_info=True)
290
- tool_input_str = str(agent_thoughts.tool_input)
291
-
292
- tool_response = await self._call_tool(requested_tool,
293
- tool_input_str,
294
- RunnableConfig(callbacks=self.callbacks),
295
- max_retries=self.tool_call_max_retries)
298
+ except JSONDecodeError as original_ex:
299
+ if self.normalize_tool_input_quotes:
300
+ # If initial JSON parsing fails, try with quote normalization as a fallback
301
+ normalized_str = tool_input_str.replace("'", '"')
302
+ try:
303
+ tool_input = json.loads(normalized_str)
304
+ logger.debug("%s Successfully parsed structured tool input after quote normalization",
305
+ AGENT_LOG_PREFIX)
306
+ except JSONDecodeError:
307
+ # the quote normalization failed, use raw string input
308
+ logger.debug(
309
+ "%s Unable to parse structured tool input after quote normalization. Using Action Input as is."
310
+ "\nParsing error: %s",
311
+ AGENT_LOG_PREFIX,
312
+ original_ex)
313
+ tool_input = tool_input_str
314
+ else:
315
+ # use raw string input
316
+ logger.debug(
317
+ "%s Unable to parse structured tool input from Action Input. Using Action Input as is."
318
+ "\nParsing error: %s",
319
+ AGENT_LOG_PREFIX,
320
+ original_ex)
321
+ tool_input = tool_input_str
322
+
323
+ # Call tool once with the determined input (either parsed dict or raw string)
324
+ tool_response = await self._call_tool(requested_tool,
325
+ tool_input,
326
+ RunnableConfig(callbacks=self.callbacks),
327
+ max_retries=self.tool_call_max_retries)
296
328
 
297
329
  if self.detailed_logs:
298
- self._log_tool_response(requested_tool.name, tool_input_str, str(tool_response.content))
330
+ self._log_tool_response(requested_tool.name, tool_input, str(tool_response.content))
299
331
 
300
332
  if not self.pass_tool_call_errors_to_agent:
301
333
  if tool_response.status == "error":
@@ -311,8 +343,8 @@ class ReActAgentGraph(DualNodeAgent):
311
343
  logger.debug("%s ReAct Graph built and compiled successfully", AGENT_LOG_PREFIX)
312
344
  return self.graph
313
345
  except Exception as ex:
314
- logger.exception("%s Failed to build ReAct Graph: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
315
- raise ex
346
+ logger.error("%s Failed to build ReAct Graph: %s", AGENT_LOG_PREFIX, ex)
347
+ raise
316
348
 
317
349
  @staticmethod
318
350
  def validate_system_prompt(system_prompt: str) -> bool:
@@ -328,8 +360,8 @@ class ReActAgentGraph(DualNodeAgent):
328
360
  errors.append(error_message)
329
361
  if errors:
330
362
  error_text = "\n".join(errors)
331
- logger.exception("%s %s", AGENT_LOG_PREFIX, error_text)
332
- raise ValueError(error_text)
363
+ logger.error("%s %s", AGENT_LOG_PREFIX, error_text)
364
+ return False
333
365
  return True
334
366
 
335
367
 
@@ -355,7 +387,7 @@ def create_react_agent_prompt(config: "ReActAgentWorkflowConfig") -> ChatPromptT
355
387
 
356
388
  valid_prompt = ReActAgentGraph.validate_system_prompt(prompt_str)
357
389
  if not valid_prompt:
358
- logger.exception("%s Invalid system_prompt", AGENT_LOG_PREFIX)
390
+ logger.error("%s Invalid system_prompt", AGENT_LOG_PREFIX)
359
391
  raise ValueError("Invalid system_prompt")
360
392
  prompt = ChatPromptTemplate([("system", prompt_str), ("user", USER_PROMPT),
361
393
  MessagesPlaceholder(variable_name='agent_scratchpad', optional=True)])