nvidia-nat 1.3.0rc2__py3-none-any.whl → 1.3.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/react_agent/register.py +15 -24
- nat/agent/rewoo_agent/register.py +15 -24
- nat/agent/tool_calling_agent/register.py +9 -5
- nat/builder/component_utils.py +1 -1
- nat/builder/function.py +4 -4
- nat/builder/workflow_builder.py +46 -3
- nat/cli/entrypoint.py +9 -1
- nat/data_models/api_server.py +120 -1
- nat/data_models/config.py +1 -1
- nat/data_models/thinking_mixin.py +2 -2
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +2 -2
- nat/front_ends/console/console_front_end_plugin.py +11 -2
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/message_handler.py +65 -40
- nat/front_ends/fastapi/message_validator.py +1 -2
- nat/front_ends/mcp/mcp_front_end_config.py +32 -0
- nat/observability/register.py +16 -0
- nat/runtime/runner.py +1 -2
- nat/runtime/session.py +1 -1
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +3 -3
- nat/utils/type_converter.py +8 -0
- nvidia_nat-1.3.0rc4.dist-info/METADATA +195 -0
- {nvidia_nat-1.3.0rc2.dist-info → nvidia_nat-1.3.0rc4.dist-info}/RECORD +31 -31
- nvidia_nat-1.3.0rc2.dist-info/METADATA +0 -389
- {nvidia_nat-1.3.0rc2.dist-info → nvidia_nat-1.3.0rc4.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0rc2.dist-info → nvidia_nat-1.3.0rc4.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0rc2.dist-info → nvidia_nat-1.3.0rc4.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0rc2.dist-info → nvidia_nat-1.3.0rc4.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0rc2.dist-info → nvidia_nat-1.3.0rc4.dist-info}/top_level.txt +0 -0
|
@@ -24,6 +24,7 @@ from nat.builder.function_info import FunctionInfo
|
|
|
24
24
|
from nat.cli.register_workflow import register_function
|
|
25
25
|
from nat.data_models.agent import AgentBaseConfig
|
|
26
26
|
from nat.data_models.api_server import ChatRequest
|
|
27
|
+
from nat.data_models.api_server import ChatRequestOrMessage
|
|
27
28
|
from nat.data_models.api_server import ChatResponse
|
|
28
29
|
from nat.data_models.api_server import Usage
|
|
29
30
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
@@ -70,9 +71,6 @@ class ReActAgentWorkflowConfig(AgentBaseConfig, OptimizableMixin, name="react_ag
|
|
|
70
71
|
default=None,
|
|
71
72
|
description="Provides the SYSTEM_PROMPT to use with the agent") # defaults to SYSTEM_PROMPT in prompt.py
|
|
72
73
|
max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
|
|
73
|
-
use_openai_api: bool = Field(default=False,
|
|
74
|
-
description=("Use OpenAI API for the input/output types to the function. "
|
|
75
|
-
"If False, strings will be used."))
|
|
76
74
|
additional_instructions: str | None = OptimizableField(
|
|
77
75
|
default=None,
|
|
78
76
|
description="Additional instructions to provide to the agent in addition to the base prompt.",
|
|
@@ -118,21 +116,23 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
118
116
|
pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent,
|
|
119
117
|
normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()
|
|
120
118
|
|
|
121
|
-
async def _response_fn(
|
|
119
|
+
async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
|
|
122
120
|
"""
|
|
123
121
|
Main workflow entry function for the ReAct Agent.
|
|
124
122
|
|
|
125
123
|
This function invokes the ReAct Agent Graph and returns the response.
|
|
126
124
|
|
|
127
125
|
Args:
|
|
128
|
-
|
|
126
|
+
chat_request_or_message (ChatRequestOrMessage): The input message to process
|
|
129
127
|
|
|
130
128
|
Returns:
|
|
131
|
-
ChatResponse: The response from the agent or error message
|
|
129
|
+
ChatResponse | str: The response from the agent or error message
|
|
132
130
|
"""
|
|
133
131
|
try:
|
|
132
|
+
message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
|
|
133
|
+
|
|
134
134
|
# initialize the starting state with the user query
|
|
135
|
-
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in
|
|
135
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
|
|
136
136
|
max_tokens=config.max_history,
|
|
137
137
|
strategy="last",
|
|
138
138
|
token_counter=len,
|
|
@@ -153,25 +153,16 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
153
153
|
content = str(output_message.content)
|
|
154
154
|
|
|
155
155
|
# Create usage statistics for the response
|
|
156
|
-
prompt_tokens = sum(len(str(msg.content).split()) for msg in
|
|
156
|
+
prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
|
|
157
157
|
completion_tokens = len(content.split()) if content else 0
|
|
158
158
|
total_tokens = prompt_tokens + completion_tokens
|
|
159
159
|
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
|
|
160
|
-
|
|
161
|
-
|
|
160
|
+
response = ChatResponse.from_string(content, usage=usage)
|
|
161
|
+
if chat_request_or_message.is_string:
|
|
162
|
+
return GlobalTypeConverter.get().convert(response, to_type=str)
|
|
163
|
+
return response
|
|
162
164
|
except Exception as ex:
|
|
163
|
-
logger.
|
|
164
|
-
raise
|
|
165
|
-
|
|
166
|
-
if (config.use_openai_api):
|
|
167
|
-
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
168
|
-
else:
|
|
169
|
-
|
|
170
|
-
async def _str_api_fn(input_message: str) -> str:
|
|
171
|
-
oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
|
|
172
|
-
|
|
173
|
-
oai_output = await _response_fn(oai_input)
|
|
174
|
-
|
|
175
|
-
return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
|
|
165
|
+
logger.error("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, str(ex))
|
|
166
|
+
raise
|
|
176
167
|
|
|
177
|
-
|
|
168
|
+
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
@@ -25,6 +25,7 @@ from nat.builder.function_info import FunctionInfo
|
|
|
25
25
|
from nat.cli.register_workflow import register_function
|
|
26
26
|
from nat.data_models.agent import AgentBaseConfig
|
|
27
27
|
from nat.data_models.api_server import ChatRequest
|
|
28
|
+
from nat.data_models.api_server import ChatRequestOrMessage
|
|
28
29
|
from nat.data_models.api_server import ChatResponse
|
|
29
30
|
from nat.data_models.api_server import Usage
|
|
30
31
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
@@ -54,9 +55,6 @@ class ReWOOAgentWorkflowConfig(AgentBaseConfig, name="rewoo_agent"):
|
|
|
54
55
|
description="The number of retries before raising a tool call error.",
|
|
55
56
|
ge=1)
|
|
56
57
|
max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
|
|
57
|
-
use_openai_api: bool = Field(default=False,
|
|
58
|
-
description=("Use OpenAI API for the input/output types to the function. "
|
|
59
|
-
"If False, strings will be used."))
|
|
60
58
|
additional_planner_instructions: str | None = Field(
|
|
61
59
|
default=None,
|
|
62
60
|
validation_alias=AliasChoices("additional_planner_instructions", "additional_instructions"),
|
|
@@ -125,21 +123,23 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
125
123
|
tool_call_max_retries=config.tool_call_max_retries,
|
|
126
124
|
raise_tool_call_error=config.raise_tool_call_error).build_graph()
|
|
127
125
|
|
|
128
|
-
async def _response_fn(
|
|
126
|
+
async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
|
|
129
127
|
"""
|
|
130
128
|
Main workflow entry function for the ReWOO Agent.
|
|
131
129
|
|
|
132
130
|
This function invokes the ReWOO Agent Graph and returns the response.
|
|
133
131
|
|
|
134
132
|
Args:
|
|
135
|
-
|
|
133
|
+
chat_request_or_message (ChatRequestOrMessage): The input message to process
|
|
136
134
|
|
|
137
135
|
Returns:
|
|
138
|
-
ChatResponse: The response from the agent or error message
|
|
136
|
+
ChatResponse | str: The response from the agent or error message
|
|
139
137
|
"""
|
|
140
138
|
try:
|
|
139
|
+
message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
|
|
140
|
+
|
|
141
141
|
# initialize the starting state with the user query
|
|
142
|
-
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in
|
|
142
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
|
|
143
143
|
max_tokens=config.max_history,
|
|
144
144
|
strategy="last",
|
|
145
145
|
token_counter=len,
|
|
@@ -160,25 +160,16 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
160
160
|
output_message = str(output_message)
|
|
161
161
|
|
|
162
162
|
# Create usage statistics for the response
|
|
163
|
-
prompt_tokens = sum(len(str(msg.content).split()) for msg in
|
|
163
|
+
prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
|
|
164
164
|
completion_tokens = len(output_message.split()) if output_message else 0
|
|
165
165
|
total_tokens = prompt_tokens + completion_tokens
|
|
166
166
|
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
|
|
167
|
-
|
|
168
|
-
|
|
167
|
+
response = ChatResponse.from_string(output_message, usage=usage)
|
|
168
|
+
if chat_request_or_message.is_string:
|
|
169
|
+
return GlobalTypeConverter.get().convert(response, to_type=str)
|
|
170
|
+
return response
|
|
169
171
|
except Exception as ex:
|
|
170
|
-
logger.
|
|
171
|
-
raise
|
|
172
|
-
|
|
173
|
-
if (config.use_openai_api):
|
|
174
|
-
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
175
|
-
|
|
176
|
-
else:
|
|
177
|
-
|
|
178
|
-
async def _str_api_fn(input_message: str) -> str:
|
|
179
|
-
oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
|
|
180
|
-
oai_output = await _response_fn(oai_input)
|
|
181
|
-
|
|
182
|
-
return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
|
|
172
|
+
logger.error("ReWOO Agent failed with exception: %s", ex)
|
|
173
|
+
raise
|
|
183
174
|
|
|
184
|
-
|
|
175
|
+
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
@@ -23,8 +23,10 @@ from nat.builder.function_info import FunctionInfo
|
|
|
23
23
|
from nat.cli.register_workflow import register_function
|
|
24
24
|
from nat.data_models.agent import AgentBaseConfig
|
|
25
25
|
from nat.data_models.api_server import ChatRequest
|
|
26
|
+
from nat.data_models.api_server import ChatRequestOrMessage
|
|
26
27
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
27
28
|
from nat.data_models.component_ref import FunctionRef
|
|
29
|
+
from nat.utils.type_converter import GlobalTypeConverter
|
|
28
30
|
|
|
29
31
|
logger = logging.getLogger(__name__)
|
|
30
32
|
|
|
@@ -81,21 +83,23 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
81
83
|
handle_tool_errors=config.handle_tool_errors,
|
|
82
84
|
return_direct=return_direct_tools).build_graph()
|
|
83
85
|
|
|
84
|
-
async def _response_fn(
|
|
86
|
+
async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> str:
|
|
85
87
|
"""
|
|
86
88
|
Main workflow entry function for the Tool Calling Agent.
|
|
87
89
|
|
|
88
90
|
This function invokes the Tool Calling Agent Graph and returns the response.
|
|
89
91
|
|
|
90
92
|
Args:
|
|
91
|
-
|
|
93
|
+
chat_request_or_message (ChatRequestOrMessage): The input message to process
|
|
92
94
|
|
|
93
95
|
Returns:
|
|
94
96
|
str: The response from the agent or error message
|
|
95
97
|
"""
|
|
96
98
|
try:
|
|
99
|
+
message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
|
|
100
|
+
|
|
97
101
|
# initialize the starting state with the user query
|
|
98
|
-
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in
|
|
102
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
|
|
99
103
|
max_tokens=config.max_history,
|
|
100
104
|
strategy="last",
|
|
101
105
|
token_counter=len,
|
|
@@ -114,8 +118,8 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
114
118
|
output_message = state.messages[-1]
|
|
115
119
|
return str(output_message.content)
|
|
116
120
|
except Exception as ex:
|
|
117
|
-
logger.
|
|
118
|
-
raise
|
|
121
|
+
logger.error("%s Tool Calling Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
|
|
122
|
+
raise
|
|
119
123
|
|
|
120
124
|
try:
|
|
121
125
|
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
nat/builder/component_utils.py
CHANGED
|
@@ -153,7 +153,7 @@ def recursive_componentref_discovery(cls: TypedBaseModel, value: typing.Any,
|
|
|
153
153
|
for v in value.values():
|
|
154
154
|
yield from recursive_componentref_discovery(cls, v, decomposed_type.args[1])
|
|
155
155
|
elif (issubclass(type(value), BaseModel)):
|
|
156
|
-
for field, field_info in value.model_fields.items():
|
|
156
|
+
for field, field_info in type(value).model_fields.items():
|
|
157
157
|
field_data = getattr(value, field)
|
|
158
158
|
yield from recursive_componentref_discovery(cls, field_data, field_info.annotation)
|
|
159
159
|
if (decomposed_type.is_union):
|
nat/builder/function.py
CHANGED
|
@@ -159,8 +159,7 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
|
|
|
159
159
|
|
|
160
160
|
return result
|
|
161
161
|
except Exception as e:
|
|
162
|
-
|
|
163
|
-
logger.error("Error with ainvoke in function with input: %s. %s", value, err_msg)
|
|
162
|
+
logger.error("Error with ainvoke in function with input: %s. Error: %s", value, e)
|
|
164
163
|
raise
|
|
165
164
|
|
|
166
165
|
@typing.final
|
|
@@ -416,8 +415,9 @@ class FunctionGroup:
|
|
|
416
415
|
"""
|
|
417
416
|
if not name.strip():
|
|
418
417
|
raise ValueError("Function name cannot be empty or blank")
|
|
419
|
-
if not re.match(r"^[a-zA-Z0-9_
|
|
420
|
-
raise ValueError(
|
|
418
|
+
if not re.match(r"^[a-zA-Z0-9_.-]+$", name):
|
|
419
|
+
raise ValueError(
|
|
420
|
+
f"Function name can only contain letters, numbers, underscores, periods, and hyphens: {name}")
|
|
421
421
|
if name in self._functions:
|
|
422
422
|
raise ValueError(f"Function {name} already exists in function group {self._instance_name}")
|
|
423
423
|
|
nat/builder/workflow_builder.py
CHANGED
|
@@ -156,6 +156,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
156
156
|
self._registry = registry
|
|
157
157
|
|
|
158
158
|
self._logging_handlers: dict[str, logging.Handler] = {}
|
|
159
|
+
self._removed_root_handlers: list[tuple[logging.Handler, int]] = []
|
|
159
160
|
self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
|
|
160
161
|
|
|
161
162
|
self._functions: dict[str, ConfiguredFunction] = {}
|
|
@@ -187,6 +188,15 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
187
188
|
# Get the telemetry info from the config
|
|
188
189
|
telemetry_config = self.general_config.telemetry
|
|
189
190
|
|
|
191
|
+
# If we have logging configuration, we need to manage the root logger properly
|
|
192
|
+
root_logger = logging.getLogger()
|
|
193
|
+
|
|
194
|
+
# Collect configured handler types to determine if we need to adjust existing handlers
|
|
195
|
+
# This is somewhat of a hack by inspecting the class name of the config object
|
|
196
|
+
has_console_handler = any(
|
|
197
|
+
hasattr(config, "__class__") and "console" in config.__class__.__name__.lower()
|
|
198
|
+
for config in telemetry_config.logging.values())
|
|
199
|
+
|
|
190
200
|
for key, logging_config in telemetry_config.logging.items():
|
|
191
201
|
# Use the same pattern as tracing, but for logging
|
|
192
202
|
logging_info = self._registry.get_logging_method(type(logging_config))
|
|
@@ -200,7 +210,31 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
200
210
|
self._logging_handlers[key] = handler
|
|
201
211
|
|
|
202
212
|
# Now attach to NAT's root logger
|
|
203
|
-
|
|
213
|
+
root_logger.addHandler(handler)
|
|
214
|
+
|
|
215
|
+
# If we added logging handlers, manage existing handlers appropriately
|
|
216
|
+
if self._logging_handlers:
|
|
217
|
+
min_handler_level = min((handler.level for handler in root_logger.handlers), default=logging.CRITICAL)
|
|
218
|
+
|
|
219
|
+
# Ensure the root logger level allows messages through
|
|
220
|
+
root_logger.level = max(root_logger.level, min_handler_level)
|
|
221
|
+
|
|
222
|
+
# If a console handler is configured, adjust or remove default CLI handlers
|
|
223
|
+
# to avoid duplicate output while preserving workflow visibility
|
|
224
|
+
if has_console_handler:
|
|
225
|
+
# Remove existing StreamHandlers that are not the newly configured ones
|
|
226
|
+
for handler in root_logger.handlers[:]:
|
|
227
|
+
if type(handler) is logging.StreamHandler and handler not in self._logging_handlers.values():
|
|
228
|
+
self._removed_root_handlers.append((handler, handler.level))
|
|
229
|
+
root_logger.removeHandler(handler)
|
|
230
|
+
else:
|
|
231
|
+
# No console handler configured, but adjust existing handler levels
|
|
232
|
+
# to respect the minimum configured level for file/other handlers
|
|
233
|
+
for handler in root_logger.handlers[:]:
|
|
234
|
+
if type(handler) is logging.StreamHandler:
|
|
235
|
+
old_level = handler.level
|
|
236
|
+
handler.setLevel(min_handler_level)
|
|
237
|
+
self._removed_root_handlers.append((handler, old_level))
|
|
204
238
|
|
|
205
239
|
# Add the telemetry exporters
|
|
206
240
|
for key, telemetry_exporter_config in telemetry_config.tracing.items():
|
|
@@ -212,8 +246,17 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
212
246
|
|
|
213
247
|
assert self._exit_stack is not None, "Exit stack not initialized"
|
|
214
248
|
|
|
215
|
-
|
|
216
|
-
|
|
249
|
+
root_logger = logging.getLogger()
|
|
250
|
+
|
|
251
|
+
# Remove custom logging handlers
|
|
252
|
+
for handler in self._logging_handlers.values():
|
|
253
|
+
root_logger.removeHandler(handler)
|
|
254
|
+
|
|
255
|
+
# Restore original handlers and their levels
|
|
256
|
+
for handler, old_level in self._removed_root_handlers:
|
|
257
|
+
if handler not in root_logger.handlers:
|
|
258
|
+
root_logger.addHandler(handler)
|
|
259
|
+
handler.setLevel(old_level)
|
|
217
260
|
|
|
218
261
|
await self._exit_stack.__aexit__(*exc_details)
|
|
219
262
|
|
nat/cli/entrypoint.py
CHANGED
|
@@ -29,6 +29,7 @@ import time
|
|
|
29
29
|
|
|
30
30
|
import click
|
|
31
31
|
import nest_asyncio
|
|
32
|
+
from dotenv import load_dotenv
|
|
32
33
|
|
|
33
34
|
from nat.utils.log_levels import LOG_LEVELS
|
|
34
35
|
|
|
@@ -45,6 +46,9 @@ from .commands.uninstall import uninstall_command
|
|
|
45
46
|
from .commands.validate import validate_command
|
|
46
47
|
from .commands.workflow.workflow import workflow_command
|
|
47
48
|
|
|
49
|
+
# Load environment variables from .env file, if it exists
|
|
50
|
+
load_dotenv()
|
|
51
|
+
|
|
48
52
|
# Apply at the beginning of the file to avoid issues with asyncio
|
|
49
53
|
nest_asyncio.apply()
|
|
50
54
|
|
|
@@ -52,7 +56,11 @@ nest_asyncio.apply()
|
|
|
52
56
|
def setup_logging(log_level: str):
|
|
53
57
|
"""Configure logging with the specified level"""
|
|
54
58
|
numeric_level = LOG_LEVELS.get(log_level.upper(), logging.INFO)
|
|
55
|
-
logging.basicConfig(
|
|
59
|
+
logging.basicConfig(
|
|
60
|
+
level=numeric_level,
|
|
61
|
+
format="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
|
|
62
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
63
|
+
)
|
|
56
64
|
return numeric_level
|
|
57
65
|
|
|
58
66
|
|
nat/data_models/api_server.py
CHANGED
|
@@ -28,6 +28,7 @@ from pydantic import HttpUrl
|
|
|
28
28
|
from pydantic import conlist
|
|
29
29
|
from pydantic import field_serializer
|
|
30
30
|
from pydantic import field_validator
|
|
31
|
+
from pydantic import model_validator
|
|
31
32
|
from pydantic_core.core_schema import ValidationInfo
|
|
32
33
|
|
|
33
34
|
from nat.data_models.interactive import HumanPrompt
|
|
@@ -152,7 +153,6 @@ class ChatRequest(BaseModel):
|
|
|
152
153
|
tool_choice: str | dict[str, typing.Any] | None = Field(default=None, description="Controls which tool is called")
|
|
153
154
|
parallel_tool_calls: bool | None = Field(default=True, description="Whether to enable parallel function calling")
|
|
154
155
|
user: str | None = Field(default=None, description="Unique identifier representing end-user")
|
|
155
|
-
|
|
156
156
|
model_config = ConfigDict(extra="allow",
|
|
157
157
|
json_schema_extra={
|
|
158
158
|
"example": {
|
|
@@ -194,6 +194,85 @@ class ChatRequest(BaseModel):
|
|
|
194
194
|
top_p=top_p)
|
|
195
195
|
|
|
196
196
|
|
|
197
|
+
class ChatRequestOrMessage(BaseModel):
|
|
198
|
+
"""
|
|
199
|
+
`ChatRequestOrMessage` is a data model that represents either a conversation or a string input.
|
|
200
|
+
This is useful for functions that can handle either type of input.
|
|
201
|
+
|
|
202
|
+
- `messages` is compatible with the OpenAI Chat Completions API specification.
|
|
203
|
+
- `input_message` is a string input that can be used for functions that do not require a conversation.
|
|
204
|
+
|
|
205
|
+
Note: When `messages` is provided, extra fields are allowed to enable lossless round-trip
|
|
206
|
+
conversion with ChatRequest. When `input_message` is provided, no extra fields are permitted.
|
|
207
|
+
"""
|
|
208
|
+
model_config = ConfigDict(
|
|
209
|
+
extra="allow",
|
|
210
|
+
json_schema_extra={
|
|
211
|
+
"examples": [
|
|
212
|
+
{
|
|
213
|
+
"input_message": "What can you do?"
|
|
214
|
+
},
|
|
215
|
+
{
|
|
216
|
+
"messages": [{
|
|
217
|
+
"role": "user", "content": "What can you do?"
|
|
218
|
+
}],
|
|
219
|
+
"model": "nvidia/nemotron",
|
|
220
|
+
"temperature": 0.7
|
|
221
|
+
},
|
|
222
|
+
],
|
|
223
|
+
"oneOf": [
|
|
224
|
+
{
|
|
225
|
+
"required": ["input_message"],
|
|
226
|
+
"properties": {
|
|
227
|
+
"input_message": {
|
|
228
|
+
"type": "string"
|
|
229
|
+
},
|
|
230
|
+
},
|
|
231
|
+
"additionalProperties": {
|
|
232
|
+
"not": True, "errorMessage": 'remove additional property ${0#}'
|
|
233
|
+
},
|
|
234
|
+
},
|
|
235
|
+
{
|
|
236
|
+
"required": ["messages"],
|
|
237
|
+
"properties": {
|
|
238
|
+
"messages": {
|
|
239
|
+
"type": "array"
|
|
240
|
+
},
|
|
241
|
+
},
|
|
242
|
+
"additionalProperties": True
|
|
243
|
+
},
|
|
244
|
+
]
|
|
245
|
+
},
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
messages: typing.Annotated[list[Message] | None, conlist(Message, min_length=1)] = Field(
|
|
249
|
+
default=None, description="A non-empty conversation of messages to process.")
|
|
250
|
+
|
|
251
|
+
input_message: str | None = Field(
|
|
252
|
+
default=None,
|
|
253
|
+
description="A single input message to process. Useful for functions that do not require a conversation")
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def is_string(self) -> bool:
|
|
257
|
+
return self.input_message is not None
|
|
258
|
+
|
|
259
|
+
@property
|
|
260
|
+
def is_conversation(self) -> bool:
|
|
261
|
+
return self.messages is not None
|
|
262
|
+
|
|
263
|
+
@model_validator(mode="after")
|
|
264
|
+
def validate_model(self):
|
|
265
|
+
if self.messages is not None and self.input_message is not None:
|
|
266
|
+
raise ValueError("Either messages or input_message must be provided, not both")
|
|
267
|
+
if self.messages is None and self.input_message is None:
|
|
268
|
+
raise ValueError("Either messages or input_message must be provided")
|
|
269
|
+
if self.input_message is not None:
|
|
270
|
+
extra_fields = self.model_dump(exclude={"input_message"}, exclude_none=True, exclude_unset=True)
|
|
271
|
+
if len(extra_fields) > 0:
|
|
272
|
+
raise ValueError("no extra fields are permitted when input_message is provided")
|
|
273
|
+
return self
|
|
274
|
+
|
|
275
|
+
|
|
197
276
|
class ChoiceMessage(BaseModel):
|
|
198
277
|
content: str | None = None
|
|
199
278
|
role: UserMessageContentRoleType | None = None
|
|
@@ -661,6 +740,46 @@ def _string_to_nat_chat_request(data: str) -> ChatRequest:
|
|
|
661
740
|
GlobalTypeConverter.register_converter(_string_to_nat_chat_request)
|
|
662
741
|
|
|
663
742
|
|
|
743
|
+
def _chat_request_or_message_to_chat_request(data: ChatRequestOrMessage) -> ChatRequest:
|
|
744
|
+
if data.input_message is not None:
|
|
745
|
+
return _string_to_nat_chat_request(data.input_message)
|
|
746
|
+
return ChatRequest(**data.model_dump(exclude={"input_message"}))
|
|
747
|
+
|
|
748
|
+
|
|
749
|
+
GlobalTypeConverter.register_converter(_chat_request_or_message_to_chat_request)
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
def _chat_request_to_chat_request_or_message(data: ChatRequest) -> ChatRequestOrMessage:
|
|
753
|
+
return ChatRequestOrMessage(**data.model_dump(by_alias=True))
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
GlobalTypeConverter.register_converter(_chat_request_to_chat_request_or_message)
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
def _chat_request_or_message_to_string(data: ChatRequestOrMessage) -> str:
|
|
760
|
+
if data.input_message is not None:
|
|
761
|
+
return data.input_message
|
|
762
|
+
# Extract content from last message in conversation
|
|
763
|
+
if data.messages is None:
|
|
764
|
+
return ""
|
|
765
|
+
content = data.messages[-1].content
|
|
766
|
+
if content is None:
|
|
767
|
+
return ""
|
|
768
|
+
if isinstance(content, str):
|
|
769
|
+
return content
|
|
770
|
+
return str(content)
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
GlobalTypeConverter.register_converter(_chat_request_or_message_to_string)
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
def _string_to_chat_request_or_message(data: str) -> ChatRequestOrMessage:
|
|
777
|
+
return ChatRequestOrMessage(input_message=data)
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
GlobalTypeConverter.register_converter(_string_to_chat_request_or_message)
|
|
781
|
+
|
|
782
|
+
|
|
664
783
|
# ======== ChatResponse Converters ========
|
|
665
784
|
def _nat_chat_response_to_string(data: ChatResponse) -> str:
|
|
666
785
|
if data.choices and data.choices[0].message:
|
nat/data_models/config.py
CHANGED
|
@@ -187,7 +187,7 @@ class TelemetryConfig(BaseModel):
|
|
|
187
187
|
|
|
188
188
|
class GeneralConfig(BaseModel):
|
|
189
189
|
|
|
190
|
-
model_config = ConfigDict(protected_namespaces=())
|
|
190
|
+
model_config = ConfigDict(protected_namespaces=(), extra="forbid")
|
|
191
191
|
|
|
192
192
|
use_uvloop: bool | None = Field(
|
|
193
193
|
default=None,
|
|
@@ -51,7 +51,7 @@ class ThinkingMixin(
|
|
|
51
51
|
Returns the system prompt to use for thinking.
|
|
52
52
|
For NVIDIA Nemotron, returns "/think" if enabled, else "/no_think".
|
|
53
53
|
For Llama Nemotron v1.5, returns "/think" if enabled, else "/no_think".
|
|
54
|
-
For Llama Nemotron v1.0, returns "detailed thinking on" if enabled, else "detailed thinking off".
|
|
54
|
+
For Llama Nemotron v1.0 or v1.1, returns "detailed thinking on" if enabled, else "detailed thinking off".
|
|
55
55
|
If thinking is not supported on the model, returns None.
|
|
56
56
|
|
|
57
57
|
Returns:
|
|
@@ -72,7 +72,7 @@ class ThinkingMixin(
|
|
|
72
72
|
return "/think" if self.thinking else "/no_think"
|
|
73
73
|
|
|
74
74
|
if model.startswith("nvidia/llama"):
|
|
75
|
-
if "v1-0" in model or "v1-1" in model:
|
|
75
|
+
if "v1-0" in model or "v1-1" in model or model.endswith("v1"):
|
|
76
76
|
return f"detailed thinking {'on' if self.thinking else 'off'}"
|
|
77
77
|
|
|
78
78
|
if "v1-5" in model:
|
|
@@ -46,7 +46,7 @@ async def execute_score_select_function(config: ExecuteScoreSelectFunctionConfig
|
|
|
46
46
|
|
|
47
47
|
from pydantic import BaseModel
|
|
48
48
|
|
|
49
|
-
executable_fn: Function = builder.get_function(name=config.augmented_fn)
|
|
49
|
+
executable_fn: Function = await builder.get_function(name=config.augmented_fn)
|
|
50
50
|
|
|
51
51
|
if config.scorer:
|
|
52
52
|
scorer = await builder.get_ttc_strategy(strategy_name=config.scorer,
|
|
@@ -98,8 +98,8 @@ async def register_ttc_tool_wrapper_function(
|
|
|
98
98
|
|
|
99
99
|
augmented_function_desc = config.tool_description
|
|
100
100
|
|
|
101
|
-
fn_input_schema: BaseModel = augmented_function.input_schema
|
|
102
|
-
fn_output_schema: BaseModel = augmented_function.single_output_schema
|
|
101
|
+
fn_input_schema: type[BaseModel] = augmented_function.input_schema
|
|
102
|
+
fn_output_schema: type[BaseModel] | type[None] = augmented_function.single_output_schema
|
|
103
103
|
|
|
104
104
|
runnable_llm = input_llm.with_structured_output(schema=fn_input_schema)
|
|
105
105
|
|
|
@@ -95,5 +95,14 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
|
|
|
95
95
|
else:
|
|
96
96
|
assert False, "Should not reach here. Should have been caught by pre_run"
|
|
97
97
|
|
|
98
|
-
|
|
99
|
-
|
|
98
|
+
line = f"{'-' * 50}"
|
|
99
|
+
prefix = f"{line}\n{Fore.GREEN}Workflow Result:\n"
|
|
100
|
+
suffix = f"{Fore.RESET}\n{line}"
|
|
101
|
+
|
|
102
|
+
logger.info(f"{prefix}%s{suffix}", runner_outputs)
|
|
103
|
+
|
|
104
|
+
# (handler is a stream handler) => (level > INFO)
|
|
105
|
+
effective_level_too_high = all(
|
|
106
|
+
type(h) is not logging.StreamHandler or h.level > logging.INFO for h in logging.getLogger().handlers)
|
|
107
|
+
if effective_level_too_high:
|
|
108
|
+
print(f"{prefix}{runner_outputs}{suffix}")
|
|
@@ -24,4 +24,4 @@ class HTTPAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
24
24
|
async def authenticate(self, config: AuthProviderBaseConfig, method: AuthFlowType) -> AuthenticatedContext:
|
|
25
25
|
|
|
26
26
|
raise NotImplementedError(f"Authentication method '{method}' is not supported by the HTTP frontend."
|
|
27
|
-
f" Do you have
|
|
27
|
+
f" Do you have WebSockets enabled?")
|