nvidia-nat 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/__init__.py +0 -0
- nat/agent/base.py +256 -0
- nat/agent/dual_node.py +67 -0
- nat/agent/react_agent/__init__.py +0 -0
- nat/agent/react_agent/agent.py +363 -0
- nat/agent/react_agent/output_parser.py +104 -0
- nat/agent/react_agent/prompt.py +44 -0
- nat/agent/react_agent/register.py +149 -0
- nat/agent/reasoning_agent/__init__.py +0 -0
- nat/agent/reasoning_agent/reasoning_agent.py +225 -0
- nat/agent/register.py +23 -0
- nat/agent/rewoo_agent/__init__.py +0 -0
- nat/agent/rewoo_agent/agent.py +415 -0
- nat/agent/rewoo_agent/prompt.py +110 -0
- nat/agent/rewoo_agent/register.py +157 -0
- nat/agent/tool_calling_agent/__init__.py +0 -0
- nat/agent/tool_calling_agent/agent.py +119 -0
- nat/agent/tool_calling_agent/register.py +106 -0
- nat/authentication/__init__.py +14 -0
- nat/authentication/api_key/__init__.py +14 -0
- nat/authentication/api_key/api_key_auth_provider.py +96 -0
- nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
- nat/authentication/api_key/register.py +26 -0
- nat/authentication/exceptions/__init__.py +14 -0
- nat/authentication/exceptions/api_key_exceptions.py +38 -0
- nat/authentication/http_basic_auth/__init__.py +0 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- nat/authentication/http_basic_auth/register.py +30 -0
- nat/authentication/interfaces.py +93 -0
- nat/authentication/oauth2/__init__.py +14 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- nat/authentication/oauth2/register.py +25 -0
- nat/authentication/register.py +21 -0
- nat/builder/__init__.py +0 -0
- nat/builder/builder.py +285 -0
- nat/builder/component_utils.py +316 -0
- nat/builder/context.py +270 -0
- nat/builder/embedder.py +24 -0
- nat/builder/eval_builder.py +161 -0
- nat/builder/evaluator.py +29 -0
- nat/builder/framework_enum.py +24 -0
- nat/builder/front_end.py +73 -0
- nat/builder/function.py +344 -0
- nat/builder/function_base.py +380 -0
- nat/builder/function_info.py +627 -0
- nat/builder/intermediate_step_manager.py +174 -0
- nat/builder/llm.py +25 -0
- nat/builder/retriever.py +25 -0
- nat/builder/user_interaction_manager.py +78 -0
- nat/builder/workflow.py +148 -0
- nat/builder/workflow_builder.py +1117 -0
- nat/cli/__init__.py +14 -0
- nat/cli/cli_utils/__init__.py +0 -0
- nat/cli/cli_utils/config_override.py +231 -0
- nat/cli/cli_utils/validation.py +37 -0
- nat/cli/commands/__init__.py +0 -0
- nat/cli/commands/configure/__init__.py +0 -0
- nat/cli/commands/configure/channel/__init__.py +0 -0
- nat/cli/commands/configure/channel/add.py +28 -0
- nat/cli/commands/configure/channel/channel.py +34 -0
- nat/cli/commands/configure/channel/remove.py +30 -0
- nat/cli/commands/configure/channel/update.py +30 -0
- nat/cli/commands/configure/configure.py +33 -0
- nat/cli/commands/evaluate.py +139 -0
- nat/cli/commands/info/__init__.py +14 -0
- nat/cli/commands/info/info.py +37 -0
- nat/cli/commands/info/list_channels.py +32 -0
- nat/cli/commands/info/list_components.py +129 -0
- nat/cli/commands/info/list_mcp.py +304 -0
- nat/cli/commands/registry/__init__.py +14 -0
- nat/cli/commands/registry/publish.py +88 -0
- nat/cli/commands/registry/pull.py +118 -0
- nat/cli/commands/registry/registry.py +36 -0
- nat/cli/commands/registry/remove.py +108 -0
- nat/cli/commands/registry/search.py +155 -0
- nat/cli/commands/sizing/__init__.py +14 -0
- nat/cli/commands/sizing/calc.py +297 -0
- nat/cli/commands/sizing/sizing.py +27 -0
- nat/cli/commands/start.py +246 -0
- nat/cli/commands/uninstall.py +81 -0
- nat/cli/commands/validate.py +47 -0
- nat/cli/commands/workflow/__init__.py +14 -0
- nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +16 -0
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- nat/cli/commands/workflow/templates/register.py.j2 +5 -0
- nat/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- nat/cli/commands/workflow/workflow.py +37 -0
- nat/cli/commands/workflow/workflow_commands.py +317 -0
- nat/cli/entrypoint.py +135 -0
- nat/cli/main.py +57 -0
- nat/cli/register_workflow.py +488 -0
- nat/cli/type_registry.py +1000 -0
- nat/data_models/__init__.py +14 -0
- nat/data_models/api_server.py +716 -0
- nat/data_models/authentication.py +231 -0
- nat/data_models/common.py +171 -0
- nat/data_models/component.py +58 -0
- nat/data_models/component_ref.py +168 -0
- nat/data_models/config.py +410 -0
- nat/data_models/dataset_handler.py +169 -0
- nat/data_models/discovery_metadata.py +305 -0
- nat/data_models/embedder.py +27 -0
- nat/data_models/evaluate.py +127 -0
- nat/data_models/evaluator.py +26 -0
- nat/data_models/front_end.py +26 -0
- nat/data_models/function.py +30 -0
- nat/data_models/function_dependencies.py +72 -0
- nat/data_models/interactive.py +246 -0
- nat/data_models/intermediate_step.py +302 -0
- nat/data_models/invocation_node.py +38 -0
- nat/data_models/llm.py +27 -0
- nat/data_models/logging.py +26 -0
- nat/data_models/memory.py +27 -0
- nat/data_models/object_store.py +44 -0
- nat/data_models/profiler.py +54 -0
- nat/data_models/registry_handler.py +26 -0
- nat/data_models/retriever.py +30 -0
- nat/data_models/retry_mixin.py +35 -0
- nat/data_models/span.py +190 -0
- nat/data_models/step_adaptor.py +64 -0
- nat/data_models/streaming.py +33 -0
- nat/data_models/swe_bench_model.py +54 -0
- nat/data_models/telemetry_exporter.py +26 -0
- nat/data_models/ttc_strategy.py +30 -0
- nat/embedder/__init__.py +0 -0
- nat/embedder/nim_embedder.py +59 -0
- nat/embedder/openai_embedder.py +43 -0
- nat/embedder/register.py +22 -0
- nat/eval/__init__.py +14 -0
- nat/eval/config.py +60 -0
- nat/eval/dataset_handler/__init__.py +0 -0
- nat/eval/dataset_handler/dataset_downloader.py +106 -0
- nat/eval/dataset_handler/dataset_filter.py +52 -0
- nat/eval/dataset_handler/dataset_handler.py +367 -0
- nat/eval/evaluate.py +510 -0
- nat/eval/evaluator/__init__.py +14 -0
- nat/eval/evaluator/base_evaluator.py +77 -0
- nat/eval/evaluator/evaluator_model.py +45 -0
- nat/eval/intermediate_step_adapter.py +99 -0
- nat/eval/rag_evaluator/__init__.py +0 -0
- nat/eval/rag_evaluator/evaluate.py +178 -0
- nat/eval/rag_evaluator/register.py +143 -0
- nat/eval/register.py +23 -0
- nat/eval/remote_workflow.py +133 -0
- nat/eval/runners/__init__.py +14 -0
- nat/eval/runners/config.py +39 -0
- nat/eval/runners/multi_eval_runner.py +54 -0
- nat/eval/runtime_event_subscriber.py +52 -0
- nat/eval/swe_bench_evaluator/__init__.py +0 -0
- nat/eval/swe_bench_evaluator/evaluate.py +215 -0
- nat/eval/swe_bench_evaluator/register.py +36 -0
- nat/eval/trajectory_evaluator/__init__.py +0 -0
- nat/eval/trajectory_evaluator/evaluate.py +75 -0
- nat/eval/trajectory_evaluator/register.py +40 -0
- nat/eval/tunable_rag_evaluator/__init__.py +0 -0
- nat/eval/tunable_rag_evaluator/evaluate.py +245 -0
- nat/eval/tunable_rag_evaluator/register.py +52 -0
- nat/eval/usage_stats.py +41 -0
- nat/eval/utils/__init__.py +0 -0
- nat/eval/utils/output_uploader.py +140 -0
- nat/eval/utils/tqdm_position_registry.py +40 -0
- nat/eval/utils/weave_eval.py +184 -0
- nat/experimental/__init__.py +0 -0
- nat/experimental/decorators/__init__.py +0 -0
- nat/experimental/decorators/experimental_warning_decorator.py +134 -0
- nat/experimental/test_time_compute/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- nat/experimental/test_time_compute/functions/__init__.py +0 -0
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
- nat/experimental/test_time_compute/models/__init__.py +0 -0
- nat/experimental/test_time_compute/models/editor_config.py +132 -0
- nat/experimental/test_time_compute/models/scoring_config.py +112 -0
- nat/experimental/test_time_compute/models/search_config.py +120 -0
- nat/experimental/test_time_compute/models/selection_config.py +154 -0
- nat/experimental/test_time_compute/models/stage_enums.py +43 -0
- nat/experimental/test_time_compute/models/strategy_base.py +66 -0
- nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
- nat/experimental/test_time_compute/models/ttc_item.py +48 -0
- nat/experimental/test_time_compute/register.py +36 -0
- nat/experimental/test_time_compute/scoring/__init__.py +0 -0
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- nat/experimental/test_time_compute/search/__init__.py +0 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- nat/experimental/test_time_compute/selection/__init__.py +0 -0
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- nat/front_ends/__init__.py +14 -0
- nat/front_ends/console/__init__.py +14 -0
- nat/front_ends/console/authentication_flow_handler.py +233 -0
- nat/front_ends/console/console_front_end_config.py +32 -0
- nat/front_ends/console/console_front_end_plugin.py +96 -0
- nat/front_ends/console/register.py +25 -0
- nat/front_ends/cron/__init__.py +14 -0
- nat/front_ends/fastapi/__init__.py +14 -0
- nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +241 -0
- nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1087 -0
- nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- nat/front_ends/fastapi/job_store.py +183 -0
- nat/front_ends/fastapi/main.py +72 -0
- nat/front_ends/fastapi/message_handler.py +320 -0
- nat/front_ends/fastapi/message_validator.py +352 -0
- nat/front_ends/fastapi/register.py +25 -0
- nat/front_ends/fastapi/response_helpers.py +195 -0
- nat/front_ends/fastapi/step_adaptor.py +319 -0
- nat/front_ends/mcp/__init__.py +14 -0
- nat/front_ends/mcp/mcp_front_end_config.py +36 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +81 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +143 -0
- nat/front_ends/mcp/register.py +27 -0
- nat/front_ends/mcp/tool_converter.py +241 -0
- nat/front_ends/register.py +22 -0
- nat/front_ends/simple_base/__init__.py +14 -0
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- nat/llm/__init__.py +0 -0
- nat/llm/aws_bedrock_llm.py +57 -0
- nat/llm/nim_llm.py +46 -0
- nat/llm/openai_llm.py +46 -0
- nat/llm/register.py +23 -0
- nat/llm/utils/__init__.py +14 -0
- nat/llm/utils/env_config_value.py +94 -0
- nat/llm/utils/error.py +17 -0
- nat/memory/__init__.py +20 -0
- nat/memory/interfaces.py +183 -0
- nat/memory/models.py +112 -0
- nat/meta/pypi.md +58 -0
- nat/object_store/__init__.py +20 -0
- nat/object_store/in_memory_object_store.py +76 -0
- nat/object_store/interfaces.py +84 -0
- nat/object_store/models.py +38 -0
- nat/object_store/register.py +20 -0
- nat/observability/__init__.py +14 -0
- nat/observability/exporter/__init__.py +14 -0
- nat/observability/exporter/base_exporter.py +449 -0
- nat/observability/exporter/exporter.py +78 -0
- nat/observability/exporter/file_exporter.py +33 -0
- nat/observability/exporter/processing_exporter.py +322 -0
- nat/observability/exporter/raw_exporter.py +52 -0
- nat/observability/exporter/span_exporter.py +288 -0
- nat/observability/exporter_manager.py +335 -0
- nat/observability/mixin/__init__.py +14 -0
- nat/observability/mixin/batch_config_mixin.py +26 -0
- nat/observability/mixin/collector_config_mixin.py +23 -0
- nat/observability/mixin/file_mixin.py +288 -0
- nat/observability/mixin/file_mode.py +23 -0
- nat/observability/mixin/resource_conflict_mixin.py +134 -0
- nat/observability/mixin/serialize_mixin.py +61 -0
- nat/observability/mixin/type_introspection_mixin.py +183 -0
- nat/observability/processor/__init__.py +14 -0
- nat/observability/processor/batching_processor.py +310 -0
- nat/observability/processor/callback_processor.py +42 -0
- nat/observability/processor/intermediate_step_serializer.py +28 -0
- nat/observability/processor/processor.py +71 -0
- nat/observability/register.py +96 -0
- nat/observability/utils/__init__.py +14 -0
- nat/observability/utils/dict_utils.py +236 -0
- nat/observability/utils/time_utils.py +31 -0
- nat/plugins/.namespace +1 -0
- nat/profiler/__init__.py +0 -0
- nat/profiler/calc/__init__.py +14 -0
- nat/profiler/calc/calc_runner.py +627 -0
- nat/profiler/calc/calculations.py +288 -0
- nat/profiler/calc/data_models.py +188 -0
- nat/profiler/calc/plot.py +345 -0
- nat/profiler/callbacks/__init__.py +0 -0
- nat/profiler/callbacks/agno_callback_handler.py +295 -0
- nat/profiler/callbacks/base_callback_class.py +20 -0
- nat/profiler/callbacks/langchain_callback_handler.py +290 -0
- nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- nat/profiler/callbacks/token_usage_base_model.py +27 -0
- nat/profiler/data_frame_row.py +51 -0
- nat/profiler/data_models.py +24 -0
- nat/profiler/decorators/__init__.py +0 -0
- nat/profiler/decorators/framework_wrapper.py +131 -0
- nat/profiler/decorators/function_tracking.py +254 -0
- nat/profiler/forecasting/__init__.py +0 -0
- nat/profiler/forecasting/config.py +18 -0
- nat/profiler/forecasting/model_trainer.py +75 -0
- nat/profiler/forecasting/models/__init__.py +22 -0
- nat/profiler/forecasting/models/forecasting_base_model.py +40 -0
- nat/profiler/forecasting/models/linear_model.py +197 -0
- nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
- nat/profiler/inference_metrics_model.py +28 -0
- nat/profiler/inference_optimization/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- nat/profiler/inference_optimization/data_models.py +386 -0
- nat/profiler/inference_optimization/experimental/__init__.py +0 -0
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- nat/profiler/inference_optimization/llm_metrics.py +212 -0
- nat/profiler/inference_optimization/prompt_caching.py +163 -0
- nat/profiler/inference_optimization/token_uniqueness.py +107 -0
- nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
- nat/profiler/intermediate_property_adapter.py +102 -0
- nat/profiler/profile_runner.py +473 -0
- nat/profiler/utils.py +184 -0
- nat/registry_handlers/__init__.py +0 -0
- nat/registry_handlers/local/__init__.py +0 -0
- nat/registry_handlers/local/local_handler.py +176 -0
- nat/registry_handlers/local/register_local.py +37 -0
- nat/registry_handlers/metadata_factory.py +60 -0
- nat/registry_handlers/package_utils.py +571 -0
- nat/registry_handlers/pypi/__init__.py +0 -0
- nat/registry_handlers/pypi/pypi_handler.py +251 -0
- nat/registry_handlers/pypi/register_pypi.py +40 -0
- nat/registry_handlers/register.py +21 -0
- nat/registry_handlers/registry_handler_base.py +157 -0
- nat/registry_handlers/rest/__init__.py +0 -0
- nat/registry_handlers/rest/register_rest.py +56 -0
- nat/registry_handlers/rest/rest_handler.py +237 -0
- nat/registry_handlers/schemas/__init__.py +0 -0
- nat/registry_handlers/schemas/headers.py +42 -0
- nat/registry_handlers/schemas/package.py +68 -0
- nat/registry_handlers/schemas/publish.py +68 -0
- nat/registry_handlers/schemas/pull.py +82 -0
- nat/registry_handlers/schemas/remove.py +36 -0
- nat/registry_handlers/schemas/search.py +91 -0
- nat/registry_handlers/schemas/status.py +47 -0
- nat/retriever/__init__.py +0 -0
- nat/retriever/interface.py +41 -0
- nat/retriever/milvus/__init__.py +14 -0
- nat/retriever/milvus/register.py +81 -0
- nat/retriever/milvus/retriever.py +228 -0
- nat/retriever/models.py +77 -0
- nat/retriever/nemo_retriever/__init__.py +14 -0
- nat/retriever/nemo_retriever/register.py +60 -0
- nat/retriever/nemo_retriever/retriever.py +190 -0
- nat/retriever/register.py +22 -0
- nat/runtime/__init__.py +14 -0
- nat/runtime/loader.py +220 -0
- nat/runtime/runner.py +195 -0
- nat/runtime/session.py +162 -0
- nat/runtime/user_metadata.py +130 -0
- nat/settings/__init__.py +0 -0
- nat/settings/global_settings.py +318 -0
- nat/test/.namespace +1 -0
- nat/tool/__init__.py +0 -0
- nat/tool/chat_completion.py +74 -0
- nat/tool/code_execution/README.md +151 -0
- nat/tool/code_execution/__init__.py +0 -0
- nat/tool/code_execution/code_sandbox.py +267 -0
- nat/tool/code_execution/local_sandbox/.gitignore +1 -0
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- nat/tool/code_execution/local_sandbox/__init__.py +13 -0
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- nat/tool/code_execution/register.py +74 -0
- nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
- nat/tool/code_execution/utils.py +100 -0
- nat/tool/datetime_tools.py +42 -0
- nat/tool/document_search.py +141 -0
- nat/tool/github_tools/__init__.py +0 -0
- nat/tool/github_tools/create_github_commit.py +133 -0
- nat/tool/github_tools/create_github_issue.py +87 -0
- nat/tool/github_tools/create_github_pr.py +106 -0
- nat/tool/github_tools/get_github_file.py +106 -0
- nat/tool/github_tools/get_github_issue.py +166 -0
- nat/tool/github_tools/get_github_pr.py +256 -0
- nat/tool/github_tools/update_github_issue.py +100 -0
- nat/tool/mcp/__init__.py +14 -0
- nat/tool/mcp/exceptions.py +142 -0
- nat/tool/mcp/mcp_client.py +255 -0
- nat/tool/mcp/mcp_tool.py +96 -0
- nat/tool/memory_tools/__init__.py +0 -0
- nat/tool/memory_tools/add_memory_tool.py +79 -0
- nat/tool/memory_tools/delete_memory_tool.py +67 -0
- nat/tool/memory_tools/get_memory_tool.py +72 -0
- nat/tool/nvidia_rag.py +95 -0
- nat/tool/register.py +38 -0
- nat/tool/retriever.py +94 -0
- nat/tool/server_tools.py +66 -0
- nat/utils/__init__.py +0 -0
- nat/utils/data_models/__init__.py +0 -0
- nat/utils/data_models/schema_validator.py +58 -0
- nat/utils/debugging_utils.py +43 -0
- nat/utils/dump_distro_mapping.py +32 -0
- nat/utils/exception_handlers/__init__.py +0 -0
- nat/utils/exception_handlers/automatic_retries.py +289 -0
- nat/utils/exception_handlers/mcp.py +211 -0
- nat/utils/exception_handlers/schemas.py +114 -0
- nat/utils/io/__init__.py +0 -0
- nat/utils/io/model_processing.py +28 -0
- nat/utils/io/yaml_tools.py +119 -0
- nat/utils/log_utils.py +37 -0
- nat/utils/metadata_utils.py +74 -0
- nat/utils/optional_imports.py +142 -0
- nat/utils/producer_consumer_queue.py +178 -0
- nat/utils/reactive/__init__.py +0 -0
- nat/utils/reactive/base/__init__.py +0 -0
- nat/utils/reactive/base/observable_base.py +65 -0
- nat/utils/reactive/base/observer_base.py +55 -0
- nat/utils/reactive/base/subject_base.py +79 -0
- nat/utils/reactive/observable.py +59 -0
- nat/utils/reactive/observer.py +76 -0
- nat/utils/reactive/subject.py +131 -0
- nat/utils/reactive/subscription.py +49 -0
- nat/utils/settings/__init__.py +0 -0
- nat/utils/settings/global_settings.py +197 -0
- nat/utils/string_utils.py +38 -0
- nat/utils/type_converter.py +290 -0
- nat/utils/type_utils.py +484 -0
- nat/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0.dist-info/METADATA +365 -0
- nvidia_nat-1.2.0.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0.dist-info/entry_points.txt +21 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import importlib.metadata
|
|
17
|
+
import inspect
|
|
18
|
+
import logging
|
|
19
|
+
import typing
|
|
20
|
+
from enum import Enum
|
|
21
|
+
from functools import lru_cache
|
|
22
|
+
from types import ModuleType
|
|
23
|
+
from typing import TYPE_CHECKING
|
|
24
|
+
|
|
25
|
+
from pydantic import BaseModel
|
|
26
|
+
from pydantic import field_validator
|
|
27
|
+
|
|
28
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
29
|
+
from nat.data_models.component import ComponentEnum
|
|
30
|
+
from nat.utils.metadata_utils import generate_config_type_docs
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from nat.cli.type_registry import ToolWrapperBuildCallableT
|
|
34
|
+
from nat.data_models.common import TypedBaseModelT
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class DiscoveryStatusEnum(str, Enum):
|
|
40
|
+
SUCCESS = "success"
|
|
41
|
+
FAILURE = "failure"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DiscoveryContractFieldsEnum(str, Enum):
|
|
45
|
+
PACKAGE = "package"
|
|
46
|
+
VERSION = "version"
|
|
47
|
+
COMPONENT_TYPE = "component_type"
|
|
48
|
+
COMPONENT_NAME = "component_name"
|
|
49
|
+
DESCRIPTION = "description"
|
|
50
|
+
DEVELOPER_NOTES = "developer_notes"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class DiscoveryMetadata(BaseModel):
|
|
54
|
+
"""A data model representing metadata about each registered component to faciliate its discovery.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
package (str): The name of the package containing the NAT component.
|
|
58
|
+
version (str): The version number of the package containing the NAT component.
|
|
59
|
+
component_type (ComponentEnum): The type of NAT component this metadata represents.
|
|
60
|
+
component_name (str): The registered name of the NAT component.
|
|
61
|
+
description (str): Description of the NAT component pulled from its config objects docstrings.
|
|
62
|
+
developer_notes (str): Other notes to a developers to aid in the use of the component.
|
|
63
|
+
status (DiscoveryStatusEnum): Provides the status of the metadata discovery process.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
package: str = ""
|
|
67
|
+
version: str = ""
|
|
68
|
+
component_type: ComponentEnum = ComponentEnum.UNDEFINED
|
|
69
|
+
component_name: str = ""
|
|
70
|
+
description: str = ""
|
|
71
|
+
developer_notes: str = ""
|
|
72
|
+
status: DiscoveryStatusEnum = DiscoveryStatusEnum.SUCCESS
|
|
73
|
+
|
|
74
|
+
@field_validator("description", mode="before")
|
|
75
|
+
@classmethod
|
|
76
|
+
def ensure_description_string(cls, v: typing.Any):
|
|
77
|
+
if not isinstance(v, str):
|
|
78
|
+
return ""
|
|
79
|
+
return v
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def get_preferred_item(items: list, preferred: str) -> str:
|
|
83
|
+
return preferred if preferred in items else items[0]
|
|
84
|
+
|
|
85
|
+
@staticmethod
|
|
86
|
+
@lru_cache
|
|
87
|
+
def get_distribution_name_from_metadata(root_package_name: str) -> str | None:
|
|
88
|
+
"""
|
|
89
|
+
This is not performant and is only present to be used (not used
|
|
90
|
+
currently) as a fallback when the distro name doesn't match the
|
|
91
|
+
module name and private_data is not available to map it.
|
|
92
|
+
"""
|
|
93
|
+
mapping = importlib.metadata.packages_distributions()
|
|
94
|
+
try:
|
|
95
|
+
distro_names = mapping.get(root_package_name, [None])
|
|
96
|
+
distro_name = DiscoveryMetadata.get_preferred_item(distro_names, "nvidia-nat")
|
|
97
|
+
except KeyError:
|
|
98
|
+
return root_package_name
|
|
99
|
+
|
|
100
|
+
return distro_name if distro_name else root_package_name
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
@lru_cache
|
|
104
|
+
def get_distribution_name_from_module(module: ModuleType | None) -> str:
|
|
105
|
+
"""Get the distribution name from the config type using the mapping of module names to distro names.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
module (ModuleType): A registered component's module.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
str: The distribution name of the NAT component.
|
|
112
|
+
"""
|
|
113
|
+
from nat.runtime.loader import get_all_entrypoints_distro_mapping
|
|
114
|
+
|
|
115
|
+
if module is None:
|
|
116
|
+
return "nvidia-nat"
|
|
117
|
+
|
|
118
|
+
# Get the mapping of module names to distro names
|
|
119
|
+
mapping = get_all_entrypoints_distro_mapping()
|
|
120
|
+
module_package = module.__package__
|
|
121
|
+
|
|
122
|
+
if module_package is None:
|
|
123
|
+
return "nvidia-nat"
|
|
124
|
+
|
|
125
|
+
# Traverse the module package parts in reverse order to find the distro name
|
|
126
|
+
# This is because the module package is the root package for the NAT component
|
|
127
|
+
# and the distro name is the name of the package that contains the component
|
|
128
|
+
module_package_parts = module_package.split(".")
|
|
129
|
+
for part_idx in range(len(module_package_parts), 0, -1):
|
|
130
|
+
candidate_module_name = ".".join(module_package_parts[0:part_idx])
|
|
131
|
+
candidate_distro_name = mapping.get(candidate_module_name, None)
|
|
132
|
+
if candidate_distro_name is not None:
|
|
133
|
+
return candidate_distro_name
|
|
134
|
+
|
|
135
|
+
return "nvidia-nat"
|
|
136
|
+
|
|
137
|
+
@staticmethod
|
|
138
|
+
@lru_cache
|
|
139
|
+
def get_distribution_name_from_config_type(config_type: type["TypedBaseModelT"]) -> str:
|
|
140
|
+
"""Get the distribution name from the config type using the mapping of module names to distro names.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
config_type (type[TypedBaseModelT]): A registered component's configuration object.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
str: The distribution name of the NAT component.
|
|
147
|
+
"""
|
|
148
|
+
module = inspect.getmodule(config_type)
|
|
149
|
+
return DiscoveryMetadata.get_distribution_name_from_module(module)
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def from_config_type(config_type: type["TypedBaseModelT"],
|
|
153
|
+
component_type: ComponentEnum = ComponentEnum.UNDEFINED) -> "DiscoveryMetadata":
|
|
154
|
+
"""Generates discovery metadata from a NAT config object.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
config_type (type[TypedBaseModelT]): A registered component's configuration object.
|
|
158
|
+
component_type (ComponentEnum, optional): The type of the registered component. Defaults to
|
|
159
|
+
ComponentEnum.UNDEFINED.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
DiscoveryMetadata: A an object containing component metadata to facilitate discovery and reuse.
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
module = inspect.getmodule(config_type)
|
|
167
|
+
distro_name = DiscoveryMetadata.get_distribution_name_from_config_type(config_type)
|
|
168
|
+
|
|
169
|
+
if not distro_name:
|
|
170
|
+
# raise an exception
|
|
171
|
+
logger.error("Encountered issue getting distro_name for module %s", module.__name__)
|
|
172
|
+
return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
version = importlib.metadata.version(distro_name) if distro_name != "" else ""
|
|
176
|
+
except importlib.metadata.PackageNotFoundError:
|
|
177
|
+
logger.warning("Package metadata not found for %s", distro_name)
|
|
178
|
+
version = ""
|
|
179
|
+
except Exception as e:
|
|
180
|
+
logger.exception("Encountered issue extracting module metadata for %s: %s", config_type, e, exc_info=True)
|
|
181
|
+
return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
|
|
182
|
+
|
|
183
|
+
description = generate_config_type_docs(config_type=config_type)
|
|
184
|
+
|
|
185
|
+
return DiscoveryMetadata(package=distro_name,
|
|
186
|
+
version=version,
|
|
187
|
+
component_type=component_type,
|
|
188
|
+
component_name=config_type.static_type(),
|
|
189
|
+
description=description)
|
|
190
|
+
|
|
191
|
+
@staticmethod
|
|
192
|
+
def from_fn_wrapper(fn: "ToolWrapperBuildCallableT",
|
|
193
|
+
wrapper_type: LLMFrameworkEnum | str,
|
|
194
|
+
component_type: ComponentEnum = ComponentEnum.TOOL_WRAPPER) -> "DiscoveryMetadata":
|
|
195
|
+
"""Generates discovery metadata from function with specified wrapper type.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
fn (ToolWrapperBuildCallableT): A tool wrapper callable to source component metadata.
|
|
199
|
+
wrapper_type (LLMFrameworkEnum): The wrapper to apply to the callable to faciliate inter-framwork
|
|
200
|
+
interoperability.
|
|
201
|
+
|
|
202
|
+
component_type (ComponentEnum, optional): The type of the registered component. Defaults to
|
|
203
|
+
ComponentEnum.TOOL_WRAPPER.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
DiscoveryMetadata: A an object containing component metadata to facilitate discovery and reuse.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
module = inspect.getmodule(fn)
|
|
211
|
+
distro_name = DiscoveryMetadata.get_distribution_name_from_module(module)
|
|
212
|
+
|
|
213
|
+
try:
|
|
214
|
+
# version = importlib.metadata.version(root_package) if root_package != "" else ""
|
|
215
|
+
version = importlib.metadata.version(distro_name) if distro_name != "" else ""
|
|
216
|
+
except importlib.metadata.PackageNotFoundError:
|
|
217
|
+
logger.warning("Package metadata not found for %s", distro_name)
|
|
218
|
+
version = ""
|
|
219
|
+
except Exception as e:
|
|
220
|
+
logger.exception("Encountered issue extracting module metadata for %s: %s", fn, e, exc_info=True)
|
|
221
|
+
return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
|
|
222
|
+
|
|
223
|
+
if isinstance(wrapper_type, LLMFrameworkEnum):
|
|
224
|
+
wrapper_type = wrapper_type.value
|
|
225
|
+
|
|
226
|
+
return DiscoveryMetadata(package=distro_name,
|
|
227
|
+
version=version,
|
|
228
|
+
component_type=component_type,
|
|
229
|
+
component_name=wrapper_type,
|
|
230
|
+
description=fn.__doc__ or "")
|
|
231
|
+
|
|
232
|
+
@staticmethod
|
|
233
|
+
def from_package_name(package_name: str, package_version: str | None) -> "DiscoveryMetadata":
|
|
234
|
+
"""Generates discovery metadata from an installed package name.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
package_name (str): The name of the NAT plugin package containing registered components.
|
|
238
|
+
package_version (str, optional): The version of the package, Defaults to None.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
DiscoveryMetadata: A an object containing component metadata to facilitate discovery and reuse.
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
try:
|
|
245
|
+
try:
|
|
246
|
+
metadata = importlib.metadata.metadata(package_name)
|
|
247
|
+
description = metadata.get("Summary", "")
|
|
248
|
+
if (package_version is None):
|
|
249
|
+
package_version = importlib.metadata.version(package_name)
|
|
250
|
+
except importlib.metadata.PackageNotFoundError:
|
|
251
|
+
logger.warning("Package metadata not found for %s", package_name)
|
|
252
|
+
description = ""
|
|
253
|
+
package_version = package_version or ""
|
|
254
|
+
except Exception as e:
|
|
255
|
+
logger.exception("Encountered issue extracting module metadata for %s: %s", package_name, e, exc_info=True)
|
|
256
|
+
return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
|
|
257
|
+
|
|
258
|
+
return DiscoveryMetadata(package=package_name,
|
|
259
|
+
version=package_version,
|
|
260
|
+
component_type=ComponentEnum.PACKAGE,
|
|
261
|
+
component_name=package_name,
|
|
262
|
+
description=description)
|
|
263
|
+
|
|
264
|
+
@staticmethod
|
|
265
|
+
def from_provider_framework_map(config_type: type["TypedBaseModelT"],
|
|
266
|
+
wrapper_type: LLMFrameworkEnum | str | None,
|
|
267
|
+
provider_type: ComponentEnum,
|
|
268
|
+
component_type: ComponentEnum = ComponentEnum.UNDEFINED) -> "DiscoveryMetadata":
|
|
269
|
+
"""Generates discovery metadata from provider and framework mapping information.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
config_type (type[TypedBaseModelT]): A registered component's configuration object.
|
|
273
|
+
wrapper_type (LLMFrameworkEnum | str): The wrapper to apply to the callable to faciliate inter-framwork
|
|
274
|
+
interoperability.
|
|
275
|
+
|
|
276
|
+
provider_type (ComponentEnum): The type of provider the registered component supports.
|
|
277
|
+
component_type (ComponentEnum, optional): The type of the registered component. Defaults to
|
|
278
|
+
ComponentEnum.UNDEFINED.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
DiscoveryMetadata: A an object containing component metadata to facilitate discovery and reuse.
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
try:
|
|
285
|
+
module = inspect.getmodule(config_type)
|
|
286
|
+
distro_name = DiscoveryMetadata.get_distribution_name_from_module(module)
|
|
287
|
+
try:
|
|
288
|
+
version = importlib.metadata.version(distro_name) if distro_name != "" else ""
|
|
289
|
+
except importlib.metadata.PackageNotFoundError:
|
|
290
|
+
logger.warning("Package metadata not found for %s", distro_name)
|
|
291
|
+
version = ""
|
|
292
|
+
except Exception as e:
|
|
293
|
+
logger.exception("Encountered issue extracting module metadata for %s: %s", config_type, e, exc_info=True)
|
|
294
|
+
return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
|
|
295
|
+
|
|
296
|
+
wrapper_type = wrapper_type.value if isinstance(wrapper_type, LLMFrameworkEnum) else wrapper_type
|
|
297
|
+
component_name = f"{config_type.static_type()} ({provider_type.value}) - {wrapper_type}"
|
|
298
|
+
|
|
299
|
+
description = generate_config_type_docs(config_type=config_type)
|
|
300
|
+
|
|
301
|
+
return DiscoveryMetadata(package=distro_name,
|
|
302
|
+
version=version,
|
|
303
|
+
component_type=component_type,
|
|
304
|
+
component_name=component_name,
|
|
305
|
+
description=description)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
|
|
18
|
+
from .common import BaseModelRegistryTag
|
|
19
|
+
from .common import TypedBaseModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class EmbedderBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
23
|
+
""" Base configuration for embedding model providers. """
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
EmbedderBaseConfigT = typing.TypeVar("EmbedderBaseConfigT", bound=EmbedderBaseConfig)
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
from enum import Enum
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
from pydantic import BaseModel
|
|
21
|
+
from pydantic import Discriminator
|
|
22
|
+
from pydantic import model_validator
|
|
23
|
+
|
|
24
|
+
from nat.data_models.common import TypedBaseModel
|
|
25
|
+
from nat.data_models.dataset_handler import EvalDatasetConfig
|
|
26
|
+
from nat.data_models.dataset_handler import EvalS3Config
|
|
27
|
+
from nat.data_models.evaluator import EvaluatorBaseConfig
|
|
28
|
+
from nat.data_models.intermediate_step import IntermediateStepType
|
|
29
|
+
from nat.data_models.profiler import ProfilerConfig
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class JobEvictionPolicy(str, Enum):
|
|
33
|
+
"""Policy for evicting old jobs when max_jobs is exceeded."""
|
|
34
|
+
TIME_CREATED = "time_created"
|
|
35
|
+
TIME_MODIFIED = "time_modified"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class EvalCustomScriptConfig(BaseModel):
|
|
39
|
+
# Path to the script to run
|
|
40
|
+
script: Path
|
|
41
|
+
# Keyword arguments to pass to the script
|
|
42
|
+
kwargs: dict[str, str] = {}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class JobManagementConfig(BaseModel):
|
|
46
|
+
# Whether to append a unique job ID to the output directory for each run
|
|
47
|
+
append_job_id_to_output_dir: bool = False
|
|
48
|
+
# Maximum number of jobs to keep in the output directory. Oldest jobs will be evicted.
|
|
49
|
+
# A value of 0 means no limit.
|
|
50
|
+
max_jobs: int = 0
|
|
51
|
+
# Policy for evicting old jobs. Defaults to using time_created.
|
|
52
|
+
eviction_policy: JobEvictionPolicy = JobEvictionPolicy.TIME_CREATED
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class EvalOutputConfig(BaseModel):
|
|
56
|
+
# Output directory for the workflow and evaluation results
|
|
57
|
+
dir: Path = Path("./.tmp/nat/examples/default/")
|
|
58
|
+
# S3 prefix for the workflow and evaluation results
|
|
59
|
+
remote_dir: str | None = None
|
|
60
|
+
# Custom scripts to run after the workflow and evaluation results are saved
|
|
61
|
+
custom_scripts: dict[str, EvalCustomScriptConfig] = {}
|
|
62
|
+
# S3 config for uploading the contents of the output directory
|
|
63
|
+
s3: EvalS3Config | None = None
|
|
64
|
+
# Whether to cleanup the output directory before running the workflow
|
|
65
|
+
cleanup: bool = True
|
|
66
|
+
# Job management configuration (job id, eviction, etc.)
|
|
67
|
+
job_management: JobManagementConfig = JobManagementConfig()
|
|
68
|
+
# Filter for the workflow output steps
|
|
69
|
+
workflow_output_step_filter: list[IntermediateStepType] | None = None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class EvalGeneralConfig(BaseModel):
|
|
73
|
+
max_concurrency: int = 8
|
|
74
|
+
|
|
75
|
+
# Workflow alias for displaying in evaluation UI, if not provided,
|
|
76
|
+
# the workflow type will be used
|
|
77
|
+
workflow_alias: str | None = None
|
|
78
|
+
|
|
79
|
+
# Output directory for the workflow and evaluation results
|
|
80
|
+
output_dir: Path = Path("./.tmp/nat/examples/default/")
|
|
81
|
+
|
|
82
|
+
# If present overrides output_dir
|
|
83
|
+
output: EvalOutputConfig | None = None
|
|
84
|
+
|
|
85
|
+
# Dataset for running the workflow and evaluating
|
|
86
|
+
dataset: EvalDatasetConfig | None = None
|
|
87
|
+
|
|
88
|
+
# Inference profiler
|
|
89
|
+
profiler: ProfilerConfig | None = None
|
|
90
|
+
|
|
91
|
+
# overwrite the output_dir with the output config if present
|
|
92
|
+
@model_validator(mode="before")
|
|
93
|
+
@classmethod
|
|
94
|
+
def override_output_dir(cls, values):
|
|
95
|
+
if values.get("output") and values["output"].get("dir"):
|
|
96
|
+
values["output_dir"] = values["output"]["dir"]
|
|
97
|
+
return values
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class EvalConfig(BaseModel):
|
|
101
|
+
|
|
102
|
+
# General Evaluation Options
|
|
103
|
+
general: EvalGeneralConfig = EvalGeneralConfig()
|
|
104
|
+
|
|
105
|
+
# Evaluators
|
|
106
|
+
evaluators: dict[str, EvaluatorBaseConfig] = {}
|
|
107
|
+
|
|
108
|
+
@classmethod
|
|
109
|
+
def rebuild_annotations(cls):
|
|
110
|
+
|
|
111
|
+
from nat.cli.type_registry import GlobalTypeRegistry # pylint: disable=cyclic-import
|
|
112
|
+
|
|
113
|
+
type_registry = GlobalTypeRegistry.get()
|
|
114
|
+
|
|
115
|
+
EvaluatorsAnnotation = dict[str,
|
|
116
|
+
typing.Annotated[type_registry.compute_annotation(EvaluatorBaseConfig),
|
|
117
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
118
|
+
|
|
119
|
+
should_rebuild = False
|
|
120
|
+
|
|
121
|
+
evaluators_field = cls.model_fields.get("evaluators")
|
|
122
|
+
if evaluators_field is not None and evaluators_field.annotation != EvaluatorsAnnotation:
|
|
123
|
+
evaluators_field.annotation = EvaluatorsAnnotation
|
|
124
|
+
should_rebuild = True
|
|
125
|
+
|
|
126
|
+
if (should_rebuild):
|
|
127
|
+
cls.model_rebuild(force=True)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
|
|
18
|
+
from .common import BaseModelRegistryTag
|
|
19
|
+
from .common import TypedBaseModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class EvaluatorBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
EvaluatorBaseConfigT = typing.TypeVar("EvaluatorBaseConfigT", bound=EvaluatorBaseConfig)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
|
|
18
|
+
from .common import BaseModelRegistryTag
|
|
19
|
+
from .common import TypedBaseModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FrontEndBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
FrontEndConfigT = typing.TypeVar("FrontEndConfigT", bound=FrontEndBaseConfig)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
|
|
18
|
+
from .common import BaseModelRegistryTag
|
|
19
|
+
from .common import TypedBaseModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FunctionBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class EmptyFunctionConfig(FunctionBaseConfig, name="EmptyFunctionConfig"):
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
FunctionConfigT = typing.TypeVar("FunctionConfigT", bound=FunctionBaseConfig)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel
|
|
17
|
+
from pydantic import Field
|
|
18
|
+
from pydantic import field_serializer
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FunctionDependencies(BaseModel):
|
|
22
|
+
"""
|
|
23
|
+
A class to represent the dependencies of a function.
|
|
24
|
+
"""
|
|
25
|
+
functions: set[str] = Field(default_factory=set)
|
|
26
|
+
llms: set[str] = Field(default_factory=set)
|
|
27
|
+
embedders: set[str] = Field(default_factory=set)
|
|
28
|
+
memory_clients: set[str] = Field(default_factory=set)
|
|
29
|
+
object_stores: set[str] = Field(default_factory=set)
|
|
30
|
+
retrievers: set[str] = Field(default_factory=set)
|
|
31
|
+
|
|
32
|
+
@field_serializer("functions", when_used="json")
|
|
33
|
+
def serialize_functions(self, v: set[str]) -> list[str]:
|
|
34
|
+
return list(v)
|
|
35
|
+
|
|
36
|
+
@field_serializer("llms", when_used="json")
|
|
37
|
+
def serialize_llms(self, v: set[str]) -> list[str]:
|
|
38
|
+
return list(v)
|
|
39
|
+
|
|
40
|
+
@field_serializer("embedders", when_used="json")
|
|
41
|
+
def serialize_embedders(self, v: set[str]) -> list[str]:
|
|
42
|
+
return list(v)
|
|
43
|
+
|
|
44
|
+
@field_serializer("memory_clients", when_used="json")
|
|
45
|
+
def serialize_memory_clients(self, v: set[str]) -> list[str]:
|
|
46
|
+
return list(v)
|
|
47
|
+
|
|
48
|
+
@field_serializer("object_stores", when_used="json")
|
|
49
|
+
def serialize_object_stores(self, v: set[str]) -> list[str]:
|
|
50
|
+
return list(v)
|
|
51
|
+
|
|
52
|
+
@field_serializer("retrievers", when_used="json")
|
|
53
|
+
def serialize_retrievers(self, v: set[str]) -> list[str]:
|
|
54
|
+
return list(v)
|
|
55
|
+
|
|
56
|
+
def add_function(self, function: str):
|
|
57
|
+
self.functions.add(function) # pylint: disable=no-member
|
|
58
|
+
|
|
59
|
+
def add_llm(self, llm: str):
|
|
60
|
+
self.llms.add(llm) # pylint: disable=no-member
|
|
61
|
+
|
|
62
|
+
def add_embedder(self, embedder: str):
|
|
63
|
+
self.embedders.add(embedder) # pylint: disable=no-member
|
|
64
|
+
|
|
65
|
+
def add_memory_client(self, memory_client: str):
|
|
66
|
+
self.memory_clients.add(memory_client) # pylint: disable=no-member
|
|
67
|
+
|
|
68
|
+
def add_object_store(self, object_store: str):
|
|
69
|
+
self.object_stores.add(object_store) # pylint: disable=no-member
|
|
70
|
+
|
|
71
|
+
def add_retriever(self, retriever: str):
|
|
72
|
+
self.retrievers.add(retriever) # pylint: disable=no-member
|