nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. nat/agent/base.py +13 -8
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +6 -5
  5. nat/agent/react_agent/register.py +49 -39
  6. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  7. nat/agent/register.py +2 -0
  8. nat/agent/responses_api_agent/__init__.py +14 -0
  9. nat/agent/responses_api_agent/register.py +126 -0
  10. nat/agent/rewoo_agent/agent.py +304 -117
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +51 -38
  13. nat/agent/tool_calling_agent/agent.py +75 -17
  14. nat/agent/tool_calling_agent/register.py +46 -23
  15. nat/authentication/api_key/api_key_auth_provider.py +6 -11
  16. nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
  17. nat/authentication/credential_validator/__init__.py +14 -0
  18. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  19. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  20. nat/authentication/interfaces.py +5 -2
  21. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  22. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
  23. nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
  24. nat/builder/builder.py +55 -23
  25. nat/builder/component_utils.py +9 -5
  26. nat/builder/context.py +54 -15
  27. nat/builder/eval_builder.py +14 -9
  28. nat/builder/framework_enum.py +1 -0
  29. nat/builder/front_end.py +1 -1
  30. nat/builder/function.py +370 -0
  31. nat/builder/function_info.py +1 -1
  32. nat/builder/intermediate_step_manager.py +38 -2
  33. nat/builder/workflow.py +5 -0
  34. nat/builder/workflow_builder.py +306 -54
  35. nat/cli/cli_utils/config_override.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/mcp/__init__.py +14 -0
  38. nat/cli/commands/mcp/mcp.py +986 -0
  39. nat/cli/commands/optimize.py +90 -0
  40. nat/cli/commands/start.py +1 -1
  41. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  42. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  43. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  44. nat/cli/commands/workflow/workflow_commands.py +60 -18
  45. nat/cli/entrypoint.py +15 -11
  46. nat/cli/main.py +3 -0
  47. nat/cli/register_workflow.py +38 -4
  48. nat/cli/type_registry.py +72 -1
  49. nat/control_flow/__init__.py +0 -0
  50. nat/control_flow/register.py +20 -0
  51. nat/control_flow/router_agent/__init__.py +0 -0
  52. nat/control_flow/router_agent/agent.py +329 -0
  53. nat/control_flow/router_agent/prompt.py +48 -0
  54. nat/control_flow/router_agent/register.py +91 -0
  55. nat/control_flow/sequential_executor.py +166 -0
  56. nat/data_models/agent.py +34 -0
  57. nat/data_models/api_server.py +199 -69
  58. nat/data_models/authentication.py +23 -9
  59. nat/data_models/common.py +47 -0
  60. nat/data_models/component.py +2 -0
  61. nat/data_models/component_ref.py +11 -0
  62. nat/data_models/config.py +41 -17
  63. nat/data_models/dataset_handler.py +4 -3
  64. nat/data_models/function.py +34 -0
  65. nat/data_models/function_dependencies.py +8 -0
  66. nat/data_models/intermediate_step.py +9 -1
  67. nat/data_models/llm.py +15 -1
  68. nat/data_models/openai_mcp.py +46 -0
  69. nat/data_models/optimizable.py +208 -0
  70. nat/data_models/optimizer.py +161 -0
  71. nat/data_models/span.py +41 -3
  72. nat/data_models/thinking_mixin.py +2 -2
  73. nat/embedder/azure_openai_embedder.py +2 -1
  74. nat/embedder/nim_embedder.py +3 -2
  75. nat/embedder/openai_embedder.py +3 -2
  76. nat/eval/config.py +1 -1
  77. nat/eval/dataset_handler/dataset_downloader.py +3 -2
  78. nat/eval/dataset_handler/dataset_filter.py +34 -2
  79. nat/eval/evaluate.py +10 -3
  80. nat/eval/evaluator/base_evaluator.py +1 -1
  81. nat/eval/rag_evaluator/evaluate.py +7 -4
  82. nat/eval/register.py +4 -0
  83. nat/eval/runtime_evaluator/__init__.py +14 -0
  84. nat/eval/runtime_evaluator/evaluate.py +123 -0
  85. nat/eval/runtime_evaluator/register.py +100 -0
  86. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  87. nat/eval/trajectory_evaluator/register.py +1 -1
  88. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  89. nat/eval/usage_stats.py +2 -0
  90. nat/eval/utils/output_uploader.py +3 -2
  91. nat/eval/utils/weave_eval.py +17 -3
  92. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  93. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  94. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  95. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  96. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  97. nat/experimental/test_time_compute/models/strategy_base.py +2 -2
  98. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  99. nat/front_ends/console/authentication_flow_handler.py +82 -30
  100. nat/front_ends/console/console_front_end_plugin.py +19 -7
  101. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  102. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  103. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  104. nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
  105. nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
  106. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
  107. nat/front_ends/fastapi/job_store.py +518 -99
  108. nat/front_ends/fastapi/main.py +11 -19
  109. nat/front_ends/fastapi/message_handler.py +69 -44
  110. nat/front_ends/fastapi/message_validator.py +8 -7
  111. nat/front_ends/fastapi/utils.py +57 -0
  112. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  113. nat/front_ends/mcp/mcp_front_end_config.py +71 -3
  114. nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
  115. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
  116. nat/front_ends/mcp/memory_profiler.py +320 -0
  117. nat/front_ends/mcp/tool_converter.py +78 -25
  118. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  119. nat/llm/aws_bedrock_llm.py +21 -8
  120. nat/llm/azure_openai_llm.py +14 -5
  121. nat/llm/litellm_llm.py +80 -0
  122. nat/llm/nim_llm.py +23 -9
  123. nat/llm/openai_llm.py +19 -7
  124. nat/llm/register.py +4 -0
  125. nat/llm/utils/thinking.py +1 -1
  126. nat/observability/exporter/base_exporter.py +1 -1
  127. nat/observability/exporter/processing_exporter.py +29 -55
  128. nat/observability/exporter/span_exporter.py +43 -15
  129. nat/observability/exporter_manager.py +2 -2
  130. nat/observability/mixin/redaction_config_mixin.py +5 -4
  131. nat/observability/mixin/tagging_config_mixin.py +26 -14
  132. nat/observability/mixin/type_introspection_mixin.py +420 -107
  133. nat/observability/processor/batching_processor.py +1 -1
  134. nat/observability/processor/processor.py +3 -0
  135. nat/observability/processor/redaction/__init__.py +24 -0
  136. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  137. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  138. nat/observability/processor/redaction/redaction_processor.py +177 -0
  139. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  140. nat/observability/processor/span_tagging_processor.py +21 -14
  141. nat/observability/register.py +16 -0
  142. nat/profiler/callbacks/langchain_callback_handler.py +32 -7
  143. nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
  144. nat/profiler/callbacks/token_usage_base_model.py +2 -0
  145. nat/profiler/decorators/framework_wrapper.py +61 -9
  146. nat/profiler/decorators/function_tracking.py +35 -3
  147. nat/profiler/forecasting/models/linear_model.py +1 -1
  148. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  149. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  150. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  151. nat/profiler/parameter_optimization/__init__.py +0 -0
  152. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  153. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  154. nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
  155. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  156. nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
  157. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  158. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  159. nat/profiler/utils.py +3 -1
  160. nat/registry_handlers/pypi/register_pypi.py +5 -3
  161. nat/registry_handlers/rest/register_rest.py +5 -3
  162. nat/retriever/milvus/retriever.py +1 -1
  163. nat/retriever/nemo_retriever/register.py +2 -1
  164. nat/runtime/loader.py +1 -1
  165. nat/runtime/runner.py +111 -6
  166. nat/runtime/session.py +49 -3
  167. nat/settings/global_settings.py +2 -2
  168. nat/tool/chat_completion.py +4 -1
  169. nat/tool/code_execution/code_sandbox.py +3 -6
  170. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  171. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
  172. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  173. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  174. nat/tool/datetime_tools.py +1 -1
  175. nat/tool/github_tools.py +450 -0
  176. nat/tool/memory_tools/add_memory_tool.py +3 -3
  177. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  178. nat/tool/memory_tools/get_memory_tool.py +4 -4
  179. nat/tool/register.py +2 -7
  180. nat/tool/server_tools.py +15 -2
  181. nat/utils/__init__.py +76 -0
  182. nat/utils/callable_utils.py +70 -0
  183. nat/utils/data_models/schema_validator.py +1 -1
  184. nat/utils/decorators.py +210 -0
  185. nat/utils/exception_handlers/automatic_retries.py +278 -72
  186. nat/utils/io/yaml_tools.py +73 -3
  187. nat/utils/log_levels.py +25 -0
  188. nat/utils/responses_api.py +26 -0
  189. nat/utils/string_utils.py +16 -0
  190. nat/utils/type_converter.py +12 -3
  191. nat/utils/type_utils.py +6 -2
  192. nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
  193. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
  194. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
  195. nat/cli/commands/info/list_mcp.py +0 -461
  196. nat/data_models/temperature_mixin.py +0 -43
  197. nat/data_models/top_p_mixin.py +0 -43
  198. nat/observability/processor/header_redaction_processor.py +0 -123
  199. nat/observability/processor/redaction_processor.py +0 -77
  200. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  201. nat/tool/github_tools/create_github_commit.py +0 -133
  202. nat/tool/github_tools/create_github_issue.py +0 -87
  203. nat/tool/github_tools/create_github_pr.py +0 -106
  204. nat/tool/github_tools/get_github_file.py +0 -106
  205. nat/tool/github_tools/get_github_issue.py +0 -166
  206. nat/tool/github_tools/get_github_pr.py +0 -256
  207. nat/tool/github_tools/update_github_issue.py +0 -100
  208. nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
  209. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  210. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
  211. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  212. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
  213. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,320 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Memory profiling utilities for MCP frontend."""
16
+
17
+ import gc
18
+ import logging
19
+ import tracemalloc
20
+ from typing import Any
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class MemoryProfiler:
26
+ """Memory profiler for tracking memory usage and potential leaks."""
27
+
28
+ def __init__(self, enabled: bool = False, log_interval: int = 50, top_n: int = 10, log_level: str = "DEBUG"):
29
+ """Initialize the memory profiler.
30
+
31
+ Args:
32
+ enabled: Whether memory profiling is enabled
33
+ log_interval: Log stats every N requests
34
+ top_n: Number of top allocations to log
35
+ log_level: Log level for memory profiling output (e.g., "DEBUG", "INFO")
36
+ """
37
+ self.enabled = enabled
38
+ # normalize interval to avoid modulo-by-zero
39
+ self.log_interval = max(1, int(log_interval))
40
+ self.top_n = top_n
41
+ self.log_level = getattr(logging, log_level.upper(), logging.DEBUG)
42
+ self.request_count = 0
43
+ self.baseline_snapshot = None
44
+
45
+ # Track whether this instance started tracemalloc (to avoid resetting external tracing)
46
+ self._we_started_tracemalloc = False
47
+
48
+ if self.enabled:
49
+ logger.info("Memory profiling ENABLED (interval=%d, top_n=%d, log_level=%s)",
50
+ self.log_interval,
51
+ top_n,
52
+ log_level)
53
+ try:
54
+ if not tracemalloc.is_tracing():
55
+ tracemalloc.start()
56
+ self._we_started_tracemalloc = True
57
+ # Take baseline snapshot
58
+ gc.collect()
59
+ self.baseline_snapshot = tracemalloc.take_snapshot()
60
+ except RuntimeError as e:
61
+ logger.warning("tracemalloc unavailable or not tracing: %s", e)
62
+ else:
63
+ logger.info("Memory profiling DISABLED")
64
+
65
+ def _log(self, message: str, *args: Any) -> None:
66
+ """Log a message at the configured log level.
67
+
68
+ Args:
69
+ message: Log message format string
70
+ args: Arguments for the format string
71
+ """
72
+ logger.log(self.log_level, message, *args)
73
+
74
+ def on_request_complete(self) -> None:
75
+ """Called after each request completes."""
76
+ if not self.enabled:
77
+ return
78
+ self.request_count += 1
79
+ if self.request_count % self.log_interval == 0:
80
+ self.log_memory_stats()
81
+
82
+ def _ensure_tracing(self) -> bool:
83
+ """Ensure tracemalloc is running if we started it originally.
84
+
85
+ Returns:
86
+ True if tracemalloc is active, False otherwise
87
+ """
88
+ if tracemalloc.is_tracing():
89
+ return True
90
+
91
+ # Only restart if we started it originally (respect external control)
92
+ if not self._we_started_tracemalloc:
93
+ return False
94
+
95
+ # Attempt to restart
96
+ try:
97
+ logger.warning("tracemalloc was stopped externally; restarting (we started it originally)")
98
+ tracemalloc.start()
99
+
100
+ # Reset baseline since old tracking data is lost
101
+ gc.collect()
102
+ self.baseline_snapshot = tracemalloc.take_snapshot()
103
+ logger.info("Baseline snapshot reset after tracemalloc restart")
104
+
105
+ return True
106
+ except RuntimeError as e:
107
+ logger.error("Failed to restart tracemalloc: %s", e)
108
+ return False
109
+
110
+ def _safe_traced_memory(self) -> tuple[float, float] | None:
111
+ """Return (current, peak usage in MB) if tracemalloc is active, else None."""
112
+ if not self._ensure_tracing():
113
+ return None
114
+
115
+ try:
116
+ current, peak = tracemalloc.get_traced_memory()
117
+ megabyte = (1 << 20)
118
+ return (current / megabyte, peak / megabyte)
119
+ except RuntimeError:
120
+ return None
121
+
122
+ def _safe_snapshot(self) -> tracemalloc.Snapshot | None:
123
+ """Return a tracemalloc Snapshot if available, else None."""
124
+ if not self._ensure_tracing():
125
+ return None
126
+
127
+ try:
128
+ return tracemalloc.take_snapshot()
129
+ except RuntimeError:
130
+ return None
131
+
132
+ def log_memory_stats(self) -> dict[str, Any]:
133
+ """Log current memory statistics and return them."""
134
+ if not self.enabled:
135
+ return {}
136
+
137
+ # Force garbage collection first
138
+ gc.collect()
139
+
140
+ # Get current memory usage
141
+ mem = self._safe_traced_memory()
142
+ if mem is None:
143
+ logger.info("tracemalloc is not active; cannot collect memory stats.")
144
+ # still return structural fields
145
+ stats = {
146
+ "request_count": self.request_count,
147
+ "current_memory_mb": None,
148
+ "peak_memory_mb": None,
149
+ "active_intermediate_managers": self._safe_intermediate_step_manager_count(),
150
+ "outstanding_steps": self._safe_outstanding_step_count(),
151
+ "active_exporters": self._safe_exporter_count(),
152
+ "isolated_exporters": self._safe_isolated_exporter_count(),
153
+ "subject_instances": self._count_instances_of_type("Subject"),
154
+ }
155
+ return stats
156
+
157
+ current_mb, peak_mb = mem
158
+
159
+ # Take snapshot and compare to baseline
160
+ snapshot = self._safe_snapshot()
161
+
162
+ # Track BaseExporter instances (observability layer)
163
+ exporter_count = self._safe_exporter_count()
164
+ isolated_exporter_count = self._safe_isolated_exporter_count()
165
+
166
+ # Track Subject instances (event streams)
167
+ subject_count = self._count_instances_of_type("Subject")
168
+
169
+ stats = {
170
+ "request_count": self.request_count,
171
+ "current_memory_mb": round(current_mb, 2),
172
+ "peak_memory_mb": round(peak_mb, 2),
173
+ "active_intermediate_managers": self._safe_intermediate_step_manager_count(),
174
+ "outstanding_steps": self._safe_outstanding_step_count(),
175
+ "active_exporters": exporter_count,
176
+ "isolated_exporters": isolated_exporter_count,
177
+ "subject_instances": subject_count,
178
+ }
179
+
180
+ self._log("=" * 80)
181
+ self._log("MEMORY PROFILE AFTER %d REQUESTS:", self.request_count)
182
+ self._log(" Current Memory: %.2f MB", current_mb)
183
+ self._log(" Peak Memory: %.2f MB", peak_mb)
184
+ self._log("")
185
+ self._log("NAT COMPONENT INSTANCES:")
186
+ self._log(" IntermediateStepManagers: %d active (%d outstanding steps)",
187
+ stats["active_intermediate_managers"],
188
+ stats["outstanding_steps"])
189
+ self._log(" BaseExporters: %d active (%d isolated)", stats["active_exporters"], stats["isolated_exporters"])
190
+ self._log(" Subject (event streams): %d instances", stats["subject_instances"])
191
+
192
+ # Show top allocations
193
+ if snapshot is None:
194
+ self._log("tracemalloc snapshot unavailable.")
195
+ else:
196
+ if self.baseline_snapshot:
197
+ self._log("TOP %d MEMORY GROWTH SINCE BASELINE:", self.top_n)
198
+ top_stats = snapshot.compare_to(self.baseline_snapshot, 'lineno')
199
+ else:
200
+ self._log("TOP %d MEMORY ALLOCATIONS:", self.top_n)
201
+ top_stats = snapshot.statistics('lineno')
202
+
203
+ for i, stat in enumerate(top_stats[:self.top_n], 1):
204
+ self._log(" #%d: %s", i, stat)
205
+
206
+ self._log("=" * 80)
207
+
208
+ return stats
209
+
210
+ def _count_instances_of_type(self, type_name: str) -> int:
211
+ """Count instances of a specific type in memory."""
212
+ count = 0
213
+ for obj in gc.get_objects():
214
+ try:
215
+ if type(obj).__name__ == type_name:
216
+ count += 1
217
+ except Exception:
218
+ pass
219
+ return count
220
+
221
+ def _safe_exporter_count(self) -> int:
222
+ try:
223
+ from nat.observability.exporter.base_exporter import BaseExporter
224
+ return BaseExporter.get_active_instance_count()
225
+ except Exception as e:
226
+ logger.debug("Could not get BaseExporter stats: %s", e)
227
+ return 0
228
+
229
+ def _safe_isolated_exporter_count(self) -> int:
230
+ try:
231
+ from nat.observability.exporter.base_exporter import BaseExporter
232
+ return BaseExporter.get_isolated_instance_count()
233
+ except Exception:
234
+ return 0
235
+
236
+ def _safe_intermediate_step_manager_count(self) -> int:
237
+ try:
238
+ from nat.builder.intermediate_step_manager import IntermediateStepManager
239
+ # len() is atomic in CPython, but catch RuntimeError just in case
240
+ try:
241
+ return IntermediateStepManager.get_active_instance_count()
242
+ except RuntimeError:
243
+ # Set was modified during len() - very rare
244
+ logger.debug("Set changed during count, returning 0")
245
+ return 0
246
+ except Exception as e:
247
+ logger.debug("Could not get IntermediateStepManager stats: %s", e)
248
+ return 0
249
+
250
+ def _safe_outstanding_step_count(self) -> int:
251
+ """Get total outstanding steps across all active IntermediateStepManager instances."""
252
+ try:
253
+ from nat.builder.intermediate_step_manager import IntermediateStepManager
254
+
255
+ # Make a snapshot to avoid "Set changed size during iteration" if GC runs
256
+ try:
257
+ instances_snapshot = list(IntermediateStepManager._active_instances)
258
+ except RuntimeError:
259
+ # Set changed during list() call - rare but possible
260
+ logger.debug("Set changed during snapshot, returning 0 for outstanding steps")
261
+ return 0
262
+
263
+ total_outstanding = 0
264
+ # Iterate through snapshot safely
265
+ for ref in instances_snapshot:
266
+ try:
267
+ manager = ref()
268
+ if manager is not None:
269
+ total_outstanding += manager.get_outstanding_step_count()
270
+ except (ReferenceError, AttributeError):
271
+ # Manager was GC'd or in invalid state - skip it
272
+ continue
273
+ return total_outstanding
274
+ except Exception as e:
275
+ logger.debug("Could not get outstanding step count: %s", e)
276
+ return 0
277
+
278
+ def get_stats(self) -> dict[str, Any]:
279
+ """Get current memory statistics without logging."""
280
+ if not self.enabled:
281
+ return {"enabled": False}
282
+
283
+ mem = self._safe_traced_memory()
284
+ if mem is None:
285
+ return {
286
+ "enabled": True,
287
+ "request_count": self.request_count,
288
+ "current_memory_mb": None,
289
+ "peak_memory_mb": None,
290
+ "active_intermediate_managers": self._safe_intermediate_step_manager_count(),
291
+ "outstanding_steps": self._safe_outstanding_step_count(),
292
+ "active_exporters": self._safe_exporter_count(),
293
+ "isolated_exporters": self._safe_isolated_exporter_count(),
294
+ "subject_instances": self._count_instances_of_type("Subject"),
295
+ }
296
+
297
+ current_mb, peak_mb = mem
298
+ return {
299
+ "enabled": True,
300
+ "request_count": self.request_count,
301
+ "current_memory_mb": round(current_mb, 2),
302
+ "peak_memory_mb": round(peak_mb, 2),
303
+ "active_intermediate_managers": self._safe_intermediate_step_manager_count(),
304
+ "outstanding_steps": self._safe_outstanding_step_count(),
305
+ "active_exporters": self._safe_exporter_count(),
306
+ "isolated_exporters": self._safe_isolated_exporter_count(),
307
+ "subject_instances": self._count_instances_of_type("Subject"),
308
+ }
309
+
310
+ def reset_baseline(self) -> None:
311
+ """Reset the baseline snapshot to current state."""
312
+ if not self.enabled:
313
+ return
314
+ gc.collect()
315
+ snap = self._safe_snapshot()
316
+ if snap is None:
317
+ logger.info("Cannot reset baseline: tracemalloc is not active.")
318
+ return
319
+ self.baseline_snapshot = snap
320
+ logger.info("Memory profiling baseline reset at request %d", self.request_count)
@@ -18,9 +18,12 @@ import logging
18
18
  from inspect import Parameter
19
19
  from inspect import Signature
20
20
  from typing import TYPE_CHECKING
21
+ from typing import Any
21
22
 
22
23
  from mcp.server.fastmcp import FastMCP
23
24
  from pydantic import BaseModel
25
+ from pydantic.fields import FieldInfo
26
+ from pydantic_core import PydanticUndefined
24
27
 
25
28
  from nat.builder.context import ContextState
26
29
  from nat.builder.function import Function
@@ -28,9 +31,45 @@ from nat.builder.function_base import FunctionBase
28
31
 
29
32
  if TYPE_CHECKING:
30
33
  from nat.builder.workflow import Workflow
34
+ from nat.front_ends.mcp.memory_profiler import MemoryProfiler
31
35
 
32
36
  logger = logging.getLogger(__name__)
33
37
 
38
+ # Sentinel: marks "optional; let Pydantic supply default/factory"
39
+ _USE_PYDANTIC_DEFAULT = object()
40
+
41
+
42
+ def is_field_optional(field: FieldInfo) -> tuple[bool, Any]:
43
+ """Determine if a Pydantic field is optional and extract its default value for MCP signatures.
44
+
45
+ For MCP tool signatures, we need to distinguish:
46
+ - Required fields: marked with Parameter.empty
47
+ - Optional with concrete default: use that default
48
+ - Optional with factory: use sentinel so Pydantic can apply the factory later
49
+
50
+ Args:
51
+ field: The Pydantic FieldInfo to check
52
+
53
+ Returns:
54
+ A tuple of (is_optional, default_value):
55
+ - (False, Parameter.empty) for required fields
56
+ - (True, actual_default) for optional fields with explicit defaults
57
+ - (True, _USE_PYDANTIC_DEFAULT) for optional fields with default_factory
58
+ """
59
+ if field.is_required():
60
+ return False, Parameter.empty
61
+
62
+ # Field is optional - has either default or factory
63
+ if field.default is not PydanticUndefined:
64
+ return True, field.default
65
+
66
+ # Factory case: mark optional in signature but don't fabricate a value
67
+ if field.default_factory is not None:
68
+ return True, _USE_PYDANTIC_DEFAULT
69
+
70
+ # Rare corner case: non-required yet no default surfaced
71
+ return True, _USE_PYDANTIC_DEFAULT
72
+
34
73
 
35
74
  def create_function_wrapper(
36
75
  function_name: str,
@@ -38,6 +77,7 @@ def create_function_wrapper(
38
77
  schema: type[BaseModel],
39
78
  is_workflow: bool = False,
40
79
  workflow: 'Workflow | None' = None,
80
+ memory_profiler: 'MemoryProfiler | None' = None,
41
81
  ):
42
82
  """Create a wrapper function that exposes the actual parameters of a NAT Function as an MCP tool.
