nvidia-nat 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (257) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +27 -18
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +81 -50
  7. nat/agent/react_agent/register.py +59 -40
  8. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +327 -149
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +64 -46
  13. nat/agent/tool_calling_agent/agent.py +152 -29
  14. nat/agent/tool_calling_agent/register.py +61 -38
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +10 -6
  24. nat/builder/context.py +70 -18
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/intermediate_step_manager.py +6 -2
  32. nat/builder/user_interaction_manager.py +2 -2
  33. nat/builder/workflow.py +13 -1
  34. nat/builder/workflow_builder.py +327 -79
  35. nat/cli/cli_utils/config_override.py +2 -2
  36. nat/cli/commands/evaluate.py +1 -1
  37. nat/cli/commands/info/info.py +16 -6
  38. nat/cli/commands/info/list_channels.py +1 -1
  39. nat/cli/commands/info/list_components.py +7 -8
  40. nat/cli/commands/mcp/__init__.py +14 -0
  41. nat/cli/commands/mcp/mcp.py +986 -0
  42. nat/cli/commands/object_store/__init__.py +14 -0
  43. nat/cli/commands/object_store/object_store.py +227 -0
  44. nat/cli/commands/optimize.py +90 -0
  45. nat/cli/commands/registry/publish.py +2 -2
  46. nat/cli/commands/registry/pull.py +2 -2
  47. nat/cli/commands/registry/remove.py +2 -2
  48. nat/cli/commands/registry/search.py +15 -17
  49. nat/cli/commands/start.py +16 -5
  50. nat/cli/commands/uninstall.py +1 -1
  51. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  52. nat/cli/commands/workflow/templates/pyproject.toml.j2 +5 -2
  53. nat/cli/commands/workflow/templates/register.py.j2 +2 -3
  54. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  55. nat/cli/commands/workflow/workflow_commands.py +105 -19
  56. nat/cli/entrypoint.py +17 -11
  57. nat/cli/main.py +3 -0
  58. nat/cli/register_workflow.py +38 -4
  59. nat/cli/type_registry.py +79 -10
  60. nat/control_flow/__init__.py +0 -0
  61. nat/control_flow/register.py +20 -0
  62. nat/control_flow/router_agent/__init__.py +0 -0
  63. nat/control_flow/router_agent/agent.py +329 -0
  64. nat/control_flow/router_agent/prompt.py +48 -0
  65. nat/control_flow/router_agent/register.py +91 -0
  66. nat/control_flow/sequential_executor.py +166 -0
  67. nat/data_models/agent.py +34 -0
  68. nat/data_models/api_server.py +196 -67
  69. nat/data_models/authentication.py +23 -9
  70. nat/data_models/common.py +1 -1
  71. nat/data_models/component.py +2 -0
  72. nat/data_models/component_ref.py +11 -0
  73. nat/data_models/config.py +42 -18
  74. nat/data_models/dataset_handler.py +1 -1
  75. nat/data_models/discovery_metadata.py +4 -4
  76. nat/data_models/evaluate.py +4 -1
  77. nat/data_models/function.py +34 -0
  78. nat/data_models/function_dependencies.py +14 -6
  79. nat/data_models/gated_field_mixin.py +242 -0
  80. nat/data_models/intermediate_step.py +3 -3
  81. nat/data_models/optimizable.py +119 -0
  82. nat/data_models/optimizer.py +149 -0
  83. nat/data_models/span.py +41 -3
  84. nat/data_models/swe_bench_model.py +1 -1
  85. nat/data_models/temperature_mixin.py +44 -0
  86. nat/data_models/thinking_mixin.py +86 -0
  87. nat/data_models/top_p_mixin.py +44 -0
  88. nat/embedder/azure_openai_embedder.py +46 -0
  89. nat/embedder/nim_embedder.py +1 -1
  90. nat/embedder/openai_embedder.py +2 -3
  91. nat/embedder/register.py +1 -1
  92. nat/eval/config.py +3 -1
  93. nat/eval/dataset_handler/dataset_handler.py +71 -7
  94. nat/eval/evaluate.py +86 -31
  95. nat/eval/evaluator/base_evaluator.py +1 -1
  96. nat/eval/evaluator/evaluator_model.py +13 -0
  97. nat/eval/intermediate_step_adapter.py +1 -1
  98. nat/eval/rag_evaluator/evaluate.py +9 -6
  99. nat/eval/rag_evaluator/register.py +3 -3
  100. nat/eval/register.py +4 -1
  101. nat/eval/remote_workflow.py +3 -3
  102. nat/eval/runtime_evaluator/__init__.py +14 -0
  103. nat/eval/runtime_evaluator/evaluate.py +123 -0
  104. nat/eval/runtime_evaluator/register.py +100 -0
  105. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  106. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  107. nat/eval/trajectory_evaluator/register.py +1 -1
  108. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  109. nat/eval/utils/eval_trace_ctx.py +89 -0
  110. nat/eval/utils/weave_eval.py +18 -9
  111. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  112. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  113. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  114. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  115. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  116. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  117. nat/experimental/test_time_compute/register.py +0 -1
  118. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  119. nat/front_ends/console/authentication_flow_handler.py +82 -30
  120. nat/front_ends/console/console_front_end_plugin.py +19 -7
  121. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  122. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  123. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  124. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  125. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  126. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  127. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +455 -282
  128. nat/front_ends/fastapi/job_store.py +518 -99
  129. nat/front_ends/fastapi/main.py +11 -19
  130. nat/front_ends/fastapi/message_handler.py +74 -50
  131. nat/front_ends/fastapi/message_validator.py +20 -21
  132. nat/front_ends/fastapi/response_helpers.py +4 -4
  133. nat/front_ends/fastapi/step_adaptor.py +2 -2
  134. nat/front_ends/fastapi/utils.py +57 -0
  135. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  136. nat/front_ends/mcp/mcp_front_end_config.py +47 -3
  137. nat/front_ends/mcp/mcp_front_end_plugin.py +48 -13
  138. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +120 -8
  139. nat/front_ends/mcp/tool_converter.py +44 -14
  140. nat/front_ends/register.py +0 -1
  141. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  142. nat/llm/aws_bedrock_llm.py +24 -12
  143. nat/llm/azure_openai_llm.py +57 -0
  144. nat/llm/litellm_llm.py +69 -0
  145. nat/llm/nim_llm.py +20 -8
  146. nat/llm/openai_llm.py +14 -6
  147. nat/llm/register.py +5 -1
  148. nat/llm/utils/env_config_value.py +2 -3
  149. nat/llm/utils/thinking.py +215 -0
  150. nat/meta/pypi.md +9 -9
  151. nat/object_store/register.py +0 -1
  152. nat/observability/exporter/base_exporter.py +3 -3
  153. nat/observability/exporter/file_exporter.py +1 -1
  154. nat/observability/exporter/processing_exporter.py +309 -81
  155. nat/observability/exporter/span_exporter.py +35 -15
  156. nat/observability/exporter_manager.py +7 -7
  157. nat/observability/mixin/file_mixin.py +7 -7
  158. nat/observability/mixin/redaction_config_mixin.py +42 -0
  159. nat/observability/mixin/tagging_config_mixin.py +62 -0
  160. nat/observability/mixin/type_introspection_mixin.py +420 -107
  161. nat/observability/processor/batching_processor.py +5 -7
  162. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  163. nat/observability/processor/processor.py +3 -0
  164. nat/observability/processor/processor_factory.py +70 -0
  165. nat/observability/processor/redaction/__init__.py +24 -0
  166. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  167. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  168. nat/observability/processor/redaction/redaction_processor.py +177 -0
  169. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  170. nat/observability/processor/span_tagging_processor.py +68 -0
  171. nat/observability/register.py +22 -4
  172. nat/profiler/calc/calc_runner.py +3 -4
  173. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  174. nat/profiler/callbacks/langchain_callback_handler.py +14 -7
  175. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  176. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  177. nat/profiler/data_frame_row.py +1 -1
  178. nat/profiler/decorators/framework_wrapper.py +62 -13
  179. nat/profiler/decorators/function_tracking.py +160 -3
  180. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  181. nat/profiler/forecasting/models/linear_model.py +1 -1
  182. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  183. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  184. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  185. nat/profiler/inference_optimization/data_models.py +3 -3
  186. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
  187. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  188. nat/profiler/parameter_optimization/__init__.py +0 -0
  189. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  190. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  191. nat/profiler/parameter_optimization/parameter_optimizer.py +164 -0
  192. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  193. nat/profiler/parameter_optimization/pareto_visualizer.py +395 -0
  194. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  195. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  196. nat/profiler/profile_runner.py +14 -9
  197. nat/profiler/utils.py +4 -2
  198. nat/registry_handlers/local/local_handler.py +2 -2
  199. nat/registry_handlers/package_utils.py +1 -2
  200. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  201. nat/registry_handlers/register.py +3 -4
  202. nat/registry_handlers/rest/rest_handler.py +12 -13
  203. nat/retriever/milvus/retriever.py +2 -2
  204. nat/retriever/nemo_retriever/retriever.py +1 -1
  205. nat/retriever/register.py +0 -1
  206. nat/runtime/loader.py +2 -2
  207. nat/runtime/runner.py +105 -8
  208. nat/runtime/session.py +69 -8
  209. nat/settings/global_settings.py +16 -5
  210. nat/tool/chat_completion.py +5 -2
  211. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  212. nat/tool/datetime_tools.py +49 -9
  213. nat/tool/document_search.py +2 -2
  214. nat/tool/github_tools.py +450 -0
  215. nat/tool/memory_tools/add_memory_tool.py +3 -3
  216. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  217. nat/tool/memory_tools/get_memory_tool.py +4 -4
  218. nat/tool/nvidia_rag.py +1 -1
  219. nat/tool/register.py +2 -9
  220. nat/tool/retriever.py +3 -2
  221. nat/utils/callable_utils.py +70 -0
  222. nat/utils/data_models/schema_validator.py +3 -3
  223. nat/utils/decorators.py +210 -0
  224. nat/utils/exception_handlers/automatic_retries.py +104 -51
  225. nat/utils/exception_handlers/schemas.py +1 -1
  226. nat/utils/io/yaml_tools.py +2 -2
  227. nat/utils/log_levels.py +25 -0
  228. nat/utils/reactive/base/observable_base.py +2 -2
  229. nat/utils/reactive/base/observer_base.py +1 -1
  230. nat/utils/reactive/observable.py +2 -2
  231. nat/utils/reactive/observer.py +4 -4
  232. nat/utils/reactive/subscription.py +1 -1
  233. nat/utils/settings/global_settings.py +6 -8
  234. nat/utils/type_converter.py +12 -3
  235. nat/utils/type_utils.py +9 -5
  236. nvidia_nat-1.3.0.dist-info/METADATA +195 -0
  237. {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/RECORD +244 -200
  238. {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/entry_points.txt +1 -0
  239. nat/cli/commands/info/list_mcp.py +0 -304
  240. nat/tool/github_tools/create_github_commit.py +0 -133
  241. nat/tool/github_tools/create_github_issue.py +0 -87
  242. nat/tool/github_tools/create_github_pr.py +0 -106
  243. nat/tool/github_tools/get_github_file.py +0 -106
  244. nat/tool/github_tools/get_github_issue.py +0 -166
  245. nat/tool/github_tools/get_github_pr.py +0 -256
  246. nat/tool/github_tools/update_github_issue.py +0 -100
  247. nat/tool/mcp/exceptions.py +0 -142
  248. nat/tool/mcp/mcp_client.py +0 -255
  249. nat/tool/mcp/mcp_tool.py +0 -96
  250. nat/utils/exception_handlers/mcp.py +0 -211
  251. nvidia_nat-1.2.1.dist-info/METADATA +0 -365
  252. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  253. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  254. {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/WHEEL +0 -0
  255. {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  256. {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE.md +0 -0
  257. {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,395 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # flake8: noqa: W293
16
+
17
+ import logging
18
+ from pathlib import Path
19
+
20
+ import matplotlib.pyplot as plt
21
+ import numpy as np
22
+ import optuna
23
+ import pandas as pd
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class ParetoVisualizer:
29
+
30
+ def __init__(self, metric_names: list[str], directions: list[str], title_prefix: str = "Optimization Results"):
31
+ self.metric_names = metric_names
32
+ self.directions = directions
33
+ self.title_prefix = title_prefix
34
+
35
+ if len(metric_names) != len(directions):
36
+ raise ValueError("Number of metric names must match number of directions")
37
+
38
+ def plot_pareto_front_2d(self,
39
+ trials_df: pd.DataFrame,
40
+ pareto_trials_df: pd.DataFrame | None = None,
41
+ save_path: Path | None = None,
42
+ figsize: tuple[int, int] = (10, 8),
43
+ show_plot: bool = True) -> plt.Figure:
44
+ if len(self.metric_names) != 2:
45
+ raise ValueError("2D Pareto front visualization requires exactly 2 metrics")
46
+
47
+ fig, ax = plt.subplots(figsize=figsize)
48
+
49
+ # Extract metric values - support both old (values_0) and new (values_metricname) formats
50
+ x_col = f"values_{self.metric_names[0]}" \
51
+ if f"values_{self.metric_names[0]}" in trials_df.columns else f"values_{0}"
52
+ y_col = f"values_{self.metric_names[1]}"\
53
+ if f"values_{self.metric_names[1]}" in trials_df.columns else f"values_{1}"
54
+ x_vals = trials_df[x_col].values
55
+ y_vals = trials_df[y_col].values
56
+
57
+ # Plot all trials
58
+ ax.scatter(x_vals,
59
+ y_vals,
60
+ alpha=0.6,
61
+ s=50,
62
+ c='lightblue',
63
+ label=f'All Trials (n={len(trials_df)})',
64
+ edgecolors='navy',
65
+ linewidths=0.5)
66
+
67
+ # Plot Pareto optimal trials if provided
68
+ if pareto_trials_df is not None and not pareto_trials_df.empty:
69
+ pareto_x = pareto_trials_df[x_col].values
70
+ pareto_y = pareto_trials_df[y_col].values
71
+
72
+ ax.scatter(pareto_x,
73
+ pareto_y,
74
+ alpha=0.9,
75
+ s=100,
76
+ c='red',
77
+ label=f'Pareto Optimal (n={len(pareto_trials_df)})',
78
+ edgecolors='darkred',
79
+ linewidths=1.5,
80
+ marker='*')
81
+
82
+ # Draw Pareto front line (only for 2D)
83
+ if len(pareto_x) > 1:
84
+ # Sort points for line drawing based on first objective
85
+ sorted_indices = np.argsort(pareto_x)
86
+ ax.plot(pareto_x[sorted_indices],
87
+ pareto_y[sorted_indices],
88
+ 'r--',
89
+ alpha=0.7,
90
+ linewidth=2,
91
+ label='Pareto Front')
92
+
93
+ # Customize plot
94
+ x_direction = "↓" if self.directions[0] == "minimize" else "↑"
95
+ y_direction = "↓" if self.directions[1] == "minimize" else "↑"
96
+
97
+ ax.set_xlabel(f"{self.metric_names[0]} {x_direction}", fontsize=12)
98
+ ax.set_ylabel(f"{self.metric_names[1]} {y_direction}", fontsize=12)
99
+ ax.set_title(f"{self.title_prefix}: Pareto Front Visualization", fontsize=14, fontweight='bold')
100
+
101
+ ax.legend(loc='best', frameon=True, fancybox=True, shadow=True)
102
+ ax.grid(True, alpha=0.3)
103
+
104
+ # Add direction annotations
105
+ x_annotation = (f"Better {self.metric_names[0]} ←"
106
+ if self.directions[0] == "minimize" else f"→ Better {self.metric_names[0]}")
107
+ ax.annotate(x_annotation,
108
+ xy=(0.02, 0.98),
109
+ xycoords='axes fraction',
110
+ ha='left',
111
+ va='top',
112
+ fontsize=10,
113
+ style='italic',
114
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.7))
115
+
116
+ y_annotation = (f"Better {self.metric_names[1]} ↓"
117
+ if self.directions[1] == "minimize" else f"Better {self.metric_names[1]} ↑")
118
+ ax.annotate(y_annotation,
119
+ xy=(0.02, 0.02),
120
+ xycoords='axes fraction',
121
+ ha='left',
122
+ va='bottom',
123
+ fontsize=10,
124
+ style='italic',
125
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgreen", alpha=0.7))
126
+
127
+ plt.tight_layout()
128
+
129
+ if save_path:
130
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
131
+ logger.info("2D Pareto plot saved to: %s", save_path)
132
+
133
+ if show_plot:
134
+ plt.show()
135
+
136
+ return fig
137
+
138
+ def plot_pareto_parallel_coordinates(self,
139
+ trials_df: pd.DataFrame,
140
+ pareto_trials_df: pd.DataFrame | None = None,
141
+ save_path: Path | None = None,
142
+ figsize: tuple[int, int] = (12, 8),
143
+ show_plot: bool = True) -> plt.Figure:
144
+ fig, ax = plt.subplots(figsize=figsize)
145
+
146
+ n_metrics = len(self.metric_names)
147
+ x_positions = np.arange(n_metrics)
148
+
149
+ # Normalize values for better visualization
150
+ all_values = []
151
+ for i in range(n_metrics):
152
+ # Support both old (values_0) and new (values_metricname) formats
153
+ col_name = f"values_{self.metric_names[i]}"\
154
+ if f"values_{self.metric_names[i]}" in trials_df.columns else f"values_{i}"
155
+ all_values.append(trials_df[col_name].values)
156
+
157
+ # Normalize each metric to [0, 1] for parallel coordinates
158
+ normalized_values = []
159
+ for i, values in enumerate(all_values):
160
+ min_val, max_val = values.min(), values.max()
161
+ if max_val > min_val:
162
+ if self.directions[i] == "minimize":
163
+ # For minimize: lower values get higher normalized scores
164
+ norm_vals = 1 - (values - min_val) / (max_val - min_val)
165
+ else:
166
+ # For maximize: higher values get higher normalized scores
167
+ norm_vals = (values - min_val) / (max_val - min_val)
168
+ else:
169
+ norm_vals = np.ones_like(values) * 0.5
170
+ normalized_values.append(norm_vals)
171
+
172
+ # Plot all trials
173
+ for i in range(len(trials_df)):
174
+ trial_values = [normalized_values[j][i] for j in range(n_metrics)]
175
+ ax.plot(x_positions, trial_values, 'b-', alpha=0.1, linewidth=1)
176
+
177
+ # Plot Pareto optimal trials
178
+ if pareto_trials_df is not None and not pareto_trials_df.empty:
179
+ pareto_indices = pareto_trials_df.index
180
+ for idx in pareto_indices:
181
+ if idx < len(trials_df):
182
+ trial_values = [normalized_values[j][idx] for j in range(n_metrics)]
183
+ ax.plot(x_positions, trial_values, 'r-', alpha=0.8, linewidth=3)
184
+
185
+ # Customize plot
186
+ ax.set_xticks(x_positions)
187
+ ax.set_xticklabels([f"{name}\n({direction})" for name, direction in zip(self.metric_names, self.directions)])
188
+ ax.set_ylabel("Normalized Performance (Higher is Better)", fontsize=12)
189
+ ax.set_title(f"{self.title_prefix}: Parallel Coordinates Plot", fontsize=14, fontweight='bold')
190
+ ax.set_ylim(-0.05, 1.05)
191
+ ax.grid(True, alpha=0.3)
192
+
193
+ # Add legend
194
+ from matplotlib.lines import Line2D
195
+ legend_elements = [
196
+ Line2D([0], [0], color='blue', alpha=0.3, linewidth=2, label='All Trials'),
197
+ Line2D([0], [0], color='red', alpha=0.8, linewidth=3, label='Pareto Optimal')
198
+ ]
199
+ ax.legend(handles=legend_elements, loc='best')
200
+
201
+ plt.tight_layout()
202
+
203
+ if save_path:
204
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
205
+ logger.info("Parallel coordinates plot saved to: %s", save_path)
206
+
207
+ if show_plot:
208
+ plt.show()
209
+
210
+ return fig
211
+
212
+ def plot_pairwise_matrix(self,
213
+ trials_df: pd.DataFrame,
214
+ pareto_trials_df: pd.DataFrame | None = None,
215
+ save_path: Path | None = None,
216
+ figsize: tuple[int, int] | None = None,
217
+ show_plot: bool = True) -> plt.Figure:
218
+ n_metrics = len(self.metric_names)
219
+ if figsize is None:
220
+ figsize = (4 * n_metrics, 4 * n_metrics)
221
+
222
+ fig, axes = plt.subplots(n_metrics, n_metrics, figsize=figsize)
223
+ fig.suptitle(f"{self.title_prefix}: Pairwise Metric Comparison", fontsize=16, fontweight='bold')
224
+
225
+ for i in range(n_metrics):
226
+ for j in range(n_metrics):
227
+ ax = axes[i, j] if n_metrics > 1 else axes
228
+
229
+ if i == j:
230
+ # Diagonal: histograms
231
+ # Support both old (values_0) and new (values_metricname) formats
232
+ col_name = f"values_{self.metric_names[i]}"\
233
+ if f"values_{self.metric_names[i]}" in trials_df.columns else f"values_{i}"
234
+ values = trials_df[col_name].values
235
+ ax.hist(values, bins=20, alpha=0.7, color='lightblue', edgecolor='navy')
236
+ if pareto_trials_df is not None and not pareto_trials_df.empty:
237
+ pareto_values = pareto_trials_df[col_name].values
238
+ ax.hist(pareto_values, bins=20, alpha=0.8, color='red', edgecolor='darkred')
239
+ ax.set_xlabel(f"{self.metric_names[i]}")
240
+ ax.set_ylabel("Frequency")
241
+ else:
242
+ # Off-diagonal: scatter plots
243
+ # Support both old (values_0) and new (values_metricname) formats
244
+ x_col = f"values_{self.metric_names[j]}"\
245
+ if f"values_{self.metric_names[j]}" in trials_df.columns else f"values_{j}"
246
+ y_col = f"values_{self.metric_names[i]}"\
247
+ if f"values_{self.metric_names[i]}" in trials_df.columns else f"values_{i}"
248
+ x_vals = trials_df[x_col].values
249
+ y_vals = trials_df[y_col].values
250
+
251
+ ax.scatter(x_vals, y_vals, alpha=0.6, s=30, c='lightblue', edgecolors='navy', linewidths=0.5)
252
+
253
+ if pareto_trials_df is not None and not pareto_trials_df.empty:
254
+ pareto_x = pareto_trials_df[x_col].values
255
+ pareto_y = pareto_trials_df[y_col].values
256
+ ax.scatter(pareto_x,
257
+ pareto_y,
258
+ alpha=0.9,
259
+ s=60,
260
+ c='red',
261
+ edgecolors='darkred',
262
+ linewidths=1,
263
+ marker='*')
264
+
265
+ ax.set_xlabel(f"{self.metric_names[j]} ({self.directions[j]})")
266
+ ax.set_ylabel(f"{self.metric_names[i]} ({self.directions[i]})")
267
+
268
+ ax.grid(True, alpha=0.3)
269
+
270
+ plt.tight_layout()
271
+
272
+ if save_path:
273
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
274
+ logger.info("Pairwise matrix plot saved to: %s", save_path)
275
+
276
+ if show_plot:
277
+ plt.show()
278
+
279
+ return fig
280
+
281
+
282
+ def load_trials_from_study(study: optuna.Study) -> tuple[pd.DataFrame, pd.DataFrame]:
283
+ # Get all trials
284
+ trials_df = study.trials_dataframe()
285
+
286
+ # Get Pareto optimal trials
287
+ pareto_trials = study.best_trials
288
+ pareto_trial_numbers = [trial.number for trial in pareto_trials]
289
+ pareto_trials_df = trials_df[trials_df['number'].isin(pareto_trial_numbers)]
290
+
291
+ return trials_df, pareto_trials_df
292
+
293
+
294
+ def load_trials_from_csv(csv_path: Path, metric_names: list[str],
295
+ directions: list[str]) -> tuple[pd.DataFrame, pd.DataFrame]:
296
+ trials_df = pd.read_csv(csv_path)
297
+
298
+ # Extract values columns
299
+ value_cols = [col for col in trials_df.columns if col.startswith('values_')]
300
+ if not value_cols:
301
+ raise ValueError("CSV file must contain 'values_' columns with metric scores")
302
+
303
+ # Compute Pareto optimal solutions manually
304
+ pareto_mask = compute_pareto_optimal_mask(trials_df, value_cols, directions)
305
+ pareto_trials_df = trials_df[pareto_mask]
306
+
307
+ return trials_df, pareto_trials_df
308
+
309
+
310
+ def compute_pareto_optimal_mask(df: pd.DataFrame, value_cols: list[str], directions: list[str]) -> np.ndarray:
311
+ values = df[value_cols].values
312
+ n_trials = len(values)
313
+
314
+ # Normalize directions: convert all to maximization
315
+ normalized_values = values.copy()
316
+ for i, direction in enumerate(directions):
317
+ if direction == "minimize":
318
+ normalized_values[:, i] = -normalized_values[:, i]
319
+
320
+ is_pareto = np.ones(n_trials, dtype=bool)
321
+
322
+ for i in range(n_trials):
323
+ if is_pareto[i]:
324
+ # Compare with all other solutions
325
+ dominates = np.all(normalized_values[i] >= normalized_values, axis=1) & \
326
+ np.any(normalized_values[i] > normalized_values, axis=1)
327
+ is_pareto[dominates] = False
328
+
329
+ return is_pareto
330
+
331
+
332
+ def create_pareto_visualization(data_source: optuna.Study | Path | pd.DataFrame,
333
+ metric_names: list[str],
334
+ directions: list[str],
335
+ output_dir: Path | None = None,
336
+ title_prefix: str = "Optimization Results",
337
+ show_plots: bool = True) -> dict[str, plt.Figure]:
338
+ # Load data based on source type
339
+ if hasattr(data_source, 'trials_dataframe'):
340
+ # Optuna study object
341
+ trials_df, pareto_trials_df = load_trials_from_study(data_source)
342
+ elif isinstance(data_source, str | Path):
343
+ # CSV file path
344
+ trials_df, pareto_trials_df = load_trials_from_csv(Path(data_source), metric_names, directions)
345
+ elif isinstance(data_source, pd.DataFrame):
346
+ # DataFrame
347
+ trials_df = data_source
348
+ value_cols = [col for col in trials_df.columns if col.startswith('values_')]
349
+ pareto_mask = compute_pareto_optimal_mask(trials_df, value_cols, directions)
350
+ pareto_trials_df = trials_df[pareto_mask]
351
+ else:
352
+ raise ValueError("data_source must be an Optuna study, CSV file path, or pandas DataFrame")
353
+
354
+ visualizer = ParetoVisualizer(metric_names, directions, title_prefix)
355
+ figures = {}
356
+
357
+ logger.info("Creating Pareto front visualizations...")
358
+ logger.info("Total trials: %d", len(trials_df))
359
+ logger.info("Pareto optimal trials: %d", len(pareto_trials_df))
360
+
361
+ # Create output directory if specified
362
+ if output_dir:
363
+ output_dir = Path(output_dir)
364
+ output_dir.mkdir(parents=True, exist_ok=True)
365
+
366
+ try:
367
+ if len(metric_names) == 2:
368
+ # 2D scatter plot
369
+ save_path = output_dir / "pareto_front_2d.png" if output_dir else None
370
+ fig = visualizer.plot_pareto_front_2d(trials_df, pareto_trials_df, save_path, show_plot=show_plots)
371
+ figures["2d_scatter"] = fig
372
+
373
+ if len(metric_names) >= 2:
374
+ # Parallel coordinates plot
375
+ save_path = output_dir / "pareto_parallel_coordinates.png" if output_dir else None
376
+ fig = visualizer.plot_pareto_parallel_coordinates(trials_df,
377
+ pareto_trials_df,
378
+ save_path,
379
+ show_plot=show_plots)
380
+ figures["parallel_coordinates"] = fig
381
+
382
+ # Pairwise matrix plot
383
+ save_path = output_dir / "pareto_pairwise_matrix.png" if output_dir else None
384
+ fig = visualizer.plot_pairwise_matrix(trials_df, pareto_trials_df, save_path, show_plot=show_plots)
385
+ figures["pairwise_matrix"] = fig
386
+
387
+ logger.info("Visualization complete!")
388
+ if output_dir:
389
+ logger.info("Plots saved to: %s", output_dir)
390
+
391
+ except Exception as e:
392
+ logger.error("Error creating visualizations: %s", e)
393
+ raise
394
+
395
+ return figures