nvidia-nat 1.4.0a20251015__py3-none-any.whl → 1.4.0a20251021__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +3 -3
- nat/agent/reasoning_agent/reasoning_agent.py +6 -6
- nat/agent/register.py +1 -0
- nat/agent/responses_api_agent/__init__.py +14 -0
- nat/agent/responses_api_agent/register.py +126 -0
- nat/agent/tool_calling_agent/agent.py +6 -10
- nat/builder/context.py +2 -1
- nat/builder/intermediate_step_manager.py +6 -2
- nat/data_models/api_server.py +83 -33
- nat/data_models/intermediate_step.py +9 -1
- nat/data_models/llm.py +15 -1
- nat/data_models/openai_mcp.py +46 -0
- nat/data_models/optimizable.py +2 -1
- nat/data_models/thinking_mixin.py +2 -2
- nat/eval/evaluate.py +2 -0
- nat/eval/usage_stats.py +2 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +3 -0
- nat/front_ends/fastapi/message_handler.py +65 -40
- nat/front_ends/fastapi/message_validator.py +1 -2
- nat/front_ends/mcp/mcp_front_end_config.py +32 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +9 -6
- nat/llm/aws_bedrock_llm.py +3 -3
- nat/llm/litellm_llm.py +6 -3
- nat/llm/nim_llm.py +3 -3
- nat/llm/openai_llm.py +4 -3
- nat/profiler/callbacks/langchain_callback_handler.py +32 -7
- nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
- nat/profiler/callbacks/token_usage_base_model.py +2 -0
- nat/utils/exception_handlers/automatic_retries.py +205 -54
- nat/utils/responses_api.py +26 -0
- {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/METADATA +4 -4
- {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/RECORD +37 -33
- {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251015.dist-info → nvidia_nat-1.4.0a20251021.dist-info}/top_level.txt +0 -0
nat/agent/base.py
CHANGED
|
@@ -102,11 +102,11 @@ class BaseAgent(ABC):
|
|
|
102
102
|
AIMessage
|
|
103
103
|
The LLM response
|
|
104
104
|
"""
|
|
105
|
-
output_message =
|
|
105
|
+
output_message = []
|
|
106
106
|
async for event in runnable.astream(inputs, config=config):
|
|
107
|
-
output_message
|
|
107
|
+
output_message.append(event.content)
|
|
108
108
|
|
|
109
|
-
return AIMessage(content=output_message)
|
|
109
|
+
return AIMessage(content="".join(output_message))
|
|
110
110
|
|
|
111
111
|
async def _call_llm(self, llm: Runnable, inputs: dict[str, Any], config: RunnableConfig | None = None) -> AIMessage:
|
|
112
112
|
"""
|
|
@@ -157,12 +157,12 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
|
|
|
157
157
|
prompt = prompt.to_string()
|
|
158
158
|
|
|
159
159
|
# Get the reasoning output from the LLM
|
|
160
|
-
reasoning_output =
|
|
160
|
+
reasoning_output = []
|
|
161
161
|
|
|
162
162
|
async for chunk in llm.astream(prompt):
|
|
163
|
-
reasoning_output
|
|
163
|
+
reasoning_output.append(chunk.content)
|
|
164
164
|
|
|
165
|
-
reasoning_output = remove_r1_think_tags(reasoning_output)
|
|
165
|
+
reasoning_output = remove_r1_think_tags("".join(reasoning_output))
|
|
166
166
|
|
|
167
167
|
output = await downstream_template.ainvoke(input={
|
|
168
168
|
"input_text": input_text, "reasoning_output": reasoning_output
|
|
@@ -200,12 +200,12 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
|
|
|
200
200
|
prompt = prompt.to_string()
|
|
201
201
|
|
|
202
202
|
# Get the reasoning output from the LLM
|
|
203
|
-
reasoning_output =
|
|
203
|
+
reasoning_output = []
|
|
204
204
|
|
|
205
205
|
async for chunk in llm.astream(prompt):
|
|
206
|
-
reasoning_output
|
|
206
|
+
reasoning_output.append(chunk.content)
|
|
207
207
|
|
|
208
|
-
reasoning_output = remove_r1_think_tags(reasoning_output)
|
|
208
|
+
reasoning_output = remove_r1_think_tags("".join(reasoning_output))
|
|
209
209
|
|
|
210
210
|
output = await downstream_template.ainvoke(input={
|
|
211
211
|
"input_text": input_text, "reasoning_output": reasoning_output
|
nat/agent/register.py
CHANGED
|
@@ -19,5 +19,6 @@
|
|
|
19
19
|
from .prompt_optimizer import register as prompt_optimizer
|
|
20
20
|
from .react_agent import register as react_agent
|
|
21
21
|
from .reasoning_agent import reasoning_agent
|
|
22
|
+
from .responses_api_agent import register as responses_api_agent
|
|
22
23
|
from .rewoo_agent import register as rewoo_agent
|
|
23
24
|
from .tool_calling_agent import register as tool_calling_agent
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import typing
|
|
18
|
+
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from nat.agent.base import AGENT_LOG_PREFIX
|
|
22
|
+
from nat.builder.builder import Builder
|
|
23
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
24
|
+
from nat.builder.function_info import FunctionInfo
|
|
25
|
+
from nat.cli.register_workflow import register_function
|
|
26
|
+
from nat.data_models.component_ref import FunctionRef
|
|
27
|
+
from nat.data_models.component_ref import LLMRef
|
|
28
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
29
|
+
from nat.data_models.openai_mcp import OpenAIMCPSchemaTool
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ResponsesAPIAgentWorkflowConfig(FunctionBaseConfig, name="responses_api_agent"):
|
|
35
|
+
"""
|
|
36
|
+
Defines an NeMo Agent Toolkit function that uses a Responses API
|
|
37
|
+
Agent performs reasoning inbetween tool calls, and utilizes the
|
|
38
|
+
tool names and descriptions to select the optimal tool.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
llm_name: LLMRef = Field(description="The LLM model to use with the agent.")
|
|
42
|
+
verbose: bool = Field(default=False, description="Set the verbosity of the agent's logging.")
|
|
43
|
+
nat_tools: list[FunctionRef] = Field(default_factory=list, description="The list of tools to provide to the agent.")
|
|
44
|
+
mcp_tools: list[OpenAIMCPSchemaTool] = Field(
|
|
45
|
+
default_factory=list,
|
|
46
|
+
description="List of MCP tools to use with the agent. If empty, no MCP tools will be used.")
|
|
47
|
+
builtin_tools: list[dict[str, typing.Any]] = Field(
|
|
48
|
+
default_factory=list,
|
|
49
|
+
description="List of built-in tools to use with the agent. If empty, no built-in tools will be used.")
|
|
50
|
+
|
|
51
|
+
max_iterations: int = Field(default=15, description="Number of tool calls before stoping the agent.")
|
|
52
|
+
description: str = Field(default="Agent Workflow", description="The description of this functions use.")
|
|
53
|
+
parallel_tool_calls: bool = Field(default=False,
|
|
54
|
+
description="Specify whether to allow parallel tool calls in the agent.")
|
|
55
|
+
handle_tool_errors: bool = Field(
|
|
56
|
+
default=True,
|
|
57
|
+
description="Specify ability to handle tool calling errors. If False, tool errors will raise an exception.")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@register_function(config_type=ResponsesAPIAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
61
|
+
async def responses_api_agent_workflow(config: ResponsesAPIAgentWorkflowConfig, builder: Builder):
|
|
62
|
+
from langchain_core.messages.human import HumanMessage
|
|
63
|
+
from langchain_core.runnables import Runnable
|
|
64
|
+
from langchain_openai import ChatOpenAI
|
|
65
|
+
|
|
66
|
+
from nat.agent.tool_calling_agent.agent import ToolCallAgentGraph
|
|
67
|
+
from nat.agent.tool_calling_agent.agent import ToolCallAgentGraphState
|
|
68
|
+
|
|
69
|
+
llm: ChatOpenAI = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
70
|
+
assert llm.use_responses_api, "Responses API Agent requires an LLM that supports the Responses API."
|
|
71
|
+
|
|
72
|
+
# Get tools
|
|
73
|
+
tools = []
|
|
74
|
+
nat_tools = await builder.get_tools(tool_names=config.nat_tools, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
75
|
+
tools.extend(nat_tools)
|
|
76
|
+
# MCP tools are optional, if provided they will be used by the agent
|
|
77
|
+
tools.extend([m.model_dump() for m in config.mcp_tools])
|
|
78
|
+
# Built-in tools are optional, if provided they will be used by the agent
|
|
79
|
+
tools.extend(config.builtin_tools)
|
|
80
|
+
|
|
81
|
+
# Bind tools to LLM
|
|
82
|
+
if tools:
|
|
83
|
+
llm: Runnable = llm.bind_tools(tools=tools, parallel_tool_calls=config.parallel_tool_calls, strict=True)
|
|
84
|
+
|
|
85
|
+
if config.verbose:
|
|
86
|
+
logger.info("%s Using LLM: %s with tools: %s", AGENT_LOG_PREFIX, llm.model_name, tools)
|
|
87
|
+
|
|
88
|
+
agent = ToolCallAgentGraph(
|
|
89
|
+
llm=llm,
|
|
90
|
+
tools=nat_tools, # MCP and built-in tools are already bound to the LLM and need not be handled by graph
|
|
91
|
+
detailed_logs=config.verbose,
|
|
92
|
+
handle_tool_errors=config.handle_tool_errors)
|
|
93
|
+
|
|
94
|
+
graph = await agent.build_graph()
|
|
95
|
+
|
|
96
|
+
async def _response_fn(input_message: str) -> str:
|
|
97
|
+
try:
|
|
98
|
+
# initialize the starting state with the user query
|
|
99
|
+
input_message = HumanMessage(content=input_message)
|
|
100
|
+
state = ToolCallAgentGraphState(messages=[input_message])
|
|
101
|
+
|
|
102
|
+
# run the Tool Calling Agent Graph
|
|
103
|
+
state = await graph.ainvoke(state, config={'recursion_limit': (config.max_iterations + 1) * 2})
|
|
104
|
+
# setting recursion_limit: 4 allows 1 tool call
|
|
105
|
+
# - allows the Tool Calling Agent to perform 1 cycle / call 1 single tool,
|
|
106
|
+
# - but stops the agent when it tries to call a tool a second time
|
|
107
|
+
|
|
108
|
+
# get and return the output from the state
|
|
109
|
+
state = ToolCallAgentGraphState(**state)
|
|
110
|
+
output_message = state.messages[-1] # pylint: disable=E1136
|
|
111
|
+
content = output_message.content[-1]['text'] if output_message.content and isinstance(
|
|
112
|
+
output_message.content[-1], dict) and 'text' in output_message.content[-1] else str(
|
|
113
|
+
output_message.content)
|
|
114
|
+
return content
|
|
115
|
+
except Exception as ex:
|
|
116
|
+
logger.exception("%s Tool Calling Agent failed with exception: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
|
|
117
|
+
if config.verbose:
|
|
118
|
+
return str(ex)
|
|
119
|
+
return "I seem to be having a problem."
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
123
|
+
except GeneratorExit:
|
|
124
|
+
logger.exception("%s Workflow exited early!", AGENT_LOG_PREFIX, exc_info=True)
|
|
125
|
+
finally:
|
|
126
|
+
logger.debug("%s Cleaning up react_agent workflow.", AGENT_LOG_PREFIX)
|
|
@@ -233,14 +233,10 @@ def create_tool_calling_agent_prompt(config: "ToolCallAgentWorkflowConfig") -> s
|
|
|
233
233
|
"""
|
|
234
234
|
# the Tool Calling Agent prompt can be customized via config option system_prompt and additional_instructions.
|
|
235
235
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
prompt_str += f" {config.additional_instructions}"
|
|
243
|
-
|
|
244
|
-
if len(prompt_str) > 0:
|
|
245
|
-
return prompt_str
|
|
236
|
+
prompt_strs = []
|
|
237
|
+
for msg in [config.system_prompt, config.additional_instructions]:
|
|
238
|
+
if msg is not None:
|
|
239
|
+
prompt_strs.append(msg)
|
|
240
|
+
if prompt_strs:
|
|
241
|
+
return " ".join(prompt_strs)
|
|
246
242
|
return None
|
nat/builder/context.py
CHANGED
|
@@ -19,6 +19,7 @@ from collections.abc import Awaitable
|
|
|
19
19
|
from collections.abc import Callable
|
|
20
20
|
from contextlib import contextmanager
|
|
21
21
|
from contextvars import ContextVar
|
|
22
|
+
from functools import cached_property
|
|
22
23
|
|
|
23
24
|
from nat.builder.intermediate_step_manager import IntermediateStepManager
|
|
24
25
|
from nat.builder.user_interaction_manager import UserInteractionManager
|
|
@@ -167,7 +168,7 @@ class Context:
|
|
|
167
168
|
"""
|
|
168
169
|
return UserInteractionManager(self._context_state)
|
|
169
170
|
|
|
170
|
-
@
|
|
171
|
+
@cached_property
|
|
171
172
|
def intermediate_step_manager(self) -> IntermediateStepManager:
|
|
172
173
|
"""
|
|
173
174
|
Retrieves the intermediate step manager instance from the current context state.
|
|
@@ -101,7 +101,10 @@ class IntermediateStepManager:
|
|
|
101
101
|
open_step = self._outstanding_start_steps.pop(payload.UUID, None)
|
|
102
102
|
|
|
103
103
|
if (open_step is None):
|
|
104
|
-
logger.warning(
|
|
104
|
+
logger.warning(
|
|
105
|
+
"Step id %s not found in outstanding start steps. "
|
|
106
|
+
"This may occur if the step was started in a different context or already completed.",
|
|
107
|
+
payload.UUID)
|
|
105
108
|
return
|
|
106
109
|
|
|
107
110
|
parent_step_id = open_step.step_parent_id
|
|
@@ -157,7 +160,8 @@ class IntermediateStepManager:
|
|
|
157
160
|
if (open_step is None):
|
|
158
161
|
logger.warning(
|
|
159
162
|
"Created a chunk for step %s, but no matching start step was found. "
|
|
160
|
-
"Chunks must be created with the same ID as the start step."
|
|
163
|
+
"Chunks must be created with the same ID as the start step. "
|
|
164
|
+
"This may occur if the step was started in a different context.",
|
|
161
165
|
payload.UUID)
|
|
162
166
|
return
|
|
163
167
|
|
nat/data_models/api_server.py
CHANGED
|
@@ -121,7 +121,15 @@ class Message(BaseModel):
|
|
|
121
121
|
role: UserMessageContentRoleType
|
|
122
122
|
|
|
123
123
|
|
|
124
|
-
class
|
|
124
|
+
class ChatRequest(BaseModel):
|
|
125
|
+
"""
|
|
126
|
+
ChatRequest is a data model that represents a request to the NAT chat API.
|
|
127
|
+
Fully compatible with OpenAI Chat Completions API specification.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
# Required fields
|
|
131
|
+
messages: typing.Annotated[list[Message], conlist(Message, min_length=1)]
|
|
132
|
+
|
|
125
133
|
# Optional fields (OpenAI Chat Completions API compatible)
|
|
126
134
|
model: str | None = Field(default=None, description="name of the model to use")
|
|
127
135
|
frequency_penalty: float | None = Field(default=0.0,
|
|
@@ -145,17 +153,6 @@ class ChatRequestOptionals(BaseModel):
|
|
|
145
153
|
tool_choice: str | dict[str, typing.Any] | None = Field(default=None, description="Controls which tool is called")
|
|
146
154
|
parallel_tool_calls: bool | None = Field(default=True, description="Whether to enable parallel function calling")
|
|
147
155
|
user: str | None = Field(default=None, description="Unique identifier representing end-user")
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
class ChatRequest(ChatRequestOptionals):
|
|
151
|
-
"""
|
|
152
|
-
ChatRequest is a data model that represents a request to the NAT chat API.
|
|
153
|
-
Fully compatible with OpenAI Chat Completions API specification.
|
|
154
|
-
"""
|
|
155
|
-
|
|
156
|
-
# Required fields
|
|
157
|
-
messages: typing.Annotated[list[Message], conlist(Message, min_length=1)]
|
|
158
|
-
|
|
159
156
|
model_config = ConfigDict(extra="allow",
|
|
160
157
|
json_schema_extra={
|
|
161
158
|
"example": {
|
|
@@ -197,39 +194,82 @@ class ChatRequest(ChatRequestOptionals):
|
|
|
197
194
|
top_p=top_p)
|
|
198
195
|
|
|
199
196
|
|
|
200
|
-
class ChatRequestOrMessage(
|
|
197
|
+
class ChatRequestOrMessage(BaseModel):
|
|
201
198
|
"""
|
|
202
|
-
ChatRequestOrMessage is a data model that represents either a conversation or a string input.
|
|
199
|
+
`ChatRequestOrMessage` is a data model that represents either a conversation or a string input.
|
|
203
200
|
This is useful for functions that can handle either type of input.
|
|
204
201
|
|
|
205
|
-
`messages` is compatible with the OpenAI Chat Completions API specification.
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
202
|
+
- `messages` is compatible with the OpenAI Chat Completions API specification.
|
|
203
|
+
- `input_message` is a string input that can be used for functions that do not require a conversation.
|
|
204
|
+
|
|
205
|
+
Note: When `messages` is provided, extra fields are allowed to enable lossless round-trip
|
|
206
|
+
conversion with ChatRequest. When `input_message` is provided, no extra fields are permitted.
|
|
207
|
+
"""
|
|
208
|
+
model_config = ConfigDict(
|
|
209
|
+
extra="allow",
|
|
210
|
+
json_schema_extra={
|
|
211
|
+
"examples": [
|
|
212
|
+
{
|
|
213
|
+
"input_message": "What can you do?"
|
|
214
|
+
},
|
|
215
|
+
{
|
|
216
|
+
"messages": [{
|
|
217
|
+
"role": "user", "content": "What can you do?"
|
|
218
|
+
}],
|
|
219
|
+
"model": "nvidia/nemotron",
|
|
220
|
+
"temperature": 0.7
|
|
221
|
+
},
|
|
222
|
+
],
|
|
223
|
+
"oneOf": [
|
|
224
|
+
{
|
|
225
|
+
"required": ["input_message"],
|
|
226
|
+
"properties": {
|
|
227
|
+
"input_message": {
|
|
228
|
+
"type": "string"
|
|
229
|
+
},
|
|
230
|
+
},
|
|
231
|
+
"additionalProperties": {
|
|
232
|
+
"not": True, "errorMessage": 'remove additional property ${0#}'
|
|
233
|
+
},
|
|
234
|
+
},
|
|
235
|
+
{
|
|
236
|
+
"required": ["messages"],
|
|
237
|
+
"properties": {
|
|
238
|
+
"messages": {
|
|
239
|
+
"type": "array"
|
|
240
|
+
},
|
|
241
|
+
},
|
|
242
|
+
"additionalProperties": True
|
|
243
|
+
},
|
|
244
|
+
]
|
|
245
|
+
},
|
|
246
|
+
)
|
|
209
247
|
|
|
210
248
|
messages: typing.Annotated[list[Message] | None, conlist(Message, min_length=1)] = Field(
|
|
211
|
-
default=None, description="
|
|
249
|
+
default=None, description="A non-empty conversation of messages to process.")
|
|
212
250
|
|
|
213
|
-
|
|
251
|
+
input_message: str | None = Field(
|
|
252
|
+
default=None,
|
|
253
|
+
description="A single input message to process. Useful for functions that do not require a conversation")
|
|
214
254
|
|
|
215
255
|
@property
|
|
216
256
|
def is_string(self) -> bool:
|
|
217
|
-
return self.
|
|
257
|
+
return self.input_message is not None
|
|
218
258
|
|
|
219
259
|
@property
|
|
220
260
|
def is_conversation(self) -> bool:
|
|
221
261
|
return self.messages is not None
|
|
222
262
|
|
|
223
263
|
@model_validator(mode="after")
|
|
224
|
-
def
|
|
225
|
-
if self.messages is not None and self.
|
|
226
|
-
raise ValueError("Either messages or input_message
|
|
227
|
-
if self.messages is None and self.
|
|
228
|
-
raise ValueError("Either messages or input_message
|
|
229
|
-
if self.
|
|
230
|
-
extra_fields = self.model_dump(exclude={"
|
|
264
|
+
def validate_model(self):
|
|
265
|
+
if self.messages is not None and self.input_message is not None:
|
|
266
|
+
raise ValueError("Either messages or input_message must be provided, not both")
|
|
267
|
+
if self.messages is None and self.input_message is None:
|
|
268
|
+
raise ValueError("Either messages or input_message must be provided")
|
|
269
|
+
if self.input_message is not None:
|
|
270
|
+
extra_fields = self.model_dump(exclude={"input_message"}, exclude_none=True, exclude_unset=True)
|
|
231
271
|
if len(extra_fields) > 0:
|
|
232
|
-
raise ValueError("no extra fields are permitted when input_message
|
|
272
|
+
raise ValueError("no extra fields are permitted when input_message is provided")
|
|
233
273
|
return self
|
|
234
274
|
|
|
235
275
|
|
|
@@ -701,9 +741,9 @@ GlobalTypeConverter.register_converter(_string_to_nat_chat_request)
|
|
|
701
741
|
|
|
702
742
|
|
|
703
743
|
def _chat_request_or_message_to_chat_request(data: ChatRequestOrMessage) -> ChatRequest:
|
|
704
|
-
if data.
|
|
705
|
-
return _string_to_nat_chat_request(data.
|
|
706
|
-
return ChatRequest(**data.model_dump(exclude={"
|
|
744
|
+
if data.input_message is not None:
|
|
745
|
+
return _string_to_nat_chat_request(data.input_message)
|
|
746
|
+
return ChatRequest(**data.model_dump(exclude={"input_message"}))
|
|
707
747
|
|
|
708
748
|
|
|
709
749
|
GlobalTypeConverter.register_converter(_chat_request_or_message_to_chat_request)
|
|
@@ -717,7 +757,17 @@ GlobalTypeConverter.register_converter(_chat_request_to_chat_request_or_message)
|
|
|
717
757
|
|
|
718
758
|
|
|
719
759
|
def _chat_request_or_message_to_string(data: ChatRequestOrMessage) -> str:
|
|
720
|
-
|
|
760
|
+
if data.input_message is not None:
|
|
761
|
+
return data.input_message
|
|
762
|
+
# Extract content from last message in conversation
|
|
763
|
+
if data.messages is None:
|
|
764
|
+
return ""
|
|
765
|
+
content = data.messages[-1].content
|
|
766
|
+
if content is None:
|
|
767
|
+
return ""
|
|
768
|
+
if isinstance(content, str):
|
|
769
|
+
return content
|
|
770
|
+
return str(content)
|
|
721
771
|
|
|
722
772
|
|
|
723
773
|
GlobalTypeConverter.register_converter(_chat_request_or_message_to_string)
|
|
@@ -103,11 +103,19 @@ class ToolSchema(BaseModel):
|
|
|
103
103
|
function: ToolDetails = Field(..., description="The function details.")
|
|
104
104
|
|
|
105
105
|
|
|
106
|
+
class ServerToolUseSchema(BaseModel):
|
|
107
|
+
name: str
|
|
108
|
+
arguments: str | dict[str, typing.Any] | typing.Any
|
|
109
|
+
output: typing.Any
|
|
110
|
+
|
|
111
|
+
model_config = ConfigDict(extra="ignore")
|
|
112
|
+
|
|
113
|
+
|
|
106
114
|
class TraceMetadata(BaseModel):
|
|
107
115
|
chat_responses: typing.Any | None = None
|
|
108
116
|
chat_inputs: typing.Any | None = None
|
|
109
117
|
tool_inputs: typing.Any | None = None
|
|
110
|
-
tool_outputs: typing.Any | None = None
|
|
118
|
+
tool_outputs: list[ServerToolUseSchema] | typing.Any | None = None
|
|
111
119
|
tool_info: typing.Any | None = None
|
|
112
120
|
span_inputs: typing.Any | None = None
|
|
113
121
|
span_outputs: typing.Any | None = None
|
nat/data_models/llm.py
CHANGED
|
@@ -14,14 +14,28 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import typing
|
|
17
|
+
from enum import Enum
|
|
18
|
+
|
|
19
|
+
from pydantic import Field
|
|
17
20
|
|
|
18
21
|
from .common import BaseModelRegistryTag
|
|
19
22
|
from .common import TypedBaseModel
|
|
20
23
|
|
|
21
24
|
|
|
25
|
+
class APITypeEnum(str, Enum):
|
|
26
|
+
CHAT_COMPLETION = "chat_completion"
|
|
27
|
+
RESPONSES = "responses"
|
|
28
|
+
|
|
29
|
+
|
|
22
30
|
class LLMBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
23
31
|
"""Base configuration for LLM providers."""
|
|
24
|
-
|
|
32
|
+
|
|
33
|
+
api_type: APITypeEnum = Field(default=APITypeEnum.CHAT_COMPLETION,
|
|
34
|
+
description="The type of API to use for the LLM provider.",
|
|
35
|
+
json_schema_extra={
|
|
36
|
+
"enum": [e.value for e in APITypeEnum],
|
|
37
|
+
"examples": [e.value for e in APITypeEnum],
|
|
38
|
+
})
|
|
25
39
|
|
|
26
40
|
|
|
27
41
|
LLMBaseConfigT = typing.TypeVar("LLMBaseConfigT", bound=LLMBaseConfig)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from enum import Enum
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import ConfigDict
|
|
20
|
+
from pydantic import Field
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MCPApprovalRequiredEnum(str, Enum):
|
|
24
|
+
"""
|
|
25
|
+
Enum to specify if approval is required for tool usage in the OpenAI MCP schema.
|
|
26
|
+
"""
|
|
27
|
+
NEVER = "never"
|
|
28
|
+
ALWAYS = "always"
|
|
29
|
+
AUTO = "auto"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class OpenAIMCPSchemaTool(BaseModel):
|
|
33
|
+
"""
|
|
34
|
+
Represents a tool in the OpenAI MCP schema.
|
|
35
|
+
"""
|
|
36
|
+
type: str = "mcp"
|
|
37
|
+
server_label: str = Field(description="Label for the server where the tool is hosted.")
|
|
38
|
+
server_url: str = Field(description="URL of the server hosting the tool.")
|
|
39
|
+
allowed_tools: list[str] | None = Field(default=None,
|
|
40
|
+
description="List of allowed tool names that can be used by the agent.")
|
|
41
|
+
require_approval: MCPApprovalRequiredEnum = Field(default=MCPApprovalRequiredEnum.NEVER,
|
|
42
|
+
description="Specifies if approval is required for tool usage.")
|
|
43
|
+
headers: dict[str, str] | None = Field(default=None,
|
|
44
|
+
description="Optional headers to include in requests to the tool server.")
|
|
45
|
+
|
|
46
|
+
model_config = ConfigDict(use_enum_values=True)
|
nat/data_models/optimizable.py
CHANGED
|
@@ -23,6 +23,7 @@ from pydantic import BaseModel
|
|
|
23
23
|
from pydantic import ConfigDict
|
|
24
24
|
from pydantic import Field
|
|
25
25
|
from pydantic import model_validator
|
|
26
|
+
from pydantic_core import PydanticUndefined
|
|
26
27
|
|
|
27
28
|
T = TypeVar("T", int, float, bool, str)
|
|
28
29
|
|
|
@@ -66,7 +67,7 @@ class SearchSpace(BaseModel, Generic[T]):
|
|
|
66
67
|
|
|
67
68
|
|
|
68
69
|
def OptimizableField(
|
|
69
|
-
default: Any,
|
|
70
|
+
default: Any = PydanticUndefined,
|
|
70
71
|
*,
|
|
71
72
|
space: SearchSpace | None = None,
|
|
72
73
|
merge_conflict: str = "overwrite",
|
|
@@ -51,7 +51,7 @@ class ThinkingMixin(
|
|
|
51
51
|
Returns the system prompt to use for thinking.
|
|
52
52
|
For NVIDIA Nemotron, returns "/think" if enabled, else "/no_think".
|
|
53
53
|
For Llama Nemotron v1.5, returns "/think" if enabled, else "/no_think".
|
|
54
|
-
For Llama Nemotron v1.0, returns "detailed thinking on" if enabled, else "detailed thinking off".
|
|
54
|
+
For Llama Nemotron v1.0 or v1.1, returns "detailed thinking on" if enabled, else "detailed thinking off".
|
|
55
55
|
If thinking is not supported on the model, returns None.
|
|
56
56
|
|
|
57
57
|
Returns:
|
|
@@ -72,7 +72,7 @@ class ThinkingMixin(
|
|
|
72
72
|
return "/think" if self.thinking else "/no_think"
|
|
73
73
|
|
|
74
74
|
if model.startswith("nvidia/llama"):
|
|
75
|
-
if "v1-0" in model or "v1-1" in model:
|
|
75
|
+
if "v1-0" in model or "v1-1" in model or model.endswith("v1"):
|
|
76
76
|
return f"detailed thinking {'on' if self.thinking else 'off'}"
|
|
77
77
|
|
|
78
78
|
if "v1-5" in model:
|
nat/eval/evaluate.py
CHANGED
|
@@ -104,6 +104,8 @@ class EvaluationRun:
|
|
|
104
104
|
usage_stats_per_llm[llm_name].prompt_tokens += step.token_usage.prompt_tokens
|
|
105
105
|
usage_stats_per_llm[llm_name].completion_tokens += step.token_usage.completion_tokens
|
|
106
106
|
usage_stats_per_llm[llm_name].total_tokens += step.token_usage.total_tokens
|
|
107
|
+
usage_stats_per_llm[llm_name].reasoning_tokens += step.token_usage.reasoning_tokens
|
|
108
|
+
usage_stats_per_llm[llm_name].cached_tokens += step.token_usage.cached_tokens
|
|
107
109
|
total_tokens += step.token_usage.total_tokens
|
|
108
110
|
|
|
109
111
|
# find min and max event timestamps
|
nat/eval/usage_stats.py
CHANGED
|
@@ -1184,6 +1184,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
1184
1184
|
"server": client.server_name,
|
|
1185
1185
|
"transport": config.server.transport,
|
|
1186
1186
|
"session_healthy": session_healthy,
|
|
1187
|
+
"protected": True if config.server.auth_provider is not None else False,
|
|
1187
1188
|
"tools": tools_info,
|
|
1188
1189
|
"total_tools": len(configured_short_names),
|
|
1189
1190
|
"available_tools": available_count
|
|
@@ -1196,6 +1197,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
1196
1197
|
"server": "unknown",
|
|
1197
1198
|
"transport": config.server.transport if config.server else "unknown",
|
|
1198
1199
|
"session_healthy": False,
|
|
1200
|
+
"protected": False,
|
|
1199
1201
|
"error": str(e),
|
|
1200
1202
|
"tools": [],
|
|
1201
1203
|
"total_tools": 0,
|
|
@@ -1226,6 +1228,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
1226
1228
|
"server": "streamable-http:http://localhost:9901/mcp",
|
|
1227
1229
|
"transport": "streamable-http",
|
|
1228
1230
|
"session_healthy": True,
|
|
1231
|
+
"protected": False,
|
|
1229
1232
|
"tools": [{
|
|
1230
1233
|
"name": "tool_a",
|
|
1231
1234
|
"description": "Tool A description",
|