43
83
 
@@ -47,6 +87,7 @@ def create_function_wrapper(
47
87
  schema (type[BaseModel]): The input schema of the function
48
88
  is_workflow (bool): Whether the function is a Workflow
49
89
  workflow (Workflow | None): The parent workflow for observability context
90
+ memory_profiler: Optional memory profiler to track requests
50
91
 
51
92
  Returns:
52
93
  A wrapper function suitable for registration with MCP
@@ -76,12 +117,15 @@ def create_function_wrapper(
76
117
  # Get the field type and convert to appropriate Python type
77
118
  field_type = field.annotation
78
119
 
120
+ # Check if field is optional and get its default value
121
+ _is_optional, param_default = is_field_optional(field)
122
+
79
123
  # Add the parameter to our list
80
124
  parameters.append(
81
125
  Parameter(
82
126
  name=name,
83
127
  kind=Parameter.KEYWORD_ONLY,
84
- default=Parameter.empty if field.is_required else None,
128
+ default=param_default,
85
129
  annotation=field_type,
86
130
  ))
87
131
 
@@ -140,47 +184,46 @@ def create_function_wrapper(
140
184
  result = await call_with_observability(lambda: function.ainvoke(chat_request, to_type=str))
141
185
  else:
142
186
  # Regular handling
143
- # Handle complex input schema - if we extracted fields from a nested schema,
144
- # we need to reconstruct the input
145
- if len(schema.model_fields) == 1 and len(parameters) > 1:
146
- # Get the field name from the original schema
147
- field_name = next(iter(schema.model_fields.keys()))
148
- field_type = schema.model_fields[field_name].annotation
149
-
150
- # If it's a pydantic model, we need to create an instance
151
- if field_type and hasattr(field_type, "model_validate"):
152
- # Create the nested object
153
- nested_obj = field_type.model_validate(kwargs)
154
- # Call with the nested object
155
- kwargs = {field_name: nested_obj}
187
+ # Strip sentinel values so Pydantic can apply defaults/factories
188
+ cleaned_kwargs = {k: v for k, v in kwargs.items() if v is not _USE_PYDANTIC_DEFAULT}
189
+
190
+ # Always validate with the declared schema
191
+ # This handles defaults, factories, nested models, validators, etc.
192
+ model_input = schema.model_validate(cleaned_kwargs)
156
193
 
157
194
  # Call the NAT function with the parameters - special handling for Workflow
158
195
  if is_workflow:
159
- # For workflow with regular input, we'll assume the first parameter is the input
160
- input_value = list(kwargs.values())[0] if kwargs else ""
161
-
162
- # Workflows have a run method that is an async context manager
163
- # that returns a Runner
164
- async with function.run(input_value) as runner:
196
+ # Workflows expect the model instance directly
197
+ async with function.run(model_input) as runner:
165
198
  # Get the result from the runner
166
199
  result = await runner.result(to_type=str)
167
200
  else:
168
- # Regular function call
169
- result = await call_with_observability(lambda: function.acall_invoke(**kwargs))
201
+ # Regular function call - unpack the validated model
202
+ result = await call_with_observability(lambda: function.acall_invoke(**model_input.model_dump())
203
+ )
170
204
 
171
205
  # Report completion
172
206
  if ctx:
173
207
  await ctx.report_progress(100, 100)
174
208
 
209
+ # Track request completion for memory profiling
210
+ if memory_profiler:
211
+ memory_profiler.on_request_complete()
212
+
175
213
  # Handle different result types for proper formatting
176
214
  if isinstance(result, str):
177
215
  return result
178
- if isinstance(result, (dict, list)):
216
+ if isinstance(result, dict | list):
179
217
  return json.dumps(result, default=str)
180
218
  return str(result)
181
219
  except Exception as e:
182
220
  if ctx:
183
221
  ctx.error("Error calling function %s: %s", function_name, str(e))
222
+
223
+ # Track request completion even on error
224
+ if memory_profiler:
225
+ memory_profiler.on_request_complete()
226
+
184
227
  raise
185
228
 
186
229
  return wrapper_with_ctx
@@ -229,6 +272,9 @@ def get_function_description(function: FunctionBase) -> str:
229
272
  # Try to get anything that might be a description
230
273
  elif hasattr(config, "topic") and config.topic:
231
274
  function_description = config.topic
275
+ # Try to get description from the workflow config
276
+ elif hasattr(config, "workflow") and hasattr(config.workflow, "description") and config.workflow.description:
277
+ function_description = config.workflow.description
232
278
 
233
279
  elif isinstance(function, Function):
234
280
  function_description = function.description
@@ -239,7 +285,8 @@ def get_function_description(function: FunctionBase) -> str:
239
285
  def register_function_with_mcp(mcp: FastMCP,
240
286
  function_name: str,
241
287
  function: FunctionBase,
242
- workflow: 'Workflow | None' = None) -> None:
288
+ workflow: 'Workflow | None' = None,
289
+ memory_profiler: 'MemoryProfiler | None' = None) -> None:
243
290
  """Register a NAT Function as an MCP tool.
244
291
 
245
292
  Args:
@@ -247,6 +294,7 @@ def register_function_with_mcp(mcp: FastMCP,
247
294
  function_name: The name to register the function under
248
295
  function: The NAT Function to register
249
296
  workflow: The parent workflow for observability context (if available)
297
+ memory_profiler: Optional memory profiler to track requests
250
298
  """
251
299
  logger.info("Registering function %s with MCP", function_name)
252
300
 
@@ -264,5 +312,10 @@ def register_function_with_mcp(mcp: FastMCP,
264
312
  function_description = get_function_description(function)
265
313
 
266
314
  # Create and register the wrapper function with MCP
267
- wrapper_func = create_function_wrapper(function_name, function, input_schema, is_workflow, workflow)
315
+ wrapper_func = create_function_wrapper(function_name,
316
+ function,
317
+ input_schema,
318
+ is_workflow,
319
+ workflow,
320
+ memory_profiler)
268
321
  mcp.tool(name=function_name, description=function_description)(wrapper_func)
@@ -35,6 +35,8 @@ class SimpleFrontEndPluginBase(FrontEndBase[FrontEndConfigT], ABC):
35
35
 
36
36
  async def run(self):
37
37
 
38
+ await self.pre_run()
39
+
38
40
  # Must yield the workflow function otherwise it cleans up
39
41
  async with WorkflowBuilder.from_config(config=self.full_config) as builder:
40
42
 
@@ -45,7 +47,7 @@ class SimpleFrontEndPluginBase(FrontEndBase[FrontEndConfigT], ABC):
45
47
 
46
48
  click.echo(stream.getvalue())
47
49
 
48
- workflow = builder.build()
50
+ workflow = await builder.build()
49
51
  session_manager = SessionManager(workflow)
50
52
  await self.run_workflow(session_manager)
51
53
 
@@ -21,22 +21,25 @@ from nat.builder.builder import Builder
21
21
  from nat.builder.llm import LLMProviderInfo
22
22
  from nat.cli.register_workflow import register_llm_provider
23
23
  from nat.data_models.llm import LLMBaseConfig
24
+ from nat.data_models.optimizable import OptimizableField
25
+ from nat.data_models.optimizable import OptimizableMixin
26
+ from nat.data_models.optimizable import SearchSpace
24
27
  from nat.data_models.retry_mixin import RetryMixin
25
- from nat.data_models.temperature_mixin import TemperatureMixin
26
28
  from nat.data_models.thinking_mixin import ThinkingMixin
27
- from nat.data_models.top_p_mixin import TopPMixin
28
29
 
29
30
 
30
- class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, ThinkingMixin, name="aws_bedrock"):
31
+ class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, name="aws_bedrock"):
31
32
  """An AWS Bedrock llm provider to be used with an LLM client."""
32
33
 
33
- model_config = ConfigDict(protected_namespaces=())
34
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
34
35
 
35
36
  # Completion parameters
36
- model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
37
- serialization_alias="model",
38
- description="The model name for the hosted AWS Bedrock.")
39
- max_tokens: int | None = Field(default=1024, gt=0, description="Maximum number of tokens to generate.")
37
+ model_name: str = OptimizableField(validation_alias=AliasChoices("model_name", "model"),
38
+ serialization_alias="model",
39
+ description="The model name for the hosted AWS Bedrock.")
40
+ max_tokens: int = OptimizableField(default=300,
41
+ description="Maximum number of tokens to generate.",
42
+ space=SearchSpace(high=2176, low=128, step=512))
40
43
  context_size: int | None = Field(
41
44
  default=1024,
42
45
  gt=0,
@@ -50,6 +53,16 @@ class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMix
50
53
  default=None, description="Bedrock endpoint to use. Needed if you don't want to default to us-east-1 endpoint.")
