nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. nat/agent/base.py +13 -8
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +6 -5
  5. nat/agent/react_agent/register.py +49 -39
  6. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  7. nat/agent/register.py +2 -0
  8. nat/agent/responses_api_agent/__init__.py +14 -0
  9. nat/agent/responses_api_agent/register.py +126 -0
  10. nat/agent/rewoo_agent/agent.py +304 -117
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +51 -38
  13. nat/agent/tool_calling_agent/agent.py +75 -17
  14. nat/agent/tool_calling_agent/register.py +46 -23
  15. nat/authentication/api_key/api_key_auth_provider.py +6 -11
  16. nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
  17. nat/authentication/credential_validator/__init__.py +14 -0
  18. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  19. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  20. nat/authentication/interfaces.py +5 -2
  21. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  22. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
  23. nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
  24. nat/builder/builder.py +55 -23
  25. nat/builder/component_utils.py +9 -5
  26. nat/builder/context.py +54 -15
  27. nat/builder/eval_builder.py +14 -9
  28. nat/builder/framework_enum.py +1 -0
  29. nat/builder/front_end.py +1 -1
  30. nat/builder/function.py +370 -0
  31. nat/builder/function_info.py +1 -1
  32. nat/builder/intermediate_step_manager.py +38 -2
  33. nat/builder/workflow.py +5 -0
  34. nat/builder/workflow_builder.py +306 -54
  35. nat/cli/cli_utils/config_override.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/mcp/__init__.py +14 -0
  38. nat/cli/commands/mcp/mcp.py +986 -0
  39. nat/cli/commands/optimize.py +90 -0
  40. nat/cli/commands/start.py +1 -1
  41. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  42. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  43. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  44. nat/cli/commands/workflow/workflow_commands.py +60 -18
  45. nat/cli/entrypoint.py +15 -11
  46. nat/cli/main.py +3 -0
  47. nat/cli/register_workflow.py +38 -4
  48. nat/cli/type_registry.py +72 -1
  49. nat/control_flow/__init__.py +0 -0
  50. nat/control_flow/register.py +20 -0
  51. nat/control_flow/router_agent/__init__.py +0 -0
  52. nat/control_flow/router_agent/agent.py +329 -0
  53. nat/control_flow/router_agent/prompt.py +48 -0
  54. nat/control_flow/router_agent/register.py +91 -0
  55. nat/control_flow/sequential_executor.py +166 -0
  56. nat/data_models/agent.py +34 -0
  57. nat/data_models/api_server.py +199 -69
  58. nat/data_models/authentication.py +23 -9
  59. nat/data_models/common.py +47 -0
  60. nat/data_models/component.py +2 -0
  61. nat/data_models/component_ref.py +11 -0
  62. nat/data_models/config.py +41 -17
  63. nat/data_models/dataset_handler.py +4 -3
  64. nat/data_models/function.py +34 -0
  65. nat/data_models/function_dependencies.py +8 -0
  66. nat/data_models/intermediate_step.py +9 -1
  67. nat/data_models/llm.py +15 -1
  68. nat/data_models/openai_mcp.py +46 -0
  69. nat/data_models/optimizable.py +208 -0
  70. nat/data_models/optimizer.py +161 -0
  71. nat/data_models/span.py +41 -3
  72. nat/data_models/thinking_mixin.py +2 -2
  73. nat/embedder/azure_openai_embedder.py +2 -1
  74. nat/embedder/nim_embedder.py +3 -2
  75. nat/embedder/openai_embedder.py +3 -2
  76. nat/eval/config.py +1 -1
  77. nat/eval/dataset_handler/dataset_downloader.py +3 -2
  78. nat/eval/dataset_handler/dataset_filter.py +34 -2
  79. nat/eval/evaluate.py +10 -3
  80. nat/eval/evaluator/base_evaluator.py +1 -1
  81. nat/eval/rag_evaluator/evaluate.py +7 -4
  82. nat/eval/register.py +4 -0
  83. nat/eval/runtime_evaluator/__init__.py +14 -0
  84. nat/eval/runtime_evaluator/evaluate.py +123 -0
  85. nat/eval/runtime_evaluator/register.py +100 -0
  86. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  87. nat/eval/trajectory_evaluator/register.py +1 -1
  88. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  89. nat/eval/usage_stats.py +2 -0
  90. nat/eval/utils/output_uploader.py +3 -2
  91. nat/eval/utils/weave_eval.py +17 -3
  92. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  93. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  94. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  95. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  96. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  97. nat/experimental/test_time_compute/models/strategy_base.py +2 -2
  98. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  99. nat/front_ends/console/authentication_flow_handler.py +82 -30
  100. nat/front_ends/console/console_front_end_plugin.py +19 -7
  101. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  102. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  103. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  104. nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
  105. nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
  106. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
  107. nat/front_ends/fastapi/job_store.py +518 -99
  108. nat/front_ends/fastapi/main.py +11 -19
  109. nat/front_ends/fastapi/message_handler.py +69 -44
  110. nat/front_ends/fastapi/message_validator.py +8 -7
  111. nat/front_ends/fastapi/utils.py +57 -0
  112. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  113. nat/front_ends/mcp/mcp_front_end_config.py +71 -3
  114. nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
  115. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
  116. nat/front_ends/mcp/memory_profiler.py +320 -0
  117. nat/front_ends/mcp/tool_converter.py +78 -25
  118. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  119. nat/llm/aws_bedrock_llm.py +21 -8
  120. nat/llm/azure_openai_llm.py +14 -5
  121. nat/llm/litellm_llm.py +80 -0
  122. nat/llm/nim_llm.py +23 -9
  123. nat/llm/openai_llm.py +19 -7
  124. nat/llm/register.py +4 -0
  125. nat/llm/utils/thinking.py +1 -1
  126. nat/observability/exporter/base_exporter.py +1 -1
  127. nat/observability/exporter/processing_exporter.py +29 -55
  128. nat/observability/exporter/span_exporter.py +43 -15
  129. nat/observability/exporter_manager.py +2 -2
  130. nat/observability/mixin/redaction_config_mixin.py +5 -4
  131. nat/observability/mixin/tagging_config_mixin.py +26 -14
  132. nat/observability/mixin/type_introspection_mixin.py +420 -107
  133. nat/observability/processor/batching_processor.py +1 -1
  134. nat/observability/processor/processor.py +3 -0
  135. nat/observability/processor/redaction/__init__.py +24 -0
  136. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  137. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  138. nat/observability/processor/redaction/redaction_processor.py +177 -0
  139. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  140. nat/observability/processor/span_tagging_processor.py +21 -14
  141. nat/observability/register.py +16 -0
  142. nat/profiler/callbacks/langchain_callback_handler.py +32 -7
  143. nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
  144. nat/profiler/callbacks/token_usage_base_model.py +2 -0
  145. nat/profiler/decorators/framework_wrapper.py +61 -9
  146. nat/profiler/decorators/function_tracking.py +35 -3
  147. nat/profiler/forecasting/models/linear_model.py +1 -1
  148. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  149. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  150. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  151. nat/profiler/parameter_optimization/__init__.py +0 -0
  152. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  153. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  154. nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
  155. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  156. nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
  157. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  158. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  159. nat/profiler/utils.py +3 -1
  160. nat/registry_handlers/pypi/register_pypi.py +5 -3
  161. nat/registry_handlers/rest/register_rest.py +5 -3
  162. nat/retriever/milvus/retriever.py +1 -1
  163. nat/retriever/nemo_retriever/register.py +2 -1
  164. nat/runtime/loader.py +1 -1
  165. nat/runtime/runner.py +111 -6
  166. nat/runtime/session.py +49 -3
  167. nat/settings/global_settings.py +2 -2
  168. nat/tool/chat_completion.py +4 -1
  169. nat/tool/code_execution/code_sandbox.py +3 -6
  170. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  171. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
  172. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  173. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  174. nat/tool/datetime_tools.py +1 -1
  175. nat/tool/github_tools.py +450 -0
  176. nat/tool/memory_tools/add_memory_tool.py +3 -3
  177. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  178. nat/tool/memory_tools/get_memory_tool.py +4 -4
  179. nat/tool/register.py +2 -7
  180. nat/tool/server_tools.py +15 -2
  181. nat/utils/__init__.py +76 -0
  182. nat/utils/callable_utils.py +70 -0
  183. nat/utils/data_models/schema_validator.py +1 -1
  184. nat/utils/decorators.py +210 -0
  185. nat/utils/exception_handlers/automatic_retries.py +278 -72
  186. nat/utils/io/yaml_tools.py +73 -3
  187. nat/utils/log_levels.py +25 -0
  188. nat/utils/responses_api.py +26 -0
  189. nat/utils/string_utils.py +16 -0
  190. nat/utils/type_converter.py +12 -3
  191. nat/utils/type_utils.py +6 -2
  192. nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
  193. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
  194. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
  195. nat/cli/commands/info/list_mcp.py +0 -461
  196. nat/data_models/temperature_mixin.py +0 -43
  197. nat/data_models/top_p_mixin.py +0 -43
  198. nat/observability/processor/header_redaction_processor.py +0 -123
  199. nat/observability/processor/redaction_processor.py +0 -77
  200. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  201. nat/tool/github_tools/create_github_commit.py +0 -133
  202. nat/tool/github_tools/create_github_issue.py +0 -87
  203. nat/tool/github_tools/create_github_pr.py +0 -106
  204. nat/tool/github_tools/get_github_file.py +0 -106
  205. nat/tool/github_tools/get_github_issue.py +0 -166
  206. nat/tool/github_tools/get_github_pr.py +0 -256
  207. nat/tool/github_tools/update_github_issue.py +0 -100
  208. nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
  209. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  210. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
  211. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  212. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
  213. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,107 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections.abc import Sequence
