nvidia-nat 1.3a20250819__py3-none-any.whl → 1.3.0.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aiq/__init__.py ADDED
@@ -0,0 +1,66 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import sys
17
+ import importlib
18
+ import importlib.abc
19
+ import importlib.util
20
+ import warnings
21
+
22
+
23
+ class CompatFinder(importlib.abc.MetaPathFinder):
24
+
25
+ def __init__(self, alias_prefix, target_prefix):
26
+ self.alias_prefix = alias_prefix
27
+ self.target_prefix = target_prefix
28
+
29
+ def find_spec(self, fullname, path, target=None): # pylint: disable=unused-argument
30
+ if fullname == self.alias_prefix or fullname.startswith(self.alias_prefix + "."):
31
+ # Map aiq.something -> nat.something
32
+ target_name = self.target_prefix + fullname[len(self.alias_prefix):]
33
+ spec = importlib.util.find_spec(target_name)
34
+ if spec is None:
35
+ return None
36
+ # Wrap the loader so it loads under the alias name
37
+ return importlib.util.spec_from_loader(fullname, CompatLoader(fullname, target_name))
38
+ return None
39
+
40
+
41
+ class CompatLoader(importlib.abc.Loader):
42
+
43
+ def __init__(self, alias_name, target_name):
44
+ self.alias_name = alias_name
45
+ self.target_name = target_name
46
+
47
+ def create_module(self, spec):
48
+ # Reuse the actual module so there's only one instance
49
+ target_module = importlib.import_module(self.target_name)
50
+ sys.modules[self.alias_name] = target_module
51
+ return target_module
52
+
53
+ def exec_module(self, module):
54
+ # Nothing to execute since the target is already loaded
55
+ pass
56
+
57
+
58
+ # Register the compatibility finder
59
+ sys.meta_path.insert(0, CompatFinder("aiq", "nat"))
60
+
61
+ warnings.warn(
62
+ "!!! The 'aiq' namespace is deprecated and will be removed in a future release. "
63
+ "Please use the 'nat' namespace instead.",
64
+ DeprecationWarning,
65
+ stacklevel=2,
66
+ )
nat/agent/base.py CHANGED
@@ -179,6 +179,7 @@ class BaseAgent(ABC):
179
179
  logger.debug("%s Retrying tool call for %s in %d seconds...", AGENT_LOG_PREFIX, tool.name, sleep_time)
180
180
  await asyncio.sleep(sleep_time)
181
181
 
182
+ # pylint: disable=C0209
182
183
  # All retries exhausted, return error message
183
184
  error_content = "Tool call failed after all retry attempts. Last error: %s" % str(last_exception)
184
185
  logger.error("%s %s", AGENT_LOG_PREFIX, error_content)
@@ -234,6 +235,22 @@ class BaseAgent(ABC):
234
235
  logger.warning("%s Unexpected error during JSON parsing: %s", AGENT_LOG_PREFIX, str(e))
235
236
  return {"error": f"Unexpected parsing error: {str(e)}", "original_string": json_string}
236
237
 
238
+ def _get_chat_history(self, messages: list[BaseMessage]) -> str:
239
+ """
240
+ Get the chat history excluding the last message.
241
+
242
+ Parameters
243
+ ----------
244
+ messages : list[BaseMessage]
245
+ The messages to get the chat history from
246
+
247
+ Returns
248
+ -------
249
+ str
250
+ The chat history excluding the last message
251
+ """
252
+ return "\n".join([f"{message.type}: {message.content}" for message in messages[:-1]])
253
+
237
254
  @abstractmethod
238
255
  async def _build_graph(self, state_schema: type) -> CompiledGraph:
239
256
  pass
@@ -16,6 +16,7 @@
16
16
  import json
17
17
  # pylint: disable=R0917
18
18
  import logging
19
+ import typing
19
20
  from json import JSONDecodeError
20
21
 
21
22
  from langchain_core.agents import AgentAction
@@ -44,7 +45,9 @@ from nat.agent.react_agent.output_parser import ReActOutputParser
44
45
  from nat.agent.react_agent.output_parser import ReActOutputParserException