51
54
  credentials_profile_name: str | None = Field(
52
55
  default=None, description="The name of the profile in the ~/.aws/credentials or ~/.aws/config files.")
56
+ temperature: float | None = OptimizableField(
57
+ default=None,
58
+ ge=0.0,
59
+ description="Sampling temperature to control randomness in the output.",
60
+ space=SearchSpace(high=0.9, low=0.1, step=0.2))
61
+ top_p: float | None = OptimizableField(default=None,
62
+ ge=0.0,
63
+ le=1.0,
64
+ description="Top-p for distribution sampling.",
65
+ space=SearchSpace(high=1.0, low=0.5, step=0.1))
53
66
 
54
67
 
55
68
  @register_llm_provider(config_type=AWSBedrockModelConfig)
@@ -20,18 +20,17 @@ from pydantic import Field
20
20
  from nat.builder.builder import Builder
21
21
  from nat.builder.llm import LLMProviderInfo
22
22
  from nat.cli.register_workflow import register_llm_provider
23
+ from nat.data_models.common import OptionalSecretStr
23
24
  from nat.data_models.llm import LLMBaseConfig
25
+ from nat.data_models.optimizable import OptimizableField
26
+ from nat.data_models.optimizable import SearchSpace
24
27
  from nat.data_models.retry_mixin import RetryMixin
