nvidia-nat 1.4.0a20251010__py3-none-any.whl → 1.4.0a20251011__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/react_agent/register.py +15 -24
- nat/agent/rewoo_agent/register.py +15 -24
- nat/agent/tool_calling_agent/register.py +9 -5
- nat/builder/function.py +4 -4
- nat/builder/intermediate_step_manager.py +32 -0
- nat/data_models/api_server.py +78 -9
- nat/front_ends/mcp/mcp_front_end_config.py +13 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +18 -1
- nat/front_ends/mcp/memory_profiler.py +320 -0
- nat/front_ends/mcp/tool_converter.py +21 -2
- nat/runtime/runner.py +1 -2
- nat/utils/type_converter.py +8 -0
- {nvidia_nat-1.4.0a20251010.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/METADATA +3 -3
- {nvidia_nat-1.4.0a20251010.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/RECORD +19 -18
- {nvidia_nat-1.4.0a20251010.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251010.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.4.0a20251010.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251010.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251010.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/top_level.txt +0 -0
|
@@ -24,6 +24,7 @@ from nat.builder.function_info import FunctionInfo
|
|
|
24
24
|
from nat.cli.register_workflow import register_function
|
|
25
25
|
from nat.data_models.agent import AgentBaseConfig
|
|
26
26
|
from nat.data_models.api_server import ChatRequest
|
|
27
|
+
from nat.data_models.api_server import ChatRequestOrMessage
|
|
27
28
|
from nat.data_models.api_server import ChatResponse
|
|
28
29
|
from nat.data_models.api_server import Usage
|
|
29
30
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
@@ -70,9 +71,6 @@ class ReActAgentWorkflowConfig(AgentBaseConfig, OptimizableMixin, name="react_ag
|
|
|
70
71
|
default=None,
|
|
71
72
|
description="Provides the SYSTEM_PROMPT to use with the agent") # defaults to SYSTEM_PROMPT in prompt.py
|
|
72
73
|
max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
|
|
73
|
-
use_openai_api: bool = Field(default=False,
|
|
74
|
-
description=("Use OpenAI API for the input/output types to the function. "
|
|
75
|
-
"If False, strings will be used."))
|
|
76
74
|
additional_instructions: str | None = OptimizableField(
|
|
77
75
|
default=None,
|
|
78
76
|
description="Additional instructions to provide to the agent in addition to the base prompt.",
|
|
@@ -118,21 +116,23 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
118
116
|
pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent,
|
|
119
117
|
normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()
|
|
120
118
|
|
|
121
|
-
async def _response_fn(
|
|
119
|
+
async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
|
|
122
120
|
"""
|
|
123
121
|
Main workflow entry function for the ReAct Agent.
|
|
124
122
|
|
|
125
123
|
This function invokes the ReAct Agent Graph and returns the response.
|
|
126
124
|
|
|
127
125
|
Args:
|
|
128
|
-
|
|
126
|
+
chat_request_or_message (ChatRequestOrMessage): The input message to process
|
|
129
127
|
|
|
130
128
|
Returns:
|
|
131
|
-
ChatResponse: The response from the agent or error message
|
|
129
|
+
ChatResponse | str: The response from the agent or error message
|
|
132
130
|
"""
|
|
133
131
|
try:
|
|
132
|
+
message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
|
|
133
|
+
|
|
134
134
|
# initialize the starting state with the user query
|
|
135
|
-
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in
|
|
135
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
|
|
136
136
|
max_tokens=config.max_history,
|
|
137
137
|
strategy="last",
|
|
138
138
|
token_counter=len,
|
|
@@ -153,25 +153,16 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
153
153
|
content = str(output_message.content)
|
|
154
154
|
|
|
155
155
|
# Create usage statistics for the response
|
|
156
|
-
prompt_tokens = sum(len(str(msg.content).split()) for msg in
|
|
156
|
+
prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
|
|
157
157
|
completion_tokens = len(content.split()) if content else 0
|
|
158
158
|
total_tokens = prompt_tokens + completion_tokens
|
|
159
159
|
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
|
|
160
|
-
|
|
161
|
-
|
|
160
|
+
response = ChatResponse.from_string(content, usage=usage)
|
|
161
|
+
if chat_request_or_message.is_string:
|
|
162
|
+
return GlobalTypeConverter.get().convert(response, to_type=str)
|
|
163
|
+
return response
|
|
162
164
|
except Exception as ex:
|
|
163
|
-
logger.
|
|
164
|
-
raise
|
|
165
|
-
|
|
166
|
-
if (config.use_openai_api):
|
|
167
|
-
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
168
|
-
else:
|
|
169
|
-
|
|
170
|
-
async def _str_api_fn(input_message: str) -> str:
|
|
171
|
-
oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
|
|
172
|
-
|
|
173
|
-
oai_output = await _response_fn(oai_input)
|
|
174
|
-
|
|
175
|
-
return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
|
|
165
|
+
logger.error("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, str(ex))
|
|
166
|
+
raise
|
|
176
167
|
|
|
177
|
-
|
|
168
|
+
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
@@ -25,6 +25,7 @@ from nat.builder.function_info import FunctionInfo
|
|
|
25
25
|
from nat.cli.register_workflow import register_function
|
|
26
26
|
from nat.data_models.agent import AgentBaseConfig
|
|
27
27
|
from nat.data_models.api_server import ChatRequest
|
|
28
|
+
from nat.data_models.api_server import ChatRequestOrMessage
|
|
28
29
|
from nat.data_models.api_server import ChatResponse
|
|
29
30
|
from nat.data_models.api_server import Usage
|
|
30
31
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
@@ -54,9 +55,6 @@ class ReWOOAgentWorkflowConfig(AgentBaseConfig, name="rewoo_agent"):
|
|
|
54
55
|
description="The number of retries before raising a tool call error.",
|
|
55
56
|
ge=1)
|
|
56
57
|
max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
|
|
57
|
-
use_openai_api: bool = Field(default=False,
|
|
58
|
-
description=("Use OpenAI API for the input/output types to the function. "
|
|
59
|
-
"If False, strings will be used."))
|
|
60
58
|
additional_planner_instructions: str | None = Field(
|
|
61
59
|
default=None,
|
|
62
60
|
validation_alias=AliasChoices("additional_planner_instructions", "additional_instructions"),
|
|
@@ -125,21 +123,23 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
125
123
|
tool_call_max_retries=config.tool_call_max_retries,
|
|
126
124
|
raise_tool_call_error=config.raise_tool_call_error).build_graph()
|
|
127
125
|
|
|
128
|
-
async def _response_fn(
|
|
126
|
+
async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
|
|
129
127
|
"""
|
|
130
128
|
Main workflow entry function for the ReWOO Agent.
|
|
131
129
|
|
|
132
130
|
This function invokes the ReWOO Agent Graph and returns the response.
|
|
133
131
|
|
|
134
132
|
Args:
|
|
135
|
-
|
|
133
|
+
chat_request_or_message (ChatRequestOrMessage): The input message to process
|
|
136
134
|
|
|
137
135
|
Returns:
|
|
138
|
-
ChatResponse: The response from the agent or error message
|
|
136
|
+
ChatResponse | str: The response from the agent or error message
|
|
139
137
|
"""
|
|
140
138
|
try:
|
|
139
|
+
message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
|
|
140
|
+
|
|
141
141
|
# initialize the starting state with the user query
|
|
142
|
-
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in
|
|
142
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
|
|
143
143
|
max_tokens=config.max_history,
|
|
144
144
|
strategy="last",
|
|
145
145
|
token_counter=len,
|
|
@@ -160,25 +160,16 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
160
160
|
output_message = str(output_message)
|
|
161
161
|
|
|
162
162
|
# Create usage statistics for the response
|
|
163
|
-
prompt_tokens = sum(len(str(msg.content).split()) for msg in
|
|
163
|
+
prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
|
|
164
164
|
completion_tokens = len(output_message.split()) if output_message else 0
|
|
165
165
|
total_tokens = prompt_tokens + completion_tokens
|
|
166
166
|
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
|
|
167
|
-
|
|
168
|
-
|
|
167
|
+
response = ChatResponse.from_string(output_message, usage=usage)
|
|
168
|
+
if chat_request_or_message.is_string:
|
|
169
|
+
return GlobalTypeConverter.get().convert(response, to_type=str)
|
|
170
|
+
return response
|
|
169
171
|
except Exception as ex:
|
|
170
|
-
logger.
|
|
171
|
-
raise
|
|
172
|
-
|
|
173
|
-
if (config.use_openai_api):
|
|
174
|
-
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
175
|
-
|
|
176
|
-
else:
|
|
177
|
-
|
|
178
|
-
async def _str_api_fn(input_message: str) -> str:
|
|
179
|
-
oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
|
|
180
|
-
oai_output = await _response_fn(oai_input)
|
|
181
|
-
|
|
182
|
-
return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
|
|
172
|
+
logger.error("ReWOO Agent failed with exception: %s", ex)
|
|
173
|
+
raise
|
|
183
174
|
|
|
184
|
-
|
|
175
|
+
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
@@ -23,8 +23,10 @@ from nat.builder.function_info import FunctionInfo
|
|
|
23
23
|
from nat.cli.register_workflow import register_function
|
|
24
24
|
from nat.data_models.agent import AgentBaseConfig
|
|
25
25
|
from nat.data_models.api_server import ChatRequest
|
|
26
|
+
from nat.data_models.api_server import ChatRequestOrMessage
|
|
26
27
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
27
28
|
from nat.data_models.component_ref import FunctionRef
|
|
29
|
+
from nat.utils.type_converter import GlobalTypeConverter
|
|
28
30
|
|
|
29
31
|
logger = logging.getLogger(__name__)
|
|
30
32
|
|
|
@@ -81,21 +83,23 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
81
83
|
handle_tool_errors=config.handle_tool_errors,
|
|
82
84
|
return_direct=return_direct_tools).build_graph()
|
|
83
85
|
|
|
84
|
-
async def _response_fn(
|
|
86
|
+
async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> str:
|
|
85
87
|
"""
|
|
86
88
|
Main workflow entry function for the Tool Calling Agent.
|
|
87
89
|
|
|
88
90
|
This function invokes the Tool Calling Agent Graph and returns the response.
|
|
89
91
|
|
|
90
92
|
Args:
|
|
91
|
-
|
|
93
|
+
chat_request_or_message (ChatRequestOrMessage): The input message to process
|
|
92
94
|
|
|
93
95
|
Returns:
|
|
94
96
|
str: The response from the agent or error message
|
|
95
97
|
"""
|
|
96
98
|
try:
|
|
99
|
+
message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
|
|
100
|
+
|
|
97
101
|
# initialize the starting state with the user query
|
|
98
|
-
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in
|
|
102
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
|
|
99
103
|
max_tokens=config.max_history,
|
|
100
104
|
strategy="last",
|
|
101
105
|
token_counter=len,
|
|
@@ -114,8 +118,8 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
114
118
|
output_message = state.messages[-1]
|
|
115
119
|
return str(output_message.content)
|
|
116
120
|
except Exception as ex:
|
|
117
|
-
logger.
|
|
118
|
-
raise
|
|
121
|
+
logger.error("%s Tool Calling Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
|
|
122
|
+
raise
|
|
119
123
|
|
|
120
124
|
try:
|
|
121
125
|
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
nat/builder/function.py
CHANGED
|
@@ -159,8 +159,7 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
|
|
|
159
159
|
|
|
160
160
|
return result
|
|
161
161
|
except Exception as e:
|
|
162
|
-
|
|
163
|
-
logger.error("Error with ainvoke in function with input: %s. %s", value, err_msg)
|
|
162
|
+
logger.error("Error with ainvoke in function with input: %s. Error: %s", value, e)
|
|
164
163
|
raise
|
|
165
164
|
|
|
166
165
|
@typing.final
|
|
@@ -416,8 +415,9 @@ class FunctionGroup:
|
|
|
416
415
|
"""
|
|
417
416
|
if not name.strip():
|
|
418
417
|
raise ValueError("Function name cannot be empty or blank")
|
|
419
|
-
if not re.match(r"^[a-zA-Z0-9_
|
|
420
|
-
raise ValueError(
|
|
418
|
+
if not re.match(r"^[a-zA-Z0-9_.-]+$", name):
|
|
419
|
+
raise ValueError(
|
|
420
|
+
f"Function name can only contain letters, numbers, underscores, periods, and hyphens: {name}")
|
|
421
421
|
if name in self._functions:
|
|
422
422
|
raise ValueError(f"Function {name} already exists in function group {self._instance_name}")
|
|
423
423
|
|
|
@@ -16,6 +16,8 @@
|
|
|
16
16
|
import dataclasses
|
|
17
17
|
import logging
|
|
18
18
|
import typing
|
|
19
|
+
import weakref
|
|
20
|
+
from typing import ClassVar
|
|
19
21
|
|
|
20
22
|
from nat.data_models.intermediate_step import IntermediateStep
|
|
21
23
|
from nat.data_models.intermediate_step import IntermediateStepPayload
|
|
@@ -46,11 +48,19 @@ class IntermediateStepManager:
|
|
|
46
48
|
Manages updates to the NAT Event Stream for intermediate steps
|
|
47
49
|
"""
|
|
48
50
|
|
|
51
|
+
# Class-level tracking for debugging and monitoring
|
|
52
|
+
_instance_count: ClassVar[int] = 0
|
|
53
|
+
_active_instances: ClassVar[set[weakref.ref]] = set()
|
|
54
|
+
|
|
49
55
|
def __init__(self, context_state: "ContextState"): # noqa: F821
|
|
50
56
|
self._context_state = context_state
|
|
51
57
|
|
|
52
58
|
self._outstanding_start_steps: dict[str, OpenStep] = {}
|
|
53
59
|
|
|
60
|
+
# Track instance creation
|
|
61
|
+
IntermediateStepManager._instance_count += 1
|
|
62
|
+
IntermediateStepManager._active_instances.add(weakref.ref(self, self._cleanup_instance_tracking))
|
|
63
|
+
|
|
54
64
|
def push_intermediate_step(self, payload: IntermediateStepPayload) -> None:
|
|
55
65
|
"""
|
|
56
66
|
Pushes an intermediate step to the NAT Event Stream
|
|
@@ -172,3 +182,25 @@ class IntermediateStepManager:
|
|
|
172
182
|
"""
|
|
173
183
|
|
|
174
184
|
return self._context_state.event_stream.get().subscribe(on_next, on_error, on_complete)
|
|
185
|
+
|
|
186
|
+
@classmethod
|
|
187
|
+
def _cleanup_instance_tracking(cls, ref: weakref.ref) -> None:
|
|
188
|
+
"""Cleanup callback for weakref when instance is garbage collected."""
|
|
189
|
+
cls._active_instances.discard(ref)
|
|
190
|
+
|
|
191
|
+
@classmethod
|
|
192
|
+
def get_active_instance_count(cls) -> int:
|
|
193
|
+
"""Get the number of active IntermediateStepManager instances.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
int: Number of active instances (cleaned up automatically via weakref)
|
|
197
|
+
"""
|
|
198
|
+
return len(cls._active_instances)
|
|
199
|
+
|
|
200
|
+
def get_outstanding_step_count(self) -> int:
|
|
201
|
+
"""Get the number of outstanding (started but not ended) steps.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
int: Number of steps that have been started but not yet ended
|
|
205
|
+
"""
|
|
206
|
+
return len(self._outstanding_start_steps)
|
nat/data_models/api_server.py
CHANGED
|
@@ -28,6 +28,7 @@ from pydantic import HttpUrl
|
|
|
28
28
|
from pydantic import conlist
|
|
29
29
|
from pydantic import field_serializer
|
|
30
30
|
from pydantic import field_validator
|
|
31
|
+
from pydantic import model_validator
|
|
31
32
|
from pydantic_core.core_schema import ValidationInfo
|
|
32
33
|
|
|
33
34
|
from nat.data_models.interactive import HumanPrompt
|
|
@@ -120,15 +121,7 @@ class Message(BaseModel):
|
|
|
120
121
|
role: UserMessageContentRoleType
|
|
121
122
|
|
|
122
123
|
|
|
123
|
-
class
|
|
124
|
-
"""
|
|
125
|
-
ChatRequest is a data model that represents a request to the NAT chat API.
|
|
126
|
-
Fully compatible with OpenAI Chat Completions API specification.
|
|
127
|
-
"""
|
|
128
|
-
|
|
129
|
-
# Required fields
|
|
130
|
-
messages: typing.Annotated[list[Message], conlist(Message, min_length=1)]
|
|
131
|
-
|
|
124
|
+
class ChatRequestOptionals(BaseModel):
|
|
132
125
|
# Optional fields (OpenAI Chat Completions API compatible)
|
|
133
126
|
model: str | None = Field(default=None, description="name of the model to use")
|
|
134
127
|
frequency_penalty: float | None = Field(default=0.0,
|
|
@@ -153,6 +146,16 @@ class ChatRequest(BaseModel):
|
|
|
153
146
|
parallel_tool_calls: bool | None = Field(default=True, description="Whether to enable parallel function calling")
|
|
154
147
|
user: str | None = Field(default=None, description="Unique identifier representing end-user")
|
|
155
148
|
|
|
149
|
+
|
|
150
|
+
class ChatRequest(ChatRequestOptionals):
|
|
151
|
+
"""
|
|
152
|
+
ChatRequest is a data model that represents a request to the NAT chat API.
|
|
153
|
+
Fully compatible with OpenAI Chat Completions API specification.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
# Required fields
|
|
157
|
+
messages: typing.Annotated[list[Message], conlist(Message, min_length=1)]
|
|
158
|
+
|
|
156
159
|
model_config = ConfigDict(extra="allow",
|
|
157
160
|
json_schema_extra={
|
|
158
161
|
"example": {
|
|
@@ -194,6 +197,42 @@ class ChatRequest(BaseModel):
|
|
|
194
197
|
top_p=top_p)
|
|
195
198
|
|
|
196
199
|
|
|
200
|
+
class ChatRequestOrMessage(ChatRequestOptionals):
|
|
201
|
+
"""
|
|
202
|
+
ChatRequestOrMessage is a data model that represents either a conversation or a string input.
|
|
203
|
+
This is useful for functions that can handle either type of input.
|
|
204
|
+
|
|
205
|
+
`messages` is compatible with the OpenAI Chat Completions API specification.
|
|
206
|
+
|
|
207
|
+
`input_string` is a string input that can be used for functions that do not require a conversation.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
messages: typing.Annotated[list[Message] | None, conlist(Message, min_length=1)] = Field(
|
|
211
|
+
default=None, description="The conversation messages to process.")
|
|
212
|
+
|
|
213
|
+
input_string: str | None = Field(default=None, alias="input_message", description="The input message to process.")
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def is_string(self) -> bool:
|
|
217
|
+
return self.input_string is not None
|
|
218
|
+
|
|
219
|
+
@property
|
|
220
|
+
def is_conversation(self) -> bool:
|
|
221
|
+
return self.messages is not None
|
|
222
|
+
|
|
223
|
+
@model_validator(mode="after")
|
|
224
|
+
def validate_messages_or_input_string(self):
|
|
225
|
+
if self.messages is not None and self.input_string is not None:
|
|
226
|
+
raise ValueError("Either messages or input_message/input_string must be provided, not both")
|
|
227
|
+
if self.messages is None and self.input_string is None:
|
|
228
|
+
raise ValueError("Either messages or input_message/input_string must be provided")
|
|
229
|
+
if self.input_string is not None:
|
|
230
|
+
extra_fields = self.model_dump(exclude={"input_string"}, exclude_none=True, exclude_unset=True)
|
|
231
|
+
if len(extra_fields) > 0:
|
|
232
|
+
raise ValueError("no extra fields are permitted when input_message/input_string is provided")
|
|
233
|
+
return self
|
|
234
|
+
|
|
235
|
+
|
|
197
236
|
class ChoiceMessage(BaseModel):
|
|
198
237
|
content: str | None = None
|
|
199
238
|
role: UserMessageContentRoleType | None = None
|
|
@@ -661,6 +700,36 @@ def _string_to_nat_chat_request(data: str) -> ChatRequest:
|
|
|
661
700
|
GlobalTypeConverter.register_converter(_string_to_nat_chat_request)
|
|
662
701
|
|
|
663
702
|
|
|
703
|
+
def _chat_request_or_message_to_chat_request(data: ChatRequestOrMessage) -> ChatRequest:
|
|
704
|
+
if data.input_string is not None:
|
|
705
|
+
return _string_to_nat_chat_request(data.input_string)
|
|
706
|
+
return ChatRequest(**data.model_dump(exclude={"input_string"}))
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
GlobalTypeConverter.register_converter(_chat_request_or_message_to_chat_request)
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
def _chat_request_to_chat_request_or_message(data: ChatRequest) -> ChatRequestOrMessage:
|
|
713
|
+
return ChatRequestOrMessage(**data.model_dump(by_alias=True))
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
GlobalTypeConverter.register_converter(_chat_request_to_chat_request_or_message)
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
def _chat_request_or_message_to_string(data: ChatRequestOrMessage) -> str:
|
|
720
|
+
return data.input_string or ""
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
GlobalTypeConverter.register_converter(_chat_request_or_message_to_string)
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
def _string_to_chat_request_or_message(data: str) -> ChatRequestOrMessage:
|
|
727
|
+
return ChatRequestOrMessage(input_message=data)
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
GlobalTypeConverter.register_converter(_string_to_chat_request_or_message)
|
|
731
|
+
|
|
732
|
+
|
|
664
733
|
# ======== ChatResponse Converters ========
|
|
665
734
|
def _nat_chat_response_to_string(data: ChatResponse) -> str:
|
|
666
735
|
if data.choices and data.choices[0].message:
|
|
@@ -43,3 +43,16 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
|
43
43
|
|
|
44
44
|
server_auth: OAuth2ResourceServerConfig | None = Field(
|
|
45
45
|
default=None, description=("OAuth 2.0 Resource Server configuration for token verification."))
|
|
46
|
+
|
|
47
|
+
# Memory profiling configuration
|
|
48
|
+
enable_memory_profiling: bool = Field(default=False,
|
|
49
|
+
description="Enable memory profiling and diagnostics (default: False)")
|
|
50
|
+
memory_profile_interval: int = Field(default=50,
|
|
51
|
+
description="Log memory stats every N requests (default: 50)",
|
|
52
|
+
ge=1)
|
|
53
|
+
memory_profile_top_n: int = Field(default=10,
|
|
54
|
+
description="Number of top memory allocations to log (default: 10)",
|
|
55
|
+
ge=1,
|
|
56
|
+
le=50)
|
|
57
|
+
memory_profile_log_level: str = Field(default="DEBUG",
|
|
58
|
+
description="Log level for memory profiling output (default: DEBUG)")
|
|
@@ -29,6 +29,7 @@ from nat.builder.workflow import Workflow
|
|
|
29
29
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
30
30
|
from nat.data_models.config import Config
|
|
31
31
|
from nat.front_ends.mcp.mcp_front_end_config import MCPFrontEndConfig
|
|
32
|
+
from nat.front_ends.mcp.memory_profiler import MemoryProfiler
|
|
32
33
|
|
|
33
34
|
logger = logging.getLogger(__name__)
|
|
34
35
|
|
|
@@ -45,6 +46,12 @@ class MCPFrontEndPluginWorkerBase(ABC):
|
|
|
45
46
|
self.full_config = config
|
|
46
47
|
self.front_end_config: MCPFrontEndConfig = config.general.front_end
|
|
47
48
|
|
|
49
|
+
# Initialize memory profiler if enabled
|
|
50
|
+
self.memory_profiler = MemoryProfiler(enabled=self.front_end_config.enable_memory_profiling,
|
|
51
|
+
log_interval=self.front_end_config.memory_profile_interval,
|
|
52
|
+
top_n=self.front_end_config.memory_profile_top_n,
|
|
53
|
+
log_level=self.front_end_config.memory_profile_log_level)
|
|
54
|
+
|
|
48
55
|
def _setup_health_endpoint(self, mcp: FastMCP):
|
|
49
56
|
"""Set up the HTTP health endpoint that exercises MCP ping handler."""
|
|
50
57
|
|
|
@@ -115,6 +122,7 @@ class MCPFrontEndPluginWorkerBase(ABC):
|
|
|
115
122
|
Exposes:
|
|
116
123
|
- GET /debug/tools/list: List tools. Optional query param `name` (one or more, repeatable or comma separated)
|
|
117
124
|
selects a subset and returns details for those tools.
|
|
125
|
+
- GET /debug/memory/stats: Get current memory profiling statistics (read-only)
|
|
118
126
|
"""
|
|
119
127
|
|
|
120
128
|
@mcp.custom_route("/debug/tools/list", methods=["GET"])
|
|
@@ -206,6 +214,15 @@ class MCPFrontEndPluginWorkerBase(ABC):
|
|
|
206
214
|
return JSONResponse(
|
|
207
215
|
_build_final_json(functions_to_include, _parse_detail_param(detail_raw, has_names=bool(names))))
|
|
208
216
|
|
|
217
|
+
# Memory profiling endpoint (read-only)
|
|
218
|
+
@mcp.custom_route("/debug/memory/stats", methods=["GET"])
|
|
219
|
+
async def get_memory_stats(_request: Request):
|
|
220
|
+
"""Get current memory profiling statistics."""
|
|
221
|
+
from starlette.responses import JSONResponse
|
|
222
|
+
|
|
223
|
+
stats = self.memory_profiler.get_stats()
|
|
224
|
+
return JSONResponse(stats)
|
|
225
|
+
|
|
209
226
|
|
|
210
227
|
class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
|
|
211
228
|
"""Default MCP front end plugin worker implementation."""
|
|
@@ -241,7 +258,7 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
|
|
|
241
258
|
|
|
242
259
|
# Register each function with MCP, passing workflow context for observability
|
|
243
260
|
for function_name, function in functions.items():
|
|
244
|
-
register_function_with_mcp(mcp, function_name, function, workflow)
|
|
261
|
+
register_function_with_mcp(mcp, function_name, function, workflow, self.memory_profiler)
|
|
245
262
|
|
|
246
263
|
# Add a simple fallback function if no functions were found
|
|
247
264
|
if not functions:
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Memory profiling utilities for MCP frontend."""
|
|
16
|
+
|
|
17
|
+
import gc
|
|
18
|
+
import logging
|
|
19
|
+
import tracemalloc
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MemoryProfiler:
|
|
26
|
+
"""Memory profiler for tracking memory usage and potential leaks."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, enabled: bool = False, log_interval: int = 50, top_n: int = 10, log_level: str = "DEBUG"):
|
|
29
|
+
"""Initialize the memory profiler.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
enabled: Whether memory profiling is enabled
|
|
33
|
+
log_interval: Log stats every N requests
|
|
34
|
+
top_n: Number of top allocations to log
|
|
35
|
+
log_level: Log level for memory profiling output (e.g., "DEBUG", "INFO")
|
|
36
|
+
"""
|
|
37
|
+
self.enabled = enabled
|
|
38
|
+
# normalize interval to avoid modulo-by-zero
|
|
39
|
+
self.log_interval = max(1, int(log_interval))
|
|
40
|
+
self.top_n = top_n
|
|
41
|
+
self.log_level = getattr(logging, log_level.upper(), logging.DEBUG)
|
|
42
|
+
self.request_count = 0
|
|
43
|
+
self.baseline_snapshot = None
|
|
44
|
+
|
|
45
|
+
# Track whether this instance started tracemalloc (to avoid resetting external tracing)
|
|
46
|
+
self._we_started_tracemalloc = False
|
|
47
|
+
|
|
48
|
+
if self.enabled:
|
|
49
|
+
logger.info("Memory profiling ENABLED (interval=%d, top_n=%d, log_level=%s)",
|
|
50
|
+
self.log_interval,
|
|
51
|
+
top_n,
|
|
52
|
+
log_level)
|
|
53
|
+
try:
|
|
54
|
+
if not tracemalloc.is_tracing():
|
|
55
|
+
tracemalloc.start()
|
|
56
|
+
self._we_started_tracemalloc = True
|
|
57
|
+
# Take baseline snapshot
|
|
58
|
+
gc.collect()
|
|
59
|
+
self.baseline_snapshot = tracemalloc.take_snapshot()
|
|
60
|
+
except RuntimeError as e:
|
|
61
|
+
logger.warning("tracemalloc unavailable or not tracing: %s", e)
|
|
62
|
+
else:
|
|
63
|
+
logger.info("Memory profiling DISABLED")
|
|
64
|
+
|
|
65
|
+
def _log(self, message: str, *args: Any) -> None:
|
|
66
|
+
"""Log a message at the configured log level.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
message: Log message format string
|
|
70
|
+
args: Arguments for the format string
|
|
71
|
+
"""
|
|
72
|
+
logger.log(self.log_level, message, *args)
|
|
73
|
+
|
|
74
|
+
def on_request_complete(self) -> None:
|
|
75
|
+
"""Called after each request completes."""
|
|
76
|
+
if not self.enabled:
|
|
77
|
+
return
|
|
78
|
+
self.request_count += 1
|
|
79
|
+
if self.request_count % self.log_interval == 0:
|
|
80
|
+
self.log_memory_stats()
|
|
81
|
+
|
|
82
|
+
def _ensure_tracing(self) -> bool:
|
|
83
|
+
"""Ensure tracemalloc is running if we started it originally.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
True if tracemalloc is active, False otherwise
|
|
87
|
+
"""
|
|
88
|
+
if tracemalloc.is_tracing():
|
|
89
|
+
return True
|
|
90
|
+
|
|
91
|
+
# Only restart if we started it originally (respect external control)
|
|
92
|
+
if not self._we_started_tracemalloc:
|
|
93
|
+
return False
|
|
94
|
+
|
|
95
|
+
# Attempt to restart
|
|
96
|
+
try:
|
|
97
|
+
logger.warning("tracemalloc was stopped externally; restarting (we started it originally)")
|
|
98
|
+
tracemalloc.start()
|
|
99
|
+
|
|
100
|
+
# Reset baseline since old tracking data is lost
|
|
101
|
+
gc.collect()
|
|
102
|
+
self.baseline_snapshot = tracemalloc.take_snapshot()
|
|
103
|
+
logger.info("Baseline snapshot reset after tracemalloc restart")
|
|
104
|
+
|
|
105
|
+
return True
|
|
106
|
+
except RuntimeError as e:
|
|
107
|
+
logger.error("Failed to restart tracemalloc: %s", e)
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
def _safe_traced_memory(self) -> tuple[float, float] | None:
|
|
111
|
+
"""Return (current, peak usage in MB) if tracemalloc is active, else None."""
|
|
112
|
+
if not self._ensure_tracing():
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
current, peak = tracemalloc.get_traced_memory()
|
|
117
|
+
megabyte = (1 << 20)
|
|
118
|
+
return (current / megabyte, peak / megabyte)
|
|
119
|
+
except RuntimeError:
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
def _safe_snapshot(self) -> tracemalloc.Snapshot | None:
|
|
123
|
+
"""Return a tracemalloc Snapshot if available, else None."""
|
|
124
|
+
if not self._ensure_tracing():
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
return tracemalloc.take_snapshot()
|
|
129
|
+
except RuntimeError:
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
def log_memory_stats(self) -> dict[str, Any]:
|
|
133
|
+
"""Log current memory statistics and return them."""
|
|
134
|
+
if not self.enabled:
|
|
135
|
+
return {}
|
|
136
|
+
|
|
137
|
+
# Force garbage collection first
|
|
138
|
+
gc.collect()
|
|
139
|
+
|
|
140
|
+
# Get current memory usage
|
|
141
|
+
mem = self._safe_traced_memory()
|
|
142
|
+
if mem is None:
|
|
143
|
+
logger.info("tracemalloc is not active; cannot collect memory stats.")
|
|
144
|
+
# still return structural fields
|
|
145
|
+
stats = {
|
|
146
|
+
"request_count": self.request_count,
|
|
147
|
+
"current_memory_mb": None,
|
|
148
|
+
"peak_memory_mb": None,
|
|
149
|
+
"active_intermediate_managers": self._safe_intermediate_step_manager_count(),
|
|
150
|
+
"outstanding_steps": self._safe_outstanding_step_count(),
|
|
151
|
+
"active_exporters": self._safe_exporter_count(),
|
|
152
|
+
"isolated_exporters": self._safe_isolated_exporter_count(),
|
|
153
|
+
"subject_instances": self._count_instances_of_type("Subject"),
|
|
154
|
+
}
|
|
155
|
+
return stats
|
|
156
|
+
|
|
157
|
+
current_mb, peak_mb = mem
|
|
158
|
+
|
|
159
|
+
# Take snapshot and compare to baseline
|
|
160
|
+
snapshot = self._safe_snapshot()
|
|
161
|
+
|
|
162
|
+
# Track BaseExporter instances (observability layer)
|
|
163
|
+
exporter_count = self._safe_exporter_count()
|
|
164
|
+
isolated_exporter_count = self._safe_isolated_exporter_count()
|
|
165
|
+
|
|
166
|
+
# Track Subject instances (event streams)
|
|
167
|
+
subject_count = self._count_instances_of_type("Subject")
|
|
168
|
+
|
|
169
|
+
stats = {
|
|
170
|
+
"request_count": self.request_count,
|
|
171
|
+
"current_memory_mb": round(current_mb, 2),
|
|
172
|
+
"peak_memory_mb": round(peak_mb, 2),
|
|
173
|
+
"active_intermediate_managers": self._safe_intermediate_step_manager_count(),
|
|
174
|
+
"outstanding_steps": self._safe_outstanding_step_count(),
|
|
175
|
+
"active_exporters": exporter_count,
|
|
176
|
+
"isolated_exporters": isolated_exporter_count,
|
|
177
|
+
"subject_instances": subject_count,
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
self._log("=" * 80)
|
|
181
|
+
self._log("MEMORY PROFILE AFTER %d REQUESTS:", self.request_count)
|
|
182
|
+
self._log(" Current Memory: %.2f MB", current_mb)
|
|
183
|
+
self._log(" Peak Memory: %.2f MB", peak_mb)
|
|
184
|
+
self._log("")
|
|
185
|
+
self._log("NAT COMPONENT INSTANCES:")
|
|
186
|
+
self._log(" IntermediateStepManagers: %d active (%d outstanding steps)",
|
|
187
|
+
stats["active_intermediate_managers"],
|
|
188
|
+
stats["outstanding_steps"])
|
|
189
|
+
self._log(" BaseExporters: %d active (%d isolated)", stats["active_exporters"], stats["isolated_exporters"])
|
|
190
|
+
self._log(" Subject (event streams): %d instances", stats["subject_instances"])
|
|
191
|
+
|
|
192
|
+
# Show top allocations
|
|
193
|
+
if snapshot is None:
|
|
194
|
+
self._log("tracemalloc snapshot unavailable.")
|
|
195
|
+
else:
|
|
196
|
+
if self.baseline_snapshot:
|
|
197
|
+
self._log("TOP %d MEMORY GROWTH SINCE BASELINE:", self.top_n)
|
|
198
|
+
top_stats = snapshot.compare_to(self.baseline_snapshot, 'lineno')
|
|
199
|
+
else:
|
|
200
|
+
self._log("TOP %d MEMORY ALLOCATIONS:", self.top_n)
|
|
201
|
+
top_stats = snapshot.statistics('lineno')
|
|
202
|
+
|
|
203
|
+
for i, stat in enumerate(top_stats[:self.top_n], 1):
|
|
204
|
+
self._log(" #%d: %s", i, stat)
|
|
205
|
+
|
|
206
|
+
self._log("=" * 80)
|
|
207
|
+
|
|
208
|
+
return stats
|
|
209
|
+
|
|
210
|
+
def _count_instances_of_type(self, type_name: str) -> int:
|
|
211
|
+
"""Count instances of a specific type in memory."""
|
|
212
|
+
count = 0
|
|
213
|
+
for obj in gc.get_objects():
|
|
214
|
+
try:
|
|
215
|
+
if type(obj).__name__ == type_name:
|
|
216
|
+
count += 1
|
|
217
|
+
except Exception:
|
|
218
|
+
pass
|
|
219
|
+
return count
|
|
220
|
+
|
|
221
|
+
def _safe_exporter_count(self) -> int:
|
|
222
|
+
try:
|
|
223
|
+
from nat.observability.exporter.base_exporter import BaseExporter
|
|
224
|
+
return BaseExporter.get_active_instance_count()
|
|
225
|
+
except Exception as e:
|
|
226
|
+
logger.debug("Could not get BaseExporter stats: %s", e)
|
|
227
|
+
return 0
|
|
228
|
+
|
|
229
|
+
def _safe_isolated_exporter_count(self) -> int:
|
|
230
|
+
try:
|
|
231
|
+
from nat.observability.exporter.base_exporter import BaseExporter
|
|
232
|
+
return BaseExporter.get_isolated_instance_count()
|
|
233
|
+
except Exception:
|
|
234
|
+
return 0
|
|
235
|
+
|
|
236
|
+
def _safe_intermediate_step_manager_count(self) -> int:
|
|
237
|
+
try:
|
|
238
|
+
from nat.builder.intermediate_step_manager import IntermediateStepManager
|
|
239
|
+
# len() is atomic in CPython, but catch RuntimeError just in case
|
|
240
|
+
try:
|
|
241
|
+
return IntermediateStepManager.get_active_instance_count()
|
|
242
|
+
except RuntimeError:
|
|
243
|
+
# Set was modified during len() - very rare
|
|
244
|
+
logger.debug("Set changed during count, returning 0")
|
|
245
|
+
return 0
|
|
246
|
+
except Exception as e:
|
|
247
|
+
logger.debug("Could not get IntermediateStepManager stats: %s", e)
|
|
248
|
+
return 0
|
|
249
|
+
|
|
250
|
+
def _safe_outstanding_step_count(self) -> int:
|
|
251
|
+
"""Get total outstanding steps across all active IntermediateStepManager instances."""
|
|
252
|
+
try:
|
|
253
|
+
from nat.builder.intermediate_step_manager import IntermediateStepManager
|
|
254
|
+
|
|
255
|
+
# Make a snapshot to avoid "Set changed size during iteration" if GC runs
|
|
256
|
+
try:
|
|
257
|
+
instances_snapshot = list(IntermediateStepManager._active_instances)
|
|
258
|
+
except RuntimeError:
|
|
259
|
+
# Set changed during list() call - rare but possible
|
|
260
|
+
logger.debug("Set changed during snapshot, returning 0 for outstanding steps")
|
|
261
|
+
return 0
|
|
262
|
+
|
|
263
|
+
total_outstanding = 0
|
|
264
|
+
# Iterate through snapshot safely
|
|
265
|
+
for ref in instances_snapshot:
|
|
266
|
+
try:
|
|
267
|
+
manager = ref()
|
|
268
|
+
if manager is not None:
|
|
269
|
+
total_outstanding += manager.get_outstanding_step_count()
|
|
270
|
+
except (ReferenceError, AttributeError):
|
|
271
|
+
# Manager was GC'd or in invalid state - skip it
|
|
272
|
+
continue
|
|
273
|
+
return total_outstanding
|
|
274
|
+
except Exception as e:
|
|
275
|
+
logger.debug("Could not get outstanding step count: %s", e)
|
|
276
|
+
return 0
|
|
277
|
+
|
|
278
|
+
def get_stats(self) -> dict[str, Any]:
|
|
279
|
+
"""Get current memory statistics without logging."""
|
|
280
|
+
if not self.enabled:
|
|
281
|
+
return {"enabled": False}
|
|
282
|
+
|
|
283
|
+
mem = self._safe_traced_memory()
|
|
284
|
+
if mem is None:
|
|
285
|
+
return {
|
|
286
|
+
"enabled": True,
|
|
287
|
+
"request_count": self.request_count,
|
|
288
|
+
"current_memory_mb": None,
|
|
289
|
+
"peak_memory_mb": None,
|
|
290
|
+
"active_intermediate_managers": self._safe_intermediate_step_manager_count(),
|
|
291
|
+
"outstanding_steps": self._safe_outstanding_step_count(),
|
|
292
|
+
"active_exporters": self._safe_exporter_count(),
|
|
293
|
+
"isolated_exporters": self._safe_isolated_exporter_count(),
|
|
294
|
+
"subject_instances": self._count_instances_of_type("Subject"),
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
current_mb, peak_mb = mem
|
|
298
|
+
return {
|
|
299
|
+
"enabled": True,
|
|
300
|
+
"request_count": self.request_count,
|
|
301
|
+
"current_memory_mb": round(current_mb, 2),
|
|
302
|
+
"peak_memory_mb": round(peak_mb, 2),
|
|
303
|
+
"active_intermediate_managers": self._safe_intermediate_step_manager_count(),
|
|
304
|
+
"outstanding_steps": self._safe_outstanding_step_count(),
|
|
305
|
+
"active_exporters": self._safe_exporter_count(),
|
|
306
|
+
"isolated_exporters": self._safe_isolated_exporter_count(),
|
|
307
|
+
"subject_instances": self._count_instances_of_type("Subject"),
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
def reset_baseline(self) -> None:
|
|
311
|
+
"""Reset the baseline snapshot to current state."""
|
|
312
|
+
if not self.enabled:
|
|
313
|
+
return
|
|
314
|
+
gc.collect()
|
|
315
|
+
snap = self._safe_snapshot()
|
|
316
|
+
if snap is None:
|
|
317
|
+
logger.info("Cannot reset baseline: tracemalloc is not active.")
|
|
318
|
+
return
|
|
319
|
+
self.baseline_snapshot = snap
|
|
320
|
+
logger.info("Memory profiling baseline reset at request %d", self.request_count)
|
|
@@ -28,6 +28,7 @@ from nat.builder.function_base import FunctionBase
|
|
|
28
28
|
|
|
29
29
|
if TYPE_CHECKING:
|
|
30
30
|
from nat.builder.workflow import Workflow
|
|
31
|
+
from nat.front_ends.mcp.memory_profiler import MemoryProfiler
|
|
31
32
|
|
|
32
33
|
logger = logging.getLogger(__name__)
|
|
33
34
|
|
|
@@ -38,6 +39,7 @@ def create_function_wrapper(
|
|
|
38
39
|
schema: type[BaseModel],
|
|
39
40
|
is_workflow: bool = False,
|
|
40
41
|
workflow: 'Workflow | None' = None,
|
|
42
|
+
memory_profiler: 'MemoryProfiler | None' = None,
|
|
41
43
|
):
|
|
42
44
|
"""Create a wrapper function that exposes the actual parameters of a NAT Function as an MCP tool.
|
|
43
45
|
|
|
@@ -47,6 +49,7 @@ def create_function_wrapper(
|
|
|
47
49
|
schema (type[BaseModel]): The input schema of the function
|
|
48
50
|
is_workflow (bool): Whether the function is a Workflow
|
|
49
51
|
workflow (Workflow | None): The parent workflow for observability context
|
|
52
|
+
memory_profiler: Optional memory profiler to track requests
|
|
50
53
|
|
|
51
54
|
Returns:
|
|
52
55
|
A wrapper function suitable for registration with MCP
|
|
@@ -172,6 +175,10 @@ def create_function_wrapper(
|
|
|
172
175
|
if ctx:
|
|
173
176
|
await ctx.report_progress(100, 100)
|
|
174
177
|
|
|
178
|
+
# Track request completion for memory profiling
|
|
179
|
+
if memory_profiler:
|
|
180
|
+
memory_profiler.on_request_complete()
|
|
181
|
+
|
|
175
182
|
# Handle different result types for proper formatting
|
|
176
183
|
if isinstance(result, str):
|
|
177
184
|
return result
|
|
@@ -181,6 +188,11 @@ def create_function_wrapper(
|
|
|
181
188
|
except Exception as e:
|
|
182
189
|
if ctx:
|
|
183
190
|
ctx.error("Error calling function %s: %s", function_name, str(e))
|
|
191
|
+
|
|
192
|
+
# Track request completion even on error
|
|
193
|
+
if memory_profiler:
|
|
194
|
+
memory_profiler.on_request_complete()
|
|
195
|
+
|
|
184
196
|
raise
|
|
185
197
|
|
|
186
198
|
return wrapper_with_ctx
|
|
@@ -242,7 +254,8 @@ def get_function_description(function: FunctionBase) -> str:
|
|
|
242
254
|
def register_function_with_mcp(mcp: FastMCP,
|
|
243
255
|
function_name: str,
|
|
244
256
|
function: FunctionBase,
|
|
245
|
-
workflow: 'Workflow | None' = None
|
|
257
|
+
workflow: 'Workflow | None' = None,
|
|
258
|
+
memory_profiler: 'MemoryProfiler | None' = None) -> None:
|
|
246
259
|
"""Register a NAT Function as an MCP tool.
|
|
247
260
|
|
|
248
261
|
Args:
|
|
@@ -250,6 +263,7 @@ def register_function_with_mcp(mcp: FastMCP,
|
|
|
250
263
|
function_name: The name to register the function under
|
|
251
264
|
function: The NAT Function to register
|
|
252
265
|
workflow: The parent workflow for observability context (if available)
|
|
266
|
+
memory_profiler: Optional memory profiler to track requests
|
|
253
267
|
"""
|
|
254
268
|
logger.info("Registering function %s with MCP", function_name)
|
|
255
269
|
|
|
@@ -267,5 +281,10 @@ def register_function_with_mcp(mcp: FastMCP,
|
|
|
267
281
|
function_description = get_function_description(function)
|
|
268
282
|
|
|
269
283
|
# Create and register the wrapper function with MCP
|
|
270
|
-
wrapper_func = create_function_wrapper(function_name,
|
|
284
|
+
wrapper_func = create_function_wrapper(function_name,
|
|
285
|
+
function,
|
|
286
|
+
input_schema,
|
|
287
|
+
is_workflow,
|
|
288
|
+
workflow,
|
|
289
|
+
memory_profiler)
|
|
271
290
|
mcp.tool(name=function_name, description=function_description)(wrapper_func)
|
nat/runtime/runner.py
CHANGED
|
@@ -196,8 +196,7 @@ class Runner:
|
|
|
196
196
|
|
|
197
197
|
return result
|
|
198
198
|
except Exception as e:
|
|
199
|
-
|
|
200
|
-
logger.error("Error running workflow%s", err_msg)
|
|
199
|
+
logger.error("Error running workflow: %s", e)
|
|
201
200
|
event_stream = self._context_state.event_stream.get()
|
|
202
201
|
if event_stream:
|
|
203
202
|
event_stream.on_complete()
|
nat/utils/type_converter.py
CHANGED
|
@@ -93,6 +93,14 @@ class TypeConverter:
|
|
|
93
93
|
if to_type is None or decomposed.is_instance(data):
|
|
94
94
|
return data
|
|
95
95
|
|
|
96
|
+
# 2) If data is a union type, try to convert to each type in the union
|
|
97
|
+
if decomposed.is_union:
|
|
98
|
+
for union_type in decomposed.args:
|
|
99
|
+
result = self._convert(data, union_type)
|
|
100
|
+
if result is not None:
|
|
101
|
+
return result
|
|
102
|
+
return None
|
|
103
|
+
|
|
96
104
|
root = decomposed.root
|
|
97
105
|
|
|
98
106
|
# 2) Attempt direct in *this* converter
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nvidia-nat
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.0a20251011
|
|
4
4
|
Summary: NVIDIA NeMo Agent toolkit
|
|
5
5
|
Author: NVIDIA Corporation
|
|
6
6
|
Maintainer: NVIDIA Corporation
|
|
7
|
-
License
|
|
7
|
+
License: Apache-2.0
|
|
8
8
|
Project-URL: documentation, https://docs.nvidia.com/nemo/agent-toolkit/latest/
|
|
9
9
|
Project-URL: source, https://github.com/NVIDIA/NeMo-Agent-Toolkit
|
|
10
10
|
Keywords: ai,rag,agents
|
|
@@ -14,8 +14,8 @@ Classifier: Programming Language :: Python :: 3.12
|
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.13
|
|
15
15
|
Requires-Python: <3.14,>=3.11
|
|
16
16
|
Description-Content-Type: text/markdown
|
|
17
|
-
License-File: LICENSE.md
|
|
18
17
|
License-File: LICENSE-3rd-party.txt
|
|
18
|
+
License-File: LICENSE.md
|
|
19
19
|
Requires-Dist: aioboto3>=11.0.0
|
|
20
20
|
Requires-Dist: authlib~=1.5
|
|
21
21
|
Requires-Dist: click~=8.1
|
|
@@ -10,16 +10,16 @@ nat/agent/react_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3h
|
|
|
10
10
|
nat/agent/react_agent/agent.py,sha256=sWrg9WrglTKQQyG3EcjNm2JTEchCPEo9li-Po7TJKss,21294
|
|
11
11
|
nat/agent/react_agent/output_parser.py,sha256=m7K6wRwtckBBpAHqOf3BZ9mqZLwrP13Kxz5fvNxbyZE,4219
|
|
12
12
|
nat/agent/react_agent/prompt.py,sha256=N47JJrT6xwYQCv1jedHhlul2AE7EfKsSYfAbgJwWRew,1758
|
|
13
|
-
nat/agent/react_agent/register.py,sha256=
|
|
13
|
+
nat/agent/react_agent/register.py,sha256=qkPaK6AvXjolL-q_Z3waVobXDz24GMfuqGqCn-2un2Q,8991
|
|
14
14
|
nat/agent/reasoning_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
15
|
nat/agent/reasoning_agent/reasoning_agent.py,sha256=k_0wEDqACQn1Rn1MAKxoXyqOKsthHCQ1gt990YYUqHU,9575
|
|
16
16
|
nat/agent/rewoo_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
17
|
nat/agent/rewoo_agent/agent.py,sha256=XXgVXY9xwkyxnr093KXUtfgyNxAQbyGAecoGqN5mMLY,26199
|
|
18
18
|
nat/agent/rewoo_agent/prompt.py,sha256=B0JeL1xDX4VKcShlkkviEcAsOKAwzSlX8NcAQdmUUPw,3645
|
|
19
|
-
nat/agent/rewoo_agent/register.py,sha256=
|
|
19
|
+
nat/agent/rewoo_agent/register.py,sha256=XArlOR37QOBtAvsdKJUjRok5qTmx39S2mJHSteOwU58,9283
|
|
20
20
|
nat/agent/tool_calling_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
21
21
|
nat/agent/tool_calling_agent/agent.py,sha256=4SIp29I56oznPRQu7B3HCoX53Ri3_o3BRRYNJjeBkF8,11006
|
|
22
|
-
nat/agent/tool_calling_agent/register.py,sha256=
|
|
22
|
+
nat/agent/tool_calling_agent/register.py,sha256=OucceyELA2xZL3KdANWK9w12fnVP75eVbZgzOnmXHys,7057
|
|
23
23
|
nat/authentication/__init__.py,sha256=Xs1JQ16L9btwreh4pdGKwskffAw1YFO48jKrU4ib_7c,685
|
|
24
24
|
nat/authentication/interfaces.py,sha256=1J2CWEJ_n6CLA3_HD3XV28CSbyfxrPAHzr7Q4kKDFdc,3511
|
|
25
25
|
nat/authentication/register.py,sha256=lFhswYUk9iZ53mq33fClR9UfjJPdjGIivGGNHQeWiYo,915
|
|
@@ -48,10 +48,10 @@ nat/builder/eval_builder.py,sha256=I-ScvupmorClYoVBIs_PhSsB7Xf9e2nGWe0rCZp3txo,6
|
|
|
48
48
|
nat/builder/evaluator.py,sha256=xWHMND2vcAUkdFP7FU3jnVki1rUHeTa0-9saFh2hWKs,1162
|
|
49
49
|
nat/builder/framework_enum.py,sha256=n7IaTQBxhFozIQqRMcX9kXntw28JhFzCj82jJ0C5tNU,901
|
|
50
50
|
nat/builder/front_end.py,sha256=FCJ87NSshVVuTg8zZrq3YAr_u0RaYVZVcibnqlRFy-M,2173
|
|
51
|
-
nat/builder/function.py,sha256=
|
|
51
|
+
nat/builder/function.py,sha256=eZZWLwhphgQTnPvbga8sGleX7HCP46usZPIegE7zFzs,27725
|
|
52
52
|
nat/builder/function_base.py,sha256=0Eg8RtjWhEU3Yme0CVxcRutobA0Qo8-YHZLI6L2qAgM,13116
|
|
53
53
|
nat/builder/function_info.py,sha256=7Rmrn-gOFrT2TIJklJwA_O-ycx_oimwZ0-qMYpbuZrU,25161
|
|
54
|
-
nat/builder/intermediate_step_manager.py,sha256=
|
|
54
|
+
nat/builder/intermediate_step_manager.py,sha256=E4syoUNn0BGHnNqqmTYn2oMXKSHkf8GCmTpVeJX3zTY,8764
|
|
55
55
|
nat/builder/llm.py,sha256=DW-2q64A06VChsXNEL5PfBjH3DcsnTKVoCEWDuP7MF4,951
|
|
56
56
|
nat/builder/retriever.py,sha256=ZyEqc7pFK31t_yr6Jaxa34c-tRas2edKqJZCNiVh9-0,970
|
|
57
57
|
nat/builder/user_interaction_manager.py,sha256=-Z2qbQes7a2cuXgT7KEbWeuok0HcCnRdw9WB8Ghyl9k,3081
|
|
@@ -112,7 +112,7 @@ nat/control_flow/router_agent/prompt.py,sha256=fIAiNsAs1zXRAatButR76zSpHJNxSkXXK
|
|
|
112
112
|
nat/control_flow/router_agent/register.py,sha256=4RGmS9sy-QtIMmvh8mfMcR1VqxFPLpG4RckWCIExh40,4144
|
|
113
113
|
nat/data_models/__init__.py,sha256=Xs1JQ16L9btwreh4pdGKwskffAw1YFO48jKrU4ib_7c,685
|
|
114
114
|
nat/data_models/agent.py,sha256=IwDyb9Zc3R4Zd5rFeqt7q0EQswczAl5focxV9KozIzs,1625
|
|
115
|
-
nat/data_models/api_server.py,sha256=
|
|
115
|
+
nat/data_models/api_server.py,sha256=NWT1ChN2qaakD2DgyYCy_7MhfzvEBQX15qnUXnpCQmk,28883
|
|
116
116
|
nat/data_models/authentication.py,sha256=XPu9W8nh4XRSuxPv3HxO-FMQ_JtTEoK6Y02JwnzDwTg,8457
|
|
117
117
|
nat/data_models/common.py,sha256=nXXfGrjpxebzBUa55mLdmzePLt7VFHvTAc6Znj3yEv0,5875
|
|
118
118
|
nat/data_models/component.py,sha256=b_hXOA8Gm5UNvlFkAhsR6kEvf33ST50MKtr5kWf75Ao,1894
|
|
@@ -259,11 +259,12 @@ nat/front_ends/fastapi/html_snippets/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv
|
|
|
259
259
|
nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py,sha256=BNpWwzmA58UM0GK4kZXG4PHJy_5K9ihaVHu8SgCs5JA,1131
|
|
260
260
|
nat/front_ends/mcp/__init__.py,sha256=Xs1JQ16L9btwreh4pdGKwskffAw1YFO48jKrU4ib_7c,685
|
|
261
261
|
nat/front_ends/mcp/introspection_token_verifier.py,sha256=s7Q4Q6rWZJ0ZVujSxxpvVI6Bnhkg1LJQ3RLkvhzFIGE,2836
|
|
262
|
-
nat/front_ends/mcp/mcp_front_end_config.py,sha256=
|
|
262
|
+
nat/front_ends/mcp/mcp_front_end_config.py,sha256=aDgNAyzl_09GfMCWKRGPV8_-16PAov5N40UMMD4yg8c,3143
|
|
263
263
|
nat/front_ends/mcp/mcp_front_end_plugin.py,sha256=NiIIgApk1X2yAEwtG9tHaY6SexQMbZrd6Drs7uIJix8,5055
|
|
264
|
-
nat/front_ends/mcp/mcp_front_end_plugin_worker.py,sha256=
|
|
264
|
+
nat/front_ends/mcp/mcp_front_end_plugin_worker.py,sha256=NUYu2FFrHat_U4VBAm3c9YMpWSyVr2pKUYyseSFw2pM,11208
|
|
265
|
+
nat/front_ends/mcp/memory_profiler.py,sha256=OpcpLBAGCdQwYSFZbtAqdfncrnGYVjDcMpWydB71hjY,12811
|
|
265
266
|
nat/front_ends/mcp/register.py,sha256=3aJtgG5VaiqujoeU1-Eq7Hl5pWslIlIwGFU2ASLTXgM,1173
|
|
266
|
-
nat/front_ends/mcp/tool_converter.py,sha256=
|
|
267
|
+
nat/front_ends/mcp/tool_converter.py,sha256=14NweQN3cPFBw7ZNiGyUHO4VhMGHrtfLGgvu4_H38oU,12426
|
|
267
268
|
nat/front_ends/simple_base/__init__.py,sha256=Xs1JQ16L9btwreh4pdGKwskffAw1YFO48jKrU4ib_7c,685
|
|
268
269
|
nat/front_ends/simple_base/simple_front_end_plugin_base.py,sha256=py_yA9XAw-yHfK5cQJLM8ElnubEEM2ac8M0bvz-ScWs,1801
|
|
269
270
|
nat/llm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -406,7 +407,7 @@ nat/retriever/nemo_retriever/register.py,sha256=3XdrvEJzX2Zc8wpdm__4YYlEWBW-FK3t
|
|
|
406
407
|
nat/retriever/nemo_retriever/retriever.py,sha256=gi3_qJFqE-iqRh3of_cmJg-SwzaQ3z24zA9LwY_MSLY,6930
|
|
407
408
|
nat/runtime/__init__.py,sha256=Xs1JQ16L9btwreh4pdGKwskffAw1YFO48jKrU4ib_7c,685
|
|
408
409
|
nat/runtime/loader.py,sha256=obUdAgZVYCPGC0R8u3wcoKFJzzSPQgJvrbU4OWygtog,7953
|
|
409
|
-
nat/runtime/runner.py,sha256=
|
|
410
|
+
nat/runtime/runner.py,sha256=qa_AqtmB8TUHX6nVJ0TLEYCKUsm2L99kq5O72AuL3yc,11736
|
|
410
411
|
nat/runtime/session.py,sha256=E8RTbnAhPbY5KCoSfiHzOJksmBh7xWjsoX0BC7Rn1ck,9101
|
|
411
412
|
nat/runtime/user_metadata.py,sha256=ce37NRYJWnMOWk6A7VAQ1GQztjMmkhMOq-uYf2gNCwo,3692
|
|
412
413
|
nat/settings/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -448,7 +449,7 @@ nat/utils/metadata_utils.py,sha256=BSsiB6jIWd8oEuEynJi55qCG762UuTYFaiUH0OT9HdY,2
|
|
|
448
449
|
nat/utils/optional_imports.py,sha256=jQSVBc2fBSRw-2d6r8cEwvh5-di2EUUPakuuo9QbbwA,4039
|
|
449
450
|
nat/utils/producer_consumer_queue.py,sha256=AcSYkAMBxLx06A5Xdy960PP3AJ7YaSPGJ7rbN_hJsjI,6599
|
|
450
451
|
nat/utils/string_utils.py,sha256=71HuIzGx7rF8ocTmeoUBpnCi1Qf1yynYlNLLIKP4BVs,1415
|
|
451
|
-
nat/utils/type_converter.py,sha256
|
|
452
|
+
nat/utils/type_converter.py,sha256=vDZzrZ9ycWgZJdkWB1sHB2ivZX-E8fPfkrB-vAAxroI,10968
|
|
452
453
|
nat/utils/type_utils.py,sha256=SMo5hM4dKf2G3U_0J0wvdFX6-lzMVSh8vd-W34Oixow,14836
|
|
453
454
|
nat/utils/url_utils.py,sha256=UzDP_xaS6brWTu7vAws0B4jZyrITIK9Si3U6pZBZqDE,1028
|
|
454
455
|
nat/utils/data_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -470,10 +471,10 @@ nat/utils/reactive/base/observer_base.py,sha256=6BiQfx26EMumotJ3KoVcdmFBYR_fnAss
|
|
|
470
471
|
nat/utils/reactive/base/subject_base.py,sha256=UQOxlkZTIeeyYmG5qLtDpNf_63Y7p-doEeUA08_R8ME,2521
|
|
471
472
|
nat/utils/settings/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
472
473
|
nat/utils/settings/global_settings.py,sha256=9JaO6pxKT_Pjw6rxJRsRlFCXdVKCl_xUKU2QHZQWWNM,7294
|
|
473
|
-
nvidia_nat-1.4.
|
|
474
|
-
nvidia_nat-1.4.
|
|
475
|
-
nvidia_nat-1.4.
|
|
476
|
-
nvidia_nat-1.4.
|
|
477
|
-
nvidia_nat-1.4.
|
|
478
|
-
nvidia_nat-1.4.
|
|
479
|
-
nvidia_nat-1.4.
|
|
474
|
+
nvidia_nat-1.4.0a20251011.dist-info/licenses/LICENSE-3rd-party.txt,sha256=fOk5jMmCX9YoKWyYzTtfgl-SUy477audFC5hNY4oP7Q,284609
|
|
475
|
+
nvidia_nat-1.4.0a20251011.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
|
476
|
+
nvidia_nat-1.4.0a20251011.dist-info/METADATA,sha256=kBMZinQbnPKYYzxF1s7BvCytUn0HncnbGG6yKrtCIqo,10228
|
|
477
|
+
nvidia_nat-1.4.0a20251011.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
478
|
+
nvidia_nat-1.4.0a20251011.dist-info/entry_points.txt,sha256=4jCqjyETMpyoWbCBf4GalZU8I_wbstpzwQNezdAVbbo,698
|
|
479
|
+
nvidia_nat-1.4.0a20251011.dist-info/top_level.txt,sha256=lgJWLkigiVZuZ_O1nxVnD_ziYBwgpE2OStdaCduMEGc,8
|
|
480
|
+
nvidia_nat-1.4.0a20251011.dist-info/RECORD,,
|
|
File without changes
|
{nvidia_nat-1.4.0a20251010.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/entry_points.txt
RENAMED
|
File without changes
|
|
File without changes
|
{nvidia_nat-1.4.0a20251010.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/licenses/LICENSE.md
RENAMED
|
File without changes
|
|
File without changes
|