nvidia-nat 1.4.0a20251008__py3-none-any.whl → 1.4.0a20251011__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. nat/agent/react_agent/register.py +15 -24
  2. nat/agent/rewoo_agent/register.py +15 -24
  3. nat/agent/tool_calling_agent/register.py +9 -5
  4. nat/builder/component_utils.py +1 -1
  5. nat/builder/function.py +4 -4
  6. nat/builder/intermediate_step_manager.py +32 -0
  7. nat/builder/workflow_builder.py +46 -3
  8. nat/cli/entrypoint.py +9 -1
  9. nat/data_models/api_server.py +78 -9
  10. nat/data_models/config.py +1 -1
  11. nat/front_ends/console/console_front_end_plugin.py +11 -2
  12. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  13. nat/front_ends/mcp/mcp_front_end_config.py +13 -0
  14. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +18 -1
  15. nat/front_ends/mcp/memory_profiler.py +320 -0
  16. nat/front_ends/mcp/tool_converter.py +21 -2
  17. nat/observability/register.py +16 -0
  18. nat/runtime/runner.py +1 -2
  19. nat/runtime/session.py +1 -1
  20. nat/tool/memory_tools/add_memory_tool.py +3 -3
  21. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  22. nat/tool/memory_tools/get_memory_tool.py +3 -3
  23. nat/utils/type_converter.py +8 -0
  24. nvidia_nat-1.4.0a20251011.dist-info/METADATA +195 -0
  25. {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/RECORD +30 -29
  26. nvidia_nat-1.4.0a20251008.dist-info/METADATA +0 -389
  27. {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/WHEEL +0 -0
  28. {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/entry_points.txt +0 -0
  29. {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  30. {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/licenses/LICENSE.md +0 -0
  31. {nvidia_nat-1.4.0a20251008.dist-info → nvidia_nat-1.4.0a20251011.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,7 @@ from nat.builder.function_info import FunctionInfo
24
24
  from nat.cli.register_workflow import register_function
25
25
  from nat.data_models.agent import AgentBaseConfig
26
26
  from nat.data_models.api_server import ChatRequest
27
+ from nat.data_models.api_server import ChatRequestOrMessage
27
28
  from nat.data_models.api_server import ChatResponse
28
29
  from nat.data_models.api_server import Usage
29
30
  from nat.data_models.component_ref import FunctionGroupRef
@@ -70,9 +71,6 @@ class ReActAgentWorkflowConfig(AgentBaseConfig, OptimizableMixin, name="react_ag
70
71
  default=None,
71
72
  description="Provides the SYSTEM_PROMPT to use with the agent") # defaults to SYSTEM_PROMPT in prompt.py
72
73
  max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
73
- use_openai_api: bool = Field(default=False,
74
- description=("Use OpenAI API for the input/output types to the function. "
75
- "If False, strings will be used."))
76
74
  additional_instructions: str | None = OptimizableField(
77
75
  default=None,
78
76
  description="Additional instructions to provide to the agent in addition to the base prompt.",
@@ -118,21 +116,23 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
118
116
  pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent,
119
117
  normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()
120
118
 
121
- async def _response_fn(input_message: ChatRequest) -> ChatResponse:
119
+ async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
122
120
  """
123
121
  Main workflow entry function for the ReAct Agent.
124
122
 
125
123
  This function invokes the ReAct Agent Graph and returns the response.
126
124
 
127
125
  Args:
128
- input_message (ChatRequest): The input message to process
126
+ chat_request_or_message (ChatRequestOrMessage): The input message to process
129
127
 
130
128
  Returns:
131
- ChatResponse: The response from the agent or error message
129
+ ChatResponse | str: The response from the agent or error message
132
130
  """
133
131
  try:
132
+ message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
133
+
134
134
  # initialize the starting state with the user query
135
- messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
135
+ messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
136
136
  max_tokens=config.max_history,
137
137
  strategy="last",
138
138
  token_counter=len,
@@ -153,25 +153,16 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
153
153
  content = str(output_message.content)
154
154
 
155
155
  # Create usage statistics for the response
156
- prompt_tokens = sum(len(str(msg.content).split()) for msg in input_message.messages)
156
+ prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
157
157
  completion_tokens = len(content.split()) if content else 0
158
158
  total_tokens = prompt_tokens + completion_tokens
159
159
  usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
160
- return ChatResponse.from_string(content, usage=usage)
161
-
160
+ response = ChatResponse.from_string(content, usage=usage)
161
+ if chat_request_or_message.is_string:
162
+ return GlobalTypeConverter.get().convert(response, to_type=str)
163
+ return response
162
164
  except Exception as ex:
163
- logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, str(ex))
164
- raise RuntimeError
165
-
166
- if (config.use_openai_api):
167
- yield FunctionInfo.from_fn(_response_fn, description=config.description)
168
- else:
169
-
170
- async def _str_api_fn(input_message: str) -> str:
171
- oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
172
-
173
- oai_output = await _response_fn(oai_input)
174
-
175
- return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
165
+ logger.error("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, str(ex))
166
+ raise
176
167
 
177
- yield FunctionInfo.from_fn(_str_api_fn, description=config.description)
168
+ yield FunctionInfo.from_fn(_response_fn, description=config.description)
@@ -25,6 +25,7 @@ from nat.builder.function_info import FunctionInfo
25
25
  from nat.cli.register_workflow import register_function
26
26
  from nat.data_models.agent import AgentBaseConfig
27
27
  from nat.data_models.api_server import ChatRequest
28
+ from nat.data_models.api_server import ChatRequestOrMessage
28
29
  from nat.data_models.api_server import ChatResponse
29
30
  from nat.data_models.api_server import Usage
30
31
  from nat.data_models.component_ref import FunctionGroupRef
@@ -54,9 +55,6 @@ class ReWOOAgentWorkflowConfig(AgentBaseConfig, name="rewoo_agent"):
54
55
  description="The number of retries before raising a tool call error.",
55
56
  ge=1)
56
57
  max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
57
- use_openai_api: bool = Field(default=False,
58
- description=("Use OpenAI API for the input/output types to the function. "
59
- "If False, strings will be used."))
60
58
  additional_planner_instructions: str | None = Field(
61
59
  default=None,
62
60
  validation_alias=AliasChoices("additional_planner_instructions", "additional_instructions"),
@@ -125,21 +123,23 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
125
123
  tool_call_max_retries=config.tool_call_max_retries,
126
124
  raise_tool_call_error=config.raise_tool_call_error).build_graph()
127
125
 
128
- async def _response_fn(input_message: ChatRequest) -> ChatResponse:
126
+ async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
129
127
  """
130
128
  Main workflow entry function for the ReWOO Agent.
131
129
 
132
130
  This function invokes the ReWOO Agent Graph and returns the response.
133
131
 
134
132
  Args:
135
- input_message (ChatRequest): The input message to process
133
+ chat_request_or_message (ChatRequestOrMessage): The input message to process
136
134
 
137
135
  Returns:
138
- ChatResponse: The response from the agent or error message
136
+ ChatResponse | str: The response from the agent or error message
139
137
  """
140
138
  try:
139
+ message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
140
+
141
141
  # initialize the starting state with the user query
142
- messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
142
+ messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
143
143
  max_tokens=config.max_history,
144
144
  strategy="last",
145
145
  token_counter=len,
@@ -160,25 +160,16 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
160
160
  output_message = str(output_message)
161
161
 
162
162
  # Create usage statistics for the response
163
- prompt_tokens = sum(len(str(msg.content).split()) for msg in input_message.messages)
163
+ prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
164
164
  completion_tokens = len(output_message.split()) if output_message else 0
165
165
  total_tokens = prompt_tokens + completion_tokens
166
166
  usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
167
- return ChatResponse.from_string(output_message, usage=usage)
168
-
167
+ response = ChatResponse.from_string(output_message, usage=usage)
168
+ if chat_request_or_message.is_string:
169
+ return GlobalTypeConverter.get().convert(response, to_type=str)
170
+ return response
169
171
  except Exception as ex:
170
- logger.exception("ReWOO Agent failed with exception: %s", ex)
171
- raise RuntimeError
172
-
173
- if (config.use_openai_api):
174
- yield FunctionInfo.from_fn(_response_fn, description=config.description)
175
-
176
- else:
177
-
178
- async def _str_api_fn(input_message: str) -> str:
179
- oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
180
- oai_output = await _response_fn(oai_input)
181
-
182
- return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
172
+ logger.error("ReWOO Agent failed with exception: %s", ex)
173
+ raise
183
174
 
184
- yield FunctionInfo.from_fn(_str_api_fn, description=config.description)
175
+ yield FunctionInfo.from_fn(_response_fn, description=config.description)
@@ -23,8 +23,10 @@ from nat.builder.function_info import FunctionInfo
23
23
  from nat.cli.register_workflow import register_function
24
24
  from nat.data_models.agent import AgentBaseConfig
25
25
  from nat.data_models.api_server import ChatRequest
26
+ from nat.data_models.api_server import ChatRequestOrMessage
26
27
  from nat.data_models.component_ref import FunctionGroupRef
27
28
  from nat.data_models.component_ref import FunctionRef
29
+ from nat.utils.type_converter import GlobalTypeConverter
28
30
 
29
31
  logger = logging.getLogger(__name__)
30
32
 
@@ -81,21 +83,23 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
81
83
  handle_tool_errors=config.handle_tool_errors,
82
84
  return_direct=return_direct_tools).build_graph()
83
85
 
84
- async def _response_fn(input_message: ChatRequest) -> str:
86
+ async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> str:
85
87
  """
86
88
  Main workflow entry function for the Tool Calling Agent.
87
89
 
88
90
  This function invokes the Tool Calling Agent Graph and returns the response.
89
91
 
90
92
  Args:
91
- input_message (ChatRequest): The input message to process
93
+ chat_request_or_message (ChatRequestOrMessage): The input message to process
92
94
 
93
95
  Returns:
94
96
  str: The response from the agent or error message
95
97
  """
96
98
  try:
99
+ message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
100
+
97
101
  # initialize the starting state with the user query
98
- messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
102
+ messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in message.messages],
99
103
  max_tokens=config.max_history,
100
104
  strategy="last",
101
105
  token_counter=len,
@@ -114,8 +118,8 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
114
118
  output_message = state.messages[-1]
115
119
  return str(output_message.content)
116
120
  except Exception as ex:
117
- logger.exception("%s Tool Calling Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
118
- raise RuntimeError
121
+ logger.error("%s Tool Calling Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
122
+ raise
119
123
 
120
124
  try:
121
125
  yield FunctionInfo.from_fn(_response_fn, description=config.description)
@@ -153,7 +153,7 @@ def recursive_componentref_discovery(cls: TypedBaseModel, value: typing.Any,
153
153
  for v in value.values():
154
154
  yield from recursive_componentref_discovery(cls, v, decomposed_type.args[1])
155
155
  elif (issubclass(type(value), BaseModel)):
156
- for field, field_info in value.model_fields.items():
156
+ for field, field_info in type(value).model_fields.items():
157
157
  field_data = getattr(value, field)
158
158
  yield from recursive_componentref_discovery(cls, field_data, field_info.annotation)
159
159
  if (decomposed_type.is_union):
nat/builder/function.py CHANGED
@@ -159,8 +159,7 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
159
159
 
160
160
  return result
161
161
  except Exception as e:
162
- err_msg = f"Error: {e}" if str(e).strip() else ""
163
- logger.error("Error with ainvoke in function with input: %s. %s", value, err_msg)
162
+ logger.error("Error with ainvoke in function with input: %s. Error: %s", value, e)
164
163
  raise
165
164
 
166
165
  @typing.final
@@ -416,8 +415,9 @@ class FunctionGroup:
416
415
  """
417
416
  if not name.strip():
418
417
  raise ValueError("Function name cannot be empty or blank")
419
- if not re.match(r"^[a-zA-Z0-9_-]+$", name):
420
- raise ValueError(f"Function name can only contain letters, numbers, underscores, and hyphens: {name}")
418
+ if not re.match(r"^[a-zA-Z0-9_.-]+$", name):
419
+ raise ValueError(
420
+ f"Function name can only contain letters, numbers, underscores, periods, and hyphens: {name}")
421
421
  if name in self._functions:
422
422
  raise ValueError(f"Function {name} already exists in function group {self._instance_name}")
423
423
 
@@ -16,6 +16,8 @@
16
16
  import dataclasses
17
17
  import logging
18
18
  import typing
19
+ import weakref
20
+ from typing import ClassVar
19
21
 
20
22
  from nat.data_models.intermediate_step import IntermediateStep
21
23
  from nat.data_models.intermediate_step import IntermediateStepPayload
@@ -46,11 +48,19 @@ class IntermediateStepManager:
46
48
  Manages updates to the NAT Event Stream for intermediate steps
47
49
  """
48
50
 
51
+ # Class-level tracking for debugging and monitoring
52
+ _instance_count: ClassVar[int] = 0
53
+ _active_instances: ClassVar[set[weakref.ref]] = set()
54
+
49
55
  def __init__(self, context_state: "ContextState"): # noqa: F821
50
56
  self._context_state = context_state
51
57
 
52
58
  self._outstanding_start_steps: dict[str, OpenStep] = {}
53
59
 
60
+ # Track instance creation
61
+ IntermediateStepManager._instance_count += 1
62
+ IntermediateStepManager._active_instances.add(weakref.ref(self, self._cleanup_instance_tracking))
63
+
54
64
  def push_intermediate_step(self, payload: IntermediateStepPayload) -> None:
55
65
  """
56
66
  Pushes an intermediate step to the NAT Event Stream
@@ -172,3 +182,25 @@ class IntermediateStepManager:
172
182
  """
173
183
 
174
184
  return self._context_state.event_stream.get().subscribe(on_next, on_error, on_complete)
185
+
186
+ @classmethod
187
+ def _cleanup_instance_tracking(cls, ref: weakref.ref) -> None:
188
+ """Cleanup callback for weakref when instance is garbage collected."""
189
+ cls._active_instances.discard(ref)
190
+
191
+ @classmethod
192
+ def get_active_instance_count(cls) -> int:
193
+ """Get the number of active IntermediateStepManager instances.
194
+
195
+ Returns:
196
+ int: Number of active instances (cleaned up automatically via weakref)
197
+ """
198
+ return len(cls._active_instances)
199
+
200
+ def get_outstanding_step_count(self) -> int:
201
+ """Get the number of outstanding (started but not ended) steps.
202
+
203
+ Returns:
204
+ int: Number of steps that have been started but not yet ended
205
+ """
206
+ return len(self._outstanding_start_steps)
@@ -156,6 +156,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
156
156
  self._registry = registry
157
157
 
158
158
  self._logging_handlers: dict[str, logging.Handler] = {}
159
+ self._removed_root_handlers: list[tuple[logging.Handler, int]] = []
159
160
  self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
160
161
 
161
162
  self._functions: dict[str, ConfiguredFunction] = {}
@@ -187,6 +188,15 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
187
188
  # Get the telemetry info from the config
188
189
  telemetry_config = self.general_config.telemetry
189
190
 
191
+ # If we have logging configuration, we need to manage the root logger properly
192
+ root_logger = logging.getLogger()
193
+
194
+ # Collect configured handler types to determine if we need to adjust existing handlers
195
+ # This is somewhat of a hack by inspecting the class name of the config object
196
+ has_console_handler = any(
197
+ hasattr(config, "__class__") and "console" in config.__class__.__name__.lower()
198
+ for config in telemetry_config.logging.values())
199
+
190
200
  for key, logging_config in telemetry_config.logging.items():
191
201
  # Use the same pattern as tracing, but for logging
192
202
  logging_info = self._registry.get_logging_method(type(logging_config))
@@ -200,7 +210,31 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
200
210
  self._logging_handlers[key] = handler
201
211
 
202
212
  # Now attach to NAT's root logger
203
- logging.getLogger().addHandler(handler)
213
+ root_logger.addHandler(handler)
214
+
215
+ # If we added logging handlers, manage existing handlers appropriately
216
+ if self._logging_handlers:
217
+ min_handler_level = min((handler.level for handler in root_logger.handlers), default=logging.CRITICAL)
218
+
219
+ # Ensure the root logger level allows messages through
220
+ root_logger.level = max(root_logger.level, min_handler_level)
221
+
222
+ # If a console handler is configured, adjust or remove default CLI handlers
223
+ # to avoid duplicate output while preserving workflow visibility
224
+ if has_console_handler:
225
+ # Remove existing StreamHandlers that are not the newly configured ones
226
+ for handler in root_logger.handlers[:]:
227
+ if type(handler) is logging.StreamHandler and handler not in self._logging_handlers.values():
228
+ self._removed_root_handlers.append((handler, handler.level))
229
+ root_logger.removeHandler(handler)
230
+ else:
231
+ # No console handler configured, but adjust existing handler levels
232
+ # to respect the minimum configured level for file/other handlers
233
+ for handler in root_logger.handlers[:]:
234
+ if type(handler) is logging.StreamHandler:
235
+ old_level = handler.level
236
+ handler.setLevel(min_handler_level)
237
+ self._removed_root_handlers.append((handler, old_level))
204
238
 
205
239
  # Add the telemetry exporters
206
240
  for key, telemetry_exporter_config in telemetry_config.tracing.items():
@@ -212,8 +246,17 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
212
246
 
213
247
  assert self._exit_stack is not None, "Exit stack not initialized"
214
248
 
215
- for _, handler in self._logging_handlers.items():
216
- logging.getLogger().removeHandler(handler)
249
+ root_logger = logging.getLogger()
250
+
251
+ # Remove custom logging handlers
252
+ for handler in self._logging_handlers.values():
253
+ root_logger.removeHandler(handler)
254
+
255
+ # Restore original handlers and their levels
256
+ for handler, old_level in self._removed_root_handlers:
257
+ if handler not in root_logger.handlers:
258
+ root_logger.addHandler(handler)
259
+ handler.setLevel(old_level)
217
260
 
218
261
  await self._exit_stack.__aexit__(*exc_details)
219
262
 
nat/cli/entrypoint.py CHANGED
@@ -29,6 +29,7 @@ import time
29
29
 
30
30
  import click
31
31
  import nest_asyncio
32
+ from dotenv import load_dotenv
32
33
 
33
34
  from nat.utils.log_levels import LOG_LEVELS
34
35
 
@@ -45,6 +46,9 @@ from .commands.uninstall import uninstall_command
45
46
  from .commands.validate import validate_command
46
47
  from .commands.workflow.workflow import workflow_command
47
48
 
49
+ # Load environment variables from .env file, if it exists
50
+ load_dotenv()
51
+
48
52
  # Apply at the beginning of the file to avoid issues with asyncio
49
53
  nest_asyncio.apply()
50
54
 
@@ -52,7 +56,11 @@ nest_asyncio.apply()
52
56
  def setup_logging(log_level: str):
53
57
  """Configure logging with the specified level"""
54
58
  numeric_level = LOG_LEVELS.get(log_level.upper(), logging.INFO)
55
- logging.basicConfig(level=numeric_level, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
59
+ logging.basicConfig(
60
+ level=numeric_level,
61
+ format="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
62
+ datefmt="%Y-%m-%d %H:%M:%S",
63
+ )
56
64
  return numeric_level
57
65
 
58
66
 
@@ -28,6 +28,7 @@ from pydantic import HttpUrl
28
28
  from pydantic import conlist
29
29
  from pydantic import field_serializer
30
30
  from pydantic import field_validator
31
+ from pydantic import model_validator
31
32
  from pydantic_core.core_schema import ValidationInfo
32
33
 
33
34
  from nat.data_models.interactive import HumanPrompt
@@ -120,15 +121,7 @@ class Message(BaseModel):
120
121
  role: UserMessageContentRoleType
121
122
 
122
123
 
123
- class ChatRequest(BaseModel):
124
- """
125
- ChatRequest is a data model that represents a request to the NAT chat API.
126
- Fully compatible with OpenAI Chat Completions API specification.
127
- """
128
-
129
- # Required fields
130
- messages: typing.Annotated[list[Message], conlist(Message, min_length=1)]
131
-
124
+ class ChatRequestOptionals(BaseModel):
132
125
  # Optional fields (OpenAI Chat Completions API compatible)
133
126
  model: str | None = Field(default=None, description="name of the model to use")
134
127
  frequency_penalty: float | None = Field(default=0.0,
@@ -153,6 +146,16 @@ class ChatRequest(BaseModel):
153
146
  parallel_tool_calls: bool | None = Field(default=True, description="Whether to enable parallel function calling")
154
147
  user: str | None = Field(default=None, description="Unique identifier representing end-user")
155
148
 
149
+
150
+ class ChatRequest(ChatRequestOptionals):
151
+ """
152
+ ChatRequest is a data model that represents a request to the NAT chat API.
153
+ Fully compatible with OpenAI Chat Completions API specification.
154
+ """
155
+
156
+ # Required fields
157
+ messages: typing.Annotated[list[Message], conlist(Message, min_length=1)]
158
+
156
159
  model_config = ConfigDict(extra="allow",
157
160
  json_schema_extra={
158
161
  "example": {
@@ -194,6 +197,42 @@ class ChatRequest(BaseModel):
194
197
  top_p=top_p)
195
198
 
196
199
 
200
+ class ChatRequestOrMessage(ChatRequestOptionals):
201
+ """
202
+ ChatRequestOrMessage is a data model that represents either a conversation or a string input.
203
+ This is useful for functions that can handle either type of input.
204
+
205
+ `messages` is compatible with the OpenAI Chat Completions API specification.
206
+
207
+ `input_string` is a string input that can be used for functions that do not require a conversation.
208
+ """
209
+
210
+ messages: typing.Annotated[list[Message] | None, conlist(Message, min_length=1)] = Field(
211
+ default=None, description="The conversation messages to process.")
212
+
213
+ input_string: str | None = Field(default=None, alias="input_message", description="The input message to process.")
214
+
215
+ @property
216
+ def is_string(self) -> bool:
217
+ return self.input_string is not None
218
+
219
+ @property
220
+ def is_conversation(self) -> bool:
221
+ return self.messages is not None
222
+
223
+ @model_validator(mode="after")
224
+ def validate_messages_or_input_string(self):
225
+ if self.messages is not None and self.input_string is not None:
226
+ raise ValueError("Either messages or input_message/input_string must be provided, not both")
227
+ if self.messages is None and self.input_string is None:
228
+ raise ValueError("Either messages or input_message/input_string must be provided")
229
+ if self.input_string is not None:
230
+ extra_fields = self.model_dump(exclude={"input_string"}, exclude_none=True, exclude_unset=True)
231
+ if len(extra_fields) > 0:
232
+ raise ValueError("no extra fields are permitted when input_message/input_string is provided")
233
+ return self
234
+
235
+
197
236
  class ChoiceMessage(BaseModel):
198
237
  content: str | None = None
199
238
  role: UserMessageContentRoleType | None = None
@@ -661,6 +700,36 @@ def _string_to_nat_chat_request(data: str) -> ChatRequest:
661
700
  GlobalTypeConverter.register_converter(_string_to_nat_chat_request)
662
701
 
663
702
 
703
+ def _chat_request_or_message_to_chat_request(data: ChatRequestOrMessage) -> ChatRequest:
704
+ if data.input_string is not None:
705
+ return _string_to_nat_chat_request(data.input_string)
706
+ return ChatRequest(**data.model_dump(exclude={"input_string"}))
707
+
708
+
709
+ GlobalTypeConverter.register_converter(_chat_request_or_message_to_chat_request)
710
+
711
+
712
+ def _chat_request_to_chat_request_or_message(data: ChatRequest) -> ChatRequestOrMessage:
713
+ return ChatRequestOrMessage(**data.model_dump(by_alias=True))
714
+
715
+
716
+ GlobalTypeConverter.register_converter(_chat_request_to_chat_request_or_message)
717
+
718
+
719
+ def _chat_request_or_message_to_string(data: ChatRequestOrMessage) -> str:
720
+ return data.input_string or ""
721
+
722
+
723
+ GlobalTypeConverter.register_converter(_chat_request_or_message_to_string)
724
+
725
+
726
+ def _string_to_chat_request_or_message(data: str) -> ChatRequestOrMessage:
727
+ return ChatRequestOrMessage(input_message=data)
728
+
729
+
730
+ GlobalTypeConverter.register_converter(_string_to_chat_request_or_message)
731
+
732
+
664
733
  # ======== ChatResponse Converters ========
665
734
  def _nat_chat_response_to_string(data: ChatResponse) -> str:
666
735
  if data.choices and data.choices[0].message:
nat/data_models/config.py CHANGED
@@ -187,7 +187,7 @@ class TelemetryConfig(BaseModel):
187
187
 
188
188
  class GeneralConfig(BaseModel):
189
189
 
190
- model_config = ConfigDict(protected_namespaces=())
190
+ model_config = ConfigDict(protected_namespaces=(), extra="forbid")
191
191
 
192
192
  use_uvloop: bool | None = Field(
193
193
  default=None,
@@ -95,5 +95,14 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
95
95
  else:
96
96
  assert False, "Should not reach here. Should have been caught by pre_run"
97
97
 
98
- # Print result
99
- logger.info(f"\n{'-' * 50}\n{Fore.GREEN}Workflow Result:\n%s{Fore.RESET}\n{'-' * 50}", runner_outputs)
98
+ line = f"{'-' * 50}"
99
+ prefix = f"{line}\n{Fore.GREEN}Workflow Result:\n"
100
+ suffix = f"{Fore.RESET}\n{line}"
101
+
102
+ logger.info(f"{prefix}%s{suffix}", runner_outputs)
103
+
104
+ # (handler is a stream handler) => (level > INFO)
105
+ effective_level_too_high = all(
106
+ type(h) is not logging.StreamHandler or h.level > logging.INFO for h in logging.getLogger().handlers)
107
+ if effective_level_too_high:
108
+ print(f"{prefix}{runner_outputs}{suffix}")
@@ -24,4 +24,4 @@ class HTTPAuthenticationFlowHandler(FlowHandlerBase):
24
24
  async def authenticate(self, config: AuthProviderBaseConfig, method: AuthFlowType) -> AuthenticatedContext:
25
25
 
26
26
  raise NotImplementedError(f"Authentication method '{method}' is not supported by the HTTP frontend."
27
- f" Do you have Websockets enabled?")
27
+ f" Do you have WebSockets enabled?")
@@ -43,3 +43,16 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
43
43
 
44
44
  server_auth: OAuth2ResourceServerConfig | None = Field(
45
45
  default=None, description=("OAuth 2.0 Resource Server configuration for token verification."))
46
+
47
+ # Memory profiling configuration
48
+ enable_memory_profiling: bool = Field(default=False,
49
+ description="Enable memory profiling and diagnostics (default: False)")
50
+ memory_profile_interval: int = Field(default=50,
51
+ description="Log memory stats every N requests (default: 50)",
52
+ ge=1)
53
+ memory_profile_top_n: int = Field(default=10,
54
+ description="Number of top memory allocations to log (default: 10)",
55
+ ge=1,
56
+ le=50)
57
+ memory_profile_log_level: str = Field(default="DEBUG",
58
+ description="Log level for memory profiling output (default: DEBUG)")
@@ -29,6 +29,7 @@ from nat.builder.workflow import Workflow
29
29
  from nat.builder.workflow_builder import WorkflowBuilder
30
30
  from nat.data_models.config import Config
31
31
  from nat.front_ends.mcp.mcp_front_end_config import MCPFrontEndConfig
32
+ from nat.front_ends.mcp.memory_profiler import MemoryProfiler
32
33
 
33
34
  logger = logging.getLogger(__name__)
34
35
 
@@ -45,6 +46,12 @@ class MCPFrontEndPluginWorkerBase(ABC):
45
46
  self.full_config = config
46
47
  self.front_end_config: MCPFrontEndConfig = config.general.front_end
47
48
 
49
+ # Initialize memory profiler if enabled
50
+ self.memory_profiler = MemoryProfiler(enabled=self.front_end_config.enable_memory_profiling,
51
+ log_interval=self.front_end_config.memory_profile_interval,
52
+ top_n=self.front_end_config.memory_profile_top_n,
53
+ log_level=self.front_end_config.memory_profile_log_level)
54
+
48
55
  def _setup_health_endpoint(self, mcp: FastMCP):
49
56
  """Set up the HTTP health endpoint that exercises MCP ping handler."""
50
57
 
@@ -115,6 +122,7 @@ class MCPFrontEndPluginWorkerBase(ABC):
115
122
  Exposes:
116
123
  - GET /debug/tools/list: List tools. Optional query param `name` (one or more, repeatable or comma separated)
117
124
  selects a subset and returns details for those tools.
125
+ - GET /debug/memory/stats: Get current memory profiling statistics (read-only)
118
126
  """
119
127
 
120
128
  @mcp.custom_route("/debug/tools/list", methods=["GET"])
@@ -206,6 +214,15 @@ class MCPFrontEndPluginWorkerBase(ABC):
206
214
  return JSONResponse(
207
215
  _build_final_json(functions_to_include, _parse_detail_param(detail_raw, has_names=bool(names))))
208
216
 
217
+ # Memory profiling endpoint (read-only)
218
+ @mcp.custom_route("/debug/memory/stats", methods=["GET"])
219
+ async def get_memory_stats(_request: Request):
220
+ """Get current memory profiling statistics."""
221
+ from starlette.responses import JSONResponse
222
+
223
+ stats = self.memory_profiler.get_stats()
224
+ return JSONResponse(stats)
225
+
209
226
 
210
227
  class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
211
228
  """Default MCP front end plugin worker implementation."""
@@ -241,7 +258,7 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
241
258
 
242
259
  # Register each function with MCP, passing workflow context for observability
243
260
  for function_name, function in functions.items():
244
- register_function_with_mcp(mcp, function_name, function, workflow)
261
+ register_function_with_mcp(mcp, function_name, function, workflow, self.memory_profiler)
245
262
 
246
263
  # Add a simple fallback function if no functions were found
247
264
  if not functions: