nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. nat/agent/base.py +13 -8
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +6 -5
  5. nat/agent/react_agent/register.py +49 -39
  6. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  7. nat/agent/register.py +2 -0
  8. nat/agent/responses_api_agent/__init__.py +14 -0
  9. nat/agent/responses_api_agent/register.py +126 -0
  10. nat/agent/rewoo_agent/agent.py +304 -117
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +51 -38
  13. nat/agent/tool_calling_agent/agent.py +75 -17
  14. nat/agent/tool_calling_agent/register.py +46 -23
  15. nat/authentication/api_key/api_key_auth_provider.py +6 -11
  16. nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
  17. nat/authentication/credential_validator/__init__.py +14 -0
  18. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  19. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  20. nat/authentication/interfaces.py +5 -2
  21. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  22. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
  23. nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
  24. nat/builder/builder.py +55 -23
  25. nat/builder/component_utils.py +9 -5
  26. nat/builder/context.py +54 -15
  27. nat/builder/eval_builder.py +14 -9
  28. nat/builder/framework_enum.py +1 -0
  29. nat/builder/front_end.py +1 -1
  30. nat/builder/function.py +370 -0
  31. nat/builder/function_info.py +1 -1
  32. nat/builder/intermediate_step_manager.py +38 -2
  33. nat/builder/workflow.py +5 -0
  34. nat/builder/workflow_builder.py +306 -54
  35. nat/cli/cli_utils/config_override.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/mcp/__init__.py +14 -0
  38. nat/cli/commands/mcp/mcp.py +986 -0
  39. nat/cli/commands/optimize.py +90 -0
  40. nat/cli/commands/start.py +1 -1
  41. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  42. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  43. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  44. nat/cli/commands/workflow/workflow_commands.py +60 -18
  45. nat/cli/entrypoint.py +15 -11
  46. nat/cli/main.py +3 -0
  47. nat/cli/register_workflow.py +38 -4
  48. nat/cli/type_registry.py +72 -1
  49. nat/control_flow/__init__.py +0 -0
  50. nat/control_flow/register.py +20 -0
  51. nat/control_flow/router_agent/__init__.py +0 -0
  52. nat/control_flow/router_agent/agent.py +329 -0
  53. nat/control_flow/router_agent/prompt.py +48 -0
  54. nat/control_flow/router_agent/register.py +91 -0
  55. nat/control_flow/sequential_executor.py +166 -0
  56. nat/data_models/agent.py +34 -0
  57. nat/data_models/api_server.py +199 -69
  58. nat/data_models/authentication.py +23 -9
  59. nat/data_models/common.py +47 -0
  60. nat/data_models/component.py +2 -0
  61. nat/data_models/component_ref.py +11 -0
  62. nat/data_models/config.py +41 -17
  63. nat/data_models/dataset_handler.py +4 -3
  64. nat/data_models/function.py +34 -0
  65. nat/data_models/function_dependencies.py +8 -0
  66. nat/data_models/intermediate_step.py +9 -1
  67. nat/data_models/llm.py +15 -1
  68. nat/data_models/openai_mcp.py +46 -0
  69. nat/data_models/optimizable.py +208 -0
  70. nat/data_models/optimizer.py +161 -0
  71. nat/data_models/span.py +41 -3
  72. nat/data_models/thinking_mixin.py +2 -2
  73. nat/embedder/azure_openai_embedder.py +2 -1
  74. nat/embedder/nim_embedder.py +3 -2
  75. nat/embedder/openai_embedder.py +3 -2
  76. nat/eval/config.py +1 -1
  77. nat/eval/dataset_handler/dataset_downloader.py +3 -2
  78. nat/eval/dataset_handler/dataset_filter.py +34 -2
  79. nat/eval/evaluate.py +10 -3
  80. nat/eval/evaluator/base_evaluator.py +1 -1
  81. nat/eval/rag_evaluator/evaluate.py +7 -4
  82. nat/eval/register.py +4 -0
  83. nat/eval/runtime_evaluator/__init__.py +14 -0
  84. nat/eval/runtime_evaluator/evaluate.py +123 -0
  85. nat/eval/runtime_evaluator/register.py +100 -0
  86. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  87. nat/eval/trajectory_evaluator/register.py +1 -1
  88. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  89. nat/eval/usage_stats.py +2 -0
  90. nat/eval/utils/output_uploader.py +3 -2
  91. nat/eval/utils/weave_eval.py +17 -3
  92. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  93. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  94. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  95. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  96. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  97. nat/experimental/test_time_compute/models/strategy_base.py +2 -2
  98. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  99. nat/front_ends/console/authentication_flow_handler.py +82 -30
  100. nat/front_ends/console/console_front_end_plugin.py +19 -7
  101. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  102. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  103. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  104. nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
  105. nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
  106. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
  107. nat/front_ends/fastapi/job_store.py +518 -99
  108. nat/front_ends/fastapi/main.py +11 -19
  109. nat/front_ends/fastapi/message_handler.py +69 -44
  110. nat/front_ends/fastapi/message_validator.py +8 -7
  111. nat/front_ends/fastapi/utils.py +57 -0
  112. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  113. nat/front_ends/mcp/mcp_front_end_config.py +71 -3
  114. nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
  115. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
  116. nat/front_ends/mcp/memory_profiler.py +320 -0
  117. nat/front_ends/mcp/tool_converter.py +78 -25
  118. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  119. nat/llm/aws_bedrock_llm.py +21 -8
  120. nat/llm/azure_openai_llm.py +14 -5
  121. nat/llm/litellm_llm.py +80 -0
  122. nat/llm/nim_llm.py +23 -9
  123. nat/llm/openai_llm.py +19 -7
  124. nat/llm/register.py +4 -0
  125. nat/llm/utils/thinking.py +1 -1
  126. nat/observability/exporter/base_exporter.py +1 -1
  127. nat/observability/exporter/processing_exporter.py +29 -55
  128. nat/observability/exporter/span_exporter.py +43 -15
  129. nat/observability/exporter_manager.py +2 -2
  130. nat/observability/mixin/redaction_config_mixin.py +5 -4
  131. nat/observability/mixin/tagging_config_mixin.py +26 -14
  132. nat/observability/mixin/type_introspection_mixin.py +420 -107
  133. nat/observability/processor/batching_processor.py +1 -1
  134. nat/observability/processor/processor.py +3 -0
  135. nat/observability/processor/redaction/__init__.py +24 -0
  136. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  137. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  138. nat/observability/processor/redaction/redaction_processor.py +177 -0
  139. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  140. nat/observability/processor/span_tagging_processor.py +21 -14
  141. nat/observability/register.py +16 -0
  142. nat/profiler/callbacks/langchain_callback_handler.py +32 -7
  143. nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
  144. nat/profiler/callbacks/token_usage_base_model.py +2 -0
  145. nat/profiler/decorators/framework_wrapper.py +61 -9
  146. nat/profiler/decorators/function_tracking.py +35 -3
  147. nat/profiler/forecasting/models/linear_model.py +1 -1
  148. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  149. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  150. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  151. nat/profiler/parameter_optimization/__init__.py +0 -0
  152. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  153. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  154. nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
  155. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  156. nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
  157. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  158. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  159. nat/profiler/utils.py +3 -1
  160. nat/registry_handlers/pypi/register_pypi.py +5 -3
  161. nat/registry_handlers/rest/register_rest.py +5 -3
  162. nat/retriever/milvus/retriever.py +1 -1
  163. nat/retriever/nemo_retriever/register.py +2 -1
  164. nat/runtime/loader.py +1 -1
  165. nat/runtime/runner.py +111 -6
  166. nat/runtime/session.py +49 -3
  167. nat/settings/global_settings.py +2 -2
  168. nat/tool/chat_completion.py +4 -1
  169. nat/tool/code_execution/code_sandbox.py +3 -6
  170. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  171. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
  172. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  173. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  174. nat/tool/datetime_tools.py +1 -1
  175. nat/tool/github_tools.py +450 -0
  176. nat/tool/memory_tools/add_memory_tool.py +3 -3
  177. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  178. nat/tool/memory_tools/get_memory_tool.py +4 -4
  179. nat/tool/register.py +2 -7
  180. nat/tool/server_tools.py +15 -2
  181. nat/utils/__init__.py +76 -0
  182. nat/utils/callable_utils.py +70 -0
  183. nat/utils/data_models/schema_validator.py +1 -1
  184. nat/utils/decorators.py +210 -0
  185. nat/utils/exception_handlers/automatic_retries.py +278 -72
  186. nat/utils/io/yaml_tools.py +73 -3
  187. nat/utils/log_levels.py +25 -0
  188. nat/utils/responses_api.py +26 -0
  189. nat/utils/string_utils.py +16 -0
  190. nat/utils/type_converter.py +12 -3
  191. nat/utils/type_utils.py +6 -2
  192. nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
  193. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
  194. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
  195. nat/cli/commands/info/list_mcp.py +0 -461
  196. nat/data_models/temperature_mixin.py +0 -43
  197. nat/data_models/top_p_mixin.py +0 -43
  198. nat/observability/processor/header_redaction_processor.py +0 -123
  199. nat/observability/processor/redaction_processor.py +0 -77
  200. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  201. nat/tool/github_tools/create_github_commit.py +0 -133
  202. nat/tool/github_tools/create_github_issue.py +0 -87
  203. nat/tool/github_tools/create_github_pr.py +0 -106
  204. nat/tool/github_tools/get_github_file.py +0 -106
  205. nat/tool/github_tools/get_github_issue.py +0 -166
  206. nat/tool/github_tools/get_github_pr.py +0 -256
  207. nat/tool/github_tools/update_github_issue.py +0 -100
  208. nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
  209. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  210. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
  211. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  212. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
  213. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,7 @@ from nat.builder.context import Context
