nvidia-nat 1.3a20250819__py3-none-any.whl → 1.3.0a20250823__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/base.py +16 -0
- nat/agent/react_agent/agent.py +38 -13
- nat/agent/react_agent/prompt.py +4 -1
- nat/agent/react_agent/register.py +1 -1
- nat/agent/register.py +0 -1
- nat/agent/rewoo_agent/agent.py +6 -3
- nat/agent/rewoo_agent/prompt.py +3 -0
- nat/agent/rewoo_agent/register.py +4 -3
- nat/agent/tool_calling_agent/agent.py +92 -22
- nat/agent/tool_calling_agent/register.py +9 -13
- nat/authentication/api_key/api_key_auth_provider.py +1 -1
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +1 -1
- nat/builder/context.py +9 -1
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +5 -7
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +3 -0
- nat/builder/workflow_builder.py +0 -1
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/info/list_mcp.py +3 -4
- nat/cli/commands/registry/search.py +14 -16
- nat/cli/commands/start.py +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +3 -0
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +0 -1
- nat/cli/type_registry.py +7 -9
- nat/data_models/config.py +1 -1
- nat/data_models/evaluate.py +1 -1
- nat/data_models/function_dependencies.py +6 -6
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/model_gated_field_mixin.py +125 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +36 -0
- nat/data_models/top_p_mixin.py +36 -0
- nat/embedder/azure_openai_embedder.py +46 -0
- nat/embedder/openai_embedder.py +1 -2
- nat/embedder/register.py +1 -1
- nat/eval/config.py +2 -0
- nat/eval/dataset_handler/dataset_handler.py +5 -6
- nat/eval/evaluate.py +64 -20
- nat/eval/rag_evaluator/register.py +2 -2
- nat/eval/register.py +0 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +0 -3
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +14 -7
- nat/experimental/test_time_compute/models/strategy_base.py +3 -2
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +0 -2
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +48 -49
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +1 -1
- nat/front_ends/register.py +0 -1
- nat/llm/aws_bedrock_llm.py +3 -3
- nat/llm/azure_openai_llm.py +49 -0
- nat/llm/nim_llm.py +4 -4
- nat/llm/openai_llm.py +4 -4
- nat/llm/register.py +1 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/meta/pypi.md +9 -9
- nat/object_store/models.py +2 -0
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/register.py +3 -3
- nat/profiler/callbacks/langchain_callback_handler.py +9 -2
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +1 -4
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/profile_runner.py +13 -8
- nat/registry_handlers/package_utils.py +0 -1
- nat/registry_handlers/pypi/pypi_handler.py +20 -23
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +8 -9
- nat/retriever/register.py +0 -1
- nat/runtime/session.py +23 -8
- nat/settings/global_settings.py +13 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +1 -1
- nat/tool/mcp/mcp_tool.py +1 -1
- nat/tool/register.py +0 -1
- nat/utils/data_models/schema_validator.py +2 -2
- nat/utils/exception_handlers/automatic_retries.py +0 -2
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +2 -2
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +4 -6
- nat/utils/type_utils.py +4 -4
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/METADATA +17 -15
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/RECORD +107 -100
- nvidia_nat-1.3.0a20250823.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/top_level.txt +1 -0
- nvidia_nat-1.3a20250819.dist-info/licenses/LICENSE-3rd-party.txt +0 -3686
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/licenses/LICENSE.md +0 -0
aiq/__init__.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import importlib
|
|
17
|
+
import importlib.abc
|
|
18
|
+
import importlib.util
|
|
19
|
+
import sys
|
|
20
|
+
import warnings
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CompatFinder(importlib.abc.MetaPathFinder):
|
|
24
|
+
|
|
25
|
+
def __init__(self, alias_prefix, target_prefix):
|
|
26
|
+
self.alias_prefix = alias_prefix
|
|
27
|
+
self.target_prefix = target_prefix
|
|
28
|
+
|
|
29
|
+
def find_spec(self, fullname, path, target=None):
|
|
30
|
+
if fullname == self.alias_prefix or fullname.startswith(self.alias_prefix + "."):
|
|
31
|
+
# Map aiq.something -> nat.something
|
|
32
|
+
target_name = self.target_prefix + fullname[len(self.alias_prefix):]
|
|
33
|
+
spec = importlib.util.find_spec(target_name)
|
|
34
|
+
if spec is None:
|
|
35
|
+
return None
|
|
36
|
+
# Wrap the loader so it loads under the alias name
|
|
37
|
+
return importlib.util.spec_from_loader(fullname, CompatLoader(fullname, target_name))
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class CompatLoader(importlib.abc.Loader):
|
|
42
|
+
|
|
43
|
+
def __init__(self, alias_name, target_name):
|
|
44
|
+
self.alias_name = alias_name
|
|
45
|
+
self.target_name = target_name
|
|
46
|
+
|
|
47
|
+
def create_module(self, spec):
|
|
48
|
+
# Reuse the actual module so there's only one instance
|
|
49
|
+
target_module = importlib.import_module(self.target_name)
|
|
50
|
+
sys.modules[self.alias_name] = target_module
|
|
51
|
+
return target_module
|
|
52
|
+
|
|
53
|
+
def exec_module(self, module):
|
|
54
|
+
# Nothing to execute since the target is already loaded
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# Register the compatibility finder
|
|
59
|
+
sys.meta_path.insert(0, CompatFinder("aiq", "nat"))
|
|
60
|
+
|
|
61
|
+
warnings.warn(
|
|
62
|
+
"!!! The 'aiq' namespace is deprecated and will be removed in a future release. "
|
|
63
|
+
"Please use the 'nat' namespace instead.",
|
|
64
|
+
DeprecationWarning,
|
|
65
|
+
stacklevel=2,
|
|
66
|
+
)
|
nat/agent/base.py
CHANGED
|
@@ -234,6 +234,22 @@ class BaseAgent(ABC):
|
|
|
234
234
|
logger.warning("%s Unexpected error during JSON parsing: %s", AGENT_LOG_PREFIX, str(e))
|
|
235
235
|
return {"error": f"Unexpected parsing error: {str(e)}", "original_string": json_string}
|
|
236
236
|
|
|
237
|
+
def _get_chat_history(self, messages: list[BaseMessage]) -> str:
|
|
238
|
+
"""
|
|
239
|
+
Get the chat history excluding the last message.
|
|
240
|
+
|
|
241
|
+
Parameters
|
|
242
|
+
----------
|
|
243
|
+
messages : list[BaseMessage]
|
|
244
|
+
The messages to get the chat history from
|
|
245
|
+
|
|
246
|
+
Returns
|
|
247
|
+
-------
|
|
248
|
+
str
|
|
249
|
+
The chat history excluding the last message
|
|
250
|
+
"""
|
|
251
|
+
return "\n".join([f"{message.type}: {message.content}" for message in messages[:-1]])
|
|
252
|
+
|
|
237
253
|
@abstractmethod
|
|
238
254
|
async def _build_graph(self, state_schema: type) -> CompiledGraph:
|
|
239
255
|
pass
|
nat/agent/react_agent/agent.py
CHANGED
|
@@ -14,20 +14,23 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import json
|
|
17
|
-
# pylint: disable=R0917
|
|
18
17
|
import logging
|
|
18
|
+
import re
|
|
19
|
+
import typing
|
|
19
20
|
from json import JSONDecodeError
|
|
20
21
|
|
|
21
22
|
from langchain_core.agents import AgentAction
|
|
22
23
|
from langchain_core.agents import AgentFinish
|
|
23
24
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
|
24
25
|
from langchain_core.language_models import BaseChatModel
|
|
26
|
+
from langchain_core.language_models import LanguageModelInput
|
|
25
27
|
from langchain_core.messages.ai import AIMessage
|
|
26
28
|
from langchain_core.messages.base import BaseMessage
|
|
27
29
|
from langchain_core.messages.human import HumanMessage
|
|
28
30
|
from langchain_core.messages.tool import ToolMessage
|
|
29
31
|
from langchain_core.prompts import ChatPromptTemplate
|
|
30
32
|
from langchain_core.prompts import MessagesPlaceholder
|
|
33
|
+
from langchain_core.runnables import Runnable
|
|
31
34
|
from langchain_core.runnables.config import RunnableConfig
|
|
32
35
|
from langchain_core.tools import BaseTool
|
|
33
36
|
from pydantic import BaseModel
|
|
@@ -44,7 +47,9 @@ from nat.agent.react_agent.output_parser import ReActOutputParser
|
|
|
44
47
|
from nat.agent.react_agent.output_parser import ReActOutputParserException
|
|
45
48
|
from nat.agent.react_agent.prompt import SYSTEM_PROMPT
|
|
46
49
|
from nat.agent.react_agent.prompt import USER_PROMPT
|
|
47
|
-
|
|
50
|
+
|
|
51
|
+
if typing.TYPE_CHECKING:
|
|
52
|
+
from nat.agent.react_agent.register import ReActAgentWorkflowConfig
|
|
48
53
|
|
|
49
54
|
logger = logging.getLogger(__name__)
|
|
50
55
|
|
|
@@ -94,11 +99,27 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
94
99
|
f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
|
|
95
100
|
prompt = prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
|
|
96
101
|
# construct the ReAct Agent
|
|
97
|
-
|
|
98
|
-
self.agent = prompt | bound_llm
|
|
102
|
+
self.agent = prompt | self._maybe_bind_llm_and_yield()
|
|
99
103
|
self.tools_dict = {tool.name: tool for tool in tools}
|
|
100
104
|
logger.debug("%s Initialized ReAct Agent Graph", AGENT_LOG_PREFIX)
|
|
101
105
|
|
|
106
|
+
def _maybe_bind_llm_and_yield(self) -> Runnable[LanguageModelInput, BaseMessage]:
|
|
107
|
+
"""
|
|
108
|
+
Bind additional parameters to the LLM if needed
|
|
109
|
+
- if the LLM is a smart model, no need to bind any additional parameters
|
|
110
|
+
- if the LLM is a non-smart model, bind a stop sequence to the LLM
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Runnable[LanguageModelInput, BaseMessage]: The LLM with any additional parameters bound.
|
|
114
|
+
"""
|
|
115
|
+
# models that don't need (or don't support)a stop sequence
|
|
116
|
+
smart_models = re.compile(r"gpt-?5", re.IGNORECASE)
|
|
117
|
+
if any(smart_models.search(getattr(self.llm, model, "")) for model in ["model", "model_name"]):
|
|
118
|
+
# no need to bind any additional parameters to the LLM
|
|
119
|
+
return self.llm
|
|
120
|
+
# add a stop sequence to the LLM
|
|
121
|
+
return self.llm.bind(stop=["Observation:"])
|
|
122
|
+
|
|
102
123
|
def _get_tool(self, tool_name: str):
|
|
103
124
|
try:
|
|
104
125
|
return self.tools_dict.get(tool_name)
|
|
@@ -124,17 +145,19 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
124
145
|
if len(state.messages) == 0:
|
|
125
146
|
raise RuntimeError('No input received in state: "messages"')
|
|
126
147
|
# to check is any human input passed or not, if no input passed Agent will return the state
|
|
127
|
-
content = str(state.messages[
|
|
148
|
+
content = str(state.messages[-1].content)
|
|
128
149
|
if content.strip() == "":
|
|
129
150
|
logger.error("%s No human input passed to the agent.", AGENT_LOG_PREFIX)
|
|
130
151
|
state.messages += [AIMessage(content=NO_INPUT_ERROR_MESSAGE)]
|
|
131
152
|
return state
|
|
132
153
|
question = content
|
|
133
154
|
logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
|
|
134
|
-
|
|
155
|
+
chat_history = self._get_chat_history(state.messages)
|
|
135
156
|
output_message = await self._stream_llm(
|
|
136
157
|
self.agent,
|
|
137
|
-
{
|
|
158
|
+
{
|
|
159
|
+
"question": question, "chat_history": chat_history
|
|
160
|
+
},
|
|
138
161
|
RunnableConfig(callbacks=self.callbacks) # type: ignore
|
|
139
162
|
)
|
|
140
163
|
|
|
@@ -152,13 +175,15 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
152
175
|
tool_response = HumanMessage(content=tool_response_content)
|
|
153
176
|
agent_scratchpad.append(tool_response)
|
|
154
177
|
agent_scratchpad += working_state
|
|
155
|
-
|
|
178
|
+
chat_history = self._get_chat_history(state.messages)
|
|
179
|
+
question = str(state.messages[-1].content)
|
|
156
180
|
logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
|
|
157
181
|
|
|
158
|
-
output_message = await self._stream_llm(
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
182
|
+
output_message = await self._stream_llm(
|
|
183
|
+
self.agent, {
|
|
184
|
+
"question": question, "agent_scratchpad": agent_scratchpad, "chat_history": chat_history
|
|
185
|
+
},
|
|
186
|
+
RunnableConfig(callbacks=self.callbacks))
|
|
162
187
|
|
|
163
188
|
if self.detailed_logs:
|
|
164
189
|
logger.info(AGENT_CALL_LOG_MESSAGE, question, output_message.content)
|
|
@@ -326,7 +351,7 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
326
351
|
return True
|
|
327
352
|
|
|
328
353
|
|
|
329
|
-
def create_react_agent_prompt(config: ReActAgentWorkflowConfig) -> ChatPromptTemplate:
|
|
354
|
+
def create_react_agent_prompt(config: "ReActAgentWorkflowConfig") -> ChatPromptTemplate:
|
|
330
355
|
"""
|
|
331
356
|
Create a ReAct Agent prompt from the config.
|
|
332
357
|
|
nat/agent/react_agent/prompt.py
CHANGED
|
@@ -26,7 +26,7 @@ Use the following format exactly to ask the human to use a tool:
|
|
|
26
26
|
Question: the input question you must answer
|
|
27
27
|
Thought: you should always think about what to do
|
|
28
28
|
Action: the action to take, should be one of [{tool_names}]
|
|
29
|
-
Action Input: the input to the action (if there is no required input, include "Action Input: None")
|
|
29
|
+
Action Input: the input to the action (if there is no required input, include "Action Input: None")
|
|
30
30
|
Observation: wait for the human to respond with the result from the tool, do not assume the response
|
|
31
31
|
|
|
32
32
|
... (this Thought/Action/Action Input/Observation can repeat N times. If you do not need to use a tool, or after asking the human to use any tools and waiting for the human to respond, you might know the final answer.)
|
|
@@ -37,5 +37,8 @@ Final Answer: the final answer to the original input question
|
|
|
37
37
|
"""
|
|
38
38
|
|
|
39
39
|
USER_PROMPT = """
|
|
40
|
+
Previous conversation history:
|
|
41
|
+
{chat_history}
|
|
42
|
+
|
|
40
43
|
Question: {question}
|
|
41
44
|
"""
|
|
@@ -125,7 +125,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
125
125
|
|
|
126
126
|
# get and return the output from the state
|
|
127
127
|
state = ReActGraphState(**state)
|
|
128
|
-
output_message = state.messages[-1]
|
|
128
|
+
output_message = state.messages[-1]
|
|
129
129
|
return ChatResponse.from_string(str(output_message.content))
|
|
130
130
|
|
|
131
131
|
except Exception as ex:
|
nat/agent/register.py
CHANGED
nat/agent/rewoo_agent/agent.py
CHANGED
|
@@ -14,13 +14,13 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import json
|
|
17
|
-
# pylint: disable=R0917
|
|
18
17
|
import logging
|
|
19
18
|
from json import JSONDecodeError
|
|
20
19
|
|
|
21
20
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
|
22
21
|
from langchain_core.language_models import BaseChatModel
|
|
23
22
|
from langchain_core.messages.ai import AIMessage
|
|
23
|
+
from langchain_core.messages.base import BaseMessage
|
|
24
24
|
from langchain_core.messages.human import HumanMessage
|
|
25
25
|
from langchain_core.messages.tool import ToolMessage
|
|
26
26
|
from langchain_core.prompts.chat import ChatPromptTemplate
|
|
@@ -43,6 +43,7 @@ logger = logging.getLogger(__name__)
|
|
|
43
43
|
|
|
44
44
|
class ReWOOGraphState(BaseModel):
|
|
45
45
|
"""State schema for the ReWOO Agent Graph"""
|
|
46
|
+
messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReWOO Agent
|
|
46
47
|
task: HumanMessage = Field(default_factory=lambda: HumanMessage(content="")) # the task provided by user
|
|
47
48
|
plan: AIMessage = Field(
|
|
48
49
|
default_factory=lambda: AIMessage(content="")) # the plan generated by the planner to solve the task
|
|
@@ -183,10 +184,12 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
183
184
|
if not task:
|
|
184
185
|
logger.error("%s No task provided to the ReWOO Agent. Please provide a valid task.", AGENT_LOG_PREFIX)
|
|
185
186
|
return {"result": NO_INPUT_ERROR_MESSAGE}
|
|
186
|
-
|
|
187
|
+
chat_history = self._get_chat_history(state.messages)
|
|
187
188
|
plan = await self._stream_llm(
|
|
188
189
|
planner,
|
|
189
|
-
{
|
|
190
|
+
{
|
|
191
|
+
"task": task, "chat_history": chat_history
|
|
192
|
+
},
|
|
190
193
|
RunnableConfig(callbacks=self.callbacks) # type: ignore
|
|
191
194
|
)
|
|
192
195
|
|
nat/agent/rewoo_agent/prompt.py
CHANGED
|
@@ -124,15 +124,16 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
124
124
|
token_counter=len,
|
|
125
125
|
start_on="human",
|
|
126
126
|
include_system=True)
|
|
127
|
-
|
|
128
|
-
|
|
127
|
+
|
|
128
|
+
task = HumanMessage(content=messages[-1].content)
|
|
129
|
+
state = ReWOOGraphState(messages=messages, task=task)
|
|
129
130
|
|
|
130
131
|
# run the ReWOO Agent Graph
|
|
131
132
|
state = await graph.ainvoke(state)
|
|
132
133
|
|
|
133
134
|
# get and return the output from the state
|
|
134
135
|
state = ReWOOGraphState(**state)
|
|
135
|
-
output_message = state.result.content
|
|
136
|
+
output_message = state.result.content
|
|
136
137
|
return ChatResponse.from_string(output_message)
|
|
137
138
|
|
|
138
139
|
except Exception as ex:
|
|
@@ -13,13 +13,15 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
# pylint: disable=R0917
|
|
17
16
|
import logging
|
|
17
|
+
import typing
|
|
18
18
|
|
|
19
19
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
|
20
20
|
from langchain_core.language_models import BaseChatModel
|
|
21
|
+
from langchain_core.messages import SystemMessage
|
|
21
22
|
from langchain_core.messages.base import BaseMessage
|
|
22
|
-
from langchain_core.runnables import
|
|
23
|
+
from langchain_core.runnables import RunnableLambda
|
|
24
|
+
from langchain_core.runnables.config import RunnableConfig
|
|
23
25
|
from langchain_core.tools import BaseTool
|
|
24
26
|
from langgraph.prebuilt import ToolNode
|
|
25
27
|
from pydantic import BaseModel
|
|
@@ -30,6 +32,9 @@ from nat.agent.base import AGENT_LOG_PREFIX
|
|
|
30
32
|
from nat.agent.base import AgentDecision
|
|
31
33
|
from nat.agent.dual_node import DualNodeAgent
|
|
32
34
|
|
|
35
|
+
if typing.TYPE_CHECKING:
|
|
36
|
+
from nat.agent.tool_calling_agent.register import ToolCallAgentWorkflowConfig
|
|
37
|
+
|
|
33
38
|
logger = logging.getLogger(__name__)
|
|
34
39
|
|
|
35
40
|
|
|
@@ -43,22 +48,51 @@ class ToolCallAgentGraph(DualNodeAgent):
|
|
|
43
48
|
A tool Calling Agent utilizes the tool input parameters to select the optimal tool. Supports handling tool errors.
|
|
44
49
|
Argument "detailed_logs" toggles logging of inputs, outputs, and intermediate steps."""
|
|
45
50
|
|
|
46
|
-
def __init__(
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
llm: BaseChatModel,
|
|
54
|
+
tools: list[BaseTool],
|
|
55
|
+
prompt: str | None = None,
|
|
56
|
+
callbacks: list[AsyncCallbackHandler] = None,
|
|
57
|
+
detailed_logs: bool = False,
|
|
58
|
+
handle_tool_errors: bool = True,
|
|
59
|
+
):
|
|
52
60
|
super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
|
|
61
|
+
# some LLMs support tool calling
|
|
62
|
+
# these models accept the tool's input schema and decide when to use a tool based on the input's relevance
|
|
63
|
+
try:
|
|
64
|
+
# in tool calling agents, we bind the tools to the LLM, to pass the tools' input schemas at runtime
|
|
65
|
+
self.bound_llm = llm.bind_tools(tools)
|
|
66
|
+
except NotImplementedError as ex:
|
|
67
|
+
logger.error("%s Failed to bind tools: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
|
|
68
|
+
raise ex
|
|
69
|
+
|
|
70
|
+
if prompt is not None:
|
|
71
|
+
system_prompt = SystemMessage(content=prompt)
|
|
72
|
+
prompt_runnable = RunnableLambda(
|
|
73
|
+
lambda state: [system_prompt] + state.get("messages", []),
|
|
74
|
+
name="SystemPrompt",
|
|
75
|
+
)
|
|
76
|
+
else:
|
|
77
|
+
prompt_runnable = RunnableLambda(
|
|
78
|
+
lambda state: state.get("messages", []),
|
|
79
|
+
name="PromptPassthrough",
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
self.agent = prompt_runnable | self.bound_llm
|
|
83
|
+
|
|
53
84
|
self.tool_caller = ToolNode(tools, handle_tool_errors=handle_tool_errors)
|
|
54
85
|
logger.debug("%s Initialized Tool Calling Agent Graph", AGENT_LOG_PREFIX)
|
|
55
86
|
|
|
56
87
|
async def agent_node(self, state: ToolCallAgentGraphState):
|
|
57
88
|
try:
|
|
58
|
-
logger.debug(
|
|
89
|
+
logger.debug("%s Starting the Tool Calling Agent Node", AGENT_LOG_PREFIX)
|
|
59
90
|
if len(state.messages) == 0:
|
|
60
91
|
raise RuntimeError('No input received in state: "messages"')
|
|
61
|
-
response = await self.
|
|
92
|
+
response = await self.agent.ainvoke(
|
|
93
|
+
{"messages": state.messages},
|
|
94
|
+
config=RunnableConfig(callbacks=self.callbacks),
|
|
95
|
+
)
|
|
62
96
|
if self.detailed_logs:
|
|
63
97
|
agent_input = "\n".join(str(message.content) for message in state.messages)
|
|
64
98
|
logger.info(AGENT_CALL_LOG_MESSAGE, agent_input, response)
|
|
@@ -75,16 +109,18 @@ class ToolCallAgentGraph(DualNodeAgent):
|
|
|
75
109
|
last_message = state.messages[-1]
|
|
76
110
|
if last_message.tool_calls:
|
|
77
111
|
# the agent wants to call a tool
|
|
78
|
-
logger.debug(
|
|
112
|
+
logger.debug("%s Agent is calling a tool", AGENT_LOG_PREFIX)
|
|
79
113
|
return AgentDecision.TOOL
|
|
80
114
|
if self.detailed_logs:
|
|
81
115
|
logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, state.messages[-1].content)
|
|
82
116
|
return AgentDecision.END
|
|
83
117
|
except Exception as ex:
|
|
84
|
-
logger.exception(
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
118
|
+
logger.exception(
|
|
119
|
+
"%s Failed to determine whether agent is calling a tool: %s",
|
|
120
|
+
AGENT_LOG_PREFIX,
|
|
121
|
+
ex,
|
|
122
|
+
exc_info=True,
|
|
123
|
+
)
|
|
88
124
|
logger.warning("%s Ending graph traversal", AGENT_LOG_PREFIX)
|
|
89
125
|
return AgentDecision.END
|
|
90
126
|
|
|
@@ -92,14 +128,15 @@ class ToolCallAgentGraph(DualNodeAgent):
|
|
|
92
128
|
try:
|
|
93
129
|
logger.debug("%s Starting Tool Node", AGENT_LOG_PREFIX)
|
|
94
130
|
tool_calls = state.messages[-1].tool_calls
|
|
95
|
-
tools = [tool.get(
|
|
131
|
+
tools = [tool.get("name") for tool in tool_calls]
|
|
96
132
|
tool_input = state.messages[-1]
|
|
97
|
-
tool_response = await self.tool_caller.ainvoke(
|
|
98
|
-
|
|
99
|
-
|
|
133
|
+
tool_response = await self.tool_caller.ainvoke(
|
|
134
|
+
input={"messages": [tool_input]},
|
|
135
|
+
config=RunnableConfig(callbacks=self.callbacks, configurable={}),
|
|
136
|
+
)
|
|
100
137
|
# this configurable = {} argument is needed due to a bug in LangGraph PreBuilt ToolNode ^
|
|
101
138
|
|
|
102
|
-
for response in tool_response.get(
|
|
139
|
+
for response in tool_response.get("messages"):
|
|
103
140
|
if self.detailed_logs:
|
|
104
141
|
self._log_tool_response(str(tools), str(tool_input), response.content)
|
|
105
142
|
state.messages += [response]
|
|
@@ -112,8 +149,41 @@ class ToolCallAgentGraph(DualNodeAgent):
|
|
|
112
149
|
async def build_graph(self):
|
|
113
150
|
try:
|
|
114
151
|
await super()._build_graph(state_schema=ToolCallAgentGraphState)
|
|
115
|
-
logger.debug(
|
|
152
|
+
logger.debug(
|
|
153
|
+
"%s Tool Calling Agent Graph built and compiled successfully",
|
|
154
|
+
AGENT_LOG_PREFIX,
|
|
155
|
+
)
|
|
116
156
|
return self.graph
|
|
117
157
|
except Exception as ex:
|
|
118
|
-
logger.exception(
|
|
158
|
+
logger.exception(
|
|
159
|
+
"%s Failed to build Tool Calling Agent Graph: %s",
|
|
160
|
+
AGENT_LOG_PREFIX,
|
|
161
|
+
ex,
|
|
162
|
+
exc_info=ex,
|
|
163
|
+
)
|
|
119
164
|
raise ex
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def create_tool_calling_agent_prompt(config: "ToolCallAgentWorkflowConfig") -> str | None:
|
|
168
|
+
"""
|
|
169
|
+
Create a Tool Calling Agent prompt from the config.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
config (ToolCallAgentWorkflowConfig): The config to use for the prompt.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
ChatPromptTemplate: The Tool Calling Agent prompt.
|
|
176
|
+
"""
|
|
177
|
+
# the Tool Calling Agent prompt can be customized via config option system_prompt and additional_instructions.
|
|
178
|
+
|
|
179
|
+
if config.system_prompt:
|
|
180
|
+
prompt_str = config.system_prompt
|
|
181
|
+
else:
|
|
182
|
+
prompt_str = ""
|
|
183
|
+
|
|
184
|
+
if config.additional_instructions:
|
|
185
|
+
prompt_str += f" {config.additional_instructions}"
|
|
186
|
+
|
|
187
|
+
if len(prompt_str) > 0:
|
|
188
|
+
return prompt_str
|
|
189
|
+
return None
|
|
@@ -41,6 +41,9 @@ class ToolCallAgentWorkflowConfig(FunctionBaseConfig, name="tool_calling_agent")
|
|
|
41
41
|
handle_tool_errors: bool = Field(default=True, description="Specify ability to handle tool calling errors.")
|
|
42
42
|
description: str = Field(default="Tool Calling Agent Workflow", description="Description of this functions use.")
|
|
43
43
|
max_iterations: int = Field(default=15, description="Number of tool calls before stoping the tool calling agent.")
|
|
44
|
+
system_prompt: str | None = Field(default=None, description="Provides the system prompt to use with the agent.")
|
|
45
|
+
additional_instructions: str | None = Field(default=None,
|
|
46
|
+
description="Additional instructions appended to the system prompt.")
|
|
44
47
|
|
|
45
48
|
|
|
46
49
|
@register_function(config_type=ToolCallAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
@@ -49,10 +52,11 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
49
52
|
from langgraph.graph.graph import CompiledGraph
|
|
50
53
|
|
|
51
54
|
from nat.agent.base import AGENT_LOG_PREFIX
|
|
55
|
+
from nat.agent.tool_calling_agent.agent import ToolCallAgentGraph
|
|
56
|
+
from nat.agent.tool_calling_agent.agent import ToolCallAgentGraphState
|
|
57
|
+
from nat.agent.tool_calling_agent.agent import create_tool_calling_agent_prompt
|
|
52
58
|
|
|
53
|
-
|
|
54
|
-
from .agent import ToolCallAgentGraphState
|
|
55
|
-
|
|
59
|
+
prompt = create_tool_calling_agent_prompt(config)
|
|
56
60
|
# we can choose an LLM for the ReAct agent in the config file
|
|
57
61
|
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
58
62
|
# the agent can run any installed tool, simply install the tool and add it to the config file
|
|
@@ -61,18 +65,10 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
61
65
|
if not tools:
|
|
62
66
|
raise ValueError(f"No tools specified for Tool Calling Agent '{config.llm_name}'")
|
|
63
67
|
|
|
64
|
-
# some LLMs support tool calling
|
|
65
|
-
# these models accept the tool's input schema and decide when to use a tool based on the input's relevance
|
|
66
|
-
try:
|
|
67
|
-
# in tool calling agents, we bind the tools to the LLM, to pass the tools' input schemas at runtime
|
|
68
|
-
llm = llm.bind_tools(tools)
|
|
69
|
-
except NotImplementedError as ex:
|
|
70
|
-
logger.error("%s Failed to bind tools: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
|
|
71
|
-
raise ex
|
|
72
|
-
|
|
73
68
|
# construct the Tool Calling Agent Graph from the configured llm, and tools
|
|
74
69
|
graph: CompiledGraph = await ToolCallAgentGraph(llm=llm,
|
|
75
70
|
tools=tools,
|
|
71
|
+
prompt=prompt,
|
|
76
72
|
detailed_logs=config.verbose,
|
|
77
73
|
handle_tool_errors=config.handle_tool_errors).build_graph()
|
|
78
74
|
|
|
@@ -90,7 +86,7 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
90
86
|
|
|
91
87
|
# get and return the output from the state
|
|
92
88
|
state = ToolCallAgentGraphState(**state)
|
|
93
|
-
output_message = state.messages[-1]
|
|
89
|
+
output_message = state.messages[-1]
|
|
94
90
|
return output_message.content
|
|
95
91
|
except Exception as ex:
|
|
96
92
|
logger.exception("%s Tool Calling Agent failed with exception: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
|
|
@@ -31,7 +31,7 @@ class APIKeyAuthProvider(AuthProviderBase[APIKeyAuthProviderConfig]):
|
|
|
31
31
|
# fmt: off
|
|
32
32
|
def __init__(self,
|
|
33
33
|
config: APIKeyAuthProviderConfig,
|
|
34
|
-
config_name: str | None = None) -> None:
|
|
34
|
+
config_name: str | None = None) -> None:
|
|
35
35
|
assert isinstance(config, APIKeyAuthProviderConfig), ("Config is not APIKeyAuthProviderConfig")
|
|
36
36
|
super().__init__(config)
|
|
37
37
|
# fmt: on
|
nat/authentication/register.py
CHANGED
nat/builder/builder.py
CHANGED
|
@@ -58,7 +58,7 @@ class UserManagerHolder():
|
|
|
58
58
|
return self._context.user_manager.get_id()
|
|
59
59
|
|
|
60
60
|
|
|
61
|
-
class Builder(ABC):
|
|
61
|
+
class Builder(ABC):
|
|
62
62
|
|
|
63
63
|
@abstractmethod
|
|
64
64
|
async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
|
nat/builder/context.py
CHANGED
|
@@ -38,7 +38,7 @@ from nat.utils.reactive.subject import Subject
|
|
|
38
38
|
|
|
39
39
|
class Singleton(type):
|
|
40
40
|
|
|
41
|
-
def __init__(cls, name, bases, dict):
|
|
41
|
+
def __init__(cls, name, bases, dict):
|
|
42
42
|
super(Singleton, cls).__init__(name, bases, dict)
|
|
43
43
|
cls.instance = None
|
|
44
44
|
|
|
@@ -65,6 +65,7 @@ class ContextState(metaclass=Singleton):
|
|
|
65
65
|
|
|
66
66
|
def __init__(self):
|
|
67
67
|
self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None)
|
|
68
|
+
self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
|
|
68
69
|
self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
|
|
69
70
|
self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
|
|
70
71
|
self.metadata: ContextVar[RequestAttributes] = ContextVar("request_attributes", default=RequestAttributes())
|
|
@@ -165,6 +166,13 @@ class Context:
|
|
|
165
166
|
"""
|
|
166
167
|
return self._context_state.conversation_id.get()
|
|
167
168
|
|
|
169
|
+
@property
|
|
170
|
+
def user_message_id(self) -> str | None:
|
|
171
|
+
"""
|
|
172
|
+
This property retrieves the user message ID which is the unique identifier for the current user message.
|
|
173
|
+
"""
|
|
174
|
+
return self._context_state.user_message_id.get()
|
|
175
|
+
|
|
168
176
|
@contextmanager
|
|
169
177
|
def push_active_function(self, function_name: str, input_data: typing.Any | None):
|
|
170
178
|
"""
|
nat/builder/function_base.py
CHANGED
|
@@ -111,7 +111,7 @@ class FunctionBase(typing.Generic[InputT, StreamingOutputT, SingleOutputT], ABC)
|
|
|
111
111
|
ValueError
|
|
112
112
|
If the input type cannot be determined from the class definition
|
|
113
113
|
"""
|
|
114
|
-
for base_cls in self.__class__.__orig_bases__:
|
|
114
|
+
for base_cls in self.__class__.__orig_bases__:
|
|
115
115
|
|
|
116
116
|
base_cls_args = typing.get_args(base_cls)
|
|
117
117
|
|
|
@@ -196,7 +196,7 @@ class FunctionBase(typing.Generic[InputT, StreamingOutputT, SingleOutputT], ABC)
|
|
|
196
196
|
ValueError
|
|
197
197
|
If the streaming output type cannot be determined from the class definition
|
|
198
198
|
"""
|
|
199
|
-
for base_cls in self.__class__.__orig_bases__:
|
|
199
|
+
for base_cls in self.__class__.__orig_bases__:
|
|
200
200
|
|
|
201
201
|
base_cls_args = typing.get_args(base_cls)
|
|
202
202
|
|
|
@@ -269,7 +269,7 @@ class FunctionBase(typing.Generic[InputT, StreamingOutputT, SingleOutputT], ABC)
|
|
|
269
269
|
ValueError
|
|
270
270
|
If the single output type cannot be determined from the class definition
|
|
271
271
|
"""
|
|
272
|
-
for base_cls in self.__class__.__orig_bases__:
|
|
272
|
+
for base_cls in self.__class__.__orig_bases__:
|
|
273
273
|
|
|
274
274
|
base_cls_args = typing.get_args(base_cls)
|
|
275
275
|
|