nvidia-nat 1.3.0a20250909__py3-none-any.whl → 1.3.0a20250917__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +11 -6
- nat/agent/dual_node.py +2 -2
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +1 -1
- nat/agent/react_agent/register.py +17 -7
- nat/agent/reasoning_agent/reasoning_agent.py +6 -1
- nat/agent/register.py +2 -0
- nat/agent/rewoo_agent/agent.py +6 -3
- nat/agent/rewoo_agent/register.py +16 -10
- nat/agent/router_agent/__init__.py +0 -0
- nat/agent/router_agent/agent.py +329 -0
- nat/agent/router_agent/prompt.py +48 -0
- nat/agent/router_agent/register.py +97 -0
- nat/agent/tool_calling_agent/agent.py +69 -7
- nat/agent/tool_calling_agent/register.py +17 -9
- nat/builder/builder.py +27 -4
- nat/builder/component_utils.py +7 -3
- nat/builder/function.py +167 -0
- nat/builder/function_info.py +1 -1
- nat/builder/workflow.py +5 -0
- nat/builder/workflow_builder.py +213 -16
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +5 -8
- nat/cli/entrypoint.py +2 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +71 -0
- nat/data_models/api_server.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +40 -16
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +8 -0
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/temperature_mixin.py +4 -3
- nat/data_models/top_p_mixin.py +4 -3
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/eval/config.py +1 -1
- nat/eval/evaluate.py +5 -1
- nat/eval/register.py +4 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
- nat/front_ends/fastapi/dask_client_mixin.py +43 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +14 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +111 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +3 -2
- nat/llm/aws_bedrock_llm.py +15 -4
- nat/llm/nim_llm.py +14 -3
- nat/llm/openai_llm.py +8 -1
- nat/observability/exporter/processing_exporter.py +29 -55
- nat/observability/mixin/redaction_config_mixin.py +5 -4
- nat/observability/mixin/tagging_config_mixin.py +26 -14
- nat/observability/mixin/type_introspection_mixin.py +401 -107
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +21 -14
- nat/profiler/decorators/framework_wrapper.py +9 -6
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +149 -0
- nat/profiler/parameter_optimization/parameter_selection.py +108 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/utils.py +3 -1
- nat/tool/chat_completion.py +5 -2
- nat/tool/document_search.py +1 -1
- nat/tool/github_tools.py +450 -0
- nat/tool/register.py +2 -7
- nat/utils/callable_utils.py +70 -0
- nat/utils/exception_handlers/automatic_retries.py +103 -48
- nat/utils/type_utils.py +4 -0
- {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/METADATA +8 -1
- {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/RECORD +94 -74
- nat/observability/processor/header_redaction_processor.py +0 -123
- nat/observability/processor/redaction_processor.py +0 -77
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
|
|
20
|
+
from nat.data_models.optimizer import OptimizerRunConfig
|
|
21
|
+
from nat.experimental.decorators.experimental_warning_decorator import experimental
|
|
22
|
+
from nat.profiler.parameter_optimization.optimizable_utils import walk_optimizables
|
|
23
|
+
from nat.profiler.parameter_optimization.parameter_optimizer import optimize_parameters
|
|
24
|
+
from nat.profiler.parameter_optimization.prompt_optimizer import optimize_prompts
|
|
25
|
+
from nat.runtime.loader import load_config
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@experimental(feature_name="Optimizer")
|
|
31
|
+
async def optimize_config(opt_run_config: OptimizerRunConfig):
|
|
32
|
+
"""Entry-point called by the CLI or runtime."""
|
|
33
|
+
# ---------------- 1. load / normalise ---------------- #
|
|
34
|
+
if not isinstance(opt_run_config.config_file, BaseModel):
|
|
35
|
+
from nat.data_models.config import Config # guarded import
|
|
36
|
+
base_cfg: Config = load_config(config_file=opt_run_config.config_file)
|
|
37
|
+
else:
|
|
38
|
+
base_cfg = opt_run_config.config_file # already validated
|
|
39
|
+
|
|
40
|
+
# ---------------- 2. discover search space ----------- #
|
|
41
|
+
full_space = walk_optimizables(base_cfg)
|
|
42
|
+
if not full_space:
|
|
43
|
+
logger.warning("No optimizable parameters found in the configuration. "
|
|
44
|
+
"Skipping optimization.")
|
|
45
|
+
return base_cfg
|
|
46
|
+
|
|
47
|
+
# ---------------- 3. numeric / enum tuning ----------- #
|
|
48
|
+
tuned_cfg = base_cfg
|
|
49
|
+
if base_cfg.optimizer.numeric.enabled:
|
|
50
|
+
tuned_cfg = optimize_parameters(
|
|
51
|
+
base_cfg=base_cfg,
|
|
52
|
+
full_space=full_space,
|
|
53
|
+
optimizer_config=base_cfg.optimizer,
|
|
54
|
+
opt_run_config=opt_run_config,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# ---------------- 4. prompt optimization ------------- #
|
|
58
|
+
if base_cfg.optimizer.prompt.enabled:
|
|
59
|
+
await optimize_prompts(
|
|
60
|
+
base_cfg=tuned_cfg,
|
|
61
|
+
full_space=full_space,
|
|
62
|
+
optimizer_config=base_cfg.optimizer,
|
|
63
|
+
opt_run_config=opt_run_config,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
logger.info("All optimization phases complete.")
|
|
67
|
+
return tuned_cfg
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
from collections.abc import Mapping as Dict
|
|
19
|
+
|
|
20
|
+
import optuna
|
|
21
|
+
import yaml
|
|
22
|
+
|
|
23
|
+
from nat.data_models.config import Config
|
|
24
|
+
from nat.data_models.optimizable import SearchSpace
|
|
25
|
+
from nat.data_models.optimizer import OptimizerConfig
|
|
26
|
+
from nat.data_models.optimizer import OptimizerRunConfig
|
|
27
|
+
from nat.eval.evaluate import EvaluationRun
|
|
28
|
+
from nat.eval.evaluate import EvaluationRunConfig
|
|
29
|
+
from nat.experimental.decorators.experimental_warning_decorator import experimental
|
|
30
|
+
from nat.profiler.parameter_optimization.parameter_selection import pick_trial
|
|
31
|
+
from nat.profiler.parameter_optimization.pareto_visualizer import create_pareto_visualization
|
|
32
|
+
from nat.profiler.parameter_optimization.update_helpers import apply_suggestions
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@experimental(feature_name="Optimizer")
|
|
38
|
+
def optimize_parameters(
|
|
39
|
+
*,
|
|
40
|
+
base_cfg: Config,
|
|
41
|
+
full_space: Dict[str, SearchSpace],
|
|
42
|
+
optimizer_config: OptimizerConfig,
|
|
43
|
+
opt_run_config: OptimizerRunConfig,
|
|
44
|
+
) -> Config:
|
|
45
|
+
"""Tune all *non-prompt* hyper-parameters and persist the best config."""
|
|
46
|
+
space = {k: v for k, v in full_space.items() if not v.is_prompt}
|
|
47
|
+
|
|
48
|
+
# Ensure output_path is not None
|
|
49
|
+
if optimizer_config.output_path is None:
|
|
50
|
+
raise ValueError("optimizer_config.output_path cannot be None")
|
|
51
|
+
out_dir = optimizer_config.output_path
|
|
52
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
53
|
+
|
|
54
|
+
# Ensure eval_metrics is not None
|
|
55
|
+
if optimizer_config.eval_metrics is None:
|
|
56
|
+
raise ValueError("optimizer_config.eval_metrics cannot be None")
|
|
57
|
+
|
|
58
|
+
metric_cfg = optimizer_config.eval_metrics
|
|
59
|
+
directions = [v.direction for v in metric_cfg.values()]
|
|
60
|
+
eval_metrics = [v.evaluator_name for v in metric_cfg.values()]
|
|
61
|
+
weights = [v.weight for v in metric_cfg.values()]
|
|
62
|
+
|
|
63
|
+
study = optuna.create_study(directions=directions)
|
|
64
|
+
|
|
65
|
+
# Create output directory for intermediate files
|
|
66
|
+
out_dir = optimizer_config.output_path
|
|
67
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
68
|
+
|
|
69
|
+
async def _run_eval(runner: EvaluationRun):
|
|
70
|
+
return await runner.run_and_evaluate()
|
|
71
|
+
|
|
72
|
+
def _objective(trial: optuna.Trial):
|
|
73
|
+
reps = max(1, getattr(optimizer_config, "reps_per_param_set", 1))
|
|
74
|
+
|
|
75
|
+
# build trial config
|
|
76
|
+
suggestions = {p: spec.suggest(trial, p) for p, spec in space.items()}
|
|
77
|
+
cfg_trial = apply_suggestions(base_cfg, suggestions)
|
|
78
|
+
|
|
79
|
+
async def _single_eval(trial_idx: int) -> list[float]: # noqa: ARG001
|
|
80
|
+
eval_cfg = EvaluationRunConfig(
|
|
81
|
+
config_file=cfg_trial,
|
|
82
|
+
dataset=opt_run_config.dataset,
|
|
83
|
+
result_json_path=opt_run_config.result_json_path,
|
|
84
|
+
endpoint=opt_run_config.endpoint,
|
|
85
|
+
endpoint_timeout=opt_run_config.endpoint_timeout,
|
|
86
|
+
)
|
|
87
|
+
scores = await _run_eval(EvaluationRun(config=eval_cfg))
|
|
88
|
+
values = []
|
|
89
|
+
for metric_name in eval_metrics:
|
|
90
|
+
metric = next(r[1] for r in scores.evaluation_results if r[0] == metric_name)
|
|
91
|
+
values.append(metric.average_score)
|
|
92
|
+
|
|
93
|
+
return values
|
|
94
|
+
|
|
95
|
+
# Create tasks for all evaluations
|
|
96
|
+
async def _run_all_evals():
|
|
97
|
+
tasks = [_single_eval(i) for i in range(reps)]
|
|
98
|
+
return await asyncio.gather(*tasks)
|
|
99
|
+
|
|
100
|
+
with (out_dir / f"config_numeric_trial_{trial._trial_id}.yml").open("w") as fh:
|
|
101
|
+
yaml.dump(cfg_trial.model_dump(), fh)
|
|
102
|
+
|
|
103
|
+
all_scores = asyncio.run(_run_all_evals())
|
|
104
|
+
# Persist raw per‑repetition scores so they appear in `trials_dataframe`.
|
|
105
|
+
trial.set_user_attr("rep_scores", all_scores)
|
|
106
|
+
return [sum(run[i] for run in all_scores) / reps for i in range(len(eval_metrics))]
|
|
107
|
+
|
|
108
|
+
logger.info("Starting numeric / enum parameter optimization...")
|
|
109
|
+
study.optimize(_objective, n_trials=optimizer_config.numeric.n_trials)
|
|
110
|
+
logger.info("Numeric optimization finished")
|
|
111
|
+
|
|
112
|
+
best_params = pick_trial(
|
|
113
|
+
study=study,
|
|
114
|
+
mode=optimizer_config.multi_objective_combination_mode,
|
|
115
|
+
weights=weights,
|
|
116
|
+
).params
|
|
117
|
+
tuned_cfg = apply_suggestions(base_cfg, best_params)
|
|
118
|
+
|
|
119
|
+
# Save final results (out_dir already created and defined above)
|
|
120
|
+
with (out_dir / "optimized_config.yml").open("w") as fh:
|
|
121
|
+
yaml.dump(tuned_cfg.model_dump(), fh)
|
|
122
|
+
with (out_dir / "trials_dataframe_params.csv").open("w") as fh:
|
|
123
|
+
# Export full trials DataFrame (values, params, timings, etc.).
|
|
124
|
+
df = study.trials_dataframe()
|
|
125
|
+
# Normalise rep_scores column naming for convenience.
|
|
126
|
+
if "user_attrs_rep_scores" in df.columns and "rep_scores" not in df.columns:
|
|
127
|
+
df = df.rename(columns={"user_attrs_rep_scores": "rep_scores"})
|
|
128
|
+
elif "user_attrs" in df.columns and "rep_scores" not in df.columns:
|
|
129
|
+
# Some Optuna versions return a dict in a single user_attrs column.
|
|
130
|
+
df["rep_scores"] = df["user_attrs"].apply(lambda d: d.get("rep_scores") if isinstance(d, dict) else None)
|
|
131
|
+
df = df.drop(columns=["user_attrs"])
|
|
132
|
+
df.to_csv(fh, index=False)
|
|
133
|
+
|
|
134
|
+
# Generate Pareto front visualizations
|
|
135
|
+
try:
|
|
136
|
+
logger.info("Generating Pareto front visualizations...")
|
|
137
|
+
create_pareto_visualization(
|
|
138
|
+
data_source=study,
|
|
139
|
+
metric_names=eval_metrics,
|
|
140
|
+
directions=directions,
|
|
141
|
+
output_dir=out_dir / "plots",
|
|
142
|
+
title_prefix="Parameter Optimization",
|
|
143
|
+
show_plots=False # Don't show plots in automated runs
|
|
144
|
+
)
|
|
145
|
+
logger.info("Pareto visualizations saved to: %s", out_dir / "plots")
|
|
146
|
+
except Exception as e:
|
|
147
|
+
logger.warning("Failed to generate visualizations: %s", e)
|
|
148
|
+
|
|
149
|
+
return tuned_cfg
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import Optional
|
|
17
|
+
from typing import Sequence
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import optuna
|
|
21
|
+
from optuna._hypervolume import compute_hypervolume
|
|
22
|
+
from optuna.study import Study
|
|
23
|
+
from optuna.study import StudyDirection
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# ---------- helper ----------
|
|
27
|
+
def _to_minimisation_matrix(
|
|
28
|
+
trials: Sequence[optuna.trial.FrozenTrial],
|
|
29
|
+
directions: Sequence[StudyDirection],
|
|
30
|
+
) -> np.ndarray:
|
|
31
|
+
"""Return array (n_trials × n_objectives) where **all** objectives are ‘smaller-is-better’."""
|
|
32
|
+
vals = np.asarray([t.values for t in trials], dtype=float)
|
|
33
|
+
for j, d in enumerate(directions):
|
|
34
|
+
if d == StudyDirection.MAXIMIZE:
|
|
35
|
+
vals[:, j] *= -1.0 # flip sign
|
|
36
|
+
return vals
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# ---------- public API ----------
|
|
40
|
+
def pick_trial(
|
|
41
|
+
study: Study,
|
|
42
|
+
mode: str = "harmonic",
|
|
43
|
+
*,
|
|
44
|
+
weights: Optional[Sequence[float]] = None,
|
|
45
|
+
ref_point: Optional[Sequence[float]] = None,
|
|
46
|
+
eps: float = 1e-12,
|
|
47
|
+
) -> optuna.trial.FrozenTrial:
|
|
48
|
+
"""
|
|
49
|
+
Collapse Optuna’s Pareto front (`study.best_trials`) to a single “best compromise”.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
study : completed **multi-objective** Optuna study
|
|
54
|
+
mode : {"harmonic", "sum", "chebyshev", "hypervolume"}
|
|
55
|
+
weights : per-objective weights (used only for "sum")
|
|
56
|
+
ref_point : reference point for hyper-volume (defaults to ones after normalisation)
|
|
57
|
+
eps : tiny value to avoid division by zero
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
optuna.trial.FrozenTrial
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
# ---- 1. Pareto front ----
|
|
65
|
+
front = study.best_trials
|
|
66
|
+
if not front:
|
|
67
|
+
raise ValueError("`study.best_trials` is empty – no Pareto-optimal trials found.")
|
|
68
|
+
|
|
69
|
+
# ---- 2. Convert & normalise objectives ----
|
|
70
|
+
vals = _to_minimisation_matrix(front, study.directions) # smaller is better
|
|
71
|
+
span = np.ptp(vals, axis=0)
|
|
72
|
+
norm = (vals - vals.min(axis=0)) / (span + eps) # 0 = best, 1 = worst
|
|
73
|
+
|
|
74
|
+
# ---- 3. Scalarise according to chosen mode ----
|
|
75
|
+
mode = mode.lower()
|
|
76
|
+
|
|
77
|
+
if mode == "harmonic":
|
|
78
|
+
hmean = norm.shape[1] / (1.0 / (norm + eps)).sum(axis=1)
|
|
79
|
+
best_idx = hmean.argmin() # lower = better
|
|
80
|
+
|
|
81
|
+
elif mode == "sum":
|
|
82
|
+
w = np.ones(norm.shape[1]) if weights is None else np.asarray(weights, float)
|
|
83
|
+
if w.size != norm.shape[1]:
|
|
84
|
+
raise ValueError("`weights` length must equal number of objectives.")
|
|
85
|
+
score = norm @ w
|
|
86
|
+
best_idx = score.argmin()
|
|
87
|
+
|
|
88
|
+
elif mode == "chebyshev":
|
|
89
|
+
score = norm.max(axis=1) # worst dimension
|
|
90
|
+
best_idx = score.argmin()
|
|
91
|
+
|
|
92
|
+
elif mode == "hypervolume":
|
|
93
|
+
# Hyper-volume assumes points are *below* the reference point (minimisation space).
|
|
94
|
+
if len(front) == 0:
|
|
95
|
+
raise ValueError("Pareto front is empty - no trials to select from")
|
|
96
|
+
elif len(front) == 1:
|
|
97
|
+
best_idx = 0
|
|
98
|
+
else:
|
|
99
|
+
rp = np.ones(norm.shape[1]) if ref_point is None else np.asarray(ref_point, float)
|
|
100
|
+
base_hv = compute_hypervolume(norm, rp)
|
|
101
|
+
contrib = np.array([base_hv - compute_hypervolume(np.delete(norm, i, 0), rp) for i in range(len(front))])
|
|
102
|
+
best_idx = contrib.argmax() # bigger contribution wins
|
|
103
|
+
|
|
104
|
+
else:
|
|
105
|
+
raise ValueError(f"Unknown mode '{mode}'. Choose from "
|
|
106
|
+
"'harmonic', 'sum', 'chebyshev', 'hypervolume'.")
|
|
107
|
+
|
|
108
|
+
return front[best_idx]
|