30
30
  from nat.builder.framework_enum import LLMFrameworkEnum
31
31
  from nat.data_models.intermediate_step import IntermediateStepPayload
32
32
  from nat.data_models.intermediate_step import IntermediateStepType
33
+ from nat.data_models.intermediate_step import ServerToolUseSchema
33
34
  from nat.data_models.intermediate_step import StreamEventData
34
35
  from nat.data_models.intermediate_step import TraceMetadata
35
36
  from nat.data_models.intermediate_step import UsageInfo
@@ -64,6 +65,26 @@ class LlamaIndexProfilerHandler(BaseCallbackHandler, BaseProfilerCallback):
64
65
  self._run_id_to_tool_input = {}
65
66
  self._run_id_to_timestamp = {}
66
67
 
68
+ @staticmethod
69
+ def _extract_token_usage(response: ChatResponse) -> TokenUsageBaseModel:
70
+ token_usage = TokenUsageBaseModel()
71
+ try:
72
+ if response and response.additional_kwargs and "usage" in response.additional_kwargs:
73
+ usage = response.additional_kwargs["usage"] if "usage" in response.additional_kwargs else {}
74
+ token_usage.prompt_tokens = usage.input_tokens if hasattr(usage, "input_tokens") else 0
75
+ token_usage.completion_tokens = usage.output_tokens if hasattr(usage, "output_tokens") else 0
76
+
77
+ if hasattr(usage, "input_tokens_details") and hasattr(usage.input_tokens_details, "cached_tokens"):
78
+ token_usage.cached_tokens = usage.input_tokens_details.cached_tokens
79
+
80
+ if hasattr(usage, "output_tokens_details") and hasattr(usage.output_tokens_details, "reasoning_tokens"):
81
+ token_usage.reasoning_tokens = usage.output_tokens_details.reasoning_tokens
82
+
83
+ except Exception as e:
84
+ logger.debug("Error extracting token usage: %s", e, exc_info=True)
85
+
86
+ return token_usage
87
+
67
88
  def on_event_start(
68
89
  self,
69
90
  event_type: CBEventType,
@@ -167,6 +188,18 @@ class LlamaIndexProfilerHandler(BaseCallbackHandler, BaseProfilerCallback):
167
188
  except Exception as e:
168
189
  logger.exception("Error getting model name: %s", e)
169
190
 
191
+ # Append usage data to NAT usage stats
192
+ tool_outputs_list = []
193
+ # Check if message.additional_kwargs as tool_outputs indicative of server side tool calling
194
+ if response and response.additional_kwargs and "built_in_tool_calls" in response.additional_kwargs:
195
+ tools_outputs = response.additional_kwargs["built_in_tool_calls"]
196
+ if isinstance(tools_outputs, list):
197
+ for tool in tools_outputs:
198
+ try:
199
+ tool_outputs_list.append(ServerToolUseSchema(**tool.model_dump()))
200
+ except Exception:
201
+ pass
202
+
170
203
  # Append usage data to NAT usage stats
171
204
  with self._lock:
172
205
  stats = IntermediateStepPayload(
@@ -176,8 +209,9 @@ class LlamaIndexProfilerHandler(BaseCallbackHandler, BaseProfilerCallback):
176
209
  name=model_name,
177
210
  UUID=event_id,
178
211
  data=StreamEventData(input=self._run_id_to_llm_input.get(event_id), output=llm_text_output),
179
- metadata=TraceMetadata(chat_responses=response.message if response.message else None),
180
- usage_info=UsageInfo(token_usage=TokenUsageBaseModel(**response.additional_kwargs)))
212
+ metadata=TraceMetadata(chat_responses=response.message if response.message else None,
213
+ tool_outputs=tool_outputs_list if tool_outputs_list else []),
214
+ usage_info=UsageInfo(token_usage=self._extract_token_usage(response)))
181
215
  self.step_manager.push_intermediate_step(stats)
