nvidia-nat 1.2.1rc1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (257) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +27 -18
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +81 -50
  7. nat/agent/react_agent/register.py +59 -40
  8. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +327 -149
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +64 -46
  13. nat/agent/tool_calling_agent/agent.py +152 -29
  14. nat/agent/tool_calling_agent/register.py +61 -38
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +10 -6
  24. nat/builder/context.py +70 -18
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/intermediate_step_manager.py +6 -2
  32. nat/builder/user_interaction_manager.py +2 -2
  33. nat/builder/workflow.py +13 -1
  34. nat/builder/workflow_builder.py +327 -79
  35. nat/cli/cli_utils/config_override.py +2 -2
  36. nat/cli/commands/evaluate.py +1 -1
  37. nat/cli/commands/info/info.py +16 -6
  38. nat/cli/commands/info/list_channels.py +1 -1
  39. nat/cli/commands/info/list_components.py +7 -8
  40. nat/cli/commands/mcp/__init__.py +14 -0
  41. nat/cli/commands/mcp/mcp.py +986 -0
  42. nat/cli/commands/object_store/__init__.py +14 -0
  43. nat/cli/commands/object_store/object_store.py +227 -0
  44. nat/cli/commands/optimize.py +90 -0
  45. nat/cli/commands/registry/publish.py +2 -2
  46. nat/cli/commands/registry/pull.py +2 -2
  47. nat/cli/commands/registry/remove.py +2 -2
  48. nat/cli/commands/registry/search.py +15 -17
  49. nat/cli/commands/start.py +16 -5
  50. nat/cli/commands/uninstall.py +1 -1
  51. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  52. nat/cli/commands/workflow/templates/pyproject.toml.j2 +5 -2
  53. nat/cli/commands/workflow/templates/register.py.j2 +2 -3
  54. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  55. nat/cli/commands/workflow/workflow_commands.py +105 -19
  56. nat/cli/entrypoint.py +17 -11
  57. nat/cli/main.py +3 -0
  58. nat/cli/register_workflow.py +38 -4
  59. nat/cli/type_registry.py +79 -10
  60. nat/control_flow/__init__.py +0 -0
  61. nat/control_flow/register.py +20 -0
  62. nat/control_flow/router_agent/__init__.py +0 -0
  63. nat/control_flow/router_agent/agent.py +329 -0
  64. nat/control_flow/router_agent/prompt.py +48 -0
  65. nat/control_flow/router_agent/register.py +91 -0
  66. nat/control_flow/sequential_executor.py +166 -0
  67. nat/data_models/agent.py +34 -0
  68. nat/data_models/api_server.py +196 -67
  69. nat/data_models/authentication.py +23 -9
  70. nat/data_models/common.py +1 -1
  71. nat/data_models/component.py +2 -0
  72. nat/data_models/component_ref.py +11 -0
  73. nat/data_models/config.py +42 -18
  74. nat/data_models/dataset_handler.py +1 -1
  75. nat/data_models/discovery_metadata.py +4 -4
  76. nat/data_models/evaluate.py +4 -1
  77. nat/data_models/function.py +34 -0
  78. nat/data_models/function_dependencies.py +14 -6
  79. nat/data_models/gated_field_mixin.py +242 -0
  80. nat/data_models/intermediate_step.py +3 -3
  81. nat/data_models/optimizable.py +119 -0
  82. nat/data_models/optimizer.py +149 -0
  83. nat/data_models/span.py +41 -3
  84. nat/data_models/swe_bench_model.py +1 -1
  85. nat/data_models/temperature_mixin.py +44 -0
  86. nat/data_models/thinking_mixin.py +86 -0
  87. nat/data_models/top_p_mixin.py +44 -0
  88. nat/embedder/azure_openai_embedder.py +46 -0
  89. nat/embedder/nim_embedder.py +1 -1
  90. nat/embedder/openai_embedder.py +2 -3
  91. nat/embedder/register.py +1 -1
  92. nat/eval/config.py +3 -1
  93. nat/eval/dataset_handler/dataset_handler.py +71 -7
  94. nat/eval/evaluate.py +86 -31
  95. nat/eval/evaluator/base_evaluator.py +1 -1
  96. nat/eval/evaluator/evaluator_model.py +13 -0
  97. nat/eval/intermediate_step_adapter.py +1 -1
  98. nat/eval/rag_evaluator/evaluate.py +9 -6
  99. nat/eval/rag_evaluator/register.py +3 -3
  100. nat/eval/register.py +4 -1
  101. nat/eval/remote_workflow.py +3 -3
  102. nat/eval/runtime_evaluator/__init__.py +14 -0
  103. nat/eval/runtime_evaluator/evaluate.py +123 -0
  104. nat/eval/runtime_evaluator/register.py +100 -0
  105. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  106. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  107. nat/eval/trajectory_evaluator/register.py +1 -1
  108. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  109. nat/eval/utils/eval_trace_ctx.py +89 -0
  110. nat/eval/utils/weave_eval.py +18 -9
  111. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  112. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  113. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  114. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  115. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  116. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  117. nat/experimental/test_time_compute/register.py +0 -1
  118. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  119. nat/front_ends/console/authentication_flow_handler.py +82 -30
  120. nat/front_ends/console/console_front_end_plugin.py +19 -7
  121. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  122. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  123. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  124. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  125. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  126. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  127. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +455 -282
  128. nat/front_ends/fastapi/job_store.py +518 -99
  129. nat/front_ends/fastapi/main.py +11 -19
  130. nat/front_ends/fastapi/message_handler.py +74 -50
  131. nat/front_ends/fastapi/message_validator.py +20 -21
  132. nat/front_ends/fastapi/response_helpers.py +4 -4
  133. nat/front_ends/fastapi/step_adaptor.py +2 -2
  134. nat/front_ends/fastapi/utils.py +57 -0
  135. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  136. nat/front_ends/mcp/mcp_front_end_config.py +47 -3
  137. nat/front_ends/mcp/mcp_front_end_plugin.py +48 -13
  138. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +120 -8
  139. nat/front_ends/mcp/tool_converter.py +44 -14
  140. nat/front_ends/register.py +0 -1
  141. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  142. nat/llm/aws_bedrock_llm.py +24 -12
  143. nat/llm/azure_openai_llm.py +57 -0
  144. nat/llm/litellm_llm.py +69 -0
  145. nat/llm/nim_llm.py +20 -8
  146. nat/llm/openai_llm.py +14 -6
  147. nat/llm/register.py +5 -1
  148. nat/llm/utils/env_config_value.py +2 -3
  149. nat/llm/utils/thinking.py +215 -0
  150. nat/meta/pypi.md +9 -9
  151. nat/object_store/register.py +0 -1
  152. nat/observability/exporter/base_exporter.py +3 -3
  153. nat/observability/exporter/file_exporter.py +1 -1
  154. nat/observability/exporter/processing_exporter.py +309 -81
  155. nat/observability/exporter/span_exporter.py +35 -15
  156. nat/observability/exporter_manager.py +7 -7
  157. nat/observability/mixin/file_mixin.py +7 -7
  158. nat/observability/mixin/redaction_config_mixin.py +42 -0
  159. nat/observability/mixin/tagging_config_mixin.py +62 -0
  160. nat/observability/mixin/type_introspection_mixin.py +420 -107
  161. nat/observability/processor/batching_processor.py +5 -7
  162. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  163. nat/observability/processor/processor.py +3 -0
  164. nat/observability/processor/processor_factory.py +70 -0
  165. nat/observability/processor/redaction/__init__.py +24 -0
  166. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  167. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  168. nat/observability/processor/redaction/redaction_processor.py +177 -0
  169. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  170. nat/observability/processor/span_tagging_processor.py +68 -0
  171. nat/observability/register.py +22 -4
  172. nat/profiler/calc/calc_runner.py +3 -4
  173. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  174. nat/profiler/callbacks/langchain_callback_handler.py +14 -7
  175. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  176. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  177. nat/profiler/data_frame_row.py +1 -1
  178. nat/profiler/decorators/framework_wrapper.py +62 -13
  179. nat/profiler/decorators/function_tracking.py +160 -3
  180. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  181. nat/profiler/forecasting/models/linear_model.py +1 -1
  182. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  183. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  184. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  185. nat/profiler/inference_optimization/data_models.py +3 -3
  186. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
  187. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  188. nat/profiler/parameter_optimization/__init__.py +0 -0
  189. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  190. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  191. nat/profiler/parameter_optimization/parameter_optimizer.py +164 -0
  192. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  193. nat/profiler/parameter_optimization/pareto_visualizer.py +395 -0
  194. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  195. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  196. nat/profiler/profile_runner.py +14 -9
  197. nat/profiler/utils.py +4 -2
  198. nat/registry_handlers/local/local_handler.py +2 -2
  199. nat/registry_handlers/package_utils.py +1 -2
  200. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  201. nat/registry_handlers/register.py +3 -4
  202. nat/registry_handlers/rest/rest_handler.py +12 -13
  203. nat/retriever/milvus/retriever.py +2 -2
  204. nat/retriever/nemo_retriever/retriever.py +1 -1
  205. nat/retriever/register.py +0 -1
  206. nat/runtime/loader.py +2 -2
  207. nat/runtime/runner.py +105 -8
  208. nat/runtime/session.py +69 -8
  209. nat/settings/global_settings.py +16 -5
  210. nat/tool/chat_completion.py +5 -2
  211. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  212. nat/tool/datetime_tools.py +49 -9
  213. nat/tool/document_search.py +2 -2
  214. nat/tool/github_tools.py +450 -0
  215. nat/tool/memory_tools/add_memory_tool.py +3 -3
  216. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  217. nat/tool/memory_tools/get_memory_tool.py +4 -4
  218. nat/tool/nvidia_rag.py +1 -1
  219. nat/tool/register.py +2 -9
  220. nat/tool/retriever.py +3 -2
  221. nat/utils/callable_utils.py +70 -0
  222. nat/utils/data_models/schema_validator.py +3 -3
  223. nat/utils/decorators.py +210 -0
  224. nat/utils/exception_handlers/automatic_retries.py +104 -51
  225. nat/utils/exception_handlers/schemas.py +1 -1
  226. nat/utils/io/yaml_tools.py +2 -2
  227. nat/utils/log_levels.py +25 -0
  228. nat/utils/reactive/base/observable_base.py +2 -2
  229. nat/utils/reactive/base/observer_base.py +1 -1
  230. nat/utils/reactive/observable.py +2 -2
  231. nat/utils/reactive/observer.py +4 -4
  232. nat/utils/reactive/subscription.py +1 -1
  233. nat/utils/settings/global_settings.py +6 -8
  234. nat/utils/type_converter.py +12 -3
  235. nat/utils/type_utils.py +9 -5
  236. nvidia_nat-1.3.0.dist-info/METADATA +195 -0
  237. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/RECORD +244 -200
  238. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/entry_points.txt +1 -0
  239. nat/cli/commands/info/list_mcp.py +0 -304
  240. nat/tool/github_tools/create_github_commit.py +0 -133
  241. nat/tool/github_tools/create_github_issue.py +0 -87
  242. nat/tool/github_tools/create_github_pr.py +0 -106
  243. nat/tool/github_tools/get_github_file.py +0 -106
  244. nat/tool/github_tools/get_github_issue.py +0 -166
  245. nat/tool/github_tools/get_github_pr.py +0 -256
  246. nat/tool/github_tools/update_github_issue.py +0 -100
  247. nat/tool/mcp/exceptions.py +0 -142
  248. nat/tool/mcp/mcp_client.py +0 -255
  249. nat/tool/mcp/mcp_tool.py +0 -96
  250. nat/utils/exception_handlers/mcp.py +0 -211
  251. nvidia_nat-1.2.1rc1.dist-info/METADATA +0 -365
  252. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  253. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  254. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/WHEEL +0 -0
  255. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  256. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE.md +0 -0
  257. {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/top_level.txt +0 -0
@@ -22,26 +22,29 @@ from nat.builder.builder import Builder
22
22
  from nat.builder.framework_enum import LLMFrameworkEnum
23
23
  from nat.builder.function_info import FunctionInfo
24
24
  from nat.cli.register_workflow import register_function
25
+ from nat.data_models.agent import AgentBaseConfig
25
26
  from nat.data_models.api_server import ChatRequest
27
+ from nat.data_models.api_server import ChatRequestOrMessage
26
28
  from nat.data_models.api_server import ChatResponse
29
+ from nat.data_models.api_server import Usage
30
+ from nat.data_models.component_ref import FunctionGroupRef
27
31
  from nat.data_models.component_ref import FunctionRef
28
- from nat.data_models.component_ref import LLMRef
29
- from nat.data_models.function import FunctionBaseConfig
32
+ from nat.data_models.optimizable import OptimizableField
33
+ from nat.data_models.optimizable import OptimizableMixin
34
+ from nat.data_models.optimizable import SearchSpace
30
35
  from nat.utils.type_converter import GlobalTypeConverter
31
36
 
32
37
  logger = logging.getLogger(__name__)
33
38
 
34
39
 
35
- class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
40
+ class ReActAgentWorkflowConfig(AgentBaseConfig, OptimizableMixin, name="react_agent"):
36
41
  """
37
42
  Defines a NAT function that uses a ReAct Agent performs reasoning inbetween tool calls, and utilizes the
38
43
  tool names and descriptions to select the optimal tool.
39
44
  """
40
-
41
- tool_names: list[FunctionRef] = Field(default_factory=list,
42
- description="The list of tools to provide to the react agent.")
43
- llm_name: LLMRef = Field(description="The LLM model to use with the react agent.")
44
- verbose: bool = Field(default=False, description="Set the verbosity of the react agent's logging.")
45
+ description: str = Field(default="ReAct Agent Workflow", description="The description of this functions use.")
46
+ tool_names: list[FunctionRef | FunctionGroupRef] = Field(
47
+ default_factory=list, description="The list of tools to provide to the react agent.")
45
48
  retry_agent_response_parsing_errors: bool = Field(
46
49
  default=True,
47
50
  validation_alias=AliasChoices("retry_agent_response_parsing_errors", "retry_parsing_errors"),
@@ -60,23 +63,29 @@ class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
60
63
  description="Whether to pass tool call errors to agent. If False, failed tool calls will raise an exception.")
61
64
  include_tool_input_schema_in_tool_description: bool = Field(
62
65
  default=True, description="Specify inclusion of tool input schemas in the prompt.")
63
- description: str = Field(default="ReAct Agent Workflow", description="The description of this functions use.")
66
+ normalize_tool_input_quotes: bool = Field(
67
+ default=True,
68
+ description="Whether to replace single quotes with double quotes in the tool input. "
69
+ "This is useful for tools that expect structured json input.")
64
70
  system_prompt: str | None = Field(
65
71
  default=None,
66
72
  description="Provides the SYSTEM_PROMPT to use with the agent") # defaults to SYSTEM_PROMPT in prompt.py
67
73
  max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
68
- use_openai_api: bool = Field(default=False,
69
- description=("Use OpenAI API for the input/output types to the function. "
70
- "If False, strings will be used."))
71
- additional_instructions: str | None = Field(
72
- default=None, description="Additional instructions to provide to the agent in addition to the base prompt.")
74
+ additional_instructions: str | None = OptimizableField(
75
+ default=None,
76
+ description="Additional instructions to provide to the agent in addition to the base prompt.",
77
+ space=SearchSpace(
78
+ is_prompt=True,
79
+ prompt="No additional instructions.",
80
+ prompt_purpose="Additional instructions to provide to the agent in addition to the base prompt.",
81
+ ))
73
82
 
74
83
 
75
84
  @register_function(config_type=ReActAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
76
85
  async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builder):
77
86
  from langchain.schema import BaseMessage
78
87
  from langchain_core.messages import trim_messages
79
- from langgraph.graph.graph import CompiledGraph
88
+ from langgraph.graph.state import CompiledStateGraph
80
89
 
81
90
  from nat.agent.base import AGENT_LOG_PREFIX
82
91
  from nat.agent.react_agent.agent import ReActAgentGraph
@@ -89,26 +98,41 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
89
98
  llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
90
99
  # the agent can run any installed tool, simply install the tool and add it to the config file
91
100
  # the sample tool provided can easily be copied or changed
92
- tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
101
+ tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
93
102
  if not tools:
94
103
  raise ValueError(f"No tools specified for ReAct Agent '{config.llm_name}'")
95
104
  # configure callbacks, for sending intermediate steps
96
105
  # construct the ReAct Agent Graph from the configured llm, prompt, and tools
97
- graph: CompiledGraph = await ReActAgentGraph(
106
+ graph: CompiledStateGraph = await ReActAgentGraph(
98
107
  llm=llm,
99
108
  prompt=prompt,
100
109
  tools=tools,
101
110
  use_tool_schema=config.include_tool_input_schema_in_tool_description,
102
111
  detailed_logs=config.verbose,
112
+ log_response_max_chars=config.log_response_max_chars,
103
113
  retry_agent_response_parsing_errors=config.retry_agent_response_parsing_errors,
104
114
  parse_agent_response_max_retries=config.parse_agent_response_max_retries,
105
115
  tool_call_max_retries=config.tool_call_max_retries,
106
- pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent).build_graph()
116
+ pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent,
117
+ normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()
118
+
119
+ async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
120
+ """
121
+ Main workflow entry function for the ReAct Agent.
122
+
123
+ This function invokes the ReAct Agent Graph and returns the response.
107
124
 
108
- async def _response_fn(input_message: ChatRequest) -> ChatResponse:
125
+ Args:
126
+ chat_request_or_message (ChatRequestOrMessage): The input message to process
127
+
128
+ Returns:
129
+ ChatResponse | str: The response from the agent or error message
130
+ """
109
131
  try:
132
+ message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
133
+
110
134
  # initialize the starting state with the user query
111
- messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
135
+ messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
112
136
  max_tokens=config.max_history,
113
137
  strategy="last",
114
138
  token_counter=len,
@@ -125,25 +149,20 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
125
149
 
126
150
  # get and return the output from the state
127
151
  state = ReActGraphState(**state)
128
- output_message = state.messages[-1] # pylint: disable=E1136
129
- return ChatResponse.from_string(str(output_message.content))
130
-
152
+ output_message = state.messages[-1]
153
+ content = str(output_message.content)
154
+
155
+ # Create usage statistics for the response
156
+ prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
157
+ completion_tokens = len(content.split()) if content else 0
158
+ total_tokens = prompt_tokens + completion_tokens
159
+ usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
160
+ response = ChatResponse.from_string(content, usage=usage)
161
+ if chat_request_or_message.is_string:
162
+ return GlobalTypeConverter.get().convert(response, to_type=str)
163
+ return response
131
164
  except Exception as ex:
132
- logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
133
- # here, we can implement custom error messages
134
- if config.verbose:
135
- return ChatResponse.from_string(str(ex))
136
- return ChatResponse.from_string("I seem to be having a problem.")
137
-
138
- if (config.use_openai_api):
139
- yield FunctionInfo.from_fn(_response_fn, description=config.description)
140
- else:
141
-
142
- async def _str_api_fn(input_message: str) -> str:
143
- oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
144
-
145
- oai_output = await _response_fn(oai_input)
146
-
147
- return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
165
+ logger.error("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, str(ex))
166
+ raise
148
167
 
149
- yield FunctionInfo.from_fn(_str_api_fn, description=config.description)
168
+ yield FunctionInfo.from_fn(_response_fn, description=config.description)
@@ -23,25 +23,22 @@ from nat.builder.builder import Builder
23
23
  from nat.builder.framework_enum import LLMFrameworkEnum
24
24
  from nat.builder.function_info import FunctionInfo
25
25
  from nat.cli.register_workflow import register_function
26
+ from nat.data_models.agent import AgentBaseConfig
26
27
  from nat.data_models.api_server import ChatRequest
27
28
  from nat.data_models.component_ref import FunctionRef
28
- from nat.data_models.component_ref import LLMRef
29
- from nat.data_models.function import FunctionBaseConfig
30
29
 
31
30
  logger = logging.getLogger(__name__)
32
31
 
33
32
 
34
- class ReasoningFunctionConfig(FunctionBaseConfig, name="reasoning_agent"):
33
+ class ReasoningFunctionConfig(AgentBaseConfig, name="reasoning_agent"):
35
34
  """
36
35
  Defines a NAT function that performs reasoning on the input data.
37
36
  Output is passed to the next function in the workflow.
38
37
 
39
38
  Designed to be used with an InterceptingFunction.
40
39
  """
41
-
42
- llm_name: LLMRef = Field(description="The name of the LLM to use for reasoning.")
40
+ description: str = Field(default="Reasoning Agent", description="The description of this function's use.")
43
41
  augmented_fn: FunctionRef = Field(description="The name of the function to reason on.")
44
- verbose: bool = Field(default=False, description="Whether to log detailed information.")
45
42
  reasoning_prompt_template: str = Field(
46
43
  default=("You are an expert reasoning model task with creating a detailed execution plan"
47
44
  " for a system that has the following description:\n\n"
@@ -102,7 +99,7 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
102
99
  llm: BaseChatModel = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
103
100
 
104
101
  # Get the augmented function's description
105
- augmented_function = builder.get_function(config.augmented_fn)
102
+ augmented_function = await builder.get_function(config.augmented_fn)
106
103
 
107
104
  # For now, we rely on runtime checking for type conversion
108
105
 
@@ -113,11 +110,16 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
113
110
  f"function without a description.")
114
111
 
115
112
  # Get the function dependencies of the augmented function
116
- function_used_tools = builder.get_function_dependencies(config.augmented_fn).functions
113
+ function_dependencies = builder.get_function_dependencies(config.augmented_fn)
114
+ function_used_tools = set()
115
+ function_used_tools.update(function_dependencies.functions)
116
+ for function_group in function_dependencies.function_groups:
117
+ function_used_tools.update(builder.get_function_group_dependencies(function_group).functions)
118
+
117
119
  tool_names_with_desc: list[tuple[str, str]] = []
118
120
 
119
121
  for tool in function_used_tools:
120
- tool_impl = builder.get_function(tool)
122
+ tool_impl = await builder.get_function(tool)
121
123
  tool_names_with_desc.append((tool, tool_impl.description if hasattr(tool_impl, "description") else ""))
122
124
 
123
125
  # Draft the reasoning prompt for the augmented function
@@ -155,12 +157,12 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
155
157
  prompt = prompt.to_string()
156
158
 
157
159
  # Get the reasoning output from the LLM
158
- reasoning_output = ""
160
+ reasoning_output = []
159
161
 
160
162
  async for chunk in llm.astream(prompt):
161
- reasoning_output += chunk.content
163
+ reasoning_output.append(chunk.content)
162
164
 
163
- reasoning_output = remove_r1_think_tags(reasoning_output)
165
+ reasoning_output = remove_r1_think_tags("".join(reasoning_output))
164
166
 
165
167
  output = await downstream_template.ainvoke(input={
166
168
  "input_text": input_text, "reasoning_output": reasoning_output
@@ -198,12 +200,12 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
198
200
  prompt = prompt.to_string()
199
201
 
200
202
  # Get the reasoning output from the LLM
201
- reasoning_output = ""
203
+ reasoning_output = []
202
204
 
203
205
  async for chunk in llm.astream(prompt):
204
- reasoning_output += chunk.content
206
+ reasoning_output.append(chunk.content)
205
207
 
206
- reasoning_output = remove_r1_think_tags(reasoning_output)
208
+ reasoning_output = remove_r1_think_tags("".join(reasoning_output))
207
209
 
208
210
  output = await downstream_template.ainvoke(input={
209
211
  "input_text": input_text, "reasoning_output": reasoning_output
nat/agent/register.py CHANGED
@@ -13,10 +13,10 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
 
19
18
  # Import any workflows which need to be automatically registered here
19
+ from .prompt_optimizer import register as prompt_optimizer
20
20
  from .react_agent import register as react_agent
21
21
  from .reasoning_agent import reasoning_agent
22
22
  from .rewoo_agent import register as rewoo_agent