nvidia-nat 1.3.0rc2__py3-none-any.whl → 1.3.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/react_agent/register.py +13 -22
- nat/agent/rewoo_agent/register.py +13 -22
- nat/agent/tool_calling_agent/register.py +7 -3
- nat/builder/component_utils.py +1 -1
- nat/builder/function.py +3 -2
- nat/builder/workflow_builder.py +46 -3
- nat/cli/entrypoint.py +9 -1
- nat/data_models/api_server.py +78 -9
- nat/data_models/config.py +1 -1
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +2 -2
- nat/front_ends/console/console_front_end_plugin.py +11 -2
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/observability/register.py +16 -0
- nat/runtime/session.py +1 -1
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +3 -3
- nat/utils/type_converter.py +8 -0
- nvidia_nat-1.3.0rc3.dist-info/METADATA +195 -0
- {nvidia_nat-1.3.0rc2.dist-info → nvidia_nat-1.3.0rc3.dist-info}/RECORD +26 -26
- nvidia_nat-1.3.0rc2.dist-info/METADATA +0 -389
- {nvidia_nat-1.3.0rc2.dist-info → nvidia_nat-1.3.0rc3.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0rc2.dist-info → nvidia_nat-1.3.0rc3.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0rc2.dist-info → nvidia_nat-1.3.0rc3.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0rc2.dist-info → nvidia_nat-1.3.0rc3.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0rc2.dist-info → nvidia_nat-1.3.0rc3.dist-info}/top_level.txt +0 -0
|
@@ -24,6 +24,7 @@ from nat.builder.function_info import FunctionInfo
|
|
|
24
24
|
from nat.cli.register_workflow import register_function
|
|
25
25
|
from nat.data_models.agent import AgentBaseConfig
|
|
26
26
|
from nat.data_models.api_server import ChatRequest
|
|
27
|
+
from nat.data_models.api_server import ChatRequestOrMessage
|
|
27
28
|
from nat.data_models.api_server import ChatResponse
|
|
28
29
|
from nat.data_models.api_server import Usage
|
|
29
30
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
@@ -70,9 +71,6 @@ class ReActAgentWorkflowConfig(AgentBaseConfig, OptimizableMixin, name="react_ag
|
|
|
70
71
|
default=None,
|
|
71
72
|
description="Provides the SYSTEM_PROMPT to use with the agent") # defaults to SYSTEM_PROMPT in prompt.py
|
|
72
73
|
max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
|
|
73
|
-
use_openai_api: bool = Field(default=False,
|
|
74
|
-
description=("Use OpenAI API for the input/output types to the function. "
|
|
75
|
-
"If False, strings will be used."))
|
|
76
74
|
additional_instructions: str | None = OptimizableField(
|
|
77
75
|
default=None,
|
|
78
76
|
description="Additional instructions to provide to the agent in addition to the base prompt.",
|
|
@@ -118,21 +116,23 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
118
116
|
pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent,
|
|
119
117
|
normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()
|
|
120
118
|
|
|
121
|
-
async def _response_fn(
|
|
119
|
+
async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
|
|
122
120
|
"""
|
|
123
121
|
Main workflow entry function for the ReAct Agent.
|
|
124
122
|
|
|
125
123
|
This function invokes the ReAct Agent Graph and returns the response.
|
|
126
124
|
|
|
127
125
|
Args:
|
|
128
|
-
|
|
126
|
+
chat_request_or_message (ChatRequestOrMessage): The input message to process
|
|
129
127
|
|
|
130
128
|
Returns:
|
|
131
|
-
ChatResponse: The response from the agent or error message
|
|
129
|
+
ChatResponse | str: The response from the agent or error message
|
|
132
130
|
"""
|
|
133
131
|
try:
|
|
132
|
+
message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
|
|
133
|
+
|
|
134
134
|
# initialize the starting state with the user query
|
|
135
|
-
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in
|
|
135
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
|
|
136
136
|
max_tokens=config.max_history,
|
|
137
137
|
strategy="last",
|
|
138
138
|
token_counter=len,
|
|
@@ -153,25 +153,16 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
153
153
|
content = str(output_message.content)
|
|
154
154
|
|
|
155
155
|
# Create usage statistics for the response
|
|
156
|
-
prompt_tokens = sum(len(str(msg.content).split()) for msg in
|
|
156
|
+
prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
|
|
157
157
|
completion_tokens = len(content.split()) if content else 0
|
|
158
158
|
total_tokens = prompt_tokens + completion_tokens
|
|
159
159
|
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
|
|
160
|
-
|
|
161
|
-
|
|
160
|
+
response = ChatResponse.from_string(content, usage=usage)
|
|
161
|
+
if chat_request_or_message.is_string:
|
|
162
|
+
return GlobalTypeConverter.get().convert(response, to_type=str)
|
|
163
|
+
return response
|
|
162
164
|
except Exception as ex:
|
|
163
165
|
logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, str(ex))
|
|
164
166
|
raise RuntimeError
|
|
165
167
|
|
|
166
|
-
|
|
167
|
-
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
168
|
-
else:
|
|
169
|
-
|
|
170
|
-
async def _str_api_fn(input_message: str) -> str:
|
|
171
|
-
oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
|
|
172
|
-
|
|
173
|
-
oai_output = await _response_fn(oai_input)
|
|
174
|
-
|
|
175
|
-
return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
|
|
176
|
-
|
|
177
|
-
yield FunctionInfo.from_fn(_str_api_fn, description=config.description)
|
|
168
|
+
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
@@ -25,6 +25,7 @@ from nat.builder.function_info import FunctionInfo
|
|
|
25
25
|
from nat.cli.register_workflow import register_function
|
|
26
26
|
from nat.data_models.agent import AgentBaseConfig
|
|
27
27
|
from nat.data_models.api_server import ChatRequest
|
|
28
|
+
from nat.data_models.api_server import ChatRequestOrMessage
|
|
28
29
|
from nat.data_models.api_server import ChatResponse
|
|
29
30
|
from nat.data_models.api_server import Usage
|
|
30
31
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
@@ -54,9 +55,6 @@ class ReWOOAgentWorkflowConfig(AgentBaseConfig, name="rewoo_agent"):
|
|
|
54
55
|
description="The number of retries before raising a tool call error.",
|
|
55
56
|
ge=1)
|
|
56
57
|
max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
|
|
57
|
-
use_openai_api: bool = Field(default=False,
|
|
58
|
-
description=("Use OpenAI API for the input/output types to the function. "
|
|
59
|
-
"If False, strings will be used."))
|
|
60
58
|
additional_planner_instructions: str | None = Field(
|
|
61
59
|
default=None,
|
|
62
60
|
validation_alias=AliasChoices("additional_planner_instructions", "additional_instructions"),
|
|
@@ -125,21 +123,23 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
125
123
|
tool_call_max_retries=config.tool_call_max_retries,
|
|
126
124
|
raise_tool_call_error=config.raise_tool_call_error).build_graph()
|
|
127
125
|
|
|
128
|
-
async def _response_fn(
|
|
126
|
+
async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
|
|
129
127
|
"""
|
|
130
128
|
Main workflow entry function for the ReWOO Agent.
|
|
131
129
|
|
|
132
130
|
This function invokes the ReWOO Agent Graph and returns the response.
|
|
133
131
|
|
|
134
132
|
Args:
|
|
135
|
-
|
|
133
|
+
chat_request_or_message (ChatRequestOrMessage): The input message to process
|
|
136
134
|
|
|
137
135
|
Returns:
|
|
138
|
-
ChatResponse: The response from the agent or error message
|
|
136
|
+
ChatResponse | str: The response from the agent or error message
|
|
139
137
|
"""
|
|
140
138
|
try:
|
|
139
|
+
message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
|
|
140
|
+
|
|
141
141
|
# initialize the starting state with the user query
|
|
142
|
-
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in
|
|
142
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
|
|
143
143
|
max_tokens=config.max_history,
|
|
144
144
|
strategy="last",
|
|
145
145
|
token_counter=len,
|
|
@@ -160,25 +160,16 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
160
160
|
output_message = str(output_message)
|
|
161
161
|
|
|
162
162
|
# Create usage statistics for the response
|
|
163
|
-
prompt_tokens = sum(len(str(msg.content).split()) for msg in
|
|
163
|
+
prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
|
|
164
164
|
completion_tokens = len(output_message.split()) if output_message else 0
|
|
165
165
|
total_tokens = prompt_tokens + completion_tokens
|
|
166
166
|
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
|
|
167
|
-
|
|
168
|
-
|
|
167
|
+
response = ChatResponse.from_string(output_message, usage=usage)
|
|
168
|
+
if chat_request_or_message.is_string:
|
|
169
|
+
return GlobalTypeConverter.get().convert(response, to_type=str)
|
|
170
|
+
return response
|
|
169
171
|
except Exception as ex:
|
|
170
172
|
logger.exception("ReWOO Agent failed with exception: %s", ex)
|
|
171
173
|
raise RuntimeError
|
|
172
174
|
|
|
173
|
-
|
|
174
|
-
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
175
|
-
|
|
176
|
-
else:
|
|
177
|
-
|
|
178
|
-
async def _str_api_fn(input_message: str) -> str:
|
|
179
|
-
oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
|
|
180
|
-
oai_output = await _response_fn(oai_input)
|
|
181
|
-
|
|
182
|
-
return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
|
|
183
|
-
|
|
184
|
-
yield FunctionInfo.from_fn(_str_api_fn, description=config.description)
|
|
175
|
+
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
@@ -23,8 +23,10 @@ from nat.builder.function_info import FunctionInfo
|
|
|
23
23
|
from nat.cli.register_workflow import register_function
|
|
24
24
|
from nat.data_models.agent import AgentBaseConfig
|
|
25
25
|
from nat.data_models.api_server import ChatRequest
|
|
26
|
+
from nat.data_models.api_server import ChatRequestOrMessage
|
|
26
27
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
27
28
|
from nat.data_models.component_ref import FunctionRef
|
|
29
|
+
from nat.utils.type_converter import GlobalTypeConverter
|
|
28
30
|
|
|
29
31
|
logger = logging.getLogger(__name__)
|
|
30
32
|
|
|
@@ -81,21 +83,23 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
81
83
|
handle_tool_errors=config.handle_tool_errors,
|
|
82
84
|
return_direct=return_direct_tools).build_graph()
|
|
83
85
|
|
|
84
|
-
async def _response_fn(
|
|
86
|
+
async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> str:
|
|
85
87
|
"""
|
|
86
88
|
Main workflow entry function for the Tool Calling Agent.
|
|
87
89
|
|
|
88
90
|
This function invokes the Tool Calling Agent Graph and returns the response.
|
|
89
91
|
|
|
90
92
|
Args:
|
|
91
|
-
|
|
93
|
+
chat_request_or_message (ChatRequestOrMessage): The input message to process
|
|
92
94
|
|
|
93
95
|
Returns:
|
|
94
96
|
str: The response from the agent or error message
|
|
95
97
|
"""
|
|
96
98
|
try:
|
|
99
|
+
message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
|
|
100
|
+
|
|
97
101
|
# initialize the starting state with the user query
|
|
98
|
-
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in
|
|
102
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
|
|
99
103
|
max_tokens=config.max_history,
|
|
100
104
|
strategy="last",
|
|
101
105
|
token_counter=len,
|
nat/builder/component_utils.py
CHANGED
|
@@ -153,7 +153,7 @@ def recursive_componentref_discovery(cls: TypedBaseModel, value: typing.Any,
|
|
|
153
153
|
for v in value.values():
|
|
154
154
|
yield from recursive_componentref_discovery(cls, v, decomposed_type.args[1])
|
|
155
155
|
elif (issubclass(type(value), BaseModel)):
|
|
156
|
-
for field, field_info in value.model_fields.items():
|
|
156
|
+
for field, field_info in type(value).model_fields.items():
|
|
157
157
|
field_data = getattr(value, field)
|
|
158
158
|
yield from recursive_componentref_discovery(cls, field_data, field_info.annotation)
|
|
159
159
|
if (decomposed_type.is_union):
|
nat/builder/function.py
CHANGED
|
@@ -416,8 +416,9 @@ class FunctionGroup:
|
|
|
416
416
|
"""
|
|
417
417
|
if not name.strip():
|
|
418
418
|
raise ValueError("Function name cannot be empty or blank")
|
|
419
|
-
if not re.match(r"^[a-zA-Z0-9_
|
|
420
|
-
raise ValueError(
|
|
419
|
+
if not re.match(r"^[a-zA-Z0-9_.-]+$", name):
|
|
420
|
+
raise ValueError(
|
|
421
|
+
f"Function name can only contain letters, numbers, underscores, periods, and hyphens: {name}")
|
|
421
422
|
if name in self._functions:
|
|
422
423
|
raise ValueError(f"Function {name} already exists in function group {self._instance_name}")
|
|
423
424
|
|
nat/builder/workflow_builder.py
CHANGED
|
@@ -156,6 +156,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
156
156
|
self._registry = registry
|
|
157
157
|
|
|
158
158
|
self._logging_handlers: dict[str, logging.Handler] = {}
|
|
159
|
+
self._removed_root_handlers: list[tuple[logging.Handler, int]] = []
|
|
159
160
|
self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
|
|
160
161
|
|
|
161
162
|
self._functions: dict[str, ConfiguredFunction] = {}
|
|
@@ -187,6 +188,15 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
187
188
|
# Get the telemetry info from the config
|
|
188
189
|
telemetry_config = self.general_config.telemetry
|
|
189
190
|
|
|
191
|
+
# If we have logging configuration, we need to manage the root logger properly
|
|
192
|
+
root_logger = logging.getLogger()
|
|
193
|
+
|
|
194
|
+
# Collect configured handler types to determine if we need to adjust existing handlers
|
|
195
|
+
# This is somewhat of a hack by inspecting the class name of the config object
|
|
196
|
+
has_console_handler = any(
|
|
197
|
+
hasattr(config, "__class__") and "console" in config.__class__.__name__.lower()
|
|
198
|
+
for config in telemetry_config.logging.values())
|
|
199
|
+
|
|
190
200
|
for key, logging_config in telemetry_config.logging.items():
|
|
191
201
|
# Use the same pattern as tracing, but for logging
|
|
192
202
|
logging_info = self._registry.get_logging_method(type(logging_config))
|
|
@@ -200,7 +210,31 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
200
210
|
self._logging_handlers[key] = handler
|
|
201
211
|
|
|
202
212
|
# Now attach to NAT's root logger
|
|
203
|
-
|
|
213
|
+
root_logger.addHandler(handler)
|
|
214
|
+
|
|
215
|
+
# If we added logging handlers, manage existing handlers appropriately
|
|
216
|
+
if self._logging_handlers:
|
|
217
|
+
min_handler_level = min((handler.level for handler in root_logger.handlers), default=logging.CRITICAL)
|
|
218
|
+
|
|
219
|
+
# Ensure the root logger level allows messages through
|
|
220
|
+
root_logger.level = max(root_logger.level, min_handler_level)
|
|
221
|
+
|
|
222
|
+
# If a console handler is configured, adjust or remove default CLI handlers
|
|
223
|
+
# to avoid duplicate output while preserving workflow visibility
|
|
224
|
+
if has_console_handler:
|
|
225
|
+
# Remove existing StreamHandlers that are not the newly configured ones
|
|
226
|
+
for handler in root_logger.handlers[:]:
|
|
227
|
+
if type(handler) is logging.StreamHandler and handler not in self._logging_handlers.values():
|
|
228
|
+
self._removed_root_handlers.append((handler, handler.level))
|
|
229
|
+
root_logger.removeHandler(handler)
|
|
230
|
+
else:
|
|
231
|
+
# No console handler configured, but adjust existing handler levels
|
|
232
|
+
# to respect the minimum configured level for file/other handlers
|
|
233
|
+
for handler in root_logger.handlers[:]:
|
|
234
|
+
if type(handler) is logging.StreamHandler:
|
|
235
|
+
old_level = handler.level
|
|
236
|
+
handler.setLevel(min_handler_level)
|
|
237
|
+
self._removed_root_handlers.append((handler, old_level))
|
|
204
238
|
|
|
205
239
|
# Add the telemetry exporters
|
|
206
240
|
for key, telemetry_exporter_config in telemetry_config.tracing.items():
|
|
@@ -212,8 +246,17 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
212
246
|
|
|
213
247
|
assert self._exit_stack is not None, "Exit stack not initialized"
|
|
214
248
|
|
|
215
|
-
|
|
216
|
-
|
|
249
|
+
root_logger = logging.getLogger()
|
|
250
|
+
|
|
251
|
+
# Remove custom logging handlers
|
|
252
|
+
for handler in self._logging_handlers.values():
|
|
253
|
+
root_logger.removeHandler(handler)
|
|
254
|
+
|
|
255
|
+
# Restore original handlers and their levels
|
|
256
|
+
for handler, old_level in self._removed_root_handlers:
|
|
257
|
+
if handler not in root_logger.handlers:
|
|
258
|
+
root_logger.addHandler(handler)
|
|
259
|
+
handler.setLevel(old_level)
|
|
217
260
|
|
|
218
261
|
await self._exit_stack.__aexit__(*exc_details)
|
|
219
262
|
|
nat/cli/entrypoint.py
CHANGED
|
@@ -29,6 +29,7 @@ import time
|
|
|
29
29
|
|
|
30
30
|
import click
|
|
31
31
|
import nest_asyncio
|
|
32
|
+
from dotenv import load_dotenv
|
|
32
33
|
|
|
33
34
|
from nat.utils.log_levels import LOG_LEVELS
|
|
34
35
|
|
|
@@ -45,6 +46,9 @@ from .commands.uninstall import uninstall_command
|
|
|
45
46
|
from .commands.validate import validate_command
|
|
46
47
|
from .commands.workflow.workflow import workflow_command
|
|
47
48
|
|
|
49
|
+
# Load environment variables from .env file, if it exists
|
|
50
|
+
load_dotenv()
|
|
51
|
+
|
|
48
52
|
# Apply at the beginning of the file to avoid issues with asyncio
|
|
49
53
|
nest_asyncio.apply()
|
|
50
54
|
|
|
@@ -52,7 +56,11 @@ nest_asyncio.apply()
|
|
|
52
56
|
def setup_logging(log_level: str):
|
|
53
57
|
"""Configure logging with the specified level"""
|
|
54
58
|
numeric_level = LOG_LEVELS.get(log_level.upper(), logging.INFO)
|
|
55
|
-
logging.basicConfig(
|
|
59
|
+
logging.basicConfig(
|
|
60
|
+
level=numeric_level,
|
|
61
|
+
format="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
|
|
62
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
63
|
+
)
|
|
56
64
|
return numeric_level
|
|
57
65
|
|
|
58
66
|
|
nat/data_models/api_server.py
CHANGED
|
@@ -28,6 +28,7 @@ from pydantic import HttpUrl
|
|
|
28
28
|
from pydantic import conlist
|
|
29
29
|
from pydantic import field_serializer
|
|
30
30
|
from pydantic import field_validator
|
|
31
|
+
from pydantic import model_validator
|
|
31
32
|
from pydantic_core.core_schema import ValidationInfo
|
|
32
33
|
|
|
33
34
|
from nat.data_models.interactive import HumanPrompt
|
|
@@ -120,15 +121,7 @@ class Message(BaseModel):
|
|
|
120
121
|
role: UserMessageContentRoleType
|
|
121
122
|
|
|
122
123
|
|
|
123
|
-
class
|
|
124
|
-
"""
|
|
125
|
-
ChatRequest is a data model that represents a request to the NAT chat API.
|
|
126
|
-
Fully compatible with OpenAI Chat Completions API specification.
|
|
127
|
-
"""
|
|
128
|
-
|
|
129
|
-
# Required fields
|
|
130
|
-
messages: typing.Annotated[list[Message], conlist(Message, min_length=1)]
|
|
131
|
-
|
|
124
|
+
class ChatRequestOptionals(BaseModel):
|
|
132
125
|
# Optional fields (OpenAI Chat Completions API compatible)
|
|
133
126
|
model: str | None = Field(default=None, description="name of the model to use")
|
|
134
127
|
frequency_penalty: float | None = Field(default=0.0,
|
|
@@ -153,6 +146,16 @@ class ChatRequest(BaseModel):
|
|
|
153
146
|
parallel_tool_calls: bool | None = Field(default=True, description="Whether to enable parallel function calling")
|
|
154
147
|
user: str | None = Field(default=None, description="Unique identifier representing end-user")
|
|
155
148
|
|
|
149
|
+
|
|
150
|
+
class ChatRequest(ChatRequestOptionals):
|
|
151
|
+
"""
|
|
152
|
+
ChatRequest is a data model that represents a request to the NAT chat API.
|
|
153
|
+
Fully compatible with OpenAI Chat Completions API specification.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
# Required fields
|
|
157
|
+
messages: typing.Annotated[list[Message], conlist(Message, min_length=1)]
|
|
158
|
+
|
|
156
159
|
model_config = ConfigDict(extra="allow",
|
|
157
160
|
json_schema_extra={
|
|
158
161
|
"example": {
|
|
@@ -194,6 +197,42 @@ class ChatRequest(BaseModel):
|
|
|
194
197
|
top_p=top_p)
|
|
195
198
|
|
|
196
199
|
|
|
200
|
+
class ChatRequestOrMessage(ChatRequestOptionals):
|
|
201
|
+
"""
|
|
202
|
+
ChatRequestOrMessage is a data model that represents either a conversation or a string input.
|
|
203
|
+
This is useful for functions that can handle either type of input.
|
|
204
|
+
|
|
205
|
+
`messages` is compatible with the OpenAI Chat Completions API specification.
|
|
206
|
+
|
|
207
|
+
`input_string` is a string input that can be used for functions that do not require a conversation.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
messages: typing.Annotated[list[Message] | None, conlist(Message, min_length=1)] = Field(
|
|
211
|
+
default=None, description="The conversation messages to process.")
|
|
212
|
+
|
|
213
|
+
input_string: str | None = Field(default=None, alias="input_message", description="The input message to process.")
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def is_string(self) -> bool:
|
|
217
|
+
return self.input_string is not None
|
|
218
|
+
|
|
219
|
+
@property
|
|
220
|
+
def is_conversation(self) -> bool:
|
|
221
|
+
return self.messages is not None
|
|
222
|
+
|
|
223
|
+
@model_validator(mode="after")
|
|
224
|
+
def validate_messages_or_input_string(self):
|
|
225
|
+
if self.messages is not None and self.input_string is not None:
|
|
226
|
+
raise ValueError("Either messages or input_message/input_string must be provided, not both")
|
|
227
|
+
if self.messages is None and self.input_string is None:
|
|
228
|
+
raise ValueError("Either messages or input_message/input_string must be provided")
|
|
229
|
+
if self.input_string is not None:
|
|
230
|
+
extra_fields = self.model_dump(exclude={"input_string"}, exclude_none=True, exclude_unset=True)
|
|
231
|
+
if len(extra_fields) > 0:
|
|
232
|
+
raise ValueError("no extra fields are permitted when input_message/input_string is provided")
|
|
233
|
+
return self
|
|
234
|
+
|
|
235
|
+
|
|
197
236
|
class ChoiceMessage(BaseModel):
|
|
198
237
|
content: str | None = None
|
|
199
238
|
role: UserMessageContentRoleType | None = None
|
|
@@ -661,6 +700,36 @@ def _string_to_nat_chat_request(data: str) -> ChatRequest:
|
|
|
661
700
|
GlobalTypeConverter.register_converter(_string_to_nat_chat_request)
|
|
662
701
|
|
|
663
702
|
|
|
703
|
+
def _chat_request_or_message_to_chat_request(data: ChatRequestOrMessage) -> ChatRequest:
|
|
704
|
+
if data.input_string is not None:
|
|
705
|
+
return _string_to_nat_chat_request(data.input_string)
|
|
706
|
+
return ChatRequest(**data.model_dump(exclude={"input_string"}))
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
GlobalTypeConverter.register_converter(_chat_request_or_message_to_chat_request)
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
def _chat_request_to_chat_request_or_message(data: ChatRequest) -> ChatRequestOrMessage:
|
|
713
|
+
return ChatRequestOrMessage(**data.model_dump(by_alias=True))
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
GlobalTypeConverter.register_converter(_chat_request_to_chat_request_or_message)
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
def _chat_request_or_message_to_string(data: ChatRequestOrMessage) -> str:
|
|
720
|
+
return data.input_string or ""
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
GlobalTypeConverter.register_converter(_chat_request_or_message_to_string)
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
def _string_to_chat_request_or_message(data: str) -> ChatRequestOrMessage:
|
|
727
|
+
return ChatRequestOrMessage(input_message=data)
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
GlobalTypeConverter.register_converter(_string_to_chat_request_or_message)
|
|
731
|
+
|
|
732
|
+
|
|
664
733
|
# ======== ChatResponse Converters ========
|
|
665
734
|
def _nat_chat_response_to_string(data: ChatResponse) -> str:
|
|
666
735
|
if data.choices and data.choices[0].message:
|
nat/data_models/config.py
CHANGED
|
@@ -187,7 +187,7 @@ class TelemetryConfig(BaseModel):
|
|
|
187
187
|
|
|
188
188
|
class GeneralConfig(BaseModel):
|
|
189
189
|
|
|
190
|
-
model_config = ConfigDict(protected_namespaces=())
|
|
190
|
+
model_config = ConfigDict(protected_namespaces=(), extra="forbid")
|
|
191
191
|
|
|
192
192
|
use_uvloop: bool | None = Field(
|
|
193
193
|
default=None,
|
|
@@ -46,7 +46,7 @@ async def execute_score_select_function(config: ExecuteScoreSelectFunctionConfig
|
|
|
46
46
|
|
|
47
47
|
from pydantic import BaseModel
|
|
48
48
|
|
|
49
|
-
executable_fn: Function = builder.get_function(name=config.augmented_fn)
|
|
49
|
+
executable_fn: Function = await builder.get_function(name=config.augmented_fn)
|
|
50
50
|
|
|
51
51
|
if config.scorer:
|
|
52
52
|
scorer = await builder.get_ttc_strategy(strategy_name=config.scorer,
|
|
@@ -98,8 +98,8 @@ async def register_ttc_tool_wrapper_function(
|
|
|
98
98
|
|
|
99
99
|
augmented_function_desc = config.tool_description
|
|
100
100
|
|
|
101
|
-
fn_input_schema: BaseModel = augmented_function.input_schema
|
|
102
|
-
fn_output_schema: BaseModel = augmented_function.single_output_schema
|
|
101
|
+
fn_input_schema: type[BaseModel] = augmented_function.input_schema
|
|
102
|
+
fn_output_schema: type[BaseModel] | type[None] = augmented_function.single_output_schema
|
|
103
103
|
|
|
104
104
|
runnable_llm = input_llm.with_structured_output(schema=fn_input_schema)
|
|
105
105
|
|
|
@@ -95,5 +95,14 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
|
|
|
95
95
|
else:
|
|
96
96
|
assert False, "Should not reach here. Should have been caught by pre_run"
|
|
97
97
|
|
|
98
|
-
|
|
99
|
-
|
|
98
|
+
line = f"{'-' * 50}"
|
|
99
|
+
prefix = f"{line}\n{Fore.GREEN}Workflow Result:\n"
|
|
100
|
+
suffix = f"{Fore.RESET}\n{line}"
|
|
101
|
+
|
|
102
|
+
logger.info(f"{prefix}%s{suffix}", runner_outputs)
|
|
103
|
+
|
|
104
|
+
# (handler is a stream handler) => (level > INFO)
|
|
105
|
+
effective_level_too_high = all(
|
|
106
|
+
type(h) is not logging.StreamHandler or h.level > logging.INFO for h in logging.getLogger().handlers)
|
|
107
|
+
if effective_level_too_high:
|
|
108
|
+
print(f"{prefix}{runner_outputs}{suffix}")
|
|
@@ -24,4 +24,4 @@ class HTTPAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
24
24
|
async def authenticate(self, config: AuthProviderBaseConfig, method: AuthFlowType) -> AuthenticatedContext:
|
|
25
25
|
|
|
26
26
|
raise NotImplementedError(f"Authentication method '{method}' is not supported by the HTTP frontend."
|
|
27
|
-
f" Do you have
|
|
27
|
+
f" Do you have WebSockets enabled?")
|
nat/observability/register.py
CHANGED
|
@@ -77,6 +77,14 @@ async def console_logging_method(config: ConsoleLoggingMethodConfig, builder: Bu
|
|
|
77
77
|
level = getattr(logging, config.level.upper(), logging.INFO)
|
|
78
78
|
handler = logging.StreamHandler(stream=sys.stdout)
|
|
79
79
|
handler.setLevel(level)
|
|
80
|
+
|
|
81
|
+
# Set formatter to match the default CLI format
|
|
82
|
+
formatter = logging.Formatter(
|
|
83
|
+
fmt="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
|
|
84
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
85
|
+
)
|
|
86
|
+
handler.setFormatter(formatter)
|
|
87
|
+
|
|
80
88
|
yield handler
|
|
81
89
|
|
|
82
90
|
|
|
@@ -95,4 +103,12 @@ async def file_logging_method(config: FileLoggingMethod, builder: Builder):
|
|
|
95
103
|
level = getattr(logging, config.level.upper(), logging.INFO)
|
|
96
104
|
handler = logging.FileHandler(filename=config.path, mode="a", encoding="utf-8")
|
|
97
105
|
handler.setLevel(level)
|
|
106
|
+
|
|
107
|
+
# Set formatter to match the default CLI format
|
|
108
|
+
formatter = logging.Formatter(
|
|
109
|
+
fmt="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
|
|
110
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
111
|
+
)
|
|
112
|
+
handler.setFormatter(formatter)
|
|
113
|
+
|
|
98
114
|
yield handler
|
nat/runtime/session.py
CHANGED
|
@@ -192,7 +192,7 @@ class SessionManager:
|
|
|
192
192
|
user_message_id: str | None,
|
|
193
193
|
conversation_id: str | None) -> None:
|
|
194
194
|
"""
|
|
195
|
-
Extracts and sets user metadata for
|
|
195
|
+
Extracts and sets user metadata for WebSocket connections.
|
|
196
196
|
"""
|
|
197
197
|
|
|
198
198
|
# Extract cookies from WebSocket headers (similar to HTTP request)
|
|
@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
|
|
|
30
30
|
class AddToolConfig(FunctionBaseConfig, name="add_memory"):
|
|
31
31
|
"""Function to add memory to a hosted memory platform."""
|
|
32
32
|
|
|
33
|
-
description: str = Field(default=("Tool to add memory about a user's interactions to a system "
|
|
33
|
+
description: str = Field(default=("Tool to add a memory about a user's interactions to a system "
|
|
34
34
|
"for retrieval later."),
|
|
35
35
|
description="The description of this function's use for tool calling agents.")
|
|
36
|
-
memory: MemoryRef = Field(default="saas_memory",
|
|
36
|
+
memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
|
|
37
37
|
description=("Instance name of the memory client instance from the workflow "
|
|
38
38
|
"configuration object."))
|
|
39
39
|
|
|
@@ -46,7 +46,7 @@ async def add_memory_tool(config: AddToolConfig, builder: Builder):
|
|
|
46
46
|
from langchain_core.tools import ToolException
|
|
47
47
|
|
|
48
48
|
# First, retrieve the memory client
|
|
49
|
-
memory_editor = builder.get_memory_client(config.memory)
|
|
49
|
+
memory_editor = await builder.get_memory_client(config.memory)
|
|
50
50
|
|
|
51
51
|
async def _arun(item: MemoryItem) -> str:
|
|
52
52
|
"""
|
|
@@ -30,10 +30,9 @@ logger = logging.getLogger(__name__)
|
|
|
30
30
|
class DeleteToolConfig(FunctionBaseConfig, name="delete_memory"):
|
|
31
31
|
"""Function to delete memory from a hosted memory platform."""
|
|
32
32
|
|
|
33
|
-
description: str = Field(default=
|
|
34
|
-
"interactions to help answer questions in a personalized way."),
|
|
33
|
+
description: str = Field(default="Tool to delete a memory from a hosted memory platform.",
|
|
35
34
|
description="The description of this function's use for tool calling agents.")
|
|
36
|
-
memory: MemoryRef = Field(default="saas_memory",
|
|
35
|
+
memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
|
|
37
36
|
description=("Instance name of the memory client instance from the workflow "
|
|
38
37
|
"configuration object."))
|
|
39
38
|
|
|
@@ -47,7 +46,7 @@ async def delete_memory_tool(config: DeleteToolConfig, builder: Builder):
|
|
|
47
46
|
from langchain_core.tools import ToolException
|
|
48
47
|
|
|
49
48
|
# First, retrieve the memory client
|
|
50
|
-
memory_editor = builder.get_memory_client(config.memory)
|
|
49
|
+
memory_editor = await builder.get_memory_client(config.memory)
|
|
51
50
|
|
|
52
51
|
async def _arun(user_id: str) -> str:
|
|
53
52
|
"""
|
|
@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
|
|
|
30
30
|
class GetToolConfig(FunctionBaseConfig, name="get_memory"):
|
|
31
31
|
"""Function to get memory to a hosted memory platform."""
|
|
32
32
|
|
|
33
|
-
description: str = Field(default=("Tool to retrieve memory about a user's "
|
|
33
|
+
description: str = Field(default=("Tool to retrieve a memory about a user's "
|
|
34
34
|
"interactions to help answer questions in a personalized way."),
|
|
35
35
|
description="The description of this function's use for tool calling agents.")
|
|
36
|
-
memory: MemoryRef = Field(default="saas_memory",
|
|
36
|
+
memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
|
|
37
37
|
description=("Instance name of the memory client instance from the workflow "
|
|
38
38
|
"configuration object."))
|
|
39
39
|
|
|
@@ -49,7 +49,7 @@ async def get_memory_tool(config: GetToolConfig, builder: Builder):
|
|
|
49
49
|
from langchain_core.tools import ToolException
|
|
50
50
|
|
|
51
51
|
# First, retrieve the memory client
|
|
52
|
-
memory_editor = builder.get_memory_client(config.memory)
|
|
52
|
+
memory_editor = await builder.get_memory_client(config.memory)
|
|
53
53
|
|
|
54
54
|
async def _arun(search_input: SearchMemoryInput) -> str:
|
|
55
55
|
"""
|
nat/utils/type_converter.py
CHANGED
|
@@ -93,6 +93,14 @@ class TypeConverter:
|
|
|
93
93
|
if to_type is None or decomposed.is_instance(data):
|
|
94
94
|
return data
|
|
95
95
|
|
|
96
|
+
# 2) If data is a union type, try to convert to each type in the union
|
|
97
|
+
if decomposed.is_union:
|
|
98
|
+
for union_type in decomposed.args:
|
|
99
|
+
result = self._convert(data, union_type)
|
|
100
|
+
if result is not None:
|
|
101
|
+
return result
|
|
102
|
+
return None
|
|
103
|
+
|
|
96
104
|
root = decomposed.root
|
|
97
105
|
|
|
98
106
|
# 2) Attempt direct in *this* converter
|