25
- from nat.data_models.temperature_mixin import TemperatureMixin
26
28
  from nat.data_models.thinking_mixin import ThinkingMixin
27
- from nat.data_models.top_p_mixin import TopPMixin
28
29
 
29
30
 
30
31
  class AzureOpenAIModelConfig(
31
32
  LLMBaseConfig,
32
33
  RetryMixin,
33
- TemperatureMixin,
34
- TopPMixin,
35
34
  ThinkingMixin,
36
35
  name="azure_openai",
37
36
  ):
@@ -39,7 +38,7 @@ class AzureOpenAIModelConfig(
39
38
 
40
39
  model_config = ConfigDict(protected_namespaces=(), extra="allow")
41
40
 
42
- api_key: str | None = Field(default=None, description="Azure OpenAI API key to interact with hosted model.")
41
+ api_key: OptionalSecretStr = Field(default=None, description="Azure OpenAI API key to interact with hosted model.")
43
42
  api_version: str = Field(default="2025-04-01-preview", description="Azure OpenAI API version.")
44
43
  azure_endpoint: str | None = Field(validation_alias=AliasChoices("azure_endpoint", "base_url"),
45
44
  serialization_alias="azure_endpoint",
@@ -49,6 +48,16 @@ class AzureOpenAIModelConfig(
49
48
  serialization_alias="azure_deployment",
50
49
  description="The Azure OpenAI hosted model/deployment name.")
51
50
  seed: int | None = Field(default=None, description="Random seed to set for generation.")
51
+ temperature: float | None = OptimizableField(
52
+ default=None,
53
+ ge=0.0,
54
+ description="Sampling temperature to control randomness in the output.",
55
+ space=SearchSpace(high=0.9, low=0.1, step=0.2))
56
+ top_p: float | None = OptimizableField(default=None,
57
+ ge=0.0,
58
+ le=1.0,
59
+ description="Top-p for distribution sampling.",
60
+ space=SearchSpace(high=1.0, low=0.5, step=0.1))
52
61
 
53
62
 
54
63
  @register_llm_provider(config_type=AzureOpenAIModelConfig)