nvidia-nat 1.3.dev0__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/base.py +40 -14
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +96 -57
- nat/agent/react_agent/prompt.py +4 -1
- nat/agent/react_agent/register.py +41 -21
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +332 -150
- nat/agent/rewoo_agent/prompt.py +22 -22
- nat/agent/rewoo_agent/register.py +49 -28
- nat/agent/tool_calling_agent/agent.py +156 -29
- nat/agent/tool_calling_agent/register.py +57 -38
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +46 -11
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +9 -13
- nat/cli/entrypoint.py +8 -10
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +79 -10
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +10 -10
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/azure_openai_embedder.py +46 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +2 -3
- nat/embedder/register.py +1 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +17 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +57 -0
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +5 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/models.py +2 -0
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +14 -7
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +3 -2
- nat/runtime/session.py +43 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +49 -21
- {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +233 -189
- {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
- nvidia_nat-1.3.0rc1.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- nvidia_nat-1.3.dev0.dist-info/licenses/LICENSE-3rd-party.txt +0 -3686
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
nat/agent/react_agent/agent.py
CHANGED
|
@@ -14,20 +14,23 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import json
|
|
17
|
-
# pylint: disable=R0917
|
|
18
17
|
import logging
|
|
18
|
+
import re
|
|
19
|
+
import typing
|
|
19
20
|
from json import JSONDecodeError
|
|
20
21
|
|
|
21
22
|
from langchain_core.agents import AgentAction
|
|
22
23
|
from langchain_core.agents import AgentFinish
|
|
23
24
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
|
24
25
|
from langchain_core.language_models import BaseChatModel
|
|
26
|
+
from langchain_core.language_models import LanguageModelInput
|
|
25
27
|
from langchain_core.messages.ai import AIMessage
|
|
26
28
|
from langchain_core.messages.base import BaseMessage
|
|
27
29
|
from langchain_core.messages.human import HumanMessage
|
|
28
30
|
from langchain_core.messages.tool import ToolMessage
|
|
29
31
|
from langchain_core.prompts import ChatPromptTemplate
|
|
30
32
|
from langchain_core.prompts import MessagesPlaceholder
|
|
33
|
+
from langchain_core.runnables import Runnable
|
|
31
34
|
from langchain_core.runnables.config import RunnableConfig
|
|
32
35
|
from langchain_core.tools import BaseTool
|
|
33
36
|
from pydantic import BaseModel
|
|
@@ -44,7 +47,9 @@ from nat.agent.react_agent.output_parser import ReActOutputParser
|
|
|
44
47
|
from nat.agent.react_agent.output_parser import ReActOutputParserException
|
|
45
48
|
from nat.agent.react_agent.prompt import SYSTEM_PROMPT
|
|
46
49
|
from nat.agent.react_agent.prompt import USER_PROMPT
|
|
47
|
-
|
|
50
|
+
|
|
51
|
+
if typing.TYPE_CHECKING:
|
|
52
|
+
from nat.agent.react_agent.register import ReActAgentWorkflowConfig
|
|
48
53
|
|
|
49
54
|
logger = logging.getLogger(__name__)
|
|
50
55
|
|
|
@@ -54,6 +59,7 @@ class ReActGraphState(BaseModel):
|
|
|
54
59
|
messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReAct Agent
|
|
55
60
|
agent_scratchpad: list[AgentAction] = Field(default_factory=list) # agent thoughts / intermediate steps
|
|
56
61
|
tool_responses: list[BaseMessage] = Field(default_factory=list) # the responses from any tool calls
|
|
62
|
+
final_answer: str | None = Field(default=None) # the final answer from the ReAct Agent
|
|
57
63
|
|
|
58
64
|
|
|
59
65
|
class ReActAgentGraph(DualNodeAgent):
|
|
@@ -68,15 +74,22 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
68
74
|
use_tool_schema: bool = True,
|
|
69
75
|
callbacks: list[AsyncCallbackHandler] | None = None,
|
|
70
76
|
detailed_logs: bool = False,
|
|
77
|
+
log_response_max_chars: int = 1000,
|
|
71
78
|
retry_agent_response_parsing_errors: bool = True,
|
|
72
79
|
parse_agent_response_max_retries: int = 1,
|
|
73
80
|
tool_call_max_retries: int = 1,
|
|
74
|
-
pass_tool_call_errors_to_agent: bool = True
|
|
75
|
-
|
|
81
|
+
pass_tool_call_errors_to_agent: bool = True,
|
|
82
|
+
normalize_tool_input_quotes: bool = True):
|
|
83
|
+
super().__init__(llm=llm,
|
|
84
|
+
tools=tools,
|
|
85
|
+
callbacks=callbacks,
|
|
86
|
+
detailed_logs=detailed_logs,
|
|
87
|
+
log_response_max_chars=log_response_max_chars)
|
|
76
88
|
self.parse_agent_response_max_retries = (parse_agent_response_max_retries
|
|
77
89
|
if retry_agent_response_parsing_errors else 1)
|
|
78
90
|
self.tool_call_max_retries = tool_call_max_retries
|
|
79
91
|
self.pass_tool_call_errors_to_agent = pass_tool_call_errors_to_agent
|
|
92
|
+
self.normalize_tool_input_quotes = normalize_tool_input_quotes
|
|
80
93
|
logger.debug(
|
|
81
94
|
"%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
|
|
82
95
|
AGENT_LOG_PREFIX)
|
|
@@ -94,21 +107,33 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
94
107
|
f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
|
|
95
108
|
prompt = prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
|
|
96
109
|
# construct the ReAct Agent
|
|
97
|
-
|
|
98
|
-
self.agent = prompt | bound_llm
|
|
110
|
+
self.agent = prompt | self._maybe_bind_llm_and_yield()
|
|
99
111
|
self.tools_dict = {tool.name: tool for tool in tools}
|
|
100
112
|
logger.debug("%s Initialized ReAct Agent Graph", AGENT_LOG_PREFIX)
|
|
101
113
|
|
|
114
|
+
def _maybe_bind_llm_and_yield(self) -> Runnable[LanguageModelInput, BaseMessage]:
|
|
115
|
+
"""
|
|
116
|
+
Bind additional parameters to the LLM if needed
|
|
117
|
+
- if the LLM is a smart model, no need to bind any additional parameters
|
|
118
|
+
- if the LLM is a non-smart model, bind a stop sequence to the LLM
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Runnable[LanguageModelInput, BaseMessage]: The LLM with any additional parameters bound.
|
|
122
|
+
"""
|
|
123
|
+
# models that don't need (or don't support)a stop sequence
|
|
124
|
+
smart_models = re.compile(r"gpt-?5", re.IGNORECASE)
|
|
125
|
+
if any(smart_models.search(getattr(self.llm, model, "")) for model in ["model", "model_name"]):
|
|
126
|
+
# no need to bind any additional parameters to the LLM
|
|
127
|
+
return self.llm
|
|
128
|
+
# add a stop sequence to the LLM
|
|
129
|
+
return self.llm.bind(stop=["Observation:"])
|
|
130
|
+
|
|
102
131
|
def _get_tool(self, tool_name: str):
|
|
103
132
|
try:
|
|
104
133
|
return self.tools_dict.get(tool_name)
|
|
105
134
|
except Exception as ex:
|
|
106
|
-
logger.
|
|
107
|
-
|
|
108
|
-
tool_name,
|
|
109
|
-
ex,
|
|
110
|
-
exc_info=True)
|
|
111
|
-
raise ex
|
|
135
|
+
logger.error("%s Unable to find tool with the name %s\n%s", AGENT_LOG_PREFIX, tool_name, ex)
|
|
136
|
+
raise
|
|
112
137
|
|
|
113
138
|
async def agent_node(self, state: ReActGraphState):
|
|
114
139
|
try:
|
|
@@ -124,17 +149,19 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
124
149
|
if len(state.messages) == 0:
|
|
125
150
|
raise RuntimeError('No input received in state: "messages"')
|
|
126
151
|
# to check is any human input passed or not, if no input passed Agent will return the state
|
|
127
|
-
content = str(state.messages[
|
|
152
|
+
content = str(state.messages[-1].content)
|
|
128
153
|
if content.strip() == "":
|
|
129
154
|
logger.error("%s No human input passed to the agent.", AGENT_LOG_PREFIX)
|
|
130
155
|
state.messages += [AIMessage(content=NO_INPUT_ERROR_MESSAGE)]
|
|
131
156
|
return state
|
|
132
157
|
question = content
|
|
133
158
|
logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
|
|
134
|
-
|
|
159
|
+
chat_history = self._get_chat_history(state.messages)
|
|
135
160
|
output_message = await self._stream_llm(
|
|
136
161
|
self.agent,
|
|
137
|
-
{
|
|
162
|
+
{
|
|
163
|
+
"question": question, "chat_history": chat_history
|
|
164
|
+
},
|
|
138
165
|
RunnableConfig(callbacks=self.callbacks) # type: ignore
|
|
139
166
|
)
|
|
140
167
|
|
|
@@ -152,13 +179,15 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
152
179
|
tool_response = HumanMessage(content=tool_response_content)
|
|
153
180
|
agent_scratchpad.append(tool_response)
|
|
154
181
|
agent_scratchpad += working_state
|
|
155
|
-
|
|
182
|
+
chat_history = self._get_chat_history(state.messages)
|
|
183
|
+
question = str(state.messages[-1].content)
|
|
156
184
|
logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
|
|
157
185
|
|
|
158
|
-
output_message = await self._stream_llm(
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
186
|
+
output_message = await self._stream_llm(
|
|
187
|
+
self.agent, {
|
|
188
|
+
"question": question, "agent_scratchpad": agent_scratchpad, "chat_history": chat_history
|
|
189
|
+
},
|
|
190
|
+
RunnableConfig(callbacks=self.callbacks))
|
|
162
191
|
|
|
163
192
|
if self.detailed_logs:
|
|
164
193
|
logger.info(AGENT_CALL_LOG_MESSAGE, question, output_message.content)
|
|
@@ -176,6 +205,7 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
176
205
|
# this is where we handle the final output of the Agent, we can clean-up/format/postprocess here
|
|
177
206
|
# the final answer goes in the "messages" state channel
|
|
178
207
|
state.messages += [AIMessage(content=final_answer)]
|
|
208
|
+
state.final_answer = final_answer
|
|
179
209
|
else:
|
|
180
210
|
# the agent wants to call a tool, ensure the thoughts are preserved for the next agentic cycle
|
|
181
211
|
agent_output.log = output_message.content
|
|
@@ -208,16 +238,15 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
208
238
|
working_state.append(output_message)
|
|
209
239
|
working_state.append(HumanMessage(content=str(ex.observation)))
|
|
210
240
|
except Exception as ex:
|
|
211
|
-
logger.
|
|
212
|
-
raise
|
|
241
|
+
logger.error("%s Failed to call agent_node: %s", AGENT_LOG_PREFIX, ex)
|
|
242
|
+
raise
|
|
213
243
|
|
|
214
244
|
async def conditional_edge(self, state: ReActGraphState):
|
|
215
245
|
try:
|
|
216
246
|
logger.debug("%s Starting the ReAct Conditional Edge", AGENT_LOG_PREFIX)
|
|
217
|
-
if
|
|
218
|
-
# the ReAct Agent has finished executing
|
|
219
|
-
|
|
220
|
-
logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, last_message_content)
|
|
247
|
+
if state.final_answer:
|
|
248
|
+
# the ReAct Agent has finished executing
|
|
249
|
+
logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, state.final_answer)
|
|
221
250
|
return AgentDecision.END
|
|
222
251
|
# else the agent wants to call a tool
|
|
223
252
|
agent_output = state.agent_scratchpad[-1]
|
|
@@ -227,7 +256,7 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
227
256
|
agent_output.tool_input)
|
|
228
257
|
return AgentDecision.TOOL
|
|
229
258
|
except Exception as ex:
|
|
230
|
-
logger.exception("Failed to determine whether agent is calling a tool: %s", ex
|
|
259
|
+
logger.exception("Failed to determine whether agent is calling a tool: %s", ex)
|
|
231
260
|
logger.warning("%s Ending graph traversal", AGENT_LOG_PREFIX)
|
|
232
261
|
return AgentDecision.END
|
|
233
262
|
|
|
@@ -260,35 +289,45 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
260
289
|
agent_thoughts.tool_input)
|
|
261
290
|
|
|
262
291
|
# Run the tool. Try to use structured input, if possible.
|
|
292
|
+
tool_input_str = agent_thoughts.tool_input.strip()
|
|
293
|
+
|
|
263
294
|
try:
|
|
264
|
-
|
|
265
|
-
tool_input_dict = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
|
|
295
|
+
tool_input = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
|
|
266
296
|
logger.debug("%s Successfully parsed structured tool input from Action Input", AGENT_LOG_PREFIX)
|
|
267
297
|
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
298
|
+
except JSONDecodeError as original_ex:
|
|
299
|
+
if self.normalize_tool_input_quotes:
|
|
300
|
+
# If initial JSON parsing fails, try with quote normalization as a fallback
|
|
301
|
+
normalized_str = tool_input_str.replace("'", '"')
|
|
302
|
+
try:
|
|
303
|
+
tool_input = json.loads(normalized_str)
|
|
304
|
+
logger.debug("%s Successfully parsed structured tool input after quote normalization",
|
|
305
|
+
AGENT_LOG_PREFIX)
|
|
306
|
+
except JSONDecodeError:
|
|
307
|
+
# the quote normalization failed, use raw string input
|
|
308
|
+
logger.debug(
|
|
309
|
+
"%s Unable to parse structured tool input after quote normalization. Using Action Input as is."
|
|
310
|
+
"\nParsing error: %s",
|
|
311
|
+
AGENT_LOG_PREFIX,
|
|
312
|
+
original_ex)
|
|
313
|
+
tool_input = tool_input_str
|
|
314
|
+
else:
|
|
315
|
+
# use raw string input
|
|
316
|
+
logger.debug(
|
|
317
|
+
"%s Unable to parse structured tool input from Action Input. Using Action Input as is."
|
|
318
|
+
"\nParsing error: %s",
|
|
319
|
+
AGENT_LOG_PREFIX,
|
|
320
|
+
original_ex)
|
|
321
|
+
tool_input = tool_input_str
|
|
322
|
+
|
|
323
|
+
# Call tool once with the determined input (either parsed dict or raw string)
|
|
324
|
+
tool_response = await self._call_tool(requested_tool,
|
|
325
|
+
tool_input,
|
|
326
|
+
RunnableConfig(callbacks=self.callbacks),
|
|
327
|
+
max_retries=self.tool_call_max_retries)
|
|
289
328
|
|
|
290
329
|
if self.detailed_logs:
|
|
291
|
-
self._log_tool_response(requested_tool.name,
|
|
330
|
+
self._log_tool_response(requested_tool.name, tool_input, str(tool_response.content))
|
|
292
331
|
|
|
293
332
|
if not self.pass_tool_call_errors_to_agent:
|
|
294
333
|
if tool_response.status == "error":
|
|
@@ -304,8 +343,8 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
304
343
|
logger.debug("%s ReAct Graph built and compiled successfully", AGENT_LOG_PREFIX)
|
|
305
344
|
return self.graph
|
|
306
345
|
except Exception as ex:
|
|
307
|
-
logger.
|
|
308
|
-
raise
|
|
346
|
+
logger.error("%s Failed to build ReAct Graph: %s", AGENT_LOG_PREFIX, ex)
|
|
347
|
+
raise
|
|
309
348
|
|
|
310
349
|
@staticmethod
|
|
311
350
|
def validate_system_prompt(system_prompt: str) -> bool:
|
|
@@ -321,12 +360,12 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
321
360
|
errors.append(error_message)
|
|
322
361
|
if errors:
|
|
323
362
|
error_text = "\n".join(errors)
|
|
324
|
-
logger.
|
|
325
|
-
|
|
363
|
+
logger.error("%s %s", AGENT_LOG_PREFIX, error_text)
|
|
364
|
+
return False
|
|
326
365
|
return True
|
|
327
366
|
|
|
328
367
|
|
|
329
|
-
def create_react_agent_prompt(config: ReActAgentWorkflowConfig) -> ChatPromptTemplate:
|
|
368
|
+
def create_react_agent_prompt(config: "ReActAgentWorkflowConfig") -> ChatPromptTemplate:
|
|
330
369
|
"""
|
|
331
370
|
Create a ReAct Agent prompt from the config.
|
|
332
371
|
|
|
@@ -348,7 +387,7 @@ def create_react_agent_prompt(config: ReActAgentWorkflowConfig) -> ChatPromptTem
|
|
|
348
387
|
|
|
349
388
|
valid_prompt = ReActAgentGraph.validate_system_prompt(prompt_str)
|
|
350
389
|
if not valid_prompt:
|
|
351
|
-
logger.
|
|
390
|
+
logger.error("%s Invalid system_prompt", AGENT_LOG_PREFIX)
|
|
352
391
|
raise ValueError("Invalid system_prompt")
|
|
353
392
|
prompt = ChatPromptTemplate([("system", prompt_str), ("user", USER_PROMPT),
|
|
354
393
|
MessagesPlaceholder(variable_name='agent_scratchpad', optional=True)])
|
nat/agent/react_agent/prompt.py
CHANGED
|
@@ -26,7 +26,7 @@ Use the following format exactly to ask the human to use a tool:
|
|
|
26
26
|
Question: the input question you must answer
|
|
27
27
|
Thought: you should always think about what to do
|
|
28
28
|
Action: the action to take, should be one of [{tool_names}]
|
|
29
|
-
Action Input: the input to the action (if there is no required input, include "Action Input: None")
|
|
29
|
+
Action Input: the input to the action (if there is no required input, include "Action Input: None")
|
|
30
30
|
Observation: wait for the human to respond with the result from the tool, do not assume the response
|
|
31
31
|
|
|
32
32
|
... (this Thought/Action/Action Input/Observation can repeat N times. If you do not need to use a tool, or after asking the human to use any tools and waiting for the human to respond, you might know the final answer.)
|
|
@@ -37,5 +37,8 @@ Final Answer: the final answer to the original input question
|
|
|
37
37
|
"""
|
|
38
38
|
|
|
39
39
|
USER_PROMPT = """
|
|
40
|
+
Previous conversation history:
|
|
41
|
+
{chat_history}
|
|
42
|
+
|
|
40
43
|
Question: {question}
|
|
41
44
|
"""
|
|
@@ -22,26 +22,27 @@ from nat.builder.builder import Builder
|
|
|
22
22
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
23
23
|
from nat.builder.function_info import FunctionInfo
|
|
24
24
|
from nat.cli.register_workflow import register_function
|
|
25
|
+
from nat.data_models.agent import AgentBaseConfig
|
|
25
26
|
from nat.data_models.api_server import ChatRequest
|
|
26
27
|
from nat.data_models.api_server import ChatResponse
|
|
28
|
+
from nat.data_models.component_ref import FunctionGroupRef
|
|
27
29
|
from nat.data_models.component_ref import FunctionRef
|
|
28
|
-
from nat.data_models.
|
|
29
|
-
from nat.data_models.
|
|
30
|
+
from nat.data_models.optimizable import OptimizableField
|
|
31
|
+
from nat.data_models.optimizable import OptimizableMixin
|
|
32
|
+
from nat.data_models.optimizable import SearchSpace
|
|
30
33
|
from nat.utils.type_converter import GlobalTypeConverter
|
|
31
34
|
|
|
32
35
|
logger = logging.getLogger(__name__)
|
|
33
36
|
|
|
34
37
|
|
|
35
|
-
class ReActAgentWorkflowConfig(
|
|
38
|
+
class ReActAgentWorkflowConfig(AgentBaseConfig, OptimizableMixin, name="react_agent"):
|
|
36
39
|
"""
|
|
37
40
|
Defines a NAT function that uses a ReAct Agent performs reasoning inbetween tool calls, and utilizes the
|
|
38
41
|
tool names and descriptions to select the optimal tool.
|
|
39
42
|
"""
|
|
40
|
-
|
|
41
|
-
tool_names: list[FunctionRef] = Field(
|
|
42
|
-
|
|
43
|
-
llm_name: LLMRef = Field(description="The LLM model to use with the react agent.")
|
|
44
|
-
verbose: bool = Field(default=False, description="Set the verbosity of the react agent's logging.")
|
|
43
|
+
description: str = Field(default="ReAct Agent Workflow", description="The description of this functions use.")
|
|
44
|
+
tool_names: list[FunctionRef | FunctionGroupRef] = Field(
|
|
45
|
+
default_factory=list, description="The list of tools to provide to the react agent.")
|
|
45
46
|
retry_agent_response_parsing_errors: bool = Field(
|
|
46
47
|
default=True,
|
|
47
48
|
validation_alias=AliasChoices("retry_agent_response_parsing_errors", "retry_parsing_errors"),
|
|
@@ -60,7 +61,10 @@ class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
|
|
|
60
61
|
description="Whether to pass tool call errors to agent. If False, failed tool calls will raise an exception.")
|
|
61
62
|
include_tool_input_schema_in_tool_description: bool = Field(
|
|
62
63
|
default=True, description="Specify inclusion of tool input schemas in the prompt.")
|
|
63
|
-
|
|
64
|
+
normalize_tool_input_quotes: bool = Field(
|
|
65
|
+
default=True,
|
|
66
|
+
description="Whether to replace single quotes with double quotes in the tool input. "
|
|
67
|
+
"This is useful for tools that expect structured json input.")
|
|
64
68
|
system_prompt: str | None = Field(
|
|
65
69
|
default=None,
|
|
66
70
|
description="Provides the SYSTEM_PROMPT to use with the agent") # defaults to SYSTEM_PROMPT in prompt.py
|
|
@@ -68,15 +72,21 @@ class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
|
|
|
68
72
|
use_openai_api: bool = Field(default=False,
|
|
69
73
|
description=("Use OpenAI API for the input/output types to the function. "
|
|
70
74
|
"If False, strings will be used."))
|
|
71
|
-
additional_instructions: str | None =
|
|
72
|
-
default=None,
|
|
75
|
+
additional_instructions: str | None = OptimizableField(
|
|
76
|
+
default=None,
|
|
77
|
+
description="Additional instructions to provide to the agent in addition to the base prompt.",
|
|
78
|
+
space=SearchSpace(
|
|
79
|
+
is_prompt=True,
|
|
80
|
+
prompt="No additional instructions.",
|
|
81
|
+
prompt_purpose="Additional instructions to provide to the agent in addition to the base prompt.",
|
|
82
|
+
))
|
|
73
83
|
|
|
74
84
|
|
|
75
85
|
@register_function(config_type=ReActAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
76
86
|
async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builder):
|
|
77
87
|
from langchain.schema import BaseMessage
|
|
78
88
|
from langchain_core.messages import trim_messages
|
|
79
|
-
from langgraph.graph.
|
|
89
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
80
90
|
|
|
81
91
|
from nat.agent.base import AGENT_LOG_PREFIX
|
|
82
92
|
from nat.agent.react_agent.agent import ReActAgentGraph
|
|
@@ -89,23 +99,36 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
89
99
|
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
90
100
|
# the agent can run any installed tool, simply install the tool and add it to the config file
|
|
91
101
|
# the sample tool provided can easily be copied or changed
|
|
92
|
-
tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
102
|
+
tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
93
103
|
if not tools:
|
|
94
104
|
raise ValueError(f"No tools specified for ReAct Agent '{config.llm_name}'")
|
|
95
105
|
# configure callbacks, for sending intermediate steps
|
|
96
106
|
# construct the ReAct Agent Graph from the configured llm, prompt, and tools
|
|
97
|
-
graph:
|
|
107
|
+
graph: CompiledStateGraph = await ReActAgentGraph(
|
|
98
108
|
llm=llm,
|
|
99
109
|
prompt=prompt,
|
|
100
110
|
tools=tools,
|
|
101
111
|
use_tool_schema=config.include_tool_input_schema_in_tool_description,
|
|
102
112
|
detailed_logs=config.verbose,
|
|
113
|
+
log_response_max_chars=config.log_response_max_chars,
|
|
103
114
|
retry_agent_response_parsing_errors=config.retry_agent_response_parsing_errors,
|
|
104
115
|
parse_agent_response_max_retries=config.parse_agent_response_max_retries,
|
|
105
116
|
tool_call_max_retries=config.tool_call_max_retries,
|
|
106
|
-
pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent
|
|
117
|
+
pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent,
|
|
118
|
+
normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()
|
|
107
119
|
|
|
108
120
|
async def _response_fn(input_message: ChatRequest) -> ChatResponse:
|
|
121
|
+
"""
|
|
122
|
+
Main workflow entry function for the ReAct Agent.
|
|
123
|
+
|
|
124
|
+
This function invokes the ReAct Agent Graph and returns the response.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
input_message (ChatRequest): The input message to process
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
ChatResponse: The response from the agent or error message
|
|
131
|
+
"""
|
|
109
132
|
try:
|
|
110
133
|
# initialize the starting state with the user query
|
|
111
134
|
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
|
|
@@ -125,15 +148,12 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
125
148
|
|
|
126
149
|
# get and return the output from the state
|
|
127
150
|
state = ReActGraphState(**state)
|
|
128
|
-
output_message = state.messages[-1]
|
|
151
|
+
output_message = state.messages[-1]
|
|
129
152
|
return ChatResponse.from_string(str(output_message.content))
|
|
130
153
|
|
|
131
154
|
except Exception as ex:
|
|
132
|
-
logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, ex
|
|
133
|
-
|
|
134
|
-
if config.verbose:
|
|
135
|
-
return ChatResponse.from_string(str(ex))
|
|
136
|
-
return ChatResponse.from_string("I seem to be having a problem.")
|
|
155
|
+
logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, str(ex))
|
|
156
|
+
raise RuntimeError
|
|
137
157
|
|
|
138
158
|
if (config.use_openai_api):
|
|
139
159
|
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
@@ -23,25 +23,22 @@ from nat.builder.builder import Builder
|
|
|
23
23
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
24
24
|
from nat.builder.function_info import FunctionInfo
|
|
25
25
|
from nat.cli.register_workflow import register_function
|
|
26
|
+
from nat.data_models.agent import AgentBaseConfig
|
|
26
27
|
from nat.data_models.api_server import ChatRequest
|
|
27
28
|
from nat.data_models.component_ref import FunctionRef
|
|
28
|
-
from nat.data_models.component_ref import LLMRef
|
|
29
|
-
from nat.data_models.function import FunctionBaseConfig
|
|
30
29
|
|
|
31
30
|
logger = logging.getLogger(__name__)
|
|
32
31
|
|
|
33
32
|
|
|
34
|
-
class ReasoningFunctionConfig(
|
|
33
|
+
class ReasoningFunctionConfig(AgentBaseConfig, name="reasoning_agent"):
|
|
35
34
|
"""
|
|
36
35
|
Defines a NAT function that performs reasoning on the input data.
|
|
37
36
|
Output is passed to the next function in the workflow.
|
|
38
37
|
|
|
39
38
|
Designed to be used with an InterceptingFunction.
|
|
40
39
|
"""
|
|
41
|
-
|
|
42
|
-
llm_name: LLMRef = Field(description="The name of the LLM to use for reasoning.")
|
|
40
|
+
description: str = Field(default="Reasoning Agent", description="The description of this function's use.")
|
|
43
41
|
augmented_fn: FunctionRef = Field(description="The name of the function to reason on.")
|
|
44
|
-
verbose: bool = Field(default=False, description="Whether to log detailed information.")
|
|
45
42
|
reasoning_prompt_template: str = Field(
|
|
46
43
|
default=("You are an expert reasoning model task with creating a detailed execution plan"
|
|
47
44
|
" for a system that has the following description:\n\n"
|
|
@@ -102,7 +99,7 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
|
|
|
102
99
|
llm: BaseChatModel = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
103
100
|
|
|
104
101
|
# Get the augmented function's description
|
|
105
|
-
augmented_function = builder.get_function(config.augmented_fn)
|
|
102
|
+
augmented_function = await builder.get_function(config.augmented_fn)
|
|
106
103
|
|
|
107
104
|
# For now, we rely on runtime checking for type conversion
|
|
108
105
|
|
|
@@ -113,11 +110,16 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
|
|
|
113
110
|
f"function without a description.")
|
|
114
111
|
|
|
115
112
|
# Get the function dependencies of the augmented function
|
|
116
|
-
|
|
113
|
+
function_dependencies = builder.get_function_dependencies(config.augmented_fn)
|
|
114
|
+
function_used_tools = set()
|
|
115
|
+
function_used_tools.update(function_dependencies.functions)
|
|
116
|
+
for function_group in function_dependencies.function_groups:
|
|
117
|
+
function_used_tools.update(builder.get_function_group_dependencies(function_group).functions)
|
|
118
|
+
|
|
117
119
|
tool_names_with_desc: list[tuple[str, str]] = []
|
|
118
120
|
|
|
119
121
|
for tool in function_used_tools:
|
|
120
|
-
tool_impl = builder.get_function(tool)
|
|
122
|
+
tool_impl = await builder.get_function(tool)
|
|
121
123
|
tool_names_with_desc.append((tool, tool_impl.description if hasattr(tool_impl, "description") else ""))
|
|
122
124
|
|
|
123
125
|
# Draft the reasoning prompt for the augmented function
|
nat/agent/register.py
CHANGED
|
@@ -13,10 +13,10 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
# pylint: disable=unused-import
|
|
17
16
|
# flake8: noqa
|
|
18
17
|
|
|
19
18
|
# Import any workflows which need to be automatically registered here
|
|
19
|
+
from .prompt_optimizer import register as prompt_optimizer
|
|
20
20
|
from .react_agent import register as react_agent
|
|
21
21
|
from .reasoning_agent import reasoning_agent
|
|
22
22
|
from .rewoo_agent import register as rewoo_agent
|