nvidia-nat 1.3.0a20250928__py3-none-any.whl → 1.3.0a20250930__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +1 -1
- nat/agent/rewoo_agent/agent.py +298 -118
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +4 -1
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +26 -18
- nat/builder/builder.py +1 -1
- nat/builder/context.py +2 -2
- nat/builder/front_end.py +1 -1
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/commands/mcp/mcp.py +2 -2
- nat/cli/commands/start.py +1 -1
- nat/cli/type_registry.py +1 -1
- nat/control_flow/router_agent/register.py +1 -1
- nat/data_models/api_server.py +9 -9
- nat/data_models/authentication.py +3 -9
- nat/data_models/dataset_handler.py +1 -1
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/swe_bench_evaluator/evaluate.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
- nat/experimental/decorators/experimental_warning_decorator.py +1 -2
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +188 -2
- nat/front_ends/fastapi/job_store.py +2 -2
- nat/front_ends/fastapi/message_handler.py +4 -4
- nat/front_ends/fastapi/message_validator.py +5 -5
- nat/front_ends/mcp/tool_converter.py +1 -1
- nat/llm/utils/thinking.py +1 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +2 -2
- nat/observability/processor/batching_processor.py +1 -1
- nat/profiler/decorators/function_tracking.py +2 -2
- nat/profiler/parameter_optimization/parameter_selection.py +3 -4
- nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
- nat/retriever/milvus/retriever.py +1 -1
- nat/settings/global_settings.py +2 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/datetime_tools.py +1 -1
- nat/utils/data_models/schema_validator.py +1 -1
- nat/utils/exception_handlers/automatic_retries.py +1 -1
- nat/utils/io/yaml_tools.py +1 -1
- nat/utils/type_utils.py +1 -1
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/METADATA +2 -1
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/RECORD +52 -52
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/top_level.txt +0 -0
nat/agent/base.py
CHANGED
|
@@ -192,7 +192,7 @@ class BaseAgent(ABC):
|
|
|
192
192
|
await asyncio.sleep(sleep_time)
|
|
193
193
|
|
|
194
194
|
# All retries exhausted, return error message
|
|
195
|
-
error_content = "Tool call failed after all retry attempts. Last error:
|
|
195
|
+
error_content = f"Tool call failed after all retry attempts. Last error: {str(last_exception)}"
|
|
196
196
|
logger.error("%s %s", AGENT_LOG_PREFIX, error_content, exc_info=True)
|
|
197
197
|
return ToolMessage(name=tool.name, tool_call_id=tool.name, content=error_content, status="error")
|
|
198
198
|
|
nat/agent/rewoo_agent/agent.py
CHANGED
|
@@ -13,9 +13,12 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
16
17
|
import json
|
|
17
18
|
import logging
|
|
19
|
+
import re
|
|
18
20
|
from json import JSONDecodeError
|
|
21
|
+
from typing import Any
|
|
19
22
|
|
|
20
23
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
|
21
24
|
from langchain_core.language_models import BaseChatModel
|
|
@@ -42,6 +45,17 @@ from nat.agent.base import BaseAgent
|
|
|
42
45
|
logger = logging.getLogger(__name__)
|
|
43
46
|
|
|
44
47
|
|
|
48
|
+
class ReWOOEvidence(BaseModel):
|
|
49
|
+
placeholder: str
|
|
50
|
+
tool: str
|
|
51
|
+
tool_input: Any
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ReWOOPlanStep(BaseModel):
|
|
55
|
+
plan: str
|
|
56
|
+
evidence: ReWOOEvidence
|
|
57
|
+
|
|
58
|
+
|
|
45
59
|
class ReWOOGraphState(BaseModel):
|
|
46
60
|
"""State schema for the ReWOO Agent Graph"""
|
|
47
61
|
messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReWOO Agent
|
|
@@ -50,15 +64,21 @@ class ReWOOGraphState(BaseModel):
|
|
|
50
64
|
default_factory=lambda: AIMessage(content="")) # the plan generated by the planner to solve the task
|
|
51
65
|
steps: AIMessage = Field(
|
|
52
66
|
default_factory=lambda: AIMessage(content="")) # the steps to solve the task, parsed from the plan
|
|
67
|
+
# New fields for parallel execution support
|
|
68
|
+
evidence_map: dict[str, ReWOOPlanStep] = Field(default_factory=dict) # mapping from placeholders to step info
|
|
69
|
+
execution_levels: list[list[str]] = Field(default_factory=list) # levels for parallel execution
|
|
70
|
+
current_level: int = Field(default=0) # current execution level
|
|
53
71
|
intermediate_results: dict[str, ToolMessage] = Field(default_factory=dict) # the intermediate results of each step
|
|
54
72
|
result: AIMessage = Field(
|
|
55
73
|
default_factory=lambda: AIMessage(content="")) # the final result of the task, generated by the solver
|
|
56
74
|
|
|
57
75
|
|
|
58
76
|
class ReWOOAgentGraph(BaseAgent):
|
|
59
|
-
"""Configurable
|
|
60
|
-
|
|
61
|
-
|
|
77
|
+
"""Configurable ReWOO Agent.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
detailed_logs: Toggles logging of inputs, outputs, and intermediate steps.
|
|
81
|
+
"""
|
|
62
82
|
|
|
63
83
|
def __init__(self,
|
|
64
84
|
llm: BaseChatModel,
|
|
@@ -80,18 +100,15 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
80
100
|
logger.debug(
|
|
81
101
|
"%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
|
|
82
102
|
AGENT_LOG_PREFIX)
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
for tool in tools[:-1]
|
|
93
|
-
]) + "\n" + (f"{tools[-1].name}: {tools[-1].description}. "
|
|
94
|
-
f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
|
|
103
|
+
|
|
104
|
+
def describe_tool(tool: BaseTool) -> str:
|
|
105
|
+
description = f"{tool.name}: {tool.description}"
|
|
106
|
+
if use_tool_schema:
|
|
107
|
+
description += f". {INPUT_SCHEMA_MESSAGE.format(schema=tool.input_schema.model_fields)}"
|
|
108
|
+
return description
|
|
109
|
+
|
|
110
|
+
tool_names = ",".join(tool.name for tool in tools)
|
|
111
|
+
tool_names_and_descriptions = "\n".join(describe_tool(tool) for tool in tools)
|
|
95
112
|
|
|
96
113
|
self.planner_prompt = planner_prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
|
|
97
114
|
self.solver_prompt = solver_prompt
|
|
@@ -109,26 +126,87 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
109
126
|
raise
|
|
110
127
|
|
|
111
128
|
@staticmethod
|
|
112
|
-
def
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
129
|
+
def _get_current_level_status(state: ReWOOGraphState) -> tuple[int, bool]:
|
|
130
|
+
"""
|
|
131
|
+
Get the current execution level and whether it's complete.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
state: The ReWOO graph state.
|
|
116
135
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
136
|
+
Returns:
|
|
137
|
+
tuple of (current_level, is_complete). Level -1 means all execution is complete.
|
|
138
|
+
"""
|
|
139
|
+
if not state.execution_levels:
|
|
140
|
+
return -1, True
|
|
120
141
|
|
|
121
|
-
|
|
142
|
+
current_level = state.current_level
|
|
143
|
+
|
|
144
|
+
# Check if we've completed all levels
|
|
145
|
+
if current_level >= len(state.execution_levels):
|
|
146
|
+
return -1, True
|
|
147
|
+
|
|
148
|
+
# Check if current level is complete
|
|
149
|
+
current_level_placeholders = state.execution_levels[current_level]
|
|
150
|
+
level_complete = all(placeholder in state.intermediate_results for placeholder in current_level_placeholders)
|
|
151
|
+
|
|
152
|
+
return current_level, level_complete
|
|
122
153
|
|
|
123
154
|
@staticmethod
|
|
124
|
-
def _parse_planner_output(planner_output: str) ->
|
|
155
|
+
def _parse_planner_output(planner_output: str) -> list[ReWOOPlanStep]:
|
|
125
156
|
|
|
126
157
|
try:
|
|
127
|
-
|
|
128
|
-
except
|
|
158
|
+
return [ReWOOPlanStep(**step) for step in json.loads(planner_output)]
|
|
159
|
+
except Exception as ex:
|
|
129
160
|
raise ValueError(f"The output of planner is invalid JSON format: {planner_output}") from ex
|
|
130
161
|
|
|
131
|
-
|
|
162
|
+
@staticmethod
|
|
163
|
+
def _parse_planner_dependencies(steps: list[ReWOOPlanStep]) -> tuple[dict[str, ReWOOPlanStep], list[list[str]]]:
|
|
164
|
+
"""
|
|
165
|
+
Parse planner steps to identify dependencies and create execution levels for parallel processing.
|
|
166
|
+
This creates a dependency map and identifies which evidence placeholders can be executed in parallel.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
steps: list of plan steps from the planner.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
A mapping from evidence placeholders to step info and execution levels for parallel processing.
|
|
173
|
+
"""
|
|
174
|
+
# First pass: collect all evidence placeholders and their info
|
|
175
|
+
evidences: dict[str, ReWOOPlanStep] = {
|
|
176
|
+
step.evidence.placeholder: step
|
|
177
|
+
for step in steps if step.evidence and step.evidence.placeholder
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
# Second pass: find dependencies now that we have all placeholders
|
|
181
|
+
dependencies = {
|
|
182
|
+
step.evidence.placeholder: [
|
|
183
|
+
var for var in re.findall(r"#E\d+", str(step.evidence.tool_input))
|
|
184
|
+
if var in evidences and var != step.evidence.placeholder
|
|
185
|
+
]
|
|
186
|
+
for step in steps if step.evidence and step.evidence.placeholder
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
# Create execution levels using topological sort
|
|
190
|
+
levels: list[list[str]] = []
|
|
191
|
+
remaining = dict(dependencies)
|
|
192
|
+
|
|
193
|
+
while remaining:
|
|
194
|
+
# Find items with no dependencies (can be executed in parallel)
|
|
195
|
+
ready = [placeholder for placeholder, deps in remaining.items() if not deps]
|
|
196
|
+
|
|
197
|
+
if not ready:
|
|
198
|
+
raise ValueError("Circular dependency detected in planner output")
|
|
199
|
+
|
|
200
|
+
levels.append(ready)
|
|
201
|
+
|
|
202
|
+
# Remove completed items from remaining
|
|
203
|
+
for placeholder in ready:
|
|
204
|
+
remaining.pop(placeholder)
|
|
205
|
+
|
|
206
|
+
# Remove completed items from other dependencies
|
|
207
|
+
for ph, deps in list(remaining.items()):
|
|
208
|
+
remaining[ph] = list(set(deps) - set(ready))
|
|
209
|
+
return evidences, levels
|
|
132
210
|
|
|
133
211
|
@staticmethod
|
|
134
212
|
def _replace_placeholder(placeholder: str, tool_input: str | dict, tool_output: str | dict) -> str | dict:
|
|
@@ -148,6 +226,7 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
148
226
|
|
|
149
227
|
else:
|
|
150
228
|
assert False, f"Unexpected type for tool_input: {type(tool_input)}"
|
|
229
|
+
|
|
151
230
|
return tool_input
|
|
152
231
|
|
|
153
232
|
@staticmethod
|
|
@@ -201,119 +280,206 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
201
280
|
|
|
202
281
|
steps = self._parse_planner_output(str(plan.content))
|
|
203
282
|
|
|
283
|
+
# Parse dependencies and create execution levels for parallel processing
|
|
284
|
+
evidence_map, execution_levels = self._parse_planner_dependencies(steps)
|
|
285
|
+
|
|
204
286
|
if self.detailed_logs:
|
|
205
287
|
agent_response_log_message = AGENT_CALL_LOG_MESSAGE % (task, str(plan.content))
|
|
206
288
|
logger.info("ReWOO agent planner output: %s", agent_response_log_message)
|
|
289
|
+
logger.info("ReWOO agent execution levels: %s", execution_levels)
|
|
207
290
|
|
|
208
|
-
return {
|
|
291
|
+
return {
|
|
292
|
+
"plan": plan,
|
|
293
|
+
"evidence_map": evidence_map,
|
|
294
|
+
"execution_levels": execution_levels,
|
|
295
|
+
"current_level": 0,
|
|
296
|
+
}
|
|
209
297
|
|
|
210
298
|
except Exception as ex:
|
|
211
299
|
logger.error("%s Failed to call planner_node: %s", AGENT_LOG_PREFIX, ex)
|
|
212
300
|
raise
|
|
213
301
|
|
|
214
302
|
async def executor_node(self, state: ReWOOGraphState):
|
|
303
|
+
"""
|
|
304
|
+
Execute tools in parallel for the current dependency level.
|
|
305
|
+
|
|
306
|
+
This replaces the sequential execution with parallel execution of tools
|
|
307
|
+
that have no dependencies between them.
|
|
308
|
+
"""
|
|
215
309
|
try:
|
|
216
310
|
logger.debug("%s Starting the ReWOO Executor Node", AGENT_LOG_PREFIX)
|
|
217
311
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
if
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
312
|
+
current_level, level_complete = self._get_current_level_status(state)
|
|
313
|
+
|
|
314
|
+
# Should not be invoked if all levels are complete
|
|
315
|
+
if current_level < 0:
|
|
316
|
+
logger.error("%s ReWOO Executor invoked after all levels complete", AGENT_LOG_PREFIX)
|
|
317
|
+
raise RuntimeError("ReWOO Executor invoked after all levels complete")
|
|
318
|
+
|
|
319
|
+
# If current level is already complete, move to next level
|
|
320
|
+
if level_complete:
|
|
321
|
+
new_level = current_level + 1
|
|
322
|
+
logger.debug("%s Level %s complete, moving to level %s", AGENT_LOG_PREFIX, current_level, new_level)
|
|
323
|
+
return {"current_level": new_level}
|
|
324
|
+
|
|
325
|
+
# Get placeholders for current level
|
|
326
|
+
current_level_placeholders = state.execution_levels[current_level]
|
|
327
|
+
|
|
328
|
+
# Filter to only placeholders not yet completed
|
|
329
|
+
pending_placeholders = list(set(current_level_placeholders) - set(state.intermediate_results.keys()))
|
|
330
|
+
|
|
331
|
+
if not pending_placeholders:
|
|
332
|
+
# All placeholders in this level are done, move to next level
|
|
333
|
+
new_level = current_level + 1
|
|
334
|
+
return {"current_level": new_level}
|
|
335
|
+
|
|
336
|
+
logger.debug("%s Executing level %s with %s tools in parallel: %s",
|
|
337
|
+
AGENT_LOG_PREFIX,
|
|
338
|
+
current_level,
|
|
339
|
+
len(pending_placeholders),
|
|
340
|
+
pending_placeholders)
|
|
341
|
+
|
|
342
|
+
# Execute all tools in current level in parallel
|
|
343
|
+
tasks = []
|
|
344
|
+
for placeholder in pending_placeholders:
|
|
345
|
+
step_info = state.evidence_map[placeholder]
|
|
346
|
+
task = self._execute_single_tool(placeholder, step_info, state.intermediate_results)
|
|
347
|
+
tasks.append(task)
|
|
348
|
+
|
|
349
|
+
# Wait for all tasks in current level to complete
|
|
350
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
351
|
+
|
|
352
|
+
# Process results and update intermediate_results
|
|
353
|
+
updated_intermediate_results = dict(state.intermediate_results)
|
|
354
|
+
|
|
355
|
+
for placeholder, result in zip(pending_placeholders, results):
|
|
356
|
+
if isinstance(result, BaseException):
|
|
357
|
+
logger.error("%s Tool execution failed for %s: %s", AGENT_LOG_PREFIX, placeholder, result)
|
|
358
|
+
# Create error tool message
|
|
359
|
+
error_message = f"Tool execution failed: {str(result)}"
|
|
360
|
+
updated_intermediate_results[placeholder] = ToolMessage(content=error_message,
|
|
361
|
+
tool_call_id=placeholder)
|
|
362
|
+
if self.raise_tool_call_error:
|
|
363
|
+
raise result
|
|
234
364
|
else:
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
intermediate_results = state.intermediate_results
|
|
242
|
-
|
|
243
|
-
# Replace the placeholder in the tool input with the previous tool output
|
|
244
|
-
for _placeholder, _tool_output in intermediate_results.items():
|
|
245
|
-
_tool_output = _tool_output.content
|
|
246
|
-
# If the content is a list, get the first element which should be a dict
|
|
247
|
-
if isinstance(_tool_output, list):
|
|
248
|
-
_tool_output = _tool_output[0]
|
|
249
|
-
assert isinstance(_tool_output, dict)
|
|
250
|
-
|
|
251
|
-
tool_input = self._replace_placeholder(_placeholder, tool_input, _tool_output)
|
|
252
|
-
|
|
253
|
-
requested_tool = self._get_tool(tool)
|
|
254
|
-
if not requested_tool:
|
|
255
|
-
configured_tool_names = list(self.tools_dict.keys())
|
|
256
|
-
logger.warning(
|
|
257
|
-
"%s ReWOO Agent wants to call tool %s. In the ReWOO Agent's configuration within the config file,"
|
|
258
|
-
"there is no tool with that name: %s",
|
|
259
|
-
AGENT_LOG_PREFIX,
|
|
260
|
-
tool,
|
|
261
|
-
configured_tool_names)
|
|
262
|
-
|
|
263
|
-
intermediate_results[placeholder] = ToolMessage(content=TOOL_NOT_FOUND_ERROR_MESSAGE.format(
|
|
264
|
-
tool_name=tool, tools=configured_tool_names),
|
|
265
|
-
tool_call_id=tool)
|
|
266
|
-
return {"intermediate_results": intermediate_results}
|
|
365
|
+
updated_intermediate_results[placeholder] = result
|
|
366
|
+
# Check if the ToolMessage has error status and raise_tool_call_error is True
|
|
367
|
+
if (isinstance(result, ToolMessage) and hasattr(result, 'status') and result.status == "error"
|
|
368
|
+
and self.raise_tool_call_error):
|
|
369
|
+
logger.error("%s Tool call failed for %s: %s", AGENT_LOG_PREFIX, placeholder, result.content)
|
|
370
|
+
raise RuntimeError(f"Tool call failed: {result.content}")
|
|
267
371
|
|
|
268
372
|
if self.detailed_logs:
|
|
269
|
-
logger.
|
|
373
|
+
logger.info("%s Completed level %s with %s tools",
|
|
374
|
+
AGENT_LOG_PREFIX,
|
|
375
|
+
current_level,
|
|
376
|
+
len(pending_placeholders))
|
|
270
377
|
|
|
271
|
-
|
|
272
|
-
tool_input_parsed = self._parse_tool_input(tool_input)
|
|
273
|
-
tool_response = await self._call_tool(requested_tool,
|
|
274
|
-
tool_input_parsed,
|
|
275
|
-
RunnableConfig(callbacks=self.callbacks),
|
|
276
|
-
max_retries=self.tool_call_max_retries)
|
|
277
|
-
|
|
278
|
-
if self.detailed_logs:
|
|
279
|
-
self._log_tool_response(requested_tool.name, tool_input_parsed, str(tool_response))
|
|
280
|
-
|
|
281
|
-
if self.raise_tool_call_error and tool_response.status == "error":
|
|
282
|
-
raise RuntimeError(f"Tool call failed: {tool_response.content}")
|
|
283
|
-
|
|
284
|
-
intermediate_results[placeholder] = tool_response
|
|
285
|
-
return {"intermediate_results": intermediate_results}
|
|
378
|
+
return {"intermediate_results": updated_intermediate_results}
|
|
286
379
|
|
|
287
380
|
except Exception as ex:
|
|
288
381
|
logger.error("%s Failed to call executor_node: %s", AGENT_LOG_PREFIX, ex)
|
|
289
382
|
raise
|
|
290
383
|
|
|
384
|
+
async def _execute_single_tool(self,
|
|
385
|
+
placeholder: str,
|
|
386
|
+
step_info: ReWOOPlanStep,
|
|
387
|
+
intermediate_results: dict[str, ToolMessage]) -> ToolMessage:
|
|
388
|
+
"""
|
|
389
|
+
Execute a single tool with proper placeholder replacement.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
placeholder: The evidence placeholder (e.g., "#E1").
|
|
393
|
+
step_info: Step information containing tool and tool_input.
|
|
394
|
+
intermediate_results: Current intermediate results for placeholder replacement.
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
ToolMessage with the tool execution result.
|
|
398
|
+
"""
|
|
399
|
+
evidence_info = step_info.evidence
|
|
400
|
+
tool_name = evidence_info.tool
|
|
401
|
+
tool_input = evidence_info.tool_input
|
|
402
|
+
|
|
403
|
+
# Replace placeholders in tool input with previous results
|
|
404
|
+
for ph_key, tool_output in intermediate_results.items():
|
|
405
|
+
tool_output_content = tool_output.content
|
|
406
|
+
# If the content is a list, get the first element which should be a dict
|
|
407
|
+
if isinstance(tool_output_content, list):
|
|
408
|
+
tool_output_content = tool_output_content[0]
|
|
409
|
+
assert isinstance(tool_output_content, dict)
|
|
410
|
+
|
|
411
|
+
tool_input = self._replace_placeholder(ph_key, tool_input, tool_output_content)
|
|
412
|
+
|
|
413
|
+
# Get the requested tool
|
|
414
|
+
requested_tool = self._get_tool(tool_name)
|
|
415
|
+
if not requested_tool:
|
|
416
|
+
configured_tool_names = list(self.tools_dict.keys())
|
|
417
|
+
logger.warning(
|
|
418
|
+
"%s ReWOO Agent wants to call tool %s. In the ReWOO Agent's configuration within the config file,"
|
|
419
|
+
"there is no tool with that name: %s",
|
|
420
|
+
AGENT_LOG_PREFIX,
|
|
421
|
+
tool_name,
|
|
422
|
+
configured_tool_names)
|
|
423
|
+
|
|
424
|
+
return ToolMessage(content=TOOL_NOT_FOUND_ERROR_MESSAGE.format(tool_name=tool_name,
|
|
425
|
+
tools=configured_tool_names),
|
|
426
|
+
tool_call_id=placeholder)
|
|
427
|
+
|
|
428
|
+
if self.detailed_logs:
|
|
429
|
+
logger.debug("%s Calling tool %s with input: %s", AGENT_LOG_PREFIX, requested_tool.name, tool_input)
|
|
430
|
+
|
|
431
|
+
# Parse and execute the tool
|
|
432
|
+
tool_input_parsed = self._parse_tool_input(tool_input)
|
|
433
|
+
tool_response = await self._call_tool(
|
|
434
|
+
requested_tool,
|
|
435
|
+
tool_input_parsed,
|
|
436
|
+
RunnableConfig(callbacks=self.callbacks), # type: ignore
|
|
437
|
+
max_retries=self.tool_call_max_retries)
|
|
438
|
+
|
|
439
|
+
if self.detailed_logs:
|
|
440
|
+
self._log_tool_response(requested_tool.name, tool_input_parsed, str(tool_response))
|
|
441
|
+
|
|
442
|
+
return tool_response
|
|
443
|
+
|
|
291
444
|
async def solver_node(self, state: ReWOOGraphState):
|
|
292
445
|
try:
|
|
293
446
|
logger.debug("%s Starting the ReWOO Solver Node", AGENT_LOG_PREFIX)
|
|
294
447
|
|
|
295
448
|
plan = ""
|
|
296
|
-
# Add the tool outputs of each step to the plan
|
|
297
|
-
for
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
449
|
+
# Add the tool outputs of each step to the plan using evidence_map
|
|
450
|
+
for placeholder, step_info in state.evidence_map.items():
|
|
451
|
+
evidence_info = step_info.evidence
|
|
452
|
+
original_tool_input = evidence_info.tool_input
|
|
453
|
+
tool_name = evidence_info.tool
|
|
454
|
+
|
|
455
|
+
# Replace placeholders in tool input with actual results
|
|
456
|
+
final_tool_input = original_tool_input
|
|
457
|
+
for ph_key, tool_output in state.intermediate_results.items():
|
|
458
|
+
tool_output_content = tool_output.content
|
|
305
459
|
# If the content is a list, get the first element which should be a dict
|
|
306
|
-
if isinstance(
|
|
307
|
-
|
|
308
|
-
assert isinstance(
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
460
|
+
if isinstance(tool_output_content, list):
|
|
461
|
+
tool_output_content = tool_output_content[0]
|
|
462
|
+
assert isinstance(tool_output_content, dict)
|
|
463
|
+
|
|
464
|
+
final_tool_input = self._replace_placeholder(ph_key, final_tool_input, tool_output_content)
|
|
465
|
+
|
|
466
|
+
# Get the final result for this placeholder
|
|
467
|
+
final_result = ""
|
|
468
|
+
if placeholder in state.intermediate_results:
|
|
469
|
+
result_content = state.intermediate_results[placeholder].content
|
|
470
|
+
if isinstance(result_content, list):
|
|
471
|
+
result_content = result_content[0]
|
|
472
|
+
if isinstance(result_content, dict):
|
|
473
|
+
final_result = str(result_content)
|
|
474
|
+
else:
|
|
475
|
+
final_result = str(result_content)
|
|
476
|
+
|
|
477
|
+
step_plan = step_info.plan
|
|
478
|
+
plan += '\n'.join([
|
|
479
|
+
f"Plan: {step_plan}",
|
|
480
|
+
f"{placeholder} = {tool_name}[{final_tool_input}",
|
|
481
|
+
f"Result: {final_result}\n\n"
|
|
482
|
+
])
|
|
317
483
|
|
|
318
484
|
task = str(state.task.content)
|
|
319
485
|
solver_prompt = self.solver_prompt.partial(plan=plan)
|
|
@@ -336,12 +502,24 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
336
502
|
try:
|
|
337
503
|
logger.debug("%s Starting the ReWOO Conditional Edge", AGENT_LOG_PREFIX)
|
|
338
504
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
505
|
+
current_level, level_complete = self._get_current_level_status(state)
|
|
506
|
+
|
|
507
|
+
# If all levels are complete, move to solver
|
|
508
|
+
if current_level == -1:
|
|
509
|
+
logger.debug("%s All execution levels complete, moving to solver", AGENT_LOG_PREFIX)
|
|
342
510
|
return AgentDecision.END
|
|
343
511
|
|
|
344
|
-
|
|
512
|
+
# If current level is complete, check if there are more levels
|
|
513
|
+
if level_complete:
|
|
514
|
+
next_level = current_level + 1
|
|
515
|
+
if next_level >= len(state.execution_levels):
|
|
516
|
+
logger.debug("%s All execution levels complete, moving to solver", AGENT_LOG_PREFIX)
|
|
517
|
+
return AgentDecision.END
|
|
518
|
+
|
|
519
|
+
logger.debug("%s Continuing with executor (level %s, complete: %s)",
|
|
520
|
+
AGENT_LOG_PREFIX,
|
|
521
|
+
current_level,
|
|
522
|
+
level_complete)
|
|
345
523
|
return AgentDecision.TOOL
|
|
346
524
|
|
|
347
525
|
except Exception as ex:
|
|
@@ -359,8 +537,10 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
359
537
|
graph.add_node("solver", self.solver_node)
|
|
360
538
|
|
|
361
539
|
graph.add_edge("planner", "executor")
|
|
362
|
-
|
|
363
|
-
|
|
540
|
+
graph.add_conditional_edges("executor",
|
|
541
|
+
self.conditional_edge, {
|
|
542
|
+
AgentDecision.TOOL: "executor", AgentDecision.END: "solver"
|
|
543
|
+
})
|
|
364
544
|
|
|
365
545
|
graph.set_entry_point("planner")
|
|
366
546
|
graph.set_finish_point("solver")
|
nat/agent/rewoo_agent/prompt.py
CHANGED
|
@@ -18,33 +18,29 @@ For the following task, make plans that can solve the problem step by step. For
|
|
|
18
18
|
which external tool together with tool input to retrieve evidence. You can store the evidence into a \
|
|
19
19
|
placeholder #E that can be called by later tools. (Plan, #E1, Plan, #E2, Plan, ...)
|
|
20
20
|
|
|
21
|
-
|
|
21
|
+
The following tools and respective requirements are available to you:
|
|
22
22
|
|
|
23
23
|
{tools}
|
|
24
24
|
|
|
25
|
-
The
|
|
25
|
+
The tool calls you make should be one of the following: [{tool_names}]
|
|
26
26
|
|
|
27
27
|
You are not required to use all the tools listed. Choose only the ones that best fit the needs of each plan step.
|
|
28
28
|
|
|
29
|
-
Your output must be a JSON array where each element represents one planning step. Each step must be an object with
|
|
30
|
-
|
|
29
|
+
Your output must be a JSON array where each element represents one planning step. Each step must be an object with \
|
|
31
30
|
exactly two keys:
|
|
32
31
|
|
|
33
32
|
1. "plan": A string that describes in detail the action or reasoning for that step.
|
|
34
33
|
|
|
35
|
-
2. "evidence": An object representing the external tool call associated with that plan step. This object must have the
|
|
34
|
+
2. "evidence": An object representing the external tool call associated with that plan step. This object must have the \
|
|
36
35
|
following keys:
|
|
37
36
|
|
|
38
|
-
-"placeholder": A string that identifies the evidence placeholder (
|
|
39
|
-
|
|
37
|
+
-"placeholder": A string that identifies the evidence placeholder ("#E1", "#E2", ...). The numbering should \
|
|
38
|
+
be sequential based on the order of steps.
|
|
40
39
|
|
|
41
40
|
-"tool": A string specifying the name of the external tool used.
|
|
42
41
|
|
|
43
|
-
-"tool_input": The input to the tool. This can be a string, array, or object, depending on the requirements of the
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
Do not include any additional keys or characters in your output, and do not wrap your response with markdown formatting.
|
|
47
|
-
Your output must be strictly valid JSON.
|
|
42
|
+
-"tool_input": The input to the tool. This can be a string, array, or object, depending on the requirements of the \
|
|
43
|
+
tool. Be careful about type assumptions because the output of former tools might contain noise.
|
|
48
44
|
|
|
49
45
|
Important instructions:
|
|
50
46
|
|
|
@@ -58,27 +54,28 @@ Here is an example of how a valid JSON output should look:
|
|
|
58
54
|
|
|
59
55
|
[
|
|
60
56
|
\'{{
|
|
61
|
-
"plan": "
|
|
57
|
+
"plan": "Find Alex's schedule on Sep 25, 2025",
|
|
62
58
|
"evidence": \'{{
|
|
63
59
|
"placeholder": "#E1",
|
|
64
|
-
"tool": "
|
|
65
|
-
"tool_input":
|
|
60
|
+
"tool": "search_calendar",
|
|
61
|
+
"tool_input": ("Alex", "09/25/2025")
|
|
66
62
|
}}\'
|
|
67
63
|
}}\',
|
|
68
64
|
\'{{
|
|
69
|
-
"plan": "
|
|
65
|
+
"plan": "Find Bill's schedule on sep 25, 2025",
|
|
70
66
|
"evidence": \'{{
|
|
71
67
|
"placeholder": "#E2",
|
|
72
|
-
"tool": "
|
|
73
|
-
"tool_input": "
|
|
68
|
+
"tool": "search_calendar",
|
|
69
|
+
"tool_input": ("Bill", "09/25/2025")
|
|
74
70
|
}}\'
|
|
75
71
|
}}\',
|
|
76
72
|
\'{{
|
|
77
|
-
"plan": "
|
|
73
|
+
"plan": "Suggest a time for 1-hour meeting given Alex's and Bill's schedule.",
|
|
78
74
|
"evidence": \'{{
|
|
79
75
|
"placeholder": "#E3",
|
|
80
|
-
"tool": "
|
|
81
|
-
"tool_input": "
|
|
76
|
+
"tool": "llm_chat",
|
|
77
|
+
"tool_input": "Find a common 1-hour time slot for Alex and Bill given their schedules. \
|
|
78
|
+
Alex's schedule: #E1; Bill's schedule: #E2?"
|
|
82
79
|
}}\'
|
|
83
80
|
}}\'
|
|
84
81
|
]
|
|
@@ -94,7 +91,7 @@ task: {task}
|
|
|
94
91
|
"""
|
|
95
92
|
|
|
96
93
|
SOLVER_SYSTEM_PROMPT = """
|
|
97
|
-
Solve the following task or problem. To solve the problem, we have made
|
|
94
|
+
Solve the following task or problem. To solve the problem, we have made some Plans ahead and \
|
|
98
95
|
retrieved corresponding Evidence to each Plan. Use them with caution since long evidence might \
|
|
99
96
|
contain irrelevant information.
|
|
100
97
|
|
|
@@ -71,8 +71,8 @@ class ReWOOAgentWorkflowConfig(AgentBaseConfig, name="rewoo_agent"):
|
|
|
71
71
|
|
|
72
72
|
@register_function(config_type=ReWOOAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
73
73
|
async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builder):
|
|
74
|
-
from langchain.schema import BaseMessage
|
|
75
74
|
from langchain_core.messages import trim_messages
|
|
75
|
+
from langchain_core.messages.base import BaseMessage
|
|
76
76
|
from langchain_core.messages.human import HumanMessage
|
|
77
77
|
from langchain_core.prompts import ChatPromptTemplate
|
|
78
78
|
from langgraph.graph.state import CompiledStateGraph
|
|
@@ -154,6 +154,9 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
154
154
|
# get and return the output from the state
|
|
155
155
|
state = ReWOOGraphState(**state)
|
|
156
156
|
output_message = state.result.content
|
|
157
|
+
# Ensure output_message is a string
|
|
158
|
+
if isinstance(output_message, list | dict):
|
|
159
|
+
output_message = str(output_message)
|
|
157
160
|
return ChatResponse.from_string(output_message)
|
|
158
161
|
|
|
159
162
|
except Exception as ex:
|