nvidia-nat 1.3a20250819__py3-none-any.whl → 1.3.0a20250823__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. aiq/__init__.py +66 -0
  2. nat/agent/base.py +16 -0
  3. nat/agent/react_agent/agent.py +38 -13
  4. nat/agent/react_agent/prompt.py +4 -1
  5. nat/agent/react_agent/register.py +1 -1
  6. nat/agent/register.py +0 -1
  7. nat/agent/rewoo_agent/agent.py +6 -3
  8. nat/agent/rewoo_agent/prompt.py +3 -0
  9. nat/agent/rewoo_agent/register.py +4 -3
  10. nat/agent/tool_calling_agent/agent.py +92 -22
  11. nat/agent/tool_calling_agent/register.py +9 -13
  12. nat/authentication/api_key/api_key_auth_provider.py +1 -1
  13. nat/authentication/register.py +0 -1
  14. nat/builder/builder.py +1 -1
  15. nat/builder/context.py +9 -1
  16. nat/builder/function_base.py +3 -3
  17. nat/builder/function_info.py +5 -7
  18. nat/builder/user_interaction_manager.py +2 -2
  19. nat/builder/workflow.py +3 -0
  20. nat/builder/workflow_builder.py +0 -1
  21. nat/cli/commands/evaluate.py +1 -1
  22. nat/cli/commands/info/list_components.py +7 -8
  23. nat/cli/commands/info/list_mcp.py +3 -4
  24. nat/cli/commands/registry/search.py +14 -16
  25. nat/cli/commands/start.py +0 -1
  26. nat/cli/commands/workflow/templates/pyproject.toml.j2 +3 -0
  27. nat/cli/commands/workflow/templates/register.py.j2 +0 -1
  28. nat/cli/commands/workflow/workflow_commands.py +0 -1
  29. nat/cli/type_registry.py +7 -9
  30. nat/data_models/config.py +1 -1
  31. nat/data_models/evaluate.py +1 -1
  32. nat/data_models/function_dependencies.py +6 -6
  33. nat/data_models/intermediate_step.py +3 -3
  34. nat/data_models/model_gated_field_mixin.py +125 -0
  35. nat/data_models/swe_bench_model.py +1 -1
  36. nat/data_models/temperature_mixin.py +36 -0
  37. nat/data_models/top_p_mixin.py +36 -0
  38. nat/embedder/azure_openai_embedder.py +46 -0
  39. nat/embedder/openai_embedder.py +1 -2
  40. nat/embedder/register.py +1 -1
  41. nat/eval/config.py +2 -0
  42. nat/eval/dataset_handler/dataset_handler.py +5 -6
  43. nat/eval/evaluate.py +64 -20
  44. nat/eval/rag_evaluator/register.py +2 -2
  45. nat/eval/register.py +0 -1
  46. nat/eval/tunable_rag_evaluator/evaluate.py +0 -3
  47. nat/eval/utils/eval_trace_ctx.py +89 -0
  48. nat/eval/utils/weave_eval.py +14 -7
  49. nat/experimental/test_time_compute/models/strategy_base.py +3 -2
  50. nat/experimental/test_time_compute/register.py +0 -1
  51. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +0 -2
  52. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +48 -49
  53. nat/front_ends/fastapi/message_handler.py +13 -14
  54. nat/front_ends/fastapi/message_validator.py +4 -4
  55. nat/front_ends/fastapi/step_adaptor.py +1 -1
  56. nat/front_ends/register.py +0 -1
  57. nat/llm/aws_bedrock_llm.py +3 -3
  58. nat/llm/azure_openai_llm.py +49 -0
  59. nat/llm/nim_llm.py +4 -4
  60. nat/llm/openai_llm.py +4 -4
  61. nat/llm/register.py +1 -1
  62. nat/llm/utils/env_config_value.py +2 -3
  63. nat/meta/pypi.md +9 -9
  64. nat/object_store/models.py +2 -0
  65. nat/object_store/register.py +0 -1
  66. nat/observability/exporter/base_exporter.py +1 -1
  67. nat/observability/exporter/file_exporter.py +1 -1
  68. nat/observability/register.py +3 -3
  69. nat/profiler/callbacks/langchain_callback_handler.py +9 -2
  70. nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
  71. nat/profiler/data_frame_row.py +1 -1
  72. nat/profiler/decorators/framework_wrapper.py +1 -4
  73. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  74. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  75. nat/profiler/inference_optimization/data_models.py +3 -3
  76. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
  77. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  78. nat/profiler/profile_runner.py +13 -8
  79. nat/registry_handlers/package_utils.py +0 -1
  80. nat/registry_handlers/pypi/pypi_handler.py +20 -23
  81. nat/registry_handlers/register.py +3 -4
  82. nat/registry_handlers/rest/rest_handler.py +8 -9
  83. nat/retriever/register.py +0 -1
  84. nat/runtime/session.py +23 -8
  85. nat/settings/global_settings.py +13 -2
  86. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  87. nat/tool/datetime_tools.py +49 -9
  88. nat/tool/document_search.py +1 -1
  89. nat/tool/mcp/mcp_tool.py +1 -1
  90. nat/tool/register.py +0 -1
  91. nat/utils/data_models/schema_validator.py +2 -2
  92. nat/utils/exception_handlers/automatic_retries.py +0 -2
  93. nat/utils/exception_handlers/schemas.py +1 -1
  94. nat/utils/reactive/base/observable_base.py +2 -2
  95. nat/utils/reactive/base/observer_base.py +1 -1
  96. nat/utils/reactive/observable.py +2 -2
  97. nat/utils/reactive/observer.py +2 -2
  98. nat/utils/reactive/subscription.py +1 -1
  99. nat/utils/settings/global_settings.py +4 -6
  100. nat/utils/type_utils.py +4 -4
  101. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/METADATA +17 -15
  102. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/RECORD +107 -100
  103. nvidia_nat-1.3.0a20250823.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  104. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/top_level.txt +1 -0
  105. nvidia_nat-1.3a20250819.dist-info/licenses/LICENSE-3rd-party.txt +0 -3686
  106. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/WHEEL +0 -0
  107. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/entry_points.txt +0 -0
  108. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/licenses/LICENSE.md +0 -0