182
216
 
183
217
  elif event_type == CBEventType.FUNCTION_CALL and payload:
@@ -24,4 +24,6 @@ class TokenUsageBaseModel(BaseModel):
24
24
 
25
25
  prompt_tokens: int = Field(default=0, description="Number of tokens in the prompt.")
26
26
  completion_tokens: int = Field(default=0, description="Number of tokens in the completion.")
27
+ cached_tokens: int = Field(default=0, description="Number of tokens read from cache.")
28
+ reasoning_tokens: int = Field(default=0, description="Number of tokens used for reasoning.")
27
29
  total_tokens: int = Field(default=0, description="Number of tokens total.")
@@ -17,6 +17,7 @@ from __future__ import annotations
17
17
 
18
18
  import functools
19
19
  import logging
20
+ from collections.abc import AsyncIterator
20
21
  from collections.abc import Callable
21
22
  from contextlib import AbstractAsyncContextManager as AsyncContextManager
22
23
  from contextlib import asynccontextmanager
@@ -32,35 +33,55 @@ _library_instrumented = {
32
33
  "crewai": False,
33
34
  "semantic_kernel": False,
34
35
  "agno": False,
36
+ "adk": False,
35
37
  }
36
38
 
37
39
  callback_handler_var: ContextVar[Any | None] = ContextVar("callback_handler_var", default=None)
