nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +50 -22
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +54 -27
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +68 -17
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +2 -3
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +62 -22
- nat/cli/entrypoint.py +8 -10
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +74 -66
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/span.py +41 -3
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +452 -282
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +19 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +35 -15
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +106 -8
- nat/runtime/session.py +69 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/get_memory_tool.py +1 -1
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +42 -18
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +238 -196
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
aiq/__init__.py
CHANGED
|
@@ -13,10 +13,10 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import sys
|
|
17
16
|
import importlib
|
|
18
17
|
import importlib.abc
|
|
19
18
|
import importlib.util
|
|
19
|
+
import sys
|
|
20
20
|
import warnings
|
|
21
21
|
|
|
22
22
|
|
|
@@ -26,7 +26,7 @@ class CompatFinder(importlib.abc.MetaPathFinder):
|
|
|
26
26
|
self.alias_prefix = alias_prefix
|
|
27
27
|
self.target_prefix = target_prefix
|
|
28
28
|
|
|
29
|
-
def find_spec(self, fullname, path, target=None):
|
|
29
|
+
def find_spec(self, fullname, path, target=None):
|
|
30
30
|
if fullname == self.alias_prefix or fullname.startswith(self.alias_prefix + "."):
|
|
31
31
|
# Map aiq.something -> nat.something
|
|
32
32
|
target_name = self.target_prefix + fullname[len(self.alias_prefix):]
|
nat/agent/base.py
CHANGED
|
@@ -27,9 +27,10 @@ from langchain_core.language_models import BaseChatModel
|
|
|
27
27
|
from langchain_core.messages import AIMessage
|
|
28
28
|
from langchain_core.messages import BaseMessage
|
|
29
29
|
from langchain_core.messages import ToolMessage
|
|
30
|
+
from langchain_core.runnables import Runnable
|
|
30
31
|
from langchain_core.runnables import RunnableConfig
|
|
31
32
|
from langchain_core.tools import BaseTool
|
|
32
|
-
from langgraph.graph.
|
|
33
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
33
34
|
|
|
34
35
|
logger = logging.getLogger(__name__)
|
|
35
36
|
|
|
@@ -70,12 +71,14 @@ class BaseAgent(ABC):
|
|
|
70
71
|
llm: BaseChatModel,
|
|
71
72
|
tools: list[BaseTool],
|
|
72
73
|
callbacks: list[AsyncCallbackHandler] | None = None,
|
|
73
|
-
detailed_logs: bool = False
|
|
74
|
+
detailed_logs: bool = False,
|
|
75
|
+
log_response_max_chars: int = 1000) -> None:
|
|
74
76
|
logger.debug("Initializing Agent Graph")
|
|
75
77
|
self.llm = llm
|
|
76
78
|
self.tools = tools
|
|
77
79
|
self.callbacks = callbacks or []
|
|
78
80
|
self.detailed_logs = detailed_logs
|
|
81
|
+
self.log_response_max_chars = log_response_max_chars
|
|
79
82
|
self.graph = None
|
|
80
83
|
|
|
81
84
|
async def _stream_llm(self,
|
|
@@ -105,21 +108,25 @@ class BaseAgent(ABC):
|
|
|
105
108
|
|
|
106
109
|
return AIMessage(content=output_message)
|
|
107
110
|
|
|
108
|
-
async def _call_llm(self,
|
|
111
|
+
async def _call_llm(self, llm: Runnable, inputs: dict[str, Any], config: RunnableConfig | None = None) -> AIMessage:
|
|
109
112
|
"""
|
|
110
113
|
Call the LLM directly. Retry logic is handled automatically by the underlying LLM client.
|
|
111
114
|
|
|
112
115
|
Parameters
|
|
113
116
|
----------
|
|
114
|
-
|
|
115
|
-
The
|
|
117
|
+
llm : Runnable
|
|
118
|
+
The LLM runnable (prompt | llm or similar)
|
|
119
|
+
inputs : dict[str, Any]
|
|
120
|
+
The inputs to pass to the runnable
|
|
121
|
+
config : RunnableConfig | None
|
|
122
|
+
The config to pass to the runnable (should include callbacks)
|
|
116
123
|
|
|
117
124
|
Returns
|
|
118
125
|
-------
|
|
119
126
|
AIMessage
|
|
120
127
|
The LLM response
|
|
121
128
|
"""
|
|
122
|
-
response = await
|
|
129
|
+
response = await llm.ainvoke(inputs, config=config)
|
|
123
130
|
return AIMessage(content=str(response.content))
|
|
124
131
|
|
|
125
132
|
async def _call_tool(self,
|
|
@@ -158,6 +165,11 @@ class BaseAgent(ABC):
|
|
|
158
165
|
tool_call_id=tool.name,
|
|
159
166
|
content=f"The tool {tool.name} provided an empty response.")
|
|
160
167
|
|
|
168
|
+
# ToolMessage only accepts str or list[str | dict] as content.
|
|
169
|
+
# Convert into list if the response is a dict.
|
|
170
|
+
if isinstance(response, dict):
|
|
171
|
+
response = [response]
|
|
172
|
+
|
|
161
173
|
return ToolMessage(name=tool.name, tool_call_id=tool.name, content=response)
|
|
162
174
|
|
|
163
175
|
except Exception as e:
|
|
@@ -179,13 +191,12 @@ class BaseAgent(ABC):
|
|
|
179
191
|
logger.debug("%s Retrying tool call for %s in %d seconds...", AGENT_LOG_PREFIX, tool.name, sleep_time)
|
|
180
192
|
await asyncio.sleep(sleep_time)
|
|
181
193
|
|
|
182
|
-
# pylint: disable=C0209
|
|
183
194
|
# All retries exhausted, return error message
|
|
184
|
-
error_content = "Tool call failed after all retry attempts. Last error:
|
|
185
|
-
logger.error("%s %s", AGENT_LOG_PREFIX, error_content)
|
|
195
|
+
error_content = f"Tool call failed after all retry attempts. Last error: {str(last_exception)}"
|
|
196
|
+
logger.error("%s %s", AGENT_LOG_PREFIX, error_content, exc_info=True)
|
|
186
197
|
return ToolMessage(name=tool.name, tool_call_id=tool.name, content=error_content, status="error")
|
|
187
198
|
|
|
188
|
-
def _log_tool_response(self, tool_name: str, tool_input: Any, tool_response: str
|
|
199
|
+
def _log_tool_response(self, tool_name: str, tool_input: Any, tool_response: str) -> None:
|
|
189
200
|
"""
|
|
190
201
|
Log tool response with consistent formatting and length limits.
|
|
191
202
|
|
|
@@ -197,13 +208,11 @@ class BaseAgent(ABC):
|
|
|
197
208
|
The input that was passed to the tool
|
|
198
209
|
tool_response : str
|
|
199
210
|
The response from the tool
|
|
200
|
-
max_chars : int
|
|
201
|
-
Maximum number of characters to log (default: 1000)
|
|
202
211
|
"""
|
|
203
212
|
if self.detailed_logs:
|
|
204
213
|
# Truncate tool response if too long
|
|
205
|
-
display_response = tool_response[:
|
|
206
|
-
tool_response) >
|
|
214
|
+
display_response = tool_response[:self.log_response_max_chars] + "...(rest of response truncated)" if len(
|
|
215
|
+
tool_response) > self.log_response_max_chars else tool_response
|
|
207
216
|
|
|
208
217
|
# Format the tool input for display
|
|
209
218
|
tool_input_str = str(tool_input)
|
|
@@ -252,5 +261,5 @@ class BaseAgent(ABC):
|
|
|
252
261
|
return "\n".join([f"{message.type}: {message.content}" for message in messages[:-1]])
|
|
253
262
|
|
|
254
263
|
@abstractmethod
|
|
255
|
-
async def _build_graph(self, state_schema: type) ->
|
|
264
|
+
async def _build_graph(self, state_schema: type) -> CompiledStateGraph:
|
|
256
265
|
pass
|
nat/agent/dual_node.py
CHANGED
|
@@ -20,7 +20,7 @@ from langchain_core.callbacks import AsyncCallbackHandler
|
|
|
20
20
|
from langchain_core.language_models import BaseChatModel
|
|
21
21
|
from langchain_core.tools import BaseTool
|
|
22
22
|
from langgraph.graph import StateGraph
|
|
23
|
-
from langgraph.graph.
|
|
23
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
24
24
|
from pydantic import BaseModel
|
|
25
25
|
|
|
26
26
|
from .base import AgentDecision
|
|
@@ -35,8 +35,13 @@ class DualNodeAgent(BaseAgent):
|
|
|
35
35
|
llm: BaseChatModel,
|
|
36
36
|
tools: list[BaseTool],
|
|
37
37
|
callbacks: list[AsyncCallbackHandler] | None = None,
|
|
38
|
-
detailed_logs: bool = False
|
|
39
|
-
|
|
38
|
+
detailed_logs: bool = False,
|
|
39
|
+
log_response_max_chars: int = 1000):
|
|
40
|
+
super().__init__(llm=llm,
|
|
41
|
+
tools=tools,
|
|
42
|
+
callbacks=callbacks,
|
|
43
|
+
detailed_logs=detailed_logs,
|
|
44
|
+
log_response_max_chars=log_response_max_chars)
|
|
40
45
|
|
|
41
46
|
@abstractmethod
|
|
42
47
|
async def agent_node(self, state: BaseModel) -> BaseModel:
|
|
@@ -50,7 +55,7 @@ class DualNodeAgent(BaseAgent):
|
|
|
50
55
|
async def conditional_edge(self, state: BaseModel) -> str:
|
|
51
56
|
pass
|
|
52
57
|
|
|
53
|
-
async def _build_graph(self, state_schema) ->
|
|
58
|
+
async def _build_graph(self, state_schema: type) -> CompiledStateGraph:
|
|
54
59
|
log.debug("Building and compiling the Agent Graph")
|
|
55
60
|
|
|
56
61
|
graph = StateGraph(state_schema)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# flake8: noqa W291
|
|
16
|
+
|
|
17
|
+
mutator_prompt = """
|
|
18
|
+
|
|
19
|
+
## CORE DIRECTIVES
|
|
20
|
+
- **Preserve the original objective and task.** Do not change what the prompt is meant to accomplish.
|
|
21
|
+
- **Keep the intent intact.** The improved prompt must solve the same problem as the original.
|
|
22
|
+
- **Do not invent new goals.** Only improve clarity, structure, constraints, and usability.
|
|
23
|
+
- **Do not drop critical instructions.** Everything essential from the original prompt must remain.
|
|
24
|
+
- **Return only the mutated prompt text.** No rationale, no diffs, no explanations.
|
|
25
|
+
- **Be Creative within bounds.** You may rephrase, reorganize, and enhance, but not alter meaning.
|
|
26
|
+
- **DO NOT use curly braces in your prompt** for anything other than existing variables in the prompt as the string
|
|
27
|
+
will be treated as an f-string.
|
|
28
|
+
- **Examples are a good idea** if the original prompt lacks them. They help clarify expected output.
|
|
29
|
+
|
|
30
|
+
---
|
|
31
|
+
|
|
32
|
+
## IMPROVEMENT HINTS
|
|
33
|
+
When modifying, apply these principles:
|
|
34
|
+
1. **Clarity & Precision** – remove vague language, strengthen directives.
|
|
35
|
+
2. **Structure & Flow** – order sections as: *Objective → Constraints → Tools → Steps → Output Schema → Examples*.
|
|
36
|
+
3. **Schema Adherence** – enforce a single canonical output schema (JSON/XML) with `schema_version`.
|
|
37
|
+
4. **Tool Governance** – clarify when/how tools are used, their inputs/outputs, and fallback behavior.
|
|
38
|
+
5. **Error Handling** – specify behavior if tools fail or inputs are insufficient.
|
|
39
|
+
6. **Budget Awareness** – minimize verbosity, respect token/latency limits.
|
|
40
|
+
7. **Safety** – include refusals for unsafe requests, enforce compliance with rules.
|
|
41
|
+
8. **Consistency** – avoid format drift; always maintain the same schema.
|
|
42
|
+
9. **Integrity** – confirm the task, objective, and intent are preserved.
|
|
43
|
+
|
|
44
|
+
---
|
|
45
|
+
|
|
46
|
+
## MUTATION OPERATORS
|
|
47
|
+
You may:
|
|
48
|
+
- **Tighten** (remove fluff, redundancies)
|
|
49
|
+
- **Reorder** (improve logical flow)
|
|
50
|
+
- **Constrain** (add explicit rules/limits)
|
|
51
|
+
- **Harden** (improve error handling/fallbacks)
|
|
52
|
+
- **Defuse** (replace ambiguous verbs with measurable actions)
|
|
53
|
+
- **Format-lock** (wrap outputs in JSON/XML fenced blocks)
|
|
54
|
+
- **Example-ify** (add examples if missing or weak)
|
|
55
|
+
|
|
56
|
+
---
|
|
57
|
+
|
|
58
|
+
## INPUT
|
|
59
|
+
Here is the prompt to mutate:
|
|
60
|
+
{original_prompt}
|
|
61
|
+
|
|
62
|
+
## OBJECTIVE
|
|
63
|
+
The prompt must acheive the following objective:
|
|
64
|
+
{objective}
|
|
65
|
+
|
|
66
|
+
The modified prompt is: \n
|
|
67
|
+
|
|
68
|
+
"""
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import Field
|
|
17
|
+
|
|
18
|
+
from nat.builder.builder import Builder
|
|
19
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
20
|
+
from nat.builder.function_info import FunctionInfo
|
|
21
|
+
from nat.cli.register_workflow import register_function
|
|
22
|
+
from nat.data_models.component_ref import LLMRef
|
|
23
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
24
|
+
from nat.profiler.parameter_optimization.prompt_optimizer import PromptOptimizerInputSchema
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PromptOptimizerConfig(FunctionBaseConfig, name="prompt_init"):
|
|
28
|
+
|
|
29
|
+
optimizer_llm: LLMRef = Field(description="LLM to use for prompt optimization")
|
|
30
|
+
optimizer_prompt: str = Field(
|
|
31
|
+
description="Prompt template for the optimizer",
|
|
32
|
+
default=(
|
|
33
|
+
"You are an expert at optimizing prompts for LLMs. "
|
|
34
|
+
"Your task is to take a given prompt and suggest an optimized version of it. "
|
|
35
|
+
"Note that the prompt might be a template with variables and curly braces. Remember to always keep the "
|
|
36
|
+
"variables and curly braces in the prompt the same. Only modify the instructions in the prompt that are"
|
|
37
|
+
"not variables. The system is meant to achieve the following objective\n"
|
|
38
|
+
"{system_objective}\n Of which, the prompt is one part. The details of the prompt and context as below.\n"))
|
|
39
|
+
system_objective: str = Field(description="Objective of the workflow")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@register_function(config_type=PromptOptimizerConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
43
|
+
async def prompt_optimizer_function(config: PromptOptimizerConfig, builder: Builder):
|
|
44
|
+
"""
|
|
45
|
+
Function to optimize prompts for LLMs.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
from langchain_core.prompts import PromptTemplate
|
|
50
|
+
|
|
51
|
+
from .prompt import mutator_prompt
|
|
52
|
+
except ImportError as exc:
|
|
53
|
+
raise ImportError("langchain-core is not installed. Please install it to use MultiLLMPlanner.\n"
|
|
54
|
+
"This error can be resolve by installing \"nvidia-nat[langchain]\".") from exc
|
|
55
|
+
|
|
56
|
+
llm = await builder.get_llm(config.optimizer_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
57
|
+
|
|
58
|
+
template = PromptTemplate(template=config.optimizer_prompt,
|
|
59
|
+
input_variables=["system_objective"],
|
|
60
|
+
validate_template=True)
|
|
61
|
+
|
|
62
|
+
base_prompt: str = (await template.ainvoke(input={"system_objective": config.system_objective})).to_string()
|
|
63
|
+
prompt_extension_template = PromptTemplate(template=mutator_prompt,
|
|
64
|
+
input_variables=["original_prompt", "objective"],
|
|
65
|
+
validate_template=True)
|
|
66
|
+
|
|
67
|
+
async def _inner(input_message: PromptOptimizerInputSchema) -> str:
|
|
68
|
+
"""
|
|
69
|
+
Optimize the prompt using the provided LLM.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
original_prompt = input_message.original_prompt
|
|
73
|
+
prompt_objective = input_message.objective
|
|
74
|
+
|
|
75
|
+
prompt_extension = (await prompt_extension_template.ainvoke(input={
|
|
76
|
+
"original_prompt": original_prompt,
|
|
77
|
+
"objective": prompt_objective,
|
|
78
|
+
})).to_string()
|
|
79
|
+
|
|
80
|
+
prompt = f"{base_prompt}\n\n{prompt_extension}"
|
|
81
|
+
|
|
82
|
+
optimized_prompt = await llm.ainvoke(prompt)
|
|
83
|
+
return optimized_prompt.content
|
|
84
|
+
|
|
85
|
+
yield FunctionInfo.from_fn(
|
|
86
|
+
fn=_inner,
|
|
87
|
+
description="Optimize prompts for LLMs using a feedback LLM.",
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class PromptRecombinerConfig(FunctionBaseConfig, name="prompt_recombiner"):
|
|
92
|
+
|
|
93
|
+
optimizer_llm: LLMRef = Field(description="LLM to use for prompt recombination")
|
|
94
|
+
optimizer_prompt: str = Field(
|
|
95
|
+
description="Prompt template for the recombiner",
|
|
96
|
+
default=("You are an expert at combining prompt instructions for LLMs. "
|
|
97
|
+
"Your task is to merge two prompts for the same objective into a single, stronger prompt. "
|
|
98
|
+
"Do not introduce new variables or modify existing placeholders."),
|
|
99
|
+
)
|
|
100
|
+
system_objective: str = Field(description="Objective of the workflow")
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@register_function(config_type=PromptRecombinerConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
104
|
+
async def prompt_recombiner_function(config: PromptRecombinerConfig, builder: Builder):
|
|
105
|
+
"""
|
|
106
|
+
Function to recombine two parent prompts into a child prompt using the optimizer LLM.
|
|
107
|
+
Uses the same base template and objective instructions.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
from langchain_core.prompts import PromptTemplate
|
|
112
|
+
except ImportError as exc:
|
|
113
|
+
raise ImportError("langchain-core is not installed. Please install it to use MultiLLMPlanner.\n"
|
|
114
|
+
"This error can be resolve by installing \"nvidia-nat[langchain]\".") from exc
|
|
115
|
+
|
|
116
|
+
llm = await builder.get_llm(config.optimizer_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
117
|
+
|
|
118
|
+
template = PromptTemplate(template=config.optimizer_prompt,
|
|
119
|
+
input_variables=["system_objective"],
|
|
120
|
+
validate_template=True)
|
|
121
|
+
|
|
122
|
+
base_prompt: str = (await template.ainvoke(input={"system_objective": config.system_objective})).to_string()
|
|
123
|
+
|
|
124
|
+
class RecombineSchema(PromptOptimizerInputSchema):
|
|
125
|
+
parent_b: str | None = None
|
|
126
|
+
|
|
127
|
+
async def _inner(input_message: RecombineSchema) -> str:
|
|
128
|
+
parent_a = input_message.original_prompt
|
|
129
|
+
parent_b = input_message.parent_b or ""
|
|
130
|
+
prompt_objective = input_message.objective
|
|
131
|
+
|
|
132
|
+
prompt = (
|
|
133
|
+
f"{base_prompt}\n\n"
|
|
134
|
+
"We are performing genetic recombination between two prompts that satisfy the same objective.\n"
|
|
135
|
+
f"Objective: {prompt_objective}\n\n"
|
|
136
|
+
f"Parent A:\n{parent_a}\n\n"
|
|
137
|
+
f"Parent B:\n{parent_b}\n\n"
|
|
138
|
+
"Combine the strongest instructions and phrasing from both parents to produce a single, coherent child "
|
|
139
|
+
"prompt.\n"
|
|
140
|
+
"Maintain variables and placeholders unchanged.\n"
|
|
141
|
+
"Return only the child prompt text, with no additional commentary.")
|
|
142
|
+
|
|
143
|
+
child_prompt = await llm.ainvoke(prompt)
|
|
144
|
+
return child_prompt.content
|
|
145
|
+
|
|
146
|
+
yield FunctionInfo.from_fn(
|
|
147
|
+
fn=_inner,
|
|
148
|
+
description="Recombine two prompts into a stronger child prompt.",
|
|
149
|
+
)
|
nat/agent/react_agent/agent.py
CHANGED
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import json
|
|
17
|
-
# pylint: disable=R0917
|
|
18
17
|
import logging
|
|
18
|
+
import re
|
|
19
19
|
import typing
|
|
20
20
|
from json import JSONDecodeError
|
|
21
21
|
|
|
@@ -23,12 +23,14 @@ from langchain_core.agents import AgentAction
|
|
|
23
23
|
from langchain_core.agents import AgentFinish
|
|
24
24
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
|
25
25
|
from langchain_core.language_models import BaseChatModel
|
|
26
|
+
from langchain_core.language_models import LanguageModelInput
|
|
26
27
|
from langchain_core.messages.ai import AIMessage
|
|
27
28
|
from langchain_core.messages.base import BaseMessage
|
|
28
29
|
from langchain_core.messages.human import HumanMessage
|
|
29
30
|
from langchain_core.messages.tool import ToolMessage
|
|
30
31
|
from langchain_core.prompts import ChatPromptTemplate
|
|
31
32
|
from langchain_core.prompts import MessagesPlaceholder
|
|
33
|
+
from langchain_core.runnables import Runnable
|
|
32
34
|
from langchain_core.runnables.config import RunnableConfig
|
|
33
35
|
from langchain_core.tools import BaseTool
|
|
34
36
|
from pydantic import BaseModel
|
|
@@ -57,6 +59,7 @@ class ReActGraphState(BaseModel):
|
|
|
57
59
|
messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReAct Agent
|
|
58
60
|
agent_scratchpad: list[AgentAction] = Field(default_factory=list) # agent thoughts / intermediate steps
|
|
59
61
|
tool_responses: list[BaseMessage] = Field(default_factory=list) # the responses from any tool calls
|
|
62
|
+
final_answer: str | None = Field(default=None) # the final answer from the ReAct Agent
|
|
60
63
|
|
|
61
64
|
|
|
62
65
|
class ReActAgentGraph(DualNodeAgent):
|
|
@@ -71,15 +74,22 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
71
74
|
use_tool_schema: bool = True,
|
|
72
75
|
callbacks: list[AsyncCallbackHandler] | None = None,
|
|
73
76
|
detailed_logs: bool = False,
|
|
77
|
+
log_response_max_chars: int = 1000,
|
|
74
78
|
retry_agent_response_parsing_errors: bool = True,
|
|
75
79
|
parse_agent_response_max_retries: int = 1,
|
|
76
80
|
tool_call_max_retries: int = 1,
|
|
77
|
-
pass_tool_call_errors_to_agent: bool = True
|
|
78
|
-
|
|
81
|
+
pass_tool_call_errors_to_agent: bool = True,
|
|
82
|
+
normalize_tool_input_quotes: bool = True):
|
|
83
|
+
super().__init__(llm=llm,
|
|
84
|
+
tools=tools,
|
|
85
|
+
callbacks=callbacks,
|
|
86
|
+
detailed_logs=detailed_logs,
|
|
87
|
+
log_response_max_chars=log_response_max_chars)
|
|
79
88
|
self.parse_agent_response_max_retries = (parse_agent_response_max_retries
|
|
80
89
|
if retry_agent_response_parsing_errors else 1)
|
|
81
90
|
self.tool_call_max_retries = tool_call_max_retries
|
|
82
91
|
self.pass_tool_call_errors_to_agent = pass_tool_call_errors_to_agent
|
|
92
|
+
self.normalize_tool_input_quotes = normalize_tool_input_quotes
|
|
83
93
|
logger.debug(
|
|
84
94
|
"%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
|
|
85
95
|
AGENT_LOG_PREFIX)
|
|
@@ -97,21 +107,33 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
97
107
|
f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
|
|
98
108
|
prompt = prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
|
|
99
109
|
# construct the ReAct Agent
|
|
100
|
-
|
|
101
|
-
self.agent = prompt | bound_llm
|
|
110
|
+
self.agent = prompt | self._maybe_bind_llm_and_yield()
|
|
102
111
|
self.tools_dict = {tool.name: tool for tool in tools}
|
|
103
112
|
logger.debug("%s Initialized ReAct Agent Graph", AGENT_LOG_PREFIX)
|
|
104
113
|
|
|
114
|
+
def _maybe_bind_llm_and_yield(self) -> Runnable[LanguageModelInput, BaseMessage]:
|
|
115
|
+
"""
|
|
116
|
+
Bind additional parameters to the LLM if needed
|
|
117
|
+
- if the LLM is a smart model, no need to bind any additional parameters
|
|
118
|
+
- if the LLM is a non-smart model, bind a stop sequence to the LLM
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Runnable[LanguageModelInput, BaseMessage]: The LLM with any additional parameters bound.
|
|
122
|
+
"""
|
|
123
|
+
# models that don't need (or don't support)a stop sequence
|
|
124
|
+
smart_models = re.compile(r"gpt-?5", re.IGNORECASE)
|
|
125
|
+
if any(smart_models.search(getattr(self.llm, model, "")) for model in ["model", "model_name"]):
|
|
126
|
+
# no need to bind any additional parameters to the LLM
|
|
127
|
+
return self.llm
|
|
128
|
+
# add a stop sequence to the LLM
|
|
129
|
+
return self.llm.bind(stop=["Observation:"])
|
|
130
|
+
|
|
105
131
|
def _get_tool(self, tool_name: str):
|
|
106
132
|
try:
|
|
107
133
|
return self.tools_dict.get(tool_name)
|
|
108
134
|
except Exception as ex:
|
|
109
|
-
logger.
|
|
110
|
-
|
|
111
|
-
tool_name,
|
|
112
|
-
ex,
|
|
113
|
-
exc_info=True)
|
|
114
|
-
raise ex
|
|
135
|
+
logger.error("%s Unable to find tool with the name %s\n%s", AGENT_LOG_PREFIX, tool_name, ex)
|
|
136
|
+
raise
|
|
115
137
|
|
|
116
138
|
async def agent_node(self, state: ReActGraphState):
|
|
117
139
|
try:
|
|
@@ -183,6 +205,7 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
183
205
|
# this is where we handle the final output of the Agent, we can clean-up/format/postprocess here
|
|
184
206
|
# the final answer goes in the "messages" state channel
|
|
185
207
|
state.messages += [AIMessage(content=final_answer)]
|
|
208
|
+
state.final_answer = final_answer
|
|
186
209
|
else:
|
|
187
210
|
# the agent wants to call a tool, ensure the thoughts are preserved for the next agentic cycle
|
|
188
211
|
agent_output.log = output_message.content
|
|
@@ -215,16 +238,15 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
215
238
|
working_state.append(output_message)
|
|
216
239
|
working_state.append(HumanMessage(content=str(ex.observation)))
|
|
217
240
|
except Exception as ex:
|
|
218
|
-
logger.
|
|
219
|
-
raise
|
|
241
|
+
logger.error("%s Failed to call agent_node: %s", AGENT_LOG_PREFIX, ex)
|
|
242
|
+
raise
|
|
220
243
|
|
|
221
244
|
async def conditional_edge(self, state: ReActGraphState):
|
|
222
245
|
try:
|
|
223
246
|
logger.debug("%s Starting the ReAct Conditional Edge", AGENT_LOG_PREFIX)
|
|
224
|
-
if
|
|
225
|
-
# the ReAct Agent has finished executing
|
|
226
|
-
|
|
227
|
-
logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, last_message_content)
|
|
247
|
+
if state.final_answer:
|
|
248
|
+
# the ReAct Agent has finished executing
|
|
249
|
+
logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, state.final_answer)
|
|
228
250
|
return AgentDecision.END
|
|
229
251
|
# else the agent wants to call a tool
|
|
230
252
|
agent_output = state.agent_scratchpad[-1]
|
|
@@ -234,7 +256,7 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
234
256
|
agent_output.tool_input)
|
|
235
257
|
return AgentDecision.TOOL
|
|
236
258
|
except Exception as ex:
|
|
237
|
-
logger.exception("Failed to determine whether agent is calling a tool: %s", ex
|
|
259
|
+
logger.exception("Failed to determine whether agent is calling a tool: %s", ex)
|
|
238
260
|
logger.warning("%s Ending graph traversal", AGENT_LOG_PREFIX)
|
|
239
261
|
return AgentDecision.END
|
|
240
262
|
|
|
@@ -267,35 +289,45 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
267
289
|
agent_thoughts.tool_input)
|
|
268
290
|
|
|
269
291
|
# Run the tool. Try to use structured input, if possible.
|
|
292
|
+
tool_input_str = agent_thoughts.tool_input.strip()
|
|
293
|
+
|
|
270
294
|
try:
|
|
271
|
-
|
|
272
|
-
tool_input_dict = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
|
|
295
|
+
tool_input = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
|
|
273
296
|
logger.debug("%s Successfully parsed structured tool input from Action Input", AGENT_LOG_PREFIX)
|
|
274
297
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
298
|
+
except JSONDecodeError as original_ex:
|
|
299
|
+
if self.normalize_tool_input_quotes:
|
|
300
|
+
# If initial JSON parsing fails, try with quote normalization as a fallback
|
|
301
|
+
normalized_str = tool_input_str.replace("'", '"')
|
|
302
|
+
try:
|
|
303
|
+
tool_input = json.loads(normalized_str)
|
|
304
|
+
logger.debug("%s Successfully parsed structured tool input after quote normalization",
|
|
305
|
+
AGENT_LOG_PREFIX)
|
|
306
|
+
except JSONDecodeError:
|
|
307
|
+
# the quote normalization failed, use raw string input
|
|
308
|
+
logger.debug(
|
|
309
|
+
"%s Unable to parse structured tool input after quote normalization. Using Action Input as is."
|
|
310
|
+
"\nParsing error: %s",
|
|
311
|
+
AGENT_LOG_PREFIX,
|
|
312
|
+
original_ex)
|
|
313
|
+
tool_input = tool_input_str
|
|
314
|
+
else:
|
|
315
|
+
# use raw string input
|
|
316
|
+
logger.debug(
|
|
317
|
+
"%s Unable to parse structured tool input from Action Input. Using Action Input as is."
|
|
318
|
+
"\nParsing error: %s",
|
|
319
|
+
AGENT_LOG_PREFIX,
|
|
320
|
+
original_ex)
|
|
321
|
+
tool_input = tool_input_str
|
|
322
|
+
|
|
323
|
+
# Call tool once with the determined input (either parsed dict or raw string)
|
|
324
|
+
tool_response = await self._call_tool(requested_tool,
|
|
325
|
+
tool_input,
|
|
326
|
+
RunnableConfig(callbacks=self.callbacks),
|
|
327
|
+
max_retries=self.tool_call_max_retries)
|
|
296
328
|
|
|
297
329
|
if self.detailed_logs:
|
|
298
|
-
self._log_tool_response(requested_tool.name,
|
|
330
|
+
self._log_tool_response(requested_tool.name, tool_input, str(tool_response.content))
|
|
299
331
|
|
|
300
332
|
if not self.pass_tool_call_errors_to_agent:
|
|
301
333
|
if tool_response.status == "error":
|
|
@@ -311,8 +343,8 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
311
343
|
logger.debug("%s ReAct Graph built and compiled successfully", AGENT_LOG_PREFIX)
|
|
312
344
|
return self.graph
|
|
313
345
|
except Exception as ex:
|
|
314
|
-
logger.
|
|
315
|
-
raise
|
|
346
|
+
logger.error("%s Failed to build ReAct Graph: %s", AGENT_LOG_PREFIX, ex)
|
|
347
|
+
raise
|
|
316
348
|
|
|
317
349
|
@staticmethod
|
|
318
350
|
def validate_system_prompt(system_prompt: str) -> bool:
|
|
@@ -328,8 +360,8 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
328
360
|
errors.append(error_message)
|
|
329
361
|
if errors:
|
|
330
362
|
error_text = "\n".join(errors)
|
|
331
|
-
logger.
|
|
332
|
-
|
|
363
|
+
logger.error("%s %s", AGENT_LOG_PREFIX, error_text)
|
|
364
|
+
return False
|
|
333
365
|
return True
|
|
334
366
|
|
|
335
367
|
|
|
@@ -355,7 +387,7 @@ def create_react_agent_prompt(config: "ReActAgentWorkflowConfig") -> ChatPromptT
|
|
|
355
387
|
|
|
356
388
|
valid_prompt = ReActAgentGraph.validate_system_prompt(prompt_str)
|
|
357
389
|
if not valid_prompt:
|
|
358
|
-
logger.
|
|
390
|
+
logger.error("%s Invalid system_prompt", AGENT_LOG_PREFIX)
|
|
359
391
|
raise ValueError("Invalid system_prompt")
|
|
360
392
|
prompt = ChatPromptTemplate([("system", prompt_str), ("user", USER_PROMPT),
|
|
361
393
|
MessagesPlaceholder(variable_name='agent_scratchpad', optional=True)])
|