aiq/__init__.py ADDED
@@ -0,0 +1,66 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import importlib
17
+ import importlib.abc
18
+ import importlib.util
19
+ import sys
20
+ import warnings
21
+
22
+
23
+ class CompatFinder(importlib.abc.MetaPathFinder):
24
+
25
+ def __init__(self, alias_prefix, target_prefix):
26
+ self.alias_prefix = alias_prefix
27
+ self.target_prefix = target_prefix
28
+
29
+ def find_spec(self, fullname, path, target=None):
30
+ if fullname == self.alias_prefix or fullname.startswith(self.alias_prefix + "."):
31
+ # Map aiq.something -> nat.something
32
+ target_name = self.target_prefix + fullname[len(self.alias_prefix):]
33
+ spec = importlib.util.find_spec(target_name)
34
+ if spec is None:
35
+ return None
36
+ # Wrap the loader so it loads under the alias name
37
+ return importlib.util.spec_from_loader(fullname, CompatLoader(fullname, target_name))
38
+ return None
39
+
40
+
41
+ class CompatLoader(importlib.abc.Loader):
42
+
43
+ def __init__(self, alias_name, target_name):
44
+ self.alias_name = alias_name
45
+ self.target_name = target_name
46
+
47
+ def create_module(self, spec):
48
+ # Reuse the actual module so there's only one instance
49
+ target_module = importlib.import_module(self.target_name)
50
+ sys.modules[self.alias_name] = target_module
51
+ return target_module
52
+
53
+ def exec_module(self, module):
54
+ # Nothing to execute since the target is already loaded
55
+ pass
56
+
57
+
58
+ # Register the compatibility finder
59
+ sys.meta_path.insert(0, CompatFinder("aiq", "nat"))
60
+
61
+ warnings.warn(
62
+ "!!! The 'aiq' namespace is deprecated and will be removed in a future release. "
63
+ "Please use the 'nat' namespace instead.",
64
+ DeprecationWarning,
65
+ stacklevel=2,
66
+ )
nat/agent/base.py CHANGED
@@ -234,6 +234,22 @@ class BaseAgent(ABC):
234
234
  logger.warning("%s Unexpected error during JSON parsing: %s", AGENT_LOG_PREFIX, str(e))
235
235
  return {"error": f"Unexpected parsing error: {str(e)}", "original_string": json_string}
236
236
 