45
46
  from nat.agent.react_agent.prompt import SYSTEM_PROMPT
46
47
  from nat.agent.react_agent.prompt import USER_PROMPT
47
- from nat.agent.react_agent.register import ReActAgentWorkflowConfig
48
+
49
+ if typing.TYPE_CHECKING:
50
+ from nat.agent.react_agent.register import ReActAgentWorkflowConfig
48
51
 
49
52
  logger = logging.getLogger(__name__)
50
53
 
@@ -124,17 +127,19 @@ class ReActAgentGraph(DualNodeAgent):
124
127
  if len(state.messages) == 0:
125
128
  raise RuntimeError('No input received in state: "messages"')
126
129
  # to check is any human input passed or not, if no input passed Agent will return the state
127
- content = str(state.messages[0].content)
130
+ content = str(state.messages[-1].content)
128
131
  if content.strip() == "":
129
132
  logger.error("%s No human input passed to the agent.", AGENT_LOG_PREFIX)
130
133
  state.messages += [AIMessage(content=NO_INPUT_ERROR_MESSAGE)]
131
134
  return state
132
135
  question = content
133
136
  logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
134
-
137
+ chat_history = self._get_chat_history(state.messages)
135
138
  output_message = await self._stream_llm(
136
139
  self.agent,
137
- {"question": question},
140
+ {
141
+ "question": question, "chat_history": chat_history
142
+ },
138
143
  RunnableConfig(callbacks=self.callbacks) # type: ignore
139
144
  )
140
145
 
@@ -152,13 +157,15 @@ class ReActAgentGraph(DualNodeAgent):
152
157
  tool_response = HumanMessage(content=tool_response_content)
153
158
  agent_scratchpad.append(tool_response)
154
159
  agent_scratchpad += working_state
155
- question = str(state.messages[0].content)
160
+ chat_history = self._get_chat_history(state.messages)
161
+ question = str(state.messages[-1].content)
156
162
  logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
157
163
 
158
- output_message = await self._stream_llm(self.agent, {
159
- "question": question, "agent_scratchpad": agent_scratchpad
160
- },
161
- RunnableConfig(callbacks=self.callbacks))
164
+ output_message = await self._stream_llm(
165
+ self.agent, {
166
+ "question": question, "agent_scratchpad": agent_scratchpad, "chat_history": chat_history
167
+ },
168
+ RunnableConfig(callbacks=self.callbacks))
162
169
 
163
170
  if self.detailed_logs:
164
171
  logger.info(AGENT_CALL_LOG_MESSAGE, question, output_message.content)
@@ -326,7 +333,7 @@ class ReActAgentGraph(DualNodeAgent):
326
333
  return True
327
334
 
328
335
 
329
- def create_react_agent_prompt(config: ReActAgentWorkflowConfig) -> ChatPromptTemplate:
336
+ def create_react_agent_prompt(config: "ReActAgentWorkflowConfig") -> ChatPromptTemplate:
330
337
  """
331
338
  Create a ReAct Agent prompt from the config.
332
339
 
@@ -26,7 +26,7 @@ Use the following format exactly to ask the human to use a tool:
26
26
  Question: the input question you must answer
27
27
  Thought: you should always think about what to do
28
28
  Action: the action to take, should be one of [{tool_names}]
29
- Action Input: the input to the action (if there is no required input, include "Action Input: None")
29
+ Action Input: the input to the action (if there is no required input, include "Action Input: None")
30
30
  Observation: wait for the human to respond with the result from the tool, do not assume the response
31
31
 
32
32
  ... (this Thought/Action/Action Input/Observation can repeat N times. If you do not need to use a tool, or after asking the human to use any tools and waiting for the human to respond, you might know the final answer.)
@@ -37,5 +37,8 @@ Final Answer: the final answer to the original input question
37
37
  """
38
38
 
39
39
  USER_PROMPT = """
40
+ Previous conversation history:
41
+ {chat_history}
42
+
40
43
  Question: {question}
41
44
  """