38
40
 
39
41
 
40
42
  def set_framework_profiler_handler(
41
- workflow_llms: dict = None,
42
- frameworks: list[LLMFrameworkEnum] = None,
43
+ workflow_llms: dict | None = None,
44
+ frameworks: list[LLMFrameworkEnum] | None = None,
43
45
  ) -> Callable[[Callable[..., AsyncContextManager[Any]]], Callable[..., AsyncContextManager[Any]]]:
44
46
  """
45
47
  Decorator that wraps an async context manager function to set up framework-specific profiling.
48
+
49
+ Args:
50
+ workflow_llms (dict | None): A dictionary of workflow LLM configurations.
51
+ frameworks (list[LLMFrameworkEnum] | None): A list of LLM frameworks used in the workflow functions.
52
+
53
+ Returns:
54
+ Callable[[Callable[..., AsyncContextManager[Any]]], Callable[..., AsyncContextManager[Any]]]:
55
+ A decorator that wraps the original function with profiling setup.
46
56
  """
47
57
 
48
58
  def decorator(func: Callable[..., AsyncContextManager[Any]]) -> Callable[..., AsyncContextManager[Any]]:
59
+ """The actual decorator that wraps the function.
60
+
61
+ Args:
62
+ func (Callable[..., AsyncContextManager[Any]]): The function to wrap.
63
+
64
+ Returns:
65
+ Callable[..., AsyncContextManager[Any]]: The wrapped function.
66
+ """
49
67
 
50
68
  @functools.wraps(func)
51
69
  @asynccontextmanager
52
70
  async def wrapper(workflow_config, builder):
53
71
 
54
- if LLMFrameworkEnum.LANGCHAIN in frameworks and not _library_instrumented["langchain"]:
55
- from langchain_core.tracers.context import register_configure_hook
56
-
72
+ if LLMFrameworkEnum.LANGCHAIN in frameworks:
73
+ # Always set a fresh handler in the current context so callbacks
74
+ # route to the active run. Only register the hook once globally.
57
75
  from nat.profiler.callbacks.langchain_callback_handler import LangchainProfilerHandler
58
76
 
59
77
  handler = LangchainProfilerHandler()
60
78
  callback_handler_var.set(handler)
61
- register_configure_hook(callback_handler_var, inheritable=True)
62
- _library_instrumented["langchain"] = True
63
- logger.debug("LangChain/LangGraph callback handler registered")
79
+
80
+ if not _library_instrumented["langchain"]:
81
+ from langchain_core.tracers.context import register_configure_hook
82
+ register_configure_hook(callback_handler_var, inheritable=True)
83
+ _library_instrumented["langchain"] = True
84
+ logger.debug("LangChain/LangGraph callback hook registered")
64
85
 
65
86
  if LLMFrameworkEnum.LLAMA_INDEX in frameworks:
66
87
  from llama_index.core import Settings