237
+ def _get_chat_history(self, messages: list[BaseMessage]) -> str:
238
+ """
239
+ Get the chat history excluding the last message.
240
+
241
+ Parameters
242
+ ----------
243
+ messages : list[BaseMessage]
244
+ The messages to get the chat history from
245
+
246
+ Returns
247
+ -------
248
+ str
249
+ The chat history excluding the last message
250
+ """
251
+ return "\n".join([f"{message.type}: {message.content}" for message in messages[:-1]])
252
+
237
253
  @abstractmethod
238
254
  async def _build_graph(self, state_schema: type) -> CompiledGraph:
239
255
  pass
@@ -14,20 +14,23 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import json
17
- # pylint: disable=R0917
18
17
  import logging
18
+ import re
19
+ import typing
19
20
  from json import JSONDecodeError
20
21
 
21
22
  from langchain_core.agents import AgentAction
22
23
  from langchain_core.agents import AgentFinish
23
24
  from langchain_core.callbacks.base import AsyncCallbackHandler
24
25
  from langchain_core.language_models import BaseChatModel
26
+ from langchain_core.language_models import LanguageModelInput
25
27
  from langchain_core.messages.ai import AIMessage
26
28
  from langchain_core.messages.base import BaseMessage
27
29
  from langchain_core.messages.human import HumanMessage
28
30
  from langchain_core.messages.tool import ToolMessage
29
31
  from langchain_core.prompts import ChatPromptTemplate
30
32
  from langchain_core.prompts import MessagesPlaceholder
33
+ from langchain_core.runnables import Runnable
31
34
  from langchain_core.runnables.config import RunnableConfig
32
35
  from langchain_core.tools import BaseTool
33
36
  from pydantic import BaseModel
@@ -44,7 +47,9 @@ from nat.agent.react_agent.output_parser import ReActOutputParser
44
47
  from nat.agent.react_agent.output_parser import ReActOutputParserException
45
48
  from nat.agent.react_agent.prompt import SYSTEM_PROMPT
46
49
  from nat.agent.react_agent.prompt import USER_PROMPT
47
- from nat.agent.react_agent.register import ReActAgentWorkflowConfig
50
+
51
+ if typing.TYPE_CHECKING:
52
+ from nat.agent.react_agent.register import ReActAgentWorkflowConfig
48
53
 
49
54
  logger = logging.getLogger(__name__)
50
55
 
@@ -94,11 +99,27 @@ class ReActAgentGraph(DualNodeAgent):
94
99
  f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
95
100
  prompt = prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
96
101
  # construct the ReAct Agent
97
- bound_llm = llm.bind(stop=["Observation:"]) # type: ignore
98
- self.agent = prompt | bound_llm
102
+ self.agent = prompt | self._maybe_bind_llm_and_yield()
99
103
  self.tools_dict = {tool.name: tool for tool in tools}
100
104
  logger.debug("%s Initialized ReAct Agent Graph", AGENT_LOG_PREFIX)
101
105
 
106
+ def _maybe_bind_llm_and_yield(self) -> Runnable[LanguageModelInput, BaseMessage]:
107
+ """
108
+ Bind additional parameters to the LLM if needed
109
+ - if the LLM is a smart model, no need to bind any additional parameters
110
+ - if the LLM is a non-smart model, bind a stop sequence to the LLM
111
+
112
+ Returns:
113
+ Runnable[LanguageModelInput, BaseMessage]: The LLM with any additional parameters bound.
114
+ """
115
+ # models that don't need (or don't support)a stop sequence
116
+ smart_models = re.compile(r"gpt-?5", re.IGNORECASE)
117
+ if any(smart_models.search(getattr(self.llm, model, "")) for model in ["model", "model_name"]):
118
+ # no need to bind any additional parameters to the LLM
119
+ return self.llm
120
+ # add a stop sequence to the LLM
121
+ return self.llm.bind(stop=["Observation:"])
122
+
102
123
  def _get_tool(self, tool_name: str):
103
124
  try:
104
125
  return self.tools_dict.get(tool_name)
@@ -124,17 +145,19 @@ class ReActAgentGraph(DualNodeAgent):
124
145
  if len(state.messages) == 0:
125
146
  raise RuntimeError('No input received in state: "messages"')
126
147
  # to check is any human input passed or not, if no input passed Agent will return the state
127
- content = str(state.messages[0].content)
148
+ content = str(state.messages[-1].content)
128
149
  if content.strip() == "":