@@ -21,6 +21,7 @@ from json import JSONDecodeError
21
21
  from langchain_core.callbacks.base import AsyncCallbackHandler
22
22
  from langchain_core.language_models import BaseChatModel
23
23
  from langchain_core.messages.ai import AIMessage
24
+ from langchain_core.messages.base import BaseMessage
24
25
  from langchain_core.messages.human import HumanMessage
25
26
  from langchain_core.messages.tool import ToolMessage
26
27
  from langchain_core.prompts.chat import ChatPromptTemplate
@@ -43,6 +44,7 @@ logger = logging.getLogger(__name__)
43
44
 
44
45
  class ReWOOGraphState(BaseModel):
45
46
  """State schema for the ReWOO Agent Graph"""
47
+ messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReWOO Agent
46
48
  task: HumanMessage = Field(default_factory=lambda: HumanMessage(content="")) # the task provided by user
47
49
  plan: AIMessage = Field(
48
50
  default_factory=lambda: AIMessage(content="")) # the plan generated by the planner to solve the task
@@ -183,10 +185,12 @@ class ReWOOAgentGraph(BaseAgent):
183
185
  if not task:
184
186
  logger.error("%s No task provided to the ReWOO Agent. Please provide a valid task.", AGENT_LOG_PREFIX)
185
187
  return {"result": NO_INPUT_ERROR_MESSAGE}
186
-
188
+ chat_history = self._get_chat_history(state.messages)
187
189
  plan = await self._stream_llm(
188
190
  planner,
189
- {"task": task},
191
+ {
192
+ "task": task, "chat_history": chat_history
193
+ },
190
194
  RunnableConfig(callbacks=self.callbacks) # type: ignore
191
195
  )
192
196
 
@@ -87,6 +87,9 @@ Begin!
87
87
  """
88
88
 
89
89
  PLANNER_USER_PROMPT = """
90
+ Previous conversation history:
91
+ {chat_history}
92
+
90
93
  task: {task}
91
94
  """
92
95
 
@@ -124,8 +124,9 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
124
124
  token_counter=len,
125
125
  start_on="human",
126
126
  include_system=True)
127
- task = HumanMessage(content=messages[0].content)
128
- state = ReWOOGraphState(task=task)
127
+
128
+ task = HumanMessage(content=messages[-1].content)
129
+ state = ReWOOGraphState(messages=messages, task=task)
129
130
 
130
131
  # run the ReWOO Agent Graph
131
132
  state = await graph.ainvoke(state)
@@ -15,11 +15,14 @@
15
15
 
16
16
  # pylint: disable=R0917
17
17
  import logging
18
+ import typing
18
19
 
19
20
  from langchain_core.callbacks.base import AsyncCallbackHandler
20
21
  from langchain_core.language_models import BaseChatModel
22
+ from langchain_core.messages import SystemMessage
21
23
  from langchain_core.messages.base import BaseMessage
22
- from langchain_core.runnables import RunnableConfig
24
+ from langchain_core.runnables import RunnableLambda
25
+ from langchain_core.runnables.config import RunnableConfig
23
26
  from langchain_core.tools import BaseTool
24
27
  from langgraph.prebuilt import ToolNode
25
28
  from pydantic import BaseModel
@@ -30,6 +33,9 @@ from nat.agent.base import AGENT_LOG_PREFIX
30
33
  from nat.agent.base import AgentDecision
31
34
  from nat.agent.dual_node import DualNodeAgent
32
35
 
36
+ if typing.TYPE_CHECKING:
37
+ from nat.agent.tool_calling_agent.register import ToolCallAgentWorkflowConfig
38
+
33
39
  logger = logging.getLogger(__name__)
34
40
 
35
41
 
@@ -43,22 +49,51 @@ class ToolCallAgentGraph(DualNodeAgent):
43
49
  A tool Calling Agent utilizes the tool input parameters to select the optimal tool. Supports handling tool errors.
44
50
  Argument "detailed_logs" toggles logging of inputs, outputs, and intermediate steps."""
45
51
 
