nvidia-nat 1.3.dev0__py3-none-any.whl → 1.3.0.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/base.py +17 -0
- nat/agent/react_agent/agent.py +17 -10
- nat/agent/react_agent/prompt.py +4 -1
- nat/agent/rewoo_agent/agent.py +6 -2
- nat/agent/rewoo_agent/prompt.py +3 -0
- nat/agent/rewoo_agent/register.py +3 -2
- nat/agent/tool_calling_agent/agent.py +92 -21
- nat/agent/tool_calling_agent/register.py +8 -12
- nat/cli/type_registry.py +4 -4
- nat/embedder/azure_openai_embedder.py +46 -0
- nat/embedder/openai_embedder.py +1 -2
- nat/embedder/register.py +1 -0
- nat/llm/azure_openai_llm.py +50 -0
- nat/llm/register.py +1 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/models.py +2 -0
- nat/profiler/callbacks/langchain_callback_handler.py +8 -1
- {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0.dev2.dist-info}/METADATA +17 -15
- {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0.dev2.dist-info}/RECORD +25 -22
- nvidia_nat-1.3.0.dev2.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0.dev2.dist-info}/top_level.txt +1 -0
- nvidia_nat-1.3.dev0.dist-info/licenses/LICENSE-3rd-party.txt +0 -3686
- {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0.dev2.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0.dev2.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.dev0.dist-info → nvidia_nat-1.3.0.dev2.dist-info}/licenses/LICENSE.md +0 -0
aiq/__init__.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import sys
|
|
17
|
+
import importlib
|
|
18
|
+
import importlib.abc
|
|
19
|
+
import importlib.util
|
|
20
|
+
import warnings
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CompatFinder(importlib.abc.MetaPathFinder):
|
|
24
|
+
|
|
25
|
+
def __init__(self, alias_prefix, target_prefix):
|
|
26
|
+
self.alias_prefix = alias_prefix
|
|
27
|
+
self.target_prefix = target_prefix
|
|
28
|
+
|
|
29
|
+
def find_spec(self, fullname, path, target=None): # pylint: disable=unused-argument
|
|
30
|
+
if fullname == self.alias_prefix or fullname.startswith(self.alias_prefix + "."):
|
|
31
|
+
# Map aiq.something -> nat.something
|
|
32
|
+
target_name = self.target_prefix + fullname[len(self.alias_prefix):]
|
|
33
|
+
spec = importlib.util.find_spec(target_name)
|
|
34
|
+
if spec is None:
|
|
35
|
+
return None
|
|
36
|
+
# Wrap the loader so it loads under the alias name
|
|
37
|
+
return importlib.util.spec_from_loader(fullname, CompatLoader(fullname, target_name))
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class CompatLoader(importlib.abc.Loader):
|
|
42
|
+
|
|
43
|
+
def __init__(self, alias_name, target_name):
|
|
44
|
+
self.alias_name = alias_name
|
|
45
|
+
self.target_name = target_name
|
|
46
|
+
|
|
47
|
+
def create_module(self, spec):
|
|
48
|
+
# Reuse the actual module so there's only one instance
|
|
49
|
+
target_module = importlib.import_module(self.target_name)
|
|
50
|
+
sys.modules[self.alias_name] = target_module
|
|
51
|
+
return target_module
|
|
52
|
+
|
|
53
|
+
def exec_module(self, module):
|
|
54
|
+
# Nothing to execute since the target is already loaded
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# Register the compatibility finder
|
|
59
|
+
sys.meta_path.insert(0, CompatFinder("aiq", "nat"))
|
|
60
|
+
|
|
61
|
+
warnings.warn(
|
|
62
|
+
"!!! The 'aiq' namespace is deprecated and will be removed in a future release. "
|
|
63
|
+
"Please use the 'nat' namespace instead.",
|
|
64
|
+
DeprecationWarning,
|
|
65
|
+
stacklevel=2,
|
|
66
|
+
)
|
nat/agent/base.py
CHANGED
|
@@ -179,6 +179,7 @@ class BaseAgent(ABC):
|
|
|
179
179
|
logger.debug("%s Retrying tool call for %s in %d seconds...", AGENT_LOG_PREFIX, tool.name, sleep_time)
|
|
180
180
|
await asyncio.sleep(sleep_time)
|
|
181
181
|
|
|
182
|
+
# pylint: disable=C0209
|
|
182
183
|
# All retries exhausted, return error message
|
|
183
184
|
error_content = "Tool call failed after all retry attempts. Last error: %s" % str(last_exception)
|
|
184
185
|
logger.error("%s %s", AGENT_LOG_PREFIX, error_content)
|
|
@@ -234,6 +235,22 @@ class BaseAgent(ABC):
|
|
|
234
235
|
logger.warning("%s Unexpected error during JSON parsing: %s", AGENT_LOG_PREFIX, str(e))
|
|
235
236
|
return {"error": f"Unexpected parsing error: {str(e)}", "original_string": json_string}
|
|
236
237
|
|
|
238
|
+
def _get_chat_history(self, messages: list[BaseMessage]) -> str:
|
|
239
|
+
"""
|
|
240
|
+
Get the chat history excluding the last message.
|
|
241
|
+
|
|
242
|
+
Parameters
|
|
243
|
+
----------
|
|
244
|
+
messages : list[BaseMessage]
|
|
245
|
+
The messages to get the chat history from
|
|
246
|
+
|
|
247
|
+
Returns
|
|
248
|
+
-------
|
|
249
|
+
str
|
|
250
|
+
The chat history excluding the last message
|
|
251
|
+
"""
|
|
252
|
+
return "\n".join([f"{message.type}: {message.content}" for message in messages[:-1]])
|
|
253
|
+
|
|
237
254
|
@abstractmethod
|
|
238
255
|
async def _build_graph(self, state_schema: type) -> CompiledGraph:
|
|
239
256
|
pass
|
nat/agent/react_agent/agent.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import json
|
|
17
17
|
# pylint: disable=R0917
|
|
18
18
|
import logging
|
|
19
|
+
import typing
|
|
19
20
|
from json import JSONDecodeError
|
|
20
21
|
|
|
21
22
|
from langchain_core.agents import AgentAction
|
|
@@ -44,7 +45,9 @@ from nat.agent.react_agent.output_parser import ReActOutputParser
|
|
|
44
45
|
from nat.agent.react_agent.output_parser import ReActOutputParserException
|
|
45
46
|
from nat.agent.react_agent.prompt import SYSTEM_PROMPT
|
|
46
47
|
from nat.agent.react_agent.prompt import USER_PROMPT
|
|
47
|
-
|
|
48
|
+
|
|
49
|
+
if typing.TYPE_CHECKING:
|
|
50
|
+
from nat.agent.react_agent.register import ReActAgentWorkflowConfig
|
|
48
51
|
|
|
49
52
|
logger = logging.getLogger(__name__)
|
|
50
53
|
|
|
@@ -124,17 +127,19 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
124
127
|
if len(state.messages) == 0:
|
|
125
128
|
raise RuntimeError('No input received in state: "messages"')
|
|
126
129
|
# to check is any human input passed or not, if no input passed Agent will return the state
|
|
127
|
-
content = str(state.messages[
|
|
130
|
+
content = str(state.messages[-1].content)
|
|
128
131
|
if content.strip() == "":
|
|
129
132
|
logger.error("%s No human input passed to the agent.", AGENT_LOG_PREFIX)
|
|
130
133
|
state.messages += [AIMessage(content=NO_INPUT_ERROR_MESSAGE)]
|
|
131
134
|
return state
|
|
132
135
|
question = content
|
|
133
136
|
logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
|
|
134
|
-
|
|
137
|
+
chat_history = self._get_chat_history(state.messages)
|
|
135
138
|
output_message = await self._stream_llm(
|
|
136
139
|
self.agent,
|
|
137
|
-
{
|
|
140
|
+
{
|
|
141
|
+
"question": question, "chat_history": chat_history
|
|
142
|
+
},
|
|
138
143
|
RunnableConfig(callbacks=self.callbacks) # type: ignore
|
|
139
144
|
)
|
|
140
145
|
|
|
@@ -152,13 +157,15 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
152
157
|
tool_response = HumanMessage(content=tool_response_content)
|
|
153
158
|
agent_scratchpad.append(tool_response)
|
|
154
159
|
agent_scratchpad += working_state
|
|
155
|
-
|
|
160
|
+
chat_history = self._get_chat_history(state.messages)
|
|
161
|
+
question = str(state.messages[-1].content)
|
|
156
162
|
logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
|
|
157
163
|
|
|
158
|
-
output_message = await self._stream_llm(
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
164
|
+
output_message = await self._stream_llm(
|
|
165
|
+
self.agent, {
|
|
166
|
+
"question": question, "agent_scratchpad": agent_scratchpad, "chat_history": chat_history
|
|
167
|
+
},
|
|
168
|
+
RunnableConfig(callbacks=self.callbacks))
|
|
162
169
|
|
|
163
170
|
if self.detailed_logs:
|
|
164
171
|
logger.info(AGENT_CALL_LOG_MESSAGE, question, output_message.content)
|
|
@@ -326,7 +333,7 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
326
333
|
return True
|
|
327
334
|
|
|
328
335
|
|
|
329
|
-
def create_react_agent_prompt(config: ReActAgentWorkflowConfig) -> ChatPromptTemplate:
|
|
336
|
+
def create_react_agent_prompt(config: "ReActAgentWorkflowConfig") -> ChatPromptTemplate:
|
|
330
337
|
"""
|
|
331
338
|
Create a ReAct Agent prompt from the config.
|
|
332
339
|
|
nat/agent/react_agent/prompt.py
CHANGED
|
@@ -26,7 +26,7 @@ Use the following format exactly to ask the human to use a tool:
|
|
|
26
26
|
Question: the input question you must answer
|
|
27
27
|
Thought: you should always think about what to do
|
|
28
28
|
Action: the action to take, should be one of [{tool_names}]
|
|
29
|
-
Action Input: the input to the action (if there is no required input, include "Action Input: None")
|
|
29
|
+
Action Input: the input to the action (if there is no required input, include "Action Input: None")
|
|
30
30
|
Observation: wait for the human to respond with the result from the tool, do not assume the response
|
|
31
31
|
|
|
32
32
|
... (this Thought/Action/Action Input/Observation can repeat N times. If you do not need to use a tool, or after asking the human to use any tools and waiting for the human to respond, you might know the final answer.)
|
|
@@ -37,5 +37,8 @@ Final Answer: the final answer to the original input question
|
|
|
37
37
|
"""
|
|
38
38
|
|
|
39
39
|
USER_PROMPT = """
|
|
40
|
+
Previous conversation history:
|
|
41
|
+
{chat_history}
|
|
42
|
+
|
|
40
43
|
Question: {question}
|
|
41
44
|
"""
|
nat/agent/rewoo_agent/agent.py
CHANGED
|
@@ -21,6 +21,7 @@ from json import JSONDecodeError
|
|
|
21
21
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
|
22
22
|
from langchain_core.language_models import BaseChatModel
|
|
23
23
|
from langchain_core.messages.ai import AIMessage
|
|
24
|
+
from langchain_core.messages.base import BaseMessage
|
|
24
25
|
from langchain_core.messages.human import HumanMessage
|
|
25
26
|
from langchain_core.messages.tool import ToolMessage
|
|
26
27
|
from langchain_core.prompts.chat import ChatPromptTemplate
|
|
@@ -43,6 +44,7 @@ logger = logging.getLogger(__name__)
|
|
|
43
44
|
|
|
44
45
|
class ReWOOGraphState(BaseModel):
|
|
45
46
|
"""State schema for the ReWOO Agent Graph"""
|
|
47
|
+
messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReWOO Agent
|
|
46
48
|
task: HumanMessage = Field(default_factory=lambda: HumanMessage(content="")) # the task provided by user
|
|
47
49
|
plan: AIMessage = Field(
|
|
48
50
|
default_factory=lambda: AIMessage(content="")) # the plan generated by the planner to solve the task
|
|
@@ -183,10 +185,12 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
183
185
|
if not task:
|
|
184
186
|
logger.error("%s No task provided to the ReWOO Agent. Please provide a valid task.", AGENT_LOG_PREFIX)
|
|
185
187
|
return {"result": NO_INPUT_ERROR_MESSAGE}
|
|
186
|
-
|
|
188
|
+
chat_history = self._get_chat_history(state.messages)
|
|
187
189
|
plan = await self._stream_llm(
|
|
188
190
|
planner,
|
|
189
|
-
{
|
|
191
|
+
{
|
|
192
|
+
"task": task, "chat_history": chat_history
|
|
193
|
+
},
|
|
190
194
|
RunnableConfig(callbacks=self.callbacks) # type: ignore
|
|
191
195
|
)
|
|
192
196
|
|
nat/agent/rewoo_agent/prompt.py
CHANGED
|
@@ -124,8 +124,9 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
124
124
|
token_counter=len,
|
|
125
125
|
start_on="human",
|
|
126
126
|
include_system=True)
|
|
127
|
-
|
|
128
|
-
|
|
127
|
+
|
|
128
|
+
task = HumanMessage(content=messages[-1].content)
|
|
129
|
+
state = ReWOOGraphState(messages=messages, task=task)
|
|
129
130
|
|
|
130
131
|
# run the ReWOO Agent Graph
|
|
131
132
|
state = await graph.ainvoke(state)
|
|
@@ -15,11 +15,14 @@
|
|
|
15
15
|
|
|
16
16
|
# pylint: disable=R0917
|
|
17
17
|
import logging
|
|
18
|
+
import typing
|
|
18
19
|
|
|
19
20
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
|
20
21
|
from langchain_core.language_models import BaseChatModel
|
|
22
|
+
from langchain_core.messages import SystemMessage
|
|
21
23
|
from langchain_core.messages.base import BaseMessage
|
|
22
|
-
from langchain_core.runnables import
|
|
24
|
+
from langchain_core.runnables import RunnableLambda
|
|
25
|
+
from langchain_core.runnables.config import RunnableConfig
|
|
23
26
|
from langchain_core.tools import BaseTool
|
|
24
27
|
from langgraph.prebuilt import ToolNode
|
|
25
28
|
from pydantic import BaseModel
|
|
@@ -30,6 +33,9 @@ from nat.agent.base import AGENT_LOG_PREFIX
|
|
|
30
33
|
from nat.agent.base import AgentDecision
|
|
31
34
|
from nat.agent.dual_node import DualNodeAgent
|
|
32
35
|
|
|
36
|
+
if typing.TYPE_CHECKING:
|
|
37
|
+
from nat.agent.tool_calling_agent.register import ToolCallAgentWorkflowConfig
|
|
38
|
+
|
|
33
39
|
logger = logging.getLogger(__name__)
|
|
34
40
|
|
|
35
41
|
|
|
@@ -43,22 +49,51 @@ class ToolCallAgentGraph(DualNodeAgent):
|
|
|
43
49
|
A tool Calling Agent utilizes the tool input parameters to select the optimal tool. Supports handling tool errors.
|
|
44
50
|
Argument "detailed_logs" toggles logging of inputs, outputs, and intermediate steps."""
|
|
45
51
|
|
|
46
|
-
def __init__(
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
llm: BaseChatModel,
|
|
55
|
+
tools: list[BaseTool],
|
|
56
|
+
prompt: str | None = None,
|
|
57
|
+
callbacks: list[AsyncCallbackHandler] = None,
|
|
58
|
+
detailed_logs: bool = False,
|
|
59
|
+
handle_tool_errors: bool = True,
|
|
60
|
+
):
|
|
52
61
|
super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
|
|
62
|
+
# some LLMs support tool calling
|
|
63
|
+
# these models accept the tool's input schema and decide when to use a tool based on the input's relevance
|
|
64
|
+
try:
|
|
65
|
+
# in tool calling agents, we bind the tools to the LLM, to pass the tools' input schemas at runtime
|
|
66
|
+
self.bound_llm = llm.bind_tools(tools)
|
|
67
|
+
except NotImplementedError as ex:
|
|
68
|
+
logger.error("%s Failed to bind tools: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
|
|
69
|
+
raise ex
|
|
70
|
+
|
|
71
|
+
if prompt is not None:
|
|
72
|
+
system_prompt = SystemMessage(content=prompt)
|
|
73
|
+
prompt_runnable = RunnableLambda(
|
|
74
|
+
lambda state: [system_prompt] + state.get("messages", []),
|
|
75
|
+
name="SystemPrompt",
|
|
76
|
+
)
|
|
77
|
+
else:
|
|
78
|
+
prompt_runnable = RunnableLambda(
|
|
79
|
+
lambda state: state.get("messages", []),
|
|
80
|
+
name="PromptPassthrough",
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
self.agent = prompt_runnable | self.bound_llm
|
|
84
|
+
|
|
53
85
|
self.tool_caller = ToolNode(tools, handle_tool_errors=handle_tool_errors)
|
|
54
86
|
logger.debug("%s Initialized Tool Calling Agent Graph", AGENT_LOG_PREFIX)
|
|
55
87
|
|
|
56
88
|
async def agent_node(self, state: ToolCallAgentGraphState):
|
|
57
89
|
try:
|
|
58
|
-
logger.debug(
|
|
90
|
+
logger.debug("%s Starting the Tool Calling Agent Node", AGENT_LOG_PREFIX)
|
|
59
91
|
if len(state.messages) == 0:
|
|
60
92
|
raise RuntimeError('No input received in state: "messages"')
|
|
61
|
-
response = await self.
|
|
93
|
+
response = await self.agent.ainvoke(
|
|
94
|
+
{"messages": state.messages},
|
|
95
|
+
config=RunnableConfig(callbacks=self.callbacks),
|
|
96
|
+
)
|
|
62
97
|
if self.detailed_logs:
|
|
63
98
|
agent_input = "\n".join(str(message.content) for message in state.messages)
|
|
64
99
|
logger.info(AGENT_CALL_LOG_MESSAGE, agent_input, response)
|
|
@@ -75,16 +110,18 @@ class ToolCallAgentGraph(DualNodeAgent):
|
|
|
75
110
|
last_message = state.messages[-1]
|
|
76
111
|
if last_message.tool_calls:
|
|
77
112
|
# the agent wants to call a tool
|
|
78
|
-
logger.debug(
|
|
113
|
+
logger.debug("%s Agent is calling a tool", AGENT_LOG_PREFIX)
|
|
79
114
|
return AgentDecision.TOOL
|
|
80
115
|
if self.detailed_logs:
|
|
81
116
|
logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, state.messages[-1].content)
|
|
82
117
|
return AgentDecision.END
|
|
83
118
|
except Exception as ex:
|
|
84
|
-
logger.exception(
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
119
|
+
logger.exception(
|
|
120
|
+
"%s Failed to determine whether agent is calling a tool: %s",
|
|
121
|
+
AGENT_LOG_PREFIX,
|
|
122
|
+
ex,
|
|
123
|
+
exc_info=True,
|
|
124
|
+
)
|
|
88
125
|
logger.warning("%s Ending graph traversal", AGENT_LOG_PREFIX)
|
|
89
126
|
return AgentDecision.END
|
|
90
127
|
|
|
@@ -92,14 +129,15 @@ class ToolCallAgentGraph(DualNodeAgent):
|
|
|
92
129
|
try:
|
|
93
130
|
logger.debug("%s Starting Tool Node", AGENT_LOG_PREFIX)
|
|
94
131
|
tool_calls = state.messages[-1].tool_calls
|
|
95
|
-
tools = [tool.get(
|
|
132
|
+
tools = [tool.get("name") for tool in tool_calls]
|
|
96
133
|
tool_input = state.messages[-1]
|
|
97
|
-
tool_response = await self.tool_caller.ainvoke(
|
|
98
|
-
|
|
99
|
-
|
|
134
|
+
tool_response = await self.tool_caller.ainvoke(
|
|
135
|
+
input={"messages": [tool_input]},
|
|
136
|
+
config=RunnableConfig(callbacks=self.callbacks, configurable={}),
|
|
137
|
+
)
|
|
100
138
|
# this configurable = {} argument is needed due to a bug in LangGraph PreBuilt ToolNode ^
|
|
101
139
|
|
|
102
|
-
for response in tool_response.get(
|
|
140
|
+
for response in tool_response.get("messages"):
|
|
103
141
|
if self.detailed_logs:
|
|
104
142
|
self._log_tool_response(str(tools), str(tool_input), response.content)
|
|
105
143
|
state.messages += [response]
|
|
@@ -112,8 +150,41 @@ class ToolCallAgentGraph(DualNodeAgent):
|
|
|
112
150
|
async def build_graph(self):
|
|
113
151
|
try:
|
|
114
152
|
await super()._build_graph(state_schema=ToolCallAgentGraphState)
|
|
115
|
-
logger.debug(
|
|
153
|
+
logger.debug(
|
|
154
|
+
"%s Tool Calling Agent Graph built and compiled successfully",
|
|
155
|
+
AGENT_LOG_PREFIX,
|
|
156
|
+
)
|
|
116
157
|
return self.graph
|
|
117
158
|
except Exception as ex:
|
|
118
|
-
logger.exception(
|
|
159
|
+
logger.exception(
|
|
160
|
+
"%s Failed to build Tool Calling Agent Graph: %s",
|
|
161
|
+
AGENT_LOG_PREFIX,
|
|
162
|
+
ex,
|
|
163
|
+
exc_info=ex,
|
|
164
|
+
)
|
|
119
165
|
raise ex
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def create_tool_calling_agent_prompt(config: "ToolCallAgentWorkflowConfig") -> str | None:
|
|
169
|
+
"""
|
|
170
|
+
Create a Tool Calling Agent prompt from the config.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
config (ToolCallAgentWorkflowConfig): The config to use for the prompt.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
ChatPromptTemplate: The Tool Calling Agent prompt.
|
|
177
|
+
"""
|
|
178
|
+
# the Tool Calling Agent prompt can be customized via config option system_prompt and additional_instructions.
|
|
179
|
+
|
|
180
|
+
if config.system_prompt:
|
|
181
|
+
prompt_str = config.system_prompt
|
|
182
|
+
else:
|
|
183
|
+
prompt_str = ""
|
|
184
|
+
|
|
185
|
+
if config.additional_instructions:
|
|
186
|
+
prompt_str += f" {config.additional_instructions}"
|
|
187
|
+
|
|
188
|
+
if len(prompt_str) > 0:
|
|
189
|
+
return prompt_str
|
|
190
|
+
return None
|
|
@@ -41,6 +41,9 @@ class ToolCallAgentWorkflowConfig(FunctionBaseConfig, name="tool_calling_agent")
|
|
|
41
41
|
handle_tool_errors: bool = Field(default=True, description="Specify ability to handle tool calling errors.")
|
|
42
42
|
description: str = Field(default="Tool Calling Agent Workflow", description="Description of this functions use.")
|
|
43
43
|
max_iterations: int = Field(default=15, description="Number of tool calls before stoping the tool calling agent.")
|
|
44
|
+
system_prompt: str | None = Field(default=None, description="Provides the system prompt to use with the agent.")
|
|
45
|
+
additional_instructions: str | None = Field(default=None,
|
|
46
|
+
description="Additional instructions appended to the system prompt.")
|
|
44
47
|
|
|
45
48
|
|
|
46
49
|
@register_function(config_type=ToolCallAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
@@ -49,10 +52,11 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
49
52
|
from langgraph.graph.graph import CompiledGraph
|
|
50
53
|
|
|
51
54
|
from nat.agent.base import AGENT_LOG_PREFIX
|
|
55
|
+
from nat.agent.tool_calling_agent.agent import ToolCallAgentGraph
|
|
56
|
+
from nat.agent.tool_calling_agent.agent import ToolCallAgentGraphState
|
|
57
|
+
from nat.agent.tool_calling_agent.agent import create_tool_calling_agent_prompt
|
|
52
58
|
|
|
53
|
-
|
|
54
|
-
from .agent import ToolCallAgentGraphState
|
|
55
|
-
|
|
59
|
+
prompt = create_tool_calling_agent_prompt(config)
|
|
56
60
|
# we can choose an LLM for the ReAct agent in the config file
|
|
57
61
|
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
58
62
|
# the agent can run any installed tool, simply install the tool and add it to the config file
|
|
@@ -61,18 +65,10 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
61
65
|
if not tools:
|
|
62
66
|
raise ValueError(f"No tools specified for Tool Calling Agent '{config.llm_name}'")
|
|
63
67
|
|
|
64
|
-
# some LLMs support tool calling
|
|
65
|
-
# these models accept the tool's input schema and decide when to use a tool based on the input's relevance
|
|
66
|
-
try:
|
|
67
|
-
# in tool calling agents, we bind the tools to the LLM, to pass the tools' input schemas at runtime
|
|
68
|
-
llm = llm.bind_tools(tools)
|
|
69
|
-
except NotImplementedError as ex:
|
|
70
|
-
logger.error("%s Failed to bind tools: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
|
|
71
|
-
raise ex
|
|
72
|
-
|
|
73
68
|
# construct the Tool Calling Agent Graph from the configured llm, and tools
|
|
74
69
|
graph: CompiledGraph = await ToolCallAgentGraph(llm=llm,
|
|
75
70
|
tools=tools,
|
|
71
|
+
prompt=prompt,
|
|
76
72
|
detailed_logs=config.verbose,
|
|
77
73
|
handle_tool_errors=config.handle_tool_errors).build_graph()
|
|
78
74
|
|
nat/cli/type_registry.py
CHANGED
|
@@ -588,8 +588,8 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
588
588
|
except KeyError as err:
|
|
589
589
|
raise KeyError(
|
|
590
590
|
f"An invalid Embedder config and wrapper combination was supplied. Config: `{config_type}`, "
|
|
591
|
-
"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Embedder client but "
|
|
592
|
-
"there is no registered conversion from that Embedder provider to LLM framework: {wrapper_type}. "
|
|
591
|
+
f"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Embedder client but "
|
|
592
|
+
f"there is no registered conversion from that Embedder provider to LLM framework: {wrapper_type}. "
|
|
593
593
|
"Please provide an Embedder configuration from one of the following providers: "
|
|
594
594
|
f"{set(self._embedder_client_provider_to_framework.keys())}") from err
|
|
595
595
|
|
|
@@ -703,8 +703,8 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
703
703
|
except KeyError as err:
|
|
704
704
|
raise KeyError(
|
|
705
705
|
f"An invalid Retriever config and wrapper combination was supplied. Config: `{config_type}`, "
|
|
706
|
-
"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Retriever client but "
|
|
707
|
-
"there is no registered conversion from that Retriever provider to LLM framework: {wrapper_type}. "
|
|
706
|
+
f"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Retriever client but "
|
|
707
|
+
f"there is no registered conversion from that Retriever provider to LLM framework: {wrapper_type}. "
|
|
708
708
|
"Please provide a Retriever configuration from one of the following providers: "
|
|
709
709
|
f"{set(self._retriever_client_provider_to_framework.keys())}") from err
|
|
710
710
|
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import AliasChoices
|
|
17
|
+
from pydantic import ConfigDict
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from nat.builder.builder import Builder
|
|
21
|
+
from nat.builder.embedder import EmbedderProviderInfo
|
|
22
|
+
from nat.cli.register_workflow import register_embedder_provider
|
|
23
|
+
from nat.data_models.embedder import EmbedderBaseConfig
|
|
24
|
+
from nat.data_models.retry_mixin import RetryMixin
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AzureOpenAIEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="azure_openai"):
|
|
28
|
+
"""An Azure OpenAI embedder provider to be used with an embedder client."""
|
|
29
|
+
|
|
30
|
+
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
31
|
+
|
|
32
|
+
api_key: str | None = Field(default=None, description="Azure OpenAI API key to interact with hosted model.")
|
|
33
|
+
api_version: str = Field(default="2025-04-01-preview", description="Azure OpenAI API version.")
|
|
34
|
+
azure_endpoint: str | None = Field(validation_alias=AliasChoices("azure_endpoint", "base_url"),
|
|
35
|
+
serialization_alias="azure_endpoint",
|
|
36
|
+
default=None,
|
|
37
|
+
description="Base URL for the hosted model.")
|
|
38
|
+
azure_deployment: str = Field(validation_alias=AliasChoices("azure_deployment", "model_name", "model"),
|
|
39
|
+
serialization_alias="azure_deployment",
|
|
40
|
+
description="The Azure OpenAI hosted model/deployment name.")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@register_embedder_provider(config_type=AzureOpenAIEmbedderModelConfig)
|
|
44
|
+
async def azure_openai_embedder_model(config: AzureOpenAIEmbedderModelConfig, _builder: Builder):
|
|
45
|
+
|
|
46
|
+
yield EmbedderProviderInfo(config=config, description="An Azure OpenAI model for use with an Embedder client.")
|
nat/embedder/openai_embedder.py
CHANGED
|
@@ -34,10 +34,9 @@ class OpenAIEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="openai"):
|
|
|
34
34
|
model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
|
|
35
35
|
serialization_alias="model",
|
|
36
36
|
description="The OpenAI hosted model name.")
|
|
37
|
-
max_retries: int = Field(default=2, description="The max number of retries for the request.")
|
|
38
37
|
|
|
39
38
|
|
|
40
39
|
@register_embedder_provider(config_type=OpenAIEmbedderModelConfig)
|
|
41
|
-
async def
|
|
40
|
+
async def openai_embedder_model(config: OpenAIEmbedderModelConfig, _builder: Builder):
|
|
42
41
|
|
|
43
42
|
yield EmbedderProviderInfo(config=config, description="An OpenAI model for use with an Embedder client.")
|
nat/embedder/register.py
CHANGED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import AliasChoices
|
|
17
|
+
from pydantic import ConfigDict
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from nat.builder.builder import Builder
|
|
21
|
+
from nat.builder.llm import LLMProviderInfo
|
|
22
|
+
from nat.cli.register_workflow import register_llm_provider
|
|
23
|
+
from nat.data_models.llm import LLMBaseConfig
|
|
24
|
+
from nat.data_models.retry_mixin import RetryMixin
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AzureOpenAIModelConfig(LLMBaseConfig, RetryMixin, name="azure_openai"):
|
|
28
|
+
"""An Azure OpenAI LLM provider to be used with an LLM client."""
|
|
29
|
+
|
|
30
|
+
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
31
|
+
|
|
32
|
+
api_key: str | None = Field(default=None, description="Azure OpenAI API key to interact with hosted model.")
|
|
33
|
+
api_version: str = Field(default="2025-04-01-preview", description="Azure OpenAI API version.")
|
|
34
|
+
azure_endpoint: str | None = Field(validation_alias=AliasChoices("azure_endpoint", "base_url"),
|
|
35
|
+
serialization_alias="azure_endpoint",
|
|
36
|
+
default=None,
|
|
37
|
+
description="Base URL for the hosted model.")
|
|
38
|
+
azure_deployment: str = Field(validation_alias=AliasChoices("azure_deployment", "model_name", "model"),
|
|
39
|
+
serialization_alias="azure_deployment",
|
|
40
|
+
description="The Azure OpenAI hosted model/deployment name.")
|
|
41
|
+
temperature: float = Field(default=0.0, description="Sampling temperature in [0, 1].")
|
|
42
|
+
top_p: float = Field(default=1.0, description="Top-p for distribution sampling.")
|
|
43
|
+
seed: int | None = Field(default=None, description="Random seed to set for generation.")
|
|
44
|
+
max_retries: int = Field(default=10, description="The max number of retries for the request.")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@register_llm_provider(config_type=AzureOpenAIModelConfig)
|
|
48
|
+
async def azure_openai_llm(config: AzureOpenAIModelConfig, _builder: Builder):
|
|
49
|
+
|
|
50
|
+
yield LLMProviderInfo(config=config, description="An Azure OpenAI model for use with an LLM client.")
|
nat/llm/register.py
CHANGED
nat/meta/pypi.md
CHANGED
|
@@ -23,19 +23,19 @@ NeMo Agent toolkit is a flexible library designed to seamlessly integrate your e
|
|
|
23
23
|
|
|
24
24
|
## Key Features
|
|
25
25
|
|
|
26
|
-
- [**Framework Agnostic:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2
|
|
27
|
-
- [**Reusability:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2
|
|
28
|
-
- [**Rapid Development:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2
|
|
29
|
-
- [**Profiling:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2
|
|
30
|
-
- [**Observability:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2
|
|
31
|
-
- [**Evaluation System:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2
|
|
32
|
-
- [**User Interface:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2
|
|
33
|
-
- [**MCP Compatibility**](https://docs.nvidia.com/nemo/agent-toolkit/1.2
|
|
26
|
+
- [**Framework Agnostic:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2/extend/plugins.html) Works with any agentic framework, so you can use your current technology stack without replatforming.
|
|
27
|
+
- [**Reusability:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2/extend/sharing-components.html) Every agent, tool, or workflow can be combined and repurposed, allowing developers to leverage existing work in new scenarios.
|
|
28
|
+
- [**Rapid Development:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2/tutorials/index.html) Start with a pre-built agent, tool, or workflow, and customize it to your needs.
|
|
29
|
+
- [**Profiling:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2/workflows/profiler.html) Profile entire workflows down to the tool and agent level, track input/output tokens and timings, and identify bottlenecks.
|
|
30
|
+
- [**Observability:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2/workflows/observe/observe-workflow-with-phoenix.html) Monitor and debug your workflows with any OpenTelemetry-compatible observability tool, with examples using [Phoenix](https://docs.nvidia.com/nemo/agent-toolkit/1.2/workflows/observe/observe-workflow-with-phoenix.html) and [W&B Weave](https://docs.nvidia.com/nemo/agent-toolkit/1.2/workflows/observe/observe-workflow-with-weave.html).
|
|
31
|
+
- [**Evaluation System:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2/workflows/evaluate.html) Validate and maintain accuracy of agentic workflows with built-in evaluation tools.
|
|
32
|
+
- [**User Interface:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2/quick-start/launching-ui.html) Use the NeMo Agent toolkit UI chat interface to interact with your agents, visualize output, and debug workflows.
|
|
33
|
+
- [**MCP Compatibility**](https://docs.nvidia.com/nemo/agent-toolkit/1.2/workflows/mcp/mcp-client.html) Compatible with Model Context Protocol (MCP), allowing tools served by MCP Servers to be used as NeMo Agent toolkit functions.
|
|
34
34
|
|
|
35
35
|
With NeMo Agent toolkit, you can move quickly, experiment freely, and ensure reliability across all your agent-driven projects.
|
|
36
36
|
|
|
37
37
|
## Links
|
|
38
|
-
* [Documentation](https://docs.nvidia.com/nemo/agent-toolkit/1.2
|
|
38
|
+
* [Documentation](https://docs.nvidia.com/nemo/agent-toolkit/1.2/index.html): Explore the full documentation for NeMo Agent toolkit.
|
|
39
39
|
|
|
40
40
|
## First time user?
|
|
41
41
|
If this is your first time using NeMo Agent toolkit, it is recommended to install the latest version from the [source repository](https://github.com/NVIDIA/NeMo-Agent-Toolkit?tab=readme-ov-file#quick-start) on GitHub. This package is intended for users who are familiar with NeMo Agent toolkit applications and need to add NeMo Agent toolkit as a dependency to their project.
|