nvidia-nat 1.1.0a20251020__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/__init__.py +0 -0
- nat/agent/base.py +265 -0
- nat/agent/dual_node.py +72 -0
- nat/agent/prompt_optimizer/__init__.py +0 -0
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/__init__.py +0 -0
- nat/agent/react_agent/agent.py +394 -0
- nat/agent/react_agent/output_parser.py +104 -0
- nat/agent/react_agent/prompt.py +44 -0
- nat/agent/react_agent/register.py +168 -0
- nat/agent/reasoning_agent/__init__.py +0 -0
- nat/agent/reasoning_agent/reasoning_agent.py +227 -0
- nat/agent/register.py +23 -0
- nat/agent/rewoo_agent/__init__.py +0 -0
- nat/agent/rewoo_agent/agent.py +593 -0
- nat/agent/rewoo_agent/prompt.py +107 -0
- nat/agent/rewoo_agent/register.py +175 -0
- nat/agent/tool_calling_agent/__init__.py +0 -0
- nat/agent/tool_calling_agent/agent.py +246 -0
- nat/agent/tool_calling_agent/register.py +129 -0
- nat/authentication/__init__.py +14 -0
- nat/authentication/api_key/__init__.py +14 -0
- nat/authentication/api_key/api_key_auth_provider.py +96 -0
- nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
- nat/authentication/api_key/register.py +26 -0
- nat/authentication/credential_validator/__init__.py +14 -0
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/exceptions/__init__.py +14 -0
- nat/authentication/exceptions/api_key_exceptions.py +38 -0
- nat/authentication/http_basic_auth/__init__.py +0 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- nat/authentication/http_basic_auth/register.py +30 -0
- nat/authentication/interfaces.py +96 -0
- nat/authentication/oauth2/__init__.py +14 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +140 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/oauth2/register.py +25 -0
- nat/authentication/register.py +20 -0
- nat/builder/__init__.py +0 -0
- nat/builder/builder.py +317 -0
- nat/builder/component_utils.py +320 -0
- nat/builder/context.py +321 -0
- nat/builder/embedder.py +24 -0
- nat/builder/eval_builder.py +166 -0
- nat/builder/evaluator.py +29 -0
- nat/builder/framework_enum.py +25 -0
- nat/builder/front_end.py +73 -0
- nat/builder/function.py +714 -0
- nat/builder/function_base.py +380 -0
- nat/builder/function_info.py +625 -0
- nat/builder/intermediate_step_manager.py +206 -0
- nat/builder/llm.py +25 -0
- nat/builder/retriever.py +25 -0
- nat/builder/user_interaction_manager.py +78 -0
- nat/builder/workflow.py +160 -0
- nat/builder/workflow_builder.py +1365 -0
- nat/cli/__init__.py +14 -0
- nat/cli/cli_utils/__init__.py +0 -0
- nat/cli/cli_utils/config_override.py +231 -0
- nat/cli/cli_utils/validation.py +37 -0
- nat/cli/commands/__init__.py +0 -0
- nat/cli/commands/configure/__init__.py +0 -0
- nat/cli/commands/configure/channel/__init__.py +0 -0
- nat/cli/commands/configure/channel/add.py +28 -0
- nat/cli/commands/configure/channel/channel.py +34 -0
- nat/cli/commands/configure/channel/remove.py +30 -0
- nat/cli/commands/configure/channel/update.py +30 -0
- nat/cli/commands/configure/configure.py +33 -0
- nat/cli/commands/evaluate.py +139 -0
- nat/cli/commands/info/__init__.py +14 -0
- nat/cli/commands/info/info.py +47 -0
- nat/cli/commands/info/list_channels.py +32 -0
- nat/cli/commands/info/list_components.py +128 -0
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/__init__.py +14 -0
- nat/cli/commands/registry/publish.py +88 -0
- nat/cli/commands/registry/pull.py +118 -0
- nat/cli/commands/registry/registry.py +36 -0
- nat/cli/commands/registry/remove.py +108 -0
- nat/cli/commands/registry/search.py +153 -0
- nat/cli/commands/sizing/__init__.py +14 -0
- nat/cli/commands/sizing/calc.py +297 -0
- nat/cli/commands/sizing/sizing.py +27 -0
- nat/cli/commands/start.py +257 -0
- nat/cli/commands/uninstall.py +81 -0
- nat/cli/commands/validate.py +47 -0
- nat/cli/commands/workflow/__init__.py +14 -0
- nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +17 -0
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +25 -0
- nat/cli/commands/workflow/templates/register.py.j2 +4 -0
- nat/cli/commands/workflow/templates/workflow.py.j2 +50 -0
- nat/cli/commands/workflow/workflow.py +37 -0
- nat/cli/commands/workflow/workflow_commands.py +403 -0
- nat/cli/entrypoint.py +141 -0
- nat/cli/main.py +60 -0
- nat/cli/register_workflow.py +522 -0
- nat/cli/type_registry.py +1069 -0
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/__init__.py +14 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +843 -0
- nat/data_models/authentication.py +245 -0
- nat/data_models/common.py +171 -0
- nat/data_models/component.py +60 -0
- nat/data_models/component_ref.py +179 -0
- nat/data_models/config.py +434 -0
- nat/data_models/dataset_handler.py +169 -0
- nat/data_models/discovery_metadata.py +305 -0
- nat/data_models/embedder.py +27 -0
- nat/data_models/evaluate.py +130 -0
- nat/data_models/evaluator.py +26 -0
- nat/data_models/front_end.py +26 -0
- nat/data_models/function.py +64 -0
- nat/data_models/function_dependencies.py +80 -0
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/interactive.py +246 -0
- nat/data_models/intermediate_step.py +302 -0
- nat/data_models/invocation_node.py +38 -0
- nat/data_models/llm.py +27 -0
- nat/data_models/logging.py +26 -0
- nat/data_models/memory.py +27 -0
- nat/data_models/object_store.py +44 -0
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/profiler.py +54 -0
- nat/data_models/registry_handler.py +26 -0
- nat/data_models/retriever.py +30 -0
- nat/data_models/retry_mixin.py +35 -0
- nat/data_models/span.py +228 -0
- nat/data_models/step_adaptor.py +64 -0
- nat/data_models/streaming.py +33 -0
- nat/data_models/swe_bench_model.py +54 -0
- nat/data_models/telemetry_exporter.py +26 -0
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/data_models/ttc_strategy.py +30 -0
- nat/embedder/__init__.py +0 -0
- nat/embedder/azure_openai_embedder.py +46 -0
- nat/embedder/nim_embedder.py +59 -0
- nat/embedder/openai_embedder.py +42 -0
- nat/embedder/register.py +22 -0
- nat/eval/__init__.py +14 -0
- nat/eval/config.py +62 -0
- nat/eval/dataset_handler/__init__.py +0 -0
- nat/eval/dataset_handler/dataset_downloader.py +106 -0
- nat/eval/dataset_handler/dataset_filter.py +52 -0
- nat/eval/dataset_handler/dataset_handler.py +431 -0
- nat/eval/evaluate.py +565 -0
- nat/eval/evaluator/__init__.py +14 -0
- nat/eval/evaluator/base_evaluator.py +77 -0
- nat/eval/evaluator/evaluator_model.py +58 -0
- nat/eval/intermediate_step_adapter.py +99 -0
- nat/eval/rag_evaluator/__init__.py +0 -0
- nat/eval/rag_evaluator/evaluate.py +178 -0
- nat/eval/rag_evaluator/register.py +143 -0
- nat/eval/register.py +26 -0
- nat/eval/remote_workflow.py +133 -0
- nat/eval/runners/__init__.py +14 -0
- nat/eval/runners/config.py +39 -0
- nat/eval/runners/multi_eval_runner.py +54 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/runtime_event_subscriber.py +52 -0
- nat/eval/swe_bench_evaluator/__init__.py +0 -0
- nat/eval/swe_bench_evaluator/evaluate.py +215 -0
- nat/eval/swe_bench_evaluator/register.py +36 -0
- nat/eval/trajectory_evaluator/__init__.py +0 -0
- nat/eval/trajectory_evaluator/evaluate.py +75 -0
- nat/eval/trajectory_evaluator/register.py +40 -0
- nat/eval/tunable_rag_evaluator/__init__.py +0 -0
- nat/eval/tunable_rag_evaluator/evaluate.py +242 -0
- nat/eval/tunable_rag_evaluator/register.py +52 -0
- nat/eval/usage_stats.py +41 -0
- nat/eval/utils/__init__.py +0 -0
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/output_uploader.py +140 -0
- nat/eval/utils/tqdm_position_registry.py +40 -0
- nat/eval/utils/weave_eval.py +193 -0
- nat/experimental/__init__.py +0 -0
- nat/experimental/decorators/__init__.py +0 -0
- nat/experimental/decorators/experimental_warning_decorator.py +154 -0
- nat/experimental/test_time_compute/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- nat/experimental/test_time_compute/functions/__init__.py +0 -0
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +228 -0
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
- nat/experimental/test_time_compute/models/__init__.py +0 -0
- nat/experimental/test_time_compute/models/editor_config.py +132 -0
- nat/experimental/test_time_compute/models/scoring_config.py +112 -0
- nat/experimental/test_time_compute/models/search_config.py +120 -0
- nat/experimental/test_time_compute/models/selection_config.py +154 -0
- nat/experimental/test_time_compute/models/stage_enums.py +43 -0
- nat/experimental/test_time_compute/models/strategy_base.py +67 -0
- nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
- nat/experimental/test_time_compute/models/ttc_item.py +48 -0
- nat/experimental/test_time_compute/register.py +35 -0
- nat/experimental/test_time_compute/scoring/__init__.py +0 -0
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- nat/experimental/test_time_compute/search/__init__.py +0 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- nat/experimental/test_time_compute/selection/__init__.py +0 -0
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +157 -0
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- nat/front_ends/__init__.py +14 -0
- nat/front_ends/console/__init__.py +14 -0
- nat/front_ends/console/authentication_flow_handler.py +285 -0
- nat/front_ends/console/console_front_end_config.py +32 -0
- nat/front_ends/console/console_front_end_plugin.py +108 -0
- nat/front_ends/console/register.py +25 -0
- nat/front_ends/cron/__init__.py +14 -0
- nat/front_ends/fastapi/__init__.py +14 -0
- nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +142 -0
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +272 -0
- nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +247 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1257 -0
- nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- nat/front_ends/fastapi/job_store.py +602 -0
- nat/front_ends/fastapi/main.py +64 -0
- nat/front_ends/fastapi/message_handler.py +344 -0
- nat/front_ends/fastapi/message_validator.py +351 -0
- nat/front_ends/fastapi/register.py +25 -0
- nat/front_ends/fastapi/response_helpers.py +195 -0
- nat/front_ends/fastapi/step_adaptor.py +319 -0
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/__init__.py +14 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +90 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +113 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +268 -0
- nat/front_ends/mcp/memory_profiler.py +320 -0
- nat/front_ends/mcp/register.py +27 -0
- nat/front_ends/mcp/tool_converter.py +290 -0
- nat/front_ends/register.py +21 -0
- nat/front_ends/simple_base/__init__.py +14 -0
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +56 -0
- nat/llm/__init__.py +0 -0
- nat/llm/aws_bedrock_llm.py +69 -0
- nat/llm/azure_openai_llm.py +57 -0
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +58 -0
- nat/llm/openai_llm.py +54 -0
- nat/llm/register.py +27 -0
- nat/llm/utils/__init__.py +14 -0
- nat/llm/utils/env_config_value.py +93 -0
- nat/llm/utils/error.py +17 -0
- nat/llm/utils/thinking.py +215 -0
- nat/memory/__init__.py +20 -0
- nat/memory/interfaces.py +183 -0
- nat/memory/models.py +112 -0
- nat/meta/pypi.md +58 -0
- nat/object_store/__init__.py +20 -0
- nat/object_store/in_memory_object_store.py +76 -0
- nat/object_store/interfaces.py +84 -0
- nat/object_store/models.py +38 -0
- nat/object_store/register.py +19 -0
- nat/observability/__init__.py +14 -0
- nat/observability/exporter/__init__.py +14 -0
- nat/observability/exporter/base_exporter.py +449 -0
- nat/observability/exporter/exporter.py +78 -0
- nat/observability/exporter/file_exporter.py +33 -0
- nat/observability/exporter/processing_exporter.py +550 -0
- nat/observability/exporter/raw_exporter.py +52 -0
- nat/observability/exporter/span_exporter.py +308 -0
- nat/observability/exporter_manager.py +335 -0
- nat/observability/mixin/__init__.py +14 -0
- nat/observability/mixin/batch_config_mixin.py +26 -0
- nat/observability/mixin/collector_config_mixin.py +23 -0
- nat/observability/mixin/file_mixin.py +288 -0
- nat/observability/mixin/file_mode.py +23 -0
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/resource_conflict_mixin.py +134 -0
- nat/observability/mixin/serialize_mixin.py +61 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +496 -0
- nat/observability/processor/__init__.py +14 -0
- nat/observability/processor/batching_processor.py +308 -0
- nat/observability/processor/callback_processor.py +42 -0
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/intermediate_step_serializer.py +28 -0
- nat/observability/processor/processor.py +74 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +114 -0
- nat/observability/utils/__init__.py +14 -0
- nat/observability/utils/dict_utils.py +236 -0
- nat/observability/utils/time_utils.py +31 -0
- nat/plugins/.namespace +1 -0
- nat/profiler/__init__.py +0 -0
- nat/profiler/calc/__init__.py +14 -0
- nat/profiler/calc/calc_runner.py +626 -0
- nat/profiler/calc/calculations.py +288 -0
- nat/profiler/calc/data_models.py +188 -0
- nat/profiler/calc/plot.py +345 -0
- nat/profiler/callbacks/__init__.py +0 -0
- nat/profiler/callbacks/agno_callback_handler.py +295 -0
- nat/profiler/callbacks/base_callback_class.py +20 -0
- nat/profiler/callbacks/langchain_callback_handler.py +297 -0
- nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- nat/profiler/callbacks/token_usage_base_model.py +27 -0
- nat/profiler/data_frame_row.py +51 -0
- nat/profiler/data_models.py +24 -0
- nat/profiler/decorators/__init__.py +0 -0
- nat/profiler/decorators/framework_wrapper.py +180 -0
- nat/profiler/decorators/function_tracking.py +411 -0
- nat/profiler/forecasting/__init__.py +0 -0
- nat/profiler/forecasting/config.py +18 -0
- nat/profiler/forecasting/model_trainer.py +75 -0
- nat/profiler/forecasting/models/__init__.py +22 -0
- nat/profiler/forecasting/models/forecasting_base_model.py +42 -0
- nat/profiler/forecasting/models/linear_model.py +197 -0
- nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
- nat/profiler/inference_metrics_model.py +28 -0
- nat/profiler/inference_optimization/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- nat/profiler/inference_optimization/data_models.py +386 -0
- nat/profiler/inference_optimization/experimental/__init__.py +0 -0
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +404 -0
- nat/profiler/inference_optimization/llm_metrics.py +212 -0
- nat/profiler/inference_optimization/prompt_caching.py +163 -0
- nat/profiler/inference_optimization/token_uniqueness.py +107 -0
- nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
- nat/profiler/intermediate_property_adapter.py +102 -0
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +478 -0
- nat/profiler/utils.py +186 -0
- nat/registry_handlers/__init__.py +0 -0
- nat/registry_handlers/local/__init__.py +0 -0
- nat/registry_handlers/local/local_handler.py +176 -0
- nat/registry_handlers/local/register_local.py +37 -0
- nat/registry_handlers/metadata_factory.py +60 -0
- nat/registry_handlers/package_utils.py +570 -0
- nat/registry_handlers/pypi/__init__.py +0 -0
- nat/registry_handlers/pypi/pypi_handler.py +248 -0
- nat/registry_handlers/pypi/register_pypi.py +40 -0
- nat/registry_handlers/register.py +20 -0
- nat/registry_handlers/registry_handler_base.py +157 -0
- nat/registry_handlers/rest/__init__.py +0 -0
- nat/registry_handlers/rest/register_rest.py +56 -0
- nat/registry_handlers/rest/rest_handler.py +236 -0
- nat/registry_handlers/schemas/__init__.py +0 -0
- nat/registry_handlers/schemas/headers.py +42 -0
- nat/registry_handlers/schemas/package.py +68 -0
- nat/registry_handlers/schemas/publish.py +68 -0
- nat/registry_handlers/schemas/pull.py +82 -0
- nat/registry_handlers/schemas/remove.py +36 -0
- nat/registry_handlers/schemas/search.py +91 -0
- nat/registry_handlers/schemas/status.py +47 -0
- nat/retriever/__init__.py +0 -0
- nat/retriever/interface.py +41 -0
- nat/retriever/milvus/__init__.py +14 -0
- nat/retriever/milvus/register.py +81 -0
- nat/retriever/milvus/retriever.py +228 -0
- nat/retriever/models.py +77 -0
- nat/retriever/nemo_retriever/__init__.py +14 -0
- nat/retriever/nemo_retriever/register.py +60 -0
- nat/retriever/nemo_retriever/retriever.py +190 -0
- nat/retriever/register.py +21 -0
- nat/runtime/__init__.py +14 -0
- nat/runtime/loader.py +220 -0
- nat/runtime/runner.py +292 -0
- nat/runtime/session.py +223 -0
- nat/runtime/user_metadata.py +130 -0
- nat/settings/__init__.py +0 -0
- nat/settings/global_settings.py +329 -0
- nat/test/.namespace +1 -0
- nat/tool/__init__.py +0 -0
- nat/tool/chat_completion.py +77 -0
- nat/tool/code_execution/README.md +151 -0
- nat/tool/code_execution/__init__.py +0 -0
- nat/tool/code_execution/code_sandbox.py +267 -0
- nat/tool/code_execution/local_sandbox/.gitignore +1 -0
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- nat/tool/code_execution/local_sandbox/__init__.py +13 -0
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- nat/tool/code_execution/register.py +74 -0
- nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
- nat/tool/code_execution/utils.py +100 -0
- nat/tool/datetime_tools.py +82 -0
- nat/tool/document_search.py +141 -0
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/__init__.py +0 -0
- nat/tool/memory_tools/add_memory_tool.py +79 -0
- nat/tool/memory_tools/delete_memory_tool.py +66 -0
- nat/tool/memory_tools/get_memory_tool.py +72 -0
- nat/tool/nvidia_rag.py +95 -0
- nat/tool/register.py +31 -0
- nat/tool/retriever.py +95 -0
- nat/tool/server_tools.py +66 -0
- nat/utils/__init__.py +0 -0
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/__init__.py +0 -0
- nat/utils/data_models/schema_validator.py +58 -0
- nat/utils/debugging_utils.py +43 -0
- nat/utils/decorators.py +210 -0
- nat/utils/dump_distro_mapping.py +32 -0
- nat/utils/exception_handlers/__init__.py +0 -0
- nat/utils/exception_handlers/automatic_retries.py +342 -0
- nat/utils/exception_handlers/schemas.py +114 -0
- nat/utils/io/__init__.py +0 -0
- nat/utils/io/model_processing.py +28 -0
- nat/utils/io/yaml_tools.py +119 -0
- nat/utils/log_levels.py +25 -0
- nat/utils/log_utils.py +37 -0
- nat/utils/metadata_utils.py +74 -0
- nat/utils/optional_imports.py +142 -0
- nat/utils/producer_consumer_queue.py +178 -0
- nat/utils/reactive/__init__.py +0 -0
- nat/utils/reactive/base/__init__.py +0 -0
- nat/utils/reactive/base/observable_base.py +65 -0
- nat/utils/reactive/base/observer_base.py +55 -0
- nat/utils/reactive/base/subject_base.py +79 -0
- nat/utils/reactive/observable.py +59 -0
- nat/utils/reactive/observer.py +76 -0
- nat/utils/reactive/subject.py +131 -0
- nat/utils/reactive/subscription.py +49 -0
- nat/utils/settings/__init__.py +0 -0
- nat/utils/settings/global_settings.py +195 -0
- nat/utils/string_utils.py +38 -0
- nat/utils/type_converter.py +299 -0
- nat/utils/type_utils.py +488 -0
- nat/utils/url_utils.py +27 -0
- nvidia_nat-1.1.0a20251020.dist-info/METADATA +195 -0
- nvidia_nat-1.1.0a20251020.dist-info/RECORD +480 -0
- nvidia_nat-1.1.0a20251020.dist-info/WHEEL +5 -0
- nvidia_nat-1.1.0a20251020.dist-info/entry_points.txt +22 -0
- nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.1.0a20251020.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
from collections.abc import Mapping as Dict
|
|
19
|
+
|
|
20
|
+
import optuna
|
|
21
|
+
import yaml
|
|
22
|
+
|
|
23
|
+
from nat.data_models.config import Config
|
|
24
|
+
from nat.data_models.optimizable import SearchSpace
|
|
25
|
+
from nat.data_models.optimizer import OptimizerConfig
|
|
26
|
+
from nat.data_models.optimizer import OptimizerRunConfig
|
|
27
|
+
from nat.eval.evaluate import EvaluationRun
|
|
28
|
+
from nat.eval.evaluate import EvaluationRunConfig
|
|
29
|
+
from nat.experimental.decorators.experimental_warning_decorator import experimental
|
|
30
|
+
from nat.profiler.parameter_optimization.parameter_selection import pick_trial
|
|
31
|
+
from nat.profiler.parameter_optimization.update_helpers import apply_suggestions
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@experimental(feature_name="Optimizer")
|
|
37
|
+
def optimize_parameters(
|
|
38
|
+
*,
|
|
39
|
+
base_cfg: Config,
|
|
40
|
+
full_space: Dict[str, SearchSpace],
|
|
41
|
+
optimizer_config: OptimizerConfig,
|
|
42
|
+
opt_run_config: OptimizerRunConfig,
|
|
43
|
+
) -> Config:
|
|
44
|
+
"""Tune all *non-prompt* hyper-parameters and persist the best config."""
|
|
45
|
+
space = {k: v for k, v in full_space.items() if not v.is_prompt}
|
|
46
|
+
|
|
47
|
+
# Ensure output_path is not None
|
|
48
|
+
if optimizer_config.output_path is None:
|
|
49
|
+
raise ValueError("optimizer_config.output_path cannot be None")
|
|
50
|
+
out_dir = optimizer_config.output_path
|
|
51
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
52
|
+
|
|
53
|
+
# Ensure eval_metrics is not None
|
|
54
|
+
if optimizer_config.eval_metrics is None:
|
|
55
|
+
raise ValueError("optimizer_config.eval_metrics cannot be None")
|
|
56
|
+
|
|
57
|
+
metric_cfg = optimizer_config.eval_metrics
|
|
58
|
+
directions = [v.direction for v in metric_cfg.values()]
|
|
59
|
+
eval_metrics = [v.evaluator_name for v in metric_cfg.values()]
|
|
60
|
+
weights = [v.weight for v in metric_cfg.values()]
|
|
61
|
+
|
|
62
|
+
study = optuna.create_study(directions=directions)
|
|
63
|
+
|
|
64
|
+
# Create output directory for intermediate files
|
|
65
|
+
out_dir = optimizer_config.output_path
|
|
66
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
67
|
+
|
|
68
|
+
async def _run_eval(runner: EvaluationRun):
|
|
69
|
+
return await runner.run_and_evaluate()
|
|
70
|
+
|
|
71
|
+
def _objective(trial: optuna.Trial):
|
|
72
|
+
reps = max(1, getattr(optimizer_config, "reps_per_param_set", 1))
|
|
73
|
+
|
|
74
|
+
# build trial config
|
|
75
|
+
suggestions = {p: spec.suggest(trial, p) for p, spec in space.items()}
|
|
76
|
+
cfg_trial = apply_suggestions(base_cfg, suggestions)
|
|
77
|
+
|
|
78
|
+
async def _single_eval(trial_idx: int) -> list[float]: # noqa: ARG001
|
|
79
|
+
eval_cfg = EvaluationRunConfig(
|
|
80
|
+
config_file=cfg_trial,
|
|
81
|
+
dataset=opt_run_config.dataset,
|
|
82
|
+
result_json_path=opt_run_config.result_json_path,
|
|
83
|
+
endpoint=opt_run_config.endpoint,
|
|
84
|
+
endpoint_timeout=opt_run_config.endpoint_timeout,
|
|
85
|
+
)
|
|
86
|
+
scores = await _run_eval(EvaluationRun(config=eval_cfg))
|
|
87
|
+
values = []
|
|
88
|
+
for metric_name in eval_metrics:
|
|
89
|
+
metric = next(r[1] for r in scores.evaluation_results if r[0] == metric_name)
|
|
90
|
+
values.append(metric.average_score)
|
|
91
|
+
|
|
92
|
+
return values
|
|
93
|
+
|
|
94
|
+
# Create tasks for all evaluations
|
|
95
|
+
async def _run_all_evals():
|
|
96
|
+
tasks = [_single_eval(i) for i in range(reps)]
|
|
97
|
+
return await asyncio.gather(*tasks)
|
|
98
|
+
|
|
99
|
+
with (out_dir / f"config_numeric_trial_{trial._trial_id}.yml").open("w") as fh:
|
|
100
|
+
yaml.dump(cfg_trial.model_dump(), fh)
|
|
101
|
+
|
|
102
|
+
all_scores = asyncio.run(_run_all_evals())
|
|
103
|
+
# Persist raw per‑repetition scores so they appear in `trials_dataframe`.
|
|
104
|
+
trial.set_user_attr("rep_scores", all_scores)
|
|
105
|
+
return [sum(run[i] for run in all_scores) / reps for i in range(len(eval_metrics))]
|
|
106
|
+
|
|
107
|
+
logger.info("Starting numeric / enum parameter optimization...")
|
|
108
|
+
study.optimize(_objective, n_trials=optimizer_config.numeric.n_trials)
|
|
109
|
+
logger.info("Numeric optimization finished")
|
|
110
|
+
|
|
111
|
+
best_params = pick_trial(
|
|
112
|
+
study=study,
|
|
113
|
+
mode=optimizer_config.multi_objective_combination_mode,
|
|
114
|
+
weights=weights,
|
|
115
|
+
).params
|
|
116
|
+
tuned_cfg = apply_suggestions(base_cfg, best_params)
|
|
117
|
+
|
|
118
|
+
# Save final results (out_dir already created and defined above)
|
|
119
|
+
with (out_dir / "optimized_config.yml").open("w") as fh:
|
|
120
|
+
yaml.dump(tuned_cfg.model_dump(), fh)
|
|
121
|
+
with (out_dir / "trials_dataframe_params.csv").open("w") as fh:
|
|
122
|
+
# Export full trials DataFrame (values, params, timings, etc.).
|
|
123
|
+
df = study.trials_dataframe()
|
|
124
|
+
# Normalise rep_scores column naming for convenience.
|
|
125
|
+
if "user_attrs_rep_scores" in df.columns and "rep_scores" not in df.columns:
|
|
126
|
+
df = df.rename(columns={"user_attrs_rep_scores": "rep_scores"})
|
|
127
|
+
elif "user_attrs" in df.columns and "rep_scores" not in df.columns:
|
|
128
|
+
# Some Optuna versions return a dict in a single user_attrs column.
|
|
129
|
+
df["rep_scores"] = df["user_attrs"].apply(lambda d: d.get("rep_scores") if isinstance(d, dict) else None)
|
|
130
|
+
df = df.drop(columns=["user_attrs"])
|
|
131
|
+
df.to_csv(fh, index=False)
|
|
132
|
+
|
|
133
|
+
# Generate Pareto front visualizations
|
|
134
|
+
try:
|
|
135
|
+
from nat.profiler.parameter_optimization.pareto_visualizer import create_pareto_visualization
|
|
136
|
+
logger.info("Generating Pareto front visualizations...")
|
|
137
|
+
create_pareto_visualization(
|
|
138
|
+
data_source=study,
|
|
139
|
+
metric_names=eval_metrics,
|
|
140
|
+
directions=directions,
|
|
141
|
+
output_dir=out_dir / "plots",
|
|
142
|
+
title_prefix="Parameter Optimization",
|
|
143
|
+
show_plots=False # Don't show plots in automated runs
|
|
144
|
+
)
|
|
145
|
+
logger.info("Pareto visualizations saved to: %s", out_dir / "plots")
|
|
146
|
+
except ImportError as ie:
|
|
147
|
+
logger.warning("Could not import visualization dependencies: %s. "
|
|
148
|
+
"Have you installed nvidia-nat-profiling?",
|
|
149
|
+
ie)
|
|
150
|
+
except Exception as e:
|
|
151
|
+
logger.warning("Failed to generate visualizations: %s", e)
|
|
152
|
+
|
|
153
|
+
return tuned_cfg
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections.abc import Sequence
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import optuna
|
|
20
|
+
from optuna._hypervolume import compute_hypervolume
|
|
21
|
+
from optuna.study import Study
|
|
22
|
+
from optuna.study import StudyDirection
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# ---------- helper ----------
|
|
26
|
+
def _to_minimisation_matrix(
|
|
27
|
+
trials: Sequence[optuna.trial.FrozenTrial],
|
|
28
|
+
directions: Sequence[StudyDirection],
|
|
29
|
+
) -> np.ndarray:
|
|
30
|
+
"""Return array (n_trials × n_objectives) where **all** objectives are ‘smaller-is-better’."""
|
|
31
|
+
vals = np.asarray([t.values for t in trials], dtype=float)
|
|
32
|
+
for j, d in enumerate(directions):
|
|
33
|
+
if d == StudyDirection.MAXIMIZE:
|
|
34
|
+
vals[:, j] *= -1.0 # flip sign
|
|
35
|
+
return vals
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# ---------- public API ----------
|
|
39
|
+
def pick_trial(
|
|
40
|
+
study: Study,
|
|
41
|
+
mode: str = "harmonic",
|
|
42
|
+
*,
|
|
43
|
+
weights: Sequence[float] | None = None,
|
|
44
|
+
ref_point: Sequence[float] | None = None,
|
|
45
|
+
eps: float = 1e-12,
|
|
46
|
+
) -> optuna.trial.FrozenTrial:
|
|
47
|
+
"""
|
|
48
|
+
Collapse Optuna’s Pareto front (`study.best_trials`) to a single “best compromise”.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
study : completed **multi-objective** Optuna study
|
|
53
|
+
mode : {"harmonic", "sum", "chebyshev", "hypervolume"}
|
|
54
|
+
weights : per-objective weights (used only for "sum")
|
|
55
|
+
ref_point : reference point for hyper-volume (defaults to ones after normalisation)
|
|
56
|
+
eps : tiny value to avoid division by zero
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
-------
|
|
60
|
+
optuna.trial.FrozenTrial
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
# ---- 1. Pareto front ----
|
|
64
|
+
front = study.best_trials
|
|
65
|
+
if not front:
|
|
66
|
+
raise ValueError("`study.best_trials` is empty – no Pareto-optimal trials found.")
|
|
67
|
+
|
|
68
|
+
# ---- 2. Convert & normalise objectives ----
|
|
69
|
+
vals = _to_minimisation_matrix(front, study.directions) # smaller is better
|
|
70
|
+
span = np.ptp(vals, axis=0)
|
|
71
|
+
norm = (vals - vals.min(axis=0)) / (span + eps) # 0 = best, 1 = worst
|
|
72
|
+
|
|
73
|
+
# ---- 3. Scalarise according to chosen mode ----
|
|
74
|
+
mode = mode.lower()
|
|
75
|
+
|
|
76
|
+
if mode == "harmonic":
|
|
77
|
+
hmean = norm.shape[1] / (1.0 / (norm + eps)).sum(axis=1)
|
|
78
|
+
best_idx = hmean.argmin() # lower = better
|
|
79
|
+
|
|
80
|
+
elif mode == "sum":
|
|
81
|
+
w = np.ones(norm.shape[1]) if weights is None else np.asarray(weights, float)
|
|
82
|
+
if w.size != norm.shape[1]:
|
|
83
|
+
raise ValueError("`weights` length must equal number of objectives.")
|
|
84
|
+
score = norm @ w
|
|
85
|
+
best_idx = score.argmin()
|
|
86
|
+
|
|
87
|
+
elif mode == "chebyshev":
|
|
88
|
+
score = norm.max(axis=1) # worst dimension
|
|
89
|
+
best_idx = score.argmin()
|
|
90
|
+
|
|
91
|
+
elif mode == "hypervolume":
|
|
92
|
+
# Hyper-volume assumes points are *below* the reference point (minimisation space).
|
|
93
|
+
if len(front) == 0:
|
|
94
|
+
raise ValueError("Pareto front is empty - no trials to select from")
|
|
95
|
+
elif len(front) == 1:
|
|
96
|
+
best_idx = 0
|
|
97
|
+
else:
|
|
98
|
+
rp = np.ones(norm.shape[1]) if ref_point is None else np.asarray(ref_point, float)
|
|
99
|
+
base_hv = compute_hypervolume(norm, rp)
|
|
100
|
+
contrib = np.array([base_hv - compute_hypervolume(np.delete(norm, i, 0), rp) for i in range(len(front))])
|
|
101
|
+
best_idx = contrib.argmax() # bigger contribution wins
|
|
102
|
+
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError(f"Unknown mode '{mode}'. Choose from "
|
|
105
|
+
"'harmonic', 'sum', 'chebyshev', 'hypervolume'.")
|
|
106
|
+
|
|
107
|
+
return front[best_idx]
|
|
@@ -0,0 +1,380 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# flake8: noqa: W293
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
import matplotlib.pyplot as plt
|
|
21
|
+
import numpy as np
|
|
22
|
+
import optuna
|
|
23
|
+
import pandas as pd
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ParetoVisualizer:
|
|
29
|
+
|
|
30
|
+
def __init__(self, metric_names: list[str], directions: list[str], title_prefix: str = "Optimization Results"):
|
|
31
|
+
self.metric_names = metric_names
|
|
32
|
+
self.directions = directions
|
|
33
|
+
self.title_prefix = title_prefix
|
|
34
|
+
|
|
35
|
+
if len(metric_names) != len(directions):
|
|
36
|
+
raise ValueError("Number of metric names must match number of directions")
|
|
37
|
+
|
|
38
|
+
def plot_pareto_front_2d(self,
|
|
39
|
+
trials_df: pd.DataFrame,
|
|
40
|
+
pareto_trials_df: pd.DataFrame | None = None,
|
|
41
|
+
save_path: Path | None = None,
|
|
42
|
+
figsize: tuple[int, int] = (10, 8),
|
|
43
|
+
show_plot: bool = True) -> plt.Figure:
|
|
44
|
+
if len(self.metric_names) != 2:
|
|
45
|
+
raise ValueError("2D Pareto front visualization requires exactly 2 metrics")
|
|
46
|
+
|
|
47
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
48
|
+
|
|
49
|
+
# Extract metric values
|
|
50
|
+
x_vals = trials_df[f"values_{0}"].values
|
|
51
|
+
y_vals = trials_df[f"values_{1}"].values
|
|
52
|
+
|
|
53
|
+
# Plot all trials
|
|
54
|
+
ax.scatter(x_vals,
|
|
55
|
+
y_vals,
|
|
56
|
+
alpha=0.6,
|
|
57
|
+
s=50,
|
|
58
|
+
c='lightblue',
|
|
59
|
+
label=f'All Trials (n={len(trials_df)})',
|
|
60
|
+
edgecolors='navy',
|
|
61
|
+
linewidths=0.5)
|
|
62
|
+
|
|
63
|
+
# Plot Pareto optimal trials if provided
|
|
64
|
+
if pareto_trials_df is not None and not pareto_trials_df.empty:
|
|
65
|
+
pareto_x = pareto_trials_df[f"values_{0}"].values
|
|
66
|
+
pareto_y = pareto_trials_df[f"values_{1}"].values
|
|
67
|
+
|
|
68
|
+
ax.scatter(pareto_x,
|
|
69
|
+
pareto_y,
|
|
70
|
+
alpha=0.9,
|
|
71
|
+
s=100,
|
|
72
|
+
c='red',
|
|
73
|
+
label=f'Pareto Optimal (n={len(pareto_trials_df)})',
|
|
74
|
+
edgecolors='darkred',
|
|
75
|
+
linewidths=1.5,
|
|
76
|
+
marker='*')
|
|
77
|
+
|
|
78
|
+
# Draw Pareto front line (only for 2D)
|
|
79
|
+
if len(pareto_x) > 1:
|
|
80
|
+
# Sort points for line drawing based on first objective
|
|
81
|
+
sorted_indices = np.argsort(pareto_x)
|
|
82
|
+
ax.plot(pareto_x[sorted_indices],
|
|
83
|
+
pareto_y[sorted_indices],
|
|
84
|
+
'r--',
|
|
85
|
+
alpha=0.7,
|
|
86
|
+
linewidth=2,
|
|
87
|
+
label='Pareto Front')
|
|
88
|
+
|
|
89
|
+
# Customize plot
|
|
90
|
+
x_direction = "↓" if self.directions[0] == "minimize" else "↑"
|
|
91
|
+
y_direction = "↓" if self.directions[1] == "minimize" else "↑"
|
|
92
|
+
|
|
93
|
+
ax.set_xlabel(f"{self.metric_names[0]} {x_direction}", fontsize=12)
|
|
94
|
+
ax.set_ylabel(f"{self.metric_names[1]} {y_direction}", fontsize=12)
|
|
95
|
+
ax.set_title(f"{self.title_prefix}: Pareto Front Visualization", fontsize=14, fontweight='bold')
|
|
96
|
+
|
|
97
|
+
ax.legend(loc='best', frameon=True, fancybox=True, shadow=True)
|
|
98
|
+
ax.grid(True, alpha=0.3)
|
|
99
|
+
|
|
100
|
+
# Add direction annotations
|
|
101
|
+
x_annotation = (f"Better {self.metric_names[0]} →"
|
|
102
|
+
if self.directions[0] == "minimize" else f"← Better {self.metric_names[0]}")
|
|
103
|
+
ax.annotate(x_annotation,
|
|
104
|
+
xy=(0.02, 0.98),
|
|
105
|
+
xycoords='axes fraction',
|
|
106
|
+
ha='left',
|
|
107
|
+
va='top',
|
|
108
|
+
fontsize=10,
|
|
109
|
+
style='italic',
|
|
110
|
+
bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.7))
|
|
111
|
+
|
|
112
|
+
y_annotation = (f"Better {self.metric_names[1]} ↑"
|
|
113
|
+
if self.directions[1] == "minimize" else f"Better {self.metric_names[1]} ↓")
|
|
114
|
+
ax.annotate(y_annotation,
|
|
115
|
+
xy=(0.02, 0.02),
|
|
116
|
+
xycoords='axes fraction',
|
|
117
|
+
ha='left',
|
|
118
|
+
va='bottom',
|
|
119
|
+
fontsize=10,
|
|
120
|
+
style='italic',
|
|
121
|
+
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgreen", alpha=0.7))
|
|
122
|
+
|
|
123
|
+
plt.tight_layout()
|
|
124
|
+
|
|
125
|
+
if save_path:
|
|
126
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
127
|
+
logger.info("2D Pareto plot saved to: %s", save_path)
|
|
128
|
+
|
|
129
|
+
if show_plot:
|
|
130
|
+
plt.show()
|
|
131
|
+
|
|
132
|
+
return fig
|
|
133
|
+
|
|
134
|
+
def plot_pareto_parallel_coordinates(self,
|
|
135
|
+
trials_df: pd.DataFrame,
|
|
136
|
+
pareto_trials_df: pd.DataFrame | None = None,
|
|
137
|
+
save_path: Path | None = None,
|
|
138
|
+
figsize: tuple[int, int] = (12, 8),
|
|
139
|
+
show_plot: bool = True) -> plt.Figure:
|
|
140
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
141
|
+
|
|
142
|
+
n_metrics = len(self.metric_names)
|
|
143
|
+
x_positions = np.arange(n_metrics)
|
|
144
|
+
|
|
145
|
+
# Normalize values for better visualization
|
|
146
|
+
all_values = []
|
|
147
|
+
for i in range(n_metrics):
|
|
148
|
+
all_values.append(trials_df[f"values_{i}"].values)
|
|
149
|
+
|
|
150
|
+
# Normalize each metric to [0, 1] for parallel coordinates
|
|
151
|
+
normalized_values = []
|
|
152
|
+
for i, values in enumerate(all_values):
|
|
153
|
+
min_val, max_val = values.min(), values.max()
|
|
154
|
+
if max_val > min_val:
|
|
155
|
+
if self.directions[i] == "minimize":
|
|
156
|
+
# For minimize: lower values get higher normalized scores
|
|
157
|
+
norm_vals = 1 - (values - min_val) / (max_val - min_val)
|
|
158
|
+
else:
|
|
159
|
+
# For maximize: higher values get higher normalized scores
|
|
160
|
+
norm_vals = (values - min_val) / (max_val - min_val)
|
|
161
|
+
else:
|
|
162
|
+
norm_vals = np.ones_like(values) * 0.5
|
|
163
|
+
normalized_values.append(norm_vals)
|
|
164
|
+
|
|
165
|
+
# Plot all trials
|
|
166
|
+
for i in range(len(trials_df)):
|
|
167
|
+
trial_values = [normalized_values[j][i] for j in range(n_metrics)]
|
|
168
|
+
ax.plot(x_positions, trial_values, 'b-', alpha=0.1, linewidth=1)
|
|
169
|
+
|
|
170
|
+
# Plot Pareto optimal trials
|
|
171
|
+
if pareto_trials_df is not None and not pareto_trials_df.empty:
|
|
172
|
+
pareto_indices = pareto_trials_df.index
|
|
173
|
+
for idx in pareto_indices:
|
|
174
|
+
if idx < len(trials_df):
|
|
175
|
+
trial_values = [normalized_values[j][idx] for j in range(n_metrics)]
|
|
176
|
+
ax.plot(x_positions, trial_values, 'r-', alpha=0.8, linewidth=3)
|
|
177
|
+
|
|
178
|
+
# Customize plot
|
|
179
|
+
ax.set_xticks(x_positions)
|
|
180
|
+
ax.set_xticklabels([f"{name}\n({direction})" for name, direction in zip(self.metric_names, self.directions)])
|
|
181
|
+
ax.set_ylabel("Normalized Performance (Higher is Better)", fontsize=12)
|
|
182
|
+
ax.set_title(f"{self.title_prefix}: Parallel Coordinates Plot", fontsize=14, fontweight='bold')
|
|
183
|
+
ax.set_ylim(-0.05, 1.05)
|
|
184
|
+
ax.grid(True, alpha=0.3)
|
|
185
|
+
|
|
186
|
+
# Add legend
|
|
187
|
+
from matplotlib.lines import Line2D
|
|
188
|
+
legend_elements = [
|
|
189
|
+
Line2D([0], [0], color='blue', alpha=0.3, linewidth=2, label='All Trials'),
|
|
190
|
+
Line2D([0], [0], color='red', alpha=0.8, linewidth=3, label='Pareto Optimal')
|
|
191
|
+
]
|
|
192
|
+
ax.legend(handles=legend_elements, loc='best')
|
|
193
|
+
|
|
194
|
+
plt.tight_layout()
|
|
195
|
+
|
|
196
|
+
if save_path:
|
|
197
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
198
|
+
logger.info("Parallel coordinates plot saved to: %s", save_path)
|
|
199
|
+
|
|
200
|
+
if show_plot:
|
|
201
|
+
plt.show()
|
|
202
|
+
|
|
203
|
+
return fig
|
|
204
|
+
|
|
205
|
+
def plot_pairwise_matrix(self,
|
|
206
|
+
trials_df: pd.DataFrame,
|
|
207
|
+
pareto_trials_df: pd.DataFrame | None = None,
|
|
208
|
+
save_path: Path | None = None,
|
|
209
|
+
figsize: tuple[int, int] | None = None,
|
|
210
|
+
show_plot: bool = True) -> plt.Figure:
|
|
211
|
+
n_metrics = len(self.metric_names)
|
|
212
|
+
if figsize is None:
|
|
213
|
+
figsize = (4 * n_metrics, 4 * n_metrics)
|
|
214
|
+
|
|
215
|
+
fig, axes = plt.subplots(n_metrics, n_metrics, figsize=figsize)
|
|
216
|
+
fig.suptitle(f"{self.title_prefix}: Pairwise Metric Comparison", fontsize=16, fontweight='bold')
|
|
217
|
+
|
|
218
|
+
for i in range(n_metrics):
|
|
219
|
+
for j in range(n_metrics):
|
|
220
|
+
ax = axes[i, j] if n_metrics > 1 else axes
|
|
221
|
+
|
|
222
|
+
if i == j:
|
|
223
|
+
# Diagonal: histograms
|
|
224
|
+
values = trials_df[f"values_{i}"].values
|
|
225
|
+
ax.hist(values, bins=20, alpha=0.7, color='lightblue', edgecolor='navy')
|
|
226
|
+
if pareto_trials_df is not None and not pareto_trials_df.empty:
|
|
227
|
+
pareto_values = pareto_trials_df[f"values_{i}"].values
|
|
228
|
+
ax.hist(pareto_values, bins=20, alpha=0.8, color='red', edgecolor='darkred')
|
|
229
|
+
ax.set_xlabel(f"{self.metric_names[i]}")
|
|
230
|
+
ax.set_ylabel("Frequency")
|
|
231
|
+
else:
|
|
232
|
+
# Off-diagonal: scatter plots
|
|
233
|
+
x_vals = trials_df[f"values_{j}"].values
|
|
234
|
+
y_vals = trials_df[f"values_{i}"].values
|
|
235
|
+
|
|
236
|
+
ax.scatter(x_vals, y_vals, alpha=0.6, s=30, c='lightblue', edgecolors='navy', linewidths=0.5)
|
|
237
|
+
|
|
238
|
+
if pareto_trials_df is not None and not pareto_trials_df.empty:
|
|
239
|
+
pareto_x = pareto_trials_df[f"values_{j}"].values
|
|
240
|
+
pareto_y = pareto_trials_df[f"values_{i}"].values
|
|
241
|
+
ax.scatter(pareto_x,
|
|
242
|
+
pareto_y,
|
|
243
|
+
alpha=0.9,
|
|
244
|
+
s=60,
|
|
245
|
+
c='red',
|
|
246
|
+
edgecolors='darkred',
|
|
247
|
+
linewidths=1,
|
|
248
|
+
marker='*')
|
|
249
|
+
|
|
250
|
+
ax.set_xlabel(f"{self.metric_names[j]} ({self.directions[j]})")
|
|
251
|
+
ax.set_ylabel(f"{self.metric_names[i]} ({self.directions[i]})")
|
|
252
|
+
|
|
253
|
+
ax.grid(True, alpha=0.3)
|
|
254
|
+
|
|
255
|
+
plt.tight_layout()
|
|
256
|
+
|
|
257
|
+
if save_path:
|
|
258
|
+
fig.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
259
|
+
logger.info("Pairwise matrix plot saved to: %s", save_path)
|
|
260
|
+
|
|
261
|
+
if show_plot:
|
|
262
|
+
plt.show()
|
|
263
|
+
|
|
264
|
+
return fig
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def load_trials_from_study(study: optuna.Study) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
268
|
+
# Get all trials
|
|
269
|
+
trials_df = study.trials_dataframe()
|
|
270
|
+
|
|
271
|
+
# Get Pareto optimal trials
|
|
272
|
+
pareto_trials = study.best_trials
|
|
273
|
+
pareto_trial_numbers = [trial.number for trial in pareto_trials]
|
|
274
|
+
pareto_trials_df = trials_df[trials_df['number'].isin(pareto_trial_numbers)]
|
|
275
|
+
|
|
276
|
+
return trials_df, pareto_trials_df
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def load_trials_from_csv(csv_path: Path, metric_names: list[str],
|
|
280
|
+
directions: list[str]) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
281
|
+
trials_df = pd.read_csv(csv_path)
|
|
282
|
+
|
|
283
|
+
# Extract values columns
|
|
284
|
+
value_cols = [col for col in trials_df.columns if col.startswith('values_')]
|
|
285
|
+
if not value_cols:
|
|
286
|
+
raise ValueError("CSV file must contain 'values_' columns with metric scores")
|
|
287
|
+
|
|
288
|
+
# Compute Pareto optimal solutions manually
|
|
289
|
+
pareto_mask = compute_pareto_optimal_mask(trials_df, value_cols, directions)
|
|
290
|
+
pareto_trials_df = trials_df[pareto_mask]
|
|
291
|
+
|
|
292
|
+
return trials_df, pareto_trials_df
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def compute_pareto_optimal_mask(df: pd.DataFrame, value_cols: list[str], directions: list[str]) -> np.ndarray:
|
|
296
|
+
values = df[value_cols].values
|
|
297
|
+
n_trials = len(values)
|
|
298
|
+
|
|
299
|
+
# Normalize directions: convert all to maximization
|
|
300
|
+
normalized_values = values.copy()
|
|
301
|
+
for i, direction in enumerate(directions):
|
|
302
|
+
if direction == "minimize":
|
|
303
|
+
normalized_values[:, i] = -normalized_values[:, i]
|
|
304
|
+
|
|
305
|
+
is_pareto = np.ones(n_trials, dtype=bool)
|
|
306
|
+
|
|
307
|
+
for i in range(n_trials):
|
|
308
|
+
if is_pareto[i]:
|
|
309
|
+
# Compare with all other solutions
|
|
310
|
+
dominates = np.all(normalized_values[i] >= normalized_values, axis=1) & \
|
|
311
|
+
np.any(normalized_values[i] > normalized_values, axis=1)
|
|
312
|
+
is_pareto[dominates] = False
|
|
313
|
+
|
|
314
|
+
return is_pareto
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def create_pareto_visualization(data_source: optuna.Study | Path | pd.DataFrame,
|
|
318
|
+
metric_names: list[str],
|
|
319
|
+
directions: list[str],
|
|
320
|
+
output_dir: Path | None = None,
|
|
321
|
+
title_prefix: str = "Optimization Results",
|
|
322
|
+
show_plots: bool = True) -> dict[str, plt.Figure]:
|
|
323
|
+
# Load data based on source type
|
|
324
|
+
if hasattr(data_source, 'trials_dataframe'):
|
|
325
|
+
# Optuna study object
|
|
326
|
+
trials_df, pareto_trials_df = load_trials_from_study(data_source)
|
|
327
|
+
elif isinstance(data_source, str | Path):
|
|
328
|
+
# CSV file path
|
|
329
|
+
trials_df, pareto_trials_df = load_trials_from_csv(Path(data_source), metric_names, directions)
|
|
330
|
+
elif isinstance(data_source, pd.DataFrame):
|
|
331
|
+
# DataFrame
|
|
332
|
+
trials_df = data_source
|
|
333
|
+
value_cols = [col for col in trials_df.columns if col.startswith('values_')]
|
|
334
|
+
pareto_mask = compute_pareto_optimal_mask(trials_df, value_cols, directions)
|
|
335
|
+
pareto_trials_df = trials_df[pareto_mask]
|
|
336
|
+
else:
|
|
337
|
+
raise ValueError("data_source must be an Optuna study, CSV file path, or pandas DataFrame")
|
|
338
|
+
|
|
339
|
+
visualizer = ParetoVisualizer(metric_names, directions, title_prefix)
|
|
340
|
+
figures = {}
|
|
341
|
+
|
|
342
|
+
logger.info("Creating Pareto front visualizations...")
|
|
343
|
+
logger.info("Total trials: %d", len(trials_df))
|
|
344
|
+
logger.info("Pareto optimal trials: %d", len(pareto_trials_df))
|
|
345
|
+
|
|
346
|
+
# Create output directory if specified
|
|
347
|
+
if output_dir:
|
|
348
|
+
output_dir = Path(output_dir)
|
|
349
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
350
|
+
|
|
351
|
+
try:
|
|
352
|
+
if len(metric_names) == 2:
|
|
353
|
+
# 2D scatter plot
|
|
354
|
+
save_path = output_dir / "pareto_front_2d.png" if output_dir else None
|
|
355
|
+
fig = visualizer.plot_pareto_front_2d(trials_df, pareto_trials_df, save_path, show_plot=show_plots)
|
|
356
|
+
figures["2d_scatter"] = fig
|
|
357
|
+
|
|
358
|
+
if len(metric_names) >= 2:
|
|
359
|
+
# Parallel coordinates plot
|
|
360
|
+
save_path = output_dir / "pareto_parallel_coordinates.png" if output_dir else None
|
|
361
|
+
fig = visualizer.plot_pareto_parallel_coordinates(trials_df,
|
|
362
|
+
pareto_trials_df,
|
|
363
|
+
save_path,
|
|
364
|
+
show_plot=show_plots)
|
|
365
|
+
figures["parallel_coordinates"] = fig
|
|
366
|
+
|
|
367
|
+
# Pairwise matrix plot
|
|
368
|
+
save_path = output_dir / "pareto_pairwise_matrix.png" if output_dir else None
|
|
369
|
+
fig = visualizer.plot_pairwise_matrix(trials_df, pareto_trials_df, save_path, show_plot=show_plots)
|
|
370
|
+
figures["pairwise_matrix"] = fig
|
|
371
|
+
|
|
372
|
+
logger.info("Visualization complete!")
|
|
373
|
+
if output_dir:
|
|
374
|
+
logger.info("Plots saved to: %s", output_dir)
|
|
375
|
+
|
|
376
|
+
except Exception as e:
|
|
377
|
+
logger.error("Error creating visualizations: %s", e)
|
|
378
|
+
raise
|
|
379
|
+
|
|
380
|
+
return figures
|