46
- def __init__(self,
47
- llm: BaseChatModel,
48
- tools: list[BaseTool],
49
- callbacks: list[AsyncCallbackHandler] = None,
50
- detailed_logs: bool = False,
51
- handle_tool_errors: bool = True):
52
+ def __init__(
53
+ self,
54
+ llm: BaseChatModel,
55
+ tools: list[BaseTool],
56
+ prompt: str | None = None,
57
+ callbacks: list[AsyncCallbackHandler] = None,
58
+ detailed_logs: bool = False,
59
+ handle_tool_errors: bool = True,
60
+ ):
52
61
  super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
62
+ # some LLMs support tool calling
63
+ # these models accept the tool's input schema and decide when to use a tool based on the input's relevance
64
+ try:
65
+ # in tool calling agents, we bind the tools to the LLM, to pass the tools' input schemas at runtime
66
+ self.bound_llm = llm.bind_tools(tools)
67
+ except NotImplementedError as ex:
68
+ logger.error("%s Failed to bind tools: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
69
+ raise ex
70
+
71
+ if prompt is not None:
72
+ system_prompt = SystemMessage(content=prompt)
73
+ prompt_runnable = RunnableLambda(
74
+ lambda state: [system_prompt] + state.get("messages", []),
75
+ name="SystemPrompt",
76
+ )
77
+ else:
78
+ prompt_runnable = RunnableLambda(
79
+ lambda state: state.get("messages", []),
80
+ name="PromptPassthrough",
81
+ )
82
+
83
+ self.agent = prompt_runnable | self.bound_llm
84
+
53
85
  self.tool_caller = ToolNode(tools, handle_tool_errors=handle_tool_errors)
54
86
  logger.debug("%s Initialized Tool Calling Agent Graph", AGENT_LOG_PREFIX)
55
87
 
56
88
  async def agent_node(self, state: ToolCallAgentGraphState):
57
89
  try:
58
- logger.debug('%s Starting the Tool Calling Agent Node', AGENT_LOG_PREFIX)
90
+ logger.debug("%s Starting the Tool Calling Agent Node", AGENT_LOG_PREFIX)
59
91
  if len(state.messages) == 0:
60
92
  raise RuntimeError('No input received in state: "messages"')
61
- response = await self.llm.ainvoke(state.messages, config=RunnableConfig(callbacks=self.callbacks))
93
+ response = await self.agent.ainvoke(
94
+ {"messages": state.messages},
95
+ config=RunnableConfig(callbacks=self.callbacks),
96
+ )
62
97
  if self.detailed_logs:
63
98
  agent_input = "\n".join(str(message.content) for message in state.messages)
64
99
  logger.info(AGENT_CALL_LOG_MESSAGE, agent_input, response)
@@ -75,16 +110,18 @@ class ToolCallAgentGraph(DualNodeAgent):
75
110
  last_message = state.messages[-1]
76
111
  if last_message.tool_calls:
77
112
  # the agent wants to call a tool
78
- logger.debug('%s Agent is calling a tool', AGENT_LOG_PREFIX)
113
+ logger.debug("%s Agent is calling a tool", AGENT_LOG_PREFIX)
79
114
  return AgentDecision.TOOL
80
115
  if self.detailed_logs:
81
116
  logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, state.messages[-1].content)
82
117
  return AgentDecision.END
83
118
  except Exception as ex:
84
- logger.exception("%s Failed to determine whether agent is calling a tool: %s",
85
- AGENT_LOG_PREFIX,
86
- ex,
87
- exc_info=True)
119
+ logger.exception(
120
+ "%s Failed to determine whether agent is calling a tool: %s",
121
+ AGENT_LOG_PREFIX,
122
+ ex,
123
+ exc_info=True,
124
+ )
88
125
  logger.warning("%s Ending graph traversal", AGENT_LOG_PREFIX)
89
126
  return AgentDecision.END
90
127
 
@@ -92,14 +129,15 @@ class ToolCallAgentGraph(DualNodeAgent):
92
129
  try:
93
130
  logger.debug("%s Starting Tool Node", AGENT_LOG_PREFIX)
94
131
  tool_calls = state.messages[-1].tool_calls