129
150
  logger.error("%s No human input passed to the agent.", AGENT_LOG_PREFIX)
130
151
  state.messages += [AIMessage(content=NO_INPUT_ERROR_MESSAGE)]
131
152
  return state
132
153
  question = content
133
154
  logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
134
-
155
+ chat_history = self._get_chat_history(state.messages)
135
156
  output_message = await self._stream_llm(
136
157
  self.agent,
137
- {"question": question},
158
+ {
159
+ "question": question, "chat_history": chat_history
160
+ },
138
161
  RunnableConfig(callbacks=self.callbacks) # type: ignore
139
162
  )
140
163
 
@@ -152,13 +175,15 @@ class ReActAgentGraph(DualNodeAgent):
152
175
  tool_response = HumanMessage(content=tool_response_content)
153
176
  agent_scratchpad.append(tool_response)
154
177
  agent_scratchpad += working_state
155
- question = str(state.messages[0].content)
178
+ chat_history = self._get_chat_history(state.messages)
179
+ question = str(state.messages[-1].content)
156
180
  logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
157
181
 
158
- output_message = await self._stream_llm(self.agent, {
159
- "question": question, "agent_scratchpad": agent_scratchpad
160
- },
161
- RunnableConfig(callbacks=self.callbacks))
182
+ output_message = await self._stream_llm(
183
+ self.agent, {
184
+ "question": question, "agent_scratchpad": agent_scratchpad, "chat_history": chat_history
185
+ },
186
+ RunnableConfig(callbacks=self.callbacks))
162
187
 
163
188
  if self.detailed_logs:
164
189
  logger.info(AGENT_CALL_LOG_MESSAGE, question, output_message.content)
@@ -326,7 +351,7 @@ class ReActAgentGraph(DualNodeAgent):
326
351
  return True
327
352
 
328
353
 
329
- def create_react_agent_prompt(config: ReActAgentWorkflowConfig) -> ChatPromptTemplate:
354
+ def create_react_agent_prompt(config: "ReActAgentWorkflowConfig") -> ChatPromptTemplate:
330
355
  """
331
356
  Create a ReAct Agent prompt from the config.
332
357
 
@@ -26,7 +26,7 @@ Use the following format exactly to ask the human to use a tool:
26
26
  Question: the input question you must answer
27
27
  Thought: you should always think about what to do
28
28
  Action: the action to take, should be one of [{tool_names}]
29
- Action Input: the input to the action (if there is no required input, include "Action Input: None")
29
+ Action Input: the input to the action (if there is no required input, include "Action Input: None")
30
30
  Observation: wait for the human to respond with the result from the tool, do not assume the response
31
31
 
32
32
  ... (this Thought/Action/Action Input/Observation can repeat N times. If you do not need to use a tool, or after asking the human to use any tools and waiting for the human to respond, you might know the final answer.)
@@ -37,5 +37,8 @@ Final Answer: the final answer to the original input question
37
37
  """
38
38
 
39
39
  USER_PROMPT = """
40
+ Previous conversation history:
41
+ {chat_history}
42
+
40
43
  Question: {question}
41
44
  """
@@ -125,7 +125,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
125
125
 
126
126
  # get and return the output from the state
127
127
  state = ReActGraphState(**state)
128
- output_message = state.messages[-1] # pylint: disable=E1136
128
+ output_message = state.messages[-1]
129
129
  return ChatResponse.from_string(str(output_message.content))
130
130
 
131
131
  except Exception as ex:
nat/agent/register.py CHANGED
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
 
19
18
  # Import any workflows which need to be automatically registered here
@@ -14,13 +14,13 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import json
17
- # pylint: disable=R0917
18
17
  import logging
19
18
  from json import JSONDecodeError
20
19
 
21
20
  from langchain_core.callbacks.base import AsyncCallbackHandler
22
21
  from langchain_core.language_models import BaseChatModel
23
22
  from langchain_core.messages.ai import AIMessage
23
+ from langchain_core.messages.base import BaseMessage
24
24
  from langchain_core.messages.human import HumanMessage
25
25
  from langchain_core.messages.tool import ToolMessage
26
26
  from langchain_core.prompts.chat import ChatPromptTemplate
@@ -43,6 +43,7 @@ logger = logging.getLogger(__name__)
43
43
 
