nvidia-nat 1.2.0rc5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/agent/__init__.py +0 -0
- aiq/agent/base.py +239 -0
- aiq/agent/dual_node.py +67 -0
- aiq/agent/react_agent/__init__.py +0 -0
- aiq/agent/react_agent/agent.py +355 -0
- aiq/agent/react_agent/output_parser.py +104 -0
- aiq/agent/react_agent/prompt.py +41 -0
- aiq/agent/react_agent/register.py +149 -0
- aiq/agent/reasoning_agent/__init__.py +0 -0
- aiq/agent/reasoning_agent/reasoning_agent.py +225 -0
- aiq/agent/register.py +23 -0
- aiq/agent/rewoo_agent/__init__.py +0 -0
- aiq/agent/rewoo_agent/agent.py +411 -0
- aiq/agent/rewoo_agent/prompt.py +108 -0
- aiq/agent/rewoo_agent/register.py +158 -0
- aiq/agent/tool_calling_agent/__init__.py +0 -0
- aiq/agent/tool_calling_agent/agent.py +119 -0
- aiq/agent/tool_calling_agent/register.py +106 -0
- aiq/authentication/__init__.py +14 -0
- aiq/authentication/api_key/__init__.py +14 -0
- aiq/authentication/api_key/api_key_auth_provider.py +96 -0
- aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
- aiq/authentication/api_key/register.py +26 -0
- aiq/authentication/exceptions/__init__.py +14 -0
- aiq/authentication/exceptions/api_key_exceptions.py +38 -0
- aiq/authentication/http_basic_auth/__init__.py +0 -0
- aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- aiq/authentication/http_basic_auth/register.py +30 -0
- aiq/authentication/interfaces.py +93 -0
- aiq/authentication/oauth2/__init__.py +14 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- aiq/authentication/oauth2/register.py +25 -0
- aiq/authentication/register.py +21 -0
- aiq/builder/__init__.py +0 -0
- aiq/builder/builder.py +285 -0
- aiq/builder/component_utils.py +316 -0
- aiq/builder/context.py +264 -0
- aiq/builder/embedder.py +24 -0
- aiq/builder/eval_builder.py +161 -0
- aiq/builder/evaluator.py +29 -0
- aiq/builder/framework_enum.py +24 -0
- aiq/builder/front_end.py +73 -0
- aiq/builder/function.py +344 -0
- aiq/builder/function_base.py +380 -0
- aiq/builder/function_info.py +627 -0
- aiq/builder/intermediate_step_manager.py +174 -0
- aiq/builder/llm.py +25 -0
- aiq/builder/retriever.py +25 -0
- aiq/builder/user_interaction_manager.py +74 -0
- aiq/builder/workflow.py +148 -0
- aiq/builder/workflow_builder.py +1117 -0
- aiq/cli/__init__.py +14 -0
- aiq/cli/cli_utils/__init__.py +0 -0
- aiq/cli/cli_utils/config_override.py +231 -0
- aiq/cli/cli_utils/validation.py +37 -0
- aiq/cli/commands/__init__.py +0 -0
- aiq/cli/commands/configure/__init__.py +0 -0
- aiq/cli/commands/configure/channel/__init__.py +0 -0
- aiq/cli/commands/configure/channel/add.py +28 -0
- aiq/cli/commands/configure/channel/channel.py +36 -0
- aiq/cli/commands/configure/channel/remove.py +30 -0
- aiq/cli/commands/configure/channel/update.py +30 -0
- aiq/cli/commands/configure/configure.py +33 -0
- aiq/cli/commands/evaluate.py +139 -0
- aiq/cli/commands/info/__init__.py +14 -0
- aiq/cli/commands/info/info.py +39 -0
- aiq/cli/commands/info/list_channels.py +32 -0
- aiq/cli/commands/info/list_components.py +129 -0
- aiq/cli/commands/info/list_mcp.py +213 -0
- aiq/cli/commands/registry/__init__.py +14 -0
- aiq/cli/commands/registry/publish.py +88 -0
- aiq/cli/commands/registry/pull.py +118 -0
- aiq/cli/commands/registry/registry.py +38 -0
- aiq/cli/commands/registry/remove.py +108 -0
- aiq/cli/commands/registry/search.py +155 -0
- aiq/cli/commands/sizing/__init__.py +14 -0
- aiq/cli/commands/sizing/calc.py +297 -0
- aiq/cli/commands/sizing/sizing.py +27 -0
- aiq/cli/commands/start.py +246 -0
- aiq/cli/commands/uninstall.py +81 -0
- aiq/cli/commands/validate.py +47 -0
- aiq/cli/commands/workflow/__init__.py +14 -0
- aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
- aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
- aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- aiq/cli/commands/workflow/workflow.py +37 -0
- aiq/cli/commands/workflow/workflow_commands.py +313 -0
- aiq/cli/entrypoint.py +135 -0
- aiq/cli/main.py +44 -0
- aiq/cli/register_workflow.py +488 -0
- aiq/cli/type_registry.py +1000 -0
- aiq/data_models/__init__.py +14 -0
- aiq/data_models/api_server.py +694 -0
- aiq/data_models/authentication.py +231 -0
- aiq/data_models/common.py +171 -0
- aiq/data_models/component.py +54 -0
- aiq/data_models/component_ref.py +168 -0
- aiq/data_models/config.py +406 -0
- aiq/data_models/dataset_handler.py +123 -0
- aiq/data_models/discovery_metadata.py +335 -0
- aiq/data_models/embedder.py +27 -0
- aiq/data_models/evaluate.py +127 -0
- aiq/data_models/evaluator.py +26 -0
- aiq/data_models/front_end.py +26 -0
- aiq/data_models/function.py +30 -0
- aiq/data_models/function_dependencies.py +72 -0
- aiq/data_models/interactive.py +246 -0
- aiq/data_models/intermediate_step.py +302 -0
- aiq/data_models/invocation_node.py +38 -0
- aiq/data_models/llm.py +27 -0
- aiq/data_models/logging.py +26 -0
- aiq/data_models/memory.py +27 -0
- aiq/data_models/object_store.py +44 -0
- aiq/data_models/profiler.py +54 -0
- aiq/data_models/registry_handler.py +26 -0
- aiq/data_models/retriever.py +30 -0
- aiq/data_models/retry_mixin.py +35 -0
- aiq/data_models/span.py +187 -0
- aiq/data_models/step_adaptor.py +64 -0
- aiq/data_models/streaming.py +33 -0
- aiq/data_models/swe_bench_model.py +54 -0
- aiq/data_models/telemetry_exporter.py +26 -0
- aiq/data_models/ttc_strategy.py +30 -0
- aiq/embedder/__init__.py +0 -0
- aiq/embedder/langchain_client.py +41 -0
- aiq/embedder/nim_embedder.py +59 -0
- aiq/embedder/openai_embedder.py +43 -0
- aiq/embedder/register.py +24 -0
- aiq/eval/__init__.py +14 -0
- aiq/eval/config.py +60 -0
- aiq/eval/dataset_handler/__init__.py +0 -0
- aiq/eval/dataset_handler/dataset_downloader.py +106 -0
- aiq/eval/dataset_handler/dataset_filter.py +52 -0
- aiq/eval/dataset_handler/dataset_handler.py +254 -0
- aiq/eval/evaluate.py +506 -0
- aiq/eval/evaluator/__init__.py +14 -0
- aiq/eval/evaluator/base_evaluator.py +73 -0
- aiq/eval/evaluator/evaluator_model.py +45 -0
- aiq/eval/intermediate_step_adapter.py +99 -0
- aiq/eval/rag_evaluator/__init__.py +0 -0
- aiq/eval/rag_evaluator/evaluate.py +178 -0
- aiq/eval/rag_evaluator/register.py +143 -0
- aiq/eval/register.py +23 -0
- aiq/eval/remote_workflow.py +133 -0
- aiq/eval/runners/__init__.py +14 -0
- aiq/eval/runners/config.py +39 -0
- aiq/eval/runners/multi_eval_runner.py +54 -0
- aiq/eval/runtime_event_subscriber.py +52 -0
- aiq/eval/swe_bench_evaluator/__init__.py +0 -0
- aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
- aiq/eval/swe_bench_evaluator/register.py +36 -0
- aiq/eval/trajectory_evaluator/__init__.py +0 -0
- aiq/eval/trajectory_evaluator/evaluate.py +75 -0
- aiq/eval/trajectory_evaluator/register.py +40 -0
- aiq/eval/tunable_rag_evaluator/__init__.py +0 -0
- aiq/eval/tunable_rag_evaluator/evaluate.py +245 -0
- aiq/eval/tunable_rag_evaluator/register.py +52 -0
- aiq/eval/usage_stats.py +41 -0
- aiq/eval/utils/__init__.py +0 -0
- aiq/eval/utils/output_uploader.py +140 -0
- aiq/eval/utils/tqdm_position_registry.py +40 -0
- aiq/eval/utils/weave_eval.py +184 -0
- aiq/experimental/__init__.py +0 -0
- aiq/experimental/decorators/__init__.py +0 -0
- aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
- aiq/experimental/test_time_compute/__init__.py +0 -0
- aiq/experimental/test_time_compute/editing/__init__.py +0 -0
- aiq/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- aiq/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- aiq/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- aiq/experimental/test_time_compute/functions/__init__.py +0 -0
- aiq/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- aiq/experimental/test_time_compute/functions/its_tool_orchestration_function.py +205 -0
- aiq/experimental/test_time_compute/functions/its_tool_wrapper_function.py +146 -0
- aiq/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- aiq/experimental/test_time_compute/models/__init__.py +0 -0
- aiq/experimental/test_time_compute/models/editor_config.py +132 -0
- aiq/experimental/test_time_compute/models/scoring_config.py +112 -0
- aiq/experimental/test_time_compute/models/search_config.py +120 -0
- aiq/experimental/test_time_compute/models/selection_config.py +154 -0
- aiq/experimental/test_time_compute/models/stage_enums.py +43 -0
- aiq/experimental/test_time_compute/models/strategy_base.py +66 -0
- aiq/experimental/test_time_compute/models/tool_use_config.py +41 -0
- aiq/experimental/test_time_compute/models/ttc_item.py +48 -0
- aiq/experimental/test_time_compute/register.py +36 -0
- aiq/experimental/test_time_compute/scoring/__init__.py +0 -0
- aiq/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- aiq/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- aiq/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- aiq/experimental/test_time_compute/search/__init__.py +0 -0
- aiq/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- aiq/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- aiq/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- aiq/experimental/test_time_compute/selection/__init__.py +0 -0
- aiq/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- aiq/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- aiq/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- aiq/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- aiq/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- aiq/front_ends/__init__.py +14 -0
- aiq/front_ends/console/__init__.py +14 -0
- aiq/front_ends/console/authentication_flow_handler.py +233 -0
- aiq/front_ends/console/console_front_end_config.py +32 -0
- aiq/front_ends/console/console_front_end_plugin.py +96 -0
- aiq/front_ends/console/register.py +25 -0
- aiq/front_ends/cron/__init__.py +14 -0
- aiq/front_ends/fastapi/__init__.py +14 -0
- aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +234 -0
- aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1092 -0
- aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
- aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- aiq/front_ends/fastapi/job_store.py +183 -0
- aiq/front_ends/fastapi/main.py +72 -0
- aiq/front_ends/fastapi/message_handler.py +298 -0
- aiq/front_ends/fastapi/message_validator.py +345 -0
- aiq/front_ends/fastapi/register.py +25 -0
- aiq/front_ends/fastapi/response_helpers.py +195 -0
- aiq/front_ends/fastapi/step_adaptor.py +321 -0
- aiq/front_ends/mcp/__init__.py +14 -0
- aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
- aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
- aiq/front_ends/mcp/register.py +27 -0
- aiq/front_ends/mcp/tool_converter.py +242 -0
- aiq/front_ends/register.py +22 -0
- aiq/front_ends/simple_base/__init__.py +14 -0
- aiq/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- aiq/llm/__init__.py +0 -0
- aiq/llm/aws_bedrock_llm.py +57 -0
- aiq/llm/nim_llm.py +46 -0
- aiq/llm/openai_llm.py +46 -0
- aiq/llm/register.py +23 -0
- aiq/llm/utils/__init__.py +14 -0
- aiq/llm/utils/env_config_value.py +94 -0
- aiq/llm/utils/error.py +17 -0
- aiq/memory/__init__.py +20 -0
- aiq/memory/interfaces.py +183 -0
- aiq/memory/models.py +112 -0
- aiq/meta/module_to_distro.json +3 -0
- aiq/meta/pypi.md +58 -0
- aiq/object_store/__init__.py +20 -0
- aiq/object_store/in_memory_object_store.py +76 -0
- aiq/object_store/interfaces.py +84 -0
- aiq/object_store/models.py +36 -0
- aiq/object_store/register.py +20 -0
- aiq/observability/__init__.py +14 -0
- aiq/observability/exporter/__init__.py +14 -0
- aiq/observability/exporter/base_exporter.py +449 -0
- aiq/observability/exporter/exporter.py +78 -0
- aiq/observability/exporter/file_exporter.py +33 -0
- aiq/observability/exporter/processing_exporter.py +322 -0
- aiq/observability/exporter/raw_exporter.py +52 -0
- aiq/observability/exporter/span_exporter.py +265 -0
- aiq/observability/exporter_manager.py +335 -0
- aiq/observability/mixin/__init__.py +14 -0
- aiq/observability/mixin/batch_config_mixin.py +26 -0
- aiq/observability/mixin/collector_config_mixin.py +23 -0
- aiq/observability/mixin/file_mixin.py +288 -0
- aiq/observability/mixin/file_mode.py +23 -0
- aiq/observability/mixin/resource_conflict_mixin.py +134 -0
- aiq/observability/mixin/serialize_mixin.py +61 -0
- aiq/observability/mixin/type_introspection_mixin.py +183 -0
- aiq/observability/processor/__init__.py +14 -0
- aiq/observability/processor/batching_processor.py +310 -0
- aiq/observability/processor/callback_processor.py +42 -0
- aiq/observability/processor/intermediate_step_serializer.py +28 -0
- aiq/observability/processor/processor.py +71 -0
- aiq/observability/register.py +96 -0
- aiq/observability/utils/__init__.py +14 -0
- aiq/observability/utils/dict_utils.py +236 -0
- aiq/observability/utils/time_utils.py +31 -0
- aiq/plugins/.namespace +1 -0
- aiq/profiler/__init__.py +0 -0
- aiq/profiler/calc/__init__.py +14 -0
- aiq/profiler/calc/calc_runner.py +627 -0
- aiq/profiler/calc/calculations.py +288 -0
- aiq/profiler/calc/data_models.py +188 -0
- aiq/profiler/calc/plot.py +345 -0
- aiq/profiler/callbacks/__init__.py +0 -0
- aiq/profiler/callbacks/agno_callback_handler.py +295 -0
- aiq/profiler/callbacks/base_callback_class.py +20 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +290 -0
- aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
- aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- aiq/profiler/callbacks/token_usage_base_model.py +27 -0
- aiq/profiler/data_frame_row.py +51 -0
- aiq/profiler/data_models.py +24 -0
- aiq/profiler/decorators/__init__.py +0 -0
- aiq/profiler/decorators/framework_wrapper.py +131 -0
- aiq/profiler/decorators/function_tracking.py +254 -0
- aiq/profiler/forecasting/__init__.py +0 -0
- aiq/profiler/forecasting/config.py +18 -0
- aiq/profiler/forecasting/model_trainer.py +75 -0
- aiq/profiler/forecasting/models/__init__.py +22 -0
- aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
- aiq/profiler/forecasting/models/linear_model.py +196 -0
- aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
- aiq/profiler/inference_metrics_model.py +28 -0
- aiq/profiler/inference_optimization/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- aiq/profiler/inference_optimization/data_models.py +386 -0
- aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
- aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- aiq/profiler/inference_optimization/llm_metrics.py +212 -0
- aiq/profiler/inference_optimization/prompt_caching.py +163 -0
- aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
- aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
- aiq/profiler/intermediate_property_adapter.py +102 -0
- aiq/profiler/profile_runner.py +473 -0
- aiq/profiler/utils.py +184 -0
- aiq/registry_handlers/__init__.py +0 -0
- aiq/registry_handlers/local/__init__.py +0 -0
- aiq/registry_handlers/local/local_handler.py +176 -0
- aiq/registry_handlers/local/register_local.py +37 -0
- aiq/registry_handlers/metadata_factory.py +60 -0
- aiq/registry_handlers/package_utils.py +567 -0
- aiq/registry_handlers/pypi/__init__.py +0 -0
- aiq/registry_handlers/pypi/pypi_handler.py +251 -0
- aiq/registry_handlers/pypi/register_pypi.py +40 -0
- aiq/registry_handlers/register.py +21 -0
- aiq/registry_handlers/registry_handler_base.py +157 -0
- aiq/registry_handlers/rest/__init__.py +0 -0
- aiq/registry_handlers/rest/register_rest.py +56 -0
- aiq/registry_handlers/rest/rest_handler.py +237 -0
- aiq/registry_handlers/schemas/__init__.py +0 -0
- aiq/registry_handlers/schemas/headers.py +42 -0
- aiq/registry_handlers/schemas/package.py +68 -0
- aiq/registry_handlers/schemas/publish.py +63 -0
- aiq/registry_handlers/schemas/pull.py +82 -0
- aiq/registry_handlers/schemas/remove.py +36 -0
- aiq/registry_handlers/schemas/search.py +91 -0
- aiq/registry_handlers/schemas/status.py +47 -0
- aiq/retriever/__init__.py +0 -0
- aiq/retriever/interface.py +37 -0
- aiq/retriever/milvus/__init__.py +14 -0
- aiq/retriever/milvus/register.py +81 -0
- aiq/retriever/milvus/retriever.py +228 -0
- aiq/retriever/models.py +74 -0
- aiq/retriever/nemo_retriever/__init__.py +14 -0
- aiq/retriever/nemo_retriever/register.py +60 -0
- aiq/retriever/nemo_retriever/retriever.py +190 -0
- aiq/retriever/register.py +22 -0
- aiq/runtime/__init__.py +14 -0
- aiq/runtime/loader.py +215 -0
- aiq/runtime/runner.py +190 -0
- aiq/runtime/session.py +158 -0
- aiq/runtime/user_metadata.py +130 -0
- aiq/settings/__init__.py +0 -0
- aiq/settings/global_settings.py +318 -0
- aiq/test/.namespace +1 -0
- aiq/tool/__init__.py +0 -0
- aiq/tool/chat_completion.py +74 -0
- aiq/tool/code_execution/README.md +151 -0
- aiq/tool/code_execution/__init__.py +0 -0
- aiq/tool/code_execution/code_sandbox.py +267 -0
- aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
- aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- aiq/tool/code_execution/register.py +74 -0
- aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
- aiq/tool/code_execution/utils.py +100 -0
- aiq/tool/datetime_tools.py +42 -0
- aiq/tool/document_search.py +141 -0
- aiq/tool/github_tools/__init__.py +0 -0
- aiq/tool/github_tools/create_github_commit.py +133 -0
- aiq/tool/github_tools/create_github_issue.py +87 -0
- aiq/tool/github_tools/create_github_pr.py +106 -0
- aiq/tool/github_tools/get_github_file.py +106 -0
- aiq/tool/github_tools/get_github_issue.py +166 -0
- aiq/tool/github_tools/get_github_pr.py +256 -0
- aiq/tool/github_tools/update_github_issue.py +100 -0
- aiq/tool/mcp/__init__.py +14 -0
- aiq/tool/mcp/exceptions.py +142 -0
- aiq/tool/mcp/mcp_client.py +255 -0
- aiq/tool/mcp/mcp_tool.py +96 -0
- aiq/tool/memory_tools/__init__.py +0 -0
- aiq/tool/memory_tools/add_memory_tool.py +79 -0
- aiq/tool/memory_tools/delete_memory_tool.py +67 -0
- aiq/tool/memory_tools/get_memory_tool.py +72 -0
- aiq/tool/nvidia_rag.py +95 -0
- aiq/tool/register.py +38 -0
- aiq/tool/retriever.py +89 -0
- aiq/tool/server_tools.py +66 -0
- aiq/utils/__init__.py +0 -0
- aiq/utils/data_models/__init__.py +0 -0
- aiq/utils/data_models/schema_validator.py +58 -0
- aiq/utils/debugging_utils.py +43 -0
- aiq/utils/dump_distro_mapping.py +32 -0
- aiq/utils/exception_handlers/__init__.py +0 -0
- aiq/utils/exception_handlers/automatic_retries.py +289 -0
- aiq/utils/exception_handlers/mcp.py +211 -0
- aiq/utils/exception_handlers/schemas.py +114 -0
- aiq/utils/io/__init__.py +0 -0
- aiq/utils/io/model_processing.py +28 -0
- aiq/utils/io/yaml_tools.py +119 -0
- aiq/utils/log_utils.py +37 -0
- aiq/utils/metadata_utils.py +74 -0
- aiq/utils/optional_imports.py +142 -0
- aiq/utils/producer_consumer_queue.py +178 -0
- aiq/utils/reactive/__init__.py +0 -0
- aiq/utils/reactive/base/__init__.py +0 -0
- aiq/utils/reactive/base/observable_base.py +65 -0
- aiq/utils/reactive/base/observer_base.py +55 -0
- aiq/utils/reactive/base/subject_base.py +79 -0
- aiq/utils/reactive/observable.py +59 -0
- aiq/utils/reactive/observer.py +76 -0
- aiq/utils/reactive/subject.py +131 -0
- aiq/utils/reactive/subscription.py +49 -0
- aiq/utils/settings/__init__.py +0 -0
- aiq/utils/settings/global_settings.py +197 -0
- aiq/utils/string_utils.py +38 -0
- aiq/utils/type_converter.py +290 -0
- aiq/utils/type_utils.py +484 -0
- aiq/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0rc5.dist-info/METADATA +363 -0
- nvidia_nat-1.2.0rc5.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0rc5.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0rc5.dist-info/entry_points.txt +20 -0
- nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
- nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0rc5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
from typing import Callable
|
|
19
|
+
|
|
20
|
+
from langchain.output_parsers import ResponseSchema
|
|
21
|
+
from langchain.output_parsers import StructuredOutputParser
|
|
22
|
+
from langchain.schema import HumanMessage
|
|
23
|
+
from langchain.schema import SystemMessage
|
|
24
|
+
from langchain_core.language_models import BaseChatModel
|
|
25
|
+
from langchain_core.runnables import RunnableLambda
|
|
26
|
+
from tqdm import tqdm
|
|
27
|
+
|
|
28
|
+
from aiq.eval.evaluator.base_evaluator import BaseEvaluator
|
|
29
|
+
from aiq.eval.evaluator.evaluator_model import EvalInputItem
|
|
30
|
+
from aiq.eval.evaluator.evaluator_model import EvalOutputItem
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
# pylint: disable=line-too-long
|
|
35
|
+
# flake8: noqa: E501
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def evaluation_prompt(judge_llm_prompt: str,
|
|
39
|
+
question: str,
|
|
40
|
+
answer_description: str,
|
|
41
|
+
generated_answer: str,
|
|
42
|
+
format_instructions: str,
|
|
43
|
+
default_scoring: bool):
|
|
44
|
+
"""
|
|
45
|
+
This function generates a prompt for the judge LLM to evaluate the generated answer.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
DEFAULT_SCORING_INSTRUCTIONS = """
|
|
49
|
+
The coverage score is a measure of how well the generated answer covers the critical aspects mentioned in the expected answer. A low coverage score indicates that the generated answer misses critical aspects of the expected answer. A middle coverage score indicates that the generated answer covers some of the must-haves of the expected answer but lacks other details. A high coverage score indicates that all of the expected aspects are present in the generated answer.
|
|
50
|
+
The correctness score is a measure of how well the generated answer matches the expected answer. A low correctness score indicates that the generated answer is incorrect or does not match the expected answer. A middle correctness score indicates that the generated answer is correct but lacks some details. A high correctness score indicates that the generated answer is exactly the same as the expected answer.
|
|
51
|
+
The relevance score is a measure of how well the generated answer is relevant to the question. A low relevance score indicates that the generated answer is not relevant to the question. A middle relevance score indicates that the generated answer is somewhat relevant to the question. A high relevance score indicates that the generated answer is exactly relevant to the question.
|
|
52
|
+
The reasoning is a 1-2 sentence explanation for the scoring.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
DEFAULT_EVAL_PROMPT = (f"You are an intelligent assistant that responds strictly in JSON format."
|
|
56
|
+
f"Judge based on the following scoring rubric: {DEFAULT_SCORING_INSTRUCTIONS}"
|
|
57
|
+
f"{judge_llm_prompt}\n"
|
|
58
|
+
f"{format_instructions}\n"
|
|
59
|
+
f"Here is the user's query: {question}"
|
|
60
|
+
f"Here is the description of the expected answer: {answer_description}"
|
|
61
|
+
f"Here is the generated answer: {generated_answer}")
|
|
62
|
+
|
|
63
|
+
EVAL_PROMPT = (f"You are an intelligent assistant that responds strictly in JSON format. {judge_llm_prompt}\n"
|
|
64
|
+
f"{format_instructions}\n"
|
|
65
|
+
f"Here is the user's query: {question}"
|
|
66
|
+
f"Here is the description of the expected answer: {answer_description}"
|
|
67
|
+
f"Here is the generated answer: {generated_answer}")
|
|
68
|
+
|
|
69
|
+
return EVAL_PROMPT if not default_scoring else DEFAULT_EVAL_PROMPT
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def runnable_with_retries(original_fn: Callable, llm_retry_control_params: dict | None = None):
|
|
73
|
+
runnable = RunnableLambda(original_fn)
|
|
74
|
+
|
|
75
|
+
if llm_retry_control_params is None:
|
|
76
|
+
llm_retry_control_params = {
|
|
77
|
+
"stop_after_attempt": 3, "initial_backoff_delay_seconds": 1, "has_exponential_jitter": True
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
if llm_retry_control_params["has_exponential_jitter"] is None:
|
|
81
|
+
llm_retry_control_params["has_exponential_jitter"] = True
|
|
82
|
+
if llm_retry_control_params["stop_after_attempt"] is None:
|
|
83
|
+
llm_retry_control_params["stop_after_attempt"] = 3
|
|
84
|
+
if llm_retry_control_params["initial_backoff_delay_seconds"] is None:
|
|
85
|
+
llm_retry_control_params["initial_backoff_delay_seconds"] = 1
|
|
86
|
+
|
|
87
|
+
# Add retry logic with exponential backoff and jitter
|
|
88
|
+
return runnable.with_retry(
|
|
89
|
+
retry_if_exception_type=(Exception, ), # Retry on any error
|
|
90
|
+
wait_exponential_jitter=llm_retry_control_params["has_exponential_jitter"], # Add jitter to exponential backoff
|
|
91
|
+
stop_after_attempt=llm_retry_control_params["stop_after_attempt"],
|
|
92
|
+
exponential_jitter_params={"initial": llm_retry_control_params["initial_backoff_delay_seconds"]
|
|
93
|
+
} # Optional: set initial backoff (seconds)
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class TunableRagEvaluator(BaseEvaluator):
|
|
98
|
+
'''Tunable RAG evaluator class with customizable LLM prompt for scoring.'''
|
|
99
|
+
|
|
100
|
+
def __init__(self,
|
|
101
|
+
llm: BaseChatModel,
|
|
102
|
+
judge_llm_prompt: str,
|
|
103
|
+
llm_retry_control_params: dict | None,
|
|
104
|
+
max_concurrency: int,
|
|
105
|
+
default_scoring: bool,
|
|
106
|
+
default_score_weights: dict):
|
|
107
|
+
super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating RAG")
|
|
108
|
+
self.llm = llm
|
|
109
|
+
self.judge_llm_prompt = judge_llm_prompt
|
|
110
|
+
self.llm_retry_control_params = llm_retry_control_params
|
|
111
|
+
self.default_scoring = default_scoring
|
|
112
|
+
# Use user-provided weights if available; otherwise, set equal weights for each score
|
|
113
|
+
self.default_score_weights = default_score_weights if default_score_weights else {
|
|
114
|
+
"coverage": 1 / 3, "correctness": 1 / 3, "relevance": 1 / 3
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem:
|
|
118
|
+
"""Compute RAG evaluation for an individual item and return EvalOutputItem"""
|
|
119
|
+
question = item.input_obj
|
|
120
|
+
answer_description = item.expected_output_obj
|
|
121
|
+
generated_answer = item.output_obj
|
|
122
|
+
|
|
123
|
+
# Call judge LLM to generate score
|
|
124
|
+
score = 0.0
|
|
125
|
+
|
|
126
|
+
default_evaluation_schema = [
|
|
127
|
+
ResponseSchema(
|
|
128
|
+
name="coverage_score",
|
|
129
|
+
description="Score for the coverage of all critical aspects mentioned in the expected answer. Ex. 0.5",
|
|
130
|
+
type="float"),
|
|
131
|
+
ResponseSchema(
|
|
132
|
+
name="correctness_score",
|
|
133
|
+
description="Score for the accuracy of the generated answer compared to the expected answer. Ex. 0.5",
|
|
134
|
+
type="float"),
|
|
135
|
+
ResponseSchema(name="relevance_score",
|
|
136
|
+
description="Score for the relevance of the generated answer to the question. Ex. 0.5",
|
|
137
|
+
type="float"),
|
|
138
|
+
ResponseSchema(
|
|
139
|
+
name="reasoning",
|
|
140
|
+
description=
|
|
141
|
+
"1-2 summarized sentences of reasoning for the scores. Ex. 'The generated answer covers all critical aspects mentioned in the expected answer, is correct, and is relevant to the question.'",
|
|
142
|
+
type="string"),
|
|
143
|
+
]
|
|
144
|
+
|
|
145
|
+
custom_evaluation_schema = [
|
|
146
|
+
ResponseSchema(name="score", description="Score for the generated answer. Ex. 0.5", type="float"),
|
|
147
|
+
ResponseSchema(
|
|
148
|
+
name="reasoning",
|
|
149
|
+
description=
|
|
150
|
+
"1-2 sentence reasoning for the score. Ex. 'The generated answer is exactly the same as the description of the expected answer.'",
|
|
151
|
+
type="string"),
|
|
152
|
+
]
|
|
153
|
+
|
|
154
|
+
if self.default_scoring:
|
|
155
|
+
evaluation_schema = default_evaluation_schema
|
|
156
|
+
else:
|
|
157
|
+
evaluation_schema = custom_evaluation_schema
|
|
158
|
+
|
|
159
|
+
llm_input_response_parser = StructuredOutputParser.from_response_schemas(evaluation_schema)
|
|
160
|
+
format_instructions = llm_input_response_parser.get_format_instructions()
|
|
161
|
+
|
|
162
|
+
eval_prompt = evaluation_prompt(judge_llm_prompt=self.judge_llm_prompt,
|
|
163
|
+
question=question,
|
|
164
|
+
answer_description=answer_description,
|
|
165
|
+
generated_answer=generated_answer,
|
|
166
|
+
format_instructions=format_instructions,
|
|
167
|
+
default_scoring=self.default_scoring)
|
|
168
|
+
|
|
169
|
+
messages = [SystemMessage(content="You must respond only in JSON format."), HumanMessage(content=eval_prompt)]
|
|
170
|
+
|
|
171
|
+
response = await runnable_with_retries(self.llm.ainvoke, self.llm_retry_control_params).ainvoke(messages)
|
|
172
|
+
|
|
173
|
+
# Initialize default values to handle service errors
|
|
174
|
+
coverage_score = 0.0
|
|
175
|
+
correctness_score = 0.0
|
|
176
|
+
relevance_score = 0.0
|
|
177
|
+
reasoning = "Error in evaluator from parsing judge LLM response."
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
parsed_response = llm_input_response_parser.parse(response.content)
|
|
181
|
+
if self.default_scoring:
|
|
182
|
+
try:
|
|
183
|
+
coverage_score = parsed_response["coverage_score"]
|
|
184
|
+
correctness_score = parsed_response["correctness_score"]
|
|
185
|
+
relevance_score = parsed_response["relevance_score"]
|
|
186
|
+
reasoning = parsed_response["reasoning"]
|
|
187
|
+
except KeyError as e:
|
|
188
|
+
logger.error("Missing required keys in default scoring response: %s",
|
|
189
|
+
", ".join(str(arg) for arg in e.args))
|
|
190
|
+
reasoning = f"Error in evaluator from parsing judge LLM response. Missing required key(s): {', '.join(str(arg) for arg in e.args)}"
|
|
191
|
+
|
|
192
|
+
coverage_weight = self.default_score_weights.get("coverage", 1 / 3)
|
|
193
|
+
correctness_weight = self.default_score_weights.get("correctness", 1 / 3)
|
|
194
|
+
relevance_weight = self.default_score_weights.get("relevance", 1 / 3)
|
|
195
|
+
|
|
196
|
+
# Calculate score
|
|
197
|
+
total_weight = coverage_weight + correctness_weight + relevance_weight
|
|
198
|
+
coverage_weight = coverage_weight / total_weight
|
|
199
|
+
correctness_weight = correctness_weight / total_weight
|
|
200
|
+
relevance_weight = relevance_weight / total_weight
|
|
201
|
+
|
|
202
|
+
if round(coverage_weight + correctness_weight + relevance_weight, 2) != 1:
|
|
203
|
+
logger.warning("The sum of the default score weights is not 1. The weights will be normalized.")
|
|
204
|
+
coverage_weight = coverage_weight / (coverage_weight + correctness_weight + relevance_weight)
|
|
205
|
+
correctness_weight = correctness_weight / (coverage_weight + correctness_weight + relevance_weight)
|
|
206
|
+
relevance_weight = relevance_weight / (coverage_weight + correctness_weight + relevance_weight)
|
|
207
|
+
|
|
208
|
+
score = (coverage_weight * coverage_score + correctness_weight * correctness_score +
|
|
209
|
+
relevance_weight * relevance_score)
|
|
210
|
+
|
|
211
|
+
else:
|
|
212
|
+
try:
|
|
213
|
+
score = parsed_response["score"]
|
|
214
|
+
reasoning = parsed_response["reasoning"]
|
|
215
|
+
except KeyError as e:
|
|
216
|
+
logger.error("Missing required keys in custom scoring response: %s",
|
|
217
|
+
", ".join(str(arg) for arg in e.args))
|
|
218
|
+
reasoning = f"Error in evaluator from parsing judge LLM response. Missing required key(s): {', '.join(str(arg) for arg in e.args)}"
|
|
219
|
+
raise
|
|
220
|
+
except (KeyError, ValueError) as e:
|
|
221
|
+
logger.error("Error parsing judge LLM response: %s", e)
|
|
222
|
+
score = 0.0
|
|
223
|
+
reasoning = "Error in evaluator from parsing judge LLM response."
|
|
224
|
+
|
|
225
|
+
if self.default_scoring:
|
|
226
|
+
reasoning = {
|
|
227
|
+
"question": question,
|
|
228
|
+
"answer_description": answer_description,
|
|
229
|
+
"generated_answer": generated_answer,
|
|
230
|
+
"score_breakdown": {
|
|
231
|
+
"coverage_score": coverage_score,
|
|
232
|
+
"correctness_score": correctness_score,
|
|
233
|
+
"relevance_score": relevance_score,
|
|
234
|
+
},
|
|
235
|
+
"reasoning": reasoning,
|
|
236
|
+
}
|
|
237
|
+
else:
|
|
238
|
+
reasoning = {
|
|
239
|
+
"question": question,
|
|
240
|
+
"answer_description": answer_description,
|
|
241
|
+
"generated_answer": generated_answer,
|
|
242
|
+
"reasoning": reasoning
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
return EvalOutputItem(id=item.id, score=score, reasoning=reasoning)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import Field
|
|
17
|
+
|
|
18
|
+
from aiq.builder.builder import EvalBuilder
|
|
19
|
+
from aiq.builder.evaluator import EvaluatorInfo
|
|
20
|
+
from aiq.builder.framework_enum import LLMFrameworkEnum
|
|
21
|
+
from aiq.cli.register_workflow import register_evaluator
|
|
22
|
+
from aiq.data_models.component_ref import LLMRef
|
|
23
|
+
from aiq.data_models.evaluator import EvaluatorBaseConfig
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TunableRagEvaluatorConfig(EvaluatorBaseConfig, name="tunable_rag_evaluator"):
|
|
27
|
+
'''Configuration for tunable RAG evaluator'''
|
|
28
|
+
llm_name: LLMRef = Field(description="Name of the judge LLM")
|
|
29
|
+
llm_retry_control_params: dict | None = Field(description="Parameters to control LLM retry behavior", default=None)
|
|
30
|
+
judge_llm_prompt: str = Field(description="LLM prompt for the judge LLM")
|
|
31
|
+
default_scoring: bool = Field(description="Whether to use default scoring", default=False)
|
|
32
|
+
default_score_weights: dict = Field(
|
|
33
|
+
default={
|
|
34
|
+
"coverage": 0.5, "correctness": 0.3, "relevance": 0.2
|
|
35
|
+
},
|
|
36
|
+
description="Weights for the different scoring components when using default scoring")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@register_evaluator(config_type=TunableRagEvaluatorConfig)
|
|
40
|
+
async def register_tunable_rag_evaluator(config: TunableRagEvaluatorConfig, builder: EvalBuilder):
|
|
41
|
+
'''Register tunable RAG evaluator'''
|
|
42
|
+
from .evaluate import TunableRagEvaluator
|
|
43
|
+
|
|
44
|
+
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
45
|
+
evaluator = TunableRagEvaluator(llm,
|
|
46
|
+
config.judge_llm_prompt,
|
|
47
|
+
config.llm_retry_control_params,
|
|
48
|
+
builder.get_max_concurrency(),
|
|
49
|
+
config.default_scoring,
|
|
50
|
+
config.default_score_weights)
|
|
51
|
+
|
|
52
|
+
yield EvaluatorInfo(config=config, evaluate_fn=evaluator.evaluate, description="Tunable RAG Evaluator")
|
aiq/eval/usage_stats.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class UsageStatsLLM(BaseModel):
|
|
22
|
+
prompt_tokens: int = 0
|
|
23
|
+
completion_tokens: int = 0
|
|
24
|
+
total_tokens: int = 0
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class UsageStatsItem(BaseModel):
|
|
28
|
+
usage_stats_per_llm: dict[str, UsageStatsLLM]
|
|
29
|
+
total_tokens: int | None = None
|
|
30
|
+
runtime: float = 0.0
|
|
31
|
+
min_timestamp: float = 0.0
|
|
32
|
+
max_timestamp: float = 0.0
|
|
33
|
+
llm_latency: float = 0.0
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class UsageStats(BaseModel):
|
|
37
|
+
# key is the id or input_obj from EvalInputItem
|
|
38
|
+
min_timestamp: float = 0.0
|
|
39
|
+
max_timestamp: float = 0.0
|
|
40
|
+
total_runtime: float = 0.0
|
|
41
|
+
usage_stats_items: dict[typing.Any, UsageStatsItem] = {}
|
|
File without changes
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
import subprocess
|
|
20
|
+
import sys
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
import aioboto3
|
|
24
|
+
from botocore.exceptions import NoCredentialsError
|
|
25
|
+
from tqdm import tqdm
|
|
26
|
+
|
|
27
|
+
from aiq.data_models.evaluate import EvalOutputConfig
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class OutputUploader:
|
|
33
|
+
"""
|
|
34
|
+
Run custom scripts and upload evaluation outputs using the configured s3
|
|
35
|
+
credentials.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, output_config: EvalOutputConfig, job_id: str | None = None):
|
|
39
|
+
self.output_config = output_config
|
|
40
|
+
self._s3_client = None
|
|
41
|
+
self.job_id = job_id
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def s3_config(self):
|
|
45
|
+
return self.output_config.s3
|
|
46
|
+
|
|
47
|
+
async def _upload_file(self, s3_client, bucket, s3_key, local_path, pbar):
|
|
48
|
+
try:
|
|
49
|
+
await s3_client.upload_file(str(local_path), bucket, s3_key)
|
|
50
|
+
logger.info("Uploaded %s to s3://%s/%s", local_path, bucket, s3_key)
|
|
51
|
+
pbar.update(1)
|
|
52
|
+
except Exception as e:
|
|
53
|
+
logger.error("Failed to upload %s to s3://%s/%s: %s", local_path, bucket, s3_key, e)
|
|
54
|
+
raise
|
|
55
|
+
|
|
56
|
+
async def upload_directory(self):
|
|
57
|
+
"""
|
|
58
|
+
Upload the contents of the local output directory to the remote S3 bucket in parallel.
|
|
59
|
+
"""
|
|
60
|
+
if not self.output_config.s3:
|
|
61
|
+
logger.info("No S3 config provided; skipping upload.")
|
|
62
|
+
return
|
|
63
|
+
|
|
64
|
+
local_dir = self.output_config.dir
|
|
65
|
+
bucket = self.s3_config.bucket
|
|
66
|
+
remote_prefix = self.output_config.remote_dir or ""
|
|
67
|
+
if self.job_id:
|
|
68
|
+
remote_prefix = str(Path(remote_prefix) / f"jobs/{self.job_id}")
|
|
69
|
+
|
|
70
|
+
file_entries = []
|
|
71
|
+
for root, _, files in os.walk(local_dir):
|
|
72
|
+
for file in files:
|
|
73
|
+
local_path = Path(root) / file
|
|
74
|
+
relative_path = local_path.relative_to(local_dir)
|
|
75
|
+
s3_path = Path(remote_prefix) / relative_path
|
|
76
|
+
s3_key = str(s3_path).replace("\\", "/") # Normalize for S3
|
|
77
|
+
file_entries.append((local_path, s3_key))
|
|
78
|
+
|
|
79
|
+
session = aioboto3.Session()
|
|
80
|
+
try:
|
|
81
|
+
if self.s3_config.endpoint_url:
|
|
82
|
+
region_name = None
|
|
83
|
+
endpoint_url = self.s3_config.endpoint_url
|
|
84
|
+
elif self.s3_config.region_name:
|
|
85
|
+
region_name = self.s3_config.region_name
|
|
86
|
+
endpoint_url = None
|
|
87
|
+
else:
|
|
88
|
+
raise ValueError("No endpoint_url or region_name provided in the config: eval.general.output.s3")
|
|
89
|
+
async with session.client(
|
|
90
|
+
"s3",
|
|
91
|
+
endpoint_url=endpoint_url,
|
|
92
|
+
region_name=region_name,
|
|
93
|
+
aws_access_key_id=self.s3_config.access_key,
|
|
94
|
+
aws_secret_access_key=self.s3_config.secret_key,
|
|
95
|
+
) as s3_client:
|
|
96
|
+
with tqdm(total=len(file_entries), desc="Uploading files to S3") as pbar:
|
|
97
|
+
upload_tasks = [
|
|
98
|
+
self._upload_file(s3_client, bucket, s3_key, local_path, pbar)
|
|
99
|
+
for local_path, s3_key in file_entries
|
|
100
|
+
]
|
|
101
|
+
await asyncio.gather(*upload_tasks)
|
|
102
|
+
|
|
103
|
+
except NoCredentialsError as e:
|
|
104
|
+
logger.error("AWS credentials not available: %s", e)
|
|
105
|
+
raise
|
|
106
|
+
except Exception as e:
|
|
107
|
+
logger.error("Failed to upload files to S3: %s", e)
|
|
108
|
+
raise
|
|
109
|
+
|
|
110
|
+
def run_custom_scripts(self):
|
|
111
|
+
"""
|
|
112
|
+
Run custom Python scripts defined in the EvalOutputConfig.
|
|
113
|
+
Each script is run with its kwargs passed as command-line arguments.
|
|
114
|
+
The output directory is passed as the first argument.
|
|
115
|
+
"""
|
|
116
|
+
for _, script_config in self.output_config.custom_scripts.items():
|
|
117
|
+
script_path = script_config.script
|
|
118
|
+
if not script_path.exists():
|
|
119
|
+
logger.error("Custom script %s does not exist.", script_path)
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
# use python interpreter
|
|
123
|
+
args = [sys.executable, str(script_path)]
|
|
124
|
+
# add output directory as first keyword argument
|
|
125
|
+
args.append("--output_dir")
|
|
126
|
+
args.append(str(self.output_config.dir))
|
|
127
|
+
if script_config.kwargs:
|
|
128
|
+
for key, value in script_config.kwargs.items():
|
|
129
|
+
args.append(f"--{key}")
|
|
130
|
+
args.append(str(value))
|
|
131
|
+
|
|
132
|
+
display_args = " ".join(f'"{arg}"' if " " in arg else arg for arg in args[1:])
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
logger.info("Running custom script: %s %s", script_path, display_args)
|
|
136
|
+
subprocess.run(args, check=True, text=True)
|
|
137
|
+
logger.info("Custom script %s completed successfully.", script_path)
|
|
138
|
+
except subprocess.CalledProcessError as e:
|
|
139
|
+
logger.error("Custom script %s failed with return code %s", script_path, e.returncode)
|
|
140
|
+
raise
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TqdmPositionRegistry:
|
|
18
|
+
"""
|
|
19
|
+
A simple registry for tqdm positions.
|
|
20
|
+
"""
|
|
21
|
+
_positions = set()
|
|
22
|
+
_max_positions = 100
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def claim(cls) -> int:
|
|
26
|
+
"""
|
|
27
|
+
Claim a tqdm position in the range of 0-99.
|
|
28
|
+
"""
|
|
29
|
+
for i in range(cls._max_positions):
|
|
30
|
+
if i not in cls._positions:
|
|
31
|
+
cls._positions.add(i)
|
|
32
|
+
return i
|
|
33
|
+
raise RuntimeError("No available tqdm positions.")
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def release(cls, pos: int):
|
|
37
|
+
"""
|
|
38
|
+
Release a tqdm position.
|
|
39
|
+
"""
|
|
40
|
+
cls._positions.discard(pos)
|