95
- tools = [tool.get('name') for tool in tool_calls]
132
+ tools = [tool.get("name") for tool in tool_calls]
96
133
  tool_input = state.messages[-1]
97
- tool_response = await self.tool_caller.ainvoke(input={"messages": [tool_input]},
98
- config=RunnableConfig(callbacks=self.callbacks,
99
- configurable={}))
134
+ tool_response = await self.tool_caller.ainvoke(
135
+ input={"messages": [tool_input]},
136
+ config=RunnableConfig(callbacks=self.callbacks, configurable={}),
137
+ )
100
138
  # this configurable = {} argument is needed due to a bug in LangGraph PreBuilt ToolNode ^
101
139
 
102
- for response in tool_response.get('messages'):
140
+ for response in tool_response.get("messages"):
103
141
  if self.detailed_logs:
104
142
  self._log_tool_response(str(tools), str(tool_input), response.content)
105
143
  state.messages += [response]
@@ -112,8 +150,41 @@ class ToolCallAgentGraph(DualNodeAgent):
112
150
  async def build_graph(self):
113
151
  try:
114
152
  await super()._build_graph(state_schema=ToolCallAgentGraphState)
115
- logger.debug("%s Tool Calling Agent Graph built and compiled successfully", AGENT_LOG_PREFIX)
153
+ logger.debug(
154
+ "%s Tool Calling Agent Graph built and compiled successfully",
155
+ AGENT_LOG_PREFIX,
156
+ )
116
157
  return self.graph
117
158
  except Exception as ex:
118
- logger.exception("%s Failed to build Tool Calling Agent Graph: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
159
+ logger.exception(
160
+ "%s Failed to build Tool Calling Agent Graph: %s",
161
+ AGENT_LOG_PREFIX,
162
+ ex,
163
+ exc_info=ex,
164
+ )
119
165
  raise ex
166
+
167
+
168
+ def create_tool_calling_agent_prompt(config: "ToolCallAgentWorkflowConfig") -> str | None:
169
+ """
170
+ Create a Tool Calling Agent prompt from the config.
171
+
172
+ Args:
173
+ config (ToolCallAgentWorkflowConfig): The config to use for the prompt.
174
+
175
+ Returns:
176
+ ChatPromptTemplate: The Tool Calling Agent prompt.
177
+ """
178
+ # the Tool Calling Agent prompt can be customized via config option system_prompt and additional_instructions.
179
+
180
+ if config.system_prompt:
181
+ prompt_str = config.system_prompt
182
+ else:
183
+ prompt_str = ""
184
+
185
+ if config.additional_instructions:
186
+ prompt_str += f" {config.additional_instructions}"
187
+
188
+ if len(prompt_str) > 0:
189
+ return prompt_str
190
+ return None
@@ -41,6 +41,9 @@ class ToolCallAgentWorkflowConfig(FunctionBaseConfig, name="tool_calling_agent")
41
41
  handle_tool_errors: bool = Field(default=True, description="Specify ability to handle tool calling errors.")
42
42
  description: str = Field(default="Tool Calling Agent Workflow", description="Description of this functions use.")
43
43
  max_iterations: int = Field(default=15, description="Number of tool calls before stoping the tool calling agent.")
44
+ system_prompt: str | None = Field(default=None, description="Provides the system prompt to use with the agent.")
45
+ additional_instructions: str | None = Field(default=None,
46
+ description="Additional instructions appended to the system prompt.")
44
47
 
45
48
 
46
49
  @register_function(config_type=ToolCallAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
@@ -49,10 +52,11 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
49
52
  from langgraph.graph.graph import CompiledGraph
50
53
 
51
54
  from nat.agent.base import AGENT_LOG_PREFIX
55
+ from nat.agent.tool_calling_agent.agent import ToolCallAgentGraph
56
+ from nat.agent.tool_calling_agent.agent import ToolCallAgentGraphState
57
+ from nat.agent.tool_calling_agent.agent import create_tool_calling_agent_prompt
52
58
 
53
- from .agent import ToolCallAgentGraph
54
- from .agent import ToolCallAgentGraphState
55
-
59
+ prompt = create_tool_calling_agent_prompt(config)
56
60
  # we can choose an LLM for the ReAct agent in the config file
57
61
  llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
58
62
  # the agent can run any installed tool, simply install the tool and add it to the config file
@@ -61,18 +65,10 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
61
65
  if not tools:
62
66
  raise ValueError(f"No tools specified for Tool Calling Agent '{config.llm_name}'")
63
67
 
64
- # some LLMs support tool calling
65
- # these models accept the tool's input schema and decide when to use a tool based on the input's relevance
66
- try:
67
- # in tool calling agents, we bind the tools to the LLM, to pass the tools' input schemas at runtime
68
- llm = llm.bind_tools(tools)
69
- except NotImplementedError as ex:
70
- logger.error("%s Failed to bind tools: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
71
- raise ex
72
-
73
68
  # construct the Tool Calling Agent Graph from the configured llm, and tools
74
69
  graph: CompiledGraph = await ToolCallAgentGraph(llm=llm,
75
70
  tools=tools,
71
+ prompt=prompt,
76
72
  detailed_logs=config.verbose,
77
73
  handle_tool_errors=config.handle_tool_errors).build_graph()
78
74
 
nat/cli/type_registry.py CHANGED
@@ -588,8 +588,8 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
588
588
  except KeyError as err:
589
589
  raise KeyError(
590
590
  f"An invalid Embedder config and wrapper combination was supplied. Config: `{config_type}`, "
591
- "Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Embedder client but "
592
- "there is no registered conversion from that Embedder provider to LLM framework: {wrapper_type}. "
591
+ f"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Embedder client but "
592
+ f"there is no registered conversion from that Embedder provider to LLM framework: {wrapper_type}. "
593
593
  "Please provide an Embedder configuration from one of the following providers: "
594
594
  f"{set(self._embedder_client_provider_to_framework.keys())}") from err
595
595
 
@@ -703,8 +703,8 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
703
703
  except KeyError as err:
704
704
  raise KeyError(
705
705
  f"An invalid Retriever config and wrapper combination was supplied. Config: `{config_type}`, "
706
- "Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Retriever client but "
707
- "there is no registered conversion from that Retriever provider to LLM framework: {wrapper_type}. "
706
+ f"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Retriever client but "
707
+ f"there is no registered conversion from that Retriever provider to LLM framework: {wrapper_type}. "
708
708
  "Please provide a Retriever configuration from one of the following providers: "
709
709
  f"{set(self._retriever_client_provider_to_framework.keys())}") from err
710
710
 
@@ -0,0 +1,46 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pydantic import AliasChoices
17
+ from pydantic import ConfigDict
18
+ from pydantic import Field
19
+
20
+ from nat.builder.builder import Builder
21
+ from nat.builder.embedder import EmbedderProviderInfo
22
+ from nat.cli.register_workflow import register_embedder_provider
23
+ from nat.data_models.embedder import EmbedderBaseConfig
24
+ from nat.data_models.retry_mixin import RetryMixin
25
+
26
+
27
+ class AzureOpenAIEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="azure_openai"):
28
+ """An Azure OpenAI embedder provider to be used with an embedder client."""
29
+
30
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
31
+
32
+ api_key: str | None = Field(default=None, description="Azure OpenAI API key to interact with hosted model.")
33
+ api_version: str = Field(default="2025-04-01-preview", description="Azure OpenAI API version.")
34
+ azure_endpoint: str | None = Field(validation_alias=AliasChoices("azure_endpoint", "base_url"),
35
+ serialization_alias="azure_endpoint",
36
+ default=None,
37
+ description="Base URL for the hosted model.")
38
+ azure_deployment: str = Field(validation_alias=AliasChoices("azure_deployment", "model_name", "model"),
39
+ serialization_alias="azure_deployment",
40
+ description="The Azure OpenAI hosted model/deployment name.")
41
+
42
+
43
+ @register_embedder_provider(config_type=AzureOpenAIEmbedderModelConfig)
44
+ async def azure_openai_embedder_model(config: AzureOpenAIEmbedderModelConfig, _builder: Builder):
45
+
46
+ yield EmbedderProviderInfo(config=config, description="An Azure OpenAI model for use with an Embedder client.")
@@ -34,10 +34,9 @@ class OpenAIEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="openai"):
34
34
  model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
35
35
  serialization_alias="model",
36
36
  description="The OpenAI hosted model name.")
