nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +41 -21
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +46 -26
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +46 -11
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +9 -13
- nat/cli/entrypoint.py +8 -10
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +10 -10
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +17 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +3 -2
- nat/runtime/session.py +43 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from typing import get_args
|
|
18
|
+
from typing import get_origin
|
|
19
|
+
|
|
20
|
+
from pydantic import BaseModel
|
|
21
|
+
|
|
22
|
+
from nat.data_models.optimizable import SearchSpace
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def walk_optimizables(obj: BaseModel, path: str = "") -> dict[str, SearchSpace]:
|
|
28
|
+
"""
|
|
29
|
+
Recursively build ``{flattened.path: SearchSpace}`` for every optimizable
|
|
30
|
+
field inside *obj*.
|
|
31
|
+
|
|
32
|
+
* Honors ``optimizable_params`` on any model that mixes in
|
|
33
|
+
``OptimizableMixin`` – only listed fields are kept.
|
|
34
|
+
* If a model contains optimizable fields **but** omits
|
|
35
|
+
``optimizable_params``, we emit a warning and skip them.
|
|
36
|
+
"""
|
|
37
|
+
spaces: dict[str, SearchSpace] = {}
|
|
38
|
+
|
|
39
|
+
allowed_params_raw = getattr(obj, "optimizable_params", None)
|
|
40
|
+
allowed_params = set(allowed_params_raw) if allowed_params_raw is not None else None
|
|
41
|
+
overrides = getattr(obj, "search_space", {}) or {}
|
|
42
|
+
has_optimizable_flag = False
|
|
43
|
+
|
|
44
|
+
for name, fld in obj.model_fields.items():
|
|
45
|
+
full = f"{path}.{name}" if path else name
|
|
46
|
+
extra = fld.json_schema_extra or {}
|
|
47
|
+
|
|
48
|
+
is_field_optimizable = extra.get("optimizable", False) or name in overrides
|
|
49
|
+
has_optimizable_flag = has_optimizable_flag or is_field_optimizable
|
|
50
|
+
|
|
51
|
+
# honour allow-list
|
|
52
|
+
if allowed_params is not None and name not in allowed_params:
|
|
53
|
+
continue
|
|
54
|
+
|
|
55
|
+
# 1. plain optimizable field or override from config
|
|
56
|
+
if is_field_optimizable:
|
|
57
|
+
space = overrides.get(name, extra.get("search_space"))
|
|
58
|
+
if space is None:
|
|
59
|
+
logger.error(
|
|
60
|
+
"Field %s is marked optimizable but no search space was provided.",
|
|
61
|
+
full,
|
|
62
|
+
)
|
|
63
|
+
raise ValueError(f"Field {full} is marked optimizable but no search space was provided")
|
|
64
|
+
spaces[full] = space
|
|
65
|
+
|
|
66
|
+
value = getattr(obj, name, None)
|
|
67
|
+
|
|
68
|
+
# 2. nested BaseModel
|
|
69
|
+
if isinstance(value, BaseModel):
|
|
70
|
+
spaces.update(walk_optimizables(value, full))
|
|
71
|
+
|
|
72
|
+
# 3. dict[str, BaseModel] container
|
|
73
|
+
elif isinstance(value, dict):
|
|
74
|
+
for key, subval in value.items():
|
|
75
|
+
if isinstance(subval, BaseModel):
|
|
76
|
+
spaces.update(walk_optimizables(subval, f"{full}.{key}"))
|
|
77
|
+
|
|
78
|
+
# 4. static-type fallback for class-level annotations
|
|
79
|
+
elif isinstance(obj, type):
|
|
80
|
+
ann = fld.annotation
|
|
81
|
+
if get_origin(ann) in (dict, dict):
|
|
82
|
+
_, val_t = get_args(ann) or (None, None)
|
|
83
|
+
if isinstance(val_t, type) and issubclass(val_t, BaseModel):
|
|
84
|
+
if allowed_params is None or name in allowed_params:
|
|
85
|
+
spaces[f"{full}.*"] = SearchSpace(low=None, high=None) # sentinel
|
|
86
|
+
|
|
87
|
+
if allowed_params is None and has_optimizable_flag:
|
|
88
|
+
logger.warning(
|
|
89
|
+
"Model %s contains optimizable fields but no `optimizable_params` "
|
|
90
|
+
"were defined; these fields will be ignored.",
|
|
91
|
+
obj.__class__.__name__,
|
|
92
|
+
)
|
|
93
|
+
return spaces
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
|
|
20
|
+
from nat.data_models.optimizer import OptimizerRunConfig
|
|
21
|
+
from nat.experimental.decorators.experimental_warning_decorator import experimental
|
|
22
|
+
from nat.profiler.parameter_optimization.optimizable_utils import walk_optimizables
|
|
23
|
+
from nat.profiler.parameter_optimization.parameter_optimizer import optimize_parameters
|
|
24
|
+
from nat.profiler.parameter_optimization.prompt_optimizer import optimize_prompts
|
|
25
|
+
from nat.runtime.loader import load_config
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@experimental(feature_name="Optimizer")
|
|
31
|
+
async def optimize_config(opt_run_config: OptimizerRunConfig):
|
|
32
|
+
"""Entry-point called by the CLI or runtime."""
|
|
33
|
+
# ---------------- 1. load / normalise ---------------- #
|
|
34
|
+
if not isinstance(opt_run_config.config_file, BaseModel):
|
|
35
|
+
from nat.data_models.config import Config # guarded import
|
|
36
|
+
base_cfg: Config = load_config(config_file=opt_run_config.config_file)
|
|
37
|
+
else:
|
|
38
|
+
base_cfg = opt_run_config.config_file # already validated
|
|
39
|
+
|
|
40
|
+
# ---------------- 2. discover search space ----------- #
|
|
41
|
+
full_space = walk_optimizables(base_cfg)
|
|
42
|
+
if not full_space:
|
|
43
|
+
logger.warning("No optimizable parameters found in the configuration. "
|
|
44
|
+
"Skipping optimization.")
|
|
45
|
+
return base_cfg
|
|
46
|
+
|
|
47
|
+
# ---------------- 3. numeric / enum tuning ----------- #
|
|
48
|
+
tuned_cfg = base_cfg
|
|
49
|
+
if base_cfg.optimizer.numeric.enabled:
|
|
50
|
+
tuned_cfg = optimize_parameters(
|
|
51
|
+
base_cfg=base_cfg,
|
|
52
|
+
full_space=full_space,
|
|
53
|
+
optimizer_config=base_cfg.optimizer,
|
|
54
|
+
opt_run_config=opt_run_config,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# ---------------- 4. prompt optimization ------------- #
|
|
58
|
+
if base_cfg.optimizer.prompt.enabled:
|
|
59
|
+
await optimize_prompts(
|
|
60
|
+
base_cfg=tuned_cfg,
|
|
61
|
+
full_space=full_space,
|
|
62
|
+
optimizer_config=base_cfg.optimizer,
|
|
63
|
+
opt_run_config=opt_run_config,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
logger.info("All optimization phases complete.")
|
|
67
|
+
return tuned_cfg
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
from collections.abc import Mapping as Dict
|
|
19
|
+
|
|
20
|
+
import optuna
|
|
21
|
+
import yaml
|
|
22
|
+
|
|
23
|
+
from nat.data_models.config import Config
|
|
24
|
+
from nat.data_models.optimizable import SearchSpace
|
|
25
|
+
from nat.data_models.optimizer import OptimizerConfig
|
|
26
|
+
from nat.data_models.optimizer import OptimizerRunConfig
|
|
27
|
+
from nat.eval.evaluate import EvaluationRun
|
|
28
|
+
from nat.eval.evaluate import EvaluationRunConfig
|
|
29
|
+
from nat.experimental.decorators.experimental_warning_decorator import experimental
|
|
30
|
+
from nat.profiler.parameter_optimization.parameter_selection import pick_trial
|
|
31
|
+
from nat.profiler.parameter_optimization.update_helpers import apply_suggestions
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@experimental(feature_name="Optimizer")
|
|
37
|
+
def optimize_parameters(
|
|
38
|
+
*,
|
|
39
|
+
base_cfg: Config,
|
|
40
|
+
full_space: Dict[str, SearchSpace],
|
|
41
|
+
optimizer_config: OptimizerConfig,
|
|
42
|
+
opt_run_config: OptimizerRunConfig,
|
|
43
|
+
) -> Config:
|
|
44
|
+
"""Tune all *non-prompt* hyper-parameters and persist the best config."""
|
|
45
|
+
space = {k: v for k, v in full_space.items() if not v.is_prompt}
|
|
46
|
+
|
|
47
|
+
# Ensure output_path is not None
|
|
48
|
+
if optimizer_config.output_path is None:
|
|
49
|
+
raise ValueError("optimizer_config.output_path cannot be None")
|
|
50
|
+
out_dir = optimizer_config.output_path
|
|
51
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
52
|
+
|
|
53
|
+
# Ensure eval_metrics is not None
|
|
54
|
+
if optimizer_config.eval_metrics is None:
|
|
55
|
+
raise ValueError("optimizer_config.eval_metrics cannot be None")
|
|
56
|
+
|
|
57
|
+
metric_cfg = optimizer_config.eval_metrics
|
|
58
|
+
directions = [v.direction for v in metric_cfg.values()]
|
|
59
|
+
eval_metrics = [v.evaluator_name for v in metric_cfg.values()]
|
|
60
|
+
weights = [v.weight for v in metric_cfg.values()]
|
|
61
|
+
|
|
62
|
+
study = optuna.create_study(directions=directions)
|
|
63
|
+
|
|
64
|
+
# Create output directory for intermediate files
|
|
65
|
+
out_dir = optimizer_config.output_path
|
|
66
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
67
|
+
|
|
68
|
+
async def _run_eval(runner: EvaluationRun):
|
|
69
|
+
return await runner.run_and_evaluate()
|
|
70
|
+
|
|
71
|
+
def _objective(trial: optuna.Trial):
|
|
72
|
+
reps = max(1, getattr(optimizer_config, "reps_per_param_set", 1))
|
|
73
|
+
|
|
74
|
+
# build trial config
|
|
75
|
+
suggestions = {p: spec.suggest(trial, p) for p, spec in space.items()}
|
|
76
|
+
cfg_trial = apply_suggestions(base_cfg, suggestions)
|
|
77
|
+
|
|
78
|
+
async def _single_eval(trial_idx: int) -> list[float]: # noqa: ARG001
|
|
79
|
+
eval_cfg = EvaluationRunConfig(
|
|
80
|
+
config_file=cfg_trial,
|
|
81
|
+
dataset=opt_run_config.dataset,
|
|
82
|
+
result_json_path=opt_run_config.result_json_path,
|
|
83
|
+
endpoint=opt_run_config.endpoint,
|
|
84
|
+
endpoint_timeout=opt_run_config.endpoint_timeout,
|
|
85
|
+
)
|
|
86
|
+
scores = await _run_eval(EvaluationRun(config=eval_cfg))
|
|
87
|
+
values = []
|
|
88
|
+
for metric_name in eval_metrics:
|
|
89
|
+
metric = next(r[1] for r in scores.evaluation_results if r[0] == metric_name)
|
|
90
|
+
values.append(metric.average_score)
|
|
91
|
+
|
|
92
|
+
return values
|
|
93
|
+
|
|
94
|
+
# Create tasks for all evaluations
|
|
95
|
+
async def _run_all_evals():
|
|
96
|
+
tasks = [_single_eval(i) for i in range(reps)]
|
|
97
|
+
return await asyncio.gather(*tasks)
|
|
98
|
+
|
|
99
|
+
with (out_dir / f"config_numeric_trial_{trial._trial_id}.yml").open("w") as fh:
|
|
100
|
+
yaml.dump(cfg_trial.model_dump(), fh)
|
|
101
|
+
|
|
102
|
+
all_scores = asyncio.run(_run_all_evals())
|
|
103
|
+
# Persist raw per‑repetition scores so they appear in `trials_dataframe`.
|
|
104
|
+
trial.set_user_attr("rep_scores", all_scores)
|
|
105
|
+
return [sum(run[i] for run in all_scores) / reps for i in range(len(eval_metrics))]
|
|
106
|
+
|
|
107
|
+
logger.info("Starting numeric / enum parameter optimization...")
|
|
108
|
+
study.optimize(_objective, n_trials=optimizer_config.numeric.n_trials)
|
|
109
|
+
logger.info("Numeric optimization finished")
|
|
110
|
+
|
|
111
|
+
best_params = pick_trial(
|
|
112
|
+
study=study,
|
|
113
|
+
mode=optimizer_config.multi_objective_combination_mode,
|
|
114
|
+
weights=weights,
|
|
115
|
+
).params
|
|
116
|
+
tuned_cfg = apply_suggestions(base_cfg, best_params)
|
|
117
|
+
|
|
118
|
+
# Save final results (out_dir already created and defined above)
|
|
119
|
+
with (out_dir / "optimized_config.yml").open("w") as fh:
|
|
120
|
+
yaml.dump(tuned_cfg.model_dump(), fh)
|
|
121
|
+
with (out_dir / "trials_dataframe_params.csv").open("w") as fh:
|
|
122
|
+
# Export full trials DataFrame (values, params, timings, etc.).
|
|
123
|
+
df = study.trials_dataframe()
|
|
124
|
+
# Normalise rep_scores column naming for convenience.
|
|
125
|
+
if "user_attrs_rep_scores" in df.columns and "rep_scores" not in df.columns:
|
|
126
|
+
df = df.rename(columns={"user_attrs_rep_scores": "rep_scores"})
|
|
127
|
+
elif "user_attrs" in df.columns and "rep_scores" not in df.columns:
|
|
128
|
+
# Some Optuna versions return a dict in a single user_attrs column.
|
|
129
|
+
df["rep_scores"] = df["user_attrs"].apply(lambda d: d.get("rep_scores") if isinstance(d, dict) else None)
|
|
130
|
+
df = df.drop(columns=["user_attrs"])
|
|
131
|
+
df.to_csv(fh, index=False)
|
|
132
|
+
|
|
133
|
+
# Generate Pareto front visualizations
|
|
134
|
+
try:
|
|
135
|
+
from nat.profiler.parameter_optimization.pareto_visualizer import create_pareto_visualization
|
|
136
|
+
logger.info("Generating Pareto front visualizations...")
|
|
137
|
+
create_pareto_visualization(
|
|
138
|
+
data_source=study,
|
|
139
|
+
metric_names=eval_metrics,
|
|
140
|
+
directions=directions,
|
|
141
|
+
output_dir=out_dir / "plots",
|
|
142
|
+
title_prefix="Parameter Optimization",
|
|
143
|
+
show_plots=False # Don't show plots in automated runs
|
|
144
|
+
)
|
|
145
|
+
logger.info("Pareto visualizations saved to: %s", out_dir / "plots")
|
|
146
|
+
except ImportError as ie:
|
|
147
|
+
logger.warning("Could not import visualization dependencies: %s. "
|
|
148
|
+
"Have you installed nvidia-nat-profiling?",
|
|
149
|
+
ie)
|
|
150
|
+
except Exception as e:
|
|
151
|
+
logger.warning("Failed to generate visualizations: %s", e)
|
|
152
|
+
|
|
153
|
+
return tuned_cfg
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections.abc import Sequence
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import optuna
|
|
20
|
+
from optuna._hypervolume import compute_hypervolume
|
|
21
|
+
from optuna.study import Study
|
|
22
|
+
from optuna.study import StudyDirection
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# ---------- helper ----------
|
|
26
|
+
def _to_minimisation_matrix(
|
|
27
|
+
trials: Sequence[optuna.trial.FrozenTrial],
|
|
28
|
+
directions: Sequence[StudyDirection],
|
|
29
|
+
) -> np.ndarray:
|
|
30
|
+
"""Return array (n_trials × n_objectives) where **all** objectives are ‘smaller-is-better’."""
|
|
31
|
+
vals = np.asarray([t.values for t in trials], dtype=float)
|
|
32
|
+
for j, d in enumerate(directions):
|
|
33
|
+
if d == StudyDirection.MAXIMIZE:
|
|
34
|
+
vals[:, j] *= -1.0 # flip sign
|
|
35
|
+
return vals
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# ---------- public API ----------
|
|
39
|
+
def pick_trial(
|
|
40
|
+
study: Study,
|
|
41
|
+
mode: str = "harmonic",
|
|
42
|
+
*,
|
|
43
|
+
weights: Sequence[float] | None = None,
|
|
44
|
+
ref_point: Sequence[float] | None = None,
|
|
45
|
+
eps: float = 1e-12,
|
|
46
|
+
) -> optuna.trial.FrozenTrial:
|
|
47
|
+
"""
|
|
48
|
+
Collapse Optuna’s Pareto front (`study.best_trials`) to a single “best compromise”.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
study : completed **multi-objective** Optuna study
|
|
53
|
+
mode : {"harmonic", "sum", "chebyshev", "hypervolume"}
|
|
54
|
+
weights : per-objective weights (used only for "sum")
|
|
55
|
+
ref_point : reference point for hyper-volume (defaults to ones after normalisation)
|
|
56
|
+
eps : tiny value to avoid division by zero
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
-------
|
|
60
|
+
optuna.trial.FrozenTrial
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
# ---- 1. Pareto front ----
|
|
64
|
+
front = study.best_trials
|
|
65
|
+
if not front:
|
|
66
|
+
raise ValueError("`study.best_trials` is empty – no Pareto-optimal trials found.")
|
|
67
|
+
|
|
68
|
+
# ---- 2. Convert & normalise objectives ----
|
|
69
|
+
vals = _to_minimisation_matrix(front, study.directions) # smaller is better
|
|
70
|
+
span = np.ptp(vals, axis=0)
|
|
71
|
+
norm = (vals - vals.min(axis=0)) / (span + eps) # 0 = best, 1 = worst
|
|
72
|
+
|
|
73
|
+
# ---- 3. Scalarise according to chosen mode ----
|
|
74
|
+
mode = mode.lower()
|
|
75
|
+
|
|
76
|
+
if mode == "harmonic":
|
|
77
|
+
hmean = norm.shape[1] / (1.0 / (norm + eps)).sum(axis=1)
|
|
78
|
+
best_idx = hmean.argmin() # lower = better
|
|
79
|
+
|
|
80
|
+
elif mode == "sum":
|
|
81
|
+
w = np.ones(norm.shape[1]) if weights is None else np.asarray(weights, float)
|
|
82
|
+
if w.size != norm.shape[1]:
|
|
83
|
+
raise ValueError("`weights` length must equal number of objectives.")
|
|
84
|
+
score = norm @ w
|
|
85
|
+
best_idx = score.argmin()
|
|
86
|
+
|
|
87
|
+
elif mode == "chebyshev":
|
|
88
|
+
score = norm.max(axis=1) # worst dimension
|
|
89
|
+
best_idx = score.argmin()
|
|
90
|
+
|
|
91
|
+
elif mode == "hypervolume":
|
|
92
|
+
# Hyper-volume assumes points are *below* the reference point (minimisation space).
|
|
93
|
+
if len(front) == 0:
|
|
94
|
+
raise ValueError("Pareto front is empty - no trials to select from")
|
|
95
|
+
elif len(front) == 1:
|
|
96
|
+
best_idx = 0
|
|
97
|
+
else:
|
|
98
|
+
rp = np.ones(norm.shape[1]) if ref_point is None else np.asarray(ref_point, float)
|
|
99
|
+
base_hv = compute_hypervolume(norm, rp)
|
|
100
|
+
contrib = np.array([base_hv - compute_hypervolume(np.delete(norm, i, 0), rp) for i in range(len(front))])
|
|
101
|
+
best_idx = contrib.argmax() # bigger contribution wins
|
|
102
|
+
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError(f"Unknown mode '{mode}'. Choose from "
|
|
105
|
+
"'harmonic', 'sum', 'chebyshev', 'hypervolume'.")
|
|
106
|
+
|
|
107
|
+
return front[best_idx]
|