44
44
  class ReWOOGraphState(BaseModel):
45
45
  """State schema for the ReWOO Agent Graph"""
46
+ messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReWOO Agent
46
47
  task: HumanMessage = Field(default_factory=lambda: HumanMessage(content="")) # the task provided by user
47
48
  plan: AIMessage = Field(
48
49
  default_factory=lambda: AIMessage(content="")) # the plan generated by the planner to solve the task
@@ -183,10 +184,12 @@ class ReWOOAgentGraph(BaseAgent):
183
184
  if not task:
184
185
  logger.error("%s No task provided to the ReWOO Agent. Please provide a valid task.", AGENT_LOG_PREFIX)
185
186
  return {"result": NO_INPUT_ERROR_MESSAGE}
186
-
187
+ chat_history = self._get_chat_history(state.messages)
187
188
  plan = await self._stream_llm(
188
189
  planner,
189
- {"task": task},
190
+ {
191
+ "task": task, "chat_history": chat_history
192
+ },
190
193
  RunnableConfig(callbacks=self.callbacks) # type: ignore
191
194
  )
192
195
 
@@ -87,6 +87,9 @@ Begin!
87
87
  """
88
88
 
89
89
  PLANNER_USER_PROMPT = """
90
+ Previous conversation history:
91
+ {chat_history}
92
+
90
93
  task: {task}
91
94
  """
92
95
 
@@ -124,15 +124,16 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
124
124
  token_counter=len,
125
125
  start_on="human",
126
126
  include_system=True)
127
- task = HumanMessage(content=messages[0].content)
128
- state = ReWOOGraphState(task=task)
127
+
128
+ task = HumanMessage(content=messages[-1].content)
129
+ state = ReWOOGraphState(messages=messages, task=task)
129
130
 
130
131
  # run the ReWOO Agent Graph
131
132
  state = await graph.ainvoke(state)
132
133
 
133
134
  # get and return the output from the state
134
135
  state = ReWOOGraphState(**state)
135
- output_message = state.result.content # pylint: disable=E1101
136
+ output_message = state.result.content
136
137
  return ChatResponse.from_string(output_message)
137
138
 
138
139
  except Exception as ex:
@@ -13,13 +13,15 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=R0917
17
16
  import logging
17
+ import typing
18
18
 
19
19
  from langchain_core.callbacks.base import AsyncCallbackHandler
20
20
  from langchain_core.language_models import BaseChatModel
21
+ from langchain_core.messages import SystemMessage
21
22
  from langchain_core.messages.base import BaseMessage
22
- from langchain_core.runnables import RunnableConfig
23
+ from langchain_core.runnables import RunnableLambda
24
+ from langchain_core.runnables.config import RunnableConfig
23
25
  from langchain_core.tools import BaseTool
24
26
  from langgraph.prebuilt import ToolNode
25
27
  from pydantic import BaseModel
@@ -30,6 +32,9 @@ from nat.agent.base import AGENT_LOG_PREFIX
30
32
  from nat.agent.base import AgentDecision
31
33
  from nat.agent.dual_node import DualNodeAgent
32
34
 
35
+ if typing.TYPE_CHECKING:
36
+ from nat.agent.tool_calling_agent.register import ToolCallAgentWorkflowConfig
37
+
33
38
  logger = logging.getLogger(__name__)
34
39
 
35
40
 
@@ -43,22 +48,51 @@ class ToolCallAgentGraph(DualNodeAgent):
43
48
  A tool Calling Agent utilizes the tool input parameters to select the optimal tool. Supports handling tool errors.
44
49
  Argument "detailed_logs" toggles logging of inputs, outputs, and intermediate steps."""
45
50
 
46
- def __init__(self,
47
- llm: BaseChatModel,
48
- tools: list[BaseTool],
49
- callbacks: list[AsyncCallbackHandler] = None,
50
- detailed_logs: bool = False,
51
- handle_tool_errors: bool = True):
51
+ def __init__(
52
+ self,
53
+ llm: BaseChatModel,
54
+ tools: list[BaseTool],
55
+ prompt: str | None = None,
56
+ callbacks: list[AsyncCallbackHandler] = None,
57
+ detailed_logs: bool = False,
58
+ handle_tool_errors: bool = True,
59
+ ):
52
60
  super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
