nvidia-nat 1.1.0a20251020__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (480) hide show
  1. aiq/__init__.py +66 -0
  2. nat/agent/__init__.py +0 -0
  3. nat/agent/base.py +265 -0
  4. nat/agent/dual_node.py +72 -0
  5. nat/agent/prompt_optimizer/__init__.py +0 -0
  6. nat/agent/prompt_optimizer/prompt.py +68 -0
  7. nat/agent/prompt_optimizer/register.py +149 -0
  8. nat/agent/react_agent/__init__.py +0 -0
  9. nat/agent/react_agent/agent.py +394 -0
  10. nat/agent/react_agent/output_parser.py +104 -0
  11. nat/agent/react_agent/prompt.py +44 -0
  12. nat/agent/react_agent/register.py +168 -0
  13. nat/agent/reasoning_agent/__init__.py +0 -0
  14. nat/agent/reasoning_agent/reasoning_agent.py +227 -0
  15. nat/agent/register.py +23 -0
  16. nat/agent/rewoo_agent/__init__.py +0 -0
  17. nat/agent/rewoo_agent/agent.py +593 -0
  18. nat/agent/rewoo_agent/prompt.py +107 -0
  19. nat/agent/rewoo_agent/register.py +175 -0
  20. nat/agent/tool_calling_agent/__init__.py +0 -0
  21. nat/agent/tool_calling_agent/agent.py +246 -0
  22. nat/agent/tool_calling_agent/register.py +129 -0
  23. nat/authentication/__init__.py +14 -0
  24. nat/authentication/api_key/__init__.py +14 -0
  25. nat/authentication/api_key/api_key_auth_provider.py +96 -0
  26. nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
  27. nat/authentication/api_key/register.py +26 -0
  28. nat/authentication/credential_validator/__init__.py +14 -0
  29. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  30. nat/authentication/exceptions/__init__.py +14 -0
  31. nat/authentication/exceptions/api_key_exceptions.py +38 -0
  32. nat/authentication/http_basic_auth/__init__.py +0 -0
  33. nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  34. nat/authentication/http_basic_auth/register.py +30 -0
  35. nat/authentication/interfaces.py +96 -0
  36. nat/authentication/oauth2/__init__.py +14 -0
  37. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +140 -0
  38. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  39. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  40. nat/authentication/oauth2/register.py +25 -0
  41. nat/authentication/register.py +20 -0
  42. nat/builder/__init__.py +0 -0
  43. nat/builder/builder.py +317 -0
  44. nat/builder/component_utils.py +320 -0
  45. nat/builder/context.py +321 -0
  46. nat/builder/embedder.py +24 -0
  47. nat/builder/eval_builder.py +166 -0
  48. nat/builder/evaluator.py +29 -0
  49. nat/builder/framework_enum.py +25 -0
  50. nat/builder/front_end.py +73 -0
  51. nat/builder/function.py +714 -0
  52. nat/builder/function_base.py +380 -0
  53. nat/builder/function_info.py +625 -0
  54. nat/builder/intermediate_step_manager.py +206 -0
  55. nat/builder/llm.py +25 -0
  56. nat/builder/retriever.py +25 -0
  57. nat/builder/user_interaction_manager.py +78 -0
  58. nat/builder/workflow.py +160 -0
  59. nat/builder/workflow_builder.py +1365 -0
  60. nat/cli/__init__.py +14 -0
  61. nat/cli/cli_utils/__init__.py +0 -0
  62. nat/cli/cli_utils/config_override.py +231 -0
  63. nat/cli/cli_utils/validation.py +37 -0
  64. nat/cli/commands/__init__.py +0 -0
  65. nat/cli/commands/configure/__init__.py +0 -0
  66. nat/cli/commands/configure/channel/__init__.py +0 -0
  67. nat/cli/commands/configure/channel/add.py +28 -0
  68. nat/cli/commands/configure/channel/channel.py +34 -0
  69. nat/cli/commands/configure/channel/remove.py +30 -0
  70. nat/cli/commands/configure/channel/update.py +30 -0
  71. nat/cli/commands/configure/configure.py +33 -0
  72. nat/cli/commands/evaluate.py +139 -0
  73. nat/cli/commands/info/__init__.py +14 -0
  74. nat/cli/commands/info/info.py +47 -0
  75. nat/cli/commands/info/list_channels.py +32 -0
  76. nat/cli/commands/info/list_components.py +128 -0
  77. nat/cli/commands/mcp/__init__.py +14 -0
  78. nat/cli/commands/mcp/mcp.py +986 -0
  79. nat/cli/commands/object_store/__init__.py +14 -0
  80. nat/cli/commands/object_store/object_store.py +227 -0
  81. nat/cli/commands/optimize.py +90 -0
  82. nat/cli/commands/registry/__init__.py +14 -0
  83. nat/cli/commands/registry/publish.py +88 -0
  84. nat/cli/commands/registry/pull.py +118 -0
  85. nat/cli/commands/registry/registry.py +36 -0
  86. nat/cli/commands/registry/remove.py +108 -0
  87. nat/cli/commands/registry/search.py +153 -0
  88. nat/cli/commands/sizing/__init__.py +14 -0
  89. nat/cli/commands/sizing/calc.py +297 -0
  90. nat/cli/commands/sizing/sizing.py +27 -0
  91. nat/cli/commands/start.py +257 -0
  92. nat/cli/commands/uninstall.py +81 -0
  93. nat/cli/commands/validate.py +47 -0
  94. nat/cli/commands/workflow/__init__.py +14 -0
  95. nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
  96. nat/cli/commands/workflow/templates/config.yml.j2 +17 -0
  97. nat/cli/commands/workflow/templates/pyproject.toml.j2 +25 -0
  98. nat/cli/commands/workflow/templates/register.py.j2 +4 -0
  99. nat/cli/commands/workflow/templates/workflow.py.j2 +50 -0
  100. nat/cli/commands/workflow/workflow.py +37 -0
  101. nat/cli/commands/workflow/workflow_commands.py +403 -0
  102. nat/cli/entrypoint.py +141 -0
  103. nat/cli/main.py +60 -0
  104. nat/cli/register_workflow.py +522 -0
  105. nat/cli/type_registry.py +1069 -0
  106. nat/control_flow/__init__.py +0 -0
  107. nat/control_flow/register.py +20 -0
  108. nat/control_flow/router_agent/__init__.py +0 -0
  109. nat/control_flow/router_agent/agent.py +329 -0
  110. nat/control_flow/router_agent/prompt.py +48 -0
  111. nat/control_flow/router_agent/register.py +91 -0
  112. nat/control_flow/sequential_executor.py +166 -0
  113. nat/data_models/__init__.py +14 -0
  114. nat/data_models/agent.py +34 -0
  115. nat/data_models/api_server.py +843 -0
  116. nat/data_models/authentication.py +245 -0
  117. nat/data_models/common.py +171 -0
  118. nat/data_models/component.py +60 -0
  119. nat/data_models/component_ref.py +179 -0
  120. nat/data_models/config.py +434 -0
  121. nat/data_models/dataset_handler.py +169 -0
  122. nat/data_models/discovery_metadata.py +305 -0
  123. nat/data_models/embedder.py +27 -0
  124. nat/data_models/evaluate.py +130 -0
  125. nat/data_models/evaluator.py +26 -0
  126. nat/data_models/front_end.py +26 -0
  127. nat/data_models/function.py +64 -0
  128. nat/data_models/function_dependencies.py +80 -0
  129. nat/data_models/gated_field_mixin.py +242 -0
  130. nat/data_models/interactive.py +246 -0
  131. nat/data_models/intermediate_step.py +302 -0
  132. nat/data_models/invocation_node.py +38 -0
  133. nat/data_models/llm.py +27 -0
  134. nat/data_models/logging.py +26 -0
  135. nat/data_models/memory.py +27 -0
  136. nat/data_models/object_store.py +44 -0
  137. nat/data_models/optimizable.py +119 -0
  138. nat/data_models/optimizer.py +149 -0
  139. nat/data_models/profiler.py +54 -0
  140. nat/data_models/registry_handler.py +26 -0
  141. nat/data_models/retriever.py +30 -0
  142. nat/data_models/retry_mixin.py +35 -0
  143. nat/data_models/span.py +228 -0
  144. nat/data_models/step_adaptor.py +64 -0
  145. nat/data_models/streaming.py +33 -0
  146. nat/data_models/swe_bench_model.py +54 -0
  147. nat/data_models/telemetry_exporter.py +26 -0
  148. nat/data_models/temperature_mixin.py +44 -0
  149. nat/data_models/thinking_mixin.py +86 -0
  150. nat/data_models/top_p_mixin.py +44 -0
  151. nat/data_models/ttc_strategy.py +30 -0
  152. nat/embedder/__init__.py +0 -0
  153. nat/embedder/azure_openai_embedder.py +46 -0
  154. nat/embedder/nim_embedder.py +59 -0
  155. nat/embedder/openai_embedder.py +42 -0
  156. nat/embedder/register.py +22 -0
  157. nat/eval/__init__.py +14 -0
  158. nat/eval/config.py +62 -0
  159. nat/eval/dataset_handler/__init__.py +0 -0
  160. nat/eval/dataset_handler/dataset_downloader.py +106 -0
  161. nat/eval/dataset_handler/dataset_filter.py +52 -0
  162. nat/eval/dataset_handler/dataset_handler.py +431 -0
  163. nat/eval/evaluate.py +565 -0
  164. nat/eval/evaluator/__init__.py +14 -0
  165. nat/eval/evaluator/base_evaluator.py +77 -0
  166. nat/eval/evaluator/evaluator_model.py +58 -0
  167. nat/eval/intermediate_step_adapter.py +99 -0
  168. nat/eval/rag_evaluator/__init__.py +0 -0
  169. nat/eval/rag_evaluator/evaluate.py +178 -0
  170. nat/eval/rag_evaluator/register.py +143 -0
  171. nat/eval/register.py +26 -0
  172. nat/eval/remote_workflow.py +133 -0
  173. nat/eval/runners/__init__.py +14 -0
  174. nat/eval/runners/config.py +39 -0
  175. nat/eval/runners/multi_eval_runner.py +54 -0
  176. nat/eval/runtime_evaluator/__init__.py +14 -0
  177. nat/eval/runtime_evaluator/evaluate.py +123 -0
  178. nat/eval/runtime_evaluator/register.py +100 -0
  179. nat/eval/runtime_event_subscriber.py +52 -0
  180. nat/eval/swe_bench_evaluator/__init__.py +0 -0
  181. nat/eval/swe_bench_evaluator/evaluate.py +215 -0
  182. nat/eval/swe_bench_evaluator/register.py +36 -0
  183. nat/eval/trajectory_evaluator/__init__.py +0 -0
  184. nat/eval/trajectory_evaluator/evaluate.py +75 -0
  185. nat/eval/trajectory_evaluator/register.py +40 -0
  186. nat/eval/tunable_rag_evaluator/__init__.py +0 -0
  187. nat/eval/tunable_rag_evaluator/evaluate.py +242 -0
  188. nat/eval/tunable_rag_evaluator/register.py +52 -0
  189. nat/eval/usage_stats.py +41 -0
  190. nat/eval/utils/__init__.py +0 -0
  191. nat/eval/utils/eval_trace_ctx.py +89 -0
  192. nat/eval/utils/output_uploader.py +140 -0
  193. nat/eval/utils/tqdm_position_registry.py +40 -0
  194. nat/eval/utils/weave_eval.py +193 -0
  195. nat/experimental/__init__.py +0 -0
  196. nat/experimental/decorators/__init__.py +0 -0
  197. nat/experimental/decorators/experimental_warning_decorator.py +154 -0
  198. nat/experimental/test_time_compute/__init__.py +0 -0
  199. nat/experimental/test_time_compute/editing/__init__.py +0 -0
  200. nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
  201. nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
  202. nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
  203. nat/experimental/test_time_compute/functions/__init__.py +0 -0
  204. nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
  205. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +228 -0
  206. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
  207. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
  208. nat/experimental/test_time_compute/models/__init__.py +0 -0
  209. nat/experimental/test_time_compute/models/editor_config.py +132 -0
  210. nat/experimental/test_time_compute/models/scoring_config.py +112 -0
  211. nat/experimental/test_time_compute/models/search_config.py +120 -0
  212. nat/experimental/test_time_compute/models/selection_config.py +154 -0
  213. nat/experimental/test_time_compute/models/stage_enums.py +43 -0
  214. nat/experimental/test_time_compute/models/strategy_base.py +67 -0
  215. nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
  216. nat/experimental/test_time_compute/models/ttc_item.py +48 -0
  217. nat/experimental/test_time_compute/register.py +35 -0
  218. nat/experimental/test_time_compute/scoring/__init__.py +0 -0
  219. nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
  220. nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
  221. nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
  222. nat/experimental/test_time_compute/search/__init__.py +0 -0
  223. nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
  224. nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
  225. nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
  226. nat/experimental/test_time_compute/selection/__init__.py +0 -0
  227. nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
  228. nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
  229. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +157 -0
  230. nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
  231. nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
  232. nat/front_ends/__init__.py +14 -0
  233. nat/front_ends/console/__init__.py +14 -0
  234. nat/front_ends/console/authentication_flow_handler.py +285 -0
  235. nat/front_ends/console/console_front_end_config.py +32 -0
  236. nat/front_ends/console/console_front_end_plugin.py +108 -0
  237. nat/front_ends/console/register.py +25 -0
  238. nat/front_ends/cron/__init__.py +14 -0
  239. nat/front_ends/fastapi/__init__.py +14 -0
  240. nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  241. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  242. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +142 -0
  243. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  244. nat/front_ends/fastapi/fastapi_front_end_config.py +272 -0
  245. nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  246. nat/front_ends/fastapi/fastapi_front_end_plugin.py +247 -0
  247. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1257 -0
  248. nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
  249. nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  250. nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
  251. nat/front_ends/fastapi/job_store.py +602 -0
  252. nat/front_ends/fastapi/main.py +64 -0
  253. nat/front_ends/fastapi/message_handler.py +344 -0
  254. nat/front_ends/fastapi/message_validator.py +351 -0
  255. nat/front_ends/fastapi/register.py +25 -0
  256. nat/front_ends/fastapi/response_helpers.py +195 -0
  257. nat/front_ends/fastapi/step_adaptor.py +319 -0
  258. nat/front_ends/fastapi/utils.py +57 -0
  259. nat/front_ends/mcp/__init__.py +14 -0
  260. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  261. nat/front_ends/mcp/mcp_front_end_config.py +90 -0
  262. nat/front_ends/mcp/mcp_front_end_plugin.py +113 -0
  263. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +268 -0
  264. nat/front_ends/mcp/memory_profiler.py +320 -0
  265. nat/front_ends/mcp/register.py +27 -0
  266. nat/front_ends/mcp/tool_converter.py +290 -0
  267. nat/front_ends/register.py +21 -0
  268. nat/front_ends/simple_base/__init__.py +14 -0
  269. nat/front_ends/simple_base/simple_front_end_plugin_base.py +56 -0
  270. nat/llm/__init__.py +0 -0
  271. nat/llm/aws_bedrock_llm.py +69 -0
  272. nat/llm/azure_openai_llm.py +57 -0
  273. nat/llm/litellm_llm.py +69 -0
  274. nat/llm/nim_llm.py +58 -0
  275. nat/llm/openai_llm.py +54 -0
  276. nat/llm/register.py +27 -0
  277. nat/llm/utils/__init__.py +14 -0
  278. nat/llm/utils/env_config_value.py +93 -0
  279. nat/llm/utils/error.py +17 -0
  280. nat/llm/utils/thinking.py +215 -0
  281. nat/memory/__init__.py +20 -0
  282. nat/memory/interfaces.py +183 -0
  283. nat/memory/models.py +112 -0
  284. nat/meta/pypi.md +58 -0
  285. nat/object_store/__init__.py +20 -0
  286. nat/object_store/in_memory_object_store.py +76 -0
  287. nat/object_store/interfaces.py +84 -0
  288. nat/object_store/models.py +38 -0
  289. nat/object_store/register.py +19 -0
  290. nat/observability/__init__.py +14 -0
  291. nat/observability/exporter/__init__.py +14 -0
  292. nat/observability/exporter/base_exporter.py +449 -0
  293. nat/observability/exporter/exporter.py +78 -0
  294. nat/observability/exporter/file_exporter.py +33 -0
  295. nat/observability/exporter/processing_exporter.py +550 -0
  296. nat/observability/exporter/raw_exporter.py +52 -0
  297. nat/observability/exporter/span_exporter.py +308 -0
  298. nat/observability/exporter_manager.py +335 -0
  299. nat/observability/mixin/__init__.py +14 -0
  300. nat/observability/mixin/batch_config_mixin.py +26 -0
  301. nat/observability/mixin/collector_config_mixin.py +23 -0
  302. nat/observability/mixin/file_mixin.py +288 -0
  303. nat/observability/mixin/file_mode.py +23 -0
  304. nat/observability/mixin/redaction_config_mixin.py +42 -0
  305. nat/observability/mixin/resource_conflict_mixin.py +134 -0
  306. nat/observability/mixin/serialize_mixin.py +61 -0
  307. nat/observability/mixin/tagging_config_mixin.py +62 -0
  308. nat/observability/mixin/type_introspection_mixin.py +496 -0
  309. nat/observability/processor/__init__.py +14 -0
  310. nat/observability/processor/batching_processor.py +308 -0
  311. nat/observability/processor/callback_processor.py +42 -0
  312. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  313. nat/observability/processor/intermediate_step_serializer.py +28 -0
  314. nat/observability/processor/processor.py +74 -0
  315. nat/observability/processor/processor_factory.py +70 -0
  316. nat/observability/processor/redaction/__init__.py +24 -0
  317. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  318. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  319. nat/observability/processor/redaction/redaction_processor.py +177 -0
  320. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  321. nat/observability/processor/span_tagging_processor.py +68 -0
  322. nat/observability/register.py +114 -0
  323. nat/observability/utils/__init__.py +14 -0
  324. nat/observability/utils/dict_utils.py +236 -0
  325. nat/observability/utils/time_utils.py +31 -0
  326. nat/plugins/.namespace +1 -0
  327. nat/profiler/__init__.py +0 -0
  328. nat/profiler/calc/__init__.py +14 -0
  329. nat/profiler/calc/calc_runner.py +626 -0
  330. nat/profiler/calc/calculations.py +288 -0
  331. nat/profiler/calc/data_models.py +188 -0
  332. nat/profiler/calc/plot.py +345 -0
  333. nat/profiler/callbacks/__init__.py +0 -0
  334. nat/profiler/callbacks/agno_callback_handler.py +295 -0
  335. nat/profiler/callbacks/base_callback_class.py +20 -0
  336. nat/profiler/callbacks/langchain_callback_handler.py +297 -0
  337. nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
  338. nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
  339. nat/profiler/callbacks/token_usage_base_model.py +27 -0
  340. nat/profiler/data_frame_row.py +51 -0
  341. nat/profiler/data_models.py +24 -0
  342. nat/profiler/decorators/__init__.py +0 -0
  343. nat/profiler/decorators/framework_wrapper.py +180 -0
  344. nat/profiler/decorators/function_tracking.py +411 -0
  345. nat/profiler/forecasting/__init__.py +0 -0
  346. nat/profiler/forecasting/config.py +18 -0
  347. nat/profiler/forecasting/model_trainer.py +75 -0
  348. nat/profiler/forecasting/models/__init__.py +22 -0
  349. nat/profiler/forecasting/models/forecasting_base_model.py +42 -0
  350. nat/profiler/forecasting/models/linear_model.py +197 -0
  351. nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
  352. nat/profiler/inference_metrics_model.py +28 -0
  353. nat/profiler/inference_optimization/__init__.py +0 -0
  354. nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
  355. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
  356. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
  357. nat/profiler/inference_optimization/data_models.py +386 -0
  358. nat/profiler/inference_optimization/experimental/__init__.py +0 -0
  359. nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
  360. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +404 -0
  361. nat/profiler/inference_optimization/llm_metrics.py +212 -0
  362. nat/profiler/inference_optimization/prompt_caching.py +163 -0
  363. nat/profiler/inference_optimization/token_uniqueness.py +107 -0
  364. nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
  365. nat/profiler/intermediate_property_adapter.py +102 -0
  366. nat/profiler/parameter_optimization/__init__.py +0 -0
  367. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  368. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  369. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  370. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  371. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  372. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  373. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  374. nat/profiler/profile_runner.py +478 -0
  375. nat/profiler/utils.py +186 -0
  376. nat/registry_handlers/__init__.py +0 -0
  377. nat/registry_handlers/local/__init__.py +0 -0
  378. nat/registry_handlers/local/local_handler.py +176 -0
  379. nat/registry_handlers/local/register_local.py +37 -0
  380. nat/registry_handlers/metadata_factory.py +60 -0
  381. nat/registry_handlers/package_utils.py +570 -0
  382. nat/registry_handlers/pypi/__init__.py +0 -0
  383. nat/registry_handlers/pypi/pypi_handler.py +248 -0
  384. nat/registry_handlers/pypi/register_pypi.py +40 -0
  385. nat/registry_handlers/register.py +20 -0
  386. nat/registry_handlers/registry_handler_base.py +157 -0
  387. nat/registry_handlers/rest/__init__.py +0 -0
  388. nat/registry_handlers/rest/register_rest.py +56 -0
  389. nat/registry_handlers/rest/rest_handler.py +236 -0
  390. nat/registry_handlers/schemas/__init__.py +0 -0
  391. nat/registry_handlers/schemas/headers.py +42 -0
  392. nat/registry_handlers/schemas/package.py +68 -0
  393. nat/registry_handlers/schemas/publish.py +68 -0
  394. nat/registry_handlers/schemas/pull.py +82 -0
  395. nat/registry_handlers/schemas/remove.py +36 -0
  396. nat/registry_handlers/schemas/search.py +91 -0
  397. nat/registry_handlers/schemas/status.py +47 -0
  398. nat/retriever/__init__.py +0 -0
  399. nat/retriever/interface.py +41 -0
  400. nat/retriever/milvus/__init__.py +14 -0
  401. nat/retriever/milvus/register.py +81 -0
  402. nat/retriever/milvus/retriever.py +228 -0
  403. nat/retriever/models.py +77 -0
  404. nat/retriever/nemo_retriever/__init__.py +14 -0
  405. nat/retriever/nemo_retriever/register.py +60 -0
  406. nat/retriever/nemo_retriever/retriever.py +190 -0
  407. nat/retriever/register.py +21 -0
  408. nat/runtime/__init__.py +14 -0
  409. nat/runtime/loader.py +220 -0
  410. nat/runtime/runner.py +292 -0
  411. nat/runtime/session.py +223 -0
  412. nat/runtime/user_metadata.py +130 -0
  413. nat/settings/__init__.py +0 -0
  414. nat/settings/global_settings.py +329 -0
  415. nat/test/.namespace +1 -0
  416. nat/tool/__init__.py +0 -0
  417. nat/tool/chat_completion.py +77 -0
  418. nat/tool/code_execution/README.md +151 -0
  419. nat/tool/code_execution/__init__.py +0 -0
  420. nat/tool/code_execution/code_sandbox.py +267 -0
  421. nat/tool/code_execution/local_sandbox/.gitignore +1 -0
  422. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
  423. nat/tool/code_execution/local_sandbox/__init__.py +13 -0
  424. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
  425. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
  426. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
  427. nat/tool/code_execution/register.py +74 -0
  428. nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
  429. nat/tool/code_execution/utils.py +100 -0
  430. nat/tool/datetime_tools.py +82 -0
  431. nat/tool/document_search.py +141 -0
  432. nat/tool/github_tools.py +450 -0
  433. nat/tool/memory_tools/__init__.py +0 -0
  434. nat/tool/memory_tools/add_memory_tool.py +79 -0
  435. nat/tool/memory_tools/delete_memory_tool.py +66 -0
  436. nat/tool/memory_tools/get_memory_tool.py +72 -0
  437. nat/tool/nvidia_rag.py +95 -0
  438. nat/tool/register.py +31 -0
  439. nat/tool/retriever.py +95 -0
  440. nat/tool/server_tools.py +66 -0
  441. nat/utils/__init__.py +0 -0
  442. nat/utils/callable_utils.py +70 -0
  443. nat/utils/data_models/__init__.py +0 -0
  444. nat/utils/data_models/schema_validator.py +58 -0
  445. nat/utils/debugging_utils.py +43 -0
  446. nat/utils/decorators.py +210 -0
  447. nat/utils/dump_distro_mapping.py +32 -0
  448. nat/utils/exception_handlers/__init__.py +0 -0
  449. nat/utils/exception_handlers/automatic_retries.py +342 -0
  450. nat/utils/exception_handlers/schemas.py +114 -0
  451. nat/utils/io/__init__.py +0 -0
  452. nat/utils/io/model_processing.py +28 -0
  453. nat/utils/io/yaml_tools.py +119 -0
  454. nat/utils/log_levels.py +25 -0
  455. nat/utils/log_utils.py +37 -0
  456. nat/utils/metadata_utils.py +74 -0
  457. nat/utils/optional_imports.py +142 -0
  458. nat/utils/producer_consumer_queue.py +178 -0
  459. nat/utils/reactive/__init__.py +0 -0
  460. nat/utils/reactive/base/__init__.py +0 -0
  461. nat/utils/reactive/base/observable_base.py +65 -0
  462. nat/utils/reactive/base/observer_base.py +55 -0
  463. nat/utils/reactive/base/subject_base.py +79 -0
  464. nat/utils/reactive/observable.py +59 -0
  465. nat/utils/reactive/observer.py +76 -0
  466. nat/utils/reactive/subject.py +131 -0
  467. nat/utils/reactive/subscription.py +49 -0
  468. nat/utils/settings/__init__.py +0 -0
  469. nat/utils/settings/global_settings.py +195 -0
  470. nat/utils/string_utils.py +38 -0
  471. nat/utils/type_converter.py +299 -0
  472. nat/utils/type_utils.py +488 -0
  473. nat/utils/url_utils.py +27 -0
  474. nvidia_nat-1.1.0a20251020.dist-info/METADATA +195 -0
  475. nvidia_nat-1.1.0a20251020.dist-info/RECORD +480 -0
  476. nvidia_nat-1.1.0a20251020.dist-info/WHEEL +5 -0
  477. nvidia_nat-1.1.0a20251020.dist-info/entry_points.txt +22 -0
  478. nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  479. nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE.md +201 -0
  480. nvidia_nat-1.1.0a20251020.dist-info/top_level.txt +2 -0