@@ -96,6 +117,20 @@ def set_framework_profiler_handler(
96
117
  _library_instrumented["agno"] = True
97
118
  logger.info("Agno callback handler registered")
98
119
 
120
+ if LLMFrameworkEnum.ADK in frameworks and not _library_instrumented["adk"]:
121
+ try:
122
+ from nat.plugins.adk.adk_callback_handler import ADKProfilerHandler
123
+ except ImportError as e:
124
+ logger.warning(
125
+ "ADK profiler not available. " +
126
+ "Install NAT with ADK extras: pip install \"nvidia-nat[adk]\". Error: %s",
127
+ e)
128
+ else:
129
+ handler = ADKProfilerHandler()
130
+ handler.instrument()
131
+ _library_instrumented["adk"] = True
132
+ logger.debug("ADK callback handler registered")
133
+
99
134
  # IMPORTANT: actually call the wrapped function as an async context manager
100
135
  async with func(workflow_config, builder) as result:
101
136
  yield result
@@ -114,11 +149,28 @@ def chain_wrapped_build_fn(
114
149
  Convert an original build function into an async context manager that
115
150
  wraps it with a single call to set_framework_profiler_handler, passing
116
151
  all frameworks at once.
152
+
153
+ Args:
154
+ original_build_fn (Callable[..., AsyncContextManager]): The original build function to wrap.
155
+ workflow_llms (dict): A dictionary of workflow LLM configurations.
156
+ function_frameworks (list[LLMFrameworkEnum]): A list of LLM frameworks used in the workflow functions.
157
+
158
+ Returns:
159
+ Callable[..., AsyncContextManager]: The wrapped build function.
117
160
  """
118
161
 
119
162
  # Define a base async context manager that simply calls the original build function.
120
163
  @asynccontextmanager
121
- async def base_fn(*args, **kwargs):
164
+ async def base_fn(*args, **kwargs) -> AsyncIterator[Any]:
165
+ """Base async context manager that calls the original build function.
166
+
167
+ Args:
168
+ *args: Positional arguments to pass to the original build function.
169
+ **kwargs: Keyword arguments to pass to the original build function.
170
+
171
+ Yields:
172
+ The result of the original build function.
173
+ """
122
174
  async with original_build_fn(*args, **kwargs) as w:
123
175
  yield w
124
176
 
@@ -18,7 +18,9 @@ import inspect
18
18
  import uuid
19
19
  from collections.abc import Callable
20
20
  from typing import Any
21
+ from typing import TypeVar
21
22
  from typing import cast
23
+ from typing import overload
22
24
 
23
25
  from pydantic import BaseModel
24
26
 
@@ -38,10 +40,10 @@ def _serialize_data(obj: Any) -> Any:
38
40
 
39
41
  if isinstance(obj, dict):
40
42
  return {str(k): _serialize_data(v) for k, v in obj.items()}
41
- if isinstance(obj, (list, tuple, set)):
43
+ if isinstance(obj, list | tuple | set):
42
44
  return [_serialize_data(item) for item in obj]
43
45
 
44
- if isinstance(obj, (str, int, float, bool, type(None))):
46
+ if isinstance(obj, str | int | float | bool | type(None)):
45
47
  return obj
46
48
 
47
49
  # Fallback
@@ -77,7 +79,24 @@ def push_intermediate_step(step_manager: IntermediateStepManager,
77
79
  step_manager.push_intermediate_step(payload)
78
80
 
79
81
 
80
- def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None):
82
+ # Type variable for overloads
83
+ F = TypeVar('F', bound=Callable[..., Any])
84
+
85
+
86
+ # Overloads for different function types
87
+ @overload
88
+ def track_function(func: F, *, metadata: dict[str, Any] | None = None) -> F:
89
+ """Overload for when a function is passed directly."""
90
+ ...
91
+
92
+
93
+ @overload
94
+ def track_function(*, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
95
+ """Overload for decorator factory usage (when called with parentheses)."""
96
+ ...
97
+
98
+
99
+ def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None) -> Any:
81
100
  """
82
101
  Decorator that can wrap any type of function (sync, async, generator,
83
102
  async generator) and executes "tracking logic" around it.
@@ -256,6 +275,19 @@ def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None):
256
275
  return sync_wrapper
257
276
 
258
277
 
278
+ # Overloads for track_unregistered_function
279
+ @overload
280
+ def track_unregistered_function(func: F, *, name: str | None = None, metadata: dict[str, Any] | None = None) -> F:
281
+ """Overload for when a function is passed directly."""
282
+ ...
283
+
284
+
285
+ @overload
286
+ def track_unregistered_function(*, name: str | None = None, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
287
+ """Overload for decorator factory usage (when called with parentheses)."""
288
+ ...
289
+
290
+
259
291
  def track_unregistered_function(func: Callable[..., Any] | None = None,
260
292
  *,
261
293
  name: str | None = None,
@@ -36,7 +36,7 @@ class LinearModel(ForecastingBaseModel):
36
36
  except ImportError:
37
37
  logger.error(
38
38
  "scikit-learn is not installed. Please install scikit-learn to use the LinearModel "
39
- "profiling model or install `nvidia-nat[profiler]` to install all necessary profiling packages.")
39
+ "profiling model or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
40
40
 
41
41
  raise
42
42
 
@@ -36,7 +36,7 @@ class RandomForestModel(ForecastingBaseModel):
36
36
  except ImportError:
37
37
  logger.error(
38
38
  "scikit-learn is not installed. Please install scikit-learn to use the RandomForest "
39
- "profiling model or install `nvidia-nat[profiler]` to install all necessary profiling packages.")
39
+ "profiling model or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
40
40
 
41
41
  raise
42
42
 
@@ -304,7 +304,7 @@ def save_gantt_chart(all_nodes: list[CallNode], output_path: str) -> None:
304
304
  import matplotlib.pyplot as plt
305
305
  except ImportError:
306
306
  logger.error("matplotlib is not installed. Please install matplotlib to use generate plots for the profiler "
307
- "or install `nvidia-nat[profiler]` to install all necessary profiling packages.")
307
+ "or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
308
308
 
309
309
  raise
310
310
 
@@ -212,7 +212,7 @@ def run_prefixspan(sequences_map: dict[int, list[PrefixCallNode]],
212
212
  from prefixspan import PrefixSpan
213
213
  except ImportError:
214
214
  logger.error("prefixspan is not installed. Please install prefixspan to run the prefix analysis in the "
215
- "profiler or install `nvidia-nat[profiler]` to install all necessary profiling packages.")
215
+ "profiler or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
216
216
 
217
217
  raise
218
218
 
File without changes
@@ -0,0 +1,93 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ from typing import get_args
18
+ from typing import get_origin
19
+
20
+ from pydantic import BaseModel
21
+
22
+ from nat.data_models.optimizable import SearchSpace
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def walk_optimizables(obj: BaseModel, path: str = "") -> dict[str, SearchSpace]:
28
+ """
29
+ Recursively build ``{flattened.path: SearchSpace}`` for every optimizable
30
+ field inside *obj*.
31
+
32
+ * Honors ``optimizable_params`` on any model that mixes in
33
+ ``OptimizableMixin`` – only listed fields are kept.
34
+ * If a model contains optimizable fields **but** omits
35
+ ``optimizable_params``, we emit a warning and skip them.
36
+ """
37
+ spaces: dict[str, SearchSpace] = {}
38
+
39
+ allowed_params_raw = getattr(obj, "optimizable_params", None)
40
+ allowed_params = set(allowed_params_raw) if allowed_params_raw is not None else None
41
+ overrides = getattr(obj, "search_space", {}) or {}
42
+ has_optimizable_flag = False
43
+
44
+ for name, fld in obj.model_fields.items():
45
+ full = f"{path}.{name}" if path else name
46
+ extra = fld.json_schema_extra or {}
47
+
48
+ is_field_optimizable = extra.get("optimizable", False) or name in overrides
49
+ has_optimizable_flag = has_optimizable_flag or is_field_optimizable
50
+
51
+ # honour allow-list
52
+ if allowed_params is not None and name not in allowed_params:
53
+ continue
54
+
55
+ # 1. plain optimizable field or override from config
56
+ if is_field_optimizable:
57
+ space = overrides.get(name, extra.get("search_space"))
58
+ if space is None:
59
+ logger.error(
60
+ "Field %s is marked optimizable but no search space was provided.",
61
+ full,
62
+ )
63
+ raise ValueError(f"Field {full} is marked optimizable but no search space was provided")
64
+ spaces[full] = space
65
+
66
+ value = getattr(obj, name, None)
67
+
68
+ # 2. nested BaseModel
69
+ if isinstance(value, BaseModel):
70
+ spaces.update(walk_optimizables(value, full))
71
+
72
+ # 3. dict[str, BaseModel] container
73
+ elif isinstance(value, dict):
74
+ for key, subval in value.items():
75
+ if isinstance(subval, BaseModel):
76
+ spaces.update(walk_optimizables(subval, f"{full}.{key}"))
77
+
78
+ # 4. static-type fallback for class-level annotations
79
+ elif isinstance(obj, type):
80
+ ann = fld.annotation
81
+ if get_origin(ann) in (dict, dict):
82
+ _, val_t = get_args(ann) or (None, None)
83
+ if isinstance(val_t, type) and issubclass(val_t, BaseModel):
84
+ if allowed_params is None or name in allowed_params:
85
+ spaces[f"{full}.*"] = SearchSpace(low=None, high=None) # sentinel
86
+
87
+ if allowed_params is None and has_optimizable_flag:
88
+ logger.warning(
89
+ "Model %s contains optimizable fields but no `optimizable_params` "
90
+ "were defined; these fields will be ignored.",
91
+ obj.__class__.__name__,
92
+ )
93
+ return spaces
@@ -0,0 +1,67 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ from pydantic import BaseModel
19
+
20
+ from nat.data_models.optimizer import OptimizerRunConfig
21
+ from nat.experimental.decorators.experimental_warning_decorator import experimental
22
+ from nat.profiler.parameter_optimization.optimizable_utils import walk_optimizables
23
+ from nat.profiler.parameter_optimization.parameter_optimizer import optimize_parameters
24
+ from nat.profiler.parameter_optimization.prompt_optimizer import optimize_prompts
25
+ from nat.runtime.loader import load_config
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ @experimental(feature_name="Optimizer")
31
+ async def optimize_config(opt_run_config: OptimizerRunConfig):
32
+ """Entry-point called by the CLI or runtime."""
33
+ # ---------------- 1. load / normalise ---------------- #
34
+ if not isinstance(opt_run_config.config_file, BaseModel):
35
+ from nat.data_models.config import Config # guarded import
36
+ base_cfg: Config = load_config(config_file=opt_run_config.config_file)
37
+ else:
38
+ base_cfg = opt_run_config.config_file # already validated
39
+
40
+ # ---------------- 2. discover search space ----------- #
41
+ full_space = walk_optimizables(base_cfg)
42
+ if not full_space:
43
+ logger.warning("No optimizable parameters found in the configuration. "
44
+ "Skipping optimization.")
45
+ return base_cfg
46
+
47
+ # ---------------- 3. numeric / enum tuning ----------- #
48
+ tuned_cfg = base_cfg
49
+ if base_cfg.optimizer.numeric.enabled:
50
+ tuned_cfg = optimize_parameters(
51
+ base_cfg=base_cfg,
52
+ full_space=full_space,
53
+ optimizer_config=base_cfg.optimizer,
54
+ opt_run_config=opt_run_config,
55
+ )
56
+
57
+ # ---------------- 4. prompt optimization ------------- #
58
+ if base_cfg.optimizer.prompt.enabled:
59
+ await optimize_prompts(
60
+ base_cfg=tuned_cfg,
61
+ full_space=full_space,
62
+ optimizer_config=base_cfg.optimizer,
63
+ opt_run_config=opt_run_config,
64
+ )
65
+
66
+ logger.info("All optimization phases complete.")
67
+ return tuned_cfg
@@ -0,0 +1,189 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import logging
18
+ from collections.abc import Mapping as Dict
19
+
20
+ import optuna
21
+ import yaml
22
+
23
+ from nat.data_models.config import Config
24
+ from nat.data_models.optimizable import SearchSpace
25
+ from nat.data_models.optimizer import OptimizerConfig
26
+ from nat.data_models.optimizer import OptimizerRunConfig
27
+ from nat.data_models.optimizer import SamplerType
28
+ from nat.eval.evaluate import EvaluationRun
29
+ from nat.eval.evaluate import EvaluationRunConfig
30
+ from nat.experimental.decorators.experimental_warning_decorator import experimental
31
+ from nat.profiler.parameter_optimization.parameter_selection import pick_trial
32
+ from nat.profiler.parameter_optimization.update_helpers import apply_suggestions
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ @experimental(feature_name="Optimizer")
38
+ def optimize_parameters(
39
+ *,
40
+ base_cfg: Config,
41
+ full_space: Dict[str, SearchSpace],
42
+ optimizer_config: OptimizerConfig,
43
+ opt_run_config: OptimizerRunConfig,
44
+ ) -> Config:
45
+ """Tune all *non-prompt* hyper-parameters and persist the best config."""
46
+ space = {k: v for k, v in full_space.items() if not v.is_prompt}
47
+
48
+ # Ensure output_path is not None
49
+ if optimizer_config.output_path is None:
50
+ raise ValueError("optimizer_config.output_path cannot be None")
51
+ out_dir = optimizer_config.output_path
52
+ out_dir.mkdir(parents=True, exist_ok=True)
53
+
54
+ # Ensure eval_metrics is not None
55
+ if optimizer_config.eval_metrics is None:
56
+ raise ValueError("optimizer_config.eval_metrics cannot be None")
57
+
58
+ metric_cfg = optimizer_config.eval_metrics
59
+ directions = [v.direction for v in metric_cfg.values()]
60
+ eval_metrics = [v.evaluator_name for v in metric_cfg.values()]
61
+ weights = [v.weight for v in metric_cfg.values()]
62
+
63
+ # Create appropriate sampler based on configuration
64
+ sampler_type = optimizer_config.numeric.sampler
65
+
66
+ if sampler_type == SamplerType.GRID:
67
+ # For grid search, convert the existing space to value sequences
68
+ grid_search_space = {param_name: search_space.to_grid_values() for param_name, search_space in space.items()}
69
+ sampler = optuna.samplers.GridSampler(grid_search_space)
70
+ logger.info("Using Grid sampler for numeric optimization")
71
+ else:
72
+ # None or BAYESIAN: let Optuna choose defaults
73
+ sampler = None
74
+ logger.info(
75
+ "Using Optuna default sampler types: TPESampler for single-objective, NSGAIISampler for multi-objective")
76
+
77
+ study = optuna.create_study(directions=directions, sampler=sampler)
78
+
79
+ # Create output directory for intermediate files
80
+ out_dir = optimizer_config.output_path
81
+ out_dir.mkdir(parents=True, exist_ok=True)
82
+
83
+ async def _run_eval(runner: EvaluationRun):
84
+ return await runner.run_and_evaluate()
85
+
86
+ def _objective(trial: optuna.Trial):
87
+ reps = max(1, getattr(optimizer_config, "reps_per_param_set", 1))
88
+
89
+ # build trial config
90
+ suggestions = {p: spec.suggest(trial, p) for p, spec in space.items()}
91
+ cfg_trial = apply_suggestions(base_cfg, suggestions)
92
+
93
+ async def _single_eval(trial_idx: int) -> list[float]: # noqa: ARG001
94
+ eval_cfg = EvaluationRunConfig(
95
+ config_file=cfg_trial,
96
+ dataset=opt_run_config.dataset,
97
+ result_json_path=opt_run_config.result_json_path,
98
+ endpoint=opt_run_config.endpoint,
99
+ endpoint_timeout=opt_run_config.endpoint_timeout,
100
+ )
101
+ scores = await _run_eval(EvaluationRun(config=eval_cfg))
102
+ values = []
103
+ for metric_name in eval_metrics:
104
+ metric = next(r[1] for r in scores.evaluation_results if r[0] == metric_name)
105
+ values.append(metric.average_score)
106
+
107
+ return values
108
+
109
+ # Create tasks for all evaluations
110
+ async def _run_all_evals():
111
+ tasks = [_single_eval(i) for i in range(reps)]
112
+ return await asyncio.gather(*tasks)
113
+
114
+ # Calculate padding width based on total number of trials
115
+ trial_id_width = len(str(max(0, optimizer_config.numeric.n_trials - 1)))
116
+ trial_id_padded = f"{trial.number:0{trial_id_width}d}"
117
+ with (out_dir / f"config_numeric_trial_{trial_id_padded}.yml").open("w") as fh:
118
+ yaml.dump(cfg_trial.model_dump(), fh)
119
+
120
+ all_scores = asyncio.run(_run_all_evals())
121
+ # Persist raw per‑repetition scores so they appear in `trials_dataframe`.
122
+ trial.set_user_attr("rep_scores", all_scores)
123
+ return [sum(run[i] for run in all_scores) / reps for i in range(len(eval_metrics))]
124
+
125
+ logger.info("Starting numeric / enum parameter optimization...")
126
+ study.optimize(_objective, n_trials=optimizer_config.numeric.n_trials)
127
+ logger.info("Numeric optimization finished")
128
+
129
+ best_params = pick_trial(
130
+ study=study,
131
+ mode=optimizer_config.multi_objective_combination_mode,
132
+ weights=weights,
133
+ ).params
134
+ tuned_cfg = apply_suggestions(base_cfg, best_params)
135
+
136
+ # Save final results (out_dir already created and defined above)
137
+ with (out_dir / "optimized_config.yml").open("w") as fh:
138
+ yaml.dump(tuned_cfg.model_dump(mode='json'), fh)
139
+ with (out_dir / "trials_dataframe_params.csv").open("w") as fh:
140
+ # Export full trials DataFrame (values, params, timings, etc.).
141
+ df = study.trials_dataframe()
142
+
143
+ # Rename values_X columns to actual metric names
144
+ metric_names = list(metric_cfg.keys())
145
+ rename_mapping = {}
146
+ for i, metric_name in enumerate(metric_names):
147
+ old_col = f"values_{i}"
148
+ if old_col in df.columns:
149
+ rename_mapping[old_col] = f"values_{metric_name}"
150
+ if rename_mapping:
151
+ df = df.rename(columns=rename_mapping)
152
+
153
+ # Normalise rep_scores column naming for convenience.
154
+ if "user_attrs_rep_scores" in df.columns and "rep_scores" not in df.columns:
155
+ df = df.rename(columns={"user_attrs_rep_scores": "rep_scores"})
156
+ elif "user_attrs" in df.columns and "rep_scores" not in df.columns:
157
+ # Some Optuna versions return a dict in a single user_attrs column.
158
+ df["rep_scores"] = df["user_attrs"].apply(lambda d: d.get("rep_scores") if isinstance(d, dict) else None)
159
+ df = df.drop(columns=["user_attrs"])
160
+
161
+ # Get Pareto optimal trial numbers from Optuna study
162
+ pareto_trials = study.best_trials
163
+ pareto_trial_numbers = {trial.number for trial in pareto_trials}
164
+ # Add boolean column indicating if trial is Pareto optimal
165
+ df["pareto_optimal"] = df["number"].isin(pareto_trial_numbers)
166
+
167
+ df.to_csv(fh, index=False)
168
+
169
+ # Generate Pareto front visualizations
170
+ try:
171
+ from nat.profiler.parameter_optimization.pareto_visualizer import create_pareto_visualization
172
+ logger.info("Generating Pareto front visualizations...")
173
+ create_pareto_visualization(
174
+ data_source=study,
175
+ metric_names=eval_metrics,
176
+ directions=directions,
177
+ output_dir=out_dir / "plots",
178
+ title_prefix="Parameter Optimization",
179
+ show_plots=False # Don't show plots in automated runs
180
+ )
181
+ logger.info("Pareto visualizations saved to: %s", out_dir / "plots")
182
+ except ImportError as ie:
183
+ logger.warning("Could not import visualization dependencies: %s. "
184
+ "Have you installed nvidia-nat-profiling?",
185
+ ie)
186
+ except Exception as e:
187
+ logger.warning("Failed to generate visualizations: %s", e)
188
+
189
+ return tuned_cfg