37
- max_retries: int = Field(default=2, description="The max number of retries for the request.")
38
37
 
39
38
 
40
39
  @register_embedder_provider(config_type=OpenAIEmbedderModelConfig)
41
- async def openai_llm(config: OpenAIEmbedderModelConfig, builder: Builder):
40
+ async def openai_embedder_model(config: OpenAIEmbedderModelConfig, _builder: Builder):
42
41
 
43
42
  yield EmbedderProviderInfo(config=config, description="An OpenAI model for use with an Embedder client.")
nat/embedder/register.py CHANGED
@@ -18,5 +18,6 @@
18
18
  # isort:skip_file
19
19
 
20
20
  # Import any providers which need to be automatically registered here
21
+ from . import azure_openai_embedder
21
22
  from . import nim_embedder
22
23
  from . import openai_embedder
@@ -0,0 +1,50 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pydantic import AliasChoices
17
+ from pydantic import ConfigDict
18
+ from pydantic import Field
19
+
20
+ from nat.builder.builder import Builder
21
+ from nat.builder.llm import LLMProviderInfo
22
+ from nat.cli.register_workflow import register_llm_provider
23
+ from nat.data_models.llm import LLMBaseConfig
24
+ from nat.data_models.retry_mixin import RetryMixin
25
+
26
+
27
+ class AzureOpenAIModelConfig(LLMBaseConfig, RetryMixin, name="azure_openai"):
28
+ """An Azure OpenAI LLM provider to be used with an LLM client."""
29
+
30
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
31
+
32
+ api_key: str | None = Field(default=None, description="Azure OpenAI API key to interact with hosted model.")
33
+ api_version: str = Field(default="2025-04-01-preview", description="Azure OpenAI API version.")
34
+ azure_endpoint: str | None = Field(validation_alias=AliasChoices("azure_endpoint", "base_url"),
35
+ serialization_alias="azure_endpoint",
36
+ default=None,
37
+ description="Base URL for the hosted model.")
38
+ azure_deployment: str = Field(validation_alias=AliasChoices("azure_deployment", "model_name", "model"),
39
+ serialization_alias="azure_deployment",
40
+ description="The Azure OpenAI hosted model/deployment name.")
41
+ temperature: float = Field(default=0.0, description="Sampling temperature in [0, 1].")
42
+ top_p: float = Field(default=1.0, description="Top-p for distribution sampling.")
43
+ seed: int | None = Field(default=None, description="Random seed to set for generation.")
44
+ max_retries: int = Field(default=10, description="The max number of retries for the request.")
45
+
46
+
47
+ @register_llm_provider(config_type=AzureOpenAIModelConfig)
48
+ async def azure_openai_llm(config: AzureOpenAIModelConfig, _builder: Builder):
49
+
50
+ yield LLMProviderInfo(config=config, description="An Azure OpenAI model for use with an LLM client.")
nat/llm/register.py CHANGED
@@ -19,5 +19,6 @@
19
19
 
20
20
  # Import any providers which need to be automatically registered here
21
21
  from . import aws_bedrock_llm
22
+ from . import azure_openai_llm
22
23
  from . import nim_llm
23
24
  from . import openai_llm
nat/meta/pypi.md CHANGED
@@ -23,19 +23,19 @@ NeMo Agent toolkit is a flexible library designed to seamlessly integrate your e
23
23
 
24
24
  ## Key Features
25
25
 