@@ -0,0 +1,153 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import logging
18
+ from collections.abc import Mapping as Dict
19
+
20
+ import optuna
21
+ import yaml
22
+
23
+ from nat.data_models.config import Config
24
+ from nat.data_models.optimizable import SearchSpace
25
+ from nat.data_models.optimizer import OptimizerConfig
26
+ from nat.data_models.optimizer import OptimizerRunConfig
27
+ from nat.eval.evaluate import EvaluationRun
28
+ from nat.eval.evaluate import EvaluationRunConfig
29
+ from nat.experimental.decorators.experimental_warning_decorator import experimental
30
+ from nat.profiler.parameter_optimization.parameter_selection import pick_trial
31
+ from nat.profiler.parameter_optimization.update_helpers import apply_suggestions
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @experimental(feature_name="Optimizer")
37
+ def optimize_parameters(
38
+ *,
39
+ base_cfg: Config,
40
+ full_space: Dict[str, SearchSpace],
41
+ optimizer_config: OptimizerConfig,
42
+ opt_run_config: OptimizerRunConfig,
43
+ ) -> Config:
44
+ """Tune all *non-prompt* hyper-parameters and persist the best config."""
45
+ space = {k: v for k, v in full_space.items() if not v.is_prompt}
46
+
47
+ # Ensure output_path is not None
48
+ if optimizer_config.output_path is None:
49
+ raise ValueError("optimizer_config.output_path cannot be None")
50
+ out_dir = optimizer_config.output_path
51
+ out_dir.mkdir(parents=True, exist_ok=True)
52
+
53
+ # Ensure eval_metrics is not None
54
+ if optimizer_config.eval_metrics is None:
55
+ raise ValueError("optimizer_config.eval_metrics cannot be None")
56
+
57
+ metric_cfg = optimizer_config.eval_metrics
58
+ directions = [v.direction for v in metric_cfg.values()]
59
+ eval_metrics = [v.evaluator_name for v in metric_cfg.values()]
60
+ weights = [v.weight for v in metric_cfg.values()]
61
+
62
+ study = optuna.create_study(directions=directions)
63
+
64
+ # Create output directory for intermediate files
65
+ out_dir = optimizer_config.output_path
66
+ out_dir.mkdir(parents=True, exist_ok=True)
67
+
68
+ async def _run_eval(runner: EvaluationRun):
69
+ return await runner.run_and_evaluate()
70
+
71
+ def _objective(trial: optuna.Trial):
72
+ reps = max(1, getattr(optimizer_config, "reps_per_param_set", 1))
73
+
74
+ # build trial config
75
+ suggestions = {p: spec.suggest(trial, p) for p, spec in space.items()}
76
+ cfg_trial = apply_suggestions(base_cfg, suggestions)
77
+
78
+ async def _single_eval(trial_idx: int) -> list[float]: # noqa: ARG001
79
+ eval_cfg = EvaluationRunConfig(
80
+ config_file=cfg_trial,
81
+ dataset=opt_run_config.dataset,
82
+ result_json_path=opt_run_config.result_json_path,
83
+ endpoint=opt_run_config.endpoint,
84
+ endpoint_timeout=opt_run_config.endpoint_timeout,
85
+ )
86
+ scores = await _run_eval(EvaluationRun(config=eval_cfg))
87
+ values = []
88
+ for metric_name in eval_metrics:
89
+ metric = next(r[1] for r in scores.evaluation_results if r[0] == metric_name)
90
+ values.append(metric.average_score)
91
+
92
+ return values
93
+
94
+ # Create tasks for all evaluations
95
+ async def _run_all_evals():
96
+ tasks = [_single_eval(i) for i in range(reps)]
97
+ return await asyncio.gather(*tasks)
98
+
99
+ with (out_dir / f"config_numeric_trial_{trial._trial_id}.yml").open("w") as fh:
100
+ yaml.dump(cfg_trial.model_dump(), fh)
101
+
102
+ all_scores = asyncio.run(_run_all_evals())
103
+ # Persist raw per‑repetition scores so they appear in `trials_dataframe`.
104
+ trial.set_user_attr("rep_scores", all_scores)
105
+ return [sum(run[i] for run in all_scores) / reps for i in range(len(eval_metrics))]
106
+
107
+ logger.info("Starting numeric / enum parameter optimization...")
108
+ study.optimize(_objective, n_trials=optimizer_config.numeric.n_trials)
109
+ logger.info("Numeric optimization finished")
110
+
111
+ best_params = pick_trial(
112
+ study=study,
113
+ mode=optimizer_config.multi_objective_combination_mode,
114
+ weights=weights,
115
+ ).params
116
+ tuned_cfg = apply_suggestions(base_cfg, best_params)
117
+
118
+ # Save final results (out_dir already created and defined above)
119
+ with (out_dir / "optimized_config.yml").open("w") as fh:
120
+ yaml.dump(tuned_cfg.model_dump(), fh)
121
+ with (out_dir / "trials_dataframe_params.csv").open("w") as fh:
122
+ # Export full trials DataFrame (values, params, timings, etc.).
123
+ df = study.trials_dataframe()
124
+ # Normalise rep_scores column naming for convenience.
125
+ if "user_attrs_rep_scores" in df.columns and "rep_scores" not in df.columns:
126
+ df = df.rename(columns={"user_attrs_rep_scores": "rep_scores"})
127
+ elif "user_attrs" in df.columns and "rep_scores" not in df.columns:
128
+ # Some Optuna versions return a dict in a single user_attrs column.
129
+ df["rep_scores"] = df["user_attrs"].apply(lambda d: d.get("rep_scores") if isinstance(d, dict) else None)
130
+ df = df.drop(columns=["user_attrs"])
131
+ df.to_csv(fh, index=False)
132
+
133
+ # Generate Pareto front visualizations
134
+ try:
135
+ from nat.profiler.parameter_optimization.pareto_visualizer import create_pareto_visualization
136
+ logger.info("Generating Pareto front visualizations...")
137
+ create_pareto_visualization(
138
+ data_source=study,
139
+ metric_names=eval_metrics,
140
+ directions=directions,
141
+ output_dir=out_dir / "plots",
142
+ title_prefix="Parameter Optimization",
143
+ show_plots=False # Don't show plots in automated runs
144
+ )
145
+ logger.info("Pareto visualizations saved to: %s", out_dir / "plots")
146
+ except ImportError as ie:
147
+ logger.warning("Could not import visualization dependencies: %s. "
148
+ "Have you installed nvidia-nat-profiling?",
149
+ ie)
150
+ except Exception as e:
151
+ logger.warning("Failed to generate visualizations: %s", e)
152
+
153
+ return tuned_cfg
@@ -0,0 +1,107 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections.abc import Sequence
17
+
18
+ import numpy as np
19
+ import optuna
20
+ from optuna._hypervolume import compute_hypervolume
21
+ from optuna.study import Study
22
+ from optuna.study import StudyDirection
23
+
24
+
25
+ # ---------- helper ----------
26
+ def _to_minimisation_matrix(
27
+ trials: Sequence[optuna.trial.FrozenTrial],
28
+ directions: Sequence[StudyDirection],
29
+ ) -> np.ndarray:
30
+ """Return array (n_trials × n_objectives) where **all** objectives are ‘smaller-is-better’."""
31
+ vals = np.asarray([t.values for t in trials], dtype=float)
32
+ for j, d in enumerate(directions):
33
+ if d == StudyDirection.MAXIMIZE:
34
+ vals[:, j] *= -1.0 # flip sign
35
+ return vals
36
+
37
+
38
+ # ---------- public API ----------
39
+ def pick_trial(
40
+ study: Study,
41
+ mode: str = "harmonic",
42
+ *,
43
+ weights: Sequence[float] | None = None,
44
+ ref_point: Sequence[float] | None = None,
45
+ eps: float = 1e-12,
46
+ ) -> optuna.trial.FrozenTrial:
47
+ """
48
+ Collapse Optuna’s Pareto front (`study.best_trials`) to a single “best compromise”.
49
+
50
+ Parameters
51
+ ----------
52
+ study : completed **multi-objective** Optuna study
53
+ mode : {"harmonic", "sum", "chebyshev", "hypervolume"}
54
+ weights : per-objective weights (used only for "sum")
55
+ ref_point : reference point for hyper-volume (defaults to ones after normalisation)
56
+ eps : tiny value to avoid division by zero
57
+
58
+ Returns
59
+ -------
60
+ optuna.trial.FrozenTrial
61
+ """
62
+
63
+ # ---- 1. Pareto front ----
64
+ front = study.best_trials
65
+ if not front:
66
+ raise ValueError("`study.best_trials` is empty – no Pareto-optimal trials found.")
67
+
68
+ # ---- 2. Convert & normalise objectives ----
69
+ vals = _to_minimisation_matrix(front, study.directions) # smaller is better
70
+ span = np.ptp(vals, axis=0)
71
+ norm = (vals - vals.min(axis=0)) / (span + eps) # 0 = best, 1 = worst
72
+
73
+ # ---- 3. Scalarise according to chosen mode ----
74
+ mode = mode.lower()
75
+
76
+ if mode == "harmonic":
77
+ hmean = norm.shape[1] / (1.0 / (norm + eps)).sum(axis=1)
78
+ best_idx = hmean.argmin() # lower = better
79
+
80
+ elif mode == "sum":
81
+ w = np.ones(norm.shape[1]) if weights is None else np.asarray(weights, float)
82
+ if w.size != norm.shape[1]:
83
+ raise ValueError("`weights` length must equal number of objectives.")
84
+ score = norm @ w
85
+ best_idx = score.argmin()
86
+
87
+ elif mode == "chebyshev":
88
+ score = norm.max(axis=1) # worst dimension
89
+ best_idx = score.argmin()
90
+
91
+ elif mode == "hypervolume":
92
+ # Hyper-volume assumes points are *below* the reference point (minimisation space).
93
+ if len(front) == 0:
94
+ raise ValueError("Pareto front is empty - no trials to select from")
95
+ elif len(front) == 1:
96
+ best_idx = 0
97
+ else:
98
+ rp = np.ones(norm.shape[1]) if ref_point is None else np.asarray(ref_point, float)
99
+ base_hv = compute_hypervolume(norm, rp)
100
+ contrib = np.array([base_hv - compute_hypervolume(np.delete(norm, i, 0), rp) for i in range(len(front))])
101
+ best_idx = contrib.argmax() # bigger contribution wins
102
+
103
+ else:
104
+ raise ValueError(f"Unknown mode '{mode}'. Choose from "
105
+ "'harmonic', 'sum', 'chebyshev', 'hypervolume'.")
106
+
107
+ return front[best_idx]
@@ -0,0 +1,380 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # flake8: noqa: W293
16
+
17
+ import logging
18
+ from pathlib import Path
19
+
20
+ import matplotlib.pyplot as plt
21
+ import numpy as np
22
+ import optuna
23
+ import pandas as pd
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class ParetoVisualizer:
29
+
30
+ def __init__(self, metric_names: list[str], directions: list[str], title_prefix: str = "Optimization Results"):
31
+ self.metric_names = metric_names
32
+ self.directions = directions
33
+ self.title_prefix = title_prefix
34
+
35
+ if len(metric_names) != len(directions):
36
+ raise ValueError("Number of metric names must match number of directions")
37
+
38
+ def plot_pareto_front_2d(self,
39
+ trials_df: pd.DataFrame,
40
+ pareto_trials_df: pd.DataFrame | None = None,
41
+ save_path: Path | None = None,
42
+ figsize: tuple[int, int] = (10, 8),
43
+ show_plot: bool = True) -> plt.Figure:
44
+ if len(self.metric_names) != 2:
45
+ raise ValueError("2D Pareto front visualization requires exactly 2 metrics")
46
+
47
+ fig, ax = plt.subplots(figsize=figsize)
48
+
49
+ # Extract metric values
50
+ x_vals = trials_df[f"values_{0}"].values
51
+ y_vals = trials_df[f"values_{1}"].values
52
+
53
+ # Plot all trials
54
+ ax.scatter(x_vals,
55
+ y_vals,
56
+ alpha=0.6,
57
+ s=50,
58
+ c='lightblue',
59
+ label=f'All Trials (n={len(trials_df)})',
60
+ edgecolors='navy',
61
+ linewidths=0.5)
62
+
63
+ # Plot Pareto optimal trials if provided
64
+ if pareto_trials_df is not None and not pareto_trials_df.empty:
65
+ pareto_x = pareto_trials_df[f"values_{0}"].values
66
+ pareto_y = pareto_trials_df[f"values_{1}"].values
67
+
68
+ ax.scatter(pareto_x,
69
+ pareto_y,
70
+ alpha=0.9,
71
+ s=100,
72
+ c='red',
73
+ label=f'Pareto Optimal (n={len(pareto_trials_df)})',
74
+ edgecolors='darkred',
75
+ linewidths=1.5,
76
+ marker='*')
77
+
78
+ # Draw Pareto front line (only for 2D)
79
+ if len(pareto_x) > 1:
80
+ # Sort points for line drawing based on first objective
81
+ sorted_indices = np.argsort(pareto_x)
82
+ ax.plot(pareto_x[sorted_indices],
83
+ pareto_y[sorted_indices],
84
+ 'r--',
85
+ alpha=0.7,
86
+ linewidth=2,
87
+ label='Pareto Front')
88
+
89
+ # Customize plot
90
+ x_direction = "↓" if self.directions[0] == "minimize" else "↑"
91
+ y_direction = "↓" if self.directions[1] == "minimize" else "↑"
92
+
93
+ ax.set_xlabel(f"{self.metric_names[0]} {x_direction}", fontsize=12)
94
+ ax.set_ylabel(f"{self.metric_names[1]} {y_direction}", fontsize=12)
95
+ ax.set_title(f"{self.title_prefix}: Pareto Front Visualization", fontsize=14, fontweight='bold')
96
+
97
+ ax.legend(loc='best', frameon=True, fancybox=True, shadow=True)
98
+ ax.grid(True, alpha=0.3)
99
+
100
+ # Add direction annotations
101
+ x_annotation = (f"Better {self.metric_names[0]} →"
102
+ if self.directions[0] == "minimize" else f"← Better {self.metric_names[0]}")
103
+ ax.annotate(x_annotation,
104
+ xy=(0.02, 0.98),
105
+ xycoords='axes fraction',
106
+ ha='left',
107
+ va='top',
108
+ fontsize=10,
109
+ style='italic',
110
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.7))
111
+
112
+ y_annotation = (f"Better {self.metric_names[1]} ↑"
113
+ if self.directions[1] == "minimize" else f"Better {self.metric_names[1]} ↓")
114
+ ax.annotate(y_annotation,
115
+ xy=(0.02, 0.02),
116
+ xycoords='axes fraction',
117
+ ha='left',
118
+ va='bottom',
119
+ fontsize=10,
120
+ style='italic',
121
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgreen", alpha=0.7))
122
+
123
+ plt.tight_layout()
124
+
125
+ if save_path:
126
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
127
+ logger.info("2D Pareto plot saved to: %s", save_path)
128
+
129
+ if show_plot:
130
+ plt.show()
131
+
132
+ return fig
133
+
134
+ def plot_pareto_parallel_coordinates(self,
135
+ trials_df: pd.DataFrame,
136
+ pareto_trials_df: pd.DataFrame | None = None,
137
+ save_path: Path | None = None,
138
+ figsize: tuple[int, int] = (12, 8),
139
+ show_plot: bool = True) -> plt.Figure:
140
+ fig, ax = plt.subplots(figsize=figsize)
141
+
142
+ n_metrics = len(self.metric_names)
143
+ x_positions = np.arange(n_metrics)
144
+
145
+ # Normalize values for better visualization
146
+ all_values = []
147
+ for i in range(n_metrics):
148
+ all_values.append(trials_df[f"values_{i}"].values)
149
+
150
+ # Normalize each metric to [0, 1] for parallel coordinates
151
+ normalized_values = []
152
+ for i, values in enumerate(all_values):
153
+ min_val, max_val = values.min(), values.max()
154
+ if max_val > min_val:
155
+ if self.directions[i] == "minimize":
156
+ # For minimize: lower values get higher normalized scores
157
+ norm_vals = 1 - (values - min_val) / (max_val - min_val)
158
+ else:
159
+ # For maximize: higher values get higher normalized scores
160
+ norm_vals = (values - min_val) / (max_val - min_val)
161
+ else:
162
+ norm_vals = np.ones_like(values) * 0.5
163
+ normalized_values.append(norm_vals)
164
+
165
+ # Plot all trials
166
+ for i in range(len(trials_df)):
167
+ trial_values = [normalized_values[j][i] for j in range(n_metrics)]
168
+ ax.plot(x_positions, trial_values, 'b-', alpha=0.1, linewidth=1)
169
+
170
+ # Plot Pareto optimal trials
171
+ if pareto_trials_df is not None and not pareto_trials_df.empty:
172
+ pareto_indices = pareto_trials_df.index
173
+ for idx in pareto_indices:
174
+ if idx < len(trials_df):
175
+ trial_values = [normalized_values[j][idx] for j in range(n_metrics)]
176
+ ax.plot(x_positions, trial_values, 'r-', alpha=0.8, linewidth=3)
177
+
178
+ # Customize plot
179
+ ax.set_xticks(x_positions)
180
+ ax.set_xticklabels([f"{name}\n({direction})" for name, direction in zip(self.metric_names, self.directions)])
181
+ ax.set_ylabel("Normalized Performance (Higher is Better)", fontsize=12)
182
+ ax.set_title(f"{self.title_prefix}: Parallel Coordinates Plot", fontsize=14, fontweight='bold')
183
+ ax.set_ylim(-0.05, 1.05)
184
+ ax.grid(True, alpha=0.3)
185
+
186
+ # Add legend
187
+ from matplotlib.lines import Line2D
188
+ legend_elements = [
189
+ Line2D([0], [0], color='blue', alpha=0.3, linewidth=2, label='All Trials'),
190
+ Line2D([0], [0], color='red', alpha=0.8, linewidth=3, label='Pareto Optimal')
191
+ ]
192
+ ax.legend(handles=legend_elements, loc='best')
193
+
194
+ plt.tight_layout()
195
+
196
+ if save_path:
197
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
198
+ logger.info("Parallel coordinates plot saved to: %s", save_path)
199
+
200
+ if show_plot:
201
+ plt.show()
202
+
203
+ return fig
204
+
205
+ def plot_pairwise_matrix(self,
206
+ trials_df: pd.DataFrame,
207
+ pareto_trials_df: pd.DataFrame | None = None,
208
+ save_path: Path | None = None,
209
+ figsize: tuple[int, int] | None = None,
210
+ show_plot: bool = True) -> plt.Figure:
211
+ n_metrics = len(self.metric_names)
212
+ if figsize is None:
213
+ figsize = (4 * n_metrics, 4 * n_metrics)
214
+
215
+ fig, axes = plt.subplots(n_metrics, n_metrics, figsize=figsize)
216
+ fig.suptitle(f"{self.title_prefix}: Pairwise Metric Comparison", fontsize=16, fontweight='bold')
217
+
218
+ for i in range(n_metrics):
219
+ for j in range(n_metrics):
220
+ ax = axes[i, j] if n_metrics > 1 else axes
221
+
222
+ if i == j:
223
+ # Diagonal: histograms
224
+ values = trials_df[f"values_{i}"].values
225
+ ax.hist(values, bins=20, alpha=0.7, color='lightblue', edgecolor='navy')
226
+ if pareto_trials_df is not None and not pareto_trials_df.empty:
227
+ pareto_values = pareto_trials_df[f"values_{i}"].values
228
+ ax.hist(pareto_values, bins=20, alpha=0.8, color='red', edgecolor='darkred')
229
+ ax.set_xlabel(f"{self.metric_names[i]}")
230
+ ax.set_ylabel("Frequency")
231
+ else:
232
+ # Off-diagonal: scatter plots
233
+ x_vals = trials_df[f"values_{j}"].values
234
+ y_vals = trials_df[f"values_{i}"].values
235
+
236
+ ax.scatter(x_vals, y_vals, alpha=0.6, s=30, c='lightblue', edgecolors='navy', linewidths=0.5)
237
+
238
+ if pareto_trials_df is not None and not pareto_trials_df.empty:
239
+ pareto_x = pareto_trials_df[f"values_{j}"].values
240
+ pareto_y = pareto_trials_df[f"values_{i}"].values
241
+ ax.scatter(pareto_x,
242
+ pareto_y,
243
+ alpha=0.9,
244
+ s=60,
245
+ c='red',
246
+ edgecolors='darkred',
247
+ linewidths=1,
248
+ marker='*')
249
+
250
+ ax.set_xlabel(f"{self.metric_names[j]} ({self.directions[j]})")
251
+ ax.set_ylabel(f"{self.metric_names[i]} ({self.directions[i]})")
252
+
253
+ ax.grid(True, alpha=0.3)
254
+
255
+ plt.tight_layout()
256
+
257
+ if save_path:
258
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
259
+ logger.info("Pairwise matrix plot saved to: %s", save_path)
260
+
261
+ if show_plot:
262
+ plt.show()
263
+
264
+ return fig
265
+
266
+
267
+ def load_trials_from_study(study: optuna.Study) -> tuple[pd.DataFrame, pd.DataFrame]:
268
+ # Get all trials
269
+ trials_df = study.trials_dataframe()
270
+
271
+ # Get Pareto optimal trials
272
+ pareto_trials = study.best_trials
273
+ pareto_trial_numbers = [trial.number for trial in pareto_trials]
274
+ pareto_trials_df = trials_df[trials_df['number'].isin(pareto_trial_numbers)]
275
+
276
+ return trials_df, pareto_trials_df
277
+
278
+
279
+ def load_trials_from_csv(csv_path: Path, metric_names: list[str],
280
+ directions: list[str]) -> tuple[pd.DataFrame, pd.DataFrame]:
281
+ trials_df = pd.read_csv(csv_path)
282
+
283
+ # Extract values columns
284
+ value_cols = [col for col in trials_df.columns if col.startswith('values_')]
285
+ if not value_cols:
286
+ raise ValueError("CSV file must contain 'values_' columns with metric scores")
287
+
288
+ # Compute Pareto optimal solutions manually
289
+ pareto_mask = compute_pareto_optimal_mask(trials_df, value_cols, directions)
290
+ pareto_trials_df = trials_df[pareto_mask]
291
+
292
+ return trials_df, pareto_trials_df
293
+
294
+
295
+ def compute_pareto_optimal_mask(df: pd.DataFrame, value_cols: list[str], directions: list[str]) -> np.ndarray:
296
+ values = df[value_cols].values
297
+ n_trials = len(values)
298
+
299
+ # Normalize directions: convert all to maximization
300
+ normalized_values = values.copy()
301
+ for i, direction in enumerate(directions):
302
+ if direction == "minimize":
303
+ normalized_values[:, i] = -normalized_values[:, i]
304
+
305
+ is_pareto = np.ones(n_trials, dtype=bool)
306
+
307
+ for i in range(n_trials):
308
+ if is_pareto[i]:
309
+ # Compare with all other solutions
310
+ dominates = np.all(normalized_values[i] >= normalized_values, axis=1) & \
311
+ np.any(normalized_values[i] > normalized_values, axis=1)
312
+ is_pareto[dominates] = False
313
+
314
+ return is_pareto
315
+
316
+
317
+ def create_pareto_visualization(data_source: optuna.Study | Path | pd.DataFrame,
318
+ metric_names: list[str],
319
+ directions: list[str],
320
+ output_dir: Path | None = None,
321
+ title_prefix: str = "Optimization Results",
322
+ show_plots: bool = True) -> dict[str, plt.Figure]:
323
+ # Load data based on source type
324
+ if hasattr(data_source, 'trials_dataframe'):
325
+ # Optuna study object
326
+ trials_df, pareto_trials_df = load_trials_from_study(data_source)
327
+ elif isinstance(data_source, str | Path):
328
+ # CSV file path
329
+ trials_df, pareto_trials_df = load_trials_from_csv(Path(data_source), metric_names, directions)
330
+ elif isinstance(data_source, pd.DataFrame):
331
+ # DataFrame
332
+ trials_df = data_source
333
+ value_cols = [col for col in trials_df.columns if col.startswith('values_')]
334
+ pareto_mask = compute_pareto_optimal_mask(trials_df, value_cols, directions)
335
+ pareto_trials_df = trials_df[pareto_mask]
336
+ else:
337
+ raise ValueError("data_source must be an Optuna study, CSV file path, or pandas DataFrame")
338
+
339
+ visualizer = ParetoVisualizer(metric_names, directions, title_prefix)
340
+ figures = {}
341
+
342
+ logger.info("Creating Pareto front visualizations...")
343
+ logger.info("Total trials: %d", len(trials_df))
344
+ logger.info("Pareto optimal trials: %d", len(pareto_trials_df))
345
+
346
+ # Create output directory if specified
347
+ if output_dir:
348
+ output_dir = Path(output_dir)
349
+ output_dir.mkdir(parents=True, exist_ok=True)
350
+
351
+ try:
352
+ if len(metric_names) == 2:
353
+ # 2D scatter plot
354
+ save_path = output_dir / "pareto_front_2d.png" if output_dir else None
355
+ fig = visualizer.plot_pareto_front_2d(trials_df, pareto_trials_df, save_path, show_plot=show_plots)
356
+ figures["2d_scatter"] = fig
357
+
358
+ if len(metric_names) >= 2:
359
+ # Parallel coordinates plot
360
+ save_path = output_dir / "pareto_parallel_coordinates.png" if output_dir else None
361
+ fig = visualizer.plot_pareto_parallel_coordinates(trials_df,
362
+ pareto_trials_df,
363
+ save_path,
364
+ show_plot=show_plots)
365
+ figures["parallel_coordinates"] = fig
366
+
367
+ # Pairwise matrix plot
368
+ save_path = output_dir / "pareto_pairwise_matrix.png" if output_dir else None
369
+ fig = visualizer.plot_pairwise_matrix(trials_df, pareto_trials_df, save_path, show_plot=show_plots)
370
+ figures["pairwise_matrix"] = fig
371
+
372
+ logger.info("Visualization complete!")
373
+ if output_dir:
374
+ logger.info("Plots saved to: %s", output_dir)
375
+
376
+ except Exception as e:
377
+ logger.error("Error creating visualizations: %s", e)
378
+ raise
379
+
380
+ return figures