nvidia-nat 1.4.0a20251008__py3-none-any.whl → 1.4.0a20251011__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/react_agent/register.py +15 -24
- nat/agent/rewoo_agent/register.py +15 -24
- nat/agent/tool_calling_agent/register.py +9 -5
- nat/builder/component_utils.py +1 -1
- nat/builder/function.py +4 -4
- nat/builder/intermediate_step_manager.py +32 -0
- nat/builder/workflow_builder.py +46 -3
- nat/cli/entrypoint.py +9 -1
- nat/data_models/api_server.py +78 -9
- nat/data_models/config.py +1 -1
- nat/front_ends/console/console_front_end_plugin.py +11 -2
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/mcp/mcp_front_end_config.py +13 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +18 -1
- nat/front_ends/mcp/memory_profiler.py +320 -0
- nat/front_ends/mcp/tool_converter.py +21 -2
- nat/observability/register.py +16 -0
- nat/runtime/runner.py +1 -2
- nat/runtime/session.py +1 -1
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +3 -3
- nat/utils/type_converter.py +8 -0
- nvidia_nat-1.4.0a20251011.dist-info/METADATA +195 -0
- {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/RECORD +30 -29
- nvidia_nat-1.4.0a20251008.dist-info/METADATA +0 -389
- {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/top_level.txt +0 -0
|
@@ -24,6 +24,7 @@ from nat.builder.function_info import FunctionInfo
|
|
|
24
24
|
from nat.cli.register_workflow import register_function
|
|
25
25
|
from nat.data_models.agent import AgentBaseConfig
|
|
26
26
|
from nat.data_models.api_server import ChatRequest
|
|
27
|
+
from nat.data_models.api_server import ChatRequestOrMessage
|
|
27
28
|
from nat.data_models.api_server import ChatResponse
|
|
28
29
|
from nat.data_models.api_server import Usage
|
|
29
30
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
@@ -70,9 +71,6 @@ class ReActAgentWorkflowConfig(AgentBaseConfig, OptimizableMixin, name="react_ag
|
|
|
70
71
|
default=None,
|
|
71
72
|
description="Provides the SYSTEM_PROMPT to use with the agent") # defaults to SYSTEM_PROMPT in prompt.py
|
|
72
73
|
max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
|
|
73
|
-
use_openai_api: bool = Field(default=False,
|
|
74
|
-
description=("Use OpenAI API for the input/output types to the function. "
|
|
75
|
-
"If False, strings will be used."))
|
|
76
74
|
additional_instructions: str | None = OptimizableField(
|
|
77
75
|
default=None,
|
|
78
76
|
description="Additional instructions to provide to the agent in addition to the base prompt.",
|
|
@@ -118,21 +116,23 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
118
116
|
pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent,
|
|
119
117
|
normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()
|
|
120
118
|
|
|
121
|
-
async def _response_fn(
|
|
119
|
+
async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
|
|
122
120
|
"""
|
|
123
121
|
Main workflow entry function for the ReAct Agent.
|
|
124
122
|
|
|
125
123
|
This function invokes the ReAct Agent Graph and returns the response.
|
|
126
124
|
|
|
127
125
|
Args:
|
|
128
|
-
|
|
126
|
+
chat_request_or_message (ChatRequestOrMessage): The input message to process
|
|
129
127
|
|
|
130
128
|
Returns:
|
|
131
|
-
ChatResponse: The response from the agent or error message
|
|
129
|
+
ChatResponse | str: The response from the agent or error message
|
|
132
130
|
"""
|
|
133
131
|
try:
|
|
132
|
+
message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
|
|
133
|
+
|
|
134
134
|
# initialize the starting state with the user query
|
|
135
|
-
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in
|
|
135
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
|
|
136
136
|
max_tokens=config.max_history,
|
|
137
137
|
strategy="last",
|
|
138
138
|
token_counter=len,
|
|
@@ -153,25 +153,16 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
153
153
|
content = str(output_message.content)
|
|
154
154
|
|
|
155
155
|
# Create usage statistics for the response
|
|
156
|
-
prompt_tokens = sum(len(str(msg.content).split()) for msg in
|
|
156
|
+
prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
|
|
157
157
|
completion_tokens = len(content.split()) if content else 0
|
|
158
158
|
total_tokens = prompt_tokens + completion_tokens
|
|
159
159
|
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
|
|
160
|
-
|
|
161
|
-
|
|
160
|
+
response = ChatResponse.from_string(content, usage=usage)
|
|
161
|
+
if chat_request_or_message.is_string:
|
|
162
|
+
return GlobalTypeConverter.get().convert(response, to_type=str)
|
|
163
|
+
return response
|
|
162
164
|
except Exception as ex:
|
|
163
|
-
logger.
|
|
164
|
-
raise
|
|
165
|
-
|
|
166
|
-
if (config.use_openai_api):
|
|
167
|
-
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
168
|
-
else:
|
|
169
|
-
|
|
170
|
-
async def _str_api_fn(input_message: str) -> str:
|
|
171
|
-
oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
|
|
172
|
-
|
|
173
|
-
oai_output = await _response_fn(oai_input)
|
|
174
|
-
|
|
175
|
-
return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
|
|
165
|
+
logger.error("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, str(ex))
|
|
166
|
+
raise
|
|
176
167
|
|
|
177
|
-
|
|
168
|
+
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
@@ -25,6 +25,7 @@ from nat.builder.function_info import FunctionInfo
|
|
|
25
25
|
from nat.cli.register_workflow import register_function
|
|
26
26
|
from nat.data_models.agent import AgentBaseConfig
|
|
27
27
|
from nat.data_models.api_server import ChatRequest
|
|
28
|
+
from nat.data_models.api_server import ChatRequestOrMessage
|
|
28
29
|
from nat.data_models.api_server import ChatResponse
|
|
29
30
|
from nat.data_models.api_server import Usage
|
|
30
31
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
@@ -54,9 +55,6 @@ class ReWOOAgentWorkflowConfig(AgentBaseConfig, name="rewoo_agent"):
|
|
|
54
55
|
description="The number of retries before raising a tool call error.",
|
|
55
56
|
ge=1)
|
|
56
57
|
max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
|
|
57
|
-
use_openai_api: bool = Field(default=False,
|
|
58
|
-
description=("Use OpenAI API for the input/output types to the function. "
|
|
59
|
-
"If False, strings will be used."))
|
|
60
58
|
additional_planner_instructions: str | None = Field(
|
|
61
59
|
default=None,
|
|
62
60
|
validation_alias=AliasChoices("additional_planner_instructions", "additional_instructions"),
|
|
@@ -125,21 +123,23 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
125
123
|
tool_call_max_retries=config.tool_call_max_retries,
|
|
126
124
|
raise_tool_call_error=config.raise_tool_call_error).build_graph()
|
|
127
125
|
|
|
128
|
-
async def _response_fn(
|
|
126
|
+
async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
|
|
129
127
|
"""
|
|
130
128
|
Main workflow entry function for the ReWOO Agent.
|
|
131
129
|
|
|
132
130
|
This function invokes the ReWOO Agent Graph and returns the response.
|
|
133
131
|
|
|
134
132
|
Args:
|
|
135
|
-
|
|
133
|
+
chat_request_or_message (ChatRequestOrMessage): The input message to process
|
|
136
134
|
|
|
137
135
|
Returns:
|
|
138
|
-
ChatResponse: The response from the agent or error message
|
|
136
|
+
ChatResponse | str: The response from the agent or error message
|
|
139
137
|
"""
|
|
140
138
|
try:
|
|
139
|
+
message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
|
|
140
|
+
|
|
141
141
|
# initialize the starting state with the user query
|
|
142
|
-
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in
|
|
142
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
|
|
143
143
|
max_tokens=config.max_history,
|
|
144
144
|
strategy="last",
|
|
145
145
|
token_counter=len,
|
|
@@ -160,25 +160,16 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
160
160
|
output_message = str(output_message)
|
|
161
161
|
|
|
162
162
|
# Create usage statistics for the response
|
|
163
|
-
prompt_tokens = sum(len(str(msg.content).split()) for msg in
|
|
163
|
+
prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
|
|
164
164
|
completion_tokens = len(output_message.split()) if output_message else 0
|
|
165
165
|
total_tokens = prompt_tokens + completion_tokens
|
|
166
166
|
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
|
|
167
|
-
|
|
168
|
-
|
|
167
|
+
response = ChatResponse.from_string(output_message, usage=usage)
|
|
168
|
+
if chat_request_or_message.is_string:
|
|
169
|
+
return GlobalTypeConverter.get().convert(response, to_type=str)
|
|
170
|
+
return response
|
|
169
171
|
except Exception as ex:
|
|
170
|
-
logger.
|
|
171
|
-
raise
|
|
172
|
-
|
|
173
|
-
if (config.use_openai_api):
|
|
174
|
-
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
175
|
-
|
|
176
|
-
else:
|
|
177
|
-
|
|
178
|
-
async def _str_api_fn(input_message: str) -> str:
|
|
179
|
-
oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
|
|
180
|
-
oai_output = await _response_fn(oai_input)
|
|
181
|
-
|
|
182
|
-
return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
|
|
172
|
+
logger.error("ReWOO Agent failed with exception: %s", ex)
|
|
173
|
+
raise
|
|
183
174
|
|
|
184
|
-
|
|
175
|
+
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
@@ -23,8 +23,10 @@ from nat.builder.function_info import FunctionInfo
|
|
|
23
23
|
from nat.cli.register_workflow import register_function
|
|
24
24
|
from nat.data_models.agent import AgentBaseConfig
|
|
25
25
|
from nat.data_models.api_server import ChatRequest
|
|
26
|
+
from nat.data_models.api_server import ChatRequestOrMessage
|
|
26
27
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
27
28
|
from nat.data_models.component_ref import FunctionRef
|
|
29
|
+
from nat.utils.type_converter import GlobalTypeConverter
|
|
28
30
|
|
|
29
31
|
logger = logging.getLogger(__name__)
|
|
30
32
|
|
|
@@ -81,21 +83,23 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
81
83
|
handle_tool_errors=config.handle_tool_errors,
|
|
82
84
|
return_direct=return_direct_tools).build_graph()
|
|
83
85
|
|
|
84
|
-
async def _response_fn(
|
|
86
|
+
async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> str:
|
|
85
87
|
"""
|
|
86
88
|
Main workflow entry function for the Tool Calling Agent.
|
|
87
89
|
|
|
88
90
|
This function invokes the Tool Calling Agent Graph and returns the response.
|
|
89
91
|
|
|
90
92
|
Args:
|
|
91
|
-
|
|
93
|
+
chat_request_or_message (ChatRequestOrMessage): The input message to process
|
|
92
94
|
|
|
93
95
|
Returns:
|
|
94
96
|
str: The response from the agent or error message
|
|
95
97
|
"""
|
|
96
98
|
try:
|
|
99
|
+
message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
|
|
100
|
+
|
|
97
101
|
# initialize the starting state with the user query
|
|
98
|
-
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in
|
|
102
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
|
|
99
103
|
max_tokens=config.max_history,
|
|
100
104
|
strategy="last",
|
|
101
105
|
token_counter=len,
|
|
@@ -114,8 +118,8 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
114
118
|
output_message = state.messages[-1]
|
|
115
119
|
return str(output_message.content)
|
|
116
120
|
except Exception as ex:
|
|
117
|
-
logger.
|
|
118
|
-
raise
|
|
121
|
+
logger.error("%s Tool Calling Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
|
|
122
|
+
raise
|
|
119
123
|
|
|
120
124
|
try:
|
|
121
125
|
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
nat/builder/component_utils.py
CHANGED
|
@@ -153,7 +153,7 @@ def recursive_componentref_discovery(cls: TypedBaseModel, value: typing.Any,
|
|
|
153
153
|
for v in value.values():
|
|
154
154
|
yield from recursive_componentref_discovery(cls, v, decomposed_type.args[1])
|
|
155
155
|
elif (issubclass(type(value), BaseModel)):
|
|
156
|
-
for field, field_info in value.model_fields.items():
|
|
156
|
+
for field, field_info in type(value).model_fields.items():
|
|
157
157
|
field_data = getattr(value, field)
|
|
158
158
|
yield from recursive_componentref_discovery(cls, field_data, field_info.annotation)
|
|
159
159
|
if (decomposed_type.is_union):
|
nat/builder/function.py
CHANGED
|
@@ -159,8 +159,7 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
|
|
|
159
159
|
|
|
160
160
|
return result
|
|
161
161
|
except Exception as e:
|
|
162
|
-
|
|
163
|
-
logger.error("Error with ainvoke in function with input: %s. %s", value, err_msg)
|
|
162
|
+
logger.error("Error with ainvoke in function with input: %s. Error: %s", value, e)
|
|
164
163
|
raise
|
|
165
164
|
|
|
166
165
|
@typing.final
|
|
@@ -416,8 +415,9 @@ class FunctionGroup:
|
|
|
416
415
|
"""
|
|
417
416
|
if not name.strip():
|
|
418
417
|
raise ValueError("Function name cannot be empty or blank")
|
|
419
|
-
if not re.match(r"^[a-zA-Z0-9_
|
|
420
|
-
raise ValueError(
|
|
418
|
+
if not re.match(r"^[a-zA-Z0-9_.-]+$", name):
|
|
419
|
+
raise ValueError(
|
|
420
|
+
f"Function name can only contain letters, numbers, underscores, periods, and hyphens: {name}")
|
|
421
421
|
if name in self._functions:
|
|
422
422
|
raise ValueError(f"Function {name} already exists in function group {self._instance_name}")
|
|
423
423
|
|
|
@@ -16,6 +16,8 @@
|
|
|
16
16
|
import dataclasses
|
|
17
17
|
import logging
|
|
18
18
|
import typing
|
|
19
|
+
import weakref
|
|
20
|
+
from typing import ClassVar
|
|
19
21
|
|
|
20
22
|
from nat.data_models.intermediate_step import IntermediateStep
|
|
21
23
|
from nat.data_models.intermediate_step import IntermediateStepPayload
|
|
@@ -46,11 +48,19 @@ class IntermediateStepManager:
|
|
|
46
48
|
Manages updates to the NAT Event Stream for intermediate steps
|
|
47
49
|
"""
|
|
48
50
|
|
|
51
|
+
# Class-level tracking for debugging and monitoring
|
|
52
|
+
_instance_count: ClassVar[int] = 0
|
|
53
|
+
_active_instances: ClassVar[set[weakref.ref]] = set()
|
|
54
|
+
|
|
49
55
|
def __init__(self, context_state: "ContextState"): # noqa: F821
|
|
50
56
|
self._context_state = context_state
|
|
51
57
|
|
|
52
58
|
self._outstanding_start_steps: dict[str, OpenStep] = {}
|
|
53
59
|
|
|
60
|
+
# Track instance creation
|
|
61
|
+
IntermediateStepManager._instance_count += 1
|
|
62
|
+
IntermediateStepManager._active_instances.add(weakref.ref(self, self._cleanup_instance_tracking))
|
|
63
|
+
|
|
54
64
|
def push_intermediate_step(self, payload: IntermediateStepPayload) -> None:
|
|
55
65
|
"""
|
|
56
66
|
Pushes an intermediate step to the NAT Event Stream
|
|
@@ -172,3 +182,25 @@ class IntermediateStepManager:
|
|
|
172
182
|
"""
|
|
173
183
|
|
|
174
184
|
return self._context_state.event_stream.get().subscribe(on_next, on_error, on_complete)
|
|
185
|
+
|
|
186
|
+
@classmethod
|
|
187
|
+
def _cleanup_instance_tracking(cls, ref: weakref.ref) -> None:
|
|
188
|
+
"""Cleanup callback for weakref when instance is garbage collected."""
|
|
189
|
+
cls._active_instances.discard(ref)
|
|
190
|
+
|
|
191
|
+
@classmethod
|
|
192
|
+
def get_active_instance_count(cls) -> int:
|
|
193
|
+
"""Get the number of active IntermediateStepManager instances.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
int: Number of active instances (cleaned up automatically via weakref)
|
|
197
|
+
"""
|
|
198
|
+
return len(cls._active_instances)
|
|
199
|
+
|
|
200
|
+
def get_outstanding_step_count(self) -> int:
|
|
201
|
+
"""Get the number of outstanding (started but not ended) steps.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
int: Number of steps that have been started but not yet ended
|
|
205
|
+
"""
|
|
206
|
+
return len(self._outstanding_start_steps)
|
nat/builder/workflow_builder.py
CHANGED
|
@@ -156,6 +156,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
156
156
|
self._registry = registry
|
|
157
157
|
|
|
158
158
|
self._logging_handlers: dict[str, logging.Handler] = {}
|
|
159
|
+
self._removed_root_handlers: list[tuple[logging.Handler, int]] = []
|
|
159
160
|
self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
|
|
160
161
|
|
|
161
162
|
self._functions: dict[str, ConfiguredFunction] = {}
|
|
@@ -187,6 +188,15 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
187
188
|
# Get the telemetry info from the config
|
|
188
189
|
telemetry_config = self.general_config.telemetry
|
|
189
190
|
|
|
191
|
+
# If we have logging configuration, we need to manage the root logger properly
|
|
192
|
+
root_logger = logging.getLogger()
|
|
193
|
+
|
|
194
|
+
# Collect configured handler types to determine if we need to adjust existing handlers
|
|
195
|
+
# This is somewhat of a hack by inspecting the class name of the config object
|
|
196
|
+
has_console_handler = any(
|
|
197
|
+
hasattr(config, "__class__") and "console" in config.__class__.__name__.lower()
|
|
198
|
+
for config in telemetry_config.logging.values())
|
|
199
|
+
|
|
190
200
|
for key, logging_config in telemetry_config.logging.items():
|
|
191
201
|
# Use the same pattern as tracing, but for logging
|
|
192
202
|
logging_info = self._registry.get_logging_method(type(logging_config))
|
|
@@ -200,7 +210,31 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
200
210
|
self._logging_handlers[key] = handler
|
|
201
211
|
|
|
202
212
|
# Now attach to NAT's root logger
|
|
203
|
-
|
|
213
|
+
root_logger.addHandler(handler)
|
|
214
|
+
|
|
215
|
+
# If we added logging handlers, manage existing handlers appropriately
|
|
216
|
+
if self._logging_handlers:
|
|
217
|
+
min_handler_level = min((handler.level for handler in root_logger.handlers), default=logging.CRITICAL)
|
|
218
|
+
|
|
219
|
+
# Ensure the root logger level allows messages through
|
|
220
|
+
root_logger.level = max(root_logger.level, min_handler_level)
|
|
221
|
+
|
|
222
|
+
# If a console handler is configured, adjust or remove default CLI handlers
|
|
223
|
+
# to avoid duplicate output while preserving workflow visibility
|
|
224
|
+
if has_console_handler:
|
|
225
|
+
# Remove existing StreamHandlers that are not the newly configured ones
|
|
226
|
+
for handler in root_logger.handlers[:]:
|
|
227
|
+
if type(handler) is logging.StreamHandler and handler not in self._logging_handlers.values():
|
|
228
|
+
self._removed_root_handlers.append((handler, handler.level))
|
|
229
|
+
root_logger.removeHandler(handler)
|
|
230
|
+
else:
|
|
231
|
+
# No console handler configured, but adjust existing handler levels
|
|
232
|
+
# to respect the minimum configured level for file/other handlers
|
|
233
|
+
for handler in root_logger.handlers[:]:
|
|
234
|
+
if type(handler) is logging.StreamHandler:
|
|
235
|
+
old_level = handler.level
|
|
236
|
+
handler.setLevel(min_handler_level)
|
|
237
|
+
self._removed_root_handlers.append((handler, old_level))
|
|
204
238
|
|
|
205
239
|
# Add the telemetry exporters
|
|
206
240
|
for key, telemetry_exporter_config in telemetry_config.tracing.items():
|
|
@@ -212,8 +246,17 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
212
246
|
|
|
213
247
|
assert self._exit_stack is not None, "Exit stack not initialized"
|
|
214
248
|
|
|
215
|
-
|
|
216
|
-
|
|
249
|
+
root_logger = logging.getLogger()
|
|
250
|
+
|
|
251
|
+
# Remove custom logging handlers
|
|
252
|
+
for handler in self._logging_handlers.values():
|
|
253
|
+
root_logger.removeHandler(handler)
|
|
254
|
+
|
|
255
|
+
# Restore original handlers and their levels
|
|
256
|
+
for handler, old_level in self._removed_root_handlers:
|
|
257
|
+
if handler not in root_logger.handlers:
|
|
258
|
+
root_logger.addHandler(handler)
|
|
259
|
+
handler.setLevel(old_level)
|
|
217
260
|
|
|
218
261
|
await self._exit_stack.__aexit__(*exc_details)
|
|
219
262
|
|
nat/cli/entrypoint.py
CHANGED
|
@@ -29,6 +29,7 @@ import time
|
|
|
29
29
|
|
|
30
30
|
import click
|
|
31
31
|
import nest_asyncio
|
|
32
|
+
from dotenv import load_dotenv
|
|
32
33
|
|
|
33
34
|
from nat.utils.log_levels import LOG_LEVELS
|
|
34
35
|
|
|
@@ -45,6 +46,9 @@ from .commands.uninstall import uninstall_command
|
|
|
45
46
|
from .commands.validate import validate_command
|
|
46
47
|
from .commands.workflow.workflow import workflow_command
|
|
47
48
|
|
|
49
|
+
# Load environment variables from .env file, if it exists
|
|
50
|
+
load_dotenv()
|
|
51
|
+
|
|
48
52
|
# Apply at the beginning of the file to avoid issues with asyncio
|
|
49
53
|
nest_asyncio.apply()
|
|
50
54
|
|
|
@@ -52,7 +56,11 @@ nest_asyncio.apply()
|
|
|
52
56
|
def setup_logging(log_level: str):
|
|
53
57
|
"""Configure logging with the specified level"""
|
|
54
58
|
numeric_level = LOG_LEVELS.get(log_level.upper(), logging.INFO)
|
|
55
|
-
logging.basicConfig(
|
|
59
|
+
logging.basicConfig(
|
|
60
|
+
level=numeric_level,
|
|
61
|
+
format="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
|
|
62
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
63
|
+
)
|
|
56
64
|
return numeric_level
|
|
57
65
|
|
|
58
66
|
|
nat/data_models/api_server.py
CHANGED
|
@@ -28,6 +28,7 @@ from pydantic import HttpUrl
|
|
|
28
28
|
from pydantic import conlist
|
|
29
29
|
from pydantic import field_serializer
|
|
30
30
|
from pydantic import field_validator
|
|
31
|
+
from pydantic import model_validator
|
|
31
32
|
from pydantic_core.core_schema import ValidationInfo
|
|
32
33
|
|
|
33
34
|
from nat.data_models.interactive import HumanPrompt
|
|
@@ -120,15 +121,7 @@ class Message(BaseModel):
|
|
|
120
121
|
role: UserMessageContentRoleType
|
|
121
122
|
|
|
122
123
|
|
|
123
|
-
class
|
|
124
|
-
"""
|
|
125
|
-
ChatRequest is a data model that represents a request to the NAT chat API.
|
|
126
|
-
Fully compatible with OpenAI Chat Completions API specification.
|
|
127
|
-
"""
|
|
128
|
-
|
|
129
|
-
# Required fields
|
|
130
|
-
messages: typing.Annotated[list[Message], conlist(Message, min_length=1)]
|
|
131
|
-
|
|
124
|
+
class ChatRequestOptionals(BaseModel):
|
|
132
125
|
# Optional fields (OpenAI Chat Completions API compatible)
|
|
133
126
|
model: str | None = Field(default=None, description="name of the model to use")
|
|
134
127
|
frequency_penalty: float | None = Field(default=0.0,
|
|
@@ -153,6 +146,16 @@ class ChatRequest(BaseModel):
|
|
|
153
146
|
parallel_tool_calls: bool | None = Field(default=True, description="Whether to enable parallel function calling")
|
|
154
147
|
user: str | None = Field(default=None, description="Unique identifier representing end-user")
|
|
155
148
|
|
|
149
|
+
|
|
150
|
+
class ChatRequest(ChatRequestOptionals):
|
|
151
|
+
"""
|
|
152
|
+
ChatRequest is a data model that represents a request to the NAT chat API.
|
|
153
|
+
Fully compatible with OpenAI Chat Completions API specification.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
# Required fields
|
|
157
|
+
messages: typing.Annotated[list[Message], conlist(Message, min_length=1)]
|
|
158
|
+
|
|
156
159
|
model_config = ConfigDict(extra="allow",
|
|
157
160
|
json_schema_extra={
|
|
158
161
|
"example": {
|
|
@@ -194,6 +197,42 @@ class ChatRequest(BaseModel):
|
|
|
194
197
|
top_p=top_p)
|
|
195
198
|
|
|
196
199
|
|
|
200
|
+
class ChatRequestOrMessage(ChatRequestOptionals):
|
|
201
|
+
"""
|
|
202
|
+
ChatRequestOrMessage is a data model that represents either a conversation or a string input.
|
|
203
|
+
This is useful for functions that can handle either type of input.
|
|
204
|
+
|
|
205
|
+
`messages` is compatible with the OpenAI Chat Completions API specification.
|
|
206
|
+
|
|
207
|
+
`input_string` is a string input that can be used for functions that do not require a conversation.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
messages: typing.Annotated[list[Message] | None, conlist(Message, min_length=1)] = Field(
|
|
211
|
+
default=None, description="The conversation messages to process.")
|
|
212
|
+
|
|
213
|
+
input_string: str | None = Field(default=None, alias="input_message", description="The input message to process.")
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def is_string(self) -> bool:
|
|
217
|
+
return self.input_string is not None
|
|
218
|
+
|
|
219
|
+
@property
|
|
220
|
+
def is_conversation(self) -> bool:
|
|
221
|
+
return self.messages is not None
|
|
222
|
+
|
|
223
|
+
@model_validator(mode="after")
|
|
224
|
+
def validate_messages_or_input_string(self):
|
|
225
|
+
if self.messages is not None and self.input_string is not None:
|
|
226
|
+
raise ValueError("Either messages or input_message/input_string must be provided, not both")
|
|
227
|
+
if self.messages is None and self.input_string is None:
|
|
228
|
+
raise ValueError("Either messages or input_message/input_string must be provided")
|
|
229
|
+
if self.input_string is not None:
|
|
230
|
+
extra_fields = self.model_dump(exclude={"input_string"}, exclude_none=True, exclude_unset=True)
|
|
231
|
+
if len(extra_fields) > 0:
|
|
232
|
+
raise ValueError("no extra fields are permitted when input_message/input_string is provided")
|
|
233
|
+
return self
|
|
234
|
+
|
|
235
|
+
|
|
197
236
|
class ChoiceMessage(BaseModel):
|
|
198
237
|
content: str | None = None
|
|
199
238
|
role: UserMessageContentRoleType | None = None
|
|
@@ -661,6 +700,36 @@ def _string_to_nat_chat_request(data: str) -> ChatRequest:
|
|
|
661
700
|
GlobalTypeConverter.register_converter(_string_to_nat_chat_request)
|
|
662
701
|
|
|
663
702
|
|
|
703
|
+
def _chat_request_or_message_to_chat_request(data: ChatRequestOrMessage) -> ChatRequest:
|
|
704
|
+
if data.input_string is not None:
|
|
705
|
+
return _string_to_nat_chat_request(data.input_string)
|
|
706
|
+
return ChatRequest(**data.model_dump(exclude={"input_string"}))
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
GlobalTypeConverter.register_converter(_chat_request_or_message_to_chat_request)
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
def _chat_request_to_chat_request_or_message(data: ChatRequest) -> ChatRequestOrMessage:
|
|
713
|
+
return ChatRequestOrMessage(**data.model_dump(by_alias=True))
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
GlobalTypeConverter.register_converter(_chat_request_to_chat_request_or_message)
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
def _chat_request_or_message_to_string(data: ChatRequestOrMessage) -> str:
|
|
720
|
+
return data.input_string or ""
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
GlobalTypeConverter.register_converter(_chat_request_or_message_to_string)
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
def _string_to_chat_request_or_message(data: str) -> ChatRequestOrMessage:
|
|
727
|
+
return ChatRequestOrMessage(input_message=data)
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
GlobalTypeConverter.register_converter(_string_to_chat_request_or_message)
|
|
731
|
+
|
|
732
|
+
|
|
664
733
|
# ======== ChatResponse Converters ========
|
|
665
734
|
def _nat_chat_response_to_string(data: ChatResponse) -> str:
|
|
666
735
|
if data.choices and data.choices[0].message:
|
nat/data_models/config.py
CHANGED
|
@@ -187,7 +187,7 @@ class TelemetryConfig(BaseModel):
|
|
|
187
187
|
|
|
188
188
|
class GeneralConfig(BaseModel):
|
|
189
189
|
|
|
190
|
-
model_config = ConfigDict(protected_namespaces=())
|
|
190
|
+
model_config = ConfigDict(protected_namespaces=(), extra="forbid")
|
|
191
191
|
|
|
192
192
|
use_uvloop: bool | None = Field(
|
|
193
193
|
default=None,
|
|
@@ -95,5 +95,14 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
|
|
|
95
95
|
else:
|
|
96
96
|
assert False, "Should not reach here. Should have been caught by pre_run"
|
|
97
97
|
|
|
98
|
-
|
|
99
|
-
|
|
98
|
+
line = f"{'-' * 50}"
|
|
99
|
+
prefix = f"{line}\n{Fore.GREEN}Workflow Result:\n"
|
|
100
|
+
suffix = f"{Fore.RESET}\n{line}"
|
|
101
|
+
|
|
102
|
+
logger.info(f"{prefix}%s{suffix}", runner_outputs)
|
|
103
|
+
|
|
104
|
+
# (handler is a stream handler) => (level > INFO)
|
|
105
|
+
effective_level_too_high = all(
|
|
106
|
+
type(h) is not logging.StreamHandler or h.level > logging.INFO for h in logging.getLogger().handlers)
|
|
107
|
+
if effective_level_too_high:
|
|
108
|
+
print(f"{prefix}{runner_outputs}{suffix}")
|
|
@@ -24,4 +24,4 @@ class HTTPAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
24
24
|
async def authenticate(self, config: AuthProviderBaseConfig, method: AuthFlowType) -> AuthenticatedContext:
|
|
25
25
|
|
|
26
26
|
raise NotImplementedError(f"Authentication method '{method}' is not supported by the HTTP frontend."
|
|
27
|
-
f" Do you have
|
|
27
|
+
f" Do you have WebSockets enabled?")
|
|
@@ -43,3 +43,16 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
|
43
43
|
|
|
44
44
|
server_auth: OAuth2ResourceServerConfig | None = Field(
|
|
45
45
|
default=None, description=("OAuth 2.0 Resource Server configuration for token verification."))
|
|
46
|
+
|
|
47
|
+
# Memory profiling configuration
|
|
48
|
+
enable_memory_profiling: bool = Field(default=False,
|
|
49
|
+
description="Enable memory profiling and diagnostics (default: False)")
|
|
50
|
+
memory_profile_interval: int = Field(default=50,
|
|
51
|
+
description="Log memory stats every N requests (default: 50)",
|
|
52
|
+
ge=1)
|
|
53
|
+
memory_profile_top_n: int = Field(default=10,
|
|
54
|
+
description="Number of top memory allocations to log (default: 10)",
|
|
55
|
+
ge=1,
|
|
56
|
+
le=50)
|
|
57
|
+
memory_profile_log_level: str = Field(default="DEBUG",
|
|
58
|
+
description="Log level for memory profiling output (default: DEBUG)")
|
|
@@ -29,6 +29,7 @@ from nat.builder.workflow import Workflow
|
|
|
29
29
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
30
30
|
from nat.data_models.config import Config
|
|
31
31
|
from nat.front_ends.mcp.mcp_front_end_config import MCPFrontEndConfig
|
|
32
|
+
from nat.front_ends.mcp.memory_profiler import MemoryProfiler
|
|
32
33
|
|
|
33
34
|
logger = logging.getLogger(__name__)
|
|
34
35
|
|
|
@@ -45,6 +46,12 @@ class MCPFrontEndPluginWorkerBase(ABC):
|
|
|
45
46
|
self.full_config = config
|
|
46
47
|
self.front_end_config: MCPFrontEndConfig = config.general.front_end
|
|
47
48
|
|
|
49
|
+
# Initialize memory profiler if enabled
|
|
50
|
+
self.memory_profiler = MemoryProfiler(enabled=self.front_end_config.enable_memory_profiling,
|
|
51
|
+
log_interval=self.front_end_config.memory_profile_interval,
|
|
52
|
+
top_n=self.front_end_config.memory_profile_top_n,
|
|
53
|
+
log_level=self.front_end_config.memory_profile_log_level)
|
|
54
|
+
|
|
48
55
|
def _setup_health_endpoint(self, mcp: FastMCP):
|
|
49
56
|
"""Set up the HTTP health endpoint that exercises MCP ping handler."""
|
|
50
57
|
|
|
@@ -115,6 +122,7 @@ class MCPFrontEndPluginWorkerBase(ABC):
|
|
|
115
122
|
Exposes:
|
|
116
123
|
- GET /debug/tools/list: List tools. Optional query param `name` (one or more, repeatable or comma separated)
|
|
117
124
|
selects a subset and returns details for those tools.
|
|
125
|
+
- GET /debug/memory/stats: Get current memory profiling statistics (read-only)
|
|
118
126
|
"""
|
|
119
127
|
|
|
120
128
|
@mcp.custom_route("/debug/tools/list", methods=["GET"])
|
|
@@ -206,6 +214,15 @@ class MCPFrontEndPluginWorkerBase(ABC):
|
|
|
206
214
|
return JSONResponse(
|
|
207
215
|
_build_final_json(functions_to_include, _parse_detail_param(detail_raw, has_names=bool(names))))
|
|
208
216
|
|
|
217
|
+
# Memory profiling endpoint (read-only)
|
|
218
|
+
@mcp.custom_route("/debug/memory/stats", methods=["GET"])
|
|
219
|
+
async def get_memory_stats(_request: Request):
|
|
220
|
+
"""Get current memory profiling statistics."""
|
|
221
|
+
from starlette.responses import JSONResponse
|
|
222
|
+
|
|
223
|
+
stats = self.memory_profiler.get_stats()
|
|
224
|
+
return JSONResponse(stats)
|
|
225
|
+
|
|
209
226
|
|
|
210
227
|
class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
|
|
211
228
|
"""Default MCP front end plugin worker implementation."""
|
|
@@ -241,7 +258,7 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
|
|
|
241
258
|
|
|
242
259
|
# Register each function with MCP, passing workflow context for observability
|
|
243
260
|
for function_name, function in functions.items():
|
|
244
|
-
register_function_with_mcp(mcp, function_name, function, workflow)
|
|
261
|
+
register_function_with_mcp(mcp, function_name, function, workflow, self.memory_profiler)
|
|
245
262
|
|
|
246
263
|
# Add a simple fallback function if no functions were found
|
|
247
264
|
if not functions:
|