26
- - [**Framework Agnostic:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2.0/extend/plugins.html) Works with any agentic framework, so you can use your current technology stack without replatforming.
27
- - [**Reusability:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2.0/extend/sharing-components.html) Every agent, tool, or workflow can be combined and repurposed, allowing developers to leverage existing work in new scenarios.
28
- - [**Rapid Development:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2.0/tutorials/index.html) Start with a pre-built agent, tool, or workflow, and customize it to your needs.
29
- - [**Profiling:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2.0/workflows/profiler.html) Profile entire workflows down to the tool and agent level, track input/output tokens and timings, and identify bottlenecks.
30
- - [**Observability:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2.0/workflows/observe/observe-workflow-with-phoenix.html) Monitor and debug your workflows with any OpenTelemetry-compatible observability tool, with examples using [Phoenix](https://docs.nvidia.com/nemo/agent-toolkit/1.2.0/workflows/observe/observe-workflow-with-phoenix.html) and [W&B Weave](https://docs.nvidia.com/nemo/agent-toolkit/1.2.0/workflows/observe/observe-workflow-with-weave.html).
31
- - [**Evaluation System:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2.0/workflows/evaluate.html) Validate and maintain accuracy of agentic workflows with built-in evaluation tools.
32
- - [**User Interface:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2.0/quick-start/launching-ui.html) Use the NeMo Agent toolkit UI chat interface to interact with your agents, visualize output, and debug workflows.
33
- - [**MCP Compatibility**](https://docs.nvidia.com/nemo/agent-toolkit/1.2.0/workflows/mcp/mcp-client.html) Compatible with Model Context Protocol (MCP), allowing tools served by MCP Servers to be used as NeMo Agent toolkit functions.
26
+ - [**Framework Agnostic:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2/extend/plugins.html) Works with any agentic framework, so you can use your current technology stack without replatforming.
27
+ - [**Reusability:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2/extend/sharing-components.html) Every agent, tool, or workflow can be combined and repurposed, allowing developers to leverage existing work in new scenarios.
28
+ - [**Rapid Development:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2/tutorials/index.html) Start with a pre-built agent, tool, or workflow, and customize it to your needs.
29
+ - [**Profiling:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2/workflows/profiler.html) Profile entire workflows down to the tool and agent level, track input/output tokens and timings, and identify bottlenecks.
30
+ - [**Observability:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2/workflows/observe/observe-workflow-with-phoenix.html) Monitor and debug your workflows with any OpenTelemetry-compatible observability tool, with examples using [Phoenix](https://docs.nvidia.com/nemo/agent-toolkit/1.2/workflows/observe/observe-workflow-with-phoenix.html) and [W&B Weave](https://docs.nvidia.com/nemo/agent-toolkit/1.2/workflows/observe/observe-workflow-with-weave.html).
31
+ - [**Evaluation System:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2/workflows/evaluate.html) Validate and maintain accuracy of agentic workflows with built-in evaluation tools.
32
+ - [**User Interface:**](https://docs.nvidia.com/nemo/agent-toolkit/1.2/quick-start/launching-ui.html) Use the NeMo Agent toolkit UI chat interface to interact with your agents, visualize output, and debug workflows.
33
+ - [**MCP Compatibility**](https://docs.nvidia.com/nemo/agent-toolkit/1.2/workflows/mcp/mcp-client.html) Compatible with Model Context Protocol (MCP), allowing tools served by MCP Servers to be used as NeMo Agent toolkit functions.
34
34
 
35
35
  With NeMo Agent toolkit, you can move quickly, experiment freely, and ensure reliability across all your agent-driven projects.
36
36
 
37
37
  ## Links
38
- * [Documentation](https://docs.nvidia.com/nemo/agent-toolkit/1.2.0/index.html): Explore the full documentation for NeMo Agent toolkit.
38
+ * [Documentation](https://docs.nvidia.com/nemo/agent-toolkit/1.2/index.html): Explore the full documentation for NeMo Agent toolkit.
39
39
 
40
40
  ## First time user?
41
41
  If this is your first time using NeMo Agent toolkit, it is recommended to install the latest version from the [source repository](https://github.com/NVIDIA/NeMo-Agent-Toolkit?tab=readme-ov-file#quick-start) on GitHub. This package is intended for users who are familiar with NeMo Agent toolkit applications and need to add NeMo Agent toolkit as a dependency to their project.