nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +13 -8
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +6 -5
- nat/agent/react_agent/register.py +49 -39
- nat/agent/reasoning_agent/reasoning_agent.py +17 -15
- nat/agent/register.py +2 -0
- nat/agent/responses_api_agent/__init__.py +14 -0
- nat/agent/responses_api_agent/register.py +126 -0
- nat/agent/rewoo_agent/agent.py +304 -117
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +51 -38
- nat/agent/tool_calling_agent/agent.py +75 -17
- nat/agent/tool_calling_agent/register.py +46 -23
- nat/authentication/api_key/api_key_auth_provider.py +6 -11
- nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
- nat/authentication/credential_validator/__init__.py +14 -0
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
- nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
- nat/builder/builder.py +55 -23
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +54 -15
- nat/builder/eval_builder.py +14 -9
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +370 -0
- nat/builder/function_info.py +1 -1
- nat/builder/intermediate_step_manager.py +38 -2
- nat/builder/workflow.py +5 -0
- nat/builder/workflow_builder.py +306 -54
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/start.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/register.py.j2 +2 -2
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +60 -18
- nat/cli/entrypoint.py +15 -11
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +72 -1
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +199 -69
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +47 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +4 -3
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +8 -0
- nat/data_models/intermediate_step.py +9 -1
- nat/data_models/llm.py +15 -1
- nat/data_models/openai_mcp.py +46 -0
- nat/data_models/optimizable.py +208 -0
- nat/data_models/optimizer.py +161 -0
- nat/data_models/span.py +41 -3
- nat/data_models/thinking_mixin.py +2 -2
- nat/embedder/azure_openai_embedder.py +2 -1
- nat/embedder/nim_embedder.py +3 -2
- nat/embedder/openai_embedder.py +3 -2
- nat/eval/config.py +1 -1
- nat/eval/dataset_handler/dataset_downloader.py +3 -2
- nat/eval/dataset_handler/dataset_filter.py +34 -2
- nat/eval/evaluate.py +10 -3
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +7 -4
- nat/eval/register.py +4 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
- nat/eval/usage_stats.py +2 -0
- nat/eval/utils/output_uploader.py +3 -2
- nat/eval/utils/weave_eval.py +17 -3
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
- nat/experimental/test_time_compute/models/strategy_base.py +2 -2
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +19 -7
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +69 -44
- nat/front_ends/fastapi/message_validator.py +8 -7
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +71 -3
- nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
- nat/front_ends/mcp/memory_profiler.py +320 -0
- nat/front_ends/mcp/tool_converter.py +78 -25
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +21 -8
- nat/llm/azure_openai_llm.py +14 -5
- nat/llm/litellm_llm.py +80 -0
- nat/llm/nim_llm.py +23 -9
- nat/llm/openai_llm.py +19 -7
- nat/llm/register.py +4 -0
- nat/llm/utils/thinking.py +1 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +29 -55
- nat/observability/exporter/span_exporter.py +43 -15
- nat/observability/exporter_manager.py +2 -2
- nat/observability/mixin/redaction_config_mixin.py +5 -4
- nat/observability/mixin/tagging_config_mixin.py +26 -14
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +1 -1
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +21 -14
- nat/observability/register.py +16 -0
- nat/profiler/callbacks/langchain_callback_handler.py +32 -7
- nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
- nat/profiler/callbacks/token_usage_base_model.py +2 -0
- nat/profiler/decorators/framework_wrapper.py +61 -9
- nat/profiler/decorators/function_tracking.py +35 -3
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/utils.py +3 -1
- nat/registry_handlers/pypi/register_pypi.py +5 -3
- nat/registry_handlers/rest/register_rest.py +5 -3
- nat/retriever/milvus/retriever.py +1 -1
- nat/retriever/nemo_retriever/register.py +2 -1
- nat/runtime/loader.py +1 -1
- nat/runtime/runner.py +111 -6
- nat/runtime/session.py +49 -3
- nat/settings/global_settings.py +2 -2
- nat/tool/chat_completion.py +4 -1
- nat/tool/code_execution/code_sandbox.py +3 -6
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
- nat/tool/datetime_tools.py +1 -1
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +4 -4
- nat/tool/register.py +2 -7
- nat/tool/server_tools.py +15 -2
- nat/utils/__init__.py +76 -0
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +1 -1
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +278 -72
- nat/utils/io/yaml_tools.py +73 -3
- nat/utils/log_levels.py +25 -0
- nat/utils/responses_api.py +26 -0
- nat/utils/string_utils.py +16 -0
- nat/utils/type_converter.py +12 -3
- nat/utils/type_utils.py +6 -2
- nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -461
- nat/data_models/temperature_mixin.py +0 -43
- nat/data_models/top_p_mixin.py +0 -43
- nat/observability/processor/header_redaction_processor.py +0 -123
- nat/observability/processor/redaction_processor.py +0 -77
- nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
|
@@ -30,6 +30,7 @@ from nat.builder.context import Context
|
|
|
30
30
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
31
31
|
from nat.data_models.intermediate_step import IntermediateStepPayload
|
|
32
32
|
from nat.data_models.intermediate_step import IntermediateStepType
|
|
33
|
+
from nat.data_models.intermediate_step import ServerToolUseSchema
|
|
33
34
|
from nat.data_models.intermediate_step import StreamEventData
|
|
34
35
|
from nat.data_models.intermediate_step import TraceMetadata
|
|
35
36
|
from nat.data_models.intermediate_step import UsageInfo
|
|
@@ -64,6 +65,26 @@ class LlamaIndexProfilerHandler(BaseCallbackHandler, BaseProfilerCallback):
|
|
|
64
65
|
self._run_id_to_tool_input = {}
|
|
65
66
|
self._run_id_to_timestamp = {}
|
|
66
67
|
|
|
68
|
+
@staticmethod
|
|
69
|
+
def _extract_token_usage(response: ChatResponse) -> TokenUsageBaseModel:
|
|
70
|
+
token_usage = TokenUsageBaseModel()
|
|
71
|
+
try:
|
|
72
|
+
if response and response.additional_kwargs and "usage" in response.additional_kwargs:
|
|
73
|
+
usage = response.additional_kwargs["usage"] if "usage" in response.additional_kwargs else {}
|
|
74
|
+
token_usage.prompt_tokens = usage.input_tokens if hasattr(usage, "input_tokens") else 0
|
|
75
|
+
token_usage.completion_tokens = usage.output_tokens if hasattr(usage, "output_tokens") else 0
|
|
76
|
+
|
|
77
|
+
if hasattr(usage, "input_tokens_details") and hasattr(usage.input_tokens_details, "cached_tokens"):
|
|
78
|
+
token_usage.cached_tokens = usage.input_tokens_details.cached_tokens
|
|
79
|
+
|
|
80
|
+
if hasattr(usage, "output_tokens_details") and hasattr(usage.output_tokens_details, "reasoning_tokens"):
|
|
81
|
+
token_usage.reasoning_tokens = usage.output_tokens_details.reasoning_tokens
|
|
82
|
+
|
|
83
|
+
except Exception as e:
|
|
84
|
+
logger.debug("Error extracting token usage: %s", e, exc_info=True)
|
|
85
|
+
|
|
86
|
+
return token_usage
|
|
87
|
+
|
|
67
88
|
def on_event_start(
|
|
68
89
|
self,
|
|
69
90
|
event_type: CBEventType,
|
|
@@ -167,6 +188,18 @@ class LlamaIndexProfilerHandler(BaseCallbackHandler, BaseProfilerCallback):
|
|
|
167
188
|
except Exception as e:
|
|
168
189
|
logger.exception("Error getting model name: %s", e)
|
|
169
190
|
|
|
191
|
+
# Append usage data to NAT usage stats
|
|
192
|
+
tool_outputs_list = []
|
|
193
|
+
# Check if message.additional_kwargs as tool_outputs indicative of server side tool calling
|
|
194
|
+
if response and response.additional_kwargs and "built_in_tool_calls" in response.additional_kwargs:
|
|
195
|
+
tools_outputs = response.additional_kwargs["built_in_tool_calls"]
|
|
196
|
+
if isinstance(tools_outputs, list):
|
|
197
|
+
for tool in tools_outputs:
|
|
198
|
+
try:
|
|
199
|
+
tool_outputs_list.append(ServerToolUseSchema(**tool.model_dump()))
|
|
200
|
+
except Exception:
|
|
201
|
+
pass
|
|
202
|
+
|
|
170
203
|
# Append usage data to NAT usage stats
|
|
171
204
|
with self._lock:
|
|
172
205
|
stats = IntermediateStepPayload(
|
|
@@ -176,8 +209,9 @@ class LlamaIndexProfilerHandler(BaseCallbackHandler, BaseProfilerCallback):
|
|
|
176
209
|
name=model_name,
|
|
177
210
|
UUID=event_id,
|
|
178
211
|
data=StreamEventData(input=self._run_id_to_llm_input.get(event_id), output=llm_text_output),
|
|
179
|
-
metadata=TraceMetadata(chat_responses=response.message if response.message else None
|
|
180
|
-
|
|
212
|
+
metadata=TraceMetadata(chat_responses=response.message if response.message else None,
|
|
213
|
+
tool_outputs=tool_outputs_list if tool_outputs_list else []),
|
|
214
|
+
usage_info=UsageInfo(token_usage=self._extract_token_usage(response)))
|
|
181
215
|
self.step_manager.push_intermediate_step(stats)
|
|
182
216
|
|
|
183
217
|
elif event_type == CBEventType.FUNCTION_CALL and payload:
|
|
@@ -24,4 +24,6 @@ class TokenUsageBaseModel(BaseModel):
|
|
|
24
24
|
|
|
25
25
|
prompt_tokens: int = Field(default=0, description="Number of tokens in the prompt.")
|
|
26
26
|
completion_tokens: int = Field(default=0, description="Number of tokens in the completion.")
|
|
27
|
+
cached_tokens: int = Field(default=0, description="Number of tokens read from cache.")
|
|
28
|
+
reasoning_tokens: int = Field(default=0, description="Number of tokens used for reasoning.")
|
|
27
29
|
total_tokens: int = Field(default=0, description="Number of tokens total.")
|
|
@@ -17,6 +17,7 @@ from __future__ import annotations
|
|
|
17
17
|
|
|
18
18
|
import functools
|
|
19
19
|
import logging
|
|
20
|
+
from collections.abc import AsyncIterator
|
|
20
21
|
from collections.abc import Callable
|
|
21
22
|
from contextlib import AbstractAsyncContextManager as AsyncContextManager
|
|
22
23
|
from contextlib import asynccontextmanager
|
|
@@ -32,35 +33,55 @@ _library_instrumented = {
|
|
|
32
33
|
"crewai": False,
|
|
33
34
|
"semantic_kernel": False,
|
|
34
35
|
"agno": False,
|
|
36
|
+
"adk": False,
|
|
35
37
|
}
|
|
36
38
|
|
|
37
39
|
callback_handler_var: ContextVar[Any | None] = ContextVar("callback_handler_var", default=None)
|
|
38
40
|
|
|
39
41
|
|
|
40
42
|
def set_framework_profiler_handler(
|
|
41
|
-
workflow_llms: dict = None,
|
|
42
|
-
frameworks: list[LLMFrameworkEnum] = None,
|
|
43
|
+
workflow_llms: dict | None = None,
|
|
44
|
+
frameworks: list[LLMFrameworkEnum] | None = None,
|
|
43
45
|
) -> Callable[[Callable[..., AsyncContextManager[Any]]], Callable[..., AsyncContextManager[Any]]]:
|
|
44
46
|
"""
|
|
45
47
|
Decorator that wraps an async context manager function to set up framework-specific profiling.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
workflow_llms (dict | None): A dictionary of workflow LLM configurations.
|
|
51
|
+
frameworks (list[LLMFrameworkEnum] | None): A list of LLM frameworks used in the workflow functions.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Callable[[Callable[..., AsyncContextManager[Any]]], Callable[..., AsyncContextManager[Any]]]:
|
|
55
|
+
A decorator that wraps the original function with profiling setup.
|
|
46
56
|
"""
|
|
47
57
|
|
|
48
58
|
def decorator(func: Callable[..., AsyncContextManager[Any]]) -> Callable[..., AsyncContextManager[Any]]:
|
|
59
|
+
"""The actual decorator that wraps the function.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
func (Callable[..., AsyncContextManager[Any]]): The function to wrap.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Callable[..., AsyncContextManager[Any]]: The wrapped function.
|
|
66
|
+
"""
|
|
49
67
|
|
|
50
68
|
@functools.wraps(func)
|
|
51
69
|
@asynccontextmanager
|
|
52
70
|
async def wrapper(workflow_config, builder):
|
|
53
71
|
|
|
54
|
-
if LLMFrameworkEnum.LANGCHAIN in frameworks
|
|
55
|
-
|
|
56
|
-
|
|
72
|
+
if LLMFrameworkEnum.LANGCHAIN in frameworks:
|
|
73
|
+
# Always set a fresh handler in the current context so callbacks
|
|
74
|
+
# route to the active run. Only register the hook once globally.
|
|
57
75
|
from nat.profiler.callbacks.langchain_callback_handler import LangchainProfilerHandler
|
|
58
76
|
|
|
59
77
|
handler = LangchainProfilerHandler()
|
|
60
78
|
callback_handler_var.set(handler)
|
|
61
|
-
|
|
62
|
-
_library_instrumented["langchain"]
|
|
63
|
-
|
|
79
|
+
|
|
80
|
+
if not _library_instrumented["langchain"]:
|
|
81
|
+
from langchain_core.tracers.context import register_configure_hook
|
|
82
|
+
register_configure_hook(callback_handler_var, inheritable=True)
|
|
83
|
+
_library_instrumented["langchain"] = True
|
|
84
|
+
logger.debug("LangChain/LangGraph callback hook registered")
|
|
64
85
|
|
|
65
86
|
if LLMFrameworkEnum.LLAMA_INDEX in frameworks:
|
|
66
87
|
from llama_index.core import Settings
|
|
@@ -96,6 +117,20 @@ def set_framework_profiler_handler(
|
|
|
96
117
|
_library_instrumented["agno"] = True
|
|
97
118
|
logger.info("Agno callback handler registered")
|
|
98
119
|
|
|
120
|
+
if LLMFrameworkEnum.ADK in frameworks and not _library_instrumented["adk"]:
|
|
121
|
+
try:
|
|
122
|
+
from nat.plugins.adk.adk_callback_handler import ADKProfilerHandler
|
|
123
|
+
except ImportError as e:
|
|
124
|
+
logger.warning(
|
|
125
|
+
"ADK profiler not available. " +
|
|
126
|
+
"Install NAT with ADK extras: pip install \"nvidia-nat[adk]\". Error: %s",
|
|
127
|
+
e)
|
|
128
|
+
else:
|
|
129
|
+
handler = ADKProfilerHandler()
|
|
130
|
+
handler.instrument()
|
|
131
|
+
_library_instrumented["adk"] = True
|
|
132
|
+
logger.debug("ADK callback handler registered")
|
|
133
|
+
|
|
99
134
|
# IMPORTANT: actually call the wrapped function as an async context manager
|
|
100
135
|
async with func(workflow_config, builder) as result:
|
|
101
136
|
yield result
|
|
@@ -114,11 +149,28 @@ def chain_wrapped_build_fn(
|
|
|
114
149
|
Convert an original build function into an async context manager that
|
|
115
150
|
wraps it with a single call to set_framework_profiler_handler, passing
|
|
116
151
|
all frameworks at once.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
original_build_fn (Callable[..., AsyncContextManager]): The original build function to wrap.
|
|
155
|
+
workflow_llms (dict): A dictionary of workflow LLM configurations.
|
|
156
|
+
function_frameworks (list[LLMFrameworkEnum]): A list of LLM frameworks used in the workflow functions.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Callable[..., AsyncContextManager]: The wrapped build function.
|
|
117
160
|
"""
|
|
118
161
|
|
|
119
162
|
# Define a base async context manager that simply calls the original build function.
|
|
120
163
|
@asynccontextmanager
|
|
121
|
-
async def base_fn(*args, **kwargs):
|
|
164
|
+
async def base_fn(*args, **kwargs) -> AsyncIterator[Any]:
|
|
165
|
+
"""Base async context manager that calls the original build function.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
*args: Positional arguments to pass to the original build function.
|
|
169
|
+
**kwargs: Keyword arguments to pass to the original build function.
|
|
170
|
+
|
|
171
|
+
Yields:
|
|
172
|
+
The result of the original build function.
|
|
173
|
+
"""
|
|
122
174
|
async with original_build_fn(*args, **kwargs) as w:
|
|
123
175
|
yield w
|
|
124
176
|
|
|
@@ -18,7 +18,9 @@ import inspect
|
|
|
18
18
|
import uuid
|
|
19
19
|
from collections.abc import Callable
|
|
20
20
|
from typing import Any
|
|
21
|
+
from typing import TypeVar
|
|
21
22
|
from typing import cast
|
|
23
|
+
from typing import overload
|
|
22
24
|
|
|
23
25
|
from pydantic import BaseModel
|
|
24
26
|
|
|
@@ -38,10 +40,10 @@ def _serialize_data(obj: Any) -> Any:
|
|
|
38
40
|
|
|
39
41
|
if isinstance(obj, dict):
|
|
40
42
|
return {str(k): _serialize_data(v) for k, v in obj.items()}
|
|
41
|
-
if isinstance(obj,
|
|
43
|
+
if isinstance(obj, list | tuple | set):
|
|
42
44
|
return [_serialize_data(item) for item in obj]
|
|
43
45
|
|
|
44
|
-
if isinstance(obj,
|
|
46
|
+
if isinstance(obj, str | int | float | bool | type(None)):
|
|
45
47
|
return obj
|
|
46
48
|
|
|
47
49
|
# Fallback
|
|
@@ -77,7 +79,24 @@ def push_intermediate_step(step_manager: IntermediateStepManager,
|
|
|
77
79
|
step_manager.push_intermediate_step(payload)
|
|
78
80
|
|
|
79
81
|
|
|
80
|
-
|
|
82
|
+
# Type variable for overloads
|
|
83
|
+
F = TypeVar('F', bound=Callable[..., Any])
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# Overloads for different function types
|
|
87
|
+
@overload
|
|
88
|
+
def track_function(func: F, *, metadata: dict[str, Any] | None = None) -> F:
|
|
89
|
+
"""Overload for when a function is passed directly."""
|
|
90
|
+
...
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@overload
|
|
94
|
+
def track_function(*, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
|
|
95
|
+
"""Overload for decorator factory usage (when called with parentheses)."""
|
|
96
|
+
...
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None) -> Any:
|
|
81
100
|
"""
|
|
82
101
|
Decorator that can wrap any type of function (sync, async, generator,
|
|
83
102
|
async generator) and executes "tracking logic" around it.
|
|
@@ -256,6 +275,19 @@ def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None):
|
|
|
256
275
|
return sync_wrapper
|
|
257
276
|
|
|
258
277
|
|
|
278
|
+
# Overloads for track_unregistered_function
|
|
279
|
+
@overload
|
|
280
|
+
def track_unregistered_function(func: F, *, name: str | None = None, metadata: dict[str, Any] | None = None) -> F:
|
|
281
|
+
"""Overload for when a function is passed directly."""
|
|
282
|
+
...
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
@overload
|
|
286
|
+
def track_unregistered_function(*, name: str | None = None, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
|
|
287
|
+
"""Overload for decorator factory usage (when called with parentheses)."""
|
|
288
|
+
...
|
|
289
|
+
|
|
290
|
+
|
|
259
291
|
def track_unregistered_function(func: Callable[..., Any] | None = None,
|
|
260
292
|
*,
|
|
261
293
|
name: str | None = None,
|
|
@@ -36,7 +36,7 @@ class LinearModel(ForecastingBaseModel):
|
|
|
36
36
|
except ImportError:
|
|
37
37
|
logger.error(
|
|
38
38
|
"scikit-learn is not installed. Please install scikit-learn to use the LinearModel "
|
|
39
|
-
"profiling model or install
|
|
39
|
+
"profiling model or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
|
|
40
40
|
|
|
41
41
|
raise
|
|
42
42
|
|
|
@@ -36,7 +36,7 @@ class RandomForestModel(ForecastingBaseModel):
|
|
|
36
36
|
except ImportError:
|
|
37
37
|
logger.error(
|
|
38
38
|
"scikit-learn is not installed. Please install scikit-learn to use the RandomForest "
|
|
39
|
-
"profiling model or install
|
|
39
|
+
"profiling model or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
|
|
40
40
|
|
|
41
41
|
raise
|
|
42
42
|
|
|
@@ -304,7 +304,7 @@ def save_gantt_chart(all_nodes: list[CallNode], output_path: str) -> None:
|
|
|
304
304
|
import matplotlib.pyplot as plt
|
|
305
305
|
except ImportError:
|
|
306
306
|
logger.error("matplotlib is not installed. Please install matplotlib to use generate plots for the profiler "
|
|
307
|
-
"or install
|
|
307
|
+
"or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
|
|
308
308
|
|
|
309
309
|
raise
|
|
310
310
|
|
|
@@ -212,7 +212,7 @@ def run_prefixspan(sequences_map: dict[int, list[PrefixCallNode]],
|
|
|
212
212
|
from prefixspan import PrefixSpan
|
|
213
213
|
except ImportError:
|
|
214
214
|
logger.error("prefixspan is not installed. Please install prefixspan to run the prefix analysis in the "
|
|
215
|
-
"profiler or install
|
|
215
|
+
"profiler or install \"nvidia-nat[profiler]\" to install all necessary profiling packages.")
|
|
216
216
|
|
|
217
217
|
raise
|
|
218
218
|
|
|
File without changes
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from typing import get_args
|
|
18
|
+
from typing import get_origin
|
|
19
|
+
|
|
20
|
+
from pydantic import BaseModel
|
|
21
|
+
|
|
22
|
+
from nat.data_models.optimizable import SearchSpace
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def walk_optimizables(obj: BaseModel, path: str = "") -> dict[str, SearchSpace]:
|
|
28
|
+
"""
|
|
29
|
+
Recursively build ``{flattened.path: SearchSpace}`` for every optimizable
|
|
30
|
+
field inside *obj*.
|
|
31
|
+
|
|
32
|
+
* Honors ``optimizable_params`` on any model that mixes in
|
|
33
|
+
``OptimizableMixin`` – only listed fields are kept.
|
|
34
|
+
* If a model contains optimizable fields **but** omits
|
|
35
|
+
``optimizable_params``, we emit a warning and skip them.
|
|
36
|
+
"""
|
|
37
|
+
spaces: dict[str, SearchSpace] = {}
|
|
38
|
+
|
|
39
|
+
allowed_params_raw = getattr(obj, "optimizable_params", None)
|
|
40
|
+
allowed_params = set(allowed_params_raw) if allowed_params_raw is not None else None
|
|
41
|
+
overrides = getattr(obj, "search_space", {}) or {}
|
|
42
|
+
has_optimizable_flag = False
|
|
43
|
+
|
|
44
|
+
for name, fld in obj.model_fields.items():
|
|
45
|
+
full = f"{path}.{name}" if path else name
|
|
46
|
+
extra = fld.json_schema_extra or {}
|
|
47
|
+
|
|
48
|
+
is_field_optimizable = extra.get("optimizable", False) or name in overrides
|
|
49
|
+
has_optimizable_flag = has_optimizable_flag or is_field_optimizable
|
|
50
|
+
|
|
51
|
+
# honour allow-list
|
|
52
|
+
if allowed_params is not None and name not in allowed_params:
|
|
53
|
+
continue
|
|
54
|
+
|
|
55
|
+
# 1. plain optimizable field or override from config
|
|
56
|
+
if is_field_optimizable:
|
|
57
|
+
space = overrides.get(name, extra.get("search_space"))
|
|
58
|
+
if space is None:
|
|
59
|
+
logger.error(
|
|
60
|
+
"Field %s is marked optimizable but no search space was provided.",
|
|
61
|
+
full,
|
|
62
|
+
)
|
|
63
|
+
raise ValueError(f"Field {full} is marked optimizable but no search space was provided")
|
|
64
|
+
spaces[full] = space
|
|
65
|
+
|
|
66
|
+
value = getattr(obj, name, None)
|
|
67
|
+
|
|
68
|
+
# 2. nested BaseModel
|
|
69
|
+
if isinstance(value, BaseModel):
|
|
70
|
+
spaces.update(walk_optimizables(value, full))
|
|
71
|
+
|
|
72
|
+
# 3. dict[str, BaseModel] container
|
|
73
|
+
elif isinstance(value, dict):
|
|
74
|
+
for key, subval in value.items():
|
|
75
|
+
if isinstance(subval, BaseModel):
|
|
76
|
+
spaces.update(walk_optimizables(subval, f"{full}.{key}"))
|
|
77
|
+
|
|
78
|
+
# 4. static-type fallback for class-level annotations
|
|
79
|
+
elif isinstance(obj, type):
|
|
80
|
+
ann = fld.annotation
|
|
81
|
+
if get_origin(ann) in (dict, dict):
|
|
82
|
+
_, val_t = get_args(ann) or (None, None)
|
|
83
|
+
if isinstance(val_t, type) and issubclass(val_t, BaseModel):
|
|
84
|
+
if allowed_params is None or name in allowed_params:
|
|
85
|
+
spaces[f"{full}.*"] = SearchSpace(low=None, high=None) # sentinel
|
|
86
|
+
|
|
87
|
+
if allowed_params is None and has_optimizable_flag:
|
|
88
|
+
logger.warning(
|
|
89
|
+
"Model %s contains optimizable fields but no `optimizable_params` "
|
|
90
|
+
"were defined; these fields will be ignored.",
|
|
91
|
+
obj.__class__.__name__,
|
|
92
|
+
)
|
|
93
|
+
return spaces
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
|
|
20
|
+
from nat.data_models.optimizer import OptimizerRunConfig
|
|
21
|
+
from nat.experimental.decorators.experimental_warning_decorator import experimental
|
|
22
|
+
from nat.profiler.parameter_optimization.optimizable_utils import walk_optimizables
|
|
23
|
+
from nat.profiler.parameter_optimization.parameter_optimizer import optimize_parameters
|
|
24
|
+
from nat.profiler.parameter_optimization.prompt_optimizer import optimize_prompts
|
|
25
|
+
from nat.runtime.loader import load_config
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@experimental(feature_name="Optimizer")
|
|
31
|
+
async def optimize_config(opt_run_config: OptimizerRunConfig):
|
|
32
|
+
"""Entry-point called by the CLI or runtime."""
|
|
33
|
+
# ---------------- 1. load / normalise ---------------- #
|
|
34
|
+
if not isinstance(opt_run_config.config_file, BaseModel):
|
|
35
|
+
from nat.data_models.config import Config # guarded import
|
|
36
|
+
base_cfg: Config = load_config(config_file=opt_run_config.config_file)
|
|
37
|
+
else:
|
|
38
|
+
base_cfg = opt_run_config.config_file # already validated
|
|
39
|
+
|
|
40
|
+
# ---------------- 2. discover search space ----------- #
|
|
41
|
+
full_space = walk_optimizables(base_cfg)
|
|
42
|
+
if not full_space:
|
|
43
|
+
logger.warning("No optimizable parameters found in the configuration. "
|
|
44
|
+
"Skipping optimization.")
|
|
45
|
+
return base_cfg
|
|
46
|
+
|
|
47
|
+
# ---------------- 3. numeric / enum tuning ----------- #
|
|
48
|
+
tuned_cfg = base_cfg
|
|
49
|
+
if base_cfg.optimizer.numeric.enabled:
|
|
50
|
+
tuned_cfg = optimize_parameters(
|
|
51
|
+
base_cfg=base_cfg,
|
|
52
|
+
full_space=full_space,
|
|
53
|
+
optimizer_config=base_cfg.optimizer,
|
|
54
|
+
opt_run_config=opt_run_config,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# ---------------- 4. prompt optimization ------------- #
|
|
58
|
+
if base_cfg.optimizer.prompt.enabled:
|
|
59
|
+
await optimize_prompts(
|
|
60
|
+
base_cfg=tuned_cfg,
|
|
61
|
+
full_space=full_space,
|
|
62
|
+
optimizer_config=base_cfg.optimizer,
|
|
63
|
+
opt_run_config=opt_run_config,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
logger.info("All optimization phases complete.")
|
|
67
|
+
return tuned_cfg
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
from collections.abc import Mapping as Dict
|
|
19
|
+
|
|
20
|
+
import optuna
|
|
21
|
+
import yaml
|
|
22
|
+
|
|
23
|
+
from nat.data_models.config import Config
|
|
24
|
+
from nat.data_models.optimizable import SearchSpace
|
|
25
|
+
from nat.data_models.optimizer import OptimizerConfig
|
|
26
|
+
from nat.data_models.optimizer import OptimizerRunConfig
|
|
27
|
+
from nat.data_models.optimizer import SamplerType
|
|
28
|
+
from nat.eval.evaluate import EvaluationRun
|
|
29
|
+
from nat.eval.evaluate import EvaluationRunConfig
|
|
30
|
+
from nat.experimental.decorators.experimental_warning_decorator import experimental
|
|
31
|
+
from nat.profiler.parameter_optimization.parameter_selection import pick_trial
|
|
32
|
+
from nat.profiler.parameter_optimization.update_helpers import apply_suggestions
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@experimental(feature_name="Optimizer")
|
|
38
|
+
def optimize_parameters(
|
|
39
|
+
*,
|
|
40
|
+
base_cfg: Config,
|
|
41
|
+
full_space: Dict[str, SearchSpace],
|
|
42
|
+
optimizer_config: OptimizerConfig,
|
|
43
|
+
opt_run_config: OptimizerRunConfig,
|
|
44
|
+
) -> Config:
|
|
45
|
+
"""Tune all *non-prompt* hyper-parameters and persist the best config."""
|
|
46
|
+
space = {k: v for k, v in full_space.items() if not v.is_prompt}
|
|
47
|
+
|
|
48
|
+
# Ensure output_path is not None
|
|
49
|
+
if optimizer_config.output_path is None:
|
|
50
|
+
raise ValueError("optimizer_config.output_path cannot be None")
|
|
51
|
+
out_dir = optimizer_config.output_path
|
|
52
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
53
|
+
|
|
54
|
+
# Ensure eval_metrics is not None
|
|
55
|
+
if optimizer_config.eval_metrics is None:
|
|
56
|
+
raise ValueError("optimizer_config.eval_metrics cannot be None")
|
|
57
|
+
|
|
58
|
+
metric_cfg = optimizer_config.eval_metrics
|
|
59
|
+
directions = [v.direction for v in metric_cfg.values()]
|
|
60
|
+
eval_metrics = [v.evaluator_name for v in metric_cfg.values()]
|
|
61
|
+
weights = [v.weight for v in metric_cfg.values()]
|
|
62
|
+
|
|
63
|
+
# Create appropriate sampler based on configuration
|
|
64
|
+
sampler_type = optimizer_config.numeric.sampler
|
|
65
|
+
|
|
66
|
+
if sampler_type == SamplerType.GRID:
|
|
67
|
+
# For grid search, convert the existing space to value sequences
|
|
68
|
+
grid_search_space = {param_name: search_space.to_grid_values() for param_name, search_space in space.items()}
|
|
69
|
+
sampler = optuna.samplers.GridSampler(grid_search_space)
|
|
70
|
+
logger.info("Using Grid sampler for numeric optimization")
|
|
71
|
+
else:
|
|
72
|
+
# None or BAYESIAN: let Optuna choose defaults
|
|
73
|
+
sampler = None
|
|
74
|
+
logger.info(
|
|
75
|
+
"Using Optuna default sampler types: TPESampler for single-objective, NSGAIISampler for multi-objective")
|
|
76
|
+
|
|
77
|
+
study = optuna.create_study(directions=directions, sampler=sampler)
|
|
78
|
+
|
|
79
|
+
# Create output directory for intermediate files
|
|
80
|
+
out_dir = optimizer_config.output_path
|
|
81
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
82
|
+
|
|
83
|
+
async def _run_eval(runner: EvaluationRun):
|
|
84
|
+
return await runner.run_and_evaluate()
|
|
85
|
+
|
|
86
|
+
def _objective(trial: optuna.Trial):
|
|
87
|
+
reps = max(1, getattr(optimizer_config, "reps_per_param_set", 1))
|
|
88
|
+
|
|
89
|
+
# build trial config
|
|
90
|
+
suggestions = {p: spec.suggest(trial, p) for p, spec in space.items()}
|
|
91
|
+
cfg_trial = apply_suggestions(base_cfg, suggestions)
|
|
92
|
+
|
|
93
|
+
async def _single_eval(trial_idx: int) -> list[float]: # noqa: ARG001
|
|
94
|
+
eval_cfg = EvaluationRunConfig(
|
|
95
|
+
config_file=cfg_trial,
|
|
96
|
+
dataset=opt_run_config.dataset,
|
|
97
|
+
result_json_path=opt_run_config.result_json_path,
|
|
98
|
+
endpoint=opt_run_config.endpoint,
|
|
99
|
+
endpoint_timeout=opt_run_config.endpoint_timeout,
|
|
100
|
+
)
|
|
101
|
+
scores = await _run_eval(EvaluationRun(config=eval_cfg))
|
|
102
|
+
values = []
|
|
103
|
+
for metric_name in eval_metrics:
|
|
104
|
+
metric = next(r[1] for r in scores.evaluation_results if r[0] == metric_name)
|
|
105
|
+
values.append(metric.average_score)
|
|
106
|
+
|
|
107
|
+
return values
|
|
108
|
+
|
|
109
|
+
# Create tasks for all evaluations
|
|
110
|
+
async def _run_all_evals():
|
|
111
|
+
tasks = [_single_eval(i) for i in range(reps)]
|
|
112
|
+
return await asyncio.gather(*tasks)
|
|
113
|
+
|
|
114
|
+
# Calculate padding width based on total number of trials
|
|
115
|
+
trial_id_width = len(str(max(0, optimizer_config.numeric.n_trials - 1)))
|
|
116
|
+
trial_id_padded = f"{trial.number:0{trial_id_width}d}"
|
|
117
|
+
with (out_dir / f"config_numeric_trial_{trial_id_padded}.yml").open("w") as fh:
|
|
118
|
+
yaml.dump(cfg_trial.model_dump(), fh)
|
|
119
|
+
|
|
120
|
+
all_scores = asyncio.run(_run_all_evals())
|
|
121
|
+
# Persist raw per‑repetition scores so they appear in `trials_dataframe`.
|
|
122
|
+
trial.set_user_attr("rep_scores", all_scores)
|
|
123
|
+
return [sum(run[i] for run in all_scores) / reps for i in range(len(eval_metrics))]
|
|
124
|
+
|
|
125
|
+
logger.info("Starting numeric / enum parameter optimization...")
|
|
126
|
+
study.optimize(_objective, n_trials=optimizer_config.numeric.n_trials)
|
|
127
|
+
logger.info("Numeric optimization finished")
|
|
128
|
+
|
|
129
|
+
best_params = pick_trial(
|
|
130
|
+
study=study,
|
|
131
|
+
mode=optimizer_config.multi_objective_combination_mode,
|
|
132
|
+
weights=weights,
|
|
133
|
+
).params
|
|
134
|
+
tuned_cfg = apply_suggestions(base_cfg, best_params)
|
|
135
|
+
|
|
136
|
+
# Save final results (out_dir already created and defined above)
|
|
137
|
+
with (out_dir / "optimized_config.yml").open("w") as fh:
|
|
138
|
+
yaml.dump(tuned_cfg.model_dump(mode='json'), fh)
|
|
139
|
+
with (out_dir / "trials_dataframe_params.csv").open("w") as fh:
|
|
140
|
+
# Export full trials DataFrame (values, params, timings, etc.).
|
|
141
|
+
df = study.trials_dataframe()
|
|
142
|
+
|
|
143
|
+
# Rename values_X columns to actual metric names
|
|
144
|
+
metric_names = list(metric_cfg.keys())
|
|
145
|
+
rename_mapping = {}
|
|
146
|
+
for i, metric_name in enumerate(metric_names):
|
|
147
|
+
old_col = f"values_{i}"
|
|
148
|
+
if old_col in df.columns:
|
|
149
|
+
rename_mapping[old_col] = f"values_{metric_name}"
|
|
150
|
+
if rename_mapping:
|
|
151
|
+
df = df.rename(columns=rename_mapping)
|
|
152
|
+
|
|
153
|
+
# Normalise rep_scores column naming for convenience.
|
|
154
|
+
if "user_attrs_rep_scores" in df.columns and "rep_scores" not in df.columns:
|
|
155
|
+
df = df.rename(columns={"user_attrs_rep_scores": "rep_scores"})
|
|
156
|
+
elif "user_attrs" in df.columns and "rep_scores" not in df.columns:
|
|
157
|
+
# Some Optuna versions return a dict in a single user_attrs column.
|
|
158
|
+
df["rep_scores"] = df["user_attrs"].apply(lambda d: d.get("rep_scores") if isinstance(d, dict) else None)
|
|
159
|
+
df = df.drop(columns=["user_attrs"])
|
|
160
|
+
|
|
161
|
+
# Get Pareto optimal trial numbers from Optuna study
|
|
162
|
+
pareto_trials = study.best_trials
|
|
163
|
+
pareto_trial_numbers = {trial.number for trial in pareto_trials}
|
|
164
|
+
# Add boolean column indicating if trial is Pareto optimal
|
|
165
|
+
df["pareto_optimal"] = df["number"].isin(pareto_trial_numbers)
|
|
166
|
+
|
|
167
|
+
df.to_csv(fh, index=False)
|
|
168
|
+
|
|
169
|
+
# Generate Pareto front visualizations
|
|
170
|
+
try:
|
|
171
|
+
from nat.profiler.parameter_optimization.pareto_visualizer import create_pareto_visualization
|
|
172
|
+
logger.info("Generating Pareto front visualizations...")
|
|
173
|
+
create_pareto_visualization(
|
|
174
|
+
data_source=study,
|
|
175
|
+
metric_names=eval_metrics,
|
|
176
|
+
directions=directions,
|
|
177
|
+
output_dir=out_dir / "plots",
|
|
178
|
+
title_prefix="Parameter Optimization",
|
|
179
|
+
show_plots=False # Don't show plots in automated runs
|
|
180
|
+
)
|
|
181
|
+
logger.info("Pareto visualizations saved to: %s", out_dir / "plots")
|
|
182
|
+
except ImportError as ie:
|
|
183
|
+
logger.warning("Could not import visualization dependencies: %s. "
|
|
184
|
+
"Have you installed nvidia-nat-profiling?",
|
|
185
|
+
ie)
|
|
186
|
+
except Exception as e:
|
|
187
|
+
logger.warning("Failed to generate visualizations: %s", e)
|
|
188
|
+
|
|
189
|
+
return tuned_cfg
|