nvidia-nat 1.2rc9__py3-none-any.whl → 1.2.0rc10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +17 -0
- nat/agent/react_agent/agent.py +18 -10
- nat/agent/react_agent/prompt.py +4 -1
- nat/agent/rewoo_agent/agent.py +6 -2
- nat/agent/rewoo_agent/prompt.py +3 -0
- nat/agent/rewoo_agent/register.py +3 -2
- {nvidia_nat-1.2rc9.dist-info → nvidia_nat-1.2.0rc10.dist-info}/METADATA +1 -1
- {nvidia_nat-1.2rc9.dist-info → nvidia_nat-1.2.0rc10.dist-info}/RECORD +13 -13
- {nvidia_nat-1.2rc9.dist-info → nvidia_nat-1.2.0rc10.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.2rc9.dist-info → nvidia_nat-1.2.0rc10.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.2rc9.dist-info → nvidia_nat-1.2.0rc10.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.2rc9.dist-info → nvidia_nat-1.2.0rc10.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.2rc9.dist-info → nvidia_nat-1.2.0rc10.dist-info}/top_level.txt +0 -0
nat/agent/base.py
CHANGED
|
@@ -179,6 +179,7 @@ class BaseAgent(ABC):
|
|
|
179
179
|
logger.debug("%s Retrying tool call for %s in %d seconds...", AGENT_LOG_PREFIX, tool.name, sleep_time)
|
|
180
180
|
await asyncio.sleep(sleep_time)
|
|
181
181
|
|
|
182
|
+
# pylint: disable=C0209
|
|
182
183
|
# All retries exhausted, return error message
|
|
183
184
|
error_content = "Tool call failed after all retry attempts. Last error: %s" % str(last_exception)
|
|
184
185
|
logger.error("%s %s", AGENT_LOG_PREFIX, error_content)
|
|
@@ -234,6 +235,22 @@ class BaseAgent(ABC):
|
|
|
234
235
|
logger.warning("%s Unexpected error during JSON parsing: %s", AGENT_LOG_PREFIX, str(e))
|
|
235
236
|
return {"error": f"Unexpected parsing error: {str(e)}", "original_string": json_string}
|
|
236
237
|
|
|
238
|
+
def _get_chat_history(self, messages: list[BaseMessage]) -> str:
|
|
239
|
+
"""
|
|
240
|
+
Get the chat history excluding the last message.
|
|
241
|
+
|
|
242
|
+
Parameters
|
|
243
|
+
----------
|
|
244
|
+
messages : list[BaseMessage]
|
|
245
|
+
The messages to get the chat history from
|
|
246
|
+
|
|
247
|
+
Returns
|
|
248
|
+
-------
|
|
249
|
+
str
|
|
250
|
+
The chat history excluding the last message
|
|
251
|
+
"""
|
|
252
|
+
return "\n".join([f"{message.type}: {message.content}" for message in messages[:-1]])
|
|
253
|
+
|
|
237
254
|
@abstractmethod
|
|
238
255
|
async def _build_graph(self, state_schema: type) -> CompiledGraph:
|
|
239
256
|
pass
|
nat/agent/react_agent/agent.py
CHANGED
|
@@ -17,6 +17,7 @@ import json
|
|
|
17
17
|
# pylint: disable=R0917
|
|
18
18
|
import logging
|
|
19
19
|
from json import JSONDecodeError
|
|
20
|
+
from typing import TYPE_CHECKING
|
|
20
21
|
|
|
21
22
|
from langchain_core.agents import AgentAction
|
|
22
23
|
from langchain_core.agents import AgentFinish
|
|
@@ -44,7 +45,10 @@ from nat.agent.react_agent.output_parser import ReActOutputParser
|
|
|
44
45
|
from nat.agent.react_agent.output_parser import ReActOutputParserException
|
|
45
46
|
from nat.agent.react_agent.prompt import SYSTEM_PROMPT
|
|
46
47
|
from nat.agent.react_agent.prompt import USER_PROMPT
|
|
47
|
-
|
|
48
|
+
|
|
49
|
+
# To avoid circular imports
|
|
50
|
+
if TYPE_CHECKING:
|
|
51
|
+
from nat.agent.react_agent.register import ReActAgentWorkflowConfig
|
|
48
52
|
|
|
49
53
|
logger = logging.getLogger(__name__)
|
|
50
54
|
|
|
@@ -124,17 +128,19 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
124
128
|
if len(state.messages) == 0:
|
|
125
129
|
raise RuntimeError('No input received in state: "messages"')
|
|
126
130
|
# to check is any human input passed or not, if no input passed Agent will return the state
|
|
127
|
-
content = str(state.messages[
|
|
131
|
+
content = str(state.messages[-1].content)
|
|
128
132
|
if content.strip() == "":
|
|
129
133
|
logger.error("%s No human input passed to the agent.", AGENT_LOG_PREFIX)
|
|
130
134
|
state.messages += [AIMessage(content=NO_INPUT_ERROR_MESSAGE)]
|
|
131
135
|
return state
|
|
132
136
|
question = content
|
|
133
137
|
logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
|
|
134
|
-
|
|
138
|
+
chat_history = self._get_chat_history(state.messages)
|
|
135
139
|
output_message = await self._stream_llm(
|
|
136
140
|
self.agent,
|
|
137
|
-
{
|
|
141
|
+
{
|
|
142
|
+
"question": question, "chat_history": chat_history
|
|
143
|
+
},
|
|
138
144
|
RunnableConfig(callbacks=self.callbacks) # type: ignore
|
|
139
145
|
)
|
|
140
146
|
|
|
@@ -152,13 +158,15 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
152
158
|
tool_response = HumanMessage(content=tool_response_content)
|
|
153
159
|
agent_scratchpad.append(tool_response)
|
|
154
160
|
agent_scratchpad += working_state
|
|
155
|
-
|
|
161
|
+
chat_history = self._get_chat_history(state.messages)
|
|
162
|
+
question = str(state.messages[-1].content)
|
|
156
163
|
logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
|
|
157
164
|
|
|
158
|
-
output_message = await self._stream_llm(
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
165
|
+
output_message = await self._stream_llm(
|
|
166
|
+
self.agent, {
|
|
167
|
+
"question": question, "agent_scratchpad": agent_scratchpad, "chat_history": chat_history
|
|
168
|
+
},
|
|
169
|
+
RunnableConfig(callbacks=self.callbacks))
|
|
162
170
|
|
|
163
171
|
if self.detailed_logs:
|
|
164
172
|
logger.info(AGENT_CALL_LOG_MESSAGE, question, output_message.content)
|
|
@@ -326,7 +334,7 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
326
334
|
return True
|
|
327
335
|
|
|
328
336
|
|
|
329
|
-
def create_react_agent_prompt(config: ReActAgentWorkflowConfig) -> ChatPromptTemplate:
|
|
337
|
+
def create_react_agent_prompt(config: "ReActAgentWorkflowConfig") -> ChatPromptTemplate:
|
|
330
338
|
"""
|
|
331
339
|
Create a ReAct Agent prompt from the config.
|
|
332
340
|
|
nat/agent/react_agent/prompt.py
CHANGED
|
@@ -26,7 +26,7 @@ Use the following format exactly to ask the human to use a tool:
|
|
|
26
26
|
Question: the input question you must answer
|
|
27
27
|
Thought: you should always think about what to do
|
|
28
28
|
Action: the action to take, should be one of [{tool_names}]
|
|
29
|
-
Action Input: the input to the action (if there is no required input, include "Action Input: None")
|
|
29
|
+
Action Input: the input to the action (if there is no required input, include "Action Input: None")
|
|
30
30
|
Observation: wait for the human to respond with the result from the tool, do not assume the response
|
|
31
31
|
|
|
32
32
|
... (this Thought/Action/Action Input/Observation can repeat N times. If you do not need to use a tool, or after asking the human to use any tools and waiting for the human to respond, you might know the final answer.)
|
|
@@ -37,5 +37,8 @@ Final Answer: the final answer to the original input question
|
|
|
37
37
|
"""
|
|
38
38
|
|
|
39
39
|
USER_PROMPT = """
|
|
40
|
+
Previous conversation history:
|
|
41
|
+
{chat_history}
|
|
42
|
+
|
|
40
43
|
Question: {question}
|
|
41
44
|
"""
|
nat/agent/rewoo_agent/agent.py
CHANGED
|
@@ -21,6 +21,7 @@ from json import JSONDecodeError
|
|
|
21
21
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
|
22
22
|
from langchain_core.language_models import BaseChatModel
|
|
23
23
|
from langchain_core.messages.ai import AIMessage
|
|
24
|
+
from langchain_core.messages.base import BaseMessage
|
|
24
25
|
from langchain_core.messages.human import HumanMessage
|
|
25
26
|
from langchain_core.messages.tool import ToolMessage
|
|
26
27
|
from langchain_core.prompts.chat import ChatPromptTemplate
|
|
@@ -43,6 +44,7 @@ logger = logging.getLogger(__name__)
|
|
|
43
44
|
|
|
44
45
|
class ReWOOGraphState(BaseModel):
|
|
45
46
|
"""State schema for the ReWOO Agent Graph"""
|
|
47
|
+
messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReWOO Agent
|
|
46
48
|
task: HumanMessage = Field(default_factory=lambda: HumanMessage(content="")) # the task provided by user
|
|
47
49
|
plan: AIMessage = Field(
|
|
48
50
|
default_factory=lambda: AIMessage(content="")) # the plan generated by the planner to solve the task
|
|
@@ -183,10 +185,12 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
183
185
|
if not task:
|
|
184
186
|
logger.error("%s No task provided to the ReWOO Agent. Please provide a valid task.", AGENT_LOG_PREFIX)
|
|
185
187
|
return {"result": NO_INPUT_ERROR_MESSAGE}
|
|
186
|
-
|
|
188
|
+
chat_history = self._get_chat_history(state.messages)
|
|
187
189
|
plan = await self._stream_llm(
|
|
188
190
|
planner,
|
|
189
|
-
{
|
|
191
|
+
{
|
|
192
|
+
"task": task, "chat_history": chat_history
|
|
193
|
+
},
|
|
190
194
|
RunnableConfig(callbacks=self.callbacks) # type: ignore
|
|
191
195
|
)
|
|
192
196
|
|
nat/agent/rewoo_agent/prompt.py
CHANGED
|
@@ -124,8 +124,9 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
124
124
|
token_counter=len,
|
|
125
125
|
start_on="human",
|
|
126
126
|
include_system=True)
|
|
127
|
-
|
|
128
|
-
|
|
127
|
+
|
|
128
|
+
task = HumanMessage(content=messages[-1].content)
|
|
129
|
+
state = ReWOOGraphState(messages=messages, task=task)
|
|
129
130
|
|
|
130
131
|
# run the ReWOO Agent Graph
|
|
131
132
|
state = await graph.ainvoke(state)
|
|
@@ -1,19 +1,19 @@
|
|
|
1
1
|
aiq/__init__.py,sha256=E9vuQX0dCZIALhOg360sRLO53f6NXPgMTv3X1sh8WAM,2376
|
|
2
2
|
nat/agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
nat/agent/base.py,sha256=
|
|
3
|
+
nat/agent/base.py,sha256=DVSUq1VBhrgeTxnDgcESMFcJlEhnBa3PU9I1Rl5t3B8,9602
|
|
4
4
|
nat/agent/dual_node.py,sha256=EOYpYzhaY-m1t2W3eiQrBjSfNjYMDttAwtzEEEcYP4s,2353
|
|
5
5
|
nat/agent/register.py,sha256=EATlFFl7ov5HNGySLcPv1T7jzV-Jy-jPVkUzSXDT-7s,1005
|
|
6
6
|
nat/agent/react_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
|
-
nat/agent/react_agent/agent.py,sha256=
|
|
7
|
+
nat/agent/react_agent/agent.py,sha256=s968IEwb3toBRDfeBFTOQlseC_5zULu_3IQaS5li9xY,19770
|
|
8
8
|
nat/agent/react_agent/output_parser.py,sha256=m7K6wRwtckBBpAHqOf3BZ9mqZLwrP13Kxz5fvNxbyZE,4219
|
|
9
|
-
nat/agent/react_agent/prompt.py,sha256=
|
|
9
|
+
nat/agent/react_agent/prompt.py,sha256=N47JJrT6xwYQCv1jedHhlul2AE7EfKsSYfAbgJwWRew,1758
|
|
10
10
|
nat/agent/react_agent/register.py,sha256=g3xkVWqr1p26Gjk4OJF_kolV6WBtS5GJkLvFHG7e7-I,8099
|
|
11
11
|
nat/agent/reasoning_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
12
|
nat/agent/reasoning_agent/reasoning_agent.py,sha256=2NDDHeesM2s2PnJfRsv2OTYjeajR1rYUVDvJZLzWGAQ,9434
|
|
13
13
|
nat/agent/rewoo_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
-
nat/agent/rewoo_agent/agent.py,sha256=
|
|
15
|
-
nat/agent/rewoo_agent/prompt.py,sha256=
|
|
16
|
-
nat/agent/rewoo_agent/register.py,sha256=
|
|
14
|
+
nat/agent/rewoo_agent/agent.py,sha256=MEyBaFjxWaTwAtaKHNwHR0q4fR_-SpO85U_6d2kGZUg,19296
|
|
15
|
+
nat/agent/rewoo_agent/prompt.py,sha256=nFMav3Zl_vmKPLzAIhbQHlldWnurPJb1GlwnekUuxDs,3720
|
|
16
|
+
nat/agent/rewoo_agent/register.py,sha256=v04nBg6608HoQL1lxuOSZjX1gr2DpscjyQwMCRWF94Y,8432
|
|
17
17
|
nat/agent/tool_calling_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
18
|
nat/agent/tool_calling_agent/agent.py,sha256=tKcVIV1EI0D9qlSDUJ7mJOhWCAotFBka2Ii0f4mukoU,5718
|
|
19
19
|
nat/agent/tool_calling_agent/register.py,sha256=iUZ53Ki-KfNwJ-R17r7a3JqmKdKbZCNZReUK0uPJfbU,5413
|
|
@@ -426,10 +426,10 @@ nat/utils/reactive/base/observer_base.py,sha256=UAlyAY_ky4q2t0P81RVFo2Bs_R7z5Nde
|
|
|
426
426
|
nat/utils/reactive/base/subject_base.py,sha256=UQOxlkZTIeeyYmG5qLtDpNf_63Y7p-doEeUA08_R8ME,2521
|
|
427
427
|
nat/utils/settings/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
428
428
|
nat/utils/settings/global_settings.py,sha256=CYHQX7F6MDR18vsVFTrEySpS9cBufuVGTUqZm9lREFs,7446
|
|
429
|
-
nvidia_nat-1.
|
|
430
|
-
nvidia_nat-1.
|
|
431
|
-
nvidia_nat-1.
|
|
432
|
-
nvidia_nat-1.
|
|
433
|
-
nvidia_nat-1.
|
|
434
|
-
nvidia_nat-1.
|
|
435
|
-
nvidia_nat-1.
|
|
429
|
+
nvidia_nat-1.2.0rc10.dist-info/licenses/LICENSE-3rd-party.txt,sha256=fOk5jMmCX9YoKWyYzTtfgl-SUy477audFC5hNY4oP7Q,284609
|
|
430
|
+
nvidia_nat-1.2.0rc10.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
|
431
|
+
nvidia_nat-1.2.0rc10.dist-info/METADATA,sha256=JfvYkcvhML1Ku6UjiXJJlagwMu1ZcOJQ7Vjk2jqhNBw,21779
|
|
432
|
+
nvidia_nat-1.2.0rc10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
433
|
+
nvidia_nat-1.2.0rc10.dist-info/entry_points.txt,sha256=FNh4pZVSe_61s29zdks66lmXBPtsnko8KSZ4ffv7WVE,653
|
|
434
|
+
nvidia_nat-1.2.0rc10.dist-info/top_level.txt,sha256=lgJWLkigiVZuZ_O1nxVnD_ziYBwgpE2OStdaCduMEGc,8
|
|
435
|
+
nvidia_nat-1.2.0rc10.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
{nvidia_nat-1.2rc9.dist-info → nvidia_nat-1.2.0rc10.dist-info}/licenses/LICENSE-3rd-party.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|