nvidia-nat 1.3.dev0__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (246) hide show
  1. aiq/__init__.py +66 -0
  2. nat/agent/base.py +40 -14
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +96 -57
  7. nat/agent/react_agent/prompt.py +4 -1
  8. nat/agent/react_agent/register.py +41 -21
  9. nat/agent/reasoning_agent/reasoning_agent.py +11 -9
  10. nat/agent/register.py +1 -1
  11. nat/agent/rewoo_agent/agent.py +332 -150
  12. nat/agent/rewoo_agent/prompt.py +22 -22
  13. nat/agent/rewoo_agent/register.py +49 -28
  14. nat/agent/tool_calling_agent/agent.py +156 -29
  15. nat/agent/tool_calling_agent/register.py +57 -38
  16. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  17. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  18. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  19. nat/authentication/interfaces.py +5 -2
  20. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
  21. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  22. nat/authentication/register.py +0 -1
  23. nat/builder/builder.py +56 -24
  24. nat/builder/component_utils.py +9 -5
  25. nat/builder/context.py +46 -11
  26. nat/builder/eval_builder.py +16 -11
  27. nat/builder/framework_enum.py +1 -0
  28. nat/builder/front_end.py +1 -1
  29. nat/builder/function.py +378 -8
  30. nat/builder/function_base.py +3 -3
  31. nat/builder/function_info.py +6 -8
  32. nat/builder/user_interaction_manager.py +2 -2
  33. nat/builder/workflow.py +13 -1
  34. nat/builder/workflow_builder.py +281 -76
  35. nat/cli/cli_utils/config_override.py +2 -2
  36. nat/cli/commands/evaluate.py +1 -1
  37. nat/cli/commands/info/info.py +16 -6
  38. nat/cli/commands/info/list_channels.py +1 -1
  39. nat/cli/commands/info/list_components.py +7 -8
  40. nat/cli/commands/mcp/__init__.py +14 -0
  41. nat/cli/commands/mcp/mcp.py +986 -0
  42. nat/cli/commands/object_store/__init__.py +14 -0
  43. nat/cli/commands/object_store/object_store.py +227 -0
  44. nat/cli/commands/optimize.py +90 -0
  45. nat/cli/commands/registry/publish.py +2 -2
  46. nat/cli/commands/registry/pull.py +2 -2
  47. nat/cli/commands/registry/remove.py +2 -2
  48. nat/cli/commands/registry/search.py +15 -17
  49. nat/cli/commands/start.py +16 -5
  50. nat/cli/commands/uninstall.py +1 -1
  51. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  52. nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
  53. nat/cli/commands/workflow/templates/register.py.j2 +0 -1
  54. nat/cli/commands/workflow/workflow_commands.py +9 -13
  55. nat/cli/entrypoint.py +8 -10
  56. nat/cli/register_workflow.py +38 -4
  57. nat/cli/type_registry.py +79 -10
  58. nat/control_flow/__init__.py +0 -0
  59. nat/control_flow/register.py +20 -0
  60. nat/control_flow/router_agent/__init__.py +0 -0
  61. nat/control_flow/router_agent/agent.py +329 -0
  62. nat/control_flow/router_agent/prompt.py +48 -0
  63. nat/control_flow/router_agent/register.py +91 -0
  64. nat/control_flow/sequential_executor.py +166 -0
  65. nat/data_models/agent.py +34 -0
  66. nat/data_models/api_server.py +10 -10
  67. nat/data_models/authentication.py +23 -9
  68. nat/data_models/common.py +1 -1
  69. nat/data_models/component.py +2 -0
  70. nat/data_models/component_ref.py +11 -0
  71. nat/data_models/config.py +41 -17
  72. nat/data_models/dataset_handler.py +1 -1
  73. nat/data_models/discovery_metadata.py +4 -4
  74. nat/data_models/evaluate.py +4 -1
  75. nat/data_models/function.py +34 -0
  76. nat/data_models/function_dependencies.py +14 -6
  77. nat/data_models/gated_field_mixin.py +242 -0
  78. nat/data_models/intermediate_step.py +3 -3
  79. nat/data_models/optimizable.py +119 -0
  80. nat/data_models/optimizer.py +149 -0
  81. nat/data_models/swe_bench_model.py +1 -1
  82. nat/data_models/temperature_mixin.py +44 -0
  83. nat/data_models/thinking_mixin.py +86 -0
  84. nat/data_models/top_p_mixin.py +44 -0
  85. nat/embedder/azure_openai_embedder.py +46 -0
  86. nat/embedder/nim_embedder.py +1 -1
  87. nat/embedder/openai_embedder.py +2 -3
  88. nat/embedder/register.py +1 -1
  89. nat/eval/config.py +3 -1
  90. nat/eval/dataset_handler/dataset_handler.py +71 -7
  91. nat/eval/evaluate.py +86 -31
  92. nat/eval/evaluator/base_evaluator.py +1 -1
  93. nat/eval/evaluator/evaluator_model.py +13 -0
  94. nat/eval/intermediate_step_adapter.py +1 -1
  95. nat/eval/rag_evaluator/evaluate.py +2 -2
  96. nat/eval/rag_evaluator/register.py +3 -3
  97. nat/eval/register.py +4 -1
  98. nat/eval/remote_workflow.py +3 -3
  99. nat/eval/runtime_evaluator/__init__.py +14 -0
  100. nat/eval/runtime_evaluator/evaluate.py +123 -0
  101. nat/eval/runtime_evaluator/register.py +100 -0
  102. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  103. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  104. nat/eval/trajectory_evaluator/register.py +1 -1
  105. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  106. nat/eval/utils/eval_trace_ctx.py +89 -0
  107. nat/eval/utils/weave_eval.py +18 -9
  108. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  109. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  110. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  111. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  112. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  113. nat/experimental/test_time_compute/register.py +0 -1
  114. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  115. nat/front_ends/console/authentication_flow_handler.py +82 -30
  116. nat/front_ends/console/console_front_end_plugin.py +8 -5
  117. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  118. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  119. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  120. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  121. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  122. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
  123. nat/front_ends/fastapi/job_store.py +518 -99
  124. nat/front_ends/fastapi/main.py +11 -19
  125. nat/front_ends/fastapi/message_handler.py +13 -14
  126. nat/front_ends/fastapi/message_validator.py +17 -19
  127. nat/front_ends/fastapi/response_helpers.py +4 -4
  128. nat/front_ends/fastapi/step_adaptor.py +2 -2
  129. nat/front_ends/fastapi/utils.py +57 -0
  130. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  131. nat/front_ends/mcp/mcp_front_end_config.py +10 -1
  132. nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
  133. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
  134. nat/front_ends/mcp/tool_converter.py +44 -14
  135. nat/front_ends/register.py +0 -1
  136. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  137. nat/llm/aws_bedrock_llm.py +24 -12
  138. nat/llm/azure_openai_llm.py +57 -0
  139. nat/llm/litellm_llm.py +69 -0
  140. nat/llm/nim_llm.py +20 -8
  141. nat/llm/openai_llm.py +14 -6
  142. nat/llm/register.py +5 -1
  143. nat/llm/utils/env_config_value.py +2 -3
  144. nat/llm/utils/thinking.py +215 -0
  145. nat/meta/pypi.md +9 -9
  146. nat/object_store/models.py +2 -0
  147. nat/object_store/register.py +0 -1
  148. nat/observability/exporter/base_exporter.py +3 -3
  149. nat/observability/exporter/file_exporter.py +1 -1
  150. nat/observability/exporter/processing_exporter.py +309 -81
  151. nat/observability/exporter/span_exporter.py +1 -1
  152. nat/observability/exporter_manager.py +7 -7
  153. nat/observability/mixin/file_mixin.py +7 -7
  154. nat/observability/mixin/redaction_config_mixin.py +42 -0
  155. nat/observability/mixin/tagging_config_mixin.py +62 -0
  156. nat/observability/mixin/type_introspection_mixin.py +420 -107
  157. nat/observability/processor/batching_processor.py +5 -7
  158. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  159. nat/observability/processor/processor.py +3 -0
  160. nat/observability/processor/processor_factory.py +70 -0
  161. nat/observability/processor/redaction/__init__.py +24 -0
  162. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  163. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  164. nat/observability/processor/redaction/redaction_processor.py +177 -0
  165. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  166. nat/observability/processor/span_tagging_processor.py +68 -0
  167. nat/observability/register.py +6 -4
  168. nat/profiler/calc/calc_runner.py +3 -4
  169. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  170. nat/profiler/callbacks/langchain_callback_handler.py +14 -7
  171. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  172. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  173. nat/profiler/data_frame_row.py +1 -1
  174. nat/profiler/decorators/framework_wrapper.py +62 -13
  175. nat/profiler/decorators/function_tracking.py +160 -3
  176. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  177. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  178. nat/profiler/inference_optimization/data_models.py +3 -3
  179. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
  180. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  181. nat/profiler/parameter_optimization/__init__.py +0 -0
  182. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  183. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  184. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  185. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  186. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  187. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  188. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  189. nat/profiler/profile_runner.py +14 -9
  190. nat/profiler/utils.py +4 -2
  191. nat/registry_handlers/local/local_handler.py +2 -2
  192. nat/registry_handlers/package_utils.py +1 -2
  193. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  194. nat/registry_handlers/register.py +3 -4
  195. nat/registry_handlers/rest/rest_handler.py +12 -13
  196. nat/retriever/milvus/retriever.py +2 -2
  197. nat/retriever/nemo_retriever/retriever.py +1 -1
  198. nat/retriever/register.py +0 -1
  199. nat/runtime/loader.py +2 -2
  200. nat/runtime/runner.py +3 -2
  201. nat/runtime/session.py +43 -8
  202. nat/settings/global_settings.py +16 -5
  203. nat/tool/chat_completion.py +5 -2
  204. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  205. nat/tool/datetime_tools.py +49 -9
  206. nat/tool/document_search.py +2 -2
  207. nat/tool/github_tools.py +450 -0
  208. nat/tool/nvidia_rag.py +1 -1
  209. nat/tool/register.py +2 -9
  210. nat/tool/retriever.py +3 -2
  211. nat/utils/callable_utils.py +70 -0
  212. nat/utils/data_models/schema_validator.py +3 -3
  213. nat/utils/exception_handlers/automatic_retries.py +104 -51
  214. nat/utils/exception_handlers/schemas.py +1 -1
  215. nat/utils/io/yaml_tools.py +2 -2
  216. nat/utils/log_levels.py +25 -0
  217. nat/utils/reactive/base/observable_base.py +2 -2
  218. nat/utils/reactive/base/observer_base.py +1 -1
  219. nat/utils/reactive/observable.py +2 -2
  220. nat/utils/reactive/observer.py +4 -4
  221. nat/utils/reactive/subscription.py +1 -1
  222. nat/utils/settings/global_settings.py +6 -8
  223. nat/utils/type_converter.py +4 -3
  224. nat/utils/type_utils.py +9 -5
  225. {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +49 -21
  226. {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +233 -189
  227. {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
  228. nvidia_nat-1.3.0rc1.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  229. {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +1 -0
  230. nat/cli/commands/info/list_mcp.py +0 -304
  231. nat/tool/github_tools/create_github_commit.py +0 -133
  232. nat/tool/github_tools/create_github_issue.py +0 -87
  233. nat/tool/github_tools/create_github_pr.py +0 -106
  234. nat/tool/github_tools/get_github_file.py +0 -106
  235. nat/tool/github_tools/get_github_issue.py +0 -166
  236. nat/tool/github_tools/get_github_pr.py +0 -256
  237. nat/tool/github_tools/update_github_issue.py +0 -100
  238. nat/tool/mcp/exceptions.py +0 -142
  239. nat/tool/mcp/mcp_client.py +0 -255
  240. nat/tool/mcp/mcp_tool.py +0 -96
  241. nat/utils/exception_handlers/mcp.py +0 -211
  242. nvidia_nat-1.3.dev0.dist-info/licenses/LICENSE-3rd-party.txt +0 -3686
  243. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  244. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  245. {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
  246. {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
@@ -13,20 +13,24 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import asyncio
16
17
  import json
17
- # pylint: disable=R0917
18
18
  import logging
19
+ import re
19
20
  from json import JSONDecodeError
21
+ from typing import Any
20
22
 
21
23
  from langchain_core.callbacks.base import AsyncCallbackHandler
22
24
  from langchain_core.language_models import BaseChatModel
23
25
  from langchain_core.messages.ai import AIMessage
26
+ from langchain_core.messages.base import BaseMessage
24
27
  from langchain_core.messages.human import HumanMessage
25
28
  from langchain_core.messages.tool import ToolMessage
26
29
  from langchain_core.prompts.chat import ChatPromptTemplate
27
30
  from langchain_core.runnables.config import RunnableConfig
28
31
  from langchain_core.tools import BaseTool
29
32
  from langgraph.graph import StateGraph
33
+ from langgraph.graph.state import CompiledStateGraph
30
34
  from pydantic import BaseModel
31
35
  from pydantic import Field
32
36
 
@@ -41,22 +45,40 @@ from nat.agent.base import BaseAgent
41
45
  logger = logging.getLogger(__name__)
42
46
 
43
47
 
48
+ class ReWOOEvidence(BaseModel):
49
+ placeholder: str
50
+ tool: str
51
+ tool_input: Any
52
+
53
+
54
+ class ReWOOPlanStep(BaseModel):
55
+ plan: str
56
+ evidence: ReWOOEvidence
57
+
58
+
44
59
  class ReWOOGraphState(BaseModel):
45
60
  """State schema for the ReWOO Agent Graph"""
61
+ messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReWOO Agent
46
62
  task: HumanMessage = Field(default_factory=lambda: HumanMessage(content="")) # the task provided by user
47
63
  plan: AIMessage = Field(
48
64
  default_factory=lambda: AIMessage(content="")) # the plan generated by the planner to solve the task
49
65
  steps: AIMessage = Field(
50
66
  default_factory=lambda: AIMessage(content="")) # the steps to solve the task, parsed from the plan
67
+ # New fields for parallel execution support
68
+ evidence_map: dict[str, ReWOOPlanStep] = Field(default_factory=dict) # mapping from placeholders to step info
69
+ execution_levels: list[list[str]] = Field(default_factory=list) # levels for parallel execution
70
+ current_level: int = Field(default=0) # current execution level
51
71
  intermediate_results: dict[str, ToolMessage] = Field(default_factory=dict) # the intermediate results of each step
52
72
  result: AIMessage = Field(
53
73
  default_factory=lambda: AIMessage(content="")) # the final result of the task, generated by the solver
54
74
 
55
75
 
56
76
  class ReWOOAgentGraph(BaseAgent):
57
- """Configurable LangGraph ReWOO Agent. A ReWOO Agent performs reasoning by interacting with other objects or tools
58
- and utilizes their outputs to make decisions. Supports retrying on output parsing errors. Argument
59
- "detailed_logs" toggles logging of inputs, outputs, and intermediate steps."""
77
+ """Configurable ReWOO Agent.
78
+
79
+ Args:
80
+ detailed_logs: Toggles logging of inputs, outputs, and intermediate steps.
81
+ """
60
82
 
61
83
  def __init__(self,
62
84
  llm: BaseChatModel,
@@ -65,28 +87,34 @@ class ReWOOAgentGraph(BaseAgent):
65
87
  tools: list[BaseTool],
66
88
  use_tool_schema: bool = True,
67
89
  callbacks: list[AsyncCallbackHandler] | None = None,
68
- detailed_logs: bool = False):
69
- super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
90
+ detailed_logs: bool = False,
91
+ log_response_max_chars: int = 1000,
92
+ tool_call_max_retries: int = 3,
93
+ raise_tool_call_error: bool = True):
94
+ super().__init__(llm=llm,
95
+ tools=tools,
96
+ callbacks=callbacks,
97
+ detailed_logs=detailed_logs,
98
+ log_response_max_chars=log_response_max_chars)
70
99
 
71
100
  logger.debug(
72
101
  "%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
73
102
  AGENT_LOG_PREFIX)
74
- tool_names = ",".join([tool.name for tool in tools[:-1]]) + ',' + tools[-1].name # prevent trailing ","
75
- if not use_tool_schema:
76
- tool_names_and_descriptions = "\n".join(
77
- [f"{tool.name}: {tool.description}"
78
- for tool in tools[:-1]]) + "\n" + f"{tools[-1].name}: {tools[-1].description}" # prevent trailing "\n"
79
- else:
80
- logger.debug("%s Adding the tools' input schema to the tools' description", AGENT_LOG_PREFIX)
81
- tool_names_and_descriptions = "\n".join([
82
- f"{tool.name}: {tool.description}. {INPUT_SCHEMA_MESSAGE.format(schema=tool.input_schema.model_fields)}"
83
- for tool in tools[:-1]
84
- ]) + "\n" + (f"{tools[-1].name}: {tools[-1].description}. "
85
- f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
103
+
104
+ def describe_tool(tool: BaseTool) -> str:
105
+ description = f"{tool.name}: {tool.description}"
106
+ if use_tool_schema:
107
+ description += f". {INPUT_SCHEMA_MESSAGE.format(schema=tool.input_schema.model_fields)}"
108
+ return description
109
+
110
+ tool_names = ",".join(tool.name for tool in tools)
111
+ tool_names_and_descriptions = "\n".join(describe_tool(tool) for tool in tools)
86
112
 
87
113
  self.planner_prompt = planner_prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
88
114
  self.solver_prompt = solver_prompt
89
115
  self.tools_dict = {tool.name: tool for tool in tools}
116
+ self.tool_call_max_retries = tool_call_max_retries
117
+ self.raise_tool_call_error = raise_tool_call_error
90
118
 
91
119
  logger.debug("%s Initialized ReWOO Agent Graph", AGENT_LOG_PREFIX)
92
120
 
@@ -94,34 +122,91 @@ class ReWOOAgentGraph(BaseAgent):
94
122
  try:
95
123
  return self.tools_dict.get(tool_name)
96
124
  except Exception as ex:
97
- logger.exception("%s Unable to find tool with the name %s\n%s",
98
- AGENT_LOG_PREFIX,
99
- tool_name,
100
- ex,
101
- exc_info=True)
102
- raise ex
125
+ logger.error("%s Unable to find tool with the name %s\n%s", AGENT_LOG_PREFIX, tool_name, ex)
126
+ raise
103
127
 
104
128
  @staticmethod
105
- def _get_current_step(state: ReWOOGraphState) -> int:
106
- steps = state.steps.content
107
- if len(steps) == 0:
108
- raise RuntimeError('No steps received in ReWOOGraphState')
129
+ def _get_current_level_status(state: ReWOOGraphState) -> tuple[int, bool]:
130
+ """
131
+ Get the current execution level and whether it's complete.
132
+
133
+ Args:
134
+ state: The ReWOO graph state.
109
135
 
110
- if len(state.intermediate_results) == len(steps):
111
- # all steps are done
112
- return -1
136
+ Returns:
137
+ tuple of (current_level, is_complete). Level -1 means all execution is complete.
138
+ """
139
+ if not state.execution_levels:
140
+ return -1, True
113
141
 
114
- return len(state.intermediate_results)
142
+ current_level = state.current_level
143
+
144
+ # Check if we've completed all levels
145
+ if current_level >= len(state.execution_levels):
146
+ return -1, True
147
+
148
+ # Check if current level is complete
149
+ current_level_placeholders = state.execution_levels[current_level]
150
+ level_complete = all(placeholder in state.intermediate_results for placeholder in current_level_placeholders)
151
+
152
+ return current_level, level_complete
115
153
 
116
154
  @staticmethod
117
- def _parse_planner_output(planner_output: str) -> AIMessage:
155
+ def _parse_planner_output(planner_output: str) -> list[ReWOOPlanStep]:
118
156
 
119
157
  try:
120
- steps = json.loads(planner_output)
121
- except json.JSONDecodeError as ex:
158
+ return [ReWOOPlanStep(**step) for step in json.loads(planner_output)]
159
+ except Exception as ex:
122
160
  raise ValueError(f"The output of planner is invalid JSON format: {planner_output}") from ex
123
161
 
124
- return AIMessage(content=steps)
162
+ @staticmethod
163
+ def _parse_planner_dependencies(steps: list[ReWOOPlanStep]) -> tuple[dict[str, ReWOOPlanStep], list[list[str]]]:
164
+ """
165
+ Parse planner steps to identify dependencies and create execution levels for parallel processing.
166
+ This creates a dependency map and identifies which evidence placeholders can be executed in parallel.
167
+
168
+ Args:
169
+ steps: list of plan steps from the planner.
170
+
171
+ Returns:
172
+ A mapping from evidence placeholders to step info and execution levels for parallel processing.
173
+ """
174
+ # First pass: collect all evidence placeholders and their info
175
+ evidences: dict[str, ReWOOPlanStep] = {
176
+ step.evidence.placeholder: step
177
+ for step in steps if step.evidence and step.evidence.placeholder
178
+ }
179
+
180
+ # Second pass: find dependencies now that we have all placeholders
181
+ dependencies = {
182
+ step.evidence.placeholder: [
183
+ var for var in re.findall(r"#E\d+", str(step.evidence.tool_input))
184
+ if var in evidences and var != step.evidence.placeholder
185
+ ]
186
+ for step in steps if step.evidence and step.evidence.placeholder
187
+ }
188
+
189
+ # Create execution levels using topological sort
190
+ levels: list[list[str]] = []
191
+ remaining = dict(dependencies)
192
+
193
+ while remaining:
194
+ # Find items with no dependencies (can be executed in parallel)
195
+ ready = [placeholder for placeholder, deps in remaining.items() if not deps]
196
+
197
+ if not ready:
198
+ raise ValueError("Circular dependency detected in planner output")
199
+
200
+ levels.append(ready)
201
+
202
+ # Remove completed items from remaining
203
+ for placeholder in ready:
204
+ remaining.pop(placeholder)
205
+
206
+ # Remove completed items from other dependencies
207
+ for ph, deps in list(remaining.items()):
208
+ remaining[ph] = list(set(deps) - set(ready))
209
+ return evidences, levels
125
210
 
126
211
  @staticmethod
127
212
  def _replace_placeholder(placeholder: str, tool_input: str | dict, tool_output: str | dict) -> str | dict:
@@ -141,6 +226,7 @@ class ReWOOAgentGraph(BaseAgent):
141
226
 
142
227
  else:
143
228
  assert False, f"Unexpected type for tool_input: {type(tool_input)}"
229
+
144
230
  return tool_input
145
231
 
146
232
  @staticmethod
@@ -183,132 +269,217 @@ class ReWOOAgentGraph(BaseAgent):
183
269
  if not task:
184
270
  logger.error("%s No task provided to the ReWOO Agent. Please provide a valid task.", AGENT_LOG_PREFIX)
185
271
  return {"result": NO_INPUT_ERROR_MESSAGE}
186
-
272
+ chat_history = self._get_chat_history(state.messages)
187
273
  plan = await self._stream_llm(
188
274
  planner,
189
- {"task": task},
275
+ {
276
+ "task": task, "chat_history": chat_history
277
+ },
190
278
  RunnableConfig(callbacks=self.callbacks) # type: ignore
191
279
  )
192
280
 
193
281
  steps = self._parse_planner_output(str(plan.content))
194
282
 
283
+ # Parse dependencies and create execution levels for parallel processing
284
+ evidence_map, execution_levels = self._parse_planner_dependencies(steps)
285
+
195
286
  if self.detailed_logs:
196
287
  agent_response_log_message = AGENT_CALL_LOG_MESSAGE % (task, str(plan.content))
197
288
  logger.info("ReWOO agent planner output: %s", agent_response_log_message)
289
+ logger.info("ReWOO agent execution levels: %s", execution_levels)
198
290
 
199
- return {"plan": plan, "steps": steps}
291
+ return {
292
+ "plan": plan,
293
+ "evidence_map": evidence_map,
294
+ "execution_levels": execution_levels,
295
+ "current_level": 0,
296
+ }
200
297
 
201
298
  except Exception as ex:
202
- logger.exception("%s Failed to call planner_node: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
203
- raise ex
299
+ logger.error("%s Failed to call planner_node: %s", AGENT_LOG_PREFIX, ex)
300
+ raise
204
301
 
205
302
  async def executor_node(self, state: ReWOOGraphState):
303
+ """
304
+ Execute tools in parallel for the current dependency level.
305
+
306
+ This replaces the sequential execution with parallel execution of tools
307
+ that have no dependencies between them.
308
+ """
206
309
  try:
207
310
  logger.debug("%s Starting the ReWOO Executor Node", AGENT_LOG_PREFIX)
208
311
 
209
- current_step = self._get_current_step(state)
210
- # The executor node should not be invoked after all steps are finished
211
- if current_step < 0:
212
- logger.error("%s ReWOO Executor is invoked with an invalid step number: %s",
213
- AGENT_LOG_PREFIX,
214
- current_step)
215
- raise RuntimeError(f"ReWOO Executor is invoked with an invalid step number: {current_step}")
216
-
217
- steps_content = state.steps.content
218
- if isinstance(steps_content, list) and current_step < len(steps_content):
219
- step = steps_content[current_step]
220
- if isinstance(step, dict) and "evidence" in step:
221
- step_info = step["evidence"]
222
- placeholder = step_info.get("placeholder", "")
223
- tool = step_info.get("tool", "")
224
- tool_input = step_info.get("tool_input", "")
312
+ current_level, level_complete = self._get_current_level_status(state)
313
+
314
+ # Should not be invoked if all levels are complete
315
+ if current_level < 0:
316
+ logger.error("%s ReWOO Executor invoked after all levels complete", AGENT_LOG_PREFIX)
317
+ raise RuntimeError("ReWOO Executor invoked after all levels complete")
318
+
319
+ # If current level is already complete, move to next level
320
+ if level_complete:
321
+ new_level = current_level + 1
322
+ logger.debug("%s Level %s complete, moving to level %s", AGENT_LOG_PREFIX, current_level, new_level)
323
+ return {"current_level": new_level}
324
+
325
+ # Get placeholders for current level
326
+ current_level_placeholders = state.execution_levels[current_level]
327
+
328
+ # Filter to only placeholders not yet completed
329
+ pending_placeholders = list(set(current_level_placeholders) - set(state.intermediate_results.keys()))
330
+
331
+ if not pending_placeholders:
332
+ # All placeholders in this level are done, move to next level
333
+ new_level = current_level + 1
334
+ return {"current_level": new_level}
335
+
336
+ logger.debug("%s Executing level %s with %s tools in parallel: %s",
337
+ AGENT_LOG_PREFIX,
338
+ current_level,
339
+ len(pending_placeholders),
340
+ pending_placeholders)
341
+
342
+ # Execute all tools in current level in parallel
343
+ tasks = []
344
+ for placeholder in pending_placeholders:
345
+ step_info = state.evidence_map[placeholder]
346
+ task = self._execute_single_tool(placeholder, step_info, state.intermediate_results)
347
+ tasks.append(task)
348
+
349
+ # Wait for all tasks in current level to complete
350
+ results = await asyncio.gather(*tasks, return_exceptions=True)
351
+
352
+ # Process results and update intermediate_results
353
+ updated_intermediate_results = dict(state.intermediate_results)
354
+
355
+ for placeholder, result in zip(pending_placeholders, results):
356
+ if isinstance(result, BaseException):
357
+ logger.error("%s Tool execution failed for %s: %s", AGENT_LOG_PREFIX, placeholder, result)
358
+ # Create error tool message
359
+ error_message = f"Tool execution failed: {str(result)}"
360
+ updated_intermediate_results[placeholder] = ToolMessage(content=error_message,
361
+ tool_call_id=placeholder)
362
+ if self.raise_tool_call_error:
363
+ raise result
225
364
  else:
226
- logger.error("%s Invalid step format at index %s", AGENT_LOG_PREFIX, current_step)
227
- return {"intermediate_results": state.intermediate_results}
228
- else:
229
- logger.error("%s Invalid steps content or index %s", AGENT_LOG_PREFIX, current_step)
230
- return {"intermediate_results": state.intermediate_results}
231
-
232
- intermediate_results = state.intermediate_results
233
-
234
- # Replace the placeholder in the tool input with the previous tool output
235
- for _placeholder, _tool_output in intermediate_results.items():
236
- _tool_output = _tool_output.content
237
- # If the content is a list, get the first element which should be a dict
238
- if isinstance(_tool_output, list):
239
- _tool_output = _tool_output[0]
240
- assert isinstance(_tool_output, dict)
241
-
242
- tool_input = self._replace_placeholder(_placeholder, tool_input, _tool_output)
243
-
244
- requested_tool = self._get_tool(tool)
245
- if not requested_tool:
246
- configured_tool_names = list(self.tools_dict.keys())
247
- logger.warning(
248
- "%s ReWOO Agent wants to call tool %s. In the ReWOO Agent's configuration within the config file,"
249
- "there is no tool with that name: %s",
250
- AGENT_LOG_PREFIX,
251
- tool,
252
- configured_tool_names)
253
-
254
- intermediate_results[placeholder] = ToolMessage(content=TOOL_NOT_FOUND_ERROR_MESSAGE.format(
255
- tool_name=tool, tools=configured_tool_names),
256
- tool_call_id=tool)
257
- return {"intermediate_results": intermediate_results}
365
+ updated_intermediate_results[placeholder] = result
366
+ # Check if the ToolMessage has error status and raise_tool_call_error is True
367
+ if (isinstance(result, ToolMessage) and hasattr(result, 'status') and result.status == "error"
368
+ and self.raise_tool_call_error):
369
+ logger.error("%s Tool call failed for %s: %s", AGENT_LOG_PREFIX, placeholder, result.content)
370
+ raise RuntimeError(f"Tool call failed: {result.content}")
258
371
 
259
372
  if self.detailed_logs:
260
- logger.debug("%s Calling tool %s with input: %s", AGENT_LOG_PREFIX, requested_tool.name, tool_input)
373
+ logger.info("%s Completed level %s with %s tools",
374
+ AGENT_LOG_PREFIX,
375
+ current_level,
376
+ len(pending_placeholders))
261
377
 
262
- # Run the tool. Try to use structured input, if possible
263
- tool_input_parsed = self._parse_tool_input(tool_input)
264
- tool_response = await self._call_tool(requested_tool,
265
- tool_input_parsed,
266
- RunnableConfig(callbacks=self.callbacks),
267
- max_retries=3)
268
-
269
- # ToolMessage only accepts str or list[str | dict] as content.
270
- # Convert into list if the response is a dict.
271
- if isinstance(tool_response, dict):
272
- tool_response = [tool_response]
273
-
274
- tool_response_message = ToolMessage(name=tool, tool_call_id=tool, content=tool_response)
275
-
276
- if self.detailed_logs:
277
- self._log_tool_response(requested_tool.name, tool_input_parsed, str(tool_response))
278
-
279
- intermediate_results[placeholder] = tool_response_message
280
- return {"intermediate_results": intermediate_results}
378
+ return {"intermediate_results": updated_intermediate_results}
281
379
 
282
380
  except Exception as ex:
283
- logger.exception("%s Failed to call executor_node: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
284
- raise ex
381
+ logger.error("%s Failed to call executor_node: %s", AGENT_LOG_PREFIX, ex)
382
+ raise
383
+
384
+ async def _execute_single_tool(self,
385
+ placeholder: str,
386
+ step_info: ReWOOPlanStep,
387
+ intermediate_results: dict[str, ToolMessage]) -> ToolMessage:
388
+ """
389
+ Execute a single tool with proper placeholder replacement.
390
+
391
+ Args:
392
+ placeholder: The evidence placeholder (e.g., "#E1").
393
+ step_info: Step information containing tool and tool_input.
394
+ intermediate_results: Current intermediate results for placeholder replacement.
395
+
396
+ Returns:
397
+ ToolMessage with the tool execution result.
398
+ """
399
+ evidence_info = step_info.evidence
400
+ tool_name = evidence_info.tool
401
+ tool_input = evidence_info.tool_input
402
+
403
+ # Replace placeholders in tool input with previous results
404
+ for ph_key, tool_output in intermediate_results.items():
405
+ tool_output_content = tool_output.content
406
+ # If the content is a list, get the first element which should be a dict
407
+ if isinstance(tool_output_content, list):
408
+ tool_output_content = tool_output_content[0]
409
+ assert isinstance(tool_output_content, dict)
410
+
411
+ tool_input = self._replace_placeholder(ph_key, tool_input, tool_output_content)
412
+
413
+ # Get the requested tool
414
+ requested_tool = self._get_tool(tool_name)
415
+ if not requested_tool:
416
+ configured_tool_names = list(self.tools_dict.keys())
417
+ logger.warning(
418
+ "%s ReWOO Agent wants to call tool %s. In the ReWOO Agent's configuration within the config file,"
419
+ "there is no tool with that name: %s",
420
+ AGENT_LOG_PREFIX,
421
+ tool_name,
422
+ configured_tool_names)
423
+
424
+ return ToolMessage(content=TOOL_NOT_FOUND_ERROR_MESSAGE.format(tool_name=tool_name,
425
+ tools=configured_tool_names),
426
+ tool_call_id=placeholder)
427
+
428
+ if self.detailed_logs:
429
+ logger.debug("%s Calling tool %s with input: %s", AGENT_LOG_PREFIX, requested_tool.name, tool_input)
430
+
431
+ # Parse and execute the tool
432
+ tool_input_parsed = self._parse_tool_input(tool_input)
433
+ tool_response = await self._call_tool(
434
+ requested_tool,
435
+ tool_input_parsed,
436
+ RunnableConfig(callbacks=self.callbacks), # type: ignore
437
+ max_retries=self.tool_call_max_retries)
438
+
439
+ if self.detailed_logs:
440
+ self._log_tool_response(requested_tool.name, tool_input_parsed, str(tool_response))
441
+
442
+ return tool_response
285
443
 
286
444
  async def solver_node(self, state: ReWOOGraphState):
287
445
  try:
288
446
  logger.debug("%s Starting the ReWOO Solver Node", AGENT_LOG_PREFIX)
289
447
 
290
448
  plan = ""
291
- # Add the tool outputs of each step to the plan
292
- for step in state.steps.content:
293
- step_info = step["evidence"]
294
- placeholder = step_info.get("placeholder", "")
295
- tool_input = step_info.get("tool_input", "")
296
-
297
- intermediate_results = state.intermediate_results
298
- for _placeholder, _tool_output in intermediate_results.items():
299
- _tool_output = _tool_output.content
449
+ # Add the tool outputs of each step to the plan using evidence_map
450
+ for placeholder, step_info in state.evidence_map.items():
451
+ evidence_info = step_info.evidence
452
+ original_tool_input = evidence_info.tool_input
453
+ tool_name = evidence_info.tool
454
+
455
+ # Replace placeholders in tool input with actual results
456
+ final_tool_input = original_tool_input
457
+ for ph_key, tool_output in state.intermediate_results.items():
458
+ tool_output_content = tool_output.content
300
459
  # If the content is a list, get the first element which should be a dict
301
- if isinstance(_tool_output, list):
302
- _tool_output = _tool_output[0]
303
- assert isinstance(_tool_output, dict)
304
-
305
- tool_input = self._replace_placeholder(_placeholder, tool_input, _tool_output)
306
-
307
- placeholder = placeholder.replace(_placeholder, str(_tool_output))
308
-
309
- _plan = step.get("plan")
310
- tool = step_info.get("tool")
311
- plan += f"Plan: {_plan}\n{placeholder} = {tool}[{tool_input}]"
460
+ if isinstance(tool_output_content, list):
461
+ tool_output_content = tool_output_content[0]
462
+ assert isinstance(tool_output_content, dict)
463
+
464
+ final_tool_input = self._replace_placeholder(ph_key, final_tool_input, tool_output_content)
465
+
466
+ # Get the final result for this placeholder
467
+ final_result = ""
468
+ if placeholder in state.intermediate_results:
469
+ result_content = state.intermediate_results[placeholder].content
470
+ if isinstance(result_content, list):
471
+ result_content = result_content[0]
472
+ if isinstance(result_content, dict):
473
+ final_result = str(result_content)
474
+ else:
475
+ final_result = str(result_content)
476
+
477
+ step_plan = step_info.plan
478
+ plan += '\n'.join([
479
+ f"Plan: {step_plan}",
480
+ f"{placeholder} = {tool_name}[{final_tool_input}",
481
+ f"Result: {final_result}\n\n"
482
+ ])
312
483
 
313
484
  task = str(state.task.content)
314
485
  solver_prompt = self.solver_prompt.partial(plan=plan)
@@ -324,30 +495,39 @@ class ReWOOAgentGraph(BaseAgent):
324
495
  return {"result": output_message}
325
496
 
326
497
  except Exception as ex:
327
- logger.exception("%s Failed to call solver_node: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
328
- raise ex
498
+ logger.error("%s Failed to call solver_node: %s", AGENT_LOG_PREFIX, ex)
499
+ raise
329
500
 
330
501
  async def conditional_edge(self, state: ReWOOGraphState):
331
502
  try:
332
503
  logger.debug("%s Starting the ReWOO Conditional Edge", AGENT_LOG_PREFIX)
333
504
 
334
- current_step = self._get_current_step(state)
335
- if current_step == -1:
336
- logger.debug("%s The ReWOO Executor has finished its task", AGENT_LOG_PREFIX)
505
+ current_level, level_complete = self._get_current_level_status(state)
506
+
507
+ # If all levels are complete, move to solver
508
+ if current_level == -1:
509
+ logger.debug("%s All execution levels complete, moving to solver", AGENT_LOG_PREFIX)
337
510
  return AgentDecision.END
338
511
 
339
- logger.debug("%s The ReWOO Executor is still working on the task", AGENT_LOG_PREFIX)
512
+ # If current level is complete, check if there are more levels
513
+ if level_complete:
514
+ next_level = current_level + 1
515
+ if next_level >= len(state.execution_levels):
516
+ logger.debug("%s All execution levels complete, moving to solver", AGENT_LOG_PREFIX)
517
+ return AgentDecision.END
518
+
519
+ logger.debug("%s Continuing with executor (level %s, complete: %s)",
520
+ AGENT_LOG_PREFIX,
521
+ current_level,
522
+ level_complete)
340
523
  return AgentDecision.TOOL
341
524
 
342
525
  except Exception as ex:
343
- logger.exception("%s Failed to determine whether agent is calling a tool: %s",
344
- AGENT_LOG_PREFIX,
345
- ex,
346
- exc_info=True)
526
+ logger.exception("%s Failed to determine whether agent is calling a tool: %s", AGENT_LOG_PREFIX, ex)
347
527
  logger.warning("%s Ending graph traversal", AGENT_LOG_PREFIX)
348
528
  return AgentDecision.END
349
529
 
350
- async def _build_graph(self, state_schema):
530
+ async def _build_graph(self, state_schema: type) -> CompiledStateGraph:
351
531
  try:
352
532
  logger.debug("%s Building and compiling the ReWOO Graph", AGENT_LOG_PREFIX)
353
533
 
@@ -357,8 +537,10 @@ class ReWOOAgentGraph(BaseAgent):
357
537
  graph.add_node("solver", self.solver_node)
358
538
 
359
539
  graph.add_edge("planner", "executor")
360
- conditional_edge_possible_outputs = {AgentDecision.TOOL: "executor", AgentDecision.END: "solver"}
361
- graph.add_conditional_edges("executor", self.conditional_edge, conditional_edge_possible_outputs)
540
+ graph.add_conditional_edges("executor",
541
+ self.conditional_edge, {
542
+ AgentDecision.TOOL: "executor", AgentDecision.END: "solver"
543
+ })
362
544
 
363
545
  graph.set_entry_point("planner")
364
546
  graph.set_finish_point("solver")
@@ -369,8 +551,8 @@ class ReWOOAgentGraph(BaseAgent):
369
551
  return self.graph
370
552
 
371
553
  except Exception as ex:
372
- logger.exception("%s Failed to build ReWOO Graph: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
373
- raise ex
554
+ logger.error("%s Failed to build ReWOO Graph: %s", AGENT_LOG_PREFIX, ex)
555
+ raise
374
556
 
375
557
  async def build_graph(self):
376
558
  try:
@@ -378,8 +560,8 @@ class ReWOOAgentGraph(BaseAgent):
378
560
  logger.debug("%s ReWOO Graph built and compiled successfully", AGENT_LOG_PREFIX)
379
561
  return self.graph
380
562
  except Exception as ex:
381
- logger.exception("%s Failed to build ReWOO Graph: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
382
- raise ex
563
+ logger.error("%s Failed to build ReWOO Graph: %s", AGENT_LOG_PREFIX, ex)
564
+ raise
383
565
 
384
566
  @staticmethod
385
567
  def validate_planner_prompt(planner_prompt: str) -> bool:
@@ -395,7 +577,7 @@ class ReWOOAgentGraph(BaseAgent):
395
577
  errors.append(error_message)
396
578
  if errors:
397
579
  error_text = "\n".join(errors)
398
- logger.exception("%s %s", AGENT_LOG_PREFIX, error_text)
580
+ logger.error("%s %s", AGENT_LOG_PREFIX, error_text)
399
581
  raise ValueError(error_text)
400
582
  return True
401
583
 
@@ -406,6 +588,6 @@ class ReWOOAgentGraph(BaseAgent):
406
588
  errors.append("The solver prompt cannot be empty.")
407
589
  if errors:
408
590
  error_text = "\n".join(errors)
409
- logger.exception("%s %s", AGENT_LOG_PREFIX, error_text)
591
+ logger.error("%s %s", AGENT_LOG_PREFIX, error_text)
410
592
  raise ValueError(error_text)
411
593
  return True