61
+ # some LLMs support tool calling
62
+ # these models accept the tool's input schema and decide when to use a tool based on the input's relevance
63
+ try:
64
+ # in tool calling agents, we bind the tools to the LLM, to pass the tools' input schemas at runtime
65
+ self.bound_llm = llm.bind_tools(tools)
66
+ except NotImplementedError as ex:
67
+ logger.error("%s Failed to bind tools: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
68
+ raise ex
69
+
70
+ if prompt is not None:
71
+ system_prompt = SystemMessage(content=prompt)
72
+ prompt_runnable = RunnableLambda(
73
+ lambda state: [system_prompt] + state.get("messages", []),
74
+ name="SystemPrompt",
75
+ )
76
+ else:
77
+ prompt_runnable = RunnableLambda(
78
+ lambda state: state.get("messages", []),
79
+ name="PromptPassthrough",
80
+ )
81
+
82
+ self.agent = prompt_runnable | self.bound_llm
83
+
53
84
  self.tool_caller = ToolNode(tools, handle_tool_errors=handle_tool_errors)
54
85
  logger.debug("%s Initialized Tool Calling Agent Graph", AGENT_LOG_PREFIX)
55
86
 
56
87
  async def agent_node(self, state: ToolCallAgentGraphState):
57
88
  try:
58
- logger.debug('%s Starting the Tool Calling Agent Node', AGENT_LOG_PREFIX)
89
+ logger.debug("%s Starting the Tool Calling Agent Node", AGENT_LOG_PREFIX)
59
90
  if len(state.messages) == 0:
60
91
  raise RuntimeError('No input received in state: "messages"')
61
- response = await self.llm.ainvoke(state.messages, config=RunnableConfig(callbacks=self.callbacks))
92
+ response = await self.agent.ainvoke(
93
+ {"messages": state.messages},
94
+ config=RunnableConfig(callbacks=self.callbacks),
95
+ )
62
96
  if self.detailed_logs:
63
97
  agent_input = "\n".join(str(message.content) for message in state.messages)
64
98
  logger.info(AGENT_CALL_LOG_MESSAGE, agent_input, response)
@@ -75,16 +109,18 @@ class ToolCallAgentGraph(DualNodeAgent):
75
109
  last_message = state.messages[-1]
76
110
  if last_message.tool_calls:
77
111
  # the agent wants to call a tool
78
- logger.debug('%s Agent is calling a tool', AGENT_LOG_PREFIX)
112
+ logger.debug("%s Agent is calling a tool", AGENT_LOG_PREFIX)
79
113
  return AgentDecision.TOOL
80
114
  if self.detailed_logs:
81
115
  logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, state.messages[-1].content)
82
116
  return AgentDecision.END
83
117
  except Exception as ex:
84
- logger.exception("%s Failed to determine whether agent is calling a tool: %s",
85
- AGENT_LOG_PREFIX,
86
- ex,
87
- exc_info=True)
118
+ logger.exception(
119
+ "%s Failed to determine whether agent is calling a tool: %s",
120
+ AGENT_LOG_PREFIX,
121
+ ex,
122
+ exc_info=True,
123
+ )
88
124
  logger.warning("%s Ending graph traversal", AGENT_LOG_PREFIX)
89
125
  return AgentDecision.END
90
126
 
@@ -92,14 +128,15 @@ class ToolCallAgentGraph(DualNodeAgent):
92
128
  try:
93
129
  logger.debug("%s Starting Tool Node", AGENT_LOG_PREFIX)
94
130
  tool_calls = state.messages[-1].tool_calls
95
- tools = [tool.get('name') for tool in tool_calls]
131
+ tools = [tool.get("name") for tool in tool_calls]
96
132
  tool_input = state.messages[-1]
97
- tool_response = await self.tool_caller.ainvoke(input={"messages": [tool_input]},
98
- config=RunnableConfig(callbacks=self.callbacks,
99
- configurable={}))
133
+ tool_response = await self.tool_caller.ainvoke(
134
+ input={"messages": [tool_input]},
135
+ config=RunnableConfig(callbacks=self.callbacks, configurable={}),
136
+ )
100
137
  # this configurable = {} argument is needed due to a bug in LangGraph PreBuilt ToolNode ^
101
138
 
102
- for response in tool_response.get('messages'):
139
+ for response in tool_response.get("messages"):
103
140
  if self.detailed_logs:
104
141
  self._log_tool_response(str(tools), str(tool_input), response.content)
105
142
  state.messages += [response]
@@ -112,8 +149,41 @@ class ToolCallAgentGraph(DualNodeAgent):
112
149
  async def build_graph(self):
113
150
  try:
114
151
  await super()._build_graph(state_schema=ToolCallAgentGraphState)
115
- logger.debug("%s Tool Calling Agent Graph built and compiled successfully", AGENT_LOG_PREFIX)
152
+ logger.debug(
153
+ "%s Tool Calling Agent Graph built and compiled successfully",
154
+ AGENT_LOG_PREFIX,
155
+ )
116
156
  return self.graph
117
157
  except Exception as ex:
118
- logger.exception("%s Failed to build Tool Calling Agent Graph: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
158
+ logger.exception(
159
+ "%s Failed to build Tool Calling Agent Graph: %s",
160
+ AGENT_LOG_PREFIX,
161
+ ex,
162
+ exc_info=ex,
163
+ )
119
164
  raise ex
165
+
166
+
167
+ def create_tool_calling_agent_prompt(config: "ToolCallAgentWorkflowConfig") -> str | None:
168
+ """
169
+ Create a Tool Calling Agent prompt from the config.
170
+
171
+ Args:
172
+ config (ToolCallAgentWorkflowConfig): The config to use for the prompt.
173
+
174
+ Returns:
175
+ ChatPromptTemplate: The Tool Calling Agent prompt.
176
+ """
177
+ # the Tool Calling Agent prompt can be customized via config option system_prompt and additional_instructions.
178
+
179
+ if config.system_prompt:
180
+ prompt_str = config.system_prompt
181
+ else:
182
+ prompt_str = ""
183
+
184
+ if config.additional_instructions:
185
+ prompt_str += f" {config.additional_instructions}"
186
+
187
+ if len(prompt_str) > 0:
188
+ return prompt_str
189
+ return None
@@ -41,6 +41,9 @@ class ToolCallAgentWorkflowConfig(FunctionBaseConfig, name="tool_calling_agent")
41
41
  handle_tool_errors: bool = Field(default=True, description="Specify ability to handle tool calling errors.")
42
42
  description: str = Field(default="Tool Calling Agent Workflow", description="Description of this functions use.")
43
43
  max_iterations: int = Field(default=15, description="Number of tool calls before stoping the tool calling agent.")
44
+ system_prompt: str | None = Field(default=None, description="Provides the system prompt to use with the agent.")
45
+ additional_instructions: str | None = Field(default=None,
46
+ description="Additional instructions appended to the system prompt.")
44
47
 
45
48
 
46
49
  @register_function(config_type=ToolCallAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
@@ -49,10 +52,11 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
49
52
  from langgraph.graph.graph import CompiledGraph
50
53
 
51
54
  from nat.agent.base import AGENT_LOG_PREFIX
55
+ from nat.agent.tool_calling_agent.agent import ToolCallAgentGraph
56
+ from nat.agent.tool_calling_agent.agent import ToolCallAgentGraphState
57
+ from nat.agent.tool_calling_agent.agent import create_tool_calling_agent_prompt
52
58
 
53
- from .agent import ToolCallAgentGraph
54
- from .agent import ToolCallAgentGraphState
55
-
59
+ prompt = create_tool_calling_agent_prompt(config)
56
60
  # we can choose an LLM for the ReAct agent in the config file
57
61
  llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
58
62
  # the agent can run any installed tool, simply install the tool and add it to the config file
@@ -61,18 +65,10 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
61
65
  if not tools:
62
66
  raise ValueError(f"No tools specified for Tool Calling Agent '{config.llm_name}'")
63
67
 
64
- # some LLMs support tool calling
65
- # these models accept the tool's input schema and decide when to use a tool based on the input's relevance
66
- try:
67
- # in tool calling agents, we bind the tools to the LLM, to pass the tools' input schemas at runtime
68
- llm = llm.bind_tools(tools)
69
- except NotImplementedError as ex:
70
- logger.error("%s Failed to bind tools: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
71
- raise ex
72
-
73
68
  # construct the Tool Calling Agent Graph from the configured llm, and tools
74
69
  graph: CompiledGraph = await ToolCallAgentGraph(llm=llm,
75
70
  tools=tools,
71
+ prompt=prompt,
76
72
  detailed_logs=config.verbose,
77
73
  handle_tool_errors=config.handle_tool_errors).build_graph()
78
74
 
@@ -90,7 +86,7 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
90
86
 
91
87
  # get and return the output from the state
92
88
  state = ToolCallAgentGraphState(**state)
93
- output_message = state.messages[-1] # pylint: disable=E1136
89
+ output_message = state.messages[-1]
94
90
  return output_message.content
95
91
  except Exception as ex:
96
92
  logger.exception("%s Tool Calling Agent failed with exception: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
@@ -31,7 +31,7 @@ class APIKeyAuthProvider(AuthProviderBase[APIKeyAuthProviderConfig]):
31
31
  # fmt: off
32
32
  def __init__(self,
33
33
  config: APIKeyAuthProviderConfig,
34
- config_name: str | None = None) -> None: # pylint: disable=unused-argument
34
+ config_name: str | None = None) -> None:
35
35
  assert isinstance(config, APIKeyAuthProviderConfig), ("Config is not APIKeyAuthProviderConfig")
36
36
  super().__init__(config)
37
37
  # fmt: on
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
 
19
18
  from nat.authentication.api_key import register as register_api_key
nat/builder/builder.py CHANGED
@@ -58,7 +58,7 @@ class UserManagerHolder():
58
58
  return self._context.user_manager.get_id()
59
59
 
60
60
 
61
- class Builder(ABC): # pylint: disable=too-many-public-methods
61
+ class Builder(ABC):
62
62
 
63
63
  @abstractmethod
64
64
  async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
nat/builder/context.py CHANGED
@@ -38,7 +38,7 @@ from nat.utils.reactive.subject import Subject
38
38
 
39
39
  class Singleton(type):
40
40
 
41
- def __init__(cls, name, bases, dict): # pylint: disable=W0622
41
+ def __init__(cls, name, bases, dict):
42
42
  super(Singleton, cls).__init__(name, bases, dict)
43
43
  cls.instance = None
44
44
 
@@ -65,6 +65,7 @@ class ContextState(metaclass=Singleton):
65
65
 
66
66
  def __init__(self):
67
67
  self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None)
68
+ self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
68
69
  self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
69
70
  self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
70
71
  self.metadata: ContextVar[RequestAttributes] = ContextVar("request_attributes", default=RequestAttributes())
@@ -165,6 +166,13 @@ class Context:
165
166
  """
166
167
  return self._context_state.conversation_id.get()
167
168
 
169
+ @property
170
+ def user_message_id(self) -> str | None:
171
+ """
172
+ This property retrieves the user message ID which is the unique identifier for the current user message.
173
+ """
174
+ return self._context_state.user_message_id.get()
175
+
168
176
  @contextmanager
169
177
  def push_active_function(self, function_name: str, input_data: typing.Any | None):
170
178
  """
@@ -111,7 +111,7 @@ class FunctionBase(typing.Generic[InputT, StreamingOutputT, SingleOutputT], ABC)
111
111
  ValueError
112
112
  If the input type cannot be determined from the class definition
113
113
  """
114
- for base_cls in self.__class__.__orig_bases__: # pylint: disable=no-member # type: ignore
114
+ for base_cls in self.__class__.__orig_bases__:
115
115
 
116
116
  base_cls_args = typing.get_args(base_cls)
117
117
 
@@ -196,7 +196,7 @@ class FunctionBase(typing.Generic[InputT, StreamingOutputT, SingleOutputT], ABC)
196
196
  ValueError
197
197
  If the streaming output type cannot be determined from the class definition
198
198
  """
199
- for base_cls in self.__class__.__orig_bases__: # pylint: disable=no-member # type: ignore
199
+ for base_cls in self.__class__.__orig_bases__:
200
200
 
201
201
  base_cls_args = typing.get_args(base_cls)
202
202
 
@@ -269,7 +269,7 @@ class FunctionBase(typing.Generic[InputT, StreamingOutputT, SingleOutputT], ABC)
269
269
  ValueError
270
270
  If the single output type cannot be determined from the class definition
271
271
  """
272
- for base_cls in self.__class__.__orig_bases__: # pylint: disable=no-member # type: ignore
272
+ for base_cls in self.__class__.__orig_bases__:
273
273
 
274
274
  base_cls_args = typing.get_args(base_cls)
275
275