nvidia-nat 1.3.dev0__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (246) hide show
  1. aiq/__init__.py +66 -0
  2. nat/agent/base.py +40 -14
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +96 -57
  7. nat/agent/react_agent/prompt.py +4 -1
  8. nat/agent/react_agent/register.py +41 -21
  9. nat/agent/reasoning_agent/reasoning_agent.py +11 -9
  10. nat/agent/register.py +1 -1
  11. nat/agent/rewoo_agent/agent.py +332 -150
  12. nat/agent/rewoo_agent/prompt.py +22 -22
  13. nat/agent/rewoo_agent/register.py +49 -28
  14. nat/agent/tool_calling_agent/agent.py +156 -29
  15. nat/agent/tool_calling_agent/register.py +57 -38
  16. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  17. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  18. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  19. nat/authentication/interfaces.py +5 -2
  20. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
  21. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  22. nat/authentication/register.py +0 -1
  23. nat/builder/builder.py +56 -24
  24. nat/builder/component_utils.py +9 -5
  25. nat/builder/context.py +46 -11
  26. nat/builder/eval_builder.py +16 -11
  27. nat/builder/framework_enum.py +1 -0
  28. nat/builder/front_end.py +1 -1
  29. nat/builder/function.py +378 -8
  30. nat/builder/function_base.py +3 -3
  31. nat/builder/function_info.py +6 -8
  32. nat/builder/user_interaction_manager.py +2 -2
  33. nat/builder/workflow.py +13 -1
  34. nat/builder/workflow_builder.py +281 -76
  35. nat/cli/cli_utils/config_override.py +2 -2
  36. nat/cli/commands/evaluate.py +1 -1
  37. nat/cli/commands/info/info.py +16 -6
  38. nat/cli/commands/info/list_channels.py +1 -1
  39. nat/cli/commands/info/list_components.py +7 -8
  40. nat/cli/commands/mcp/__init__.py +14 -0
  41. nat/cli/commands/mcp/mcp.py +986 -0
  42. nat/cli/commands/object_store/__init__.py +14 -0
  43. nat/cli/commands/object_store/object_store.py +227 -0
  44. nat/cli/commands/optimize.py +90 -0
  45. nat/cli/commands/registry/publish.py +2 -2
  46. nat/cli/commands/registry/pull.py +2 -2
  47. nat/cli/commands/registry/remove.py +2 -2
  48. nat/cli/commands/registry/search.py +15 -17
  49. nat/cli/commands/start.py +16 -5
  50. nat/cli/commands/uninstall.py +1 -1
  51. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  52. nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
  53. nat/cli/commands/workflow/templates/register.py.j2 +0 -1
  54. nat/cli/commands/workflow/workflow_commands.py +9 -13
  55. nat/cli/entrypoint.py +8 -10
  56. nat/cli/register_workflow.py +38 -4
  57. nat/cli/type_registry.py +79 -10
  58. nat/control_flow/__init__.py +0 -0
  59. nat/control_flow/register.py +20 -0
  60. nat/control_flow/router_agent/__init__.py +0 -0
  61. nat/control_flow/router_agent/agent.py +329 -0
  62. nat/control_flow/router_agent/prompt.py +48 -0
  63. nat/control_flow/router_agent/register.py +91 -0
  64. nat/control_flow/sequential_executor.py +166 -0
  65. nat/data_models/agent.py +34 -0
  66. nat/data_models/api_server.py +10 -10
  67. nat/data_models/authentication.py +23 -9
  68. nat/data_models/common.py +1 -1
  69. nat/data_models/component.py +2 -0
  70. nat/data_models/component_ref.py +11 -0
  71. nat/data_models/config.py +41 -17
  72. nat/data_models/dataset_handler.py +1 -1
  73. nat/data_models/discovery_metadata.py +4 -4
  74. nat/data_models/evaluate.py +4 -1
  75. nat/data_models/function.py +34 -0
  76. nat/data_models/function_dependencies.py +14 -6
  77. nat/data_models/gated_field_mixin.py +242 -0
  78. nat/data_models/intermediate_step.py +3 -3
  79. nat/data_models/optimizable.py +119 -0
  80. nat/data_models/optimizer.py +149 -0
  81. nat/data_models/swe_bench_model.py +1 -1
  82. nat/data_models/temperature_mixin.py +44 -0
  83. nat/data_models/thinking_mixin.py +86 -0
  84. nat/data_models/top_p_mixin.py +44 -0
  85. nat/embedder/azure_openai_embedder.py +46 -0
  86. nat/embedder/nim_embedder.py +1 -1
  87. nat/embedder/openai_embedder.py +2 -3
  88. nat/embedder/register.py +1 -1
  89. nat/eval/config.py +3 -1
  90. nat/eval/dataset_handler/dataset_handler.py +71 -7
  91. nat/eval/evaluate.py +86 -31
  92. nat/eval/evaluator/base_evaluator.py +1 -1
  93. nat/eval/evaluator/evaluator_model.py +13 -0
  94. nat/eval/intermediate_step_adapter.py +1 -1
  95. nat/eval/rag_evaluator/evaluate.py +2 -2
  96. nat/eval/rag_evaluator/register.py +3 -3
  97. nat/eval/register.py +4 -1
  98. nat/eval/remote_workflow.py +3 -3
  99. nat/eval/runtime_evaluator/__init__.py +14 -0
  100. nat/eval/runtime_evaluator/evaluate.py +123 -0
  101. nat/eval/runtime_evaluator/register.py +100 -0
  102. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  103. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  104. nat/eval/trajectory_evaluator/register.py +1 -1
  105. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  106. nat/eval/utils/eval_trace_ctx.py +89 -0
  107. nat/eval/utils/weave_eval.py +18 -9
  108. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  109. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  110. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  111. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  112. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  113. nat/experimental/test_time_compute/register.py +0 -1
  114. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  115. nat/front_ends/console/authentication_flow_handler.py +82 -30
  116. nat/front_ends/console/console_front_end_plugin.py +8 -5
  117. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  118. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  119. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  120. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  121. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  122. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
  123. nat/front_ends/fastapi/job_store.py +518 -99
  124. nat/front_ends/fastapi/main.py +11 -19
  125. nat/front_ends/fastapi/message_handler.py +13 -14
  126. nat/front_ends/fastapi/message_validator.py +17 -19
  127. nat/front_ends/fastapi/response_helpers.py +4 -4
  128. nat/front_ends/fastapi/step_adaptor.py +2 -2
  129. nat/front_ends/fastapi/utils.py +57 -0
  130. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  131. nat/front_ends/mcp/mcp_front_end_config.py +10 -1
  132. nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
  133. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
  134. nat/front_ends/mcp/tool_converter.py +44 -14
  135. nat/front_ends/register.py +0 -1
  136. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  137. nat/llm/aws_bedrock_llm.py +24 -12
  138. nat/llm/azure_openai_llm.py +57 -0
  139. nat/llm/litellm_llm.py +69 -0
  140. nat/llm/nim_llm.py +20 -8
  141. nat/llm/openai_llm.py +14 -6
  142. nat/llm/register.py +5 -1
  143. nat/llm/utils/env_config_value.py +2 -3
  144. nat/llm/utils/thinking.py +215 -0
  145. nat/meta/pypi.md +9 -9
  146. nat/object_store/models.py +2 -0
  147. nat/object_store/register.py +0 -1
  148. nat/observability/exporter/base_exporter.py +3 -3
  149. nat/observability/exporter/file_exporter.py +1 -1
  150. nat/observability/exporter/processing_exporter.py +309 -81
  151. nat/observability/exporter/span_exporter.py +1 -1
  152. nat/observability/exporter_manager.py +7 -7
  153. nat/observability/mixin/file_mixin.py +7 -7
  154. nat/observability/mixin/redaction_config_mixin.py +42 -0
  155. nat/observability/mixin/tagging_config_mixin.py +62 -0
  156. nat/observability/mixin/type_introspection_mixin.py +420 -107
  157. nat/observability/processor/batching_processor.py +5 -7
  158. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  159. nat/observability/processor/processor.py +3 -0
  160. nat/observability/processor/processor_factory.py +70 -0
  161. nat/observability/processor/redaction/__init__.py +24 -0
  162. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  163. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  164. nat/observability/processor/redaction/redaction_processor.py +177 -0
  165. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  166. nat/observability/processor/span_tagging_processor.py +68 -0
  167. nat/observability/register.py +6 -4
  168. nat/profiler/calc/calc_runner.py +3 -4
  169. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  170. nat/profiler/callbacks/langchain_callback_handler.py +14 -7
  171. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  172. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  173. nat/profiler/data_frame_row.py +1 -1
  174. nat/profiler/decorators/framework_wrapper.py +62 -13
  175. nat/profiler/decorators/function_tracking.py +160 -3
  176. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  177. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  178. nat/profiler/inference_optimization/data_models.py +3 -3
  179. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
  180. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  181. nat/profiler/parameter_optimization/__init__.py +0 -0
  182. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  183. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  184. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  185. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  186. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  187. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  188. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  189. nat/profiler/profile_runner.py +14 -9
  190. nat/profiler/utils.py +4 -2
  191. nat/registry_handlers/local/local_handler.py +2 -2
  192. nat/registry_handlers/package_utils.py +1 -2
  193. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  194. nat/registry_handlers/register.py +3 -4
  195. nat/registry_handlers/rest/rest_handler.py +12 -13
  196. nat/retriever/milvus/retriever.py +2 -2
  197. nat/retriever/nemo_retriever/retriever.py +1 -1
  198. nat/retriever/register.py +0 -1
  199. nat/runtime/loader.py +2 -2
  200. nat/runtime/runner.py +3 -2
  201. nat/runtime/session.py +43 -8
  202. nat/settings/global_settings.py +16 -5
  203. nat/tool/chat_completion.py +5 -2
  204. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  205. nat/tool/datetime_tools.py +49 -9
  206. nat/tool/document_search.py +2 -2
  207. nat/tool/github_tools.py +450 -0
  208. nat/tool/nvidia_rag.py +1 -1
  209. nat/tool/register.py +2 -9
  210. nat/tool/retriever.py +3 -2
  211. nat/utils/callable_utils.py +70 -0
  212. nat/utils/data_models/schema_validator.py +3 -3
  213. nat/utils/exception_handlers/automatic_retries.py +104 -51
  214. nat/utils/exception_handlers/schemas.py +1 -1
  215. nat/utils/io/yaml_tools.py +2 -2
  216. nat/utils/log_levels.py +25 -0
  217. nat/utils/reactive/base/observable_base.py +2 -2
  218. nat/utils/reactive/base/observer_base.py +1 -1
  219. nat/utils/reactive/observable.py +2 -2
  220. nat/utils/reactive/observer.py +4 -4
  221. nat/utils/reactive/subscription.py +1 -1
  222. nat/utils/settings/global_settings.py +6 -8
  223. nat/utils/type_converter.py +4 -3
  224. nat/utils/type_utils.py +9 -5
  225. {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +49 -21
  226. {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +233 -189
  227. {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
  228. nvidia_nat-1.3.0rc1.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  229. {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +1 -0
  230. nat/cli/commands/info/list_mcp.py +0 -304
  231. nat/tool/github_tools/create_github_commit.py +0 -133
  232. nat/tool/github_tools/create_github_issue.py +0 -87
  233. nat/tool/github_tools/create_github_pr.py +0 -106
  234. nat/tool/github_tools/get_github_file.py +0 -106
  235. nat/tool/github_tools/get_github_issue.py +0 -166
  236. nat/tool/github_tools/get_github_pr.py +0 -256
  237. nat/tool/github_tools/update_github_issue.py +0 -100
  238. nat/tool/mcp/exceptions.py +0 -142
  239. nat/tool/mcp/mcp_client.py +0 -255
  240. nat/tool/mcp/mcp_tool.py +0 -96
  241. nat/utils/exception_handlers/mcp.py +0 -211
  242. nvidia_nat-1.3.dev0.dist-info/licenses/LICENSE-3rd-party.txt +0 -3686
  243. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  244. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  245. {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
  246. {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
@@ -14,20 +14,23 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import json
17
- # pylint: disable=R0917
18
17
  import logging
18
+ import re
19
+ import typing
19
20
  from json import JSONDecodeError
20
21
 
21
22
  from langchain_core.agents import AgentAction
22
23
  from langchain_core.agents import AgentFinish
23
24
  from langchain_core.callbacks.base import AsyncCallbackHandler
24
25
  from langchain_core.language_models import BaseChatModel
26
+ from langchain_core.language_models import LanguageModelInput
25
27
  from langchain_core.messages.ai import AIMessage
26
28
  from langchain_core.messages.base import BaseMessage
27
29
  from langchain_core.messages.human import HumanMessage
28
30
  from langchain_core.messages.tool import ToolMessage
29
31
  from langchain_core.prompts import ChatPromptTemplate
30
32
  from langchain_core.prompts import MessagesPlaceholder
33
+ from langchain_core.runnables import Runnable
31
34
  from langchain_core.runnables.config import RunnableConfig
32
35
  from langchain_core.tools import BaseTool
33
36
  from pydantic import BaseModel
@@ -44,7 +47,9 @@ from nat.agent.react_agent.output_parser import ReActOutputParser
44
47
  from nat.agent.react_agent.output_parser import ReActOutputParserException
45
48
  from nat.agent.react_agent.prompt import SYSTEM_PROMPT
46
49
  from nat.agent.react_agent.prompt import USER_PROMPT
47
- from nat.agent.react_agent.register import ReActAgentWorkflowConfig
50
+
51
+ if typing.TYPE_CHECKING:
52
+ from nat.agent.react_agent.register import ReActAgentWorkflowConfig
48
53
 
49
54
  logger = logging.getLogger(__name__)
50
55
 
@@ -54,6 +59,7 @@ class ReActGraphState(BaseModel):
54
59
  messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReAct Agent
55
60
  agent_scratchpad: list[AgentAction] = Field(default_factory=list) # agent thoughts / intermediate steps
56
61
  tool_responses: list[BaseMessage] = Field(default_factory=list) # the responses from any tool calls
62
+ final_answer: str | None = Field(default=None) # the final answer from the ReAct Agent
57
63
 
58
64
 
59
65
  class ReActAgentGraph(DualNodeAgent):
@@ -68,15 +74,22 @@ class ReActAgentGraph(DualNodeAgent):
68
74
  use_tool_schema: bool = True,
69
75
  callbacks: list[AsyncCallbackHandler] | None = None,
70
76
  detailed_logs: bool = False,
77
+ log_response_max_chars: int = 1000,
71
78
  retry_agent_response_parsing_errors: bool = True,
72
79
  parse_agent_response_max_retries: int = 1,
73
80
  tool_call_max_retries: int = 1,
74
- pass_tool_call_errors_to_agent: bool = True):
75
- super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
81
+ pass_tool_call_errors_to_agent: bool = True,
82
+ normalize_tool_input_quotes: bool = True):
83
+ super().__init__(llm=llm,
84
+ tools=tools,
85
+ callbacks=callbacks,
86
+ detailed_logs=detailed_logs,
87
+ log_response_max_chars=log_response_max_chars)
76
88
  self.parse_agent_response_max_retries = (parse_agent_response_max_retries
77
89
  if retry_agent_response_parsing_errors else 1)
78
90
  self.tool_call_max_retries = tool_call_max_retries
79
91
  self.pass_tool_call_errors_to_agent = pass_tool_call_errors_to_agent
92
+ self.normalize_tool_input_quotes = normalize_tool_input_quotes
80
93
  logger.debug(
81
94
  "%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
82
95
  AGENT_LOG_PREFIX)
@@ -94,21 +107,33 @@ class ReActAgentGraph(DualNodeAgent):
94
107
  f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
95
108
  prompt = prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
96
109
  # construct the ReAct Agent
97
- bound_llm = llm.bind(stop=["Observation:"]) # type: ignore
98
- self.agent = prompt | bound_llm
110
+ self.agent = prompt | self._maybe_bind_llm_and_yield()
99
111
  self.tools_dict = {tool.name: tool for tool in tools}
100
112
  logger.debug("%s Initialized ReAct Agent Graph", AGENT_LOG_PREFIX)
101
113
 
114
+ def _maybe_bind_llm_and_yield(self) -> Runnable[LanguageModelInput, BaseMessage]:
115
+ """
116
+ Bind additional parameters to the LLM if needed
117
+ - if the LLM is a smart model, no need to bind any additional parameters
118
+ - if the LLM is a non-smart model, bind a stop sequence to the LLM
119
+
120
+ Returns:
121
+ Runnable[LanguageModelInput, BaseMessage]: The LLM with any additional parameters bound.
122
+ """
123
+ # models that don't need (or don't support)a stop sequence
124
+ smart_models = re.compile(r"gpt-?5", re.IGNORECASE)
125
+ if any(smart_models.search(getattr(self.llm, model, "")) for model in ["model", "model_name"]):
126
+ # no need to bind any additional parameters to the LLM
127
+ return self.llm
128
+ # add a stop sequence to the LLM
129
+ return self.llm.bind(stop=["Observation:"])
130
+
102
131
  def _get_tool(self, tool_name: str):
103
132
  try:
104
133
  return self.tools_dict.get(tool_name)
105
134
  except Exception as ex:
106
- logger.exception("%s Unable to find tool with the name %s\n%s",
107
- AGENT_LOG_PREFIX,
108
- tool_name,
109
- ex,
110
- exc_info=True)
111
- raise ex
135
+ logger.error("%s Unable to find tool with the name %s\n%s", AGENT_LOG_PREFIX, tool_name, ex)
136
+ raise
112
137
 
113
138
  async def agent_node(self, state: ReActGraphState):
114
139
  try:
@@ -124,17 +149,19 @@ class ReActAgentGraph(DualNodeAgent):
124
149
  if len(state.messages) == 0:
125
150
  raise RuntimeError('No input received in state: "messages"')
126
151
  # to check is any human input passed or not, if no input passed Agent will return the state
127
- content = str(state.messages[0].content)
152
+ content = str(state.messages[-1].content)
128
153
  if content.strip() == "":
129
154
  logger.error("%s No human input passed to the agent.", AGENT_LOG_PREFIX)
130
155
  state.messages += [AIMessage(content=NO_INPUT_ERROR_MESSAGE)]
131
156
  return state
132
157
  question = content
133
158
  logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
134
-
159
+ chat_history = self._get_chat_history(state.messages)
135
160
  output_message = await self._stream_llm(
136
161
  self.agent,
137
- {"question": question},
162
+ {
163
+ "question": question, "chat_history": chat_history
164
+ },
138
165
  RunnableConfig(callbacks=self.callbacks) # type: ignore
139
166
  )
140
167
 
@@ -152,13 +179,15 @@ class ReActAgentGraph(DualNodeAgent):
152
179
  tool_response = HumanMessage(content=tool_response_content)
153
180
  agent_scratchpad.append(tool_response)
154
181
  agent_scratchpad += working_state
155
- question = str(state.messages[0].content)
182
+ chat_history = self._get_chat_history(state.messages)
183
+ question = str(state.messages[-1].content)
156
184
  logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
157
185
 
158
- output_message = await self._stream_llm(self.agent, {
159
- "question": question, "agent_scratchpad": agent_scratchpad
160
- },
161
- RunnableConfig(callbacks=self.callbacks))
186
+ output_message = await self._stream_llm(
187
+ self.agent, {
188
+ "question": question, "agent_scratchpad": agent_scratchpad, "chat_history": chat_history
189
+ },
190
+ RunnableConfig(callbacks=self.callbacks))
162
191
 
163
192
  if self.detailed_logs:
164
193
  logger.info(AGENT_CALL_LOG_MESSAGE, question, output_message.content)
@@ -176,6 +205,7 @@ class ReActAgentGraph(DualNodeAgent):
176
205
  # this is where we handle the final output of the Agent, we can clean-up/format/postprocess here
177
206
  # the final answer goes in the "messages" state channel
178
207
  state.messages += [AIMessage(content=final_answer)]
208
+ state.final_answer = final_answer
179
209
  else:
180
210
  # the agent wants to call a tool, ensure the thoughts are preserved for the next agentic cycle
181
211
  agent_output.log = output_message.content
@@ -208,16 +238,15 @@ class ReActAgentGraph(DualNodeAgent):
208
238
  working_state.append(output_message)
209
239
  working_state.append(HumanMessage(content=str(ex.observation)))
210
240
  except Exception as ex:
211
- logger.exception("%s Failed to call agent_node: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
212
- raise ex
241
+ logger.error("%s Failed to call agent_node: %s", AGENT_LOG_PREFIX, ex)
242
+ raise
213
243
 
214
244
  async def conditional_edge(self, state: ReActGraphState):
215
245
  try:
216
246
  logger.debug("%s Starting the ReAct Conditional Edge", AGENT_LOG_PREFIX)
217
- if len(state.messages) > 1:
218
- # the ReAct Agent has finished executing, the last agent output was AgentFinish
219
- last_message_content = str(state.messages[-1].content)
220
- logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, last_message_content)
247
+ if state.final_answer:
248
+ # the ReAct Agent has finished executing
249
+ logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, state.final_answer)
221
250
  return AgentDecision.END
222
251
  # else the agent wants to call a tool
223
252
  agent_output = state.agent_scratchpad[-1]
@@ -227,7 +256,7 @@ class ReActAgentGraph(DualNodeAgent):
227
256
  agent_output.tool_input)
228
257
  return AgentDecision.TOOL
229
258
  except Exception as ex:
230
- logger.exception("Failed to determine whether agent is calling a tool: %s", ex, exc_info=True)
259
+ logger.exception("Failed to determine whether agent is calling a tool: %s", ex)
231
260
  logger.warning("%s Ending graph traversal", AGENT_LOG_PREFIX)
232
261
  return AgentDecision.END
233
262
 
@@ -260,35 +289,45 @@ class ReActAgentGraph(DualNodeAgent):
260
289
  agent_thoughts.tool_input)
261
290
 
262
291
  # Run the tool. Try to use structured input, if possible.
292
+ tool_input_str = agent_thoughts.tool_input.strip()
293
+
263
294
  try:
264
- tool_input_str = str(agent_thoughts.tool_input).strip().replace("'", '"')
265
- tool_input_dict = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
295
+ tool_input = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
266
296
  logger.debug("%s Successfully parsed structured tool input from Action Input", AGENT_LOG_PREFIX)
267
297
 
268
- tool_response = await self._call_tool(requested_tool,
269
- tool_input_dict,
270
- RunnableConfig(callbacks=self.callbacks),
271
- max_retries=self.tool_call_max_retries)
272
-
273
- if self.detailed_logs:
274
- self._log_tool_response(requested_tool.name, tool_input_dict, str(tool_response.content))
275
-
276
- except JSONDecodeError as ex:
277
- logger.debug(
278
- "%s Unable to parse structured tool input from Action Input. Using Action Input as is."
279
- "\nParsing error: %s",
280
- AGENT_LOG_PREFIX,
281
- ex,
282
- exc_info=True)
283
- tool_input_str = str(agent_thoughts.tool_input)
284
-
285
- tool_response = await self._call_tool(requested_tool,
286
- tool_input_str,
287
- RunnableConfig(callbacks=self.callbacks),
288
- max_retries=self.tool_call_max_retries)
298
+ except JSONDecodeError as original_ex:
299
+ if self.normalize_tool_input_quotes:
300
+ # If initial JSON parsing fails, try with quote normalization as a fallback
301
+ normalized_str = tool_input_str.replace("'", '"')
302
+ try:
303
+ tool_input = json.loads(normalized_str)
304
+ logger.debug("%s Successfully parsed structured tool input after quote normalization",
305
+ AGENT_LOG_PREFIX)
306
+ except JSONDecodeError:
307
+ # the quote normalization failed, use raw string input
308
+ logger.debug(
309
+ "%s Unable to parse structured tool input after quote normalization. Using Action Input as is."
310
+ "\nParsing error: %s",
311
+ AGENT_LOG_PREFIX,
312
+ original_ex)
313
+ tool_input = tool_input_str
314
+ else:
315
+ # use raw string input
316
+ logger.debug(
317
+ "%s Unable to parse structured tool input from Action Input. Using Action Input as is."
318
+ "\nParsing error: %s",
319
+ AGENT_LOG_PREFIX,
320
+ original_ex)
321
+ tool_input = tool_input_str
322
+
323
+ # Call tool once with the determined input (either parsed dict or raw string)
324
+ tool_response = await self._call_tool(requested_tool,
325
+ tool_input,
326
+ RunnableConfig(callbacks=self.callbacks),
327
+ max_retries=self.tool_call_max_retries)
289
328
 
290
329
  if self.detailed_logs:
291
- self._log_tool_response(requested_tool.name, tool_input_str, str(tool_response.content))
330
+ self._log_tool_response(requested_tool.name, tool_input, str(tool_response.content))
292
331
 
293
332
  if not self.pass_tool_call_errors_to_agent:
294
333
  if tool_response.status == "error":
@@ -304,8 +343,8 @@ class ReActAgentGraph(DualNodeAgent):
304
343
  logger.debug("%s ReAct Graph built and compiled successfully", AGENT_LOG_PREFIX)
305
344
  return self.graph
306
345
  except Exception as ex:
307
- logger.exception("%s Failed to build ReAct Graph: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
308
- raise ex
346
+ logger.error("%s Failed to build ReAct Graph: %s", AGENT_LOG_PREFIX, ex)
347
+ raise
309
348
 
310
349
  @staticmethod
311
350
  def validate_system_prompt(system_prompt: str) -> bool:
@@ -321,12 +360,12 @@ class ReActAgentGraph(DualNodeAgent):
321
360
  errors.append(error_message)
322
361
  if errors:
323
362
  error_text = "\n".join(errors)
324
- logger.exception("%s %s", AGENT_LOG_PREFIX, error_text)
325
- raise ValueError(error_text)
363
+ logger.error("%s %s", AGENT_LOG_PREFIX, error_text)
364
+ return False
326
365
  return True
327
366
 
328
367
 
329
- def create_react_agent_prompt(config: ReActAgentWorkflowConfig) -> ChatPromptTemplate:
368
+ def create_react_agent_prompt(config: "ReActAgentWorkflowConfig") -> ChatPromptTemplate:
330
369
  """
331
370
  Create a ReAct Agent prompt from the config.
332
371
 
@@ -348,7 +387,7 @@ def create_react_agent_prompt(config: ReActAgentWorkflowConfig) -> ChatPromptTem
348
387
 
349
388
  valid_prompt = ReActAgentGraph.validate_system_prompt(prompt_str)
350
389
  if not valid_prompt:
351
- logger.exception("%s Invalid system_prompt", AGENT_LOG_PREFIX)
390
+ logger.error("%s Invalid system_prompt", AGENT_LOG_PREFIX)
352
391
  raise ValueError("Invalid system_prompt")
353
392
  prompt = ChatPromptTemplate([("system", prompt_str), ("user", USER_PROMPT),
354
393
  MessagesPlaceholder(variable_name='agent_scratchpad', optional=True)])
@@ -26,7 +26,7 @@ Use the following format exactly to ask the human to use a tool:
26
26
  Question: the input question you must answer
27
27
  Thought: you should always think about what to do
28
28
  Action: the action to take, should be one of [{tool_names}]
29
- Action Input: the input to the action (if there is no required input, include "Action Input: None")
29
+ Action Input: the input to the action (if there is no required input, include "Action Input: None")
30
30
  Observation: wait for the human to respond with the result from the tool, do not assume the response
31
31
 
32
32
  ... (this Thought/Action/Action Input/Observation can repeat N times. If you do not need to use a tool, or after asking the human to use any tools and waiting for the human to respond, you might know the final answer.)
@@ -37,5 +37,8 @@ Final Answer: the final answer to the original input question
37
37
  """
38
38
 
39
39
  USER_PROMPT = """
40
+ Previous conversation history:
41
+ {chat_history}
42
+
40
43
  Question: {question}
41
44
  """
@@ -22,26 +22,27 @@ from nat.builder.builder import Builder
22
22
  from nat.builder.framework_enum import LLMFrameworkEnum
23
23
  from nat.builder.function_info import FunctionInfo
24
24
  from nat.cli.register_workflow import register_function
25
+ from nat.data_models.agent import AgentBaseConfig
25
26
  from nat.data_models.api_server import ChatRequest
26
27
  from nat.data_models.api_server import ChatResponse
28
+ from nat.data_models.component_ref import FunctionGroupRef
27
29
  from nat.data_models.component_ref import FunctionRef
28
- from nat.data_models.component_ref import LLMRef
29
- from nat.data_models.function import FunctionBaseConfig
30
+ from nat.data_models.optimizable import OptimizableField
31
+ from nat.data_models.optimizable import OptimizableMixin
32
+ from nat.data_models.optimizable import SearchSpace
30
33
  from nat.utils.type_converter import GlobalTypeConverter
31
34
 
32
35
  logger = logging.getLogger(__name__)
33
36
 
34
37
 
35
- class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
38
+ class ReActAgentWorkflowConfig(AgentBaseConfig, OptimizableMixin, name="react_agent"):
36
39
  """
37
40
  Defines a NAT function that uses a ReAct Agent performs reasoning inbetween tool calls, and utilizes the
38
41
  tool names and descriptions to select the optimal tool.
39
42
  """
40
-
41
- tool_names: list[FunctionRef] = Field(default_factory=list,
42
- description="The list of tools to provide to the react agent.")
43
- llm_name: LLMRef = Field(description="The LLM model to use with the react agent.")
44
- verbose: bool = Field(default=False, description="Set the verbosity of the react agent's logging.")
43
+ description: str = Field(default="ReAct Agent Workflow", description="The description of this functions use.")
44
+ tool_names: list[FunctionRef | FunctionGroupRef] = Field(
45
+ default_factory=list, description="The list of tools to provide to the react agent.")
45
46
  retry_agent_response_parsing_errors: bool = Field(
46
47
  default=True,
47
48
  validation_alias=AliasChoices("retry_agent_response_parsing_errors", "retry_parsing_errors"),
@@ -60,7 +61,10 @@ class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
60
61
  description="Whether to pass tool call errors to agent. If False, failed tool calls will raise an exception.")
61
62
  include_tool_input_schema_in_tool_description: bool = Field(
62
63
  default=True, description="Specify inclusion of tool input schemas in the prompt.")
63
- description: str = Field(default="ReAct Agent Workflow", description="The description of this functions use.")
64
+ normalize_tool_input_quotes: bool = Field(
65
+ default=True,
66
+ description="Whether to replace single quotes with double quotes in the tool input. "
67
+ "This is useful for tools that expect structured json input.")
64
68
  system_prompt: str | None = Field(
65
69
  default=None,
66
70
  description="Provides the SYSTEM_PROMPT to use with the agent") # defaults to SYSTEM_PROMPT in prompt.py
@@ -68,15 +72,21 @@ class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
68
72
  use_openai_api: bool = Field(default=False,
69
73
  description=("Use OpenAI API for the input/output types to the function. "
70
74
  "If False, strings will be used."))
71
- additional_instructions: str | None = Field(
72
- default=None, description="Additional instructions to provide to the agent in addition to the base prompt.")
75
+ additional_instructions: str | None = OptimizableField(
76
+ default=None,
77
+ description="Additional instructions to provide to the agent in addition to the base prompt.",
78
+ space=SearchSpace(
79
+ is_prompt=True,
80
+ prompt="No additional instructions.",
81
+ prompt_purpose="Additional instructions to provide to the agent in addition to the base prompt.",
82
+ ))
73
83
 
74
84
 
75
85
  @register_function(config_type=ReActAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
76
86
  async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builder):
77
87
  from langchain.schema import BaseMessage
78
88
  from langchain_core.messages import trim_messages
79
- from langgraph.graph.graph import CompiledGraph
89
+ from langgraph.graph.state import CompiledStateGraph
80
90
 
81
91
  from nat.agent.base import AGENT_LOG_PREFIX
82
92
  from nat.agent.react_agent.agent import ReActAgentGraph
@@ -89,23 +99,36 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
89
99
  llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
90
100
  # the agent can run any installed tool, simply install the tool and add it to the config file
91
101
  # the sample tool provided can easily be copied or changed
92
- tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
102
+ tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
93
103
  if not tools:
94
104
  raise ValueError(f"No tools specified for ReAct Agent '{config.llm_name}'")
95
105
  # configure callbacks, for sending intermediate steps
96
106
  # construct the ReAct Agent Graph from the configured llm, prompt, and tools
97
- graph: CompiledGraph = await ReActAgentGraph(
107
+ graph: CompiledStateGraph = await ReActAgentGraph(
98
108
  llm=llm,
99
109
  prompt=prompt,
100
110
  tools=tools,
101
111
  use_tool_schema=config.include_tool_input_schema_in_tool_description,
102
112
  detailed_logs=config.verbose,
113
+ log_response_max_chars=config.log_response_max_chars,
103
114
  retry_agent_response_parsing_errors=config.retry_agent_response_parsing_errors,
104
115
  parse_agent_response_max_retries=config.parse_agent_response_max_retries,
105
116
  tool_call_max_retries=config.tool_call_max_retries,
106
- pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent).build_graph()
117
+ pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent,
118
+ normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()
107
119
 
108
120
  async def _response_fn(input_message: ChatRequest) -> ChatResponse:
121
+ """
122
+ Main workflow entry function for the ReAct Agent.
123
+
124
+ This function invokes the ReAct Agent Graph and returns the response.
125
+
126
+ Args:
127
+ input_message (ChatRequest): The input message to process
128
+
129
+ Returns:
130
+ ChatResponse: The response from the agent or error message
131
+ """
109
132
  try:
110
133
  # initialize the starting state with the user query
111
134
  messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
@@ -125,15 +148,12 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
125
148
 
126
149
  # get and return the output from the state
127
150
  state = ReActGraphState(**state)
128
- output_message = state.messages[-1] # pylint: disable=E1136
151
+ output_message = state.messages[-1]
129
152
  return ChatResponse.from_string(str(output_message.content))
130
153
 
131
154
  except Exception as ex:
132
- logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
133
- # here, we can implement custom error messages
134
- if config.verbose:
135
- return ChatResponse.from_string(str(ex))
136
- return ChatResponse.from_string("I seem to be having a problem.")
155
+ logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, str(ex))
156
+ raise RuntimeError
137
157
 
138
158
  if (config.use_openai_api):
139
159
  yield FunctionInfo.from_fn(_response_fn, description=config.description)
@@ -23,25 +23,22 @@ from nat.builder.builder import Builder
23
23
  from nat.builder.framework_enum import LLMFrameworkEnum
24
24
  from nat.builder.function_info import FunctionInfo
25
25
  from nat.cli.register_workflow import register_function
26
+ from nat.data_models.agent import AgentBaseConfig
26
27
  from nat.data_models.api_server import ChatRequest
27
28
  from nat.data_models.component_ref import FunctionRef
28
- from nat.data_models.component_ref import LLMRef
29
- from nat.data_models.function import FunctionBaseConfig
30
29
 
31
30
  logger = logging.getLogger(__name__)
32
31
 
33
32
 
34
- class ReasoningFunctionConfig(FunctionBaseConfig, name="reasoning_agent"):
33
+ class ReasoningFunctionConfig(AgentBaseConfig, name="reasoning_agent"):
35
34
  """
36
35
  Defines a NAT function that performs reasoning on the input data.
37
36
  Output is passed to the next function in the workflow.
38
37
 
39
38
  Designed to be used with an InterceptingFunction.
40
39
  """
41
-
42
- llm_name: LLMRef = Field(description="The name of the LLM to use for reasoning.")
40
+ description: str = Field(default="Reasoning Agent", description="The description of this function's use.")
43
41
  augmented_fn: FunctionRef = Field(description="The name of the function to reason on.")
44
- verbose: bool = Field(default=False, description="Whether to log detailed information.")
45
42
  reasoning_prompt_template: str = Field(
46
43
  default=("You are an expert reasoning model task with creating a detailed execution plan"
47
44
  " for a system that has the following description:\n\n"
@@ -102,7 +99,7 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
102
99
  llm: BaseChatModel = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
103
100
 
104
101
  # Get the augmented function's description
105
- augmented_function = builder.get_function(config.augmented_fn)
102
+ augmented_function = await builder.get_function(config.augmented_fn)
106
103
 
107
104
  # For now, we rely on runtime checking for type conversion
108
105
 
@@ -113,11 +110,16 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
113
110
  f"function without a description.")
114
111
 
115
112
  # Get the function dependencies of the augmented function
116
- function_used_tools = builder.get_function_dependencies(config.augmented_fn).functions
113
+ function_dependencies = builder.get_function_dependencies(config.augmented_fn)
114
+ function_used_tools = set()
115
+ function_used_tools.update(function_dependencies.functions)
116
+ for function_group in function_dependencies.function_groups:
117
+ function_used_tools.update(builder.get_function_group_dependencies(function_group).functions)
118
+
117
119
  tool_names_with_desc: list[tuple[str, str]] = []
118
120
 
119
121
  for tool in function_used_tools:
120
- tool_impl = builder.get_function(tool)
122
+ tool_impl = await builder.get_function(tool)
121
123
  tool_names_with_desc.append((tool, tool_impl.description if hasattr(tool_impl, "description") else ""))
122
124
 
123
125
  # Draft the reasoning prompt for the augmented function
nat/agent/register.py CHANGED
@@ -13,10 +13,10 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
 
19
18
  # Import any workflows which need to be automatically registered here
19
+ from .prompt_optimizer import register as prompt_optimizer
20
20
  from .react_agent import register as react_agent
21
21
  from .reasoning_agent import reasoning_agent
22
22
  from .rewoo_agent import register as rewoo_agent