17
+
18
+ import numpy as np
19
+ import optuna
20
+ from optuna._hypervolume import compute_hypervolume
21
+ from optuna.study import Study
22
+ from optuna.study import StudyDirection
23
+
24
+
25
+ # ---------- helper ----------
26
+ def _to_minimisation_matrix(
27
+ trials: Sequence[optuna.trial.FrozenTrial],
28
+ directions: Sequence[StudyDirection],
29
+ ) -> np.ndarray:
30
+ """Return array (n_trials × n_objectives) where **all** objectives are ‘smaller-is-better’."""
31
+ vals = np.asarray([t.values for t in trials], dtype=float)
32
+ for j, d in enumerate(directions):
33
+ if d == StudyDirection.MAXIMIZE:
34
+ vals[:, j] *= -1.0 # flip sign
35
+ return vals
36
+
37
+
38
+ # ---------- public API ----------
39
+ def pick_trial(
40
+ study: Study,
41
+ mode: str = "harmonic",
42
+ *,
43
+ weights: Sequence[float] | None = None,
44
+ ref_point: Sequence[float] | None = None,
45
+ eps: float = 1e-12,
46
+ ) -> optuna.trial.FrozenTrial:
47
+ """
48
+ Collapse Optuna’s Pareto front (`study.best_trials`) to a single “best compromise”.
49
+
50
+ Parameters
51
+ ----------
52
+ study : completed **multi-objective** Optuna study
53
+ mode : {"harmonic", "sum", "chebyshev", "hypervolume"}
54
+ weights : per-objective weights (used only for "sum")
55
+ ref_point : reference point for hyper-volume (defaults to ones after normalisation)
56
+ eps : tiny value to avoid division by zero
57
+
58
+ Returns
59
+ -------
60
+ optuna.trial.FrozenTrial
61
+ """
62
+
63
+ # ---- 1. Pareto front ----
64
+ front = study.best_trials
65
+ if not front:
66
+ raise ValueError("`study.best_trials` is empty – no Pareto-optimal trials found.")
67
+
68
+ # ---- 2. Convert & normalise objectives ----
69
+ vals = _to_minimisation_matrix(front, study.directions) # smaller is better
70
+ span = np.ptp(vals, axis=0)
71
+ norm = (vals - vals.min(axis=0)) / (span + eps) # 0 = best, 1 = worst
72
+
73
+ # ---- 3. Scalarise according to chosen mode ----
74
+ mode = mode.lower()
75
+
76
+ if mode == "harmonic":
77
+ hmean = norm.shape[1] / (1.0 / (norm + eps)).sum(axis=1)
78
+ best_idx = hmean.argmin() # lower = better
79
+
80
+ elif mode == "sum":
81
+ w = np.ones(norm.shape[1]) if weights is None else np.asarray(weights, float)
82
+ if w.size != norm.shape[1]:
83
+ raise ValueError("`weights` length must equal number of objectives.")
84
+ score = norm @ w
85
+ best_idx = score.argmin()
86
+
87
+ elif mode == "chebyshev":
88
+ score = norm.max(axis=1) # worst dimension
89
+ best_idx = score.argmin()
90
+
91
+ elif mode == "hypervolume":
92
+ # Hyper-volume assumes points are *below* the reference point (minimisation space).
93
+ if len(front) == 0:
94
+ raise ValueError("Pareto front is empty - no trials to select from")
95
+ elif len(front) == 1:
96
+ best_idx = 0
97
+ else:
98
+ rp = np.ones(norm.shape[1]) if ref_point is None else np.asarray(ref_point, float)
99
+ base_hv = compute_hypervolume(norm, rp)
100
+ contrib = np.array([base_hv - compute_hypervolume(np.delete(norm, i, 0), rp) for i in range(len(front))])
101
+ best_idx = contrib.argmax() # bigger contribution wins
102
+
103
+ else:
104
+ raise ValueError(f"Unknown mode '{mode}'. Choose from "
105
+ "'harmonic', 'sum', 'chebyshev', 'hypervolume'.")
106
+
107
+ return front[best_idx]
@@ -0,0 +1,460 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # flake8: noqa: W293
16
+
17
+ import logging
18
+ from pathlib import Path
19
+
20
+ import matplotlib.pyplot as plt
21
+ import numpy as np
22
+ import optuna
23
+ import pandas as pd
24
+ from matplotlib.lines import Line2D
25
+ from matplotlib.patches import Patch
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class ParetoVisualizer:
31
+
32
+ def __init__(self, metric_names: list[str], directions: list[str], title_prefix: str = "Optimization Results"):
33
+ self.metric_names = metric_names
34
+ self.directions = directions
35
+ self.title_prefix = title_prefix
36
+
37
+ if len(metric_names) != len(directions):
38
+ raise ValueError("Number of metric names must match number of directions")
39
+
40
+ def plot_pareto_front_2d(self,
41
+ trials_df: pd.DataFrame,
42
+ pareto_trials_df: pd.DataFrame | None = None,
43
+ save_path: Path | None = None,
44
+ figsize: tuple[int, int] = (10, 8),
45
+ show_plot: bool = True) -> plt.Figure:
46
+ if len(self.metric_names) != 2:
47
+ raise ValueError("2D Pareto front visualization requires exactly 2 metrics")
48
+
49
+ fig, ax = plt.subplots(figsize=figsize)
50
+
51
+ # Extract metric values - support both old (values_0) and new (values_metricname) formats
52
+ x_col = f"values_{self.metric_names[0]}" \
53
+ if f"values_{self.metric_names[0]}" in trials_df.columns else f"values_{0}"
54
+ y_col = f"values_{self.metric_names[1]}"\
55
+ if f"values_{self.metric_names[1]}" in trials_df.columns else f"values_{1}"
56
+ x_vals = trials_df[x_col].values
57
+ y_vals = trials_df[y_col].values
58
+
59
+ # Plot all trials
60
+ ax.scatter(x_vals,
61
+ y_vals,
62
+ alpha=0.6,
63
+ s=50,
64
+ c='lightblue',
65
+ label=f'All Trials (n={len(trials_df)})',
66
+ edgecolors='navy',
67
+ linewidths=0.5)
68
+
69
+ # Plot Pareto optimal trials if provided
70
+ if pareto_trials_df is not None and not pareto_trials_df.empty:
71
+ pareto_x = pareto_trials_df[x_col].values
72
+ pareto_y = pareto_trials_df[y_col].values
73
+
74
+ ax.scatter(pareto_x,
75
+ pareto_y,
76
+ alpha=0.9,
77
+ s=100,
78
+ c='red',
79
+ label=f'Pareto Optimal (n={len(pareto_trials_df)})',
80
+ edgecolors='darkred',
81
+ linewidths=1.5,
82
+ marker='*')
83
+
84
+ # Add trial number labels to Pareto optimal points
85
+ for idx in range(len(pareto_trials_df)):
86
+ trial_number = pareto_trials_df.iloc[idx]['number'] \
87
+ if 'number' in pareto_trials_df.columns else pareto_trials_df.index[idx]
88
+ ax.annotate(f'{int(trial_number)}',
89
+ xy=(pareto_x[idx], pareto_y[idx]),
90
+ xytext=(8, 8),
91
+ textcoords='offset points',
92
+ fontsize=9,
93
+ fontweight='bold',
94
+ color='darkred',
95
+ bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='red', alpha=0.9))
96
+
97
+ # Draw Pareto front line (only for 2D)
98
+ if len(pareto_x) > 1:
99
+ # Sort points for line drawing based on first objective
100
+ sorted_indices = np.argsort(pareto_x)
101
+ ax.plot(pareto_x[sorted_indices],
102
+ pareto_y[sorted_indices],
103
+ 'r--',
104
+ alpha=0.7,
105
+ linewidth=2,
106
+ label='Pareto Front')
107
+
108
+ # Customize plot
109
+ x_direction = "↓" if self.directions[0] == "minimize" else "↑"
110
+ y_direction = "↓" if self.directions[1] == "minimize" else "↑"
111
+
112
+ ax.set_xlabel(f"{self.metric_names[0]} {x_direction}", fontsize=12)
113
+ ax.set_ylabel(f"{self.metric_names[1]} {y_direction}", fontsize=12)
114
+ ax.set_title(f"{self.title_prefix}: Pareto Front Visualization", fontsize=14, fontweight='bold')
115
+
116
+ ax.legend(loc='best', frameon=True, fancybox=True, shadow=True)
117
+ ax.grid(True, alpha=0.3)
118
+
119
+ # Add direction annotations
120
+ x_annotation = (f"Better {self.metric_names[0]} ←"
121
+ if self.directions[0] == "minimize" else f"→ Better {self.metric_names[0]}")
122
+ ax.annotate(x_annotation,
123
+ xy=(0.02, 0.98),
124
+ xycoords='axes fraction',
125
+ ha='left',
126
+ va='top',
127
+ fontsize=10,
128
+ style='italic',
129
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.7))
130
+
131
+ y_annotation = (f"Better {self.metric_names[1]} ↓"
132
+ if self.directions[1] == "minimize" else f"Better {self.metric_names[1]} ↑")
133
+ ax.annotate(y_annotation,
134
+ xy=(0.02, 0.02),
135
+ xycoords='axes fraction',
136
+ ha='left',
137
+ va='bottom',
138
+ fontsize=10,
139
+ style='italic',
140
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgreen", alpha=0.7))
141
+
142
+ plt.tight_layout()
143
+
144
+ if save_path:
145
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
146
+ logger.info("2D Pareto plot saved to: %s", save_path)
147
+
148
+ if show_plot:
149
+ plt.show()
150
+
151
+ return fig
152
+
153
+ def plot_pareto_parallel_coordinates(self,
154
+ trials_df: pd.DataFrame,
155
+ pareto_trials_df: pd.DataFrame | None = None,
156
+ save_path: Path | None = None,
157
+ figsize: tuple[int, int] = (12, 8),
158
+ show_plot: bool = True) -> plt.Figure:
159
+ fig, ax = plt.subplots(figsize=figsize)
160
+
161
+ n_metrics = len(self.metric_names)
162
+ x_positions = np.arange(n_metrics)
163
+
164
+ # Normalize values for better visualization
165
+ all_values = []
166
+ for i in range(n_metrics):
167
+ # Support both old (values_0) and new (values_metricname) formats
168
+ col_name = f"values_{self.metric_names[i]}"\
169
+ if f"values_{self.metric_names[i]}" in trials_df.columns else f"values_{i}"
170
+ all_values.append(trials_df[col_name].values)
171
+
172
+ # Normalize each metric to [0, 1] for parallel coordinates
173
+ normalized_values = []
174
+ for i, values in enumerate(all_values):
175
+ min_val, max_val = values.min(), values.max()
176
+ if max_val > min_val:
177
+ if self.directions[i] == "minimize":
178
+ # For minimize: lower values get higher normalized scores
179
+ norm_vals = 1 - (values - min_val) / (max_val - min_val)
180
+ else:
181
+ # For maximize: higher values get higher normalized scores
182
+ norm_vals = (values - min_val) / (max_val - min_val)
183
+ else:
184
+ norm_vals = np.ones_like(values) * 0.5
185
+ normalized_values.append(norm_vals)
186
+
187
+ # Plot all trials
188
+ for i in range(len(trials_df)):
189
+ trial_values = [normalized_values[j][i] for j in range(n_metrics)]
190
+ ax.plot(x_positions, trial_values, 'b-', alpha=0.1, linewidth=1)
191
+
192
+ # Plot Pareto optimal trials
193
+ if pareto_trials_df is not None and not pareto_trials_df.empty:
194
+ pareto_indices = pareto_trials_df.index
195
+ for idx in pareto_indices:
196
+ if idx < len(trials_df):
197
+ trial_values = [normalized_values[j][idx] for j in range(n_metrics)]
198
+ ax.plot(x_positions, trial_values, 'r-', alpha=0.8, linewidth=3)
199
+
200
+ # Add trial number label at the rightmost point
201
+ trial_number = trials_df.iloc[idx]['number'] if 'number' in trials_df.columns else idx
202
+ # Position label slightly to the right and above the last point
203
+ ax.annotate(f'{int(trial_number)}',
204
+ xy=(x_positions[-1], trial_values[-1]),
205
+ xytext=(5, 5),
206
+ textcoords='offset points',
207
+ fontsize=9,
208
+ fontweight='bold',
209
+ color='darkred',
210
+ bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='red', alpha=0.8))
211
+
212
+ # Customize plot
213
+ ax.set_xticks(x_positions)
214
+ ax.set_xticklabels([f"{name}\n({direction})" for name, direction in zip(self.metric_names, self.directions)])
215
+ ax.set_ylabel("Normalized Performance (Higher is Better)", fontsize=12)
216
+ ax.set_title(f"{self.title_prefix}: Parallel Coordinates Plot", fontsize=14, fontweight='bold')
217
+ ax.set_ylim(-0.05, 1.05)
218
+ ax.grid(True, alpha=0.3)
219
+
220
+ # Add legend
221
+ legend_elements = [
222
+ Line2D([0], [0], color='blue', alpha=0.3, linewidth=2, label='All Trials'),
223
+ Line2D([0], [0], color='red', alpha=0.8, linewidth=3, label='Pareto Optimal'),
224
+ Patch(facecolor='white', edgecolor='red', label='[n]: trial number')
225
+ ]
226
+ ax.legend(handles=legend_elements, loc='best')
227
+
228
+ plt.tight_layout()
229
+
230
+ if save_path:
231
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
232
+ logger.info("Parallel coordinates plot saved to: %s", save_path)
233
+
234
+ if show_plot:
235
+ plt.show()
236
+
237
+ return fig
238
+
239
+ def plot_pairwise_matrix(self,
240
+ trials_df: pd.DataFrame,
241
+ pareto_trials_df: pd.DataFrame | None = None,
242
+ save_path: Path | None = None,
243
+ figsize: tuple[int, int] | None = None,
244
+ show_plot: bool = True) -> plt.Figure:
245
+ n_metrics = len(self.metric_names)
246
+ if figsize is None:
247
+ figsize = (4 * n_metrics, 4 * n_metrics)
248
+
249
+ fig, axes = plt.subplots(n_metrics, n_metrics, figsize=figsize)
250
+ fig.suptitle(f"{self.title_prefix}: Pairwise Metric Comparison", fontsize=16, fontweight='bold')
251
+
252
+ for i in range(n_metrics):
253
+ for j in range(n_metrics):
254
+ ax = axes[i, j] if n_metrics > 1 else axes
255
+
256
+ if i == j:
257
+ # Diagonal: histograms
258
+ # Support both old (values_0) and new (values_metricname) formats
259
+ col_name = f"values_{self.metric_names[i]}"\
260
+ if f"values_{self.metric_names[i]}" in trials_df.columns else f"values_{i}"
261
+ values = trials_df[col_name].values
262
+ ax.hist(values, bins=20, alpha=0.7, color='lightblue', edgecolor='navy')
263
+ if pareto_trials_df is not None and not pareto_trials_df.empty:
264
+ pareto_values = pareto_trials_df[col_name].values
265
+ ax.hist(pareto_values, bins=20, alpha=0.8, color='red', edgecolor='darkred')
266
+ ax.set_xlabel(f"{self.metric_names[i]}")
267
+ ax.set_ylabel("Frequency")
268
+ else:
269
+ # Off-diagonal: scatter plots
270
+ # Support both old (values_0) and new (values_metricname) formats
271
+ x_col = f"values_{self.metric_names[j]}"\
272
+ if f"values_{self.metric_names[j]}" in trials_df.columns else f"values_{j}"
273
+ y_col = f"values_{self.metric_names[i]}"\
274
+ if f"values_{self.metric_names[i]}" in trials_df.columns else f"values_{i}"
275
+ x_vals = trials_df[x_col].values
276
+ y_vals = trials_df[y_col].values
277
+
278
+ ax.scatter(x_vals, y_vals, alpha=0.6, s=30, c='lightblue', edgecolors='navy', linewidths=0.5)
279
+
280
+ if pareto_trials_df is not None and not pareto_trials_df.empty:
281
+ pareto_x = pareto_trials_df[x_col].values
282
+ pareto_y = pareto_trials_df[y_col].values
283
+ ax.scatter(pareto_x,
284
+ pareto_y,
285
+ alpha=0.9,
286
+ s=60,
287
+ c='red',
288
+ edgecolors='darkred',
289
+ linewidths=1,
290
+ marker='*')
291
+
292
+ # Add trial number labels to Pareto optimal points
293
+ for idx in range(len(pareto_trials_df)):
294
+ trial_number = pareto_trials_df.iloc[idx]['number'] \
295
+ if 'number' in pareto_trials_df.columns else pareto_trials_df.index[idx]
296
+ ax.annotate(f'{int(trial_number)}',
297
+ xy=(pareto_x[idx], pareto_y[idx]),
298
+ xytext=(6, 6),
299
+ textcoords='offset points',
300
+ fontsize=8,
301
+ fontweight='bold',
302
+ color='darkred',
303
+ bbox=dict(boxstyle='round,pad=0.2',
304
+ facecolor='white',
305
+ edgecolor='red',
306
+ alpha=0.8))
307
+
308
+ ax.set_xlabel(f"{self.metric_names[j]} ({self.directions[j]})")
309
+ ax.set_ylabel(f"{self.metric_names[i]} ({self.directions[i]})")
310
+
311
+ ax.grid(True, alpha=0.3)
312
+
313
+ # Add legend to the figure
314
+ legend_elements = [
315
+ Line2D([0], [0],
316
+ marker='o',
317
+ color='w',
318
+ markerfacecolor='lightblue',
319
+ markeredgecolor='navy',
320
+ markersize=8,
321
+ alpha=0.6,
322
+ label='All Trials'),
323
+ Line2D([0], [0],
324
+ marker='*',
325
+ color='w',
326
+ markerfacecolor='red',
327
+ markeredgecolor='darkred',
328
+ markersize=10,
329
+ alpha=0.9,
330
+ label='Pareto Optimal'),
331
+ Patch(facecolor='white', edgecolor='red', label='[n]: trial number')
332
+ ]
333
+ fig.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(0.98, 0.98), framealpha=0.9, fontsize=10)
334
+
335
+ plt.tight_layout()
336
+
337
+ if save_path:
338
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
339
+ logger.info("Pairwise matrix plot saved to: %s", save_path)
340
+
341
+ if show_plot:
342
+ plt.show()
343
+
344
+ return fig
345
+
346
+
347
+ def load_trials_from_study(study: optuna.Study) -> tuple[pd.DataFrame, pd.DataFrame]:
348
+ # Get all trials
349
+ trials_df = study.trials_dataframe()
350
+
351
+ # Get Pareto optimal trials
352
+ pareto_trials = study.best_trials
353
+ pareto_trial_numbers = [trial.number for trial in pareto_trials]
354
+ pareto_trials_df = trials_df[trials_df['number'].isin(pareto_trial_numbers)]
355
+
356
+ return trials_df, pareto_trials_df
357
+
358
+
359
+ def load_trials_from_csv(csv_path: Path, metric_names: list[str],
360
+ directions: list[str]) -> tuple[pd.DataFrame, pd.DataFrame]:
361
+ trials_df = pd.read_csv(csv_path)
362
+
363
+ # Extract values columns
364
+ value_cols = [col for col in trials_df.columns if col.startswith('values_')]
365
+ if not value_cols:
366
+ raise ValueError("CSV file must contain 'values_' columns with metric scores")
367
+
368
+ # Compute Pareto optimal solutions manually
369
+ pareto_mask = compute_pareto_optimal_mask(trials_df, value_cols, directions)
370
+ pareto_trials_df = trials_df[pareto_mask]
371
+
372
+ return trials_df, pareto_trials_df
373
+
374
+
375
+ def compute_pareto_optimal_mask(df: pd.DataFrame, value_cols: list[str], directions: list[str]) -> np.ndarray:
376
+ values = df[value_cols].values
377
+ n_trials = len(values)
378
+
379
+ # Normalize directions: convert all to maximization
380
+ normalized_values = values.copy()
381
+ for i, direction in enumerate(directions):
382
+ if direction == "minimize":
383
+ normalized_values[:, i] = -normalized_values[:, i]
384
+
385
+ is_pareto = np.ones(n_trials, dtype=bool)
386
+
387
+ for i in range(n_trials):
388
+ if is_pareto[i]:
389
+ # Compare with all other solutions
390
+ dominates = np.all(normalized_values[i] >= normalized_values, axis=1) & \
391
+ np.any(normalized_values[i] > normalized_values, axis=1)
392
+ is_pareto[dominates] = False
393
+
394
+ return is_pareto
395
+
396
+
397
+ def create_pareto_visualization(data_source: optuna.Study | Path | pd.DataFrame,
398
+ metric_names: list[str],
399
+ directions: list[str],
400
+ output_dir: Path | None = None,
401
+ title_prefix: str = "Optimization Results",
402
+ show_plots: bool = True) -> dict[str, plt.Figure]:
403
+ # Load data based on source type
404
+ if hasattr(data_source, 'trials_dataframe'):
405
+ # Optuna study object
406
+ trials_df, pareto_trials_df = load_trials_from_study(data_source)
407
+ elif isinstance(data_source, str | Path):
408
+ # CSV file path
409
+ trials_df, pareto_trials_df = load_trials_from_csv(Path(data_source), metric_names, directions)
410
+ elif isinstance(data_source, pd.DataFrame):
411
+ # DataFrame
412
+ trials_df = data_source
413
+ value_cols = [col for col in trials_df.columns if col.startswith('values_')]
414
+ pareto_mask = compute_pareto_optimal_mask(trials_df, value_cols, directions)
415
+ pareto_trials_df = trials_df[pareto_mask]
416
+ else:
417
+ raise ValueError("data_source must be an Optuna study, CSV file path, or pandas DataFrame")
418
+
419
+ visualizer = ParetoVisualizer(metric_names, directions, title_prefix)
420
+ figures = {}
421
+
422
+ logger.info("Creating Pareto front visualizations...")
423
+ logger.info("Total trials: %d", len(trials_df))
424
+ logger.info("Pareto optimal trials: %d", len(pareto_trials_df))
425
+
426
+ # Create output directory if specified
427
+ if output_dir:
428
+ output_dir = Path(output_dir)
429
+ output_dir.mkdir(parents=True, exist_ok=True)
430
+
431
+ try:
432
+ if len(metric_names) == 2:
433
+ # 2D scatter plot
434
+ save_path = output_dir / "pareto_front_2d.png" if output_dir else None
435
+ fig = visualizer.plot_pareto_front_2d(trials_df, pareto_trials_df, save_path, show_plot=show_plots)
436
+ figures["2d_scatter"] = fig
437
+
438
+ if len(metric_names) >= 2:
439
+ # Parallel coordinates plot
440
+ save_path = output_dir / "pareto_parallel_coordinates.png" if output_dir else None
441
+ fig = visualizer.plot_pareto_parallel_coordinates(trials_df,
442
+ pareto_trials_df,
443
+ save_path,
444
+ show_plot=show_plots)
445
+ figures["parallel_coordinates"] = fig
446
+
447
+ # Pairwise matrix plot
448
+ save_path = output_dir / "pareto_pairwise_matrix.png" if output_dir else None
449
+ fig = visualizer.plot_pairwise_matrix(trials_df, pareto_trials_df, save_path, show_plot=show_plots)
450
+ figures["pairwise_matrix"] = fig
451
+
452
+ logger.info("Visualization complete!")
453
+ if output_dir:
454
+ logger.info("Plots saved to: %s", output_dir)
455
+
456
+ except Exception as e:
457
+ logger.error("Error creating visualizations: %s", e)
458
+ raise
459
+
460
+ return figures